import collections
import copy
import logging
import time
from collections import defaultdict
from contextlib import contextmanager
from dataclasses import dataclass, fields
from typing import (
    Any,
    Dict,
    List,
    Mapping,
    Optional,
    Set,
    Tuple,
    Union,
)
from uuid import uuid4

import numpy as np

import ray
from ray.actor import ActorHandle
from ray.data._internal.block_list import BlockList
from ray.data._internal.execution.dataset_state import DatasetState
from ray.data._internal.execution.interfaces.common import RuntimeMetricsHistogram
from ray.data._internal.execution.interfaces.op_runtime_metrics import (
    NODE_UNKNOWN,
    MetricsGroup,
    MetricsType,
    NodeMetrics,
    OpRuntimeMetrics,
)
from ray.data._internal.metadata_exporter import (
    DatasetMetadata,
    Topology,
    get_dataset_metadata_exporter,
)
from ray.data._internal.util import capfirst
from ray.data.block import BlockStats
from ray.data.context import DataContext
from ray.util.annotations import DeveloperAPI
from ray.util.metrics import Counter, Gauge, Histogram, Metric
from ray.util.scheduling_strategies import NodeAffinitySchedulingStrategy

logger = logging.getLogger(__name__)

STATS_ACTOR_NAME = "datasets_stats_actor"
STATS_ACTOR_NAMESPACE = "_dataset_stats_actor"
UNKNOWN = "unknown"


StatsDict = Dict[str, List[BlockStats]]


def fmt(seconds: float) -> str:
    if seconds > 1:
        return str(round(seconds, 2)) + "s"
    elif seconds > 0.001:
        return str(round(seconds * 1000, 2)) + "ms"
    else:
        return str(round(seconds * 1000 * 1000, 2)) + "us"


def leveled_indent(lvl: int = 0, spaces_per_indent: int = 3) -> str:
    """Returns a string of spaces which contains `level` indents,
    each indent containing `spaces_per_indent` spaces. For example:
    >>> leveled_indent(2, 3)
    '      '
    """
    return (" " * spaces_per_indent) * lvl


class Timer:
    """Helper class for tracking accumulated time (in seconds)."""

    def __init__(self):
        self._total: float = 0
        self._min: float = float("inf")
        self._max: float = 0
        self._total_count: float = 0

    @contextmanager
    def timer(self) -> None:
        time_start = time.perf_counter()
        try:
            yield
        finally:
            self.add(time.perf_counter() - time_start)

    def add(self, value: float) -> None:
        self._total += value
        if value < self._min:
            self._min = value
        if value > self._max:
            self._max = value
        self._total_count += 1

    def get(self) -> float:
        return self._total

    def min(self) -> float:
        return self._min

    def max(self) -> float:
        return self._max

    def avg(self) -> float:
        return self._total / self._total_count if self._total_count else float("inf")


class _DatasetStatsBuilder:
    """Helper class for building dataset stats.

    When this class is created, we record the start time. When build() is
    called with the final blocks of the new dataset, the time delta is
    saved as part of the stats."""

    def __init__(
        self,
        operator_name: str,
        parent: "DatasetStats",
        override_start_time: Optional[float],
    ):
        self.operator_name = operator_name
        self.parent = parent
        self.start_time = override_start_time or time.perf_counter()

    def build_multioperator(self, metadata: StatsDict) -> "DatasetStats":
        op_metadata = {}
        for i, (k, v) in enumerate(metadata.items()):
            capped_k = capfirst(k)
            if len(metadata) > 1:
                if i == 0:
                    op_metadata[self.operator_name + capped_k] = v
                else:
                    op_metadata[self.operator_name.split("->")[-1] + capped_k] = v
            else:
                op_metadata[self.operator_name] = v
        stats = DatasetStats(
            metadata=op_metadata,
            parent=self.parent,
            base_name=self.operator_name,
        )
        stats.time_total_s = time.perf_counter() - self.start_time
        return stats

    def build(self, final_blocks: BlockList) -> "DatasetStats":
        stats = DatasetStats(
            metadata={self.operator_name: final_blocks.get_metadata()},
            parent=self.parent,
        )
        stats.time_total_s = time.perf_counter() - self.start_time
        return stats


