[CI] Reduce CI_SERIAL_LIST list (#124085)

Add serial marker for individual tests so the test file can be removed from the ci serial list
Run serial marked tests first in serial
Run all other tests afterwards in parallel

Slowly reduce list and mark individual tests as serial instead

Hope # of serial tests is small so sharding evenness doesn't get too messed up

Hopefully can do 3 procs for sm86 and cpu?

serial no longer looks like a real word to me

Pull Request resolved: https://github.com/pytorch/pytorch/pull/124085
Approved by: https://github.com/seemethere, https://github.com/malfet
This commit is contained in:
Catherine Lee
2024-04-17 00:23:42 +00:00
committed by PyTorch MergeBot
parent 946b50c788
commit 0abd3f60fd
4 changed files with 42 additions and 10 deletions

View File

@ -19,3 +19,6 @@ filterwarnings =
ignore:Module already imported so cannot be rewritten.*hypothesis:pytest.PytestAssertRewriteWarning
xfail_strict = True
markers =
serial: marks tests as needs to be run serially (deselect with '-m "not serial"')

View File

@ -70,6 +70,7 @@ from torch.testing._internal.common_utils import (
IS_WINDOWS,
IS_X86,
parametrize,
serialTest,
skipIfRocm,
subtest,
TEST_WITH_ASAN,
@ -9278,6 +9279,7 @@ class CommonTemplate:
@config.patch(
"triton.autotune_pointwise", True
) # needed to introduce config that exceed max shared memory usage
@serialTest()
def test_large_block_sizes(self):
"""
Inductor will try triton configs like x = 64 and y = 1024 which will

View File

@ -246,9 +246,6 @@ CI_SERIAL_LIST = [
"test_module_hooks", # OOM
"inductor/test_max_autotune",
"inductor/test_cutlass_backend", # slow due to many nvcc compilation steps
"inductor/test_torchinductor", # OOM on test_large_block_sizes
"inductor/test_torchinductor_dynamic_shapes", # OOM on test_large_block_sizes
"inductor/test_torchinductor_codegen_dynamic_shapes", # OOM on test_large_block_sizes
"test_profiler", # test_source_multithreaded is probably not compatible with parallelism
]
# A subset of onnx tests that cannot run in parallel due to high memory usage.
@ -1591,6 +1588,11 @@ def run_tests(
):
pool.terminate()
keep_going_message = (
"\n\nTip: You can keep running tests even on failure by passing --keep-going to run_test.py.\n"
"If running on CI, add the 'keep-going' label to your PR and rerun your jobs."
)
try:
for test in selected_tests_serial:
options_clone = copy.deepcopy(options)
@ -1603,19 +1605,29 @@ def run_tests(
and not options.continue_through_error
and not RERUN_DISABLED_TESTS
):
raise RuntimeError(
failure.message
+ "\n\nTip: You can keep running tests even on failure by "
"passing --keep-going to run_test.py.\n"
"If running on CI, add the 'keep-going' label to "
"your PR and rerun your jobs."
)
raise RuntimeError(failure.message + keep_going_message)
# Run tests marked as serial first
for test in selected_tests_parallel:
options_clone = copy.deepcopy(options)
if can_run_in_pytest(test):
options_clone.pytest = True
options_clone.additional_unittest_args.extend(["-m", "serial"])
failure = run_test_module(test, test_directory, options_clone)
test_failed = handle_error_messages(failure)
if (
test_failed
and not options.continue_through_error
and not RERUN_DISABLED_TESTS
):
raise RuntimeError(failure.message + keep_going_message)
os.environ["NUM_PARALLEL_PROCS"] = str(NUM_PROCS)
for test in selected_tests_parallel:
options_clone = copy.deepcopy(options)
if can_run_in_pytest(test):
options_clone.pytest = True
options_clone.additional_unittest_args.extend(["-m", "not serial"])
pool.apply_async(
run_test_module,
args=(test, test_directory, options_clone),
@ -1718,6 +1730,7 @@ def main():
if IS_CI:
gen_ci_artifact([x.to_json() for x in include], [x.to_json() for x in exclude])
print_to_stderr(f"Running parallel tests on {NUM_PROCS} processes")
print_to_stderr(test_batch)
print_to_stderr(test_batch_exclude)

View File

@ -97,6 +97,11 @@ from torch.utils._import_utils import _check_module_exists
import torch.utils._pytree as pytree
from .composite_compliance import no_dispatch
try:
import pytest
has_pytest = True
except ImportError:
has_pytest = False
# Class to keep track of test flags configurable by environment variables.
@ -1384,6 +1389,15 @@ def skipIfTorchInductor(msg="test doesn't currently work with torchinductor",
return decorator
def serialTest(condition=True):
"""
Decorator for running tests serially. Requires pytest
"""
def decorator(fn):
if has_pytest and condition:
return pytest.mark.serial(fn)
return fn
return decorator
def unMarkDynamoStrictTest(cls=None):
def decorator(cls):