Add a test case for findDanglingImpls (#61104)

Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/61104

This patch added a new test case for findDanglingImpls. The test case introduces a C++ extension which has a dangling impl such that findDanglingImpls can find it and output its information.

Test Plan:
python test/test_dispatch.py TestDispatch.test_find_dangling_impls_ext

Imported from OSS

Reviewed By: ezyang

Differential Revision: D29512520

fbshipit-source-id: 6883fb8f065f2c0ae0e7a1adf6fd298591497e2b
This commit is contained in:
Jiewen Tan
2021-07-07 13:15:40 -07:00
committed by Facebook GitHub Bot
parent 4d9fd8958b
commit 357c4d9cc4
2 changed files with 42 additions and 9 deletions

View File

@ -4,7 +4,9 @@ from torch._python_dispatcher import PythonDispatcher
from collections import namedtuple
import itertools
import os
import re
import torch.utils.cpp_extension
# TODO: Expand the dispatcher API to be a generic API for interfacing with
# the dispatcher from Python!
@ -755,6 +757,35 @@ CompositeImplicitAutograd[alias] (inactive): fn1 :: (Tensor _0) -> (Tensor _0) [
'''
)
def test_find_dangling_impls(self):
dangling_impls = C._dispatch_find_dangling_impls()
self.assertEqual(
0,
len(dangling_impls),
msg=f"Expect zero dangling impls, but found: {dangling_impls}"
)
def test_find_dangling_impls_ext(self):
extension_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'cpp_extensions', 'dangling_impl_extension.cpp')
module = torch.utils.cpp_extension.load(
name="dangling_impl_extension",
sources=[
extension_path,
],
extra_cflags=["-g"],
verbose=True,
)
impls = C._dispatch_find_dangling_impls()
self.assertEqual(1, len(impls))
self.assertEqual(
'''\
name: __test::foo
schema: (none)
CPU: registered at {}:5 :: () -> () [ boxed unboxed ]
'''.format(extension_path),
impls[0])
class TestPythonDispatcher(TestCase):
def test_basic(self):
dispatcher = PythonDispatcher()
@ -886,14 +917,5 @@ CompositeImplicitAutograd[alias] fn_CompositeImplicitAutograd
r"Registration to both CompositeImplicitAutograd and CompositeExplicitAutograd is not allowed"):
dispatcher.register(["CompositeExplicitAutograd", "CompositeImplicitAutograd"])
# TODO(jwtan): Use a C++ extension, like msnpu, to introduce a dangling impl and examine the output.
def test_find_dangling_impls(self):
dangling_impls = C._dispatch_find_dangling_impls()
self.assertEqual(
0,
len(dangling_impls),
msg=f"Expect zero dangling impls, but found: {dangling_impls}"
)
if __name__ == '__main__':
run_tests()