Using the latest torch.library.register_fake API instead of torch.library.impl_abstract (#158839)

As the title stated.

`torch.library.impl_abstract` have beed deprecated in PyTorch2.4, so change to use the new API.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/158839
Approved by: https://github.com/jingsh, https://github.com/zou3519
ghstack dependencies: #158838
This commit is contained in:
FFFrog
2025-07-24 11:32:18 +08:00
committed by PyTorch MergeBot
parent c60d382870
commit 6fc0ad22f0
11 changed files with 26 additions and 26 deletions

View File

@ -1608,7 +1608,7 @@ class TestCustomOp(CustomOpTestCaseBase):
lib = self.lib()
lib.define("sin.blah(Tensor x) -> Tensor")
torch.library.impl_abstract(
torch.library.register_fake(
f"{self.test_ns}::sin.blah", torch.empty_like, lib=lib
)
@ -1621,7 +1621,7 @@ class TestCustomOp(CustomOpTestCaseBase):
def foo(x: torch.Tensor, dim: int) -> torch.Tensor:
raise NotImplementedError
@torch.library.impl_abstract(f"{TestCustomOp.test_ns}::foo", lib=self.lib())
@torch.library.register_fake(f"{TestCustomOp.test_ns}::foo", lib=self.lib())
def foo_meta(x, dim):
output_shape = list(x.shape)
del output_shape[dim]
@ -1637,7 +1637,7 @@ class TestCustomOp(CustomOpTestCaseBase):
def foo(x: torch.Tensor, dim: int) -> torch.Tensor:
raise NotImplementedError
@torch.library.impl_abstract(f"{TestCustomOp.test_ns}::foo", lib=self.lib())
@torch.library.register_fake(f"{TestCustomOp.test_ns}::foo", lib=self.lib())
def foo_meta(x, dim):
output_shape = list(x.shape)
del output_shape[dim]
@ -1645,7 +1645,7 @@ class TestCustomOp(CustomOpTestCaseBase):
with self.assertRaisesRegex(RuntimeError, r"test_custom_ops.py:\d+"):
@torch.library.impl_abstract(f"{TestCustomOp.test_ns}::foo", lib=self.lib())
@torch.library.register_fake(f"{TestCustomOp.test_ns}::foo", lib=self.lib())
def foo_meta2(x, dim):
output_shape = list(x.shape)
del output_shape[dim]
@ -1656,7 +1656,7 @@ class TestCustomOp(CustomOpTestCaseBase):
def foo(x: torch.Tensor) -> torch.Tensor:
raise NotImplementedError
@torch.library.impl_abstract(f"{TestCustomOp.test_ns}::foo", lib=self.lib())
@torch.library.register_fake(f"{TestCustomOp.test_ns}::foo", lib=self.lib())
def foo_meta(x):
ctx = torch.library.get_ctx()
r = ctx.new_dynamic_size(min=1)
@ -1683,7 +1683,7 @@ class TestCustomOp(CustomOpTestCaseBase):
def foo(x: torch.Tensor) -> torch.Tensor:
raise NotImplementedError
@torch.library.impl_abstract(f"{TestCustomOp.test_ns}::foo", lib=self.lib())
@torch.library.register_fake(f"{TestCustomOp.test_ns}::foo", lib=self.lib())
def foo_meta(x):
return x.sum()
@ -1827,7 +1827,7 @@ Dynamic shape operator
lib.define("foo(Tensor x) -> Tensor")
qualname = f"{self.test_ns}::foo"
@torch.library.impl_abstract(qualname, lib=self.lib())
@torch.library.register_fake(qualname, lib=self.lib())
def foo_impl(x):
return x.sin()
@ -1850,7 +1850,7 @@ Dynamic shape operator
op = self.get_op(qualname)
with self.assertRaisesRegex(RuntimeError, r"already has .*Meta implementation"):
torch.library.impl_abstract(qualname, func=foo_impl, lib=self.lib())
torch.library.register_fake(qualname, foo_impl, lib=self.lib())
def test_abstract_impl_on_existing_op_with_CompositeImplicitAutograd(self):
lib = self.lib()
@ -1864,7 +1864,7 @@ Dynamic shape operator
op = self.get_op(qualname)
with self.assertRaisesRegex(RuntimeError, "CompositeImplicitAutograd"):
torch.library.impl_abstract(qualname, func=foo_impl, lib=self.lib())
torch.library.register_fake(qualname, foo_impl, lib=self.lib())
def test_abstract_impl_on_existing_op_with_CompositeExplicitAutograd(self):
lib = self.lib()
@ -1877,7 +1877,7 @@ Dynamic shape operator
lib.impl("foo", foo_impl, "CompositeExplicitAutograd")
op = self.get_op(qualname)
torch.library.impl_abstract(qualname, func=lambda x: x.sum(), lib=self.lib())
torch.library.register_fake(qualname, lambda x: x.sum(), lib=self.lib())
with torch._subclasses.FakeTensorMode():
x = torch.randn(10)
result = op(x)