mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[dynamo] Support builtin bool on non-constant VTs (#155863)
In practice `bool(...)` is either constant folded by Dynamo or used for branching (so most of its emulation logic lived in `InstructionTranslator.generic_jump`. This patch adds a dedicated `bool` hanlder (only for symbolic bool/int/float for now), and fixes #136075. Pull Request resolved: https://github.com/pytorch/pytorch/pull/155863 Approved by: https://github.com/williamwen42
This commit is contained in:
committed by
PyTorch MergeBot
parent
6b45af38a5
commit
640f5a7090
@ -4578,6 +4578,20 @@ def forward(self, x, b, y):
|
||||
out = graph(x)
|
||||
self.assertEqual(ref_out, out)
|
||||
|
||||
def test_strict_fake_tensor_prop_real_tensors(self):
|
||||
class Foo(torch.nn.Module):
|
||||
def forward(self, x):
|
||||
return bool(x.eq(0.1).any().item())
|
||||
|
||||
model = Foo()
|
||||
inputs = (torch.randn(64),)
|
||||
ref = model(*inputs)
|
||||
with torch._functorch.config.patch(fake_tensor_propagate_real_tensors=True):
|
||||
ep = torch.export.export(model, inputs, strict=True)
|
||||
res = ep.module()(*inputs)
|
||||
|
||||
self.assertEqual(ref, res)
|
||||
|
||||
|
||||
class ExportTestsDevice(torch._dynamo.test_case.TestCase):
|
||||
def test_export_with_parameters(self, device):
|
||||
|
@ -12567,6 +12567,42 @@ fn
|
||||
with unittest.mock.patch("torch._dynamo.config.error_on_recompile", True):
|
||||
opt_mod(x)
|
||||
|
||||
@torch._dynamo.config.patch(capture_scalar_outputs=True)
|
||||
def test_builtin_bool_on_symint(self):
|
||||
def f(x):
|
||||
return bool(x.item())
|
||||
|
||||
opt_f = torch.compile(f, backend="eager", fullgraph=True)
|
||||
x = torch.randint(10, (1,))
|
||||
|
||||
ref = f(x)
|
||||
res = opt_f(x)
|
||||
self.assertEqual(ref, res)
|
||||
|
||||
@torch._dynamo.config.patch(capture_scalar_outputs=True)
|
||||
def test_builtin_bool_on_symfloat(self):
|
||||
def f(x):
|
||||
return bool(x.item())
|
||||
|
||||
opt_f = torch.compile(f, backend="eager", fullgraph=True)
|
||||
x = torch.randn(1)
|
||||
|
||||
ref = f(x)
|
||||
res = opt_f(x)
|
||||
self.assertEqual(ref, res)
|
||||
|
||||
@torch._dynamo.config.patch(capture_scalar_outputs=True)
|
||||
def test_builtin_bool_on_symbool(self):
|
||||
def f(x):
|
||||
return bool(x.item())
|
||||
|
||||
opt_f = torch.compile(f, backend="eager", fullgraph=True)
|
||||
x = torch.randn(1) == 1
|
||||
|
||||
ref = f(x)
|
||||
res = opt_f(x)
|
||||
self.assertEqual(ref, res)
|
||||
|
||||
|
||||
class TestTracer(JitTestCase):
|
||||
def test_jit_save(self):
|
||||
|
@ -1303,6 +1303,24 @@ class BuiltinVariable(VariableTracker):
|
||||
call_int = _call_int_float
|
||||
call_float = _call_int_float
|
||||
|
||||
def call_bool(self, tx: "InstructionTranslator", arg):
|
||||
# Emulate `PyBool_Type.tp_vectorcall` which boils down to `PyObject_IsTrue`.
|
||||
# https://github.com/python/cpython/blob/3.12/Objects/object.c#L1674-L1697
|
||||
if isinstance(arg, SymNodeVariable):
|
||||
# Note that we delay specializing on symbolic values to avoid
|
||||
# unnecessary guards. Specialization will happen later if, e.g., the
|
||||
# resulting boolean is used for branching.
|
||||
if isinstance(arg.sym_num, torch.SymBool):
|
||||
return arg
|
||||
|
||||
# Emulate `nb_bool` of int/float objects
|
||||
# - https://github.com/python/cpython/blob/3.12/Objects/longobject.c#L4940-L4944
|
||||
# - https://github.com/python/cpython/blob/3.12/Objects/floatobject.c#L878-L882
|
||||
assert istype(arg.sym_num, (torch.SymInt, torch.SymFloat))
|
||||
return SymNodeVariable.create(tx, arg.as_proxy() != 0)
|
||||
|
||||
# TODO handle more cases and merge this with this with `generic_jump`.
|
||||
|
||||
def call_str(self, tx: "InstructionTranslator", arg):
|
||||
# Handle `str` on a user defined function or object
|
||||
if isinstance(arg, (variables.UserFunctionVariable)):
|
||||
|
Reference in New Issue
Block a user