Revert "Improve torch.ops typing (#153558)"

This reverts commit c5cba39d469151895cd0ecf7673b98e5072b69c2.

Reverted https://github.com/pytorch/pytorch/pull/153558 on behalf of https://github.com/yangw-dev due to Your diff will not be landed to fbcode since we suspect it caused the following breakage in an internal test:[D75007157](https://www.internalfb.com/diff/D75007157) for instance: tests_gpu/lookup_gpu_index_test.py:232:8 Undefined attribute [16]: torch._ops._OpNamespace has no attribute simple_index_mm_batch ([comment](https://github.com/pytorch/pytorch/pull/153558#issuecomment-2892506789))
This commit is contained in:
PyTorch MergeBot
2025-05-19 23:32:36 +00:00
parent 701e22112d
commit d81217be2e
13 changed files with 96 additions and 131 deletions

View File

@ -1636,7 +1636,6 @@ class Generator:
class _DispatchOperatorHandle:
def schema(self) -> FunctionSchema: ...
def debug(self) -> str: ...
def redispatch_boxed(self, keyset: DispatchKeySet, *args, **kwargs) -> Stack: ...
class _DispatchModule:
def reset(self) -> None: ...
@ -1833,7 +1832,7 @@ class _SetExcludeDispatchKeyGuard:
# Defined in torch/csrc/utils/schema_info.h
class _SchemaInfo:
def __init__(self, schema: FunctionSchema) -> None: ...
def __init__(self, schema: _int) -> None: ...
@overload
def is_mutable(self) -> _bool: ...

View File

