mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Add support for parsing torch.Generator in JIT (#140489)
Fixes #140420 Pull Request resolved: https://github.com/pytorch/pytorch/pull/140489 Approved by: https://github.com/davidberard98
This commit is contained in:
committed by
PyTorch MergeBot
parent
70060b0927
commit
b34bb1f562
@ -14184,6 +14184,43 @@ dedent """
|
||||
|
||||
FileCheck().check_not("prim::PythonOp").run(cu.test.graph)
|
||||
|
||||
def test_parse_generator(self):
|
||||
def _test_parse_generator(seed):
|
||||
jit_graph = parse_ir(
|
||||
f"""
|
||||
graph():
|
||||
%0 : float = prim::Constant[value=-0.31622776601683789]()
|
||||
%1 : float = prim::Constant[value=0.31622776601683789]()
|
||||
%2 : Generator = prim::Constant[value=torch.Generator(device="cpu", seed={seed})]()
|
||||
%3 : NoneType = prim::Constant()
|
||||
%4 : int[] = prim::Constant[value=[]]()
|
||||
%5 : int = prim::Constant[value=6]()
|
||||
%6 : Device = prim::Constant[value="cpu"]()
|
||||
%7 : Tensor = aten::empty(%4, %5, %3, %6, %3, %3)
|
||||
%8 : Float() = aten::uniform(%7, %0, %1, %2)
|
||||
return (%8)
|
||||
""",
|
||||
)
|
||||
|
||||
node = next(
|
||||
n
|
||||
for n in jit_graph.nodes()
|
||||
if isinstance(n.output().type(), torch._C._GeneratorType)
|
||||
)
|
||||
assert isinstance(node.output().type(), torch._C._GeneratorType)
|
||||
g = node.ival("value")
|
||||
assert isinstance(g, torch.Generator)
|
||||
self.assertEqual(g.initial_seed(), seed)
|
||||
|
||||
_test_parse_generator(2024)
|
||||
_test_parse_generator(2**63 - 1)
|
||||
|
||||
with self.assertRaisesRegex(RuntimeError, "Seed must be a non-negative integer"):
|
||||
_test_parse_generator(-2024)
|
||||
|
||||
with self.assertRaisesRegex(RuntimeError, "Number is too big"):
|
||||
_test_parse_generator(2**63)
|
||||
|
||||
def test_early_return_rewrite(self):
|
||||
def test_foo(x: bool):
|
||||
if x:
|
||||
|
Reference in New Issue
Block a user