# Copyright (c) ONNX Project Contributors
#
# SPDX-License-Identifier: Apache-2.0
from __future__ import annotations

import collections.abc
import functools
import math
import numbers
import typing
from typing import TYPE_CHECKING, Any, TypeVar

import google.protobuf.message
import numpy as np
import numpy.typing as npt
import typing_extensions

import onnx
from onnx import _mapping, defs
from onnx.onnx_data_pb import MapProto, OptionalProto, SequenceProto
from onnx.onnx_pb import (
    AttributeProto,
    FunctionProto,
    GraphProto,
    ModelProto,
    NodeProto,
    OperatorSetIdProto,
    TensorProto,
    TensorShapeProto,
    TrainingInfoProto,
    TypeProto,
    ValueInfoProto,
)

if TYPE_CHECKING:
    from collections.abc import Callable, KeysView, Sequence

    from google.protobuf.internal.containers import RepeatedCompositeFieldContainer

VersionRowType = tuple[str, int, int, int] | tuple[str, int, int, int, int]
VersionTableType = list[VersionRowType]
AssignmentBindingType = list[tuple[str, str]]

# This is a copy of the documented version in https://github.com/onnx/onnx/blob/main/docs/Versioning.md#released-versions
# Both must be updated whenever a new version of ONNX is released.
VERSION_TABLE: VersionTableType = [
    # Release-version, IR version, ai.onnx version, ai.onnx.ml version, (optional) ai.onnx.training version
    ("1.0", 3, 1, 1),
    ("1.1", 3, 5, 1),
    ("1.1.2", 3, 6, 1),
    ("1.2", 3, 7, 1),
    ("1.3", 3, 8, 1),
    ("1.4.1", 4, 9, 1),
    ("1.5.0", 5, 10, 1),
    ("1.6.0", 6, 11, 2),
    ("1.7.0", 7, 12, 2, 1),
    ("1.8.0", 7, 13, 2, 1),
    ("1.8.1", 7, 13, 2, 1),
    ("1.9.0", 7, 14, 2, 1),
    ("1.10.0", 8, 15, 2, 1),
    ("1.10.1", 8, 15, 2, 1),
    ("1.10.2", 8, 15, 2, 1),
    ("1.11.0", 8, 16, 3, 1),
    ("1.12.0", 8, 17, 3, 1),
    ("1.13.0", 8, 18, 3, 1),
    ("1.13.1", 8, 18, 3, 1),
    ("1.14.0", 9, 19, 3, 1),
    ("1.14.1", 9, 19, 3, 1),
    ("1.15.0", 9, 20, 4, 1),
    ("1.16.0", 10, 21, 5, 1),
    ("1.16.1", 10, 21, 5, 1),
    ("1.16.2", 10, 21, 5, 1),
    ("1.17.0", 10, 22, 5, 1),
    ("1.18.0", 11, 23, 5, 1),
    ("1.19.0", 12, 24, 5, 1),
    ("1.19.1", 12, 24, 5, 1),
    ("1.20.0", 13, 25, 5, 1),
    ("1.20.1", 13, 25, 5, 1),
]

VersionMapType = dict[tuple[str, int], int]


def _create_op_set_id_version_map(table: VersionTableType) -> VersionMapType:
    """Create a map from (opset-domain, opset-version) to ir-version from above table."""
    result: VersionMapType = {}

    def process(release_version: str, ir_version: int, *args: Any) -> None:
        del release_version  # Unused
        for pair in zip(
            ["ai.onnx", "ai.onnx.ml", "ai.onnx.training"], args, strict=False
        ):
            if pair not in result:
                result[pair] = ir_version
                if pair[0] == "ai.onnx.training":
                    result["ai.onnx.preview.training", pair[1]] = ir_version

    for row in table:
        process(*row)
    return result


OP_SET_ID_VERSION_MAP = _create_op_set_id_version_map(VERSION_TABLE)


def find_min_ir_version_for(
    opsetidlist: Sequence[OperatorSetIdProto], ignore_unknown: bool = False
) -> int:
    """Given list of opset ids, determine minimum IR version required.

    Args:
        opsetidlist: A sequence of OperatorSetIdProto.
        ignore_unknown: If True, ignore unknown domain and return default minimum
            version for that domain.

    Returns:
        The minimum IR version required (integer)
    """
    default_min_version = 3

    def find_min(domain: str | None, version: int) -> int:
        key = (domain or "ai.onnx", version)
        if key in OP_SET_ID_VERSION_MAP:
            return OP_SET_ID_VERSION_MAP[key]
        if ignore_unknown:
            return default_min_version
        raise ValueError("Unsupported opset-version.")

    if opsetidlist:
        return max(find_min(x.domain, x.version) for x in opsetidlist)
    return default_min_version  # if no opsets specified


def make_node(
    op_type: str,
    inputs: Sequence[str],
    outputs: Sequence[str],
    name: str | None = None,
    doc_string: str | None = None,
    domain: str | None = None,
    overload: str | None = None,
    **kwargs: Any,
) -> NodeProto:
    """Construct a NodeProto.

    Args:
        op_type (string): The name of the operator to construct
        inputs (list of string): list of input names
        outputs (list of string): list of output names
        name (string, default None): optional unique identifier for NodeProto
        doc_string (string, default None): optional documentation string for NodeProto
        domain (string, default None): optional domain for NodeProto.
            If it's None, we will just use default domain (which is empty)
        overload (string, default None): optional field, used to
            resolve calls to model-local functions
        **kwargs (dict): the attributes of the node.  The acceptable values
            are documented in :func:`make_attribute`.

    Returns:
        NodeProto
    """
    node = NodeProto()
    node.op_type = op_type
    node.input.extend(inputs)
    node.output.extend(outputs)
    if name:
        node.name = name
    if doc_string:
        node.doc_string = doc_string
    if domain is not None:
        node.domain = domain
    if overload is not None:
        node.overload = overload
    if kwargs:
        node.attribute.extend(
            make_attribute(key, value)
            for key, value in sorted(kwargs.items())
            if value is not None
        )
    return node


