Files
pytorch/test/jit/test_python_ir.py
David Reiss 16980e455f Fix naming of "strides" method in TensorType (#35170)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/35170

Looks like this was renamed by accident in 0cbd7fa46f2

Test Plan:
Unit test.

Imported from OSS

Differential Revision: D20783298

fbshipit-source-id: 8fcc146284af022ec1afe8d651baf6721b190ad3
2020-04-08 15:59:28 -07:00

23 lines
727 B
Python

import torch
from torch.testing._internal.jit_utils import JitTestCase
from torch.testing import FileCheck
import io
if __name__ == '__main__':
raise RuntimeError("This test file is not meant to be run directly, use:\n\n"
"\tpython test/test_jit.py TESTNAME\n\n"
"instead.")
class TestPythonIr(JitTestCase):
def test_param_strides(self):
def trace_me(arg):
return arg
t = torch.zeros(1,3,16,16)
traced = torch.jit.trace(trace_me, t)
value = list(traced.graph.param_node().outputs())[0]
real_strides = list(t.stride())
type_strides = value.type().strides()
self.assertEqual(real_strides, type_strides)