Skip to content

Commit dbf2353

Browse files
authored
Document how to create OpInfo tests (#2035)
1 parent d44853e commit dbf2353

File tree

1 file changed

+57
-5
lines changed

1 file changed

+57
-5
lines changed

tests/function_libs/torch_lib/README.md

Lines changed: 57 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,19 @@
1-
# Test op correctness by comparing with PyTorch results
1+
# Test op correctness by comparing with PyTorch results using OpInfo
2+
3+
`OpInfo` is PyTorch's standard mechanism for composing test data for operators.
4+
Read more about them on https://github.com/pytorch/pytorch/blob/ce4a097bf769d753712a1fd969b446c59e29d8b9/torch/testing/_internal/opinfo/core.py#L362.
25

36
## Usage
47

58
```bash
69
# All
7-
pytest onnxscript/tests/function_libs/torch_lib/ops_test.py
10+
python -m pytest onnxscript/tests/function_libs/torch_lib/ops_test.py
811

912
# To run tests on a specific operator (e.g. torch.ceil):
10-
pytest onnxscript/tests/function_libs/torch_lib/ops_test.py -k ceil
13+
python -m pytest onnxscript/tests/function_libs/torch_lib/ops_test.py -k ceil
1114

1215
# To run tests on a nn operator (e.g. nn.functional.scaled_dot_product_attention):
13-
pytest onnxscript/tests/function_libs/torch_lib/ops_test.py -k nn_functional_scaled_dot_product_attention
16+
python -m pytest onnxscript/tests/function_libs/torch_lib/ops_test.py -k nn_functional_scaled_dot_product_attention
1417
```
1518

1619
### Environment variables
@@ -25,4 +28,53 @@ in onnxruntime by running the inference sessions in a separate process.
2528

2629
## How to add a new operator test
2730

28-
See _usage_ in [ops_test_data.py](./ops_test_data.py)
31+
See _usage_ in [`ops_test_data.py`](./ops_test_data.py)
32+
33+
## How to add custom OpInfo tests
34+
35+
Sometimes, there is no existing OpInfo that fits our need to test an operator. You want to create a custom OpInfo for it.
36+
37+
Follow the steps below to create new OpInfo tests:
38+
39+
1. Use the implementation for `ops.aten.slice_scatter` as a reference (https://github.com/microsoft/onnxscript/blob/e67335101e4a06b8cc98cb4129935a9af5062c77/tests/function_libs/torch_lib/extra_opinfo.py#L2412-L2418) to declare an OpInfo in [`extra_opinfo.py`](./extra_opinfo.py)
40+
41+
```py
42+
opinfo_core.OpInfo(
43+
"ops.aten.slice_scatter",
44+
aten_name="slice_scatter",
45+
dtypes=common_dtype.all_types_and(torch.bfloat16, torch.half, torch.bool),
46+
sample_inputs_func=sample_inputs_slice_scatter,
47+
supports_out=False,
48+
),
49+
```
50+
51+
- The first argument should be the operator name under the `torch.ops` namespace. For example, if you want to test the `prims.var` op, then put `"ops.prims.var"`. It should almost always start with `ops.`.
52+
- Follow existing examples to specify the `dtypes` you want to test the op on.
53+
- Specify `op=` if the target operator is not the same as the OpInfo name (first arg). For example https://github.com/microsoft/onnxscript/blob/e67335101e4a06b8cc98cb4129935a9af5062c77/tests/function_libs/torch_lib/extra_opinfo.py#L2065-L2068.
54+
55+
```py
56+
opinfo_core.OpInfo(
57+
"ops.aten.bernoulli.p_deterministic",
58+
op=torch.ops.aten.bernoulli.p,
59+
```
60+
61+
The op is `torch.ops.aten.bernoulli.p`, which is different from the name `ops.aten.bernoulli.p_deterministic`. OpInfo names need to be globally unique in a test suite. When `op` is not specified, it will look for the op in `torch.` using its name.
62+
63+
2. Implement the `sample_inputs_func`. (Ref: https://github.com/microsoft/onnxscript/blob/e67335101e4a06b8cc98cb4129935a9af5062c77/tests/function_libs/torch_lib/extra_opinfo.py#L1242-L1268)
64+
1. Copy the function and decide what the input shapes should be. Use `make_arg` to generate a torch.Tensor. Alternatively you could also use `torch.tensor` to generate the tensor yourself. Be sure to double check the dtype and device. Finally yield each test cases with
65+
66+
```py
67+
yield opinfo_core.SampleInput(input, args=(...), kwargs={...})
68+
```
69+
70+
`input` is the first arg. The rest of the args are in `args`.
71+
3. Enable the test case in [`ops_test_data.py`](./ops_test_data.py)
72+
1. Add a `TorchLibOpInfo` entry to the `TESTED_TORCHLIB_OPS` list. (For example https://github.com/microsoft/onnxscript/blob/e67335101e4a06b8cc98cb4129935a9af5062c77/tests/function_libs/torch_lib/ops_test_data.py#L2116)
73+
74+
```py
75+
TorchLibOpInfo("ops.aten.slice_scatter", core_ops.aten_slice_scatter)
76+
```
77+
78+
You can additionally specify dtype tolerance (https://github.com/microsoft/onnxscript/blob/e67335101e4a06b8cc98cb4129935a9af5062c77/tests/function_libs/torch_lib/ops_test_data.py#L539) or conditional skips (https://github.com/microsoft/onnxscript/blob/e67335101e4a06b8cc98cb4129935a9af5062c77/tests/function_libs/torch_lib/ops_test_data.py#L586-L590).
79+
80+
Now that the test is added, you may run the test like mentioned above. Set `CREATE_REPRODUCTION_REPORT=1` to get markdown reports and view failing input combinations should any test case fails.

0 commit comments

Comments
 (0)