# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#    http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import List, Optional

import torch
from compressed_tensors.config import CompressionFormat, SparsityStructure
from compressed_tensors.quantization import (
    QuantizationArgs,
    QuantizationStrategy,
    QuantizationType,
)
from compressed_tensors.quantization.utils import is_module_quantized
from loguru import logger


__all__ = ["infer_and_set_per_module_quantization_format"]


def _get_quant_compression_format(
    input_args: Optional[QuantizationArgs],
    weight_args: Optional[QuantizationArgs],
    sparsity_structure: Optional[str] = None,
) -> CompressionFormat:
    """
    Using the weight and input quantization args as well as an optional
    sparsity structure, determine the compression format that should be
    applied to a given module

    :param input_args: input quantization parameters
    :param weight_args: weight quantization parameters
    :param sparsity_structure: optional (global) modle sparsity
        structure
    :return CompresssionFormat for the module
    """
    is_24_structure = (
        SparsityStructure(sparsity_structure) == SparsityStructure.TWO_FOUR
    )
    is_weight_only = weight_args is not None and input_args is None

    if weight_args.num_bits == 4 and weight_args.type == QuantizationType.FLOAT.value:
        if weight_args.group_size == 32:
            return CompressionFormat.mxfp4_pack_quantized
        return CompressionFormat.nvfp4_pack_quantized

    if is_weight_only:  # w4a16 and w8a16
        is_valid_pack = (
            weight_args.num_bits in [4, 8]
            and weight_args.type == QuantizationType.INT.value
        )
        if not is_valid_pack:  # packing only valid for int4 and int 8
            return CompressionFormat.naive_quantized

        if is_24_structure and weight_args.strategy in (
            QuantizationStrategy.CHANNEL.value,
            QuantizationStrategy.GROUP.value,
        ):
            # marlin24 kernel only applicable for channel/group quantization
            # Note: vLLM may only support group quant for marlin24
            return CompressionFormat.marlin_24
        return CompressionFormat.pack_quantized

    else:  # w8a8 float and int
        if (
            weight_args.type == QuantizationType.FLOAT.value
            and weight_args.num_bits == 8
        ):
            return CompressionFormat.float_quantized
        if weight_args.type == QuantizationType.INT.value:
            return CompressionFormat.int_quantized

        return CompressionFormat.naive_quantized


def set_per_module_format(
    module: torch.nn.Module,
    sparsity_structure: Optional[str] = None,
    quantization_format: Optional[str] = None,
):
    """
    Determine and set the per module quantization format given quantization args
    and sparsity structure.

    :param module: module which has its quantization inferred
    :param sparsity_structure: optional sparsity applied to the module
    :param quantization_format: optional global format to override
        the per module formats

    """
    weight_scheme = module.quantization_scheme.weights
    input_scheme = module.quantization_scheme.input_activations
    if weight_scheme is None:
        return  # no weight quant - nothing to compress
    compression_format = _get_quant_compression_format(
        input_scheme, weight_scheme, sparsity_structure
    )

    # Check if a global format was provided first
    # This will override any per module format
    if quantization_format is not None:
        if quantization_format != compression_format.value:
            logger.warning(
                "The provided format for the module does not match the "
                "inferred format. Compression may fail "
            )
        module.quantization_scheme.format = quantization_format
    # If a per module format is not provided, we check if it matches our inferred one
    elif module.quantization_scheme.format is not None:
        # If it does not, warn the user
        if module.quantization_scheme.format != compression_format.value:
            logger.warning(
                "The provided format for the module does not match the "
                "inferred format. Compression may fail "
            )
    # If neither provided, set ours
    else:
        module.quantization_scheme.format = compression_format.value


def infer_and_set_per_module_quantization_format(
    model: torch.nn.Module,
    sparsity_structure: Optional[str] = None,
    quantization_format: Optional[str] = None,
) -> List[str]:
    """
    Infers the quantization format for a model based on its state and provided
    compression arguments. Updates thhe quantization_scheme.format value
    based on the inferred format. Returns the unique list of formats in the model.
    All None formats are mapped to CompressionFormat.dense.value

    For a summary of the formats, see `docs/guides/compression_formats.md`.

    :param model: model to check for quantization
    :param sparsity_structure: optional sparsity applied to the module
    :param quantization_format: optional global format to override
        the per module formats
    :return compression format appropriate for the model
    """
    unique_formats = []
    for submodule in model.modules():
        if is_module_quantized(submodule):
            assert hasattr(submodule, "quantization_scheme")
            set_per_module_format(submodule, sparsity_structure, quantization_format)
            if (
                submodule.quantization_scheme.format
                and submodule.quantization_scheme.format not in unique_formats
            ):
                unique_formats.append(submodule.quantization_scheme.format)

    if len(unique_formats) > 0:
        return unique_formats
    return [CompressionFormat.dense.value]
