Do not propagate real tensor in extern kernel (#151377)

Summary: See internal Diff for more details.

In ExternKernel, the FakeTensors do not have associated real tensors, because they are just created from ir.Node's shape and stride.

Test Plan:
```
buck run fbcode//mode/dev-nosan //caffe2/test/inductor:test_aot_inductor -- -r aoti_data_dependent_ex

buck2 run mode/dev-nosan  fbcode//caffe2/test/inductor:aot_inductor_arrayref_cpu -- -r data_dependent_extern_kernel_op
```

Differential Revision: D73002775

Pull Request resolved: https://github.com/pytorch/pytorch/pull/151377
Approved by: https://github.com/angelayi
This commit is contained in:
Shangdi Yu
2025-04-18 17:28:10 +00:00
committed by PyTorch MergeBot
parent 181b3883e7
commit 931bd05560
5 changed files with 62 additions and 8 deletions

View File

@ -3583,6 +3583,41 @@ class AOTInductorTestsTemplate:
self.assertTrue(same(optimized(*example_inputs), m(*example_inputs)))
def test_aoti_data_dependent_extern_kernel_op(self):
# Skip GPY because custom op only implemented for cpu
if self.device == GPU_TYPE:
raise unittest.SkipTest("skips for GPU")
with torch.library._scoped_library("mylib", "FRAGMENT") as lib:
torch.library.define(
"mylib::foo",
"(Tensor a, Tensor b) -> Tensor",
tags=torch.Tag.pt2_compliant_tag,
lib=lib,
)
@torch.library.impl("mylib::foo", "cpu", lib=lib)
def foo(a: torch.Tensor, b: torch.Tensor) -> torch.Tensor:
assert a[0] != 0
return a + b
@torch.library.impl_abstract("mylib::foo", lib=lib)
def foo_fake_impl(a, b):
return a + b
class Model(torch.nn.Module):
def forward(self, a, b):
res = torch.ops.mylib.foo(a, b)
return res
example_inputs = (torch.ones(10), torch.randn(10))
from torch._functorch import config as functorch_config
# use this config to mimic FakeTensors resulting from
# draft export
with functorch_config.patch({"fake_tensor_propagate_real_tensors": True}):
self.check_model(Model(), example_inputs)
def test_index_put_with_none_index(self):
# index_put falls back in the deterministic mode
with DeterministicGuard(True):

View File

@ -169,6 +169,9 @@ CPU_TEST_FAILURES = {
"test_symbool_item": fail_minimal_arrayref_interface(is_skip=True),
# TODO: AttributeError: 'ShapeAsConstantBuffer' object has no attribute 'dtype'
"test_symfloat_item": fail_minimal_arrayref_interface(is_skip=True),
"test_aoti_data_dependent_extern_kernel_op": fail_minimal_arrayref_interface(
is_skip=True
),
}

View File

@ -1043,7 +1043,7 @@ class ChainedSource(Source):
return current
def detect_fake_mode(inputs: Any = None):
def detect_fake_mode(inputs: Any = None, ignore_context=False):
"""
Attempts to "detect" what the current fake mode is. If there is one ambiently
available from TracingContext, we preferentially use that. Otherwise, we
@ -1058,10 +1058,11 @@ def detect_fake_mode(inputs: Any = None):
fake_modes = []
if context := TracingContext.try_get():
fake_mode = context.fake_mode
if fake_mode is not None:
fake_modes.append((fake_mode, "tracing context", 0))
if not ignore_context:
if context := TracingContext.try_get():
fake_mode = context.fake_mode
if fake_mode is not None:
fake_modes.append((fake_mode, "tracing context", 0))
from torch.utils._python_dispatch import _get_current_dispatch_mode_stack

View File

@ -34,7 +34,7 @@ import torch._library.utils as library_utils
import torch._logging
import torch.fx
import torch.utils._pytree as pytree
from torch._dynamo.utils import identity
from torch._dynamo.utils import detect_fake_mode, identity
from torch._export.serde.serialize import GraphModuleSerializer
from torch._higher_order_ops.auto_functionalize import can_auto_functionalize
from torch._inductor import metrics
@ -45,7 +45,7 @@ from torch._prims_common import (
make_channels_last_strides_for,
StrideType,
)
from torch._subclasses.fake_tensor import get_schema_info
from torch._subclasses.fake_tensor import get_schema_info, not_progapate_real_tensors
from torch.fx.experimental.symbolic_shapes import (
_remove_effect_token_unbacked_bindings,
compute_unbacked_bindings,
@ -5228,7 +5228,10 @@ class ExternKernel(InputsKernel):
example_args.append(ir_node_to_tensor(x, guard_shape=True))
new_args, new_kwargs = unflatten_args(example_args, non_tensor_args)
example_output = kernel(*new_args, **new_kwargs)
fake_mode = detect_fake_mode(example_args, ignore_context=True)
with not_progapate_real_tensors(fake_mode):
example_output = kernel(*new_args, **new_kwargs)
unbacked_bindings: Optional[dict[sympy.Symbol, pytree.KeyPath]] = None
if shape_env := V.fake_mode.shape_env:

View File

@ -3113,3 +3113,15 @@ def inferred_fake_kernel_from_real_out(
fake_flat_out = [_infer_fake_from_real_tensor(mode, op, t) for t in real_flat_out]
return pytree.tree_unflatten(fake_flat_out, spec)
@contextlib.contextmanager
def not_progapate_real_tensors(
fake_mode: FakeTensorMode,
) -> Generator[None, None, None]:
original_value = fake_mode.propagate_real_tensors
fake_mode.propagate_real_tensors = False
try:
yield
finally:
fake_mode.propagate_real_tensors = original_value