Skip to content

Commit 6257f20

Browse files
committed
Remove Join view flag
Do not normalize constant axis in make_node and fix rewrite that assumed this would always be positive
1 parent 0f5da80 commit 6257f20

File tree

6 files changed

+157
-299
lines changed

6 files changed

+157
-299
lines changed

pytensor/scan/checkpoints.py

+1-4
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
import pytensor.tensor.basic as ptb
22
from pytensor.scan.basic import scan
3-
from pytensor.tensor.basic import Join
43
from pytensor.tensor.math import ceil, eq, neq
54
from pytensor.tensor.subtensor import set_subtensor
65

@@ -127,14 +126,12 @@ def scan_checkpoints(
127126

128127
# Pad the sequences if needed
129128
if padding:
130-
# Since padding could be an empty tensor, Join returns a view of s.
131-
join = Join(view=0)
132129
for i, s in enumerate(sequences):
133130
overshoots_by = s.shape[0] % save_every_N
134131
overshoots = neq(overshoots_by, 0)
135132
n = (save_every_N - overshoots_by) * overshoots
136133
z = ptb.zeros((n, *s.shape[1:]), dtype=s.dtype)
137-
sequences[i] = join(0, s, z)
134+
sequences[i] = ptb.join(0, s, z)
138135

139136
# Establish the input variables of the outer scan
140137
o_sequences = [

pytensor/tensor/basic.py

+102-164
Original file line numberDiff line numberDiff line change
@@ -2439,27 +2439,7 @@ class Join(COp):
24392439
"""
24402440

24412441
check_input = False
2442-
__props__ = ("view",)
2443-
2444-
def __init__(self, view=-1):
2445-
self.view = view
2446-
if view != -1:
2447-
# since the first input is always the axis, the tensors
2448-
# start from index 1.
2449-
self.view_map = {0: [1 + view]}
2450-
2451-
def __str__(self):
2452-
if self.view == -1:
2453-
return self.__class__.__name__
2454-
else:
2455-
classname = self.__class__.__name__
2456-
args = ", ".join(f"{p}={getattr(self, p)!r}" for p in self.__props__)
2457-
return f"{classname}{{{args}}}"
2458-
2459-
def __setstate__(self, d):
2460-
self.__dict__.update(d)
2461-
if not hasattr(self, "view"):
2462-
self.view = -1
2442+
__props__ = ()
24632443

24642444
def make_node(self, axis, *tensors):
24652445
"""
@@ -2476,74 +2456,62 @@ def make_node(self, axis, *tensors):
24762456
if not tensors:
24772457
raise ValueError("Cannot join an empty list of tensors")
24782458

2459+
axis = as_tensor_variable(axis)
2460+
if axis.type.dtype not in int_dtypes:
2461+
raise TypeError(f"Axis {axis} must be an integer type.")
2462+
if axis.type.ndim > 0:
2463+
raise TypeError(f"Axis {axis} must be 0-d.")
2464+
24792465
tensors = [as_tensor_variable(x) for x in tensors]
2480-
out_dtype = ps.upcast(*[x.type.dtype for x in tensors])
24812466

2482-
if not builtins.all(targs.type.ndim for targs in tensors):
2467+
if not builtins.all(targs.type.ndim > 0 for targs in tensors):
24832468
raise TypeError(
24842469
"Join cannot handle arguments of dimension 0."
2485-
" Use `stack` to join scalar values."
2470+
" Use `stack` to join scalar values and/or increase rank of scalars."
24862471
)
24872472

24882473
if len(tensors) == 1:
24892474
out_shape = tensors[0].type.shape
24902475
else:
2491-
# When the axis is fixed, a dimension should be
2492-
# broadcastable if at least one of the inputs is
2493-
# broadcastable on that dimension (see justification below),
2494-
# except for the axis dimension.
2495-
# Initialize bcastable all false, and then fill in some trues with
2496-
# the loops.
2497-
2498-
if not isinstance(axis, int):
2499-
try:
2500-
axis = int(get_scalar_constant_value(axis))
2501-
except NotScalarConstantError:
2502-
pass
2503-
25042476
ndim = tensors[0].type.ndim
2505-
if isinstance(axis, int):
2506-
# Basically, broadcastable -> length 1, but the
2507-
# converse does not hold. So we permit e.g. T/F/T
2508-
# joins, and if they fail at runtime they fail, but if
2509-
# they don't then it means that the argument where
2510-
# that broadcastable flag was False had length 1 along
2511-
# this dimension, and therefore this dimension should
2512-
# be broadcastable for the output.
2513-
2514-
if axis < -ndim:
2515-
raise IndexError(
2516-
f"Axis value {axis} is out of range for the given input dimensions"
2517-
)
2518-
if axis < 0:
2519-
axis += ndim
2520-
if axis > ndim - 1:
2521-
raise ValueError(
2522-
f"Axis value {axis} is out of range for the given input dimensions"
2523-
)
2524-
# NOTE: Constant negative axis can no longer be negative at this point.
2525-
2526-
in_shapes = [x.type.shape for x in tensors]
2527-
in_ndims = [len(s) for s in in_shapes]
2528-
if set(in_ndims) != {ndim}:
2529-
raise TypeError(
2530-
"Only tensors with the same number of dimensions can be joined."
2531-
f" Input ndims were: {in_ndims}."
2532-
)
2477+
2478+
if not builtins.all(x.ndim == ndim for x in tensors):
2479+
raise TypeError(
2480+
"Only tensors with the same number of dimensions can be joined"
2481+
)
2482+
2483+
try:
2484+
# Note: This is dubious, if a user passed a constant we should propagate it to the inputs
2485+
# Not override it.
2486+
static_axis = int(get_scalar_constant_value(axis))
2487+
except NotScalarConstantError:
2488+
static_axis = None
2489+
2490+
if static_axis is None:
2491+
# When axis isn't static, we can't canclude anything about output dimension
2492+
# (unless we had some degenerate zero arrays) that can be removed during rewrites.
2493+
# We could also raise errors if any dimensions are pairwise inconsistent across all the axes
2494+
# As no matter the join it would be invalid.
2495+
# However, dynamic axis is so rare that is not worth the trouble
2496+
out_shape = [None] * ndim
2497+
2498+
else: # We know the axis statically
2499+
static_axis = normalize_axis_index(static_axis, ndim)
2500+
static_shapes = [x.type.shape for x in tensors]
25332501

25342502
# Determine output shapes from a matrix of input shapes
2535-
in_shapes = np.array(in_shapes)
2503+
static_shapes = np.array(static_shapes)
25362504
out_shape = [None] * ndim
25372505
for d in range(ndim):
2538-
ins = in_shapes[:, d]
2539-
if d == axis:
2540-
# Any unknown size along the axis means we can't sum
2506+
ins = static_shapes[:, d]
2507+
if d == static_axis:
2508+
# Any unknown size along the axis means we can't infer it
25412509
if None in ins:
25422510
out_shape[d] = None
25432511
else:
25442512
out_shape[d] = sum(ins)
25452513
else:
2546-
inset = set(in_shapes[:, d])
2514+
inset = set(static_shapes[:, d])
25472515
# Other dims must match exactly,
25482516
# or if a mix of None and ? the output will be ?
25492517
# otherwise the input shapes are incompatible.
@@ -2553,100 +2521,71 @@ def make_node(self, axis, *tensors):
25532521
(out_shape[d],) = inset - {None}
25542522
else:
25552523
raise ValueError(
2556-
f"all input array dimensions other than the specified `axis` ({axis})"
2524+
f"all input array dimensions other than the specified `axis` ({static_axis})"
25572525
" must match exactly, or be unknown (None),"
25582526
f" but along dimension {d}, the inputs shapes are incompatible: {ins}"
25592527
)
2560-
else:
2561-
# When the axis may vary, no dimension can be guaranteed to be
2562-
# broadcastable.
2563-
out_shape = [None] * tensors[0].type.ndim
2564-
2565-
if not builtins.all(x.ndim == len(out_shape) for x in tensors):
2566-
raise TypeError(
2567-
"Only tensors with the same number of dimensions can be joined"
2568-
)
2569-
2570-
inputs = [as_tensor_variable(axis), *tensors]
2571-
2572-
if inputs[0].type.dtype not in int_dtypes:
2573-
raise TypeError(f"Axis value {inputs[0]} must be an integer type")
25742528

2529+
inputs = [axis, *tensors]
2530+
out_dtype = ps.upcast(*[x.type.dtype for x in tensors])
25752531
return Apply(self, inputs, [tensor(dtype=out_dtype, shape=out_shape)])
25762532

2577-
def perform(self, node, axis_and_tensors, out_):
2578-
(out,) = out_
2579-
view = self.view
2580-
axis, tens = axis_and_tensors[0], axis_and_tensors[1:]
2581-
# we check these tensors for being empty.
2582-
if (view != -1) and all(
2583-
tensor.shape[axis] == 0 for tensor in tens[0:view] + tens[view + 1 :]
2584-
):
2585-
out[0] = tens[view]
2586-
2587-
else:
2588-
ndim = tens[0].ndim
2589-
if axis < -ndim:
2590-
raise IndexError(
2591-
f"Join axis {int(axis)} out of bounds [0, {int(ndim)})"
2592-
)
2593-
2594-
out[0] = np.asarray(
2595-
np.concatenate(tens, axis=axis), dtype=node.outputs[0].type.dtype
2596-
)
2533+
def perform(self, node, inputs, output_storage):
2534+
axis, *arrays = inputs
2535+
output_storage[0][0] = np.concatenate(
2536+
arrays, axis=axis, dtype=node.outputs[0].type.dtype
2537+
)
25972538

25982539
def c_code_cache_version(self):
2599-
return (5,)
2540+
return (6,)
26002541

26012542
def c_code(self, node, name, inputs, outputs, sub):
2602-
axis, tens = inputs[0], inputs[1:]
2603-
view = self.view
2604-
non_empty_tensor = tens[view]
2605-
input_1 = tens[0]
2606-
l = len(tens)
2607-
(out,) = outputs
2543+
axis, *arrays = inputs
2544+
[out] = outputs
2545+
n = len(arrays)
2546+
ndim = node.outputs[0].type.ndim
26082547
fail = sub["fail"]
2609-
adtype = node.inputs[0].type.dtype_specs()[1]
26102548

2611-
copy_to_list = (
2612-
f"""Py_INCREF({inp}); PyList_SetItem(list, {i}, (PyObject*){inp});"""
2613-
for i, inp in enumerate(tens)
2614-
)
2549+
# Most times axis is constant, inline it
2550+
# This is safe to do because the hash of the c_code includes the constant signature
2551+
if isinstance(node.inputs[0], Constant):
2552+
static_axis = int(node.inputs[0].data)
2553+
static_axis = normalize_axis_index(static_axis, ndim)
2554+
axis_def = f"{static_axis};"
2555+
axis_check = ""
2556+
else:
2557+
axis_dtype = node.inputs[0].type.dtype_specs()[1]
2558+
axis_def = f"(({axis_dtype} *)PyArray_DATA({axis}))[0];"
2559+
axis_check = f"""
2560+
if (axis < 0){{
2561+
axis = {ndim} + axis;
2562+
}}
2563+
if (axis >= {ndim} || axis < 0) {{
2564+
PyErr_SetString(PyExc_ValueError, "Join axis is out of bounds");
2565+
{fail}
2566+
}}
2567+
"""
26152568

2616-
copy_inputs_to_list = "\n".join(copy_to_list)
2617-
n = len(tens)
2569+
copy_arrays_to_tuple = "\n".join(
2570+
(
2571+
f"""Py_INCREF({array}); PyTuple_SetItem(arrays_tuple, {i}, (PyObject*){array});"""
2572+
for i, array in enumerate(arrays)
2573+
)
2574+
)
26182575

26192576
code = f"""
2620-
int axis = (({adtype} *)PyArray_DATA({axis}))[0];
2621-
PyObject* list = PyList_New({l});
2622-
{copy_inputs_to_list}
2623-
int tensors_lens_sum;
2624-
if({view} != -1) {{
2625-
tensors_lens_sum = 0;
2626-
2627-
for(int i=0; i < {n}; i++){{
2628-
tensors_lens_sum += PyArray_DIM((PyArrayObject *)(PyList_GetItem(list, i)), axis);
2629-
}}
2630-
tensors_lens_sum -= PyArray_DIM({non_empty_tensor}, axis);
2631-
}}
2632-
if({view} != -1 && tensors_lens_sum == 0) {{
2633-
Py_XDECREF({out});
2634-
Py_INCREF({non_empty_tensor});
2635-
{out} = {non_empty_tensor};
2636-
}}else{{
2637-
//PyObject* PyArray_Concatenate(PyObject* obj, int axis)
2638-
int ndim = PyArray_NDIM({input_1});
2639-
if( axis < -ndim ){{
2640-
PyErr_Format(PyExc_IndexError,
2641-
"Join axis %d out of bounds [0, %d)", axis, ndim);
2642-
{fail}
2643-
}}
2644-
Py_XDECREF({out});
2645-
{out} = (PyArrayObject *)PyArray_Concatenate(list, axis);
2646-
Py_DECREF(list);
2647-
if(!{out}){{
2648-
{fail}
2649-
}}
2577+
int axis = {axis_def}
2578+
PyArrayObject* arrays[{n}] = {{{','.join(arrays)}}};
2579+
2580+
{axis_check}
2581+
2582+
Py_XDECREF({out});
2583+
PyObject* arrays_tuple = PyTuple_New({n});
2584+
{copy_arrays_to_tuple}
2585+
{out} = (PyArrayObject *)PyArray_Concatenate(arrays_tuple, axis);
2586+
Py_DECREF(arrays_tuple);
2587+
if(!{out}){{
2588+
{fail}
26502589
}}
26512590
"""
26522591
return code
@@ -2656,22 +2595,21 @@ def R_op(self, inputs, eval_points):
26562595
return [None]
26572596
return self.make_node(inputs[0], *eval_points[1:]).outputs
26582597

2659-
def grad(self, axis_and_tensors, grads):
2598+
def L_op(self, inputs, outputs, grads):
26602599
"""The gradient wrt a join op is a `Split`, used to partition
26612600
the gradient along the `axis` which was used for joining.
26622601
"""
2663-
(gz,) = grads
2664-
axis, tens = axis_and_tensors[0], axis_and_tensors[1:]
2602+
[gz] = grads
2603+
[out] = outputs
2604+
axis, *tensors = inputs
26652605

26662606
rval = [grad_undefined(self, 0, axis)]
2667-
2668-
dtypes = [as_tensor_variable(x).type.dtype for x in tens]
2669-
out_dtype = ps.upcast(*dtypes)
2607+
out_dtype = out.type.dtype
26702608

26712609
if "float" in out_dtype or "complex" in out_dtype:
26722610
# assume that this is differentiable
2673-
split = Split(len(tens))
2674-
split_gz = split(gz, axis, stack([shape(x)[axis] for x in tens]))
2611+
split_sizes = stack([shape(x)[axis] for x in tensors])
2612+
split_gz = split(gz, split_sizes, n_splits=len(tensors), axis=axis)
26752613
# If there is only one split, it might not be in a list.
26762614
if not isinstance(split_gz, list):
26772615
split_gz = [split_gz]
@@ -2684,13 +2622,12 @@ def grad(self, axis_and_tensors, grads):
26842622
else specify_broadcastable(
26852623
g, *(ax for (ax, s) in enumerate(t.type.shape) if s == 1)
26862624
)
2687-
for t, g in zip(tens, split_gz, strict=True)
2625+
for t, g in zip(tensors, split_gz, strict=True)
26882626
]
26892627
rval = rval + split_gz
26902628
else:
2691-
# the output has integer type, so the gradient through it
2692-
# is 0
2693-
rval = rval + [t.zeros_like(dtype=config.floatX) for t in tens]
2629+
# the output has integer type, so the gradient through it is 0
2630+
rval = rval + [t.zeros_like(dtype=config.floatX) for t in tensors]
26942631

26952632
return rval
26962633

@@ -2710,7 +2647,8 @@ def infer_shape(self, fgraph, node, ishapes):
27102647
# An axis < -n_dim or >= ndim would be invalid, but this is
27112648
# not checked here. A `CheckAndRaise` `Op` would be a way of
27122649
# addressing that, but it may disrupt optimizations.
2713-
join_dim = switch(ge(node.inputs[0], 0), node.inputs[0], node.inputs[0] + n_dim)
2650+
axis = node.inputs[0]
2651+
join_dim = switch(ge(axis, 0), axis, axis + n_dim)
27142652
out_shapes = []
27152653
for dim in range(n_dim):
27162654
# we have to deal with 2 possible cases in here :
@@ -2733,7 +2671,7 @@ def infer_shape(self, fgraph, node, ishapes):
27332671
return [tuple(out_shapes)]
27342672

27352673

2736-
join_ = Join()
2674+
_join = Join()
27372675
pprint.assign(Join, printing.FunctionPrinter(["join"]))
27382676

27392677

@@ -2776,7 +2714,7 @@ def join(axis, *tensors_list):
27762714
if len(tensors_list) == 1:
27772715
return tensors_list[0]
27782716
else:
2779-
return join_(axis, *tensors_list)
2717+
return _join(axis, *tensors_list)
27802718

27812719

27822720
@_vectorize_node.register(Join)

0 commit comments

Comments
 (0)