[codemod][lint][fbcode/c*] Enable BLACK by default

Test Plan: manual inspection & sandcastle

Reviewed By: zertosh

Differential Revision: D30279364

fbshipit-source-id: c1ed77dfe43a3bde358f92737cd5535ae5d13c9a
This commit is contained in:
Zsolt Dollenstein
2021-08-12 10:56:55 -07:00
committed by Facebook GitHub Bot
parent aac3c7bd06
commit b004307252
188 changed files with 56875 additions and 28744 deletions

View File

@ -1,4 +1,5 @@
from typing import Optional, List
import torch
from torch.testing._internal.common_utils import TestCase, run_tests
@ -37,11 +38,14 @@ class TestNativeFunctions(TestCase):
def trace_optional_floatlist(self, const):
def wrapper(values):
return torch._C._nn._test_optional_floatlist(values, const)
return torch.jit.trace(wrapper, torch.tensor([1.5, 2.5], dtype=torch.float))
def test_optional_floatlist(self):
self.do_test_optional_floatlist_with_module(FloatListWrapperModule())
self.do_test_optional_floatlist_with_module(torch.jit.script(FloatListWrapperModule()))
self.do_test_optional_floatlist_with_module(
torch.jit.script(FloatListWrapperModule())
)
traced_none = self.trace_optional_floatlist(None)
traced_list = self.trace_optional_floatlist([5.1, 4.1])
@ -61,13 +65,17 @@ class TestNativeFunctions(TestCase):
with self.assertRaisesRegex(TypeError, "must be tuple of floats, not list"):
FloatListWrapperModule()(torch.zeros(1), ["hi"])
with self.assertRaisesRegex(RuntimeError, "value of type .* instead found type"):
with self.assertRaisesRegex(
RuntimeError, "value of type .* instead found type"
):
torch.jit.script(FloatListWrapperModule())(torch.zeros(1), ["hi"])
with self.assertRaisesRegex(TypeError, "must be .* Tensor"):
FloatListWrapperModule()(torch.zeros(1), torch.zeros(1))
with self.assertRaisesRegex(RuntimeError, "value of type .* instead found type"):
with self.assertRaisesRegex(
RuntimeError, "value of type .* instead found type"
):
torch.jit.script(FloatListWrapperModule())(torch.zeros(1), torch.zeros(1))
#
@ -90,11 +98,14 @@ class TestNativeFunctions(TestCase):
def trace_optional_intlist(self, const):
def wrapper(values):
return torch._C._nn._test_optional_intlist(values, const)
return torch.jit.trace(wrapper, torch.tensor([1, 2], dtype=torch.int))
def test_optional_intlist(self):
self.do_test_optional_intlist_with_module(IntListWrapperModule())
self.do_test_optional_intlist_with_module(torch.jit.script(IntListWrapperModule()))
self.do_test_optional_intlist_with_module(
torch.jit.script(IntListWrapperModule())
)
traced_none = self.trace_optional_intlist(None)
traced_list = self.trace_optional_intlist([5, 4])
@ -114,13 +125,17 @@ class TestNativeFunctions(TestCase):
with self.assertRaisesRegex(TypeError, "must be .* not"):
IntListWrapperModule()(torch.zeros(1), [0.5])
with self.assertRaisesRegex(RuntimeError, "value of type .* instead found type"):
with self.assertRaisesRegex(
RuntimeError, "value of type .* instead found type"
):
torch.jit.script(IntListWrapperModule())(torch.zeros(1), [0.5])
with self.assertRaisesRegex(TypeError, "must be .* Tensor"):
IntListWrapperModule()(torch.zeros(1), torch.zeros(1))
with self.assertRaisesRegex(RuntimeError, "value of type .* instead found type"):
with self.assertRaisesRegex(
RuntimeError, "value of type .* instead found type"
):
torch.jit.script(IntListWrapperModule())(torch.zeros(1), torch.zeros(1))
#
@ -143,13 +158,17 @@ class TestNativeFunctions(TestCase):
def trace_optional_filled_intlist(self, const):
def wrapper(values):
return torch._C._nn._test_optional_filled_intlist(values, const)
return torch.jit.trace(wrapper, torch.tensor([1, 2], dtype=torch.int))
def test_optional_filled_intlist(self):
def f(n: int):
x = torch._C._nn._test_optional_filled_intlist(torch.tensor([1, 1], dtype=torch.int), (n, n))
y = torch._C._nn._test_optional_filled_intlist(torch.tensor([1, 1], dtype=torch.int), n)
x = torch._C._nn._test_optional_filled_intlist(
torch.tensor([1, 1], dtype=torch.int), (n, n)
)
y = torch._C._nn._test_optional_filled_intlist(
torch.tensor([1, 1], dtype=torch.int), n
)
return x, y
# eager
@ -189,9 +208,10 @@ class TestNativeFunctions(TestCase):
def f(x):
torch._C._nn._test_string_default(x)
scripted_fn = torch.jit.script(f)
scripted_fn(dummy)
if __name__ == '__main__':
if __name__ == "__main__":
run_tests()