import json
import logging
import os
import time
from datetime import datetime
from urllib.parse import urlparse

import numpy as np

try:
    from smart_open import smart_open
except ImportError:
    smart_open = None

from typing import Any, Dict, List

from ray.air._internal.json import SafeFallbackEncoder
from ray.rllib.offline.io_context import IOContext
from ray.rllib.offline.output_writer import OutputWriter
from ray.rllib.policy.sample_batch import MultiAgentBatch
from ray.rllib.utils.annotations import PublicAPI, override
from ray.rllib.utils.compression import compression_supported, pack
from ray.rllib.utils.typing import FileType, SampleBatchType

logger = logging.getLogger(__name__)

WINDOWS_DRIVES = [chr(i) for i in range(ord("c"), ord("z") + 1)]


# TODO(jungong): use DatasetWriter to back JsonWriter, so we reduce codebase complexity
#  without losing existing functionality.
@PublicAPI
class JsonWriter(OutputWriter):
    """Writer object that saves experiences in JSON file chunks."""

    @PublicAPI
    def __init__(
        self,
        path: str,
        ioctx: IOContext = None,
        max_file_size: int = 64 * 1024 * 1024,
        compress_columns: List[str] = frozenset(["obs", "new_obs"]),
    ):
        """Initializes a JsonWriter instance.

        Args:
            path: a path/URI of the output directory to save files in.
            ioctx: current IO context object.
            max_file_size: max size of single files before rolling over.
            compress_columns: list of sample batch columns to compress.
        """
        logger.info(
            "You are using JSONWriter. It is recommended to use "
            + "DatasetWriter instead."
        )

        self.ioctx = ioctx or IOContext()
        self.max_file_size = max_file_size
        self.compress_columns = compress_columns
        if urlparse(path).scheme not in [""] + WINDOWS_DRIVES:
            self.path_is_uri = True
        else:
            path = os.path.abspath(os.path.expanduser(path))
            # Try to create local dirs if they don't exist
            os.makedirs(path, exist_ok=True)
            assert os.path.exists(path), "Failed to create {}".format(path)
            self.path_is_uri = False
        self.path = path
        self.file_index = 0
        self.bytes_written = 0
        self.cur_file = None

    @override(OutputWriter)
    def write(self, sample_batch: SampleBatchType):
        start = time.time()
        data = _to_json(sample_batch, self.compress_columns)
        f = self._get_file()
        f.write(data)
        f.write("\n")
        if hasattr(f, "flush"):  # legacy smart_open impls
            f.flush()
        self.bytes_written += len(data)
        logger.debug(
            "Wrote {} bytes to {} in {}s".format(len(data), f, time.time() - start)
        )

    def _get_file(self) -> FileType:
        if not self.cur_file or self.bytes_written >= self.max_file_size:
            if self.cur_file:
                self.cur_file.close()
            timestr = datetime.today().strftime("%Y-%m-%d_%H-%M-%S")
            path = os.path.join(
                self.path,
                "output-{}_worker-{}_{}.json".format(
                    timestr, self.ioctx.worker_index, self.file_index
                ),
            )
            if self.path_is_uri:
                if smart_open is None:
                    raise ValueError(
                        "You must install the `smart_open` module to write "
                        "to URIs like {}".format(path)
                    )
                self.cur_file = smart_open(path, "w")
            else:
                self.cur_file = open(path, "w")
            self.file_index += 1
            self.bytes_written = 0
            logger.info("Writing to new output file {}".format(self.cur_file))
        return self.cur_file


def _to_jsonable(v, compress: bool) -> Any:
    if compress and compression_supported():
        return str(pack(v))
    elif isinstance(v, np.ndarray):
        return v.tolist()

    return v


def _to_json_dict(batch: SampleBatchType, compress_columns: List[str]) -> Dict:
    out = {}
    if isinstance(batch, MultiAgentBatch):
        out["type"] = "MultiAgentBatch"
        out["count"] = batch.count
        policy_batches = {}
        for policy_id, sub_batch in batch.policy_batches.items():
            policy_batches[policy_id] = {}
            for k, v in sub_batch.items():
                policy_batches[policy_id][k] = _to_jsonable(
                    v, compress=k in compress_columns
                )
        out["policy_batches"] = policy_batches
    else:
        out["type"] = "SampleBatch"
        for k, v in batch.items():
            out[k] = _to_jsonable(v, compress=k in compress_columns)
    return out


def _to_json(batch: SampleBatchType, compress_columns: List[str]) -> str:
    out = _to_json_dict(batch, compress_columns)
    return json.dumps(out, cls=SafeFallbackEncoder)
