"""Abstractions around GCP resources and nodes.

The logic has been abstracted away here to allow for different GCP resources
(API endpoints), which can differ widely, making it impossible to use
the same logic for everything.

Classes inheriting from ``GCPResource`` represent different GCP resources -
API endpoints that allow for nodes to be created, removed, listed and
otherwise managed. Those classes contain methods abstracting GCP REST API
calls.
Each resource has a corresponding node type, represented by a
class inheriting from ``GCPNode``. Those classes are essentially dicts
with some extra methods. The instances of those classes will be created
from API responses.

The ``GCPNodeType`` enum is a lightweight way to classify nodes.

Currently, Compute and TPU resources & nodes are supported.

In order to add support for new resources, create classes inheriting from
``GCPResource`` and ``GCPNode``, update the ``GCPNodeType`` enum,
update the ``_generate_node_name`` method and finally update the
node provider.
"""

import abc
import logging
import re
import time
from collections import UserDict
from copy import deepcopy
from enum import Enum
from functools import wraps
from typing import Any, Dict, List, Optional, Tuple, Union
from uuid import uuid4

import httplib2
from google_auth_httplib2 import AuthorizedHttp
from googleapiclient.discovery import Resource
from googleapiclient.errors import HttpError

from ray.autoscaler.tags import TAG_RAY_CLUSTER_NAME, TAG_RAY_NODE_NAME

logger = logging.getLogger(__name__)

INSTANCE_NAME_MAX_LEN = 64
INSTANCE_NAME_UUID_LEN = 8
MAX_POLLS = 12
# TPUs take a long while to respond, so we increase the MAX_POLLS
# considerably - this probably could be smaller
# TPU deletion uses MAX_POLLS
MAX_POLLS_TPU = MAX_POLLS * 8
POLL_INTERVAL = 5


def _retry_on_exception(
    exception: Union[Exception, Tuple[Exception]],
    regex: Optional[str] = None,
    max_retries: int = MAX_POLLS,
    retry_interval_s: int = POLL_INTERVAL,
):
    """Retry a function call n-times for as long as it throws an exception."""

    def dec(func):
        @wraps(func)
        def wrapper(*args, **kwargs):
            def try_catch_exc():
                try:
                    value = func(*args, **kwargs)
                    return value
                except Exception as e:
                    if not isinstance(e, exception) or (
                        regex and not re.search(regex, str(e))
                    ):
                        raise e
                    return e

            for _ in range(max_retries):
                ret = try_catch_exc()
                if not isinstance(ret, Exception):
                    break
                time.sleep(retry_interval_s)
            if isinstance(ret, Exception):
                raise ret
            return ret

        return wrapper

    return dec


def _generate_node_name(labels: dict, node_suffix: str) -> str:
    """Generate node name from labels and suffix.

    This is required so that the correct resource can be selected
    when the only information autoscaler has is the name of the node.

    The suffix is expected to be one of 'compute' or 'tpu'
    (as in ``GCPNodeType``).
    """
    name_label = labels[TAG_RAY_NODE_NAME]
    assert len(name_label) <= (INSTANCE_NAME_MAX_LEN - INSTANCE_NAME_UUID_LEN - 1), (
        name_label,
        len(name_label),
    )
    return f"{name_label}-{uuid4().hex[:INSTANCE_NAME_UUID_LEN]}-{node_suffix}"


class GCPNodeType(Enum):
    """Enum for GCP node types (compute & tpu)"""

    COMPUTE = "compute"
    TPU = "tpu"

    @staticmethod
    def from_gcp_node(node: "GCPNode"):
        """Return GCPNodeType based on ``node``'s class"""
        if isinstance(node, GCPTPUNode):
            return GCPNodeType.TPU
        if isinstance(node, GCPComputeNode):
            return GCPNodeType.COMPUTE
        raise TypeError(f"Wrong GCPNode type {type(node)}.")

    @staticmethod
    def name_to_type(name: str):
        """Provided a node name, determine the type.

        This expects the name to be in format '[NAME]-[UUID]-[TYPE]',
        where [TYPE] is either 'compute' or 'tpu'.
        """
        return GCPNodeType(name.split("-")[-1])