def make_operatorsetid(
    domain: str,
    version: int,
) -> OperatorSetIdProto:
    """Construct an OperatorSetIdProto.

    Args:
        domain (string): The domain of the operator set id
        version (integer): Version of operator set id
    Returns:
        OperatorSetIdProto
    """
    operatorsetid = OperatorSetIdProto()
    operatorsetid.domain = domain
    operatorsetid.version = version
    return operatorsetid


def make_graph(
    nodes: Sequence[NodeProto],
    name: str,
    inputs: Sequence[ValueInfoProto],
    outputs: Sequence[ValueInfoProto],
    initializer: Sequence[TensorProto] | None = None,
    doc_string: str | None = None,
    value_info: Sequence[ValueInfoProto] | None = None,
    sparse_initializer: Sequence[onnx.SparseTensorProto] | None = None,
) -> GraphProto:
    """Construct a GraphProto

    Args:
        nodes: list of NodeProto
        name (string): graph name
        inputs: list of ValueInfoProto
        outputs: list of ValueInfoProto
        initializer: list of TensorProto
        doc_string (string): graph documentation
        value_info: list of ValueInfoProto
        sparse_initializer: list of onnx.SparseTensorProto
    Returns:
        GraphProto
    """
    if initializer is None:
        initializer = []
    if sparse_initializer is None:
        sparse_initializer = []
    if value_info is None:
        value_info = []
    graph = GraphProto()
    graph.node.extend(nodes)
    graph.name = name
    graph.input.extend(inputs)
    graph.output.extend(outputs)
    graph.initializer.extend(initializer)
    graph.sparse_initializer.extend(sparse_initializer)
    graph.value_info.extend(value_info)
    if doc_string:
        graph.doc_string = doc_string
    return graph


def make_opsetid(domain: str, version: int) -> OperatorSetIdProto:
    """Construct an OperatorSetIdProto.

    Args:
        domain (string): The domain of the operator set id
        version (integer): Version of operator set id
    Returns:
        OperatorSetIdProto
    """
    opsetid = OperatorSetIdProto()
    opsetid.domain = domain
    opsetid.version = version
    return opsetid


def make_function(
    domain: str,
    fname: str,
    inputs: Sequence[str],
    outputs: Sequence[str],
    nodes: Sequence[NodeProto],
    opset_imports: Sequence[OperatorSetIdProto],
    attributes: Sequence[str] | None = None,
    attribute_protos: Sequence[AttributeProto] | None = None,
    doc_string: str | None = None,
    overload: str | None = None,
    value_info: Sequence[ValueInfoProto] | None = None,
) -> FunctionProto:
    if attributes is None:
        attributes = []
    if attribute_protos is None:
        attribute_protos = []
    if value_info is None:
        value_info = []
    f = FunctionProto()
    f.domain = domain
    f.name = fname
    f.input.extend(inputs)
    f.output.extend(outputs)
    f.node.extend(nodes)
    f.opset_import.extend(opset_imports)
    f.attribute.extend(attributes)
    f.attribute_proto.extend(attribute_protos)
    if doc_string:
        f.doc_string = doc_string
    if overload is not None:
        f.overload = overload
    f.value_info.extend(value_info)
    return f


def make_model(graph: GraphProto, **kwargs: Any) -> ModelProto:
    """Construct a ModelProto

    Args:
        graph (GraphProto): *make_graph* returns
        **kwargs: any attribute to add to the returned instance
    Returns:
        ModelProto
    """
    model = ModelProto()
    # Touch model.ir_version so it is stored as the version from which it is
    # generated.
    model.ir_version = onnx.IR_VERSION
    model.graph.CopyFrom(graph)

    opset_imports: Sequence[OperatorSetIdProto] | None = kwargs.pop(
        "opset_imports", None
    )
    if opset_imports is not None:
        model.opset_import.extend(opset_imports)
    else:
        # Default import
        imp = model.opset_import.add()
        imp.version = defs.onnx_opset_version()

    functions: Sequence[FunctionProto] | None = kwargs.pop("functions", None)
    if functions is not None:
        model.functions.extend(functions)

    for k, v in kwargs.items():
        # TODO: Does this work with repeated fields?
        setattr(model, k, v)
    return model


# An extension of make_model that infers an IR_VERSION for the model,
# if not specified, using a best-effort-basis.
def make_model_gen_version(graph: GraphProto, **kwargs: Any) -> ModelProto:
    ir_version_field = "ir_version"
    if ir_version_field not in kwargs:
        opset_imports_field = "opset_imports"
        imports = kwargs.get(opset_imports_field, [])
        kwargs[ir_version_field] = find_min_ir_version_for(imports)
    return make_model(graph, **kwargs)


def set_metadata_props(
    proto: (
        ModelProto
        | GraphProto
        | FunctionProto
        | NodeProto
        | TensorProto
        | ValueInfoProto
    ),
    dict_value: dict[str, str],
) -> None:
    del proto.metadata_props[:]
    for k, v in dict_value.items():
        entry = proto.metadata_props.add()
        entry.key = k
        entry.value = v


def set_model_props(model: ModelProto, dict_value: dict[str, str]) -> None:
    set_metadata_props(model, dict_value)


def _pack_4bitx2(array: np.ndarray) -> npt.NDArray[np.uint8]:
    """Convert a numpy array to flatten, packed int4/uint4. Elements must be in the correct range."""
    # Create a 1D copy
    array_flat = array.ravel().view(np.uint8).copy()
    size = array.size
    odd_sized = size % 2 == 1
    if odd_sized:
        array_flat.resize([size + 1], refcheck=False)
    array_flat &= 0x0F
    array_flat[1::2] <<= 4
    return array_flat[0::2] | array_flat[1::2]  # type: ignore[return-type]


