Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
311 changes: 311 additions & 0 deletions pytensor/scalar/c_code/incbet.c
Original file line number Diff line number Diff line change
@@ -0,0 +1,311 @@
/* adapted from file incbet.c, obtained from the Cephes library (MIT License)
Cephes Math Library, Release 2.8: June, 2000
Copyright 1984, 1995, 2000 by Stephen L. Moshier
*/

//For GPU support
#ifdef __CUDACC__
#define DEVICE __device__
#else
#define DEVICE
#endif

#include <float.h>
#include <math.h>
#include <stdio.h>
#include <numpy/npy_math.h>


#define MINLOG -170.0
#define MAXLOG +170.0
#define MAXGAM 171.624376956302725
#define EPSILON 2.2204460492503131e-16

DEVICE static double pseries(double, double, double);
DEVICE static double incbcf(double, double, double);
DEVICE static double incbd(double, double, double);

static double big = 4.503599627370496e15;
static double biginv = 2.22044604925031308085e-16;


DEVICE double BetaInc(double a, double b, double x)
{
double xc, y, w, t;
/* check function arguments */
if (a <= 0.0) return NPY_NAN;
if (b <= 0.0) return NPY_NAN;
if (x < 0.0) return NPY_NAN;
if (1.0 < x) return NPY_NAN;

/* some special cases */
if (x == 0.0) return 0.0;
if (x == 1.0) return 1.0;

if ( (b * x) <= 1.0 && x <= 0.95)
{
return pseries(a, b, x);
}

xc = 1.0 - x;
/* reverse a and b if x is greater than the mean */
if (x > (a / (a + b)))
{
t = BetaInc(b, a, xc);
if (t <= EPSILON) return 1.0 - EPSILON;
return 1.0 - t;
}

/* Choose expansion for better convergence. */
y = x * (a+b-2.0) - (a-1.0);
if( y < 0.0 )
w = incbcf( a, b, x );
else
w = incbd( a, b, x ) / xc;

y = a * log(x);
t = b * log(xc);
if( (a+b) < MAXGAM && fabs(y) < MAXLOG && fabs(t) < MAXLOG )
{
t = pow(xc, b);
t *= pow(x, a);
t /= a;
t *= w;
t *= tgamma(a + b) / (tgamma(a) * tgamma(b));

return t;
}

/* Resort to logarithms. */
y += t + lgamma(a+b) - lgamma(a) - lgamma(b);
y += log(w / a);
if( y < MINLOG )
t = 0.0;
else
t = exp(y);

return t;
}

/* Continued fraction expansion #1
* for incomplete beta integral
*/

DEVICE static double incbcf(double a, double b, double x)
{
double xk, pk, pkm1, pkm2, qk, qkm1, qkm2;
double k1, k2, k3, k4, k5, k6, k7, k8;
double r, t, ans, thresh;
int n;

k1 = a;
k2 = a + b;
k3 = a;
k4 = a + 1.0;
k5 = 1.0;
k6 = b - 1.0;
k7 = k4;
k8 = a + 2.0;

pkm2 = 0.0;
qkm2 = 1.0;
pkm1 = 1.0;
qkm1 = 1.0;
ans = 1.0;
r = 1.0;
n = 0;
thresh = 3.0 * EPSILON;
do
{

xk = -( x * k1 * k2 ) / ( k3 * k4 );
pk = pkm1 + pkm2 * xk;
qk = qkm1 + qkm2 * xk;
pkm2 = pkm1;
pkm1 = pk;
qkm2 = qkm1;
qkm1 = qk;

xk = ( x * k5 * k6 ) / ( k7 * k8 );
pk = pkm1 + pkm2 * xk;
qk = qkm1 + qkm2 * xk;
pkm2 = pkm1;
pkm1 = pk;
qkm2 = qkm1;
qkm1 = qk;

if( qk != 0.0 )
r = pk/qk;
if( r != 0.0 )
{
t = fabs( (ans - r) / r );
ans = r;
}
else
t = 1.0;

if( t < thresh )
break;

k1 += 1.0;
k2 += 1.0;
k3 += 2.0;
k4 += 2.0;
k5 += 1.0;
k6 -= 1.0;
k7 += 2.0;
k8 += 2.0;

if( (fabs(qk) + fabs(pk)) > big )
{
pkm2 *= biginv;
pkm1 *= biginv;
qkm2 *= biginv;
qkm1 *= biginv;
}
if( (fabs(qk) < biginv) || (fabs(pk) < biginv) )
{
pkm2 *= big;
pkm1 *= big;
qkm2 *= big;
qkm1 *= big;
}
}
while( ++n < 300 );

return ans;
}