@ray.remote(num_cpus=0)
class _StatsActor:
    """Actor holding stats for blocks created by LazyBlockList.

    This actor is shared across all datasets created in the same cluster.
    In order to cap memory usage, we set a max number of stats to keep
    in the actor. When this limit is exceeded, the stats will be garbage
    collected in FIFO order.

    TODO(ekl) we should consider refactoring LazyBlockList so stats can be
    extracted without using an out-of-band actor."""

    def __init__(self, max_stats=1000):
        # Mapping from uuid -> (task_id -> list of blocks statistics).
        self.metadata = collections.defaultdict(dict)
        self.last_time = {}
        self.start_time = {}
        self.max_stats = max_stats

        # Assign dataset uuids with a global counter.
        self.next_dataset_id = 0
        # Dataset metadata to be queried directly by DashboardHead api.
        self.datasets: Dict[str, Any] = {}

        # Cache of calls to ray.nodes() to prevent unnecessary network calls
        self._ray_nodes_cache: Dict[str, str] = {}

        # Initialize the metadata exporter
        self._metadata_exporter = get_dataset_metadata_exporter()
        self.dataset_metadatas: Dict[str, DatasetMetadata] = {}

        # A FIFO queue of dataset_tags for finished datasets. This is used to
        # efficiently evict the oldest finished datasets when max_stats is reached.
        self.finished_datasets_queue = collections.deque()

        # Ray Data dashboard metrics
        # Everything is a gauge because we need to reset all of
        # a dataset's metrics to 0 after each finishes execution.
        op_tags_keys = ("dataset", "operator")

        # TODO(scottjlee): move these overvie metrics as fields in a
        # separate dataclass, similar to OpRuntimeMetrics.
        self.spilled_bytes = Gauge(
            "data_spilled_bytes",
            description="""Bytes spilled by dataset operators.
                DataContext.enable_get_object_locations_for_metrics
                must be set to True to report this metric""",
            tag_keys=op_tags_keys,
        )
        self.freed_bytes = Gauge(
            "data_freed_bytes",
            description="Bytes freed by dataset operators",
            tag_keys=op_tags_keys,
        )
        self.current_bytes = Gauge(
            "data_current_bytes",
            description="Bytes currently in memory store used by dataset operators",
            tag_keys=op_tags_keys,
        )
        self.cpu_usage_cores = Gauge(
            "data_cpu_usage_cores",
            description="CPUs allocated to dataset operators",
            tag_keys=op_tags_keys,
        )
        self.gpu_usage_cores = Gauge(
            "data_gpu_usage_cores",
            description="GPUs allocated to dataset operators",
            tag_keys=op_tags_keys,
        )
        self.output_bytes = Gauge(
            "data_output_bytes",
            description="Bytes outputted by dataset operators",
            tag_keys=op_tags_keys,
        )
        self.output_rows = Gauge(
            "data_output_rows",
            description="Rows outputted by dataset operators",
            tag_keys=op_tags_keys,
        )

        # === Metrics from OpRuntimeMetrics ===
        # Inputs-related metrics
        self.execution_metrics_inputs = (
            self._create_prometheus_metrics_for_execution_metrics(
                metrics_group=MetricsGroup.INPUTS,
                tag_keys=op_tags_keys,
            )
        )

        # Outputs-related metrics
        self.execution_metrics_outputs = (
            self._create_prometheus_metrics_for_execution_metrics(
                metrics_group=MetricsGroup.OUTPUTS,
                tag_keys=op_tags_keys,
            )
        )

        # Task-related metrics
        self.execution_metrics_tasks = (
            self._create_prometheus_metrics_for_execution_metrics(
                metrics_group=MetricsGroup.TASKS,
                tag_keys=op_tags_keys,
            )
        )

        # Object store memory-related metrics
        self.execution_metrics_obj_store_memory = (
            self._create_prometheus_metrics_for_execution_metrics(
                metrics_group=MetricsGroup.OBJECT_STORE_MEMORY,
                tag_keys=op_tags_keys,
            )
        )

        # Actor related metrics
        self.execution_metrics_actors = (
            self._create_prometheus_metrics_for_execution_metrics(
                metrics_group=MetricsGroup.ACTORS,
                tag_keys=op_tags_keys,
            )
        )

        # Miscellaneous metrics
        self.execution_metrics_misc = (
            self._create_prometheus_metrics_for_execution_metrics(
                metrics_group=MetricsGroup.MISC,
                tag_keys=op_tags_keys,
            )
        )

        # Per Node metrics
        self.per_node_metrics = self._create_prometheus_metrics_for_per_node_metrics()

        iter_tag_keys = ("dataset",)

        self.time_to_first_batch_s = Gauge(
            "data_iter_time_to_first_batch_seconds",
            description="Total time spent waiting for the first batch after starting iteration. "
            "This includes the dataset pipeline warmup time. This metric is accumulated across different epochs.",
            tag_keys=iter_tag_keys,
        )

        self.iter_block_fetching_s = Gauge(
            "data_iter_block_fetching_seconds",
            description="Seconds taken to fetch (with ray.get) blocks by iter_batches()",
            tag_keys=iter_tag_keys,
        )
        self.iter_batch_shaping_s = Gauge(
            "data_iter_batch_shaping_seconds",
            description="Seconds taken to shape batch from incoming blocks by iter_batches()",
            tag_keys=iter_tag_keys,
        )
        self.iter_batch_formatting_s = Gauge(
            "data_iter_batch_formatting_seconds",
            description="Seconds taken to format batches by iter_batches()",
            tag_keys=iter_tag_keys,
        )
        self.iter_batch_collating_s = Gauge(
            "data_iter_batch_collating_seconds",
            description="Seconds taken to collate batches by iter_batches()",
            tag_keys=iter_tag_keys,
        )
        self.iter_batch_finalizing_s = Gauge(
            "data_iter_batch_finalizing_seconds",
            description="Seconds taken to collate batches by iter_batches()",
            tag_keys=iter_tag_keys,
        )

        self.iter_total_blocked_s = Gauge(
            "data_iter_total_blocked_seconds",
            description="Seconds user thread is blocked by iter_batches()",
            tag_keys=iter_tag_keys,
        )
        self.iter_user_s = Gauge(
            "data_iter_user_seconds",
            description="Seconds spent in user code",
            tag_keys=iter_tag_keys,
        )
        self.iter_initialize_s = Gauge(
            "data_iter_initialize_seconds",
            description="Seconds spent in iterator initialization code",
            tag_keys=iter_tag_keys,
        )
        self.iter_get_ref_bundles_s = Gauge(
            "data_iter_get_ref_bundles_seconds",
            description="Seconds spent getting RefBundles from the dataset iterator",
            tag_keys=iter_tag_keys,
        )
        self.iter_get_s = Gauge(
            "data_iter_get_seconds",
            description="Seconds spent in ray.get() while resolving block references",
            tag_keys=iter_tag_keys,
        )
        self.iter_next_batch_s = Gauge(
            "data_iter_next_batch_seconds",
            description="Seconds spent getting the next batch from the block buffer",
            tag_keys=iter_tag_keys,
        )
        self.iter_format_batch_s = Gauge(
            "data_iter_format_batch_seconds",
            description="Seconds spent formatting the batch",
            tag_keys=iter_tag_keys,
        )
        self.iter_collate_batch_s = Gauge(
            "data_iter_collate_batch_seconds",
            description="Seconds spent collating the batch",
            tag_keys=iter_tag_keys,
        )
        self.iter_finalize_batch_s = Gauge(
            "data_iter_finalize_batch_seconds",
            description="Seconds spent finalizing the batch",
            tag_keys=iter_tag_keys,
        )
        self.iter_blocks_local = Gauge(
            "data_iter_blocks_local",
            description="Number of blocks already on the local node",
            tag_keys=iter_tag_keys,
        )
        self.iter_blocks_remote = Gauge(
            "data_iter_blocks_remote",
            description="Number of blocks that require fetching from another node",
            tag_keys=iter_tag_keys,
        )
        self.iter_unknown_location = Gauge(
            "data_iter_unknown_location",
            description="Number of blocks that have unknown locations",
            tag_keys=iter_tag_keys,
        )
        self.iter_prefetched_bytes = Gauge(
            "data_iter_prefetched_bytes",
            description="Current bytes of prefetched blocks in the iterator",
            tag_keys=iter_tag_keys,
        )

        # === Dataset and Operator Metadata Metrics ===
        dataset_tags = ("dataset", "job_id", "start_time")
        self.data_dataset_estimated_total_blocks = Gauge(
            "data_dataset_estimated_total_blocks",
            description="Total work units in blocks for dataset",
            tag_keys=dataset_tags,
        )
        self.data_dataset_estimated_total_rows = Gauge(
            "data_dataset_estimated_total_rows",
            description="Total work units in rows for dataset",
            tag_keys=dataset_tags,
        )
        self.data_dataset_state = Gauge(
            "data_dataset_state",
            description=f"State of dataset ({', '.join([f'{s.value}={s.name}' for s in DatasetState])})",
            tag_keys=dataset_tags,
        )

        operator_tags = ("dataset", "operator")
        self.data_operator_estimated_total_blocks = Gauge(
            "data_operator_estimated_total_blocks",
            description="Total work units in blocks for operator",
            tag_keys=operator_tags,
        )
        self.data_operator_estimated_total_rows = Gauge(
            "data_operator_estimated_total_rows",
            description="Total work units in rows for operator",
            tag_keys=operator_tags,
        )
        self.data_operator_queued_blocks = Gauge(
            "data_operator_queued_blocks",
            description="Number of queued blocks for operator",
            tag_keys=operator_tags,
        )
        self.data_operator_state = Gauge(
            "data_operator_state",
            description=f"State of operator ({', '.join([f'{s.value}={s.name}' for s in DatasetState])})",
            tag_keys=operator_tags,
        )

    def _create_prometheus_metrics_for_execution_metrics(
        self, metrics_group: MetricsGroup, tag_keys: Tuple[str, ...]
    ) -> Dict[str, Metric]:
        metrics = {}
        for metric in OpRuntimeMetrics.get_metrics():
            if not metric.metrics_group == metrics_group:
                continue
            metric_name = f"data_{metric.name}"
            metric_description = metric.description
            if metric.metrics_type == MetricsType.Gauge:
                metrics[metric.name] = Gauge(
                    metric_name,
                    description=metric_description,
                    tag_keys=tag_keys,
                )
            elif metric.metrics_type == MetricsType.Histogram:
                metrics[metric.name] = Histogram(
                    metric_name,
                    description=metric_description,
                    tag_keys=tag_keys,
                    **metric.metrics_args,
                )
            elif metric.metrics_type == MetricsType.Counter:
                metrics[metric.name] = Counter(
                    metric_name,
                    description=metric_description,
                    tag_keys=tag_keys,
                )
        return metrics

    def _create_prometheus_metrics_for_per_node_metrics(self) -> Dict[str, Gauge]:
        metrics = {}
        for field in fields(NodeMetrics):
            metric_name = f"data_{field.name}_per_node"
            metrics[field.name] = Gauge(
                metric_name,
                description="",
                tag_keys=("dataset", "node_ip"),
            )
        return metrics

    def gen_dataset_id(self) -> str:
        """Generate a unique dataset_id for tracking datasets."""
        dataset_id = str(self.next_dataset_id)
        self.next_dataset_id += 1
        return dataset_id

    def update_execution_metrics(
        self,
        dataset_tag: str,
        op_metrics: List[Dict[str, Union[int, float]]],
        operator_tags: List[str],
        state: Dict[str, Any],
        per_node_metrics: Optional[Dict[str, Dict[str, Union[int, float]]]] = None,
    ):
        def _record(
            prom_metric: Metric,
            value: Union[int, float, List[int]],
            tags: Dict[str, str] = None,
        ):
            if isinstance(prom_metric, Gauge):
                prom_metric.set(value, tags)
            elif isinstance(prom_metric, Counter):
                prom_metric.inc(value, tags)
            elif isinstance(prom_metric, Histogram):
                if isinstance(value, RuntimeMetricsHistogram):
                    value.export_to(prom_metric, tags)

        for stats, operator_tag in zip(op_metrics, operator_tags):
            tags = self._create_tags(dataset_tag, operator_tag)

            self.spilled_bytes.set(stats.get("obj_store_mem_spilled", 0), tags)
            self.freed_bytes.set(stats.get("obj_store_mem_freed", 0), tags)
            self.current_bytes.set(stats.get("obj_store_mem_used", 0), tags)
            self.output_bytes.set(stats.get("bytes_task_outputs_generated", 0), tags)
            self.output_rows.set(stats.get("row_outputs_taken", 0), tags)
            self.cpu_usage_cores.set(stats.get("cpu_usage", 0), tags)
            self.gpu_usage_cores.set(stats.get("gpu_usage", 0), tags)
            for field_name, prom_metric in self.execution_metrics_inputs.items():
                _record(prom_metric, stats.get(field_name, 0), tags)
            for field_name, prom_metric in self.execution_metrics_outputs.items():
                _record(prom_metric, stats.get(field_name, 0), tags)

            for field_name, prom_metric in self.execution_metrics_tasks.items():
                _record(prom_metric, stats.get(field_name, 0), tags)

            for (
                field_name,
                prom_metric,
            ) in self.execution_metrics_obj_store_memory.items():
                _record(prom_metric, stats.get(field_name, 0), tags)

            for field_name, prom_metric in self.execution_metrics_actors.items():
                _record(prom_metric, stats.get(field_name, 0), tags)

            for field_name, prom_metric in self.execution_metrics_misc.items():
                _record(prom_metric, stats.get(field_name, 0), tags)

        # Update per node metrics if they exist, the creation of these metrics is controlled
        # by the _data_context.enable_per_node_metrics flag in the streaming executor but
        # that is not exposed in the _StatsActor so here we simply check if the metrics exist
        # and if so, update them
        if per_node_metrics is not None:
            for node_id, node_metrics in per_node_metrics.items():
                # Translate node_id into node_name (the node ip), cache node info
                if node_id not in self._ray_nodes_cache:
                    # Rebuilding this cache will fetch all nodes, this
                    # only needs to be done up to once per loop
                    self._rebuild_ray_nodes_cache()

                node_ip = self._ray_nodes_cache.get(node_id, NODE_UNKNOWN)

                tags = self._create_tags(dataset_tag=dataset_tag, node_ip_tag=node_ip)
                for metric_name, metric_value in node_metrics.items():
                    prom_metric = self.per_node_metrics[metric_name]
                    _record(prom_metric, metric_value, tags)

        # This update is called from a dataset's executor,
        # so all tags should contain the same dataset
        self.update_dataset(dataset_tag, state)

    def _rebuild_ray_nodes_cache(self) -> None:
        current_nodes = ray.nodes()
        for node in current_nodes:
            node_id = node.get("NodeID", None)
            node_name = node.get("NodeName", None)
            if node_id is not None and node_name is not None:
                self._ray_nodes_cache[node_id] = node_name

    def update_iteration_metrics(
        self,
        stats: "DatasetStats",
        dataset_tag,
    ):
        tags = self._create_tags(dataset_tag)

        self.iter_initialize_s.set(stats.iter_initialize_s.get(), tags)
        self.iter_get_ref_bundles_s.set(stats.iter_get_ref_bundles_s.get(), tags)
        self.iter_get_s.set(stats.iter_get_s.get(), tags)
        self.iter_next_batch_s.set(stats.iter_next_batch_s.get(), tags)
        self.iter_format_batch_s.set(stats.iter_format_batch_s.get(), tags)
        self.iter_collate_batch_s.set(stats.iter_collate_batch_s.get(), tags)
        self.iter_finalize_batch_s.set(stats.iter_finalize_batch_s.get(), tags)
        self.iter_blocks_local.set(stats.iter_blocks_local, tags)
        self.iter_blocks_remote.set(stats.iter_blocks_remote, tags)
        self.iter_unknown_location.set(stats.iter_unknown_location, tags)
        self.iter_prefetched_bytes.set(stats.iter_prefetched_bytes, tags)

        self.iter_block_fetching_s.set(stats.iter_get_s.get(), tags)
        self.iter_batch_shaping_s.set(stats.iter_next_batch_s.get(), tags)
        self.iter_batch_formatting_s.set(stats.iter_format_batch_s.get(), tags)
        self.iter_batch_collating_s.set(stats.iter_collate_batch_s.get(), tags)
        self.iter_batch_finalizing_s.set(stats.iter_finalize_batch_s.get(), tags)

        self.time_to_first_batch_s.set(stats.iter_time_to_first_batch_s.get(), tags)

        self.iter_total_blocked_s.set(stats.iter_total_blocked_s.get(), tags)
        self.iter_user_s.set(stats.iter_user_s.get(), tags)

    def register_dataset(
        self,
        job_id: str,
        dataset_tag: str,
        operator_tags: List[str],
        topology: Topology,
        data_context: DataContext,
    ):
        start_time = time.time()
        self.datasets[dataset_tag] = {
            "job_id": job_id,
            "state": DatasetState.PENDING.name,
            "progress": 0,
            "total": 0,
            "total_rows": 0,
            "start_time": start_time,
            "end_time": None,
            "operators": {
                operator: {
                    "state": DatasetState.PENDING.name,
                    "progress": 0,
                    "total": 0,
                    "queued_blocks": 0,
                }
                for operator in operator_tags
            },
        }
        if self._metadata_exporter is not None:
            self.dataset_metadatas[dataset_tag] = DatasetMetadata(
                job_id=job_id,
                topology=topology,
                dataset_id=dataset_tag,
                start_time=start_time,
                data_context=data_context,
                execution_start_time=None,
                execution_end_time=None,
                state=DatasetState.PENDING.name,
            )
            self._metadata_exporter.export_dataset_metadata(
                self.dataset_metadatas[dataset_tag]
            )

    def update_dataset(self, dataset_tag: str, state: Dict[str, Any]):
        self.datasets[dataset_tag].update(state)
        state = self.datasets[dataset_tag]

        job_id = self.datasets[dataset_tag].get("job_id", "None")
        start_time = str(int(self.datasets[dataset_tag].get("start_time", 0)))

        # Update dataset-level metrics
        dataset_tags = {
            "dataset": dataset_tag,
            "job_id": job_id,
            "start_time": start_time,
        }
        self.data_dataset_estimated_total_blocks.set(
            state.get("total", 0), dataset_tags
        )
        self.data_dataset_estimated_total_rows.set(
            state.get("total_rows", 0), dataset_tags
        )
        state_string = state.get("state", DatasetState.UNKNOWN.name)
        state_enum = DatasetState.from_string(state_string)
        self.data_dataset_state.set(state_enum.value, dataset_tags)
        self.update_dataset_metadata_state(dataset_tag, state_string)

        # Update operator-level metrics
        operator_states: Dict[str, str] = {}
        for operator, op_state in state.get("operators", {}).items():
            operator_tags = {
                "dataset": dataset_tag,
                "operator": operator,
            }
            self.data_operator_estimated_total_blocks.set(
                op_state.get("total", 0), operator_tags
            )
            self.data_operator_estimated_total_rows.set(
                op_state.get("total_rows", 0), operator_tags
            )
            self.data_operator_queued_blocks.set(
                op_state.get("queued_blocks", 0), operator_tags
            )

            # Get state code directly from enum
            state_string = op_state.get("state", DatasetState.UNKNOWN.name)
            state_enum = DatasetState.from_string(state_string)
            self.data_operator_state.set(state_enum.value, operator_tags)
            operator_states[operator] = state_string

        self.update_dataset_metadata_operator_states(dataset_tag, operator_states)

        # Evict the oldest finished datasets to ensure the `max_stats` limit is enforced.
        if state["state"] in {DatasetState.FINISHED.name, DatasetState.FAILED.name}:
            self.finished_datasets_queue.append(dataset_tag)
            while len(self.datasets) > self.max_stats and self.finished_datasets_queue:
                tag_to_evict = self.finished_datasets_queue.popleft()
                self.datasets.pop(tag_to_evict, None)
                self.dataset_metadatas.pop(tag_to_evict, None)

    def get_datasets(self, job_id: Optional[str] = None):
        if not job_id:
            return self.datasets
        return {k: v for k, v in self.datasets.items() if v["job_id"] == job_id}

    def update_dataset_metadata_state(self, dataset_id: str, new_state: str):
        if dataset_id not in self.dataset_metadatas:
            return
        update_time = time.time()
        dataset_metadata = self.dataset_metadatas[dataset_id]
        if dataset_metadata.state == new_state:
            return
        updated_dataset_metadata = copy.deepcopy(dataset_metadata)
        updated_dataset_metadata.state = new_state
        if new_state == DatasetState.RUNNING.name:
            updated_dataset_metadata.execution_start_time = update_time
        elif new_state in (DatasetState.FINISHED.name, DatasetState.FAILED.name):
            updated_dataset_metadata.execution_end_time = update_time
            # Update metadata of running operators
            for operator in updated_dataset_metadata.topology.operators:
                if operator.state == DatasetState.RUNNING.name:
                    operator.state = new_state
                    operator.execution_end_time = update_time

        self.dataset_metadatas[dataset_id] = updated_dataset_metadata
        if self._metadata_exporter is not None:
            self._metadata_exporter.export_dataset_metadata(
                updated_dataset_metadata,
                include_data_context=False,
                include_op_args=False,
            )

    def update_dataset_metadata_operator_states(
        self, dataset_id: str, operator_states: Dict[str, str]
    ):
        if dataset_id not in self.dataset_metadatas:
            return

        dataset_metadata = self.dataset_metadatas[dataset_id]
        update_needed = False
        for operator in dataset_metadata.topology.operators:
            if (
                operator.id in operator_states
                and operator.state != operator_states[operator.id]
            ):
                update_needed = True
                break

        if not update_needed:
            return

        updated_dataset_metadata = copy.deepcopy(dataset_metadata)
        update_time = time.time()
        for operator in updated_dataset_metadata.topology.operators:
            if operator.id in operator_states:
                new_state = operator_states[operator.id]
                if operator.state == new_state:
                    continue
                operator.state = new_state
                if new_state == DatasetState.RUNNING.name:
                    operator.execution_start_time = update_time
                elif new_state in (
                    DatasetState.FINISHED.name,
                    DatasetState.FAILED.name,
                ):
                    operator.execution_end_time = update_time
                    # Handle outlier case for InputDataBuffer, which is marked as finished immediately and does not have a RUNNING state.
                    # Set the execution time the same as its end time
                    if not operator.execution_start_time:
                        operator.execution_start_time = update_time

        self.dataset_metadatas[dataset_id] = updated_dataset_metadata
        if self._metadata_exporter is not None:
            self._metadata_exporter.export_dataset_metadata(
                updated_dataset_metadata,
                include_data_context=False,
                include_op_args=False,
            )

    def _create_tags(
        self,
        dataset_tag: str,
        operator_tag: Optional[str] = None,
        node_ip_tag: Optional[str] = None,
    ):
        tags = {"dataset": dataset_tag}
        if operator_tag is not None:
            tags["operator"] = operator_tag
        if node_ip_tag is not None:
            tags["node_ip"] = node_ip_tag
        return tags


