mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-21 05:34:18 +08:00
[codemod][lint][fbcode/c*] Enable BLACK by default
Test Plan: manual inspection & sandcastle Reviewed By: zertosh Differential Revision: D30279364 fbshipit-source-id: c1ed77dfe43a3bde358f92737cd5535ae5d13c9a
This commit is contained in:
committed by
Facebook GitHub Bot
parent
aac3c7bd06
commit
b004307252
@ -1,15 +1,16 @@
|
||||
import os
|
||||
import unittest
|
||||
|
||||
import torch.testing._internal.common_utils as common
|
||||
from torch.testing._internal.common_utils import IS_WINDOWS
|
||||
from torch.testing._internal.common_cuda import TEST_CUDA
|
||||
import torch
|
||||
import torch.backends.cudnn
|
||||
import torch.testing._internal.common_utils as common
|
||||
import torch.utils.cpp_extension
|
||||
from torch.testing._internal.common_cuda import TEST_CUDA
|
||||
from torch.testing._internal.common_utils import IS_WINDOWS
|
||||
|
||||
try:
|
||||
import pytest
|
||||
|
||||
HAS_PYTEST = True
|
||||
except ImportError as e:
|
||||
HAS_PYTEST = False
|
||||
@ -102,28 +103,28 @@ class TestCppExtensionAOT(common.TestCase):
|
||||
|
||||
class TestMSNPUTensor(common.TestCase):
|
||||
def test_unregistered(self):
|
||||
a = torch.arange(0, 10, device='cpu')
|
||||
a = torch.arange(0, 10, device="cpu")
|
||||
with self.assertRaisesRegex(RuntimeError, "Could not run"):
|
||||
b = torch.arange(0, 10, device='msnpu')
|
||||
b = torch.arange(0, 10, device="msnpu")
|
||||
|
||||
def test_zeros(self):
|
||||
a = torch.empty(5, 5, device='cpu')
|
||||
self.assertEqual(a.device, torch.device('cpu'))
|
||||
a = torch.empty(5, 5, device="cpu")
|
||||
self.assertEqual(a.device, torch.device("cpu"))
|
||||
|
||||
b = torch.empty(5, 5, device='msnpu')
|
||||
self.assertEqual(b.device, torch.device('msnpu', 0))
|
||||
b = torch.empty(5, 5, device="msnpu")
|
||||
self.assertEqual(b.device, torch.device("msnpu", 0))
|
||||
self.assertEqual(msnpu_extension.get_test_int(), 0)
|
||||
self.assertEqual(torch.get_default_dtype(), b.dtype)
|
||||
|
||||
c = torch.empty((5, 5), dtype=torch.int64, device='msnpu')
|
||||
c = torch.empty((5, 5), dtype=torch.int64, device="msnpu")
|
||||
self.assertEqual(msnpu_extension.get_test_int(), 0)
|
||||
self.assertEqual(torch.int64, c.dtype)
|
||||
|
||||
def test_add(self):
|
||||
a = torch.empty(5, 5, device='msnpu', requires_grad=True)
|
||||
a = torch.empty(5, 5, device="msnpu", requires_grad=True)
|
||||
self.assertEqual(msnpu_extension.get_test_int(), 0)
|
||||
|
||||
b = torch.empty(5, 5, device='msnpu')
|
||||
b = torch.empty(5, 5, device="msnpu")
|
||||
self.assertEqual(msnpu_extension.get_test_int(), 0)
|
||||
|
||||
c = a + b
|
||||
@ -132,9 +133,9 @@ class TestMSNPUTensor(common.TestCase):
|
||||
def test_conv_backend_override(self):
|
||||
# To simplify tests, we use 4d input here to avoid doing view4d( which
|
||||
# needs more overrides) in _convolution.
|
||||
input = torch.empty(2, 4, 10, 2, device='msnpu', requires_grad=True)
|
||||
weight = torch.empty(6, 4, 2, 2, device='msnpu', requires_grad=True)
|
||||
bias = torch.empty(6, device='msnpu')
|
||||
input = torch.empty(2, 4, 10, 2, device="msnpu", requires_grad=True)
|
||||
weight = torch.empty(6, 4, 2, 2, device="msnpu", requires_grad=True)
|
||||
bias = torch.empty(6, device="msnpu")
|
||||
|
||||
# Make sure forward is overriden
|
||||
out = torch.nn.functional.conv1d(input, weight, bias, 2, 0, 1, 1)
|
||||
@ -151,7 +152,6 @@ class TestMSNPUTensor(common.TestCase):
|
||||
|
||||
|
||||
class TestRNGExtension(common.TestCase):
|
||||
|
||||
def setUp(self):
|
||||
super(TestRNGExtension, self).setUp()
|
||||
|
||||
@ -161,7 +161,7 @@ class TestRNGExtension(common.TestCase):
|
||||
t = torch.empty(10, dtype=torch.int64).random_()
|
||||
self.assertNotEqual(t, fourty_two)
|
||||
|
||||
gen = torch.Generator(device='cpu')
|
||||
gen = torch.Generator(device="cpu")
|
||||
t = torch.empty(10, dtype=torch.int64).random_(generator=gen)
|
||||
self.assertNotEqual(t, fourty_two)
|
||||
|
||||
@ -187,7 +187,6 @@ class TestRNGExtension(common.TestCase):
|
||||
|
||||
@unittest.skipIf(not TEST_CUDA, "CUDA not found")
|
||||
class TestTorchLibrary(common.TestCase):
|
||||
|
||||
def test_torch_library(self):
|
||||
import torch_test_cpp_extension.torch_library # noqa: F401
|
||||
|
||||
@ -203,7 +202,7 @@ class TestTorchLibrary(common.TestCase):
|
||||
self.assertFalse(s(True, False))
|
||||
self.assertFalse(s(False, True))
|
||||
self.assertFalse(s(False, False))
|
||||
self.assertIn('torch_library::logical_and', str(s.graph))
|
||||
self.assertIn("torch_library::logical_and", str(s.graph))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
Reference in New Issue
Block a user