|
6 | 6 | import os
|
7 | 7 | import tempfile
|
8 | 8 | from itertools import chain
|
9 |
| -from typing import Callable, Dict, Generic, List, Optional, Tuple, TypeVar |
10 |
| - |
11 |
| -import sys |
12 |
| -if sys.version_info < (3, 10): |
13 |
| - from typing_extensions import ParamSpec, Concatenate |
14 |
| -else: |
15 |
| - from typing import ParamSpec, Concatenate |
| 9 | +from typing import Callable, Dict, List, Optional, Tuple, TypeVar |
16 | 10 |
|
17 | 11 | import torch
|
18 | 12 | import torch.distributed as dist
|
|
23 | 17 | import torch_xla.utils.utils as xu
|
24 | 18 | from torch_xla.experimental import tpu
|
25 | 19 |
|
26 |
| -P = ParamSpec('P') |
27 | 20 | R = TypeVar('R')
|
28 | 21 | FN = TypeVar('FN')
|
29 | 22 |
|
@@ -198,8 +191,7 @@ def _thread_fn(device: torch.device):
|
198 | 191 |
|
199 | 192 |
|
200 | 193 | @requires_pjrt
|
201 |
| -def _run_multiprocess(fn: Callable[P, R], *args: P.args, |
202 |
| - **kwargs: P.kwargs) -> Dict[int, R]: |
| 194 | +def _run_multiprocess(fn: Callable[..., R], *args, **kwargs) -> Dict[int, R]: |
203 | 195 | """Runs `fn` on all devices available to PjRt.
|
204 | 196 |
|
205 | 197 | Spawns one process per physical device (e.g. TPU chip).
|
@@ -237,8 +229,7 @@ def _run_multiprocess(fn: Callable[P, R], *args: P.args,
|
237 | 229 | class _SpawnFn:
|
238 | 230 | """Pickle-able wrapper around `fn` that passes the ordinal before `args`"""
|
239 | 231 |
|
240 |
| - def __init__(self, fn: Callable[Concatenate[int, P], R], *args: P.args, |
241 |
| - **kwargs: P.kwargs): |
| 232 | + def __init__(self, fn: Callable[..., R], *args, **kwargs): |
242 | 233 | self.fn = fn
|
243 | 234 | self.args = args
|
244 | 235 | self.kwargs = kwargs
|
|
0 commit comments