mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[tensorboard] Handle bfloat16 type in add_histogram (#120087)
Summary: add_histogram fails for this data type. Updating conversion code to handle it. Stack trace for the failure - ` [trainer0]Traceback (most recent call last): [trainer0] File "<torch_package_0>.tensorboard/logging/summary_v2.py", line 203, in unscriptable_record_summary [trainer0] unscriptable_histogram(name, t, step, ranks) [trainer0] File "<torch_package_0>.tensorboard/logging/fx_v1.py", line 146, in unscriptable_histogram [trainer0] Adhoc.writer().add_histogram(tag, x, step.int()) [trainer0] File "/tmp/aienv/images/aienv_image_09slg3j1/torch/utils/tensorboard/writer.py", line 40, in wrapper [trainer0] resp = super_method(*args, **kwargs) [trainer0] File "/tmp/aienv/images/aienv_image_09slg3j1/torch/utils/tensorboard/writer_oss.py", line 526, in add_histogram [trainer0] histogram(tag, values, bins, max_bins=max_bins), global_step, walltime [trainer0] File "/tmp/aienv/images/aienv_image_09slg3j1/torch/utils/tensorboard/summary.py", line 482, in histogram [trainer0] values = make_np(values) [trainer0] File "/tmp/aienv/images/aienv_image_09slg3j1/torch/utils/tensorboard/_convert_np.py", line 23, in make_np [trainer0] return _prepare_pytorch(x) [trainer0] File "/tmp/aienv/images/aienv_image_09slg3j1/torch/utils/tensorboard/_convert_np.py", line 30, in _prepare_pytorch [trainer0] x = x.detach().cpu().numpy() [trainer0]TypeError: Got unsupported ScalarType BFloat16 ` Test Plan: Updated unit test that was failing before but passes after this change. Reviewed By: hamzajzmati, jcarreiro Differential Revision: D53841197 Pull Request resolved: https://github.com/pytorch/pytorch/pull/120087 Approved by: https://github.com/jcarreiro, https://github.com/yanboliang
This commit is contained in:
committed by
PyTorch MergeBot
parent
a3a8137484
commit
bcf35c6ae6
@ -131,6 +131,7 @@ class TestTensorBoardPyTorchNumpy(BaseTestCase):
|
||||
with self.createSummaryWriter() as w:
|
||||
w.add_histogram('float histogram', torch.rand((50,)))
|
||||
w.add_histogram('int histogram', torch.randint(0, 100, (50,)))
|
||||
w.add_histogram('bfloat16 histogram', torch.rand(50, dtype=torch.bfloat16))
|
||||
|
||||
def test_pytorch_histogram_raw(self):
|
||||
with self.createSummaryWriter() as w:
|
||||
|
Reference in New Issue
Block a user