"""This example demonstrates the usage of Optuna define-by-run with Ray Tune.

It also checks that it is usable with a separate scheduler.

Requires the Optuna library to be installed (`pip install optuna`).

For an example of using a Tune search space, see
:doc:`/tune/examples/optuna_example`.
"""

import time
from typing import Any, Dict, Optional

import ray
from ray import tune
from ray.tune.schedulers import AsyncHyperBandScheduler
from ray.tune.search import ConcurrencyLimiter
from ray.tune.search.optuna import OptunaSearch


def evaluation_fn(step, width, height, mult=1):
    return (0.1 + width * step / 100) ** (-1) + height * 0.1 * mult


def easy_objective(config):
    # Hyperparameters
    width, height, mult = config["width"], config["height"], config.get("mult", 1)
    print(config)

    for step in range(config["steps"]):
        # Iterative training function - can be any arbitrary training procedure
        intermediate_score = evaluation_fn(step, width, height, mult)
        # Feed the score back back to Tune.
        tune.report({"iterations": step, "mean_loss": intermediate_score})
        time.sleep(0.1)


def define_by_run_func(trial) -> Optional[Dict[str, Any]]:
    """Define-by-run function to create the search space.

    Ensure no actual computation takes place here. That should go into
    the trainable passed to ``Tuner`` (in this example, that's
    ``easy_objective``).

    For more information, see https://optuna.readthedocs.io/en/stable\
/tutorial/10_key_features/002_configurations.html

    This function should either return None or a dict with constant values.
    """
    # This param is not used in the objective function.
    activation = trial.suggest_categorical("activation", ["relu", "tanh"])
    trial.suggest_float("width", 0, 20)
    trial.suggest_float("height", -100, 100)

    # Define-by-run allows for conditional search spaces.
    if activation == "relu":
        trial.suggest_float("mult", 1, 2)

    # Return all constants in a dictionary.
    return {"steps": 100}


def run_optuna_tune(smoke_test=False):
    algo = OptunaSearch(space=define_by_run_func, metric="mean_loss", mode="min")
    algo = ConcurrencyLimiter(algo, max_concurrent=4)
    scheduler = AsyncHyperBandScheduler()
    tuner = tune.Tuner(
        easy_objective,
        tune_config=tune.TuneConfig(
            metric="mean_loss",
            mode="min",
            search_alg=algo,
            scheduler=scheduler,
            num_samples=10 if smoke_test else 100,
        ),
    )
    results = tuner.fit()

    print("Best hyperparameters found were: ", results.get_best_result().config)


if __name__ == "__main__":
    import argparse

    parser = argparse.ArgumentParser()
    parser.add_argument(
        "--smoke-test", action="store_true", help="Finish quickly for testing"
    )
    args, _ = parser.parse_known_args()

    ray.init(configure_logging=False)

    run_optuna_tune(smoke_test=args.smoke_test)
