pyfmt lint more export files (#155783)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/155783
Approved by: https://github.com/Skylion007
ghstack dependencies: #155782
This commit is contained in:
Laith Sakka
2025-06-11 23:41:50 -07:00
committed by PyTorch MergeBot
parent 86b1116f22
commit 2903e5ad3c
5 changed files with 102 additions and 47 deletions

View File

@ -1304,17 +1304,6 @@ exclude_patterns = [
'torch/_export/db/examples/type_reflection_method.py',
'torch/_export/db/gen_example.py',
'torch/_export/db/logging.py',
'torch/_export/error.py',
'torch/_export/exported_program.py',
'torch/_export/pass_base.py',
'torch/_export/pass_infra/__init__.py',
'torch/_export/pass_infra/node_metadata.py',
'torch/_export/pass_infra/proxy_value.py',
'torch/_export/passes/__init__.py',
'torch/_export/passes/add_runtime_assertions_for_constraints_pass.py',
'torch/_export/passes/const_prop_pass.py',
'torch/_export/passes/functionalize_side_effectful_ops_pass.py',
'torch/_export/passes/replace_sym_size_ops_pass.py',
'torch/testing/_internal/__init__.py',
'torch/testing/_internal/autocast_test_lists.py',
'torch/testing/_internal/autograd_function_db.py',

View File

@ -6,20 +6,23 @@ from contextlib import nullcontext
from typing import Any, Callable, Optional, Union
import torch
from torch._higher_order_ops.map import _unstack_pytree
from torch import fx
from torch._dispatch.python import enable_python_dispatcher
from torch._export.pass_infra.node_metadata import NodeMetadata
from torch._export.pass_infra.proxy_value import ProxyValue
from torch._higher_order_ops.map import _unstack_pytree
from torch._subclasses import FakeTensor, UnsupportedFakeTensorException
from torch._subclasses.fake_tensor import FakeTensorMode
from torch.fx import traceback as fx_traceback
from torch.fx.experimental.proxy_tensor import PythonKeyTracer
from torch.fx.experimental.symbolic_shapes import (
compute_unbacked_bindings,
PropagateUnbackedSymInts,
)
from torch.fx.graph import CodeGen
from torch.fx.passes.infra.pass_base import PassBase, PassResult
from torch.fx.passes.shape_prop import _extract_tensor_metadata, TensorMetadata
from torch.utils import _pytree as pytree
from torch.fx.experimental.symbolic_shapes import PropagateUnbackedSymInts, compute_unbacked_bindings
__all__ = ["_ExportPassBaseDeprecatedDoNotUse"]
@ -56,9 +59,10 @@ class _ExportPassBaseDeprecatedDoNotUse(PassBase):
def _create_dummy_node_metadata():
return NodeMetadata({"stack_trace": "".join(traceback.format_stack(limit=1))})
class ExportTracer(PythonKeyTracer):
def __init__(self, callback: "_ExportPassBaseDeprecatedDoNotUse", codegen: CodeGen) -> None:
def __init__(
self, callback: "_ExportPassBaseDeprecatedDoNotUse", codegen: CodeGen
) -> None:
super().__init__()
self.callback = callback
self.root = torch.nn.Module()
@ -92,12 +96,24 @@ class _ExportPassBaseDeprecatedDoNotUse(PassBase):
return node
def set_metadata(
self, node: torch.fx.Node, value: Argument,
self,
node: torch.fx.Node,
value: Argument,
) -> None:
# propagate the fake tensor or sym nodes
def make_val(
x: Argument,
) -> Union[FakeTensor, torch.SymInt, torch.SymFloat, torch.SymBool, int, float, bool, str, None]:
) -> Union[
FakeTensor,
torch.SymInt,
torch.SymFloat,
torch.SymBool,
int,
float,
bool,
str,
None,
]:
if isinstance(x, FakeTensor):
return x
elif isinstance(x, torch.Tensor):
@ -124,7 +140,18 @@ class _ExportPassBaseDeprecatedDoNotUse(PassBase):
)
fake_tensor = None
return fake_tensor
elif isinstance(x, (torch.SymInt, torch.SymFloat, torch.SymBool, int, float, bool, str)):
elif isinstance(
x,
(
torch.SymInt,
torch.SymFloat,
torch.SymBool,
int,
float,
bool,
str,
),
):
return x
else:
return None
@ -153,7 +180,9 @@ class _ExportPassBaseDeprecatedDoNotUse(PassBase):
node.meta["tensor_meta"] = pytree.tree_map(make_tensor_meta, value)
class ExportInterpreter(fx.Interpreter):
def __init__(self, callback: "_ExportPassBaseDeprecatedDoNotUse", gm: fx.GraphModule) -> None:
def __init__(
self, callback: "_ExportPassBaseDeprecatedDoNotUse", gm: fx.GraphModule
) -> None:
super().__init__(gm)
self.callback = callback
self.node: torch.fx.Node = next(iter(gm.graph.nodes))
@ -186,13 +215,19 @@ class _ExportPassBaseDeprecatedDoNotUse(PassBase):
if target == operator.getitem:
value, key = args
return self.callback.call_getitem(value, key, meta)
elif getattr(target, "__module__", None) in {"_operator", "builtins", "math"}:
elif getattr(target, "__module__", None) in {
"_operator",
"builtins",
"math",
}:
assert callable(target)
return self.callback.call_sym(target, args, meta)
elif target in _TORCH_SYM_OPS:
assert callable(target)
return self.callback.call_sym(target, args, meta)
elif isinstance(target, (torch._ops.OpOverload, torch._ops.OpOverloadPacket)):
elif isinstance(
target, (torch._ops.OpOverload, torch._ops.OpOverloadPacket)
):
return self.callback.call_operator(
target,
args,
@ -269,7 +304,9 @@ class _ExportPassBaseDeprecatedDoNotUse(PassBase):
if isinstance(target, torch._ops.OpOverload):
name = self.tracer.graph._target_to_str(target.overloadpacket.__name__)
res_proxy = self.tracer.create_proxy(kind, target, args_proxy, kwargs_proxy, name=name)
res_proxy = self.tracer.create_proxy(
kind, target, args_proxy, kwargs_proxy, name=name
)
res_proxy.node.meta.update(meta.data)
if self.fake_tensor_mode and (shape_env := self.fake_tensor_mode.shape_env):
if symbol_to_path := compute_unbacked_bindings(shape_env, res_data):

View File

@ -1,12 +1,13 @@
# pyre-strict
from typing import Union, Generic
from collections.abc import Iterator, Iterable
from collections.abc import Iterable, Iterator
from typing import Generic, TypeVar, Union
import torch
from typing import TypeVar
_T = TypeVar("_T")
class ProxyValue(Generic[_T]):
# pyre-ignore
def __init__(self, data: Iterable[_T], proxy: Union[torch.fx.Proxy, torch.fx.Node]):

View File

@ -9,10 +9,11 @@ import sympy
import torch
import torch.fx
from torch.utils._sympy.value_ranges import ValueRanges
from torch.utils._sympy.numbers import int_oo
from torch.fx.experimental.symbolic_shapes import free_unbacked_symbols
from torch.fx.passes.infra.pass_base import PassBase, PassResult
from torch.utils._sympy.numbers import int_oo
from torch.utils._sympy.value_ranges import ValueRanges
__all__ = ["InputDim"]
@ -30,9 +31,7 @@ def _convert_to_int(val):
return -math.inf
if isinstance(val, sympy.Integer):
return int(val)
raise RuntimeError(
"Export constraints cannot be non-integer expressions"
)
raise RuntimeError("Export constraints cannot be non-integer expressions")
def _convert_range_to_int(range: ValueRanges):
@ -55,10 +54,14 @@ class _AddRuntimeAssertionsForInlineConstraintsPass(PassBase):
def _assert_range_constraint(self, node, lower, upper, assert_msg):
last_node = node
if lower > -math.inf:
last_node = self._insert_assert_async(last_node, operator.ge, node, lower, assert_msg)
last_node = self._insert_assert_async(
last_node, operator.ge, node, lower, assert_msg
)
if upper < math.inf:
last_node = self._insert_assert_async(last_node, operator.le, node, upper, assert_msg)
last_node = self._insert_assert_async(
last_node, operator.le, node, upper, assert_msg
)
def _insert_assert_async(self, last_node, op, lower, upper, assert_msg):
"""
@ -70,7 +73,9 @@ class _AddRuntimeAssertionsForInlineConstraintsPass(PassBase):
with graph.inserting_after(last_node):
cmp = graph.call_function(op, (lower, upper), {})
with graph.inserting_after(cmp):
cmp_tensor = graph.call_function(torch.ops.aten.scalar_tensor.default, (cmp,), {})
cmp_tensor = graph.call_function(
torch.ops.aten.scalar_tensor.default, (cmp,), {}
)
with graph.inserting_after(cmp_tensor):
assert_async = graph.call_function(
torch.ops.aten._assert_async.msg,
@ -111,7 +116,9 @@ class _AddRuntimeAssertionsForInlineConstraintsPass(PassBase):
symbol = val.node.expr
if symbol in self.existing_inline_assertions:
return call_backs, messages
if isinstance(symbol, sympy.Symbol) and free_unbacked_symbols(symbol):
if isinstance(symbol, sympy.Symbol) and free_unbacked_symbols(
symbol
):
if symbol in self._asserts_generated_unbacked_symbols:
return call_backs, messages
# We only care about unbacked symints for these inline
@ -120,7 +127,11 @@ class _AddRuntimeAssertionsForInlineConstraintsPass(PassBase):
min_val, max_val = _convert_range_to_int(constraint)
assert_msg = f" is outside of inline constraint [{min_val}, {max_val}]."
call_backs.append(
partial(self._assert_range_constraint, lower=min_val, upper=max_val)
partial(
self._assert_range_constraint,
lower=min_val,
upper=max_val,
)
)
messages.append(assert_msg)
self._asserts_generated_unbacked_symbols.add(symbol)
@ -129,6 +140,7 @@ class _AddRuntimeAssertionsForInlineConstraintsPass(PassBase):
for i, sym in enumerate(val.shape):
cbs, msgs = add_assertions(sym)
for cb, msg in zip(cbs, msgs):
def sym_size_cb(node, assert_msg, dim):
with node.graph.inserting_after(node):
dim_node = module.graph.call_function(
@ -137,6 +149,7 @@ class _AddRuntimeAssertionsForInlineConstraintsPass(PassBase):
{},
)
cb(node=dim_node, assert_msg=assert_msg)
call_backs.append(partial(sym_size_cb, dim=i))
messages.append(f".shape[{i}]" + msg)
return call_backs, messages
@ -149,12 +162,18 @@ class _AddRuntimeAssertionsForInlineConstraintsPass(PassBase):
# Sometimes this pass would return a wrong graph where we have mismatched
# node names in signature. Before we fix it, let's just skip it.
if self.counter == 0 and type(self) is _AddRuntimeAssertionsForInlineConstraintsPass:
if (
self.counter == 0
and type(self) is _AddRuntimeAssertionsForInlineConstraintsPass
):
return PassResult(graph_module, False)
# Populate the stack trace with dummy vals to respect IR
for node in graph_module.graph.nodes:
if not node.meta.get("stack_trace", None) and node.op not in ["placeholder", "output"]:
if not node.meta.get("stack_trace", None) and node.op not in [
"placeholder",
"output",
]:
node.meta["stack_trace"] = "".join(traceback.format_stack(limit=1))
return PassResult(graph_module, True)
@ -179,10 +198,10 @@ def _get_existing_inline_assertions(
compare_arg = node.args[0]
if not (
isinstance(compare_arg, torch.fx.Node) and
compare_arg.op == "call_function" and
compare_arg.target in (operator.le, operator.ge) and
len(compare_arg.args) == 2
isinstance(compare_arg, torch.fx.Node)
and compare_arg.op == "call_function"
and compare_arg.target in (operator.le, operator.ge)
and len(compare_arg.args) == 2
):
continue
@ -191,9 +210,9 @@ def _get_existing_inline_assertions(
def maybe_get_symint(x):
if (
isinstance(x, torch.fx.Node) and
"val" in x.meta and
isinstance(x.meta["val"], torch.SymInt)
isinstance(x, torch.fx.Node)
and "val" in x.meta
and isinstance(x.meta["val"], torch.SymInt)
):
return x.meta["val"].node.expr
return x
@ -214,9 +233,13 @@ def _get_existing_inline_assertions(
continue
if symint not in range_constraints:
raise RuntimeError(f"Unable to find symint {symint} in {range_constraints}")
raise RuntimeError(
f"Unable to find symint {symint} in {range_constraints}"
)
previous_range = existing_inline_assertions.get(symint, ValueRanges(-math.inf, math.inf))
previous_range = existing_inline_assertions.get(
symint, ValueRanges(-math.inf, math.inf)
)
if symint is lhs:
bounds = ValueRanges(-math.inf, scalar)

View File

@ -2,11 +2,16 @@ import copy
from typing import Optional
import torch
from torch._export.pass_base import _ExportPassBaseDeprecatedDoNotUse, PassResult, Argument
from torch._export.pass_base import (
_ExportPassBaseDeprecatedDoNotUse,
Argument,
PassResult,
)
from torch._export.pass_infra.node_metadata import NodeMetadata
from torch._export.pass_infra.proxy_value import ProxyValue
from torch._ops import OpOverload
aten = torch.ops.aten
_NON_FUNCTIONAL_TO_FUNCTIONAL_SIDE_EFFECTFUL_FUNCS: dict[OpOverload, OpOverload] = {