# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import base64
import io
import math
import sys
from dataclasses import dataclass
from typing import TYPE_CHECKING, Any, Literal

import numpy as np
import torch
from typing_extensions import assert_never

if TYPE_CHECKING:
    from vllm import PoolingRequestOutput
else:
    PoolingRequestOutput = Any

sys_byteorder = sys.byteorder


EMBED_DTYPE_TO_TORCH_DTYPE = {
    "float32": torch.float32,
    "float16": torch.float16,
    "bfloat16": torch.bfloat16,
    # I'm not sure if other platforms' CPUs support the fp8 data format.
    # EMBED_DTYPE only uses the fp8 data representation,
    # does not use fp8 computation, and only occurs on the CPU.
    # Apologize for any possible break.
    "fp8_e4m3": torch.float8_e4m3fn,
    "fp8_e5m2": torch.float8_e5m2,
}

EMBED_DTYPE_TO_N_BYTES = {
    "float32": 4,
    "float16": 2,
    "bfloat16": 2,
    "fp8_e4m3": 1,
    "fp8_e5m2": 1,
}


EMBED_DTYPE_TO_TORCH_DTYPE_VIEW = {
    "float32": torch.float32,
    "float16": torch.float16,
    # numpy does not support bfloat16 and fp8
    "bfloat16": torch.float16,
    "fp8_e4m3": torch.uint8,
    "fp8_e5m2": torch.uint8,
}

EMBED_DTYPE_TO_NUMPY_DTYPE_VIEW = {
    "float32": np.float32,
    "float16": np.float16,
    # numpy does not support bfloat16 and fp8
    "bfloat16": np.float16,
    "fp8_e4m3": np.uint8,
    "fp8_e5m2": np.uint8,
}

ENDIANNESS = ["native", "big", "little"]

EmbedDType = Literal["float32", "float16", "bfloat16", "fp8_e4m3", "fp8_e5m2"]
Endianness = Literal["native", "big", "little"]
EncodingFormat = Literal["float", "base64", "bytes", "bytes_only"]


def tensor2base64(x: torch.Tensor) -> str:
    with io.BytesIO() as buf:
        torch.save(x, buf)
        buf.seek(0)
        binary_data = buf.read()

    return base64.b64encode(binary_data).decode("utf-8")


def tensor2binary(
    tensor: torch.Tensor, embed_dtype: EmbedDType, endianness: Endianness
) -> bytes:
    assert isinstance(tensor, torch.Tensor)
    assert embed_dtype in EMBED_DTYPE_TO_TORCH_DTYPE
    assert endianness in ENDIANNESS

    torch_dtype = EMBED_DTYPE_TO_TORCH_DTYPE[embed_dtype]
    torch_view_dtype = EMBED_DTYPE_TO_TORCH_DTYPE_VIEW[embed_dtype]

    np_array = (
        tensor.to(torch_dtype).flatten().contiguous().view(torch_view_dtype).numpy()
    )

    if endianness != "native" and endianness != sys_byteorder:
        np_array = np_array.byteswap()

    return np_array.tobytes()


def binary2tensor(
    binary: bytes,
    shape: tuple[int, ...],
    embed_dtype: EmbedDType,
    endianness: Endianness,
) -> torch.Tensor:
    assert embed_dtype in EMBED_DTYPE_TO_TORCH_DTYPE
    assert embed_dtype in EMBED_DTYPE_TO_NUMPY_DTYPE_VIEW
    assert endianness in ENDIANNESS

    torch_dtype = EMBED_DTYPE_TO_TORCH_DTYPE[embed_dtype]
    np_dtype = EMBED_DTYPE_TO_NUMPY_DTYPE_VIEW[embed_dtype]

    np_array = np.frombuffer(binary, dtype=np_dtype).reshape(shape)

    if endianness != "native" and endianness != sys_byteorder:
        np_array = np_array.byteswap()

    return torch.from_numpy(np_array).view(torch_dtype)


def encode_pooling_output(
    output: PoolingRequestOutput,
    encoding_format: EncodingFormat,
    embed_dtype: EmbedDType,
    endianness: Endianness,
) -> list[float] | str | bytes:
    if encoding_format == "float":
        return output.outputs.data.tolist()
    elif encoding_format == "base64":
        embedding_bytes = tensor2binary(output.outputs.data, embed_dtype, endianness)
        return base64.b64encode(embedding_bytes).decode("utf-8")
    elif encoding_format == "bytes" or encoding_format == "bytes_only":
        return tensor2binary(output.outputs.data, embed_dtype, endianness)
    assert_never(encoding_format)


@dataclass
class MetadataItem:
    index: int
    embed_dtype: EmbedDType
    endianness: Endianness
    start: int
    end: int
    shape: tuple[int, ...]


def build_metadata_items(
    embed_dtype: EmbedDType,
    endianness: Endianness,
    shape: tuple[int, ...],
    n_request: int,
):
    n_bytes = EMBED_DTYPE_TO_N_BYTES[embed_dtype]
    size = math.prod(shape)
    items = [
        MetadataItem(
            index=i,
            embed_dtype=embed_dtype,
            endianness=endianness,
            start=i * size * n_bytes,
            end=(i + 1) * size * n_bytes,
            shape=shape,
        )
        for i in range(n_request)
    ]

    return items


def encode_pooling_bytes(
    pooling_outputs: list[PoolingRequestOutput],
    embed_dtype: EmbedDType,
    endianness: Endianness,
):
    num_prompt_tokens = 0
    items: list[dict[str, MetadataItem]] = []
    body = []
    offset = 0
    for idx, output in enumerate(pooling_outputs):
        binary = tensor2binary(
            tensor=output.outputs.data,
            embed_dtype=embed_dtype,
            endianness=endianness,
        )
        size = len(binary)

        item = {
            "index": idx,
            "embed_dtype": embed_dtype,
            "endianness": endianness,
            "start": offset,
            "end": offset + size,
            "shape": output.outputs.data.shape,
        }

        body.append(binary)
        items.append(item)
        prompt_token_ids = output.prompt_token_ids
        num_prompt_tokens += len(prompt_token_ids)
        offset += size

    usage = {
        "prompt_tokens": num_prompt_tokens,
        "total_tokens": num_prompt_tokens,
    }
    return body, items, usage


def decode_pooling_output(items: list[MetadataItem], body: bytes) -> list[torch.Tensor]:
    items.sort(key=lambda x: x.index)

    tensor_list: list[torch.Tensor] = []
    for item in items:
        binary = body[item.start : item.end]
        tensor = binary2tensor(binary, item.shape, item.embed_dtype, item.endianness)
        tensor_list.append(tensor)
    return tensor_list
