Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 43 additions & 0 deletions thinc/backends/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -1157,6 +1157,29 @@ def reduce_sum(self, X: Floats2d, lengths: Ints1d) -> Floats2d:
Y[i] = 0.0
return Y

def reduce_first(self, X: Floats2d, lengths: Ints1d) -> Tuple[Floats2d, Ints1d]:
if lengths.size == 0:
return self.alloc2f(0, X.shape[1]), lengths
if not self.xp.all(lengths > 0):
raise ValueError(f"all sequence lengths must be >= 0")
starts_ends = self.alloc1i(lengths.shape[0] + 1, zeros=False)
starts_ends[0] = 0
starts_ends[1:] = lengths.cumsum()
if starts_ends[-1] != X.shape[0]:
raise IndexError("lengths must sum up to the number of rows")

return X[starts_ends[:-1]], starts_ends

def reduce_last(self, X: Floats2d, lengths: Ints1d) -> Tuple[Floats2d, Ints1d]:
if lengths.size == 0:
return self.alloc2f(0, X.shape[1]), lengths
if not self.xp.all(lengths > 0):
raise ValueError(f"all sequence lengths must be >= 0")
lasts = lengths.cumsum() - 1
if lasts[-1] + 1 != X.shape[0]:
raise IndexError("lengths must sum up to the number of rows")
return X[lasts], lasts

def reduce_mean(self, X: Floats2d, lengths: Ints1d) -> Floats2d:
Y = self.alloc2f(lengths.shape[0], X.shape[1], zeros=False)
start = 0
Expand Down Expand Up @@ -1187,6 +1210,26 @@ def reduce_max(self, X: Floats2d, lengths: Ints1d) -> Tuple[Floats2d, Ints2d]:
start += length
return Y, which

def backprop_reduce_first(
self, d_firsts: Floats2d, starts_ends: Ints1d
) -> Floats2d:
if starts_ends.size < 2:
raise ValueError(f"starts_ends should least have size 2")
dX = self.alloc2f(
starts_ends[-1], d_firsts.shape[1], dtype=d_firsts.dtype, zeros=True
)
dX[starts_ends[:-1]] = d_firsts
return dX

def backprop_reduce_last(self, d_lasts: Floats2d, lasts: Ints1d) -> Floats2d:
if lasts.size < 1:
raise ValueError(f"lasts should least have size 2")
dX = self.alloc2f(
lasts[-1] + 1, d_lasts.shape[1], dtype=d_lasts.dtype, zeros=True
)
dX[lasts] = d_lasts
return dX

