from typing import Dict, List, Optional

from ray.tune.experiment import Trial
from ray.tune.search import ConcurrencyLimiter, Searcher
from ray.tune.search.search_generator import SearchGenerator


class _MockSearcher(Searcher):
    def __init__(self, **kwargs):
        self.live_trials = {}
        self.counter = {"result": 0, "complete": 0}
        self.final_results = []
        self.stall = False
        self.results = []
        super(_MockSearcher, self).__init__(**kwargs)

    def suggest(self, trial_id: str):
        if not self.stall:
            self.live_trials[trial_id] = 1
            return {"test_variable": 2}
        return None

    def on_trial_result(self, trial_id: str, result: Dict):
        self.counter["result"] += 1
        self.results += [result]

    def on_trial_complete(
        self, trial_id: str, result: Optional[Dict] = None, error: bool = False
    ):
        self.counter["complete"] += 1
        if result:
            self._process_result(result)
        if trial_id in self.live_trials:
            del self.live_trials[trial_id]

    def _process_result(self, result: Dict):
        self.final_results += [result]


class _MockSuggestionAlgorithm(SearchGenerator):
    def __init__(self, max_concurrent: Optional[int] = None, **kwargs):
        self.searcher = _MockSearcher(**kwargs)
        if max_concurrent:
            self.searcher = ConcurrencyLimiter(
                self.searcher, max_concurrent=max_concurrent
            )
        super(_MockSuggestionAlgorithm, self).__init__(self.searcher)

    @property
    def live_trials(self) -> List[Trial]:
        return self.searcher.live_trials

    @property
    def results(self) -> List[Dict]:
        return self.searcher.results
