Skip to content

Commit c30f4bc

Browse files
committed
basic truncated normal dist
1 parent 9f3d9aa commit c30f4bc

File tree

1 file changed

+112
-1
lines changed

1 file changed

+112
-1
lines changed

paramnormal/dist.py

Lines changed: 112 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -797,6 +797,7 @@ class exponential(BaseDist_Mixin):
797797
numpy.random.exponential
798798
799799
"""
800+
800801
dist = stats.expon
801802
param_template = namedtuple('params', ['lamda', 'loc'])
802803
name = 'exponential'
@@ -833,7 +834,7 @@ class rice(BaseDist_Mixin):
833834
R : float
834835
The shape parameter of the distribution.
835836
sigma : float
836-
The standard deviate of the distribution.
837+
The standard deviation of the distribution.
837838
loc : float, optional
838839
Location parameter of the distribution. This defaults to, and
839840
should probably be left at, 0.
@@ -879,6 +880,7 @@ class rice(BaseDist_Mixin):
879880
numpy.random.exponential
880881
881882
"""
883+
882884
dist = stats.rice
883885
param_template = namedtuple('params', ['R', 'sigma', 'loc'])
884886
name = 'rice'
@@ -904,6 +906,114 @@ def fit(cls, data, **guesses):
904906
return cls.param_template(R=b*sigma, loc=loc, sigma=sigma)
905907

906908

909+
class truncated_normal(BaseDist_Mixin):
910+
"""
911+
Create and fit data to a truncated normal distribution.
912+
913+
Methods
914+
-------
915+
fit
916+
Use scipy's maximum likelihood estimation methods to estimate
917+
the parameters of the data's distribution.
918+
from_params
919+
Create a new distribution instances from the ``namedtuple``
920+
result of the :meth:`~fit` method.
921+
922+
Parameters
923+
----------
924+
lower, upper : float
925+
The lower and upper limits of the distribution that serve as its
926+
shape parameters.
927+
mu : float, optional (default = 0)
928+
The expected value (mean) of the underlying normal distribution.
929+
Acts as the location parameter of the distribution.
930+
sigma : float, optional (default = 1)
931+
The standard deviation of the underlying normal distribution.
932+
Also acts as the scale parameter of distribution.
933+
934+
Examples
935+
--------
936+
>>> import numpy
937+
>>> import paramnormal as pn
938+
>>> numpy.random.seed(0)
939+
>>> pn.truncated_normal(lower=-0.5, upper=0.5).rvs(size=3)
940+
array([ 0.04687082, 0.20804061, 0.09879796])
941+
942+
>>> # you can also use greek letters
943+
>>> numpy.random.seed(0)
944+
>>> pn.truncated_normal(lower=-0.5, upper=2.5, σ=2).rvs(size=3)
945+
array([ 0.8902748 , 1.37377049, 1.04012565])
946+
947+
>>> # silly fake data
948+
>>> numpy.random.seed(0)
949+
>>> data = pn.truncated_normal(lower=-0.5, upper=2.5, mu=0, sigma=2).rvs(size=37)
950+
>>> # pretend `data` is unknown and we want to fit a dist. to it
951+
>>> pn.truncated_normal.fit(data)
952+
params(lower=1.040124, upper=1.082447, mu=-8.097877e-06, sigma=1.033405)
953+
954+
In scipy, the distribution is defined as
955+
``stats.truncnorm(a, b, loc, scale)`` where
956+
957+
.. math::
958+
959+
a = \frac{\mathrm{lower bound}} - \mu}{\sigma}
960+
961+
and
962+
963+
.. math::
964+
965+
b = \frac{x_{\mathrm{upper bound}} - \mu}{\sigma}
966+
967+
Since ``a`` and ``b`` are directly linked to the location and scale
968+
of the distribution as well as the lower and upper limits,
969+
respectively, it's difficult to use the ``fit`` method of this
970+
distirbution without either knowing a lot about it `a priori` or
971+
assuming just as much.
972+
973+
References
974+
----------
975+
http://scipy.github.io/devdocs/generated/scipy.stats.truncnorm
976+
https://en.wikipedia.org/wiki/Rice_distribution
977+
978+
See Also
979+
--------
980+
scipy.stats.rice
981+
numpy.random.exponential
982+
983+
"""
984+
985+
dist = stats.truncnorm
986+
param_template = namedtuple('params', ['lower', 'upper', 'mu', 'sigma'])
987+
name = 'truncated normal'
988+
989+
@staticmethod
990+
@utils.greco_deco
991+
def _process_args(lower=None, upper=None, mu=None, sigma=None, fit=False):
992+
a = None
993+
b = None
994+
if lower is not None and mu is not None and sigma is not None:
995+
a = (lower - mu) / sigma
996+
997+
if upper is not None and mu is not None and sigma is not None:
998+
b = (upper - mu) / sigma
999+
1000+
loc_key, scale_key = utils._get_loc_scale_keys(fit=fit)
1001+
if fit:
1002+
akey = 'f0'
1003+
bkey = 'f1'
1004+
else:
1005+
akey = 'a'
1006+
bkey = 'b'
1007+
return {akey: a, bkey: b, loc_key: mu, scale_key: sigma}
1008+
1009+
@classmethod
1010+
def fit(cls, data, **guesses):
1011+
a, b, mu, sigma = cls._fit(data, **guesses)
1012+
lower = a * sigma + mu
1013+
upper = b * sigma + mu
1014+
return cls.param_template(lower=lower, upper=upper, mu=mu, sigma=sigma)
1015+
1016+
9071017
__all__ = [
9081018
'normal',
9091019
'lognormal',
@@ -915,4 +1025,5 @@ def fit(cls, data, **guesses):
9151025
'pareto',
9161026
'exponential',
9171027
'rice',
1028+
'truncated_normal',
9181029
]

0 commit comments

Comments
 (0)