Compare commits

...

1 Commits

Author SHA1 Message Date
7141b78dba Move decomp up to functional_tensor.py 2025-07-10 15:43:26 -07:00
6 changed files with 81 additions and 20 deletions

8
test.py Normal file
View File

@ -0,0 +1,8 @@
import torch
a = torch.zeros(3, 3)
b = torch.ones(3)
torch.compile(torch.ops.aten.diagonal_scatter, backend="aot_eager")(a, b)
print(f"yay, test passed I guess?")

View File

@ -2745,6 +2745,12 @@ def pad_sequence(sequences, batch_first=False, padding_value=0.0):
return out
@register_decomposition(aten.diagonal_scatter)
def diagonal_scatter(input: TensorLike, src: TensorLike):
result = input.clone()
result.diagonal().copy_(src)
return result
@register_decomposition(aten.index_copy_)
def index_copy_(x: TensorLike, dim: int, index: TensorLike, tensor: TensorLike):
return _index_copy(x, dim, index, tensor, inplace=True)

View File

@ -46,6 +46,7 @@ aot_graphs_log = getArtifactLogger(__name__, "aot_graphs")
def _create_graph(f, args, *, aot_config: AOTConfig) -> torch.fx.GraphModule:
# FunctionalTensorMode must be enabled here.
# See Note [Accessing .grad_fn on FunctionalTensor]
# NOTE: This seems to be where it is happening; but why not going in
with (
enable_python_dispatcher(),
FunctionalTensorMode(

View File

@ -4413,24 +4413,24 @@ def diag(
return torch.diagonal_copy(self, offset)
@register_decomposition(aten.diagonal_scatter)
@out_wrapper()
def diagonal_scatter(
input: TensorLikeType,
src: TensorLikeType,
offset: int = 0,
dim1: int = 0,
dim2: int = 1,
) -> TensorLikeType:
out = utils.clone_preserve_strides(input)
diag = out.diagonal(offset, dim1, dim2)
torch._check(
diag.shape == src.shape,
lambda: "expected src to have a size equal to the diagonal of the input."
f"Got {src.shape} for a diagonal of shape {diag.shape}",
)
copy_to(diag, src)
return out
# @register_decomposition(aten.diagonal_scatter)
# @out_wrapper()
# def diagonal_scatter(
# input: TensorLikeType,
# src: TensorLikeType,
# offset: int = 0,
# dim1: int = 0,
# dim2: int = 1,
# ) -> TensorLikeType:
# out = utils.clone_preserve_strides(input)
# diag = out.diagonal(offset, dim1, dim2)
# torch._check(
# diag.shape == src.shape,
# lambda: "expected src to have a size equal to the diagonal of the input."
# f"Got {src.shape} for a diagonal of shape {diag.shape}",
# )
# copy_to(diag, src)
# return out
@register_decomposition(aten.diagonal)

View File

@ -4,9 +4,10 @@ import warnings
import weakref
from abc import ABC, abstractmethod
from contextlib import AbstractContextManager
from typing import Any, Callable, Optional, Union
from typing import Any, Callable, Optional, Union, TYPE_CHECKING
import torch
import torch._ops
import torch.utils._pytree as pytree
from torch._C import _functionalization_reapply_views_tls as _reapply_views
from torch._ops import _get_dispatch_mode_pre_dispatch
@ -17,7 +18,15 @@ from torch.utils._python_dispatch import (
return_and_correct_aliasing,
TorchDispatchMode,
)
from collections.abc import Generator, Mapping, Sequence
from contextlib import contextmanager
if TYPE_CHECKING:
import types
from collections.abc import MutableMapping
import sympy
from torch._ops import OpOverload
not_implemented_log = torch._logging.getArtifactLogger(__name__, "not_implemented")
@ -42,6 +51,19 @@ def _conversion_method_template(**extra_kwargs):
return _
CURRENT_DECOMPOSITION_TABLE: Mapping[Any, Callable] = {}
@contextmanager
def decompose(
decomposition_table: Optional[Mapping[Any, Callable]],
) -> Generator[Mapping[Any, Callable], None, None]:
global CURRENT_DECOMPOSITION_TABLE
old_decomposition_table = CURRENT_DECOMPOSITION_TABLE
CURRENT_DECOMPOSITION_TABLE = decomposition_table or {}
try:
yield CURRENT_DECOMPOSITION_TABLE
finally:
CURRENT_DECOMPOSITION_TABLE = old_decomposition_table
class FunctionalTensor(torch.Tensor):
"""
@ -411,6 +433,9 @@ class FunctionalTensorMode(TorchDispatchMode):
# in normal torch.compile IR, we decompose functional composite ops
return True
# r = maybe_handle_decomp(self, func, args, kwargs)
# if r is not NotImplemented:
# return r
if (
func not in FunctionalTensor.metadata_fns
and _can_decompose(func)
@ -562,6 +587,28 @@ class FunctionalTensorMode(TorchDispatchMode):
def is_infra_mode(cls) -> bool:
return True
aten = torch._ops.ops.aten
def maybe_handle_decomp(
proxy_mode: FunctionalTensorMode,
op: Any,
args: tuple[object, ...],
kwargs: dict[str, object],
) -> object:
from torch._inductor.compiler_bisector import CompilerBisector
# print(op)
# if op == aten.native_batch_norm.out:
# breakpoint()
if op in CURRENT_DECOMPOSITION_TABLE:
if CompilerBisector.disable_subsystem(
"aot_eager_decomp_partition", "decomposition", lambda: repr(op)
):
return NotImplemented
with proxy_mode:
out = CURRENT_DECOMPOSITION_TABLE[op](*args, **kwargs)
return out
return NotImplemented
@contextlib.contextmanager
def disable_functional_mode():

View File

@ -2417,7 +2417,6 @@ def maybe_handle_decomp(
"aot_eager_decomp_partition", "decomposition", lambda: repr(op)
):
return NotImplemented
with proxy_mode:
proxy_mode.decomp_layers += 1
out = CURRENT_DECOMPOSITION_TABLE[op](*args, **kwargs)