import dataclasses
from dataclasses import asdict, fields
from typing import Awaitable, Callable, List, Tuple

import aiohttp.web

from ray.dashboard.optional_utils import rest_response
from ray.dashboard.utils import HTTPStatusCode
from ray.util.state.common import (
    DEFAULT_LIMIT,
    DEFAULT_RPC_TIMEOUT,
    RAY_MAX_LIMIT_FROM_API_SERVER,
    ListApiOptions,
    ListApiResponse,
    PredicateType,
    StateSchema,
    SummaryApiOptions,
    SummaryApiResponse,
    SupportedFilterType,
    filter_fields,
)
from ray.util.state.exception import DataSourceUnavailable
from ray.util.state.util import convert_string_to_type


def do_reply(
    status_code: HTTPStatusCode, error_message: str, result: ListApiResponse, **kwargs
):
    return rest_response(
        status_code=status_code,
        message=error_message,
        result=result,
        convert_google_style=False,
        **kwargs,
    )


async def handle_list_api(
    list_api_fn: Callable[[ListApiOptions], Awaitable[ListApiResponse]],
    req: aiohttp.web.Request,
):
    try:
        result = await list_api_fn(option=options_from_req(req))
        return do_reply(
            status_code=HTTPStatusCode.OK,
            error_message="",
            result=asdict(result),
        )
    except ValueError as e:
        return do_reply(
            status_code=HTTPStatusCode.BAD_REQUEST,
            error_message=str(e),
            result=None,
        )
    except DataSourceUnavailable as e:
        return do_reply(
            status_code=HTTPStatusCode.INTERNAL_ERROR,
            error_message=str(e),
            result=None,
        )


def _get_filters_from_req(
    req: aiohttp.web.Request,
) -> List[Tuple[str, PredicateType, SupportedFilterType]]:
    filter_keys = req.query.getall("filter_keys", [])
    filter_predicates = req.query.getall("filter_predicates", [])
    filter_values = req.query.getall("filter_values", [])
    assert len(filter_keys) == len(filter_values)
    filters = []
    for key, predicate, val in zip(filter_keys, filter_predicates, filter_values):
        filters.append((key, predicate, val))
    return filters


def options_from_req(req: aiohttp.web.Request) -> ListApiOptions:
    """Obtain `ListApiOptions` from the aiohttp request."""
    limit = int(
        req.query.get("limit") if req.query.get("limit") is not None else DEFAULT_LIMIT
    )

    if limit > RAY_MAX_LIMIT_FROM_API_SERVER:
        raise ValueError(
            f"Given limit {limit} exceeds the supported "
            f"Given limit {limit} exceeds the supported "
            f"limit {RAY_MAX_LIMIT_FROM_API_SERVER}. Use a lower limit, or set the "
            f"`RAY_MAX_LIMIT_FROM_API_SERVER` environment variable to a larger value."
        )

    timeout = int(req.query.get("timeout", 30))
    filters = _get_filters_from_req(req)
    detail = convert_string_to_type(req.query.get("detail", False), bool)
    exclude_driver = convert_string_to_type(req.query.get("exclude_driver", True), bool)

    return ListApiOptions(
        limit=limit,
        timeout=timeout,
        filters=filters,
        detail=detail,
        exclude_driver=exclude_driver,
    )


def summary_options_from_req(req: aiohttp.web.Request) -> SummaryApiOptions:
    timeout = int(req.query.get("timeout", DEFAULT_RPC_TIMEOUT))
    filters = _get_filters_from_req(req)
    summary_by = req.query.get("summary_by", None)
    return SummaryApiOptions(timeout=timeout, filters=filters, summary_by=summary_by)


async def handle_summary_api(
    summary_fn: Callable[[SummaryApiOptions], SummaryApiResponse],
    req: aiohttp.web.Request,
):
    result = await summary_fn(option=summary_options_from_req(req))
    return do_reply(
        status_code=HTTPStatusCode.OK,
        error_message="",
        result=asdict(result),
    )