@ -6,7 +6,7 @@ import typing
import warnings
from collections.abc import Sequence
from contextlib import contextmanager
from typing import Any, Callable, Optional, Union
from typing import Any, Optional, Union
import torch
import torch.export._trace
@ -229,7 +229,7 @@ def get_dtype_as_int(tensor):
# Those operators will be automatically populated to a instance method
# of TS2FXGraphConverter with name convert_<namespace>_<opname>().
# Please check __init__ for method population implementations.
kind_to_standard_operators: dict[str, Callable[..., Any]] = {
kind_to_standard_operators = {
"prim::max": builtins.max,
"prim::min": builtins.min,
"prim::TupleIndex": operator.getitem,

View File

@ -10,7 +10,7 @@ from torch._ops import OpOverload
aten = torch.ops.aten
_NON_FUNCTIONAL_TO_FUNCTIONAL_SIDE_EFFECTFUL_FUNCS: dict[OpOverload, OpOverload] = {
aten.sym_constrain_range.default: aten._functional_sym_constrain_range.default,
aten.sym_constrain_range.default: aten._functional_sym_constrain_range,
aten._assert_async.msg: aten._functional_assert_async.msg,
}

View File

@ -10,7 +10,7 @@ import os
import os.path
from collections import defaultdict
from dataclasses import dataclass, replace
from typing import Any, Callable, Optional, TYPE_CHECKING, Union
from typing import Callable, Optional, TYPE_CHECKING, Union
import torch
import torch._inductor.inductor_prims
@ -2049,9 +2049,7 @@ def get_default_op_list() -> OpTypes:
default_recomputable_ops += [method_to_operator(m) for m in magic_methods]
recomputable_ops = OrderedSet(default_recomputable_ops)
random_ops = OrderedSet[Callable[..., Any]](
[aten.native_dropout, aten.rand_like, aten.randn_like]
)
random_ops = OrderedSet([aten.native_dropout, aten.rand_like, aten.randn_like])
compute_intensive_ops = [
aten.mm,
aten.convolution,

View File

@ -24,11 +24,11 @@ class _EffectType(Enum):
OpType = Union[torch._ops.HigherOrderOperator, torch._ops.OpOverload]
SIDE_EFFECTS = WeakKeyDictionary[OpType, _EffectType](
[
(torch.ops.aten._print.default, _EffectType.ORDERED),
(call_torchbind, _EffectType.ORDERED),
]
SIDE_EFFECTS: "WeakKeyDictionary[OpType, _EffectType]" = WeakKeyDictionary(
{
torch.ops.aten._print.default: _EffectType.ORDERED,
call_torchbind: _EffectType.ORDERED,
}
)

View File

@ -107,7 +107,7 @@ decompositions = {**core_aten_decompositions(), **inductor_decompositions}
# Remove unwanted decompositions included via the core ATen decompositions from
# the Inductor decomp table.
decomps_to_exclude: list[Union[torch._ops.OpOverload, torch._ops.OpOverloadPacket]] = [
decomps_to_exclude = [
aten._unsafe_index,
aten._unsafe_masked_index,
aten._unsafe_masked_index_put_accumulate,
@ -522,7 +522,7 @@ def fmax(self: torch.Tensor, other: torch.Tensor) -> torch.Tensor:
return torch.where(torch.isnan(other) | (other < self), self, other)
@register_decomposition([aten.amax])
@register_decomposition(aten.amax)
def amax(
self: torch.Tensor,
dim: Optional[int] = None,
@ -533,7 +533,7 @@ def amax(
return NotImplemented
@register_decomposition([aten.amin])
@register_decomposition(aten.amin)
def amin(
self: torch.Tensor,
dim: Optional[int] = None,
@ -581,7 +581,7 @@ def get_like_layout(
return memory_format
@register_decomposition([aten.rand_like])
@register_decomposition(aten.rand_like)
def rand_like(
self: torch.Tensor,
*,
@ -598,7 +598,7 @@ def rand_like(
).to(memory_format=get_like_layout(self, memory_format))
@register_decomposition([aten.randn_like])
@register_decomposition(aten.randn_like)
def randn_like(
self: torch.Tensor,
*,
@ -615,7 +615,7 @@ def randn_like(
).to(memory_format=get_like_layout(self, memory_format))
@register_decomposition([aten.full_like])
@register_decomposition(aten.full_like)
def full_like(
self: torch.Tensor,
fill_value: Union[int, float],
@ -637,7 +637,7 @@ def full_like(
).to(memory_format=get_like_layout(self, memory_format))
@register_decomposition([aten.randint_like.default])
@register_decomposition(aten.randint_like.default)
def randint_like(
self: torch.Tensor,
high: int,
@ -657,7 +657,7 @@ def randint_like(
).to(memory_format=get_like_layout(self, memory_format))
@register_decomposition([aten.randint_like.low_dtype])
@register_decomposition(aten.randint_like.low_dtype)
def randint_like_low(
self: torch.Tensor,
low: int,
@ -678,7 +678,7 @@ def randint_like_low(
).to(memory_format=get_like_layout(self, memory_format))
@register_decomposition([aten.randint.default])
@register_decomposition(aten.randint.default)
def randint(
high: int,
size: list[Union[int, torch.SymInt]],
@ -687,7 +687,7 @@ def randint(
return aten.randint.low(0, high, size, **kwargs)
@register_decomposition([quantized.linear_dynamic_fp16_unpacked_weight.default])
@register_decomposition(quantized.linear_dynamic_fp16_unpacked_weight.default)
def linear_dynamic_fp16_unpacked_weight(
input: torch.Tensor,
weight: torch.Tensor,
@ -699,7 +699,7 @@ def linear_dynamic_fp16_unpacked_weight(
)
@register_decomposition([_quantized.wrapped_quantized_linear.default])
@register_decomposition(_quantized.wrapped_quantized_linear.default)
def wrapped_quantized_linear(
input: torch.Tensor,
input_scale: torch.Tensor,
@ -726,7 +726,7 @@ def wrapped_quantized_linear(
)
@register_decomposition([torch.ops.quantized.embedding_bag_byte_unpack])
@register_decomposition(torch.ops.quantized.embedding_bag_byte_unpack)
def q_embedding_bag_byte_unpack_decomp(packed: torch.Tensor) -> torch.Tensor:
def bitcast_u8_to_f32(u8: torch.Tensor) -> torch.Tensor:
x, y, z, w = (u8[..., n].to(torch.int32) for n in (0, 1, 2, 3))
@ -771,7 +771,7 @@ def grid_sampler_2d(
return output
@register_decomposition([aten._foreach_addcmul.Scalar])
@register_decomposition(aten._foreach_addcmul.Scalar)
def _foreach_addcmul_scalar(
self: list[torch.Tensor],
left_tensors: list[torch.Tensor],
@ -783,7 +783,7 @@ def _foreach_addcmul_scalar(
)
@register_decomposition([aten._foreach_addcdiv.Scalar])
@register_decomposition(aten._foreach_addcdiv.Scalar)
def _foreach_addcdiv_scalar(
self: list[torch.Tensor],
left_tensors: list[torch.Tensor],
@ -795,7 +795,7 @@ def _foreach_addcdiv_scalar(
)
@register_decomposition([aten._foreach_lerp.Scalar])
@register_decomposition(aten._foreach_lerp.Scalar)
def _foreach_lerp_scalar(
start_tensors: list[torch.Tensor],
end_tensors: list[torch.Tensor],
@ -809,7 +809,7 @@ def _foreach_lerp_scalar(
)
@register_decomposition([aten._foreach_lerp.ScalarList])
@register_decomposition(aten._foreach_lerp.ScalarList)
def _foreach_lerp_scalarlist(
start_tensors: list[torch.Tensor],
end_tensors: list[torch.Tensor],
@ -824,7 +824,7 @@ def _foreach_lerp_scalarlist(
@aten.miopen_batch_norm.default.py_impl(torch._C.DispatchKey.Autograd)
@register_decomposition([aten.miopen_batch_norm])
@register_decomposition(aten.miopen_batch_norm)
def miopen_batch_norm(
input: torch.Tensor,
weight: torch.Tensor,
@ -869,7 +869,7 @@ def select_decomp_table() -> dict[Any, Callable[..., Any]]:
return fast_random_decomps()
@register_decomposition([aten.masked_scatter])
@register_decomposition(aten.masked_scatter)
def masked_scatter(
self: torch.Tensor,
mask: torch.Tensor,
@ -888,7 +888,7 @@ def masked_scatter(
return NotImplemented
@register_decomposition([quantized_decomposed.choose_qparams.tensor])
@register_decomposition(quantized_decomposed.choose_qparams.tensor)
def choose_qparams_tensor(
input: torch.Tensor,
quant_min: int,
@ -904,7 +904,7 @@ def choose_qparams_tensor(
return scale.to(torch.float64), zero_point.to(torch.int64)
@register_decomposition([aten.put])
@register_decomposition(aten.put)
def put(
self: torch.Tensor,
index: torch.Tensor,
@ -918,7 +918,7 @@ def put(
return flattened.reshape(self.shape)
@register_decomposition([aten.put_])
@register_decomposition(aten.put_)
def put_(
self: torch.Tensor,
index: torch.Tensor,
@ -929,7 +929,7 @@ def put_(
return self.copy_(out)
@register_decomposition([aten._softmax_backward_data.default])
@register_decomposition(aten._softmax_backward_data.default)
@pw_cast_for_opmath
def _softmax_backward_data(
grad_output: torch.Tensor,
@ -951,7 +951,7 @@ def _softmax_backward_data(
return grad_input.contiguous()
@register_decomposition([aten.index_reduce])
@register_decomposition(aten.index_reduce)
def index_reduce(
self: torch.Tensor,
dim: int,
@ -1056,7 +1056,7 @@ def _max_pool_with_indices(
return vals, indices
@register_decomposition([aten.max_pool2d_with_indices])
@register_decomposition(aten.max_pool2d_with_indices)
def max_pool2d_with_indices(
x: torch.Tensor,
kernel_size: list[int],
@ -1070,7 +1070,7 @@ def max_pool2d_with_indices(
)
@register_decomposition([aten.max_pool3d_with_indices])
@register_decomposition(aten.max_pool3d_with_indices)
def max_pool3d_with_indices(
x: torch.Tensor,
kernel_size: list[int],
@ -1084,7 +1084,7 @@ def max_pool3d_with_indices(
)
@register_decomposition([aten.adaptive_max_pool2d])
@register_decomposition(aten.adaptive_max_pool2d)
def adaptive_max_pool2d(
x: torch.Tensor, output_size: list[int]
) -> tuple[torch.Tensor, torch.Tensor]:
@ -1102,7 +1102,7 @@ def adaptive_max_pool2d(
return NotImplemented
@register_decomposition([aten.searchsorted.Scalar])
@register_decomposition(aten.searchsorted.Scalar)
def searchsorted_scalar(
sorted_sequence: torch.Tensor,
self: torch.types.Number,
@ -1122,7 +1122,7 @@ def searchsorted_scalar(
)[0]
@register_decomposition([aten.rrelu_with_noise_functional])
@register_decomposition(aten.rrelu_with_noise_functional)
def rrelu_with_noise_functional(
self: torch.Tensor,
noise: torch.Tensor,

View File

@ -2,7 +2,7 @@
import functools
import operator
from functools import reduce
from typing import Any, Callable
from typing import Any
import torch
from torch._dynamo.utils import counters
@ -1345,9 +1345,7 @@ if torch._C._has_mkldnn:
or V.aot_compilation
):
packed_linear_inputs += (bias, "none", [], "")
packed_linear_op: Callable[..., Any] = (
mkldnn._linear_pointwise.default
)
packed_linear_op = mkldnn._linear_pointwise.default
else:
packed_linear_inputs += (transpose_weight_node, bias, batch_size)
packed_linear_op = torch.ops.mkl._mkl_linear

View File

@ -3,12 +3,10 @@ import itertools
import logging
import operator
from collections import defaultdict
from collections.abc import Sequence
from dataclasses import dataclass
from typing import Any, Callable, cast, Union
from typing import Any, Callable, Union
import torch
import torch.fx.node
from torch._C._dynamo.guards import compute_overlapping_tensors
from torch._dispatch.python import enable_python_dispatcher
from torch._dynamo.utils import ReinplaceCounters, ReInplaceTrigger
@ -178,13 +176,7 @@ _ALWAYS_MUTATING_SCATTER_OPS = OrderedSet(
def scatter_always_uses_mutation(node: torch.fx.Node) -> bool:
_, _, view_ops = node.args
view_ops = cast(Sequence[torch.fx.node.Argument], view_ops)
return any(
view.target in _ALWAYS_MUTATING_SCATTER_OPS
for view in view_ops
if isinstance(view, torch.fx.Node)
and isinstance(view.target, torch._ops.OpOverload)
)
return any(view.target in _ALWAYS_MUTATING_SCATTER_OPS for view in view_ops) # type: ignore[union-attr]
def should_reinplace_scatter(node: torch.fx.Node) -> bool:
@ -275,7 +267,6 @@ def canonicalize_view_scatter_ops(graph: torch.fx.Graph) -> None:
assert len(node.args) >= 2
inp, src = node.args[:2]
assert isinstance(node.target, torch._ops.OpOverload)
scatter_view_op = ViewOp(
_SCATTER_OP_TO_VIEW[node.target],
args=node.args[2:],
@ -340,7 +331,7 @@ def canonicalize_view_scatter_ops(graph: torch.fx.Graph) -> None:
handle_view_scatter(node)
inplaceable_ops: dict[Callable[..., Any], InplaceableOp] = {
inplaceable_ops = {
aten.index_put.default: InplaceableOp(aten.index_put_.default, 0),
aten._unsafe_index_put.default: InplaceableOp(inductor_prims._unsafe_index_put_, 0),
_generalized_scatter: InplaceableOp(
@ -352,7 +343,7 @@ inplaceable_ops: dict[Callable[..., Any], InplaceableOp] = {
try:
c10d_functional = torch.ops._c10d_functional
inplaceable_collective_ops: dict[Callable[..., Any], InplaceableOp] = {
inplaceable_collective_ops = {
c10d_functional.all_reduce.default: InplaceableOp(
c10d_functional.all_reduce_.default, 0
),

View File

@ -6,17 +6,7 @@ import importlib
import inspect
import sys
import types
from functools import cached_property
from typing import (
Any,
Callable,
ClassVar,
final,
Optional,
TYPE_CHECKING,
TypeVar,
Union,
)
from typing import Any, Callable, final, Optional, TYPE_CHECKING, TypeVar, Union
from typing_extensions import Concatenate, ParamSpec
import torch
@ -745,14 +735,7 @@ def get_cached_ops():
# Each OpOverload object contains pointer to a specific operator overload, a pointer to the parent `OpOverloadPacket` object.
# You can obtain an OpOverload object through attribute query on OpOverloadPacket.
class OpOverload(OperatorBase):
def __init__(
self,
overloadpacket: "OpOverloadPacket",
op: Callable[_P, _T],
op_dk: Callable[Concatenate[DispatchKey, _P], _T],
schema: torch._C.FunctionSchema,
tags: list[Any],
) -> None:
def __init__(self, overloadpacket, op, op_dk, schema, tags):
super().__init__()
self._op = op
self._op_dk = op_dk
@ -772,6 +755,9 @@ class OpOverload(OperatorBase):
op.__module__ = overloadpacket.__module__
self.__qualname__ = self._name
self.__annotations__ = {}
# Only compute the OperatorHandle when we need it. Not all OpOverloads have
# OperatorHandles (the TorchScript ones don't...)
self._lazy_handle = None
# If the OpOverload was constructed from a Library.def in Python.
self._defined_in_python = self.__qualname__ in torch.library._defs
@ -797,11 +783,13 @@ class OpOverload(OperatorBase):
def _opname(self):
return self._schema.name.split("::")[1]
@cached_property
def _handle(self) -> torch._C._DispatchOperatorHandle:
return torch._C._dispatch_find_schema_or_throw(
self._schema.name, self._schema.overload_name
)
@property
def _handle(self):
if self._lazy_handle is None:
self._lazy_handle = torch._C._dispatch_find_schema_or_throw(
self._schema.name, self._schema.overload_name
)
return self._lazy_handle
# it's a no-op since OpOverload object is immutable and must be unique for a given op overload.
def __deepcopy__(self, memo=None):
@ -819,9 +807,7 @@ class OpOverload(OperatorBase):
# Use positional-only argument to avoid naming collision with aten ops arguments
# that are named "self". This way, all the aten ops can be called by kwargs.
def redispatch(
self, /, keyset: torch._C.DispatchKeySet, *args, **kwargs
) -> "torch._C.Stack":
def redispatch(self, /, keyset, *args, **kwargs):
return self._handle.redispatch_boxed(keyset, *args, **kwargs)
def __hash__(self):
@ -1112,22 +1098,14 @@ def _has_script_object_arg(schema: torch.FunctionSchema) -> bool:
# OpOverloadPacket class contains pointer to a base unresolved operator that doesn't correspond to a specific operator
# You can obtain an OpOverload object through attribute query.
class OpOverloadPacket:
__file__: ClassVar[str] = "torch.ops"
def __init__(
self,
qualified_op_name: str,
op_name: str,
op: Callable[..., Any],
overload_names: list[str],
) -> None:
def __init__(self, qualified_op_name, op_name, op, overload_names):
# These attributes are accessible on the object through the properties
# defined below but are immutable
self._qualified_op_name = qualified_op_name
self.__name__ = op_name
self._op = op
self._overload_names = overload_names
self._dir: list[str] = []
self._dir = []
self._has_torchbind_op_overload = any(
_has_script_object_arg(schema) for schema in self._schemas.values()
)
@ -1158,7 +1136,11 @@ class OpOverloadPacket:
for overload_name in self._overload_names
}
def __getattr__(self, key: str) -> OpOverload:
def __getattr__(self, key) -> Any:
# It is not a valid op_name when __file__ is passed in
if key == "__file__":
return "torch.ops"
# ensure that query for dunder attributes that does not exist on
# opoverloadpacket but instead exists on the self._op object does not unnecessarily call
# `_get_operation_overload` (which is an expensive operation).
@ -1306,18 +1288,19 @@ class _OpNamespace(types.ModuleType):
operation will already exist).
"""
__file__ = "torch.ops"
def __init__(self, name: str) -> None:
def __init__(self, name):
super().__init__("torch.ops." + name)
self.name = name
self._dir: list[str] = []
self._dir = []
def __iter__(self):
return iter(self._dir)
def __getattr__(self, op_name: str) -> OpOverloadPacket:
if op_name in ("__origin__", "__self__"):
def __getattr__(self, op_name) -> Any:
# It is not a valid op_name when __file__ is passed in
if op_name == "__file__":
return "torch.ops"
elif op_name in ["__origin__", "__self__"]:
raise AttributeError(
f"Invalid attribute '{op_name}' for '_OpNamespace' '{self.name}'"
)
@ -1370,25 +1353,19 @@ def _refresh_packet(packet):
packet._overload_names = overload_names
class _HigherOrderNamespace(types.ModuleType):
__file__ = "torch.ops"
class _PyOpNamespace(_OpNamespace):
def __init__(self, name, ops):
super().__init__(name)
self._ops = ops
def __init__(self) -> None:
super().__init__("torch.ops.higher_order")
self._dir: list[str] = []
def __iter__(self):
return iter(self._dir)
def __getattr__(self, name) -> HigherOrderOperator:
# Following _OpNamespace.__getattr__, we cache the op on this object.
op = _higher_order_ops.get(name, None)
def __getattr__(self, name):
# Following _OpNamespace.__getattr__, we cache the op on the _PyOpNamespace object.
op = self._ops.get(name, None)
if op is None:
raise AttributeError(
f"'_HigherOrderNamespace' 'torch.ops.higher_order' object has no attribute '{name}'"
f"'_PyOpNamespace' '{self.name}' object has no attribute '{name}'"
)
setattr(self, name, op)
self._dir.append(name)
return op
@ -1398,10 +1375,16 @@ class _Ops(types.ModuleType):
def __init__(self):
super().__init__("torch.ops")
self.loaded_libraries = set()
self.higher_order = _HigherOrderNamespace()
self._higher_order_op_namespace = _PyOpNamespace(
"torch.ops.higher_order", _higher_order_ops
)
self._dir = []
def __getattr__(self, name) -> _OpNamespace:
def __getattr__(self, name):
# Check if the name is a HigherOrderOperator
if name == "higher_order":
return self._higher_order_op_namespace
# Here we are creating `torch.ops.my_namespace`
namespace = _OpNamespace(name)
setattr(self, name, namespace)

View File

@ -28,7 +28,7 @@ from .utils import (
)
QOP_TO_ARG_NAMES_TO_SKIP: dict[Callable[..., Any], list[str]] = {
QOP_TO_ARG_NAMES_TO_SKIP = {
torch._ops.ops.quantized.hardswish: ["inplace"],
torch._ops.ops.quantized.elu: ["inplace"],
torch._ops.ops.quantized.dropout: ["inplace"],

View File

@ -85,11 +85,10 @@ def _find_q_dq_node_for_user(
q_node = None
if (
isinstance(arg := dq_node.args[0], torch.fx.Node)
and arg.op == "call_function"
and arg.target in _QUANTIZE_OPS
dq_node.args[0].op == "call_function" # type: ignore[union-attr]
and dq_node.args[0].target in _QUANTIZE_OPS # type: ignore[union-attr]
):
q_node = arg
q_node = dq_node.args[0]
return (q_node, dq_node)

View File

@ -77,7 +77,7 @@ _legal_ops = dict.fromkeys(
# Dynamo is unable to trace global set[Callable].__contains__.
# See https://github.com/pytorch/pytorch/issues/145761. Since we only have
# a handful of ops so switch to list of callables.
_side_effectful_need_to_be_preserved_pre_dispatch: list[Callable[..., Any]] = [
_side_effectful_need_to_be_preserved_pre_dispatch: list[Callable] = [
torch._C._set_grad_enabled,
torch.amp._enter_autocast,
torch.amp._exit_autocast,
@ -85,7 +85,7 @@ _side_effectful_need_to_be_preserved_pre_dispatch: list[Callable[..., Any]] = [
# TODO: Either refactor this into 2 functions 1 dce for functional graphs and 1 dce for all graphs,
# or add logic to correctly mark all inplace ops as side effectful.
_side_effectful_functions: set[Callable[..., Any]] = {
_side_effectful_functions: set[Callable] = {
torch._assert,
torch._assert_async,
_ops.aten._assert_async.msg,
@ -98,8 +98,7 @@ _side_effectful_functions: set[Callable[..., Any]] = {
_ops.profiler._record_function_exit,
_ops.inductor.accumulate_grad_.default,
operator.setitem,
*_side_effectful_need_to_be_preserved_pre_dispatch,
}
} | set(_side_effectful_need_to_be_preserved_pre_dispatch)
if hasattr(_ops.inductor, "resize_storage_bytes_"):
_side_effectful_functions.add(_ops.inductor.resize_storage_bytes_.default)

View File

@ -3,7 +3,6 @@ import _operator
import itertools
from collections import defaultdict
from enum import Enum
from typing import Any, Callable
import torch
from torch._subclasses.fake_tensor import FakeTensor, FakeTensorMode
@ -188,7 +187,7 @@ def _maybe_get_inplace_op(op):
return inplace_op
_VIEW_INVERSE_MAP: dict[Callable[..., Any], Callable[..., Any]] = {
_VIEW_INVERSE_MAP = {
torch.ops.aten.diagonal_scatter.default: torch.ops.aten.diagonal.default,
torch.ops.aten.select_scatter.default: torch.ops.aten.select.int,
torch.ops.aten.slice_scatter.default: torch.ops.aten.slice.Tensor,
@ -253,7 +252,6 @@ def _get_view_inverse_node_usages(
assert isinstance(base.meta["fake_result"], FakeTensor)
assert isinstance(mutated_view, Node)
assert isinstance(mutated_view.meta["fake_result"], FakeTensor)
assert not isinstance(n.target, str)
# Check that this view_inverse op actually corresponds to taking doing the inverse
# of one of our existing self_alias nodes.
original_view = _VIEW_INVERSE_MAP[n.target]