mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/37174 ghstack-source-id: 106938112 Test Plan: Upcoming diffs use this for upsampling. Differential Revision: D21210002 fbshipit-source-id: d6a55ab6420c05a92873a569221b613149aa0daa
62 lines
2.5 KiB
Python
62 lines
2.5 KiB
Python
from typing import Optional, List
|
|
import torch
|
|
from torch.testing._internal.common_utils import TestCase
|
|
|
|
# End-to-end tests of features in native_functions.yaml
|
|
|
|
class IntListWrapperModule(torch.nn.Module):
|
|
def forward(self, values, incr: Optional[List[int]]):
|
|
return torch._C._nn._test_optional_intlist(values, incr)
|
|
|
|
|
|
class TestNativeFunctions(TestCase):
|
|
|
|
def do_test_optional_intlist_with_module(self, module):
|
|
values = torch.tensor([1, 2], dtype=torch.int)
|
|
|
|
returned = module(values, None)
|
|
self.assertEqual(values, returned)
|
|
# Make sure that it's an alias, indicating that the operator saw a nullopt.
|
|
values[0] = 3
|
|
self.assertEqual(values, returned)
|
|
|
|
returned = module(values, [5, 4])
|
|
self.assertEqual(values, torch.tensor([3, 2], dtype=torch.int))
|
|
self.assertEqual(returned, torch.tensor([8, 6], dtype=torch.int))
|
|
|
|
def trace_optional_intlist(self, const):
|
|
def wrapper(values):
|
|
return torch._C._nn._test_optional_intlist(values, const)
|
|
return torch.jit.trace(wrapper, torch.tensor([1, 2], dtype=torch.int))
|
|
|
|
def test_optional_intlist(self):
|
|
self.do_test_optional_intlist_with_module(IntListWrapperModule())
|
|
self.do_test_optional_intlist_with_module(torch.jit.script(IntListWrapperModule()))
|
|
|
|
traced_none = self.trace_optional_intlist(None)
|
|
traced_list = self.trace_optional_intlist([5, 4])
|
|
|
|
# Not really a module, just lets us use our two traced functions to handle
|
|
# the specific cases of passing None and [5, 4].
|
|
def fake_module(values, const):
|
|
if const is None:
|
|
return traced_none(values)
|
|
if const == [5, 4]:
|
|
return traced_list(values)
|
|
raise Exception("Invalid argument")
|
|
|
|
self.do_test_optional_intlist_with_module(fake_module)
|
|
|
|
def test_optional_intlist_invalid(self):
|
|
with self.assertRaisesRegex(TypeError, "must be .* but found"):
|
|
IntListWrapperModule()(torch.zeros(1), [0.5])
|
|
|
|
with self.assertRaisesRegex(RuntimeError, "value of type .* instead found type"):
|
|
torch.jit.script(IntListWrapperModule())(torch.zeros(1), [0.5])
|
|
|
|
with self.assertRaisesRegex(TypeError, "must be .* Tensor"):
|
|
IntListWrapperModule()(torch.zeros(1), torch.zeros(1))
|
|
|
|
with self.assertRaisesRegex(RuntimeError, "value of type .* instead found type"):
|
|
torch.jit.script(IntListWrapperModule())(torch.zeros(1), torch.zeros(1))
|