def backprop_reduce_sum(self, d_sums: Floats2d, lengths: Ints1d) -> Floats2d:
dX = self.alloc2f(
lengths.sum(), d_sums.shape[1], dtype=d_sums.dtype, zeros=False
Expand Down
20 changes: 8 additions & 12 deletions thinc/layers/reduce_first.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
from typing import Callable, Tuple, cast, TypeVar
from typing import Callable, Tuple, cast

from ..model import Model
from ..config import registry
from ..types import Ragged, ArrayXd
from ..types import Ragged, Floats2d
from ..util import ArrayInfo

OutT = TypeVar("OutT", bound=ArrayXd)

InT = Ragged
OutT = Floats2d


@registry.layers("reduce_first.v1")
Expand All @@ -17,19 +19,13 @@ def reduce_first() -> Model[Ragged, OutT]:
def forward(
model: Model[Ragged, OutT], Xr: Ragged, is_train: bool
) -> Tuple[OutT, Callable[[OutT], Ragged]]:
starts = model.ops.alloc1i(Xr.lengths.shape[0])
starts[1:] += Xr.lengths.cumsum()[:-1]
X = Xr.dataXd
Y = cast(OutT, X[starts])
x_shape = Xr.dataXd.shape
lengths = Xr.lengths
Y, starts_ends = model.ops.reduce_first(cast(Floats2d, Xr.data), Xr.lengths)

array_info = ArrayInfo.from_array(Y)

def backprop(dY: OutT) -> Ragged:
array_info.check_consistency(dY)
dX: OutT = model.ops.alloc(x_shape, dtype=dY.dtype)
dX[starts] = dY # type: ignore[assignment]
return Ragged(dX, lengths)
dX = model.ops.backprop_reduce_first(dY, starts_ends)
return Ragged(dX, Xr.lengths)

return Y, backprop
19 changes: 8 additions & 11 deletions thinc/layers/reduce_last.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
from typing import Callable, Tuple, cast, TypeVar
from typing import Callable, Tuple, cast

from ..model import Model
from ..config import registry
from ..types import Ragged, ArrayXd
from ..types import Ragged, Floats2d
from ..util import ArrayInfo

OutT = TypeVar("OutT", bound=ArrayXd)
InT = Ragged
OutT = Floats2d


@registry.layers("reduce_last.v1")
Expand All @@ -17,16 +18,12 @@ def reduce_last() -> Model[Ragged, OutT]:
def forward(
model: Model[Ragged, OutT], Xr: Ragged, is_train: bool
) -> Tuple[OutT, Callable[[OutT], Ragged]]:
ends = Xr.lengths.cumsum() - 1
Y = cast(OutT, Xr.dataXd[ends])
x_shape = Xr.dataXd.shape
lengths = Xr.lengths
Y, lasts = model.ops.reduce_last(cast(Floats2d, Xr.data), Xr.lengths)
array_info = ArrayInfo.from_array(Y)

def backprop(dY: OutT) -> Ragged:
def backprop(dY: OutT) -> InT:
array_info.check_consistency(dY)
dX: OutT = model.ops.alloc(x_shape, dtype=dY.dtype)
dX[ends] = dY # type: ignore[assignment]
return Ragged(dX, lengths)
dX = model.ops.backprop_reduce_last(dY, lasts)
return Ragged(dX, Xr.lengths)

return Y, backprop
62 changes: 62 additions & 0 deletions thinc/tests/backends/test_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -816,6 +816,68 @@ def test_backprop_fails_with_incorrect_length(ops, dtype):
)


@pytest.mark.parametrize("ops", ALL_OPS)
@pytest.mark.parametrize("dtype", FLOAT_TYPES)
def test_reduce_first(ops, dtype):
X = ops.asarray2f(
[[1.0, 6.0], [2.0, 7.0], [3.0, 8.0], [4.0, 9.0], [5.0, 10.0]], dtype=dtype
)
lengths = ops.asarray1i([3, 2])
Y, starts_ends = ops.reduce_first(X, lengths)
ops.xp.testing.assert_equal(starts_ends, ops.asarray1i([0, 3, 5]))
ops.xp.testing.assert_allclose(Y, [[1.0, 6.0], [4.0, 9.0]])

lengths = ops.asarray1i([3, 0, 2])
with pytest.raises(ValueError, match=r"all sequence lengths must be >= 0"):
ops.reduce_last(X, lengths)

lengths = ops.asarray1i([3, 2, 1])
with pytest.raises(IndexError, match=r"lengths must sum up to the number of rows"):
ops.reduce_last(X, lengths)


@pytest.mark.parametrize("ops", ALL_OPS)
@pytest.mark.parametrize("dtype", FLOAT_TYPES)
def test_backprop_reduce_first(ops, dtype):
dY = ops.asarray2f([[1.0, 3.0], [2.0, 4.0]], dtype=dtype)
starts_ends = ops.asarray1i([0, 3, 5])
dX = ops.backprop_reduce_first(dY, starts_ends)
ops.xp.testing.assert_allclose(
dX, [[1.0, 3.0], [0.0, 0.0], [0.0, 0.0], [2.0, 4.0], [0.0, 0.0]]
)


@pytest.mark.parametrize("ops", ALL_OPS)
@pytest.mark.parametrize("dtype", FLOAT_TYPES)
def test_reduce_last(ops, dtype):
X = ops.asarray2f(
[[1.0, 6.0], [2.0, 7.0], [3.0, 8.0], [4.0, 9.0], [5.0, 10.0]], dtype=dtype
)
lengths = ops.asarray1i([3, 2])
Y, lasts = ops.reduce_last(X, lengths)
ops.xp.testing.assert_equal(lasts, ops.asarray1i([2, 4]))
ops.xp.testing.assert_allclose(Y, [[3.0, 8.0], [5.0, 10.0]])

lengths = ops.asarray1i([3, 0, 2])
with pytest.raises(ValueError, match=r"all sequence lengths must be >= 0"):
ops.reduce_last(X, lengths)

