mirror of
https://github.com/pytorch/pytorch.git
synced 2025-11-01 04:54:55 +08:00
[JIT] python IR bindings: consolidate tests, add short docs in OVERVIEW.md (#118319)
Document the existence of python IR bindings; quick comments about it; and consolidate tests in one file to serve as examples to users. Pull Request resolved: https://github.com/pytorch/pytorch/pull/118319 Approved by: https://github.com/eellison
This commit is contained in:
committed by
PyTorch MergeBot
parent
9bce208dfb
commit
40c08795b0
@ -1,7 +1,12 @@
|
||||
# Owner(s): ["oncall: jit"]
|
||||
|
||||
import torch
|
||||
from torch.testing import FileCheck
|
||||
from torch.testing._internal.jit_utils import JitTestCase
|
||||
from torch.testing._internal.common_utils import IS_MACOS
|
||||
|
||||
import numpy as np
|
||||
import unittest
|
||||
|
||||
if __name__ == '__main__':
|
||||
raise RuntimeError("This test file is not meant to be run directly, use:\n\n"
|
||||
@ -18,3 +23,70 @@ class TestPythonIr(JitTestCase):
|
||||
real_strides = list(t.stride())
|
||||
type_strides = value.type().strides()
|
||||
self.assertEqual(real_strides, type_strides)
|
||||
|
||||
def test_permute_inputs_binding(self):
|
||||
@torch.jit.script
|
||||
def foo(i, j, k):
|
||||
pass
|
||||
|
||||
g = foo.graph
|
||||
|
||||
idxs = []
|
||||
for i, inp in enumerate(g.inputs()):
|
||||
inp.setDebugName(f"inp{i}")
|
||||
idxs.append(i)
|
||||
|
||||
permuted_idxs = list(np.random.permutation(idxs))
|
||||
g.permuteInputs(permuted_idxs)
|
||||
for i, inp in enumerate(g.inputs()):
|
||||
self.assertEqual(f"inp{permuted_idxs[i]}", inp.debugName())
|
||||
|
||||
@unittest.skipIf(IS_MACOS, "Failing on MacOS only")
|
||||
def test_python_ir_utils(self):
|
||||
@torch.jit.script
|
||||
def foo(inp):
|
||||
x = inp + 1
|
||||
y = x / 2
|
||||
z = y * y
|
||||
return z
|
||||
|
||||
add_node = foo.graph.findNode("aten::add")
|
||||
div_node = foo.graph.findNode("aten::div")
|
||||
|
||||
with foo.graph.insert_point_guard(add_node):
|
||||
with foo.graph.insert_point_guard(div_node):
|
||||
foo.graph.insertConstant("goodbye")
|
||||
foo.graph.insertConstant("hello")
|
||||
with foo.graph.insert_point_guard(foo.graph.findNode("aten::mul")):
|
||||
foo.graph.insertConstant("hello")
|
||||
FileCheck().check("hello").check("goodbye").check("hello").run(foo.graph)
|
||||
|
||||
self.assertTrue(add_node.matches(add_node.schema()))
|
||||
self.assertFalse(add_node.matches(div_node.schema()))
|
||||
|
||||
def test_python_ir_utils_graph(self):
|
||||
@torch.jit.script
|
||||
def unrolled_mul(x: torch.Tensor, y: int):
|
||||
out = x
|
||||
for _ in range(y - 1):
|
||||
out = out + x
|
||||
return out
|
||||
|
||||
@torch.jit.script
|
||||
def foo(x):
|
||||
return x * 4
|
||||
|
||||
g = foo.graph
|
||||
muls = g.findAllNodes("aten::mul")
|
||||
scalar_muls = filter(lambda x: x.matches("aten::mul(Tensor self, Scalar other) -> Tensor"), muls)
|
||||
mul_constant_int = filter(lambda x: isinstance(list(x.inputs())[1].toIValue(), int), scalar_muls)
|
||||
for mul in mul_constant_int:
|
||||
with g.insert_point_guard(mul):
|
||||
outputs = g.insertGraph(unrolled_mul.graph, list(mul.inputs()))
|
||||
assert len(outputs) == len(list(mul.outputs()))
|
||||
for new_out, old_out in zip(outputs, g.outputs()):
|
||||
old_out.replaceAllUsesWith(new_out)
|
||||
mul.destroy()
|
||||
|
||||
FileCheck().check_not("aten::mul").check("aten::add").run(foo.graph)
|
||||
self.assertEqual(foo(torch.ones([2, 2])), torch.ones([2, 2]) * 4)
|
||||
|
||||
Reference in New Issue
Block a user