import logging
import os
import time
from typing import Dict, List

from ray import data
from ray.rllib.offline.io_context import IOContext
from ray.rllib.offline.json_writer import _to_json_dict
from ray.rllib.offline.output_writer import OutputWriter
from ray.rllib.utils.annotations import PublicAPI, override
from ray.rllib.utils.typing import SampleBatchType

logger = logging.getLogger(__name__)


@PublicAPI
class DatasetWriter(OutputWriter):
    """Writer object that saves experiences using Datasets."""

    @PublicAPI
    def __init__(
        self,
        ioctx: IOContext = None,
        compress_columns: List[str] = frozenset(["obs", "new_obs"]),
    ):
        """Initializes a DatasetWriter instance.

        Examples:
        config = {
            "output": "dataset",
            "output_config": {
                "format": "json",
                "path": "/tmp/test_samples/",
                "max_num_samples_per_file": 100000,
            }
        }

        Args:
            ioctx: current IO context object.
            compress_columns: list of sample batch columns to compress.
        """
        self.ioctx = ioctx or IOContext()

        output_config: Dict = ioctx.output_config
        assert (
            "format" in output_config
        ), "output_config.format must be specified when using Dataset output."
        assert (
            "path" in output_config
        ), "output_config.path must be specified when using Dataset output."

        self.format = output_config["format"]
        self.path = os.path.abspath(os.path.expanduser(output_config["path"]))
        self.max_num_samples_per_file = (
            output_config["max_num_samples_per_file"]
            if "max_num_samples_per_file" in output_config
            else 100000
        )
        self.compress_columns = compress_columns

        self.samples = []

    @override(OutputWriter)
    def write(self, sample_batch: SampleBatchType):
        start = time.time()

        # Make sure columns like obs are compressed and writable.
        d = _to_json_dict(sample_batch, self.compress_columns)
        self.samples.append(d)

        # Todo: We should flush at the end of sampling even if this
        # condition was not reached.
        if len(self.samples) >= self.max_num_samples_per_file:
            ds = data.from_items(self.samples).repartition(num_blocks=1, shuffle=False)
            if self.format == "json":
                ds.write_json(self.path, try_create_dir=True)
            elif self.format == "parquet":
                ds.write_parquet(self.path, try_create_dir=True)
            else:
                raise ValueError("Unknown output type: ", self.format)
            self.samples = []
            logger.debug("Wrote dataset in {}s".format(time.time() - start))