lengths = ops.asarray1i([3, 2, 1])
with pytest.raises(IndexError, match=r"lengths must sum up to the number of rows"):
ops.reduce_last(X, lengths)


@pytest.mark.parametrize("ops", ALL_OPS)
@pytest.mark.parametrize("dtype", FLOAT_TYPES)
def test_backprop_reduce_last(ops, dtype):
dY = ops.asarray2f([[1.0, 3.0], [2.0, 4.0]], dtype=dtype)
lasts = ops.asarray1i([2, 4])
dX = ops.backprop_reduce_last(dY, lasts)
ops.xp.testing.assert_allclose(
dX, [[0.0, 0.0], [0.0, 0.0], [1.0, 3.0], [0.0, 0.0], [2.0, 4.0]]
)


@pytest.mark.parametrize("ops", ALL_OPS)
@pytest.mark.parametrize("dtype", FLOAT_TYPES)
def test_reduce_max_sm(ops, dtype):
Expand Down
76 changes: 76 additions & 0 deletions website/docs/api-backends.md
Original file line number Diff line number Diff line change
Expand Up @@ -1213,6 +1213,82 @@ Backpropagate the hard Swish MobileNet activation.
| `inplace` | <tt>bool</tt> | If `True`, the `dY` array is modified in place. |
| **RETURNS** | <tt>FloatsXd</tt> | The gradient of the input. |

### Ops.reduce_first {#reduce_first tag="method"}

<inline-list>

- **default:** <i name="yes"></i>
- **numpy:** default
- **cupy:** default

</inline-list>

Perform sequence-wise first pooling for data in the ragged format. Zero-length
sequences are not allowed. A `ValueError` is raised if any element in `lengths`
is zero.

| Argument | Type | Description |
| ----------- | ------------------------------- | --------------------------------------------------------------------- |
| `X` | <tt>Floats2d</tt> | The concatenated sequences. |
| `lengths` | <tt>Ints1d</tt> | The sequence lengths. |
| **RETURNS** | <tt>Tuple[Floats2d,Ints1d]</tt> | The first vector of each sequence and the sequence start/end indices. |

### Ops.backprop_reduce_first {#backprop_reduce_first tag="method"}

<inline-list>

- **default:** <i name="yes"></i>
- **numpy:** default
- **cupy:** default

</inline-list>

Backpropagate the `reduce_first` operation.

| Argument | Type | Description |
| ------------- | ----------------- | ------------------------------------------- |
| `d_firsts` | <tt>Floats2d</tt> | The gradient of the outputs. |
| `starts_ends` | <tt>Ints1d</tt> | The sequence start/end indices. |
| **RETURNS** | <tt>Floats2d</tt> | The gradient of the concatenated sequences. |

### Ops.reduce_last {#reduce_last tag="method"}

<inline-list>

- **default:** <i name="yes"></i>
- **numpy:** default
- **cupy:** default

</inline-list>

Perform sequence-wise last pooling for data in the ragged format. Zero-length
sequences are not allowed. A `ValueError` is raised if any element in `lengths`
is zero.

| Argument | Type | Description |
| ----------- | ------------------------------- | ------------------------------------------------------------------------------- |
| `X` | <tt>Floats2d</tt> | The concatenated sequences. |
| `lengths` | <tt>Ints1d</tt> | The sequence lengths. |
| **RETURNS** | <tt>Tuple[Floats2d,Ints1d]</tt> | The last vector of each sequence and the indices of the last sequence elements. |

### Ops.backprop_reduce_last {#backprop_reduce_last tag="method"}

<inline-list>

- **default:** <i name="yes"></i>
- **numpy:** default
- **cupy:** default

</inline-list>

Backpropagate the `reduce_last` operation.

| Argument | Type | Description |
| ----------- | ----------------- | ------------------------------------------- |
| `d_lasts` | <tt>Floats2d</tt> | The gradient of the outputs. |
| `lasts` | <tt>Ints1d</tt> | Indices of the last sequence elements. |
| **RETURNS** | <tt>Floats2d</tt> | The gradient of the concatenated sequences. |

### Ops.reduce_sum {#reduce_sum tag="method"}

<inline-list>
Expand Down