# Copyright (c) ONNX Project Contributors
#
# SPDX-License-Identifier: Apache-2.0
from __future__ import annotations

import numpy as np

from onnx.reference.op_run import OpRun


def rotary_embedding(
    input: np.ndarray,
    cos_cache: np.ndarray,
    sin_cache: np.ndarray,
    position_ids: np.ndarray | None = None,
    interleaved=None,
    rotary_embedding_dim=None,
    num_heads=None,
) -> np.ndarray:
    original_input_shape = input.shape
    # First ensure input to be processed has shape [batch_size, seq_len, num_heads, head_size]
    if len(input.shape) == 4:
        input = np.transpose(input, (0, 2, 1, 3))
    batch_size = input.shape[0]
    sequence_length = input.shape[1]
    if len(input.shape) == 3:
        hidden_size = input.shape[2]
        assert num_heads != 0
        head_size = int(hidden_size / num_heads)
        new_shape = [batch_size, sequence_length, num_heads, head_size]
        input = np.reshape(input, new_shape)
    assert len(input.shape) == 4
    head_size = input.shape[3]

    # Fully or partially perform rotation on input based on rotary_embedding_dim attribute
    if rotary_embedding_dim is None or rotary_embedding_dim == 0:
        # If rotary_embedding_dim not provided, perform full rotation by using head_size
        rotary_embedding_dim = head_size
    x_rotate = input[:, :, :, :rotary_embedding_dim]
    x_not_rotate = input[:, :, :, rotary_embedding_dim:]
    rotary_embedding_dim_half = int(rotary_embedding_dim / 2)

    # Retrieve sin and cos caches using position ids
    if position_ids is not None:
        cos_cache = cos_cache[
            position_ids
        ]  # Shape: [batch_size, sequence_length, rotary_embedding_dim/2]
        sin_cache = sin_cache[
            position_ids
        ]  # Shape: [batch_size, sequence_length, rotary_embedding_dim/2]

    # Shape: [batch_size, sequence_length, rotary_embedding_dim/2]
    if cos_cache.shape[-1] != rotary_embedding_dim_half:
        raise ValueError(
            f"Last dimension of cos cache ({cos_cache.shape[-1]}) does not match rotary_embedding_dim/2 ({rotary_embedding_dim_half})."
        )
    if sin_cache.shape[-1] != rotary_embedding_dim_half:
        raise ValueError(
            f"Last dimension of sin cache ({sin_cache.shape[-1]}) does not match rotary_embedding_dim/2 ({rotary_embedding_dim_half})."
        )

    cos_cache = np.expand_dims(
        cos_cache, axis=2
    )  # Shape: [batch_size, sequence_length, 1, rotary_embedding_dim/2]
    sin_cache = np.expand_dims(
        sin_cache, axis=2
    )  # Shape: [batch_size, sequence_length, 1, rotary_embedding_dim/2]

    # Either divide the input in halves or interleave (based on interleaved attribute)
    if interleaved:
        x1 = x_rotate[:, :, :, 0::2]
        x2 = x_rotate[:, :, :, 1::2]
    else:
        x1, x2 = np.split(x_rotate, 2, axis=-1)

    # Calculate real and imaginary values
    real = (cos_cache * x1) - (sin_cache * x2)
    imag = (sin_cache * x1) + (cos_cache * x2)

    # Inserted rotated embeddings back to the original input
    if interleaved:
        # x_rotate[:, :, :, 0::2] = real
        # x_rotate[:, :, :, 1::2] = imag
        real = np.expand_dims(real, axis=-1)
        imag = np.expand_dims(imag, axis=-1)
        x_rotate_concat = np.concatenate((real, imag), axis=-1)
        x_rotate = np.reshape(x_rotate_concat, x_rotate.shape)
    else:
        x_rotate = np.concatenate((real, imag), axis=-1)
    output = np.concatenate((x_rotate, x_not_rotate), axis=-1)
    if len(original_input_shape) == 3:
        output = np.reshape(output, original_input_shape)
    else:
        output = np.transpose(output, (0, 2, 1, 3))
    return output


class RotaryEmbedding(OpRun):
    def _run(
        self,
        input: np.ndarray,
        cos_cache: np.ndarray,
        sin_cache: np.ndarray,
        position_ids: np.ndarray | None = None,
        interleaved=None,
        rotary_embedding_dim=None,
        num_heads=None,
    ) -> np.ndarray:
        return (
            rotary_embedding(
                input,
                cos_cache,
                sin_cache,
                position_ids=position_ids,
                interleaved=interleaved,
                rotary_embedding_dim=rotary_embedding_dim,
                num_heads=num_heads,
            ),
        )
