pyfmt lint more export files (#155783)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/155783
Approved by: https://github.com/Skylion007
ghstack dependencies: #155782
This commit is contained in:
Laith Sakka
2025-06-11 23:41:50 -07:00
committed by PyTorch MergeBot
parent 86b1116f22
commit 2903e5ad3c
5 changed files with 102 additions and 47 deletions

View File

@ -1304,17 +1304,6 @@ exclude_patterns = [
'torch/_export/db/examples/type_reflection_method.py', 'torch/_export/db/examples/type_reflection_method.py',
'torch/_export/db/gen_example.py', 'torch/_export/db/gen_example.py',
'torch/_export/db/logging.py', 'torch/_export/db/logging.py',
'torch/_export/error.py',
'torch/_export/exported_program.py',
'torch/_export/pass_base.py',
'torch/_export/pass_infra/__init__.py',
'torch/_export/pass_infra/node_metadata.py',
'torch/_export/pass_infra/proxy_value.py',
'torch/_export/passes/__init__.py',
'torch/_export/passes/add_runtime_assertions_for_constraints_pass.py',
'torch/_export/passes/const_prop_pass.py',
'torch/_export/passes/functionalize_side_effectful_ops_pass.py',
'torch/_export/passes/replace_sym_size_ops_pass.py',
'torch/testing/_internal/__init__.py', 'torch/testing/_internal/__init__.py',
'torch/testing/_internal/autocast_test_lists.py', 'torch/testing/_internal/autocast_test_lists.py',
'torch/testing/_internal/autograd_function_db.py', 'torch/testing/_internal/autograd_function_db.py',

View File

@ -6,20 +6,23 @@ from contextlib import nullcontext
from typing import Any, Callable, Optional, Union from typing import Any, Callable, Optional, Union
import torch import torch
from torch._higher_order_ops.map import _unstack_pytree
from torch import fx from torch import fx
from torch._dispatch.python import enable_python_dispatcher from torch._dispatch.python import enable_python_dispatcher
from torch._export.pass_infra.node_metadata import NodeMetadata from torch._export.pass_infra.node_metadata import NodeMetadata
from torch._export.pass_infra.proxy_value import ProxyValue from torch._export.pass_infra.proxy_value import ProxyValue
from torch._higher_order_ops.map import _unstack_pytree
from torch._subclasses import FakeTensor, UnsupportedFakeTensorException from torch._subclasses import FakeTensor, UnsupportedFakeTensorException
from torch._subclasses.fake_tensor import FakeTensorMode from torch._subclasses.fake_tensor import FakeTensorMode
from torch.fx import traceback as fx_traceback from torch.fx import traceback as fx_traceback
from torch.fx.experimental.proxy_tensor import PythonKeyTracer from torch.fx.experimental.proxy_tensor import PythonKeyTracer
from torch.fx.experimental.symbolic_shapes import (
compute_unbacked_bindings,
PropagateUnbackedSymInts,
)
from torch.fx.graph import CodeGen from torch.fx.graph import CodeGen
from torch.fx.passes.infra.pass_base import PassBase, PassResult from torch.fx.passes.infra.pass_base import PassBase, PassResult
from torch.fx.passes.shape_prop import _extract_tensor_metadata, TensorMetadata from torch.fx.passes.shape_prop import _extract_tensor_metadata, TensorMetadata
from torch.utils import _pytree as pytree from torch.utils import _pytree as pytree
from torch.fx.experimental.symbolic_shapes import PropagateUnbackedSymInts, compute_unbacked_bindings
__all__ = ["_ExportPassBaseDeprecatedDoNotUse"] __all__ = ["_ExportPassBaseDeprecatedDoNotUse"]
@ -56,9 +59,10 @@ class _ExportPassBaseDeprecatedDoNotUse(PassBase):
def _create_dummy_node_metadata(): def _create_dummy_node_metadata():
return NodeMetadata({"stack_trace": "".join(traceback.format_stack(limit=1))}) return NodeMetadata({"stack_trace": "".join(traceback.format_stack(limit=1))})
class ExportTracer(PythonKeyTracer): class ExportTracer(PythonKeyTracer):
def __init__(self, callback: "_ExportPassBaseDeprecatedDoNotUse", codegen: CodeGen) -> None: def __init__(
self, callback: "_ExportPassBaseDeprecatedDoNotUse", codegen: CodeGen
) -> None:
super().__init__() super().__init__()
self.callback = callback self.callback = callback
self.root = torch.nn.Module() self.root = torch.nn.Module()
@ -92,12 +96,24 @@ class _ExportPassBaseDeprecatedDoNotUse(PassBase):
return node return node
def set_metadata( def set_metadata(
self, node: torch.fx.Node, value: Argument, self,
node: torch.fx.Node,
value: Argument,
) -> None: ) -> None:
# propagate the fake tensor or sym nodes # propagate the fake tensor or sym nodes
def make_val( def make_val(
x: Argument, x: Argument,
) -> Union[FakeTensor, torch.SymInt, torch.SymFloat, torch.SymBool, int, float, bool, str, None]: ) -> Union[
FakeTensor,
torch.SymInt,
torch.SymFloat,
torch.SymBool,
int,
float,
bool,
str,
None,
]:
if isinstance(x, FakeTensor): if isinstance(x, FakeTensor):
return x return x
elif isinstance(x, torch.Tensor): elif isinstance(x, torch.Tensor):
@ -124,7 +140,18 @@ class _ExportPassBaseDeprecatedDoNotUse(PassBase):
) )
fake_tensor = None fake_tensor = None
return fake_tensor return fake_tensor
elif isinstance(x, (torch.SymInt, torch.SymFloat, torch.SymBool, int, float, bool, str)): elif isinstance(
x,
(
torch.SymInt,
torch.SymFloat,
torch.SymBool,
int,
float,
bool,
str,
),
):
return x return x
else: else:
return None return None
@ -153,7 +180,9 @@ class _ExportPassBaseDeprecatedDoNotUse(PassBase):
node.meta["tensor_meta"] = pytree.tree_map(make_tensor_meta, value) node.meta["tensor_meta"] = pytree.tree_map(make_tensor_meta, value)
class ExportInterpreter(fx.Interpreter): class ExportInterpreter(fx.Interpreter):
def __init__(self, callback: "_ExportPassBaseDeprecatedDoNotUse", gm: fx.GraphModule) -> None: def __init__(
self, callback: "_ExportPassBaseDeprecatedDoNotUse", gm: fx.GraphModule
) -> None:
super().__init__(gm) super().__init__(gm)
self.callback = callback self.callback = callback
self.node: torch.fx.Node = next(iter(gm.graph.nodes)) self.node: torch.fx.Node = next(iter(gm.graph.nodes))
@ -186,13 +215,19 @@ class _ExportPassBaseDeprecatedDoNotUse(PassBase):
if target == operator.getitem: if target == operator.getitem:
value, key = args value, key = args
return self.callback.call_getitem(value, key, meta) return self.callback.call_getitem(value, key, meta)
elif getattr(target, "__module__", None) in {"_operator", "builtins", "math"}: elif getattr(target, "__module__", None) in {
"_operator",
"builtins",
"math",
}:
assert callable(target) assert callable(target)
return self.callback.call_sym(target, args, meta) return self.callback.call_sym(target, args, meta)
elif target in _TORCH_SYM_OPS: elif target in _TORCH_SYM_OPS:
assert callable(target) assert callable(target)
return self.callback.call_sym(target, args, meta) return self.callback.call_sym(target, args, meta)
elif isinstance(target, (torch._ops.OpOverload, torch._ops.OpOverloadPacket)): elif isinstance(
target, (torch._ops.OpOverload, torch._ops.OpOverloadPacket)
):
return self.callback.call_operator( return self.callback.call_operator(
target, target,
args, args,
@ -269,7 +304,9 @@ class _ExportPassBaseDeprecatedDoNotUse(PassBase):
if isinstance(target, torch._ops.OpOverload): if isinstance(target, torch._ops.OpOverload):
name = self.tracer.graph._target_to_str(target.overloadpacket.__name__) name = self.tracer.graph._target_to_str(target.overloadpacket.__name__)
res_proxy = self.tracer.create_proxy(kind, target, args_proxy, kwargs_proxy, name=name) res_proxy = self.tracer.create_proxy(
kind, target, args_proxy, kwargs_proxy, name=name
)
res_proxy.node.meta.update(meta.data) res_proxy.node.meta.update(meta.data)
if self.fake_tensor_mode and (shape_env := self.fake_tensor_mode.shape_env): if self.fake_tensor_mode and (shape_env := self.fake_tensor_mode.shape_env):
if symbol_to_path := compute_unbacked_bindings(shape_env, res_data): if symbol_to_path := compute_unbacked_bindings(shape_env, res_data):

View File

@ -1,12 +1,13 @@
# pyre-strict # pyre-strict
from typing import Union, Generic from collections.abc import Iterable, Iterator
from collections.abc import Iterator, Iterable from typing import Generic, TypeVar, Union
import torch import torch
from typing import TypeVar
_T = TypeVar("_T") _T = TypeVar("_T")
class ProxyValue(Generic[_T]): class ProxyValue(Generic[_T]):
# pyre-ignore # pyre-ignore
def __init__(self, data: Iterable[_T], proxy: Union[torch.fx.Proxy, torch.fx.Node]): def __init__(self, data: Iterable[_T], proxy: Union[torch.fx.Proxy, torch.fx.Node]):

View File

@ -9,10 +9,11 @@ import sympy
import torch import torch
import torch.fx import torch.fx
from torch.utils._sympy.value_ranges import ValueRanges
from torch.utils._sympy.numbers import int_oo
from torch.fx.experimental.symbolic_shapes import free_unbacked_symbols from torch.fx.experimental.symbolic_shapes import free_unbacked_symbols
from torch.fx.passes.infra.pass_base import PassBase, PassResult from torch.fx.passes.infra.pass_base import PassBase, PassResult
from torch.utils._sympy.numbers import int_oo
from torch.utils._sympy.value_ranges import ValueRanges
__all__ = ["InputDim"] __all__ = ["InputDim"]
@ -30,9 +31,7 @@ def _convert_to_int(val):
return -math.inf return -math.inf
if isinstance(val, sympy.Integer): if isinstance(val, sympy.Integer):
return int(val) return int(val)
raise RuntimeError( raise RuntimeError("Export constraints cannot be non-integer expressions")
"Export constraints cannot be non-integer expressions"
)
def _convert_range_to_int(range: ValueRanges): def _convert_range_to_int(range: ValueRanges):
@ -55,10 +54,14 @@ class _AddRuntimeAssertionsForInlineConstraintsPass(PassBase):
def _assert_range_constraint(self, node, lower, upper, assert_msg): def _assert_range_constraint(self, node, lower, upper, assert_msg):
last_node = node last_node = node
if lower > -math.inf: if lower > -math.inf:
last_node = self._insert_assert_async(last_node, operator.ge, node, lower, assert_msg) last_node = self._insert_assert_async(
last_node, operator.ge, node, lower, assert_msg
)
if upper < math.inf: if upper < math.inf:
last_node = self._insert_assert_async(last_node, operator.le, node, upper, assert_msg) last_node = self._insert_assert_async(
last_node, operator.le, node, upper, assert_msg
)
def _insert_assert_async(self, last_node, op, lower, upper, assert_msg): def _insert_assert_async(self, last_node, op, lower, upper, assert_msg):
""" """
@ -70,7 +73,9 @@ class _AddRuntimeAssertionsForInlineConstraintsPass(PassBase):
with graph.inserting_after(last_node): with graph.inserting_after(last_node):
cmp = graph.call_function(op, (lower, upper), {}) cmp = graph.call_function(op, (lower, upper), {})
with graph.inserting_after(cmp): with graph.inserting_after(cmp):
cmp_tensor = graph.call_function(torch.ops.aten.scalar_tensor.default, (cmp,), {}) cmp_tensor = graph.call_function(
torch.ops.aten.scalar_tensor.default, (cmp,), {}
)
with graph.inserting_after(cmp_tensor): with graph.inserting_after(cmp_tensor):
assert_async = graph.call_function( assert_async = graph.call_function(
torch.ops.aten._assert_async.msg, torch.ops.aten._assert_async.msg,
@ -111,7 +116,9 @@ class _AddRuntimeAssertionsForInlineConstraintsPass(PassBase):
symbol = val.node.expr symbol = val.node.expr
if symbol in self.existing_inline_assertions: if symbol in self.existing_inline_assertions:
return call_backs, messages return call_backs, messages
if isinstance(symbol, sympy.Symbol) and free_unbacked_symbols(symbol): if isinstance(symbol, sympy.Symbol) and free_unbacked_symbols(
symbol
):
if symbol in self._asserts_generated_unbacked_symbols: if symbol in self._asserts_generated_unbacked_symbols:
return call_backs, messages return call_backs, messages
# We only care about unbacked symints for these inline # We only care about unbacked symints for these inline
@ -120,7 +127,11 @@ class _AddRuntimeAssertionsForInlineConstraintsPass(PassBase):
min_val, max_val = _convert_range_to_int(constraint) min_val, max_val = _convert_range_to_int(constraint)
assert_msg = f" is outside of inline constraint [{min_val}, {max_val}]." assert_msg = f" is outside of inline constraint [{min_val}, {max_val}]."
call_backs.append( call_backs.append(
partial(self._assert_range_constraint, lower=min_val, upper=max_val) partial(
self._assert_range_constraint,
lower=min_val,
upper=max_val,
)
) )
messages.append(assert_msg) messages.append(assert_msg)
self._asserts_generated_unbacked_symbols.add(symbol) self._asserts_generated_unbacked_symbols.add(symbol)
@ -129,6 +140,7 @@ class _AddRuntimeAssertionsForInlineConstraintsPass(PassBase):
for i, sym in enumerate(val.shape): for i, sym in enumerate(val.shape):
cbs, msgs = add_assertions(sym) cbs, msgs = add_assertions(sym)
for cb, msg in zip(cbs, msgs): for cb, msg in zip(cbs, msgs):
def sym_size_cb(node, assert_msg, dim): def sym_size_cb(node, assert_msg, dim):
with node.graph.inserting_after(node): with node.graph.inserting_after(node):
dim_node = module.graph.call_function( dim_node = module.graph.call_function(
@ -137,6 +149,7 @@ class _AddRuntimeAssertionsForInlineConstraintsPass(PassBase):
{}, {},
) )
cb(node=dim_node, assert_msg=assert_msg) cb(node=dim_node, assert_msg=assert_msg)
call_backs.append(partial(sym_size_cb, dim=i)) call_backs.append(partial(sym_size_cb, dim=i))
messages.append(f".shape[{i}]" + msg) messages.append(f".shape[{i}]" + msg)
return call_backs, messages return call_backs, messages
@ -149,12 +162,18 @@ class _AddRuntimeAssertionsForInlineConstraintsPass(PassBase):
# Sometimes this pass would return a wrong graph where we have mismatched # Sometimes this pass would return a wrong graph where we have mismatched
# node names in signature. Before we fix it, let's just skip it. # node names in signature. Before we fix it, let's just skip it.
if self.counter == 0 and type(self) is _AddRuntimeAssertionsForInlineConstraintsPass: if (
self.counter == 0
and type(self) is _AddRuntimeAssertionsForInlineConstraintsPass
):
return PassResult(graph_module, False) return PassResult(graph_module, False)
# Populate the stack trace with dummy vals to respect IR # Populate the stack trace with dummy vals to respect IR
for node in graph_module.graph.nodes: for node in graph_module.graph.nodes:
if not node.meta.get("stack_trace", None) and node.op not in ["placeholder", "output"]: if not node.meta.get("stack_trace", None) and node.op not in [
"placeholder",
"output",
]:
node.meta["stack_trace"] = "".join(traceback.format_stack(limit=1)) node.meta["stack_trace"] = "".join(traceback.format_stack(limit=1))
return PassResult(graph_module, True) return PassResult(graph_module, True)
@ -179,10 +198,10 @@ def _get_existing_inline_assertions(
compare_arg = node.args[0] compare_arg = node.args[0]
if not ( if not (
isinstance(compare_arg, torch.fx.Node) and isinstance(compare_arg, torch.fx.Node)
compare_arg.op == "call_function" and and compare_arg.op == "call_function"
compare_arg.target in (operator.le, operator.ge) and and compare_arg.target in (operator.le, operator.ge)
len(compare_arg.args) == 2 and len(compare_arg.args) == 2
): ):
continue continue
@ -191,9 +210,9 @@ def _get_existing_inline_assertions(
def maybe_get_symint(x): def maybe_get_symint(x):
if ( if (
isinstance(x, torch.fx.Node) and isinstance(x, torch.fx.Node)
"val" in x.meta and and "val" in x.meta
isinstance(x.meta["val"], torch.SymInt) and isinstance(x.meta["val"], torch.SymInt)
): ):
return x.meta["val"].node.expr return x.meta["val"].node.expr
return x return x
@ -214,9 +233,13 @@ def _get_existing_inline_assertions(
continue continue
if symint not in range_constraints: if symint not in range_constraints:
raise RuntimeError(f"Unable to find symint {symint} in {range_constraints}") raise RuntimeError(
f"Unable to find symint {symint} in {range_constraints}"
)
previous_range = existing_inline_assertions.get(symint, ValueRanges(-math.inf, math.inf)) previous_range = existing_inline_assertions.get(
symint, ValueRanges(-math.inf, math.inf)
)
if symint is lhs: if symint is lhs:
bounds = ValueRanges(-math.inf, scalar) bounds = ValueRanges(-math.inf, scalar)

View File

@ -2,11 +2,16 @@ import copy
from typing import Optional from typing import Optional
import torch import torch
from torch._export.pass_base import _ExportPassBaseDeprecatedDoNotUse, PassResult, Argument from torch._export.pass_base import (
_ExportPassBaseDeprecatedDoNotUse,
Argument,
PassResult,
)
from torch._export.pass_infra.node_metadata import NodeMetadata from torch._export.pass_infra.node_metadata import NodeMetadata
from torch._export.pass_infra.proxy_value import ProxyValue from torch._export.pass_infra.proxy_value import ProxyValue
from torch._ops import OpOverload from torch._ops import OpOverload
aten = torch.ops.aten aten = torch.ops.aten
_NON_FUNCTIONAL_TO_FUNCTIONAL_SIDE_EFFECTFUL_FUNCS: dict[OpOverload, OpOverload] = { _NON_FUNCTIONAL_TO_FUNCTIONAL_SIDE_EFFECTFUL_FUNCS: dict[OpOverload, OpOverload] = {