# Copyright (c) ONNX Project Contributors
#
# SPDX-License-Identifier: Apache-2.0
from __future__ import annotations

from typing import NamedTuple

import ml_dtypes
import numpy as np

from onnx.onnx_pb import TensorProto


class TensorDtypeMap(NamedTuple):
    np_dtype: np.dtype
    storage_dtype: int
    name: str


# tensor_dtype: (numpy type, storage type, string name)
# The storage type is the type used to store the tensor in the *_data field of
# a TensorProto. All available fields are float_data, int32_data, int64_data,
# string_data, uint64_data and double_data.
TENSOR_TYPE_MAP: dict[int, TensorDtypeMap] = {
    int(TensorProto.FLOAT): TensorDtypeMap(
        np.dtype("float32"), int(TensorProto.FLOAT), "TensorProto.FLOAT"
    ),
    int(TensorProto.UINT8): TensorDtypeMap(
        np.dtype("uint8"), int(TensorProto.INT32), "TensorProto.UINT8"
    ),
    int(TensorProto.INT8): TensorDtypeMap(
        np.dtype("int8"), int(TensorProto.INT32), "TensorProto.INT8"
    ),
    int(TensorProto.UINT16): TensorDtypeMap(
        np.dtype("uint16"), int(TensorProto.INT32), "TensorProto.UINT16"
    ),
    int(TensorProto.INT16): TensorDtypeMap(
        np.dtype("int16"), int(TensorProto.INT32), "TensorProto.INT16"
    ),
    int(TensorProto.INT32): TensorDtypeMap(
        np.dtype("int32"), int(TensorProto.INT32), "TensorProto.INT32"
    ),
    int(TensorProto.INT64): TensorDtypeMap(
        np.dtype("int64"), int(TensorProto.INT64), "TensorProto.INT64"
    ),
    int(TensorProto.BOOL): TensorDtypeMap(
        np.dtype("bool"), int(TensorProto.INT32), "TensorProto.BOOL"
    ),
    int(TensorProto.FLOAT16): TensorDtypeMap(
        np.dtype("float16"), int(TensorProto.INT32), "TensorProto.FLOAT16"
    ),
    int(TensorProto.BFLOAT16): TensorDtypeMap(
        np.dtype(ml_dtypes.bfloat16),
        int(TensorProto.INT32),
        "TensorProto.BFLOAT16",
    ),
    int(TensorProto.DOUBLE): TensorDtypeMap(
        np.dtype("float64"), int(TensorProto.DOUBLE), "TensorProto.DOUBLE"
    ),
    int(TensorProto.COMPLEX64): TensorDtypeMap(
        np.dtype("complex64"), int(TensorProto.FLOAT), "TensorProto.COMPLEX64"
    ),
    int(TensorProto.COMPLEX128): TensorDtypeMap(
        np.dtype("complex128"),
        int(TensorProto.DOUBLE),
        "TensorProto.COMPLEX128",
    ),
    int(TensorProto.UINT32): TensorDtypeMap(
        np.dtype("uint32"), int(TensorProto.UINT64), "TensorProto.UINT32"
    ),
    int(TensorProto.UINT64): TensorDtypeMap(
        np.dtype("uint64"), int(TensorProto.UINT64), "TensorProto.UINT64"
    ),
    int(TensorProto.STRING): TensorDtypeMap(
        np.dtype("object"), int(TensorProto.STRING), "TensorProto.STRING"
    ),
    int(TensorProto.FLOAT8E4M3FN): TensorDtypeMap(
        np.dtype(ml_dtypes.float8_e4m3fn),
        int(TensorProto.INT32),
        "TensorProto.FLOAT8E4M3FN",
    ),
    int(TensorProto.FLOAT8E4M3FNUZ): TensorDtypeMap(
        np.dtype(ml_dtypes.float8_e4m3fnuz),
        int(TensorProto.INT32),
        "TensorProto.FLOAT8E4M3FNUZ",
    ),
    int(TensorProto.FLOAT8E5M2): TensorDtypeMap(
        np.dtype(ml_dtypes.float8_e5m2),
        int(TensorProto.INT32),
        "TensorProto.FLOAT8E5M2",
    ),
    int(TensorProto.FLOAT8E5M2FNUZ): TensorDtypeMap(
        np.dtype(ml_dtypes.float8_e5m2fnuz),
        int(TensorProto.INT32),
        "TensorProto.FLOAT8E5M2FNUZ",
    ),
    int(TensorProto.UINT4): TensorDtypeMap(
        np.dtype(ml_dtypes.uint4), int(TensorProto.INT32), "TensorProto.UINT4"
    ),
    int(TensorProto.INT4): TensorDtypeMap(
        np.dtype(ml_dtypes.int4), int(TensorProto.INT32), "TensorProto.INT4"
    ),
    int(TensorProto.FLOAT4E2M1): TensorDtypeMap(
        np.dtype(ml_dtypes.float4_e2m1fn),
        int(TensorProto.INT32),
        "TensorProto.FLOAT4E2M1",
    ),
    int(TensorProto.FLOAT8E8M0): TensorDtypeMap(
        np.dtype(ml_dtypes.float8_e8m0fnu),
        int(TensorProto.INT32),
        "TensorProto.FLOAT8E8M0",
    ),
    int(TensorProto.UINT2): TensorDtypeMap(
        np.dtype(ml_dtypes.uint2), int(TensorProto.INT32), "TensorProto.UINT2"
    ),
    int(TensorProto.INT2): TensorDtypeMap(
        np.dtype(ml_dtypes.int2), int(TensorProto.INT32), "TensorProto.INT2"
    ),
}