def _pack_2bitx4(array: np.ndarray) -> npt.NDArray[np.uint8]:
    """Convert a numpy array to flatten, packed int2/uint2. Elements must be in the correct range."""
    # Create a 1D copy
    array_flat = array.ravel().view(np.uint8).copy()
    size = array.size
    pad_len = size % 4
    if pad_len:
        array_flat.resize([size + (4 - pad_len)], refcheck=False)
    array_flat &= 0x03
    array_flat[1::4] <<= 2
    array_flat[2::4] <<= 4
    array_flat[3::4] <<= 6
    return array_flat[0::4] | array_flat[1::4] | array_flat[2::4] | array_flat[3::4]  # type: ignore[return-type]


def make_tensor(
    name: str,
    data_type: int,
    dims: Sequence[int],
    vals: Sequence[int | float] | bytes | np.ndarray,
    raw: bool = False,
) -> TensorProto:
    """Make a TensorProto with specified arguments.  If raw is False, this
    function will choose the corresponding proto field to store the
    values based on data_type. If raw is True, use "raw_data" proto
    field to store the values, and values should be of type bytes in
    this case.

    Args:
        name: tensor name
        data_type: a value such as onnx.TensorProto.FLOAT
        dims: shape
        vals: values
        raw: if True, vals contains the serialized content of the tensor,
            otherwise, vals should be a list of values of the type defined by ``data_type``.

    Returns:
        TensorProto
    """
    tensor = TensorProto()
    tensor.data_type = data_type
    tensor.name = name
    tensor.dims.extend(dims)

    if data_type == TensorProto.STRING and raw:
        raise TypeError("Can not use raw_data to store string type.")

    np_dtype = tensor_dtype_to_np_dtype(data_type)

    if raw:
        # NumPy doesn't have INT2/INT4/FP4. It is packed in couples to UINT8 buffers.
        if data_type in {TensorProto.UINT4, TensorProto.INT4, TensorProto.FLOAT4E2M1}:
            expected_size_bytes = 0.5
        elif data_type in {TensorProto.UINT2, TensorProto.INT2}:
            expected_size_bytes = 0.25
        else:
            expected_size_bytes = np_dtype.itemsize
        expected_size_bytes *= math.prod(dims)
        expected_size_bytes = math.ceil(expected_size_bytes)
        if isinstance(vals, np.ndarray):
            if data_type in {
                TensorProto.INT4,
                TensorProto.UINT4,
                TensorProto.FLOAT4E2M1,
            }:
                vals = onnx.numpy_helper._pack_4bitx2(vals)
            elif data_type in {TensorProto.UINT2, TensorProto.INT2}:
                vals = onnx.numpy_helper._pack_2bitx4(vals)

            raw_data = onnx.numpy_helper.tobytes_little_endian(vals)
        elif isinstance(vals, bytes):
            raw_data = vals
        else:
            raise TypeError(
                f"Raw data must be bytes or numpy.ndarray, but got {type(vals)}."
            )
        if len(raw_data) != expected_size_bytes:
            raise ValueError(
                f"Raw data size does not match tensor's size. Expected {expected_size_bytes} bytes, but got {len(raw_data)} bytes."
            )
        tensor.raw_data = raw_data
        return tensor

    assert not raw, "Bug: raw should be False at this point."

    if data_type == TensorProto.STRING:
        vals = np.array(vals).flatten()
        if len(vals) != 0:
            vals = np.vectorize(_to_bytes)(vals)  # Convert to bytes
    elif data_type in {
        TensorProto.FLOAT8E4M3FN,
        TensorProto.FLOAT8E4M3FNUZ,
        TensorProto.FLOAT8E5M2,
        TensorProto.FLOAT8E5M2FNUZ,
    }:
        # Float8 values are by default casted using saturating cast.
        vals = onnx.numpy_helper.saturate_cast(np.asarray(vals), np_dtype).flatten()
    elif data_type == TensorProto.FLOAT8E8M0:
        vals = onnx.numpy_helper.to_float8e8m0(
            np.asarray(vals), saturate=True, round_mode="up"
        ).flatten()
    else:
        vals = np.asarray(vals, dtype=np_dtype).flatten()

    if data_type == TensorProto.COMPLEX128:
        vals = vals.view(np.float64)  # type: ignore[union-attr]
    elif data_type == TensorProto.COMPLEX64:
        vals = vals.view(np.float32)  # type: ignore[union-attr]
    elif data_type in {TensorProto.BFLOAT16, TensorProto.FLOAT16}:
        vals = vals.view(np.uint16)  # type: ignore[union-attr]
    elif data_type in {
        TensorProto.FLOAT8E4M3FN,
        TensorProto.FLOAT8E4M3FNUZ,
        TensorProto.FLOAT8E5M2,
        TensorProto.FLOAT8E5M2FNUZ,
        TensorProto.FLOAT8E8M0,
    }:
        vals = vals.view(np.uint8)  # type: ignore[union-attr]
    elif data_type in {TensorProto.UINT4, TensorProto.INT4, TensorProto.FLOAT4E2M1}:
        # Convert to packed 4-bit representation
        vals = _pack_4bitx2(vals)  # type: ignore[union-attr,arg-type]
    elif data_type in {TensorProto.UINT2, TensorProto.INT2}:
        # Convert to packed 2-bit representation
        vals = _pack_2bitx4(vals)  # type: ignore[union-attr,arg-type]
    elif data_type == TensorProto.BOOL:
        vals = vals.astype(np.uint8)  # type: ignore[union-attr]

    field = tensor_dtype_to_field(data_type)
    getattr(tensor, field).extend(vals)
    return tensor


def make_sparse_tensor(
    values: TensorProto, indices: TensorProto, dims: Sequence[int]
) -> onnx.SparseTensorProto:
    """Construct a SparseTensorProto

    Args:
        values (TensorProto): the values
        indices (TensorProto): the indices
        dims: the shape

    Returns:
        SparseTensorProto
    """
    sparse = onnx.SparseTensorProto()
    sparse.values.CopyFrom(values)
    sparse.indices.CopyFrom(indices)
    sparse.dims.extend(dims)
    return sparse


