mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Using the latest torch.library.register_fake API instead of torch.library.impl_abstract (#158839)
As the title stated. `torch.library.impl_abstract` have beed deprecated in PyTorch2.4, so change to use the new API. Pull Request resolved: https://github.com/pytorch/pytorch/pull/158839 Approved by: https://github.com/jingsh, https://github.com/zou3519 ghstack dependencies: #158838
This commit is contained in:
@ -1608,7 +1608,7 @@ class TestCustomOp(CustomOpTestCaseBase):
|
||||
lib = self.lib()
|
||||
lib.define("sin.blah(Tensor x) -> Tensor")
|
||||
|
||||
torch.library.impl_abstract(
|
||||
torch.library.register_fake(
|
||||
f"{self.test_ns}::sin.blah", torch.empty_like, lib=lib
|
||||
)
|
||||
|
||||
@ -1621,7 +1621,7 @@ class TestCustomOp(CustomOpTestCaseBase):
|
||||
def foo(x: torch.Tensor, dim: int) -> torch.Tensor:
|
||||
raise NotImplementedError
|
||||
|
||||
@torch.library.impl_abstract(f"{TestCustomOp.test_ns}::foo", lib=self.lib())
|
||||
@torch.library.register_fake(f"{TestCustomOp.test_ns}::foo", lib=self.lib())
|
||||
def foo_meta(x, dim):
|
||||
output_shape = list(x.shape)
|
||||
del output_shape[dim]
|
||||
@ -1637,7 +1637,7 @@ class TestCustomOp(CustomOpTestCaseBase):
|
||||
def foo(x: torch.Tensor, dim: int) -> torch.Tensor:
|
||||
raise NotImplementedError
|
||||
|
||||
@torch.library.impl_abstract(f"{TestCustomOp.test_ns}::foo", lib=self.lib())
|
||||
@torch.library.register_fake(f"{TestCustomOp.test_ns}::foo", lib=self.lib())
|
||||
def foo_meta(x, dim):
|
||||
output_shape = list(x.shape)
|
||||
del output_shape[dim]
|
||||
@ -1645,7 +1645,7 @@ class TestCustomOp(CustomOpTestCaseBase):
|
||||
|
||||
with self.assertRaisesRegex(RuntimeError, r"test_custom_ops.py:\d+"):
|
||||
|
||||
@torch.library.impl_abstract(f"{TestCustomOp.test_ns}::foo", lib=self.lib())
|
||||
@torch.library.register_fake(f"{TestCustomOp.test_ns}::foo", lib=self.lib())
|
||||
def foo_meta2(x, dim):
|
||||
output_shape = list(x.shape)
|
||||
del output_shape[dim]
|
||||
@ -1656,7 +1656,7 @@ class TestCustomOp(CustomOpTestCaseBase):
|
||||
def foo(x: torch.Tensor) -> torch.Tensor:
|
||||
raise NotImplementedError
|
||||
|
||||
@torch.library.impl_abstract(f"{TestCustomOp.test_ns}::foo", lib=self.lib())
|
||||
@torch.library.register_fake(f"{TestCustomOp.test_ns}::foo", lib=self.lib())
|
||||
def foo_meta(x):
|
||||
ctx = torch.library.get_ctx()
|
||||
r = ctx.new_dynamic_size(min=1)
|
||||
@ -1683,7 +1683,7 @@ class TestCustomOp(CustomOpTestCaseBase):
|
||||
def foo(x: torch.Tensor) -> torch.Tensor:
|
||||
raise NotImplementedError
|
||||
|
||||
@torch.library.impl_abstract(f"{TestCustomOp.test_ns}::foo", lib=self.lib())
|
||||
@torch.library.register_fake(f"{TestCustomOp.test_ns}::foo", lib=self.lib())
|
||||
def foo_meta(x):
|
||||
return x.sum()
|
||||
|
||||
@ -1827,7 +1827,7 @@ Dynamic shape operator
|
||||
lib.define("foo(Tensor x) -> Tensor")
|
||||
qualname = f"{self.test_ns}::foo"
|
||||
|
||||
@torch.library.impl_abstract(qualname, lib=self.lib())
|
||||
@torch.library.register_fake(qualname, lib=self.lib())
|
||||
def foo_impl(x):
|
||||
return x.sin()
|
||||
|
||||
@ -1850,7 +1850,7 @@ Dynamic shape operator
|
||||
op = self.get_op(qualname)
|
||||
|
||||
with self.assertRaisesRegex(RuntimeError, r"already has .*Meta implementation"):
|
||||
torch.library.impl_abstract(qualname, func=foo_impl, lib=self.lib())
|
||||
torch.library.register_fake(qualname, foo_impl, lib=self.lib())
|
||||
|
||||
def test_abstract_impl_on_existing_op_with_CompositeImplicitAutograd(self):
|
||||
lib = self.lib()
|
||||
@ -1864,7 +1864,7 @@ Dynamic shape operator
|
||||
op = self.get_op(qualname)
|
||||
|
||||
with self.assertRaisesRegex(RuntimeError, "CompositeImplicitAutograd"):
|
||||
torch.library.impl_abstract(qualname, func=foo_impl, lib=self.lib())
|
||||
torch.library.register_fake(qualname, foo_impl, lib=self.lib())
|
||||
|
||||
def test_abstract_impl_on_existing_op_with_CompositeExplicitAutograd(self):
|
||||
lib = self.lib()
|
||||
@ -1877,7 +1877,7 @@ Dynamic shape operator
|
||||
lib.impl("foo", foo_impl, "CompositeExplicitAutograd")
|
||||
op = self.get_op(qualname)
|
||||
|
||||
torch.library.impl_abstract(qualname, func=lambda x: x.sum(), lib=self.lib())
|
||||
torch.library.register_fake(qualname, lambda x: x.sum(), lib=self.lib())
|
||||
with torch._subclasses.FakeTensorMode():
|
||||
x = torch.randn(10)
|
||||
result = op(x)
|
||||
|
Reference in New Issue
Block a user