diff --git a/pytensor/tensor/basic.py b/pytensor/tensor/basic.py index c0a49c10e7..8ab49e6bb5 100644 --- a/pytensor/tensor/basic.py +++ b/pytensor/tensor/basic.py @@ -759,7 +759,21 @@ def cast(x, dtype: Union[str, np.dtype]) -> TensorVariable: @scalar_elemwise def switch(cond, ift, iff): - """if cond then ift else iff""" + r""" + if cond then ift else iff + + Examples + -------- + .. code-block:: python + + import pymc as pm + with pm.Model() as model: + x = pm.Normal('x', mu=0, sigma=1) + mu = pm.math.switch(x > 0, 2.0, -2.0) + + x1 = pm.Normal('x1', mu=mu, sigma=1) + + """ where = switch