Skip to content

Commit 669a207

Browse files
authored
Remove ParamSpec from pjrt.py to unblock nightly build. (#4037)
1 parent 8fff44e commit 669a207

File tree

1 file changed

+3
-12
lines changed

1 file changed

+3
-12
lines changed

torch_xla/experimental/pjrt.py

Lines changed: 3 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -6,13 +6,7 @@
66
import os
77
import tempfile
88
from itertools import chain
9-
from typing import Callable, Dict, Generic, List, Optional, Tuple, TypeVar
10-
11-
import sys
12-
if sys.version_info < (3, 10):
13-
from typing_extensions import ParamSpec, Concatenate
14-
else:
15-
from typing import ParamSpec, Concatenate
9+
from typing import Callable, Dict, List, Optional, Tuple, TypeVar
1610

1711
import torch
1812
import torch.distributed as dist
@@ -23,7 +17,6 @@
2317
import torch_xla.utils.utils as xu
2418
from torch_xla.experimental import tpu
2519

26-
P = ParamSpec('P')
2720
R = TypeVar('R')
2821
FN = TypeVar('FN')
2922

@@ -198,8 +191,7 @@ def _thread_fn(device: torch.device):
198191

199192

200193
@requires_pjrt
201-
def _run_multiprocess(fn: Callable[P, R], *args: P.args,
202-
**kwargs: P.kwargs) -> Dict[int, R]:
194+
def _run_multiprocess(fn: Callable[..., R], *args, **kwargs) -> Dict[int, R]:
203195
"""Runs `fn` on all devices available to PjRt.
204196
205197
Spawns one process per physical device (e.g. TPU chip).
@@ -237,8 +229,7 @@ def _run_multiprocess(fn: Callable[P, R], *args: P.args,
237229
class _SpawnFn:
238230
"""Pickle-able wrapper around `fn` that passes the ordinal before `args`"""
239231

240-
def __init__(self, fn: Callable[Concatenate[int, P], R], *args: P.args,
241-
**kwargs: P.kwargs):
232+
def __init__(self, fn: Callable[..., R], *args, **kwargs):
242233
self.fn = fn
243234
self.args = args
244235
self.kwargs = kwargs

0 commit comments

Comments
 (0)