From 81b7a1eca2658f5189fddb0d8ae369f1a7e1c833 Mon Sep 17 00:00:00 2001 From: Ricardo Vieira Date: Fri, 13 Dec 2024 16:44:26 +0100 Subject: [PATCH] Use static shape in join_nonshared_inputs --- pymc/pytensorf.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/pymc/pytensorf.py b/pymc/pytensorf.py index 6fd44b0382..ecfa5ff26d 100644 --- a/pymc/pytensorf.py +++ b/pymc/pytensorf.py @@ -596,13 +596,13 @@ def join_nonshared_inputs( raise ValueError("Empty list of input variables.") raveled_inputs = pt.concatenate([var.ravel() for var in inputs]) + size = sum(point[var_name].size for var_name in point) if not make_inputs_shared: - tensor_type = raveled_inputs.type - joined_inputs = tensor_type("joined_inputs") + joined_inputs = pt.tensor("joined_inputs", shape=(size,), dtype=raveled_inputs.dtype) else: joined_values = np.concatenate([point[var.name].ravel() for var in inputs]) - joined_inputs = pytensor.shared(joined_values, "joined_inputs") + joined_inputs = pytensor.shared(joined_values, "joined_inputs", shape=(size,)) if pytensor.config.compute_test_value != "off": joined_inputs.tag.test_value = raveled_inputs.tag.test_value