def get_or_create_stats_actor() -> ActorHandle[_StatsActor]:
    """Each cluster will contain exactly 1 _StatsActor. This function
    returns the current _StatsActor handle, or create a new one if one
    does not exist in the connected cluster. The _StatsActor is pinned on
    on driver process' node.
    """
    if ray._private.worker._global_node is None:
        raise RuntimeError(
            "Global node is not initialized. Driver might be not connected to Ray."
        )

    current_cluster_id = ray._private.worker._global_node.cluster_id

    logger.debug(f"Stats Actor located on cluster_id={current_cluster_id}")

    # so it fate-shares with the driver.
    scheduling_strategy = NodeAffinitySchedulingStrategy(
        ray.get_runtime_context().get_node_id(),
        soft=False,
    )

    return _StatsActor.options(
        name=STATS_ACTOR_NAME,
        namespace=STATS_ACTOR_NAMESPACE,
        get_if_exists=True,
        lifetime="detached",
        scheduling_strategy=scheduling_strategy,
    ).remote()


class _StatsManager:
    """A Class containing util functions that manage remote calls to _StatsActor.

    Ray Data updates metrics through the _StatsManager, and direct remote calls
    to the _StatsActor is discouraged. Some functionalities provided by
    _StatsManager:
        - Format and update iteration metrics
        - Format and update execution metrics
        - Aggregate per node metrics
        - Dataset registration
    """

    @staticmethod
    def _aggregate_per_node_metrics(
        op_metrics: List[OpRuntimeMetrics],
    ) -> Optional[Mapping[str, Mapping[str, Union[int, float]]]]:
        """
        Aggregate per-node metrics from a list of OpRuntimeMetrics objects.

        If per-node metrics are disabled in the current DataContext, returns None.
        Otherwise, it sums up all NodeMetrics fields across the provided metrics and
        returns a nested dictionary mapping each node ID to a dict of field values.
        """
        if not DataContext.get_current().enable_per_node_metrics:
            return None

        aggregated_by_node = defaultdict(lambda: defaultdict(int))
        for metrics in op_metrics:
            for node_id, node_metrics in metrics._per_node_metrics.items():
                agg_node_metrics = aggregated_by_node[node_id]
                for f in fields(NodeMetrics):
                    agg_node_metrics[f.name] += getattr(node_metrics, f.name)

        return aggregated_by_node

    @staticmethod
    def update_execution_metrics(
        dataset_tag: str,
        op_metrics: List[OpRuntimeMetrics],
        operator_tags: List[str],
        state: Dict[str, Any],
    ):
        per_node_metrics = _StatsManager._aggregate_per_node_metrics(op_metrics)
        op_metrics_dicts = [metric.as_dict() for metric in op_metrics]
        args = (
            dataset_tag,
            op_metrics_dicts,
            operator_tags,
            state,
            per_node_metrics,
        )
        try:
            get_or_create_stats_actor().update_execution_metrics.remote(*args)
        except Exception as e:
            logger.warning(
                f"Error occurred during update_execution_metrics.remote call to _StatsActor: {e}",
                exc_info=True,
            )
            return

    @staticmethod
    def update_iteration_metrics(stats: "DatasetStats", dataset_tag: str):
        args = (stats, dataset_tag)
        try:
            get_or_create_stats_actor().update_iteration_metrics.remote(*args)
        except Exception as e:
            logger.warning(
                f"Error occurred during update_iteration_metrics.remote call to _StatsActor: {e}",
                exc_info=True,
            )

    @staticmethod
    def register_dataset_to_stats_actor(
        dataset_tag: str,
        operator_tags: List[str],
        topology: Topology,
        data_context: DataContext,
    ):
        """Register a dataset with the stats actor.

        Args:
            dataset_tag: Tag for the dataset
            operator_tags: List of operator tags
            topology: Optional Topology representing the DAG structure to export
            data_context: The DataContext attached to the dataset
        """

        get_or_create_stats_actor().register_dataset.remote(
            ray.get_runtime_context().get_job_id(),
            dataset_tag,
            operator_tags,
            topology,
            data_context,
        )

    @staticmethod
    def gen_dataset_id_from_stats_actor() -> str:
        try:
            stats_actor = get_or_create_stats_actor()

            return ray.get(stats_actor.gen_dataset_id.remote())
        except Exception as e:
            logger.warning(
                f"Failed to generate dataset_id, falling back to random uuid_v4: {e}"
            )
            # Getting dataset id from _StatsActor may fail, in this case
            # fall back to uuid4
            return uuid4().hex


