Skip to content

Commit ee6e22f

Browse files
committed
Update on "[ONNX] Introduce FX-ONNX dispatcher"
Needs microsoft/onnxscript#721 The current FX exporter is using manually maintained dictionary to map ATen op to its OnnxFunction. However, the issue arises when ATen op has overloads or OnnxFunction has overloads, which is not resolvable by the one to one mapping . For example, `aten::arange` has onverloads: `aten::arange.start` and `aten::arange.start_step`, or for `aten::argmax`, torchlib provides two function: aten_argmax, and aten_argmax_dim. This PR utilizes newly introduced [ONNX OpSchema](microsoft/onnxscript#626) to match the input arguments of an ATen operator to find the correct overload. ### OnnxRegistry Heavily reference on [TorchScript Registry](#84382). The only difference is that in FX registry, an ATen operator with specific opset version is mapped to a list of overloaded functions. * No longer use global registry. The registry is initialized in `ResolvedExportOptions` with torchlib, and will be exposed to users in the future. * Multiple opset version layer is kept through `_SymbolicFunctionGroup` , but torchlib now only supports 18. * Basic API of custom operator support: `register`, `unregister`, and `is_register_op` are kept for future development. To further complete them, the follow-up PRs should address: - How to allow users to remove/override specific overload? Using OpSchema to differentiate? - User registers a new overload with the same OpSchema as one of registered overload. ### OnnxDispatcher Dispatch ATen operators to the matched overload by comparing OpSchema with input arguments. * `OpSchemaWrapper` wrap the onnx schema, and record matching score. * `dispatch` uses `OpSchemaWrapper` to compare data types to find the best matched overload. If the match isn't perfect, record warning in diagnostics. * `dispatch_opset_version` is referenced from #84382 and kept, but torchlib doesn't support opset version != 18. * Because right now (1) OnnxFunction arguments are manually typed, and (2) ORT could unfollow ONNX type spec, we relax the schema match with `matching score system`. * To include more supports: the follow-up PRs should address: - How to add op.Cast with autocast? In torchlib or converter? - The need of type promotion can be captured by dispatcher, but needs OpSchema shows the T1/T2 information. ### OpSchemaWrapper - Matching Score Mechanism #### The matching score system: This is a temporary solution to how we target the correct ONNX overloads given that we only have manually annotated arguments (potentially inaccurate schema) and limited supports on AttributeProto. 1. Perfect match exam: If all arguments/kwargs are all matched, return the function without any warnings. 2. Best match exam: The system add the each correct matching input counts orderly, and subtract the symmetrical difference between their attributes to calculate the matching score. And select the one with the highest score in the end. If the selection is not a perfect match, a warning message is sent to SARIF. #### Example of overloads 1. Different types: Caused by the difference between the ONNX spec and PyTorch. The matching system finds the correct one. ```python torch_op("aten::mul") def aten_mul(self: TReal, other: TReal) -> TReal: ... torch_op("aten::mul") def aten_mul_bool(self: BOOL, other: BOOL) -> BOOL: ... ``` 2. Optional dim: caused by unsupported op.OptionalHasElement (will support on opset version == 20). dim could be "None" ```python torch_op("aten::argmax", trace_only=True) def aten_argmax( self: TrealOrUInt8, dim: Optional[int] = None, keepdim: bool = False ) -> TrealOrUInt8: ... torch_op("aten::argmax", private=True) def _aten_argmax_dim(self: TrealOrUInt8, dim: int, keepdim: bool = False) -> TrealOrUInt8: ... ``` This case is impossible to differentiate, as they both might have dim in kwargs, so in this case, please make sure you turn the one with `dim: int` to private function. 3. Optional dtype: dtype could be "unprovided". The difference from 2 is that dtype would not be None. ```python torch_op("aten::new_full") def aten_new_full(self: TTensor, size: INT64, fill_value: TTensor) -> TTensor: ... torch_op("aten::new_full") def aten_new_full_dtype(self: TTensor, size: INT64, fill_value: TTensor, dtype: int) -> TTensor: ... ``` Depends on dtype is provided or not, matching system will dispatch the ATen op to the correct one. 4. `None` and `[]` and `NoneType` are considered failing the match. 5. Two functions have the same score is recorded into SARIFs. ### TODOs 1. Type promotion can be captured by dispatcher only if OpSchema can provide it. However, the implementation of "graph-level" pass vs "in-op"" promotion can be further discussed in microsoft/onnxscript#563. 5. torchlib should provide the "opset version" to OnnxRegistry. 7. How to expose OnnxRegistry with custom add/remove ops APIs nneds to be further discussed. Co-authored-by: Justin Chu <justinchubymicrosoft.com> [ghstack-poisoned]
2 parents 153ba39 + 282518a commit ee6e22f

