Add support for parsing torch.Generator in JIT (#140489)

Fixes #140420

Pull Request resolved: https://github.com/pytorch/pytorch/pull/140489
Approved by: https://github.com/davidberard98
This commit is contained in:
Antonio Kim
2024-11-13 23:06:54 +00:00
committed by PyTorch MergeBot
parent 70060b0927
commit b34bb1f562
2 changed files with 100 additions and 0 deletions

View File

@ -14184,6 +14184,43 @@ dedent """
FileCheck().check_not("prim::PythonOp").run(cu.test.graph)
def test_parse_generator(self):
def _test_parse_generator(seed):
jit_graph = parse_ir(
f"""
graph():
%0 : float = prim::Constant[value=-0.31622776601683789]()
%1 : float = prim::Constant[value=0.31622776601683789]()
%2 : Generator = prim::Constant[value=torch.Generator(device="cpu", seed={seed})]()
%3 : NoneType = prim::Constant()
%4 : int[] = prim::Constant[value=[]]()
%5 : int = prim::Constant[value=6]()
%6 : Device = prim::Constant[value="cpu"]()
%7 : Tensor = aten::empty(%4, %5, %3, %6, %3, %3)
%8 : Float() = aten::uniform(%7, %0, %1, %2)
return (%8)
""",
)
node = next(
n
for n in jit_graph.nodes()
if isinstance(n.output().type(), torch._C._GeneratorType)
)
assert isinstance(node.output().type(), torch._C._GeneratorType)
g = node.ival("value")
assert isinstance(g, torch.Generator)
self.assertEqual(g.initial_seed(), seed)
_test_parse_generator(2024)
_test_parse_generator(2**63 - 1)
with self.assertRaisesRegex(RuntimeError, "Seed must be a non-negative integer"):
_test_parse_generator(-2024)
with self.assertRaisesRegex(RuntimeError, "Number is too big"):
_test_parse_generator(2**63)
def test_early_return_rewrite(self):
def test_foo(x: bool):
if x: