mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Rename register_fake_profile to unsafe_generate_fake_kernels (#151797)
Fixes https://docs.google.com/document/d/1BZsuUR1zJ-52Y7wP4yWX8beB4dwYbgdu5o1qKam_iWg/edit?disco=AAABiJdX1XU Pull Request resolved: https://github.com/pytorch/pytorch/pull/151797 Approved by: https://github.com/zou3519
This commit is contained in:
committed by
PyTorch MergeBot
parent
efdcc981d0
commit
01f1cc44cb
@ -56,7 +56,9 @@ class TestDraftExport(TestCase):
|
||||
inp = (torch.randn(3, 3), torch.randn(3, 3))
|
||||
self.assertEqual(ep.module()(*inp), M()(*inp))
|
||||
|
||||
with torch._library.fake_profile.register_fake_profile(report.op_profiles):
|
||||
with torch._library.fake_profile.unsafe_generate_fake_kernels(
|
||||
report.op_profiles
|
||||
):
|
||||
ep.run_decompositions()
|
||||
|
||||
def test_missing_meta_kernel_impl(self):
|
||||
@ -95,7 +97,9 @@ class TestDraftExport(TestCase):
|
||||
self.assertEqual(len(report.op_profiles["mylib.foo.default"]), 1)
|
||||
print(report.op_profiles)
|
||||
|
||||
with torch._library.fake_profile.register_fake_profile(report.op_profiles):
|
||||
with torch._library.fake_profile.unsafe_generate_fake_kernels(
|
||||
report.op_profiles
|
||||
):
|
||||
ep = ep.run_decompositions()
|
||||
self.assertEqual(ep.module()(*inp), M()(*inp))
|
||||
|
||||
@ -129,7 +133,9 @@ class TestDraftExport(TestCase):
|
||||
self.assertEqual(len(report.op_profiles), 1)
|
||||
self.assertEqual(len(report.op_profiles["mylib.foo3.default"]), 2)
|
||||
|
||||
with torch._library.fake_profile.register_fake_profile(report.op_profiles):
|
||||
with torch._library.fake_profile.unsafe_generate_fake_kernels(
|
||||
report.op_profiles
|
||||
):
|
||||
ep.run_decompositions()
|
||||
|
||||
def test_missing_meta_kernel_custom_op_update_profile(self):
|
||||
@ -159,7 +165,9 @@ class TestDraftExport(TestCase):
|
||||
torch.ones(2, 3, 4),
|
||||
)
|
||||
|
||||
with torch._library.fake_profile.register_fake_profile(report.op_profiles):
|
||||
with torch._library.fake_profile.unsafe_generate_fake_kernels(
|
||||
report.op_profiles
|
||||
):
|
||||
with FakeTensorMode(allow_non_fake_inputs=True, shape_env=ShapeEnv()):
|
||||
torch.ops.mylib.foo8(*inp)
|
||||
with self.assertRaisesRegex(
|
||||
@ -173,7 +181,7 @@ class TestDraftExport(TestCase):
|
||||
self.assertEqual(len(report.op_profiles), 1)
|
||||
self.assertEqual(len(report.op_profiles["mylib.foo8.default"]), 1)
|
||||
|
||||
with torch._library.fake_profile.register_fake_profile(
|
||||
with torch._library.fake_profile.unsafe_generate_fake_kernels(
|
||||
report.op_profiles
|
||||
), FakeTensorMode(allow_non_fake_inputs=True, shape_env=ShapeEnv()):
|
||||
torch.ops.mylib.foo8(*new_inp)
|
||||
@ -560,7 +568,9 @@ class TestDraftExport(TestCase):
|
||||
],
|
||||
)
|
||||
|
||||
with torch._library.fake_profile.register_fake_profile(report.op_profiles):
|
||||
with torch._library.fake_profile.unsafe_generate_fake_kernels(
|
||||
report.op_profiles
|
||||
):
|
||||
ep.run_decompositions()
|
||||
|
||||
def test_override_incorrectly_aliasing_kernel(self):
|
||||
@ -641,7 +651,9 @@ class TestDraftExport(TestCase):
|
||||
report.failures[0].data["reason"],
|
||||
"Dtypes torch.bfloat16 and torch.float32 are not equal!",
|
||||
)
|
||||
with torch._library.fake_profile.register_fake_profile(report.op_profiles):
|
||||
with torch._library.fake_profile.unsafe_generate_fake_kernels(
|
||||
report.op_profiles
|
||||
):
|
||||
ep.run_decompositions()
|
||||
|
||||
# https://github.com/pytorch/pytorch/issues/140625
|
||||
|
||||
@ -4517,7 +4517,10 @@ class TestOpProfiles(TestCase):
|
||||
):
|
||||
torch.ops.mylib.foo(t1, t2)
|
||||
|
||||
with torch._library.fake_profile.register_fake_profile(op_profiles), fm:
|
||||
with (
|
||||
torch._library.fake_profile.unsafe_generate_fake_kernels(op_profiles),
|
||||
fm,
|
||||
):
|
||||
torch.ops.mylib.foo(t1, t2)
|
||||
|
||||
with self.assertRaisesRegex(MissingOpProfile, "mylib::foo"):
|
||||
@ -4560,7 +4563,7 @@ class TestOpProfiles(TestCase):
|
||||
with fm:
|
||||
self.assertEqual(torch.ops.mylib.foo(t1, t2).dtype, torch.bfloat16)
|
||||
|
||||
with torch._library.fake_profile.register_fake_profile(op_profiles):
|
||||
with torch._library.fake_profile.unsafe_generate_fake_kernels(op_profiles):
|
||||
with fm:
|
||||
self.assertEqual(torch.ops.mylib.foo(t1, t2).dtype, torch.float32)
|
||||
|
||||
@ -4591,7 +4594,7 @@ class TestOpProfiles(TestCase):
|
||||
"mylib.foo1.default": self.get_sample_op_profile()["mylib.foo.default"]
|
||||
}
|
||||
|
||||
with torch._library.fake_profile.register_fake_profile(op_profiles):
|
||||
with torch._library.fake_profile.unsafe_generate_fake_kernels(op_profiles):
|
||||
with fm:
|
||||
self.assertEqual(torch.ops.mylib.foo1(t1, t2).dtype, torch.float32)
|
||||
|
||||
|
||||
@ -87,7 +87,7 @@ def _generate_fake_kernel(op_name: str, op_profile: set[OpProfile]) -> Callable:
|
||||
|
||||
|
||||
@contextlib.contextmanager
|
||||
def register_fake_profile(op_profiles: dict[str, set[OpProfile]]) -> Generator:
|
||||
def unsafe_generate_fake_kernels(op_profiles: dict[str, set[OpProfile]]) -> Generator:
|
||||
"""
|
||||
Registers a fake kernel based on the given operator profiles. This fake
|
||||
kernel registration will override any existing fake kernel registrations.
|
||||
@ -99,10 +99,40 @@ def register_fake_profile(op_profiles: dict[str, set[OpProfile]]) -> Generator:
|
||||
an output with the same metadata as in the recorded profile. If a profile
|
||||
doesn't exist then an exception will be thrown.
|
||||
|
||||
The fake kernel generation is considerd unsafe because it relies on the
|
||||
rigid, pre-defined operator profiles that do not account for potential
|
||||
variations in output behavior. Specifically, the generated kernels assume a
|
||||
fixed relationship between input and output ranks. However, in reality, it's
|
||||
possible that data-dependent operations may produce outputs of different
|
||||
ranks even when given inputs of the same rank. The generated fake kernels
|
||||
are inflexible and unable to accommodate these nuances, making them
|
||||
potentially unsafe.
|
||||
|
||||
Args:
|
||||
op_profiles (dict[str, set[OpProfile]]): A dictionary mapping operator
|
||||
name to a set of operator profiles from which we will generate fake
|
||||
kernels.
|
||||
|
||||
Examples:
|
||||
|
||||
>>> # Example: Registering an op-profile from draft-export
|
||||
>>> import torch
|
||||
>>> from torch.export._draft_export import draft_export
|
||||
>>>
|
||||
>>> @torch.library.custom_op("mylib::foo", mutates_args=())
|
||||
>>> def foo(x: Tensor, y: Tensor) -> Tensor:
|
||||
>>> return x + y
|
||||
>>>
|
||||
>>> class M(torch.nn.Module):
|
||||
>>> def forward(self, a, b):
|
||||
>>> res = torch.ops.mylib.foo(a, b) # no fake impl
|
||||
>>> return res
|
||||
>>>
|
||||
>>> ep = draft_export(M(), (torch.ones(3, 4), torch.ones(3, 4))
|
||||
>>>
|
||||
>>> with torch._library.fake_profile.unsafe_generate_fake_kernels(ep._report.op_profiles):
|
||||
>>> decomp = ep.run_decompositions()
|
||||
|
||||
"""
|
||||
|
||||
libs: list[torch.library.Library] = []
|
||||
|
||||
@ -486,7 +486,7 @@ While tracing we found {len(report.op_profiles)} operator(s) which do not have a
|
||||
If you intend to retrace the exported graph or run it with fake tensors, please run it under the
|
||||
following context manager, which will register a fake kernel for those operators.
|
||||
```
|
||||
with torch._library.fake_profile.register_fake_profile(ep._report.op_profiles):
|
||||
with torch._library.fake_profile.unsafe_generate_fake_kernels(ep._report.op_profiles):
|
||||
# run with fake tensors
|
||||
```
|
||||
"""
|
||||
|
||||
Reference in New Issue
Block a user