class GCPNode(UserDict, metaclass=abc.ABCMeta):
    """Abstraction around compute and tpu nodes"""

    NON_TERMINATED_STATUSES = None
    RUNNING_STATUSES = None
    STATUS_FIELD = None

    def __init__(self, base_dict: dict, resource: "GCPResource", **kwargs) -> None:
        super().__init__(base_dict, **kwargs)
        self.resource = resource
        assert isinstance(self.resource, GCPResource)

    def is_running(self) -> bool:
        return self.get(self.STATUS_FIELD) in self.RUNNING_STATUSES

    def is_terminated(self) -> bool:
        return self.get(self.STATUS_FIELD) not in self.NON_TERMINATED_STATUSES

    @abc.abstractmethod
    def get_labels(self) -> dict:
        return

    @abc.abstractmethod
    def get_external_ip(self) -> str:
        return

    @abc.abstractmethod
    def get_internal_ip(self) -> str:
        return

    def __repr__(self) -> str:
        return f"<{self.__class__.__name__}: {self.get('name')}>"


class GCPComputeNode(GCPNode):
    """Abstraction around compute nodes"""

    # https://cloud.google.com/compute/docs/instances/instance-life-cycle
    NON_TERMINATED_STATUSES = {"PROVISIONING", "STAGING", "RUNNING"}
    TERMINATED_STATUSES = {"TERMINATED", "SUSPENDED"}
    RUNNING_STATUSES = {"RUNNING"}
    STATUS_FIELD = "status"

    def get_labels(self) -> dict:
        return self.get("labels", {})

    def get_external_ip(self) -> str:
        return (
            self.get("networkInterfaces", [{}])[0]
            .get("accessConfigs", [{}])[0]
            .get("natIP", None)
        )

    def get_internal_ip(self) -> str:
        return self.get("networkInterfaces", [{}])[0].get("networkIP")


class GCPTPUNode(GCPNode):
    """Abstraction around tpu nodes"""

    # https://cloud.google.com/tpu/docs/reference/rest/v2alpha1/projects.locations.nodes#State

    NON_TERMINATED_STATUSES = {"CREATING", "STARTING", "RESTARTING", "READY"}
    RUNNING_STATUSES = {"READY"}
    STATUS_FIELD = "state"

    def get_labels(self) -> dict:
        return self.get("labels", {})

    @property
    def num_workers(self) -> int:
        return len(self.get("networkEndpoints", [{}]))

    def get_external_ips(self) -> List[str]:
        return self.get("networkEndpoints", [{}])

    def get_external_ip(self, worker_index: int = 0) -> str:
        return (
            self.get_external_ips()[worker_index]
            .get("accessConfig", {})
            .get("externalIp", None)
        )

    def get_internal_ips(self) -> List[str]:
        return self.get("networkEndpoints", [{}])

    def get_internal_ip(self, worker_index: int = 0) -> str:
        return self.get_internal_ips()[worker_index].get("ipAddress", None)


