[tensorboard] Handle bfloat16 type in add_histogram (#120087)

Summary:
add_histogram fails for this data type. Updating conversion code to handle it.

Stack trace for the failure -

`
[trainer0]Traceback (most recent call last):
[trainer0]  File "<torch_package_0>.tensorboard/logging/summary_v2.py", line 203, in unscriptable_record_summary
[trainer0]    unscriptable_histogram(name, t, step, ranks)
[trainer0]  File "<torch_package_0>.tensorboard/logging/fx_v1.py", line 146, in unscriptable_histogram
[trainer0]    Adhoc.writer().add_histogram(tag, x, step.int())
[trainer0]  File "/tmp/aienv/images/aienv_image_09slg3j1/torch/utils/tensorboard/writer.py", line 40, in wrapper
[trainer0]    resp = super_method(*args, **kwargs)
[trainer0]  File "/tmp/aienv/images/aienv_image_09slg3j1/torch/utils/tensorboard/writer_oss.py", line 526, in add_histogram
[trainer0]    histogram(tag, values, bins, max_bins=max_bins), global_step, walltime
[trainer0]  File "/tmp/aienv/images/aienv_image_09slg3j1/torch/utils/tensorboard/summary.py", line 482, in histogram
[trainer0]    values = make_np(values)
[trainer0]  File "/tmp/aienv/images/aienv_image_09slg3j1/torch/utils/tensorboard/_convert_np.py", line 23, in make_np
[trainer0]    return _prepare_pytorch(x)
[trainer0]  File "/tmp/aienv/images/aienv_image_09slg3j1/torch/utils/tensorboard/_convert_np.py", line 30, in _prepare_pytorch
[trainer0]    x = x.detach().cpu().numpy()
[trainer0]TypeError: Got unsupported ScalarType BFloat16
`

Test Plan: Updated unit test that was failing before but passes after this change.

Reviewed By: hamzajzmati, jcarreiro

Differential Revision: D53841197

Pull Request resolved: https://github.com/pytorch/pytorch/pull/120087
Approved by: https://github.com/jcarreiro, https://github.com/yanboliang
This commit is contained in:
Vikram Srivastava
2024-03-05 00:27:17 +00:00
committed by PyTorch MergeBot
parent a3a8137484
commit bcf35c6ae6
2 changed files with 3 additions and 0 deletions

View File

@ -131,6 +131,7 @@ class TestTensorBoardPyTorchNumpy(BaseTestCase):
with self.createSummaryWriter() as w:
w.add_histogram('float histogram', torch.rand((50,)))
w.add_histogram('int histogram', torch.randint(0, 100, (50,)))
w.add_histogram('bfloat16 histogram', torch.rand(50, dtype=torch.bfloat16))
def test_pytorch_histogram_raw(self):
with self.createSummaryWriter() as w: