mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Fix max_width computation in _tensor_str._Formatter (#126859)
Previous version of `torch._tensor_str._Formatter` was not using `PRINT_OPTS.sci_mode` for the `max_width` computation but was using it for the formatting of values leading to a weird discrepancy. Now, the code first checks if it should be in sci_mode, then compute `max_width` Here is an example to test the behavior: ```python A = torch.tensor([10, 1e-1, 1e-2]) B = torch.tensor([10, 1e-1, 1e-1]) print("================= Default =================") print(A, f"Formatter max_width: {torch._tensor_str._Formatter(A).max_width}") print(B, f"Formatter max_width: {torch._tensor_str._Formatter(B).max_width}") print("================= sci_mode=False =================") with torch._tensor_str.printoptions(sci_mode=False): print(A, f"Formatter max_width: {torch._tensor_str._Formatter(A).max_width}") print(B, f"Formatter max_width: {torch._tensor_str._Formatter(B).max_width}") print("================= sci_mode=True =================") with torch._tensor_str.printoptions(sci_mode=True): print(A, f"Formatter max_width: {torch._tensor_str._Formatter(A).max_width}") print(B, f"Formatter max_width: {torch._tensor_str._Formatter(B).max_width}") ``` In the current version this prints: ``` ================= Default ================= tensor([1.0000e+01, 1.0000e-01, 1.0000e-02]) Formatter max_width: 10 tensor([10.0000, 0.1000, 0.1000]) Formatter max_width: 7 ================= sci_mode=False ================= tensor([ 10.0000, 0.1000, 0.0100]) Formatter max_width: 10 tensor([10.0000, 0.1000, 0.1000]) Formatter max_width: 7 ================= sci_mode=True ================= tensor([1.0000e+01, 1.0000e-01, 1.0000e-02]) Formatter max_width: 10 tensor([1.0000e+01, 1.0000e-01, 1.0000e-01]) Formatter max_width: 7 ``` On can see that in `sci_mode=False`, the values of A are prefixed with unneeded 0 and does not have the same `max_width` as B (It keeps the `max_width` from `sci_mode = None`) Also in `sci_mode = True`, for B, the `max_width` is 7 but each value takes 10 chars... (But it is fine as the code that uses `max_width` do not rely much on it, but still, this is missleading) After this commit, this will print ``` ================= Default ================= tensor([1.0000e+01, 1.0000e-01, 1.0000e-02]) Formatter max_width: 10 tensor([10.0000, 0.1000, 0.1000]) Formatter max_width: 7 ================= sci_mode=False ================= tensor([10.0000, 0.1000, 0.0100]) Formatter max_width: 7 tensor([10.0000, 0.1000, 0.1000]) Formatter max_width: 7 ================= sci_mode=True ================= tensor([1.0000e+01, 1.0000e-01, 1.0000e-02]) Formatter max_width: 10 tensor([1.0000e+01, 1.0000e-01, 1.0000e-01]) Formatter max_width: 10 ``` This also allows to align A with B for `sci_mode=False`. Pull Request resolved: https://github.com/pytorch/pytorch/pull/126859 Approved by: https://github.com/malfet
This commit is contained in:
committed by
PyTorch MergeBot
parent
b0b3e6e48b
commit
ee2649219c
@ -8337,7 +8337,7 @@ class TestTorch(TestCase):
|
||||
self.assertExpectedInline(str(x), '''tensor([1.0000e+02, 1.0000e-02])''')
|
||||
torch.set_printoptions(sci_mode=False)
|
||||
self.assertEqual(x.__repr__(), str(x))
|
||||
self.assertExpectedInline(str(x), '''tensor([ 100.0000, 0.0100])''')
|
||||
self.assertExpectedInline(str(x), '''tensor([100.0000, 0.0100])''')
|
||||
torch.set_printoptions(sci_mode=None) # reset to the default value
|
||||
|
||||
# test no leading space if all elements positive
|
||||
|
@ -178,14 +178,18 @@ class _Formatter:
|
||||
self.int_mode = False
|
||||
break
|
||||
|
||||
self.sci_mode = (
|
||||
nonzero_finite_max / nonzero_finite_min > 1000.0
|
||||
or nonzero_finite_max > 1.0e8
|
||||
or nonzero_finite_min < 1.0e-4
|
||||
if PRINT_OPTS.sci_mode is None
|
||||
else PRINT_OPTS.sci_mode
|
||||
)
|
||||
|
||||
if self.int_mode:
|
||||
# in int_mode for floats, all numbers are integers, and we append a decimal to nonfinites
|
||||
# to indicate that the tensor is of floating type. add 1 to the len to account for this.
|
||||
if (
|
||||
nonzero_finite_max / nonzero_finite_min > 1000.0
|
||||
or nonzero_finite_max > 1.0e8
|
||||
):
|
||||
self.sci_mode = True
|
||||
if self.sci_mode:
|
||||
for value in nonzero_finite_vals:
|
||||
value_str = f"{{:.{PRINT_OPTS.precision}e}}".format(value)
|
||||
self.max_width = max(self.max_width, len(value_str))
|
||||
@ -195,12 +199,7 @@ class _Formatter:
|
||||
self.max_width = max(self.max_width, len(value_str) + 1)
|
||||
else:
|
||||
# Check if scientific representation should be used.
|
||||
if (
|
||||
nonzero_finite_max / nonzero_finite_min > 1000.0
|
||||
or nonzero_finite_max > 1.0e8
|
||||
or nonzero_finite_min < 1.0e-4
|
||||
):
|
||||
self.sci_mode = True
|
||||
if self.sci_mode:
|
||||
for value in nonzero_finite_vals:
|
||||
value_str = f"{{:.{PRINT_OPTS.precision}e}}".format(value)
|
||||
self.max_width = max(self.max_width, len(value_str))
|
||||
@ -209,9 +208,6 @@ class _Formatter:
|
||||
value_str = f"{{:.{PRINT_OPTS.precision}f}}".format(value)
|
||||
self.max_width = max(self.max_width, len(value_str))
|
||||
|
||||
if PRINT_OPTS.sci_mode is not None:
|
||||
self.sci_mode = PRINT_OPTS.sci_mode
|
||||
|
||||
def width(self):
|
||||
return self.max_width
|
||||
|
||||
|
Reference in New Issue
Block a user