from sklearn.datasets import load_breast_cancer

from ray import tune
from ray.data import Dataset, Datasource, ReadTask, read_datasource
from ray.data.block import BlockMetadata
from ray.tune.impl.utils import execute_dataset


# TODO(xwjiang): Enable this when Clark's out-of-band-serialization is landed.
class TestDatasource(Datasource):
    def prepare_read(self, parallelism: int, **read_args):
        import pyarrow as pa

        def load_data():
            data_raw = load_breast_cancer(as_frame=True)
            dataset_df = data_raw["data"]
            dataset_df["target"] = data_raw["target"]
            return [pa.Table.from_pandas(dataset_df)]

        meta = BlockMetadata(
            num_rows=None,
            size_bytes=None,
            input_files=None,
            exec_stats=None,
        )
        return [ReadTask(load_data, meta)]


def gen_dataset_func() -> Dataset:
    test_datasource = TestDatasource()
    return read_datasource(test_datasource)


def test_grid_search():
    ds1 = gen_dataset_func().lazy().map(lambda x: x)
    ds2 = gen_dataset_func().lazy().map(lambda x: x)
    assert not ds1._plan._has_final_stage_snapshot()
    assert not ds2._plan._has_final_stage_snapshot()
    param_space = {"train_dataset": tune.grid_search([ds1, ds2])}
    execute_dataset(param_space)
    executed_ds = param_space["train_dataset"]["grid_search"]
    assert len(executed_ds) == 2
    assert executed_ds[0]._plan._has_final_stage_snapshot()
    assert executed_ds[1]._plan._has_final_stage_snapshot()


def test_choice():
    ds1 = gen_dataset_func().lazy().map(lambda x: x)
    ds2 = gen_dataset_func().lazy().map(lambda x: x)
    assert not ds1._plan._has_final_stage_snapshot()
    assert not ds2._plan._has_final_stage_snapshot()
    param_space = {"train_dataset": tune.choice([ds1, ds2])}
    execute_dataset(param_space)
    executed_ds = param_space["train_dataset"].categories
    assert len(executed_ds) == 2
    assert executed_ds[0]._plan._has_final_stage_snapshot()
    assert executed_ds[1]._plan._has_final_stage_snapshot()


if __name__ == "__main__":
    import sys

    import pytest

    sys.exit(pytest.main(["-v", "-x", __file__]))
