mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Reland torch.ops API change machinery with the core functionality disabled (#71785)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/71785 see https://github.com/pytorch/pytorch/pull/67254 ghstack-source-id: 147648699 Test Plan: github CI Reviewed By: albanD Differential Revision: D33777229 fbshipit-source-id: 517b36be9743025eb40d708d380dae62e3663184 (cherry picked from commit a637e695694d3fd615dbe821394bfe53d41b6901)
This commit is contained in:
committed by
PyTorch MergeBot
parent
1fdbe9aa76
commit
a1383a9cfa
67
test/test_per_overload_api.py
Normal file
67
test/test_per_overload_api.py
Normal file
@ -0,0 +1,67 @@
|
||||
# Owner(s): ["module: unknown"]
|
||||
# import torch
|
||||
# import copy
|
||||
from torch.testing._internal.common_utils import TestCase, run_tests
|
||||
|
||||
class TestPerOverloadAPI(TestCase):
|
||||
# def test_basics_opoverloadpacket(self):
|
||||
# # add is ony used as an example here. It is ok to update the test
|
||||
# # if the semantics of add are modified in the future.
|
||||
# add_packet = torch.ops.aten.add
|
||||
|
||||
# # class attributes
|
||||
# self.assertEqual(add_packet.op_name, 'add')
|
||||
# self.assertEqual(add_packet.qualified_op_name, 'aten.add')
|
||||
|
||||
# # callable
|
||||
# self.assertEqual(add_packet(torch.tensor(2), torch.tensor(3)), torch.tensor(5))
|
||||
|
||||
# # correct module
|
||||
# self.assertEqual(add_packet.__module__, add_packet.op.__module__)
|
||||
|
||||
# # caching
|
||||
# another_add_packet = torch.ops.aten.add
|
||||
# self.assertEqual(id(add_packet), id(another_add_packet))
|
||||
|
||||
# # deepcopy is a no-op
|
||||
# self.assertEqual(id(add_packet), id(copy.deepcopy(add_packet)))
|
||||
|
||||
# # pretty print
|
||||
# self.assertEqual(str(add_packet), "OpOverloadPacket(op='aten.add')")
|
||||
|
||||
# self.assertRaises(AttributeError, lambda: add_packet.foo)
|
||||
|
||||
# def test_basics_opoverload(self):
|
||||
# add_packet = torch.ops.aten.add
|
||||
# add_tensoroverload = add_packet.Tensor
|
||||
|
||||
# # class attributes
|
||||
# self.assertEqual(add_tensoroverload.name, 'aten.add')
|
||||
# self.assertEqual(add_tensoroverload.overload_name, 'Tensor')
|
||||
# self.assertEqual(add_tensoroverload.overload_packet, add_packet)
|
||||
|
||||
# # deepcopy is a no-op
|
||||
# self.assertEqual(id(add_tensoroverload), id(copy.deepcopy(add_tensoroverload)))
|
||||
|
||||
# # caching
|
||||
# another_add_tensoroverload = torch.ops.aten.add.Tensor
|
||||
# self.assertEqual(id(add_tensoroverload), id(another_add_tensoroverload))
|
||||
|
||||
# # pretty print
|
||||
# self.assertEqual(str(add_tensoroverload), "OpOverload(op='aten.add', overload='Tensor')")
|
||||
|
||||
# # callable
|
||||
# self.assertEqual(add_tensoroverload(torch.tensor(2), torch.tensor(3)), torch.tensor(5))
|
||||
|
||||
# a = torch.tensor(2)
|
||||
# b = torch.tensor(0)
|
||||
# torch.ops.aten.add.out(a, a, out=b)
|
||||
# self.assertEqual(b, torch.tensor(4))
|
||||
|
||||
# self.assertRaises(RuntimeError, lambda: add_tensoroverload(a, a, out=b))
|
||||
|
||||
def do_nothing(self):
|
||||
return
|
||||
|
||||
if __name__ == '__main__':
|
||||
run_tests()
|
Reference in New Issue
Block a user