@@ -797,6 +797,7 @@ class exponential(BaseDist_Mixin):
797
797
numpy.random.exponential
798
798
799
799
"""
800
+
800
801
dist = stats .expon
801
802
param_template = namedtuple ('params' , ['lamda' , 'loc' ])
802
803
name = 'exponential'
@@ -833,7 +834,7 @@ class rice(BaseDist_Mixin):
833
834
R : float
834
835
The shape parameter of the distribution.
835
836
sigma : float
836
- The standard deviate of the distribution.
837
+ The standard deviation of the distribution.
837
838
loc : float, optional
838
839
Location parameter of the distribution. This defaults to, and
839
840
should probably be left at, 0.
@@ -879,6 +880,7 @@ class rice(BaseDist_Mixin):
879
880
numpy.random.exponential
880
881
881
882
"""
883
+
882
884
dist = stats .rice
883
885
param_template = namedtuple ('params' , ['R' , 'sigma' , 'loc' ])
884
886
name = 'rice'
@@ -904,6 +906,114 @@ def fit(cls, data, **guesses):
904
906
return cls .param_template (R = b * sigma , loc = loc , sigma = sigma )
905
907
906
908
909
+ class truncated_normal (BaseDist_Mixin ):
910
+ """
911
+ Create and fit data to a truncated normal distribution.
912
+
913
+ Methods
914
+ -------
915
+ fit
916
+ Use scipy's maximum likelihood estimation methods to estimate
917
+ the parameters of the data's distribution.
918
+ from_params
919
+ Create a new distribution instances from the ``namedtuple``
920
+ result of the :meth:`~fit` method.
921
+
922
+ Parameters
923
+ ----------
924
+ lower, upper : float
925
+ The lower and upper limits of the distribution that serve as its
926
+ shape parameters.
927
+ mu : float, optional (default = 0)
928
+ The expected value (mean) of the underlying normal distribution.
929
+ Acts as the location parameter of the distribution.
930
+ sigma : float, optional (default = 1)
931
+ The standard deviation of the underlying normal distribution.
932
+ Also acts as the scale parameter of distribution.
933
+
934
+ Examples
935
+ --------
936
+ >>> import numpy
937
+ >>> import paramnormal as pn
938
+ >>> numpy.random.seed(0)
939
+ >>> pn.truncated_normal(lower=-0.5, upper=0.5).rvs(size=3)
940
+ array([ 0.04687082, 0.20804061, 0.09879796])
941
+
942
+ >>> # you can also use greek letters
943
+ >>> numpy.random.seed(0)
944
+ >>> pn.truncated_normal(lower=-0.5, upper=2.5, σ=2).rvs(size=3)
945
+ array([ 0.8902748 , 1.37377049, 1.04012565])
946
+
947
+ >>> # silly fake data
948
+ >>> numpy.random.seed(0)
949
+ >>> data = pn.truncated_normal(lower=-0.5, upper=2.5, mu=0, sigma=2).rvs(size=37)
950
+ >>> # pretend `data` is unknown and we want to fit a dist. to it
951
+ >>> pn.truncated_normal.fit(data)
952
+ params(lower=1.040124, upper=1.082447, mu=-8.097877e-06, sigma=1.033405)
953
+
954
+ In scipy, the distribution is defined as
955
+ ``stats.truncnorm(a, b, loc, scale)`` where
956
+
957
+ .. math::
958
+
959
+ a = \f rac{\mathrm{lower bound}} - \mu}{\sigma}
960
+
961
+ and
962
+
963
+ .. math::
964
+
965
+ b = \f rac{x_{\mathrm{upper bound}} - \mu}{\sigma}
966
+
967
+ Since ``a`` and ``b`` are directly linked to the location and scale
968
+ of the distribution as well as the lower and upper limits,
969
+ respectively, it's difficult to use the ``fit`` method of this
970
+ distirbution without either knowing a lot about it `a priori` or
971
+ assuming just as much.
972
+
973
+ References
974
+ ----------
975
+ http://scipy.github.io/devdocs/generated/scipy.stats.truncnorm
976
+ https://en.wikipedia.org/wiki/Rice_distribution
977
+
978
+ See Also
979
+ --------
980
+ scipy.stats.rice
981
+ numpy.random.exponential
982
+
983
+ """
984
+
985
+ dist = stats .truncnorm
986
+ param_template = namedtuple ('params' , ['lower' , 'upper' , 'mu' , 'sigma' ])
987
+ name = 'truncated normal'
988
+
989
+ @staticmethod
990
+ @utils .greco_deco
991
+ def _process_args (lower = None , upper = None , mu = None , sigma = None , fit = False ):
992
+ a = None
993
+ b = None
994
+ if lower is not None and mu is not None and sigma is not None :
995
+ a = (lower - mu ) / sigma
996
+
997
+ if upper is not None and mu is not None and sigma is not None :
998
+ b = (upper - mu ) / sigma
999
+
1000
+ loc_key , scale_key = utils ._get_loc_scale_keys (fit = fit )
1001
+ if fit :
1002
+ akey = 'f0'
1003
+ bkey = 'f1'
1004
+ else :
1005
+ akey = 'a'
1006
+ bkey = 'b'
1007
+ return {akey : a , bkey : b , loc_key : mu , scale_key : sigma }
1008
+
1009
+ @classmethod
1010
+ def fit (cls , data , ** guesses ):
1011
+ a , b , mu , sigma = cls ._fit (data , ** guesses )
1012
+ lower = a * sigma + mu
1013
+ upper = b * sigma + mu
1014
+ return cls .param_template (lower = lower , upper = upper , mu = mu , sigma = sigma )
1015
+
1016
+
907
1017
__all__ = [
908
1018
'normal' ,
909
1019
'lognormal' ,
@@ -915,4 +1025,5 @@ def fit(cls, data, **guesses):
915
1025
'pareto' ,
916
1026
'exponential' ,
917
1027
'rice' ,
1028
+ 'truncated_normal' ,
918
1029
]
0 commit comments