Generate meta kernel with operator profiles (#150807)

Added a context manager, `torch._library.fake_profile.register_fake_profile(op_profiles)`, where given an operator profile, it will generate and register a fake impl for the operator based on the operator profile.

The input to `register_fake_profile` is a dictionary mapping operator name to a set of profiles which describe the input and outputs of the operator. Here's an example of a profile for `mylib.foo.default`:
```
"mylib.foo.default": {
    OpProfile(
        args_profile=(
            TensorMetadata(rank=2, dtype=torch.float32, device=torch.device("cpu"), layout=torch.strided,),
            TensorMetadata(rank=2, dtype=torch.float32, device=torch.device("cpu"), layout=torch.strided,),
        ),
        out_profile=TensorMetadata(rank=2, dtype=torch.float32, device=torch.device("cpu"), layout=torch.strided,),
    )
}
```
`foo`'s profile contains only one profile, which says that for 2 input tensors of rank 2, dtype float32, device cpu, we will return one tensor of rank 2, dtype float32, and device cpu.

This will then generate a fake kernel where given 2 input tensors of rank 2 (and the other tensor metadata), we will output one tensor of rank 2 (and the other tensor metadata). If the operator also supports other input ranks, then we can add to the profile for the fake impl to support more input types.

This profile can either be manually written or created by draft-export, and then checked into the codebase.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/150807
Approved by: https://github.com/zou3519
ghstack dependencies: #150806
This commit is contained in:
angelayi
2025-04-11 11:26:31 -07:00
committed by PyTorch MergeBot
parent 901e37515f
commit 53528440e1
2 changed files with 297 additions and 0 deletions

View File

@ -20,8 +20,10 @@ import torch.utils.cpp_extension
from functorch import make_fx
from torch import Tensor
from torch._custom_op.impl import CustomOp, infer_schema
from torch._library.fake_profile import MissingOpProfile, OpProfile, TensorMetadata
from torch._library.infer_schema import tuple_to_list
from torch._utils_internal import get_file_path_2 # @manual
from torch.fx.experimental.symbolic_shapes import ShapeEnv
from torch.testing._internal import custom_op_db
from torch.testing._internal.common_cuda import TEST_CUDA
from torch.testing._internal.common_device_type import (
@ -4456,6 +4458,147 @@ class TestTypeConversion(TestCase):
self.assertEqual(result_type, list[typing.Union[int, float, str]])
class TestOpProfiles(TestCase):
def get_sample_op_profile(self) -> dict[str, set[OpProfile]]:
return {
"mylib.foo.default": {
OpProfile(
args_profile=(
TensorMetadata(
rank=2,
dtype=torch.float32,
device=torch.device("cpu"),
layout=torch.strided,
),
TensorMetadata(
rank=2,
dtype=torch.float32,
device=torch.device("cpu"),
layout=torch.strided,
),
),
out_profile=TensorMetadata(
rank=2,
dtype=torch.float32,
device=torch.device("cpu"),
layout=torch.strided,
),
)
}
}
def test_fake_registration(self):
fm = torch._subclasses.FakeTensorMode(
shape_env=ShapeEnv(allow_dynamic_output_shape_ops=True)
)
t1 = fm.from_tensor(torch.ones(3, 3))
t2 = fm.from_tensor(torch.ones(3, 3))
op_profiles = self.get_sample_op_profile()
with torch.library._scoped_library("mylib", "FRAGMENT") as lib:
torch.library.define(
"mylib::foo",
"(Tensor a, Tensor b) -> Tensor",
tags=torch.Tag.pt2_compliant_tag,
lib=lib,
)
@torch.library.impl("mylib::foo", "cpu", lib=lib)
def foo_impl(a, b):
return a + b
with (
self.assertRaisesRegex(
torch._subclasses.fake_tensor.UnsupportedOperatorException,
"mylib.foo.default",
),
fm,
):
torch.ops.mylib.foo(t1, t2)
with torch._library.fake_profile.register_fake_profile(op_profiles), fm:
torch.ops.mylib.foo(t1, t2)
with self.assertRaisesRegex(MissingOpProfile, "mylib::foo"):
torch.ops.mylib.foo(torch.ones(3, 3, 3), torch.ones(3, 3, 3))
with (
self.assertRaisesRegex(
torch._subclasses.fake_tensor.UnsupportedOperatorException,
"mylib.foo.default",
),
fm,
):
torch.ops.mylib.foo(t1, t2)
def test_duplicate_registration_impl(self):
fm = torch._subclasses.FakeTensorMode(
shape_env=ShapeEnv(allow_dynamic_output_shape_ops=True)
)
t1 = fm.from_tensor(torch.ones(3, 3))
t2 = fm.from_tensor(torch.ones(3, 3))
op_profiles = self.get_sample_op_profile()
with torch.library._scoped_library("mylib", "FRAGMENT") as lib:
torch.library.define(
"mylib::foo",
"(Tensor a, Tensor b) -> Tensor",
tags=torch.Tag.pt2_compliant_tag,
lib=lib,
)
@torch.library.impl("mylib::foo", "cpu", lib=lib)
def foo_impl(a, b):
return a + b
@torch.library.register_fake("mylib::foo", lib=lib)
def foo_impl_fake(a, b):
return (a + b).to(dtype=torch.bfloat16)
with fm:
self.assertEqual(torch.ops.mylib.foo(t1, t2).dtype, torch.bfloat16)
with torch._library.fake_profile.register_fake_profile(op_profiles):
with fm:
self.assertEqual(torch.ops.mylib.foo(t1, t2).dtype, torch.float32)
with fm:
self.assertEqual(torch.ops.mylib.foo(t1, t2).dtype, torch.bfloat16)
def test_duplicate_registration_custom_op(self):
fm = torch._subclasses.FakeTensorMode(
shape_env=ShapeEnv(allow_dynamic_output_shape_ops=True)
)
t1 = fm.from_tensor(torch.ones(3, 3))
t2 = fm.from_tensor(torch.ones(3, 3))
op_profiles = self.get_sample_op_profile()
@torch.library.custom_op("mylib::foo1", mutates_args=())
def foo_impl(a: torch.Tensor, b: torch.Tensor) -> torch.Tensor:
return a + b
@torch.library.register_fake("mylib::foo1")
def foo_impl_fake(a, b):
return torch.empty_like(a, dtype=torch.bfloat16)
with fm:
self.assertEqual(torch.ops.mylib.foo1(t1, t2).dtype, torch.bfloat16)
op_profiles = {
"mylib.foo1.default": self.get_sample_op_profile()["mylib.foo.default"]
}
with torch._library.fake_profile.register_fake_profile(op_profiles):
with fm:
self.assertEqual(torch.ops.mylib.foo1(t1, t2).dtype, torch.float32)
with fm:
self.assertEqual(torch.ops.mylib.foo1(t1, t2).dtype, torch.bfloat16)
only_for = ("cpu", "cuda")
instantiate_device_type_tests(TestCustomOpTesting, globals(), only_for=only_for)
instantiate_parametrized_tests(TestCustomOp)

View File

@ -0,0 +1,154 @@
import contextlib
import logging
from collections.abc import Generator
from dataclasses import dataclass
from typing import Any, Callable, Optional, Union
import torch
from torch._library.custom_ops import _maybe_get_opdef
log = logging.getLogger(__name__)
class MissingOpProfile(RuntimeError):
"""
This is raised when we don't have an operator profile available for the
given inputs.
"""
@dataclass(frozen=True)
class TensorMetadata:
rank: int
dtype: torch.dtype
device: torch.device
layout: torch.layout
@staticmethod
def maybe_from_tensor(t: Any) -> Optional["TensorMetadata"]:
if not isinstance(t, torch.Tensor):
return None
return TensorMetadata(t.dim(), t.dtype, t.device, t.layout)
@dataclass(frozen=True)
class OpProfile:
args_profile: tuple[Optional[TensorMetadata]]
out_profile: Union[TensorMetadata, tuple[TensorMetadata]]
def _generate_fake_kernel(op_name: str, op_profile: set[OpProfile]) -> Callable:
def _match_args(args_profile: tuple[Optional[TensorMetadata]], args: Any) -> bool:
return all(
TensorMetadata.maybe_from_tensor(arg) == args_profile[i]
for i, arg in enumerate(args)
)
def _generate_res(
out_profile: Union[TensorMetadata, tuple[TensorMetadata]],
) -> Union[torch.Tensor, list[torch.Tensor]]:
ctx = torch.library.get_ctx()
def _generate_tensor_out(t: TensorMetadata) -> torch.Tensor:
fake_shape = [ctx.new_dynamic_size() for _ in range(t.rank)]
fake_strides = [-1] * t.rank
expected = 1
fake_stride = expected
for i in range(t.rank):
fake_strides[i] = fake_stride # type: ignore[assignment]
fake_stride = fake_stride * fake_shape[i] # type: ignore[assignment]
return torch.empty_strided(
fake_shape,
fake_strides,
device=t.device,
dtype=t.dtype,
layout=t.layout,
)
if isinstance(out_profile, TensorMetadata):
return _generate_tensor_out(out_profile)
else:
return [_generate_tensor_out(t) for t in out_profile]
def _fake_kernel(*args, **kwargs): # type: ignore[no-untyped-def]
for profile in op_profile:
if _match_args(profile.args_profile, (*args, *kwargs.values())):
return _generate_res(profile.out_profile)
raise MissingOpProfile(
f"No fake kernel was found for {op_name}, and although we have "
"previously registered some profiles to generate a fake kernel, "
f"no profiles match the given inputs: {args, kwargs}."
)
return _fake_kernel
@contextlib.contextmanager
def register_fake_profile(op_profiles: dict[str, set[OpProfile]]) -> Generator:
"""
Registers a fake kernel based on the given operator profiles. This fake
kernel registration will override any existing fake kernel registrations.
The input is a dictionary mapping operator names to a set of operator
profiles, which we will use to generate fake kernels. The operator profiles
are a record of the input and output tensor metadata. Based on this
information we will match a given input to the recorded profile, and return
an output with the same metadata as in the recorded profile. If a profile
doesn't exist then an exception will be thrown.
Args:
op_profiles (dict[str, set[OpProfile]]): A dictionary mapping operator
name to a set of operator profiles from which we will generate fake
kernels.
"""
libs: list[torch.library.Library] = []
# Stores old fake impls from custom ops declared through @custom_op
old_fake_impls: dict[str, Callable] = {}
for op_name, profiles in op_profiles.items():
log.warning(
"Registering fake profile for %s. This will override any existing "
"fake kernel registration.",
op_name,
)
op_name_split = op_name.split(".")
namespace, op_name_str = op_name_split[0], op_name_split[1]
op_str = f"{namespace}::{op_name_str}"
fake_kernel = _generate_fake_kernel(op_str, profiles)
if opdef := _maybe_get_opdef(op_str):
# If the op is a CustomOpDef, save the existing abstract_fn so that
# we can restore it after this contextmanager
if opdef._abstract_fn is not None:
old_fake_impls[op_str] = opdef._abstract_fn
opdef.register_fake(fake_kernel)
else:
# Create a new library so that we can register a new fake impl.
# These libraries will then be destroyed after the contextmanager,
# which will automatically restore the previously registered fake
# impls.
newlib = torch.library.Library(namespace, "FRAGMENT") # noqa: TOR901
torch.library.register_fake(
op_str, fake_kernel, lib=newlib, allow_override=True
)
libs.append(newlib)
try:
yield libs
finally:
# Destroying the libraries will automatically restore the previously
# registered fake impls
for lib in libs:
lib._destroy()
# Restore abstract_fns for CustomOpDefs
for op_str, old_fake in old_fake_impls.items():
opdef = _maybe_get_opdef(op_str)
assert opdef is not None
opdef.register_fake(old_fake)