Rename impl_abstract to register_fake, part 1/2 (#123937)

This PR:
- adds a new torch.library.register_fake and deprecates
  torch.library.impl_abstract. The motivation is that we have a lot of
  confusion around the naming so we are going to align the naming with
  the actual subsystem (FakeTensor).
- renames `m.impl_abstract_pystub("fbgemm_gpu.sparse_ops")` to
  `m.has_python_registration("fbgemm_gpu.sparse_ops")`. No deprecation
  here yet; I need to test how this works with static initialization.
- Renames a bunch of internals to match (e.g. abstractimplpystub ->
  pystub)

I'm scared to rename the Python-side internal APIs (e.g.
torch._library.abstract_impl) because of torch.package concerns. I'll do
that in its own isolated PR next just in case it causes problems.

DEPRECATION NOTE: torch.library.impl_abstract was renamed to to
torch.library.register_fake. Please use register_fake. We'll delete
impl_abstract in a future version of PyTorch.

Test Plan:
- existing tests
Pull Request resolved: https://github.com/pytorch/pytorch/pull/123937
Approved by: https://github.com/albanD
This commit is contained in:
rzou
2024-04-16 10:57:00 -07:00
committed by PyTorch MergeBot
parent 6efcb6c718
commit 47dbfecd37
13 changed files with 108 additions and 87 deletions

View File

