Skip to content

Commit 158a884

Browse files
committed
Remove Join view flag
1 parent 0b56ed9 commit 158a884

File tree

6 files changed

+111
-252
lines changed

6 files changed

+111
-252
lines changed

pytensor/scan/checkpoints.py

+1-4
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
import pytensor.tensor.basic as ptb
22
from pytensor.scan.basic import scan
3-
from pytensor.tensor.basic import Join
43
from pytensor.tensor.math import ceil, eq, neq
54
from pytensor.tensor.subtensor import set_subtensor
65

@@ -127,14 +126,12 @@ def scan_checkpoints(
127126

128127
# Pad the sequences if needed
129128
if padding:
130-
# Since padding could be an empty tensor, Join returns a view of s.
131-
join = Join(view=0)
132129
for i, s in enumerate(sequences):
133130
overshoots_by = s.shape[0] % save_every_N
134131
overshoots = neq(overshoots_by, 0)
135132
n = (save_every_N - overshoots_by) * overshoots
136133
z = ptb.zeros((n, *s.shape[1:]), dtype=s.dtype)
137-
sequences[i] = join(0, s, z)
134+
sequences[i] = ptb.join(0, s, z)
138135

139136
# Establish the input variables of the outer scan
140137
o_sequences = [

pytensor/tensor/basic.py

+61-121
Original file line numberDiff line numberDiff line change
@@ -2439,27 +2439,7 @@ class Join(COp):
24392439
"""
24402440

24412441
check_input = False
2442-
__props__ = ("view",)
2443-
2444-
def __init__(self, view=-1):
2445-
self.view = view
2446-
if view != -1:
2447-
# since the first input is always the axis, the tensors
2448-
# start from index 1.
2449-
self.view_map = {0: [1 + view]}
2450-
2451-
def __str__(self):
2452-
if self.view == -1:
2453-
return self.__class__.__name__
2454-
else:
2455-
classname = self.__class__.__name__
2456-
args = ", ".join(f"{p}={getattr(self, p)!r}" for p in self.__props__)
2457-
return f"{classname}{{{args}}}"
2458-
2459-
def __setstate__(self, d):
2460-
self.__dict__.update(d)
2461-
if not hasattr(self, "view"):
2462-
self.view = -1
2442+
__props__ = ()
24632443

24642444
def make_node(self, axis, *tensors):
24652445
"""
@@ -2476,74 +2456,62 @@ def make_node(self, axis, *tensors):
24762456
if not tensors:
24772457
raise ValueError("Cannot join an empty list of tensors")
24782458

2459+
axis = as_tensor_variable(axis)
2460+
if axis.type.dtype not in int_dtypes:
2461+
raise TypeError(f"Axis {axis} must be an integer type.")
2462+
if axis.type.ndim > 0:
2463+
raise TypeError(f"Axis {axis} must be 0-d.")
2464+
24792465
tensors = [as_tensor_variable(x) for x in tensors]
2480-
out_dtype = ps.upcast(*[x.type.dtype for x in tensors])
24812466

2482-
if not builtins.all(targs.type.ndim for targs in tensors):
2467+
if not builtins.all(targs.type.ndim > 0 for targs in tensors):
24832468
raise TypeError(
24842469
"Join cannot handle arguments of dimension 0."
2485-
" Use `stack` to join scalar values."
2470+
" Use `stack` to join scalar values and/or increase rank of scalars."
24862471
)
24872472

24882473
if len(tensors) == 1:
24892474
out_shape = tensors[0].type.shape
24902475
else:
2491-
# When the axis is fixed, a dimension should be
2492-
# broadcastable if at least one of the inputs is
2493-
# broadcastable on that dimension (see justification below),
2494-
# except for the axis dimension.
2495-
# Initialize bcastable all false, and then fill in some trues with
2496-
# the loops.
2497-
2498-
if not isinstance(axis, int):
2499-
try:
2500-
axis = int(get_scalar_constant_value(axis))
2501-
except NotScalarConstantError:
2502-
pass
2503-
25042476
ndim = tensors[0].type.ndim
2505-
if isinstance(axis, int):
2506-
# Basically, broadcastable -> length 1, but the
2507-
# converse does not hold. So we permit e.g. T/F/T
2508-
# joins, and if they fail at runtime they fail, but if
2509-
# they don't then it means that the argument where
2510-
# that broadcastable flag was False had length 1 along
2511-
# this dimension, and therefore this dimension should
2512-
# be broadcastable for the output.
2513-
2514-
if axis < -ndim:
2515-
raise IndexError(
2516-
f"Axis value {axis} is out of range for the given input dimensions"
2517-
)
2518-
if axis < 0:
2519-
axis += ndim
2520-
if axis > ndim - 1:
2521-
raise ValueError(
2522-
f"Axis value {axis} is out of range for the given input dimensions"
2523-
)
2524-
# NOTE: Constant negative axis can no longer be negative at this point.
2525-
2526-
in_shapes = [x.type.shape for x in tensors]
2527-
in_ndims = [len(s) for s in in_shapes]
2528-
if set(in_ndims) != {ndim}:
2529-
raise TypeError(
2530-
"Only tensors with the same number of dimensions can be joined."
2531-
f" Input ndims were: {in_ndims}."
2532-
)
2477+
2478+
if not builtins.all(x.ndim == ndim for x in tensors):
2479+
raise TypeError(
2480+
"Only tensors with the same number of dimensions can be joined"
2481+
)
2482+
2483+
try:
2484+
# Note: This is dubious, if a user passed a constant we should propagate it to the inputs
2485+
# Not override it.
2486+
static_axis = int(get_scalar_constant_value(axis))
2487+
except NotScalarConstantError:
2488+
static_axis = None
2489+
2490+
if static_axis is None:
2491+
# When axis isn't static, we can't canclude anything about output dimension
2492+
# (unless we had some degenerate zero arrays) that can be removed during rewrites.
2493+
# We could also raise errors if any dimensions are pairwise inconsistent across all the axes
2494+
# As no matter the join it would be invalid.
2495+
# However, dynamic axis is so rare that is not worth the trouble
2496+
out_shape = [None] * ndim
2497+
2498+
else: # We know the axis statically
2499+
static_axis = normalize_axis_index(static_axis, ndim)
2500+
static_shapes = [x.type.shape for x in tensors]
25332501

25342502
# Determine output shapes from a matrix of input shapes
2535-
in_shapes = np.array(in_shapes)
2503+
static_shapes = np.array(static_shapes)
25362504
out_shape = [None] * ndim
25372505
for d in range(ndim):
2538-
ins = in_shapes[:, d]
2539-
if d == axis:
2540-
# Any unknown size along the axis means we can't sum
2506+
ins = static_shapes[:, d]
2507+
if d == static_axis:
2508+
# Any unknown size along the axis means we can't infer it
25412509
if None in ins:
25422510
out_shape[d] = None
25432511
else:
25442512
out_shape[d] = sum(ins)
25452513
else:
2546-
inset = set(in_shapes[:, d])
2514+
inset = set(static_shapes[:, d])
25472515
# Other dims must match exactly,
25482516
# or if a mix of None and ? the output will be ?
25492517
# otherwise the input shapes are incompatible.
@@ -2553,54 +2521,27 @@ def make_node(self, axis, *tensors):
25532521
(out_shape[d],) = inset - {None}
25542522
else:
25552523
raise ValueError(
2556-
f"all input array dimensions other than the specified `axis` ({axis})"
2524+
f"all input array dimensions other than the specified `axis` ({static_axis})"
25572525
" must match exactly, or be unknown (None),"
25582526
f" but along dimension {d}, the inputs shapes are incompatible: {ins}"
25592527
)
2560-
else:
2561-
# When the axis may vary, no dimension can be guaranteed to be
2562-
# broadcastable.
2563-
out_shape = [None] * tensors[0].type.ndim
2564-
2565-
if not builtins.all(x.ndim == len(out_shape) for x in tensors):
2566-
raise TypeError(
2567-
"Only tensors with the same number of dimensions can be joined"
2568-
)
2569-
2570-
inputs = [as_tensor_variable(axis), *tensors]
2571-
2572-
if inputs[0].type.dtype not in int_dtypes:
2573-
raise TypeError(f"Axis value {inputs[0]} must be an integer type")
25742528

2529+
inputs = [axis, *tensors]
2530+
out_dtype = ps.upcast(*[x.type.dtype for x in tensors])
25752531
return Apply(self, inputs, [tensor(dtype=out_dtype, shape=out_shape)])
25762532

2577-
def perform(self, node, axis_and_tensors, out_):
2578-
(out,) = out_
2579-
view = self.view
2580-
axis, tens = axis_and_tensors[0], axis_and_tensors[1:]
2581-
# we check these tensors for being empty.
2582-
if (view != -1) and all(
2583-
tensor.shape[axis] == 0 for tensor in tens[0:view] + tens[view + 1 :]
2584-
):
2585-
out[0] = tens[view]
2586-
2587-
else:
2588-
ndim = tens[0].ndim
2589-
if axis < -ndim:
2590-
raise IndexError(
2591-
f"Join axis {int(axis)} out of bounds [0, {int(ndim)})"
2592-
)
2593-
2594-
out[0] = np.asarray(
2595-
np.concatenate(tens, axis=axis), dtype=node.outputs[0].type.dtype
2596-
)
2533+
def perform(self, node, inputs, output_storage):
2534+
axis, *arrays = inputs
2535+
output_storage[0][0] = np.concatenate(
2536+
arrays, axis=axis, dtype=node.outputs[0].type.dtype
2537+
)
25972538

25982539
def c_code_cache_version(self):
25992540
return (5,)
26002541

26012542
def c_code(self, node, name, inputs, outputs, sub):
26022543
axis, tens = inputs[0], inputs[1:]
2603-
view = self.view
2544+
view = -1
26042545
non_empty_tensor = tens[view]
26052546
input_1 = tens[0]
26062547
l = len(tens)
@@ -2656,22 +2597,21 @@ def R_op(self, inputs, eval_points):
26562597
return [None]
26572598
return self.make_node(inputs[0], *eval_points[1:]).outputs
26582599

2659-
def grad(self, axis_and_tensors, grads):
2600+
def L_op(self, inputs, outputs, grads):
26602601
"""The gradient wrt a join op is a `Split`, used to partition
26612602
the gradient along the `axis` which was used for joining.
26622603
"""
2663-
(gz,) = grads
2664-
axis, tens = axis_and_tensors[0], axis_and_tensors[1:]
2604+
[gz] = grads
2605+
[out] = outputs
2606+
axis, *tensors = inputs
26652607

26662608
rval = [grad_undefined(self, 0, axis)]
2667-
2668-
dtypes = [as_tensor_variable(x).type.dtype for x in tens]
2669-
out_dtype = ps.upcast(*dtypes)
2609+
out_dtype = out.type.dtype
26702610

26712611
if "float" in out_dtype or "complex" in out_dtype:
26722612
# assume that this is differentiable
2673-
split = Split(len(tens))
2674-
split_gz = split(gz, axis, stack([shape(x)[axis] for x in tens]))
2613+
split_sizes = stack([shape(x)[axis] for x in tensors])
2614+
split_gz = split(gz, split_sizes, n_splits=len(tensors), axis=axis)
26752615
# If there is only one split, it might not be in a list.
26762616
if not isinstance(split_gz, list):
26772617
split_gz = [split_gz]
@@ -2684,13 +2624,12 @@ def grad(self, axis_and_tensors, grads):
26842624
else specify_broadcastable(
26852625
g, *(ax for (ax, s) in enumerate(t.type.shape) if s == 1)
26862626
)
2687-
for t, g in zip(tens, split_gz, strict=True)
2627+
for t, g in zip(tensors, split_gz, strict=True)
26882628
]
26892629
rval = rval + split_gz
26902630
else:
2691-
# the output has integer type, so the gradient through it
2692-
# is 0
2693-
rval = rval + [t.zeros_like(dtype=config.floatX) for t in tens]
2631+
# the output has integer type, so the gradient through it is 0
2632+
rval = rval + [t.zeros_like(dtype=config.floatX) for t in tensors]
26942633

26952634
return rval
26962635

@@ -2710,7 +2649,8 @@ def infer_shape(self, fgraph, node, ishapes):
27102649
# An axis < -n_dim or >= ndim would be invalid, but this is
27112650
# not checked here. A `CheckAndRaise` `Op` would be a way of
27122651
# addressing that, but it may disrupt optimizations.
2713-
join_dim = switch(ge(node.inputs[0], 0), node.inputs[0], node.inputs[0] + n_dim)
2652+
axis = node.inputs[0]
2653+
join_dim = switch(ge(axis, 0), axis, axis + n_dim)
27142654
out_shapes = []
27152655
for dim in range(n_dim):
27162656
# we have to deal with 2 possible cases in here :
@@ -2733,7 +2673,7 @@ def infer_shape(self, fgraph, node, ishapes):
27332673
return [tuple(out_shapes)]
27342674

27352675

2736-
join_ = Join()
2676+
_join = Join()
27372677
pprint.assign(Join, printing.FunctionPrinter(["join"]))
27382678

27392679

@@ -2776,7 +2716,7 @@ def join(axis, *tensors_list):
27762716
if len(tensors_list) == 1:
27772717
return tensors_list[0]
27782718
else:
2779-
return join_(axis, *tensors_list)
2719+
return _join(axis, *tensors_list)
27802720

27812721

27822722
@_vectorize_node.register(Join)

pytensor/tensor/rewriting/basic.py

+17-31
Original file line numberDiff line numberDiff line change
@@ -817,52 +817,38 @@ def local_join_1(fgraph, node):
817817
return [tensors[0]]
818818

819819

820-
# TODO: merge in local_useless_join
821-
@register_infer_shape
822820
@register_useless
823-
@register_specialize
824821
@register_canonicalize
822+
@register_specialize
825823
@node_rewriter([Join])
826824
def local_join_empty(fgraph, node):
827825
"""Join(i, x, y, empty) => Join(i, x, y)
828826
829827
Remove empty inputs to joins. The empty inputs can be anywhere.
830-
831828
"""
832-
if not isinstance(node.op, Join):
833-
return
834-
new_inputs = []
829+
axis, *tensors = node.inputs
830+
835831
try:
836-
join_idx = get_scalar_constant_value(
832+
static_axis = get_scalar_constant_value(
837833
node.inputs[0], only_process_constants=True
838834
)
839835
except NotScalarConstantError:
840836
return
841-
for idx in range(1, len(node.inputs)):
842-
inp = node.inputs[idx]
843-
# We can not use size == 0,, as this can change shape from 3,0
844-
# to 2,0. This trigger DebugMode error. This happen with
845-
# stack(...,[]) as this add a dimshuffle on [], that add a
846-
# dimensions with shape 1.
847-
if isinstance(inp, Constant) and inp.data.shape[join_idx] == 0:
848-
continue
849-
new_inputs.append(inp)
850-
if len(new_inputs) < len(node.inputs) - 1:
851-
if len(new_inputs) == 0:
852-
# at.join do not work in that case.
853-
# constant folding will take care of this case.
854-
return
855-
ret = join(node.inputs[0], *new_inputs)
856-
o = node.outputs[0]
857-
if ret.dtype != o.dtype:
858-
# Join can upcast some inputs
859-
return
860837

861-
# Copy over stacktrace from previous output (after join op)
862-
# to new output, because an error in the new op must be caused
863-
# by an error in the old join op.
864-
copy_stack_trace(node.outputs, ret)
838+
new_tensors = [tensor for tensor in tensors if tensor.type.shape[static_axis] != 0]
839+
840+
# If there are zero tensors, the join is useless but so is any other operation
841+
# Another rewrite will (one day) handle all those cases
842+
if 0 < len(new_tensors) < len(tensors):
843+
# join eagerly returns a tensor when there is only one, no need for us to check
844+
ret = join(axis, *new_tensors)
845+
846+
[old_output] = node.outputs
847+
848+
if ret.dtype != old_output.dtype:
849+
ret = ret.astype(old_output.dtype)
865850

851+
copy_stack_trace(old_output, ret)
866852
return [ret]
867853

868854

tests/link/numba/test_tensor_basic.py

-18
Original file line numberDiff line numberDiff line change
@@ -172,24 +172,6 @@ def test_Join(vals, axis):
172172
)
173173

174174

175-
def test_Join_view():
176-
vals, vals_test = zip(
177-
*(
178-
(pt.matrix(), rng.normal(size=(2, 2)).astype(config.floatX)),
179-
(pt.matrix(), rng.normal(size=(2, 2)).astype(config.floatX)),
180-
),
181-
strict=True,
182-
)
183-
g = ptb.Join(view=1)(1, *vals)
184-
185-
with pytest.raises(NotImplementedError):
186-
compare_numba_and_py(
187-
vals,
188-
g,
189-
vals_test,
190-
)
191-
192-
193175
@pytest.mark.parametrize(
194176
"n_splits, axis, values, sizes",
195177
[

0 commit comments

Comments
 (0)