Files
pytorch/test/test_native_functions.py
David Reiss 5e03a1e926 Add support for int[]? arguments in native_functions.yaml (#37174)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/37174

ghstack-source-id: 106938112

Test Plan: Upcoming diffs use this for upsampling.

Differential Revision: D21210002

fbshipit-source-id: d6a55ab6420c05a92873a569221b613149aa0daa
2020-07-07 13:52:20 -07:00

62 lines
2.5 KiB
Python

from typing import Optional, List
import torch
from torch.testing._internal.common_utils import TestCase
# End-to-end tests of features in native_functions.yaml
class IntListWrapperModule(torch.nn.Module):
def forward(self, values, incr: Optional[List[int]]):
return torch._C._nn._test_optional_intlist(values, incr)
class TestNativeFunctions(TestCase):
def do_test_optional_intlist_with_module(self, module):
values = torch.tensor([1, 2], dtype=torch.int)
returned = module(values, None)
self.assertEqual(values, returned)
# Make sure that it's an alias, indicating that the operator saw a nullopt.
values[0] = 3
self.assertEqual(values, returned)
returned = module(values, [5, 4])
self.assertEqual(values, torch.tensor([3, 2], dtype=torch.int))
self.assertEqual(returned, torch.tensor([8, 6], dtype=torch.int))
def trace_optional_intlist(self, const):
def wrapper(values):
return torch._C._nn._test_optional_intlist(values, const)
return torch.jit.trace(wrapper, torch.tensor([1, 2], dtype=torch.int))
def test_optional_intlist(self):
self.do_test_optional_intlist_with_module(IntListWrapperModule())
self.do_test_optional_intlist_with_module(torch.jit.script(IntListWrapperModule()))
traced_none = self.trace_optional_intlist(None)
traced_list = self.trace_optional_intlist([5, 4])
# Not really a module, just lets us use our two traced functions to handle
# the specific cases of passing None and [5, 4].
def fake_module(values, const):
if const is None:
return traced_none(values)
if const == [5, 4]:
return traced_list(values)
raise Exception("Invalid argument")
self.do_test_optional_intlist_with_module(fake_module)
def test_optional_intlist_invalid(self):
with self.assertRaisesRegex(TypeError, "must be .* but found"):
IntListWrapperModule()(torch.zeros(1), [0.5])
with self.assertRaisesRegex(RuntimeError, "value of type .* instead found type"):
torch.jit.script(IntListWrapperModule())(torch.zeros(1), [0.5])
with self.assertRaisesRegex(TypeError, "must be .* Tensor"):
IntListWrapperModule()(torch.zeros(1), torch.zeros(1))
with self.assertRaisesRegex(RuntimeError, "value of type .* instead found type"):
torch.jit.script(IntListWrapperModule())(torch.zeros(1), torch.zeros(1))