class GCPResource(metaclass=abc.ABCMeta):
    """Abstraction around compute and TPU resources"""

    def __init__(
        self,
        resource: Resource,
        project_id: str,
        availability_zone: str,
        cluster_name: str,
    ) -> None:
        self.resource = resource
        self.project_id = project_id
        self.availability_zone = availability_zone
        self.cluster_name = cluster_name

    @abc.abstractmethod
    def get_new_authorized_http(self, http: AuthorizedHttp) -> AuthorizedHttp:
        """Generate a new AuthorizedHttp object with the given credentials."""
        return

    @abc.abstractmethod
    def wait_for_operation(
        self,
        operation: dict,
        max_polls: int = MAX_POLLS,
        poll_interval: int = POLL_INTERVAL,
    ) -> dict:
        """Waits a preset amount of time for operation to complete."""
        return None

    @abc.abstractmethod
    def list_instances(
        self,
        label_filters: Optional[dict] = None,
        is_terminated: bool = False,
    ) -> List["GCPNode"]:
        """Returns a filtered list of all instances.

        The filter removes all terminated instances and, if ``label_filters``
        are provided, all instances which labels are not matching the
        ones provided.
        """
        return

    @abc.abstractmethod
    def get_instance(self, node_id: str) -> "GCPNode":
        """Returns a single instance."""
        return

    @abc.abstractmethod
    def set_labels(
        self, node: GCPNode, labels: dict, wait_for_operation: bool = True
    ) -> dict:
        """Sets labels on an instance and returns result.

        Completely replaces the labels dictionary."""
        return

    @abc.abstractmethod
    def create_instance(
        self, base_config: dict, labels: dict, wait_for_operation: bool = True
    ) -> Tuple[dict, str]:
        """Creates a single instance and returns result.

        Returns a tuple of (result, node_name).
        """
        return

    def create_instances(
        self,
        base_config: dict,
        labels: dict,
        count: int,
        wait_for_operation: bool = True,
    ) -> List[Tuple[dict, str]]:
        """Creates multiple instances and returns result.

        Returns a list of tuples of (result, node_name).
        """
        operations = [
            self.create_instance(base_config, labels, wait_for_operation=False)
            for i in range(count)
        ]

        if wait_for_operation:
            results = [
                (self.wait_for_operation(operation), node_name)
                for operation, node_name in operations
            ]
        else:
            results = operations

        return results

    @abc.abstractmethod
    def delete_instance(self, node_id: str, wait_for_operation: bool = True) -> dict:
        """Deletes an instance and returns result."""
        return

    @abc.abstractmethod
    def stop_instance(self, node_id: str, wait_for_operation: bool = True) -> dict:
        """Deletes an instance and returns result."""
        return

    @abc.abstractmethod
    def start_instance(self, node_id: str, wait_for_operation: bool = True) -> dict:
        """Starts a single instance and returns result."""
        return


