# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD 3-Clause license found in the
# LICENSE file in the root directory of this source tree.
import torch
from torch.ao.quantization.observer import UniformQuantizationObserverBase

__all__ = [
    "create_block_sparse_tensor",
    "create_semi_structured_tensor",
    "PerChannelNormObserver",
    "mask_creator",
]


def create_block_sparse_tensor(M, N, blocksize, sparsity, dtype):
    assert sparsity <= 1.0 and sparsity >= 0.0, (
        "sparsity should be a value between 0 and 1"
    )
    A = torch.bernoulli(
        torch.full((M // blocksize, N // blocksize), 1 - sparsity, dtype=dtype)
    )
    A = torch.repeat_interleave(A, blocksize, dim=0)
    A = torch.repeat_interleave(A, blocksize, dim=1)
    return A.to(dtype).contiguous().cuda()


def create_semi_structured_tensor(r, c, dtype):
    """
    This function returns a 1:2 sparse matrix of size (r, c).
    Note that this means this matrix will also be 2:4 and 4:8 sparse as well.
    """

    # Choices are [0, 1] and [1, 0] - this is practically one-hot
    # encoding for two classes, so for better performance the mask is
    # built as random selection between these two encodings.
    choice_indices = torch.randint(0, 2, (r * c // 2,)).cuda()
    mask = (
        torch.nn.functional.one_hot(choice_indices, num_classes=2)
        .reshape(r, c)
        .contiguous()
        .to(torch.int32)
    )

    sparse_weight = mask + (torch.rand(r, c).cuda() * mask)
    return sparse_weight.to(dtype)


# Observers
class PerChannelNormObserver(UniformQuantizationObserverBase):
    """
    A custom observer that computes the L2 norm of each channel and stores it in a buffer.
    """

    def __init__(self, **kwargs) -> None:
        # init with fixed qparams for quantization flow
        super().__init__(
            dtype=torch.quint8,
            qscheme=torch.per_channel_affine,
            reduce_range=False,
            quant_min=None,
            quant_max=None,
            eps=torch.finfo(torch.float32).eps,
            **kwargs,
        )
        # set averaging constant so quantization flow knows observer is memoryless.
        self.averaging_constant = 1.0
        self.register_buffer("norm", torch.tensor([]))

    #  inconsistently.

    def forward(self, x_orig):
        if x_orig.numel() == 0:
            return x_orig
        x = x_orig.detach()  # avoid keeping autograd tape

        # channel_ax is always the last dimension
        new_axis_list = [i for i in range(x.dim())]  # noqa: C416
        new_axis_list[0], new_axis_list[-1] = new_axis_list[-1], new_axis_list[0]
        y = x.permute(new_axis_list)
        y = torch.flatten(y, start_dim=1)
        norm = torch.linalg.vector_norm(y, dim=1) ** 2

        if self.norm.numel() == 0:
            self.norm.resize_(norm.shape)
            self.norm.copy_(norm)
        else:
            self.norm += norm

        return x_orig

    #  inconsistently.

    def calculate_qparams(self):
        raise NotImplementedError(
            "PerChannelNormObserver is designed to store activations only. "
        )


def mask_creator(
    tensor: torch.Tensor,
    N: int = 2,
    M: int = 4,
) -> torch.Tensor:
    """
    Class for creating N:M sparsity masks.
    Masks will be created using the N:M ratio, where for every block of
    M weights, N will be pruned based on ranked weight value. Each mask
    will correspond to the given tensor.
    :param tensor: The input tensor to create a mask for
    :param N: The number of weights in a group to keep
    :param M: The size of a weight group
    :return: A mask tensor with the same shape as the input tensor
    """
    mask = None
    # for i, tensor in enumerate(tensors):
    if tensor.numel() % M != 0:
        raise ValueError(
            f"Tensor of size {tensor.shape} can't be evenly divided into {M} groups"
        )

    num_groups = tensor.numel() // M

    # N:M sparsity for linear layers
    tensor_temp = tensor.detach().abs().reshape(num_groups, M)
    index = torch.argsort(tensor_temp, dim=1)[:, : int(M - N)]

    w_b = torch.ones(tensor_temp.shape, device=tensor_temp.device)
    mask = w_b.scatter_(dim=1, index=index, value=0).reshape(tensor.shape)

    return mask
