mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-21 05:34:18 +08:00
UFMT formatting on test/autograd test/ao test/cpp test/backends (#123369)
Partially addresses #123062 Ran lintrunner on - test/_test_bazel.py - test/ao - test/autograd test/backends test/benchmark_uitls test/conftest.py test/bottleneck_test test/cpp Pull Request resolved: https://github.com/pytorch/pytorch/pull/123369 Approved by: https://github.com/huydhn
This commit is contained in:
committed by
PyTorch MergeBot
parent
de7edeea25
commit
f71e368969
@ -1,5 +1,6 @@
|
||||
import sys
|
||||
import os
|
||||
import sys
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
@ -21,7 +22,7 @@ class FileSetup:
|
||||
|
||||
|
||||
class EvalModeForLoadedModule(FileSetup):
|
||||
path = 'dropout_model.pt'
|
||||
path = "dropout_model.pt"
|
||||
|
||||
def setup(self):
|
||||
class Model(torch.jit.ScriptModule):
|
||||
@ -40,7 +41,7 @@ class EvalModeForLoadedModule(FileSetup):
|
||||
|
||||
|
||||
class SerializationInterop(FileSetup):
|
||||
path = 'ivalue.pt'
|
||||
path = "ivalue.pt"
|
||||
|
||||
def setup(self):
|
||||
ones = torch.ones(2, 2)
|
||||
@ -53,7 +54,7 @@ class SerializationInterop(FileSetup):
|
||||
|
||||
# See testTorchSaveError in test/cpp/jit/tests.h for usage
|
||||
class TorchSaveError(FileSetup):
|
||||
path = 'eager_value.pt'
|
||||
path = "eager_value.pt"
|
||||
|
||||
def setup(self):
|
||||
ones = torch.ones(2, 2)
|
||||
@ -63,8 +64,9 @@ class TorchSaveError(FileSetup):
|
||||
|
||||
torch.save(value, self.path, _use_new_zipfile_serialization=False)
|
||||
|
||||
|
||||
class TorchSaveJitStream_CUDA(FileSetup):
|
||||
path = 'saved_stream_model.pt'
|
||||
path = "saved_stream_model.pt"
|
||||
|
||||
def setup(self):
|
||||
if not torch.cuda.is_available():
|
||||
@ -77,7 +79,9 @@ class TorchSaveJitStream_CUDA(FileSetup):
|
||||
b = torch.rand(3, 4, device="cuda")
|
||||
|
||||
with torch.cuda.stream(s):
|
||||
is_stream_s = torch.cuda.current_stream(s.device_index()).id() == s.id()
|
||||
is_stream_s = (
|
||||
torch.cuda.current_stream(s.device_index()).id() == s.id()
|
||||
)
|
||||
c = torch.cat((a, b), 0).to("cuda")
|
||||
s.synchronize()
|
||||
return is_stream_s, a, b, c
|
||||
@ -93,9 +97,10 @@ tests = [
|
||||
EvalModeForLoadedModule(),
|
||||
SerializationInterop(),
|
||||
TorchSaveError(),
|
||||
TorchSaveJitStream_CUDA()
|
||||
TorchSaveJitStream_CUDA(),
|
||||
]
|
||||
|
||||
|
||||
def setup():
|
||||
for test in tests:
|
||||
test.setup()
|
||||
|
Reference in New Issue
Block a user