Files
pytorch/test/jit/test_python_ir.py
Anthony Barbier bf7e290854 Add __main__ guards to jit tests (#154725)
This PR is part of a series attempting to re-submit https://github.com/pytorch/pytorch/pull/134592 as smaller PRs.

In jit tests:

- Add and use a common raise_on_run_directly method for when a user runs a test file directly which should not be run this way. Print the file which the user should have run.
- Raise a RuntimeError on tests which have been disabled (not run)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/154725
Approved by: https://github.com/clee2000
2025-06-16 10:28:45 +00:00

99 lines
3.2 KiB
Python

# Owner(s): ["oncall: jit"]
import unittest
import numpy as np
import torch
from torch.testing import FileCheck
from torch.testing._internal.common_utils import IS_MACOS, raise_on_run_directly
from torch.testing._internal.jit_utils import JitTestCase
class TestPythonIr(JitTestCase):
def test_param_strides(self):
def trace_me(arg):
return arg
t = torch.zeros(1, 3, 16, 16)
traced = torch.jit.trace(trace_me, t)
value = list(traced.graph.param_node().outputs())[0]
real_strides = list(t.stride())
type_strides = value.type().strides()
self.assertEqual(real_strides, type_strides)
def test_permute_inputs_binding(self):
@torch.jit.script
def foo(i, j, k):
pass
g = foo.graph
idxs = []
for i, inp in enumerate(g.inputs()):
inp.setDebugName(f"inp{i}")
idxs.append(i)
permuted_idxs = list(np.random.permutation(idxs))
g.permuteInputs(permuted_idxs)
for i, inp in enumerate(g.inputs()):
self.assertEqual(f"inp{permuted_idxs[i]}", inp.debugName())
@unittest.skipIf(IS_MACOS, "Failing on MacOS only")
def test_python_ir_utils(self):
@torch.jit.script
def foo(inp):
x = inp + 1
y = x / 2
z = y * y
return z
add_node = foo.graph.findNode("aten::add")
div_node = foo.graph.findNode("aten::div")
with foo.graph.insert_point_guard(add_node):
with foo.graph.insert_point_guard(div_node):
foo.graph.insertConstant("goodbye")
foo.graph.insertConstant("hello")
with foo.graph.insert_point_guard(foo.graph.findNode("aten::mul")):
foo.graph.insertConstant("hello")
FileCheck().check("hello").check("goodbye").check("hello").run(foo.graph)
self.assertTrue(add_node.matches(add_node.schema()))
self.assertFalse(add_node.matches(div_node.schema()))
def test_python_ir_utils_graph(self):
@torch.jit.script
def unrolled_mul(x: torch.Tensor, y: int):
out = x
for _ in range(y - 1):
out = out + x
return out
@torch.jit.script
def foo(x):
return x * 4
g = foo.graph
muls = g.findAllNodes("aten::mul")
scalar_muls = filter(
lambda x: x.matches("aten::mul(Tensor self, Scalar other) -> Tensor"), muls
)
mul_constant_int = filter(
lambda x: isinstance(list(x.inputs())[1].toIValue(), int), scalar_muls
)
for mul in mul_constant_int:
with g.insert_point_guard(mul):
outputs = g.insertGraph(unrolled_mul.graph, list(mul.inputs()))
assert len(outputs) == len(list(mul.outputs()))
for new_out, old_out in zip(outputs, g.outputs()):
old_out.replaceAllUsesWith(new_out)
mul.destroy()
FileCheck().check_not("aten::mul").check("aten::add").run(foo.graph)
self.assertEqual(foo(torch.ones([2, 2])), torch.ones([2, 2]) * 4)
if __name__ == "__main__":
raise_on_run_directly("test/test_jit.py")