mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
In pytorch/torch/_C/__init__.pyi, Graph.addInput has signature ```python def addInput(self, name: str) -> Value: ... ``` which doesn't match the corresponding function ```cpp Value* addInput(const std::string& name = "") { return block_->addInput(name); } ``` in python_ir.cpp. This PR aligns the bound function on both C++ and Python sides. Without this PR, mypy will compain whenever a change contains some calls to `addInput`; for example,  Pull Request resolved: https://github.com/pytorch/pytorch/pull/88528 Approved by: https://github.com/davidberard98
116 lines
3.6 KiB
Python
116 lines
3.6 KiB
Python
# Owner(s): ["oncall: jit"]
|
|
|
|
import torch
|
|
from torch.testing import FileCheck
|
|
from torch.testing._internal.jit_utils import JitTestCase
|
|
|
|
if __name__ == "__main__":
|
|
raise RuntimeError(
|
|
"This test file is not meant to be run directly, use:\n\n"
|
|
"\tpython test/test_jit.py TestPythonBindings\n\n"
|
|
"instead."
|
|
)
|
|
|
|
|
|
class TestPythonBindings(JitTestCase):
|
|
def test_cu_get_functions(self):
|
|
@torch.jit.script
|
|
def test_get_python_cu_fn(x: torch.Tensor):
|
|
return 2 * x
|
|
|
|
cu = torch.jit._state._python_cu
|
|
self.assertTrue(
|
|
"test_get_python_cu_fn" in (str(fn.name) for fn in cu.get_functions())
|
|
)
|
|
|
|
def test_cu_create_function(self):
|
|
@torch.jit.script
|
|
def fn(x: torch.Tensor):
|
|
return 2 * x
|
|
|
|
cu = torch._C.CompilationUnit()
|
|
cu.create_function("test_fn", fn.graph)
|
|
|
|
inp = torch.randn(5)
|
|
|
|
self.assertEqual(inp * 2, cu.find_function("test_fn")(inp))
|
|
self.assertEqual(cu.find_function("doesnt_exist"), None)
|
|
self.assertEqual(inp * 2, cu.test_fn(inp))
|
|
with self.assertRaises(AttributeError):
|
|
cu.doesnt_exist(inp)
|
|
|
|
def test_invalidation(self):
|
|
@torch.jit.script
|
|
def test_invalidation_fn(x: torch.Tensor):
|
|
return 2 * x
|
|
|
|
gr = test_invalidation_fn.graph.copy()
|
|
n = gr.insertNode(gr.create("prim::profile"))
|
|
v = n.output()
|
|
# check that they work
|
|
str((n, v))
|
|
torch._C._jit_pass_dce(gr)
|
|
with self.assertRaisesRegex(RuntimeError, "invalidated"):
|
|
str(n)
|
|
with self.assertRaisesRegex(RuntimeError, "invalidated"):
|
|
str(v)
|
|
|
|
def test_graph_iterator_keepalive(self):
|
|
@torch.jit.script
|
|
def test_iterator_keepalive_fn(x: torch.Tensor):
|
|
return 2 * x
|
|
|
|
# the list would segfault before because inlined_graph
|
|
# is temporary and had been deleted (see issue #50454)
|
|
n = test_iterator_keepalive_fn.inlined_graph.nodes()
|
|
list(n)
|
|
i = test_iterator_keepalive_fn.inlined_graph.inputs()
|
|
list(i)
|
|
o = test_iterator_keepalive_fn.inlined_graph.outputs()
|
|
list(o)
|
|
|
|
def test_aliasdb(self):
|
|
@torch.jit.script
|
|
def test_aliasdb_fn(x: torch.Tensor):
|
|
return 2 * x
|
|
|
|
gr = test_aliasdb_fn.graph.copy()
|
|
alias_db = gr.alias_db()
|
|
self.assertTrue("WILDCARD" in str(alias_db))
|
|
self.assertTrue("digraph alias_db" in alias_db.to_graphviz_str())
|
|
|
|
def test_graph_create(self):
|
|
gr = torch._C.Graph()
|
|
with self.assertRaises(ValueError):
|
|
gr.create("prim::Constant", [None])
|
|
|
|
def test_add_input(self):
|
|
gr = torch._C.Graph()
|
|
foo_value = gr.addInput("foo")
|
|
assert foo_value in gr.inputs()
|
|
|
|
def test_canonicalize(self):
|
|
ir = """
|
|
graph(%p207 : Tensor,
|
|
%1 : Tensor,
|
|
%p407 : int):
|
|
%11 : Tensor = aten::view_expand_placeholder(%1)
|
|
%12 : Tensor = aten::pointwise_placeholder(%11, %p207, %p407)
|
|
%13 : Tensor = aten::view_expand_placeholder(%12)
|
|
%14 : Tensor = aten::pointwise_placeholder(%13)
|
|
return (%14)
|
|
"""
|
|
|
|
graph1 = torch._C.parse_ir(ir)
|
|
graph1 = torch._C._jit_pass_canonicalize(graph1, True)
|
|
|
|
graph2 = torch._C.parse_ir(ir)
|
|
graph2 = torch._C._jit_pass_canonicalize(graph2)
|
|
|
|
self.assertEqual(str(graph1), str(graph2))
|
|
FileCheck().check("%p207").check_not("%14").run(graph1)
|
|
|
|
graph3 = torch._C.parse_ir(ir)
|
|
graph3 = torch._C._jit_pass_canonicalize(graph3, False)
|
|
FileCheck().check_not("%p207").run(graph3)
|