mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 12:54:11 +08:00
pyfmt lint more export files (#155783)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/155783 Approved by: https://github.com/Skylion007 ghstack dependencies: #155782
This commit is contained in:
committed by
PyTorch MergeBot
parent
86b1116f22
commit
2903e5ad3c
@ -1304,17 +1304,6 @@ exclude_patterns = [
|
||||
'torch/_export/db/examples/type_reflection_method.py',
|
||||
'torch/_export/db/gen_example.py',
|
||||
'torch/_export/db/logging.py',
|
||||
'torch/_export/error.py',
|
||||
'torch/_export/exported_program.py',
|
||||
'torch/_export/pass_base.py',
|
||||
'torch/_export/pass_infra/__init__.py',
|
||||
'torch/_export/pass_infra/node_metadata.py',
|
||||
'torch/_export/pass_infra/proxy_value.py',
|
||||
'torch/_export/passes/__init__.py',
|
||||
'torch/_export/passes/add_runtime_assertions_for_constraints_pass.py',
|
||||
'torch/_export/passes/const_prop_pass.py',
|
||||
'torch/_export/passes/functionalize_side_effectful_ops_pass.py',
|
||||
'torch/_export/passes/replace_sym_size_ops_pass.py',
|
||||
'torch/testing/_internal/__init__.py',
|
||||
'torch/testing/_internal/autocast_test_lists.py',
|
||||
'torch/testing/_internal/autograd_function_db.py',
|
||||
|
@ -6,20 +6,23 @@ from contextlib import nullcontext
|
||||
from typing import Any, Callable, Optional, Union
|
||||
|
||||
import torch
|
||||
from torch._higher_order_ops.map import _unstack_pytree
|
||||
from torch import fx
|
||||
from torch._dispatch.python import enable_python_dispatcher
|
||||
from torch._export.pass_infra.node_metadata import NodeMetadata
|
||||
from torch._export.pass_infra.proxy_value import ProxyValue
|
||||
from torch._higher_order_ops.map import _unstack_pytree
|
||||
from torch._subclasses import FakeTensor, UnsupportedFakeTensorException
|
||||
from torch._subclasses.fake_tensor import FakeTensorMode
|
||||
from torch.fx import traceback as fx_traceback
|
||||
from torch.fx.experimental.proxy_tensor import PythonKeyTracer
|
||||
from torch.fx.experimental.symbolic_shapes import (
|
||||
compute_unbacked_bindings,
|
||||
PropagateUnbackedSymInts,
|
||||
)
|
||||
from torch.fx.graph import CodeGen
|
||||
from torch.fx.passes.infra.pass_base import PassBase, PassResult
|
||||
from torch.fx.passes.shape_prop import _extract_tensor_metadata, TensorMetadata
|
||||
from torch.utils import _pytree as pytree
|
||||
from torch.fx.experimental.symbolic_shapes import PropagateUnbackedSymInts, compute_unbacked_bindings
|
||||
|
||||
|
||||
__all__ = ["_ExportPassBaseDeprecatedDoNotUse"]
|
||||
@ -56,9 +59,10 @@ class _ExportPassBaseDeprecatedDoNotUse(PassBase):
|
||||
def _create_dummy_node_metadata():
|
||||
return NodeMetadata({"stack_trace": "".join(traceback.format_stack(limit=1))})
|
||||
|
||||
|
||||
class ExportTracer(PythonKeyTracer):
|
||||
def __init__(self, callback: "_ExportPassBaseDeprecatedDoNotUse", codegen: CodeGen) -> None:
|
||||
def __init__(
|
||||
self, callback: "_ExportPassBaseDeprecatedDoNotUse", codegen: CodeGen
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.callback = callback
|
||||
self.root = torch.nn.Module()
|
||||
@ -92,12 +96,24 @@ class _ExportPassBaseDeprecatedDoNotUse(PassBase):
|
||||
return node
|
||||
|
||||
def set_metadata(
|
||||
self, node: torch.fx.Node, value: Argument,
|
||||
self,
|
||||
node: torch.fx.Node,
|
||||
value: Argument,
|
||||
) -> None:
|
||||
# propagate the fake tensor or sym nodes
|
||||
def make_val(
|
||||
x: Argument,
|
||||
) -> Union[FakeTensor, torch.SymInt, torch.SymFloat, torch.SymBool, int, float, bool, str, None]:
|
||||
) -> Union[
|
||||
FakeTensor,
|
||||
torch.SymInt,
|
||||
torch.SymFloat,
|
||||
torch.SymBool,
|
||||
int,
|
||||
float,
|
||||
bool,
|
||||
str,
|
||||
None,
|
||||
]:
|
||||
if isinstance(x, FakeTensor):
|
||||
return x
|
||||
elif isinstance(x, torch.Tensor):
|
||||
@ -124,7 +140,18 @@ class _ExportPassBaseDeprecatedDoNotUse(PassBase):
|
||||
)
|
||||
fake_tensor = None
|
||||
return fake_tensor
|
||||
elif isinstance(x, (torch.SymInt, torch.SymFloat, torch.SymBool, int, float, bool, str)):
|
||||
elif isinstance(
|
||||
x,
|
||||
(
|
||||
torch.SymInt,
|
||||
torch.SymFloat,
|
||||
torch.SymBool,
|
||||
int,
|
||||
float,
|
||||
bool,
|
||||
str,
|
||||
),
|
||||
):
|
||||
return x
|
||||
else:
|
||||
return None
|
||||
@ -153,7 +180,9 @@ class _ExportPassBaseDeprecatedDoNotUse(PassBase):
|
||||
node.meta["tensor_meta"] = pytree.tree_map(make_tensor_meta, value)
|
||||
|
||||
class ExportInterpreter(fx.Interpreter):
|
||||
def __init__(self, callback: "_ExportPassBaseDeprecatedDoNotUse", gm: fx.GraphModule) -> None:
|
||||
def __init__(
|
||||
self, callback: "_ExportPassBaseDeprecatedDoNotUse", gm: fx.GraphModule
|
||||
) -> None:
|
||||
super().__init__(gm)
|
||||
self.callback = callback
|
||||
self.node: torch.fx.Node = next(iter(gm.graph.nodes))
|
||||
@ -186,13 +215,19 @@ class _ExportPassBaseDeprecatedDoNotUse(PassBase):
|
||||
if target == operator.getitem:
|
||||
value, key = args
|
||||
return self.callback.call_getitem(value, key, meta)
|
||||
elif getattr(target, "__module__", None) in {"_operator", "builtins", "math"}:
|
||||
elif getattr(target, "__module__", None) in {
|
||||
"_operator",
|
||||
"builtins",
|
||||
"math",
|
||||
}:
|
||||
assert callable(target)
|
||||
return self.callback.call_sym(target, args, meta)
|
||||
elif target in _TORCH_SYM_OPS:
|
||||
assert callable(target)
|
||||
return self.callback.call_sym(target, args, meta)
|
||||
elif isinstance(target, (torch._ops.OpOverload, torch._ops.OpOverloadPacket)):
|
||||
elif isinstance(
|
||||
target, (torch._ops.OpOverload, torch._ops.OpOverloadPacket)
|
||||
):
|
||||
return self.callback.call_operator(
|
||||
target,
|
||||
args,
|
||||
@ -269,7 +304,9 @@ class _ExportPassBaseDeprecatedDoNotUse(PassBase):
|
||||
if isinstance(target, torch._ops.OpOverload):
|
||||
name = self.tracer.graph._target_to_str(target.overloadpacket.__name__)
|
||||
|
||||
res_proxy = self.tracer.create_proxy(kind, target, args_proxy, kwargs_proxy, name=name)
|
||||
res_proxy = self.tracer.create_proxy(
|
||||
kind, target, args_proxy, kwargs_proxy, name=name
|
||||
)
|
||||
res_proxy.node.meta.update(meta.data)
|
||||
if self.fake_tensor_mode and (shape_env := self.fake_tensor_mode.shape_env):
|
||||
if symbol_to_path := compute_unbacked_bindings(shape_env, res_data):
|
||||
|
@ -1,12 +1,13 @@
|
||||
# pyre-strict
|
||||
from typing import Union, Generic
|
||||
from collections.abc import Iterator, Iterable
|
||||
from collections.abc import Iterable, Iterator
|
||||
from typing import Generic, TypeVar, Union
|
||||
|
||||
import torch
|
||||
|
||||
from typing import TypeVar
|
||||
|
||||
_T = TypeVar("_T")
|
||||
|
||||
|
||||
class ProxyValue(Generic[_T]):
|
||||
# pyre-ignore
|
||||
def __init__(self, data: Iterable[_T], proxy: Union[torch.fx.Proxy, torch.fx.Node]):
|
||||
|
@ -9,10 +9,11 @@ import sympy
|
||||
|
||||
import torch
|
||||
import torch.fx
|
||||
from torch.utils._sympy.value_ranges import ValueRanges
|
||||
from torch.utils._sympy.numbers import int_oo
|
||||
from torch.fx.experimental.symbolic_shapes import free_unbacked_symbols
|
||||
from torch.fx.passes.infra.pass_base import PassBase, PassResult
|
||||
from torch.utils._sympy.numbers import int_oo
|
||||
from torch.utils._sympy.value_ranges import ValueRanges
|
||||
|
||||
|
||||
__all__ = ["InputDim"]
|
||||
|
||||
@ -30,9 +31,7 @@ def _convert_to_int(val):
|
||||
return -math.inf
|
||||
if isinstance(val, sympy.Integer):
|
||||
return int(val)
|
||||
raise RuntimeError(
|
||||
"Export constraints cannot be non-integer expressions"
|
||||
)
|
||||
raise RuntimeError("Export constraints cannot be non-integer expressions")
|
||||
|
||||
|
||||
def _convert_range_to_int(range: ValueRanges):
|
||||
@ -55,10 +54,14 @@ class _AddRuntimeAssertionsForInlineConstraintsPass(PassBase):
|
||||
def _assert_range_constraint(self, node, lower, upper, assert_msg):
|
||||
last_node = node
|
||||
if lower > -math.inf:
|
||||
last_node = self._insert_assert_async(last_node, operator.ge, node, lower, assert_msg)
|
||||
last_node = self._insert_assert_async(
|
||||
last_node, operator.ge, node, lower, assert_msg
|
||||
)
|
||||
|
||||
if upper < math.inf:
|
||||
last_node = self._insert_assert_async(last_node, operator.le, node, upper, assert_msg)
|
||||
last_node = self._insert_assert_async(
|
||||
last_node, operator.le, node, upper, assert_msg
|
||||
)
|
||||
|
||||
def _insert_assert_async(self, last_node, op, lower, upper, assert_msg):
|
||||
"""
|
||||
@ -70,7 +73,9 @@ class _AddRuntimeAssertionsForInlineConstraintsPass(PassBase):
|
||||
with graph.inserting_after(last_node):
|
||||
cmp = graph.call_function(op, (lower, upper), {})
|
||||
with graph.inserting_after(cmp):
|
||||
cmp_tensor = graph.call_function(torch.ops.aten.scalar_tensor.default, (cmp,), {})
|
||||
cmp_tensor = graph.call_function(
|
||||
torch.ops.aten.scalar_tensor.default, (cmp,), {}
|
||||
)
|
||||
with graph.inserting_after(cmp_tensor):
|
||||
assert_async = graph.call_function(
|
||||
torch.ops.aten._assert_async.msg,
|
||||
@ -111,7 +116,9 @@ class _AddRuntimeAssertionsForInlineConstraintsPass(PassBase):
|
||||
symbol = val.node.expr
|
||||
if symbol in self.existing_inline_assertions:
|
||||
return call_backs, messages
|
||||
if isinstance(symbol, sympy.Symbol) and free_unbacked_symbols(symbol):
|
||||
if isinstance(symbol, sympy.Symbol) and free_unbacked_symbols(
|
||||
symbol
|
||||
):
|
||||
if symbol in self._asserts_generated_unbacked_symbols:
|
||||
return call_backs, messages
|
||||
# We only care about unbacked symints for these inline
|
||||
@ -120,7 +127,11 @@ class _AddRuntimeAssertionsForInlineConstraintsPass(PassBase):
|
||||
min_val, max_val = _convert_range_to_int(constraint)
|
||||
assert_msg = f" is outside of inline constraint [{min_val}, {max_val}]."
|
||||
call_backs.append(
|
||||
partial(self._assert_range_constraint, lower=min_val, upper=max_val)
|
||||
partial(
|
||||
self._assert_range_constraint,
|
||||
lower=min_val,
|
||||
upper=max_val,
|
||||
)
|
||||
)
|
||||
messages.append(assert_msg)
|
||||
self._asserts_generated_unbacked_symbols.add(symbol)
|
||||
@ -129,6 +140,7 @@ class _AddRuntimeAssertionsForInlineConstraintsPass(PassBase):
|
||||
for i, sym in enumerate(val.shape):
|
||||
cbs, msgs = add_assertions(sym)
|
||||
for cb, msg in zip(cbs, msgs):
|
||||
|
||||
def sym_size_cb(node, assert_msg, dim):
|
||||
with node.graph.inserting_after(node):
|
||||
dim_node = module.graph.call_function(
|
||||
@ -137,6 +149,7 @@ class _AddRuntimeAssertionsForInlineConstraintsPass(PassBase):
|
||||
{},
|
||||
)
|
||||
cb(node=dim_node, assert_msg=assert_msg)
|
||||
|
||||
call_backs.append(partial(sym_size_cb, dim=i))
|
||||
messages.append(f".shape[{i}]" + msg)
|
||||
return call_backs, messages
|
||||
@ -149,12 +162,18 @@ class _AddRuntimeAssertionsForInlineConstraintsPass(PassBase):
|
||||
|
||||
# Sometimes this pass would return a wrong graph where we have mismatched
|
||||
# node names in signature. Before we fix it, let's just skip it.
|
||||
if self.counter == 0 and type(self) is _AddRuntimeAssertionsForInlineConstraintsPass:
|
||||
if (
|
||||
self.counter == 0
|
||||
and type(self) is _AddRuntimeAssertionsForInlineConstraintsPass
|
||||
):
|
||||
return PassResult(graph_module, False)
|
||||
|
||||
# Populate the stack trace with dummy vals to respect IR
|
||||
for node in graph_module.graph.nodes:
|
||||
if not node.meta.get("stack_trace", None) and node.op not in ["placeholder", "output"]:
|
||||
if not node.meta.get("stack_trace", None) and node.op not in [
|
||||
"placeholder",
|
||||
"output",
|
||||
]:
|
||||
node.meta["stack_trace"] = "".join(traceback.format_stack(limit=1))
|
||||
return PassResult(graph_module, True)
|
||||
|
||||
@ -179,10 +198,10 @@ def _get_existing_inline_assertions(
|
||||
|
||||
compare_arg = node.args[0]
|
||||
if not (
|
||||
isinstance(compare_arg, torch.fx.Node) and
|
||||
compare_arg.op == "call_function" and
|
||||
compare_arg.target in (operator.le, operator.ge) and
|
||||
len(compare_arg.args) == 2
|
||||
isinstance(compare_arg, torch.fx.Node)
|
||||
and compare_arg.op == "call_function"
|
||||
and compare_arg.target in (operator.le, operator.ge)
|
||||
and len(compare_arg.args) == 2
|
||||
):
|
||||
continue
|
||||
|
||||
@ -191,9 +210,9 @@ def _get_existing_inline_assertions(
|
||||
|
||||
def maybe_get_symint(x):
|
||||
if (
|
||||
isinstance(x, torch.fx.Node) and
|
||||
"val" in x.meta and
|
||||
isinstance(x.meta["val"], torch.SymInt)
|
||||
isinstance(x, torch.fx.Node)
|
||||
and "val" in x.meta
|
||||
and isinstance(x.meta["val"], torch.SymInt)
|
||||
):
|
||||
return x.meta["val"].node.expr
|
||||
return x
|
||||
@ -214,9 +233,13 @@ def _get_existing_inline_assertions(
|
||||
continue
|
||||
|
||||
if symint not in range_constraints:
|
||||
raise RuntimeError(f"Unable to find symint {symint} in {range_constraints}")
|
||||
raise RuntimeError(
|
||||
f"Unable to find symint {symint} in {range_constraints}"
|
||||
)
|
||||
|
||||
previous_range = existing_inline_assertions.get(symint, ValueRanges(-math.inf, math.inf))
|
||||
previous_range = existing_inline_assertions.get(
|
||||
symint, ValueRanges(-math.inf, math.inf)
|
||||
)
|
||||
|
||||
if symint is lhs:
|
||||
bounds = ValueRanges(-math.inf, scalar)
|
||||
|
@ -2,11 +2,16 @@ import copy
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
from torch._export.pass_base import _ExportPassBaseDeprecatedDoNotUse, PassResult, Argument
|
||||
from torch._export.pass_base import (
|
||||
_ExportPassBaseDeprecatedDoNotUse,
|
||||
Argument,
|
||||
PassResult,
|
||||
)
|
||||
from torch._export.pass_infra.node_metadata import NodeMetadata
|
||||
from torch._export.pass_infra.proxy_value import ProxyValue
|
||||
from torch._ops import OpOverload
|
||||
|
||||
|
||||
aten = torch.ops.aten
|
||||
|
||||
_NON_FUNCTIONAL_TO_FUNCTIONAL_SIDE_EFFECTFUL_FUNCS: dict[OpOverload, OpOverload] = {
|
||||
|
Reference in New Issue
Block a user