import logging
import os
import time

from ray.rllib.utils.annotations import OldAPIStack
from ray.rllib.utils.framework import try_import_tf
from ray.util.debug import log_once

tf1, tf, tfv = try_import_tf()
logger = logging.getLogger(__name__)


@OldAPIStack
class _TFRunBuilder:
    """Used to incrementally build up a TensorFlow run.

    This is particularly useful for batching ops from multiple different
    policies in the multi-agent setting.
    """

    def __init__(self, session, debug_name):
        self.session = session
        self.debug_name = debug_name
        self.feed_dict = {}
        self.fetches = []
        self._executed = None

    def add_feed_dict(self, feed_dict):
        assert not self._executed
        for k in feed_dict:
            if k in self.feed_dict:
                raise ValueError("Key added twice: {}".format(k))
        self.feed_dict.update(feed_dict)

    def add_fetches(self, fetches):
        assert not self._executed
        base_index = len(self.fetches)
        self.fetches.extend(fetches)
        return list(range(base_index, len(self.fetches)))

    def get(self, to_fetch):
        if self._executed is None:
            try:
                self._executed = _run_timeline(
                    self.session,
                    self.fetches,
                    self.debug_name,
                    self.feed_dict,
                    os.environ.get("TF_TIMELINE_DIR"),
                )
            except Exception as e:
                logger.exception(
                    "Error fetching: {}, feed_dict={}".format(
                        self.fetches, self.feed_dict
                    )
                )
                raise e
        if isinstance(to_fetch, int):
            return self._executed[to_fetch]
        elif isinstance(to_fetch, list):
            return [self.get(x) for x in to_fetch]
        elif isinstance(to_fetch, tuple):
            return tuple(self.get(x) for x in to_fetch)
        else:
            raise ValueError("Unsupported fetch type: {}".format(to_fetch))


_count = 0


def _run_timeline(sess, ops, debug_name, feed_dict=None, timeline_dir=None):
    if feed_dict is None:
        feed_dict = {}

    if timeline_dir:
        from tensorflow.python.client import timeline

        try:
            run_options = tf1.RunOptions(trace_level=tf.RunOptions.FULL_TRACE)
        except AttributeError:
            run_options = None
            # In local mode, tf1.RunOptions is not available, see #26511
            if log_once("tf1.RunOptions_not_available"):
                logger.exception(
                    "Can not access tf.RunOptions.FULL_TRACE. This may be because "
                    "you have used `ray.init(local_mode=True)`. RLlib will use "
                    "timeline without `options=tf.RunOptions.FULL_TRACE`."
                )
        run_metadata = tf1.RunMetadata()
        start = time.time()
        fetches = sess.run(
            ops, options=run_options, run_metadata=run_metadata, feed_dict=feed_dict
        )
        trace = timeline.Timeline(step_stats=run_metadata.step_stats)
        global _count
        outf = os.path.join(
            timeline_dir,
            "timeline-{}-{}-{}.json".format(debug_name, os.getpid(), _count % 10),
        )
        _count += 1
        trace_file = open(outf, "w")
        logger.info(
            "Wrote tf timeline ({} s) to {}".format(
                time.time() - start, os.path.abspath(outf)
            )
        )
        trace_file.write(trace.generate_chrome_trace_format())
    else:
        if log_once("tf_timeline"):
            logger.info(
                "Executing TF run without tracing. To dump TF timeline traces "
                "to disk, set the TF_TIMELINE_DIR environment variable."
            )
        fetches = sess.run(ops, feed_dict=feed_dict)
    return fetches