/* Continued fraction expansion #2
* for incomplete beta integral
*/

DEVICE static double incbd(double a, double b, double x)
{
double xk, pk, pkm1, pkm2, qk, qkm1, qkm2;
double k1, k2, k3, k4, k5, k6, k7, k8;
double r, t, ans, z, thresh;
int n;

k1 = a;
k2 = b - 1.0;
k3 = a;
k4 = a + 1.0;
k5 = 1.0;
k6 = a + b;
k7 = a + 1.0;;
k8 = a + 2.0;

pkm2 = 0.0;
qkm2 = 1.0;
pkm1 = 1.0;
qkm1 = 1.0;
z = x / (1.0-x);
ans = 1.0;
r = 1.0;
n = 0;
thresh = 3.0 * EPSILON;
do
{

xk = -( z * k1 * k2 ) / ( k3 * k4 );
pk = pkm1 + pkm2 * xk;
qk = qkm1 + qkm2 * xk;
pkm2 = pkm1;
pkm1 = pk;
qkm2 = qkm1;
qkm1 = qk;

xk = ( z * k5 * k6 ) / ( k7 * k8 );
pk = pkm1 + pkm2 * xk;
qk = qkm1 + qkm2 * xk;
pkm2 = pkm1;
pkm1 = pk;
qkm2 = qkm1;
qkm1 = qk;

if( qk != 0 )
r = pk/qk;
if( r != 0 )
{
t = fabs( (ans - r) / r );
ans = r;
}
else
t = 1.0;

if( t < thresh )
break;

k1 += 1.0;
k2 -= 1.0;
k3 += 2.0;
k4 += 2.0;
k5 += 1.0;
k6 += 1.0;
k7 += 2.0;
k8 += 2.0;

if( (fabs(qk) + fabs(pk)) > big )
{
pkm2 *= biginv;
pkm1 *= biginv;
qkm2 *= biginv;
qkm1 *= biginv;
}
if( (fabs(qk) < biginv) || (fabs(pk) < biginv) )
{
pkm2 *= big;
pkm1 *= big;
qkm2 *= big;
qkm1 *= big;
}
}
while( ++n < 300 );

return ans;
}


/* Power series for incomplete beta integral.
Use when b*x is small and x not too close to 1. */

DEVICE static double pseries(double a, double b, double x)
{
double s, t, u, v, n, t1, z, ai;

ai = 1.0 / a;
u = (1.0 - b) * x;
v = u / (a + 1.0);
t1 = v;
t = u;
n = 2.0;
s = 0.0;
z = EPSILON * ai;
while( fabs(v) > z )
{
u = (n - b) * x / n;
t *= u;
v = t / (a + n);
s += v;
n += 1.0;
}
s += t1;
s += ai;

u = a * log(x);
if( (a+b) < MAXGAM && fabs(u) < MAXLOG )
{
t = tgamma(a + b) / (tgamma(a) * tgamma(b));
s = s * t * pow(x,a);
}
else
{
t = lgamma(a + b) - lgamma(a) - lgamma(b) + u + log(s);
if( t < MINLOG )
s = 0.0;
else
s = exp(t);
}
return s;
}
21 changes: 19 additions & 2 deletions pytensor/scalar/math.py
Original file line number Diff line number Diff line change
Expand Up @@ -1495,8 +1495,25 @@ def grad(self, inp, grads):
),
]

def c_code(self, *args, **kwargs):
raise NotImplementedError()
def c_support_code(self, **kwargs):
with open(os.path.join(os.path.dirname(__file__), "c_code", "incbet.c")) as f:
raw = f.read()
return raw

def c_code(self, node, name, inp, out, sub):
(a, b, x) = inp
(z,) = out
if (
node.inputs[0].type in float_types
and node.inputs[1].type in float_types
and node.inputs[2].type in float_types
):
return f"""{z} = BetaInc({a}, {b}, {x});"""

raise NotImplementedError("type not supported", type)

def c_code_cache_version(self):
return (1,)


betainc = BetaInc(upgrade_to_float_no_complex, name="betainc")
Expand Down