mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[Dynamo][Better Engineering] Type devices, resume_execution and testing utils (#158593)
As part of better engineering week, we would like to improve out type support to improve dev experience in dynamo This PR adds strict typing support to a set of utilities in dynamo, `device_interface.py`, `resume_execution.py`, `tensor_version_ops.py`, `test_case.py`, and `test_minifier_common.py` Running ``` mypy torch/_dynamo/device_interface.py torch/_dynamo/resume_execution.py torch/_dynamo/tensor_version_op.py torch/_dynamo/test_case.py torch/_dynamo/test_minifier_common.py --linecount-report /tmp/coverage_log ``` | -------- | Lines Unannotated | Lines Total | % lines covered | Funcs Unannotated | Funcs Total | % funcs covered | | -------- | ------- | -------- | ------- | ------- | ------- | ------- | | Main | 976 | 1672 | 58.37% | 76 | 112 | 67.86% | | This PR | 1719 | 1719 | 100.00% | 112 | 112 | 100.00% | | Delta | +743 | +47 | +41.63% | +36 | 0 | +32.14% | Pull Request resolved: https://github.com/pytorch/pytorch/pull/158593 Approved by: https://github.com/mlazos
This commit is contained in:
committed by
PyTorch MergeBot
parent
6e07d6a0ff
commit
656885b614
@ -1,5 +1,3 @@
|
||||
# mypy: allow-untyped-defs
|
||||
|
||||
"""
|
||||
Device abstraction layer for TorchDynamo and Inductor backends.
|
||||
|
||||
@ -21,7 +19,7 @@ import inspect
|
||||
import time
|
||||
from collections.abc import Iterable
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Callable, Optional, Union
|
||||
from typing import Any, Callable, Literal, Optional, Union
|
||||
|
||||
import torch
|
||||
|
||||
@ -44,17 +42,17 @@ class DeviceInterface:
|
||||
"""
|
||||
|
||||
class device:
|
||||
def __new__(cls, device: torch.types.Device):
|
||||
def __new__(cls, device: torch.types.Device) -> Any:
|
||||
raise NotImplementedError
|
||||
|
||||
class Event:
|
||||
def __new__(cls, *args, **kwargs):
|
||||
def __new__(cls, *args: Any, **kwargs: Any) -> Any:
|
||||
raise NotImplementedError(
|
||||
"Event should be inherited from torch.Event, otherwise, it couldn't be captured by dynamo."
|
||||
)
|
||||
|
||||
class Stream:
|
||||
def __new__(cls, *args, **kwargs):
|
||||
def __new__(cls, *args: Any, **kwargs: Any) -> Any:
|
||||
raise NotImplementedError(
|
||||
"Stream should be inherited from torch.Stream, otherwise, it couldn't be captured by dynamo."
|
||||
)
|
||||
@ -68,7 +66,7 @@ class DeviceInterface:
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def set_device(device: int):
|
||||
def set_device(device: int) -> None:
|
||||
raise NotImplementedError
|
||||
|
||||
@staticmethod
|
||||
@ -76,15 +74,15 @@ class DeviceInterface:
|
||||
raise NotImplementedError
|
||||
|
||||
@staticmethod
|
||||
def get_device_properties(device: torch.types.Device = None):
|
||||
def get_device_properties(device: torch.types.Device = None) -> Any:
|
||||
raise NotImplementedError
|
||||
|
||||
@staticmethod
|
||||
def current_device():
|
||||
def current_device() -> int:
|
||||
raise NotImplementedError
|
||||
|
||||
@staticmethod
|
||||
def set_device(device: torch.types.Device):
|
||||
def set_device(device: torch.types.Device) -> None:
|
||||
raise NotImplementedError
|
||||
|
||||
@staticmethod
|
||||
@ -96,7 +94,7 @@ class DeviceInterface:
|
||||
raise NotImplementedError
|
||||
|
||||
@staticmethod
|
||||
def device_count():
|
||||
def device_count() -> int:
|
||||
raise NotImplementedError
|
||||
|
||||
@staticmethod
|
||||
@ -104,19 +102,19 @@ class DeviceInterface:
|
||||
raise NotImplementedError
|
||||
|
||||
@staticmethod
|
||||
def stream(stream: torch.Stream):
|
||||
def stream(stream: torch.Stream) -> Any:
|
||||
raise NotImplementedError
|
||||
|
||||
@staticmethod
|
||||
def current_stream():
|
||||
def current_stream() -> torch.Stream:
|
||||
raise NotImplementedError
|
||||
|
||||
@staticmethod
|
||||
def set_stream(stream: torch.Stream):
|
||||
def set_stream(stream: torch.Stream) -> None:
|
||||
raise NotImplementedError
|
||||
|
||||
@staticmethod
|
||||
def _set_stream_by_id(stream_id: int, device_index: int, device_type: int):
|
||||
def _set_stream_by_id(stream_id: int, device_index: int, device_type: int) -> None:
|
||||
raise NotImplementedError
|
||||
|
||||
@staticmethod
|
||||
@ -124,19 +122,19 @@ class DeviceInterface:
|
||||
raise NotImplementedError
|
||||
|
||||
@staticmethod
|
||||
def synchronize(device: torch.types.Device = None):
|
||||
def synchronize(device: torch.types.Device = None) -> None:
|
||||
raise NotImplementedError
|
||||
|
||||
@classmethod
|
||||
def get_device_properties(cls, device: torch.types.Device = None):
|
||||
def get_device_properties(cls, device: torch.types.Device = None) -> Any:
|
||||
return cls.Worker.get_device_properties(device)
|
||||
|
||||
@staticmethod
|
||||
def get_compute_capability(device: torch.types.Device = None):
|
||||
def get_compute_capability(device: torch.types.Device = None) -> Any:
|
||||
raise NotImplementedError
|
||||
|
||||
@staticmethod
|
||||
def is_bf16_supported(including_emulation: bool = False):
|
||||
def is_bf16_supported(including_emulation: bool = False) -> bool:
|
||||
raise NotImplementedError
|
||||
|
||||
@classmethod
|
||||
@ -188,11 +186,11 @@ class DeviceGuard:
|
||||
self.idx = index
|
||||
self.prev_idx = -1
|
||||
|
||||
def __enter__(self):
|
||||
def __enter__(self) -> None:
|
||||
if self.idx is not None:
|
||||
self.prev_idx = self.device_interface.exchange_device(self.idx)
|
||||
|
||||
def __exit__(self, type: Any, value: Any, traceback: Any):
|
||||
def __exit__(self, type: Any, value: Any, traceback: Any) -> Literal[False]:
|
||||
if self.idx is not None:
|
||||
self.idx = self.device_interface.maybe_exchange_device(self.prev_idx)
|
||||
return False
|
||||
@ -208,7 +206,7 @@ class CudaInterface(DeviceInterface):
|
||||
|
||||
class Worker:
|
||||
@staticmethod
|
||||
def set_device(device: int):
|
||||
def set_device(device: int) -> None:
|
||||
caching_worker_current_devices["cuda"] = device
|
||||
|
||||
@staticmethod
|
||||
@ -218,7 +216,7 @@ class CudaInterface(DeviceInterface):
|
||||
return torch.cuda.current_device()
|
||||
|
||||
@staticmethod
|
||||
def get_device_properties(device: torch.types.Device = None):
|
||||
def get_device_properties(device: torch.types.Device = None) -> Any:
|
||||
if device is not None:
|
||||
if isinstance(device, str):
|
||||
device = torch.device(device)
|
||||
@ -258,7 +256,7 @@ class CudaInterface(DeviceInterface):
|
||||
return torch.cuda.is_available()
|
||||
|
||||
@staticmethod
|
||||
def get_compute_capability(device: torch.types.Device = None):
|
||||
def get_compute_capability(device: torch.types.Device = None) -> Union[int, str]:
|
||||
if torch.version.hip is None:
|
||||
major, min = torch.cuda.get_device_capability(device)
|
||||
return major * 10 + min
|
||||
@ -303,7 +301,7 @@ class XpuInterface(DeviceInterface):
|
||||
|
||||
class Worker:
|
||||
@staticmethod
|
||||
def set_device(device: int):
|
||||
def set_device(device: int) -> None:
|
||||
caching_worker_current_devices["xpu"] = device
|
||||
|
||||
@staticmethod
|
||||
@ -313,7 +311,7 @@ class XpuInterface(DeviceInterface):
|
||||
return torch.xpu.current_device()
|
||||
|
||||
@staticmethod
|
||||
def get_device_properties(device: torch.types.Device = None):
|
||||
def get_device_properties(device: torch.types.Device = None) -> Any:
|
||||
if device is not None:
|
||||
if isinstance(device, str):
|
||||
device = torch.device(device)
|
||||
@ -352,7 +350,7 @@ class XpuInterface(DeviceInterface):
|
||||
return torch.xpu.is_available()
|
||||
|
||||
@staticmethod
|
||||
def get_compute_capability(device: torch.types.Device = None):
|
||||
def get_compute_capability(device: torch.types.Device = None) -> Any:
|
||||
cc = torch.xpu.get_device_capability(device)
|
||||
return cc
|
||||
|
||||
@ -365,7 +363,7 @@ class XpuInterface(DeviceInterface):
|
||||
return True
|
||||
|
||||
@staticmethod
|
||||
def raise_if_triton_unavailable(evice: torch.types.Device = None) -> None:
|
||||
def raise_if_triton_unavailable(device: torch.types.Device = None) -> None:
|
||||
import triton.backends
|
||||
|
||||
if "intel" not in triton.backends.backends:
|
||||
@ -379,18 +377,20 @@ class CpuDeviceProperties:
|
||||
|
||||
class CpuInterface(DeviceInterface):
|
||||
class Event(torch.Event):
|
||||
def __init__(self, enable_timing=True):
|
||||
def __init__(self, enable_timing: bool = True) -> None:
|
||||
self.time = 0.0
|
||||
|
||||
def elapsed_time(self, end_event) -> float:
|
||||
def elapsed_time(self, end_event: Any) -> float:
|
||||
return (end_event.time - self.time) * 1000
|
||||
|
||||
def record(self, stream=None):
|
||||
def record(self, stream: Any = None) -> None:
|
||||
self.time = time.perf_counter()
|
||||
|
||||
class Worker:
|
||||
@staticmethod
|
||||
def get_device_properties(device: torch.types.Device = None):
|
||||
def get_device_properties(
|
||||
device: torch.types.Device = None,
|
||||
) -> CpuDeviceProperties:
|
||||
import multiprocessing
|
||||
|
||||
cpu_count = multiprocessing.cpu_count()
|
||||
@ -401,7 +401,7 @@ class CpuInterface(DeviceInterface):
|
||||
return True
|
||||
|
||||
@staticmethod
|
||||
def is_bf16_supported(including_emulation: bool = False):
|
||||
def is_bf16_supported(including_emulation: bool = False) -> bool:
|
||||
return True
|
||||
|
||||
@staticmethod
|
||||
@ -409,15 +409,15 @@ class CpuInterface(DeviceInterface):
|
||||
return ""
|
||||
|
||||
@staticmethod
|
||||
def get_raw_stream(device_idx) -> int:
|
||||
def get_raw_stream(device_idx: Any) -> int:
|
||||
return 0
|
||||
|
||||
@staticmethod
|
||||
def current_device():
|
||||
def current_device() -> int:
|
||||
return 0
|
||||
|
||||
@staticmethod
|
||||
def synchronize(device: torch.types.Device = None):
|
||||
def synchronize(device: torch.types.Device = None) -> None:
|
||||
pass
|
||||
|
||||
@staticmethod
|
||||
@ -450,7 +450,7 @@ class MpsInterface(DeviceInterface):
|
||||
return torch.backends.mps.is_available()
|
||||
|
||||
@staticmethod
|
||||
def current_device():
|
||||
def current_device() -> int:
|
||||
return 0
|
||||
|
||||
@staticmethod
|
||||
@ -458,16 +458,16 @@ class MpsInterface(DeviceInterface):
|
||||
return ""
|
||||
|
||||
@staticmethod
|
||||
def synchronize(device: torch.types.Device = None):
|
||||
def synchronize(device: torch.types.Device = None) -> None:
|
||||
torch.mps.synchronize()
|
||||
|
||||
class Worker:
|
||||
@staticmethod
|
||||
def get_device_properties(device: torch.types.Device = None):
|
||||
def get_device_properties(device: torch.types.Device = None) -> dict[str, Any]:
|
||||
return {}
|
||||
|
||||
@staticmethod
|
||||
def current_device():
|
||||
def current_device() -> int:
|
||||
return 0
|
||||
|
||||
|
||||
@ -477,7 +477,7 @@ _device_initialized = False
|
||||
|
||||
def register_interface_for_device(
|
||||
device: Union[str, torch.device], device_interface: type[DeviceInterface]
|
||||
):
|
||||
) -> None:
|
||||
if isinstance(device, torch.device):
|
||||
device = device.type
|
||||
device_interfaces[device] = device_interface
|
||||
@ -499,7 +499,7 @@ def get_registered_device_interfaces() -> Iterable[tuple[str, type[DeviceInterfa
|
||||
return device_interfaces.items()
|
||||
|
||||
|
||||
def init_device_reg():
|
||||
def init_device_reg() -> None:
|
||||
global _device_initialized
|
||||
register_interface_for_device("cuda", CudaInterface)
|
||||
for i in range(torch.cuda.device_count()):
|
||||
|
@ -1,5 +1,3 @@
|
||||
# mypy: allow-untyped-defs
|
||||
|
||||
"""
|
||||
This module provides functionality for resuming Python execution at specific points in code,
|
||||
primarily used by PyTorch Dynamo for control flow handling and optimization. It implements
|
||||
@ -19,7 +17,9 @@ import copy
|
||||
import dataclasses
|
||||
import sys
|
||||
import types
|
||||
from typing import Any, cast, Optional
|
||||
from collections.abc import Iterable
|
||||
from contextlib import AbstractContextManager
|
||||
from typing import Any, Callable, cast, Optional
|
||||
|
||||
from .bytecode_transformation import (
|
||||
bytecode_from_template,
|
||||
@ -52,7 +52,7 @@ TORCH_DYNAMO_RESUME_IN_PREFIX = "torch_dynamo_resume_in"
|
||||
IS_TRACING_RESUME_PROLOGUE_VARNAME = "__is_tracing_resume_prologue"
|
||||
|
||||
|
||||
def _initial_push_null(insts):
|
||||
def _initial_push_null(insts: list[Instruction]) -> None:
|
||||
if sys.version_info >= (3, 11):
|
||||
insts.append(create_instruction("PUSH_NULL"))
|
||||
if sys.version_info < (3, 13):
|
||||
@ -60,7 +60,11 @@ def _initial_push_null(insts):
|
||||
|
||||
|
||||
# Generates bytecode from template and splits the code where LOAD_FAST dummy is present.
|
||||
def _bytecode_from_template_with_split(template, stack_index, varname_map=None):
|
||||
def _bytecode_from_template_with_split(
|
||||
template: Callable[..., Any],
|
||||
stack_index: int,
|
||||
varname_map: Optional[dict[str, Any]] = None,
|
||||
) -> tuple[list[Instruction], list[Instruction]]:
|
||||
template_code = bytecode_from_template(template, varname_map=varname_map)
|
||||
template_code.append(create_instruction("POP_TOP"))
|
||||
|
||||
@ -90,7 +94,7 @@ def _bytecode_from_template_with_split(template, stack_index, varname_map=None):
|
||||
return template_code[: dummy_idx + 1], template_code[dummy_idx + 1 :]
|
||||
|
||||
|
||||
def _try_except_tf_mode_template(dummy, stack_var_name):
|
||||
def _try_except_tf_mode_template(dummy: Any, stack_var_name: Any) -> None:
|
||||
# NOTE: Make sure this name matches what is generated by symbolic_convert:import_source
|
||||
# on torch._dynamo.utils.
|
||||
global __import_torch_dot__dynamo_dot_utils
|
||||
@ -108,7 +112,9 @@ class ReenterWith:
|
||||
stack_index: int
|
||||
target_values: Optional[tuple[Any, ...]] = None
|
||||
|
||||
def try_except_torch_function_mode(self, code_options, cleanup: list[Instruction]):
|
||||
def try_except_torch_function_mode(
|
||||
self, code_options: dict[str, Any], cleanup: list[Instruction]
|
||||
) -> list[Instruction]:
|
||||
"""
|
||||
Codegen based off of:
|
||||
try:
|
||||
@ -130,7 +136,9 @@ class ReenterWith:
|
||||
|
||||
# If we do not want to destroy the stack, we can do the same thing as a
|
||||
# `SETUP_WITH` block, only that we store the context manager in a local_symbol
|
||||
def try_finally(self, code_options, cleanup: list[Instruction]):
|
||||
def try_finally(
|
||||
self, code_options: dict[str, Any], cleanup: list[Instruction]
|
||||
) -> list[Instruction]:
|
||||
"""
|
||||
Codegen based off of:
|
||||
load args
|
||||
@ -161,7 +169,7 @@ class ReenterWith:
|
||||
]
|
||||
)
|
||||
|
||||
def _template(ctx, dummy):
|
||||
def _template(ctx: AbstractContextManager[Any], dummy: Any) -> None:
|
||||
ctx.__enter__()
|
||||
try:
|
||||
dummy
|
||||
@ -174,7 +182,9 @@ class ReenterWith:
|
||||
cleanup[:] = epilogue + cleanup
|
||||
return create_ctx + setup_try_finally
|
||||
|
||||
def __call__(self, code_options, cleanup):
|
||||
def __call__(
|
||||
self, code_options: dict[str, Any], cleanup: list[Instruction]
|
||||
) -> tuple[list[Instruction], Optional[Instruction]]:
|
||||
"""
|
||||
Codegen based off of:
|
||||
with ctx(args):
|
||||
@ -194,7 +204,7 @@ class ReenterWith:
|
||||
]
|
||||
)
|
||||
|
||||
def _template(ctx, dummy):
|
||||
def _template(ctx: AbstractContextManager[Any], dummy: Any) -> None:
|
||||
with ctx:
|
||||
dummy
|
||||
|
||||
@ -242,7 +252,11 @@ class ResumeFunctionMetadata:
|
||||
block_target_offset_remap: Optional[dict[int, int]] = None
|
||||
|
||||
|
||||
def _filter_iter(l1, l2, cond):
|
||||
def _filter_iter(
|
||||
l1: Iterable[Any],
|
||||
l2: Iterable[Any],
|
||||
cond: Callable[[Any, Any], bool],
|
||||
) -> list[Any]:
|
||||
"""
|
||||
Two-pointer conditional filter.
|
||||
e.g. _filter_iter(insts, sorted_offsets, lambda i, o: i.offset == o)
|
||||
@ -261,7 +275,7 @@ def _filter_iter(l1, l2, cond):
|
||||
return res
|
||||
|
||||
|
||||
def _load_tuple_and_call(tup):
|
||||
def _load_tuple_and_call(tup: tuple[Any, ...]) -> list[Instruction]:
|
||||
insts: list[Instruction] = []
|
||||
_initial_push_null(insts)
|
||||
insts.extend(create_load_const(val) for val in tup)
|
||||
@ -274,7 +288,7 @@ class ContinueExecutionCache:
|
||||
generated_code_metadata = ExactWeakKeyDictionary()
|
||||
|
||||
@classmethod
|
||||
def lookup(cls, code, lineno, *key):
|
||||
def lookup(cls, code: types.CodeType, lineno: int, *key: Any) -> types.CodeType:
|
||||
if code not in cls.cache:
|
||||
cls.cache[code] = {}
|
||||
key = tuple(key)
|
||||
@ -285,8 +299,8 @@ class ContinueExecutionCache:
|
||||
@classmethod
|
||||
def generate(
|
||||
cls,
|
||||
code,
|
||||
lineno,
|
||||
code: types.CodeType,
|
||||
lineno: int,
|
||||
offset: int,
|
||||
setup_fn_target_offsets: tuple[int, ...], # only used in Python 3.11+
|
||||
nstack: int,
|
||||
@ -321,7 +335,9 @@ class ContinueExecutionCache:
|
||||
is_py311_plus = sys.version_info >= (3, 11)
|
||||
meta = ResumeFunctionMetadata(code)
|
||||
|
||||
def update(instructions: list[Instruction], code_options: dict[str, Any]):
|
||||
def update(
|
||||
instructions: list[Instruction], code_options: dict[str, Any]
|
||||
) -> None:
|
||||
meta.instructions = copy.deepcopy(instructions)
|
||||
|
||||
args = [f"___stack{i}" for i in range(nstack)]
|
||||
@ -479,7 +495,7 @@ class ContinueExecutionCache:
|
||||
inst.exn_tab_entry
|
||||
and inst.exn_tab_entry.target in old_hook_target_remap
|
||||
):
|
||||
inst.exn_tab_entry.target = old_hook_target_remap[
|
||||
inst.exn_tab_entry.target = old_hook_target_remap[ # type: ignore[assignment]
|
||||
inst.exn_tab_entry.target
|
||||
]
|
||||
|
||||
@ -491,7 +507,7 @@ class ContinueExecutionCache:
|
||||
return new_code
|
||||
|
||||
@staticmethod
|
||||
def unreachable_codes(code_options) -> list[Instruction]:
|
||||
def unreachable_codes(code_options: dict[str, Any]) -> list[Instruction]:
|
||||
"""Codegen a `raise None` to make analysis work for unreachable code"""
|
||||
return [
|
||||
create_load_const(None),
|
||||
@ -500,8 +516,13 @@ class ContinueExecutionCache:
|
||||
|
||||
@classmethod
|
||||
def generate_based_on_original_code_object(
|
||||
cls, code, lineno, offset: int, setup_fn_target_offsets: tuple[int, ...], *args
|
||||
):
|
||||
cls,
|
||||
code: types.CodeType,
|
||||
lineno: int,
|
||||
offset: int,
|
||||
setup_fn_target_offsets: tuple[int, ...],
|
||||
*args: Any,
|
||||
) -> types.CodeType:
|
||||
"""
|
||||
This handles the case of generating a resume into code generated
|
||||
to resume something else. We want to always generate starting
|
||||
@ -517,7 +538,7 @@ class ContinueExecutionCache:
|
||||
|
||||
def find_new_offset(
|
||||
instructions: list[Instruction], code_options: dict[str, Any]
|
||||
):
|
||||
) -> None:
|
||||
nonlocal new_offset
|
||||
(target,) = (i for i in instructions if i.offset == offset)
|
||||
# match the functions starting at the last instruction as we have added a prefix
|
||||
@ -541,7 +562,7 @@ class ContinueExecutionCache:
|
||||
|
||||
def remap_block_offsets(
|
||||
instructions: list[Instruction], code_options: dict[str, Any]
|
||||
):
|
||||
) -> None:
|
||||
# NOTE: each prefix block generates exactly one PUSH_EXC_INFO,
|
||||
# so we can tell which block a prefix PUSH_EXC_INFO belongs to,
|
||||
# by counting. Then we can use meta.prefix_block-target_offset_remap
|
||||
|
@ -1,5 +1,3 @@
|
||||
# mypy: allow-untyped-defs
|
||||
|
||||
"""This module implements tensor version operations for Dynamo tracing.
|
||||
|
||||
It provides primitives for handling tensor versioning during tracing, particularly in the
|
||||
@ -18,7 +16,11 @@ Why is this ok?
|
||||
Note this is similar to how no_grad is handled.
|
||||
"""
|
||||
|
||||
from contextlib import AbstractContextManager
|
||||
from typing import Any
|
||||
|
||||
import torch
|
||||
from torch import SymInt
|
||||
from torch._prims import _make_prim, RETURN_TYPE
|
||||
from torch._subclasses import FakeTensorMode
|
||||
from torch._subclasses.functional_tensor import FunctionalTensorMode
|
||||
@ -33,13 +35,14 @@ _tensor_version = _make_prim(
|
||||
)
|
||||
|
||||
|
||||
@_tensor_version.py_impl(FakeTensorMode)
|
||||
def _tensor_version_fake(fake_mode, self_tensor):
|
||||
@_tensor_version.py_impl(FakeTensorMode) # type: ignore[misc]
|
||||
def _tensor_version_fake(fake_mode: FakeTensorMode, self_tensor: Any) -> SymInt:
|
||||
"""
|
||||
The initial dynamo capture of _tensor_version + _unsafe_set_version_counter turns the
|
||||
`._version` into an unbacked SymInt so that we don't need to specialize on the `._version`
|
||||
of input tensors to the graph.
|
||||
"""
|
||||
assert fake_mode.shape_env is not None
|
||||
return fake_mode.shape_env.create_unbacked_symint()
|
||||
|
||||
|
||||
@ -53,11 +56,15 @@ _unsafe_set_version_counter = _make_prim(
|
||||
torch.fx.node.has_side_effect(_unsafe_set_version_counter)
|
||||
|
||||
|
||||
@_tensor_version.py_impl(FunctionalTensorMode)
|
||||
def _tensor_version_functional(mode, self):
|
||||
@_tensor_version.py_impl(FunctionalTensorMode) # type: ignore[misc]
|
||||
def _tensor_version_functional(mode: FunctionalTensorMode, self: Any) -> int:
|
||||
return self._version
|
||||
|
||||
|
||||
@_unsafe_set_version_counter.py_impl(FunctionalTensorMode)
|
||||
def _unsafe_set_version_counter_functional(ctx, tensors, versions):
|
||||
@_unsafe_set_version_counter.py_impl(FunctionalTensorMode) # type: ignore[misc]
|
||||
def _unsafe_set_version_counter_functional(
|
||||
ctx: AbstractContextManager[Any],
|
||||
tensors: tuple[torch.Tensor, ...],
|
||||
versions: tuple[int, ...],
|
||||
) -> None:
|
||||
torch._C._autograd._unsafe_set_version_counter(tensors, versions)
|
||||
|
@ -1,5 +1,3 @@
|
||||
# mypy: allow-untyped-defs
|
||||
|
||||
"""Testing utilities for Dynamo, providing a specialized TestCase class and test running functionality.
|
||||
|
||||
This module extends PyTorch's testing framework with Dynamo-specific testing capabilities.
|
||||
@ -18,7 +16,7 @@ import os
|
||||
import re
|
||||
import sys
|
||||
import unittest
|
||||
from typing import Union
|
||||
from typing import Any, Callable, Union
|
||||
|
||||
import torch
|
||||
import torch.testing
|
||||
@ -151,7 +149,12 @@ class CPythonTestCase(TestCase):
|
||||
fail = unittest.TestCase.fail
|
||||
failureException = unittest.TestCase.failureException
|
||||
|
||||
def compile_fn(self, fn, backend, nopython):
|
||||
def compile_fn(
|
||||
self,
|
||||
fn: Callable[..., Any],
|
||||
backend: Union[str, Callable[..., Any]],
|
||||
nopython: bool,
|
||||
) -> Callable[..., Any]:
|
||||
# We want to compile only the test function, excluding any setup code
|
||||
# from unittest
|
||||
method = getattr(self, self._testMethodName)
|
||||
@ -159,7 +162,7 @@ class CPythonTestCase(TestCase):
|
||||
setattr(self, self._testMethodName, method)
|
||||
return fn
|
||||
|
||||
def _dynamo_test_key(self):
|
||||
def _dynamo_test_key(self) -> str:
|
||||
suffix = super()._dynamo_test_key()
|
||||
test_cls = self.__class__
|
||||
test_file = inspect.getfile(test_cls).split(os.sep)[-1].split(".")[0]
|
||||
|
@ -1,5 +1,3 @@
|
||||
# mypy: allow-untyped-defs
|
||||
|
||||
"""Common utilities for testing Dynamo's minifier functionality.
|
||||
|
||||
This module provides the base infrastructure for running minification tests in Dynamo.
|
||||
@ -25,7 +23,8 @@ import subprocess
|
||||
import sys
|
||||
import tempfile
|
||||
import traceback
|
||||
from typing import Optional
|
||||
from collections.abc import Sequence
|
||||
from typing import Any, Optional, Union
|
||||
from unittest.mock import patch
|
||||
|
||||
import torch
|
||||
@ -40,7 +39,7 @@ class MinifierTestResult:
|
||||
minifier_code: str
|
||||
repro_code: str
|
||||
|
||||
def _get_module(self, t):
|
||||
def _get_module(self, t: str) -> str:
|
||||
match = re.search(r"class Repro\(torch\.nn\.Module\):\s+([ ].*\n| *\n)+", t)
|
||||
assert match is not None, "failed to find module"
|
||||
r = match.group(0)
|
||||
@ -48,7 +47,7 @@ class MinifierTestResult:
|
||||
r = re.sub(r"\n{3,}", "\n\n", r)
|
||||
return r.strip()
|
||||
|
||||
def get_exported_program_path(self):
|
||||
def get_exported_program_path(self) -> Optional[str]:
|
||||
# Extract the exported program file path from AOTI minifier's repro.py
|
||||
# Regular expression pattern to match the file path
|
||||
pattern = r'torch\.export\.load\(\s*["\'](.*?)["\']\s*\)'
|
||||
@ -60,10 +59,10 @@ class MinifierTestResult:
|
||||
return file_path
|
||||
return None
|
||||
|
||||
def minifier_module(self):
|
||||
def minifier_module(self) -> str:
|
||||
return self._get_module(self.minifier_code)
|
||||
|
||||
def repro_module(self):
|
||||
def repro_module(self) -> str:
|
||||
return self._get_module(self.repro_code)
|
||||
|
||||
|
||||
@ -71,7 +70,7 @@ class MinifierTestBase(torch._dynamo.test_case.TestCase):
|
||||
DEBUG_DIR = tempfile.mkdtemp()
|
||||
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
def setUpClass(cls) -> None:
|
||||
super().setUpClass()
|
||||
if not os.path.exists(cls.DEBUG_DIR):
|
||||
cls.DEBUG_DIR = tempfile.mkdtemp()
|
||||
@ -94,14 +93,14 @@ class MinifierTestBase(torch._dynamo.test_case.TestCase):
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def tearDownClass(cls):
|
||||
def tearDownClass(cls) -> None:
|
||||
if os.getenv("PYTORCH_KEEP_TMPDIR", "0") != "1":
|
||||
shutil.rmtree(cls.DEBUG_DIR)
|
||||
else:
|
||||
print(f"test_minifier_common tmpdir kept at: {cls.DEBUG_DIR}")
|
||||
cls._exit_stack.close() # type: ignore[attr-defined]
|
||||
|
||||
def _gen_codegen_fn_patch_code(self, device, bug_type):
|
||||
def _gen_codegen_fn_patch_code(self, device: str, bug_type: str) -> str:
|
||||
assert bug_type in ("compile_error", "runtime_error", "accuracy")
|
||||
return f"""\
|
||||
{torch._dynamo.config.codegen_config()}
|
||||
@ -109,7 +108,9 @@ class MinifierTestBase(torch._dynamo.test_case.TestCase):
|
||||
torch._inductor.config.{"cpp" if device == "cpu" else "triton"}.inject_relu_bug_TESTING_ONLY = {bug_type!r}
|
||||
"""
|
||||
|
||||
def _maybe_subprocess_run(self, args, *, isolate, cwd=None):
|
||||
def _maybe_subprocess_run(
|
||||
self, args: Sequence[Any], *, isolate: bool, cwd: Optional[str] = None
|
||||
) -> subprocess.CompletedProcess[bytes]:
|
||||
if not isolate:
|
||||
assert len(args) >= 2, args
|
||||
assert args[0] == "python3", args
|
||||
@ -174,7 +175,9 @@ torch._inductor.config.{"cpp" if device == "cpu" else "triton"}.inject_relu_bug_
|
||||
# Run `code` in a separate python process.
|
||||
# Returns the completed process state and the directory containing the
|
||||
# minifier launcher script, if `code` outputted it.
|
||||
def _run_test_code(self, code, *, isolate):
|
||||
def _run_test_code(
|
||||
self, code: str, *, isolate: bool
|
||||
) -> tuple[subprocess.CompletedProcess[bytes], Union[str, Any]]:
|
||||
proc = self._maybe_subprocess_run(
|
||||
["python3", "-c", code], isolate=isolate, cwd=self.DEBUG_DIR
|
||||
)
|
||||
@ -190,8 +193,13 @@ torch._inductor.config.{"cpp" if device == "cpu" else "triton"}.inject_relu_bug_
|
||||
|
||||
# Runs the minifier launcher script in `repro_dir`
|
||||
def _run_minifier_launcher(
|
||||
self, repro_dir, isolate, *, minifier_args=(), repro_after=None
|
||||
):
|
||||
self,
|
||||
repro_dir: str,
|
||||
isolate: bool,
|
||||
*,
|
||||
minifier_args: Sequence[Any] = (),
|
||||
repro_after: Optional[str] = None,
|
||||
) -> tuple[subprocess.CompletedProcess[bytes], str]:
|
||||
self.assertIsNotNone(repro_dir)
|
||||
launch_file = _as_posix_path(os.path.join(repro_dir, "minifier_launcher.py"))
|
||||
with open(launch_file) as f:
|
||||
@ -212,7 +220,9 @@ torch._inductor.config.{"cpp" if device == "cpu" else "triton"}.inject_relu_bug_
|
||||
return launch_proc, launch_code
|
||||
|
||||
# Runs the repro script in `repro_dir`
|
||||
def _run_repro(self, repro_dir, *, isolate=True):
|
||||
def _run_repro(
|
||||
self, repro_dir: str, *, isolate: bool = True
|
||||
) -> tuple[subprocess.CompletedProcess[bytes], str]:
|
||||
self.assertIsNotNone(repro_dir)
|
||||
repro_file = _as_posix_path(os.path.join(repro_dir, "repro.py"))
|
||||
with open(repro_file) as f:
|
||||
@ -230,7 +240,7 @@ torch._inductor.config.{"cpp" if device == "cpu" else "triton"}.inject_relu_bug_
|
||||
# `run_code` is the code to run for the test case.
|
||||
# `patch_code` is the code to be patched in every generated file; usually
|
||||
# just use this to turn on bugs via the config
|
||||
def _gen_test_code(self, run_code, repro_after, repro_level):
|
||||
def _gen_test_code(self, run_code: str, repro_after: str, repro_level: int) -> str:
|
||||
repro_after_line = ""
|
||||
if repro_after == "aot_inductor":
|
||||
repro_after_line = (
|
||||
@ -263,7 +273,13 @@ torch._dynamo.config.debug_dir_root = "{_as_posix_path(self.DEBUG_DIR)}"
|
||||
# isolate=True only if the bug you're testing would otherwise
|
||||
# crash the process
|
||||
def _run_full_test(
|
||||
self, run_code, repro_after, expected_error, *, isolate, minifier_args=()
|
||||
self,
|
||||
run_code: str,
|
||||
repro_after: str,
|
||||
expected_error: Optional[str],
|
||||
*,
|
||||
isolate: bool,
|
||||
minifier_args: Sequence[Any] = (),
|
||||
) -> Optional[MinifierTestResult]:
|
||||
if isolate:
|
||||
repro_level = 3
|
||||
|
Reference in New Issue
Block a user