mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Do not propagate real tensor in extern kernel (#151377)
Summary: See internal Diff for more details. In ExternKernel, the FakeTensors do not have associated real tensors, because they are just created from ir.Node's shape and stride. Test Plan: ``` buck run fbcode//mode/dev-nosan //caffe2/test/inductor:test_aot_inductor -- -r aoti_data_dependent_ex buck2 run mode/dev-nosan fbcode//caffe2/test/inductor:aot_inductor_arrayref_cpu -- -r data_dependent_extern_kernel_op ``` Differential Revision: D73002775 Pull Request resolved: https://github.com/pytorch/pytorch/pull/151377 Approved by: https://github.com/angelayi
This commit is contained in:
committed by
PyTorch MergeBot
parent
181b3883e7
commit
931bd05560
@ -3583,6 +3583,41 @@ class AOTInductorTestsTemplate:
|
||||
|
||||
self.assertTrue(same(optimized(*example_inputs), m(*example_inputs)))
|
||||
|
||||
def test_aoti_data_dependent_extern_kernel_op(self):
|
||||
# Skip GPY because custom op only implemented for cpu
|
||||
if self.device == GPU_TYPE:
|
||||
raise unittest.SkipTest("skips for GPU")
|
||||
|
||||
with torch.library._scoped_library("mylib", "FRAGMENT") as lib:
|
||||
torch.library.define(
|
||||
"mylib::foo",
|
||||
"(Tensor a, Tensor b) -> Tensor",
|
||||
tags=torch.Tag.pt2_compliant_tag,
|
||||
lib=lib,
|
||||
)
|
||||
|
||||
@torch.library.impl("mylib::foo", "cpu", lib=lib)
|
||||
def foo(a: torch.Tensor, b: torch.Tensor) -> torch.Tensor:
|
||||
assert a[0] != 0
|
||||
return a + b
|
||||
|
||||
@torch.library.impl_abstract("mylib::foo", lib=lib)
|
||||
def foo_fake_impl(a, b):
|
||||
return a + b
|
||||
|
||||
class Model(torch.nn.Module):
|
||||
def forward(self, a, b):
|
||||
res = torch.ops.mylib.foo(a, b)
|
||||
return res
|
||||
|
||||
example_inputs = (torch.ones(10), torch.randn(10))
|
||||
from torch._functorch import config as functorch_config
|
||||
|
||||
# use this config to mimic FakeTensors resulting from
|
||||
# draft export
|
||||
with functorch_config.patch({"fake_tensor_propagate_real_tensors": True}):
|
||||
self.check_model(Model(), example_inputs)
|
||||
|
||||
def test_index_put_with_none_index(self):
|
||||
# index_put falls back in the deterministic mode
|
||||
with DeterministicGuard(True):
|
||||
|
@ -169,6 +169,9 @@ CPU_TEST_FAILURES = {
|
||||
"test_symbool_item": fail_minimal_arrayref_interface(is_skip=True),
|
||||
# TODO: AttributeError: 'ShapeAsConstantBuffer' object has no attribute 'dtype'
|
||||
"test_symfloat_item": fail_minimal_arrayref_interface(is_skip=True),
|
||||
"test_aoti_data_dependent_extern_kernel_op": fail_minimal_arrayref_interface(
|
||||
is_skip=True
|
||||
),
|
||||
}
|
||||
|
||||
|
||||
|
@ -1043,7 +1043,7 @@ class ChainedSource(Source):
|
||||
return current
|
||||
|
||||
|
||||
def detect_fake_mode(inputs: Any = None):
|
||||
def detect_fake_mode(inputs: Any = None, ignore_context=False):
|
||||
"""
|
||||
Attempts to "detect" what the current fake mode is. If there is one ambiently
|
||||
available from TracingContext, we preferentially use that. Otherwise, we
|
||||
@ -1058,10 +1058,11 @@ def detect_fake_mode(inputs: Any = None):
|
||||
|
||||
fake_modes = []
|
||||
|
||||
if context := TracingContext.try_get():
|
||||
fake_mode = context.fake_mode
|
||||
if fake_mode is not None:
|
||||
fake_modes.append((fake_mode, "tracing context", 0))
|
||||
if not ignore_context:
|
||||
if context := TracingContext.try_get():
|
||||
fake_mode = context.fake_mode
|
||||
if fake_mode is not None:
|
||||
fake_modes.append((fake_mode, "tracing context", 0))
|
||||
|
||||
from torch.utils._python_dispatch import _get_current_dispatch_mode_stack
|
||||
|
||||
|
@ -34,7 +34,7 @@ import torch._library.utils as library_utils
|
||||
import torch._logging
|
||||
import torch.fx
|
||||
import torch.utils._pytree as pytree
|
||||
from torch._dynamo.utils import identity
|
||||
from torch._dynamo.utils import detect_fake_mode, identity
|
||||
from torch._export.serde.serialize import GraphModuleSerializer
|
||||
from torch._higher_order_ops.auto_functionalize import can_auto_functionalize
|
||||
from torch._inductor import metrics
|
||||
@ -45,7 +45,7 @@ from torch._prims_common import (
|
||||
make_channels_last_strides_for,
|
||||
StrideType,
|
||||
)
|
||||
from torch._subclasses.fake_tensor import get_schema_info
|
||||
from torch._subclasses.fake_tensor import get_schema_info, not_progapate_real_tensors
|
||||
from torch.fx.experimental.symbolic_shapes import (
|
||||
_remove_effect_token_unbacked_bindings,
|
||||
compute_unbacked_bindings,
|
||||
@ -5228,7 +5228,10 @@ class ExternKernel(InputsKernel):
|
||||
example_args.append(ir_node_to_tensor(x, guard_shape=True))
|
||||
|
||||
new_args, new_kwargs = unflatten_args(example_args, non_tensor_args)
|
||||
example_output = kernel(*new_args, **new_kwargs)
|
||||
|
||||
fake_mode = detect_fake_mode(example_args, ignore_context=True)
|
||||
with not_progapate_real_tensors(fake_mode):
|
||||
example_output = kernel(*new_args, **new_kwargs)
|
||||
|
||||
unbacked_bindings: Optional[dict[sympy.Symbol, pytree.KeyPath]] = None
|
||||
if shape_env := V.fake_mode.shape_env:
|
||||
|
@ -3113,3 +3113,15 @@ def inferred_fake_kernel_from_real_out(
|
||||
|
||||
fake_flat_out = [_infer_fake_from_real_tensor(mode, op, t) for t in real_flat_out]
|
||||
return pytree.tree_unflatten(fake_flat_out, spec)
|
||||
|
||||
|
||||
@contextlib.contextmanager
|
||||
def not_progapate_real_tensors(
|
||||
fake_mode: FakeTensorMode,
|
||||
) -> Generator[None, None, None]:
|
||||
original_value = fake_mode.propagate_real_tensors
|
||||
fake_mode.propagate_real_tensors = False
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
fake_mode.propagate_real_tensors = original_value
|
||||
|
Reference in New Issue
Block a user