mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-21 05:34:18 +08:00
See https://github.com/pytorch/pytorch/pull/129751#issue-2380881501. Most changes are auto-generated by linter. You can review these PRs via: ```bash git diff --ignore-all-space --ignore-blank-lines HEAD~1 ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/129756 Approved by: https://github.com/ezyang
195 lines
6.1 KiB
Python
195 lines
6.1 KiB
Python
from __future__ import annotations
|
|
|
|
import modulefinder
|
|
import os
|
|
import sys
|
|
import warnings
|
|
from pathlib import Path
|
|
from typing import Any
|
|
|
|
|
|
REPO_ROOT = Path(__file__).resolve().parent.parent.parent
|
|
|
|
# These tests are slow enough that it's worth calculating whether the patch
|
|
# touched any related files first. This list was manually generated, but for every
|
|
# run with --determine-from, we use another generated list based on this one and the
|
|
# previous test stats.
|
|
TARGET_DET_LIST = [
|
|
# test_autograd.py is not slow, so it does not belong here. But
|
|
# note that if you try to add it back it will run into
|
|
# https://bugs.python.org/issue40350 because it imports files
|
|
# under test/autograd/.
|
|
"test_binary_ufuncs",
|
|
"test_cpp_extensions_aot_ninja",
|
|
"test_cpp_extensions_aot_no_ninja",
|
|
"test_cpp_extensions_jit",
|
|
"test_cpp_extensions_open_device_registration",
|
|
"test_cpp_extensions_stream_and_event",
|
|
"test_cpp_extensions_mtia_backend",
|
|
"test_cuda",
|
|
"test_cuda_primary_ctx",
|
|
"test_dataloader",
|
|
"test_determination",
|
|
"test_futures",
|
|
"test_jit",
|
|
"test_jit_legacy",
|
|
"test_jit_profiling",
|
|
"test_linalg",
|
|
"test_multiprocessing",
|
|
"test_nn",
|
|
"test_numpy_interop",
|
|
"test_optim",
|
|
"test_overrides",
|
|
"test_pruning_op",
|
|
"test_quantization",
|
|
"test_reductions",
|
|
"test_serialization",
|
|
"test_shape_ops",
|
|
"test_sort_and_select",
|
|
"test_tensorboard",
|
|
"test_testing",
|
|
"test_torch",
|
|
"test_utils",
|
|
"test_view_ops",
|
|
]
|
|
|
|
|
|
_DEP_MODULES_CACHE: dict[str, set[str]] = {}
|
|
|
|
|
|
def should_run_test(
|
|
target_det_list: list[str], test: str, touched_files: list[str], options: Any
|
|
) -> bool:
|
|
test = parse_test_module(test)
|
|
# Some tests are faster to execute than to determine.
|
|
if test not in target_det_list:
|
|
if options.verbose:
|
|
print_to_stderr(f"Running {test} without determination")
|
|
return True
|
|
# HACK: "no_ninja" is not a real module
|
|
if test.endswith("_no_ninja"):
|
|
test = test[: (-1 * len("_no_ninja"))]
|
|
if test.endswith("_ninja"):
|
|
test = test[: (-1 * len("_ninja"))]
|
|
|
|
dep_modules = get_dep_modules(test)
|
|
|
|
for touched_file in touched_files:
|
|
file_type = test_impact_of_file(touched_file)
|
|
if file_type == "NONE":
|
|
continue
|
|
elif file_type == "CI":
|
|
# Force all tests to run if any change is made to the CI
|
|
# configurations.
|
|
log_test_reason(file_type, touched_file, test, options)
|
|
return True
|
|
elif file_type == "UNKNOWN":
|
|
# Assume uncategorized source files can affect every test.
|
|
log_test_reason(file_type, touched_file, test, options)
|
|
return True
|
|
elif file_type in ["TORCH", "CAFFE2", "TEST"]:
|
|
parts = os.path.splitext(touched_file)[0].split(os.sep)
|
|
touched_module = ".".join(parts)
|
|
# test/ path does not have a "test." namespace
|
|
if touched_module.startswith("test."):
|
|
touched_module = touched_module.split("test.")[1]
|
|
if touched_module in dep_modules or touched_module == test.replace(
|
|
"/", "."
|
|
):
|
|
log_test_reason(file_type, touched_file, test, options)
|
|
return True
|
|
|
|
# If nothing has determined the test has run, don't run the test.
|
|
if options.verbose:
|
|
print_to_stderr(f"Determination is skipping {test}")
|
|
|
|
return False
|
|
|
|
|
|
def test_impact_of_file(filename: str) -> str:
|
|
"""Determine what class of impact this file has on test runs.
|
|
|
|
Possible values:
|
|
TORCH - torch python code
|
|
CAFFE2 - caffe2 python code
|
|
TEST - torch test code
|
|
UNKNOWN - may affect all tests
|
|
NONE - known to have no effect on test outcome
|
|
CI - CI configuration files
|
|
"""
|
|
parts = filename.split(os.sep)
|
|
if parts[0] in [".jenkins", ".circleci", ".ci"]:
|
|
return "CI"
|
|
if parts[0] in ["docs", "scripts", "CODEOWNERS", "README.md"]:
|
|
return "NONE"
|
|
elif parts[0] == "torch":
|
|
if parts[-1].endswith(".py") or parts[-1].endswith(".pyi"):
|
|
return "TORCH"
|
|
elif parts[0] == "caffe2":
|
|
if parts[-1].endswith(".py") or parts[-1].endswith(".pyi"):
|
|
return "CAFFE2"
|
|
elif parts[0] == "test":
|
|
if parts[-1].endswith(".py") or parts[-1].endswith(".pyi"):
|
|
return "TEST"
|
|
|
|
return "UNKNOWN"
|
|
|
|
|
|
def log_test_reason(file_type: str, filename: str, test: str, options: Any) -> None:
|
|
if options.verbose:
|
|
print_to_stderr(
|
|
f"Determination found {file_type} file {filename} -- running {test}"
|
|
)
|
|
|
|
|
|
def get_dep_modules(test: str) -> set[str]:
|
|
# Cache results in case of repetition
|
|
if test in _DEP_MODULES_CACHE:
|
|
return _DEP_MODULES_CACHE[test]
|
|
|
|
test_location = REPO_ROOT / "test" / f"{test}.py"
|
|
|
|
# HACK: some platforms default to ascii, so we can't just run_script :(
|
|
finder = modulefinder.ModuleFinder(
|
|
# Ideally exclude all third party modules, to speed up calculation.
|
|
excludes=[
|
|
"scipy",
|
|
"numpy",
|
|
"numba",
|
|
"multiprocessing",
|
|
"sklearn",
|
|
"setuptools",
|
|
"hypothesis",
|
|
"llvmlite",
|
|
"joblib",
|
|
"email",
|
|
"importlib",
|
|
"unittest",
|
|
"urllib",
|
|
"json",
|
|
"collections",
|
|
# Modules below are excluded because they are hitting https://bugs.python.org/issue40350
|
|
# Trigger AttributeError: 'NoneType' object has no attribute 'is_package'
|
|
"mpl_toolkits",
|
|
"google",
|
|
"onnx",
|
|
# Triggers RecursionError
|
|
"mypy",
|
|
],
|
|
)
|
|
|
|
with warnings.catch_warnings():
|
|
warnings.simplefilter("ignore")
|
|
finder.run_script(str(test_location))
|
|
dep_modules = set(finder.modules.keys())
|
|
_DEP_MODULES_CACHE[test] = dep_modules
|
|
return dep_modules
|
|
|
|
|
|
def parse_test_module(test: str) -> str:
|
|
return test.split(".")[0]
|
|
|
|
|
|
def print_to_stderr(message: str) -> None:
|
|
print(message, file=sys.stderr)
|