# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""This is basically a copy from perception_models/core/vision_encoder/pe.py"""

from collections.abc import Callable
from functools import partial

import torch
from einops import rearrange, repeat
from torch import nn
from torch.nn import functional as F

from vllm.config import VllmConfig
from vllm.distributed import get_tensor_model_parallel_world_size
from vllm.model_executor.layers.activation import get_act_fn
from vllm.model_executor.layers.attention.mm_encoder_attention import MMEncoderAttention
from vllm.model_executor.layers.conv import Conv2dLayer
from vllm.model_executor.layers.linear import (
    ColumnParallelLinear,
    QKVParallelLinear,
    RowParallelLinear,
)
from vllm.model_executor.layers.quantization import QuantizationConfig

from .step3_vl import Step3VLForConditionalGeneration
from .utils import WeightsMapper, init_vllm_registered_model, maybe_prefix
from .vision import run_dp_sharded_vision_model

_DEFAULT_NORM_LAYER = partial(nn.LayerNorm, eps=1e-5)


def rotate_half(x):
    x = rearrange(x, "... (d r) -> ... d r", r=2)
    x1, x2 = x.unbind(dim=-1)
    x = torch.stack((-x2, x1), dim=-1)
    return rearrange(x, "... d r -> ... (d r)")


def apply_rotary_emb(freqs, t, start_index=0, scale=1.0, seq_dim=-2):
    dtype = t.dtype

    if t.ndim == 3:
        seq_len = t.shape[seq_dim]
        freqs = freqs[-seq_len:]

    rot_dim = freqs.shape[-1]
    end_index = start_index + rot_dim

    assert rot_dim <= t.shape[-1], (
        "feature dimension {} is not of sufficient size to rotate in all the "
        "positions {}".format(t.shape[-1], rot_dim)
    )

    t_left, t, t_right = (
        t[..., :start_index],
        t[..., start_index:end_index],
        t[..., end_index:],
    )
    t = (t * freqs.cos() * scale) + (rotate_half(t) * freqs.sin() * scale)
    out = torch.cat((t_left, t, t_right), dim=-1)

    return out.type(dtype)


