# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.

# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.

from typing import Any, Optional

import torch
import torch.nn.functional as F

from torchao.dtypes.utils import is_device
from torchao.quantization.granularity import PerGroup, PerRow
from torchao.quantization.linear_quant_modules import (
    Int8DynActInt4WeightLinear,
    WeightOnlyInt4Linear,
    _check_linear_int4_k,
    _replace_linear_8da4w,
    _replace_linear_int4,
    groupwise_affine_quantize_tensor,
)
from torchao.quantization.quant_primitives import (
    TorchAODType,
    ZeroPointDomain,
)
from torchao.quantization.unified import TwoStepQuantizer
from torchao.quantization.utils import get_group_qparams_symmetric

from .fake_quantize_config import (
    FakeQuantizeConfigBase,
    Float8FakeQuantizeConfig,
    IntxFakeQuantizeConfig,
)
from .fake_quantizer import (
    FakeQuantizerBase,
)
from .utils import (
    _get_qmin_qmax,
)


class FakeQuantizedLinear(torch.nn.Linear):
    """
    General linear layer with fake quantized weights and activations.

    Specific target dtypes, granularity, schemes etc. are specified
    through separate configs for weights and activations.

    Example usage::

        activation_config = IntxFakeQuantizeConfig(
            dtype=torch.int8,
            granularity="per_token",
            is_symmetric=False,
        )
        weight_config = IntxFakeQuantizeConfig(
            dtype=torch.int4,
            group_size=8,
            is_symmetric=True,
        )
        fq_linear = FakeQuantizedLinear(
            16, 32, False, activation_config, weight_config,
        )
        fq_linear(torch.randn(16))
    """

    def __init__(
        self,
        in_features: int,
        out_features: int,
        bias: bool = False,
        activation_config: Optional[FakeQuantizeConfigBase] = None,
        weight_config: Optional[FakeQuantizeConfigBase] = None,
        *args,
        **kwargs,
    ) -> None:
        super().__init__(
            in_features,
            out_features,
            bias,
            *args,
            **kwargs,
        )
        torch._C._log_api_usage_once("torchao.quantization.qat.FakeQuantizedLinear")
        # initialize activation fake quantizer
        if activation_config is not None:
            self.activation_fake_quantizer = FakeQuantizerBase.from_config(
                activation_config
            )
        else:
            self.activation_fake_quantizer = None

        # initialize weight fake quantizer
        if weight_config is not None:
            if isinstance(weight_config, IntxFakeQuantizeConfig) and isinstance(
                weight_config.granularity, PerGroup
            ):
                group_size = weight_config.group_size
                if group_size is not None and in_features % group_size != 0:
                    raise ValueError(
                        "in_features (%s) %% group_size (%s) must be == 0"
                        % (in_features, group_size)
                    )
            self.weight_fake_quantizer = FakeQuantizerBase.from_config(weight_config)
        else:
            self.weight_fake_quantizer = None

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        if self.activation_fake_quantizer is not None:
            x = self.activation_fake_quantizer(x)
        if self.weight_fake_quantizer is not None:
            w = self.weight_fake_quantizer(self.weight)
        else:
            w = self.weight
        return F.linear(x, w, self.bias)

    def to_linear(self) -> torch.nn.Linear:
        new_linear = torch.nn.Linear(
            self.in_features,
            self.out_features,
            self.bias is not None,
            device=self.weight.device,
            dtype=self.weight.dtype,
        )
        # In distributed training, the model may be instantiated
        # on the meta device, in which case there is no need to
        # copy the weights, and doing so will result in an error
        if self.weight.device != torch.device("meta"):
            new_linear.weight = self.weight
            new_linear.bias = self.bias
        return new_linear

    @classmethod
    def from_linear(
        cls,
        mod: torch.nn.Linear,
        activation_config: Optional[FakeQuantizeConfigBase] = None,
        weight_config: Optional[FakeQuantizeConfigBase] = None,
    ):
        new_linear = FakeQuantizedLinear(
            mod.in_features,
            mod.out_features,
            mod.bias is not None,
            activation_config=activation_config,
            weight_config=weight_config,
            device=mod.weight.device,
            dtype=mod.weight.dtype,
        )
        # In distributed training, the model may be instantiated
        # on the meta device, in which case there is no need to
        # copy the weights, and doing so will result in an error
        if mod.weight.device != torch.device("meta"):
            new_linear.weight = mod.weight
            new_linear.bias = mod.bias
        return new_linear


