mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[codemod][lint][fbcode/c*] Enable BLACK by default
Test Plan: manual inspection & sandcastle Reviewed By: zertosh Differential Revision: D30279364 fbshipit-source-id: c1ed77dfe43a3bde358f92737cd5535ae5d13c9a
This commit is contained in:
committed by
Facebook GitHub Bot
parent
aac3c7bd06
commit
b004307252
@ -1,26 +1,28 @@
|
||||
import unittest
|
||||
import torch
|
||||
from torch import ops
|
||||
import torch.jit as jit
|
||||
import glob
|
||||
import os
|
||||
import unittest
|
||||
|
||||
import torch
|
||||
import torch.jit as jit
|
||||
from torch import ops
|
||||
from torch.testing._internal.common_utils import TestCase, run_tests
|
||||
|
||||
|
||||
def get_custom_class_library_path():
|
||||
library_filename = glob.glob("build/*custom_class*")
|
||||
assert (len(library_filename) == 1)
|
||||
assert len(library_filename) == 1
|
||||
library_filename = library_filename[0]
|
||||
path = os.path.abspath(library_filename)
|
||||
assert os.path.exists(path), path
|
||||
return path
|
||||
|
||||
|
||||
def test_equality(f, cmp_key):
|
||||
obj1 = f()
|
||||
obj2 = jit.script(f)()
|
||||
return (cmp_key(obj1), cmp_key(obj2))
|
||||
|
||||
|
||||
class TestCustomOperators(TestCase):
|
||||
def setUp(self):
|
||||
ops.load_library(get_custom_class_library_path())
|
||||
@ -29,12 +31,14 @@ class TestCustomOperators(TestCase):
|
||||
def f():
|
||||
val = torch.classes._TorchScriptTesting._Foo(5, 3)
|
||||
return val.info()
|
||||
|
||||
self.assertEqual(*test_equality(f, lambda x: x))
|
||||
|
||||
def test_constructor_with_args(self):
|
||||
def f():
|
||||
val = torch.classes._TorchScriptTesting._Foo(5, 3)
|
||||
return val
|
||||
|
||||
self.assertEqual(*test_equality(f, lambda x: x.info()))
|
||||
|
||||
def test_function_call_with_args(self):
|
||||
@ -54,7 +58,9 @@ class TestCustomOperators(TestCase):
|
||||
with self.assertRaisesRegex(RuntimeError, "Expected"):
|
||||
jit.script(f)()
|
||||
|
||||
@unittest.skip("We currently don't support passing custom classes to custom methods.")
|
||||
@unittest.skip(
|
||||
"We currently don't support passing custom classes to custom methods."
|
||||
)
|
||||
def test_input_class_type(self):
|
||||
def f():
|
||||
val = torch.classes._TorchScriptTesting._Foo(1, 2)
|
||||
@ -68,6 +74,7 @@ class TestCustomOperators(TestCase):
|
||||
def f():
|
||||
val = torch.classes._TorchScriptTesting._StackString(["asdf", "bruh"])
|
||||
return val.pop()
|
||||
|
||||
self.assertEqual(*test_equality(f, lambda x: x))
|
||||
|
||||
def test_stack_push_pop(self):
|
||||
@ -76,6 +83,7 @@ class TestCustomOperators(TestCase):
|
||||
val2 = torch.classes._TorchScriptTesting._StackString(["111", "222"])
|
||||
val.push(val2.pop())
|
||||
return val.pop() + val2.pop()
|
||||
|
||||
self.assertEqual(*test_equality(f, lambda x: x))
|
||||
|
||||
|
||||
|
@ -2,9 +2,8 @@ import os.path
|
||||
import tempfile
|
||||
|
||||
import torch
|
||||
from torch import ops
|
||||
|
||||
from model import Model, get_custom_op_library_path
|
||||
from torch import ops
|
||||
from torch.testing._internal.common_utils import TestCase, run_tests
|
||||
|
||||
|
||||
|
Reference in New Issue
Block a user