
    `i>                         d dl Z d dlmZ d dlmZ d dlmZ d dlmZm	Z	m
Z
 d dlmZmZ dZ e j        edd	 d
D             d d
D             z             Zd ZddZddZddZd Zd ZddZddZdS )    N)get_typename)_normalize_axis_index)lfilter)
axis_sliceaxis_assignaxis_reverse)collapse_2dapply_iir_sosaj
  
#include <cupy/math_constants.h>
#include <cupy/carray.cuh>

template<typename T>
__device__ T _compute_symiirorder2_fwd_hc(
        const int k, const T cs, const T r, const T omega) {
    T base;

    if(k < 0) {
        return 0;
    }

    if(omega == 0.0) {
        base = cs * pow(r, ((T) k)) * (k + 1);
    } else if(omega == M_PI) {
        base = cs * pow(r, ((T) k)) * (k + 1) * (1 - 2 * (k % 2));
    } else {
        base = (cs * pow(r, ((T) k)) * sin(omega * (k + 1)) /
                sin(omega));
    }
    return base;
}

template<typename T>
__global__ void compute_symiirorder2_fwd_sc(
        const int n, const int off, const T* cs_ptr, const T* r_ptr,
        const T* omega_ptr, const double precision, bool* valid, T* out) {

    int idx = blockDim.x * blockIdx.x + threadIdx.x;
    if(idx + off >= n) {
        return;
    }

    const T cs = cs_ptr[0];
    const T r = r_ptr[0];
    const T omega = omega_ptr[0];

    T val = _compute_symiirorder2_fwd_hc<T>(idx + off + 1, cs, r, omega);
    T err = val * val;

    out[idx] = val;
    valid[idx] = err <= precision;
}

template<typename T>
__device__ T _compute_symiirorder2_bwd_hs(
        const int ki, const T cs, const T rsq, const T omega) {
    T c0;
    T gamma;

    T cssq = cs * cs;
    int k = abs(ki);
    T rsupk = pow(rsq, ((T) k) / ((T) 2.0));


    if(omega == 0.0) {
        c0 = (1 + rsq) / ((1 - rsq) * (1 - rsq) * (1 - rsq)) * cssq;
        gamma = (1 - rsq) / (1 + rsq);
        return c0 * rsupk * (1 + gamma * k);
    }

    if(omega == M_PI) {
        c0 = (1 + rsq) / ((1 - rsq) * (1 - rsq) * (1 - rsq)) * cssq;
        gamma = (1 - rsq) / (1 + rsq) * (1 - 2 * (k % 2));
        return c0 * rsupk * (1 + gamma * k);
    }

    c0 = (cssq * (1.0 + rsq) / (1.0 - rsq) /
                (1 - 2 * rsq * cos(2 * omega) + rsq * rsq));
    gamma = (1.0 - rsq) / (1.0 + rsq) / tan(omega);
    return c0 * rsupk * (cos(omega * k) + gamma * sin(omega * k));
}

template<typename T>
__global__ void compute_symiirorder2_bwd_sc(
        const int n, const int off, const int l_off, const int r_off,
        const T* cs_ptr, const T* rsq_ptr, const T* omega_ptr,
        const double precision, bool* valid, T* out) {

    int idx = blockDim.x * blockIdx.x + threadIdx.x;
    if(idx + off >= n) {
        return;
    }

    const T cs = cs_ptr[0];
    const T rsq = rsq_ptr[0];
    const T omega = omega_ptr[0];

    T v1 = _compute_symiirorder2_bwd_hs<T>(idx + l_off + off, cs, rsq, omega);
    T v2 = _compute_symiirorder2_bwd_hs<T>(idx + r_off + off, cs, rsq, omega);

    T diff = v1 + v2;
    T err = diff * diff;
    out[idx] = diff;
    valid[idx] = err <= precision;
}
)z
-std=c++11c                     g | ]}d | d	S )zcompute_symiirorder2_bwd_sc<> .0ts     o/home/jaya/work/projects/VOICE-AGENT/VIET/agent-env/lib/python3.11/site-packages/cupyx/scipy/signal/_splines.py
<listcomp>r   q   s3     4 4 4 :Q999 4 4 4    )floatdoublec                     g | ]}d | d	S )zcompute_symiirorder2_fwd_sc<r   r   r   s     r   r   r   s   s3     # # #	
 )A((( # # #r   )codeoptionsname_expressionsc                     d |D             }d                     |          }|r| d| dn|}|                     |          }|S )Nc                 6    g | ]}t          |j                  S r   )r   dtype)r   args     r   r   z$_get_module_func.<locals>.<listcomp>x   s"    DDDs<	**DDDr   z, <r   )joinget_function)module	func_nametemplate_argsargs_dtypestemplatekernel_namekernels          r   _get_module_funcr(   w   s^    DDmDDDKyy%%H0=LY,,,,,,9K  --FMr   c           
         t          j        |           d         dz   |z   }t           j        }|j        dk    rPt          j        |d         |k    t           j        t	          ||d         dz
  |z
  |d         |z
  |                    }|S )Nr      axis)cupywherenansizer   )	all_validcum_polynoffr-   indiceszis          r   _find_initial_condr8      s    j##A&*S0G	B|aZAJ!OTXxa#!5qzC'd4 4 45 5 Ir         c                    t          || j                  }| j        }| j        }| j        dk    rt          | |          \  } }t	          j        |          dk    rt          d          |dk    s|dk    rr| j        t	          j        t          j                  u rd}nJ| j        t	          j        t          j	                  u rd}n"dt	          j
        | j                  j         z  }||z  }t	          j        d|d         dz   | j        	          }||z  }|t	          j        |          z  }	t	          j        || z  d
          t          | ddd
          z   }
|	|k    }t!          ||
|d                   }t	          j        t	          j        |                    rt          d          d}|dk    rd| j        d         df}t	          j        || j        	          }t)          ||dd          }t          j        dddd| df         }t	          j        |          }t/          t          | d          ||| j        d          \  }}t          j        ||f         }| |dz
  z  t          |d          z  }t)          ||dd          }t          j        |ddd| df         }t	          j        |          }t/          t          |dd          ||| j                  \  }}|dk    r"t          j        t3          |          |f         }n!t          j        t3          |          |f         }|dk    rK|                    |          }t	          j        |d|          }|j        j        s|                                }|S )Nr+   |z1| must be less than 1.0              ?ư>gMbP?
   r)   r   r,   r   ;Sum to find symmetric boundary conditions did not converge.r+      rC      F)r7   r   	apply_firstepr7   r   )r   ndimshaper	   r.   abs
ValueErrorr   float64float32finfoiexparange	conjugatecumsumr   r8   anyisnanzerosr   r_
atleast_2dr
   c_r   reshapemoveaxisflagsc_contiguouscopy)inputc0z1	precisionr-   input_shape
input_ndimpospow_z1diffr3   r2   r7   zi_shapeall_zicoefy1_outs                      r   _symiirorder1_ndro      sw    uz22D+KJzA~~(55{x||q5666C9s??;$*T\2222II[DJt|4444IItz%+66;;;II
+aR1,EK
@
@
@C3YFDN6***D{R! ! !#-eQ#C#C#CDH 	!I	IxR	A	ABx
2 KIK K 	K HA~~u{1~q)Z444FQ**F71aAsA%&D?4  D*UA.. %u> > >EB	RB S	Jr2..	.BQ**F72q!QQ&'D?4  D2r###TfEKI I IFC A~~gl3''+,gl3''+,A~~kk+&&mCT**y% 	((**CJr   c                 V   t          j        |g| j                  }t          j        |g| j                  }t          j        |          dk    rt	          d          |dk    s|dk    rt          j        | j                  j        }||z  }t          j        d| j        dz   | j                  }||z  }|t          j	        |          z  }t          j
        || z            | d         z   }||k    }t          ||| j                  }	t          j        |	          rt	          d          t           j        d| f         }
|
                    | j                  }
t          t          j        d| j                  |
| dd         |		          \  }}t           j        |	|f         }| |dz
  z  |d
         z  }	t           j        d| f         }
|
                    | j                  }
t          ||
|dd
         ddd
         |		          \  }}t           j        |ddd
         |	f         S )aR  
    Implement a smoothing IIR filter with mirror-symmetric boundary conditions
    using a cascade of first-order sections.  The second section uses a
    reversed sequence.  This implements a system with the following
    transfer function and mirror-symmetric boundary conditions::

                           c0
           H(z) = ---------------------
                   (1-z1/z) (1 - z1 z)

    The resulting signal will have mirror symmetric boundary conditions
    as well.

    Parameters
    ----------
    input : ndarray
        The input signal.
    c0, z1 : scalar
        Parameters in the transfer function.
    precision :
        Specifies the precision for calculating initial conditions
        of the recursive filter based on mirror-symmetric input.

    Returns
    -------
    output : ndarray
        The filtered signal.
    r+   r;   r<   r=   r@   r   rA   Nr7   r)   )r.   asarrayr   rL   rM   rP   
resolutionrR   r1   rS   rT   r8   rV   rX   astyper   ones)r`   ra   rb   rc   rf   rg   rh   r3   r2   r7   arl   rm   rn   s                 r   symiirorder1rw      s   : 
rdEK	(	(B	rdEK	(	(Bx||q5666C9s??Ju{++6	I
+aau{
;
;
;C3YFDN6***D{6E>**U1X5H	!I	Ix	<	<Bz"~~ KIK K 	K 	B3A	A 	!5;'''E!""I"> > >EB	RB S	BrF	"BB3A	ARBssGDDbDMb111FC73ttt9b=!!r   c                    d }|dk    r|t          j        ||           z  | dz   z  }n|t           j        k    r+|t          j        ||           z  | dz   z  dd| dz  z  z
  z  }nH|t          j        ||           z  t          j        || dz   z            z  t          j        |          z  }t          j        | dk     d|          S )Nr<   r+      r   )r.   powerpisinr/   )kcsromegabases        r   _compute_symiirorder2_fwd_hcr     s    D||DJq!$$$!,	$'		DJq!$$$A.!a1q5k/BTZ1%%%!a%(A(AA :a!eS$'''r   c                    ||z  }t          j        |           } t          j        || dz            }|dk    r3d|z   d|z
  d|z
  z  d|z
  z  z  |z  }d|z
  d|z   z  }||z  d|| z  z   z  S |t           j        k    r?d|z   d|z
  d|z
  z  d|z
  z  z  |z  }d|z
  d|z   z  dd| dz  z  z
  z  }||z  d|| z  z   z  S |d|z   z  d|z
  z  dd|z  t          j        d|z            z  z
  ||z  z   z  }d|z
  d|z   z  t          j        |          z  }||z  t          j        || z            |t          j        || z            z  z   z  S )Ng       @r<   r+   ry   r=   )r.   rL   rz   r{   costanr|   )r}   r~   rsqr   cssqrsupkra   gammas           r   _compute_symiirorder2_bwd_hsr   &  s   7DAJsAG$$E||#g1s7q3w/1s7;<tCSQW%EzQ]++#g1s7q3w/1s7;<tCSQW%Q!a%[9EzQ]++