def enable_linear_fake_quant(
    mod: torch.nn.Module,
    enabled: bool = True,
):
    """
    Helper function to enable fake quantization in `FakeQuantizedLinear`.
    """
    if isinstance(mod, FakeQuantizedLinear):
        if mod.activation_fake_quantizer is not None:
            mod.activation_fake_quantizer.enabled = enabled
        if mod.weight_fake_quantizer is not None:
            mod.weight_fake_quantizer.enabled = enabled


def disable_linear_fake_quant(mod: torch.nn.Module):
    """
    Helper function to disable fake quantization in `FakeQuantizedLinear`.
    """
    enable_linear_fake_quant(mod, enabled=False)


# ===========================
# | QAT quantizer interface |
# ===========================


class _LegacyQATQuantizer(TwoStepQuantizer):
    """
    Base class for sharing common methods across legacy QAT quantizers.
    """

    def get_activation_fake_quantize_config(self) -> Optional[FakeQuantizeConfigBase]:
        return None

    def get_weight_fake_quantize_config(self) -> Optional[FakeQuantizeConfigBase]:
        return None


# ===========================================
# | int8 dynamic activations + int4 weights |
# ===========================================


class Int8DynActInt4WeightQATQuantizer(_LegacyQATQuantizer):
    """
    Quantizer for performing QAT on a model, where linear layers have int8
    dynamic per token fake quantized activations and int4 fake quantized
    grouped per channel weights.
    """

    def __init__(
        self,
        groupsize: int = 256,
        padding_allowed: bool = False,
        precision: torch.dtype = torch.float32,
        scales_precision: torch.dtype = torch.float32,
    ) -> None:
        super().__init__()
        torch._C._log_api_usage_once(
            "torchao.quantization.qat.Int8DynActInt4WeightQATQuantizer"
        )
        self.groupsize: int = groupsize
        self.padding_allowed: bool = padding_allowed
        self.precision: torch.dtype = precision
        self.scales_precision: torch.dtype = scales_precision
        # TODO: generalize this
        self.activation_scales_precision = torch.float32

    def prepare(
        self, model: torch.nn.Module, *args: Any, **kwargs: Any
    ) -> torch.nn.Module:
        _replace_linear_8da4w(
            model,
            self.groupsize,
            self.padding_allowed,
            self.precision,
            self.scales_precision,
            Int8DynActInt4WeightQATLinear,
            copy_weights=True,
        )
        return model

    def convert(
        self, model: torch.nn.Module, *args: Any, **kwargs: Any
    ) -> torch.nn.Module:
        self._convert_qat_linear_8da4w(model)
        return model

    def _convert_qat_linear_8da4w(self, module: torch.nn.Module):
        """
        Replace all `Int8DynActInt4WeightQATLinear` with `Int8DynActInt4WeightLinear`.
        """
        for name, child in module.named_children():
            if isinstance(child, Int8DynActInt4WeightQATLinear):
                config = child.weight_fake_quantizer.config
                quantized_linear = Int8DynActInt4WeightLinear(
                    child.in_features,
                    child.out_features,
                    child.bias is not None,
                    groupsize=config.group_size,
                    precision=child.weight.dtype,
                    scales_precision=config.scale_precision,
                )
                setattr(module, name, quantized_linear)

                # Load weights and qparams into quantized linear
                n_bit = 4
                (qmin, qmax) = _get_qmin_qmax(n_bit)
                (s, zp) = get_group_qparams_symmetric(
                    child.weight,
                    n_bit,
                    config.group_size,
                    precision=config.scale_precision,
                )
                zp = zp.to(config.zero_point_precision)
                from torchao._executorch_ops import (
                    _quantized_decomposed_quantize_per_channel_group_wrapper,
                )

                q_weight = _quantized_decomposed_quantize_per_channel_group_wrapper(
                    child.weight,
                    s,
                    zp,
                    qmin,
                    qmax,
                    torch.int8,
                    config.group_size,
                )
                quantized_linear.weight = q_weight
                quantized_linear.scales = s
                quantized_linear.zeros = zp
                if child.bias is not None:
                    quantized_linear.bias = child.bias
            else:
                self._convert_qat_linear_8da4w(child)

    def get_activation_fake_quantize_config(self) -> Optional[FakeQuantizeConfigBase]:
        return _get_8da4w_activation_config(self.activation_scales_precision)

    def get_weight_fake_quantize_config(self) -> Optional[FakeQuantizeConfigBase]:
        return _get_8da4w_weight_config(self.groupsize, self.scales_precision)