def make_sequence(
    name: str,
    elem_type: SequenceProto.DataType,
    values: Sequence[Any],
) -> SequenceProto:
    """Make a Sequence with specified value arguments."""
    sequence = SequenceProto()
    sequence.name = name
    sequence.elem_type = elem_type

    if elem_type == SequenceProto.UNDEFINED:
        return sequence

    attribute: RepeatedCompositeFieldContainer | None = None
    if elem_type == SequenceProto.TENSOR:
        attribute = sequence.tensor_values
    elif elem_type == SequenceProto.SPARSE_TENSOR:
        attribute = sequence.sparse_tensor_values
    elif elem_type == SequenceProto.SEQUENCE:
        attribute = sequence.sequence_values
    elif elem_type == SequenceProto.MAP:
        attribute = sequence.map_values
    elif elem_type == OptionalProto.OPTIONAL:
        attribute = sequence.optional_values
    else:
        raise TypeError("The element type in the input sequence is not supported.")

    attribute.extend(values)
    return sequence


def make_map(
    name: str, key_type: int, keys: list[Any], values: SequenceProto
) -> MapProto:
    """Make a Map with specified key-value pair arguments.

    Criteria for conversion:
    - Keys and Values must have the same number of elements
    - Every key in keys must be of the same type
    - Every value in values must be of the same type
    """
    map_proto = MapProto()
    valid_key_int_types = [
        TensorProto.INT8,
        TensorProto.INT16,
        TensorProto.INT32,
        TensorProto.INT64,
        TensorProto.UINT8,
        TensorProto.UINT16,
        TensorProto.UINT32,
        TensorProto.UINT64,
    ]
    map_proto.name = name
    map_proto.key_type = key_type
    if key_type == TensorProto.STRING:
        map_proto.string_keys.extend(keys)
    elif key_type in valid_key_int_types:
        map_proto.keys.extend(keys)
    map_proto.values.CopyFrom(values)
    return map_proto


def make_optional(
    name: str,
    elem_type: OptionalProto.DataType,
    value: google.protobuf.message.Message | None,
) -> OptionalProto:
    """Make an Optional with specified value arguments."""
    optional = OptionalProto()
    optional.name = name
    optional.elem_type = elem_type

    if elem_type == OptionalProto.UNDEFINED:
        return optional
    attribute: google.protobuf.message.Message | None = None
    if elem_type == OptionalProto.TENSOR:
        attribute = optional.tensor_value
    elif elem_type == OptionalProto.SPARSE_TENSOR:
        attribute = optional.sparse_tensor_value
    elif elem_type == OptionalProto.SEQUENCE:
        attribute = optional.sequence_value
    elif elem_type == OptionalProto.MAP:
        attribute = optional.map_value
    elif elem_type == OptionalProto.OPTIONAL:
        attribute = optional.optional_value
    else:
        raise TypeError("The element type in the input optional is not supported.")

    assert value is not None
    attribute.CopyFrom(value)  # type: ignore[arg-type]
    return optional


def _to_bytes(value: str | bytes) -> bytes:
    """Coerce a string (or bytes) value into UTF-8 bytes."""
    if isinstance(value, str):
        return value.encode("utf-8")
    return value


def make_attribute(
    key: str,
    value: Any,
    doc_string: str | None = None,
    attr_type: int | None = None,
) -> AttributeProto:
    """Makes an AttributeProto based on the value type."""
    attr = AttributeProto()
    attr.name = key
    if doc_string:
        attr.doc_string = doc_string

    # Singular cases
    if isinstance(value, numbers.Integral):
        attr.i = int(value)
        attr.type = AttributeProto.INT
    elif isinstance(value, numbers.Real):
        attr.f = float(value)
        attr.type = AttributeProto.FLOAT
    elif isinstance(value, (str, bytes)):
        # Encode strings into utf-8
        attr.s = _to_bytes(value)
        attr.type = AttributeProto.STRING
    elif isinstance(value, TensorProto):
        attr.t.CopyFrom(value)
        attr.type = AttributeProto.TENSOR
    elif isinstance(value, onnx.SparseTensorProto):
        attr.sparse_tensor.CopyFrom(value)
        attr.type = AttributeProto.SPARSE_TENSOR
    elif isinstance(value, GraphProto):
        attr.g.CopyFrom(value)
        attr.type = AttributeProto.GRAPH
    elif isinstance(value, TypeProto):
        attr.tp.CopyFrom(value)
        attr.type = AttributeProto.TYPE_PROTO
    # Iterable cases
    elif isinstance(value, collections.abc.Iterable):
        value = list(value)
        if len(value) == 0 and attr_type is None:
            raise ValueError(
                f"Could not infer attribute `{key}` type from empty iterator"
            )
        if attr_type is None:
            types = {type(v) for v in value}
            for exp_t, exp_enum in (
                (numbers.Integral, AttributeProto.INTS),
                (numbers.Real, AttributeProto.FLOATS),
                ((str, bytes), AttributeProto.STRINGS),
                (TensorProto, AttributeProto.TENSORS),
                (onnx.SparseTensorProto, AttributeProto.SPARSE_TENSORS),
                (GraphProto, AttributeProto.GRAPHS),
                (TypeProto, AttributeProto.TYPE_PROTOS),
            ):
                if all(issubclass(t, exp_t) for t in types):  # type: ignore[arg-type]
                    attr_type = exp_enum
                    break
            if attr_type is None:
                raise ValueError(
                    "Could not infer the attribute type from the elements of the passed Iterable value."
                )

        if attr_type == AttributeProto.INTS:
            attr.ints.extend(value)
            attr.type = AttributeProto.INTS
        elif attr_type == AttributeProto.FLOATS:
            attr.floats.extend(value)
            attr.type = AttributeProto.FLOATS
        elif attr_type == AttributeProto.STRINGS:
            attr.strings.extend(_to_bytes(v) for v in value)
            attr.type = AttributeProto.STRINGS
        elif attr_type == AttributeProto.TENSORS:
            attr.tensors.extend(value)
            attr.type = AttributeProto.TENSORS
        elif attr_type == AttributeProto.SPARSE_TENSORS:
            attr.sparse_tensors.extend(value)
            attr.type = AttributeProto.SPARSE_TENSORS
        elif attr_type == AttributeProto.GRAPHS:
            attr.graphs.extend(value)
            attr.type = AttributeProto.GRAPHS
        elif attr_type == AttributeProto.TYPE_PROTOS:
            attr.type_protos.extend(value)
            attr.type = AttributeProto.TYPE_PROTOS
        else:
            raise AssertionError()  # Should not reach since `ValueError` must be raised in attr_type checking
    else:
        raise TypeError(f"'{value}' is not an accepted attribute value.")

    if attr_type is not None and attr.type != attr_type:
        raise TypeError(
            f"Inferred attribute type '{_attr_type_to_str(attr.type)}'({attr.type}) mismatched with specified type '{_attr_type_to_str(attr_type)}'({attr_type})"
        )
    return attr


