mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
fix errors about mypy check in torch/_inductor/compile_fx.py (#107508)
the `compile_fx.py` blocked the merging of [PR1 ](https://github.com/pytorch/pytorch/pull/107127)and [PR2](https://github.com/pytorch/pytorch/pull/107448) Pull Request resolved: https://github.com/pytorch/pytorch/pull/107508 Approved by: https://github.com/ezyang
This commit is contained in:
@ -168,7 +168,7 @@ def count_bytes_inner(
|
||||
shape_env = _shape_env_from_inputs(example_inputs)
|
||||
|
||||
graph = GraphLowering(gm, shape_env=shape_env, num_static_inputs=num_fixed)
|
||||
with V.set_graph_handler(graph), V.set_real_inputs(example_inputs):
|
||||
with V.set_graph_handler(graph), V.set_real_inputs(example_inputs): # type: ignore[call-arg]
|
||||
graph.run(*example_inputs)
|
||||
num_bytes, nodes_num_elem, node_runtimes = graph.count_bytes()
|
||||
metrics.num_bytes_accessed += num_bytes
|
||||
@ -197,7 +197,7 @@ def inner_compile_with_cpp_wrapper(inner_compile: Callable[..., Any]):
|
||||
kwargs_patched = {**kwargs, "cpp_wrapper": True}
|
||||
return inner_compile(gm, example_inputs, **kwargs_patched)
|
||||
else:
|
||||
with config.patch(
|
||||
with config.patch( # type: ignore[attr-defined]
|
||||
{
|
||||
"triton.store_cubin": True,
|
||||
}
|
||||
@ -214,11 +214,6 @@ def inner_compile_with_cpp_wrapper(inner_compile: Callable[..., Any]):
|
||||
compiled = inner_compile(
|
||||
clone_graph(gm), example_inputs, **kwargs_patched
|
||||
)
|
||||
if (
|
||||
torch._guards.TracingContext.get()
|
||||
and torch._guards.TracingContext.get().output_strides
|
||||
):
|
||||
torch._guards.TracingContext.get().output_strides.clear()
|
||||
|
||||
def materialize(x):
|
||||
if isinstance(x, (torch.SymInt, torch.SymFloat)):
|
||||
@ -228,10 +223,14 @@ def inner_compile_with_cpp_wrapper(inner_compile: Callable[..., Any]):
|
||||
assert not isinstance(x, FakeTensor)
|
||||
return x
|
||||
|
||||
if torch._guards.TracingContext.get():
|
||||
tracing_context = torch._guards.TracingContext.get()
|
||||
if tracing_context:
|
||||
if tracing_context.output_strides:
|
||||
tracing_context.output_strides.clear()
|
||||
|
||||
params_flat = [
|
||||
param
|
||||
for param in torch._guards.TracingContext.get().params_flat
|
||||
for param in tracing_context.params_flat # type: ignore[union-attr]
|
||||
if param is not None
|
||||
]
|
||||
real_inputs = [
|
||||
@ -532,12 +531,12 @@ def fx_codegen_and_compile(
|
||||
# on node.meta["val"]. if in the future we rely on these being
|
||||
# correct we will need to fix.
|
||||
|
||||
with V.set_fake_mode(fake_mode):
|
||||
with V.set_fake_mode(fake_mode): # type: ignore[call-arg]
|
||||
# has some issues with memory in training
|
||||
post_grad_passes(gm, is_inference=is_inference)
|
||||
V.debug.fx_graph_transformed(gm, example_inputs)
|
||||
|
||||
with V.set_fake_mode(fake_mode):
|
||||
with V.set_fake_mode(fake_mode): # type: ignore[call-arg]
|
||||
graph = GraphLowering(
|
||||
gm,
|
||||
shape_env=shape_env,
|
||||
@ -547,7 +546,7 @@ def fx_codegen_and_compile(
|
||||
aot_mode=aot_mode,
|
||||
user_visible_outputs=user_visible_outputs,
|
||||
)
|
||||
with V.set_graph_handler(graph):
|
||||
with V.set_graph_handler(graph): # type: ignore[call-arg]
|
||||
graph.run(*example_inputs)
|
||||
context = torch._guards.TracingContext.get()
|
||||
if context is not None and context.output_strides is not None:
|
||||
@ -557,7 +556,7 @@ def fx_codegen_and_compile(
|
||||
for out in graph.graph_outputs:
|
||||
if hasattr(out, "layout"):
|
||||
context.output_strides.append(
|
||||
tuple(
|
||||
tuple( # type: ignore[arg-type]
|
||||
V.graph.sizevars.size_hint(s) for s in out.layout.stride
|
||||
)
|
||||
)
|
||||
@ -576,7 +575,7 @@ def fx_codegen_and_compile(
|
||||
device_types=graph.device_types,
|
||||
device_idxs=graph.device_idxs,
|
||||
mutated_inputs=graph.mutated_inputs,
|
||||
mutated_input_idxs=graph.mutated_input_idxs,
|
||||
mutated_input_idxs=set(graph.mutated_input_idxs),
|
||||
)
|
||||
return compiled_graph
|
||||
|
||||
@ -913,7 +912,10 @@ def fw_compiler_freezing(
|
||||
]
|
||||
|
||||
# constant params will be real tensors, not fake
|
||||
params_flat = torch._guards.TracingContext.get().params_flat
|
||||
tracing_context = torch._guards.TracingContext.get()
|
||||
assert tracing_context is not None
|
||||
params_flat = tracing_context.params_flat
|
||||
assert params_flat is not None
|
||||
for i in range(len(params_flat)):
|
||||
if i not in preserved_arg_indices:
|
||||
params_flat[i] = None
|
||||
@ -955,17 +957,17 @@ def compile_fx(
|
||||
):
|
||||
"""Main entrypoint to a compile given FX graph"""
|
||||
if config_patches:
|
||||
with config.patch(config_patches):
|
||||
with config.patch(config_patches): # type: ignore[attr-defined]
|
||||
return compile_fx(
|
||||
model_,
|
||||
example_inputs_,
|
||||
# need extra layer of patching as backwards is compiled out of scope
|
||||
inner_compile=config.patch(config_patches)(inner_compile),
|
||||
inner_compile=config.patch(config_patches)(inner_compile), # type: ignore[attr-defined]
|
||||
decompositions=decompositions,
|
||||
)
|
||||
|
||||
if config.cpp_wrapper:
|
||||
with config.patch(
|
||||
with config.patch( # type: ignore[attr-defined]
|
||||
{
|
||||
"cpp_wrapper": False,
|
||||
"triton.autotune_cublasLt": False,
|
||||
@ -973,7 +975,9 @@ def compile_fx(
|
||||
# CudaWrapperCodeGen relies on kernel name to find the autotuned cubin file
|
||||
"triton.unique_kernel_names": True,
|
||||
}
|
||||
), V.set_real_inputs(example_inputs_):
|
||||
), V.set_real_inputs(
|
||||
example_inputs_
|
||||
): # type: ignore[call-arg]
|
||||
return compile_fx(
|
||||
model_,
|
||||
example_inputs_,
|
||||
@ -1140,7 +1144,7 @@ def compile_fx(
|
||||
torch._guards.TracingContext.get() or torch._guards.TracingContext(fake_mode)
|
||||
)
|
||||
|
||||
with V.set_fake_mode(fake_mode), torch._guards.tracing(
|
||||
with V.set_fake_mode(fake_mode), torch._guards.tracing( # type: ignore[call-arg]
|
||||
tracing_context
|
||||
), compiled_autograd.disable():
|
||||
return aot_autograd(
|
||||
@ -1155,8 +1159,8 @@ def compile_fx(
|
||||
|
||||
# pass config dict back to user
|
||||
def get_patched_config_dict(config_patches=None):
|
||||
with config.patch(config_patches):
|
||||
return config.get_config_copy()
|
||||
with config.patch(config_patches): # type: ignore[attr-defined]
|
||||
return config.get_config_copy() # type: ignore[attr-defined]
|
||||
|
||||
|
||||
def _shape_env_from_inputs(inputs: List[torch.Tensor]):
|
||||
|
Reference in New Issue
Block a user