class DatasetStats:
    """Holds the execution times for a given Dataset.

    This object contains a reference to the parent Dataset's stats as well,
    but not the Dataset object itself, to allow its blocks to be dropped from
    memory."""

    def __init__(
        self,
        *,
        metadata: StatsDict,
        parent: Union[Optional["DatasetStats"], List["DatasetStats"]],
        base_name: str = None,
    ):
        """Create dataset stats.

        Args:
            metadata: Dict of operators used to create this Dataset from the
                previous one. Typically one entry, e.g., {"map": [...]}.
            parent: Reference to parent Dataset's stats, or a list of parents
                if there are multiple.
            base_name: The name of the base operation for a multi-operator operation.
        """

        self.metadata: StatsDict = metadata
        if parent is not None and not isinstance(parent, list):
            parent = [parent]
        self.parents: List["DatasetStats"] = parent or []
        self.number: int = (
            0 if not self.parents else max(p.number for p in self.parents) + 1
        )
        self.base_name = base_name
        # TODO(ekl) deprecate and remove the notion of dataset UUID once we move
        # fully to streaming execution.
        self.dataset_uuid: str = "unknown_uuid"
        self.time_total_s: float = 0

        # Streaming executor stats
        self.streaming_exec_schedule_s: Timer = Timer()

        # Iteration stats, filled out if the user iterates over the dataset.
        self.iter_wait_s: Timer = Timer()
        self.iter_get_ref_bundles_s: Timer = Timer()
        self.iter_get_s: Timer = Timer()
        self.iter_next_batch_s: Timer = Timer()
        self.iter_format_batch_s: Timer = Timer()
        self.iter_collate_batch_s: Timer = Timer()
        self.iter_finalize_batch_s: Timer = Timer()
        self.iter_time_to_first_batch_s: Timer = Timer()
        self.iter_total_blocked_s: Timer = Timer()
        self.iter_user_s: Timer = Timer()
        self.iter_initialize_s: Timer = Timer()
        self.iter_total_s: Timer = Timer()
        self.extra_metrics = {}

        # Block fetch stats during iteration.
        # These are stats about locations of blocks when the iterator is trying to
        # consume them. The iteration performance will be affected depending on
        # whether the block is in the local object store of the node where the
        # iterator is running.
        # This serves as an indicator of block prefetching effectiveness.
        self.iter_blocks_local: int = 0
        self.iter_blocks_remote: int = 0
        self.iter_unknown_location: int = 0
        self.iter_prefetched_bytes: int = 0

        # Memory usage stats
        self.global_bytes_spilled: int = 0
        self.global_bytes_restored: int = 0
        self.dataset_bytes_spilled: int = 0

        # Streaming split coordinator stats (dataset level)
        self.streaming_split_coordinator_s: Timer = Timer()

    @property
    def stats_actor(self):
        return get_or_create_stats_actor()

    def child_builder(
        self, name: str, override_start_time: Optional[float] = None
    ) -> _DatasetStatsBuilder:
        """Start recording stats for an op of the given name (e.g., map)."""
        return _DatasetStatsBuilder(name, self, override_start_time)

    def to_summary(self) -> "DatasetStatsSummary":
        """Generate a `DatasetStatsSummary` object from the given `DatasetStats`
        object, which can be used to generate a summary string."""
        operators_stats = []
        is_sub_operator = len(self.metadata) > 1

        iter_stats = IterStatsSummary(
            self.iter_wait_s,
            self.iter_get_ref_bundles_s,
            self.iter_get_s,
            self.iter_next_batch_s,
            self.iter_format_batch_s,
            self.iter_collate_batch_s,
            self.iter_finalize_batch_s,
            self.iter_time_to_first_batch_s,
            self.iter_total_blocked_s,
            self.iter_user_s,
            self.iter_initialize_s,
            self.iter_total_s,
            self.streaming_split_coordinator_s,
            self.iter_blocks_local,
            self.iter_blocks_remote,
            self.iter_unknown_location,
            self.iter_prefetched_bytes,
        )

        stats_summary_parents = []
        if self.parents is not None:
            stats_summary_parents = [p.to_summary() for p in self.parents]

        # Collect the sum of the final output row counts from all parent nodes
        parent_total_output = 0
        for i, parent_summary in enumerate(stats_summary_parents):
            if parent_summary.operators_stats:
                # Get the last operator stats from the current parent summary
                last_parent_op = parent_summary.operators_stats[-1]
                # Extract output row count (handle dict type with "sum" key)
                op_output = (
                    last_parent_op.output_num_rows.get("sum", 0)
                    if isinstance(last_parent_op.output_num_rows, dict)
                    else 0
                )
                logger.debug(
                    f"Parent {i + 1} (operator: {last_parent_op.operator_name}) contributes {op_output} rows to input"
                )
                parent_total_output += op_output

        # Create temporary operator stats objects from block metadata
        op_stats = [
            OperatorStatsSummary.from_block_metadata(
                name, stats, is_sub_operator=is_sub_operator
            )
            for name, stats in self.metadata.items()
        ]

        for i, op_stat in enumerate(op_stats):
            # For sub-operators: inherit input based on the order in the current list
            if is_sub_operator:
                if i == 0:
                    # Input of the first sub-operator is the total output from parent nodes
                    op_stat.total_input_num_rows = parent_total_output
                else:
                    # Input of subsequent sub-operators is the output of the previous sub-operator
                    prev_op = op_stats[i - 1]
                    op_stat.total_input_num_rows = (
                        prev_op.output_num_rows["sum"]
                        if (
                            prev_op.output_num_rows and "sum" in prev_op.output_num_rows
                        )
                        else 0
                    )
            else:
                # Single operator scenario: input rows = total output from all parent nodes
                op_stat.total_input_num_rows = parent_total_output
            operators_stats.append(op_stat)
        streaming_exec_schedule_s = (
            self.streaming_exec_schedule_s.get()
            if self.streaming_exec_schedule_s
            else 0
        )
        return DatasetStatsSummary(
            operators_stats,
            iter_stats,
            stats_summary_parents,
            self.number,
            self.dataset_uuid,
            self.time_total_s,
            self.base_name,
            self.extra_metrics,
            self.global_bytes_spilled,
            self.global_bytes_restored,
            self.dataset_bytes_spilled,
            streaming_exec_schedule_s,
        )

    def runtime_metrics(self) -> str:
        """Generate a string representing the runtime metrics of a Dataset. This is
        a high level summary of the time spent in Ray Data code broken down by operator.
        It also includes the time spent in the scheduler. Times are shown as the total
        time for each operator and percentages of time are shown as a fraction of the
        total time for the whole dataset."""
        return self.to_summary().runtime_metrics()


