from __future__ import annotations

from functools import partial
import re
from typing import (
    TYPE_CHECKING,
    Any,
    Literal,
    Self,
)

import numpy as np

from pandas._libs import lib
from pandas.compat import (
    HAS_PYARROW,
    pa_version_under17p0,
    pa_version_under21p0,
)

if HAS_PYARROW:
    import pyarrow as pa
    import pyarrow.compute as pc

if TYPE_CHECKING:
    from collections.abc import Callable

    from pandas._typing import Scalar


class ArrowStringArrayMixin:
    _pa_array: pa.ChunkedArray

    def __init__(self, *args, **kwargs) -> None:
        raise NotImplementedError

    def _from_pyarrow_array(self, pa_array) -> Self:
        raise NotImplementedError

    def _convert_bool_result(self, result, na=lib.no_default, method_name=None):
        # Convert a bool-dtype result to the appropriate result type
        raise NotImplementedError

    def _convert_int_result(self, result):
        # Convert an integer-dtype result to the appropriate result type
        raise NotImplementedError

    def _apply_elementwise(self, func: Callable) -> list[list[Any]]:
        raise NotImplementedError

    @staticmethod
    def _has_unsupported_regex(pat: str | re.Pattern) -> bool:
        """
        Determine if regex pattern contains features not supported by RE2 / pyarrow.

        This includes lookaround (lookahead or lookbehind) assertions and
        backreferences.

        Parameters
        ----------
        pat: str | re.Pattern
            Regex pattern.

        Returns
        -------
        bool
            Whether `pat` contains a lookahead or lookbehind.
        """
        try:
            # error: Module "re" has no attribute "_parser"
            from re import _parser  # type: ignore[attr-defined]

            regex_parser = _parser.parse
        except Exception as err:
            raise type(err)(
                "Incompatible version for regex; you will need to upgrade pandas "
                "or downgrade Python"
            ) from err

        def has_unsupported_code(tokens):
            # For certain op codes we need to recurse.
            for op_code, argument in tokens:
                if (
                    (
                        op_code == _parser.SUBPATTERN
                        and has_unsupported_code(argument[3])
                    )
                    or (
                        op_code == _parser.BRANCH
                        and any(has_unsupported_code(tokens) for tokens in argument[1])
                    )
                    or (
                        op_code
                        in [_parser.ASSERT_NOT, _parser.ASSERT, _parser.GROUPREF]
                    )
                ):
                    return True
            return False

        str_pat = pat.pattern if isinstance(pat, re.Pattern) else pat
        tokens = regex_parser(str_pat)
        return has_unsupported_code(tokens)

    def _str_len(self):
        result = pc.utf8_length(self._pa_array)
        return self._convert_int_result(result)

    def _str_lower(self) -> Self:
        return self._from_pyarrow_array(pc.utf8_lower(self._pa_array))

    def _str_upper(self) -> Self:
        return self._from_pyarrow_array(pc.utf8_upper(self._pa_array))

    def _str_strip(self, to_strip=None) -> Self:
        if to_strip is None:
            result = pc.utf8_trim_whitespace(self._pa_array)
        else:
            result = pc.utf8_trim(self._pa_array, characters=to_strip)
        return self._from_pyarrow_array(result)

    def _str_lstrip(self, to_strip=None) -> Self:
        if to_strip is None:
            result = pc.utf8_ltrim_whitespace(self._pa_array)
        else:
            result = pc.utf8_ltrim(self._pa_array, characters=to_strip)
        return self._from_pyarrow_array(result)

    def _str_rstrip(self, to_strip=None) -> Self:
        if to_strip is None:
            result = pc.utf8_rtrim_whitespace(self._pa_array)
        else:
            result = pc.utf8_rtrim(self._pa_array, characters=to_strip)
        return self._from_pyarrow_array(result)

    def _str_pad(
        self,
        width: int,
        side: Literal["left", "right", "both"] = "left",
        fillchar: str = " ",
    ) -> Self:
        if side == "left":
            pa_pad = pc.utf8_lpad
        elif side == "right":
            pa_pad = pc.utf8_rpad
        elif side == "both":
            if pa_version_under17p0:
                # GH#59624 fall back to object dtype
                from pandas import array

                obj_arr = self.astype(object, copy=False)  # type: ignore[attr-defined]
                obj = array(obj_arr, dtype=object)
                result = obj._str_pad(width, side, fillchar)  # type: ignore[attr-defined]
                return type(self)._from_sequence(result, dtype=self.dtype)  # type: ignore[attr-defined]
            else:
                # GH#54792
                # https://github.com/apache/arrow/issues/15053#issuecomment-2317032347
                lean_left = (width % 2) == 0
                pa_pad = partial(pc.utf8_center, lean_left_on_odd_padding=lean_left)
        else:
            raise ValueError(
                f"Invalid side: {side}. Side must be one of 'left', 'right', 'both'"
            )
        return self._from_pyarrow_array(
            pa_pad(self._pa_array, width=width, padding=fillchar)
        )

    def _str_get(self, i: int) -> Self:
        lengths = pc.utf8_length(self._pa_array)
        if i >= 0:
            out_of_bounds = pc.greater_equal(i, lengths)
            start = i
            stop = i + 1
            step = 1
        else:
            out_of_bounds = pc.greater(-i, lengths)
            start = i
            stop = i - 1
            step = -1
        not_out_of_bounds = pc.invert(out_of_bounds.fill_null(True))
        selected = pc.utf8_slice_codeunits(
            self._pa_array, start=start, stop=stop, step=step
        )
        null_value = pa.scalar(None, type=self._pa_array.type)
        result = pc.if_else(not_out_of_bounds, selected, null_value)
        return self._from_pyarrow_array(result)

    def _str_slice(
        self, start: int | None = None, stop: int | None = None, step: int | None = None
    ) -> Self:
        if start is None:
            if step is not None and step < 0:
                # GH#59710
                start = -1
            else:
                start = 0
        if step is None:
            step = 1
        return self._from_pyarrow_array(
            pc.utf8_slice_codeunits(self._pa_array, start=start, stop=stop, step=step)
        )

    def _str_slice_replace(
        self, start: int | None = None, stop: int | None = None, repl: str | None = None
    ) -> Self:
        if repl is None:
            repl = ""
        if start is None:
            start = 0
        if stop is None:
            stop = np.iinfo(np.int64).max
        return self._from_pyarrow_array(
            pc.utf8_replace_slice(self._pa_array, start, stop, repl)
        )

    def _str_replace(
        self,
        pat: str | re.Pattern,
        repl: str | Callable,
        n: int = -1,
        case: bool = True,
        flags: int = 0,
        regex: bool = True,
    ) -> Self:
        if (
            isinstance(pat, re.Pattern)
            or callable(repl)
            or not case
            or flags
            or (isinstance(repl, str) and r"\g<" in repl)
        ):
            raise NotImplementedError(
                "replace is not supported with a re.Pattern, callable repl, "
                "case=False, flags!=0, or when the replacement string contains "
                "named group references (\\g<...>)"
            )

        func = pc.replace_substring_regex if regex else pc.replace_substring
        # https://github.com/apache/arrow/issues/39149
        # GH 56404, unexpected behavior with negative max_replacements with pyarrow.
        pa_max_replacements = None if n < 0 else n
        result = func(
            self._pa_array,
            pattern=pat,
            replacement=repl,
            max_replacements=pa_max_replacements,
        )
        return self._from_pyarrow_array(result)

    def _str_capitalize(self) -> Self:
        return self._from_pyarrow_array(pc.utf8_capitalize(self._pa_array))

    def _str_title(self) -> Self:
        return self._from_pyarrow_array(pc.utf8_title(self._pa_array))

    def _str_swapcase(self) -> Self:
        return self._from_pyarrow_array(pc.utf8_swapcase(self._pa_array))

    def _str_removeprefix(self, prefix: str):
        if prefix == "":
            return self._from_pyarrow_array(self._pa_array)
        starts_with = pc.starts_with(self._pa_array, pattern=prefix)
        removed = pc.utf8_slice_codeunits(self._pa_array, len(prefix))
        result = pc.if_else(starts_with, removed, self._pa_array)
        return self._from_pyarrow_array(result)

    def _str_removesuffix(self, suffix: str):
        if suffix == "":
            return self._from_pyarrow_array(self._pa_array)
        ends_with = pc.ends_with(self._pa_array, pattern=suffix)
        removed = pc.utf8_slice_codeunits(self._pa_array, 0, stop=-len(suffix))
        result = pc.if_else(ends_with, removed, self._pa_array)
        return self._from_pyarrow_array(result)

    def _str_startswith(
        self, pat: str | tuple[str, ...], na: Scalar | lib.NoDefault = lib.no_default
    ):
        if isinstance(pat, str):
            result = pc.starts_with(self._pa_array, pattern=pat)
        elif len(pat) == 0:
            # For empty tuple we return null for missing values and False
            #  for valid values.
            result = pc.if_else(pc.is_null(self._pa_array), None, False)
        else:
            result = pc.starts_with(self._pa_array, pattern=pat[0])

            for p in pat[1:]:
                result = pc.or_(result, pc.starts_with(self._pa_array, pattern=p))
        return self._convert_bool_result(result, na=na, method_name="startswith")

    def _str_endswith(
        self, pat: str | tuple[str, ...], na: Scalar | lib.NoDefault = lib.no_default
    ):
        if isinstance(pat, str):
            result = pc.ends_with(self._pa_array, pattern=pat)
        elif len(pat) == 0:
            # For empty tuple we return null for missing values and False
            #  for valid values.
            result = pc.if_else(pc.is_null(self._pa_array), None, False)
        else:
            result = pc.ends_with(self._pa_array, pattern=pat[0])

            for p in pat[1:]:
                result = pc.or_(result, pc.ends_with(self._pa_array, pattern=p))
        return self._convert_bool_result(result, na=na, method_name="endswith")

    def _str_isalnum(self):
        result = pc.utf8_is_alnum(self._pa_array)
        return self._convert_bool_result(result)

    def _str_isalpha(self):
        result = pc.utf8_is_alpha(self._pa_array)
        return self._convert_bool_result(result)

    def _str_isascii(self):
        result = pc.string_is_ascii(self._pa_array)
        return self._convert_bool_result(result)

    def _str_isdecimal(self):
        result = pc.utf8_is_decimal(self._pa_array)
        return self._convert_bool_result(result)

    def _str_isdigit(self):
        if pa_version_under21p0:
            # https://github.com/pandas-dev/pandas/issues/61466
            res_list = self._apply_elementwise(str.isdigit)
            return self._convert_bool_result(
                pa.chunked_array(res_list, type=pa.bool_())
            )
        result = pc.utf8_is_digit(self._pa_array)
        return self._convert_bool_result(result)

    def _str_islower(self):
        result = pc.utf8_is_lower(self._pa_array)
        return self._convert_bool_result(result)

    def _str_isnumeric(self):
        result = pc.utf8_is_numeric(self._pa_array)
        return self._convert_bool_result(result)

    def _str_isspace(self):
        result = pc.utf8_is_space(self._pa_array)
        return self._convert_bool_result(result)

    def _str_istitle(self):
        result = pc.utf8_is_title(self._pa_array)
        return self._convert_bool_result(result)

    def _str_isupper(self):
        result = pc.utf8_is_upper(self._pa_array)
        return self._convert_bool_result(result)

    def _str_contains(
        self,
        pat,
        case: bool = True,
        flags: int = 0,
        na: Scalar | lib.NoDefault = lib.no_default,
        regex: bool = True,
    ):
        if flags:
            raise NotImplementedError(f"contains not implemented with {flags=}")

        if regex:
            pa_contains = pc.match_substring_regex
        else:
            pa_contains = pc.match_substring
        result = pa_contains(self._pa_array, pat, ignore_case=not case)
        return self._convert_bool_result(result, na=na, method_name="contains")

    def _str_match(
        self,
        pat: str,
        case: bool = True,
        flags: int = 0,
        na: Scalar | lib.NoDefault = lib.no_default,
    ):
        if not pat.startswith("^"):
            pat = f"^({pat})"
        return ArrowStringArrayMixin._str_contains(
            self, pat, case, flags, na, regex=True
        )

    def _str_fullmatch(
        self,
        pat: str,
        case: bool = True,
        flags: int = 0,
        na: Scalar | lib.NoDefault = lib.no_default,
    ):
        if (not pat.endswith("$") or pat.endswith("\\$")) and not pat.startswith("^"):
            pat = f"^({pat})$"
        elif not pat.endswith("$") or pat.endswith("\\$"):
            pat = f"^({pat[1:]})$"
        elif not pat.startswith("^"):
            pat = f"^({pat[0:-1]})$"
        return ArrowStringArrayMixin._str_match(self, pat, case, flags, na)

    def _str_find(self, sub: str, start: int = 0, end: int | None = None):
        if (start == 0 or start is None) and end is None:
            result = pc.find_substring(self._pa_array, sub)
        else:
            if sub == "":
                # GH#56792
                res_list = self._apply_elementwise(
                    lambda val: val.find(sub, start, end)
                )
                return self._convert_int_result(pa.chunked_array(res_list))
            if start is None:
                start_offset = 0
                start = 0
            elif start < 0:
                start_offset = pc.add(start, pc.utf8_length(self._pa_array))
                start_offset = pc.if_else(pc.less(start_offset, 0), 0, start_offset)
            else:
                start_offset = start
            slices = pc.utf8_slice_codeunits(self._pa_array, start, stop=end)
            result = pc.find_substring(slices, sub)
            found = pc.not_equal(result, pa.scalar(-1, type=result.type))
            offset_result = pc.add(result, start_offset)
            result = pc.if_else(found, offset_result, -1)
        return self._convert_int_result(result)
