mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[codemod][lint][fbcode/c*] Enable BLACK by default
Test Plan: manual inspection & sandcastle Reviewed By: zertosh Differential Revision: D30279364 fbshipit-source-id: c1ed77dfe43a3bde358f92737cd5535ae5d13c9a
This commit is contained in:
committed by
Facebook GitHub Bot
parent
aac3c7bd06
commit
b004307252
@ -1,9 +1,8 @@
|
||||
|
||||
import torch
|
||||
from torch.utils import ThroughputBenchmark
|
||||
from torch.testing import assert_allclose
|
||||
|
||||
from torch.testing._internal.common_utils import run_tests, TestCase, TemporaryFileName
|
||||
from torch.utils import ThroughputBenchmark
|
||||
|
||||
|
||||
class TwoLayerNet(torch.jit.ScriptModule):
|
||||
def __init__(self, D_in, H, D_out):
|
||||
@ -19,6 +18,7 @@ class TwoLayerNet(torch.jit.ScriptModule):
|
||||
y_pred = self.linear2(cat)
|
||||
return y_pred
|
||||
|
||||
|
||||
class TwoLayerNetModule(torch.nn.Module):
|
||||
def __init__(self, D_in, H, D_out):
|
||||
super(TwoLayerNetModule, self).__init__()
|
||||
@ -32,6 +32,7 @@ class TwoLayerNetModule(torch.nn.Module):
|
||||
y_pred = self.linear2(cat)
|
||||
return y_pred
|
||||
|
||||
|
||||
class TestThroughputBenchmark(TestCase):
|
||||
def linear_test(self, Module, profiler_output_path=""):
|
||||
D_in = 10
|
||||
@ -67,7 +68,6 @@ class TestThroughputBenchmark(TestCase):
|
||||
|
||||
print(stats)
|
||||
|
||||
|
||||
def test_script_module(self):
|
||||
self.linear_test(TwoLayerNet)
|
||||
|
||||
@ -79,5 +79,5 @@ class TestThroughputBenchmark(TestCase):
|
||||
self.linear_test(TwoLayerNetModule, profiler_output_path=fname)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
if __name__ == "__main__":
|
||||
run_tests()
|
||||
|
Reference in New Issue
Block a user