[Dynamo][Better Engineering] Type devices, resume_execution and testing utils (#158593)

As part of better engineering week, we would like to improve out type support to improve dev experience in dynamo

This PR adds strict typing support to a set of utilities in dynamo, `device_interface.py`, `resume_execution.py`, `tensor_version_ops.py`, `test_case.py`, and `test_minifier_common.py`

Running
```
mypy torch/_dynamo/device_interface.py torch/_dynamo/resume_execution.py torch/_dynamo/tensor_version_op.py torch/_dynamo/test_case.py torch/_dynamo/test_minifier_common.py  --linecount-report /tmp/coverage_log
```

| -------- | Lines Unannotated | Lines Total | % lines covered | Funcs Unannotated | Funcs Total | % funcs covered |
| -------- | ------- | -------- | ------- | ------- | ------- | ------- |
| Main  |  976 | 1672 | 58.37% | 76 | 112 | 67.86% |
| This PR | 1719 | 1719 | 100.00% | 112 | 112 | 100.00% |
| Delta    | +743 | +47 | +41.63% | +36 | 0 | +32.14% |

Pull Request resolved: https://github.com/pytorch/pytorch/pull/158593
Approved by: https://github.com/mlazos
This commit is contained in:
Lucas Kabela
2025-07-18 18:22:01 +00:00
committed by PyTorch MergeBot
parent 6e07d6a0ff
commit 656885b614
5 changed files with 142 additions and 95 deletions

View File

@ -1,5 +1,3 @@
# mypy: allow-untyped-defs
"""
Device abstraction layer for TorchDynamo and Inductor backends.
@ -21,7 +19,7 @@ import inspect
import time
from collections.abc import Iterable
from dataclasses import dataclass
from typing import Any, Callable, Optional, Union
from typing import Any, Callable, Literal, Optional, Union
import torch
@ -44,17 +42,17 @@ class DeviceInterface:
"""
class device:
def __new__(cls, device: torch.types.Device):
def __new__(cls, device: torch.types.Device) -> Any:
raise NotImplementedError
class Event:
def __new__(cls, *args, **kwargs):
def __new__(cls, *args: Any, **kwargs: Any) -> Any:
raise NotImplementedError(
"Event should be inherited from torch.Event, otherwise, it couldn't be captured by dynamo."
)
class Stream:
def __new__(cls, *args, **kwargs):
def __new__(cls, *args: Any, **kwargs: Any) -> Any:
raise NotImplementedError(
"Stream should be inherited from torch.Stream, otherwise, it couldn't be captured by dynamo."
)
@ -68,7 +66,7 @@ class DeviceInterface:
"""
@staticmethod
def set_device(device: int):
def set_device(device: int) -> None:
raise NotImplementedError
@staticmethod
@ -76,15 +74,15 @@ class DeviceInterface:
raise NotImplementedError
@staticmethod
def get_device_properties(device: torch.types.Device = None):
def get_device_properties(device: torch.types.Device = None) -> Any:
raise NotImplementedError
@staticmethod
def current_device():
def current_device() -> int:
raise NotImplementedError
@staticmethod
def set_device(device: torch.types.Device):
def set_device(device: torch.types.Device) -> None:
raise NotImplementedError
@staticmethod
@ -96,7 +94,7 @@ class DeviceInterface:
raise NotImplementedError
@staticmethod
def device_count():
def device_count() -> int:
raise NotImplementedError
@staticmethod
@ -104,19 +102,19 @@ class DeviceInterface:
raise NotImplementedError
@staticmethod
def stream(stream: torch.Stream):
def stream(stream: torch.Stream) -> Any:
raise NotImplementedError
@staticmethod
def current_stream():
def current_stream() -> torch.Stream:
raise NotImplementedError
@staticmethod
def set_stream(stream: torch.Stream):
def set_stream(stream: torch.Stream) -> None:
raise NotImplementedError
@staticmethod
def _set_stream_by_id(stream_id: int, device_index: int, device_type: int):
def _set_stream_by_id(stream_id: int, device_index: int, device_type: int) -> None:
raise NotImplementedError
@staticmethod
@ -124,19 +122,19 @@ class DeviceInterface:
raise NotImplementedError
@staticmethod
def synchronize(device: torch.types.Device = None):
def synchronize(device: torch.types.Device = None) -> None:
raise NotImplementedError
@classmethod
def get_device_properties(cls, device: torch.types.Device = None):
def get_device_properties(cls, device: torch.types.Device = None) -> Any:
return cls.Worker.get_device_properties(device)
@staticmethod
def get_compute_capability(device: torch.types.Device = None):
def get_compute_capability(device: torch.types.Device = None) -> Any:
raise NotImplementedError
@staticmethod
def is_bf16_supported(including_emulation: bool = False):
def is_bf16_supported(including_emulation: bool = False) -> bool:
raise NotImplementedError
@classmethod
@ -188,11 +186,11 @@ class DeviceGuard:
self.idx = index
self.prev_idx = -1
def __enter__(self):
def __enter__(self) -> None:
if self.idx is not None:
self.prev_idx = self.device_interface.exchange_device(self.idx)
def __exit__(self, type: Any, value: Any, traceback: Any):
def __exit__(self, type: Any, value: Any, traceback: Any) -> Literal[False]:
if self.idx is not None:
self.idx = self.device_interface.maybe_exchange_device(self.prev_idx)
return False
@ -208,7 +206,7 @@ class CudaInterface(DeviceInterface):
class Worker:
@staticmethod
def set_device(device: int):
def set_device(device: int) -> None:
caching_worker_current_devices["cuda"] = device
@staticmethod
@ -218,7 +216,7 @@ class CudaInterface(DeviceInterface):
return torch.cuda.current_device()
@staticmethod
def get_device_properties(device: torch.types.Device = None):
def get_device_properties(device: torch.types.Device = None) -> Any:
if device is not None:
if isinstance(device, str):
device = torch.device(device)
@ -258,7 +256,7 @@ class CudaInterface(DeviceInterface):
return torch.cuda.is_available()
@staticmethod
def get_compute_capability(device: torch.types.Device = None):
def get_compute_capability(device: torch.types.Device = None) -> Union[int, str]:
if torch.version.hip is None:
major, min = torch.cuda.get_device_capability(device)
return major * 10 + min
@ -303,7 +301,7 @@ class XpuInterface(DeviceInterface):
class Worker:
@staticmethod
def set_device(device: int):
def set_device(device: int) -> None:
caching_worker_current_devices["xpu"] = device
@staticmethod
@ -313,7 +311,7 @@ class XpuInterface(DeviceInterface):
return torch.xpu.current_device()
@staticmethod
def get_device_properties(device: torch.types.Device = None):
def get_device_properties(device: torch.types.Device = None) -> Any:
if device is not None:
if isinstance(device, str):
device = torch.device(device)
@ -352,7 +350,7 @@ class XpuInterface(DeviceInterface):
return torch.xpu.is_available()
@staticmethod
def get_compute_capability(device: torch.types.Device = None):
def get_compute_capability(device: torch.types.Device = None) -> Any:
cc = torch.xpu.get_device_capability(device)
return cc
@ -365,7 +363,7 @@ class XpuInterface(DeviceInterface):
return True
@staticmethod
def raise_if_triton_unavailable(evice: torch.types.Device = None) -> None:
def raise_if_triton_unavailable(device: torch.types.Device = None) -> None:
import triton.backends
if "intel" not in triton.backends.backends:
@ -379,18 +377,20 @@ class CpuDeviceProperties:
class CpuInterface(DeviceInterface):
class Event(torch.Event):
def __init__(self, enable_timing=True):
def __init__(self, enable_timing: bool = True) -> None:
self.time = 0.0
def elapsed_time(self, end_event) -> float:
def elapsed_time(self, end_event: Any) -> float:
return (end_event.time - self.time) * 1000
def record(self, stream=None):
def record(self, stream: Any = None) -> None:
self.time = time.perf_counter()
class Worker:
@staticmethod
def get_device_properties(device: torch.types.Device = None):
def get_device_properties(
device: torch.types.Device = None,
) -> CpuDeviceProperties:
import multiprocessing
cpu_count = multiprocessing.cpu_count()
@ -401,7 +401,7 @@ class CpuInterface(DeviceInterface):
return True
@staticmethod
def is_bf16_supported(including_emulation: bool = False):
def is_bf16_supported(including_emulation: bool = False) -> bool:
return True
@staticmethod
@ -409,15 +409,15 @@ class CpuInterface(DeviceInterface):
return ""
@staticmethod
def get_raw_stream(device_idx) -> int:
def get_raw_stream(device_idx: Any) -> int:
return 0
@staticmethod
def current_device():
def current_device() -> int:
return 0
@staticmethod
def synchronize(device: torch.types.Device = None):
def synchronize(device: torch.types.Device = None) -> None:
pass
@staticmethod
@ -450,7 +450,7 @@ class MpsInterface(DeviceInterface):
return torch.backends.mps.is_available()
@staticmethod
def current_device():
def current_device() -> int:
return 0
@staticmethod
@ -458,16 +458,16 @@ class MpsInterface(DeviceInterface):
return ""
@staticmethod
def synchronize(device: torch.types.Device = None):
def synchronize(device: torch.types.Device = None) -> None:
torch.mps.synchronize()
class Worker:
@staticmethod
def get_device_properties(device: torch.types.Device = None):
def get_device_properties(device: torch.types.Device = None) -> dict[str, Any]:
return {}
@staticmethod
def current_device():
def current_device() -> int:
return 0
@ -477,7 +477,7 @@ _device_initialized = False
def register_interface_for_device(
device: Union[str, torch.device], device_interface: type[DeviceInterface]
):
) -> None:
if isinstance(device, torch.device):
device = device.type
device_interfaces[device] = device_interface
@ -499,7 +499,7 @@ def get_registered_device_interfaces() -> Iterable[tuple[str, type[DeviceInterfa
return device_interfaces.items()
def init_device_reg():
def init_device_reg() -> None:
global _device_initialized
register_interface_for_device("cuda", CudaInterface)
for i in range(torch.cuda.device_count()):

View File

@ -1,5 +1,3 @@
# mypy: allow-untyped-defs
"""
This module provides functionality for resuming Python execution at specific points in code,
primarily used by PyTorch Dynamo for control flow handling and optimization. It implements
@ -19,7 +17,9 @@ import copy
import dataclasses
import sys
import types
from typing import Any, cast, Optional
from collections.abc import Iterable
from contextlib import AbstractContextManager
from typing import Any, Callable, cast, Optional
from .bytecode_transformation import (
bytecode_from_template,
@ -52,7 +52,7 @@ TORCH_DYNAMO_RESUME_IN_PREFIX = "torch_dynamo_resume_in"
IS_TRACING_RESUME_PROLOGUE_VARNAME = "__is_tracing_resume_prologue"
def _initial_push_null(insts):
def _initial_push_null(insts: list[Instruction]) -> None:
if sys.version_info >= (3, 11):
insts.append(create_instruction("PUSH_NULL"))
if sys.version_info < (3, 13):
@ -60,7 +60,11 @@ def _initial_push_null(insts):
# Generates bytecode from template and splits the code where LOAD_FAST dummy is present.
def _bytecode_from_template_with_split(template, stack_index, varname_map=None):
def _bytecode_from_template_with_split(
template: Callable[..., Any],
stack_index: int,
varname_map: Optional[dict[str, Any]] = None,
) -> tuple[list[Instruction], list[Instruction]]:
template_code = bytecode_from_template(template, varname_map=varname_map)
template_code.append(create_instruction("POP_TOP"))
@ -90,7 +94,7 @@ def _bytecode_from_template_with_split(template, stack_index, varname_map=None):
return template_code[: dummy_idx + 1], template_code[dummy_idx + 1 :]
def _try_except_tf_mode_template(dummy, stack_var_name):
def _try_except_tf_mode_template(dummy: Any, stack_var_name: Any) -> None:
# NOTE: Make sure this name matches what is generated by symbolic_convert:import_source
# on torch._dynamo.utils.
global __import_torch_dot__dynamo_dot_utils
@ -108,7 +112,9 @@ class ReenterWith:
stack_index: int
target_values: Optional[tuple[Any, ...]] = None
def try_except_torch_function_mode(self, code_options, cleanup: list[Instruction]):
def try_except_torch_function_mode(
self, code_options: dict[str, Any], cleanup: list[Instruction]
) -> list[Instruction]:
"""
Codegen based off of:
try:
@ -130,7 +136,9 @@ class ReenterWith:
# If we do not want to destroy the stack, we can do the same thing as a
# `SETUP_WITH` block, only that we store the context manager in a local_symbol
def try_finally(self, code_options, cleanup: list[Instruction]):
def try_finally(
self, code_options: dict[str, Any], cleanup: list[Instruction]
) -> list[Instruction]:
"""
Codegen based off of:
load args
@ -161,7 +169,7 @@ class ReenterWith:
]
)
def _template(ctx, dummy):
def _template(ctx: AbstractContextManager[Any], dummy: Any) -> None:
ctx.__enter__()
try:
dummy
@ -174,7 +182,9 @@ class ReenterWith:
cleanup[:] = epilogue + cleanup
return create_ctx + setup_try_finally
def __call__(self, code_options, cleanup):
def __call__(
self, code_options: dict[str, Any], cleanup: list[Instruction]
) -> tuple[list[Instruction], Optional[Instruction]]:
"""
Codegen based off of:
with ctx(args):
@ -194,7 +204,7 @@ class ReenterWith:
]
)
def _template(ctx, dummy):
def _template(ctx: AbstractContextManager[Any], dummy: Any) -> None:
with ctx:
dummy
@ -242,7 +252,11 @@ class ResumeFunctionMetadata:
block_target_offset_remap: Optional[dict[int, int]] = None
def _filter_iter(l1, l2, cond):
def _filter_iter(
l1: Iterable[Any],
l2: Iterable[Any],
cond: Callable[[Any, Any], bool],
) -> list[Any]:
"""
Two-pointer conditional filter.
e.g. _filter_iter(insts, sorted_offsets, lambda i, o: i.offset == o)
@ -261,7 +275,7 @@ def _filter_iter(l1, l2, cond):
return res
def _load_tuple_and_call(tup):
def _load_tuple_and_call(tup: tuple[Any, ...]) -> list[Instruction]:
insts: list[Instruction] = []
_initial_push_null(insts)
insts.extend(create_load_const(val) for val in tup)
@ -274,7 +288,7 @@ class ContinueExecutionCache:
generated_code_metadata = ExactWeakKeyDictionary()
@classmethod
def lookup(cls, code, lineno, *key):
def lookup(cls, code: types.CodeType, lineno: int, *key: Any) -> types.CodeType:
if code not in cls.cache:
cls.cache[code] = {}
key = tuple(key)
@ -285,8 +299,8 @@ class ContinueExecutionCache:
@classmethod
def generate(
cls,
code,
lineno,
code: types.CodeType,
lineno: int,
offset: int,
setup_fn_target_offsets: tuple[int, ...], # only used in Python 3.11+
nstack: int,
@ -321,7 +335,9 @@ class ContinueExecutionCache:
is_py311_plus = sys.version_info >= (3, 11)
meta = ResumeFunctionMetadata(code)
def update(instructions: list[Instruction], code_options: dict[str, Any]):
def update(
instructions: list[Instruction], code_options: dict[str, Any]
) -> None:
meta.instructions = copy.deepcopy(instructions)
args = [f"___stack{i}" for i in range(nstack)]
@ -479,7 +495,7 @@ class ContinueExecutionCache:
inst.exn_tab_entry
and inst.exn_tab_entry.target in old_hook_target_remap
):
inst.exn_tab_entry.target = old_hook_target_remap[
inst.exn_tab_entry.target = old_hook_target_remap[ # type: ignore[assignment]
inst.exn_tab_entry.target
]
@ -491,7 +507,7 @@ class ContinueExecutionCache:
return new_code
@staticmethod
def unreachable_codes(code_options) -> list[Instruction]:
def unreachable_codes(code_options: dict[str, Any]) -> list[Instruction]:
"""Codegen a `raise None` to make analysis work for unreachable code"""
return [
create_load_const(None),
@ -500,8 +516,13 @@ class ContinueExecutionCache:
@classmethod
def generate_based_on_original_code_object(
cls, code, lineno, offset: int, setup_fn_target_offsets: tuple[int, ...], *args
):
cls,
code: types.CodeType,
lineno: int,
offset: int,
setup_fn_target_offsets: tuple[int, ...],
*args: Any,
) -> types.CodeType:
"""
This handles the case of generating a resume into code generated
to resume something else. We want to always generate starting
@ -517,7 +538,7 @@ class ContinueExecutionCache:
def find_new_offset(
instructions: list[Instruction], code_options: dict[str, Any]
):
) -> None:
nonlocal new_offset
(target,) = (i for i in instructions if i.offset == offset)
# match the functions starting at the last instruction as we have added a prefix
@ -541,7 +562,7 @@ class ContinueExecutionCache:
def remap_block_offsets(
instructions: list[Instruction], code_options: dict[str, Any]
):
) -> None:
# NOTE: each prefix block generates exactly one PUSH_EXC_INFO,
# so we can tell which block a prefix PUSH_EXC_INFO belongs to,
# by counting. Then we can use meta.prefix_block-target_offset_remap

View File

@ -1,5 +1,3 @@
# mypy: allow-untyped-defs
"""This module implements tensor version operations for Dynamo tracing.
It provides primitives for handling tensor versioning during tracing, particularly in the
@ -18,7 +16,11 @@ Why is this ok?
Note this is similar to how no_grad is handled.
"""
from contextlib import AbstractContextManager
from typing import Any
import torch
from torch import SymInt
from torch._prims import _make_prim, RETURN_TYPE
from torch._subclasses import FakeTensorMode
from torch._subclasses.functional_tensor import FunctionalTensorMode
@ -33,13 +35,14 @@ _tensor_version = _make_prim(
)
@_tensor_version.py_impl(FakeTensorMode)
def _tensor_version_fake(fake_mode, self_tensor):
@_tensor_version.py_impl(FakeTensorMode) # type: ignore[misc]
def _tensor_version_fake(fake_mode: FakeTensorMode, self_tensor: Any) -> SymInt:
"""
The initial dynamo capture of _tensor_version + _unsafe_set_version_counter turns the
`._version` into an unbacked SymInt so that we don't need to specialize on the `._version`
of input tensors to the graph.
"""
assert fake_mode.shape_env is not None
return fake_mode.shape_env.create_unbacked_symint()
@ -53,11 +56,15 @@ _unsafe_set_version_counter = _make_prim(
torch.fx.node.has_side_effect(_unsafe_set_version_counter)
@_tensor_version.py_impl(FunctionalTensorMode)
def _tensor_version_functional(mode, self):
@_tensor_version.py_impl(FunctionalTensorMode) # type: ignore[misc]
def _tensor_version_functional(mode: FunctionalTensorMode, self: Any) -> int:
return self._version
@_unsafe_set_version_counter.py_impl(FunctionalTensorMode)
def _unsafe_set_version_counter_functional(ctx, tensors, versions):
@_unsafe_set_version_counter.py_impl(FunctionalTensorMode) # type: ignore[misc]
def _unsafe_set_version_counter_functional(
ctx: AbstractContextManager[Any],
tensors: tuple[torch.Tensor, ...],
versions: tuple[int, ...],
) -> None:
torch._C._autograd._unsafe_set_version_counter(tensors, versions)

View File

@ -1,5 +1,3 @@
# mypy: allow-untyped-defs
"""Testing utilities for Dynamo, providing a specialized TestCase class and test running functionality.
This module extends PyTorch's testing framework with Dynamo-specific testing capabilities.
@ -18,7 +16,7 @@ import os
import re
import sys
import unittest
from typing import Union
from typing import Any, Callable, Union
import torch
import torch.testing
@ -151,7 +149,12 @@ class CPythonTestCase(TestCase):
fail = unittest.TestCase.fail
failureException = unittest.TestCase.failureException
def compile_fn(self, fn, backend, nopython):
def compile_fn(
self,
fn: Callable[..., Any],
backend: Union[str, Callable[..., Any]],
nopython: bool,
) -> Callable[..., Any]:
# We want to compile only the test function, excluding any setup code
# from unittest
method = getattr(self, self._testMethodName)
@ -159,7 +162,7 @@ class CPythonTestCase(TestCase):
setattr(self, self._testMethodName, method)
return fn
def _dynamo_test_key(self):
def _dynamo_test_key(self) -> str:
suffix = super()._dynamo_test_key()
test_cls = self.__class__
test_file = inspect.getfile(test_cls).split(os.sep)[-1].split(".")[0]

View File

@ -1,5 +1,3 @@
# mypy: allow-untyped-defs
"""Common utilities for testing Dynamo's minifier functionality.
This module provides the base infrastructure for running minification tests in Dynamo.
@ -25,7 +23,8 @@ import subprocess
import sys
import tempfile
import traceback
from typing import Optional
from collections.abc import Sequence
from typing import Any, Optional, Union
from unittest.mock import patch
import torch
@ -40,7 +39,7 @@ class MinifierTestResult:
minifier_code: str
repro_code: str
def _get_module(self, t):
def _get_module(self, t: str) -> str:
match = re.search(r"class Repro\(torch\.nn\.Module\):\s+([ ].*\n| *\n)+", t)
assert match is not None, "failed to find module"
r = match.group(0)
@ -48,7 +47,7 @@ class MinifierTestResult:
r = re.sub(r"\n{3,}", "\n\n", r)
return r.strip()
def get_exported_program_path(self):
def get_exported_program_path(self) -> Optional[str]:
# Extract the exported program file path from AOTI minifier's repro.py
# Regular expression pattern to match the file path
pattern = r'torch\.export\.load\(\s*["\'](.*?)["\']\s*\)'
@ -60,10 +59,10 @@ class MinifierTestResult:
return file_path
return None
def minifier_module(self):
def minifier_module(self) -> str:
return self._get_module(self.minifier_code)
def repro_module(self):
def repro_module(self) -> str:
return self._get_module(self.repro_code)
@ -71,7 +70,7 @@ class MinifierTestBase(torch._dynamo.test_case.TestCase):
DEBUG_DIR = tempfile.mkdtemp()
@classmethod
def setUpClass(cls):
def setUpClass(cls) -> None:
super().setUpClass()
if not os.path.exists(cls.DEBUG_DIR):
cls.DEBUG_DIR = tempfile.mkdtemp()
@ -94,14 +93,14 @@ class MinifierTestBase(torch._dynamo.test_case.TestCase):
)
@classmethod
def tearDownClass(cls):
def tearDownClass(cls) -> None:
if os.getenv("PYTORCH_KEEP_TMPDIR", "0") != "1":
shutil.rmtree(cls.DEBUG_DIR)
else:
print(f"test_minifier_common tmpdir kept at: {cls.DEBUG_DIR}")
cls._exit_stack.close() # type: ignore[attr-defined]
def _gen_codegen_fn_patch_code(self, device, bug_type):
def _gen_codegen_fn_patch_code(self, device: str, bug_type: str) -> str:
assert bug_type in ("compile_error", "runtime_error", "accuracy")
return f"""\
{torch._dynamo.config.codegen_config()}
@ -109,7 +108,9 @@ class MinifierTestBase(torch._dynamo.test_case.TestCase):
torch._inductor.config.{"cpp" if device == "cpu" else "triton"}.inject_relu_bug_TESTING_ONLY = {bug_type!r}
"""
def _maybe_subprocess_run(self, args, *, isolate, cwd=None):
def _maybe_subprocess_run(
self, args: Sequence[Any], *, isolate: bool, cwd: Optional[str] = None
) -> subprocess.CompletedProcess[bytes]:
if not isolate:
assert len(args) >= 2, args
assert args[0] == "python3", args
@ -174,7 +175,9 @@ torch._inductor.config.{"cpp" if device == "cpu" else "triton"}.inject_relu_bug_
# Run `code` in a separate python process.
# Returns the completed process state and the directory containing the
# minifier launcher script, if `code` outputted it.
def _run_test_code(self, code, *, isolate):
def _run_test_code(
self, code: str, *, isolate: bool
) -> tuple[subprocess.CompletedProcess[bytes], Union[str, Any]]:
proc = self._maybe_subprocess_run(
["python3", "-c", code], isolate=isolate, cwd=self.DEBUG_DIR
)
@ -190,8 +193,13 @@ torch._inductor.config.{"cpp" if device == "cpu" else "triton"}.inject_relu_bug_
# Runs the minifier launcher script in `repro_dir`
def _run_minifier_launcher(
self, repro_dir, isolate, *, minifier_args=(), repro_after=None
):
self,
repro_dir: str,
isolate: bool,
*,
minifier_args: Sequence[Any] = (),
repro_after: Optional[str] = None,
) -> tuple[subprocess.CompletedProcess[bytes], str]:
self.assertIsNotNone(repro_dir)
launch_file = _as_posix_path(os.path.join(repro_dir, "minifier_launcher.py"))
with open(launch_file) as f:
@ -212,7 +220,9 @@ torch._inductor.config.{"cpp" if device == "cpu" else "triton"}.inject_relu_bug_
return launch_proc, launch_code
# Runs the repro script in `repro_dir`
def _run_repro(self, repro_dir, *, isolate=True):
def _run_repro(
self, repro_dir: str, *, isolate: bool = True
) -> tuple[subprocess.CompletedProcess[bytes], str]:
self.assertIsNotNone(repro_dir)
repro_file = _as_posix_path(os.path.join(repro_dir, "repro.py"))
with open(repro_file) as f:
@ -230,7 +240,7 @@ torch._inductor.config.{"cpp" if device == "cpu" else "triton"}.inject_relu_bug_
# `run_code` is the code to run for the test case.
# `patch_code` is the code to be patched in every generated file; usually
# just use this to turn on bugs via the config
def _gen_test_code(self, run_code, repro_after, repro_level):
def _gen_test_code(self, run_code: str, repro_after: str, repro_level: int) -> str:
repro_after_line = ""
if repro_after == "aot_inductor":
repro_after_line = (
@ -263,7 +273,13 @@ torch._dynamo.config.debug_dir_root = "{_as_posix_path(self.DEBUG_DIR)}"
# isolate=True only if the bug you're testing would otherwise
# crash the process
def _run_full_test(
self, run_code, repro_after, expected_error, *, isolate, minifier_args=()
self,
run_code: str,
repro_after: str,
expected_error: Optional[str],
*,
isolate: bool,
minifier_args: Sequence[Any] = (),
) -> Optional[MinifierTestResult]:
if isolate:
repro_level = 3