# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD 3-Clause license found in the
# LICENSE file in the root directory of this source tree.
import abc
import dataclasses
import enum
import importlib
import json
import warnings
from typing import Any, Dict

import torch

__all__ = [
    "AOBaseConfig",
    "config_from_dict",
    "config_to_dict",
    "ALLOWED_AO_MODULES",
]

# the default version for all configs, should never change
_DEFAULT_VERSION = 1


class AOBaseConfig(abc.ABC):
    """
    If a workflow config inherits from this then `quantize_` knows
    how to a apply it to a model. For example::

        # user facing code
        class WorkflowFooConfig(AOBaseConfig): ...
            # configuration for workflow `Foo` is defined here
            bar = 'baz'

        # non user facing code
        @register_quantize_module_handler(WorkflowFooConfig)
        def _transform(
            mod: torch.nn.Module,
            config: WorkflowFooConfig,
        ) -> torch.nn.Module:
            # the transform is implemented here, usually a tensor sublass
            # weight swap or a module swap
            ...

        # then, the user calls `quantize_` with a config, and `_transform` is called
        # under the hood by `quantize_.

    """

    """
    Note: this is not the version of AOBaseConfig, but the default version for instances of
    all child configs inheriting from AOBaseConfig, and it should be `_DEFAULT_VERSION` and never change
    this is making sure all config instances has a version defined, when they need to bump the default
    version they have to define a instance variable version for the child config to overwrite the default version
    that's defined here. Different child config instances will maintain their own version.

    Why version is instance variable instead of class variable? instance level version is needed becuase
    when we have multiple versions co-exist, we need to be able to load objects with earlier versions,
    class level version is global and can't achieve this goal so we have to use instance variable.

    to overwrite this in subclasses, we need to define `version: int` (with type annotations)

    default Version of a config, should never change
    """
    version: int = _DEFAULT_VERSION


class ConfigJSONEncoder(json.JSONEncoder):
    """Custom JSON encoder for AOBaseConfig objects"""

    def default(self, o):
        # Handle AOBaseConfig subclasses first (most specific case)
        if isinstance(o, AOBaseConfig):
            data_dict = {}
            # Process each attribute to handle nested objects
            for k, v in o.__dict__.items():
                if not k.startswith("_") and k != "version":
                    # Recursively encode each value (important for nested objects)
                    data_dict[k] = self.encode_value(v)

            return {
                # Only store the class name, not the full module path
                "_type": o.__class__.__name__,
                "_version": getattr(o, "version", _DEFAULT_VERSION),
                "_data": data_dict,
            }

        # Handle NamedTuple types
        if hasattr(o, "_fields") and hasattr(
            o, "_asdict"
        ):  # Check for NamedTuple characteristics
            asdict_data = o._asdict()
            # Process each field to handle nested objects
            processed_data = {k: self.encode_value(v) for k, v in asdict_data.items()}

            return {
                "_type": o.__class__.__name__,
                "_version": getattr(o, "version", _DEFAULT_VERSION),
                "_data": processed_data,
            }

        # Handle dataclasses
        if dataclasses.is_dataclass(o) and not isinstance(o, type):
            data_dict = {}
            # Process each field to handle nested objects
            for f in dataclasses.fields(o):
                if f.name != "version":
                    data_dict[f.name] = self.encode_value(getattr(o, f.name))

            return {
                # Only store the class name for dataclasses too
                "_type": o.__class__.__name__,
                "_version": getattr(o, "version", _DEFAULT_VERSION),
                "_data": data_dict,
            }

        # Handle torch.dtype
        if hasattr(o, "__module__") and o.__module__ == "torch" and isinstance(o, type):
            return {"_type": "torch.dtype", "_data": str(o).split(".")[-1]}

        # Handle Layout objects
        if hasattr(o, "__class__") and "Layout" in o.__class__.__name__:
            return {
                "_type": o.__class__.__name__,
                "_data": {
                    k: self.encode_value(v)
                    for k, v in o.__dict__.items()
                    if not k.startswith("_")
                },
            }

        # Handle enum values
        if isinstance(o, enum.Enum):
            # Store the full path for enums to ensure uniqueness
            return {"_type": f"{o.__class__.__name__}", "_data": o.name}

        if isinstance(o, torch.dtype):
            return {"_type": "torch.dtype", "_data": str(o).split(".")[-1]}

        # For lists and dictionaries, recursively process their items
        if isinstance(o, list):
            return [self.encode_value(item) for item in o]

        elif isinstance(o, tuple):
            raise NotImplementedError(
                "Tuples will be serialized as List in JSON, so we recommend to use "
                f"Lists instead to avoid surprises. got: {o}"
            )

        if isinstance(o, dict):
            return {k: self.encode_value(v) for k, v in o.items()}

        # Default case - let the parent class handle it
        return super().default(o)

    def encode_value(self, value):
        """Helper method to recursively encode a value"""
        # Try to use default for custom type
        try:
            # This will handle all our special cases and raise TypeError
            # if it can't handle the type
            result = self.default(value)
            return result
        except TypeError:
            pass

        # Default case - return as is
        # (This will be processed by standard JSON encoder later)
        return value