class Int8DynActInt4WeightQATLinear(FakeQuantizedLinear):
    """
    This module implements a linear layer with int8 dynamic per token fake
    quantized activations with int4 fake quantized grouped per channel weights.

    args:
        groupsize: the number of elements in each quantized group for weights
        precision: precision of weights
        scales_precision: precision of per group scales and zero points

    Note: we hardcode activation scales to use torch.fp32, but allow users to specify the weight scales (defaults to torch.fp32).
    To get an exact numerical match with Int8DynamicActivationInt4WeightConfig, users must use the same dtype for both the weights
    and the scales. Here scales_precision refers specifically to the weight scales only, not the activation scales.
    """

    def __init__(
        self,
        in_features: int,
        out_features: int,
        bias: bool = False,
        device: torch.device = None,
        groupsize: int = 256,
        precision: torch.dtype = torch.float32,
        scales_precision: torch.dtype = torch.float32,
    ) -> None:
        # Use torch.float32 to match torchao.quantization.quant_api._int8_asymm_per_token_quant,
        # which is used in PTQ routines
        # TODO: generalize this
        activation_config = _get_8da4w_activation_config(torch.float32)
        weight_config = _get_8da4w_weight_config(groupsize, scales_precision)
        super().__init__(
            in_features,
            out_features,
            bias,
            activation_config,
            weight_config,
            device=device,
            dtype=precision,
        )

    def enable_fake_quant(self, enabled: bool = True):
        self.activation_fake_quantizer.enabled = enabled
        self.weight_fake_quantizer.enabled = enabled

    def disable_fake_quant(self):
        self.enable_fake_quant(False)


# TODO: remove these in favor of enable_linear_fake_quant
def enable_8da4w_fake_quant(mod: torch.nn.Module):
    """
    (deprecated) Enable fake quantization for `Int8DynActInt4WeightQATLinear`.
    """
    if isinstance(mod, Int8DynActInt4WeightQATLinear):
        mod.enable_fake_quant()


# TODO: remove in favor of disable_linear_fake_quant
def disable_8da4w_fake_quant(mod: torch.nn.Module):
    """
    (deprecated) Disable fake quantization for `Int8DynActInt4WeightQATLinear`.
    """
    if isinstance(mod, Int8DynActInt4WeightQATLinear):
        mod.disable_fake_quant()


def _get_8da4w_activation_config(
    qparams_precision: torch.dtype,
) -> IntxFakeQuantizeConfig:
    """
    Return the activation `IntxFakeQuantizeConfig` for `Int8DynActInt4WeightQATQuantizer`.
    """
    # TODO: generalize this
    assert qparams_precision == torch.float32
    return IntxFakeQuantizeConfig(
        dtype=torch.int8,
        granularity="per_token",
        is_symmetric=False,
        is_dynamic=True,
        scale_precision=qparams_precision,
        zero_point_precision=qparams_precision,
        eps=torch.finfo(qparams_precision).eps,
    )


def _get_8da4w_weight_config(
    group_size: int,
    qparams_precision: torch.dtype,
) -> IntxFakeQuantizeConfig:
    """
    Return the weight `IntxFakeQuantizeConfig` for `Int8DynActInt4WeightQATQuantizer`.
    """
    return IntxFakeQuantizeConfig(
        dtype=TorchAODType.INT4,
        group_size=group_size,
        is_symmetric=True,
        is_dynamic=True,
        scale_precision=qparams_precision,
        zero_point_precision=qparams_precision,
    )


# ====================
# | int4 weight-only |
# ====================