File tree

285 files changed

+13427
-5906
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

285 files changed

+13427
-5906
lines changed

.github/ci_commit_pins/torchbench.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
44ccad1d0647d56fe4e56c0b1933102be5dcc874
1+
f4acd1a7fcce986155c5e20beffa92b24ae0a3fa

.github/ci_commit_pins/vision.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
99ec261c72c097160e94653cce6f90f2d1209222
1+
d2f7486ccaef461913cdb51990ff353addf6f064

.github/scripts/label_utils.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515

1616
BOT_AUTHORS = ["github-actions", "pytorchmergebot", "pytorch-bot"]
1717

18-
LABEL_ERR_MSG_TITLE = "This PR needs a label"
18+
LABEL_ERR_MSG_TITLE = "This PR needs a `release notes:` label"
1919
LABEL_ERR_MSG = f"""# {LABEL_ERR_MSG_TITLE}
2020
If your changes are user facing and intended to be a part of release notes, please use a label starting with `release notes:`.
2121
@@ -111,7 +111,9 @@ def has_required_labels(pr: "GitHubPR") -> bool:
111111

112112

113113
def is_label_err_comment(comment: GitHubComment) -> bool:
114+
# comment.body_text returns text without markdown
115+
no_format_title = LABEL_ERR_MSG_TITLE.replace("`", "")
114116
return (
115-
comment.body_text.lstrip(" #").startswith(LABEL_ERR_MSG_TITLE)
117+
comment.body_text.lstrip(" #").startswith(no_format_title)
116118
and comment.author_login in BOT_AUTHORS
117119
)

.github/scripts/test_check_labels.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@ def mock_get_comments() -> List[GitHubComment]:
4444
),
4545
# Case 2 - a label err comment
4646
GitHubComment(
47-
body_text=" #" + LABEL_ERR_MSG_TITLE,
47+
body_text=" #" + LABEL_ERR_MSG_TITLE.replace("`", ""),
4848
created_at="",
4949
author_login=BOT_AUTHORS[1],
5050
author_association="",

.github/scripts/trymerge.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1262,9 +1262,9 @@ def find_matching_merge_rule(
12621262
reject_reason_score = num_matching_files
12631263
reject_reason = "\n".join(
12641264
(
1265-
f"Not all files match rule `{rule_name}`."
1266-
f"{num_matching_files} files matched, but there are still non-matching files:"
1267-
f"{','.join(non_matching_files[:5])}{', ...' if len(non_matching_files) > 5 else ''}"
1265+
f"Not all files match rule `{rule_name}`.",
1266+
f"{num_matching_files} files matched, but there are still non-matching files:",
1267+
f"{','.join(non_matching_files[:5])}{', ...' if len(non_matching_files) > 5 else ''}",
12681268
)
12691269
)
12701270
continue

.github/scripts/update_commit_hashes.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,13 +6,13 @@
66

77
import requests
88

9-
MERGEBOT_TOKEN = os.environ["MERGEBOT_TOKEN"]
9+
UPDATEBOT_TOKEN = os.environ["UPDATEBOT_TOKEN"]
1010
PYTORCHBOT_TOKEN = os.environ["PYTORCHBOT_TOKEN"]
1111
OWNER, REPO = "pytorch", "pytorch"
1212

1313

1414
def git_api(
15-
url: str, params: Dict[str, str], type: str = "get", token: str = MERGEBOT_TOKEN
15+
url: str, params: Dict[str, str], type: str = "get", token: str = UPDATEBOT_TOKEN
1616
) -> Any:
1717
headers = {
1818
"Accept": "application/vnd.github.v3+json",

.github/workflows/_update-commit-hash.yml

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ on:
2222
required: false
2323
default: .github/ci_commit_pins
2424
secrets:
25-
MERGEBOT_TOKEN:
25+
UPDATEBOT_TOKEN:
2626
required: true
2727
description: Permissions for opening PR
2828
PYTORCHBOT_TOKEN:
@@ -41,7 +41,7 @@ jobs:
4141
with:
4242
fetch-depth: 1
4343
submodules: false
44-
token: ${{ secrets.MERGEBOT_TOKEN }}
44+
token: ${{ secrets.UPDATEBOT_TOKEN }}
4545

4646
- name: Checkout
4747
shell: bash
@@ -54,11 +54,11 @@ jobs:
5454
REPO_NAME: ${{ inputs.repo-name }}
5555
BRANCH: ${{ inputs.branch }}
5656
PIN_FOLDER: ${{ inputs.pin-folder }}
57-
MERGEBOT_TOKEN: ${{ secrets.MERGEBOT_TOKEN }}
57+
UPDATEBOT_TOKEN: ${{ secrets.UPDATEBOT_TOKEN }}
5858
PYTORCHBOT_TOKEN: ${{ secrets.PYTORCHBOT_TOKEN }}
5959
run: |
6060
# put this here instead of the script to prevent accidentally changing the config when running the script locally
61-
git config --global user.name "PyTorch MergeBot"
62-
git config --global user.email "pytorchmergebot@users.noreply.github.com"
61+
git config --global user.name "PyTorch UpdateBot"
62+
git config --global user.email "pytorchupdatebot@users.noreply.github.com"
6363
6464
python .github/scripts/update_commit_hashes.py --repo-name "${REPO_NAME}" --branch "${BRANCH}" --pin-folder "${PIN_FOLDER}"

.github/workflows/check-labels.yml

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,22 @@
11
name: Check Labels
22

33
on:
4+
# We need pull_request_target to be able to post comments on PRs from forks.
5+
# Only allow pull_request_target when merging to main, not some historical branch.
6+
#
7+
# Make sure to don't introduce explicit checking out and installing/running
8+
# untrusted user code into this workflow!
9+
pull_request_target:
10+
types: [opened, synchronize, reopened, labeled, unlabeled]
11+
branches: [main]
12+
paths-ignore: [.github]
13+
14+
# To allow testing PRs that change workflows.
15+
# May be triggered together with pull_request_target, it's OK.
416
pull_request:
517
types: [opened, synchronize, reopened, labeled, unlabeled]
18+
paths: [.github]
19+
620
workflow_dispatch:
721

822
concurrency:

.github/workflows/inductor.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -73,7 +73,7 @@ jobs:
7373
docker-image-name: pytorch-linux-focal-py3.8-gcc7
7474
test-matrix: |
7575
{ include: [
76-
{ config: "inductor_huggingface_cpu_accuracy", shard: 1, num_shards: 1, runner: "linux.4xlarge" },
76+
{ config: "inductor_huggingface_cpu_accuracy", shard: 1, num_shards: 1, runner: "linux.12xlarge" },
7777
{ config: "inductor_timm_cpu_accuracy", shard: 1, num_shards: 2, runner: "linux.4xlarge" },
7878
{ config: "inductor_timm_cpu_accuracy", shard: 2, num_shards: 2, runner: "linux.4xlarge" },
7979
{ config: "inductor_torchbench_cpu_accuracy", shard: 1, num_shards: 1, runner: "linux.4xlarge" },

.github/workflows/lint-bc.yml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,8 @@ on:
88
- reopened
99
- labeled
1010
- unlabeled
11+
branches-ignore:
12+
- nightly
1113
workflow_dispatch:
1214

1315
jobs:

.github/workflows/lint.yml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,8 @@ name: Lint
22

33
on:
44
pull_request:
5+
branches-ignore:
6+
- nightly
57
push:
68
branches:
79
- main

.github/workflows/nightly.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -43,5 +43,5 @@ jobs:
4343
repo-name: vision
4444
branch: main
4545
secrets:
46-
MERGEBOT_TOKEN: ${{ secrets.MERGEBOT_TOKEN }}
46+
UPDATEBOT_TOKEN: ${{ secrets.UPDATEBOT_TOKEN }}
4747
PYTORCHBOT_TOKEN: ${{ secrets.GH_PYTORCHBOT_TOKEN }}

.github/workflows/pull.yml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,8 @@ name: pull
22

33
on:
44
pull_request:
5+
branches-ignore:
6+
- nightly
57
push:
68
branches:
79
- main

.github/workflows/revert.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ jobs:
88
do_revert:
99
name: try_revert_pr_${{ github.event.client_payload.pr_num }}
1010
runs-on: linux.20_04.4x
11+
environment: mergebot
1112
env:
1213
GH_RUN_URL: ${{ github.server_url }}/${{ github.repository }}/actions/runs/${{ github.run_id }}
1314
steps:

.github/workflows/trymerge.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ jobs:
88
do_merge:
99
name: try_merge_pr_${{ github.event.client_payload.pr_num }}
1010
runs-on: linux.20_04.4x
11+
environment: mergebot
1112
env:
1213
GH_RUN_URL: ${{ github.server_url }}/${{ github.repository }}/actions/runs/${{ github.run_id }}
1314
steps:

.github/workflows/tryrebase.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ on:
77
jobs:
88
do_rebase:
99
runs-on: ubuntu-20.04
10+
environment: mergebot
1011
env:
1112
GH_RUN_URL: ${{ github.server_url }}/${{ github.repository }}/actions/runs/${{ github.run_id }}
1213
steps:

.github/workflows/update-viablestrict.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ concurrency:
1212
jobs:
1313
do_update_viablestrict:
1414
runs-on: ubuntu-20.04
15+
environment: mergebot
1516
steps:
1617
- name: Checkout repo
1718
uses: actions/checkout@v3

.github/workflows/upload-test-stats.yml

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -91,8 +91,6 @@ jobs:
9191
env:
9292
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
9393
PYTORCHBOT_TOKEN: ${{ secrets.GH_PYTORCHBOT_TOKEN}}
94-
MERGEBOT_TOKEN: ${{ secrets.MERGEBOT_TOKEN}}
9594
run: |
9695
curl -H "Accept: application/vnd.github.v3+json" -H "Authorization: token $GITHUB_TOKEN" https://api.github.com/rate_limit
9796
curl -H "Accept: application/vnd.github.v3+json" -H "Authorization: token $PYTORCHBOT_TOKEN" https://api.github.com/rate_limit
98-
curl -H "Accept: application/vnd.github.v3+json" -H "Authorization: token $MERGEBOT_TOKEN" https://api.github.com/rate_limit

.github/workflows/weekly.yml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ jobs:
1515
repo-name: xla
1616
branch: master
1717
secrets:
18-
MERGEBOT_TOKEN: ${{ secrets.MERGEBOT_TOKEN }}
18+
UPDATEBOT_TOKEN: ${{ secrets.UPDATEBOT_TOKEN }}
1919
PYTORCHBOT_TOKEN: ${{ secrets.GH_PYTORCHBOT_TOKEN }}
2020

2121
update-triton-commit-hash:
@@ -26,5 +26,5 @@ jobs:
2626
branch: main
2727
pin-folder: .ci/docker/ci_commit_pins
2828
secrets:
29-
MERGEBOT_TOKEN: ${{ secrets.MERGEBOT_TOKEN }}
29+
UPDATEBOT_TOKEN: ${{ secrets.UPDATEBOT_TOKEN }}
3030
PYTORCHBOT_TOKEN: ${{ secrets.GH_PYTORCHBOT_TOKEN }}

.lintrunner.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1018,6 +1018,6 @@ init_command = [
10181018
'python3',
10191019
'tools/linter/adapters/pip_init.py',
10201020
'--dry-run={{DRYRUN}}',
1021-
'ruff==0.0.265',
1021+
'ruff==0.0.269',
10221022
]
10231023
is_formatter = true

BUILD.bazel

Lines changed: 34 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1696,6 +1696,36 @@ pybind_extension(
16961696
],
16971697
)
16981698

1699+
cc_library(
1700+
name = "functorch",
1701+
hdrs = glob([
1702+
"functorch/csrc/dim/*.h",
1703+
]),
1704+
srcs = glob([
1705+
"functorch/csrc/dim/*.cpp",
1706+
]),
1707+
deps = [
1708+
":aten_nvrtc",
1709+
":torch_python",
1710+
"@pybind11",
1711+
],
1712+
)
1713+
1714+
pybind_extension(
1715+
name = "functorch/_C",
1716+
copts=[
1717+
"-DTORCH_EXTENSION_NAME=_C"
1718+
],
1719+
srcs = [
1720+
"functorch/csrc/init_dim_only.cpp",
1721+
],
1722+
deps = [
1723+
":functorch",
1724+
":torch_python",
1725+
":aten_nvrtc",
1726+
],
1727+
)
1728+
16991729
cc_binary(
17001730
name = "torch/bin/torch_shm_manager",
17011731
srcs = [
@@ -1724,7 +1754,7 @@ template_rule(
17241754
rules.py_library(
17251755
name = "pytorch_py",
17261756
visibility = ["//visibility:public"],
1727-
srcs = glob(["torch/**/*.py"], exclude = ["torch/version.py"]) + [":torch/version.py"],
1757+
srcs = glob(["torch/**/*.py"], exclude = ["torch/version.py"]) + [":torch/version.py"] + glob(["functorch/**/*.py"]),
17281758
deps = [
17291759
rules.requirement("future"),
17301760
rules.requirement("numpy"),
@@ -1737,6 +1767,7 @@ rules.py_library(
17371767
],
17381768
data = [
17391769
":torch/_C.so",
1770+
":functorch/_C.so",
17401771
":torch/bin/torch_shm_manager",
17411772
],
17421773
)
@@ -1903,7 +1934,8 @@ cc_test(
19031934

19041935
py_test(
19051936
name = "test_bazel",
1906-
srcs = ["test/test_bazel.py"],
1937+
srcs = ["test/_test_bazel.py"],
1938+
main = "test/_test_bazel.py",
19071939
deps = [":pytorch_py"],
19081940
)
19091941

CONTRIBUTING.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1228,7 +1228,7 @@ an active PR, CI jobs will be run automatically. Some of these may
12281228
fail and you will need to find out why, by looking at the logs.
12291229
12301230
Fairly often, a CI failure might be unrelated to your changes. You can
1231-
confirm by going to our [HUD](hud.pytorch.org) and seeing if the CI job
1231+
confirm by going to our [HUD](https://hud.pytorch.org) and seeing if the CI job
12321232
is failing upstream already. In this case, you
12331233
can usually ignore the failure. See [the following
12341234
subsection](#which-commit-is-used-in-ci) for more details.

README.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -434,14 +434,14 @@ Three-pointers to get you started:
434434

435435
## Releases and Contributing
436436

437-
PyTorch has a 90-day release cycle (major releases). Please let us know if you encounter a bug by [filing an issue](https://github.com/pytorch/pytorch/issues).
437+
Typically, PyTorch has three major releases a year. Please let us know if you encounter a bug by [filing an issue](https://github.com/pytorch/pytorch/issues).
438438

439439
We appreciate all contributions. If you are planning to contribute back bug-fixes, please do so without any further discussion.
440440

441441
If you plan to contribute new features, utility functions, or extensions to the core, please first open an issue and discuss the feature with us.
442442
Sending a PR without discussion might end up resulting in a rejected PR because we might be taking the core in a different direction than you might be aware of.
443443

444-
To learn more about making a contribution to Pytorch, please see our [Contribution page](CONTRIBUTING.md).
444+
To learn more about making a contribution to Pytorch, please see our [Contribution page](CONTRIBUTING.md). For more information about PyTorch releases, see [Release page](RELEASE.md).
445445

446446
## The Team
447447

aten/conda/meta.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ requirements:
2626
about:
2727
home: https://github.com/pytorch/pytorch
2828
license: BSD
29-
summary: A TENsor library for C++14
29+
summary: A TENsor library for C++17
3030

3131
extra:
3232
recipe-maintainers:

aten/src/ATen/ATen.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
#pragma once
22

3-
#if !defined(_MSC_VER) && __cplusplus < 201402L
4-
#error C++14 or later compatible compiler is required to use ATen.
3+
#if !defined(_MSC_VER) && __cplusplus < 201703L
4+
#error C++17 or later compatible compiler is required to use ATen.
55
#endif
66

77
#include <ATen/Context.h>

aten/src/ATen/LegacyVmapTransforms.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@ using VmapPhysicalViewVec =
4040
// dimensions to get 8. Adjust this number as necessary
4141
constexpr int64_t kVmapStaticDimVecSize = 8;
4242
using VmapDimVector = SmallVector<int64_t, kVmapStaticDimVecSize>;
43+
using VmapSymDimVector = SmallVector<c10::SymInt, kVmapStaticDimVecSize>;
4344

4445
// NOTE: [What is an VmapTransform?]
4546
// An *VmapTransform* converts logical views of tensors to physical views.

aten/src/ATen/cpu/tbb/CMakeLists.txt

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -121,9 +121,9 @@ endif()
121121
if(UNIX)
122122
add_definitions(-DUSE_PTHREAD)
123123

124-
check_cxx_compiler_flag("-std=c++14" SUPPORTS_STDCXX14)
125-
if(SUPPORTS_STDCXX14)
126-
set(CMAKE_CXX_FLAGS "-std=c++14 ${CMAKE_CXX_FLAGS}")
124+
check_cxx_compiler_flag("-std=c++17" SUPPORTS_STDCXX17)
125+
if(SUPPORTS_STDCXX17)
126+
set(CMAKE_CXX_FLAGS "-std=c++17 ${CMAKE_CXX_FLAGS}")
127127
endif()
128128

129129
check_cxx_compiler_flag("-mrtm -Werror" SUPPORTS_MRTM)

aten/src/ATen/cpu/vec/vec256/vsx/vec256_common_vsx.h

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -127,6 +127,30 @@ inline void convert(const int64_t* src, double* dst, int64_t n) {
127127
dst[i] = static_cast<double>(src[i]);
128128
}
129129
}
130+
//Generic implementation to fix compiler error
131+
//TO-DO : Add optimized version for ppc64
132+
inline std::tuple<Vectorized<float>, Vectorized<float>> convert_half_float(
133+
const Vectorized<Half>& a) {
134+
constexpr int64_t K = Vectorized<Half>::size();
135+
__at_align__ float arr[K];
136+
__at_align__ Half arr2[K];
137+
a.store(arr2);
138+
convert(arr2, arr, K);
139+
return std::make_tuple(
140+
Vectorized<float>::loadu(arr),
141+
Vectorized<float>::loadu(arr + Vectorized<float>::size()));
142+
}
143+
144+
inline Vectorized<Half> convert_float_half(
145+
const Vectorized<float>& a, const Vectorized<float>& b) {
146+
constexpr int64_t K = Vectorized<Half>::size();
147+
__at_align__ float arr[K];
148+
__at_align__ Half arr2[K];
149+
a.store(arr);
150+
b.store(arr + Vectorized<float>::size());
151+
convert(arr, arr2, K);
152+
return Vectorized<Half>::loadu(arr2);
153+
};
130154

131155
template <>
132156
std::pair<Vectorized<double>, Vectorized<double>> inline interleave2<double>(

0 commit comments

Comments
 (0)