mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[fx][pass] Support converting a float32 tensor to a scalar in FX trace. (#158216)
Fixes https://github.com/pytorch/pytorch/issues/158083 Pull Request resolved: https://github.com/pytorch/pytorch/pull/158216 Approved by: https://github.com/laithsakka
This commit is contained in:
committed by
PyTorch MergeBot
parent
01f66d08d9
commit
29712314dd
@ -714,6 +714,40 @@ class UnspecTests(torch._dynamo.test_case.TestCase):
|
||||
self.assertEqual(fn_opt(x, y3), fn(x, y3))
|
||||
self.assertEqual(cnt.frame_count, 1)
|
||||
|
||||
@torch._dynamo.config.patch(capture_scalar_outputs=True)
|
||||
def test_tensorfiy_python_scalars_1(self):
|
||||
@torch.compile(backend="aot_eager")
|
||||
def f(x):
|
||||
y = x.sum()
|
||||
return x + y.item()
|
||||
|
||||
dtypes = [torch.bfloat16, torch.float16, torch.float32, torch.float64]
|
||||
for i, dtype in enumerate(dtypes):
|
||||
x = torch.ones(3, 3, dtype=dtype)
|
||||
self.assertEqual(f(x), x + x.sum().item())
|
||||
|
||||
@torch._dynamo.config.patch(capture_scalar_outputs=True)
|
||||
def test_tensorfiy_python_scalars_2(self):
|
||||
@torch.compile(backend="aot_eager")
|
||||
def f(x):
|
||||
return x.item() * x.item() * torch.ones((), dtype=torch.float64)
|
||||
|
||||
x = torch.tensor(1e20, dtype=torch.float32)
|
||||
self.assertEqual(
|
||||
f(x), x.item() * x.item() * torch.ones((), dtype=torch.float64)
|
||||
)
|
||||
|
||||
@torch._dynamo.config.patch(capture_scalar_outputs=True)
|
||||
def test_tensorfiy_python_scalars_3(self):
|
||||
@torch.compile(backend="aot_eager")
|
||||
def f(x):
|
||||
y = x.item() * 101
|
||||
return y * torch.tensor([1], dtype=torch.float32)
|
||||
|
||||
finfo_float16 = torch.finfo(torch.float16)
|
||||
x = torch.tensor([finfo_float16.max], dtype=torch.float16)
|
||||
self.assertEqual(f(x), x.item() * 101 * torch.tensor([1], dtype=torch.float32))
|
||||
|
||||
@torch._dynamo.config.patch(specialize_float=False, assume_static_by_default=False)
|
||||
def test_unspec_float_input_f64(self):
|
||||
cnts = torch._dynamo.testing.CompileCounter()
|
||||
|
@ -203,7 +203,7 @@ def tensorify_python_scalars(
|
||||
and node.target is torch.ops.aten._local_scalar_dense.default
|
||||
):
|
||||
dtype = node.args[0].meta["val"].dtype
|
||||
if dtype != torch.float64:
|
||||
if not dtype.is_floating_point:
|
||||
continue
|
||||
|
||||
assert isinstance(node.args[0], fx.Node), node.args[0]
|
||||
@ -212,6 +212,10 @@ def tensorify_python_scalars(
|
||||
expr_to_tensor_proxy[s] = MetaProxy(
|
||||
node.args[0], tracer=tracer, fake_mode=fake_mode
|
||||
)
|
||||
# Upcast the float tensor to torch.float64 to avoid precision problem
|
||||
expr_to_tensor_proxy[s] = torch.ops.prims.convert_element_type.default(
|
||||
expr_to_tensor_proxy[s], torch.float64
|
||||
)
|
||||
expr_to_sym_proxy[s] = MetaProxy(
|
||||
node, tracer=tracer, fake_mode=fake_mode
|
||||
)
|
||||
|
Reference in New Issue
Block a user