mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Enable some PyTorch core tests with inductor (#87490)
Summary: 1) Graph break on torch.random.set_rng_state since it blocks running inductor core tests; 2) Add several inductor-specific skips; 3) Enable several core tests for inductor CI; cc @jansel @mlazos @soumith @voznesenskym @yanboliang @penguinwu @anijain2305 Pull Request resolved: https://github.com/pytorch/pytorch/pull/87490 Approved by: https://github.com/eellison
This commit is contained in:
committed by
PyTorch MergeBot
parent
f7a04f310b
commit
2c1efe7472
@ -251,13 +251,10 @@ test_dynamo_shard() {
|
||||
|
||||
|
||||
test_inductor() {
|
||||
echo "TODO: enable inductor unit tests"
|
||||
# time python test/run_test.py --core --exclude test_autograd --continue-through-error --verbose
|
||||
|
||||
# PYTORCH_TEST_WITH_DYNAMO and PYTORCH_TEST_WITH_INDUCTOR are only needed for PyTorch tests not written with
|
||||
# using dynamo/inductor. For dynamo/inductor unit tests, specifiying them will trigger an error like
|
||||
# "Detected two calls to `torchdynamo.optimize(...)` with a different backend compiler arguments."
|
||||
# PYTORCH_TEST_WITH_DYNAMO=0 PYTORCH_TEST_WITH_INDUCTOR=0 pytest test/inductor
|
||||
python test/test_modules.py --verbose
|
||||
# TODO: investigate "RuntimeError: CUDA driver API confirmed a leak"
|
||||
# seen intest_ops_gradients.py
|
||||
# pytest test/test_ops_gradients.py --verbose -k "not _complex and not test_inplace_grad_acos_cuda_float64"
|
||||
}
|
||||
|
||||
test_inductor_huggingface_shard() {
|
||||
|
@ -1016,6 +1016,8 @@ class ReproTests(torch._dynamo.test_case.TestCase):
|
||||
self.assertEqual(cnt.frame_count, 1)
|
||||
self.assertEqual(cnt.op_count, 8)
|
||||
|
||||
# TODO: make set_rng_state work with FakeTensor/aot_autograd
|
||||
@patch.object(torch._dynamo.config, "fake_tensor_propagation", False)
|
||||
def test_rng_state(self):
|
||||
def fn():
|
||||
state = torch.get_rng_state()
|
||||
|
@ -11,7 +11,8 @@ from torch.testing._internal.common_device_type import (
|
||||
instantiate_device_type_tests, onlyCUDA, toleranceOverride, tol, skipMeta)
|
||||
from torch.testing._internal.common_modules import module_db, modules, TrainEvalMode
|
||||
from torch.testing._internal.common_utils import (
|
||||
TestCase, run_tests, freeze_rng_state, mock_wrapper, get_tensors_from, gradcheck, gradgradcheck, skipIfMps)
|
||||
TestCase, run_tests, freeze_rng_state, mock_wrapper, get_tensors_from, gradcheck,
|
||||
gradgradcheck, skipIfMps, skipIfTorchInductor)
|
||||
from unittest.mock import patch, call
|
||||
|
||||
|
||||
@ -326,6 +327,7 @@ class TestModule(TestCase):
|
||||
|
||||
@skipIfMps
|
||||
@modules(module_db)
|
||||
@skipIfTorchInductor("to be fixed")
|
||||
def test_non_contiguous_tensors(self, device, dtype, module_info, training):
|
||||
# Check modules work with non-contiguous tensors
|
||||
|
||||
@ -489,6 +491,7 @@ class TestModule(TestCase):
|
||||
@toleranceOverride({torch.float32: tol(5e-2, 0),
|
||||
torch.float64: tol(4e-4, 0)})
|
||||
@modules(module_db)
|
||||
@skipIfTorchInductor("to be fixed")
|
||||
def test_cpu_gpu_parity(self, device, dtype, module_info, training):
|
||||
# TODO: RNN / GRU / LSTM don't support backwards on eval mode for cuDNN; skip this in a
|
||||
# nicer way for eval mode only.
|
||||
@ -579,6 +582,7 @@ class TestModule(TestCase):
|
||||
|
||||
@skipIfMps
|
||||
@modules(module_db)
|
||||
@skipIfTorchInductor("to be fixed")
|
||||
def test_memory_format(self, device, dtype, module_info, training):
|
||||
is_sm86 = device.startswith("cuda") and torch.cuda.get_device_capability(0) == (8, 6)
|
||||
# TODO tighten it to a specific module
|
||||
|
@ -36,6 +36,7 @@ from torch.testing._internal.common_utils import (
|
||||
first_sample,
|
||||
parametrize,
|
||||
skipIfSlowGradcheckEnv,
|
||||
skipIfTorchInductor,
|
||||
slowTest,
|
||||
)
|
||||
from torch.testing._internal.common_methods_invocations import (
|
||||
@ -209,6 +210,7 @@ class TestCommon(TestCase):
|
||||
@unittest.skipIf(TEST_WITH_ASAN, "Skipped under ASAN")
|
||||
@onlyNativeDeviceTypes
|
||||
@ops(python_ref_db)
|
||||
@skipIfTorchInductor("Takes too long for inductor")
|
||||
def test_python_ref_meta(self, device, dtype, op):
|
||||
with FakeTensorMode() as mode:
|
||||
pass
|
||||
@ -374,6 +376,7 @@ class TestCommon(TestCase):
|
||||
@unittest.skipIf(TEST_WITH_ASAN, "Skipped under ASAN")
|
||||
@onlyNativeDeviceTypes
|
||||
@ops(python_ref_db)
|
||||
@skipIfTorchInductor("Takes too long for inductor")
|
||||
def test_python_ref(self, device, dtype, op):
|
||||
# In this test, primTorch refs call into the refs namespace
|
||||
# For example, a ref with torch.foo in it will calls refs.foo instead
|
||||
@ -386,6 +389,7 @@ class TestCommon(TestCase):
|
||||
@unittest.skipIf(TEST_WITH_ASAN, "Skipped under ASAN")
|
||||
@onlyNativeDeviceTypes
|
||||
@ops(python_ref_db)
|
||||
@skipIfTorchInductor("Takes too long for inductor")
|
||||
def test_python_ref_torch_fallback(self, device, dtype, op):
|
||||
# In this test, refs call into the torch namespace (after the initial invocation)
|
||||
# For example, a ref with torch.foo in it will call torch.foo instead of refs.foo
|
||||
@ -397,6 +401,7 @@ class TestCommon(TestCase):
|
||||
@skipCUDAIfRocm
|
||||
@ops(python_ref_db)
|
||||
@parametrize('executor', ['aten', 'nvfuser'])
|
||||
@skipIfTorchInductor("Takes too long for inductor")
|
||||
def test_python_ref_executor(self, device, dtype, op, executor):
|
||||
# TODO: Not all dtypes are supported with nvfuser
|
||||
from torch._prims_common import _torch_dtype_to_nvfuser_dtype_map
|
||||
@ -457,6 +462,7 @@ class TestCommon(TestCase):
|
||||
@skipMeta
|
||||
@onlyNativeDeviceTypes
|
||||
@ops([op for op in python_ref_db if op.error_inputs_func is not None], dtypes=OpDTypes.none)
|
||||
@skipIfTorchInductor("Takes too long for inductor")
|
||||
def test_python_ref_errors(self, device, op):
|
||||
mode = FakeTensorMode()
|
||||
with mode:
|
||||
|
@ -4,8 +4,9 @@ from functools import partial, wraps
|
||||
from itertools import chain
|
||||
import torch
|
||||
|
||||
from torch.testing._internal.common_utils import \
|
||||
(TestCase, is_iterable_of_tensors, run_tests, gradcheck, gradgradcheck, is_slow_gradcheck_env)
|
||||
from torch.testing._internal.common_utils import (
|
||||
TestCase, is_iterable_of_tensors, run_tests, gradcheck, gradgradcheck, is_slow_gradcheck_env,
|
||||
skipIfTorchInductor)
|
||||
from torch.testing._internal.common_methods_invocations import op_db
|
||||
from torch.testing._internal.common_device_type import \
|
||||
(instantiate_device_type_tests, ops, OpDTypes)
|
||||
@ -253,6 +254,7 @@ class TestGradients(TestCase):
|
||||
self._forward_grad_helper(device, dtype, op, op.get_op(), is_inplace=False)
|
||||
|
||||
@_gradcheck_ops(op_db)
|
||||
@skipIfTorchInductor("to be fixed")
|
||||
def test_inplace_forward_mode_AD(self, device, dtype, op):
|
||||
self._skip_helper(op, device, dtype)
|
||||
|
||||
|
@ -320,6 +320,9 @@ class TorchVariable(VariableTracker):
|
||||
assert isinstance(args[0], TensorVariable)
|
||||
|
||||
if config.fake_tensor_propagation:
|
||||
unimplemented(
|
||||
"TODO: make torch.random.set_rng_state work with FakeTensor/aot_autograd"
|
||||
)
|
||||
# In fake tensor case, this state doesn't matter, but
|
||||
# it needs to be valid to not segfault. Pull a real tensor out.
|
||||
# The value won't matter since we are running with fake tensors anyway, so rng doesn't matter.
|
||||
|
Reference in New Issue
Block a user