@@ -377,3 +377,35 @@ def test_fit(self):
377
377
(params .sigma , 1.759817171541185 ),
378
378
(params .loc , 0 ),
379
379
)
380
+
381
+
382
+ class Test_truncated_normal (CheckDist_Mixin ):
383
+ def setup (self ):
384
+ self .dist = dist .truncated_normal
385
+ self .cargs = []
386
+ self .ckwds = dict (lower = - 0.5 , upper = 2.5 , mu = 1 , sigma = 4 )
387
+
388
+ self .np_rand_fxn = stats .truncnorm .rvs
389
+ self .npargs = [- 0.375 , 0.375 ]
390
+ self .npkwds = dict (loc = 1 , scale = 4 )
391
+
392
+ def test_processargs (self ):
393
+ result = self .dist ._process_args (lower = - 0.5 , upper = 2.5 , mu = 1 , sigma = 4 )
394
+ expected = dict (a = - 0.375 , b = 0.375 , loc = 1 , scale = 4 )
395
+ assert result == expected
396
+
397
+ result = self .dist ._process_args (upper = 2.5 , mu = 1 , sigma = 4 , fit = True )
398
+ expected = dict (f0 = None , f1 = 0.375 , floc = 1 , fscale = 4 )
399
+ assert result == expected
400
+
401
+ @seed
402
+ def test_fit (self ):
403
+ stn = stats .truncnorm (- 0.375 , 0.375 , loc = 1 , scale = 4 )
404
+ data = stn .rvs (size = 37000 )
405
+ params = self .dist .fit (data , lower = - 0.5 , mu = 1 , sigma = 4 )
406
+ check_params (
407
+ (params .lower , - 0.5 ),
408
+ (params .upper , 2.4999301 ),
409
+ (params .mu , 1 ),
410
+ (params .sigma , 4 ),
411
+ )
0 commit comments