import glob
import json
import logging
import math
import os
import random
import re
import zipfile
from pathlib import Path
from typing import TYPE_CHECKING, List, Optional, Union
from urllib.parse import urlparse

import numpy as np
import tree  # pip install dm_tree

try:
    from smart_open import smart_open
except ImportError:
    smart_open = None

from ray.rllib.offline.input_reader import InputReader
from ray.rllib.offline.io_context import IOContext
from ray.rllib.policy.policy import Policy
from ray.rllib.policy.sample_batch import (
    DEFAULT_POLICY_ID,
    MultiAgentBatch,
    SampleBatch,
    concat_samples,
    convert_ma_batch_to_sample_batch,
)
from ray.rllib.utils.annotations import DeveloperAPI, PublicAPI, override
from ray.rllib.utils.compression import unpack_if_needed
from ray.rllib.utils.spaces.space_utils import clip_action, normalize_action
from ray.rllib.utils.typing import Any, FileType, SampleBatchType

if TYPE_CHECKING:
    from ray.rllib.evaluation import RolloutWorker

logger = logging.getLogger(__name__)

WINDOWS_DRIVES = [chr(i) for i in range(ord("c"), ord("z") + 1)]


def _adjust_obs_actions_for_policy(json_data: dict, policy: Policy) -> dict:
    """Handle nested action/observation spaces for policies.

    Translates nested lists/dicts from the json into proper
    np.ndarrays, according to the (nested) observation- and action-
    spaces of the given policy.

    Providing nested lists w/o this preprocessing step would
    confuse a SampleBatch constructor.
    """
    for k, v in json_data.items():
        data_col = (
            policy.view_requirements[k].data_col
            if k in policy.view_requirements
            else ""
        )
        # No action flattening -> Process nested (leaf) action(s).
        if policy.config.get("_disable_action_flattening") and (
            k == SampleBatch.ACTIONS
            or data_col == SampleBatch.ACTIONS
            or k == SampleBatch.PREV_ACTIONS
            or data_col == SampleBatch.PREV_ACTIONS
        ):
            json_data[k] = tree.map_structure_up_to(
                policy.action_space_struct,
                lambda comp: np.array(comp),
                json_data[k],
                check_types=False,
            )
        # No preprocessing -> Process nested (leaf) observation(s).
        elif policy.config.get("_disable_preprocessor_api") and (
            k == SampleBatch.OBS
            or data_col == SampleBatch.OBS
            or k == SampleBatch.NEXT_OBS
            or data_col == SampleBatch.NEXT_OBS
        ):
            json_data[k] = tree.map_structure_up_to(
                policy.observation_space_struct,
                lambda comp: np.array(comp),
                json_data[k],
                check_types=False,
            )
    return json_data


@DeveloperAPI
def _adjust_dones(json_data: dict) -> dict:
    """Make sure DONES in json data is properly translated into TERMINATEDS."""
    new_json_data = {}
    for k, v in json_data.items():
        # Translate DONES into TERMINATEDS.
        if k == SampleBatch.DONES:
            new_json_data[SampleBatch.TERMINATEDS] = v
        # Leave everything else as-is.
        else:
            new_json_data[k] = v

    return new_json_data


@DeveloperAPI
def postprocess_actions(batch: SampleBatchType, ioctx: IOContext) -> SampleBatchType:
    # Clip actions (from any values into env's bounds), if necessary.
    cfg = ioctx.config
    # TODO(jungong): We should not clip_action in input reader.
    #  Use connector to handle this.
    if cfg.get("clip_actions"):
        if ioctx.worker is None:
            raise ValueError(
                "clip_actions is True but cannot clip actions since no workers exist"
            )

        if isinstance(batch, SampleBatch):
            policy = ioctx.worker.policy_map.get(DEFAULT_POLICY_ID)
            if policy is None:
                assert len(ioctx.worker.policy_map) == 1
                policy = next(iter(ioctx.worker.policy_map.values()))
            batch[SampleBatch.ACTIONS] = clip_action(
                batch[SampleBatch.ACTIONS], policy.action_space_struct
            )
        else:
            for pid, b in batch.policy_batches.items():
                b[SampleBatch.ACTIONS] = clip_action(
                    b[SampleBatch.ACTIONS],
                    ioctx.worker.policy_map[pid].action_space_struct,
                )
    # Re-normalize actions (from env's bounds to zero-centered), if
    # necessary.
    if (
        cfg.get("actions_in_input_normalized") is False
        and cfg.get("normalize_actions") is True
    ):
        if ioctx.worker is None:
            raise ValueError(
                "actions_in_input_normalized is False but"
                "cannot normalize actions since no workers exist"
            )

        # If we have a complex action space and actions were flattened
        # and we have to normalize -> Error.
        error_msg = (
            "Normalization of offline actions that are flattened is not "
            "supported! Make sure that you record actions into offline "
            "file with the `_disable_action_flattening=True` flag OR "
            "as already normalized (between -1.0 and 1.0) values. "
            "Also, when reading already normalized action values from "
            "offline files, make sure to set "
            "`actions_in_input_normalized=True` so that RLlib will not "
            "perform normalization on top."
        )

        if isinstance(batch, SampleBatch):
            policy = ioctx.worker.policy_map.get(DEFAULT_POLICY_ID)
            if policy is None:
                assert len(ioctx.worker.policy_map) == 1
                policy = next(iter(ioctx.worker.policy_map.values()))
            if isinstance(
                policy.action_space_struct, (tuple, dict)
            ) and not policy.config.get("_disable_action_flattening"):
                raise ValueError(error_msg)
            batch[SampleBatch.ACTIONS] = normalize_action(
                batch[SampleBatch.ACTIONS], policy.action_space_struct
            )
        else:
            for pid, b in batch.policy_batches.items():
                policy = ioctx.worker.policy_map[pid]
                if isinstance(
                    policy.action_space_struct, (tuple, dict)
                ) and not policy.config.get("_disable_action_flattening"):
                    raise ValueError(error_msg)
                b[SampleBatch.ACTIONS] = normalize_action(
                    b[SampleBatch.ACTIONS],
                    ioctx.worker.policy_map[pid].action_space_struct,
                )

    return batch


