# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.

# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.

from typing import Any, Optional

import torch
import torch.nn.functional as F

from torchao.quantization.quant_primitives import TorchAODType
from torchao.quantization.unified import TwoStepQuantizer
from torchao.quantization.utils import get_group_qparams_symmetric

from .fake_quantize_config import (
    FakeQuantizeConfigBase,
    IntxFakeQuantizeConfig,
)
from .fake_quantizer import FakeQuantizerBase
from .utils import (
    _get_qmin_qmax,
)


class FakeQuantizedEmbedding(torch.nn.Embedding):
    """
    General embedding layer with fake quantized weights.

    Specific target dtypes, granularity, schemes etc. are specified
    through separate configs for weights and activations.

    Example usage::

        weight_config = IntxFakeQuantizeConfig(
            dtype=torch.int4,
            group_size=8,
            symmetric=True,
        )
        fq_embedding = FakeQuantizedEmbedding(5, 10, weight_config)
        fq_embedding(torch.LongTensor([3]))
    """

    def __init__(
        self,
        num_embeddings: int,
        embedding_dim: int,
        padding_idx: Optional[int] = None,
        max_norm: Optional[float] = None,
        norm_type: float = 2.0,
        scale_grad_by_freq: bool = False,
        sparse: bool = False,
        weight_config: Optional[FakeQuantizeConfigBase] = None,
        *args,
        **kwargs,
    ) -> None:
        super().__init__(
            num_embeddings,
            embedding_dim,
            padding_idx,
            max_norm,
            norm_type,
            scale_grad_by_freq,
            sparse,
            *args,
            **kwargs,
        )
        torch._C._log_api_usage_once("torchao.quantization.qat.FakeQuantizedEmbedding")
        if weight_config is not None:
            self.weight_fake_quantizer = FakeQuantizerBase.from_config(weight_config)
        else:
            self.weight_fake_quantizer = None

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        if self.weight_fake_quantizer is not None:
            w = self.weight_fake_quantizer(self.weight)
        else:
            w = self.weight
        return F.embedding(
            x,
            w,
            self.padding_idx,
            self.max_norm,
            self.norm_type,
            self.scale_grad_by_freq,
            self.sparse,
        )

    def to_embedding(self) -> torch.nn.Embedding:
        new_embedding = torch.nn.Embedding(
            self.num_embeddings,
            self.embedding_dim,
            self.padding_idx,
            self.max_norm,
            self.norm_type,
            self.scale_grad_by_freq,
            self.sparse,
            device=self.weight.device,
            dtype=self.weight.dtype,
        )
        # In distributed training, the model may be instantiated
        # on the meta device, in which case there is no need to
        # copy the weights, and doing so will result in an error
        if self.weight.device != torch.device("meta"):
            new_embedding.weight = self.weight
        return new_embedding

    @classmethod
    def from_embedding(
        cls,
        mod: torch.nn.Embedding,
        weight_config: Optional[FakeQuantizeConfigBase] = None,
    ):
        new_embedding = FakeQuantizedEmbedding(
            mod.num_embeddings,
            mod.embedding_dim,
            mod.padding_idx,
            mod.max_norm,
            mod.norm_type,
            mod.scale_grad_by_freq,
            mod.sparse,
            weight_config=weight_config,
            device=mod.weight.device,
            dtype=mod.weight.dtype,
        )
        # In distributed training, the model may be instantiated
        # on the meta device, in which case there is no need to
        # copy the weights, and doing so will result in an error
        if mod.weight.device != torch.device("meta"):
            new_embedding.weight = mod.weight
        return new_embedding


# ======================================
# |   Embedding int4 weight-only QAT   |
# ======================================


