mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Set seed per sample for OpInfo tests + support for restricting to a single sample input (#128238)
This PR: * Sets a random seed before generating each sample for an OpInfo test. It does this by intercepting the sample input iterator via `TrackedInputIter`, optionally setting the seed to a test name specific seed before each iterator call (default is to set the seed). * Some quick and dirty benchmarking shows (hopefully) negligible overhead from setting the random seed before each sample input generation. For a trivial (single assert) test that uses `@ops`: * Uncovered a bunch of test issues: * Test breakdown (>100 total) * A lot of tolerance issues (tweaked tolerance values to fix) * 1 broken OpInfo (`sample_inputs_masked_fill` was generating a sample of the wrong dtype) * 3 actually broken semantics (for masked tensor; added xfails) * 4 Jacobian mismatches (added xfails) * 2 nan results (skip for now, need fixing) * 3 results too far from reference result (add xfails) * Skips MPS tests for now (there are so many failures!). Those will default to the old behavior. **before (no seed setting):** ``` real 0m21.306s user 0m19.053s sys 0m5.192s ``` **after (with seed setting):** ``` real 0m21.905s user 0m19.578s sys 0m5.390s ``` * Utilizing the above for reproducible sample input generation, adds support for restricting the iterator to a single sample input. This is done via an env var `PYTORCH_OPINFO_SAMPLE_INPUT_INDEX` and its usage is included in the repro command. ``` ====================================================================== ERROR: test_bar_add_cuda_uint8 (__main__.TestFooCUDA.test_bar_add_cuda_uint8) ---------------------------------------------------------------------- Traceback (most recent call last): File "/home/jbschlosser/branches/testing_updates/torch/testing/_internal/common_device_type.py", line 971, in test_wrapper return test(*args, **kwargs) ^^^^^^^^^^^^^^^^^^^^^ File "/home/jbschlosser/branches/testing_updates/test/test_ops.py", line 2671, in test_bar self.assertFalse(True) AssertionError: True is not false The above exception was the direct cause of the following exception: Traceback (most recent call last): File "/home/jbschlosser/branches/testing_updates/torch/testing/_internal/common_utils.py", line 2816, in wrapper method(*args, **kwargs) File "/home/jbschlosser/branches/testing_updates/torch/testing/_internal/common_utils.py", line 2816, in wrapper method(*args, **kwargs) File "/home/jbschlosser/branches/testing_updates/torch/testing/_internal/common_device_type.py", line 419, in instantiated_test result = test(self, **param_kwargs) ^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/home/jbschlosser/branches/testing_updates/torch/testing/_internal/common_utils.py", line 1426, in wrapper fn(*args, **kwargs) File "/home/jbschlosser/branches/testing_updates/torch/testing/_internal/common_device_type.py", line 982, in test_wrapper raise new_e from e Exception: Caused by sample input at index 3: SampleInput(input=Tensor[size=(10, 5), device="cuda:0", dtype=torch.uint8], args=TensorList[Tensor[size=(), device="cuda:0", dtype=torch.uint8]], kwargs={}, broadcasts_input=False, name='') To execute this test, run the following from the base repo dir: PYTORCH_OPINFO_SAMPLE_INPUT_INDEX=3 python test/test_ops.py -k TestFooCUDA.test_bar_add_cuda_uint8 This message can be suppressed by setting PYTORCH_PRINT_REPRO_ON_FAILURE=0 ---------------------------------------------------------------------- Ran 1 test in 0.037s FAILED (errors=1) ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/128238 Approved by: https://github.com/janeyx99, https://github.com/justinchuby
This commit is contained in:
committed by
PyTorch MergeBot
parent
acf9e31cf8
commit
c8ab2e8b63
@ -3397,6 +3397,8 @@ class TestComposability(TestCase):
|
||||
new_cotangent = torch.randn(())
|
||||
self.assertEqual(fx_f(new_cotangent, True, True), vjp_fn(new_cotangent))
|
||||
|
||||
# FIXME: test fails in Windows
|
||||
@unittest.skipIf(IS_WINDOWS, "fails in Windows; needs investigation")
|
||||
@unittest.skipIf(IS_FBCODE, "can't subprocess in fbcode")
|
||||
# it is redundant to run this test twice on a machine that has GPUs
|
||||
@onlyCPU
|
||||
|
Reference in New Issue
Block a user