mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Generate meta kernel with operator profiles (#150807)
Added a context manager, `torch._library.fake_profile.register_fake_profile(op_profiles)`, where given an operator profile, it will generate and register a fake impl for the operator based on the operator profile. The input to `register_fake_profile` is a dictionary mapping operator name to a set of profiles which describe the input and outputs of the operator. Here's an example of a profile for `mylib.foo.default`: ``` "mylib.foo.default": { OpProfile( args_profile=( TensorMetadata(rank=2, dtype=torch.float32, device=torch.device("cpu"), layout=torch.strided,), TensorMetadata(rank=2, dtype=torch.float32, device=torch.device("cpu"), layout=torch.strided,), ), out_profile=TensorMetadata(rank=2, dtype=torch.float32, device=torch.device("cpu"), layout=torch.strided,), ) } ``` `foo`'s profile contains only one profile, which says that for 2 input tensors of rank 2, dtype float32, device cpu, we will return one tensor of rank 2, dtype float32, and device cpu. This will then generate a fake kernel where given 2 input tensors of rank 2 (and the other tensor metadata), we will output one tensor of rank 2 (and the other tensor metadata). If the operator also supports other input ranks, then we can add to the profile for the fake impl to support more input types. This profile can either be manually written or created by draft-export, and then checked into the codebase. Pull Request resolved: https://github.com/pytorch/pytorch/pull/150807 Approved by: https://github.com/zou3519 ghstack dependencies: #150806
This commit is contained in:
committed by
PyTorch MergeBot
parent
901e37515f
commit
53528440e1
@ -20,8 +20,10 @@ import torch.utils.cpp_extension
|
||||
from functorch import make_fx
|
||||
from torch import Tensor
|
||||
from torch._custom_op.impl import CustomOp, infer_schema
|
||||
from torch._library.fake_profile import MissingOpProfile, OpProfile, TensorMetadata
|
||||
from torch._library.infer_schema import tuple_to_list
|
||||
from torch._utils_internal import get_file_path_2 # @manual
|
||||
from torch.fx.experimental.symbolic_shapes import ShapeEnv
|
||||
from torch.testing._internal import custom_op_db
|
||||
from torch.testing._internal.common_cuda import TEST_CUDA
|
||||
from torch.testing._internal.common_device_type import (
|
||||
@ -4456,6 +4458,147 @@ class TestTypeConversion(TestCase):
|
||||
self.assertEqual(result_type, list[typing.Union[int, float, str]])
|
||||
|
||||
|
||||
class TestOpProfiles(TestCase):
|
||||
def get_sample_op_profile(self) -> dict[str, set[OpProfile]]:
|
||||
return {
|
||||
"mylib.foo.default": {
|
||||
OpProfile(
|
||||
args_profile=(
|
||||
TensorMetadata(
|
||||
rank=2,
|
||||
dtype=torch.float32,
|
||||
device=torch.device("cpu"),
|
||||
layout=torch.strided,
|
||||
),
|
||||
TensorMetadata(
|
||||
rank=2,
|
||||
dtype=torch.float32,
|
||||
device=torch.device("cpu"),
|
||||
layout=torch.strided,
|
||||
),
|
||||
),
|
||||
out_profile=TensorMetadata(
|
||||
rank=2,
|
||||
dtype=torch.float32,
|
||||
device=torch.device("cpu"),
|
||||
layout=torch.strided,
|
||||
),
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
def test_fake_registration(self):
|
||||
fm = torch._subclasses.FakeTensorMode(
|
||||
shape_env=ShapeEnv(allow_dynamic_output_shape_ops=True)
|
||||
)
|
||||
t1 = fm.from_tensor(torch.ones(3, 3))
|
||||
t2 = fm.from_tensor(torch.ones(3, 3))
|
||||
|
||||
op_profiles = self.get_sample_op_profile()
|
||||
|
||||
with torch.library._scoped_library("mylib", "FRAGMENT") as lib:
|
||||
torch.library.define(
|
||||
"mylib::foo",
|
||||
"(Tensor a, Tensor b) -> Tensor",
|
||||
tags=torch.Tag.pt2_compliant_tag,
|
||||
lib=lib,
|
||||
)
|
||||
|
||||
@torch.library.impl("mylib::foo", "cpu", lib=lib)
|
||||
def foo_impl(a, b):
|
||||
return a + b
|
||||
|
||||
with (
|
||||
self.assertRaisesRegex(
|
||||
torch._subclasses.fake_tensor.UnsupportedOperatorException,
|
||||
"mylib.foo.default",
|
||||
),
|
||||
fm,
|
||||
):
|
||||
torch.ops.mylib.foo(t1, t2)
|
||||
|
||||
with torch._library.fake_profile.register_fake_profile(op_profiles), fm:
|
||||
torch.ops.mylib.foo(t1, t2)
|
||||
|
||||
with self.assertRaisesRegex(MissingOpProfile, "mylib::foo"):
|
||||
torch.ops.mylib.foo(torch.ones(3, 3, 3), torch.ones(3, 3, 3))
|
||||
|
||||
with (
|
||||
self.assertRaisesRegex(
|
||||
torch._subclasses.fake_tensor.UnsupportedOperatorException,
|
||||
"mylib.foo.default",
|
||||
),
|
||||
fm,
|
||||
):
|
||||
torch.ops.mylib.foo(t1, t2)
|
||||
|
||||
def test_duplicate_registration_impl(self):
|
||||
fm = torch._subclasses.FakeTensorMode(
|
||||
shape_env=ShapeEnv(allow_dynamic_output_shape_ops=True)
|
||||
)
|
||||
t1 = fm.from_tensor(torch.ones(3, 3))
|
||||
t2 = fm.from_tensor(torch.ones(3, 3))
|
||||
|
||||
op_profiles = self.get_sample_op_profile()
|
||||
|
||||
with torch.library._scoped_library("mylib", "FRAGMENT") as lib:
|
||||
torch.library.define(
|
||||
"mylib::foo",
|
||||
"(Tensor a, Tensor b) -> Tensor",
|
||||
tags=torch.Tag.pt2_compliant_tag,
|
||||
lib=lib,
|
||||
)
|
||||
|
||||
@torch.library.impl("mylib::foo", "cpu", lib=lib)
|
||||
def foo_impl(a, b):
|
||||
return a + b
|
||||
|
||||
@torch.library.register_fake("mylib::foo", lib=lib)
|
||||
def foo_impl_fake(a, b):
|
||||
return (a + b).to(dtype=torch.bfloat16)
|
||||
|
||||
with fm:
|
||||
self.assertEqual(torch.ops.mylib.foo(t1, t2).dtype, torch.bfloat16)
|
||||
|
||||
with torch._library.fake_profile.register_fake_profile(op_profiles):
|
||||
with fm:
|
||||
self.assertEqual(torch.ops.mylib.foo(t1, t2).dtype, torch.float32)
|
||||
|
||||
with fm:
|
||||
self.assertEqual(torch.ops.mylib.foo(t1, t2).dtype, torch.bfloat16)
|
||||
|
||||
def test_duplicate_registration_custom_op(self):
|
||||
fm = torch._subclasses.FakeTensorMode(
|
||||
shape_env=ShapeEnv(allow_dynamic_output_shape_ops=True)
|
||||
)
|
||||
t1 = fm.from_tensor(torch.ones(3, 3))
|
||||
t2 = fm.from_tensor(torch.ones(3, 3))
|
||||
|
||||
op_profiles = self.get_sample_op_profile()
|
||||
|
||||
@torch.library.custom_op("mylib::foo1", mutates_args=())
|
||||
def foo_impl(a: torch.Tensor, b: torch.Tensor) -> torch.Tensor:
|
||||
return a + b
|
||||
|
||||
@torch.library.register_fake("mylib::foo1")
|
||||
def foo_impl_fake(a, b):
|
||||
return torch.empty_like(a, dtype=torch.bfloat16)
|
||||
|
||||
with fm:
|
||||
self.assertEqual(torch.ops.mylib.foo1(t1, t2).dtype, torch.bfloat16)
|
||||
|
||||
op_profiles = {
|
||||
"mylib.foo1.default": self.get_sample_op_profile()["mylib.foo.default"]
|
||||
}
|
||||
|
||||
with torch._library.fake_profile.register_fake_profile(op_profiles):
|
||||
with fm:
|
||||
self.assertEqual(torch.ops.mylib.foo1(t1, t2).dtype, torch.float32)
|
||||
|
||||
with fm:
|
||||
self.assertEqual(torch.ops.mylib.foo1(t1, t2).dtype, torch.bfloat16)
|
||||
|
||||
|
||||
only_for = ("cpu", "cuda")
|
||||
instantiate_device_type_tests(TestCustomOpTesting, globals(), only_for=only_for)
|
||||
instantiate_parametrized_tests(TestCustomOp)
|
||||
|
154
torch/_library/fake_profile.py
Normal file
154
torch/_library/fake_profile.py
Normal file
@ -0,0 +1,154 @@
|
||||
import contextlib
|
||||
import logging
|
||||
from collections.abc import Generator
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Callable, Optional, Union
|
||||
|
||||
import torch
|
||||
from torch._library.custom_ops import _maybe_get_opdef
|
||||
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class MissingOpProfile(RuntimeError):
|
||||
"""
|
||||
This is raised when we don't have an operator profile available for the
|
||||
given inputs.
|
||||
"""
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class TensorMetadata:
|
||||
rank: int
|
||||
dtype: torch.dtype
|
||||
device: torch.device
|
||||
layout: torch.layout
|
||||
|
||||
@staticmethod
|
||||
def maybe_from_tensor(t: Any) -> Optional["TensorMetadata"]:
|
||||
if not isinstance(t, torch.Tensor):
|
||||
return None
|
||||
return TensorMetadata(t.dim(), t.dtype, t.device, t.layout)
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class OpProfile:
|
||||
args_profile: tuple[Optional[TensorMetadata]]
|
||||
out_profile: Union[TensorMetadata, tuple[TensorMetadata]]
|
||||
|
||||
|
||||
def _generate_fake_kernel(op_name: str, op_profile: set[OpProfile]) -> Callable:
|
||||
def _match_args(args_profile: tuple[Optional[TensorMetadata]], args: Any) -> bool:
|
||||
return all(
|
||||
TensorMetadata.maybe_from_tensor(arg) == args_profile[i]
|
||||
for i, arg in enumerate(args)
|
||||
)
|
||||
|
||||
def _generate_res(
|
||||
out_profile: Union[TensorMetadata, tuple[TensorMetadata]],
|
||||
) -> Union[torch.Tensor, list[torch.Tensor]]:
|
||||
ctx = torch.library.get_ctx()
|
||||
|
||||
def _generate_tensor_out(t: TensorMetadata) -> torch.Tensor:
|
||||
fake_shape = [ctx.new_dynamic_size() for _ in range(t.rank)]
|
||||
fake_strides = [-1] * t.rank
|
||||
expected = 1
|
||||
fake_stride = expected
|
||||
for i in range(t.rank):
|
||||
fake_strides[i] = fake_stride # type: ignore[assignment]
|
||||
fake_stride = fake_stride * fake_shape[i] # type: ignore[assignment]
|
||||
|
||||
return torch.empty_strided(
|
||||
fake_shape,
|
||||
fake_strides,
|
||||
device=t.device,
|
||||
dtype=t.dtype,
|
||||
layout=t.layout,
|
||||
)
|
||||
|
||||
if isinstance(out_profile, TensorMetadata):
|
||||
return _generate_tensor_out(out_profile)
|
||||
else:
|
||||
return [_generate_tensor_out(t) for t in out_profile]
|
||||
|
||||
def _fake_kernel(*args, **kwargs): # type: ignore[no-untyped-def]
|
||||
for profile in op_profile:
|
||||
if _match_args(profile.args_profile, (*args, *kwargs.values())):
|
||||
return _generate_res(profile.out_profile)
|
||||
|
||||
raise MissingOpProfile(
|
||||
f"No fake kernel was found for {op_name}, and although we have "
|
||||
"previously registered some profiles to generate a fake kernel, "
|
||||
f"no profiles match the given inputs: {args, kwargs}."
|
||||
)
|
||||
|
||||
return _fake_kernel
|
||||
|
||||
|
||||
@contextlib.contextmanager
|
||||
def register_fake_profile(op_profiles: dict[str, set[OpProfile]]) -> Generator:
|
||||
"""
|
||||
Registers a fake kernel based on the given operator profiles. This fake
|
||||
kernel registration will override any existing fake kernel registrations.
|
||||
|
||||
The input is a dictionary mapping operator names to a set of operator
|
||||
profiles, which we will use to generate fake kernels. The operator profiles
|
||||
are a record of the input and output tensor metadata. Based on this
|
||||
information we will match a given input to the recorded profile, and return
|
||||
an output with the same metadata as in the recorded profile. If a profile
|
||||
doesn't exist then an exception will be thrown.
|
||||
|
||||
Args:
|
||||
op_profiles (dict[str, set[OpProfile]]): A dictionary mapping operator
|
||||
name to a set of operator profiles from which we will generate fake
|
||||
kernels.
|
||||
"""
|
||||
|
||||
libs: list[torch.library.Library] = []
|
||||
# Stores old fake impls from custom ops declared through @custom_op
|
||||
old_fake_impls: dict[str, Callable] = {}
|
||||
for op_name, profiles in op_profiles.items():
|
||||
log.warning(
|
||||
"Registering fake profile for %s. This will override any existing "
|
||||
"fake kernel registration.",
|
||||
op_name,
|
||||
)
|
||||
|
||||
op_name_split = op_name.split(".")
|
||||
namespace, op_name_str = op_name_split[0], op_name_split[1]
|
||||
op_str = f"{namespace}::{op_name_str}"
|
||||
|
||||
fake_kernel = _generate_fake_kernel(op_str, profiles)
|
||||
|
||||
if opdef := _maybe_get_opdef(op_str):
|
||||
# If the op is a CustomOpDef, save the existing abstract_fn so that
|
||||
# we can restore it after this contextmanager
|
||||
if opdef._abstract_fn is not None:
|
||||
old_fake_impls[op_str] = opdef._abstract_fn
|
||||
opdef.register_fake(fake_kernel)
|
||||
|
||||
else:
|
||||
# Create a new library so that we can register a new fake impl.
|
||||
# These libraries will then be destroyed after the contextmanager,
|
||||
# which will automatically restore the previously registered fake
|
||||
# impls.
|
||||
newlib = torch.library.Library(namespace, "FRAGMENT") # noqa: TOR901
|
||||
torch.library.register_fake(
|
||||
op_str, fake_kernel, lib=newlib, allow_override=True
|
||||
)
|
||||
libs.append(newlib)
|
||||
|
||||
try:
|
||||
yield libs
|
||||
finally:
|
||||
# Destroying the libraries will automatically restore the previously
|
||||
# registered fake impls
|
||||
for lib in libs:
|
||||
lib._destroy()
|
||||
|
||||
# Restore abstract_fns for CustomOpDefs
|
||||
for op_str, old_fake in old_fake_impls.items():
|
||||
opdef = _maybe_get_opdef(op_str)
|
||||
assert opdef is not None
|
||||
opdef.register_fake(old_fake)
|
Reference in New Issue
Block a user