mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-21 05:34:18 +08:00
[codemod][lint][fbcode/c*] Enable BLACK by default
Test Plan: manual inspection & sandcastle Reviewed By: zertosh Differential Revision: D30279364 fbshipit-source-id: c1ed77dfe43a3bde358f92737cd5535ae5d13c9a
This commit is contained in:
committed by
Facebook GitHub Bot
parent
aac3c7bd06
commit
b004307252
@ -1,4 +1,5 @@
|
||||
from typing import Optional, List
|
||||
|
||||
import torch
|
||||
from torch.testing._internal.common_utils import TestCase, run_tests
|
||||
|
||||
@ -37,11 +38,14 @@ class TestNativeFunctions(TestCase):
|
||||
def trace_optional_floatlist(self, const):
|
||||
def wrapper(values):
|
||||
return torch._C._nn._test_optional_floatlist(values, const)
|
||||
|
||||
return torch.jit.trace(wrapper, torch.tensor([1.5, 2.5], dtype=torch.float))
|
||||
|
||||
def test_optional_floatlist(self):
|
||||
self.do_test_optional_floatlist_with_module(FloatListWrapperModule())
|
||||
self.do_test_optional_floatlist_with_module(torch.jit.script(FloatListWrapperModule()))
|
||||
self.do_test_optional_floatlist_with_module(
|
||||
torch.jit.script(FloatListWrapperModule())
|
||||
)
|
||||
|
||||
traced_none = self.trace_optional_floatlist(None)
|
||||
traced_list = self.trace_optional_floatlist([5.1, 4.1])
|
||||
@ -61,13 +65,17 @@ class TestNativeFunctions(TestCase):
|
||||
with self.assertRaisesRegex(TypeError, "must be tuple of floats, not list"):
|
||||
FloatListWrapperModule()(torch.zeros(1), ["hi"])
|
||||
|
||||
with self.assertRaisesRegex(RuntimeError, "value of type .* instead found type"):
|
||||
with self.assertRaisesRegex(
|
||||
RuntimeError, "value of type .* instead found type"
|
||||
):
|
||||
torch.jit.script(FloatListWrapperModule())(torch.zeros(1), ["hi"])
|
||||
|
||||
with self.assertRaisesRegex(TypeError, "must be .* Tensor"):
|
||||
FloatListWrapperModule()(torch.zeros(1), torch.zeros(1))
|
||||
|
||||
with self.assertRaisesRegex(RuntimeError, "value of type .* instead found type"):
|
||||
with self.assertRaisesRegex(
|
||||
RuntimeError, "value of type .* instead found type"
|
||||
):
|
||||
torch.jit.script(FloatListWrapperModule())(torch.zeros(1), torch.zeros(1))
|
||||
|
||||
#
|
||||
@ -90,11 +98,14 @@ class TestNativeFunctions(TestCase):
|
||||
def trace_optional_intlist(self, const):
|
||||
def wrapper(values):
|
||||
return torch._C._nn._test_optional_intlist(values, const)
|
||||
|
||||
return torch.jit.trace(wrapper, torch.tensor([1, 2], dtype=torch.int))
|
||||
|
||||
def test_optional_intlist(self):
|
||||
self.do_test_optional_intlist_with_module(IntListWrapperModule())
|
||||
self.do_test_optional_intlist_with_module(torch.jit.script(IntListWrapperModule()))
|
||||
self.do_test_optional_intlist_with_module(
|
||||
torch.jit.script(IntListWrapperModule())
|
||||
)
|
||||
|
||||
traced_none = self.trace_optional_intlist(None)
|
||||
traced_list = self.trace_optional_intlist([5, 4])
|
||||
@ -114,13 +125,17 @@ class TestNativeFunctions(TestCase):
|
||||
with self.assertRaisesRegex(TypeError, "must be .* not"):
|
||||
IntListWrapperModule()(torch.zeros(1), [0.5])
|
||||
|
||||
with self.assertRaisesRegex(RuntimeError, "value of type .* instead found type"):
|
||||
with self.assertRaisesRegex(
|
||||
RuntimeError, "value of type .* instead found type"
|
||||
):
|
||||
torch.jit.script(IntListWrapperModule())(torch.zeros(1), [0.5])
|
||||
|
||||
with self.assertRaisesRegex(TypeError, "must be .* Tensor"):
|
||||
IntListWrapperModule()(torch.zeros(1), torch.zeros(1))
|
||||
|
||||
with self.assertRaisesRegex(RuntimeError, "value of type .* instead found type"):
|
||||
with self.assertRaisesRegex(
|
||||
RuntimeError, "value of type .* instead found type"
|
||||
):
|
||||
torch.jit.script(IntListWrapperModule())(torch.zeros(1), torch.zeros(1))
|
||||
|
||||
#
|
||||
@ -143,13 +158,17 @@ class TestNativeFunctions(TestCase):
|
||||
def trace_optional_filled_intlist(self, const):
|
||||
def wrapper(values):
|
||||
return torch._C._nn._test_optional_filled_intlist(values, const)
|
||||
|
||||
return torch.jit.trace(wrapper, torch.tensor([1, 2], dtype=torch.int))
|
||||
|
||||
def test_optional_filled_intlist(self):
|
||||
|
||||
def f(n: int):
|
||||
x = torch._C._nn._test_optional_filled_intlist(torch.tensor([1, 1], dtype=torch.int), (n, n))
|
||||
y = torch._C._nn._test_optional_filled_intlist(torch.tensor([1, 1], dtype=torch.int), n)
|
||||
x = torch._C._nn._test_optional_filled_intlist(
|
||||
torch.tensor([1, 1], dtype=torch.int), (n, n)
|
||||
)
|
||||
y = torch._C._nn._test_optional_filled_intlist(
|
||||
torch.tensor([1, 1], dtype=torch.int), n
|
||||
)
|
||||
return x, y
|
||||
|
||||
# eager
|
||||
@ -189,9 +208,10 @@ class TestNativeFunctions(TestCase):
|
||||
|
||||
def f(x):
|
||||
torch._C._nn._test_string_default(x)
|
||||
|
||||
scripted_fn = torch.jit.script(f)
|
||||
scripted_fn(dummy)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
if __name__ == "__main__":
|
||||
run_tests()
|
||||
|
Reference in New Issue
Block a user