[fx][pass] Support converting a float32 tensor to a scalar in FX trace. (#158216)

Fixes https://github.com/pytorch/pytorch/issues/158083

Pull Request resolved: https://github.com/pytorch/pytorch/pull/158216
Approved by: https://github.com/laithsakka
This commit is contained in:
thenumberouscode
2025-08-09 15:13:13 +00:00
committed by PyTorch MergeBot
parent 01f66d08d9
commit 29712314dd
2 changed files with 39 additions and 1 deletions

View File

@ -714,6 +714,40 @@ class UnspecTests(torch._dynamo.test_case.TestCase):
self.assertEqual(fn_opt(x, y3), fn(x, y3))
self.assertEqual(cnt.frame_count, 1)
@torch._dynamo.config.patch(capture_scalar_outputs=True)
def test_tensorfiy_python_scalars_1(self):
@torch.compile(backend="aot_eager")
def f(x):
y = x.sum()
return x + y.item()
dtypes = [torch.bfloat16, torch.float16, torch.float32, torch.float64]
for i, dtype in enumerate(dtypes):
x = torch.ones(3, 3, dtype=dtype)
self.assertEqual(f(x), x + x.sum().item())
@torch._dynamo.config.patch(capture_scalar_outputs=True)
def test_tensorfiy_python_scalars_2(self):
@torch.compile(backend="aot_eager")
def f(x):
return x.item() * x.item() * torch.ones((), dtype=torch.float64)
x = torch.tensor(1e20, dtype=torch.float32)
self.assertEqual(
f(x), x.item() * x.item() * torch.ones((), dtype=torch.float64)
)
@torch._dynamo.config.patch(capture_scalar_outputs=True)
def test_tensorfiy_python_scalars_3(self):
@torch.compile(backend="aot_eager")
def f(x):
y = x.item() * 101
return y * torch.tensor([1], dtype=torch.float32)
finfo_float16 = torch.finfo(torch.float16)
x = torch.tensor([finfo_float16.max], dtype=torch.float16)
self.assertEqual(f(x), x.item() * 101 * torch.tensor([1], dtype=torch.float32))
@torch._dynamo.config.patch(specialize_float=False, assume_static_by_default=False)
def test_unspec_float_input_f64(self):
cnts = torch._dynamo.testing.CompileCounter()

View File

@ -203,7 +203,7 @@ def tensorify_python_scalars(
and node.target is torch.ops.aten._local_scalar_dense.default
):
dtype = node.args[0].meta["val"].dtype
if dtype != torch.float64:
if not dtype.is_floating_point:
continue
assert isinstance(node.args[0], fx.Node), node.args[0]
@ -212,6 +212,10 @@ def tensorify_python_scalars(
expr_to_tensor_proxy[s] = MetaProxy(
node.args[0], tracer=tracer, fake_mode=fake_mode
)
# Upcast the float tensor to torch.float64 to avoid precision problem
expr_to_tensor_proxy[s] = torch.ops.prims.convert_element_type.default(
expr_to_tensor_proxy[s], torch.float64
)
expr_to_sym_proxy[s] = MetaProxy(
node, tracer=tracer, fake_mode=fake_mode
)