mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Fixes some files in https://github.com/pytorch/pytorch/issues/123062 Run lintrunner on files: test/test_package.py test/test_per_overload_api.py ```bash $ lintrunner -a --take UFMT --all-files ok No lint issues. Successfully applied all patches. ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/125834 Approved by: https://github.com/malfet
80 lines
2.5 KiB
Python
80 lines
2.5 KiB
Python
# Owner(s): ["module: unknown"]
|
|
import copy
|
|
|
|
import torch
|
|
from torch.testing._internal.common_utils import run_tests, TestCase
|
|
|
|
|
|
class TestPerOverloadAPI(TestCase):
|
|
def test_basics_opoverloadpacket(self):
|
|
# add is ony used as an example here. It is ok to update the test
|
|
# if the semantics of add are modified in the future.
|
|
add_packet = torch.ops.aten.add
|
|
|
|
# class attributes
|
|
self.assertEqual(add_packet.__name__, "add")
|
|
self.assertEqual(str(add_packet), "aten.add")
|
|
|
|
# callable
|
|
self.assertEqual(add_packet(torch.tensor(2), torch.tensor(3)), torch.tensor(5))
|
|
|
|
# correct module
|
|
self.assertEqual(add_packet.__module__, add_packet.op.__module__)
|
|
|
|
# caching
|
|
another_add_packet = torch.ops.aten.add
|
|
self.assertEqual(id(add_packet), id(another_add_packet))
|
|
|
|
# deepcopy is a no-op
|
|
self.assertEqual(id(add_packet), id(copy.deepcopy(add_packet)))
|
|
|
|
# pretty print
|
|
self.assertEqual(repr(add_packet), "<OpOverloadPacket(op='aten.add')>")
|
|
|
|
self.assertRaises(AttributeError, lambda: add_packet.foo)
|
|
|
|
def test_basics_opoverload(self):
|
|
add_packet = torch.ops.aten.add
|
|
add_tensoroverload = add_packet.Tensor
|
|
|
|
# class attributes
|
|
self.assertEqual(str(add_tensoroverload), "aten.add.Tensor")
|
|
self.assertEqual(add_tensoroverload.__name__, "add.Tensor")
|
|
self.assertEqual(add_tensoroverload.overloadpacket, add_packet)
|
|
|
|
# deepcopy is a no-op
|
|
self.assertEqual(id(add_tensoroverload), id(copy.deepcopy(add_tensoroverload)))
|
|
|
|
# caching
|
|
another_add_tensoroverload = torch.ops.aten.add.Tensor
|
|
self.assertEqual(id(add_tensoroverload), id(another_add_tensoroverload))
|
|
|
|
# pretty print
|
|
self.assertEqual(
|
|
repr(add_tensoroverload), "<OpOverload(op='aten.add', overload='Tensor')>"
|
|
)
|
|
|
|
# callable
|
|
self.assertEqual(
|
|
add_tensoroverload(torch.tensor(2), torch.tensor(3)), torch.tensor(5)
|
|
)
|
|
|
|
a = torch.tensor(2)
|
|
b = torch.tensor(0)
|
|
torch.ops.aten.add.out(a, a, out=b)
|
|
self.assertEqual(b, torch.tensor(4))
|
|
|
|
self.assertRaises(RuntimeError, lambda: add_tensoroverload(a, a, out=b))
|
|
|
|
def test_decompose(self):
|
|
x = torch.randn(2, 3)
|
|
y = torch.randn(5, 3)
|
|
self.assertEqual(
|
|
torch.ops.aten.linear.default.decompose(x, y),
|
|
torch.ops.aten.linear.default(x, y),
|
|
)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
run_tests()
|