class Int4WeightOnlyQATQuantizer(_LegacyQATQuantizer):
    """
    Quantizer for performing QAT on a model, where linear layers have
    int4 fake quantized grouped per channel weights.
    """

    def __init__(
        self,
        groupsize: int = 256,
        inner_k_tiles: Optional[int] = 8,
        precision: torch.dtype = torch.bfloat16,
        scales_precision: torch.dtype = torch.bfloat16,
    ) -> None:
        super().__init__()
        torch._C._log_api_usage_once(
            "torchao.quantization.qat.Int4WeightOnlyQATQuantizer"
        )
        assert inner_k_tiles in [2, 4, 8]
        assert groupsize in [32, 64, 128, 256]
        self.inner_k_tiles = inner_k_tiles
        self.groupsize = groupsize
        self.precision = precision
        self.scales_precision = scales_precision

    def prepare(
        self, model: torch.nn.Module, *args: Any, **kwargs: Any
    ) -> torch.nn.Module:
        _replace_linear_int4(
            model,
            self.groupsize,
            self.inner_k_tiles,
            padding_allowed=True,
            precision=self.precision,
            scales_precision=self.scales_precision,
            linear_class=Int4WeightOnlyQATLinear,
            copy_weights=True,
        )
        return model

    def convert(
        self, model: torch.nn.Module, *args: Any, **kwargs: Any
    ) -> torch.nn.Module:
        self._convert_qat_linear_4w(model)
        return model

    def _convert_qat_linear_4w(self, module: torch.nn.Module):
        """
        Replace all `Int4WeightOnlyQATLinear` with `WeightOnlyInt4Linear`.
        """
        for name, child in module.named_children():
            if isinstance(child, Int4WeightOnlyQATLinear):
                in_features = child.in_features
                out_features = child.out_features
                inner_k_tiles = child.inner_k_tiles
                config = child.weight_fake_quantizer.config
                quantized_linear = WeightOnlyInt4Linear(
                    in_features,
                    out_features,
                    bias=False,
                    groupsize=config.group_size,
                    inner_k_tiles=inner_k_tiles,
                    precision=child.weight.dtype,
                    scales_precision=config.scale_precision,
                    device=next(child.parameters()).device,
                )
                setattr(module, name, quantized_linear)

                # Load weights and qparams into quantized linear
                n_bit = 4
                (q_weight, scales_and_zeros) = groupwise_affine_quantize_tensor(
                    child.weight,
                    n_bit,
                    config.group_size,
                )
                if is_device(q_weight.device.type, "cpu"):
                    q_weight = torch.ops.aten._convert_weight_to_int4pack_for_cpu(
                        q_weight.to(child.weight.device),
                        child.inner_k_tiles,
                    )
                else:
                    q_weight = torch.ops.aten._convert_weight_to_int4pack(
                        q_weight.to(child.weight.device),
                        child.inner_k_tiles,
                    )
                quantized_linear.weight = q_weight
                quantized_linear.scales_and_zeros = scales_and_zeros
            else:
                self._convert_qat_linear_4w(child)

    def get_weight_fake_quantize_config(self) -> Optional[FakeQuantizeConfigBase]:
        return _get_4w_weight_config(self.groupsize, self.scales_precision)


class Int4WeightOnlyQATLinear(FakeQuantizedLinear):
    """
    This module implements a linear layer with int4 fake quantized grouped
    per channel weights, with forward numerics matching `WeightOnlyInt4Linear`,
    which uses the efficient int4 tinygemm kernel.

    args:
        groupsize: the number of elements in each quantized group for weights
        precision: precision of weights
        scales_precision: precision of per group scales and zero points
    """

    def __init__(
        self,
        in_features: int,
        out_features: int,
        bias: bool = False,
        device: torch.device = None,
        groupsize: int = 256,
        inner_k_tiles: int = 8,
        precision: torch.dtype = torch.bfloat16,
        scales_precision: torch.dtype = torch.bfloat16,
    ) -> None:
        assert scales_precision == torch.bfloat16, "only bf16 is supported for scales"
        if not _check_linear_int4_k(in_features, groupsize, inner_k_tiles):
            raise ValueError("Padding for QAT 4w is not supported yet")
        self.inner_k_tiles = inner_k_tiles
        weight_config = _get_4w_weight_config(groupsize, scales_precision)
        super().__init__(
            in_features,
            out_features,
            bias,
            activation_config=None,
            weight_config=weight_config,
            device=device,
            dtype=precision,
        )

    def enable_fake_quant(self, enabled: bool = True):
        self.activation_fake_quantizer.enabled = enabled
        self.weight_fake_quantizer.enabled = enabled

    def disable_fake_quant(self):
        self.enable_fake_quant(False)