def config_to_dict(config: AOBaseConfig) -> Dict[str, Any]:
    """
    Convert an AOBaseConfig instance to a dictionary suitable for serialization.

    Args:
        config: An instance of AOBaseConfig subclass

    Returns:
        Dict representation of the config
    """
    if not isinstance(config, AOBaseConfig):
        raise TypeError(f"Expected AOBaseConfig instance, got {type(config)}")

    # Use the existing JSON encoder but return the dict directly
    return json.loads(json.dumps(config, cls=ConfigJSONEncoder))


ALLOWED_AO_MODULES = {
    "torchao.quantization",
    "torchao.sparsity.sparse_api",
    "torchao.prototype.quantization",
    "torchao.prototype.mx_formats",
    "torchao.prototype.parq",
    "torchao.dtypes",
    "torchao.prototype.awq",
    "torchao.prototype.parq.quant",
    "torchao.quantization.quantize_.common",
    "torchao.quantization.quantize_.workflows",
}


def config_from_dict(data: Dict[str, Any]) -> AOBaseConfig:
    """
    Create an AOBaseConfig subclass instance from a dictionary.

    Args:
        data: Dictionary containing serialized config data

    Returns:
        An instance of the appropriate AOBaseConfig subclass

    Raises:
        ValueError: If deserialization fails for other reasons
    """
    if not isinstance(data, dict):
        raise TypeError(f"Expected dictionary, got {type(data)}")

    if "_type" not in data or "_data" not in data:
        raise ValueError("Input dictionary missing required '_type' or '_data' fields")

    type_path = data["_type"]
    stored_version = data.get("_version", _DEFAULT_VERSION)
    obj_data = data["_data"]

    # Handle torch.dtype
    if type_path == "torch.dtype":
        import torch

        return getattr(torch, obj_data)
    # Try to find the class in any of the allowed modules
    cls = None
    for module_path in ALLOWED_AO_MODULES:
        try:
            module = importlib.import_module(module_path)
            cls = getattr(module, type_path)
            break  # Found the class, exit the loop
        except (ImportError, AttributeError):
            continue  # Try the next module

    # If we couldn't find the class in any allowed module, raise an error
    if cls is None:
        allowed_modules_str = ", ".join(ALLOWED_AO_MODULES)
        raise ValueError(
            f"Failed to find class {type_path} in any of the allowed modules: {allowed_modules_str}"
        )

    current_default_version = getattr(cls, "version", _DEFAULT_VERSION)
    if stored_version != current_default_version:
        warnings.warn(
            f"Stored version is not the same as current default version of the config: {stored_version=}, {current_default_version=}, please check the deprecation warning"
        )

    # Handle the case where obj_data is not a dictionary
    if not isinstance(obj_data, dict):
        if issubclass(cls, enum.Enum):
            # For enums, convert string to enum value
            return getattr(cls, obj_data)
        else:
            # For other primitive types, create an instance with the value
            try:
                return cls(obj_data)
            except:
                return obj_data

    # Process nested structures for dictionary obj_data
    if stored_version != current_default_version:
        processed_data = {"version": stored_version}
    else:
        processed_data = {}

    for key, value in obj_data.items():
        if isinstance(value, dict) and "_type" in value and "_data" in value:
            # Recursively handle nested configs
            processed_data[key] = config_from_dict(value)
        elif isinstance(value, list):
            # Handle lists or tuples of possible configs
            processed_data[key] = [
                config_from_dict(item)
                if isinstance(item, dict) and "_type" in item and "_data" in item
                else item
                for item in value
            ]
        elif isinstance(value, tuple):
            raise NotImplementedError(
                "Tuples will be serialized as List in JSON, so we recommend to use "
                f"Lists instead to avoid surprises. got: {value}"
            )
        elif isinstance(value, dict):
            # Handle dicts of possible configs
            processed_data[key] = {
                k: config_from_dict(v)
                if isinstance(v, dict) and "_type" in v and "_data" in v
                else v
                for k, v in value.items()
            }
        else:
            processed_data[key] = value

    # Create and return the instance
    try:
        return cls(**processed_data)
    except Exception as e:
        raise ValueError(f"Failed to create instance of {cls.__name__}: {e}")