@ -8,6 +8,7 @@ import inspect
import re
import contextlib
import sys
import warnings
from torch._library.custom_ops import custom_op
@ -17,6 +18,7 @@ __all__ = [
'define',
'fallthrough_kernel',
'impl_abstract',
'register_fake',
'get_ctx',
'custom_op',
]
@ -244,7 +246,7 @@ def define(qualname, schema, *, lib=None, tags=()):
This entrypoint defines the custom operator (the first step)
you must then perform the second step by calling various
``impl_*`` APIs, like :func:`torch.library.impl` or
:func:`torch.library.impl_abstract`.
:func:`torch.library.register_fake`.
Args:
qualname (str): The qualified name for the operator. Should be
@ -393,21 +395,35 @@ def _(lib: Library, name, dispatch_key=""):
return wrap
def impl_abstract(qualname, func=None, *, lib=None, _stacklevel=1):
r"""Register an abstract implementation for this operator.
r"""This API was renamed to :func:`torch.library.register_fake` in PyTorch 2.4.
Please use that instead.
"""
warnings.warn("torch.library.impl_abstract was renamed to "
"torch.library.register_fake. Please use that instead; "
"we will remove torch.library.impl_abstract in a future "
"version of PyTorch.",
DeprecationWarning, stacklevel=2)
return register_fake(qualname, func, lib=lib, _stacklevel=_stacklevel + 1)
An "abstract implementation" specifies the behavior of this operator on
Tensors that carry no data. Given some input Tensors with certain properties
(sizes/strides/storage_offset/device), it specifies what the properties of
the output Tensors are.
The abstract implementation has the same signature as the operator.
It is run for both FakeTensors and meta tensors. To write an abstract
def register_fake(qualname, func=None, /, *, lib=None, _stacklevel=1):
r"""Register a FakeTensor implementation ("fake impl") for this operator.
Also sometimes known as a "meta kernel", "abstract impl".
An "FakeTensor implementation" specifies the behavior of this operator on
Tensors that carry no data ("FakeTensor"). Given some input Tensors with
certain properties (sizes/strides/storage_offset/device), it specifies
what the properties of the output Tensors are.
The FakeTensor implementation has the same signature as the operator.
It is run for both FakeTensors and meta tensors. To write a FakeTensor
implementation, assume that all Tensor inputs to the operator are
regular CPU/CUDA/Meta tensors, but they do not have storage, and
you are trying to return regular CPU/CUDA/Meta tensor(s) as output.
The abstract implementation must consist of only PyTorch operations
The FakeTensor implementation must consist of only PyTorch operations
(and may not directly access the storage or data of any input or
intermediate Tensors).
@ -426,8 +442,8 @@ def impl_abstract(qualname, func=None, *, lib=None, _stacklevel=1):
>>> "mylib::custom_linear",
>>> "(Tensor x, Tensor weight, Tensor bias) -> Tensor")
>>>
>>> @torch.library.impl_abstract("mylib::custom_linear")
>>> def custom_linear_abstract(x, weight, bias):
>>> @torch.library.register_fake("mylib::custom_linear")
>>> def _(x, weight, bias):
>>> assert x.dim() == 2
>>> assert weight.dim() == 2
>>> assert bias.dim() == 1
@ -448,10 +464,10 @@ def impl_abstract(qualname, func=None, *, lib=None, _stacklevel=1):
>>> # Example 2: an operator with data-dependent output shape
>>> torch.library.define("mylib::custom_nonzero", "(Tensor x) -> Tensor")
>>>
>>> @torch.library.impl_abstract("mylib::custom_nonzero")
>>> def custom_nonzero_abstract(x):
>>> @torch.library.register_fake("mylib::custom_nonzero")
>>> def _(x):
>>> # Number of nonzero-elements is data-dependent.
>>> # Since we cannot peek at the data in an abstract impl,
>>> # Since we cannot peek at the data in an fake impl,
>>> # we use the ctx object to construct a new symint that
>>> # represents the data-dependent size.
>>> ctx = torch.library.get_ctx()
@ -478,7 +494,7 @@ def impl_abstract(qualname, func=None, *, lib=None, _stacklevel=1):
source = torch._library.utils.get_source(_stacklevel + 1)
frame = sys._getframe(_stacklevel)
caller_module = inspect.getmodule(frame)
# Can be none if you call impl_abstract from somewhere there isn't a module
# Can be none if you call register_fake from somewhere there isn't a module
# (e.g. __main__)
caller_module_name = None if caller_module is None else caller_module.__name__
@ -505,8 +521,8 @@ def impl_abstract(qualname, func=None, *, lib=None, _stacklevel=1):
# If the op was defined in C++, then we want to make sure there was an
# m.impl_abstract_pystub(module, ...) call and that the module is the
# same as the module that called torch.library.impl_abstract.
# m.set_python_module(module, ...) call and that the module is the
# same as the module that called torch.library.register_fake.
def _check_pystubs_once(func, qualname, actual_module_name):
checked = False
@ -528,8 +544,8 @@ def _check_pystubs_once(func, qualname, actual_module_name):
cpp_filename = op._handle().debug()
raise RuntimeError(
f"Operator '{qualname}' was defined in C++ and has a Python "
f"abstract impl. In this situation, we require there to also be a "
f"companion C++ `m.impl_abstract_pystub(\"{actual_module_name}\")` "
f"fake impl. In this situation, we require there to also be a "
f"companion C++ `m.set_python_module(\"{actual_module_name}\")` "
f"call, but we could not find one. Please add that to "
f"to the top of the C++ TORCH_LIBRARY({namespace}, ...) block the "
f"operator was registered in ({cpp_filename})")
@ -537,10 +553,10 @@ def _check_pystubs_once(func, qualname, actual_module_name):
if actual_module_name != pystub_module:
cpp_filename = op._handle().debug()
raise RuntimeError(
f"Operator '{qualname}' specified that its python abstract impl "
f"Operator '{qualname}' specified that its python fake impl "
f"is in the Python module '{pystub_module}' but it was actually found "
f"in '{actual_module_name}'. Please either move the abstract impl "
f"or correct the m.impl_abstract_pystub call ({cpp_filename})")
f"in '{actual_module_name}'. Please either move the fake impl "
f"or correct the m.set_python_module call ({cpp_filename})")
checked = True
return func(*args, **kwargs)
return inner
@ -556,7 +572,7 @@ def _check_pystubs_once(func, qualname, actual_module_name):
def get_ctx() -> "torch._library.abstract_impl.AbstractImplCtx":
"""get_ctx() returns the current AbstractImplCtx object.
Calling ``get_ctx()`` is only valid inside of an abstract impl
(see :func:`torch.library.impl_abstract` for more usage details.
Calling ``get_ctx()`` is only valid inside of an fake impl
(see :func:`torch.library.register_fake` for more usage details.
"""
return torch._library.abstract_impl.global_ctx_getter()