class PerceptionEncoderRope2D(nn.Module):
    def __init__(
        self,
        dim: int,
        max_grid_height: int,
        max_grid_width: int,
        use_cls_token: bool = False,
        theta=10000,
        max_freq=10,
        num_freqs=1,
        theta_rescale_factor=1.0,
    ):
        super().__init__()
        self.dim = dim
        self.max_grid_height = max_grid_height
        self.max_grid_width = max_grid_width
        self.use_cls_token = use_cls_token
        self.theta = theta * theta_rescale_factor ** (dim / (dim - 2))
        self.max_freq = max_freq
        self.num_freqs = num_freqs
        cache = self._compute_2d_freqs()
        self.register_buffer("freqs_cache", cache, persistent=False)

    def _compute_inv_freq(self, base: int | float, dim: int) -> torch.Tensor:
        freqs = 1.0 / (base ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim))
        return freqs

    def _compute_freqs(self, t: torch.Tensor, inv_freq: torch.Tensor):
        freqs = torch.einsum("..., f -> ... f", t.type(inv_freq.dtype), inv_freq)
        freqs = repeat(freqs, "... n -> ... (n r)", r=2)
        return freqs

    def _compute_2d_freqs(self) -> torch.Tensor:
        grid_h_range = torch.arange(self.max_grid_height, dtype=torch.float)
        grid_w_range = torch.arange(self.max_grid_width, dtype=torch.float)
        if self.use_cls_token:
            grid_h_range += 1
            grid_w_range += 1
        inv_freq = self._compute_inv_freq(self.theta, self.dim // 2)
        freqs_h = self._compute_freqs(grid_h_range, inv_freq)[:, None].expand(
            self.max_grid_height, self.max_grid_width, -1
        )
        freqs_w = self._compute_freqs(grid_w_range, inv_freq)[None, :].expand(
            self.max_grid_height, self.max_grid_width, -1
        )
        freqs = torch.cat([freqs_w, freqs_h], dim=-1).reshape(
            self.max_grid_height * self.max_grid_width, -1
        )
        if self.use_cls_token:
            freqs = torch.cat([torch.zeros(1, freqs.shape[-1]), freqs], dim=0)
        freqs = freqs[None, None, ...]
        return freqs

    def forward(self, q: torch.Tensor, k: torch.Tensor, grid_hw: tuple[int, int]):
        if grid_hw[0] != self.max_grid_height or grid_hw[1] != self.max_grid_width:
            rows = torch.arange(grid_hw[0], device=q.device).view(-1, 1)
            cols = torch.arange(grid_hw[1], device=q.device).view(1, -1)
            positions = (rows * self.max_grid_width + cols).reshape(-1).to(torch.long)
            if self.use_cls_token:
                positions = torch.cat(
                    [torch.zeros(1, device=q.device), positions + 1], dim=0
                )
                positions = positions.to(torch.long)
            freqs = self.freqs_cache.index_select(2, positions)
        else:
            freqs = self.freqs_cache
        q = apply_rotary_emb(freqs, q)
        k = apply_rotary_emb(freqs, k)
        return q, k


class PerceptionEncoderLayerScale(nn.Module):
    def __init__(self, dim, init_values=1e-5, inplace=False):
        super().__init__()
        self.inplace = inplace
        self.gamma = nn.Parameter(init_values * torch.ones(dim))

    def forward(self, x):
        return x.mul_(self.gamma) if self.inplace else x * self.gamma


class PerceptionEncoderMLP(nn.Module):
    def __init__(
        self,
        input_dim: int,
        hidden_dim: int,
        act_layer: Callable[[], nn.Module],
        quant_config: QuantizationConfig | None = None,
        prefix: str = "",
        use_data_parallel: bool = False,
    ):
        super().__init__()
        self.fc1 = ColumnParallelLinear(
            input_dim,
            hidden_dim,
            bias=True,
            quant_config=quant_config,
            prefix=f"{prefix}.fc1",
            disable_tp=use_data_parallel,
        )
        self.activation = act_layer
        self.fc2 = RowParallelLinear(
            hidden_dim,
            input_dim,
            bias=True,
            quant_config=quant_config,
            prefix=f"{prefix}.fc2",
            disable_tp=use_data_parallel,
        )

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x, _ = self.fc1(x)
        x = self.activation(x)
        x, _ = self.fc2(x)
        return x


class PerceptionEncoderVisionAttention(nn.Module):
    def __init__(
        self,
        embed_dim: int,
        num_heads: int,
        max_grid_height: int,
        max_grid_width: int,
        use_cls_token: bool = False,
        quant_config: QuantizationConfig | None = None,
        prefix: str = "",
        use_data_parallel: bool = False,
    ):
        super().__init__()
        self.embed_dim = embed_dim
        self.total_num_heads = num_heads
        self.head_dim = embed_dim // num_heads
        self.scale = self.head_dim**-0.5

        tp_size = 1 if use_data_parallel else get_tensor_model_parallel_world_size()
        assert self.total_num_heads % tp_size == 0, (
            "embed_dim must be divisible by num_heads"
        )
        self.num_heads = self.total_num_heads // tp_size

        self.qkv_proj = QKVParallelLinear(
            embed_dim,
            self.head_dim,
            self.total_num_heads,
            bias=True,
            quant_config=quant_config,
            prefix=f"{prefix}.qkv_proj",
            disable_tp=use_data_parallel,
        )
        self.out_proj = RowParallelLinear(
            embed_dim,
            embed_dim,
            bias=True,
            quant_config=quant_config,
            prefix=f"{prefix}.out_proj",
            disable_tp=use_data_parallel,
        )
        self.attn = MMEncoderAttention(self.num_heads, self.head_dim, self.scale)
        self.rope = PerceptionEncoderRope2D(
            dim=self.head_dim,
            max_grid_height=max_grid_height,
            max_grid_width=max_grid_width,
            use_cls_token=use_cls_token,
        )

    def forward(self, x: torch.Tensor, grid_hw: tuple[int, int]) -> torch.Tensor:
        bsz, seq_len, _ = x.shape
        qkv, _ = self.qkv_proj(x)
        q, k, v = qkv.chunk(chunks=3, dim=-1)

        q = q.view(bsz, seq_len, self.num_heads, self.head_dim).permute(0, 2, 1, 3)
        k = k.view(bsz, seq_len, self.num_heads, self.head_dim).permute(0, 2, 1, 3)
        q, k = self.rope(q, k, grid_hw=grid_hw)
        q = q.permute(0, 2, 1, 3).reshape(bsz, seq_len, self.num_heads * self.head_dim)
        k = k.permute(0, 2, 1, 3).reshape(bsz, seq_len, self.num_heads * self.head_dim)

        attn_output = self.attn(q, k, v)
        attn_output, _ = self.out_proj(attn_output)
        return attn_output


class PerceptionEncoderVisionBlock(nn.Module):
    def __init__(
        self,
        d_model: int,
        n_head: int,
        max_grid_height: int,
        max_grid_width: int,
        mlp_ratio: float = 4.0,
        ls_init_value: float = None,
        act_layer: Callable = nn.GELU,
        norm_layer: Callable = nn.LayerNorm,
        use_cls_token: bool = False,
        quant_config: QuantizationConfig | None = None,
        prefix: str = "",
        use_data_parallel: bool = False,
    ):
        super().__init__()
        self.attn = PerceptionEncoderVisionAttention(
            d_model,
            n_head,
            max_grid_height=max_grid_height,
            max_grid_width=max_grid_width,
            use_cls_token=use_cls_token,
            quant_config=quant_config,
            prefix=f"{prefix}.attn",
            use_data_parallel=use_data_parallel,
        )
        self.ls_1 = (
            PerceptionEncoderLayerScale(d_model, ls_init_value)
            if ls_init_value is not None
            else nn.Identity()
        )
        self.ls_2 = (
            PerceptionEncoderLayerScale(d_model, ls_init_value)
            if ls_init_value is not None
            else nn.Identity()
        )
        self.ln_1 = norm_layer(d_model)
        self.ln_2 = norm_layer(d_model)
        hidden_dim = int(d_model * mlp_ratio)
        self.mlp = PerceptionEncoderMLP(
            d_model,
            hidden_dim,
            act_layer,
            quant_config=quant_config,
            prefix=f"{prefix}.mlp",
            use_data_parallel=use_data_parallel,
        )

    def forward(self, x: torch.Tensor, grid_hw: tuple[int, int]):
        x = x + self.ls_1(self.attn(self.ln_1(x), grid_hw=grid_hw))
        x = x + self.ls_2(self.mlp(self.ln_2(x)))
        return x


class PerceptionEncoderVisionTransformer(nn.Module):
    def __init__(
        self,
        width: int,
        layers: int,
        heads: int,
        max_grid_height: int,
        max_grid_width: int,
        mlp_ratio: float = 4.0,
        ls_init_value: float = None,
        act_layer: Callable = nn.GELU,
        norm_layer: Callable = nn.LayerNorm,
        use_cls_token: bool = False,
        quant_config: QuantizationConfig | None = None,
        prefix: str = "",
        use_data_parallel: bool = False,
    ):
        super().__init__()
        self.width = width
        self.layers = layers
        self.resblocks = nn.ModuleList(
            [
                PerceptionEncoderVisionBlock(
                    d_model=width,
                    n_head=heads,
                    max_grid_height=max_grid_height,
                    max_grid_width=max_grid_width,
                    mlp_ratio=mlp_ratio,
                    ls_init_value=ls_init_value,
                    act_layer=act_layer,
                    norm_layer=norm_layer,
                    use_cls_token=use_cls_token,
                    quant_config=quant_config,
                    prefix=f"{prefix}.resblocks.{i}",
                    use_data_parallel=use_data_parallel,
                )
                for i in range(layers)
            ]
        )

    def forward(self, x: torch.Tensor, grid_hw: tuple[int, int]):
        for block in self.resblocks:
            x = block(x, grid_hw=grid_hw)
        return x


class PerceptionEncoder(nn.Module):
    def __init__(
        self,
        config,
        act_layer: Callable,
        norm_layer: Callable = _DEFAULT_NORM_LAYER,
        quant_config: QuantizationConfig | None = None,
        prefix: str = "",
        use_data_parallel: bool = False,
    ):
        super().__init__()
        self.patch_size = config.patch_size

        self.output_dim = config.output_dim or config.width
        self.heads = config.heads
        self.width = config.width
        self.layers = config.layers

        self.use_abs_posemb = config.use_abs_posemb
        self.use_cls_token = config.use_cls_token
        self.use_rope2d = config.use_rope2d
        if not self.use_rope2d:
            raise ValueError("use_rope2d must be True")
        self.image_size = config.image_size

        self.conv1 = Conv2dLayer(
            in_channels=3,
            out_channels=config.width,
            kernel_size=config.patch_size,
            stride=config.patch_size,
            bias=False,
        )

        self.ln_pre = norm_layer(config.width) if config.use_ln_pre else nn.Identity()
        self.ln_post = norm_layer(self.width) if config.use_ln_post else nn.Identity()

        self.transformer = PerceptionEncoderVisionTransformer(
            config.width,
            config.layers,
            config.heads,
            max_grid_height=self.image_size // self.patch_size,
            max_grid_width=self.image_size // self.patch_size,
            mlp_ratio=config.mlp_ratio,
            ls_init_value=config.ls_init_value,
            act_layer=act_layer,
            norm_layer=norm_layer,
            use_cls_token=self.use_cls_token,
            quant_config=quant_config,
            prefix=f"{prefix}.transformer",
            use_data_parallel=use_data_parallel,
        )

        self.vit_downsampler1 = Conv2dLayer(
            config.width, config.width * 2, kernel_size=3, stride=2, padding=1
        )
        self.vit_downsampler2 = Conv2dLayer(
            config.width * 2, config.width * 4, kernel_size=3, stride=2, padding=1
        )

        if self.use_cls_token:
            self.class_embedding = nn.Parameter(
                (self.width**-0.5) * torch.randn(self.width)
            )

        if self.use_abs_posemb:
            self.posemb_grid_size = self.image_size // self.patch_size
            self.positional_embedding = nn.Parameter(
                (self.width**-0.5)
                * torch.randn(
                    int(self.use_cls_token) + self.posemb_grid_size**2,
                    self.width,
                )
            )

    def sample_abs_posemb(self, grid_h: int, grid_w: int):
        if self.posemb_grid_size == grid_h and self.posemb_grid_size == grid_w:
            return self.positional_embedding[None, ...]

        pos_embed = self.positional_embedding
        if self.use_cls_token:
            cls_token_embed, pos_embed = pos_embed[:1], pos_embed[1:]

        pos_embed = (
            pos_embed.reshape(1, self.posemb_grid_size, self.posemb_grid_size, -1)
            .permute(0, 3, 1, 2)
            .contiguous()
        )
        pos_embed = F.interpolate(
            pos_embed, size=(grid_h, grid_w), mode="bilinear", align_corners=False
        )
        pos_embed = pos_embed.permute(0, 2, 3, 1).reshape(-1, self.width)

        if self.use_cls_token:
            pos_embed = torch.cat([cls_token_embed, pos_embed], dim=0)

        return pos_embed[None, ...]

    def forward_features(self, x: torch.Tensor):
        batch, _, h, w = x.shape
        grid_h, grid_w = h // self.patch_size, w // self.patch_size

        x = self.conv1(x)
        x = x.permute(0, 2, 3, 1).reshape(batch, -1, self.width)

        if self.use_cls_token:
            x = torch.cat(
                [self.class_embedding.view(1, 1, -1).expand(batch, -1, -1), x], dim=1
            )

        if self.use_abs_posemb:
            x = x + self.sample_abs_posemb(grid_h, grid_w)

        x = self.ln_pre(x)
        x = self.transformer(x, grid_hw=(grid_h, grid_w))
        x = self.ln_post(x)

        if self.use_cls_token:
            x = x[:, 1:, :]

        return x

    def forward(self, x: torch.Tensor):
        x = self.forward_features(x)
        B, P, C = x.shape
        T = int(P**0.5)
        x = x.transpose(2, 1).contiguous()
        x = x.view(B, C, T, T)

        x = self.vit_downsampler1(x)
        x = self.vit_downsampler2(x)

        B, C, T, T = x.shape
        return x.view(B, -1, T * T).transpose(1, 2)


class StepVLForConditionalGeneration(Step3VLForConditionalGeneration):
    hf_to_vllm_mapper = WeightsMapper(
        orig_to_new_prefix={
            "model.": "language_model.model.",
            "lm_head.": "language_model.lm_head.",
        },
        orig_to_new_substr={
            ".attn.in_proj_weight": ".attn.qkv_proj.weight",
            ".attn.in_proj_bias": ".attn.qkv_proj.bias",
            ".mlp.c_fc": ".mlp.fc1",
            ".mlp.c_proj": ".mlp.fc2",
        },
    )

    def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None:
        super(Step3VLForConditionalGeneration, self).__init__()

        config = vllm_config.model_config.hf_config
        multimodal_config = vllm_config.model_config.multimodal_config
        quant_config = vllm_config.quant_config

        self.config = config
        self.multimodal_config = multimodal_config
        self.use_data_parallel = multimodal_config.mm_encoder_tp_mode == "data"

        with self._mark_tower_model(vllm_config, "image"):
            self.vision_model = PerceptionEncoder(
                config.vision_config,
                get_act_fn(config.vision_config.hidden_act),
                quant_config=quant_config,
                prefix=maybe_prefix(prefix, "vision_model"),
                use_data_parallel=self.use_data_parallel,
            )
            self.vit_large_projector = ColumnParallelLinear(
                config.vision_config.width * 4,
                config.text_config.hidden_size,
                bias=config.projector_bias,
                gather_output=True,
                quant_config=quant_config,
                prefix=maybe_prefix(prefix, "vit_large_projector"),
                disable_tp=self.use_data_parallel,
            )

        with self._mark_language_model(vllm_config):
            self.language_model = init_vllm_registered_model(
                vllm_config=vllm_config,
                hf_config=config.text_config,
                prefix=maybe_prefix(prefix, "language_model"),
            )

        self.make_empty_intermediate_tensors = (
            self.language_model.make_empty_intermediate_tensors
        )

    def _get_vision_model_output(
        self, input_tensor: torch.Tensor | None
    ) -> torch.Tensor | None:
        if input_tensor is None:
            return None
        if self.use_data_parallel:
            return run_dp_sharded_vision_model(input_tensor, self.vision_model)
        return self.vision_model(input_tensor)

    def _process_image_features(self, image_features: torch.Tensor) -> torch.Tensor:
        image_features, _ = self.vit_large_projector(image_features)
        return image_features
