[dynamo] Enable typechecking for testing.py (#112129)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/112129
Approved by: https://github.com/Skylion007
ghstack dependencies: #111894, #111992, #112031, #112127, #112128
This commit is contained in:
Jez Ng
2023-10-26 20:01:01 -07:00
committed by PyTorch MergeBot
parent d3bf6803b6
commit a26cb0a3f2
3 changed files with 17 additions and 12 deletions

View File

@ -186,6 +186,7 @@ include_patterns = [
'torch/_dynamo/funcname_cache.py',
'torch/_dynamo/convert_frame.py',
'torch/_dynamo/symbolic_convert.py',
'torch/_dynamo/testing.py',
'torch/_dynamo/types.py',
'torch/_dynamo/output_graph.py',
'torch/_dynamo/guards.py',

View File

@ -41,6 +41,9 @@ ignore_errors = True
[mypy-torch.fb.*]
ignore_missing_imports = True
[mypy-torch_xla.*]
ignore_missing_imports = True
[mypy-torchvision.*]
ignore_missing_imports = True

View File

@ -8,9 +8,10 @@ import re
import sys
import types
import unittest
from typing import Sequence, Union
from typing import List, Optional, Sequence, Union
from unittest.mock import patch
np: Optional[types.ModuleType] = None
try:
import numpy as np
except ModuleNotFoundError:
@ -62,7 +63,7 @@ def named_buffers_for_optimized_module(mod):
return mod._orig_mod.named_buffers
def remove_optimized_module_prefix(name):
def remove_optimized_module_prefix(name) -> str:
return re.sub(r"^_orig_mod[.]", "", name)
@ -140,21 +141,21 @@ def reduce_to_scalar_loss(out):
raise NotImplementedError("Don't know how to reduce", type(out))
def debug_dir():
def debug_dir() -> str:
path = os.path.join(os.path.dirname(__file__), "../debug")
if not os.path.exists(path):
os.mkdir(path)
return path
def debug_dump(name, code: types.CodeType, extra=""):
def debug_dump(name, code: types.CodeType, extra="") -> None:
with open(os.path.join(debug_dir(), name), "w") as fd:
fd.write(
f"{dis.Bytecode(code).info()}\n\n{dis.Bytecode(code).dis()}\n\n{extra}\n"
)
def debug_insert_nops(frame, cache_size, hooks, _):
def debug_insert_nops(frame, cache_size, hooks, _) -> Optional[GuardedCode]:
"""used to debug jump updates"""
def insert_nops(instructions, code_options):
@ -187,7 +188,7 @@ class CompileCounter:
self.frame_count = 0
self.op_count = 0
def __call__(self, gm: torch.fx.GraphModule, example_inputs):
def __call__(self, gm: torch.fx.GraphModule, example_inputs: List[torch.Tensor]):
self.frame_count += 1
for node in gm.graph.nodes:
if "call" in node.op:
@ -206,7 +207,7 @@ class CompileCounterWithBackend:
self.backend = backend
self.graphs = []
def __call__(self, gm: torch.fx.GraphModule, example_inputs):
def __call__(self, gm: torch.fx.GraphModule, example_inputs: List[torch.Tensor]):
from .backends.registry import lookup_backend
self.frame_count += 1
@ -223,21 +224,21 @@ class EagerAndRecordGraphs:
def __init__(self):
self.graphs = []
def __call__(self, gm: torch.fx.GraphModule, example_inputs):
def __call__(self, gm: torch.fx.GraphModule, example_inputs: List[torch.Tensor]):
self.graphs.append(gm)
return gm
def strip_comment(code):
def strip_comment(code) -> str:
code = str(code)
return re.sub(r"(?m)^ *#.*\n?", "", code)
def remove_trailing_space(code):
def remove_trailing_space(code) -> str:
return "\n".join([line.rstrip() for line in code.split("\n")])
def normalize_gm(gm_str):
def normalize_gm(gm_str) -> str:
# strip comments as comments have path to files which may differ from
# system to system.
return remove_trailing_space(strip_comment(gm_str))
@ -252,7 +253,7 @@ def standard_test(self, fn, nargs, expected_ops=None, expected_ops_dynamic=None)
expected = CompileCounter()
try:
gm = torch.fx.symbolic_trace(fn)
expected(gm)
expected(gm) # type: ignore[call-arg] # FIXME: https://github.com/pytorch/pytorch/issues/112230
print("\nfx.symbolic_trace graph:")
gm.graph.print_tabular()
expected_ops = expected.op_count