def make_attribute_ref(
    name: str, attr_type: AttributeProto.AttributeType, doc_string: str | None = None
) -> AttributeProto:
    """Make an AttributeProto holding a reference to the parent function's attribute of given name and type."""
    attr = AttributeProto()
    attr.name = name
    attr.type = attr_type
    if doc_string:
        attr.doc_string = doc_string
    return attr


def get_attribute_value(attr: AttributeProto) -> Any:  # noqa: PLR0911
    if attr.ref_attr_name:
        raise ValueError(f"Cannot get value of reference attribute: {attr}")
    if attr.type == AttributeProto.FLOAT:
        return attr.f
    if attr.type == AttributeProto.INT:
        return attr.i
    if attr.type == AttributeProto.STRING:
        return attr.s
    if attr.type == AttributeProto.TENSOR:
        return attr.t
    if attr.type == AttributeProto.SPARSE_TENSOR:
        return attr.sparse_tensor
    if attr.type == AttributeProto.GRAPH:
        return attr.g
    if attr.type == AttributeProto.TYPE_PROTO:
        return attr.tp
    if attr.type == AttributeProto.FLOATS:
        return list(attr.floats)
    if attr.type == AttributeProto.INTS:
        return list(attr.ints)
    if attr.type == AttributeProto.STRINGS:
        return list(attr.strings)
    if attr.type == AttributeProto.TENSORS:
        return list(attr.tensors)
    if attr.type == AttributeProto.SPARSE_TENSORS:
        return list(attr.sparse_tensors)
    if attr.type == AttributeProto.GRAPHS:
        return list(attr.graphs)
    if attr.type == AttributeProto.TYPE_PROTOS:
        return list(attr.type_protos)
    if attr.type == AttributeProto.UNDEFINED:
        return None
    raise ValueError(f"Unsupported ONNX attribute: {attr}")


def get_node_attr_value(node: NodeProto, attr_name: str) -> Any:
    matching = [x for x in node.attribute if x.name == attr_name]
    if len(matching) > 1:
        raise ValueError(f"Node has multiple attributes with name {attr_name}")
    if len(matching) < 1:
        raise ValueError(f"Node has no attribute with name {attr_name}")
    return get_attribute_value(matching[0])


def make_empty_tensor_value_info(name: str) -> ValueInfoProto:
    value_info_proto = ValueInfoProto()
    value_info_proto.name = name
    return value_info_proto


def make_tensor_type_proto(
    elem_type: int,
    shape: Sequence[str | int | None] | None,
    shape_denotation: list[str] | None = None,
) -> TypeProto:
    """Makes a Tensor TypeProto based on the data type and shape."""
    type_proto = TypeProto()
    tensor_type_proto = type_proto.tensor_type
    tensor_type_proto.elem_type = elem_type
    tensor_shape_proto = tensor_type_proto.shape

    if shape is not None:
        # You might think this is a no-op (extending a normal Python
        # list by [] certainly is), but protobuf lists work a little
        # differently; if a field is never set, it is omitted from the
        # resulting protobuf; a list that is explicitly set to be
        # empty will get an (empty) entry in the protobuf. This
        # difference is visible to our consumers, so make sure we emit
        # an empty shape!
        tensor_shape_proto.dim.extend([])

        if shape_denotation and len(shape_denotation) != len(shape):
            raise ValueError(
                "Invalid shape_denotation. Must be of the same length as shape."
            )

        for i, d in enumerate(shape):
            dim = tensor_shape_proto.dim.add()
            if d is None:
                pass
            elif isinstance(d, int):
                dim.dim_value = d
            elif isinstance(d, str):
                dim.dim_param = d
            else:
                raise ValueError(
                    f"Invalid item in shape: {d}. Needs to be of int or str."
                )

            if shape_denotation:
                dim.denotation = shape_denotation[i]

    return type_proto


def make_tensor_value_info(
    name: str,
    elem_type: int,
    shape: Sequence[str | int | None] | None,
    doc_string: str = "",
    shape_denotation: list[str] | None = None,
) -> ValueInfoProto:
    """Makes a ValueInfoProto based on the data type and shape."""
    value_info_proto = ValueInfoProto()
    value_info_proto.name = name
    if doc_string:
        value_info_proto.doc_string = doc_string

    tensor_type_proto = make_tensor_type_proto(elem_type, shape, shape_denotation)
    value_info_proto.type.CopyFrom(tensor_type_proto)
    return value_info_proto


