mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Using the latest torch.library.register_fake API instead of torch.library.impl_abstract (#158839)
As the title stated. `torch.library.impl_abstract` have beed deprecated in PyTorch2.4, so change to use the new API. Pull Request resolved: https://github.com/pytorch/pytorch/pull/158839 Approved by: https://github.com/jingsh, https://github.com/zou3519 ghstack dependencies: #158838
This commit is contained in:
@ -6,7 +6,7 @@ import torch
|
|||||||
torch.ops.load_library(get_custom_op_library_path())
|
torch.ops.load_library(get_custom_op_library_path())
|
||||||
|
|
||||||
|
|
||||||
@torch.library.impl_abstract("custom::nonzero")
|
@torch.library.register_fake("custom::nonzero")
|
||||||
def nonzero_abstract(x):
|
def nonzero_abstract(x):
|
||||||
n = x.dim()
|
n = x.dim()
|
||||||
ctx = torch.library.get_ctx()
|
ctx = torch.library.get_ctx()
|
||||||
|
@ -6,6 +6,6 @@ import torch
|
|||||||
torch.ops.load_library(get_custom_op_library_path())
|
torch.ops.load_library(get_custom_op_library_path())
|
||||||
|
|
||||||
|
|
||||||
@torch.library.impl_abstract("custom::sin")
|
@torch.library.register_fake("custom::sin")
|
||||||
def sin_abstract(x):
|
def sin_abstract(x):
|
||||||
return torch.empty_like(x)
|
return torch.empty_like(x)
|
||||||
|
@ -8,12 +8,12 @@ torch.ops.load_library(get_custom_op_library_path())
|
|||||||
|
|
||||||
# NB: The impl_abstract_pystub for cos actually
|
# NB: The impl_abstract_pystub for cos actually
|
||||||
# specifies it should live in the my_custom_ops2 module.
|
# specifies it should live in the my_custom_ops2 module.
|
||||||
@torch.library.impl_abstract("custom::cos")
|
@torch.library.register_fake("custom::cos")
|
||||||
def cos_abstract(x):
|
def cos_abstract(x):
|
||||||
return torch.empty_like(x)
|
return torch.empty_like(x)
|
||||||
|
|
||||||
|
|
||||||
# NB: There is no impl_abstract_pystub for tan
|
# NB: There is no impl_abstract_pystub for tan
|
||||||
@torch.library.impl_abstract("custom::tan")
|
@torch.library.register_fake("custom::tan")
|
||||||
def tan_abstract(x):
|
def tan_abstract(x):
|
||||||
return torch.empty_like(x)
|
return torch.empty_like(x)
|
||||||
|
@ -911,7 +911,7 @@ class TestConverter(TestCase):
|
|||||||
return x + x
|
return x + x
|
||||||
|
|
||||||
# Meta function of the custom op.
|
# Meta function of the custom op.
|
||||||
@torch.library.impl_abstract(
|
@torch.library.register_fake(
|
||||||
"mylib::foo",
|
"mylib::foo",
|
||||||
lib=lib,
|
lib=lib,
|
||||||
)
|
)
|
||||||
|
@ -146,7 +146,7 @@ torch.library.define(
|
|||||||
|
|
||||||
|
|
||||||
@torch.library.impl("testlib::returns_tensor_symint", "cpu")
|
@torch.library.impl("testlib::returns_tensor_symint", "cpu")
|
||||||
@torch.library.impl_abstract("testlib::returns_tensor_symint")
|
@torch.library.register_fake("testlib::returns_tensor_symint")
|
||||||
def returns_tensor_symint_impl(x):
|
def returns_tensor_symint_impl(x):
|
||||||
return x, x.shape[0]
|
return x, x.shape[0]
|
||||||
|
|
||||||
@ -159,7 +159,7 @@ def foo_impl(x, z):
|
|||||||
return x, z, x + z
|
return x, z, x + z
|
||||||
|
|
||||||
|
|
||||||
@torch.library.impl_abstract("testlib::foo")
|
@torch.library.register_fake("testlib::foo")
|
||||||
def foo_abstract(x, z):
|
def foo_abstract(x, z):
|
||||||
return x, z, x + z
|
return x, z, x + z
|
||||||
|
|
||||||
|
@ -795,7 +795,7 @@ class TestDeserialize(TestCase):
|
|||||||
)
|
)
|
||||||
|
|
||||||
@torch.library.impl("mylib::foo", "cpu", lib=lib)
|
@torch.library.impl("mylib::foo", "cpu", lib=lib)
|
||||||
@torch.library.impl_abstract("mylib::foo")
|
@torch.library.register_fake("mylib::foo")
|
||||||
def foo_impl(a, b, c):
|
def foo_impl(a, b, c):
|
||||||
res2 = None
|
res2 = None
|
||||||
if c is not None:
|
if c is not None:
|
||||||
@ -884,21 +884,21 @@ class TestDeserialize(TestCase):
|
|||||||
)
|
)
|
||||||
|
|
||||||
@torch.library.impl("mylib::foo1", "cpu", lib=lib)
|
@torch.library.impl("mylib::foo1", "cpu", lib=lib)
|
||||||
@torch.library.impl_abstract("mylib::foo1")
|
@torch.library.register_fake("mylib::foo1")
|
||||||
def foo1_impl(x, y, z, w, n):
|
def foo1_impl(x, y, z, w, n):
|
||||||
x.add_(y[0] + w)
|
x.add_(y[0] + w)
|
||||||
z.add_(y[1] + n)
|
z.add_(y[1] + n)
|
||||||
return n + n
|
return n + n
|
||||||
|
|
||||||
@torch.library.impl("mylib::foo2", "cpu", lib=lib)
|
@torch.library.impl("mylib::foo2", "cpu", lib=lib)
|
||||||
@torch.library.impl_abstract("mylib::foo2")
|
@torch.library.register_fake("mylib::foo2")
|
||||||
def foo2_impl(x, y, z, w, n):
|
def foo2_impl(x, y, z, w, n):
|
||||||
x.add_(y[0] + w)
|
x.add_(y[0] + w)
|
||||||
z.add_(y[1] + n)
|
z.add_(y[1] + n)
|
||||||
return (n + n, n * n)
|
return (n + n, n * n)
|
||||||
|
|
||||||
@torch.library.impl("mylib::foo3", "cpu", lib=lib)
|
@torch.library.impl("mylib::foo3", "cpu", lib=lib)
|
||||||
@torch.library.impl_abstract("mylib::foo3")
|
@torch.library.register_fake("mylib::foo3")
|
||||||
def foo3_impl(x, y, z, w, n):
|
def foo3_impl(x, y, z, w, n):
|
||||||
x.add_(y[0] + w)
|
x.add_(y[0] + w)
|
||||||
z.add_(y[1] + n)
|
z.add_(y[1] + n)
|
||||||
|
@ -328,7 +328,7 @@ def forward(self, arg0_1, arg1_1, arg2_1):
|
|||||||
return
|
return
|
||||||
|
|
||||||
# Meta function of the custom op
|
# Meta function of the custom op
|
||||||
@torch.library.impl_abstract(
|
@torch.library.register_fake(
|
||||||
"mylib::record_scalar_tensor",
|
"mylib::record_scalar_tensor",
|
||||||
lib=lib,
|
lib=lib,
|
||||||
)
|
)
|
||||||
|
@ -4315,7 +4315,7 @@ class AOTInductorTestsTemplate:
|
|||||||
def foo(a: torch.Tensor, b: torch.Tensor) -> torch.Tensor:
|
def foo(a: torch.Tensor, b: torch.Tensor) -> torch.Tensor:
|
||||||
return a[: b.item()]
|
return a[: b.item()]
|
||||||
|
|
||||||
@torch.library.impl_abstract("mylib::foo", lib=lib)
|
@torch.library.register_fake("mylib::foo", lib=lib)
|
||||||
def foo_fake_impl(a, b):
|
def foo_fake_impl(a, b):
|
||||||
ctx = torch.library.get_ctx()
|
ctx = torch.library.get_ctx()
|
||||||
u = ctx.new_dynamic_size()
|
u = ctx.new_dynamic_size()
|
||||||
|
@ -217,7 +217,7 @@ def forward(self, arg0_1: "f32[3][1]cpu", arg1_1: "f32[3][1]cpu", arg2_1: "f32[3
|
|||||||
z.add_(y[1] + n)
|
z.add_(y[1] + n)
|
||||||
return y[0] + w, y[1] + n
|
return y[0] + w, y[1] + n
|
||||||
|
|
||||||
@torch.library.impl_abstract("mylib::foo", lib=lib)
|
@torch.library.register_fake("mylib::foo", lib=lib)
|
||||||
def foo_abstract(x, y, z, w, n):
|
def foo_abstract(x, y, z, w, n):
|
||||||
return y[0] + w, y[1] + n
|
return y[0] + w, y[1] + n
|
||||||
|
|
||||||
@ -495,7 +495,7 @@ def forward(self, arg0_1: "f32[3][1]cpu", arg1_1: "f32[3][1]cpu", arg2_1: "f32[3
|
|||||||
z.add_(y[1] + n)
|
z.add_(y[1] + n)
|
||||||
return y[0] + w, y[1] + n
|
return y[0] + w, y[1] + n
|
||||||
|
|
||||||
@torch.library.impl_abstract("mylib::foo", lib=lib)
|
@torch.library.register_fake("mylib::foo", lib=lib)
|
||||||
def foo_abstract(x, y, z, w, n):
|
def foo_abstract(x, y, z, w, n):
|
||||||
return y[0] + w, y[1] + n
|
return y[0] + w, y[1] + n
|
||||||
|
|
||||||
|
@ -1608,7 +1608,7 @@ class TestCustomOp(CustomOpTestCaseBase):
|
|||||||
lib = self.lib()
|
lib = self.lib()
|
||||||
lib.define("sin.blah(Tensor x) -> Tensor")
|
lib.define("sin.blah(Tensor x) -> Tensor")
|
||||||
|
|
||||||
torch.library.impl_abstract(
|
torch.library.register_fake(
|
||||||
f"{self.test_ns}::sin.blah", torch.empty_like, lib=lib
|
f"{self.test_ns}::sin.blah", torch.empty_like, lib=lib
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -1621,7 +1621,7 @@ class TestCustomOp(CustomOpTestCaseBase):
|
|||||||
def foo(x: torch.Tensor, dim: int) -> torch.Tensor:
|
def foo(x: torch.Tensor, dim: int) -> torch.Tensor:
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
@torch.library.impl_abstract(f"{TestCustomOp.test_ns}::foo", lib=self.lib())
|
@torch.library.register_fake(f"{TestCustomOp.test_ns}::foo", lib=self.lib())
|
||||||
def foo_meta(x, dim):
|
def foo_meta(x, dim):
|
||||||
output_shape = list(x.shape)
|
output_shape = list(x.shape)
|
||||||
del output_shape[dim]
|
del output_shape[dim]
|
||||||
@ -1637,7 +1637,7 @@ class TestCustomOp(CustomOpTestCaseBase):
|
|||||||
def foo(x: torch.Tensor, dim: int) -> torch.Tensor:
|
def foo(x: torch.Tensor, dim: int) -> torch.Tensor:
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
@torch.library.impl_abstract(f"{TestCustomOp.test_ns}::foo", lib=self.lib())
|
@torch.library.register_fake(f"{TestCustomOp.test_ns}::foo", lib=self.lib())
|
||||||
def foo_meta(x, dim):
|
def foo_meta(x, dim):
|
||||||
output_shape = list(x.shape)
|
output_shape = list(x.shape)
|
||||||
del output_shape[dim]
|
del output_shape[dim]
|
||||||
@ -1645,7 +1645,7 @@ class TestCustomOp(CustomOpTestCaseBase):
|
|||||||
|
|
||||||
with self.assertRaisesRegex(RuntimeError, r"test_custom_ops.py:\d+"):
|
with self.assertRaisesRegex(RuntimeError, r"test_custom_ops.py:\d+"):
|
||||||
|
|
||||||
@torch.library.impl_abstract(f"{TestCustomOp.test_ns}::foo", lib=self.lib())
|
@torch.library.register_fake(f"{TestCustomOp.test_ns}::foo", lib=self.lib())
|
||||||
def foo_meta2(x, dim):
|
def foo_meta2(x, dim):
|
||||||
output_shape = list(x.shape)
|
output_shape = list(x.shape)
|
||||||
del output_shape[dim]
|
del output_shape[dim]
|
||||||
@ -1656,7 +1656,7 @@ class TestCustomOp(CustomOpTestCaseBase):
|
|||||||
def foo(x: torch.Tensor) -> torch.Tensor:
|
def foo(x: torch.Tensor) -> torch.Tensor:
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
@torch.library.impl_abstract(f"{TestCustomOp.test_ns}::foo", lib=self.lib())
|
@torch.library.register_fake(f"{TestCustomOp.test_ns}::foo", lib=self.lib())
|
||||||
def foo_meta(x):
|
def foo_meta(x):
|
||||||
ctx = torch.library.get_ctx()
|
ctx = torch.library.get_ctx()
|
||||||
r = ctx.new_dynamic_size(min=1)
|
r = ctx.new_dynamic_size(min=1)
|
||||||
@ -1683,7 +1683,7 @@ class TestCustomOp(CustomOpTestCaseBase):
|
|||||||
def foo(x: torch.Tensor) -> torch.Tensor:
|
def foo(x: torch.Tensor) -> torch.Tensor:
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
@torch.library.impl_abstract(f"{TestCustomOp.test_ns}::foo", lib=self.lib())
|
@torch.library.register_fake(f"{TestCustomOp.test_ns}::foo", lib=self.lib())
|
||||||
def foo_meta(x):
|
def foo_meta(x):
|
||||||
return x.sum()
|
return x.sum()
|
||||||
|
|
||||||
@ -1827,7 +1827,7 @@ Dynamic shape operator
|
|||||||
lib.define("foo(Tensor x) -> Tensor")
|
lib.define("foo(Tensor x) -> Tensor")
|
||||||
qualname = f"{self.test_ns}::foo"
|
qualname = f"{self.test_ns}::foo"
|
||||||
|
|
||||||
@torch.library.impl_abstract(qualname, lib=self.lib())
|
@torch.library.register_fake(qualname, lib=self.lib())
|
||||||
def foo_impl(x):
|
def foo_impl(x):
|
||||||
return x.sin()
|
return x.sin()
|
||||||
|
|
||||||
@ -1850,7 +1850,7 @@ Dynamic shape operator
|
|||||||
op = self.get_op(qualname)
|
op = self.get_op(qualname)
|
||||||
|
|
||||||
with self.assertRaisesRegex(RuntimeError, r"already has .*Meta implementation"):
|
with self.assertRaisesRegex(RuntimeError, r"already has .*Meta implementation"):
|
||||||
torch.library.impl_abstract(qualname, func=foo_impl, lib=self.lib())
|
torch.library.register_fake(qualname, foo_impl, lib=self.lib())
|
||||||
|
|
||||||
def test_abstract_impl_on_existing_op_with_CompositeImplicitAutograd(self):
|
def test_abstract_impl_on_existing_op_with_CompositeImplicitAutograd(self):
|
||||||
lib = self.lib()
|
lib = self.lib()
|
||||||
@ -1864,7 +1864,7 @@ Dynamic shape operator
|
|||||||
op = self.get_op(qualname)
|
op = self.get_op(qualname)
|
||||||
|
|
||||||
with self.assertRaisesRegex(RuntimeError, "CompositeImplicitAutograd"):
|
with self.assertRaisesRegex(RuntimeError, "CompositeImplicitAutograd"):
|
||||||
torch.library.impl_abstract(qualname, func=foo_impl, lib=self.lib())
|
torch.library.register_fake(qualname, foo_impl, lib=self.lib())
|
||||||
|
|
||||||
def test_abstract_impl_on_existing_op_with_CompositeExplicitAutograd(self):
|
def test_abstract_impl_on_existing_op_with_CompositeExplicitAutograd(self):
|
||||||
lib = self.lib()
|
lib = self.lib()
|
||||||
@ -1877,7 +1877,7 @@ Dynamic shape operator
|
|||||||
lib.impl("foo", foo_impl, "CompositeExplicitAutograd")
|
lib.impl("foo", foo_impl, "CompositeExplicitAutograd")
|
||||||
op = self.get_op(qualname)
|
op = self.get_op(qualname)
|
||||||
|
|
||||||
torch.library.impl_abstract(qualname, func=lambda x: x.sum(), lib=self.lib())
|
torch.library.register_fake(qualname, lambda x: x.sum(), lib=self.lib())
|
||||||
with torch._subclasses.FakeTensorMode():
|
with torch._subclasses.FakeTensorMode():
|
||||||
x = torch.randn(10)
|
x = torch.randn(10)
|
||||||
result = op(x)
|
result = op(x)
|
||||||
|
@ -403,7 +403,7 @@ class CustomOpDef:
|
|||||||
(sizes/strides/storage_offset/device), it specifies what the properties of
|
(sizes/strides/storage_offset/device), it specifies what the properties of
|
||||||
the output Tensors are.
|
the output Tensors are.
|
||||||
|
|
||||||
Please see :func:`torch.library.impl_abstract` for more details.
|
Please see :func:`torch.library.register_fake` for more details.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
fn (Callable): The function to register as the FakeTensor
|
fn (Callable): The function to register as the FakeTensor
|
||||||
|
Reference in New Issue
Block a user