# Runs several scenarios with varying max batch size, max concurrent queries,
# number of replicas, and with intermediate serve handles (to simulate ensemble
# models) either on or off.

import asyncio
import logging
from pprint import pprint
from typing import Dict, Union

import aiohttp
from starlette.requests import Request

import ray
from ray import serve
from ray.serve._private.benchmarks.common import run_throughput_benchmark
from ray.serve.handle import DeploymentHandle

NUM_CLIENTS = 8
CALLS_PER_BATCH = 100


async def fetch(session, data):
    async with session.get("http://localhost:8000/", data=data) as response:
        response = await response.text()
        assert response == "ok", response


@ray.remote
class Client:
    def ready(self):
        return "ok"

    async def do_queries(self, num, data):
        async with aiohttp.ClientSession() as session:
            for _ in range(num):
                await fetch(session, data)


def build_app(
    intermediate_handles: bool,
    num_replicas: int,
    max_batch_size: int,
    max_ongoing_requests: int,
):
    @serve.deployment(max_ongoing_requests=1000)
    class Upstream:
        def __init__(self, handle: DeploymentHandle):
            self._handle = handle

            # Turn off access log.
            logging.getLogger("ray.serve").setLevel(logging.WARNING)

        async def __call__(self, req: Request):
            return await self._handle.remote(await req.body())

    @serve.deployment(
        num_replicas=num_replicas,
        max_ongoing_requests=max_ongoing_requests,
    )
    class Downstream:
        def __init__(self):
            # Turn off access log.
            logging.getLogger("ray.serve").setLevel(logging.WARNING)

        @serve.batch(max_batch_size=max_batch_size)
        async def batch(self, reqs):
            return [b"ok"] * len(reqs)

        async def __call__(self, req: Union[bytes, Request]):
            if max_batch_size > 1:
                return await self.batch(req)
            else:
                return b"ok"

    if intermediate_handles:
        return Upstream.bind(Downstream.bind())
    else:
        return Downstream.bind()


async def trial(
    intermediate_handles: bool,
    num_replicas: int,
    max_batch_size: int,
    max_ongoing_requests: int,
    data_size: str,
) -> Dict[str, float]:
    results = {}

    trial_key_base = (
        f"replica:{num_replicas}/batch_size:{max_batch_size}/"
        f"concurrent_queries:{max_ongoing_requests}/"
        f"data_size:{data_size}/intermediate_handle:{intermediate_handles}"
    )

    print(
        f"intermediate_handles={intermediate_handles},"
        f"num_replicas={num_replicas},"
        f"max_batch_size={max_batch_size},"
        f"max_ongoing_requests={max_ongoing_requests},"
        f"data_size={data_size}"
    )

    app = build_app(
        intermediate_handles, num_replicas, max_batch_size, max_ongoing_requests
    )
    serve.run(app)

    if data_size == "small":
        data = None
    elif data_size == "large":
        data = b"a" * 1024 * 1024
    else:
        raise ValueError("data_size should be 'small' or 'large'.")

    async with aiohttp.ClientSession() as session:

        async def single_client():
            for _ in range(CALLS_PER_BATCH):
                await fetch(session, data)

        single_client_avg_tps, single_client_std_tps = await run_throughput_benchmark(
            single_client,
            multiplier=CALLS_PER_BATCH,
        )
        print(
            "\t{} {} +- {} requests/s".format(
                "single client {} data".format(data_size),
                single_client_avg_tps,
                single_client_std_tps,
            )
        )
        key = f"num_client:1/{trial_key_base}"
        results[key] = single_client_avg_tps

    clients = [Client.remote() for _ in range(NUM_CLIENTS)]
    ray.get([client.ready.remote() for client in clients])

    async def many_clients():
        ray.get([a.do_queries.remote(CALLS_PER_BATCH, data) for a in clients])

    multi_client_avg_tps, _ = await run_throughput_benchmark(
        many_clients,
        multiplier=CALLS_PER_BATCH * len(clients),
    )

    results[f"num_client:{len(clients)}/{trial_key_base}"] = multi_client_avg_tps
    return results


async def main():
    results = {}
    for intermediate_handles in [False, True]:
        for num_replicas in [1, 8]:
            for max_batch_size, max_ongoing_requests in [
                (1, 1),
                (1, 10000),
                (10000, 10000),
            ]:
                # TODO(edoakes): large data causes broken pipe errors.
                for data_size in ["small"]:
                    results.update(
                        await trial(
                            intermediate_handles,
                            num_replicas,
                            max_batch_size,
                            max_ongoing_requests,
                            data_size,
                        )
                    )

    print("Results from all conditions:")
    pprint(results)
    return results


if __name__ == "__main__":
    ray.init()
    serve.start()
    loop = asyncio.new_event_loop()
    asyncio.set_event_loop(loop)
    loop.run_until_complete(main())