def make_sparse_tensor_type_proto(
    elem_type: int,
    shape: Sequence[str | int | None] | None,
    shape_denotation: list[str] | None = None,
) -> TypeProto:
    """Makes a SparseTensor TypeProto based on the data type and shape."""
    type_proto = TypeProto()
    sparse_tensor_type_proto = type_proto.sparse_tensor_type
    sparse_tensor_type_proto.elem_type = elem_type
    sparse_tensor_shape_proto = sparse_tensor_type_proto.shape

    if shape is not None:
        # You might think this is a no-op (extending a normal Python
        # list by [] certainly is), but protobuf lists work a little
        # differently; if a field is never set, it is omitted from the
        # resulting protobuf; a list that is explicitly set to be
        # empty will get an (empty) entry in the protobuf. This
        # difference is visible to our consumers, so make sure we emit
        # an empty shape!
        sparse_tensor_shape_proto.dim.extend([])

        if shape_denotation and len(shape_denotation) != len(shape):
            raise ValueError(
                "Invalid shape_denotation. Must be of the same length as shape."
            )

        for i, d in enumerate(shape):
            dim = sparse_tensor_shape_proto.dim.add()
            if d is None:
                pass
            elif isinstance(d, int):
                dim.dim_value = d
            elif isinstance(d, str):
                dim.dim_param = d
            else:
                raise ValueError(
                    f"Invalid item in shape: {d}. Needs to be of int or text."
                )

            if shape_denotation:
                dim.denotation = shape_denotation[i]

    return type_proto


def make_sparse_tensor_value_info(
    name: str,
    elem_type: int,
    shape: Sequence[str | int | None] | None,
    doc_string: str = "",
    shape_denotation: list[str] | None = None,
) -> ValueInfoProto:
    """Makes a SparseTensor ValueInfoProto based on the data type and shape."""
    value_info_proto = ValueInfoProto()
    value_info_proto.name = name
    if doc_string:
        value_info_proto.doc_string = doc_string

    sparse_tensor_type_proto = make_sparse_tensor_type_proto(
        elem_type, shape, shape_denotation
    )
    value_info_proto.type.sparse_tensor_type.CopyFrom(
        sparse_tensor_type_proto.sparse_tensor_type
    )
    return value_info_proto


def make_sequence_type_proto(
    inner_type_proto: TypeProto,
) -> TypeProto:
    """Makes a sequence TypeProto."""
    type_proto = TypeProto()
    type_proto.sequence_type.elem_type.CopyFrom(inner_type_proto)
    return type_proto


def make_optional_type_proto(
    inner_type_proto: TypeProto,
) -> TypeProto:
    """Makes an optional TypeProto."""
    type_proto = TypeProto()
    type_proto.optional_type.elem_type.CopyFrom(inner_type_proto)
    return type_proto


def make_map_type_proto(
    key_type: int,
    value_type: TypeProto,
) -> TypeProto:
    """Makes a map TypeProto."""
    type_proto = TypeProto()
    type_proto.map_type.key_type = key_type
    type_proto.map_type.value_type.CopyFrom(value_type)
    return type_proto


def make_value_info(
    name: str,
    type_proto: TypeProto,
    doc_string: str = "",
) -> ValueInfoProto:
    """Makes a ValueInfoProto with the given type_proto."""
    value_info_proto = ValueInfoProto()
    value_info_proto.name = name
    if doc_string:
        value_info_proto.doc_string = doc_string

    value_info_proto.type.CopyFrom(type_proto)
    return value_info_proto


def _sanitize_str(s: str | bytes) -> str:
    if isinstance(s, str):
        sanitized = s
    elif isinstance(s, bytes):
        sanitized = s.decode("utf-8", errors="ignore")
    else:
        sanitized = str(s)
    if len(sanitized) < 64:  # noqa: PLR2004
        return sanitized
    return sanitized[:64] + f"...<+len={(len(sanitized) - 64)}>"


def make_tensor_sequence_value_info(
    name: str,
    elem_type: int,
    shape: Sequence[str | int | None] | None,
    doc_string: str = "",
    elem_shape_denotation: list[str] | None = None,
) -> ValueInfoProto:
    """Makes a Sequence[Tensors] ValueInfoProto based on the data type and shape."""
    value_info_proto = ValueInfoProto()
    value_info_proto.name = name
    if doc_string:
        value_info_proto.doc_string = doc_string

    tensor_type_proto = make_tensor_type_proto(elem_type, shape, elem_shape_denotation)
    sequence_type_proto = make_sequence_type_proto(tensor_type_proto)
    value_info_proto.type.sequence_type.CopyFrom(sequence_type_proto.sequence_type)

    return value_info_proto