class Int4WeightOnlyEmbeddingQATQuantizer(TwoStepQuantizer):
    """
    Quantizer for performing QAT on a model, where embedding layers have
    int4 fake quantized grouped per channel weights.
    """

    def __init__(
        self,
        group_size: int = 256,
        scale_precision: torch.dtype = torch.float32,
        zero_point_precision: torch.dtype = torch.int32,
    ) -> None:
        super().__init__()
        torch._C._log_api_usage_once(
            "torchao.quantization.qat.Int4WeightOnlyEmbeddingQATQuantizer"
        )
        self.bit_width = 4
        self.group_size: int = group_size
        self.scale_precision: torch.dtype = scale_precision
        self.zero_point_precision: torch.dtype = zero_point_precision

    def prepare(
        self, model: torch.nn.Module, *args: Any, **kwargs: Any
    ) -> torch.nn.Module:
        """
        Swap `nn.Embedding` modules with `Int4WeightOnlyQATEmbedding`.
        """
        # avoid circular imports
        from torchao.quantization.quant_api import (
            _replace_with_custom_fn_if_matches_filter,
        )

        def filter_fn(child: torch.nn.Module, cur_fqn: str) -> bool:
            return isinstance(child, torch.nn.Embedding)

        def replacement_fn(child: torch.nn.Module) -> torch.nn.Module:
            new_embedding = Int4WeightOnlyQATEmbedding(
                # nn.Embedding args
                num_embeddings=child.num_embeddings,
                embedding_dim=child.embedding_dim,
                padding_idx=child.padding_idx,
                max_norm=child.max_norm,
                norm_type=child.norm_type,
                scale_grad_by_freq=child.scale_grad_by_freq,
                sparse=child.sparse,
                # quantization args
                group_size=self.group_size,
                scale_precision=self.scale_precision,
                zero_point_precision=self.zero_point_precision,
                device=child.weight.device,
                dtype=child.weight.dtype,
            )
            # In distributed training, the model may be instantiated
            # on the meta device, in which case there is no need to
            # copy the weights, and doing so will result in an error
            if child.weight.device != torch.device("meta"):
                new_embedding.weight = child.weight
            return new_embedding

        _replace_with_custom_fn_if_matches_filter(model, replacement_fn, filter_fn)
        return model

    def convert(
        self, model: torch.nn.Module, *args: Any, **kwargs: Any
    ) -> torch.nn.Module:
        """
        Swap all `Int4WeightOnlyQATEmbedding` modules with `Int4WeightOnlyEmbedding`.
        """
        self._convert_helper(model)
        return model

    def _convert_helper(self, module: torch.nn.Module):
        """
        Helper function to recursively swap `Int4WeightOnlyQATEmbedding`
        modules with `Int4WeightOnlyEmbedding`
        """
        from torchao._executorch_ops import (
            _quantized_decomposed_quantize_per_channel_group_wrapper,
        )

        for name, child in module.named_children():
            if isinstance(child, Int4WeightOnlyQATEmbedding):
                group_size = child.weight_fake_quantizer.config.group_size
                scale_precision = child.weight_fake_quantizer.config.scale_precision
                zero_point_precision = (
                    child.weight_fake_quantizer.config.zero_point_precision
                )
                quantized_embedding = Int4WeightOnlyEmbedding(
                    # nn.Embedding args
                    num_embeddings=child.num_embeddings,
                    embedding_dim=child.embedding_dim,
                    padding_idx=child.padding_idx,
                    max_norm=child.max_norm,
                    norm_type=child.norm_type,
                    scale_grad_by_freq=child.scale_grad_by_freq,
                    sparse=child.sparse,
                    # quantization args
                    group_size=group_size,
                    scale_precision=scale_precision,
                    zero_point_precision=zero_point_precision,
                    device=child.weight.device,
                    output_dtype=child.weight.dtype,
                )
                setattr(module, name, quantized_embedding)

                # Load weights and qparams into quantized embedding
                (qmin, qmax) = _get_qmin_qmax(self.bit_width)
                (s, zp) = get_group_qparams_symmetric(
                    child.weight,
                    self.bit_width,
                    group_size,
                    precision=scale_precision,
                )
                zp = zp.to(zero_point_precision)
                q_weight = _quantized_decomposed_quantize_per_channel_group_wrapper(
                    child.weight,
                    s,
                    zp,
                    qmin,
                    qmax,
                    torch.int8,
                    group_size,
                )
                quantized_embedding.weight = q_weight
                quantized_embedding.scale = s.to(scale_precision)
                quantized_embedding.zero_point = zp.to(zero_point_precision)
            else:
                self._convert_helper(child)


