diff --git a/python/jittor/__init__.py b/python/jittor/__init__.py index 509db83a..d42521de 100644 --- a/python/jittor/__init__.py +++ b/python/jittor/__init__.py @@ -1270,4 +1270,4 @@ def get_len(var): from . import numpy2cupy from .contrib import concat from .misc import * -from . import sparse +from . import sparse \ No newline at end of file diff --git a/python/jittor/distributions.py b/python/jittor/distributions.py index c3b3606f..eba5d574 100644 --- a/python/jittor/distributions.py +++ b/python/jittor/distributions.py @@ -27,6 +27,18 @@ def simple_presum(x): cpu_src=src, cuda_src=src) +def lgamma(x): + header = '''#include''' + src = ''' + @alias(a, in0) + @alias(b, out0) + for (int i=0;i low + + def sample(self,sample_shape): + return jt.uniform(self.low,self.high,sample_shape) + + def log_prob(self,x): + if x < self.low or x >= self.high: + return math.inf + return -jt.log(self.high - self.low) + + def entropy(self): + return jt.log(self.high - self.low) + + def cdf(self, x): + return jt.clamp((x-self.low)/(self.high - self.low),min_v=0,max_v=1) + + +class Geometric: + def __init__(self,p=None,logits=None): + assert (p is not None) or (logits is not None) + assert 0 < p and p < 1 + if p is None: + self.prob = jt.sigmoid(logits) + self.logits = logits + elif logits is None: + self.prob = p + self.logits = -jt.log(1. / p - 1) + + def sample(self,sample_shape): + tiny = jt.info(self.probs.dtype).tiny + u = jt.clamp(jt.rand(sample_shape),min_v=tiny) + return (jt.log(u) / (jt.log(-self.probs+1))).floor() + + def log_prob(self,x): + return x*jt.log(-self.prob+1)+jt.log(self.prob) + + def entropy(self): + return binary_cross_entropy_with_logits(jt.array(self.logits),jt.array(self.prob)) / self.prob + + +def Poisson_sample(la, size): + p = math.exp(-la) + u = jt.random(size, "float32") + res = jt.zeros(size) + for i in size: + k = 0 + p = math.exp(-la) + s = p + if u[i] <= p: + res[i] = 0 + continue + else: + while u > s: + p = la * p / (k + 1) + s = s + p + k += 1 + res[i] = k + return res + + +class Poisson: + def __init__(self, la): + self.la = la + + def sample(self, sample_shape): + return Poisson_sample(self.la,sample_shape) + + def log_prob(self,x): + # todo: add lgamma. + return jt.log(self.la)* x - self.la - lgamma(x + 1) class Uniform: @@ -158,15 +249,17 @@ def kl_divergence(cur_dist, old_dist): vr = (cur_dist.sigma / old_dist.sigma)**2 t1 = ((cur_dist.mu - old_dist.mu) / old_dist.sigma)**2 return 0.5*(vr+t1-1-jt.log(vr)) - if isinstance(cur_dist, Categorical) or isinstance(cur_dist,OneHotCategorical): + if isinstance(cur_dist,Categorical) or isinstance(cur_dist,OneHotCategorical): t = cur_dist.probs * (cur_dist.logits-old_dist.logits) t[jt.array((old_dist.probs == 0))] = math.inf t[jt.array((cur_dist.probs == 0))] = 0 return t.sum(-1) - if isinstance(cur_dist, Uniform): + if isinstance(cur_dist,Uniform): res = jt.log((old_dist.high - old_dist.low) / (cur_dist.high - cur_dist.low)) if old_dist.low > cur_dist.low or old_dist.high < cur_dist.high: res = math.inf return res - if isinstance(cur_dist, Geometric): + if isinstance(cur_dist,Geometric): return -cur_dist.entropy() - jt.log(-old_dist.prob+1) / cur_dist.prob - old_dist.logits + if isinstance(cur_dist,Poisson): + return cur_dist.la * (jt.log(cur_dist.la) - jt.log(old_dist.la)) - (cur_dist.la - old_dist.la) diff --git a/python/jittor/test/test_distributions.py b/python/jittor/test/test_distributions.py index 54b75265..1aadd32f 100644 --- a/python/jittor/test/test_distributions.py +++ b/python/jittor/test/test_distributions.py @@ -18,6 +18,12 @@ def test_presum(self): a = jt.array([[1,2,3,4]]) b = jd.simple_presum(a) assert (b.data == [[0,1,3,6,10]]).all() + + def test_lgamma(self): + import torch + ta = np.random.uniform(2,3,(1)) + a = jt.array(ta).float32() + assert np.allclose(jd.lgamma(a).data, torch.lgamma(torch.tensor(ta)).numpy()),(jd.lgamma(a).data, torch.lgamma(torch.tensor(ta)).numpy()) def test_one_hot(self): a = jd.OneHotCategorical(jt.array([0.25, 0.25, 0.25, 0.25])) @@ -31,9 +37,9 @@ def test_one_hot(self): probs,probs2 = np.random.uniform(0,1,(10)), np.random.uniform(0,1,(10)) probs,probs2 = probs / probs.sum(),probs2 / probs2.sum() import torch + tc, tc2 = torch.distributions.OneHotCategorical(torch.tensor(probs).to(torch.float32)),torch.distributions.OneHotCategorical(torch.tensor(probs2).to(torch.float32)) jc, jc2 = jd.OneHotCategorical(jt.array(probs).reshape(1,-1)),jd.OneHotCategorical(jt.array(probs2).reshape(1,-1)) - tc, tc2 = torch.distributions.OneHotCategorical(torch.tensor(probs)),torch.distributions.OneHotCategorical(torch.tensor(probs2)) - assert np.allclose(jc.entropy().data,tc.entropy().numpy()) + assert np.allclose(jc.entropy().data,tc.entropy().numpy()), (jc.entropy().data, tc.entropy().numpy()) x = np.zeros((4,10)) for _ in range(4): nx = np.random.randint(0,9) @@ -72,7 +78,6 @@ def test_categorical(self): for _ in range(4): probs,probs2 = np.random.uniform(0,1,(10)), np.random.uniform(0,1,(10)) probs,probs2 = probs / probs.sum(),probs2 / probs2.sum() - jc, jc2 = jd.Categorical(jt.array(probs).reshape(1,-1)),jd.Categorical(jt.array(probs2).reshape(1,-1)) tc, tc2 = torch.distributions.Categorical(torch.tensor(probs)),torch.distributions.Categorical(torch.tensor(probs2)) assert np.allclose(jc.entropy().data, tc.entropy().numpy()), (jc.entropy().data, tc.entropy().numpy()) x = np.random.randint(0,10,(4)) @@ -101,8 +106,7 @@ def test_geometric(self): assert np.allclose(jg.entropy().data,tg.entropy().numpy()) x = np.random.randint(1,10) assert np.allclose(jg.log_prob(x),tg.log_prob(torch.tensor(x))) - # print(jd.kl_divergence(jg,jg2),torch.distributions.kl_divergence(tg,tg2)) assert np.allclose(jd.kl_divergence(jg,jg2),torch.distributions.kl_divergence(tg,tg2)) if __name__ == "__main__": - unittest.main() \ No newline at end of file + unittest.main()