# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project

# Copyright 2024 The vLLM team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Transformers modeling backend utilities."""

from contextlib import contextmanager
from pathlib import Path
from typing import TYPE_CHECKING, Literal

import torch
from torch import nn

from vllm.config.utils import getattr_iter
from vllm.logger import init_logger
from vllm.model_executor.layers.conv import Conv2dLayer, Conv3dLayer
from vllm.model_executor.layers.layernorm import GemmaRMSNorm, RMSNorm
from vllm.model_executor.layers.linear import (
    ColumnParallelLinear,
    ReplicatedLinear,
    RowParallelLinear,
)
from vllm.transformers_utils.config import is_rope_parameters_nested

if TYPE_CHECKING:
    from vllm.config import VllmConfig
    from vllm.model_executor.layers.quantization import QuantizationConfig


logger = init_logger(__name__)


# Copied from `accelerate`
@contextmanager
def init_on_device_without_buffers(device: torch.device):
    """
    A context manager under which models are initialized with all
    parameters on the specified device. However buffers are not
    initialized on specified device.

    Args:
        device (`torch.device`):
            Device to initialize all parameters on.
    """

    old_register_parameter = nn.Module.register_parameter

    def register_empty_parameter(module, name, param):
        old_register_parameter(module, name, param)
        if param is not None:
            param_cls = type(module._parameters[name])
            kwargs = module._parameters[name].__dict__
            kwargs["requires_grad"] = param.requires_grad
            module._parameters[name] = param_cls(
                module._parameters[name].to(device), **kwargs
            )

    tensor_constructors_to_patch = {}

    def patch_tensor_constructor(fn):
        def wrapper(*args, **kwargs):
            kwargs["device"] = device
            return fn(*args, **kwargs)

        return wrapper

    try:
        nn.Module.register_parameter = register_empty_parameter
        for torch_function_name in tensor_constructors_to_patch:
            setattr(
                torch,
                torch_function_name,
                patch_tensor_constructor(getattr(torch, torch_function_name)),
            )
        yield
    finally:
        nn.Module.register_parameter = old_register_parameter
        for (
            torch_function_name,
            old_torch_function,
        ) in tensor_constructors_to_patch.items():
            setattr(torch, torch_function_name, old_torch_function)


Style = Literal["colwise", "colwise_rep", "rowwise", "rowwise_rep", "replicate"]


def replace_linear_class(
    linear: nn.Linear,
    style: Style = "replicate",
    quant_config: "QuantizationConfig | None" = None,
    *,
    prefix: str = "",
) -> ColumnParallelLinear | RowParallelLinear | ReplicatedLinear:
    """
    Replace nn.Linear with one of vLLM's tensor parallel linear classes.

    Args:
        linear: `nn.Linear` to be replaced.
        style: Tensor parallel style of the new linear, e.g. "colwise".
        quant_config: Quantization config for the new linear.
    Returns:
        The new linear.
    """

    if not isinstance(style, str):
        raise ValueError(f"Unsupported parallel style type {type(style)}, expected str")

    vllm_linear_cls, vllm_linear_kwargs = {
        "colwise": (ColumnParallelLinear, {}),
        "colwise_rep": (ColumnParallelLinear, {"gather_output": True}),
        "rowwise": (RowParallelLinear, {}),
        "rowwise_rep": (RowParallelLinear, {"input_is_parallel": False}),
        "replicate": (ReplicatedLinear, {}),
    }.get(style, (ReplicatedLinear, {}))

    return vllm_linear_cls(
        input_size=linear.in_features,
        output_size=linear.out_features,
        bias=linear.bias is not None,
        quant_config=quant_config,
        prefix=prefix,
        return_bias=False,
        **vllm_linear_kwargs,
    )


TorchConv = nn.Conv2d | nn.Conv3d
VllmConv = Conv2dLayer | Conv3dLayer