@DeveloperAPI
@dataclass
class DatasetStatsSummary:
    operators_stats: List["OperatorStatsSummary"]
    iter_stats: "IterStatsSummary"
    parents: List["DatasetStatsSummary"]
    number: int
    dataset_uuid: str
    time_total_s: float
    base_name: str
    extra_metrics: Dict[str, Any]
    global_bytes_spilled: int
    global_bytes_restored: int
    dataset_bytes_spilled: int
    streaming_exec_schedule_s: float

    def to_string(
        self,
        already_printed: Optional[Set[str]] = None,
        include_parent: bool = True,
        add_global_stats=True,
    ) -> str:
        """Return a human-readable summary of this Dataset's stats.

        Args:
            already_printed: Set of operator IDs that have already had its stats printed
               out.
            include_parent: If true, also include parent stats summary; otherwise, only
               log stats of the latest operator.
            add_global_stats: If true, includes global stats to this summary.
        Returns:
            String with summary statistics for executing the Dataset.
        """
        if already_printed is None:
            already_printed = set()

        out = ""
        if self.parents and include_parent:
            for p in self.parents:
                parent_sum = p.to_string(already_printed, add_global_stats=False)
                if parent_sum:
                    out += parent_sum
                    out += "\n"
        operators_stats_summary = None
        if len(self.operators_stats) == 1:
            operators_stats_summary = self.operators_stats[0]
            operator_name = operators_stats_summary.operator_name
            operator_uuid = self.dataset_uuid + operator_name
            out += "Operator {} {}: ".format(self.number, operator_name)
            if operator_uuid in already_printed:
                out += "[execution cached]\n"
            else:
                already_printed.add(operator_uuid)
                out += str(operators_stats_summary)
        elif len(self.operators_stats) > 1:
            rounded_total = round(self.time_total_s, 2)
            if rounded_total <= 0:
                # Handle -0.0 case.
                rounded_total = 0
            out += "Operator {} {}: executed in {}s\n".format(
                self.number, self.base_name, rounded_total
            )
            for n, operators_stats_summary in enumerate(self.operators_stats):
                operator_name = operators_stats_summary.operator_name
                operator_uuid = self.dataset_uuid + operator_name
                out += "\n"
                out += "\tSuboperator {} {}: ".format(n, operator_name)
                if operator_uuid in already_printed:
                    out += "\t[execution cached]\n"
                else:
                    already_printed.add(operator_uuid)
                    out += str(operators_stats_summary)
        verbose_stats_logs = DataContext.get_current().verbose_stats_logs
        if verbose_stats_logs and self.extra_metrics:
            indent = (
                "\t"
                if operators_stats_summary and operators_stats_summary.is_sub_operator
                else ""
            )
            out += indent
            out += "* Extra metrics: " + str(self.extra_metrics) + "\n"
        out += str(self.iter_stats)

        if len(self.operators_stats) > 0 and add_global_stats:
            mb_spilled = round(self.global_bytes_spilled / 1e6)
            mb_restored = round(self.global_bytes_restored / 1e6)
            if mb_spilled or mb_restored:
                out += "\nCluster memory:\n"
                out += "* Spilled to disk: {}MB\n".format(mb_spilled)
                out += "* Restored from disk: {}MB\n".format(mb_restored)

            dataset_mb_spilled = round(self.dataset_bytes_spilled / 1e6)
            if dataset_mb_spilled:
                out += "\nDataset memory:\n"
                out += "* Spilled to disk: {}MB\n".format(dataset_mb_spilled)

            if self.num_rows_per_s:
                out += "\n"
                out += "Dataset throughput:\n"
                out += f"\t* Ray Data throughput: {self.num_rows_per_s} rows/s\n"
        if verbose_stats_logs and add_global_stats:
            out += "\n" + self.runtime_metrics()

        return out

    @property
    def num_rows_per_s(self) -> float:
        """Calculates the throughput in rows per second for the entire dataset."""
        # The observed dataset throughput is computed by dividing the total number
        # of rows produced by the total wall time of the dataset (i.e. from start to
        # finish how long did the dataset take to be processed). With the recursive
        # nature of the DatasetStatsSummary, we use get_total_wall_time to determine
        # the total wall time (this finds the difference between the earliest start
        # and latest end for any block in any operator).
        output_num_rows = (
            self.operators_stats[-1].output_num_rows if self.operators_stats else 0
        )
        total_num_out_rows = output_num_rows["sum"] if output_num_rows else 0
        wall_time = self.get_total_wall_time()
        if not total_num_out_rows or not wall_time:
            return 0.0
        return total_num_out_rows / wall_time

    @staticmethod
    def _collect_dataset_stats_summaries(
        curr: "DatasetStatsSummary",
    ) -> List["DatasetStatsSummary"]:
        summs = []
        # TODO: Do operators ever have multiple parents? Do we need to deduplicate?
        for p in curr.parents:
            if p and p.parents:
                summs.extend(DatasetStatsSummary._collect_dataset_stats_summaries(p))
        return summs + [curr]

    @staticmethod
    def _find_start_and_end(summ: "DatasetStatsSummary") -> Tuple[float, float]:
        earliest_start = min(ops.earliest_start_time for ops in summ.operators_stats)
        latest_end = max(ops.latest_end_time for ops in summ.operators_stats)
        return earliest_start, latest_end

    def runtime_metrics(self) -> str:
        total_wall_time = self.get_total_wall_time()

        def fmt_line(name: str, time: float) -> str:
            fraction = time / total_wall_time if total_wall_time > 0 else 0
            return f"* {name}: {fmt(time)} ({fraction * 100:.3f}%)\n"

        summaries = DatasetStatsSummary._collect_dataset_stats_summaries(self)
        out = "Runtime Metrics:\n"
        for summ in summaries:
            if len(summ.operators_stats) > 0:
                earliest_start, latest_end = DatasetStatsSummary._find_start_and_end(
                    summ
                )
                op_total_time = latest_end - earliest_start
                out += fmt_line(summ.base_name, op_total_time)
        out += fmt_line("Scheduling", self.streaming_exec_schedule_s)
        out += fmt_line("Total", total_wall_time)
        return out

    def __repr__(self, level=0) -> str:
        indent = leveled_indent(level)
        operators_stats = "\n".join(
            [ss.__repr__(level + 2) for ss in self.operators_stats]
        )
        parent_stats = "\n".join([ps.__repr__(level + 2) for ps in self.parents])
        extra_metrics = "\n".join(
            f"{leveled_indent(level + 2)}{k}: {v},"
            for k, v in self.extra_metrics.items()
        )

        # Handle formatting case for empty outputs.
        operators_stats = (
            f"\n{operators_stats},\n{indent}   " if operators_stats else ""
        )
        parent_stats = f"\n{parent_stats},\n{indent}   " if parent_stats else ""
        extra_metrics = f"\n{extra_metrics}\n{indent}   " if extra_metrics else ""
        return (
            f"{indent}DatasetStatsSummary(\n"
            f"{indent}   dataset_uuid={self.dataset_uuid},\n"
            f"{indent}   base_name={self.base_name},\n"
            f"{indent}   number={self.number},\n"
            f"{indent}   extra_metrics={{{extra_metrics}}},\n"
            f"{indent}   operators_stats=[{operators_stats}],\n"
            f"{indent}   iter_stats={self.iter_stats.__repr__(level+1)},\n"
            f"{indent}   global_bytes_spilled={self.global_bytes_spilled / 1e6}MB,\n"
            f"{indent}   global_bytes_restored={self.global_bytes_restored / 1e6}MB,\n"
            f"{indent}   dataset_bytes_spilled={self.dataset_bytes_spilled / 1e6}MB,\n"
            f"{indent}   parents=[{parent_stats}],\n"
            f"{indent})"
        )

    def get_total_wall_time(self) -> float:
        """Calculate the total wall time for the dataset, this is done by finding
        the earliest start time and latest end time for any block in any operator.
        The wall time is the difference of these two times.
        """
        start_ends = [
            DatasetStatsSummary._find_start_and_end(summ)
            for summ in DatasetStatsSummary._collect_dataset_stats_summaries(self)
            if len(summ.operators_stats) > 0
        ]
        if len(start_ends) == 0:
            return 0
        else:
            earliest_start = min(start_end[0] for start_end in start_ends)
            latest_end = max(start_end[1] for start_end in start_ends)
            return latest_end - earliest_start

    def get_total_time_all_blocks(self) -> float:
        """Calculate the sum of the wall times across all blocks of all operators."""
        summaries = DatasetStatsSummary._collect_dataset_stats_summaries(self)
        return sum(
            (
                sum(
                    ops.wall_time.get("sum", 0) if ops.wall_time else 0
                    for ops in summ.operators_stats
                )
            )
            for summ in summaries
        )

    def get_total_cpu_time(self) -> float:
        parent_sum = sum(p.get_total_cpu_time() for p in self.parents)
        return parent_sum + sum(
            ss.cpu_time.get("sum", 0) for ss in self.operators_stats
        )

    def get_max_heap_memory(self) -> float:
        parent_memory = [p.get_max_heap_memory() for p in self.parents]
        parent_max = max(parent_memory) if parent_memory else 0
        if not self.operators_stats:
            return parent_max

        return max(
            parent_max,
            *[ss.memory.get("max", 0) for ss in self.operators_stats],
        )


