[tensorboard] Fix TensorBoard summary encoding for torch.bfloat16 tensors (#108351)

Summary:
The `tensor_proto` function in the TensorBoard summary writer code doesn't correctly encode `torch.bfloat16` tensors; it tries to use a data type of `DT_BFLOAT` when creating the protobuf, but `DT_BFLOAT` is not a valid enum value (see `types.proto`). The correct value to use when encoding tensors of this type is `DT_BLOAT16`. This diff updates the type map in the summary code to use the correct type.

While fixing this error, I also noticed the wrong field of the protobuf was being used when encoding tensors of this type; per the docs in the proto file, the DT_HALF and DT_BFLOAT16 types should use the `half_val` field, not `float_val`. Since this might confuse folks trying to read this data from storage in the future, I've updated the code to correctly use to `half_val` field for these cases. Note that there's no real size advantage from doing this, since both the `half_val` and `float_val` fields are 32 bits long.

Test Plan:
Added a parameterized unit test that tests encoding tensors with `torch.half`, `torch.float16`, and `torch.bfloat16` data types.

# Before this change
The test fails with an `ValueError` due to the incorrect enum label:
```
======================================================================
ERROR: test_bfloat16_tensor_proto (test_tensorboard.TestTensorProtoSummary)
----------------------------------------------------------------------
Traceback (most recent call last):
  File "/data/users/jcarreiro/fbsource/buck-out/v2/gen/fbcode/f88b3f368c9334db/caffe2/test/__tensorboard__/tensorboard#link-tree/torch/testing/_internal/common_utils.py", line 2382, in wrapper
    method(*args, **kwargs)
  File "/data/users/jcarreiro/fbsource/buck-out/v2/gen/fbcode/f88b3f368c9334db/caffe2/test/__tensorboard__/tensorboard#link-tree/test_tensorboard.py", line 871, in test_bfloat16_tensor_proto
    tensor_proto(
  File "/data/users/jcarreiro/fbsource/buck-out/v2/gen/fbcode/f88b3f368c9334db/caffe2/test/__tensorboard__/tensorboard#link-tree/torch/utils/tensorboard/summary.py", line 400, in tensor_proto
    tensor_proto = TensorProto(**tensor_proto_args)
ValueError: unknown enum label "DT_BFLOAT"

To execute this test, run the following from the base repo dir:
     python test/__tensorboard__/tensorboard#link-tree/test_tensorboard.py -k test_bfloat16_tensor_proto

This message can be suppressed by setting PYTORCH_PRINT_REPRO_ON_FAILURE=0

----------------------------------------------------------------------
```

# After this change
The test passes.

Reviewed By: tanvigupta17

Differential Revision: D48828958

Pull Request resolved: https://github.com/pytorch/pytorch/pull/108351
Approved by: https://github.com/hamzajzmati, https://github.com/XilunWu
This commit is contained in:
Jason Carreiro
2023-09-14 23:12:22 +00:00
committed by PyTorch MergeBot
parent bf5622e965
commit fa62308673
2 changed files with 102 additions and 54 deletions

View File

@ -1,5 +1,6 @@
# Owner(s): ["module: unknown"]
import expecttest
import io
import numpy as np
import os
@ -7,7 +8,6 @@ import shutil
import sys
import tempfile
import unittest
import expecttest
TEST_TENSORBOARD = True
try:
@ -43,7 +43,14 @@ except ImportError:
skipIfNoMatplotlib = unittest.skipIf(not TEST_MATPLOTLIB, "no matplotlib")
import torch
from torch.testing._internal.common_utils import TestCase, run_tests, TEST_WITH_ASAN, TEST_WITH_CROSSREF
from torch.testing._internal.common_utils import (
instantiate_parametrized_tests,
parametrize,
TestCase,
run_tests,
TEST_WITH_ASAN,
TEST_WITH_CROSSREF,
)
def tensor_N(shape, dtype=float):
numel = np.prod(shape)
@ -80,7 +87,7 @@ if TEST_TENSORBOARD:
from torch.utils.tensorboard import summary, SummaryWriter
from torch.utils.tensorboard._utils import _prepare_video, convert_to_HWC
from tensorboard.compat.proto.types_pb2 import DataType
from torch.utils.tensorboard.summary import tensor_proto
from torch.utils.tensorboard.summary import int_to_half, tensor_proto
from torch.utils.tensorboard._convert_np import make_np
from torch.utils.tensorboard._pytorch_graph import graph
from google.protobuf import text_format
@ -865,6 +872,25 @@ class TestTensorBoardNumpy(BaseTestCase):
compare_proto(graph, self)
class TestTensorProtoSummary(BaseTestCase):
@parametrize(
"tensor_type,proto_type",
[
(torch.float16, DataType.DT_HALF),
(torch.bfloat16, DataType.DT_BFLOAT16),
],
)
def test_half_tensor_proto(self, tensor_type, proto_type):
float_values = [1.0, 2.0, 3.0]
actual_proto = tensor_proto(
"dummy",
torch.tensor(float_values, dtype=tensor_type),
).value[0].tensor
self.assertSequenceEqual(
[int_to_half(x) for x in actual_proto.half_val],
float_values,
)
self.assertTrue(actual_proto.dtype == proto_type)
def test_float_tensor_proto(self):
float_values = [1.0, 2.0, 3.0]
actual_proto = (
@ -902,5 +928,7 @@ class TestTensorProtoSummary(BaseTestCase):
actual_proto = tensor_proto("dummy", torch.empty(0)).value[0].tensor
self.assertEqual(actual_proto.float_val, [])
instantiate_parametrized_tests(TestTensorProtoSummary)
if __name__ == '__main__':
run_tests()