# TODO: remove these in favor of enable_linear_fake_quant
def enable_4w_fake_quant(mod: torch.nn.Module):
    """
    (deprecated) Enable fake quantization for `Int4WeightOnlyQATLinear`.
    """
    if isinstance(mod, Int4WeightOnlyQATLinear):
        mod.enable_fake_quant()


# TODO: remove these in favor of disable_linear_fake_quant
def disable_4w_fake_quant(mod: torch.nn.Module):
    """
    (deprecated) Disable fake quantization for `Int4WeightOnlyQATLinear`.
    """
    if isinstance(mod, Int4WeightOnlyQATLinear):
        mod.disable_fake_quant()


def _get_4w_weight_config(
    group_size: int,
    qparams_precision: torch.dtype,
) -> IntxFakeQuantizeConfig:
    """
    Return the weight `IntxFakeQuantizeConfig` for `Int4WeightOnlyQATQuantizer`.
    """
    return IntxFakeQuantizeConfig(
        dtype=torch.uint4,
        group_size=group_size,
        is_symmetric=False,
        is_dynamic=True,
        scale_precision=qparams_precision,
        zero_point_precision=qparams_precision,
        zero_point_domain=ZeroPointDomain.FLOAT,
    )


# =============================================
# | float8 rowwise activations + int4 weights |
# =============================================


class Float8ActInt4WeightQATQuantizer(_LegacyQATQuantizer):
    """
    QAT quantizer for applying dynamic rowwise float8 activation + int4
    per group/channel symmetric weight fake quantization to linear layers
    in the model. Currently only supports rowwise granularity for float8
    activations.

    args:
        group_size (Optional[int]): the number of elements in each quantized
            group for weights, defaults to 64. Use None for per channel.
        scale_precision: precision of weight scales, defaults to torch.bfloat16.
    """

    def __init__(
        self,
        group_size: Optional[int] = 64,
        scale_precision: torch.dtype = torch.bfloat16,
    ):
        torch._C._log_api_usage_once(
            "torchao.quantization.qat.Float8ActInt4WeightQATQuantizer"
        )
        if group_size is not None:
            weight_granularity = "per_group"
        else:
            weight_granularity = "per_channel"
        self._activation_config = Float8FakeQuantizeConfig(
            dtype=torch.float8_e4m3fn,
            granularity=PerRow(),
        )
        self._weight_config = IntxFakeQuantizeConfig(
            dtype=torch.int4,
            granularity=weight_granularity,
            group_size=group_size,
            is_symmetric=True,
            is_dynamic=True,
            scale_precision=scale_precision,
        )

    def prepare(
        self, model: torch.nn.Module, *args: Any, **kwargs: Any
    ) -> torch.nn.Module:
        """
        Swap all `nn.Linear` with `FakeQuantizedLinear` with float8
        fake quantizer for activations and int4 fake quantizer for weights.
        """
        for name, child in model.named_children():
            if isinstance(child, torch.nn.Linear):
                new_linear = FakeQuantizedLinear.from_linear(
                    child,
                    activation_config=self._activation_config,
                    weight_config=self._weight_config,
                )
                setattr(model, name, new_linear)
            else:
                self.prepare(child)
        return model

    # TODO: add convert path
    def convert(
        self, model: torch.nn.Module, *args: Any, **kwargs: Any
    ) -> torch.nn.Module:
        raise NotImplementedError

    def get_activation_fake_quantize_config(self) -> Optional[FakeQuantizeConfigBase]:
        raise NotImplementedError("Float8 FakeQuantizeConfig does not exist yet")

    def get_weight_fake_quantize_config(self) -> Optional[FakeQuantizeConfigBase]:
        return self.weight_config