def replace_conv_class(conv: TorchConv) -> VllmConv | TorchConv:
    """Replace a Transformers Conv2d/Conv3d with vLLM's Conv2d/Conv3d.

    Args:
        conv: `nn.Conv2d` or `nn.Conv3d` to be replaced.
    Returns:
        The new `Conv2dLayer` or `Conv3dLayer`. If the conv module is not supported,
        returns the original conv module.
    """
    # vLLM does not handle non-zero padding modes
    if conv.padding_mode != "zeros":
        return conv

    vllm_conv_cls = {
        nn.Conv2d: Conv2dLayer,
        nn.Conv3d: Conv3dLayer,
    }.get(type(conv))

    if vllm_conv_cls is None:
        return conv

    return vllm_conv_cls(
        in_channels=conv.in_channels,
        out_channels=conv.out_channels,
        kernel_size=conv.kernel_size,
        stride=conv.stride,
        padding=conv.padding,
        dilation=conv.dilation,
        groups=conv.groups,
        bias=conv.bias is not None,
        padding_mode=conv.padding_mode,
        params_dtype=conv.weight.dtype,
    )


def replace_rms_norm_class(rms_norm: nn.Module, hidden_size: int) -> RMSNorm:
    """Replace a Transformers RMSNorm with vLLM's RMSNorm.

    This method assumes:
    - Weight is stored as `weight`.
    - Epsilon is stored as `eps` or `variance_epsilon`.
    - `with_scale` indicates whether the layer has a weight (Gemma3n only).
    - `var_hidden_size` is only ever used for Intern vision encoder in vLLM
    and Transformers doesn't appear to have the same concept.
    """
    eps = getattr_iter(rms_norm, ("eps", "variance_epsilon"), 1e-6)
    kwargs = {"hidden_size": hidden_size, "eps": eps}
    # Update hidden size if weight is available
    weight_meta = getattr(rms_norm, "weight", None)
    if weight_meta is not None:
        kwargs["hidden_size"] = weight_meta.size(0)
    # Check if weight is all zeros, which indicates GemmaRMSNorm
    # We must create a new instance because rms_norm is on meta
    try:
        with torch.device("cpu"):
            weight_test = getattr(rms_norm.__class__(1), "weight", None)
    except Exception:
        logger.warning(
            "Failed to determine if RMSNorm weight is centered on zero or one. "
            "Defaulting to one."
        )
        weight_test = None
    if weight_test is not None and torch.all(weight_test == 0):
        return GemmaRMSNorm(**kwargs)
    # Otherwise assume it's a regular RMSNorm
    kwargs["has_weight"] = getattr(rms_norm, "with_scale", True)
    if weight_meta is not None:
        kwargs["dtype"] = weight_meta.dtype
    else:
        # No weight, fall back to weightless RMSNorm
        kwargs["has_weight"] = False
    return RMSNorm(**kwargs)


def log_replacement(name: str, old_module: nn.Module, new_module: nn.Module):
    logger.debug("%s: %s -> %s", name, old_module, new_module)


def get_feature_request_tip(
    model: str,
    trust_remote_code: bool,
) -> str:
    hf_url = f"a discussion at https://huggingface.co/{model}/discussions/new"
    gh_url = "an issue at https://github.com/huggingface/transformers/issues/new/choose"
    url = hf_url if trust_remote_code else gh_url
    prefix = f"Please open {url} to request support for this feature. "
    if Path(model).exists():
        prefix = ""
    doc_url = "https://docs.vllm.ai/en/latest/models/supported_models.html#writing-custom-models"
    tip = f"See {doc_url} for instructions on how to add support yourself."
    return f"{prefix}{tip}"


def can_enable_torch_compile(vllm_config: "VllmConfig") -> bool:
    """
    Callable to be passed to `@support_torch_compile`'s `enable_if` argument.

    Defaults to `True` but is disabled in the following situations:

    - The model uses dynamic rope scaling.
    """
    text_config = vllm_config.model_config.hf_config.get_text_config()
    # Dynamic rope scaling is not compatible with torch.compile
    rope_parameters: dict | None = getattr(text_config, "rope_parameters", None) or {}
    if rope_parameters:
        # Nest rope_parameters if not nested already to simplify logic
        if not is_rope_parameters_nested(rope_parameters):
            rope_parameters = {"": rope_parameters}
        return all(rp["rope_type"] != "dynamic" for rp in rope_parameters.values())
    return True
