Skip to content

Commit 19d77cb

Browse files
committed
Register get_cpu_capability for jit
Context: In torchvision we ensure that functional ops are torchsciptable. Recently exposed `torch.backends.cpu.get_cpu_capability()` in pytorch#100164 is failing in torchvision CI ``` RuntimeError: Python builtin <built-in function _get_cpu_capability> is currently not supported in Torchscript: File "/usr/local/lib/python3.10/dist-packages/torch/backends/cpu/__init__.py", line 17 - "AVX512" """ return torch._C._get_cpu_capability() ~~~~~~~~~~~~~~~~~~~~~~~~~~~~ <--- HERE ``` Ref: pytorch/vision#7557 In this PR, `torch._C._get_cpu_capability()` is explicitly registered for JIT and tested.
1 parent d9d98b4 commit 19d77cb

File tree

3 files changed

+10
-0
lines changed

3 files changed

+10
-0
lines changed

test/test_torch.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7445,8 +7445,13 @@ def test_parallel_info(self):
74457445
torch.__config__.parallel_info()
74467446

74477447
def test_get_cpu_capability(self):
7448+
# This method is primarily exposed for torchvision's resize
74487449
torch.backends.cpu.get_cpu_capability()
74497450

7451+
# We have to ensure that method is torchscriptable as torchvision's resize
7452+
# should be torchscriptable
7453+
torch.jit.script(torch.backends.cpu.get_cpu_capability)
7454+
74507455
@slowTest
74517456
def test_slow_test(self):
74527457
# Just a smoketest to make sure our slowTest decorator works.

torch/csrc/jit/runtime/register_special_ops.cpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -457,6 +457,10 @@ RegisterOperators reg({
457457
"aten::set_grad_enabled(bool val) -> ()",
458458
[](Stack& stack) { torch::GradMode::set_enabled(pop(stack).toBool()); },
459459
aliasAnalysisConservative()),
460+
Operator(
461+
"aten::_get_cpu_capability() -> str",
462+
[](Stack& stack) { push(stack, at::get_cpu_capability()); },
463+
aliasAnalysisConservative()),
460464
});
461465
} // namespace
462466
} // namespace jit

torch/jit/_builtins.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -95,6 +95,7 @@
9595
(torch.nn.init._no_grad_uniform_, "aten::_no_grad_uniform_"),
9696
(torch.nn.init._no_grad_zero_, "aten::_no_grad_zero_"),
9797
(torch._C._get_tracing_state, "aten::_get_tracing_state"),
98+
(torch._C._get_cpu_capability, "aten::_get_cpu_capability"),
9899
(warnings.warn, "aten::warn"),
99100
(torch._VF.stft, "aten::stft"), # type: ignore[attr-defined]
100101
(torch._VF.istft, "aten::istft"), # type: ignore[attr-defined]

0 commit comments

Comments
 (0)