@dataclass
class OperatorStatsSummary:
    operator_name: str
    # Whether the operator associated with this OperatorStatsSummary object
    # is a suboperator
    is_sub_operator: bool
    # This is the total walltime of the entire operator, typically obtained from
    # `DatasetStats.time_total_s`. An important distinction is that this is the
    # overall runtime of the operator, pulled from the stats actor, whereas the
    # computed walltimes in `self.wall_time` are calculated on a operator level.
    time_total_s: float
    earliest_start_time: float
    latest_end_time: float
    # String summarizing high-level statistics from executing the operator
    block_execution_summary_str: str
    # The fields below are dicts with stats aggregated across blocks
    # processed in this operator. For example:
    # {"min": ..., "max": ..., "mean": ..., "sum": ...}
    wall_time: Optional[Dict[str, float]] = None
    cpu_time: Optional[Dict[str, float]] = None
    udf_time: Optional[Dict[str, float]] = None
    # memory: no "sum" stat
    memory: Optional[Dict[str, float]] = None
    # Use the output_num_rows of the parent Operator as output_num_rows
    total_input_num_rows: Optional[int] = None
    output_num_rows: Optional[Dict[str, float]] = None
    output_size_bytes: Optional[Dict[str, float]] = None
    # node_count: "count" stat instead of "sum"
    node_count: Optional[Dict[str, float]] = None
    task_rows: Optional[Dict[str, float]] = None

    @property
    def num_rows_per_s(self) -> float:
        # The observed Ray Data operator throughput is computed by dividing the
        # total number of rows produced by the wall time of the operator,
        # time_total_s.
        if not self.output_num_rows or not self.time_total_s:
            return 0.0
        return self.output_num_rows["sum"] / self.time_total_s

    @property
    def num_rows_per_task_s(self) -> float:
        """Calculates the estimated single-task throughput in rows per second."""
        # The estimated single task operator throughput is computed by dividing the
        # total number of rows produced by the sum of the wall times across all
        # blocks of the operator. This assumes that on a single task the work done
        # would be equivalent, with no concurrency.
        if not self.output_num_rows or not self.wall_time or not self.wall_time["sum"]:
            return 0.0
        return self.output_num_rows["sum"] / self.wall_time["sum"]

    @classmethod
    def from_block_metadata(
        cls,
        operator_name: str,
        block_stats: List[BlockStats],
        is_sub_operator: bool,
    ) -> "OperatorStatsSummary":
        """Calculate the stats for a operator from a given list of blocks,
        and generates a `OperatorStatsSummary` object with the results.

        Args:
            block_stats: List of `BlockStats` to calculate stats of
            operator_name: Name of operator associated with `blocks`
            is_sub_operator: Whether this set of blocks belongs to a sub operator.
        Returns:
            A `OperatorStatsSummary` object initialized with the calculated statistics
        """
        exec_stats = [m.exec_stats for m in block_stats if m.exec_stats is not None]
        rounded_total = 0
        time_total_s = 0
        earliest_start_time, latest_end_time = 0, 0

        if exec_stats:
            # Calculate the total execution time of operator as
            # the difference between the latest end time and
            # the earliest start time of all blocks in the operator.
            earliest_start_time = min(s.start_time_s for s in exec_stats)
            latest_end_time = max(s.end_time_s for s in exec_stats)
            time_total_s = latest_end_time - earliest_start_time

        if is_sub_operator:
            exec_summary_str = "{} blocks produced\n".format(len(exec_stats))
        else:
            if exec_stats:
                rounded_total = round(time_total_s, 2)
                if rounded_total <= 0:
                    # Handle -0.0 case.
                    rounded_total = 0
                exec_summary_str = "{} blocks produced in {}s".format(
                    len(exec_stats), rounded_total
                )
            else:
                exec_summary_str = ""
            exec_summary_str += "\n"

        task_rows = collections.defaultdict(int)
        for meta in block_stats:
            if meta.num_rows is not None and meta.exec_stats is not None:
                task_rows[meta.exec_stats.task_idx] += meta.num_rows
        task_rows_stats = None
        if len(task_rows) > 0:
            task_rows_stats = {
                "min": min(task_rows.values()),
                "max": max(task_rows.values()),
                "mean": int(np.mean(list(task_rows.values()))),
                "count": len(task_rows),
            }
            exec_summary_str = "{} tasks executed, {}".format(
                len(task_rows), exec_summary_str
            )

        wall_time_stats, cpu_stats, memory_stats, udf_stats = None, None, None, None
        if exec_stats:
            wall_time_stats = {
                "min": min([e.wall_time_s for e in exec_stats]),
                "max": max([e.wall_time_s for e in exec_stats]),
                "mean": np.mean([e.wall_time_s for e in exec_stats]),
                "sum": sum([e.wall_time_s for e in exec_stats]),
            }
            cpu_stats = {
                "min": min([e.cpu_time_s for e in exec_stats]),
                "max": max([e.cpu_time_s for e in exec_stats]),
                "mean": np.mean([e.cpu_time_s for e in exec_stats]),
                "sum": sum([e.cpu_time_s for e in exec_stats]),
            }

            memory_stats_mb = [
                round((e.max_uss_bytes or 0) / (1024 * 1024), 2) for e in exec_stats
            ]
            memory_stats = {
                "min": min(memory_stats_mb),
                "max": max(memory_stats_mb),
                "mean": int(np.mean(memory_stats_mb)),
            }

            udf_stats = {
                "min": min([e.udf_time_s for e in exec_stats]),
                "max": max([e.udf_time_s for e in exec_stats]),
                "mean": np.mean([e.udf_time_s for e in exec_stats]),
                "sum": sum([e.udf_time_s for e in exec_stats]),
            }

        output_num_rows_stats = None
        output_num_rows = [m.num_rows for m in block_stats if m.num_rows is not None]
        if output_num_rows:
            output_num_rows_stats = {
                "min": min(output_num_rows),
                "max": max(output_num_rows),
                "mean": int(np.mean(output_num_rows)),
                "sum": sum(output_num_rows),
            }

        output_size_bytes_stats = None
        output_size_bytes = [
            m.size_bytes for m in block_stats if m.size_bytes is not None
        ]
        if output_size_bytes:
            output_size_bytes_stats = {
                "min": min(output_size_bytes),
                "max": max(output_size_bytes),
                "mean": int(np.mean(output_size_bytes)),
                "sum": sum(output_size_bytes),
            }

        node_counts_stats = None
        if exec_stats:
            node_tasks = collections.defaultdict(set)
            for s in exec_stats:
                node_tasks[s.node_id].add(s.task_idx)

            node_counts = {node: len(tasks) for node, tasks in node_tasks.items()}
            node_counts_stats = {
                "min": min(node_counts.values()),
                "max": max(node_counts.values()),
                "mean": int(np.mean(list(node_counts.values()))),
                "count": len(node_counts),
            }

        # Assign a value in to_summary and initialize it as None.
        total_input_num_rows = None

        return OperatorStatsSummary(
            operator_name=operator_name,
            is_sub_operator=is_sub_operator,
            time_total_s=time_total_s,
            earliest_start_time=earliest_start_time,
            latest_end_time=latest_end_time,
            block_execution_summary_str=exec_summary_str,
            wall_time=wall_time_stats,
            cpu_time=cpu_stats,
            udf_time=udf_stats,
            memory=memory_stats,
            total_input_num_rows=total_input_num_rows,
            output_num_rows=output_num_rows_stats,
            output_size_bytes=output_size_bytes_stats,
            node_count=node_counts_stats,
            task_rows=task_rows_stats,
        )

    def __str__(self) -> str:
        """For a given (pre-calculated) `OperatorStatsSummary` object (e.g. generated from
        `OperatorStatsSummary.from_block_metadata()`), returns a human-friendly string
        that summarizes operator execution statistics.

        Returns:
            String with summary statistics for executing the given operator.
        """
        indent = "\t" if self.is_sub_operator else ""
        out = self.block_execution_summary_str

        wall_time_stats = self.wall_time
        if wall_time_stats:
            out += indent
            out += "* Remote wall time: {} min, {} max, {} mean, {} total\n".format(
                fmt(wall_time_stats["min"]),
                fmt(wall_time_stats["max"]),
                fmt(wall_time_stats["mean"]),
                fmt(wall_time_stats["sum"]),
            )

        cpu_stats = self.cpu_time
        if cpu_stats:
            out += indent
            out += "* Remote cpu time: {} min, {} max, {} mean, {} total\n".format(
                fmt(cpu_stats["min"]),
                fmt(cpu_stats["max"]),
                fmt(cpu_stats["mean"]),
                fmt(cpu_stats["sum"]),
            )

        udf_stats = self.udf_time
        if udf_stats:
            out += indent
            out += "* UDF time: {} min, {} max, {} mean, {} total\n".format(
                fmt(udf_stats["min"]),
                fmt(udf_stats["max"]),
                fmt(udf_stats["mean"]),
                fmt(udf_stats["sum"]),
            )

        memory_stats = self.memory
        if memory_stats:
            out += indent
            out += "* Peak heap memory usage (MiB): {} min, {} max, {} mean\n".format(
                memory_stats["min"],
                memory_stats["max"],
                memory_stats["mean"],
            )

        output_num_rows_stats = self.output_num_rows
        if output_num_rows_stats:
            out += indent
            out += (
                "* Output num rows per block: {} min, {} max, {} mean, {} total\n"
            ).format(
                output_num_rows_stats["min"],
                output_num_rows_stats["max"],
                output_num_rows_stats["mean"],
                output_num_rows_stats["sum"],
            )

        output_size_bytes_stats = self.output_size_bytes
        if output_size_bytes_stats:
            out += indent
            out += (
                "* Output size bytes per block: {} min, {} max, {} mean, {} total\n"
            ).format(
                output_size_bytes_stats["min"],
                output_size_bytes_stats["max"],
                output_size_bytes_stats["mean"],
                output_size_bytes_stats["sum"],
            )

        task_rows = self.task_rows
        if task_rows:
            out += indent
            out += (
                "* Output rows per task: {} min, {} max, {} mean, {} tasks used\n"
            ).format(
                task_rows["min"],
                task_rows["max"],
                task_rows["mean"],
                task_rows["count"],
            )

        node_count_stats = self.node_count
        if node_count_stats:
            out += indent
            out += "* Tasks per node: {} min, {} max, {} mean; {} nodes used\n".format(
                node_count_stats["min"],
                node_count_stats["max"],
                node_count_stats["mean"],
                node_count_stats["count"],
            )
        if self.num_rows_per_s and self.num_rows_per_task_s:
            total_num_in_rows = (
                self.total_input_num_rows if self.total_input_num_rows else 0
            )
            total_num_out_rows = output_num_rows_stats["sum"]
            out += indent
            out += "* Operator throughput:\n"
            out += (
                indent + "\t* Total input num rows:" f" {total_num_in_rows} " "rows\n"
            )
            out += (
                indent + "\t* Total output num rows:" f" {total_num_out_rows} " "rows\n"
            )
            out += (
                indent + "\t* Ray Data throughput:"
                f" {self.num_rows_per_s} "
                "rows/s\n"
            )
            out += (
                indent + "\t* Estimated single task throughput:"
                f" {self.num_rows_per_task_s} "
                "rows/s\n"
            )
        return out

    def __repr__(self, level=0) -> str:
        """For a given (pre-calculated) `OperatorStatsSummary` object (e.g. generated from
        `OperatorStatsSummary.from_block_metadata()`), returns a human-friendly string
        that summarizes operator execution statistics.

        Returns:
            String with summary statistics for executing the given operator.
        """
        indent = leveled_indent(level)
        indent += leveled_indent(1) if self.is_sub_operator else ""

        wall_time_stats = {k: fmt(v) for k, v in (self.wall_time or {}).items()}
        cpu_stats = {k: fmt(v) for k, v in (self.cpu_time or {}).items()}
        memory_stats = {k: fmt(v) for k, v in (self.memory or {}).items()}
        output_num_rows_stats = {
            k: fmt(v) for k, v in (self.output_num_rows or {}).items()
        }
        output_size_bytes_stats = {
            k: fmt(v) for k, v in (self.output_size_bytes or {}).items()
        }
        node_conut_stats = {k: fmt(v) for k, v in (self.node_count or {}).items()}
        out = (
            f"{indent}OperatorStatsSummary(\n"
            f"{indent}   operator_name='{self.operator_name}',\n"
            f"{indent}   is_suboperator={self.is_sub_operator},\n"
            f"{indent}   time_total_s={fmt(self.time_total_s)},\n"
            # block_execution_summary_str already ends with \n
            f"{indent}   block_execution_summary_str={self.block_execution_summary_str}"
            f"{indent}   wall_time={wall_time_stats or None},\n"
            f"{indent}   cpu_time={cpu_stats or None},\n"
            f"{indent}   memory={memory_stats or None},\n"
            f"{indent}   output_num_rows={output_num_rows_stats or None},\n"
            f"{indent}   output_size_bytes={output_size_bytes_stats or None},\n"
            f"{indent}   node_count={node_conut_stats or None},\n"
            f"{indent})"
        )
        return out


