# Copyright (c) ONNX Project Contributors

# SPDX-License-Identifier: Apache-2.0
from __future__ import annotations

import os
import platform
import sys
import unittest
from typing import Any

import numpy
import version_utils

import onnx.backend.base
import onnx.backend.test
from onnx import ModelProto
from onnx.backend.base import Device, DeviceType
from onnx.reference import ReferenceEvaluator

# The following just executes a backend based on ReferenceEvaluator through the backend test
VERBOSE = int(os.environ.get("VERBOSE", "0"))


class ReferenceEvaluatorBackendRep(onnx.backend.base.BackendRep):
    def __init__(self, session):
        self._session = session

    def run(self, inputs, **kwargs):  # noqa: ARG002
        if isinstance(inputs, numpy.ndarray):
            inputs = [inputs]
        if isinstance(inputs, list):
            if len(inputs) == len(self._session.input_names):
                feeds = dict(zip(self._session.input_names, inputs, strict=True))
            else:
                feeds = {}
                pos_inputs = 0
                for inp, tshape in zip(
                    self._session.input_names, self._session.input_types, strict=True
                ):
                    shape = tuple(d.dim_value for d in tshape.tensor_type.shape.dim)
                    if shape == inputs[pos_inputs].shape:
                        feeds[inp] = inputs[pos_inputs]
                        pos_inputs += 1
                        if pos_inputs >= len(inputs):
                            break
        elif isinstance(inputs, dict):
            feeds = inputs
        else:
            raise TypeError(f"Unexpected input type {type(inputs)!r}.")
        return self._session.run(None, feeds)


class ReferenceEvaluatorBackend(onnx.backend.base.Backend):
    @classmethod
    def is_opset_supported(cls, model):  # noqa: ARG003
        return True, ""

    @classmethod
    def supports_device(cls, device: str) -> bool:
        d = Device(device)
        return d.type == DeviceType.CPU

    @classmethod
    def create_inference_session(cls, model):
        return ReferenceEvaluator(model, verbose=VERBOSE)

    @classmethod
    def prepare(
        cls, model: Any, device: str = "CPU", **kwargs: Any
    ) -> ReferenceEvaluatorBackendRep:
        # if isinstance(model, ReferenceEvaluatorBackendRep):
        #    return model
        if isinstance(model, ReferenceEvaluator):
            return ReferenceEvaluatorBackendRep(model)
        if isinstance(model, (str, bytes, ModelProto)):
            inf = cls.create_inference_session(model)
            return cls.prepare(inf, device, **kwargs)
        raise TypeError(f"Unexpected type {type(model)} for model.")

    @classmethod
    def run_model(cls, model, inputs, device=None, **kwargs):
        rep = cls.prepare(model, device, **kwargs)
        return rep.run(inputs, **kwargs)

    @classmethod
    def run_node(cls, node, inputs, device=None, outputs_info=None, **kwargs):
        raise NotImplementedError("Unable to run the model node by node.")


dft_atol = 1e-3 if sys.platform != "linux" else 1e-6
backend_test = onnx.backend.test.BackendTest(
    ReferenceEvaluatorBackend,
    __name__,
    test_kwargs={
        "test_dft": {"atol": dft_atol},
        "test_dft_axis": {"atol": dft_atol},
        "test_dft_axis_opset19": {"atol": dft_atol},
        "test_dft_inverse": {"atol": dft_atol},
        "test_dft_inverse_opset19": {"atol": dft_atol},
        "test_dft_opset19": {"atol": dft_atol},
    },
)

if os.getenv("APPVEYOR"):
    backend_test.exclude("(test_vgg19|test_zfnet)")
if platform.architecture()[0] == "32bit":
    backend_test.exclude("(test_vgg19|test_zfnet|test_bvlc_alexnet)")
if platform.system() == "Windows":
    backend_test.exclude("test_sequence_model")

# The following tests are not supported.
backend_test.exclude(
    "(test_gradient"
    "|test_if_opt"
    "|test_loop16_seq_none"
    "|test_range_float_type_positive_delta_expanded"
    "|test_range_int32_type_negative_delta_expanded"
    "|test_scan_sum)"
)

# The following tests are about deprecated operators.
backend_test.exclude("(test_scatter_with_axis|test_scatter_without)")

# The following tests are too slow with the reference implementation (Conv).
backend_test.exclude(
    "(test_bvlc_alexnet"
    "|test_densenet121"
    "|test_inception_v1"
    "|test_inception_v2"
    "|test_resnet50"
    "|test_shufflenet"
    "|test_squeezenet"
    "|test_vgg19"
    "|test_zfnet512)"
)

# The following tests cannot pass because they consists in generating random number.
backend_test.exclude("(test_bernoulli)")

# The following tests fail due to discrepancies (small but still higher than 1e-7).
backend_test.exclude("test_adam_multiple")  # 1e-2

# Currently Pillow is not supported on Win32 and is required for the reference implementation of RegexFullMatch.
if sys.platform == "win32":
    backend_test.exclude("test_regex_full_match_basic_cpu")
    backend_test.exclude("test_regex_full_match_email_domain_cpu")
    backend_test.exclude("test_regex_full_match_empty_cpu")
    backend_test.exclude("test_image_decoder_decode_")


if sys.platform == "darwin":
    # FIXME: https://github.com/onnx/onnx/issues/5792
    backend_test.exclude("test_qlinearmatmul_3D_int8_float16_cpu")
    backend_test.exclude("test_qlinearmatmul_3D_int8_float32_cpu")

if version_utils.pillow_older_than("10.0"):
    backend_test.exclude("test_image_decoder_decode_webp_rgb")
    backend_test.exclude("test_image_decoder_decode_jpeg2k_rgb")

if version_utils.numpy_older_than("2.0"):
    # assert_allclose does not support ml_dtypes types in numpy < 2.0
    backend_test.exclude(r"test_cast.*(FLOAT8|BFLOAT16|FLOAT4|INT4)")
    backend_test.exclude(r"test_quantizelinear_e4m3fn")
    backend_test.exclude(r"test_quantizelinear_float4e2m1")

# The documentation does not explicitly say that is_causal=1 and attn_mask is not None
# is not allowed. The expansion (based on the function definition in ONNX)
# assumes this case never happens and behaves likes is_causal=0 even if it is 1.
# The reference implementation and the backend tests have a different behavior in that case.
backend_test.exclude(
    "(test_attention_4d_with_past_and_present_qk_matmul_bias_4d_mask_causal_expanded"
    "|test_attention_4d_with_past_and_present_qk_matmul_bias_3d_mask_causal_expanded"
    "|test_attention_4d_attn_mask_4d_causal_expanded"
    "|test_attention_4d_attn_mask_3d_causal_expanded)"
)

# import all test cases at global scope to make them visible to python.unittest
globals().update(backend_test.test_cases)

if __name__ == "__main__":
    res = unittest.main(verbosity=2, exit=False)
    tests_run = res.result.testsRun
    errors = len(res.result.errors)
    skipped = len(res.result.skipped)
    unexpected_successes = len(res.result.unexpectedSuccesses)
    expected_failures = len(res.result.expectedFailures)
    print("---------------------------------")
    print(
        f"tests_run={tests_run} errors={errors} skipped={skipped} "
        f"unexpected_successes={unexpected_successes} "
        f"expected_failures={expected_failures}"
    )
