Using the latest torch.library.register_fake API instead of torch.library.impl_abstract (#158839)

As the title stated.

`torch.library.impl_abstract` have beed deprecated in PyTorch2.4, so change to use the new API.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/158839
Approved by: https://github.com/jingsh, https://github.com/zou3519
ghstack dependencies: #158838
This commit is contained in:
FFFrog
2025-07-24 11:32:18 +08:00
committed by PyTorch MergeBot
parent c60d382870
commit 6fc0ad22f0
11 changed files with 26 additions and 26 deletions

View File

@ -6,7 +6,7 @@ import torch
torch.ops.load_library(get_custom_op_library_path())
@torch.library.impl_abstract("custom::nonzero")
@torch.library.register_fake("custom::nonzero")
def nonzero_abstract(x):
n = x.dim()
ctx = torch.library.get_ctx()

View File

@ -6,6 +6,6 @@ import torch
torch.ops.load_library(get_custom_op_library_path())
@torch.library.impl_abstract("custom::sin")
@torch.library.register_fake("custom::sin")
def sin_abstract(x):
return torch.empty_like(x)

View File

@ -8,12 +8,12 @@ torch.ops.load_library(get_custom_op_library_path())
# NB: The impl_abstract_pystub for cos actually
# specifies it should live in the my_custom_ops2 module.
@torch.library.impl_abstract("custom::cos")
@torch.library.register_fake("custom::cos")
def cos_abstract(x):
return torch.empty_like(x)
# NB: There is no impl_abstract_pystub for tan
@torch.library.impl_abstract("custom::tan")
@torch.library.register_fake("custom::tan")
def tan_abstract(x):
return torch.empty_like(x)

View File

@ -911,7 +911,7 @@ class TestConverter(TestCase):
return x + x
# Meta function of the custom op.
@torch.library.impl_abstract(
@torch.library.register_fake(
"mylib::foo",
lib=lib,
)

View File

@ -146,7 +146,7 @@ torch.library.define(
@torch.library.impl("testlib::returns_tensor_symint", "cpu")
@torch.library.impl_abstract("testlib::returns_tensor_symint")
@torch.library.register_fake("testlib::returns_tensor_symint")
def returns_tensor_symint_impl(x):
return x, x.shape[0]
@ -159,7 +159,7 @@ def foo_impl(x, z):
return x, z, x + z
@torch.library.impl_abstract("testlib::foo")
@torch.library.register_fake("testlib::foo")
def foo_abstract(x, z):
return x, z, x + z

View File

@ -795,7 +795,7 @@ class TestDeserialize(TestCase):
)
@torch.library.impl("mylib::foo", "cpu", lib=lib)
@torch.library.impl_abstract("mylib::foo")
@torch.library.register_fake("mylib::foo")
def foo_impl(a, b, c):
res2 = None
if c is not None:
@ -884,21 +884,21 @@ class TestDeserialize(TestCase):
)
@torch.library.impl("mylib::foo1", "cpu", lib=lib)
@torch.library.impl_abstract("mylib::foo1")
@torch.library.register_fake("mylib::foo1")
def foo1_impl(x, y, z, w, n):
x.add_(y[0] + w)
z.add_(y[1] + n)
return n + n
@torch.library.impl("mylib::foo2", "cpu", lib=lib)
@torch.library.impl_abstract("mylib::foo2")
@torch.library.register_fake("mylib::foo2")
def foo2_impl(x, y, z, w, n):
x.add_(y[0] + w)
z.add_(y[1] + n)
return (n + n, n * n)
@torch.library.impl("mylib::foo3", "cpu", lib=lib)
@torch.library.impl_abstract("mylib::foo3")
@torch.library.register_fake("mylib::foo3")
def foo3_impl(x, y, z, w, n):
x.add_(y[0] + w)
z.add_(y[1] + n)

View File

@ -328,7 +328,7 @@ def forward(self, arg0_1, arg1_1, arg2_1):
return
# Meta function of the custom op
@torch.library.impl_abstract(
@torch.library.register_fake(
"mylib::record_scalar_tensor",
lib=lib,
)

View File

@ -4315,7 +4315,7 @@ class AOTInductorTestsTemplate:
def foo(a: torch.Tensor, b: torch.Tensor) -> torch.Tensor:
return a[: b.item()]
@torch.library.impl_abstract("mylib::foo", lib=lib)
@torch.library.register_fake("mylib::foo", lib=lib)
def foo_fake_impl(a, b):
ctx = torch.library.get_ctx()
u = ctx.new_dynamic_size()

View File

@ -217,7 +217,7 @@ def forward(self, arg0_1: "f32[3][1]cpu", arg1_1: "f32[3][1]cpu", arg2_1: "f32[3
z.add_(y[1] + n)
return y[0] + w, y[1] + n
@torch.library.impl_abstract("mylib::foo", lib=lib)
@torch.library.register_fake("mylib::foo", lib=lib)
def foo_abstract(x, y, z, w, n):
return y[0] + w, y[1] + n
@ -495,7 +495,7 @@ def forward(self, arg0_1: "f32[3][1]cpu", arg1_1: "f32[3][1]cpu", arg2_1: "f32[3
z.add_(y[1] + n)
return y[0] + w, y[1] + n
@torch.library.impl_abstract("mylib::foo", lib=lib)
@torch.library.register_fake("mylib::foo", lib=lib)
def foo_abstract(x, y, z, w, n):
return y[0] + w, y[1] + n

View File

@ -1608,7 +1608,7 @@ class TestCustomOp(CustomOpTestCaseBase):
lib = self.lib()
lib.define("sin.blah(Tensor x) -> Tensor")
torch.library.impl_abstract(
torch.library.register_fake(
f"{self.test_ns}::sin.blah", torch.empty_like, lib=lib
)
@ -1621,7 +1621,7 @@ class TestCustomOp(CustomOpTestCaseBase):
def foo(x: torch.Tensor, dim: int) -> torch.Tensor:
raise NotImplementedError
@torch.library.impl_abstract(f"{TestCustomOp.test_ns}::foo", lib=self.lib())
@torch.library.register_fake(f"{TestCustomOp.test_ns}::foo", lib=self.lib())
def foo_meta(x, dim):
output_shape = list(x.shape)
del output_shape[dim]
@ -1637,7 +1637,7 @@ class TestCustomOp(CustomOpTestCaseBase):
def foo(x: torch.Tensor, dim: int) -> torch.Tensor:
raise NotImplementedError
@torch.library.impl_abstract(f"{TestCustomOp.test_ns}::foo", lib=self.lib())
@torch.library.register_fake(f"{TestCustomOp.test_ns}::foo", lib=self.lib())
def foo_meta(x, dim):
output_shape = list(x.shape)
del output_shape[dim]
@ -1645,7 +1645,7 @@ class TestCustomOp(CustomOpTestCaseBase):
with self.assertRaisesRegex(RuntimeError, r"test_custom_ops.py:\d+"):
@torch.library.impl_abstract(f"{TestCustomOp.test_ns}::foo", lib=self.lib())
@torch.library.register_fake(f"{TestCustomOp.test_ns}::foo", lib=self.lib())
def foo_meta2(x, dim):
output_shape = list(x.shape)
del output_shape[dim]
@ -1656,7 +1656,7 @@ class TestCustomOp(CustomOpTestCaseBase):
def foo(x: torch.Tensor) -> torch.Tensor:
raise NotImplementedError
@torch.library.impl_abstract(f"{TestCustomOp.test_ns}::foo", lib=self.lib())
@torch.library.register_fake(f"{TestCustomOp.test_ns}::foo", lib=self.lib())
def foo_meta(x):
ctx = torch.library.get_ctx()
r = ctx.new_dynamic_size(min=1)
@ -1683,7 +1683,7 @@ class TestCustomOp(CustomOpTestCaseBase):
def foo(x: torch.Tensor) -> torch.Tensor:
raise NotImplementedError
@torch.library.impl_abstract(f"{TestCustomOp.test_ns}::foo", lib=self.lib())
@torch.library.register_fake(f"{TestCustomOp.test_ns}::foo", lib=self.lib())
def foo_meta(x):
return x.sum()
@ -1827,7 +1827,7 @@ Dynamic shape operator
lib.define("foo(Tensor x) -> Tensor")
qualname = f"{self.test_ns}::foo"
@torch.library.impl_abstract(qualname, lib=self.lib())
@torch.library.register_fake(qualname, lib=self.lib())
def foo_impl(x):
return x.sin()
@ -1850,7 +1850,7 @@ Dynamic shape operator
op = self.get_op(qualname)
with self.assertRaisesRegex(RuntimeError, r"already has .*Meta implementation"):
torch.library.impl_abstract(qualname, func=foo_impl, lib=self.lib())
torch.library.register_fake(qualname, foo_impl, lib=self.lib())
def test_abstract_impl_on_existing_op_with_CompositeImplicitAutograd(self):
lib = self.lib()
@ -1864,7 +1864,7 @@ Dynamic shape operator
op = self.get_op(qualname)
with self.assertRaisesRegex(RuntimeError, "CompositeImplicitAutograd"):
torch.library.impl_abstract(qualname, func=foo_impl, lib=self.lib())
torch.library.register_fake(qualname, foo_impl, lib=self.lib())
def test_abstract_impl_on_existing_op_with_CompositeExplicitAutograd(self):
lib = self.lib()
@ -1877,7 +1877,7 @@ Dynamic shape operator
lib.impl("foo", foo_impl, "CompositeExplicitAutograd")
op = self.get_op(qualname)
torch.library.impl_abstract(qualname, func=lambda x: x.sum(), lib=self.lib())
torch.library.register_fake(qualname, lambda x: x.sum(), lib=self.lib())
with torch._subclasses.FakeTensorMode():
x = torch.randn(10)
result = op(x)

View File

@ -403,7 +403,7 @@ class CustomOpDef:
(sizes/strides/storage_offset/device), it specifies what the properties of
the output Tensors are.
Please see :func:`torch.library.impl_abstract` for more details.
Please see :func:`torch.library.register_fake` for more details.
Args:
fn (Callable): The function to register as the FakeTensor