Files
pytorch/test/test_per_overload_api.py
Edward Z. Yang 988bd0173c Add OpOverload.decompose API (#83075)
This allows you to directly call into the CompositeImplicitAutograd
implementation of an operator, *without* changing any aspects of the
dispatcher state.  In particular, you can use this to recursively call
into a decomposition, dispatching back to your tensor subclass/mode
as desired.

Hypothetically, we should also make these available in the
decompositions dictionary, but I'm leaving this as future work as
enumerating these decompositions is annoying (as operators are lazily
registered.)

Signed-off-by: Edward Z. Yang <ezyang@fb.com>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/83075
Approved by: https://github.com/albanD
2022-08-09 18:53:19 +00:00

73 lines
2.5 KiB
Python

# Owner(s): ["module: unknown"]
import torch
import copy
from torch.testing._internal.common_utils import TestCase, run_tests
class TestPerOverloadAPI(TestCase):
def test_basics_opoverloadpacket(self):
# add is ony used as an example here. It is ok to update the test
# if the semantics of add are modified in the future.
add_packet = torch.ops.aten.add
# class attributes
self.assertEqual(add_packet.__name__, 'add')
self.assertEqual(str(add_packet), 'aten.add')
# callable
self.assertEqual(add_packet(torch.tensor(2), torch.tensor(3)), torch.tensor(5))
# correct module
self.assertEqual(add_packet.__module__, add_packet.op.__module__)
# caching
another_add_packet = torch.ops.aten.add
self.assertEqual(id(add_packet), id(another_add_packet))
# deepcopy is a no-op
self.assertEqual(id(add_packet), id(copy.deepcopy(add_packet)))
# pretty print
self.assertEqual(repr(add_packet), "<OpOverloadPacket(op='aten.add')>")
self.assertRaises(AttributeError, lambda: add_packet.foo)
def test_basics_opoverload(self):
add_packet = torch.ops.aten.add
add_tensoroverload = add_packet.Tensor
# class attributes
self.assertEqual(str(add_tensoroverload), 'aten.add.Tensor')
self.assertEqual(add_tensoroverload.__name__, 'add.Tensor')
self.assertEqual(add_tensoroverload.overloadpacket, add_packet)
# deepcopy is a no-op
self.assertEqual(id(add_tensoroverload), id(copy.deepcopy(add_tensoroverload)))
# caching
another_add_tensoroverload = torch.ops.aten.add.Tensor
self.assertEqual(id(add_tensoroverload), id(another_add_tensoroverload))
# pretty print
self.assertEqual(repr(add_tensoroverload), "<OpOverload(op='aten.add', overload='Tensor')>")
# callable
self.assertEqual(add_tensoroverload(torch.tensor(2), torch.tensor(3)), torch.tensor(5))
a = torch.tensor(2)
b = torch.tensor(0)
torch.ops.aten.add.out(a, a, out=b)
self.assertEqual(b, torch.tensor(4))
self.assertRaises(RuntimeError, lambda: add_tensoroverload(a, a, out=b))
def test_decompose(self):
x = torch.randn(2, 3)
y = torch.randn(5, 3)
self.assertEqual(
torch.ops.aten.linear.default.decompose(x, y),
torch.ops.aten.linear.default(x, y)
)
if __name__ == '__main__':
run_tests()