mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[tensorboard] Fix TensorBoard summary encoding for torch.bfloat16 tensors (#108351)
Summary: The `tensor_proto` function in the TensorBoard summary writer code doesn't correctly encode `torch.bfloat16` tensors; it tries to use a data type of `DT_BFLOAT` when creating the protobuf, but `DT_BFLOAT` is not a valid enum value (see `types.proto`). The correct value to use when encoding tensors of this type is `DT_BLOAT16`. This diff updates the type map in the summary code to use the correct type. While fixing this error, I also noticed the wrong field of the protobuf was being used when encoding tensors of this type; per the docs in the proto file, the DT_HALF and DT_BFLOAT16 types should use the `half_val` field, not `float_val`. Since this might confuse folks trying to read this data from storage in the future, I've updated the code to correctly use to `half_val` field for these cases. Note that there's no real size advantage from doing this, since both the `half_val` and `float_val` fields are 32 bits long. Test Plan: Added a parameterized unit test that tests encoding tensors with `torch.half`, `torch.float16`, and `torch.bfloat16` data types. # Before this change The test fails with an `ValueError` due to the incorrect enum label: ``` ====================================================================== ERROR: test_bfloat16_tensor_proto (test_tensorboard.TestTensorProtoSummary) ---------------------------------------------------------------------- Traceback (most recent call last): File "/data/users/jcarreiro/fbsource/buck-out/v2/gen/fbcode/f88b3f368c9334db/caffe2/test/__tensorboard__/tensorboard#link-tree/torch/testing/_internal/common_utils.py", line 2382, in wrapper method(*args, **kwargs) File "/data/users/jcarreiro/fbsource/buck-out/v2/gen/fbcode/f88b3f368c9334db/caffe2/test/__tensorboard__/tensorboard#link-tree/test_tensorboard.py", line 871, in test_bfloat16_tensor_proto tensor_proto( File "/data/users/jcarreiro/fbsource/buck-out/v2/gen/fbcode/f88b3f368c9334db/caffe2/test/__tensorboard__/tensorboard#link-tree/torch/utils/tensorboard/summary.py", line 400, in tensor_proto tensor_proto = TensorProto(**tensor_proto_args) ValueError: unknown enum label "DT_BFLOAT" To execute this test, run the following from the base repo dir: python test/__tensorboard__/tensorboard#link-tree/test_tensorboard.py -k test_bfloat16_tensor_proto This message can be suppressed by setting PYTORCH_PRINT_REPRO_ON_FAILURE=0 ---------------------------------------------------------------------- ``` # After this change The test passes. Reviewed By: tanvigupta17 Differential Revision: D48828958 Pull Request resolved: https://github.com/pytorch/pytorch/pull/108351 Approved by: https://github.com/hamzajzmati, https://github.com/XilunWu
This commit is contained in:
committed by
PyTorch MergeBot
parent
bf5622e965
commit
fa62308673
@ -1,5 +1,6 @@
|
||||
# Owner(s): ["module: unknown"]
|
||||
|
||||
import expecttest
|
||||
import io
|
||||
import numpy as np
|
||||
import os
|
||||
@ -7,7 +8,6 @@ import shutil
|
||||
import sys
|
||||
import tempfile
|
||||
import unittest
|
||||
import expecttest
|
||||
|
||||
TEST_TENSORBOARD = True
|
||||
try:
|
||||
@ -43,7 +43,14 @@ except ImportError:
|
||||
skipIfNoMatplotlib = unittest.skipIf(not TEST_MATPLOTLIB, "no matplotlib")
|
||||
|
||||
import torch
|
||||
from torch.testing._internal.common_utils import TestCase, run_tests, TEST_WITH_ASAN, TEST_WITH_CROSSREF
|
||||
from torch.testing._internal.common_utils import (
|
||||
instantiate_parametrized_tests,
|
||||
parametrize,
|
||||
TestCase,
|
||||
run_tests,
|
||||
TEST_WITH_ASAN,
|
||||
TEST_WITH_CROSSREF,
|
||||
)
|
||||
|
||||
def tensor_N(shape, dtype=float):
|
||||
numel = np.prod(shape)
|
||||
@ -80,7 +87,7 @@ if TEST_TENSORBOARD:
|
||||
from torch.utils.tensorboard import summary, SummaryWriter
|
||||
from torch.utils.tensorboard._utils import _prepare_video, convert_to_HWC
|
||||
from tensorboard.compat.proto.types_pb2 import DataType
|
||||
from torch.utils.tensorboard.summary import tensor_proto
|
||||
from torch.utils.tensorboard.summary import int_to_half, tensor_proto
|
||||
from torch.utils.tensorboard._convert_np import make_np
|
||||
from torch.utils.tensorboard._pytorch_graph import graph
|
||||
from google.protobuf import text_format
|
||||
@ -865,6 +872,25 @@ class TestTensorBoardNumpy(BaseTestCase):
|
||||
compare_proto(graph, self)
|
||||
|
||||
class TestTensorProtoSummary(BaseTestCase):
|
||||
@parametrize(
|
||||
"tensor_type,proto_type",
|
||||
[
|
||||
(torch.float16, DataType.DT_HALF),
|
||||
(torch.bfloat16, DataType.DT_BFLOAT16),
|
||||
],
|
||||
)
|
||||
def test_half_tensor_proto(self, tensor_type, proto_type):
|
||||
float_values = [1.0, 2.0, 3.0]
|
||||
actual_proto = tensor_proto(
|
||||
"dummy",
|
||||
torch.tensor(float_values, dtype=tensor_type),
|
||||
).value[0].tensor
|
||||
self.assertSequenceEqual(
|
||||
[int_to_half(x) for x in actual_proto.half_val],
|
||||
float_values,
|
||||
)
|
||||
self.assertTrue(actual_proto.dtype == proto_type)
|
||||
|
||||
def test_float_tensor_proto(self):
|
||||
float_values = [1.0, 2.0, 3.0]
|
||||
actual_proto = (
|
||||
@ -902,5 +928,7 @@ class TestTensorProtoSummary(BaseTestCase):
|
||||
actual_proto = tensor_proto("dummy", torch.empty(0)).value[0].tensor
|
||||
self.assertEqual(actual_proto.float_val, [])
|
||||
|
||||
instantiate_parametrized_tests(TestTensorProtoSummary)
|
||||
|
||||
if __name__ == '__main__':
|
||||
run_tests()
|
||||
|
Reference in New Issue
Block a user