@DeveloperAPI
def from_json_data(json_data: Any, worker: Optional["RolloutWorker"]):
    # Try to infer the SampleBatchType (SampleBatch or MultiAgentBatch).
    if "type" in json_data:
        data_type = json_data.pop("type")
    else:
        raise ValueError("JSON record missing 'type' field")

    if data_type == "SampleBatch":
        if worker is not None and len(worker.policy_map) != 1:
            raise ValueError(
                "Found single-agent SampleBatch in input file, but our "
                "PolicyMap contains more than 1 policy!"
            )
        for k, v in json_data.items():
            json_data[k] = unpack_if_needed(v)
        if worker is not None:
            policy = next(iter(worker.policy_map.values()))
            json_data = _adjust_obs_actions_for_policy(json_data, policy)
        json_data = _adjust_dones(json_data)
        return SampleBatch(json_data)
    elif data_type == "MultiAgentBatch":
        policy_batches = {}
        for policy_id, policy_batch in json_data["policy_batches"].items():
            inner = {}
            for k, v in policy_batch.items():
                # Translate DONES into TERMINATEDS.
                if k == SampleBatch.DONES:
                    k = SampleBatch.TERMINATEDS
                inner[k] = unpack_if_needed(v)
            if worker is not None:
                policy = worker.policy_map[policy_id]
                inner = _adjust_obs_actions_for_policy(inner, policy)
            inner = _adjust_dones(inner)
            policy_batches[policy_id] = SampleBatch(inner)
        return MultiAgentBatch(policy_batches, json_data["count"])
    else:
        raise ValueError(
            "Type field must be one of ['SampleBatch', 'MultiAgentBatch']", data_type
        )