class GCPCompute(GCPResource):
    """Abstraction around GCP compute resource"""

    def get_new_authorized_http(self, http: AuthorizedHttp) -> AuthorizedHttp:
        """Generate a new AuthorizedHttp object with the given credentials."""
        new_http = AuthorizedHttp(http.credentials, http=httplib2.Http())
        return new_http

    def wait_for_operation(
        self,
        operation: dict,
        max_polls: int = MAX_POLLS,
        poll_interval: int = POLL_INTERVAL,
    ) -> dict:
        """Poll for compute zone operation until finished."""
        logger.info(
            "wait_for_compute_zone_operation: "
            f"Waiting for operation {operation['name']} to finish..."
        )

        for _ in range(max_polls):
            result = (
                self.resource.zoneOperations()
                .get(
                    project=self.project_id,
                    operation=operation["name"],
                    zone=self.availability_zone,
                )
                .execute(http=self.get_new_authorized_http(self.resource._http))
            )
            if "error" in result:
                raise Exception(result["error"])

            if result["status"] == "DONE":
                logger.info(
                    "wait_for_compute_zone_operation: "
                    f"Operation {operation['name']} finished."
                )
                break

            time.sleep(poll_interval)

        return result

    def list_instances(
        self,
        label_filters: Optional[dict] = None,
        is_terminated: bool = False,
    ) -> List[GCPComputeNode]:
        label_filters = label_filters or {}

        if label_filters:
            label_filter_expr = (
                "("
                + " AND ".join(
                    [
                        "(labels.{key} = {value})".format(key=key, value=value)
                        for key, value in label_filters.items()
                    ]
                )
                + ")"
            )
        else:
            label_filter_expr = ""

        statuses = (
            GCPComputeNode.TERMINATED_STATUSES
            if is_terminated
            else GCPComputeNode.NON_TERMINATED_STATUSES
        )

        instance_state_filter_expr = (
            "("
            + " OR ".join(
                ["(status = {status})".format(status=status) for status in statuses]
            )
            + ")"
        )

        cluster_name_filter_expr = "(labels.{key} = {value})".format(
            key=TAG_RAY_CLUSTER_NAME, value=self.cluster_name
        )

        # TPU VMs spawn accompanying Compute Instances that must be filtered out,
        # else this results in duplicated nodes.
        tpu_negation_filter_expr = "(NOT labels.{label}:*)".format(label="tpu_cores")

        not_empty_filters = [
            f
            for f in [
                label_filter_expr,
                instance_state_filter_expr,
                cluster_name_filter_expr,
                tpu_negation_filter_expr,
            ]
            if f
        ]

        filter_expr = " AND ".join(not_empty_filters)

        response = (
            self.resource.instances()
            .list(
                project=self.project_id,
                zone=self.availability_zone,
                filter=filter_expr,
            )
            .execute(http=self.get_new_authorized_http(self.resource._http))
        )

        instances = response.get("items", [])
        return [GCPComputeNode(i, self) for i in instances]

    def get_instance(self, node_id: str) -> GCPComputeNode:
        instance = (
            self.resource.instances()
            .get(
                project=self.project_id,
                zone=self.availability_zone,
                instance=node_id,
            )
            .execute()
        )

        return GCPComputeNode(instance, self)

    def set_labels(
        self, node: GCPComputeNode, labels: dict, wait_for_operation: bool = True
    ) -> dict:
        body = {
            "labels": dict(node["labels"], **labels),
            "labelFingerprint": node["labelFingerprint"],
        }
        node_id = node["name"]
        operation = (
            self.resource.instances()
            .setLabels(
                project=self.project_id,
                zone=self.availability_zone,
                instance=node_id,
                body=body,
            )
            .execute(http=self.get_new_authorized_http(self.resource._http))
        )

        if wait_for_operation:
            result = self.wait_for_operation(operation)
        else:
            result = operation

        return result

    def _convert_resources_to_urls(
        self, configuration_dict: Dict[str, Any]
    ) -> Dict[str, Any]:
        """Ensures that resources are in their full URL form.

        GCP expects machineType and acceleratorType to be a full URL (e.g.
        `zones/us-west1/machineTypes/n1-standard-2`) instead of just the
        type (`n1-standard-2`)

        Args:
            configuration_dict: Dict of options that will be passed to GCP
        Returns:
            Input dictionary, but with possibly expanding `machineType` and
                `acceleratorType`.
        """
        configuration_dict = deepcopy(configuration_dict)
        existing_machine_type = configuration_dict["machineType"]
        if not re.search(".*/machineTypes/.*", existing_machine_type):
            configuration_dict[
                "machineType"
            ] = "zones/{zone}/machineTypes/{machine_type}".format(
                zone=self.availability_zone,
                machine_type=configuration_dict["machineType"],
            )

        for accelerator in configuration_dict.get("guestAccelerators", []):
            gpu_type = accelerator["acceleratorType"]
            if not re.search(".*/acceleratorTypes/.*", gpu_type):
                accelerator[
                    "acceleratorType"
                ] = "projects/{project}/zones/{zone}/acceleratorTypes/{accelerator}".format(  # noqa: E501
                    project=self.project_id,
                    zone=self.availability_zone,
                    accelerator=gpu_type,
                )

        return configuration_dict

    def create_instance(
        self, base_config: dict, labels: dict, wait_for_operation: bool = True
    ) -> Tuple[dict, str]:

        config = self._convert_resources_to_urls(base_config)
        # removing TPU-specific default key set in config.py
        config.pop("networkConfig", None)
        name = _generate_node_name(labels, GCPNodeType.COMPUTE.value)

        labels = dict(config.get("labels", {}), **labels)

        config.update(
            {
                "labels": dict(labels, **{TAG_RAY_CLUSTER_NAME: self.cluster_name}),
                "name": name,
            }
        )

        # Allow Google Compute Engine instance templates.
        #
        # Config example:
        #
        #     ...
        #     node_config:
        #         sourceInstanceTemplate: global/instanceTemplates/worker-16
        #         machineType: e2-standard-16
        #     ...
        #
        # node_config parameters override matching template parameters, if any.
        #
        # https://cloud.google.com/compute/docs/instance-templates
        # https://cloud.google.com/compute/docs/reference/rest/v1/instances/insert
        source_instance_template = config.pop("sourceInstanceTemplate", None)

        operation = (
            self.resource.instances()
            .insert(
                project=self.project_id,
                zone=self.availability_zone,
                sourceInstanceTemplate=source_instance_template,
                body=config,
            )
            .execute(http=self.get_new_authorized_http(self.resource._http))
        )

        if wait_for_operation:
            result = self.wait_for_operation(operation)
        else:
            result = operation

        return result, name

    def delete_instance(self, node_id: str, wait_for_operation: bool = True) -> dict:
        operation = (
            self.resource.instances()
            .delete(
                project=self.project_id,
                zone=self.availability_zone,
                instance=node_id,
            )
            .execute()
        )

        if wait_for_operation:
            result = self.wait_for_operation(operation)
        else:
            result = operation

        return result

    def stop_instance(self, node_id: str, wait_for_operation: bool = True) -> dict:
        operation = (
            self.resource.instances()
            .stop(
                project=self.project_id,
                zone=self.availability_zone,
                instance=node_id,
            )
            .execute()
        )

        if wait_for_operation:
            result = self.wait_for_operation(operation)
        else:
            result = operation
        return result

    def start_instance(self, node_id: str, wait_for_operation: bool = True) -> dict:

        operation = (
            self.resource.instances()
            .start(
                project=self.project_id,
                zone=self.availability_zone,
                instance=node_id,
            )
            .execute(http=self.get_new_authorized_http(self.resource._http))
        )

        if wait_for_operation:
            result = self.wait_for_operation(operation)
        else:
            result = operation
        return result


