import itertools
import logging
from dataclasses import dataclass
from typing import TYPE_CHECKING, Generic, Iterable, List, Optional, TypeVar

import ray
from ray.data._internal.execution.interfaces import TaskContext
from ray.data.block import Block, BlockAccessor
from ray.util.annotations import DeveloperAPI

if TYPE_CHECKING:
    import pyarrow as pa

logger = logging.getLogger(__name__)


WriteReturnType = TypeVar("WriteReturnType")
"""Generic type for the return value of `Datasink.write`."""


@dataclass
@DeveloperAPI
class WriteResult(Generic[WriteReturnType]):
    """Aggregated result of the Datasink write operations."""

    # Total number of written rows.
    num_rows: int
    # Total size in bytes of written data.
    size_bytes: int
    # All returned values of `Datasink.write`.
    write_returns: List[WriteReturnType]

    @classmethod
    def combine(cls, *wrs: "WriteResult") -> "WriteResult":
        num_rows = sum(wr.num_rows for wr in wrs)
        size_bytes = sum(wr.size_bytes for wr in wrs)
        write_returns = list(itertools.chain(*[wr.write_returns for wr in wrs]))

        return WriteResult(
            num_rows=num_rows,
            size_bytes=size_bytes,
            write_returns=write_returns,
        )


@DeveloperAPI
class Datasink(Generic[WriteReturnType]):
    """Interface for defining write-related logic.

    If you want to write data to something that isn't built-in, subclass this class
    and call :meth:`~ray.data.Dataset.write_datasink`.
    """

    def on_write_start(self, schema: Optional["pa.Schema"] = None) -> None:
        """Callback for when a write job starts.

        Use this method to perform setup for write tasks. For example, creating a
        staging bucket in S3.

        This is called on the driver when the first input bundle is ready, just
        before write tasks are submitted. The schema is extracted from the first
        input bundle, enabling schema-dependent initialization.

        Args:
            schema: The PyArrow schema of the data being written. This is
                automatically extracted from the first input bundle. May be None
                if the input data has no schema.
        """
        pass

    def write(
        self,
        blocks: Iterable[Block],
        ctx: TaskContext,
    ) -> WriteReturnType:
        """Write blocks. This is used by a single write task.

        Args:
            blocks: Generator of data blocks.
            ctx: ``TaskContext`` for the write task.

        Returns:
            Result of this write task. When the entire write operator finishes,
            All returned values will be passed as `WriteResult.write_returns`
            to `Datasink.on_write_complete`.
        """
        raise NotImplementedError

    def on_write_complete(self, write_result: WriteResult[WriteReturnType]):
        """Callback for when a write job completes.

        This can be used to `commit` a write output. This method must
        succeed prior to ``write_datasink()`` returning to the user. If this
        method fails, then ``on_write_failed()`` is called.

        Args:
            write_result: Aggregated result of the
               Write operator, containing write results and stats.
        """
        pass

    def on_write_failed(self, error: Exception) -> None:
        """Callback for when a write job fails.

        This is called on a best-effort basis on write failures.

        Args:
            error: The first error encountered.
        """
        pass

    def get_name(self) -> str:
        """Return a human-readable name for this datasink.

        This is used as the names of the write tasks.
        """
        name = type(self).__name__
        datasink_suffix = "Datasink"
        if name.startswith("_"):
            name = name[1:]
        if name.endswith(datasink_suffix):
            name = name[: -len(datasink_suffix)]
        return name

    @property
    def supports_distributed_writes(self) -> bool:
        """If ``False``, only launch write tasks on the driver's node."""
        return True

    @property
    def min_rows_per_write(self) -> Optional[int]:
        """The target number of rows to pass to each :meth:`~ray.data.Datasink.write` call.

        If ``None``, Ray Data passes a system-chosen number of rows.
        """
        return None


@DeveloperAPI
class DummyOutputDatasink(Datasink[None]):
    """An example implementation of a writable datasource for testing.
    Examples:
        >>> import ray
        >>> from ray.data.datasource import DummyOutputDatasink
        >>> output = DummyOutputDatasink()
        >>> ray.data.range(10).write_datasink(output)
        >>> assert output.num_ok == 1
    """

    def __init__(self):
        ctx = ray.data.DataContext.get_current()

        # Setup a dummy actor to send the data. In a real datasource, write
        # tasks would send data to an external system instead of a Ray actor.
        @ray.remote(scheduling_strategy=ctx.scheduling_strategy)
        class DataSink:
            def __init__(self):
                self.rows_written = 0
                self.enabled = True

            def write(self, block: Block) -> None:
                block = BlockAccessor.for_block(block)
                self.rows_written += block.num_rows()

            def get_rows_written(self):
                return self.rows_written

        self.data_sink = DataSink.remote()
        self.num_ok = 0
        self.num_failed = 0
        self.enabled = True

    def write(
        self,
        blocks: Iterable[Block],
        ctx: TaskContext,
    ) -> None:
        tasks = []
        if not self.enabled:
            raise ValueError("disabled")
        for b in blocks:
            tasks.append(self.data_sink.write.remote(b))
        ray.get(tasks)

    def on_write_complete(self, write_result: WriteResult[None]):
        self.num_ok += 1

    def on_write_failed(self, error: Exception) -> None:
        self.num_failed += 1


def _gen_datasink_write_result(
    write_result_blocks: List[Block],
) -> WriteResult:
    import pandas as pd

    assert all(
        isinstance(block, pd.DataFrame) and len(block) == 1
        for block in write_result_blocks
    )

    total_num_rows = sum(result["num_rows"].sum() for result in write_result_blocks)
    total_size_bytes = sum(result["size_bytes"].sum() for result in write_result_blocks)

    write_returns = [result["write_return"][0] for result in write_result_blocks]
    return WriteResult(total_num_rows, total_size_bytes, write_returns)