#)
c	
*q3w!e),,,,sSy8:B3Y39%7E:%!),,utx	7J7J/JJKKr   c                 B   |dk    rt          d          |dk    s|dk    rr| j        t          j        t          j                  u rd}nJ| j        t          j        t          j                  u rd}n"dt          j        | j                  j         z  }t          || j                  }| j	        }| j        }| j        dk    rt          | |          \  } }d}||z  }d	|z  t          j        |          z  }	| }
t          j        dd	|z  t          j        |          z  z
  |z             }t          j        ||j                  }t          j        ||j                  }t          j        ||j                  }||z  }t          t          d
|          }t          j        |dz   f|j                  }t          j        |dz   ft          j                  }t          j        d	| j                  }t'          ||||          }t          j        }t          j        }t+          d| j	        d         d	z   |          D ]} |d|dz   f| j	        d         d	z   |||||||f           t-          | |||z             }|d d         d |j	        d                  }|dd          d |j	        d                  }t          j        |          rlt          j        ||z  d          |d         t-          | dd          z  z   }t3          |d d         d |j	        d                  || j	        d         |          }t          j        |          rt          j        ||z  d          |d         t-          | dd	          z  z   |d         t-          | dd          z  z   }t3          |dd          d |j	        d                  || j	        d         |          }t          j        t          j        t          j        ||f                             s nt          j        t          j        t          j        ||f                             rt          d          d}|dk    rd| j	        d         df}t          j        t          j        |ddd|	 |
 f                   }|                    | j                  }t          j        || j                  }t?          ||d	d          }t?          ||dd          }tA          t-          | d	          ||| j                  \  }}|dk    rt          j!        |||f         }nt          j        |||f         }t          t          d|          }t          j        |f|j                  }t          j        |ft          j                  }tE          |           }t          j        }t+          d| j	        d         dz   |          D ]} |d|f| j	        d         dz   |dd|t          j        ||j                  t          j        ||j                  |||f
           t-          ||||z             }t          j        |d |j	        d                  |z  d          }t3          |d |j	        d                  || j	        d         |          }t          j        t          j        |                    s nt          j        t          j        |                    rt          d          t          j        }t+          d| j	        d         dz   |          D ]} |d|f| j#        dz   |dd	|t          j        ||j                  t          j        ||j                  |||f
           t-          ||||z             }t          j        |d |j	        d                  |z  d          }t3          |d |j#                 || j#        |          }t          j        t          j        |                    s nt          j        t          j        |                    rt          d          t?          ||d	d          }t?          ||dd          }tA          t-          |dd          ||          \  }}|dk    r#t          j!        tE          |          ||f         }n"t          j        tE          |          ||f         }|dk    rK|$                    |          }t          j%        |d|          }|j&        j'        s|(                                }|S )Nr=   zr must be less than 1.0r<   gdy=r>   r?   r+      ry   compute_symiirorder2_fwd_scr@   r   r)   )r+   r,   rA   rB   rC   rD   rI   compute_symiirorder2_bwd_scrG   rq   ))rM   r   r.   rN   rO   rP   rQ   r   rJ   rK   r	   r   