@dataclass
class IterStatsSummary:
    # Time spent in actor based prefetching, in seconds.
    wait_time: Timer
    # Time spent getting RefBundles from the dataset iterator, in seconds
    get_ref_bundles_time: Timer
    # Time spent in `ray.get()`, in seconds
    get_time: Timer
    # Time spent in batch building, in seconds
    next_time: Timer
    # Time spent in `_format_batch_()`, in seconds
    format_time: Timer
    # Time spent in collate fn, in seconds
    collate_time: Timer
    # Time spent in finalize_fn, in seconds
    finalize_batch_time: Timer
    # Time user thread is blocked waiting for first batch
    time_to_first_batch: Timer
    # Total time user thread is blocked by iter_batches
    block_time: Timer
    # Time spent in user code, in seconds
    user_time: Timer
    initialize_time: Timer
    # Total time taken by Dataset iterator, in seconds
    total_time: Timer
    # Time spent in streaming split coordinator
    streaming_split_coord_time: Timer
    # Num of blocks that are in local object store
    iter_blocks_local: int
    # Num of blocks that are in remote node and have to fetch locally
    iter_blocks_remote: int
    # Num of blocks with unknown locations
    iter_unknown_location: int
    # Current bytes of prefetched blocks in the iterator
    iter_prefetched_bytes: int

    def __str__(self) -> str:
        return self.to_string()

    def to_string(self) -> str:
        out = ""
        if (
            self.block_time.get()
            or self.time_to_first_batch.get()
            or self.total_time.get()
            or self.get_ref_bundles_time.get()
            or self.get_time.get()
            or self.next_time.get()
            or self.format_time.get()
            or self.collate_time.get()
            or self.finalize_batch_time.get()
        ):
            out += "\nDataset iterator time breakdown:\n"
            if self.total_time.get():
                out += "* Total time overall: {}\n".format(fmt(self.total_time.get()))
            if self.initialize_time.get():
                out += (
                    "    * Total time in Ray Data iterator initialization code: "
                    "{}\n".format(fmt(self.initialize_time.get()))
                )
            if self.block_time.get():
                out += (
                    "    * Total time user thread is blocked by Ray Data iter_batches: "
                    "{}\n".format(fmt(self.block_time.get()))
                )
            if self.time_to_first_batch.get():
                out += (
                    "    * Total time spent waiting for the first batch after starting iteration: "
                    "{}\n".format(fmt(self.time_to_first_batch.get()))
                )
            if self.user_time.get():
                out += "    * Total execution time for user thread: {}\n".format(
                    fmt(self.user_time.get())
                )
            out += (
                "* Batch iteration time breakdown (summed across prefetch threads):\n"
            )
            if self.get_ref_bundles_time.get():
                out += "    * In get RefBundles: {} min, {} max, {} avg, {} total\n".format(
                    fmt(self.get_ref_bundles_time.min()),
                    fmt(self.get_ref_bundles_time.max()),
                    fmt(self.get_ref_bundles_time.avg()),
                    fmt(self.get_ref_bundles_time.get()),
                )
            if self.get_time.get():
                out += "    * In ray.get(): {} min, {} max, {} avg, {} total\n".format(
                    fmt(self.get_time.min()),
                    fmt(self.get_time.max()),
                    fmt(self.get_time.avg()),
                    fmt(self.get_time.get()),
                )
            if self.next_time.get():
                batch_creation_str = (
                    "    * In batch creation: {} min, {} max, {} avg, {} total\n"
                )
                out += batch_creation_str.format(
                    fmt(self.next_time.min()),
                    fmt(self.next_time.max()),
                    fmt(self.next_time.avg()),
                    fmt(self.next_time.get()),
                )
            if self.format_time.get():
                format_str = (
                    "    * In batch formatting: {} min, {} max, {} avg, {} total\n"
                )
                out += format_str.format(
                    fmt(self.format_time.min()),
                    fmt(self.format_time.max()),
                    fmt(self.format_time.avg()),
                    fmt(self.format_time.get()),
                )
            if self.collate_time.get():
                out += "    * In collate_fn: {} min, {} max, {} avg, {} total\n".format(
                    fmt(self.collate_time.min()),
                    fmt(self.collate_time.max()),
                    fmt(self.collate_time.avg()),
                    fmt(self.collate_time.get()),
                )
            if self.finalize_batch_time.get():
                format_str = (
                    "    * In host->device transfer: {} min, {} max, {} avg, {} total\n"
                )
                out += format_str.format(
                    fmt(self.finalize_batch_time.min()),
                    fmt(self.finalize_batch_time.max()),
                    fmt(self.finalize_batch_time.avg()),
                    fmt(self.finalize_batch_time.get()),
                )
            if DataContext.get_current().enable_get_object_locations_for_metrics:
                out += "Block locations:\n"
                out += "    * Num blocks local: {}\n".format(self.iter_blocks_local)
                out += "    * Num blocks remote: {}\n".format(self.iter_blocks_remote)
                out += "    * Num blocks unknown location: {}\n".format(
                    self.iter_unknown_location
                )
            if self.iter_prefetched_bytes:
                out += "    * Prefetched bytes: {}\n".format(self.iter_prefetched_bytes)
            if self.streaming_split_coord_time.get() != 0:
                out += "Streaming split coordinator overhead time: "
                out += f"{fmt(self.streaming_split_coord_time.get())}\n"

        return out

    def __repr__(self, level=0) -> str:
        indent = leveled_indent(level)
        return (
            f"IterStatsSummary(\n"
            f"{indent}   wait_time={fmt(self.wait_time.get()) or None},\n"
            f"{indent}   get_ref_bundles_time={fmt(self.get_ref_bundles_time.get()) or None},\n"
            f"{indent}   get_time={fmt(self.get_time.get()) or None},\n"
            f"{indent}   iter_blocks_local={self.iter_blocks_local or None},\n"
            f"{indent}   iter_blocks_remote={self.iter_blocks_remote or None},\n"
            f"{indent}   iter_unknown_location={self.iter_unknown_location or None},\n"
            f"{indent}   iter_prefetched_bytes={self.iter_prefetched_bytes or None},\n"
            f"{indent}   next_time={fmt(self.next_time.get()) or None},\n"
            f"{indent}   format_time={fmt(self.format_time.get()) or None},\n"
            f"{indent}   user_time={fmt(self.user_time.get()) or None},\n"
            f"{indent}   total_time={fmt(self.total_time.get()) or None},\n"
            f"{indent})"
        )