class Int4WeightOnlyQATEmbedding(FakeQuantizedEmbedding):
    """
    This module implements a embedding layer with int4 fake quantized
    grouped per channel weights.

    args:
        group_size: the number of elements in each quantized group for weights
        scale_precision: precision of per group scales
        zero_point_precision: precision of per group zero points
    """

    def __init__(
        self,
        num_embeddings: int,
        embedding_dim: int,
        padding_idx: Optional[int] = None,
        max_norm: Optional[float] = None,
        norm_type: float = 2.0,
        scale_grad_by_freq: bool = False,
        sparse: bool = False,
        group_size: int = 32,
        scale_precision: torch.dtype = torch.float32,
        zero_point_precision: torch.dtype = torch.int32,
        *args,
        **kwargs,
    ):
        weight_config = IntxFakeQuantizeConfig(
            dtype=TorchAODType.INT4,
            group_size=group_size,
            is_symmetric=True,
            is_dynamic=True,
            scale_precision=scale_precision,
            zero_point_precision=zero_point_precision,
        )
        super().__init__(
            num_embeddings,
            embedding_dim,
            padding_idx,
            max_norm,
            norm_type,
            scale_grad_by_freq,
            sparse,
            weight_config,
            *args,
            **kwargs,
        )

    def enable_fake_quant(self, enabled: bool = True):
        self.weight_fake_quantizer.enabled = enabled

    def disable_fake_quant(self):
        self.enable_fake_quant(False)


class Int4WeightOnlyEmbedding(torch.nn.Module):
    """
    This module implements a embedding layer with int4 quantized
    grouped per channel weights.
    """

    def __init__(
        self,
        num_embeddings: int,
        embedding_dim: int,
        padding_idx: Optional[int] = None,
        max_norm: Optional[float] = None,
        norm_type: float = 2.0,
        scale_grad_by_freq: bool = False,
        sparse: bool = False,
        group_size: int = 32,
        scale_precision: torch.dtype = torch.float32,
        zero_point_precision: torch.dtype = torch.int32,
        device: torch.device = None,
        output_dtype: torch.dtype = torch.float32,
    ):
        super().__init__()

        # nn.Embedding args
        self.num_embeddings = num_embeddings
        self.embedding_dim = embedding_dim
        self.padding_idx = padding_idx
        self.max_norm = max_norm
        self.norm_type = norm_type
        self.scale_grad_by_freq = scale_grad_by_freq
        self.sparse = sparse

        # quantization args
        self.bit_width = 4
        self.group_size = group_size
        self.scale_precision = scale_precision
        self.zero_point_precision = zero_point_precision
        self.output_dtype = output_dtype

        # currently storing unpacked int8 weights
        self.register_buffer(
            "weight",
            torch.empty(
                (num_embeddings, embedding_dim), dtype=torch.int8, device=device
            ),
        )
        self.register_buffer(
            "scale",
            torch.empty(
                (num_embeddings, embedding_dim // group_size),
                dtype=scale_precision,
                device=device,
            ),
        )
        self.register_buffer(
            "zero_point",
            torch.empty(
                (num_embeddings, embedding_dim // group_size),
                dtype=zero_point_precision,
                device=device,
            ),
        )

    def forward(self, x):
        from torchao.quantization.quant_primitives import (
            dequantize_affine,
        )

        qmin, qmax = _get_qmin_qmax(self.bit_width)

        # dequantize_affine casts to output_dtype before scaling
        # dequantize_per_channel_group scales and then casts to output_dtype
        # The two do not agree when dtype != torch.float32
        w_dq = dequantize_affine(
            self.weight,
            [1, self.group_size],
            self.scale,
            self.zero_point,
            torch.int8,
            qmin,
            qmax,
            output_dtype=self.output_dtype,
        )
        return F.embedding(
            x,
            w_dq,
            self.padding_idx,
            self.max_norm,
            self.norm_type,
            self.scale_grad_by_freq,
            self.sparse,
        )