# TODO(jungong) : use DatasetReader to back JsonReader, so we reduce
#     codebase complexity without losing existing functionality.
@PublicAPI
class JsonReader(InputReader):
    """Reader object that loads experiences from JSON file chunks.

    The input files will be read from in random order.
    """

    @PublicAPI
    def __init__(
        self, inputs: Union[str, List[str]], ioctx: Optional[IOContext] = None
    ):
        """Initializes a JsonReader instance.

        Args:
            inputs: Either a glob expression for files, e.g. `/tmp/**/*.json`,
                or a list of single file paths or URIs, e.g.,
                ["s3://bucket/file.json", "s3://bucket/file2.json"].
            ioctx: Current IO context object or None.
        """
        logger.info(
            "You are using JSONReader. It is recommended to use "
            + "DatasetReader instead for better sharding support."
        )

        self.ioctx = ioctx or IOContext()
        self.default_policy = self.policy_map = None
        self.batch_size = 1
        if self.ioctx:
            self.batch_size = self.ioctx.config.get("train_batch_size", 1)
            num_workers = self.ioctx.config.get("num_env_runners", 0)
            if num_workers:
                self.batch_size = max(math.ceil(self.batch_size / num_workers), 1)

        if self.ioctx.worker is not None:
            self.policy_map = self.ioctx.worker.policy_map
            self.default_policy = self.policy_map.get(DEFAULT_POLICY_ID)
            if self.default_policy is None:
                assert len(self.policy_map) == 1
                self.default_policy = next(iter(self.policy_map.values()))

        if isinstance(inputs, str):
            inputs = os.path.abspath(os.path.expanduser(inputs))
            if os.path.isdir(inputs):
                inputs = [os.path.join(inputs, "*.json"), os.path.join(inputs, "*.zip")]
                logger.warning(f"Treating input directory as glob patterns: {inputs}")
            else:
                inputs = [inputs]

            if any(urlparse(i).scheme not in [""] + WINDOWS_DRIVES for i in inputs):
                raise ValueError(
                    "Don't know how to glob over `{}`, ".format(inputs)
                    + "please specify a list of files to read instead."
                )
            else:
                self.files = []
                for i in inputs:
                    self.files.extend(glob.glob(i))
        elif isinstance(inputs, (list, tuple)):
            self.files = list(inputs)
        else:
            raise ValueError(
                "type of inputs must be list or str, not {}".format(inputs)
            )
        if self.files:
            logger.info("Found {} input files.".format(len(self.files)))
        else:
            raise ValueError("No files found matching {}".format(inputs))
        self.cur_file = None

    @override(InputReader)
    def next(self) -> SampleBatchType:
        ret = []
        count = 0
        while count < self.batch_size:
            batch = self._try_parse(self._next_line())
            tries = 0
            while not batch and tries < 100:
                tries += 1
                logger.debug("Skipping empty line in {}".format(self.cur_file))
                batch = self._try_parse(self._next_line())
            if not batch:
                raise ValueError(
                    "Failed to read valid experience batch from file: {}".format(
                        self.cur_file
                    )
                )
            batch = self._postprocess_if_needed(batch)
            count += batch.count
            ret.append(batch)
        ret = concat_samples(ret)
        return ret

    def read_all_files(self) -> SampleBatchType:
        """Reads through all files and yields one SampleBatchType per line.

        When reaching the end of the last file, will start from the beginning
        again.

        Yields:
            One SampleBatch or MultiAgentBatch per line in all input files.
        """
        for path in self.files:
            file = self._try_open_file(path)
            while True:
                line = file.readline()
                if not line:
                    break
                batch = self._try_parse(line)
                if batch is None:
                    break
                yield batch

    def _postprocess_if_needed(self, batch: SampleBatchType) -> SampleBatchType:
        if not self.ioctx.config.get("postprocess_inputs"):
            return batch

        batch = convert_ma_batch_to_sample_batch(batch)

        if isinstance(batch, SampleBatch):
            out = []
            for sub_batch in batch.split_by_episode():
                out.append(self.default_policy.postprocess_trajectory(sub_batch))
            return concat_samples(out)
        else:
            # TODO(ekl) this is trickier since the alignments between agent
            #  trajectories in the episode are not available any more.
            raise NotImplementedError(
                "Postprocessing of multi-agent data not implemented yet."
            )

    def _try_open_file(self, path):
        if urlparse(path).scheme not in [""] + WINDOWS_DRIVES:
            if smart_open is None:
                raise ValueError(
                    "You must install the `smart_open` module to read "
                    "from URIs like {}".format(path)
                )
            ctx = smart_open
        else:
            # Allow shortcut for home directory ("~/" -> env[HOME]).
            if path.startswith("~/"):
                path = os.path.join(os.environ.get("HOME", ""), path[2:])

            # If path doesn't exist, try to interpret is as relative to the
            # rllib directory (located ../../ from this very module).
            path_orig = path
            if not os.path.exists(path):
                path = os.path.join(Path(__file__).parent.parent, path)
            if not os.path.exists(path):
                raise FileNotFoundError(f"Offline file {path_orig} not found!")

            # Unzip files, if necessary and re-point to extracted json file.
            if re.search("\\.zip$", path):
                with zipfile.ZipFile(path, "r") as zip_ref:
                    zip_ref.extractall(Path(path).parent)
                path = re.sub("\\.zip$", ".json", path)
                assert os.path.exists(path)
            ctx = open
        file = ctx(path, "r")
        return file

    def _try_parse(self, line: str) -> Optional[SampleBatchType]:
        line = line.strip()
        if not line:
            return None
        try:
            batch = self._from_json(line)
        except Exception:
            logger.exception(
                "Ignoring corrupt json record in {}: {}".format(self.cur_file, line)
            )
            return None

        batch = postprocess_actions(batch, self.ioctx)

        return batch

    def _next_line(self) -> str:
        if not self.cur_file:
            self.cur_file = self._next_file()
        line = self.cur_file.readline()
        tries = 0
        while not line and tries < 100:
            tries += 1
            if hasattr(self.cur_file, "close"):  # legacy smart_open impls
                self.cur_file.close()
            self.cur_file = self._next_file()
            line = self.cur_file.readline()
            if not line:
                logger.debug("Ignoring empty file {}".format(self.cur_file))
        if not line:
            raise ValueError(
                "Failed to read next line from files: {}".format(self.files)
            )
        return line

    def _next_file(self) -> FileType:
        # If this is the first time, we open a file, make sure all workers
        # start with a different one if possible.
        if self.cur_file is None and self.ioctx.worker is not None:
            idx = self.ioctx.worker.worker_index
            total = self.ioctx.worker.num_workers or 1
            path = self.files[round((len(self.files) - 1) * (idx / total))]
        # After the first file, pick all others randomly.
        else:
            path = random.choice(self.files)
        return self._try_open_file(path)

    def _from_json(self, data: str) -> SampleBatchType:
        if isinstance(data, bytes):  # smart_open S3 doesn't respect "r"
            data = data.decode("utf-8")
        json_data = json.loads(data)
        return from_json_data(json_data, self.ioctx.worker)
