# This source code contains SciPy's code.
# https://github.com/scipy/scipy/blob/master/scipy/special/cephes/psi.c
#
#
# Cephes Math Library Release 2.8:  June, 2000
# Copyright 1984, 1987, 1992, 2000 by Stephen L. Moshier
#
#
# Code for the rational approximation on [1, 2] is:
#
# (C) Copyright John Maddock 2006.
# Use, modification and distribution are subject to the
# Boost Software License, Version 1.0. (See accompanying file
# LICENSE_1_0.txt or copy at http://www.boost.org/LICENSE_1_0.txt)

from cupy import _core


polevl_definition = '''
template<int N> static __device__ double polevl(double x, double coef[])
{
    double ans;
    double *p;

    p = coef;
    ans = *p++;

    for (int i = 0; i < N; ++i){
        ans = ans * x + *p++;
    }

    return ans;
}
'''


psi_definition = '''
__constant__ double A[] = {
    8.33333333333333333333E-2,
    -2.10927960927960927961E-2,
    7.57575757575757575758E-3,
    -4.16666666666666666667E-3,
    3.96825396825396825397E-3,
    -8.33333333333333333333E-3,
    8.33333333333333333333E-2
};

__constant__ double PI = 3.141592653589793;
__constant__ double EULER = 0.5772156649015329;

__constant__ float Y = 0.99558162689208984f;

__constant__ double root1 = 1569415565.0 / 1073741824.0;
__constant__ double root2 = (381566830.0 / 1073741824.0) / 1073741824.0;
__constant__ double root3 = 0.9016312093258695918615325266959189453125e-19;

__constant__ double P[] = {
    -0.0020713321167745952,
    -0.045251321448739056,
    -0.28919126444774784,
    -0.65031853770896507,
    -0.32555031186804491,
    0.25479851061131551
};
__constant__ double Q[] = {
    -0.55789841321675513e-6,
    0.0021284987017821144,
    0.054151797245674225,
    0.43593529692665969,
    1.4606242909763515,
    2.0767117023730469,
    1.0
};

static __device__ double digamma_imp_1_2(double x)
{
    /*
     * Rational approximation on [1, 2] taken from Boost.
     *
     * Now for the approximation, we use the form:
     *
     * digamma(x) = (x - root) * (Y + R(x-1))
     *
     * Where root is the location of the positive root of digamma,
     * Y is a constant, and R is optimised for low absolute error
     * compared to Y.
     *
     * Maximum Deviation Found:               1.466e-18
     * At double precision, max error found:  2.452e-17
     */
    double r, g;
    g = x - root1 - root2 - root3;
    r = polevl<5>(x - 1.0, P) / polevl<6>(x - 1.0, Q);

    return g * Y + g * r;
}


static __device__ double psi_asy(double x)
{
    double y, z;

    if (x < 1.0e17) {
        z = 1.0 / (x * x);
        y = z * polevl<6>(z, A);
    }
    else {
        y = 0.0;
    }

    return log(x) - (0.5 / x) - y;
}


double __device__ psi(double x)
{
    double y = 0.0;
    double q, r;
    int i, n;

    if (isnan(x)) {
        return x;
    }
    else if (isinf(x)){
        if(x > 0){
            return x;
        }else{
            return nan("");
        }
    }
    else if (x == 0) {
        return -1.0/0.0;
    }
    else if (x < 0.0) {
        /* argument reduction before evaluating tan(pi * x) */
        r = modf(x, &q);
        if (r == 0.0) {
            return nan("");
        }
        y = -PI / tan(PI * r);
        x = 1.0 - x;
    }

    /* check for positive integer up to 10 */
    if ((x <= 10.0) && (x == floor(x))) {
        n = (int)x;
        for (i = 1; i < n; i++) {
            y += 1.0 / i;
        }
        y -= EULER;
        return y;
    }

    /* use the recurrence relation to move x into [1, 2] */
    if (x < 1.0) {
        y -= 1.0 / x;
        x += 1.0;
    }
    else if (x < 10.0) {
        while (x > 2.0) {
            x -= 1.0;
            y += 1.0 / x;
        }
    }
    if ((1.0 <= x) && (x <= 2.0)) {
        y += digamma_imp_1_2(x);
        return y;
    }

    /* x is large, use the asymptotic series */
    y += psi_asy(x);
    return y;
}
'''


digamma = _core.create_ufunc(
    'cupyx_scipy_special_digamma', ('f->f', 'd->d'),
    'out0 = psi(in0)',
    preamble=polevl_definition+psi_definition,
    doc="""The digamma function.

    Args:
        x (cupy.ndarray): The input of digamma function.

    Returns:
        cupy.ndarray: Computed value of digamma function.

    .. seealso:: :data:`scipy.special.digamma`

    """)
