mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-21 05:34:18 +08:00
Rename impl_abstract to register_fake, part 1/2 (#123937)
This PR: - adds a new torch.library.register_fake and deprecates torch.library.impl_abstract. The motivation is that we have a lot of confusion around the naming so we are going to align the naming with the actual subsystem (FakeTensor). - renames `m.impl_abstract_pystub("fbgemm_gpu.sparse_ops")` to `m.has_python_registration("fbgemm_gpu.sparse_ops")`. No deprecation here yet; I need to test how this works with static initialization. - Renames a bunch of internals to match (e.g. abstractimplpystub -> pystub) I'm scared to rename the Python-side internal APIs (e.g. torch._library.abstract_impl) because of torch.package concerns. I'll do that in its own isolated PR next just in case it causes problems. DEPRECATION NOTE: torch.library.impl_abstract was renamed to to torch.library.register_fake. Please use register_fake. We'll delete impl_abstract in a future version of PyTorch. Test Plan: - existing tests Pull Request resolved: https://github.com/pytorch/pytorch/pull/123937 Approved by: https://github.com/albanD
This commit is contained in:
@ -8,6 +8,7 @@ import inspect
|
||||
import re
|
||||
import contextlib
|
||||
import sys
|
||||
import warnings
|
||||
from torch._library.custom_ops import custom_op
|
||||
|
||||
|
||||
@ -17,6 +18,7 @@ __all__ = [
|
||||
'define',
|
||||
'fallthrough_kernel',
|
||||
'impl_abstract',
|
||||
'register_fake',
|
||||
'get_ctx',
|
||||
'custom_op',
|
||||
]
|
||||
@ -244,7 +246,7 @@ def define(qualname, schema, *, lib=None, tags=()):
|
||||
This entrypoint defines the custom operator (the first step)
|
||||
you must then perform the second step by calling various
|
||||
``impl_*`` APIs, like :func:`torch.library.impl` or
|
||||
:func:`torch.library.impl_abstract`.
|
||||
:func:`torch.library.register_fake`.
|
||||
|
||||
Args:
|
||||
qualname (str): The qualified name for the operator. Should be
|
||||
@ -393,21 +395,35 @@ def _(lib: Library, name, dispatch_key=""):
|
||||
return wrap
|
||||
|
||||
|
||||
|
||||
def impl_abstract(qualname, func=None, *, lib=None, _stacklevel=1):
|
||||
r"""Register an abstract implementation for this operator.
|
||||
r"""This API was renamed to :func:`torch.library.register_fake` in PyTorch 2.4.
|
||||
Please use that instead.
|
||||
"""
|
||||
warnings.warn("torch.library.impl_abstract was renamed to "
|
||||
"torch.library.register_fake. Please use that instead; "
|
||||
"we will remove torch.library.impl_abstract in a future "
|
||||
"version of PyTorch.",
|
||||
DeprecationWarning, stacklevel=2)
|
||||
return register_fake(qualname, func, lib=lib, _stacklevel=_stacklevel + 1)
|
||||
|
||||
An "abstract implementation" specifies the behavior of this operator on
|
||||
Tensors that carry no data. Given some input Tensors with certain properties
|
||||
(sizes/strides/storage_offset/device), it specifies what the properties of
|
||||
the output Tensors are.
|
||||
|
||||
The abstract implementation has the same signature as the operator.
|
||||
It is run for both FakeTensors and meta tensors. To write an abstract
|
||||
|
||||
def register_fake(qualname, func=None, /, *, lib=None, _stacklevel=1):
|
||||
r"""Register a FakeTensor implementation ("fake impl") for this operator.
|
||||
|
||||
Also sometimes known as a "meta kernel", "abstract impl".
|
||||
|
||||
An "FakeTensor implementation" specifies the behavior of this operator on
|
||||
Tensors that carry no data ("FakeTensor"). Given some input Tensors with
|
||||
certain properties (sizes/strides/storage_offset/device), it specifies
|
||||
what the properties of the output Tensors are.
|
||||
|
||||
The FakeTensor implementation has the same signature as the operator.
|
||||
It is run for both FakeTensors and meta tensors. To write a FakeTensor
|
||||
implementation, assume that all Tensor inputs to the operator are
|
||||
regular CPU/CUDA/Meta tensors, but they do not have storage, and
|
||||
you are trying to return regular CPU/CUDA/Meta tensor(s) as output.
|
||||
The abstract implementation must consist of only PyTorch operations
|
||||
The FakeTensor implementation must consist of only PyTorch operations
|
||||
(and may not directly access the storage or data of any input or
|
||||
intermediate Tensors).
|
||||
|
||||
@ -426,8 +442,8 @@ def impl_abstract(qualname, func=None, *, lib=None, _stacklevel=1):
|
||||
>>> "mylib::custom_linear",
|
||||
>>> "(Tensor x, Tensor weight, Tensor bias) -> Tensor")
|
||||
>>>
|
||||
>>> @torch.library.impl_abstract("mylib::custom_linear")
|
||||
>>> def custom_linear_abstract(x, weight, bias):
|
||||
>>> @torch.library.register_fake("mylib::custom_linear")
|
||||
>>> def _(x, weight, bias):
|
||||
>>> assert x.dim() == 2
|
||||
>>> assert weight.dim() == 2
|
||||
>>> assert bias.dim() == 1
|
||||
@ -448,10 +464,10 @@ def impl_abstract(qualname, func=None, *, lib=None, _stacklevel=1):
|
||||
>>> # Example 2: an operator with data-dependent output shape
|
||||
>>> torch.library.define("mylib::custom_nonzero", "(Tensor x) -> Tensor")
|
||||
>>>
|
||||
>>> @torch.library.impl_abstract("mylib::custom_nonzero")
|
||||
>>> def custom_nonzero_abstract(x):
|
||||
>>> @torch.library.register_fake("mylib::custom_nonzero")
|
||||
>>> def _(x):
|
||||
>>> # Number of nonzero-elements is data-dependent.
|
||||
>>> # Since we cannot peek at the data in an abstract impl,
|
||||
>>> # Since we cannot peek at the data in an fake impl,
|
||||
>>> # we use the ctx object to construct a new symint that
|
||||
>>> # represents the data-dependent size.
|
||||
>>> ctx = torch.library.get_ctx()
|
||||
@ -478,7 +494,7 @@ def impl_abstract(qualname, func=None, *, lib=None, _stacklevel=1):
|
||||
source = torch._library.utils.get_source(_stacklevel + 1)
|
||||
frame = sys._getframe(_stacklevel)
|
||||
caller_module = inspect.getmodule(frame)
|
||||
# Can be none if you call impl_abstract from somewhere there isn't a module
|
||||
# Can be none if you call register_fake from somewhere there isn't a module
|
||||
# (e.g. __main__)
|
||||
caller_module_name = None if caller_module is None else caller_module.__name__
|
||||
|
||||
@ -505,8 +521,8 @@ def impl_abstract(qualname, func=None, *, lib=None, _stacklevel=1):
|
||||
|
||||
|
||||
# If the op was defined in C++, then we want to make sure there was an
|
||||
# m.impl_abstract_pystub(module, ...) call and that the module is the
|
||||
# same as the module that called torch.library.impl_abstract.
|
||||
# m.set_python_module(module, ...) call and that the module is the
|
||||
# same as the module that called torch.library.register_fake.
|
||||
def _check_pystubs_once(func, qualname, actual_module_name):
|
||||
checked = False
|
||||
|
||||
@ -528,8 +544,8 @@ def _check_pystubs_once(func, qualname, actual_module_name):
|
||||
cpp_filename = op._handle().debug()
|
||||
raise RuntimeError(
|
||||
f"Operator '{qualname}' was defined in C++ and has a Python "
|
||||
f"abstract impl. In this situation, we require there to also be a "
|
||||
f"companion C++ `m.impl_abstract_pystub(\"{actual_module_name}\")` "
|
||||
f"fake impl. In this situation, we require there to also be a "
|
||||
f"companion C++ `m.set_python_module(\"{actual_module_name}\")` "
|
||||
f"call, but we could not find one. Please add that to "
|
||||
f"to the top of the C++ TORCH_LIBRARY({namespace}, ...) block the "
|
||||
f"operator was registered in ({cpp_filename})")
|
||||
@ -537,10 +553,10 @@ def _check_pystubs_once(func, qualname, actual_module_name):
|
||||
if actual_module_name != pystub_module:
|
||||
cpp_filename = op._handle().debug()
|
||||
raise RuntimeError(
|
||||
f"Operator '{qualname}' specified that its python abstract impl "
|
||||
f"Operator '{qualname}' specified that its python fake impl "
|
||||
f"is in the Python module '{pystub_module}' but it was actually found "
|
||||
f"in '{actual_module_name}'. Please either move the abstract impl "
|
||||
f"or correct the m.impl_abstract_pystub call ({cpp_filename})")
|
||||
f"in '{actual_module_name}'. Please either move the fake impl "
|
||||
f"or correct the m.set_python_module call ({cpp_filename})")
|
||||
checked = True
|
||||
return func(*args, **kwargs)
|
||||
return inner
|
||||
@ -556,7 +572,7 @@ def _check_pystubs_once(func, qualname, actual_module_name):
|
||||
def get_ctx() -> "torch._library.abstract_impl.AbstractImplCtx":
|
||||
"""get_ctx() returns the current AbstractImplCtx object.
|
||||
|
||||
Calling ``get_ctx()`` is only valid inside of an abstract impl
|
||||
(see :func:`torch.library.impl_abstract` for more usage details.
|
||||
Calling ``get_ctx()`` is only valid inside of an fake impl
|
||||
(see :func:`torch.library.register_fake` for more usage details.
|
||||
"""
|
||||
return torch._library.abstract_impl.global_ctx_getter()
|
||||
|
Reference in New Issue
Block a user