mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Revert "Improve torch.ops typing (#153558)"
This reverts commit c5cba39d469151895cd0ecf7673b98e5072b69c2. Reverted https://github.com/pytorch/pytorch/pull/153558 on behalf of https://github.com/yangw-dev due to Your diff will not be landed to fbcode since we suspect it caused the following breakage in an internal test:[D75007157](https://www.internalfb.com/diff/D75007157) for instance: tests_gpu/lookup_gpu_index_test.py:232:8 Undefined attribute [16]: torch._ops._OpNamespace has no attribute simple_index_mm_batch ([comment](https://github.com/pytorch/pytorch/pull/153558#issuecomment-2892506789))
This commit is contained in:
@ -1636,7 +1636,6 @@ class Generator:
|
||||
class _DispatchOperatorHandle:
|
||||
def schema(self) -> FunctionSchema: ...
|
||||
def debug(self) -> str: ...
|
||||
def redispatch_boxed(self, keyset: DispatchKeySet, *args, **kwargs) -> Stack: ...
|
||||
|
||||
class _DispatchModule:
|
||||
def reset(self) -> None: ...
|
||||
@ -1833,7 +1832,7 @@ class _SetExcludeDispatchKeyGuard:
|
||||
# Defined in torch/csrc/utils/schema_info.h
|
||||
|
||||
class _SchemaInfo:
|
||||
def __init__(self, schema: FunctionSchema) -> None: ...
|
||||
def __init__(self, schema: _int) -> None: ...
|
||||
|
||||
@overload
|
||||
def is_mutable(self) -> _bool: ...
|
||||
|
@ -6,7 +6,7 @@ import typing
|
||||
import warnings
|
||||
from collections.abc import Sequence
|
||||
from contextlib import contextmanager
|
||||
from typing import Any, Callable, Optional, Union
|
||||
from typing import Any, Optional, Union
|
||||
|
||||
import torch
|
||||
import torch.export._trace
|
||||
@ -229,7 +229,7 @@ def get_dtype_as_int(tensor):
|
||||
# Those operators will be automatically populated to a instance method
|
||||
# of TS2FXGraphConverter with name convert_<namespace>_<opname>().
|
||||
# Please check __init__ for method population implementations.
|
||||
kind_to_standard_operators: dict[str, Callable[..., Any]] = {
|
||||
kind_to_standard_operators = {
|
||||
"prim::max": builtins.max,
|
||||
"prim::min": builtins.min,
|
||||
"prim::TupleIndex": operator.getitem,
|
||||
|
@ -10,7 +10,7 @@ from torch._ops import OpOverload
|
||||
aten = torch.ops.aten
|
||||
|
||||
_NON_FUNCTIONAL_TO_FUNCTIONAL_SIDE_EFFECTFUL_FUNCS: dict[OpOverload, OpOverload] = {
|
||||
aten.sym_constrain_range.default: aten._functional_sym_constrain_range.default,
|
||||
aten.sym_constrain_range.default: aten._functional_sym_constrain_range,
|
||||
aten._assert_async.msg: aten._functional_assert_async.msg,
|
||||
}
|
||||
|
||||
|
@ -10,7 +10,7 @@ import os
|
||||
import os.path
|
||||
from collections import defaultdict
|
||||
from dataclasses import dataclass, replace
|
||||
from typing import Any, Callable, Optional, TYPE_CHECKING, Union
|
||||
from typing import Callable, Optional, TYPE_CHECKING, Union
|
||||
|
||||
import torch
|
||||
import torch._inductor.inductor_prims
|
||||
@ -2049,9 +2049,7 @@ def get_default_op_list() -> OpTypes:
|
||||
default_recomputable_ops += [method_to_operator(m) for m in magic_methods]
|
||||
recomputable_ops = OrderedSet(default_recomputable_ops)
|
||||
|
||||
random_ops = OrderedSet[Callable[..., Any]](
|
||||
[aten.native_dropout, aten.rand_like, aten.randn_like]
|
||||
)
|
||||
random_ops = OrderedSet([aten.native_dropout, aten.rand_like, aten.randn_like])
|
||||
compute_intensive_ops = [
|
||||
aten.mm,
|
||||
aten.convolution,
|
||||
|
@ -24,11 +24,11 @@ class _EffectType(Enum):
|
||||
OpType = Union[torch._ops.HigherOrderOperator, torch._ops.OpOverload]
|
||||
|
||||
|
||||
SIDE_EFFECTS = WeakKeyDictionary[OpType, _EffectType](
|
||||
[
|
||||
(torch.ops.aten._print.default, _EffectType.ORDERED),
|
||||
(call_torchbind, _EffectType.ORDERED),
|
||||
]
|
||||
SIDE_EFFECTS: "WeakKeyDictionary[OpType, _EffectType]" = WeakKeyDictionary(
|
||||
{
|
||||
torch.ops.aten._print.default: _EffectType.ORDERED,
|
||||
call_torchbind: _EffectType.ORDERED,
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
|
@ -107,7 +107,7 @@ decompositions = {**core_aten_decompositions(), **inductor_decompositions}
|
||||
|
||||
# Remove unwanted decompositions included via the core ATen decompositions from
|
||||
# the Inductor decomp table.
|
||||
decomps_to_exclude: list[Union[torch._ops.OpOverload, torch._ops.OpOverloadPacket]] = [
|
||||
decomps_to_exclude = [
|
||||
aten._unsafe_index,
|
||||
aten._unsafe_masked_index,
|
||||
aten._unsafe_masked_index_put_accumulate,
|
||||
@ -522,7 +522,7 @@ def fmax(self: torch.Tensor, other: torch.Tensor) -> torch.Tensor:
|
||||
return torch.where(torch.isnan(other) | (other < self), self, other)
|
||||
|
||||
|
||||
@register_decomposition([aten.amax])
|
||||
@register_decomposition(aten.amax)
|
||||
def amax(
|
||||
self: torch.Tensor,
|
||||
dim: Optional[int] = None,
|
||||
@ -533,7 +533,7 @@ def amax(
|
||||
return NotImplemented
|
||||
|
||||
|
||||
@register_decomposition([aten.amin])
|
||||
@register_decomposition(aten.amin)
|
||||
def amin(
|
||||
self: torch.Tensor,
|
||||
dim: Optional[int] = None,
|
||||
@ -581,7 +581,7 @@ def get_like_layout(
|
||||
return memory_format
|
||||
|
||||
|
||||
@register_decomposition([aten.rand_like])
|
||||
@register_decomposition(aten.rand_like)
|
||||
def rand_like(
|
||||
self: torch.Tensor,
|
||||
*,
|
||||
@ -598,7 +598,7 @@ def rand_like(
|
||||
).to(memory_format=get_like_layout(self, memory_format))
|
||||
|
||||
|
||||
@register_decomposition([aten.randn_like])
|
||||
@register_decomposition(aten.randn_like)
|
||||
def randn_like(
|
||||
self: torch.Tensor,
|
||||
*,
|
||||
@ -615,7 +615,7 @@ def randn_like(
|
||||
).to(memory_format=get_like_layout(self, memory_format))
|
||||
|
||||
|
||||
@register_decomposition([aten.full_like])
|
||||
@register_decomposition(aten.full_like)
|
||||
def full_like(
|
||||
self: torch.Tensor,
|
||||
fill_value: Union[int, float],
|
||||
@ -637,7 +637,7 @@ def full_like(
|
||||
).to(memory_format=get_like_layout(self, memory_format))
|
||||
|
||||
|
||||
@register_decomposition([aten.randint_like.default])
|
||||
@register_decomposition(aten.randint_like.default)
|
||||
def randint_like(
|
||||
self: torch.Tensor,
|
||||
high: int,
|
||||
@ -657,7 +657,7 @@ def randint_like(
|
||||
).to(memory_format=get_like_layout(self, memory_format))
|
||||
|
||||
|
||||
@register_decomposition([aten.randint_like.low_dtype])
|
||||
@register_decomposition(aten.randint_like.low_dtype)
|
||||
def randint_like_low(
|
||||
self: torch.Tensor,
|
||||
low: int,
|
||||
@ -678,7 +678,7 @@ def randint_like_low(
|
||||
).to(memory_format=get_like_layout(self, memory_format))
|
||||
|
||||
|
||||
@register_decomposition([aten.randint.default])
|
||||
@register_decomposition(aten.randint.default)
|
||||
def randint(
|
||||
high: int,
|
||||
size: list[Union[int, torch.SymInt]],
|
||||
@ -687,7 +687,7 @@ def randint(
|
||||
return aten.randint.low(0, high, size, **kwargs)
|
||||
|
||||
|
||||
@register_decomposition([quantized.linear_dynamic_fp16_unpacked_weight.default])
|
||||
@register_decomposition(quantized.linear_dynamic_fp16_unpacked_weight.default)
|
||||
def linear_dynamic_fp16_unpacked_weight(
|
||||
input: torch.Tensor,
|
||||
weight: torch.Tensor,
|
||||
@ -699,7 +699,7 @@ def linear_dynamic_fp16_unpacked_weight(
|
||||
)
|
||||
|
||||
|
||||
@register_decomposition([_quantized.wrapped_quantized_linear.default])
|
||||
@register_decomposition(_quantized.wrapped_quantized_linear.default)
|
||||
def wrapped_quantized_linear(
|
||||
input: torch.Tensor,
|
||||
input_scale: torch.Tensor,
|
||||
@ -726,7 +726,7 @@ def wrapped_quantized_linear(
|
||||
)
|
||||
|
||||
|
||||
@register_decomposition([torch.ops.quantized.embedding_bag_byte_unpack])
|
||||
@register_decomposition(torch.ops.quantized.embedding_bag_byte_unpack)
|
||||
def q_embedding_bag_byte_unpack_decomp(packed: torch.Tensor) -> torch.Tensor:
|
||||
def bitcast_u8_to_f32(u8: torch.Tensor) -> torch.Tensor:
|
||||
x, y, z, w = (u8[..., n].to(torch.int32) for n in (0, 1, 2, 3))
|
||||
@ -771,7 +771,7 @@ def grid_sampler_2d(
|
||||
return output
|
||||
|
||||
|
||||
@register_decomposition([aten._foreach_addcmul.Scalar])
|
||||
@register_decomposition(aten._foreach_addcmul.Scalar)
|
||||
def _foreach_addcmul_scalar(
|
||||
self: list[torch.Tensor],
|
||||
left_tensors: list[torch.Tensor],
|
||||
@ -783,7 +783,7 @@ def _foreach_addcmul_scalar(
|
||||
)
|
||||
|
||||
|
||||
@register_decomposition([aten._foreach_addcdiv.Scalar])
|
||||
@register_decomposition(aten._foreach_addcdiv.Scalar)
|
||||
def _foreach_addcdiv_scalar(
|
||||
self: list[torch.Tensor],
|
||||
left_tensors: list[torch.Tensor],
|
||||
@ -795,7 +795,7 @@ def _foreach_addcdiv_scalar(
|
||||
)
|
||||
|
||||
|
||||
@register_decomposition([aten._foreach_lerp.Scalar])
|
||||
@register_decomposition(aten._foreach_lerp.Scalar)
|
||||
def _foreach_lerp_scalar(
|
||||
start_tensors: list[torch.Tensor],
|
||||
end_tensors: list[torch.Tensor],
|
||||
@ -809,7 +809,7 @@ def _foreach_lerp_scalar(
|
||||
)
|
||||
|
||||
|
||||
@register_decomposition([aten._foreach_lerp.ScalarList])
|
||||
@register_decomposition(aten._foreach_lerp.ScalarList)
|
||||
def _foreach_lerp_scalarlist(
|
||||
start_tensors: list[torch.Tensor],
|
||||
end_tensors: list[torch.Tensor],
|
||||
@ -824,7 +824,7 @@ def _foreach_lerp_scalarlist(
|
||||
|
||||
|
||||
@aten.miopen_batch_norm.default.py_impl(torch._C.DispatchKey.Autograd)
|
||||
@register_decomposition([aten.miopen_batch_norm])
|
||||
@register_decomposition(aten.miopen_batch_norm)
|
||||
def miopen_batch_norm(
|
||||
input: torch.Tensor,
|
||||
weight: torch.Tensor,
|
||||
@ -869,7 +869,7 @@ def select_decomp_table() -> dict[Any, Callable[..., Any]]:
|
||||
return fast_random_decomps()
|
||||
|
||||
|
||||
@register_decomposition([aten.masked_scatter])
|
||||
@register_decomposition(aten.masked_scatter)
|
||||
def masked_scatter(
|
||||
self: torch.Tensor,
|
||||
mask: torch.Tensor,
|
||||
@ -888,7 +888,7 @@ def masked_scatter(
|
||||
return NotImplemented
|
||||
|
||||
|
||||
@register_decomposition([quantized_decomposed.choose_qparams.tensor])
|
||||
@register_decomposition(quantized_decomposed.choose_qparams.tensor)
|
||||
def choose_qparams_tensor(
|
||||
input: torch.Tensor,
|
||||
quant_min: int,
|
||||
@ -904,7 +904,7 @@ def choose_qparams_tensor(
|
||||
return scale.to(torch.float64), zero_point.to(torch.int64)
|
||||
|
||||
|
||||
@register_decomposition([aten.put])
|
||||
@register_decomposition(aten.put)
|
||||
def put(
|
||||
self: torch.Tensor,
|
||||
index: torch.Tensor,
|
||||
@ -918,7 +918,7 @@ def put(
|
||||
return flattened.reshape(self.shape)
|
||||
|
||||
|
||||
@register_decomposition([aten.put_])
|
||||
@register_decomposition(aten.put_)
|
||||
def put_(
|
||||
self: torch.Tensor,
|
||||
index: torch.Tensor,
|
||||
@ -929,7 +929,7 @@ def put_(
|
||||
return self.copy_(out)
|
||||
|
||||
|
||||
@register_decomposition([aten._softmax_backward_data.default])
|
||||
@register_decomposition(aten._softmax_backward_data.default)
|
||||
@pw_cast_for_opmath
|
||||
def _softmax_backward_data(
|
||||
grad_output: torch.Tensor,
|
||||
@ -951,7 +951,7 @@ def _softmax_backward_data(
|
||||
return grad_input.contiguous()
|
||||
|
||||
|
||||
@register_decomposition([aten.index_reduce])
|
||||
@register_decomposition(aten.index_reduce)
|
||||
def index_reduce(
|
||||
self: torch.Tensor,
|
||||
dim: int,
|
||||
@ -1056,7 +1056,7 @@ def _max_pool_with_indices(
|
||||
return vals, indices
|
||||
|
||||
|
||||
@register_decomposition([aten.max_pool2d_with_indices])
|
||||
@register_decomposition(aten.max_pool2d_with_indices)
|
||||
def max_pool2d_with_indices(
|
||||
x: torch.Tensor,
|
||||
kernel_size: list[int],
|
||||
@ -1070,7 +1070,7 @@ def max_pool2d_with_indices(
|
||||
)
|
||||
|
||||
|
||||
@register_decomposition([aten.max_pool3d_with_indices])
|
||||
@register_decomposition(aten.max_pool3d_with_indices)
|
||||
def max_pool3d_with_indices(
|
||||
x: torch.Tensor,
|
||||
kernel_size: list[int],
|
||||
@ -1084,7 +1084,7 @@ def max_pool3d_with_indices(
|
||||
)
|
||||
|
||||
|
||||
@register_decomposition([aten.adaptive_max_pool2d])
|
||||
@register_decomposition(aten.adaptive_max_pool2d)
|
||||
def adaptive_max_pool2d(
|
||||
x: torch.Tensor, output_size: list[int]
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
@ -1102,7 +1102,7 @@ def adaptive_max_pool2d(
|
||||
return NotImplemented
|
||||
|
||||
|
||||
@register_decomposition([aten.searchsorted.Scalar])
|
||||
@register_decomposition(aten.searchsorted.Scalar)
|
||||
def searchsorted_scalar(
|
||||
sorted_sequence: torch.Tensor,
|
||||
self: torch.types.Number,
|
||||
@ -1122,7 +1122,7 @@ def searchsorted_scalar(
|
||||
)[0]
|
||||
|
||||
|
||||
@register_decomposition([aten.rrelu_with_noise_functional])
|
||||
@register_decomposition(aten.rrelu_with_noise_functional)
|
||||
def rrelu_with_noise_functional(
|
||||
self: torch.Tensor,
|
||||
noise: torch.Tensor,
|
||||
|
@ -2,7 +2,7 @@
|
||||
import functools
|
||||
import operator
|
||||
from functools import reduce
|
||||
from typing import Any, Callable
|
||||
from typing import Any
|
||||
|
||||
import torch
|
||||
from torch._dynamo.utils import counters
|
||||
@ -1345,9 +1345,7 @@ if torch._C._has_mkldnn:
|
||||
or V.aot_compilation
|
||||
):
|
||||
packed_linear_inputs += (bias, "none", [], "")
|
||||
packed_linear_op: Callable[..., Any] = (
|
||||
mkldnn._linear_pointwise.default
|
||||
)
|
||||
packed_linear_op = mkldnn._linear_pointwise.default
|
||||
else:
|
||||
packed_linear_inputs += (transpose_weight_node, bias, batch_size)
|
||||
packed_linear_op = torch.ops.mkl._mkl_linear
|
||||
|
@ -3,12 +3,10 @@ import itertools
|
||||
import logging
|
||||
import operator
|
||||
from collections import defaultdict
|
||||
from collections.abc import Sequence
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Callable, cast, Union
|
||||
from typing import Any, Callable, Union
|
||||
|
||||
import torch
|
||||
import torch.fx.node
|
||||
from torch._C._dynamo.guards import compute_overlapping_tensors
|
||||
from torch._dispatch.python import enable_python_dispatcher
|
||||
from torch._dynamo.utils import ReinplaceCounters, ReInplaceTrigger
|
||||
@ -178,13 +176,7 @@ _ALWAYS_MUTATING_SCATTER_OPS = OrderedSet(
|
||||
|
||||
def scatter_always_uses_mutation(node: torch.fx.Node) -> bool:
|
||||
_, _, view_ops = node.args
|
||||
view_ops = cast(Sequence[torch.fx.node.Argument], view_ops)
|
||||
return any(
|
||||
view.target in _ALWAYS_MUTATING_SCATTER_OPS
|
||||
for view in view_ops
|
||||
if isinstance(view, torch.fx.Node)
|
||||
and isinstance(view.target, torch._ops.OpOverload)
|
||||
)
|
||||
return any(view.target in _ALWAYS_MUTATING_SCATTER_OPS for view in view_ops) # type: ignore[union-attr]
|
||||
|
||||
|
||||
def should_reinplace_scatter(node: torch.fx.Node) -> bool:
|
||||
@ -275,7 +267,6 @@ def canonicalize_view_scatter_ops(graph: torch.fx.Graph) -> None:
|
||||
assert len(node.args) >= 2
|
||||
inp, src = node.args[:2]
|
||||
|
||||
assert isinstance(node.target, torch._ops.OpOverload)
|
||||
scatter_view_op = ViewOp(
|
||||
_SCATTER_OP_TO_VIEW[node.target],
|
||||
args=node.args[2:],
|
||||
@ -340,7 +331,7 @@ def canonicalize_view_scatter_ops(graph: torch.fx.Graph) -> None:
|
||||
handle_view_scatter(node)
|
||||
|
||||
|
||||
inplaceable_ops: dict[Callable[..., Any], InplaceableOp] = {
|
||||
inplaceable_ops = {
|
||||
aten.index_put.default: InplaceableOp(aten.index_put_.default, 0),
|
||||
aten._unsafe_index_put.default: InplaceableOp(inductor_prims._unsafe_index_put_, 0),
|
||||
_generalized_scatter: InplaceableOp(
|
||||
@ -352,7 +343,7 @@ inplaceable_ops: dict[Callable[..., Any], InplaceableOp] = {
|
||||
|
||||
try:
|
||||
c10d_functional = torch.ops._c10d_functional
|
||||
inplaceable_collective_ops: dict[Callable[..., Any], InplaceableOp] = {
|
||||
inplaceable_collective_ops = {
|
||||
c10d_functional.all_reduce.default: InplaceableOp(
|
||||
c10d_functional.all_reduce_.default, 0
|
||||
),
|
||||
|
103
torch/_ops.py
103
torch/_ops.py
@ -6,17 +6,7 @@ import importlib
|
||||
import inspect
|
||||
import sys
|
||||
import types
|
||||
from functools import cached_property
|
||||
from typing import (
|
||||
Any,
|
||||
Callable,
|
||||
ClassVar,
|
||||
final,
|
||||
Optional,
|
||||
TYPE_CHECKING,
|
||||
TypeVar,
|
||||
Union,
|
||||
)
|
||||
from typing import Any, Callable, final, Optional, TYPE_CHECKING, TypeVar, Union
|
||||
from typing_extensions import Concatenate, ParamSpec
|
||||
|
||||
import torch
|
||||
@ -745,14 +735,7 @@ def get_cached_ops():
|
||||
# Each OpOverload object contains pointer to a specific operator overload, a pointer to the parent `OpOverloadPacket` object.
|
||||
# You can obtain an OpOverload object through attribute query on OpOverloadPacket.
|
||||
class OpOverload(OperatorBase):
|
||||
def __init__(
|
||||
self,
|
||||
overloadpacket: "OpOverloadPacket",
|
||||
op: Callable[_P, _T],
|
||||
op_dk: Callable[Concatenate[DispatchKey, _P], _T],
|
||||
schema: torch._C.FunctionSchema,
|
||||
tags: list[Any],
|
||||
) -> None:
|
||||
def __init__(self, overloadpacket, op, op_dk, schema, tags):
|
||||
super().__init__()
|
||||
self._op = op
|
||||
self._op_dk = op_dk
|
||||
@ -772,6 +755,9 @@ class OpOverload(OperatorBase):
|
||||
op.__module__ = overloadpacket.__module__
|
||||
self.__qualname__ = self._name
|
||||
self.__annotations__ = {}
|
||||
# Only compute the OperatorHandle when we need it. Not all OpOverloads have
|
||||
# OperatorHandles (the TorchScript ones don't...)
|
||||
self._lazy_handle = None
|
||||
|
||||
# If the OpOverload was constructed from a Library.def in Python.
|
||||
self._defined_in_python = self.__qualname__ in torch.library._defs
|
||||
@ -797,11 +783,13 @@ class OpOverload(OperatorBase):
|
||||
def _opname(self):
|
||||
return self._schema.name.split("::")[1]
|
||||
|
||||
@cached_property
|
||||
def _handle(self) -> torch._C._DispatchOperatorHandle:
|
||||
return torch._C._dispatch_find_schema_or_throw(
|
||||
self._schema.name, self._schema.overload_name
|
||||
)
|
||||
@property
|
||||
def _handle(self):
|
||||
if self._lazy_handle is None:
|
||||
self._lazy_handle = torch._C._dispatch_find_schema_or_throw(
|
||||
self._schema.name, self._schema.overload_name
|
||||
)
|
||||
return self._lazy_handle
|
||||
|
||||
# it's a no-op since OpOverload object is immutable and must be unique for a given op overload.
|
||||
def __deepcopy__(self, memo=None):
|
||||
@ -819,9 +807,7 @@ class OpOverload(OperatorBase):
|
||||
|
||||
# Use positional-only argument to avoid naming collision with aten ops arguments
|
||||
# that are named "self". This way, all the aten ops can be called by kwargs.
|
||||
def redispatch(
|
||||
self, /, keyset: torch._C.DispatchKeySet, *args, **kwargs
|
||||
) -> "torch._C.Stack":
|
||||
def redispatch(self, /, keyset, *args, **kwargs):
|
||||
return self._handle.redispatch_boxed(keyset, *args, **kwargs)
|
||||
|
||||
def __hash__(self):
|
||||
@ -1112,22 +1098,14 @@ def _has_script_object_arg(schema: torch.FunctionSchema) -> bool:
|
||||
# OpOverloadPacket class contains pointer to a base unresolved operator that doesn't correspond to a specific operator
|
||||
# You can obtain an OpOverload object through attribute query.
|
||||
class OpOverloadPacket:
|
||||
__file__: ClassVar[str] = "torch.ops"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
qualified_op_name: str,
|
||||
op_name: str,
|
||||
op: Callable[..., Any],
|
||||
overload_names: list[str],
|
||||
) -> None:
|
||||
def __init__(self, qualified_op_name, op_name, op, overload_names):
|
||||
# These attributes are accessible on the object through the properties
|
||||
# defined below but are immutable
|
||||
self._qualified_op_name = qualified_op_name
|
||||
self.__name__ = op_name
|
||||
self._op = op
|
||||
self._overload_names = overload_names
|
||||
self._dir: list[str] = []
|
||||
self._dir = []
|
||||
self._has_torchbind_op_overload = any(
|
||||
_has_script_object_arg(schema) for schema in self._schemas.values()
|
||||
)
|
||||
@ -1158,7 +1136,11 @@ class OpOverloadPacket:
|
||||
for overload_name in self._overload_names
|
||||
}
|
||||
|
||||
def __getattr__(self, key: str) -> OpOverload:
|
||||
def __getattr__(self, key) -> Any:
|
||||
# It is not a valid op_name when __file__ is passed in
|
||||
if key == "__file__":
|
||||
return "torch.ops"
|
||||
|
||||
# ensure that query for dunder attributes that does not exist on
|
||||
# opoverloadpacket but instead exists on the self._op object does not unnecessarily call
|
||||
# `_get_operation_overload` (which is an expensive operation).
|
||||
@ -1306,18 +1288,19 @@ class _OpNamespace(types.ModuleType):
|
||||
operation will already exist).
|
||||
"""
|
||||
|
||||
__file__ = "torch.ops"
|
||||
|
||||
def __init__(self, name: str) -> None:
|
||||
def __init__(self, name):
|
||||
super().__init__("torch.ops." + name)
|
||||
self.name = name
|
||||
self._dir: list[str] = []
|
||||
self._dir = []
|
||||
|
||||
def __iter__(self):
|
||||
return iter(self._dir)
|
||||
|
||||
def __getattr__(self, op_name: str) -> OpOverloadPacket:
|
||||
if op_name in ("__origin__", "__self__"):
|
||||
def __getattr__(self, op_name) -> Any:
|
||||
# It is not a valid op_name when __file__ is passed in
|
||||
if op_name == "__file__":
|
||||
return "torch.ops"
|
||||
elif op_name in ["__origin__", "__self__"]:
|
||||
raise AttributeError(
|
||||
f"Invalid attribute '{op_name}' for '_OpNamespace' '{self.name}'"
|
||||
)
|
||||
@ -1370,25 +1353,19 @@ def _refresh_packet(packet):
|
||||
packet._overload_names = overload_names
|
||||
|
||||
|
||||
class _HigherOrderNamespace(types.ModuleType):
|
||||
__file__ = "torch.ops"
|
||||
class _PyOpNamespace(_OpNamespace):
|
||||
def __init__(self, name, ops):
|
||||
super().__init__(name)
|
||||
self._ops = ops
|
||||
|
||||
def __init__(self) -> None:
|
||||
super().__init__("torch.ops.higher_order")
|
||||
self._dir: list[str] = []
|
||||
|
||||
def __iter__(self):
|
||||
return iter(self._dir)
|
||||
|
||||
def __getattr__(self, name) -> HigherOrderOperator:
|
||||
# Following _OpNamespace.__getattr__, we cache the op on this object.
|
||||
op = _higher_order_ops.get(name, None)
|
||||
def __getattr__(self, name):
|
||||
# Following _OpNamespace.__getattr__, we cache the op on the _PyOpNamespace object.
|
||||
op = self._ops.get(name, None)
|
||||
if op is None:
|
||||
raise AttributeError(
|
||||
f"'_HigherOrderNamespace' 'torch.ops.higher_order' object has no attribute '{name}'"
|
||||
f"'_PyOpNamespace' '{self.name}' object has no attribute '{name}'"
|
||||
)
|
||||
setattr(self, name, op)
|
||||
self._dir.append(name)
|
||||
return op
|
||||
|
||||
|
||||
@ -1398,10 +1375,16 @@ class _Ops(types.ModuleType):
|
||||
def __init__(self):
|
||||
super().__init__("torch.ops")
|
||||
self.loaded_libraries = set()
|
||||
self.higher_order = _HigherOrderNamespace()
|
||||
self._higher_order_op_namespace = _PyOpNamespace(
|
||||
"torch.ops.higher_order", _higher_order_ops
|
||||
)
|
||||
self._dir = []
|
||||
|
||||
def __getattr__(self, name) -> _OpNamespace:
|
||||
def __getattr__(self, name):
|
||||
# Check if the name is a HigherOrderOperator
|
||||
if name == "higher_order":
|
||||
return self._higher_order_op_namespace
|
||||
|
||||
# Here we are creating `torch.ops.my_namespace`
|
||||
namespace = _OpNamespace(name)
|
||||
setattr(self, name, namespace)
|
||||
|
@ -28,7 +28,7 @@ from .utils import (
|
||||
)
|
||||
|
||||
|
||||
QOP_TO_ARG_NAMES_TO_SKIP: dict[Callable[..., Any], list[str]] = {
|
||||
QOP_TO_ARG_NAMES_TO_SKIP = {
|
||||
torch._ops.ops.quantized.hardswish: ["inplace"],
|
||||
torch._ops.ops.quantized.elu: ["inplace"],
|
||||
torch._ops.ops.quantized.dropout: ["inplace"],
|
||||
|
@ -85,11 +85,10 @@ def _find_q_dq_node_for_user(
|
||||
|
||||
q_node = None
|
||||
if (
|
||||
isinstance(arg := dq_node.args[0], torch.fx.Node)
|
||||
and arg.op == "call_function"
|
||||
and arg.target in _QUANTIZE_OPS
|
||||
dq_node.args[0].op == "call_function" # type: ignore[union-attr]
|
||||
and dq_node.args[0].target in _QUANTIZE_OPS # type: ignore[union-attr]
|
||||
):
|
||||
q_node = arg
|
||||
q_node = dq_node.args[0]
|
||||
return (q_node, dq_node)
|
||||
|
||||
|
||||
|
@ -77,7 +77,7 @@ _legal_ops = dict.fromkeys(
|
||||
# Dynamo is unable to trace global set[Callable].__contains__.
|
||||
# See https://github.com/pytorch/pytorch/issues/145761. Since we only have
|
||||
# a handful of ops so switch to list of callables.
|
||||
_side_effectful_need_to_be_preserved_pre_dispatch: list[Callable[..., Any]] = [
|
||||
_side_effectful_need_to_be_preserved_pre_dispatch: list[Callable] = [
|
||||
torch._C._set_grad_enabled,
|
||||
torch.amp._enter_autocast,
|
||||
torch.amp._exit_autocast,
|
||||
@ -85,7 +85,7 @@ _side_effectful_need_to_be_preserved_pre_dispatch: list[Callable[..., Any]] = [
|
||||
|
||||
# TODO: Either refactor this into 2 functions 1 dce for functional graphs and 1 dce for all graphs,
|
||||
# or add logic to correctly mark all inplace ops as side effectful.
|
||||
_side_effectful_functions: set[Callable[..., Any]] = {
|
||||
_side_effectful_functions: set[Callable] = {
|
||||
torch._assert,
|
||||
torch._assert_async,
|
||||
_ops.aten._assert_async.msg,
|
||||
@ -98,8 +98,7 @@ _side_effectful_functions: set[Callable[..., Any]] = {
|
||||
_ops.profiler._record_function_exit,
|
||||
_ops.inductor.accumulate_grad_.default,
|
||||
operator.setitem,
|
||||
*_side_effectful_need_to_be_preserved_pre_dispatch,
|
||||
}
|
||||
} | set(_side_effectful_need_to_be_preserved_pre_dispatch)
|
||||
|
||||
if hasattr(_ops.inductor, "resize_storage_bytes_"):
|
||||
_side_effectful_functions.add(_ops.inductor.resize_storage_bytes_.default)
|
||||
|
@ -3,7 +3,6 @@ import _operator
|
||||
import itertools
|
||||
from collections import defaultdict
|
||||
from enum import Enum
|
||||
from typing import Any, Callable
|
||||
|
||||
import torch
|
||||
from torch._subclasses.fake_tensor import FakeTensor, FakeTensorMode
|
||||
@ -188,7 +187,7 @@ def _maybe_get_inplace_op(op):
|
||||
return inplace_op
|
||||
|
||||
|
||||
_VIEW_INVERSE_MAP: dict[Callable[..., Any], Callable[..., Any]] = {
|
||||
_VIEW_INVERSE_MAP = {
|
||||
torch.ops.aten.diagonal_scatter.default: torch.ops.aten.diagonal.default,
|
||||
torch.ops.aten.select_scatter.default: torch.ops.aten.select.int,
|
||||
torch.ops.aten.slice_scatter.default: torch.ops.aten.slice.Tensor,
|
||||
@ -253,7 +252,6 @@ def _get_view_inverse_node_usages(
|
||||
assert isinstance(base.meta["fake_result"], FakeTensor)
|
||||
assert isinstance(mutated_view, Node)
|
||||
assert isinstance(mutated_view.meta["fake_result"], FakeTensor)
|
||||
assert not isinstance(n.target, str)
|
||||
# Check that this view_inverse op actually corresponds to taking doing the inverse
|
||||
# of one of our existing self_alias nodes.
|
||||
original_view = _VIEW_INVERSE_MAP[n.target]
|
||||
|
Reference in New Issue
Block a user