Skip to content

Commit dcde539

Browse files
authored
Merge branch 'main' into paper
2 parents ea3e2a2 + 37a02e4 commit dcde539

File tree

14 files changed

+489
-86
lines changed

14 files changed

+489
-86
lines changed

README.md

Lines changed: 30 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,26 @@
88

99
Bayesian conjugate models in Python
1010

11+
## Overview
12+
13+
`conjugate-models` is a modern Python package for Bayesian conjugate inference that prioritizes a clean, idiomatic API and seamless integration with widely used Python data analysis libraries. It implements the conjugate likelihood-prior pairs cataloged in [Fink's compendium](https://www.johndcook.com/CompendiumOfConjugatePriors.pdf) and [Wikipedia's conjugate prior table](https://en.wikipedia.org/wiki/Conjugate_prior), making rigorous Bayesian updating, exploration, and visualization accessible for practitioners, educators, and researchers.
14+
15+
### Why Conjugate Priors?
16+
17+
A prior distribution is conjugate to a likelihood when the posterior remains in the same distribution family after observing data. Conjugate priors provide closed-form posterior updates and posterior predictive distributions, eliminating the need for numerical integration or MCMC sampling. Because these updates are analytic rather than iterative, **posterior computation is instantaneous regardless of data size**—enabling real-time interactive exploration and rapid model iteration.
18+
19+
### Key Benefits
20+
21+
-**Instant Updates:** No MCMC or optimization required—posterior computation is immediate
22+
- 🔢 **Vectorized Operations:** Batch inference for multi-arm problems without explicit loops
23+
- 📊 **Built-in Visualization:** Plot priors, posteriors, and predictive distributions
24+
- 🔗 **SciPy Integration:** Direct access to scipy.stats distributions via `.dist` property
25+
- 📦 **Data Library Support:** Works seamlessly with numpy, pandas, polars, and general array-like objects
26+
- 🪶 **Lightweight Dependencies:** Minimal requirements—no heavy ML frameworks or complex toolchains
27+
28+
### Lightweight & Easy to Install
29+
30+
With minimal dependencies from the scientific Python stack, `conjugate-models` installs quickly without requiring heavyweight probabilistic programming frameworks, MCMC samplers, or complex compilation toolchains.
1131

1232
## Installation
1333

@@ -19,7 +39,7 @@ pip install conjugate-models
1939

2040
- [Interactive Distribution Explorer](https://williambdean.github.io/conjugate/explorer) for exploring probability distributions with real-time parameter adjustment
2141
- **[Raw Data Workflow](https://williambdean.github.io/conjugate/examples/raw-data-workflow)** - Complete examples from raw observational data to posterior distributions with helper functions
22-
- **Data Input Helper Functions** - Extract sufficient statistics from raw observational data for all supported models
42+
- **[Data Input Helper Functions](https://williambdean.github.io/conjugate/helpers)** - Extract sufficient statistics from raw observational data for all supported models
2343
- [Connection to Scipy Distributions](https://williambdean.github.io/conjugate/examples/scipy-connection) with `dist` attribute
2444
- [Built in Plotting](https://williambdean.github.io/conjugate/examples/plotting) with `plot_pdf`, `plot_pmf`, and `plot_cdf` methods
2545
- [Vectorized Operations](https://williambdean.github.io/conjugate/examples/vectorized-inputs) for parameters and data
@@ -38,9 +58,11 @@ Many likelihoods are supported including
3858
- `Normal` (including linear regression)
3959
- and [many more](https://williambdean.github.io/conjugate/models/)
4060

61+
See the [Quick Reference](https://williambdean.github.io/conjugate/quick-reference) for a complete table of likelihood → prior/posterior mappings with links to model functions and helper functions.
62+
4163
## Basic Usage
4264

43-
### Working with Pre-processed Data
65+
### Pattern 1: Working with Pre-processed Data
4466

4567
1. Define prior distribution from `distributions` module
4668
1. Pass data and prior into model from `models` modules
@@ -64,7 +86,7 @@ posterior_predictive: BetaBinomial = binomial_beta_predictive(
6486
)
6587
```
6688

67-
### Working with Raw Observational Data
89+
### Pattern 2: Working with Raw Observational Data
6890

6991
For raw data, use **helper functions** from the `helpers` module to extract sufficient statistics:
7092

@@ -99,18 +121,18 @@ from conjugate.helpers import (
99121
# Count data (e.g., website visits per day)
100122
count_data = [5, 3, 8, 2, 6, 4, 7, 1, 9, 3]
101123
inputs = poisson_gamma_inputs(count_data)
102-
# Returns: {'x': sum(count_data), 'n': len(count_data)}
124+
# Returns: {'x_total': sum(count_data), 'n': len(count_data)}
103125

104126
# Continuous measurements with known variance
105127
measurements = [2.3, 1.9, 2.7, 2.1, 2.5]
106-
variance = 0.5
107-
inputs = normal_known_variance_inputs(measurements, variance=variance)
108-
# Returns: {'x_mean': mean(measurements), 'n': len(measurements), 'variance': variance}
128+
inputs = normal_known_variance_inputs(measurements)
129+
# Returns: {'x_total': sum(measurements), 'n': len(measurements)}
130+
# Note: variance must be passed separately to the model function
109131

110132
# Time between events (e.g., customer arrivals)
111133
wait_times = [3.2, 1.8, 4.1, 2.7, 3.9]
112134
inputs = exponential_gamma_inputs(wait_times)
113-
# Returns: {'x': sum(wait_times), 'n': len(wait_times)}
135+
# Returns: {'x_total': sum(wait_times), 'n': len(wait_times)}
114136

115137
# Categorical outcomes (e.g., survey responses A, B, C)
116138
responses = ['A', 'B', 'A', 'C', 'B', 'A', 'B']

conjugate/distributions.py

Lines changed: 47 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -631,7 +631,9 @@ def from_inverse_gamma(
631631
def inverse_gamma(self) -> InverseGamma:
632632
return InverseGamma(alpha=self.alpha, beta=self.beta)
633633

634-
def sample_variance(self, size: int, random_state=None) -> NUMERIC:
634+
def sample_variance(
635+
self, size: int, random_state: np.random.RandomState | None = None
636+
) -> NUMERIC:
635637
"""Sample variance from the inverse gamma distribution.
636638
637639
Args:
@@ -644,11 +646,15 @@ def sample_variance(self, size: int, random_state=None) -> NUMERIC:
644646
"""
645647
return self.inverse_gamma.dist.rvs(size=size, random_state=random_state)
646648

647-
def _sample_beta_1d(self, variance, size: int, random_state=None) -> NUMERIC:
649+
def _sample_beta_1d(
650+
self, variance, size: int, random_state: np.random.RandomState | None = None
651+
) -> NUMERIC:
648652
sigma = (variance / self.nu) ** 0.5
649653
return stats.norm(self.mu, sigma).rvs(size=size, random_state=random_state)
650654

651-
def _sample_beta_nd(self, variance, size: int, random_state=None) -> NUMERIC:
655+
def _sample_beta_nd(
656+
self, variance, size: int, random_state: np.random.RandomState | None = None
657+
) -> NUMERIC:
652658
variance = (self.delta_inverse[None, ...].T * variance).T
653659
return np.stack(
654660
[
@@ -663,7 +669,7 @@ def sample_mean(
663669
self,
664670
size: int,
665671
return_variance: bool = False,
666-
random_state=None,
672+
random_state: np.random.RandomState | None = None,
667673
) -> NUMERIC | tuple[NUMERIC, NUMERIC]:
668674
"""Sample the mean from the normal distribution.
669675
@@ -681,7 +687,10 @@ def sample_mean(
681687
)
682688

683689
def sample_beta(
684-
self, size: int, return_variance: bool = False, random_state=None
690+
self,
691+
size: int,
692+
return_variance: bool = False,
693+
random_state: np.random.RandomState | None = None,
685694
) -> NUMERIC | tuple[NUMERIC, NUMERIC]:
686695
"""Sample beta from the normal distribution.
687696
@@ -809,7 +818,11 @@ class GammaKnownRateProportional:
809818
c: NUMERIC
810819

811820
def approx_log_likelihood(
812-
self, alpha: NUMERIC, beta: NUMERIC, ln=np.log, gammaln=gammaln
821+
self,
822+
alpha: NUMERIC,
823+
beta: NUMERIC,
824+
ln: Callable = np.log,
825+
gammaln: Callable = gammaln,
813826
) -> NUMERIC:
814827
"""Approximate log likelihood.
815828
@@ -848,7 +861,11 @@ class GammaProportional:
848861
s: NUMERIC
849862

850863
def approx_log_likelihood(
851-
self, alpha: NUMERIC, beta: NUMERIC, ln=np.log, gammaln=gammaln
864+
self,
865+
alpha: NUMERIC,
866+
beta: NUMERIC,
867+
ln: Callable = np.log,
868+
gammaln: Callable = gammaln,
852869
) -> NUMERIC:
853870
"""Approximate log likelihood.
854871
@@ -886,7 +903,11 @@ class BetaProportional:
886903
k: NUMERIC
887904

888905
def approx_log_likelihood(
889-
self, alpha: NUMERIC, beta: NUMERIC, ln=np.log, gammaln=gammaln
906+
self,
907+
alpha: NUMERIC,
908+
beta: NUMERIC,
909+
ln: Callable = np.log,
910+
gammaln: Callable = gammaln,
890911
) -> NUMERIC:
891912
"""Approximate log likelihood.
892913
@@ -946,7 +967,13 @@ class VonMisesKnownConcentration:
946967
a: NUMERIC
947968
b: NUMERIC
948969

949-
def log_likelihood(self, mu: NUMERIC, cos=np.cos, ln=np.log, i0=i0) -> NUMERIC:
970+
def log_likelihood(
971+
self,
972+
mu: NUMERIC,
973+
cos: Callable = np.cos,
974+
ln: Callable = np.log,
975+
i0: Callable = i0,
976+
) -> NUMERIC:
950977
"""Approximate log likelihood.
951978
952979
Args:
@@ -976,7 +1003,9 @@ class VonMisesKnownDirectionProportional:
9761003
c: NUMERIC
9771004
r: NUMERIC
9781005

979-
def approx_log_likelihood(self, kappa: NUMERIC, ln=np.log, i0=i0) -> NUMERIC:
1006+
def approx_log_likelihood(
1007+
self, kappa: NUMERIC, ln: Callable = np.log, i0: Callable = i0
1008+
) -> NUMERIC:
9801009
"""Approximate log likelihood.
9811010
9821011
Args:
@@ -1058,7 +1087,9 @@ class NormalGamma:
10581087
def gamma(self) -> Gamma:
10591088
return Gamma(alpha=self.alpha, beta=self.beta)
10601089

1061-
def sample_variance(self, size: int, random_state=None) -> NUMERIC:
1090+
def sample_variance(
1091+
self, size: int, random_state: np.random.RandomState | None = None
1092+
) -> NUMERIC:
10621093
"""Sample precision from gamma distribution and invert.
10631094
10641095
Args:
@@ -1077,7 +1108,7 @@ def sample_mean(
10771108
self,
10781109
size: int,
10791110
return_variance: bool = False,
1080-
random_state=None,
1111+
random_state: np.random.RandomState | None = None,
10811112
) -> NUMERIC | tuple[NUMERIC, NUMERIC]:
10821113
"""Sample mean from the normal distribution.
10831114
@@ -1095,7 +1126,10 @@ def sample_mean(
10951126
)
10961127

10971128
def sample_beta(
1098-
self, size: int, return_variance: bool = False, random_state=None
1129+
self,
1130+
size: int,
1131+
return_variance: bool = False,
1132+
random_state: np.random.RandomState | None = None,
10991133
) -> NUMERIC | tuple[NUMERIC, NUMERIC]:
11001134
"""Sample beta from the normal distribution.
11011135

0 commit comments

Comments
 (0)