def printable_attribute(
    attr: AttributeProto, subgraphs: bool = False
) -> str | tuple[str, list[GraphProto]]:
    content = []
    content.append(attr.name)
    content.append("=")

    def str_float(f: float) -> str:
        # NB: Different Python versions print different numbers of trailing
        # decimals, specifying this explicitly keeps it consistent for all
        # versions
        return f"{f:.15g}"

    def str_int(i: int) -> str:
        return str(i)

    _T = TypeVar("_T")

    def str_list(str_elem: Callable[[_T], str], xs: Sequence[_T]) -> str:
        return "[" + ", ".join(map(str_elem, xs)) + "]"

    # for now, this logic should continue to work as long as we are running on a proto3
    # implementation. If/when we switch to proto3, we will need to use attr.type

    # To support printing subgraphs, if we find a graph attribute, print out
    # its name here and pass the graph itself up to the caller for later
    # printing.
    graphs = []
    if attr.HasField("f"):
        content.append(str_float(attr.f))
    elif attr.HasField("i"):
        content.append(str_int(attr.i))
    elif attr.HasField("s"):
        # TODO: Bit nervous about Python 2 / Python 3 determinism implications
        content.append(repr(_sanitize_str(attr.s)))
    elif attr.HasField("t"):
        if len(attr.t.dims) > 0:
            content.append("<Tensor>")
        else:
            # special case to print scalars
            field = tensor_dtype_to_field(attr.t.data_type)
            content.append(f"<Scalar Tensor {getattr(attr.t, field)}>")
    elif attr.HasField("g"):
        content.append(f"<graph {attr.g.name}>")
        graphs.append(attr.g)
    elif attr.HasField("tp"):
        content.append(f"<Type Proto {attr.tp}>")
    elif attr.floats:
        content.append(str_list(str_float, attr.floats))
    elif attr.ints:
        content.append(str_list(str_int, attr.ints))
    elif attr.strings:
        # TODO: Bit nervous about Python 2 / Python 3 determinism implications
        content.append(str(list(map(_sanitize_str, attr.strings))))
    elif attr.tensors:
        content.append("[<Tensor>, ...]")
    elif attr.type_protos:
        content.append("[")
        for i, tp in enumerate(attr.type_protos):
            comma = "," if i != len(attr.type_protos) - 1 else ""
            content.append(f"<Type Proto {tp}>{comma}")
        content.append("]")
    elif attr.graphs:
        content.append("[")
        for i, g in enumerate(attr.graphs):
            comma = "," if i != len(attr.graphs) - 1 else ""
            content.append(f"<graph {g.name}>{comma}")
        content.append("]")
        graphs.extend(attr.graphs)
    else:
        content.append("<Unknown>")
    if subgraphs:
        return " ".join(content), graphs
    return " ".join(content)


def printable_dim(dim: TensorShapeProto.Dimension) -> str:
    which = dim.WhichOneof("value")
    if which is None:
        return "?"
    return str(getattr(dim, which))


def printable_type(t: TypeProto) -> str:
    if t.WhichOneof("value") == "tensor_type":
        s: str = TensorProto.DataType.Name(t.tensor_type.elem_type)  # type: ignore[attr-defined]
        if t.tensor_type.HasField("shape"):
            if len(t.tensor_type.shape.dim):
                s += str(", " + "x".join(map(printable_dim, t.tensor_type.shape.dim)))
            else:
                s += ", scalar"
        return s
    if t.WhichOneof("value") is None:
        return ""
    return f"Unknown type {t.WhichOneof('value')}"


def printable_value_info(v: ValueInfoProto) -> str:
    s = f"%{v.name}"
    if v.type:
        s = f"{s}[{printable_type(v.type)}]"
    return s


def printable_tensor_proto(t: TensorProto) -> str:
    s = f"%{t.name}["
    s += TensorProto.DataType.Name(t.data_type)  # type: ignore[attr-defined]
    if t.dims is not None:
        if len(t.dims):
            s += str(", " + "x".join(map(str, t.dims)))
        else:
            s += ", scalar"
    s += "]"
    return s


def printable_node(
    node: NodeProto, prefix: str = "", subgraphs: bool = False
) -> str | tuple[str, list[GraphProto]]:
    content = []
    if len(node.output):
        content.append(", ".join([f"%{name}" for name in node.output]))
        content.append("=")
    # To deal with nested graphs
    graphs: list[GraphProto] = []
    printed_attrs = []
    for attr in node.attribute:
        if subgraphs:
            printed_attr_subgraphs = printable_attribute(attr, subgraphs)
            if not isinstance(printed_attr_subgraphs[1], list):
                raise TypeError(
                    f"printed_attr_subgraphs[1] must be an instance of {list}."
                )
            graphs.extend(printed_attr_subgraphs[1])
            printed_attrs.append(printed_attr_subgraphs[0])
        else:
            printed = printable_attribute(attr)
            if not isinstance(printed, str):
                raise TypeError(f"printed must be an instance of {str}.")
            printed_attrs.append(printed)
    printed_attributes = ", ".join(sorted(printed_attrs))
    printed_inputs = ", ".join([f"%{name}" for name in node.input])
    if node.attribute:
        content.append(f"{node.op_type}[{printed_attributes}]({printed_inputs})")
    else:
        content.append(f"{node.op_type}({printed_inputs})")
    if subgraphs:
        return prefix + " ".join(content), graphs
    return prefix + " ".join(content)


@typing_extensions.deprecated(
    "Deprecated since 1.19. Consider using onnx.printer.to_text() instead."
)
def printable_graph(graph: GraphProto, prefix: str = "") -> str:
    """Display a GraphProto as a string.

    .. deprecated:: 1.19
        Consider using :func:`onnx.printer.to_text` instead.

    Args:
        graph (GraphProto): the graph to display
        prefix (string): prefix of every line

    Returns:
        string
    """
    content = []
    indent = prefix + "  "
    # header
    header = ["graph", graph.name]
    initializers = {t.name for t in graph.initializer}
    if len(graph.input):
        header.append("(")
        in_strs = []  # required inputs
        in_with_init_strs: list = []  # optional inputs with initializer providing default value
        for inp in graph.input:
            if inp.name not in initializers:
                in_strs.append(printable_value_info(inp))
            else:
                in_with_init_strs.append(printable_value_info(inp))
        if in_strs:
            content.append(prefix + " ".join(header))
            header = []
            for line in in_strs:
                content.append(prefix + "  " + line)  # noqa: PERF401
        header.append(")")

        if in_with_init_strs:
            header.append("optional inputs with matching initializers (")
            content.append(prefix + " ".join(header))
            header = []
            for line in in_with_init_strs:
                content.append(prefix + "  " + line)  # noqa: PERF401
            header.append(")")

        # from IR 4 onwards an initializer is not required to have a matching graph input
        # so output the name, type and shape of those as well
        if len(in_with_init_strs) < len(initializers):
            graph_inputs = {i.name for i in graph.input}
            init_strs = [
                printable_tensor_proto(i)
                for i in graph.initializer
                if i.name not in graph_inputs
            ]
            header.append("initializers (")
            content.append(prefix + " ".join(header))
            header = []
            for line in init_strs:
                content.append(prefix + "  " + line)  # noqa: PERF401
            header.append(")")

    header.append("{")
    content.append(prefix + " ".join(header))
    graphs: list[GraphProto] = []
    # body
    for node in graph.node:
        contents_subgraphs = printable_node(node, indent, subgraphs=True)
        if not isinstance(contents_subgraphs[1], list):
            raise TypeError(f"contents_subgraphs[1] must be an instance of {list}.")
        content.append(contents_subgraphs[0])
        graphs.extend(contents_subgraphs[1])
    # tail
    tail = ["return"]
    if len(graph.output):
        tail.append(", ".join([f"%{out.name}" for out in graph.output]))
    content.append(indent + " ".join(tail))
    # closing bracket
    content.append(prefix + "}")
    for g in graphs:
        content.append("\n" + printable_graph(g))  # noqa: PERF401
    return "\n".join(content)


