# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#    http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from compressed_tensors.transform import HadamardFactory, TransformFactory
from compressed_tensors.transform.utils.hadamard import random_hadamard_matrix
from torch import device, dtype
from torch.nn import Parameter


@TransformFactory.register("random-hadamard")
class RandomHadamardFactory(HadamardFactory):
    """
    Factory used to apply random hadamard transforms to a model

    :param name: name associated with transform scheme
    :param scheme: transform scheme which defines how transforms should be created
    :param seed: random seed used to transform weight randomization
    """

    def _create_weight(
        self,
        size: int,
        device: device,
        construct_device: device,
        precision: dtype,
    ) -> Parameter:
        data = random_hadamard_matrix(size, precision, construct_device, self.generator)
        data = data.to(device=device)
        return Parameter(data, requires_grad=self.scheme.requires_grad)