atleast_1drr   r(   SYMIIR2_MODULEemptybool_rR   r   r0   ranger   rV   rT   r8   rU   rX   rY   rt   rW   r   r
   rZ   r   r1   r[   r\   r]   r^   r_   ) r`   r   r   rc   r-   rd   re   block_szr   a2a3r~   r   rh   r2   starting_diffy0rl   iinput_slicediff_y0diff_y1cum_poly_y0cum_poly_y1ri   sosrj   y_fwdrm   r   	rev_inputrn   s                                    r   _symiirorder2_ndr   ;  s,	   Cxx2333C9s??;$*T\2222II[DJt|4444IItz%+66;;;I uz22D+KJzA~~(55{H
a%C	
Q%	 B
B	QUTXe__44s:	;	;BL))EQ!!A
,sBH
%
%CI #35r#; #; :x!|oRX666D
HqL?$*===IK555M0AuMMM	B	B1ek"o)844  ##8a</B!#QAui$	 	 	
 !1x<88ss)2[.r223qrr(1K-b112:b>> 	$+g&;"EEEa :eQ#:#::<K##2#5 1" 556B$ $B :b>> 	$;w'<2FFF(+j1.E.EEF(+j1.E.EEFK $!""4{0445{B$ $B x
472r6?3344 	E	 x
472r6?++,, KIK K 	K HA~~u{1~q)
/$'"aAsRC"78
9
9C
**U[
!
!CZ444FQ**FQ**F5!cfEKA A AHE1A~~B&B& #35r#; #; :xk222D
H;dj999IU##I	B1ek"o)844  ##8+B!#Q1b$,sBH2M2MUBH--y)T K	L 	L 	L
 !Aq8|<<k$'=(9"(='=">"L')+ + +,{(,,-{EKOQP Px
2'' 	E	 x
2 KIK K 	K 
B1ek"o)844  ##8+
Q2q"dl3.I.IUBH--y)T K	L 	L 	L
 !Aq8|<<k$'=(9"(='=">"L')+ + +'{''(+uz1F Fx
2'' 	E	 x
2 KIK K 	K Q**FQ**F:eRb99936JJJFCA~~gl3''R/0gl3''R/0A~~kk+&&mCT**y% 	((**CJr   c                 &    t          | |||          S )ak  
    Implement a smoothing IIR filter with mirror-symmetric boundary conditions
    using a cascade of second-order sections.  The second section uses a
    reversed sequence.  This implements the following transfer function::

                                  cs^2
         H(z) = ---------------------------------------
                (1 - a2/z - a3/z^2) (1 - a2 z - a3 z^2 )

    where::

          a2 = 2 * r * cos(omega)
          a3 = - r ** 2
          cs = 1 - 2 * r * cos(omega) + r ** 2

    Parameters
    ----------
    input : ndarray
        The input signal.
    r, omega : float
        Parameters in the transfer function.
    precision : float
        Specifies the precision for calculating initial conditions
        of the recursive filter based on mirror-symmetric input.

    Returns
    -------
    output : ndarray
        The filtered signal.
    )r   )r`   r   r   rc   s       r   symiirorder2r     s    > E1eY777r   )r   r)   )r9   r)   )r9   )r.   cupy._core._scalarr   cupy._core.internalr   cupyx.scipy.signal._signaltoolsr   cupyx.scipy.signal._arraytoolsr   r   r   cupyx.scipy.signal._iir_utilsr	   r
   SYMIIR2_KERNEL	RawModuler   r(   r8   ro   rw   r   r   r   r   r   r   r   <module>r      s    + + + + + + 5 5 5 5 5 5 3 3 3 3 3 3+ + + + + + + + + + D D D D D D D DaF  	4 424 4 4# #!# # ##$ $ $     G G G GTC" C" C" C"L	( 	( 	(L L L*Z Z Z Zz8 8 8 8 8 8r   