def convert_filters_type(
    filter: List[Tuple[str, PredicateType, SupportedFilterType]],
    schema: StateSchema,
) -> List[Tuple[str, PredicateType, SupportedFilterType]]:
    """Convert the given filter's type to SupportedFilterType.

    This method is necessary because click can only accept a single type
    for its tuple (which is string in this case).

    Args:
        filter: A list of filter which is a tuple of (key, val).
        schema: The state schema. It is used to infer the type of the column for filter.

    Returns:
        A new list of filters with correct types that match the schema.
    """
    new_filter = []
    if dataclasses.is_dataclass(schema):
        schema = {field.name: field.type for field in fields(schema)}
    else:
        schema = schema.schema_dict()

    for col, predicate, val in filter:
        if col in schema:
            column_type = schema[col]
            try:
                isinstance(val, column_type)
            except TypeError:
                # Calling `isinstance` to the Literal type raises a TypeError.
                # Ignore this case.
                pass
            else:
                if isinstance(val, column_type):
                    # Do nothing.
                    pass
                elif column_type is int or column_type == "integer":
                    try:
                        val = convert_string_to_type(val, int)
                    except ValueError:
                        raise ValueError(
                            f"Invalid filter `--filter {col} {val}` for a int type "
                            "column. Please provide an integer filter "
                            f"`--filter {col} [int]`"
                        )
                elif column_type is float or column_type == "number":
                    try:
                        val = convert_string_to_type(
                            val,
                            float,
                        )
                    except ValueError:
                        raise ValueError(
                            f"Invalid filter `--filter {col} {val}` for a float "
                            "type column. Please provide an integer filter "
                            f"`--filter {col} [float]`"
                        )
                elif column_type is bool or column_type == "boolean":
                    try:
                        val = convert_string_to_type(val, bool)
                    except ValueError:
                        raise ValueError(
                            f"Invalid filter `--filter {col} {val}` for a boolean "
                            "type column. Please provide "
                            f"`--filter {col} [True|true|1]` for True or "
                            f"`--filter {col} [False|false|0]` for False."
                        )
        new_filter.append((col, predicate, val))
    return new_filter


def do_filter(
    data: List[dict],
    filters: List[Tuple[str, PredicateType, SupportedFilterType]],
    state_dataclass: StateSchema,
    detail: bool,
) -> List[dict]:
    """Return the filtered data given filters.

    Args:
        data: A list of state data.
        filters: A list of KV tuple to filter data (key, val). The data is filtered
            if data[key] != val.
        state_dataclass: The state schema.

    Returns:
        A list of filtered state data in dictionary. Each state data's
        unnecessary columns are filtered by the given state_dataclass schema.
    """
    filters = convert_filters_type(filters, state_dataclass)
    result = []
    for datum in data:
        match = True
        for filter_column, filter_predicate, filter_value in filters:
            filterable_columns = state_dataclass.filterable_columns()
            filter_column = filter_column.lower()
            if filter_column not in filterable_columns:
                raise ValueError(
                    f"The given filter column {filter_column} is not supported. "
                    "Enter filters with –-filter key=value "
                    "or –-filter key!=value "
                    f"Supported filter columns: {filterable_columns}"
                )

            if filter_column not in datum:
                match = False
            elif filter_predicate == "=":
                if isinstance(filter_value, str) and isinstance(
                    datum[filter_column], str
                ):
                    # Case insensitive match for string filter values.
                    match = datum[filter_column].lower() == filter_value.lower()
                elif isinstance(filter_value, str) and isinstance(
                    datum[filter_column], bool
                ):
                    match = datum[filter_column] == convert_string_to_type(
                        filter_value, bool
                    )
                elif isinstance(filter_value, str) and isinstance(
                    datum[filter_column], int
                ):
                    match = datum[filter_column] == convert_string_to_type(
                        filter_value, int
                    )
                else:
                    match = datum[filter_column] == filter_value
            elif filter_predicate == "!=":
                if isinstance(filter_value, str) and isinstance(
                    datum[filter_column], str
                ):
                    match = datum[filter_column].lower() != filter_value.lower()
                else:
                    match = datum[filter_column] != filter_value
            else:
                raise ValueError(
                    f"Unsupported filter predicate {filter_predicate} is given. "
                    "Available predicates: =, !=."
                )

            if not match:
                break

        if match:
            result.append(filter_fields(datum, state_dataclass, detail))
    return result
