mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[dynamo] Enable typechecking for testing.py (#112129)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/112129 Approved by: https://github.com/Skylion007 ghstack dependencies: #111894, #111992, #112031, #112127, #112128
This commit is contained in:
@ -186,6 +186,7 @@ include_patterns = [
|
||||
'torch/_dynamo/funcname_cache.py',
|
||||
'torch/_dynamo/convert_frame.py',
|
||||
'torch/_dynamo/symbolic_convert.py',
|
||||
'torch/_dynamo/testing.py',
|
||||
'torch/_dynamo/types.py',
|
||||
'torch/_dynamo/output_graph.py',
|
||||
'torch/_dynamo/guards.py',
|
||||
|
@ -41,6 +41,9 @@ ignore_errors = True
|
||||
[mypy-torch.fb.*]
|
||||
ignore_missing_imports = True
|
||||
|
||||
[mypy-torch_xla.*]
|
||||
ignore_missing_imports = True
|
||||
|
||||
[mypy-torchvision.*]
|
||||
ignore_missing_imports = True
|
||||
|
||||
|
@ -8,9 +8,10 @@ import re
|
||||
import sys
|
||||
import types
|
||||
import unittest
|
||||
from typing import Sequence, Union
|
||||
from typing import List, Optional, Sequence, Union
|
||||
from unittest.mock import patch
|
||||
|
||||
np: Optional[types.ModuleType] = None
|
||||
try:
|
||||
import numpy as np
|
||||
except ModuleNotFoundError:
|
||||
@ -62,7 +63,7 @@ def named_buffers_for_optimized_module(mod):
|
||||
return mod._orig_mod.named_buffers
|
||||
|
||||
|
||||
def remove_optimized_module_prefix(name):
|
||||
def remove_optimized_module_prefix(name) -> str:
|
||||
return re.sub(r"^_orig_mod[.]", "", name)
|
||||
|
||||
|
||||
@ -140,21 +141,21 @@ def reduce_to_scalar_loss(out):
|
||||
raise NotImplementedError("Don't know how to reduce", type(out))
|
||||
|
||||
|
||||
def debug_dir():
|
||||
def debug_dir() -> str:
|
||||
path = os.path.join(os.path.dirname(__file__), "../debug")
|
||||
if not os.path.exists(path):
|
||||
os.mkdir(path)
|
||||
return path
|
||||
|
||||
|
||||
def debug_dump(name, code: types.CodeType, extra=""):
|
||||
def debug_dump(name, code: types.CodeType, extra="") -> None:
|
||||
with open(os.path.join(debug_dir(), name), "w") as fd:
|
||||
fd.write(
|
||||
f"{dis.Bytecode(code).info()}\n\n{dis.Bytecode(code).dis()}\n\n{extra}\n"
|
||||
)
|
||||
|
||||
|
||||
def debug_insert_nops(frame, cache_size, hooks, _):
|
||||
def debug_insert_nops(frame, cache_size, hooks, _) -> Optional[GuardedCode]:
|
||||
"""used to debug jump updates"""
|
||||
|
||||
def insert_nops(instructions, code_options):
|
||||
@ -187,7 +188,7 @@ class CompileCounter:
|
||||
self.frame_count = 0
|
||||
self.op_count = 0
|
||||
|
||||
def __call__(self, gm: torch.fx.GraphModule, example_inputs):
|
||||
def __call__(self, gm: torch.fx.GraphModule, example_inputs: List[torch.Tensor]):
|
||||
self.frame_count += 1
|
||||
for node in gm.graph.nodes:
|
||||
if "call" in node.op:
|
||||
@ -206,7 +207,7 @@ class CompileCounterWithBackend:
|
||||
self.backend = backend
|
||||
self.graphs = []
|
||||
|
||||
def __call__(self, gm: torch.fx.GraphModule, example_inputs):
|
||||
def __call__(self, gm: torch.fx.GraphModule, example_inputs: List[torch.Tensor]):
|
||||
from .backends.registry import lookup_backend
|
||||
|
||||
self.frame_count += 1
|
||||
@ -223,21 +224,21 @@ class EagerAndRecordGraphs:
|
||||
def __init__(self):
|
||||
self.graphs = []
|
||||
|
||||
def __call__(self, gm: torch.fx.GraphModule, example_inputs):
|
||||
def __call__(self, gm: torch.fx.GraphModule, example_inputs: List[torch.Tensor]):
|
||||
self.graphs.append(gm)
|
||||
return gm
|
||||
|
||||
|
||||
def strip_comment(code):
|
||||
def strip_comment(code) -> str:
|
||||
code = str(code)
|
||||
return re.sub(r"(?m)^ *#.*\n?", "", code)
|
||||
|
||||
|
||||
def remove_trailing_space(code):
|
||||
def remove_trailing_space(code) -> str:
|
||||
return "\n".join([line.rstrip() for line in code.split("\n")])
|
||||
|
||||
|
||||
def normalize_gm(gm_str):
|
||||
def normalize_gm(gm_str) -> str:
|
||||
# strip comments as comments have path to files which may differ from
|
||||
# system to system.
|
||||
return remove_trailing_space(strip_comment(gm_str))
|
||||
@ -252,7 +253,7 @@ def standard_test(self, fn, nargs, expected_ops=None, expected_ops_dynamic=None)
|
||||
expected = CompileCounter()
|
||||
try:
|
||||
gm = torch.fx.symbolic_trace(fn)
|
||||
expected(gm)
|
||||
expected(gm) # type: ignore[call-arg] # FIXME: https://github.com/pytorch/pytorch/issues/112230
|
||||
print("\nfx.symbolic_trace graph:")
|
||||
gm.graph.print_tabular()
|
||||
expected_ops = expected.op_count
|
||||
|
Reference in New Issue
Block a user