fix errors about mypy check in torch/_inductor/compile_fx.py (#107508)

the `compile_fx.py` blocked the merging of [PR1 ](https://github.com/pytorch/pytorch/pull/107127)and [PR2](https://github.com/pytorch/pytorch/pull/107448)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/107508
Approved by: https://github.com/ezyang
This commit is contained in:
FFFrog
2023-08-22 22:33:32 +00:00
committed by PyTorch MergeBot
parent 5025fb9213
commit 4d13422997

View File

@ -168,7 +168,7 @@ def count_bytes_inner(
shape_env = _shape_env_from_inputs(example_inputs)
graph = GraphLowering(gm, shape_env=shape_env, num_static_inputs=num_fixed)
with V.set_graph_handler(graph), V.set_real_inputs(example_inputs):
with V.set_graph_handler(graph), V.set_real_inputs(example_inputs): # type: ignore[call-arg]
graph.run(*example_inputs)
num_bytes, nodes_num_elem, node_runtimes = graph.count_bytes()
metrics.num_bytes_accessed += num_bytes
@ -197,7 +197,7 @@ def inner_compile_with_cpp_wrapper(inner_compile: Callable[..., Any]):
kwargs_patched = {**kwargs, "cpp_wrapper": True}
return inner_compile(gm, example_inputs, **kwargs_patched)
else:
with config.patch(
with config.patch( # type: ignore[attr-defined]
{
"triton.store_cubin": True,
}
@ -214,11 +214,6 @@ def inner_compile_with_cpp_wrapper(inner_compile: Callable[..., Any]):
compiled = inner_compile(
clone_graph(gm), example_inputs, **kwargs_patched
)
if (
torch._guards.TracingContext.get()
and torch._guards.TracingContext.get().output_strides
):
torch._guards.TracingContext.get().output_strides.clear()
def materialize(x):
if isinstance(x, (torch.SymInt, torch.SymFloat)):
@ -228,10 +223,14 @@ def inner_compile_with_cpp_wrapper(inner_compile: Callable[..., Any]):
assert not isinstance(x, FakeTensor)
return x
if torch._guards.TracingContext.get():
tracing_context = torch._guards.TracingContext.get()
if tracing_context:
if tracing_context.output_strides:
tracing_context.output_strides.clear()
params_flat = [
param
for param in torch._guards.TracingContext.get().params_flat
for param in tracing_context.params_flat # type: ignore[union-attr]
if param is not None
]
real_inputs = [
@ -532,12 +531,12 @@ def fx_codegen_and_compile(
# on node.meta["val"]. if in the future we rely on these being
# correct we will need to fix.
with V.set_fake_mode(fake_mode):
with V.set_fake_mode(fake_mode): # type: ignore[call-arg]
# has some issues with memory in training
post_grad_passes(gm, is_inference=is_inference)
V.debug.fx_graph_transformed(gm, example_inputs)
with V.set_fake_mode(fake_mode):
with V.set_fake_mode(fake_mode): # type: ignore[call-arg]
graph = GraphLowering(
gm,
shape_env=shape_env,
@ -547,7 +546,7 @@ def fx_codegen_and_compile(
aot_mode=aot_mode,
user_visible_outputs=user_visible_outputs,
)
with V.set_graph_handler(graph):
with V.set_graph_handler(graph): # type: ignore[call-arg]
graph.run(*example_inputs)
context = torch._guards.TracingContext.get()
if context is not None and context.output_strides is not None:
@ -557,7 +556,7 @@ def fx_codegen_and_compile(
for out in graph.graph_outputs:
if hasattr(out, "layout"):
context.output_strides.append(
tuple(
tuple( # type: ignore[arg-type]
V.graph.sizevars.size_hint(s) for s in out.layout.stride
)
)
@ -576,7 +575,7 @@ def fx_codegen_and_compile(
device_types=graph.device_types,
device_idxs=graph.device_idxs,
mutated_inputs=graph.mutated_inputs,
mutated_input_idxs=graph.mutated_input_idxs,
mutated_input_idxs=set(graph.mutated_input_idxs),
)
return compiled_graph
@ -913,7 +912,10 @@ def fw_compiler_freezing(
]
# constant params will be real tensors, not fake
params_flat = torch._guards.TracingContext.get().params_flat
tracing_context = torch._guards.TracingContext.get()
assert tracing_context is not None
params_flat = tracing_context.params_flat
assert params_flat is not None
for i in range(len(params_flat)):
if i not in preserved_arg_indices:
params_flat[i] = None
@ -955,17 +957,17 @@ def compile_fx(
):
"""Main entrypoint to a compile given FX graph"""
if config_patches:
with config.patch(config_patches):
with config.patch(config_patches): # type: ignore[attr-defined]
return compile_fx(
model_,
example_inputs_,
# need extra layer of patching as backwards is compiled out of scope
inner_compile=config.patch(config_patches)(inner_compile),
inner_compile=config.patch(config_patches)(inner_compile), # type: ignore[attr-defined]
decompositions=decompositions,
)
if config.cpp_wrapper:
with config.patch(
with config.patch( # type: ignore[attr-defined]
{
"cpp_wrapper": False,
"triton.autotune_cublasLt": False,
@ -973,7 +975,9 @@ def compile_fx(
# CudaWrapperCodeGen relies on kernel name to find the autotuned cubin file
"triton.unique_kernel_names": True,
}
), V.set_real_inputs(example_inputs_):
), V.set_real_inputs(
example_inputs_
): # type: ignore[call-arg]
return compile_fx(
model_,
example_inputs_,
@ -1140,7 +1144,7 @@ def compile_fx(
torch._guards.TracingContext.get() or torch._guards.TracingContext(fake_mode)
)
with V.set_fake_mode(fake_mode), torch._guards.tracing(
with V.set_fake_mode(fake_mode), torch._guards.tracing( # type: ignore[call-arg]
tracing_context
), compiled_autograd.disable():
return aot_autograd(
@ -1155,8 +1159,8 @@ def compile_fx(
# pass config dict back to user
def get_patched_config_dict(config_patches=None):
with config.patch(config_patches):
return config.get_config_copy()
with config.patch(config_patches): # type: ignore[attr-defined]
return config.get_config_copy() # type: ignore[attr-defined]
def _shape_env_from_inputs(inputs: List[torch.Tensor]):