Fix max_width computation in _tensor_str._Formatter (#126859)

Previous version of `torch._tensor_str._Formatter` was not using `PRINT_OPTS.sci_mode` for the `max_width` computation but was using it for the formatting of values leading to a weird discrepancy.

Now, the code first checks if it should be in sci_mode, then compute `max_width`

Here is an example to test the behavior:
```python
A = torch.tensor([10, 1e-1, 1e-2])
B = torch.tensor([10, 1e-1, 1e-1])

print("================= Default =================")
print(A, f"Formatter max_width: {torch._tensor_str._Formatter(A).max_width}")
print(B, f"Formatter max_width: {torch._tensor_str._Formatter(B).max_width}")

print("================= sci_mode=False =================")
with torch._tensor_str.printoptions(sci_mode=False):
    print(A, f"Formatter max_width: {torch._tensor_str._Formatter(A).max_width}")
    print(B, f"Formatter max_width: {torch._tensor_str._Formatter(B).max_width}")

print("================= sci_mode=True =================")
with torch._tensor_str.printoptions(sci_mode=True):
    print(A, f"Formatter max_width: {torch._tensor_str._Formatter(A).max_width}")
    print(B, f"Formatter max_width: {torch._tensor_str._Formatter(B).max_width}")
```

In the current version this prints:
```
================= Default =================
tensor([1.0000e+01, 1.0000e-01, 1.0000e-02]) Formatter max_width: 10
tensor([10.0000,  0.1000,  0.1000]) Formatter max_width: 7
================= sci_mode=False =================
tensor([   10.0000,     0.1000,     0.0100]) Formatter max_width: 10
tensor([10.0000,  0.1000,  0.1000]) Formatter max_width: 7
================= sci_mode=True =================
tensor([1.0000e+01, 1.0000e-01, 1.0000e-02]) Formatter max_width: 10
tensor([1.0000e+01, 1.0000e-01, 1.0000e-01]) Formatter max_width: 7
```

On can see that in `sci_mode=False`, the values of A are prefixed with unneeded 0 and does not have the same `max_width` as B (It keeps the `max_width` from `sci_mode = None`)

Also in `sci_mode = True`, for B, the `max_width` is 7 but each value takes 10 chars... (But it is fine as the code that uses `max_width` do not rely much on it, but still, this is missleading)

After this commit, this will print
```
================= Default =================
tensor([1.0000e+01, 1.0000e-01, 1.0000e-02]) Formatter max_width: 10
tensor([10.0000,  0.1000,  0.1000]) Formatter max_width: 7
================= sci_mode=False =================
tensor([10.0000,  0.1000,  0.0100]) Formatter max_width: 7
tensor([10.0000,  0.1000,  0.1000]) Formatter max_width: 7
================= sci_mode=True =================
tensor([1.0000e+01, 1.0000e-01, 1.0000e-02]) Formatter max_width: 10
tensor([1.0000e+01, 1.0000e-01, 1.0000e-01]) Formatter max_width: 10
```

This also allows to align A with B for `sci_mode=False`.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/126859
Approved by: https://github.com/malfet
This commit is contained in:
Raphael Reme
2025-08-01 15:05:41 +00:00
committed by PyTorch MergeBot
parent b0b3e6e48b
commit ee2649219c
2 changed files with 11 additions and 15 deletions

View File

@ -8337,7 +8337,7 @@ class TestTorch(TestCase):
self.assertExpectedInline(str(x), '''tensor([1.0000e+02, 1.0000e-02])''')
torch.set_printoptions(sci_mode=False)
self.assertEqual(x.__repr__(), str(x))
self.assertExpectedInline(str(x), '''tensor([ 100.0000, 0.0100])''')
self.assertExpectedInline(str(x), '''tensor([100.0000, 0.0100])''')
torch.set_printoptions(sci_mode=None) # reset to the default value
# test no leading space if all elements positive

View File

@ -178,14 +178,18 @@ class _Formatter:
self.int_mode = False
break
self.sci_mode = (
nonzero_finite_max / nonzero_finite_min > 1000.0
or nonzero_finite_max > 1.0e8
or nonzero_finite_min < 1.0e-4
if PRINT_OPTS.sci_mode is None
else PRINT_OPTS.sci_mode
)
if self.int_mode:
# in int_mode for floats, all numbers are integers, and we append a decimal to nonfinites
# to indicate that the tensor is of floating type. add 1 to the len to account for this.
if (
nonzero_finite_max / nonzero_finite_min > 1000.0
or nonzero_finite_max > 1.0e8
):
self.sci_mode = True
if self.sci_mode:
for value in nonzero_finite_vals:
value_str = f"{{:.{PRINT_OPTS.precision}e}}".format(value)
self.max_width = max(self.max_width, len(value_str))
@ -195,12 +199,7 @@ class _Formatter:
self.max_width = max(self.max_width, len(value_str) + 1)
else:
# Check if scientific representation should be used.
if (
nonzero_finite_max / nonzero_finite_min > 1000.0
or nonzero_finite_max > 1.0e8
or nonzero_finite_min < 1.0e-4
):
self.sci_mode = True
if self.sci_mode:
for value in nonzero_finite_vals:
value_str = f"{{:.{PRINT_OPTS.precision}e}}".format(value)
self.max_width = max(self.max_width, len(value_str))
@ -209,9 +208,6 @@ class _Formatter:
value_str = f"{{:.{PRINT_OPTS.precision}f}}".format(value)
self.max_width = max(self.max_width, len(value_str))
if PRINT_OPTS.sci_mode is not None:
self.sci_mode = PRINT_OPTS.sci_mode
def width(self):
return self.max_width