def strip_doc_string(proto: google.protobuf.message.Message) -> None:
    """Empties `doc_string` field on any nested protobuf messages"""
    if not isinstance(proto, google.protobuf.message.Message):
        raise TypeError(
            f"proto must be an instance of {google.protobuf.message.Message}."
        )
    for descriptor in proto.DESCRIPTOR.fields:
        if descriptor.name == "doc_string":
            proto.ClearField(descriptor.name)
        elif descriptor.type == descriptor.TYPE_MESSAGE:
            if descriptor.label == descriptor.LABEL_REPEATED:
                for x in getattr(proto, descriptor.name):
                    strip_doc_string(x)
            elif proto.HasField(descriptor.name):
                strip_doc_string(getattr(proto, descriptor.name))


def make_training_info(
    algorithm: GraphProto,
    algorithm_bindings: AssignmentBindingType,
    initialization: GraphProto | None,
    initialization_bindings: AssignmentBindingType | None,
) -> TrainingInfoProto:
    training_info = TrainingInfoProto()
    training_info.algorithm.CopyFrom(algorithm)
    for k, v in algorithm_bindings:
        binding = training_info.update_binding.add()
        binding.key = k
        binding.value = v

    if initialization:
        training_info.initialization.CopyFrom(initialization)
    if initialization_bindings:
        for k, v in initialization_bindings:
            binding = training_info.initialization_binding.add()
            binding.key = k
            binding.value = v

    return training_info


# Following functions are used for mapping
def tensor_dtype_to_np_dtype(tensor_dtype: int) -> np.dtype:
    """Convert a TensorProto's data_type to corresponding numpy dtype. It can be used while making tensor.

    Args:
        tensor_dtype: TensorProto's data_type

    Returns:
        numpy's data_type
    """
    return _mapping.TENSOR_TYPE_MAP[tensor_dtype].np_dtype


def tensor_dtype_to_storage_tensor_dtype(tensor_dtype: int) -> int:
    """Convert a TensorProto's data_type to corresponding data_type for storage.

    Args:
        tensor_dtype: TensorProto's data_type

    Returns:
        data_type for storage
    """
    return _mapping.TENSOR_TYPE_MAP[tensor_dtype].storage_dtype


def tensor_dtype_to_string(tensor_dtype: int) -> str:
    """Get the name of given TensorProto's data_type.

    Args:
        tensor_dtype: TensorProto's data_type

    Returns:
        the name of data_type
    """
    return _mapping.TENSOR_TYPE_MAP[tensor_dtype].name


@functools.lru_cache(None)
def tensor_dtype_to_field(tensor_dtype: int) -> str:
    """Convert a TensorProto's data_type to corresponding field name for storage. It can be used while making tensors.

    Args:
        tensor_dtype: TensorProto's data_type

    Returns:
        field name
    """
    storage_tensor_type_to_field = {
        int(TensorProto.FLOAT): "float_data",
        int(TensorProto.INT32): "int32_data",
        int(TensorProto.INT64): "int64_data",
        int(TensorProto.DOUBLE): "double_data",
        int(TensorProto.UINT32): "uint64_data",
        int(TensorProto.UINT64): "uint64_data",
        int(TensorProto.STRING): "string_data",
    }
    return storage_tensor_type_to_field[
        _mapping.TENSOR_TYPE_MAP[tensor_dtype].storage_dtype
    ]


@functools.lru_cache(None)
def np_dtype_to_tensor_dtype(np_dtype: np.dtype) -> TensorProto.DataType:
    """Convert a numpy's dtype to corresponding tensor type. It can be used while converting numpy arrays to tensors.

    Args:
        np_dtype: numpy's data_type

    Returns:
        TensorsProto's data_type
    """
    _np_dtype_to_tensor_dtype = {
        v.np_dtype: k for k, v in _mapping.TENSOR_TYPE_MAP.items()
    }
    if np_dtype in _np_dtype_to_tensor_dtype:
        return typing.cast("TensorProto.DataType", _np_dtype_to_tensor_dtype[np_dtype])
    if np.issubdtype(np_dtype, np.str_):
        return TensorProto.STRING  # type: ignore[no-any-return]

    raise ValueError(
        f"Unable to convert type {np_dtype!r} into TensorProto element type."
    )


def get_all_tensor_dtypes() -> KeysView[int]:
    """Get all tensor types from TensorProto.

    Returns:
        all tensor types from TensorProto
    """
    return _mapping.TENSOR_TYPE_MAP.keys()


_ATTRIBUTE_TYPE_TO_STR: dict[int, str] = {
    k: v
    for v, k in AttributeProto.AttributeType.items()  # type: ignore[attr-defined]
}


def _attr_type_to_str(attr_type: int) -> str:
    """Convert AttributeProto type to string.

    Args:
        attr_type: AttributeProto type.

    Returns:
        String representing the supplied attr_type.
    """
    if attr_type in AttributeProto.AttributeType.values():  # type: ignore[attr-defined]
        return _ATTRIBUTE_TYPE_TO_STR[attr_type]
    return AttributeProto.AttributeType.keys()[0]  # type: ignore