class GCPTPU(GCPResource):
    """Abstraction around GCP TPU resource"""

    # node names already contain the path, but this is required for `parent`
    # arguments
    @property
    def path(self):
        return f"projects/{self.project_id}/locations/{self.availability_zone}"

    def get_new_authorized_http(self, http: AuthorizedHttp) -> AuthorizedHttp:
        """Generate a new AuthorizedHttp object with the given credentials."""
        new_http = AuthorizedHttp(http.credentials, http=httplib2.Http())
        return new_http

    def wait_for_operation(
        self,
        operation: dict,
        max_polls: int = MAX_POLLS_TPU,
        poll_interval: int = POLL_INTERVAL,
    ) -> dict:
        """Poll for TPU operation until finished."""
        logger.info(
            "wait_for_tpu_operation: "
            f"Waiting for operation {operation['name']} to finish..."
        )

        for _ in range(max_polls):
            result = (
                self.resource.projects()
                .locations()
                .operations()
                .get(name=f"{operation['name']}")
                .execute(http=self.get_new_authorized_http(self.resource._http))
            )
            if "error" in result:
                raise Exception(result["error"])

            if "response" in result:
                logger.info(
                    "wait_for_tpu_operation: "
                    f"Operation {operation['name']} finished."
                )
                break

            time.sleep(poll_interval)

        return result

    def list_instances(
        self,
        label_filters: Optional[dict] = None,
        is_terminated: bool = False,
    ) -> List[GCPTPUNode]:
        response = (
            self.resource.projects()
            .locations()
            .nodes()
            .list(parent=self.path)
            .execute(http=self.get_new_authorized_http(self.resource._http))
        )

        instances = response.get("nodes", [])
        instances = [GCPTPUNode(i, self) for i in instances]

        # filter_expr cannot be passed directly to API
        # so we need to filter the results ourselves

        # same logic as in GCPCompute.list_instances
        label_filters = label_filters or {}
        label_filters[TAG_RAY_CLUSTER_NAME] = self.cluster_name

        def filter_instance(instance: GCPTPUNode) -> bool:
            if instance.is_terminated():
                return False

            labels = instance.get_labels()
            if label_filters:
                for key, value in label_filters.items():
                    if key not in labels:
                        return False
                    if value != labels[key]:
                        return False

            return True

        instances = list(filter(filter_instance, instances))

        return instances

    def get_instance(self, node_id: str) -> GCPTPUNode:
        instance = (
            self.resource.projects()
            .locations()
            .nodes()
            .get(name=node_id)
            .execute(http=self.get_new_authorized_http(self.resource._http))
        )

        return GCPTPUNode(instance, self)

    # this sometimes fails without a clear reason, so we retry it
    # MAX_POLLS times
    @_retry_on_exception(HttpError, "unable to queue the operation")
    def set_labels(
        self, node: GCPTPUNode, labels: dict, wait_for_operation: bool = True
    ) -> dict:
        body = {
            "labels": dict(node["labels"], **labels),
        }
        update_mask = "labels"

        operation = (
            self.resource.projects()
            .locations()
            .nodes()
            .patch(
                name=node["name"],
                updateMask=update_mask,
                body=body,
            )
            .execute(http=self.get_new_authorized_http(self.resource._http))
        )

        if wait_for_operation:
            result = self.wait_for_operation(operation)
        else:
            result = operation

        return result

    def create_instance(
        self, base_config: dict, labels: dict, wait_for_operation: bool = True
    ) -> Tuple[dict, str]:
        config = base_config.copy()
        # removing Compute-specific default key set in config.py
        config.pop("networkInterfaces", None)
        name = _generate_node_name(labels, GCPNodeType.TPU.value)

        labels = dict(config.get("labels", {}), **labels)

        config.update(
            {
                "labels": dict(labels, **{TAG_RAY_CLUSTER_NAME: self.cluster_name}),
            }
        )

        if "networkConfig" not in config:
            config["networkConfig"] = {}
        if "enableExternalIps" not in config["networkConfig"]:
            # this is required for SSH to work, per google documentation
            # https://cloud.google.com/tpu/docs/users-guide-tpu-vm#create-curl
            config["networkConfig"]["enableExternalIps"] = True

        # replace serviceAccounts with serviceAccount, and scopes with scope
        # this is necessary for the head node to work
        # see here: https://tpu.googleapis.com/$discovery/rest?version=v2alpha1
        if "serviceAccounts" in config:
            config["serviceAccount"] = config.pop("serviceAccounts")[0]
            config["serviceAccount"]["scope"] = config["serviceAccount"].pop("scopes")

        operation = (
            self.resource.projects()
            .locations()
            .nodes()
            .create(
                parent=self.path,
                body=config,
                nodeId=name,
            )
            .execute(http=self.get_new_authorized_http(self.resource._http))
        )

        if wait_for_operation:
            result = self.wait_for_operation(operation)
        else:
            result = operation

        return result, name

    def delete_instance(self, node_id: str, wait_for_operation: bool = True) -> dict:
        operation = (
            self.resource.projects()
            .locations()
            .nodes()
            .delete(name=node_id)
            .execute(http=self.get_new_authorized_http(self.resource._http))
        )

        # No need to increase MAX_POLLS for deletion
        if wait_for_operation:
            result = self.wait_for_operation(operation, max_polls=MAX_POLLS)
        else:
            result = operation

        return result

    def stop_instance(self, node_id: str, wait_for_operation: bool = True) -> dict:
        operation = (
            self.resource.projects()
            .locations()
            .nodes()
            .stop(name=node_id)
            .execute(http=self.get_new_authorized_http(self.resource._http))
        )

        if wait_for_operation:
            result = self.wait_for_operation(operation, max_polls=MAX_POLLS)
        else:
            result = operation

        return result

    def start_instance(self, node_id: str, wait_for_operation: bool = True) -> dict:
        operation = (
            self.resource.projects()
            .locations()
            .nodes()
            .start(name=node_id)
            .execute(http=self.get_new_authorized_http(self.resource._http))
        )

        if wait_for_operation:
            result = self.wait_for_operation(operation, max_polls=MAX_POLLS)
        else:
            result = operation

        return result
