Skip to content
This repository was archived by the owner on Nov 17, 2025. It is now read-only.
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion aesara/scalar/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -332,8 +332,11 @@ class Scalar(CType):
ndim = 0

def __init__(self, dtype):
if dtype == "floatX":
if isinstance(dtype, str) and dtype == "floatX":
dtype = config.floatX
else:
dtype = np.dtype(dtype).name

self.dtype = dtype
self.dtype_specs() # error checking

Expand Down
3 changes: 3 additions & 0 deletions aesara/tensor/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -253,3 +253,6 @@ def _as_tensor_variable(
zvector,
)
from aesara.tensor.type_other import *


__all__ = ["random"] # noqa: F405
11 changes: 7 additions & 4 deletions aesara/tensor/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -923,18 +923,21 @@ def _conversion(real_value, name):

def cast(x, dtype):
"""Symbolically cast `x` to a Tensor of type `dtype`."""
if dtype == "floatX":

if isinstance(dtype, str) and dtype == "floatX":
dtype = config.floatX

dtype_name = np.dtype(dtype).name

_x = as_tensor_variable(x)
if _x.type.dtype == dtype:
if _x.type.dtype == dtype_name:
return _x
if _x.type.dtype.startswith("complex") and not dtype.startswith("complex"):
if _x.type.dtype.startswith("complex") and not dtype_name.startswith("complex"):
raise TypeError(
"Casting from complex to real is ambiguous: consider real(), "
"imag(), angle() or abs()"
)
return _cast_mapping[dtype](x)
return _cast_mapping[dtype_name](x)


##########################
Expand Down
2 changes: 2 additions & 0 deletions aesara/tensor/random/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
# Initialize `RandomVariable` optimizations
import aesara.tensor.random.opt
import aesara.tensor.random.utils
from aesara.tensor.random.basic import *
from aesara.tensor.random.utils import RandomStream
38 changes: 38 additions & 0 deletions aesara/tensor/random/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -693,3 +693,41 @@ def __call__(self, x, **kwargs):


permutation = PermutationRV()


__all__ = [
"permutation",
"choice",
"randint",
"categorical",
"multinomial",
"betabinom",
"nbinom",
"binomial",
"laplace",
"bernoulli",
"truncexpon",
"wald",
"invgamma",
"halfcauchy",
"cauchy",
"hypergeometric",
"geometric",
"poisson",
"dirichlet",
"multivariate_normal",
"vonmises",
"logistic",
"weibull",
"exponential",
"gumbel",
"pareto",
"chisquare",
"gamma",
"lognormal",
"halfnormal",
"normal",
"beta",
"triangular",
"uniform",
]
62 changes: 9 additions & 53 deletions aesara/tensor/random/op.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,17 +9,18 @@
from aesara.graph.basic import Apply, Variable
from aesara.graph.op import Op
from aesara.misc.safe_asarray import _asarray
from aesara.scalar.basic import Cast
from aesara.tensor.basic import (
as_tensor_variable,
constant,
get_scalar_constant_value,
get_vector_length,
)
from aesara.tensor.elemwise import Elemwise
from aesara.tensor.exceptions import NotScalarConstantError
from aesara.tensor.random.type import RandomStateType
from aesara.tensor.random.utils import normalize_size_param, params_broadcast_shapes
from aesara.tensor.type import TensorType, all_dtypes
from aesara.tensor.type_other import NoneConst


def default_shape_from_params(
Expand Down Expand Up @@ -285,6 +286,13 @@ def compute_bcast(self, dist_params, size):
"""
shape = self._infer_shape(size, dist_params)

# Ignore `Cast`s, since they do not affect broadcastables
if getattr(shape, "owner", None) and (
isinstance(shape.owner.op, Elemwise)
and isinstance(shape.owner.op.scalar_op, Cast)
):
shape = shape.owner.inputs[0]

# Let's try to do a better job than `_infer_ndim_bcast` when
# dimension sizes are symbolic.
bcast = []
Expand Down Expand Up @@ -422,55 +430,3 @@ def grad(self, inputs, outputs):

def R_op(self, inputs, eval_points):
return [None for i in eval_points]


class Observed(Op):
"""An `Op` that represents an observed random variable.

This `Op` establishes an observation relationship between a random
variable and a specific value.
"""

default_output = 0
view_map = {0: [1]}

def make_node(self, rv, val):
"""Make an `Observed` random variable.

Parameters
----------
rv: RandomVariable
The distribution from which `val` is assumed to be a sample value.
val: Variable
The observed value.
"""
val = as_tensor_variable(val)

if rv is not None:
if not hasattr(rv, "type") or rv.type.convert_variable(val) is None:
raise TypeError(
(
"`rv` and `val` do not have compatible types:"
f" rv={rv}, val={val}"
)
)
else:
rv = NoneConst.clone()

inputs = [rv, val]

return Apply(self, inputs, [val.type()])

def perform(self, node, inputs, out):
out[0][0] = inputs[1]

def grad(self, inputs, outputs):
return [
aesara.gradient.grad_undefined(
self, k, inp, "No gradient defined for random variables"
)
for k, inp in enumerate(inputs)
]


observed = Observed()
6 changes: 4 additions & 2 deletions aesara/tensor/type.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,9 +75,11 @@ class TensorType(CType):
"""

def __init__(self, dtype, broadcastable, name=None):
self.dtype = str(dtype)
if self.dtype == "floatX":
if isinstance(dtype, str) and dtype == "floatX":
self.dtype = config.floatX
else:
self.dtype = np.dtype(dtype).name

# broadcastable is immutable, and all elements are either
# True or False
self.broadcastable = tuple(bool(b) for b in broadcastable)
Expand Down
5 changes: 5 additions & 0 deletions tests/scalar/test_type.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,11 @@
)


