"""Utility functions for importing modules in the LLM module."""
import importlib
import logging
from types import ModuleType
from typing import Any, Optional, Type

logger = logging.getLogger(__name__)


def try_import(name: str, error: bool = False) -> Optional[ModuleType]:
    """Try importing the module and returns the module (or None).

    Args:
        name: The name of the module to import.
        error: Whether to raise an error if the module cannot be imported.

    Returns:
        The module, or None if it cannot be imported.

    Raises:
        ImportError: If error=True and the module is not installed.
    """
    try:
        return importlib.import_module(name)
    except ImportError:
        if error:
            raise ImportError(f"Could not import {name}")
        else:
            logger.warning("Could not import %s", name)
    return None


def load_class(path: str) -> Type[Any]:
    """Load class from string path."""
    if ":" in path:
        module_path, class_name = path.rsplit(":", 1)
    else:
        module_path, class_name = path.rsplit(".", 1)

    module = try_import(module_path, error=True)
    callback_class = getattr(module, class_name)

    return callback_class
