Skip to content

Commit 4e358e5

Browse files
hbq1ChexDev
authored andcommitted
Update pytypes.
PiperOrigin-RevId: 513161806
1 parent b988ae0 commit 4e358e5

File tree

1 file changed

+13
-19
lines changed

1 file changed

+13
-19
lines changed

chex/_src/pytypes.py

Lines changed: 13 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -14,37 +14,31 @@
1414
# ==============================================================================
1515
"""Type definitions to use for type annotations."""
1616

17-
from typing import Any, Iterable, Mapping, Union
17+
from typing import Any, TypeAlias, Union
1818
import jax
1919
import jax.numpy as jnp
2020
import numpy as np
2121

2222
# Special types of arrays.
23-
ArrayBatched = jax.interpreters.batching.BatchTracer
24-
ArrayNumpy = np.ndarray
25-
ArraySharded = jax.interpreters.pxla.ShardedDeviceArray
23+
ArrayBatched: TypeAlias = jax.interpreters.batching.BatchTracer
24+
ArrayNumpy: TypeAlias = np.ndarray
25+
ArraySharded: TypeAlias = jax.interpreters.pxla.ShardedDeviceArray
2626
# For instance checking, use `isinstance(x, jax.Array)`.
27-
if hasattr(jax, 'Array'):
28-
ArrayDevice = jax.Array # jax >= 0.3.20
29-
elif hasattr(jax.interpreters.xla, '_DeviceArray'): # 0.2.5 < jax < 0.3.20
30-
ArrayDevice = jax.interpreters.xla._DeviceArray # pylint:disable=protected-access
31-
else: # jax <= 0.2.5
32-
ArrayDevice = jax.interpreters.xla.DeviceArray
27+
ArrayDevice: TypeAlias = jax.Array # jax >= 0.3.20
3328

3429
# Generic array type.
35-
Array = Union[ArrayDevice, ArrayNumpy, ArrayBatched, ArraySharded]
30+
Array = Union[jax.Array, np.ndarray]
31+
ArrayLike: TypeAlias = jax.typing.ArrayLike
3632

3733
# A tree of generic arrays.
38-
ArrayTree = Union[Array, Iterable['ArrayTree'], Mapping[Any, 'ArrayTree']]
34+
ArrayTree = Any
35+
# Union[Array, Iterable['ArrayTree'], Mapping[Any, 'ArrayTree']]
3936

4037
# Other types.
4138
Scalar = Union[float, int]
4239
Numeric = Union[Array, Scalar]
43-
Shape = jax.core.Shape
44-
PRNGKey = jax.random.KeyArray
45-
PyTreeDef = type(jax.tree_util.tree_structure(None))
46-
if hasattr(jax, 'Device'):
47-
Device = jax.Device # jax >= 0.4.3
48-
else:
49-
Device = jax.lib.xla_extension.Device
40+
Shape: TypeAlias = jax.core.Shape
41+
PRNGKey: TypeAlias = jax.random.KeyArray
42+
PyTreeDef: TypeAlias = jax.tree_util.PyTreeDef
43+
Device: TypeAlias = jax.Device # jax >= 0.4.3
5044
ArrayDType = type(jnp.float32)

0 commit comments

Comments
 (0)