|
14 | 14 | # ==============================================================================
|
15 | 15 | """Type definitions to use for type annotations."""
|
16 | 16 |
|
17 |
| -from typing import Any, Iterable, Mapping, Union |
| 17 | +from typing import Any, TypeAlias, Union |
18 | 18 | import jax
|
19 | 19 | import jax.numpy as jnp
|
20 | 20 | import numpy as np
|
21 | 21 |
|
22 | 22 | # Special types of arrays.
|
23 |
| -ArrayBatched = jax.interpreters.batching.BatchTracer |
24 |
| -ArrayNumpy = np.ndarray |
25 |
| -ArraySharded = jax.interpreters.pxla.ShardedDeviceArray |
| 23 | +ArrayBatched: TypeAlias = jax.interpreters.batching.BatchTracer |
| 24 | +ArrayNumpy: TypeAlias = np.ndarray |
| 25 | +ArraySharded: TypeAlias = jax.interpreters.pxla.ShardedDeviceArray |
26 | 26 | # For instance checking, use `isinstance(x, jax.Array)`.
|
27 |
| -if hasattr(jax, 'Array'): |
28 |
| - ArrayDevice = jax.Array # jax >= 0.3.20 |
29 |
| -elif hasattr(jax.interpreters.xla, '_DeviceArray'): # 0.2.5 < jax < 0.3.20 |
30 |
| - ArrayDevice = jax.interpreters.xla._DeviceArray # pylint:disable=protected-access |
31 |
| -else: # jax <= 0.2.5 |
32 |
| - ArrayDevice = jax.interpreters.xla.DeviceArray |
| 27 | +ArrayDevice: TypeAlias = jax.Array # jax >= 0.3.20 |
33 | 28 |
|
34 | 29 | # Generic array type.
|
35 |
| -Array = Union[ArrayDevice, ArrayNumpy, ArrayBatched, ArraySharded] |
| 30 | +Array = Union[jax.Array, np.ndarray] |
| 31 | +ArrayLike: TypeAlias = jax.typing.ArrayLike |
36 | 32 |
|
37 | 33 | # A tree of generic arrays.
|
38 |
| -ArrayTree = Union[Array, Iterable['ArrayTree'], Mapping[Any, 'ArrayTree']] |
| 34 | +ArrayTree = Any |
| 35 | +# Union[Array, Iterable['ArrayTree'], Mapping[Any, 'ArrayTree']] |
39 | 36 |
|
40 | 37 | # Other types.
|
41 | 38 | Scalar = Union[float, int]
|
42 | 39 | Numeric = Union[Array, Scalar]
|
43 |
| -Shape = jax.core.Shape |
44 |
| -PRNGKey = jax.random.KeyArray |
45 |
| -PyTreeDef = type(jax.tree_util.tree_structure(None)) |
46 |
| -if hasattr(jax, 'Device'): |
47 |
| - Device = jax.Device # jax >= 0.4.3 |
48 |
| -else: |
49 |
| - Device = jax.lib.xla_extension.Device |
| 40 | +Shape: TypeAlias = jax.core.Shape |
| 41 | +PRNGKey: TypeAlias = jax.random.KeyArray |
| 42 | +PyTreeDef: TypeAlias = jax.tree_util.PyTreeDef |
| 43 | +Device: TypeAlias = jax.Device # jax >= 0.4.3 |
50 | 44 | ArrayDType = type(jnp.float32)
|
0 commit comments