[dynamo] Support builtin bool on non-constant VTs (#155863)

In practice `bool(...)` is either constant folded by Dynamo or used for
branching (so most of its emulation logic lived in
`InstructionTranslator.generic_jump`.

This patch adds a dedicated `bool` hanlder (only for symbolic
bool/int/float for now), and fixes #136075.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/155863
Approved by: https://github.com/williamwen42
This commit is contained in:
Ryan Guo
2025-06-18 09:49:46 -07:00
committed by PyTorch MergeBot
parent 6b45af38a5
commit 640f5a7090
3 changed files with 68 additions and 0 deletions

View File

@ -4578,6 +4578,20 @@ def forward(self, x, b, y):
out = graph(x)
self.assertEqual(ref_out, out)
def test_strict_fake_tensor_prop_real_tensors(self):
class Foo(torch.nn.Module):
def forward(self, x):
return bool(x.eq(0.1).any().item())
model = Foo()
inputs = (torch.randn(64),)
ref = model(*inputs)
with torch._functorch.config.patch(fake_tensor_propagate_real_tensors=True):
ep = torch.export.export(model, inputs, strict=True)
res = ep.module()(*inputs)
self.assertEqual(ref, res)
class ExportTestsDevice(torch._dynamo.test_case.TestCase):
def test_export_with_parameters(self, device):

View File

@ -12567,6 +12567,42 @@ fn
with unittest.mock.patch("torch._dynamo.config.error_on_recompile", True):
opt_mod(x)
@torch._dynamo.config.patch(capture_scalar_outputs=True)
def test_builtin_bool_on_symint(self):
def f(x):
return bool(x.item())
opt_f = torch.compile(f, backend="eager", fullgraph=True)
x = torch.randint(10, (1,))
ref = f(x)
res = opt_f(x)
self.assertEqual(ref, res)
@torch._dynamo.config.patch(capture_scalar_outputs=True)
def test_builtin_bool_on_symfloat(self):
def f(x):
return bool(x.item())
opt_f = torch.compile(f, backend="eager", fullgraph=True)
x = torch.randn(1)
ref = f(x)
res = opt_f(x)
self.assertEqual(ref, res)
@torch._dynamo.config.patch(capture_scalar_outputs=True)
def test_builtin_bool_on_symbool(self):
def f(x):
return bool(x.item())
opt_f = torch.compile(f, backend="eager", fullgraph=True)
x = torch.randn(1) == 1
ref = f(x)
res = opt_f(x)
self.assertEqual(ref, res)
class TestTracer(JitTestCase):
def test_jit_save(self):

View File

@ -1303,6 +1303,24 @@ class BuiltinVariable(VariableTracker):
call_int = _call_int_float
call_float = _call_int_float
def call_bool(self, tx: "InstructionTranslator", arg):
# Emulate `PyBool_Type.tp_vectorcall` which boils down to `PyObject_IsTrue`.
# https://github.com/python/cpython/blob/3.12/Objects/object.c#L1674-L1697
if isinstance(arg, SymNodeVariable):
# Note that we delay specializing on symbolic values to avoid
# unnecessary guards. Specialization will happen later if, e.g., the
# resulting boolean is used for branching.
if isinstance(arg.sym_num, torch.SymBool):
return arg
# Emulate `nb_bool` of int/float objects
# - https://github.com/python/cpython/blob/3.12/Objects/longobject.c#L4940-L4944
# - https://github.com/python/cpython/blob/3.12/Objects/floatobject.c#L878-L882
assert istype(arg.sym_num, (torch.SymInt, torch.SymFloat))
return SymNodeVariable.create(tx, arg.as_proxy() != 0)
# TODO handle more cases and merge this with this with `generic_jump`.
def call_str(self, tx: "InstructionTranslator", arg):
# Handle `str` on a user defined function or object
if isinstance(arg, (variables.UserFunctionVariable)):