mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[codemod][lint][fbcode/c*] Enable BLACK by default
Test Plan: manual inspection & sandcastle Reviewed By: zertosh Differential Revision: D30279364 fbshipit-source-id: c1ed77dfe43a3bde358f92737cd5535ae5d13c9a
This commit is contained in:
committed by
Facebook GitHub Bot
parent
aac3c7bd06
commit
b004307252
@ -2,14 +2,16 @@ import torch
|
||||
|
||||
OUTPUT_DIR = "src/androidTest/assets/"
|
||||
|
||||
|
||||
def scriptAndSave(module, fileName):
|
||||
print('-' * 80)
|
||||
print("-" * 80)
|
||||
script_module = torch.jit.script(module)
|
||||
print(script_module.graph)
|
||||
outputFileName = OUTPUT_DIR + fileName
|
||||
script_module.save(outputFileName)
|
||||
print("Saved to " + outputFileName)
|
||||
print('=' * 80)
|
||||
print("=" * 80)
|
||||
|
||||
|
||||
class Test(torch.jit.ScriptModule):
|
||||
def __init__(self):
|
||||
@ -73,7 +75,9 @@ class Test(torch.jit.ScriptModule):
|
||||
return res
|
||||
|
||||
@torch.jit.script_method
|
||||
def tupleIntSumReturnTuple(self, input: Tuple[int, int, int]) -> Tuple[Tuple[int, int, int], int]:
|
||||
def tupleIntSumReturnTuple(
|
||||
self, input: Tuple[int, int, int]
|
||||
) -> Tuple[Tuple[int, int, int], int]:
|
||||
sum = 0
|
||||
for x in input:
|
||||
sum += x
|
||||
@ -114,7 +118,7 @@ class Test(torch.jit.ScriptModule):
|
||||
@torch.jit.script_method
|
||||
def conv2d(self, x: Tensor, w: Tensor, toChannelsLast: bool) -> Tensor:
|
||||
r = torch.nn.functional.conv2d(x, w)
|
||||
if (toChannelsLast):
|
||||
if toChannelsLast:
|
||||
r = r.contiguous(memory_format=torch.channels_last)
|
||||
else:
|
||||
r = r.contiguous()
|
||||
@ -132,4 +136,5 @@ class Test(torch.jit.ScriptModule):
|
||||
def contiguousChannelsLast3d(self, x: Tensor) -> Tensor:
|
||||
return x.contiguous(memory_format=torch.channels_last_3d)
|
||||
|
||||
|
||||
scriptAndSave(Test(), "test.pt")
|
||||
|
Reference in New Issue
Block a user