mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-21 05:34:18 +08:00
This allows you to directly call into the CompositeImplicitAutograd implementation of an operator, *without* changing any aspects of the dispatcher state. In particular, you can use this to recursively call into a decomposition, dispatching back to your tensor subclass/mode as desired. Hypothetically, we should also make these available in the decompositions dictionary, but I'm leaving this as future work as enumerating these decompositions is annoying (as operators are lazily registered.) Signed-off-by: Edward Z. Yang <ezyang@fb.com> Pull Request resolved: https://github.com/pytorch/pytorch/pull/83075 Approved by: https://github.com/albanD
73 lines
2.5 KiB
Python
73 lines
2.5 KiB
Python
# Owner(s): ["module: unknown"]
|
|
import torch
|
|
import copy
|
|
from torch.testing._internal.common_utils import TestCase, run_tests
|
|
|
|
class TestPerOverloadAPI(TestCase):
|
|
def test_basics_opoverloadpacket(self):
|
|
# add is ony used as an example here. It is ok to update the test
|
|
# if the semantics of add are modified in the future.
|
|
add_packet = torch.ops.aten.add
|
|
|
|
# class attributes
|
|
self.assertEqual(add_packet.__name__, 'add')
|
|
self.assertEqual(str(add_packet), 'aten.add')
|
|
|
|
# callable
|
|
self.assertEqual(add_packet(torch.tensor(2), torch.tensor(3)), torch.tensor(5))
|
|
|
|
# correct module
|
|
self.assertEqual(add_packet.__module__, add_packet.op.__module__)
|
|
|
|
# caching
|
|
another_add_packet = torch.ops.aten.add
|
|
self.assertEqual(id(add_packet), id(another_add_packet))
|
|
|
|
# deepcopy is a no-op
|
|
self.assertEqual(id(add_packet), id(copy.deepcopy(add_packet)))
|
|
|
|
# pretty print
|
|
self.assertEqual(repr(add_packet), "<OpOverloadPacket(op='aten.add')>")
|
|
|
|
self.assertRaises(AttributeError, lambda: add_packet.foo)
|
|
|
|
def test_basics_opoverload(self):
|
|
add_packet = torch.ops.aten.add
|
|
add_tensoroverload = add_packet.Tensor
|
|
|
|
# class attributes
|
|
self.assertEqual(str(add_tensoroverload), 'aten.add.Tensor')
|
|
self.assertEqual(add_tensoroverload.__name__, 'add.Tensor')
|
|
self.assertEqual(add_tensoroverload.overloadpacket, add_packet)
|
|
|
|
# deepcopy is a no-op
|
|
self.assertEqual(id(add_tensoroverload), id(copy.deepcopy(add_tensoroverload)))
|
|
|
|
# caching
|
|
another_add_tensoroverload = torch.ops.aten.add.Tensor
|
|
self.assertEqual(id(add_tensoroverload), id(another_add_tensoroverload))
|
|
|
|
# pretty print
|
|
self.assertEqual(repr(add_tensoroverload), "<OpOverload(op='aten.add', overload='Tensor')>")
|
|
|
|
# callable
|
|
self.assertEqual(add_tensoroverload(torch.tensor(2), torch.tensor(3)), torch.tensor(5))
|
|
|
|
a = torch.tensor(2)
|
|
b = torch.tensor(0)
|
|
torch.ops.aten.add.out(a, a, out=b)
|
|
self.assertEqual(b, torch.tensor(4))
|
|
|
|
self.assertRaises(RuntimeError, lambda: add_tensoroverload(a, a, out=b))
|
|
|
|
def test_decompose(self):
|
|
x = torch.randn(2, 3)
|
|
y = torch.randn(5, 3)
|
|
self.assertEqual(
|
|
torch.ops.aten.linear.default.decompose(x, y),
|
|
torch.ops.aten.linear.default(x, y)
|
|
)
|
|
|
|
if __name__ == '__main__':
|
|
run_tests()
|