Skip to content

Commit 0e5599e

Browse files
ricardoV94twiecki
authored andcommitted
Reintroduce test_types.py
1 parent e66a8fb commit 0e5599e

File tree

2 files changed

+7
-5
lines changed

2 files changed

+7
-5
lines changed

.github/workflows/pytest.yml

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,6 @@ jobs:
4141
--ignore=pymc/tests/test_step.py
4242
--ignore=pymc/tests/test_tuning.py
4343
--ignore=pymc/tests/test_transforms.py
44-
--ignore=pymc/tests/test_types.py
4544
--ignore=pymc/tests/test_variational_inference.py
4645
--ignore=pymc/tests/test_sampling_jax.py
4746
--ignore=pymc/tests/test_dist_math.py

pymc/tests/test_types.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616

1717
import aesara
1818
import numpy as np
19+
import pytest
1920

2021
from pymc.distributions import Normal
2122
from pymc.model import Model
@@ -45,7 +46,7 @@ def test_float64(self):
4546

4647
for sampler in self.samplers:
4748
with model:
48-
sample(10, sampler())
49+
sample(draws=10, tune=10, chains=1, step=sampler())
4950

5051
@aesara.config.change_flags({"floatX": "float32", "warn_float64": "warn"})
5152
def test_float32(self):
@@ -58,8 +59,9 @@ def test_float32(self):
5859

5960
for sampler in self.samplers:
6061
with model:
61-
sample(10, sampler())
62+
sample(draws=10, tune=10, chains=1, step=sampler())
6263

64+
@pytest.mark.xfail(reason="MLDA not refactored for V4 yet")
6365
@aesara.config.change_flags({"floatX": "float64", "warn_float64": "ignore"})
6466
def test_float64_MLDA(self):
6567
data = np.random.randn(5)
@@ -76,8 +78,9 @@ def test_float64_MLDA(self):
7678
assert obs.dtype == "float64"
7779

7880
with model:
79-
sample(10, MLDA(coarse_models=[coarse_model]))
81+
sample(draws=10, tune=10, chains=1, step=MLDA(coarse_models=[coarse_model]))
8082

83+
@pytest.mark.xfail(reason="MLDA not refactored for V4 yet")
8184
@aesara.config.change_flags({"floatX": "float32", "warn_float64": "warn"})
8285
def test_float32_MLDA(self):
8386
data = np.random.randn(5).astype("float32")
@@ -94,4 +97,4 @@ def test_float32_MLDA(self):
9497
assert obs.dtype == "float32"
9598

9699
with model:
97-
sample(10, MLDA(coarse_models=[coarse_model]))
100+
sample(draws=10, tune=10, chains=1, step=MLDA(coarse_models=[coarse_model]))

0 commit comments

Comments
 (0)