def test_numpy_dtype():
test_type = Scalar(np.int32)
assert test_type.dtype == "int32"


def test_div_types():
a = int8()
b = int32()
Expand Down
37 changes: 8 additions & 29 deletions tests/tensor/random/test_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,8 @@
from aesara.assert_op import Assert
from aesara.gradient import NullTypeGradError, grad
from aesara.tensor.math import eq
from aesara.tensor.random.basic import normal
from aesara.tensor.random.op import RandomVariable, default_shape_from_params, observed
from aesara.tensor.type import all_dtypes, iscalar, tensor, vector
from aesara.tensor.type_other import NoneTypeT
from aesara.tensor.random.op import RandomVariable, default_shape_from_params
from aesara.tensor.type import all_dtypes, iscalar, tensor


@fixture(scope="module", autouse=True)
Expand Down Expand Up @@ -113,6 +111,8 @@ def test_RandomVariable_basics():
with raises(NullTypeGradError):
grad(rv_out, [rv_node.inputs[0]])


def test_RandomVariable_bcast():
rv = RandomVariable("normal", 0, [0, 0], config.floatX, inplace=True)

mu = tensor(config.floatX, [True, False, False])
Expand All @@ -131,6 +131,10 @@ def test_RandomVariable_basics():
res = rv.compute_bcast([mu, sd], (s1, s2, s3))
assert res == [False] * 3

size = aet.as_tensor((1, 2, 3), dtype=np.int32).astype(np.int64)
res = rv.compute_bcast([mu, sd], size)
assert res == [True, False, False]


def test_RandomVariable_floatX():
test_rv_op = RandomVariable(
Expand All @@ -149,28 +153,3 @@ def test_RandomVariable_floatX():

with config.change_flags(floatX=new_floatX):
assert test_rv_op(0, 1).dtype == new_floatX


def test_observed():
rv_var = normal(0, 1, size=3)
obs_var = observed(rv_var, np.array([0.2, 0.1, -2.4], dtype=config.floatX))

assert obs_var.owner.inputs[0] is rv_var

with raises(TypeError):
observed(rv_var, np.array([1, 2], dtype=int))

with raises(TypeError):
observed(rv_var, np.array([[1.0, 2.0]], dtype=rv_var.dtype))

obs_rv = observed(None, np.array([0.2, 0.1, -2.4], dtype=config.floatX))

assert isinstance(obs_rv.owner.inputs[0].type, NoneTypeT)

rv_val = vector()
rv_val.tag.test_value = np.array([0.2, 0.1, -2.4], dtype=config.floatX)

obs_var = observed(rv_var, rv_val)

with raises(NullTypeGradError):
grad(obs_var.sum(), [rv_val])
6 changes: 6 additions & 0 deletions tests/tensor/test_basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -851,6 +851,12 @@ def check(dtype):


class TestCast:
def test_can_use_numpy_types(self):
x = vector(dtype=np.int32)
y = cast(x, np.int64)
f = function([x], y)
assert f(np.array([1, 2], dtype=np.int32)).dtype == np.int64

def test_good_between_real_types(self):
good = itertools.chain(
multi_dtype_cast_checks((2,), dtypes=REAL_DTYPES),
Expand Down
5 changes: 5 additions & 0 deletions tests/tensor/test_type.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,11 @@
from aesara.tensor.type import TensorType


def test_numpy_dtype():
test_type = TensorType(np.int32, [])
assert test_type.dtype == "int32"


def test_filter_variable():
test_type = TensorType(config.floatX, [])

Expand Down