Apply UP035 `ruff` rule in tests, but some tests for `fx` and `dynamo` are excluded in case the old typing is the test target. Pull Request resolved: https://github.com/pytorch/pytorch/pull/163947 Approved by: https://github.com/ezyang
Test op correctness by comparing with PyTorch results using OpInfo
OpInfo
is PyTorch's standard mechanism for composing test data for operators.
Read more about them on ce4a097bf7/torch/testing/_internal/opinfo/core.py (L362)
.
Usage
# All
python -m pytest test_ops.py
# To run tests on a specific operator (e.g. torch.ceil):
python -m pytest test_ops.py -k ceil
# To run tests on a nn operator (e.g. nn.functional.scaled_dot_product_attention):
python -m pytest test_ops.py -k nn_functional_scaled_dot_product_attention
Environment variables
-
Set environment variable
CATCH_ORT_SEGFAULT=1
to catch segmentation faults in onnxruntime by running the inference sessions in a separate process. -
Set
CREATE_REPRODUCTION_REPORT=1
to create markdown files for reproduction of errors. E.g.CREATE_REPRODUCTION_REPORT=1 python -m pytest test/onnx/torchlib/test_ops.py -k div_mode_int
How to add a new operator test
See usage in ops_test_data.py
How to add custom OpInfo tests
Sometimes, there is no existing OpInfo that fits our need to test an operator. You want to create a custom OpInfo for it.
Follow the steps below to create new OpInfo tests:
-
Use the implementation for
ops.aten.slice_scatter
as a reference (e67335101e/tests/function_libs/torch_lib/extra_opinfo.py (L2412-L2418)
) to declare anOpInfo
inextra_opinfo.py
.opinfo_core.OpInfo( "ops.aten.slice_scatter", aten_name="slice_scatter", dtypes=common_dtype.all_types_and(torch.bfloat16, torch.half, torch.bool), sample_inputs_func=sample_inputs_slice_scatter, supports_out=False, ),
- The first argument should be the operator name under the
torch.ops
namespace. For example, if you want to test theprims.var
op, then put"ops.prims.var"
. It should almost always start withops.
. - Follow existing examples to specify the
dtypes
you want to test the op on. - Specify
op=
if the target operator is not the same as the OpInfo name (first arg). For examplee67335101e/tests/function_libs/torch_lib/extra_opinfo.py (L2065-L2068)
.
opinfo_core.OpInfo( "ops.aten.bernoulli.p_deterministic", op=torch.ops.aten.bernoulli.p,
The op is
torch.ops.aten.bernoulli.p
, which is different from the nameops.aten.bernoulli.p_deterministic
. OpInfo names need to be globally unique in a test suite. Whenop
is not specified, it will look for the op intorch.
using its name. - The first argument should be the operator name under the
-
Implement the
sample_inputs_func
. (Ref:e67335101e/tests/function_libs/torch_lib/extra_opinfo.py (L1242-L1268)
)- Copy the function and decide what the input shapes should be. Use
make_arg
to generate a torch.Tensor. Alternatively you could also usetorch.tensor
to generate the tensor yourself. Be sure to double check the dtype and device. Finally yield each test cases with
yield opinfo_core.SampleInput(input, args=(...), kwargs={...})
input
is the first arg. The rest of the args are inargs
. - Copy the function and decide what the input shapes should be. Use
-
Enable the test case in
ops_test_data.py
- Add a
TorchLibOpInfo
entry to theTESTED_TORCHLIB_OPS
list. (For examplee67335101e/tests/function_libs/torch_lib/ops_test_data.py (L2116)
)
TorchLibOpInfo("ops.aten.slice_scatter", core_ops.aten_slice_scatter)
You can additionally specify dtype tolerance (
e67335101e/tests/function_libs/torch_lib/ops_test_data.py (L539)
) or conditional skips (e67335101e/tests/function_libs/torch_lib/ops_test_data.py (L586-L590)
). - Add a
Now that the test is added, you may run the test like mentioned above. Set CREATE_REPRODUCTION_REPORT=1
to get markdown reports and view failing input combinations should any test case fails.