Set seed per sample for OpInfo tests + support for restricting to a single sample input (#128238)

This PR:
* Sets a random seed before generating each sample for an OpInfo test. It does this by intercepting the sample input iterator via `TrackedInputIter`, optionally setting the seed to a test name specific seed before each iterator call (default is to set the seed).
    * Some quick and dirty benchmarking shows (hopefully) negligible overhead from setting the random seed before each sample input generation. For a trivial (single assert) test that uses `@ops`:
* Uncovered a bunch of test issues:
    * Test breakdown (>100 total)
        * A lot of tolerance issues (tweaked tolerance values to fix)
        * 1 broken OpInfo (`sample_inputs_masked_fill` was generating a sample of the wrong dtype)
        * 3 actually broken semantics (for masked tensor; added xfails)
        * 4 Jacobian mismatches (added xfails)
        * 2 nan results (skip for now, need fixing)
        * 3 results too far from reference result (add xfails)
* Skips MPS tests for now (there are so many failures!). Those will default to the old behavior.

**before (no seed setting):**
```
real	0m21.306s
user	0m19.053s
sys	0m5.192s
```

**after (with seed setting):**
```
real	0m21.905s
user	0m19.578s
sys	0m5.390s
```

* Utilizing the above for reproducible sample input generation, adds support for restricting the iterator to a single sample input. This is done via an env var `PYTORCH_OPINFO_SAMPLE_INPUT_INDEX` and its usage is included in the repro command.

```
======================================================================
ERROR: test_bar_add_cuda_uint8 (__main__.TestFooCUDA.test_bar_add_cuda_uint8)
----------------------------------------------------------------------
Traceback (most recent call last):
  File "/home/jbschlosser/branches/testing_updates/torch/testing/_internal/common_device_type.py", line 971, in test_wrapper
    return test(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^
  File "/home/jbschlosser/branches/testing_updates/test/test_ops.py", line 2671, in test_bar
    self.assertFalse(True)
AssertionError: True is not false

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File "/home/jbschlosser/branches/testing_updates/torch/testing/_internal/common_utils.py", line 2816, in wrapper
    method(*args, **kwargs)
  File "/home/jbschlosser/branches/testing_updates/torch/testing/_internal/common_utils.py", line 2816, in wrapper
    method(*args, **kwargs)
  File "/home/jbschlosser/branches/testing_updates/torch/testing/_internal/common_device_type.py", line 419, in instantiated_test
    result = test(self, **param_kwargs)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/jbschlosser/branches/testing_updates/torch/testing/_internal/common_utils.py", line 1426, in wrapper
    fn(*args, **kwargs)
  File "/home/jbschlosser/branches/testing_updates/torch/testing/_internal/common_device_type.py", line 982, in test_wrapper
    raise new_e from e
Exception: Caused by sample input at index 3: SampleInput(input=Tensor[size=(10, 5), device="cuda:0", dtype=torch.uint8], args=TensorList[Tensor[size=(), device="cuda:0", dtype=torch.uint8]], kwargs={}, broadcasts_input=False, name='')

To execute this test, run the following from the base repo dir:
    PYTORCH_OPINFO_SAMPLE_INPUT_INDEX=3 python test/test_ops.py -k TestFooCUDA.test_bar_add_cuda_uint8

This message can be suppressed by setting PYTORCH_PRINT_REPRO_ON_FAILURE=0

----------------------------------------------------------------------
Ran 1 test in 0.037s

FAILED (errors=1)
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/128238
Approved by: https://github.com/janeyx99, https://github.com/justinchuby
This commit is contained in:
Joel Schlosser
2024-07-05 18:48:32 -04:00
committed by PyTorch MergeBot
parent acf9e31cf8
commit c8ab2e8b63
13 changed files with 572 additions and 58 deletions

View File

@ -3397,6 +3397,8 @@ class TestComposability(TestCase):
new_cotangent = torch.randn(())
self.assertEqual(fx_f(new_cotangent, True, True), vjp_fn(new_cotangent))
# FIXME: test fails in Windows
@unittest.skipIf(IS_WINDOWS, "fails in Windows; needs investigation")
@unittest.skipIf(IS_FBCODE, "can't subprocess in fbcode")
# it is redundant to run this test twice on a machine that has GPUs
@onlyCPU

View File

@ -461,6 +461,11 @@ class TestOperators(TestCase):
{torch.float32: tol(atol=3e-04, rtol=3e-04)},
device_type="cuda",
),
tol1(
"linalg.multi_dot",
{torch.float32: tol(atol=1e-05, rtol=8e-04)},
device_type="cuda",
),
tol1(
"linalg.tensorsolve",
{torch.float32: tol(atol=3e-04, rtol=3e-04)},
@ -480,6 +485,11 @@ class TestOperators(TestCase):
{torch.float32: tol(atol=3e-04, rtol=3e-04)},
device_type="cuda",
),
tol1(
"pca_lowrank",
{torch.float32: tol(atol=3e-05, rtol=4e-06)},
device_type="cpu",
),
),
)
def test_grad(self, device, dtype, op):
@ -601,6 +611,11 @@ class TestOperators(TestCase):
{torch.float32: tol(atol=1e-04, rtol=1.3e-05)},
device_type="cuda",
),
tol1(
"masked.prod",
{torch.float32: tol(atol=1e-05, rtol=1.3e-05)},
device_type="cuda",
),
tol1(
"nn.functional.binary_cross_entropy_with_logits",
{torch.float32: tol(atol=4e-04, rtol=4e-04)},
@ -615,6 +630,9 @@ class TestOperators(TestCase):
"nn.functional.multi_head_attention_forward",
{torch.float32: tol(atol=6e-05, rtol=2e-05)},
),
tol2(
"linalg.pinv", "hermitian", {torch.float32: tol(atol=5e-5, rtol=2e-5)}
),
),
)
def test_jvp(self, device, dtype, op):
@ -766,7 +784,7 @@ class TestOperators(TestCase):
tol2(
"linalg.pinv", "hermitian", {torch.float32: tol(atol=1e-05, rtol=1e-05)}
),
tol1("linalg.tensorsolve", {torch.float32: tol(atol=1e-05, rtol=1e-05)}),
tol1("linalg.tensorsolve", {torch.float32: tol(atol=4e-05, rtol=5e-05)}),
tol1("linalg.multi_dot", {torch.float32: tol(atol=1e-04, rtol=1e-04)}),
tol1("svd_lowrank", {torch.float32: tol(atol=1e-04, rtol=1e-04)}),
tol1("pca_lowrank", {torch.float32: tol(atol=1e-04, rtol=1e-04)}),
@ -936,6 +954,11 @@ class TestOperators(TestCase):
decorate(
"linalg.householder_product", decorator=runOnRocm
), # works on ROCm
xfail(
# nans
"masked.softmax",
device_type="cpu",
),
xfail(
"nanquantile", device_type="cpu"
), # vmap not implemented for at::equal.
@ -1028,9 +1051,12 @@ class TestOperators(TestCase):
"test_vmapvjpvjp",
(
tol1("linalg.svd", {torch.float32: tol(atol=1e-03, rtol=5e-04)}),
tol1("linalg.lu", {torch.float32: tol(atol=5e-04, rtol=7e-04)}),
tol1("linalg.lu_factor", {torch.float32: tol(atol=2e-03, rtol=2e-02)}),
tol1("linalg.multi_dot", {torch.float32: tol(atol=2e-03, rtol=2e-04)}),
tol1("svd", {torch.float32: tol(atol=1e-03, rtol=5e-04)}),
tol1("matrix_exp", {torch.float32: tol(atol=1e-03, rtol=5e-04)}),
tol1("masked.prod", {torch.float32: tol(atol=2e-03, rtol=2e-04)}),
),
)
@skipOps(
@ -1175,13 +1201,23 @@ class TestOperators(TestCase):
),
tol1(
"linalg.householder_product",
{torch.float32: tol(atol=1e-04, rtol=1e-04)},
{torch.float32: tol(atol=3e-04, rtol=9e-04)},
),
tol1(
"matrix_exp",
{torch.float32: tol(atol=5e-04, rtol=1e-04)},
device_type="cuda",
),
tol1(
"nn.functional.layer_norm",
{torch.float32: tol(atol=3e-4, rtol=1e-4)},
device_type="cpu",
),
tol1(
"native_layer_norm",
{torch.float32: tol(atol=3e-4, rtol=1e-4)},
device_type="cpu",
),
),
)
@skipOps(
@ -1796,7 +1832,12 @@ class TestOperators(TestCase):
tol1("masked.cumprod", {torch.float32: tol(atol=1e-04, rtol=5e-04)}),
tol1(
"cumprod",
{torch.float32: tol(atol=1e-04, rtol=1.3e-05)},
{torch.float32: tol(atol=1e-03, rtol=5e-04)},
device_type="cuda",
),
tol1(
"linalg.det",
{torch.float32: tol(atol=3e-05, rtol=5e-06)},
device_type="cuda",
),
tol1(
@ -2369,6 +2410,11 @@ class TestOperators(TestCase):
"TestOperators",
"test_vmap_autograd_grad",
(
tol1(
"ldexp",
{torch.float32: tol(atol=3e-04, rtol=1.6e-06)},
device_type="cuda",
),
tol1(
"linalg.householder_product",
{torch.float32: tol(atol=5e-04, rtol=9e-03)},
@ -2376,7 +2422,7 @@ class TestOperators(TestCase):
),
tol1(
"linalg.householder_product",
{torch.float32: tol(atol=1e-04, rtol=1e-04)},
{torch.float32: tol(atol=6e-03, rtol=1e-03)},
device_type="cpu",
),
tol1(

View File

@ -1033,11 +1033,6 @@ EXPECTED_SKIPS_OR_FAILS_WITH_DTYPES: Tuple[onnx_test_common.DecorateMeta, ...] =
dtypes=(torch.float16,),
reason="fixme: Assertion error: result mismatch",
),
xfail(
"nn.functional.tanhshrink",
dtypes=(torch.float16,),
reason="fixme: Assertion error: result mismatch",
),
xfail(
"nonzero",
dtypes=(torch.int8, torch.int16),
@ -1227,7 +1222,7 @@ EXPECTED_SKIPS_OR_FAILS_WITH_DTYPES: Tuple[onnx_test_common.DecorateMeta, ...] =
),
xfail(
"topk",
dtypes=(torch.int64, torch.int32),
dtypes=(torch.int64, torch.int32, torch.float16),
reason="fixme: Assertion error: result mismatch",
),
xfail(
@ -1992,7 +1987,7 @@ class TestOnnxModelOutputConsistency(onnx_test_common._TestONNXRuntime):
fp16_low_precision_dict = {
"addbmm": [2e-1, 2e-2],
"addcdiv": [3e-2, 1e-3],
"addcdiv": [3e-2, 1.4e-3],
"addcmul": [3e-2, 1e-3],
"addmv": [5e-2, 3e-2],
"addr": [3e-3, 4e-3],
@ -2000,6 +1995,7 @@ class TestOnnxModelOutputConsistency(onnx_test_common._TestONNXRuntime):
"cumulative_trapezoid": [3e-2, 1e-3],
"cross": [3e-2, 2e-2],
"diff": [1e-2, 5e-2],
"div": [5e-3, 1e-3],
"gradient": [3e-3, 4e-3],
"linalg.cross": [1e-3, 2e-2],
"linalg.multi_dot": [3e-2, 1e-3],
@ -2008,9 +2004,10 @@ class TestOnnxModelOutputConsistency(onnx_test_common._TestONNXRuntime):
"masked.std": [2e-2, 2e-3],
"masked.var": [2e-2, 2e-2],
"matmul": [2e-2, 6e-2],
"mv": [9e-3, 1e-5],
"nn.functional.batch_norm": [3e-2, 1e-3],
"nn.functional.binary_cross_entropy": [3e-2, 1e-3],
"nn.functional.binary_cross_entropy_with_logits": [3e-2, 1e-3],
"nn.functional.binary_cross_entropy_with_logits": [4e-2, 4e-3],
"nn.functional.cosine_similarity": [3e-2, 1e-3],
"nn.functional.cosine_embedding_loss": [1e-2, 1e-3],
"nn.functional.hardsigmoid": [1e-3, 5e-3],
@ -2022,7 +2019,7 @@ class TestOnnxModelOutputConsistency(onnx_test_common._TestONNXRuntime):
"nn.functional.kl_div": [2e-3, 2e-4],
"nn.functional.multilabel_soft_margin_loss": [4e-2, 5e-3],
"nn.functional.local_response_norm": [1e-2, 5e-3],
"nn.functional.poisson_nll_loss": [3e-2, 1e-3],
"nn.functional.poisson_nll_loss": [4e-2, 6e-3],
"nn.functional.nll_loss": [3e-2, 1e-3],
"nn.functional.triplet_margin_loss": [2e-2, 1e-2],
"nn.functional.triplet_margin_with_distance_loss": [3e-2, 1e-2],

View File

@ -211,11 +211,13 @@ def op_assert_ref(test_case, op, test_dtype, i, orig, decomp, ref, args, kwargs)
(torch.float16, torch.ops.aten.hardswish.default): 2e-7,
(torch.bfloat16, torch.ops.aten.hardswish.default): 2e-7,
(torch.float16, torch.ops.aten.multi_margin_loss.default): 3e-2,
(torch.bfloat16, torch.ops.aten.multi_margin_loss.default): 3e-2,
(torch.bfloat16, torch.ops.aten.multi_margin_loss.default): 5e-2,
(torch.float16, torch.ops.aten.multilabel_margin_loss_forward.default): 3e-2,
(torch.bfloat16, torch.ops.aten.multilabel_margin_loss_forward.default): 3e-2,
# see https://github.com/pytorch/pytorch/pull/96264
(torch.float16, torch.ops.aten.mv.default): 1e-5,
(torch.bfloat16, torch.ops.aten.mv.default): 1e-5,
(torch.float16, torch.ops.aten.log_sigmoid_backward.default): 2e-5,
}
if ref.is_floating_point():
orig_diff = (orig - ref).abs().max()

View File

@ -1,7 +1,9 @@
# Owner(s): ["module: masked operators"]
import torch
import unittest
from torch.testing._internal.common_utils import (
decorateIf,
TestCase,
run_tests,
make_tensor,
@ -883,6 +885,37 @@ class TestOperators(TestCase):
@ops(mt_binary_ufuncs, allowed_dtypes=MASKEDTENSOR_FLOAT_TYPES) # type: ignore[arg-type]
@parametrize("layout", [torch.strided, torch.sparse_coo, torch.sparse_csr])
# FIXME:
# Result is just wrong; production logic should be fixed
@decorateIf(
unittest.expectedFailure,
lambda params: (
params["op"].name == "add" and
params["dtype"] in [torch.float16, torch.float32] and
params["device"] == "cpu" and
params["layout"] == torch.sparse_csr
)
)
# Result is just wrong; production logic should be fixed
@decorateIf(
unittest.expectedFailure,
lambda params: (
params["op"].name == "sub" and
params["dtype"] in [torch.float16, torch.float32] and
params["device"] == "cpu" and
params["layout"] == torch.sparse_csr
)
)
# Result is just wrong; production logic should be fixed
@decorateIf(
unittest.expectedFailure,
lambda params: (
params["op"].name == "eq" and
params["dtype"] == torch.float64 and
params["device"] == "cpu" and
params["layout"] == torch.sparse_csr
)
)
def test_binary_core(self, device, dtype, op, layout):
self._test_unary_binary_equality(device, dtype, op, layout)

View File

@ -11829,7 +11829,13 @@ class TestConsistency(TestCaseMPS):
self.assertEqual(device, "cpu")
def get_samples():
return op.sample_inputs(device, dtype, requires_grad=(dtype.is_floating_point or dtype.is_complex))
return op.sample_inputs(
device,
dtype,
requires_grad=(dtype.is_floating_point or dtype.is_complex),
# TODO: Enable per-sample seed setting and tweak tolerances / fix xfails
set_seed=False,
)
cpu_samples = get_samples()
for cpu_sample in cpu_samples:
@ -11864,7 +11870,13 @@ class TestConsistency(TestCaseMPS):
self.assertEqual(device, "cpu")
def get_samples():
return op.sample_inputs(device, dtype, requires_grad=(dtype.is_floating_point or dtype.is_complex))
return op.sample_inputs(
device,
dtype,
requires_grad=(dtype.is_floating_point or dtype.is_complex),
# TODO: Enable per-sample seed setting and tweak tolerances / fix xfails
set_seed=False,
)
cpu_samples = get_samples()
for cpu_sample in cpu_samples:
@ -11940,7 +11952,8 @@ class TestErrorInputs(TestCase):
def test_error_inputs(self, device, op):
self.assertEqual(device, "mps:0")
mps_samples = op.error_inputs(device)
# TODO: Enable per-sample seed setting and tweak tolerances / fix xfails
mps_samples = op.error_inputs(device, set_seed=False)
for mps_sample in mps_samples:
mps_sample_input = mps_sample.sample_input
@ -12015,7 +12028,12 @@ class TestCommon(TestCase):
# A few ops are currently broken on their reference inputs, but not their sample inputs. These should
# get patched up and this workaround removed.
broken_on_ref_inputs = op.name in ['clamp', 'where']
inputs = op.reference_inputs(device, dtype) if not broken_on_ref_inputs else op.sample_inputs(device, dtype)
# TODO: Enable per-sample seed setting and tweak tolerances / fix xfails
inputs = (
op.reference_inputs(device, dtype, set_seed=False) if not broken_on_ref_inputs
else op.sample_inputs(device, dtype, set_seed=False)
)
for sample_input in inputs:
self.compare_with_reference(op, op.ref, sample_input)

View File

@ -1125,11 +1125,13 @@ class ops(_TestParametrizer):
except Exception as e:
tracked_input = get_tracked_input()
if PRINT_REPRO_ON_FAILURE and tracked_input is not None:
raise Exception( # noqa: TRY002
e_tracked = Exception( # noqa: TRY002
f"Caused by {tracked_input.type_desc} "
f"at index {tracked_input.index}: "
f"{_serialize_sample(tracked_input.val)}"
) from e
)
e_tracked._tracked_input = tracked_input # type: ignore[attr]
raise e_tracked from e
raise e
finally:
clear_tracked_input()

View File

@ -6693,7 +6693,9 @@ def sample_inputs_masked_fill(op_info, device, dtype, requires_grad, **kwargs):
if torch.device(device).type == 'cuda':
# `self` and `mask` on CUDA but `value` is a CPU scalar tensor.
yield SampleInput(make_arg((S, S)), args=(torch.randn(S, S, device=device) > 0, torch.randn(())))
yield SampleInput(make_arg((S, S)),
args=(torch.randn(S, S, device=device) > 0,
make_tensor((), device="cpu", dtype=dtype)))
def error_inputs_masked_fill(op_info, device, **kwargs):
make_arg = partial(make_tensor, device=device, dtype=torch.float, requires_grad=False)
@ -10015,6 +10017,14 @@ foreach_unary_op_db: List[OpInfo] = [
"test_meta_inplace",
dtypes=integral_types_and(torch.bool),
),
# FIXME: fails check
# https://github.com/pytorch/pytorch/blob/main/aten/src/ATen/TensorIterator.cpp#L508-L510
DecorateInfo(
unittest.skip("Skipped!"),
"TestForeach",
"test_parity",
dtypes=(torch.bool,),
),
),
),
ForeachFuncInfo(
@ -12201,7 +12211,7 @@ op_db: List[OpInfo] = [
toleranceOverride({torch.float32: tol(atol=1.5e-05, rtol=1e-05)}),
'TestCommon', 'test_out'),
DecorateInfo(
toleranceOverride({torch.half: tol(atol=6e-3, rtol=6e-3)}),
toleranceOverride({torch.half: tol(atol=6e-3, rtol=1e-2)}),
'TestInductorOpInfo', 'test_comprehensive', device_type='cpu'),
],
skips=(
@ -12350,7 +12360,12 @@ op_db: List[OpInfo] = [
decorators=[
DecorateInfo(
toleranceOverride({torch.float16: tol(atol=1e-05, rtol=1e-03)}),
'TestUnaryUfuncs', device_type='cuda'),
'TestUnaryUfuncs', device_type='cuda'
),
DecorateInfo(
toleranceOverride({torch.float32: tol(atol=8e-5, rtol=4e-5)}),
'TestInductorOpInfo', 'test_comprehensive', device_type='cuda'
),
precisionOverride({torch.bfloat16: 1e-2}),
],
skips=(
@ -12449,7 +12464,15 @@ op_db: List[OpInfo] = [
domain=(-1, 1),
dtypes=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16),
dtypesIfCUDA=all_types_and_complex_and(torch.chalf, torch.bool, torch.half, torch.bfloat16),
decorators=(precisionOverride({torch.bfloat16: 1e-2}),),
decorators=[
precisionOverride({torch.bfloat16: 1e-2}),
DecorateInfo(
toleranceOverride({torch.float32: tol(atol=9e-3, rtol=8e-5)}),
"TestInductorOpInfo",
"test_comprehensive",
device_type="cuda"
),
],
supports_inplace_autograd=False,
supports_forward_ad=True,
supports_fwgrad_bwgrad=True,
@ -12935,6 +12958,8 @@ op_db: List[OpInfo] = [
# return torch.cov(i0, correction=0, fweights=None, aweights=tensor([0.0518, 0.4681], dtype=torch.float32, requires_grad=True)) # noqa: B950
# ~~~~~~ <--- HERE
DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit'),
DecorateInfo(toleranceOverride({torch.float16: tol(atol=8e-3, rtol=1.4e-3)}),
"TestInductorOpInfo", "test_comprehensive", device_type="cpu"),
)),
OpInfo('cross',
dtypes=all_types_and_complex_and(torch.half, torch.bfloat16),
@ -13041,6 +13066,14 @@ op_db: List[OpInfo] = [
skips=(
# RuntimeError: MALFORMED INPUT: Unhandled node kind (in computeValue): aten::div
DecorateInfo(unittest.expectedFailure, 'TestNNCOpInfo', 'test_working'),
# FIXME:
# torch.autograd.gradcheck.GradcheckError: Jacobian mismatch for
# output 0 with respect to input 1,
# numerical:tensor(-17746.9307, dtype=torch.float64)
# analytical:tensor(0., dtype=torch.float64)
DecorateInfo(unittest.skip("Skipped!"), 'TestBwdGradients',
'test_fn_grad', device_type='cpu',
dtypes=(torch.float64,)),
)),
BinaryUfuncInfo('div',
aliases=('divide',),
@ -13061,6 +13094,15 @@ op_db: List[OpInfo] = [
skips=(
# RuntimeError: MALFORMED INPUT: Unhandled node kind (in computeValue): aten::div
DecorateInfo(unittest.expectedFailure, 'TestNNCOpInfo', 'test_working'),
# FIXME:
# torch.autograd.gradcheck.GradcheckError: Jacobian mismatch for
# output 0 with respect to input 1,
# numerical:tensor(-17746.9307, dtype=torch.float64)
# analytical:tensor(0., dtype=torch.float64)
DecorateInfo(unittest.skip("Skipped!"), 'TestBwdGradients',
'test_fn_grad',
dtypes=(torch.float64,),
device_type='cpu'),
)),
BinaryUfuncInfo('true_divide',
dtypes=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16),
@ -13211,6 +13253,15 @@ op_db: List[OpInfo] = [
DecorateInfo(unittest.skip("Skipped!"), 'TestBinaryUfuncs',
'test_reference_numerics_small_values',
dtypes=(torch.uint8,)),
# FIXME:
# torch.autograd.gradcheck.GradcheckError: Jacobian mismatch for
# output 0 with respect to input 1,
# numerical:tensor(101.6283, dtype=torch.float64)
# analytical:tensor(-18.3575, dtype=torch.float64)
DecorateInfo(unittest.skip("Skipped!"), 'TestBwdGradients',
'test_fn_grad',
dtypes=(torch.float64,),
device_type='cpu'),
)),
BinaryUfuncInfo('remainder',
ref=np.remainder,
@ -13245,6 +13296,22 @@ op_db: List[OpInfo] = [
# False is not true : Tensors failed to compare as equal!
# Attempted to compare equality of tensors with different dtypes
DecorateInfo(unittest.skip("Skipped!"), 'TestOpInfo', device_type='xla', dtypes=(torch.long,)),
# FIXME:
# torch.autograd.gradcheck.GradcheckError: Jacobian mismatch for
# output 0 with respect to input 1,
# numerical:tensor(102.4676, dtype=torch.float64)
# analytical:tensor(-17.5182, dtype=torch.float64)
DecorateInfo(unittest.skip("Skipped!"), 'TestBwdGradients',
'test_fn_grad', device_type='cpu',
dtypes=(torch.float64,)),
DecorateInfo(
toleranceOverride({
torch.float16: tol(atol=5e-4, rtol=3e-3),
}),
"TestInductorOpInfo",
"test_comprehensive",
device_type="cuda"
),
)),
UnaryUfuncInfo('frac',
ref=lambda x: np.modf(x)[0],
@ -14039,6 +14106,8 @@ op_db: List[OpInfo] = [
decorators=(
DecorateInfo(toleranceOverride({torch.float64: tol(atol=2e-7, rtol=2e-7)}),
"TestDecomp", "test_comprehensive", device_type="cuda"),
DecorateInfo(toleranceOverride({torch.float16: tol(atol=1e-3, rtol=2e-3)}),
"TestInductorOpInfo", "test_comprehensive", device_type="cuda"),
)),
OpInfo('var_mean',
variant_test_name='unbiased',
@ -14052,6 +14121,8 @@ op_db: List[OpInfo] = [
decorators=(
DecorateInfo(toleranceOverride({torch.float64: tol(atol=2e-7, rtol=2e-7)}),
"TestDecomp", "test_comprehensive", device_type="cuda"),
DecorateInfo(toleranceOverride({torch.float16: tol(atol=1e-3, rtol=2e-3)}),
"TestInductorOpInfo", "test_comprehensive", device_type="cuda"),
)),
OpInfo('std_mean',
dtypes=floating_and_complex_types_and(torch.half, torch.bfloat16),
@ -14075,8 +14146,24 @@ op_db: List[OpInfo] = [
check_batched_forward_grad=False,
supports_fwgrad_bwgrad=True,
decorators=(
DecorateInfo(toleranceOverride({torch.float64: tol(atol=2e-7, rtol=2e-7)}),
"TestDecomp", "test_comprehensive", device_type="cuda"),
DecorateInfo(
toleranceOverride({
torch.float16: tol(atol=4e-5, rtol=9e-3),
torch.float64: tol(atol=2e-7, rtol=2e-7),
}),
"TestDecomp",
"test_comprehensive",
device_type="cuda"
),
DecorateInfo(
toleranceOverride({
torch.float16: tol(atol=4e-5, rtol=9e-3),
torch.float64: tol(atol=2e-7, rtol=2e-7),
}),
"TestInductorOpInfo",
"test_comprehensive",
device_type="cuda"
),
)),
OpInfo('meshgrid',
variant_test_name='variadic_tensors',
@ -14401,7 +14488,7 @@ op_db: List[OpInfo] = [
supports_fwgrad_bwgrad=True,
decorators=(
DecorateInfo(
toleranceOverride({torch.float32: tol(atol=1e-5, rtol=1e-3)}),
toleranceOverride({torch.float32: tol(atol=3e-3, rtol=1e-3)}),
"TestJit",
"test_variant_consistency_jit",
device_type="cpu",
@ -14559,6 +14646,8 @@ op_db: List[OpInfo] = [
# JIT test also tries to compute double backward, which fails
DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit'),
DecorateInfo(unittest.skip("Unsupported on MPS for now"), 'TestCommon', 'test_numpy_ref_mps'),
DecorateInfo(toleranceOverride({torch.float32: tol(atol=2e-03, rtol=5e-03)}),
"TestDecomp", "test_comprehensive", device_type="cpu"),
)),
OpInfo('native_batch_norm',
aten_name='native_batch_norm',
@ -14655,6 +14744,14 @@ op_db: List[OpInfo] = [
supports_out=False,
supports_forward_ad=True,
supports_fwgrad_bwgrad=True,
decorators=[
DecorateInfo(
toleranceOverride({torch.float16: tol(atol=1.3e-5, rtol=2e-2)}),
"TestInductorOpInfo",
"test_comprehensive",
device_type="cuda"
),
],
sample_inputs_func=sample_inputs_cosine_similarity),
OpInfo('nn.functional.adaptive_avg_pool1d',
dtypes=floating_types_and(torch.half, torch.bfloat16),
@ -14846,7 +14943,7 @@ op_db: List[OpInfo] = [
toleranceOverride({torch.float: tol(atol=1.5e-5, rtol=1.5e-5), }),
'TestCommon', 'test_numpy_ref_mps'),
DecorateInfo(
toleranceOverride({torch.half: tol(atol=1e-3, rtol=2e-3), }),
toleranceOverride({torch.half: tol(atol=1e-3, rtol=5e-3), }),
'TestInductorOpInfo', 'test_comprehensive', device_type='cpu'),
),
skips=(
@ -14892,7 +14989,7 @@ op_db: List[OpInfo] = [
toleranceOverride({torch.chalf: tol(atol=8e-2, rtol=8e-2), }),
'TestCommon', 'test_complex_half_reference_testing'),
DecorateInfo(
toleranceOverride({torch.half: tol(atol=1e-3, rtol=2e-3), }),
toleranceOverride({torch.half: tol(atol=1e-3, rtol=4e-3), }),
'TestInductorOpInfo', 'test_comprehensive', device_type='cpu')],
skips=(
# RuntimeError: !lhs.isAliasOf(rhs)INTERNAL ASSERT FAILED at
@ -14955,7 +15052,7 @@ op_db: List[OpInfo] = [
toleranceOverride({torch.chalf: tol(atol=9e-2, rtol=9e-2), }),
'TestCommon', 'test_complex_half_reference_testing'),
DecorateInfo(
toleranceOverride({torch.half: tol(atol=1e-3, rtol=2e-1), }),
toleranceOverride({torch.half: tol(atol=9e-3, rtol=2e-1), }),
'TestInductorOpInfo', 'test_comprehensive', device_type='cpu')],
skips=(
# RuntimeError: !lhs.isAliasOf(rhs)INTERNAL ASSERT FAILED at
@ -15101,7 +15198,13 @@ op_db: List[OpInfo] = [
decorators=[
# RuntimeError: Cannot insert a Tensor that requires grad as a constant.
# Consider making it a parameter or input, or detaching the gradient
DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit', dtypes=(torch.float32,))
DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit', dtypes=(torch.float32,)),
DecorateInfo(
toleranceOverride({torch.float32: tol(atol=5e-05, rtol=3e-03)}),
"TestDecomp",
"test_comprehensive",
device_type="cpu"
),
],
sample_inputs_func=sample_inputs_group_norm,
reference_inputs_func=reference_inputs_group_norm,
@ -15518,6 +15621,12 @@ op_db: List[OpInfo] = [
"TestJit",
"test_variant_consistency_jit",
),
DecorateInfo(
toleranceOverride({torch.float16: tol(atol=4e-3, rtol=1.3e-3)}),
"TestInductorOpInfo",
"test_comprehensive",
device_type="cuda"
),
),
skips=(
# AssertionError: False is not true : Scalars failed to compare as equal! 0 != 4096
@ -15805,7 +15914,7 @@ op_db: List[OpInfo] = [
dtypesIfCUDA=floating_types_and(torch.float16,
*[torch.bfloat16] if SM53OrLater or TEST_WITH_ROCM else []),
decorators=(
DecorateInfo(toleranceOverride({torch.float16: tol(atol=5e-05, rtol=1e-03)}),
DecorateInfo(toleranceOverride({torch.float16: tol(atol=2e-03, rtol=1.3e-03)}),
'TestInductorOpInfo', 'test_comprehensive', device_type='cpu'),
),
skips=(
@ -16848,8 +16957,12 @@ op_db: List[OpInfo] = [
# which leads to failure of this test.
DecorateInfo(unittest.skip("Skipped!"), 'TestDecomp', 'test_quick',
dtypes=(torch.complex32,), active_if=TEST_WITH_ROCM),
# FIXME:
# Mismatched elements: 1 / 500 (0.2%)
# Greatest absolute difference: nan at index (7, 9, 0) (up to 1e-05 allowed)
# Greatest relative difference: nan at index (7, 9, 0) (up to 0.001 allowed)
DecorateInfo(unittest.skip("Skipped!"), 'TestDecomp', 'test_comprehensive',
dtypes=(torch.complex32,), active_if=TEST_WITH_ROCM),
dtypes=(torch.complex32,)),
DecorateInfo(unittest.skip("Skipped!"), 'TestCommon', 'test_complex_half_reference_testing',
dtypes=(torch.complex32,), active_if=TEST_WITH_ROCM),
DecorateInfo(unittest.skip("Skipped!"), 'TestBinaryUfuncs', 'test_batch_vs_slicing',
@ -17421,6 +17534,17 @@ op_db: List[OpInfo] = [
active_if=(IS_MACOS or IS_WINDOWS)),
DecorateInfo(unittest.skip("Skipped! sparse backward not supported"),
'TestSparseUnaryUfuncs', 'test_sparse_fn_grad'),
# FIXME:
# Mismatched elements: 2 / 400 (0.5%)
# Greatest absolute difference: inf at index (7, 16) (up to 1e-05 allowed)
# Greatest relative difference: nan at index (7, 16) (up to 0.001 allowed)
DecorateInfo(
unittest.skip("Skipped!"),
"TestInductorOpInfo",
"test_comprehensive",
dtypes=(torch.float16,),
device_type="cuda",
),
),
# tan(pi/2 * odd_number) is nan
reference_numerics_filter=NumericsFilter(
@ -17510,7 +17634,14 @@ op_db: List[OpInfo] = [
supports_forward_ad=True,
supports_fwgrad_bwgrad=True,
gradcheck_wrapper=lambda *args, **kwargs: gradcheck_wrapper_triangular_input(*args, idx=1, **kwargs),
decorators=[skipCUDAIfNoMagma, skipCPUIfNoLapack],
decorators=[
skipCUDAIfNoMagma,
skipCPUIfNoLapack,
DecorateInfo(
toleranceOverride({torch.float32: tol(atol=3e-5, rtol=3e-6)}),
'TestConsistency', 'test_output_match', device_type='cpu',
),
],
skips=(
# AssertionError: Scalars are not equal!
DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_out'),
@ -17872,7 +18003,9 @@ op_db: List[OpInfo] = [
'TestFwdGradients',
'test_fn_fwgrad_bwgrad',
dtypes=[torch.complex128]),
DecorateInfo(
toleranceOverride({torch.float32: tol(atol=3e-5, rtol=1e-3)}),
'TestInductorOpInfo', 'test_comprehensive', device_type='cuda'),
],
skips=(
# test does not work with passing lambda for op
@ -18201,6 +18334,12 @@ op_db: List[OpInfo] = [
supports_scripting=False,
supports_forward_ad=True,
supports_fwgrad_bwgrad=True,
decorators=(
DecorateInfo(
toleranceOverride({torch.float16: tol(atol=2e-3, rtol=3e-2)}),
'TestInductorOpInfo', 'test_comprehensive', device_type='cpu'
),
),
sample_inputs_func=sample_inputs__unsafe_masked_index_put_accumulate),
OpInfo('__getitem__',
dtypes=all_types_and_complex_and(torch.bool, torch.float16, torch.bfloat16, torch.chalf),
@ -19418,6 +19557,12 @@ op_db: List[OpInfo] = [
supports_fwgrad_bwgrad=True,
# See https://github.com/pytorch/pytorch/pull/78358
check_batched_forward_grad=False,
decorators=[
DecorateInfo(
toleranceOverride({torch.half: tol(atol=9e-4, rtol=4.3e-3)}),
'TestInductorOpInfo', 'test_comprehensive', device_type='cuda'
),
],
sample_inputs_func=sample_trapezoid),
OpInfo('trapezoid',
dtypes=all_types_and_complex_and(torch.float16, torch.bfloat16),
@ -19426,6 +19571,12 @@ op_db: List[OpInfo] = [
supports_fwgrad_bwgrad=True,
# See https://github.com/pytorch/pytorch/pull/78358
check_batched_forward_grad=False,
decorators=[
DecorateInfo(
toleranceOverride({torch.half: tol(atol=9e-4, rtol=4.3e-3)}),
'TestInductorOpInfo', 'test_comprehensive', device_type='cuda'
),
],
sample_inputs_func=sample_trapezoid),
OpInfo('cumulative_trapezoid',
dtypes=all_types_and_complex_and(torch.bfloat16, torch.float16),
@ -19434,6 +19585,12 @@ op_db: List[OpInfo] = [
# See https://github.com/pytorch/pytorch/pull/78358
check_batched_forward_grad=False,
supports_out=False,
decorators=(
DecorateInfo(
toleranceOverride({torch.float16: tol(atol=4e-3, rtol=4e-3)}),
'TestInductorOpInfo', 'test_comprehensive',
),
),
sample_inputs_func=sample_cumulative_trapezoid,),
OpInfo('unsqueeze',
dtypes=all_types_and_complex_and(torch.bool, torch.float16, torch.bfloat16, torch.chalf),
@ -19675,6 +19832,14 @@ op_db: List[OpInfo] = [
# Falling back to non-numerically stablized exp, causing nan in the results.
DecorateInfo(unittest.expectedFailure, 'TestFwdGradients', 'test_forward_mode_AD', dtypes=[torch.complex128]),
DecorateInfo(unittest.expectedFailure, 'TestFwdGradients', 'test_fn_fwgrad_bwgrad', dtypes=[torch.complex128]),
DecorateInfo(
toleranceOverride({
torch.float16: tol(atol=7e-5, rtol=6e-3),
}),
"TestInductorOpInfo",
"test_comprehensive",
device_type="cuda"
),
),
sample_inputs_func=sample_inputs_logcumsumexp,
error_inputs_func=error_inputs_logcumsumexp),
@ -20348,6 +20513,11 @@ op_db: List[OpInfo] = [
DecorateInfo(unittest.skip("Skipped!"), 'TestMathBits', 'test_neg_view', device_type='cuda'),
# Not a problem: embedding does weird stuff to its input (it renormalizes)
DecorateInfo(unittest.skip('Allowed exemption'), 'TestCompositeCompliance', 'test_operator'),
# Fails due to non-determinism (see issue #74679)
# TODO: Investigate why more granular skips in the test don't work in CI
DecorateInfo(unittest.skip('Skipped!'),
'TestExpandedWeightFunctional',
'test_expanded_weight_forward'),
),
supports_expanded_weight=True,
supports_out=False,
@ -20861,6 +21031,12 @@ op_db: List[OpInfo] = [
supports_out=False,
supports_forward_ad=True,
supports_fwgrad_bwgrad=True,
decorators=[
DecorateInfo(
toleranceOverride({torch.float16: tol(atol=1e-4, rtol=2e-3)}),
'TestInductorOpInfo', 'test_comprehensive', device_type="cuda",
),
],
sample_inputs_func=sample_inputs_cosine_embedding_loss,
),
OpInfo(
@ -22471,6 +22647,15 @@ python_ref_db = [
skips=(
# RunTimeError: no _refs support for torch.Tensor.index_select
DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_python_ref'),
# Reference result was farther (1.946091651916504e-05) from the precise
# computation than the torch result was (1.1920928955078125e-06)!
DecorateInfo(
unittest.expectedFailure,
'TestCommon',
'test_python_ref_torch_fallback',
dtypes=(torch.float32,),
device_type='cpu',
),
)),
PythonRefInfo(
"_refs.nn.functional.leaky_relu",
@ -22566,6 +22751,14 @@ python_ref_db = [
PythonRefInfo(
"_refs.nn.functional.hinge_embedding_loss",
torch_opinfo_name="nn.functional.hinge_embedding_loss",
skips=(
# Reference result was farther (0.29562714856322714) from the precise
# computation than the torch result was (0.20437285143677286)!
DecorateInfo(
unittest.expectedFailure, 'TestCommon', 'test_python_ref',
dtypes=(torch.bfloat16,), device_type="cpu"
),
),
),
PythonRefInfo(
"_refs.nn.functional.nll_loss",
@ -22735,6 +22928,15 @@ python_ref_db = [
decorators=(
# See https://github.com/pytorch/pytorch/issues/111126
DecorateInfo(unittest.expectedFailure, 'TestBinaryUfuncs', 'test_type_promotion'),
# Reference result was farther (nan) from the precise computation than the
# torch result was (inf)!
DecorateInfo(
unittest.expectedFailure,
"TestCommon",
"test_python_ref",
dtypes=(torch.bfloat16,),
device_type="cpu",
),
),
),
ElementwiseBinaryPythonRefInfo(

View File

@ -14,6 +14,7 @@ import ctypes
import errno
import functools
import gc
import hashlib
import inspect
import io
import json
@ -162,6 +163,40 @@ class TestEnvironment:
assert name not in globals(), f"duplicate definition of flag '{name}'"
globals()[name] = enabled
# Defines a setting usable throughout the test suite, determining its value by querying
# the specified environment variable. This differs from a flag in that it's not restricted
# to a boolean value.
#
# Args:
# name (str): The name of the setting. A global variable with this name will be set
# for convenient access throughout the test suite.
# env_var (str): The name of the primary environment variable from which to
# determine the value of this setting. If this is None or the environment variable
# is unset, the default value will be used. Default: None
# default (Any): The default value to use for the setting if unset by the environment
# variable. Default: None
# include_in_repro (bool): Indicates whether this setting should be included in the
# repro command that is output on test failure (i.e. whether it is possibly
# relevant to reproducing the test failure). Default: True
# parse_fn (Callable): Callable parsing the env var string. Default value just uses
# the string itself.
@staticmethod
def def_setting(
name,
env_var=None,
default=None,
include_in_repro=True,
parse_fn=lambda maybe_val_str: maybe_val_str,
):
value = default if env_var is None else os.getenv(env_var)
value = parse_fn(value)
if include_in_repro and (value != default):
TestEnvironment.repro_env_vars[env_var] = value
# export setting globally for convenience
assert name not in globals(), f"duplicate definition of setting '{name}'"
globals()[name] = value
# Returns a string prefix usable to set environment variables for any test
# settings that should be explicitly set to match this instantiation of the
# test suite.
@ -211,6 +246,17 @@ TestEnvironment.def_flag(
TestEnvironment.def_flag("PRINT_REPRO_ON_FAILURE", env_var="PYTORCH_PRINT_REPRO_ON_FAILURE",
default=(not IS_FBCODE), include_in_repro=False) # noqa: F821
# possibly restrict OpInfo tests to a single sample input
TestEnvironment.def_setting(
"OPINFO_SAMPLE_INPUT_INDEX",
env_var="PYTORCH_OPINFO_SAMPLE_INPUT_INDEX",
default=None,
# Don't include the env var value in the repro command because the info will
# be queried from the tracked sample input instead
include_in_repro=False,
parse_fn=lambda val: None if val is None else int(val),
)
DEFAULT_DISABLED_TESTS_FILE = '.pytorch-disabled-tests.json'
DEFAULT_SLOW_TESTS_FILE = '.pytorch-slow-tests.json'
@ -292,20 +338,39 @@ def clear_tracked_input():
# Wraps an iterator and tracks the most recent value the iterator produces
# for debugging purposes. Tracked values are stored on the test function.
class TrackedInputIter:
def __init__(self, child_iter, input_type_desc, callback=lambda x: x):
def __init__(self, child_iter, input_type_desc,
callback=lambda x: x, set_seed=True, restrict_to_index=None):
self.child_iter = enumerate(child_iter)
# Input type describes the things we're tracking (e.g. "sample input", "error input").
self.input_type_desc = input_type_desc
# Callback is run on each iterated thing to get the thing to track.
self.callback = callback
self.test_fn = extract_test_fn()
# Indicates whether the random seed should be set before each call to the iterator
self.set_seed = set_seed
# Indicates that iteration should be restricted to only the provided index.
# If None, no restriction is done
self.restrict_to_index = restrict_to_index
def __iter__(self):
return self
def __next__(self):
# allow StopIteration to bubble up
input_idx, input_val = next(self.child_iter)
while True:
if self.set_seed:
# use a test-name-specific hash for the seed if possible
seed = (
int.from_bytes(hashlib.sha256(
self.test_fn.__qualname__.encode("utf-8")).digest()[:4], 'little')
if self.test_fn is not None else SEED
)
set_rng_seed(seed)
# allow StopIteration to bubble up
input_idx, input_val = next(self.child_iter)
if (self.restrict_to_index is None) or (input_idx == self.restrict_to_index):
break
self._set_tracked_input(
TrackedInput(
index=input_idx, val=self.callback(input_val), type_desc=self.input_type_desc
@ -2186,16 +2251,29 @@ def skip_exception_type(exc_type):
raise unittest.SkipTest(f"not implemented: {e}") from e
@contextmanager
def print_repro_on_failure(repro_str):
def print_repro_on_failure(repro_parts):
try:
yield
except unittest.SkipTest:
raise
except Exception as e:
# Get the index of the sample input that failed the test if possible.
sample_isolation_prefix = ""
tracked_input = getattr(e, "_tracked_input", None)
if tracked_input is not None:
sample_isolation_prefix = f"PYTORCH_OPINFO_SAMPLE_INPUT_INDEX={tracked_input.index}"
repro_str = " ".join(filter(None, (sample_isolation_prefix, *repro_parts)))
repro_msg = f"""
To execute this test, run the following from the base repo dir:
{repro_str}
This message can be suppressed by setting PYTORCH_PRINT_REPRO_ON_FAILURE=0"""
# NB: Hacking the exception args is the cleanest way I've found to append
# failure reproduction info without poisoning the stack trace.
if len(e.args) >= 1:
e.args = (f"{e.args[0]}\n{repro_str}", *e.args[1:])
e.args = (f"{e.args[0]}\n{repro_msg}", *e.args[1:])
raise
# "min_satisfying_examples" setting has been deprecated in hypothesis
@ -2641,7 +2719,6 @@ class TestCase(expecttest.TestCase):
self.wrap_with_policy(method_name, lambda: skip_exception_type(NotImplementedError))
if PRINT_REPRO_ON_FAILURE: # noqa: F821
env_var_prefix = TestEnvironment.repro_env_var_prefix()
try:
def _get_rel_test_path(abs_test_path):
# Attempt to get relative path based on the "test" dir.
@ -2662,14 +2739,12 @@ class TestCase(expecttest.TestCase):
abs_test_path = os.path.abspath(inspect.getfile(type(self)))
test_filename = _get_rel_test_path(abs_test_path)
class_name = type(self).__name__
repro_str = f"""
To execute this test, run the following from the base repo dir:
{env_var_prefix} python {test_filename} -k {class_name}.{method_name}
This message can be suppressed by setting PYTORCH_PRINT_REPRO_ON_FAILURE=0"""
test_run_cmd = f"python {test_filename} -k {class_name}.{method_name}"
env_var_prefix = TestEnvironment.repro_env_var_prefix()
repro_parts = [env_var_prefix, test_run_cmd]
self.wrap_with_policy(
method_name,
lambda repro_str=repro_str: print_repro_on_failure(repro_str=repro_str))
lambda repro_parts=repro_parts: print_repro_on_failure(repro_parts))
except Exception as e:
# Don't fail entirely if we can't get the test filename
log.info("could not print repro string", extra=str(e))

View File

@ -28,6 +28,7 @@ from torch.testing._internal.common_dtype import (
from torch.testing._internal.common_utils import (
is_iterable_of_tensors,
noncontiguous_like,
OPINFO_SAMPLE_INPUT_INDEX,
TEST_WITH_ROCM,
torch_to_numpy_dtype_dict,
TrackedInputIter,
@ -1177,6 +1178,7 @@ class OpInfo:
tensor in a sequence input conjugated.
"""
set_seed = kwargs.pop("set_seed", True)
samples = self.sample_inputs_func(self, device, dtype, requires_grad, **kwargs)
conj_samples = list(samples)
@ -1193,7 +1195,12 @@ class OpInfo:
else:
sample.input[0] = conjugate(sample.input[0])
return TrackedInputIter(iter(conj_samples), "conjugate sample input")
return TrackedInputIter(
iter(conj_samples),
"conjugate sample input",
set_seed=set_seed,
restrict_to_index=OPINFO_SAMPLE_INPUT_INDEX,
)
def sample_inputs(self, device, dtype, requires_grad=False, **kwargs):
"""
@ -1202,6 +1209,7 @@ class OpInfo:
These samples should be sufficient to test the function works correctly
with autograd, TorchScript, etc.
"""
set_seed = kwargs.pop("set_seed", True)
samples = self.sample_inputs_func(self, device, dtype, requires_grad, **kwargs)
if kwargs.get("include_conjugated_inputs", False):
@ -1212,7 +1220,12 @@ class OpInfo:
samples_list.extend(conj_samples)
samples = tuple(samples_list)
return TrackedInputIter(iter(samples), "sample input")
return TrackedInputIter(
iter(samples),
"sample input",
set_seed=set_seed,
restrict_to_index=OPINFO_SAMPLE_INPUT_INDEX,
)
def reference_inputs(self, device, dtype, requires_grad=False, **kwargs):
"""
@ -1222,11 +1235,17 @@ class OpInfo:
of inputs when reference_inputs_func is defined. If undefined this returns
the sample inputs.
"""
set_seed = kwargs.pop("set_seed", True)
if self.reference_inputs_func is None:
samples = self.sample_inputs_func(
self, device, dtype, requires_grad, **kwargs
)
return TrackedInputIter(iter(samples), "sample input")
return TrackedInputIter(
iter(samples),
"reference input",
set_seed=set_seed,
restrict_to_index=OPINFO_SAMPLE_INPUT_INDEX,
)
if kwargs.get("include_conjugated_inputs", False):
raise NotImplementedError
@ -1234,15 +1253,25 @@ class OpInfo:
references = self.reference_inputs_func(
self, device, dtype, requires_grad, **kwargs
)
return TrackedInputIter(iter(references), "reference input")
return TrackedInputIter(
iter(references),
"reference input",
set_seed=set_seed,
restrict_to_index=OPINFO_SAMPLE_INPUT_INDEX,
)
def error_inputs(self, device, **kwargs):
"""
Returns an iterable of ErrorInputs.
"""
set_seed = kwargs.pop("set_seed", True)
errs = self.error_inputs_func(self, device, **kwargs)
return TrackedInputIter(
iter(errs), "error input", callback=lambda e: e.sample_input
iter(errs),
"error input",
callback=lambda e: e.sample_input,
set_seed=set_seed,
restrict_to_index=OPINFO_SAMPLE_INPUT_INDEX,
)
def error_inputs_sparse(self, device, layout, **kwargs):

View File

@ -558,6 +558,12 @@ op_db: List[OpInfo] = [
"test_mask_layout",
device_type="cpu",
),
DecorateInfo(
toleranceOverride({torch.float32: tol(atol=1e-05, rtol=1e-05)}),
"TestOperators",
"test_jvp",
device_type="cuda",
),
],
sample_inputs_func=sample_inputs_masked_reduction,
sample_inputs_sparse_coo_func=sample_inputs_sparse_coo_masked_reduction,
@ -614,7 +620,7 @@ op_db: List[OpInfo] = [
device_type="cuda",
),
DecorateInfo(
toleranceOverride({torch.float16: tol(atol=2e-3, rtol=2e-3)}),
toleranceOverride({torch.float16: tol(atol=1e-2, rtol=2.6e-3)}),
"TestInductorOpInfo",
"test_comprehensive",
device_type="cuda",
@ -955,6 +961,16 @@ op_db: List[OpInfo] = [
"TestMasked",
"test_reference_masked",
),
DecorateInfo(
toleranceOverride(
{
torch.float16: tol(atol=4e-5, rtol=2e-2),
}
),
"TestInductorOpInfo",
"test_comprehensive",
device_type="cuda",
),
],
sample_inputs_func=sample_inputs_masked_std_var,
gradcheck_wrapper=gradcheck_wrapper_masked_operation,
@ -1091,6 +1107,16 @@ op_db: List[OpInfo] = [
DecorateInfo(
unittest.expectedFailure, "TestJit", "test_variant_consistency_jit"
),
# FIXME:
# Mismatched elements: 2 / 2 (100.0%)
# Greatest absolute difference: nan at index (0,) (up to 0.0001 allowed)
# Greatest relative difference: nan at index (0,) (up to 0.0001 allowed
DecorateInfo(
unittest.skip("Skipped!"),
"TestOperators",
"test_vmapvjpvjp",
device_type="cpu",
),
),
gradcheck_wrapper=gradcheck_wrapper_masked_operation,
supports_forward_ad=True,
@ -1102,6 +1128,14 @@ op_db: List[OpInfo] = [
method_variant=None,
dtypes=floating_and_complex_types_and(torch.half, torch.bfloat16),
sample_inputs_func=sample_inputs_masked_normalize,
decorators=[
DecorateInfo(
toleranceOverride({torch.float16: tol(atol=2e-5, rtol=6e-3)}),
"TestInductorOpInfo",
"test_comprehensive",
device_type="cuda",
),
],
skips=(
DecorateInfo(
unittest.expectedFailure,
@ -1177,6 +1211,16 @@ op_db: List[OpInfo] = [
),
# all the values are the same except for -inf vs nan
DecorateInfo(unittest.skip("Skipped!"), "TestDecomp", "test_comprehensive"),
# FIXME:
# Mismatched elements: 2 / 12 (16.7%)
# Greatest absolute difference: 9223372034707292160 at index (0, 0, 0, 0)
# Greatest relative difference: 0.0 at index (0, 0, 0, 1)
DecorateInfo(
unittest.skip("Skipped!"),
"TestInductorOpInfo",
"test_comprehensive",
device_type="cpu",
),
),
sample_inputs_func=sample_inputs_masked_reduction,
gradcheck_wrapper=gradcheck_wrapper_masked_operation,

View File

@ -250,7 +250,7 @@ op_db: List[OpInfo] = [
precisionOverride({torch.float: 2e-4, torch.cfloat: 2e-4}),
"TestFFT",
"test_reference_nd",
)
),
],
skips=(
# Issue with conj and torch dispatch, see https://github.com/pytorch/pytorch/issues/82479
@ -259,6 +259,13 @@ op_db: List[OpInfo] = [
"TestSchemaCheckModeOpInfo",
"test_schema_correctness",
),
# FIXME: errors are too large; needs investigation
DecorateInfo(
unittest.skip("Skipped!"),
"TestCommon",
"test_complex_half_reference_testing",
device_type="cuda",
),
),
),
SpectralFuncInfo(
@ -708,7 +715,25 @@ python_ref_db: List[OpInfo] = [
precisionOverride({torch.float: 2e-4}),
"TestFFT",
"test_reference_nd",
)
),
# AssertionError: Reference result was farther (0.09746177145360499) from the precise
# computation than the torch result was (0.09111555632069855)
DecorateInfo(
unittest.skip("Skipped!"),
"TestCommon",
"test_python_ref_torch_fallback",
dtypes=(torch.float16,),
device_type="cuda",
),
# AssertionError: Reference result was farther (0.0953431016138116) from the precise
# computation than the torch result was (0.09305490684430734)
DecorateInfo(
unittest.skip("Skipped!"),
"TestCommon",
"test_python_ref_executor",
dtypes=(torch.float16,),
device_type="cuda",
),
],
),
SpectralFuncPythonRefInfo(
@ -760,7 +785,16 @@ python_ref_db: List[OpInfo] = [
precisionOverride({torch.float: 2e-4}),
"TestFFT",
"test_reference_nd",
)
),
# FIXME:
# Reference result was farther (0.0953431016138116) from the precise computation
# than the torch result was (0.09305490684430734)!
DecorateInfo(
unittest.skip("Skipped!"),
"TestCommon",
"test_python_ref_executor",
device_type="cuda",
),
],
),
PythonRefInfo(

View File

@ -1342,6 +1342,12 @@ op_db: List[OpInfo] = [
"TestCommon",
"test_numpy_ref_mps",
),
DecorateInfo(
toleranceOverride({torch.half: tol(atol=1.2e-2, rtol=1.7e-2)}),
"TestInductorOpInfo",
"test_comprehensive",
device_type="cuda",
),
),
),
OpInfo(
@ -2007,7 +2013,16 @@ op_db: List[OpInfo] = [
gradcheck_fast_mode=True,
supports_forward_ad=True,
supports_fwgrad_bwgrad=True,
decorators=[skipCUDAIfNoMagmaAndNoCusolver, skipCPUIfNoLapack],
decorators=[
skipCUDAIfNoMagmaAndNoCusolver,
skipCPUIfNoLapack,
DecorateInfo(
toleranceOverride({torch.float32: tol(atol=1.3e-05, rtol=6e-04)}),
"TestCommon",
"test_noncontiguous_samples",
device_type="cpu",
),
],
skips=(
DecorateInfo(
unittest.skip("Skipped!"),
@ -2040,7 +2055,16 @@ op_db: List[OpInfo] = [
sample_inputs_func=sample_inputs_linalg_solve,
supports_forward_ad=True,
supports_fwgrad_bwgrad=True,
decorators=[skipCUDAIfNoMagmaAndNoCusolver, skipCPUIfNoLapack],
decorators=[
skipCUDAIfNoMagmaAndNoCusolver,
skipCPUIfNoLapack,
DecorateInfo(
toleranceOverride({torch.float32: tol(atol=1.3e-05, rtol=6e-04)}),
"TestCommon",
"test_noncontiguous_samples",
device_type="cpu",
),
],
skips=(
DecorateInfo(
unittest.skip("Skipped!"),
@ -2368,6 +2392,12 @@ op_db: List[OpInfo] = [
"test_noncontiguous_samples",
device_type="cuda",
),
DecorateInfo(
toleranceOverride({torch.float32: tol(atol=8e-04, rtol=7e-06)}),
"TestCommon",
"test_noncontiguous_samples",
device_type="cpu",
),
],
skips=(
DecorateInfo(