Rename register_fake_profile to unsafe_generate_fake_kernels (#151797)

Fixes https://docs.google.com/document/d/1BZsuUR1zJ-52Y7wP4yWX8beB4dwYbgdu5o1qKam_iWg/edit?disco=AAABiJdX1XU
Pull Request resolved: https://github.com/pytorch/pytorch/pull/151797
Approved by: https://github.com/zou3519
This commit is contained in:
angelayi
2025-04-21 23:08:12 +00:00
committed by PyTorch MergeBot
parent efdcc981d0
commit 01f1cc44cb
4 changed files with 57 additions and 12 deletions

View File

@ -56,7 +56,9 @@ class TestDraftExport(TestCase):
inp = (torch.randn(3, 3), torch.randn(3, 3))
self.assertEqual(ep.module()(*inp), M()(*inp))
with torch._library.fake_profile.register_fake_profile(report.op_profiles):
with torch._library.fake_profile.unsafe_generate_fake_kernels(
report.op_profiles
):
ep.run_decompositions()
def test_missing_meta_kernel_impl(self):
@ -95,7 +97,9 @@ class TestDraftExport(TestCase):
self.assertEqual(len(report.op_profiles["mylib.foo.default"]), 1)
print(report.op_profiles)
with torch._library.fake_profile.register_fake_profile(report.op_profiles):
with torch._library.fake_profile.unsafe_generate_fake_kernels(
report.op_profiles
):
ep = ep.run_decompositions()
self.assertEqual(ep.module()(*inp), M()(*inp))
@ -129,7 +133,9 @@ class TestDraftExport(TestCase):
self.assertEqual(len(report.op_profiles), 1)
self.assertEqual(len(report.op_profiles["mylib.foo3.default"]), 2)
with torch._library.fake_profile.register_fake_profile(report.op_profiles):
with torch._library.fake_profile.unsafe_generate_fake_kernels(
report.op_profiles
):
ep.run_decompositions()
def test_missing_meta_kernel_custom_op_update_profile(self):
@ -159,7 +165,9 @@ class TestDraftExport(TestCase):
torch.ones(2, 3, 4),
)
with torch._library.fake_profile.register_fake_profile(report.op_profiles):
with torch._library.fake_profile.unsafe_generate_fake_kernels(
report.op_profiles
):
with FakeTensorMode(allow_non_fake_inputs=True, shape_env=ShapeEnv()):
torch.ops.mylib.foo8(*inp)
with self.assertRaisesRegex(
@ -173,7 +181,7 @@ class TestDraftExport(TestCase):
self.assertEqual(len(report.op_profiles), 1)
self.assertEqual(len(report.op_profiles["mylib.foo8.default"]), 1)
with torch._library.fake_profile.register_fake_profile(
with torch._library.fake_profile.unsafe_generate_fake_kernels(
report.op_profiles
), FakeTensorMode(allow_non_fake_inputs=True, shape_env=ShapeEnv()):
torch.ops.mylib.foo8(*new_inp)
@ -560,7 +568,9 @@ class TestDraftExport(TestCase):
],
)
with torch._library.fake_profile.register_fake_profile(report.op_profiles):
with torch._library.fake_profile.unsafe_generate_fake_kernels(
report.op_profiles
):
ep.run_decompositions()
def test_override_incorrectly_aliasing_kernel(self):
@ -641,7 +651,9 @@ class TestDraftExport(TestCase):
report.failures[0].data["reason"],
"Dtypes torch.bfloat16 and torch.float32 are not equal!",
)
with torch._library.fake_profile.register_fake_profile(report.op_profiles):
with torch._library.fake_profile.unsafe_generate_fake_kernels(
report.op_profiles
):
ep.run_decompositions()
# https://github.com/pytorch/pytorch/issues/140625

View File

@ -4517,7 +4517,10 @@ class TestOpProfiles(TestCase):
):
torch.ops.mylib.foo(t1, t2)
with torch._library.fake_profile.register_fake_profile(op_profiles), fm:
with (
torch._library.fake_profile.unsafe_generate_fake_kernels(op_profiles),
fm,
):
torch.ops.mylib.foo(t1, t2)
with self.assertRaisesRegex(MissingOpProfile, "mylib::foo"):
@ -4560,7 +4563,7 @@ class TestOpProfiles(TestCase):
with fm:
self.assertEqual(torch.ops.mylib.foo(t1, t2).dtype, torch.bfloat16)
with torch._library.fake_profile.register_fake_profile(op_profiles):
with torch._library.fake_profile.unsafe_generate_fake_kernels(op_profiles):
with fm:
self.assertEqual(torch.ops.mylib.foo(t1, t2).dtype, torch.float32)
@ -4591,7 +4594,7 @@ class TestOpProfiles(TestCase):
"mylib.foo1.default": self.get_sample_op_profile()["mylib.foo.default"]
}
with torch._library.fake_profile.register_fake_profile(op_profiles):
with torch._library.fake_profile.unsafe_generate_fake_kernels(op_profiles):
with fm:
self.assertEqual(torch.ops.mylib.foo1(t1, t2).dtype, torch.float32)

View File

@ -87,7 +87,7 @@ def _generate_fake_kernel(op_name: str, op_profile: set[OpProfile]) -> Callable:
@contextlib.contextmanager
def register_fake_profile(op_profiles: dict[str, set[OpProfile]]) -> Generator:
def unsafe_generate_fake_kernels(op_profiles: dict[str, set[OpProfile]]) -> Generator:
"""
Registers a fake kernel based on the given operator profiles. This fake
kernel registration will override any existing fake kernel registrations.
@ -99,10 +99,40 @@ def register_fake_profile(op_profiles: dict[str, set[OpProfile]]) -> Generator:
an output with the same metadata as in the recorded profile. If a profile
doesn't exist then an exception will be thrown.
The fake kernel generation is considerd unsafe because it relies on the
rigid, pre-defined operator profiles that do not account for potential
variations in output behavior. Specifically, the generated kernels assume a
fixed relationship between input and output ranks. However, in reality, it's
possible that data-dependent operations may produce outputs of different
ranks even when given inputs of the same rank. The generated fake kernels
are inflexible and unable to accommodate these nuances, making them
potentially unsafe.
Args:
op_profiles (dict[str, set[OpProfile]]): A dictionary mapping operator
name to a set of operator profiles from which we will generate fake
kernels.
Examples:
>>> # Example: Registering an op-profile from draft-export
>>> import torch
>>> from torch.export._draft_export import draft_export
>>>
>>> @torch.library.custom_op("mylib::foo", mutates_args=())
>>> def foo(x: Tensor, y: Tensor) -> Tensor:
>>> return x + y
>>>
>>> class M(torch.nn.Module):
>>> def forward(self, a, b):
>>> res = torch.ops.mylib.foo(a, b) # no fake impl
>>> return res
>>>
>>> ep = draft_export(M(), (torch.ones(3, 4), torch.ones(3, 4))
>>>
>>> with torch._library.fake_profile.unsafe_generate_fake_kernels(ep._report.op_profiles):
>>> decomp = ep.run_decompositions()
"""
libs: list[torch.library.Library] = []

View File

@ -486,7 +486,7 @@ While tracing we found {len(report.op_profiles)} operator(s) which do not have a
If you intend to retrace the exported graph or run it with fake tensors, please run it under the
following context manager, which will register a fake kernel for those operators.
```
with torch._library.fake_profile.register_fake_profile(ep._report.op_profiles):
with torch._library.fake_profile.unsafe_generate_fake_kernels(ep._report.op_profiles):
# run with fake tensors
```
"""