mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
fix add_histogram_raw (#20688)
Summary: This is a porting of the fix from: https://github.com/lanpa/tensorboardX/issues/421 cc orionr Pull Request resolved: https://github.com/pytorch/pytorch/pull/20688 Reviewed By: NarineK Differential Revision: D15415093 Pulled By: orionr fbshipit-source-id: d32a6298218fbc6fe315aa0f18b57e0c8ef92627
This commit is contained in:
committed by
Facebook Github Bot
parent
fd2aa93b37
commit
cfc98ae714
BIN
docs/source/_static/img/tensorboard/add_histogram_raw.png
Normal file
BIN
docs/source/_static/img/tensorboard/add_histogram_raw.png
Normal file
Binary file not shown.
After Width: | Height: | Size: 27 KiB |
@ -86,6 +86,7 @@ Expected result:
|
||||
.. automethod:: add_scalar
|
||||
.. automethod:: add_scalars
|
||||
.. automethod:: add_histogram
|
||||
.. automethod:: add_histogram_raw
|
||||
.. automethod:: add_image
|
||||
.. automethod:: add_images
|
||||
.. automethod:: add_figure
|
||||
|
@ -105,7 +105,7 @@ if TEST_TENSORBOARD:
|
||||
num=num,
|
||||
sum=floats.sum().item(),
|
||||
sum_squares=sum_sq,
|
||||
bucket_limits=limits.tolist(),
|
||||
bucket_limits=limits[1:].tolist(),
|
||||
bucket_counts=counts.tolist())
|
||||
|
||||
ints = make_np(torch.randint(0, 100, (num,)))
|
||||
@ -118,7 +118,7 @@ if TEST_TENSORBOARD:
|
||||
num=num,
|
||||
sum=ints.sum().item(),
|
||||
sum_squares=sum_sq,
|
||||
bucket_limits=limits.tolist(),
|
||||
bucket_limits=limits[1:].tolist(),
|
||||
bucket_counts=counts.tolist())
|
||||
|
||||
ints = torch.tensor(range(0, 100)).float()
|
||||
|
@ -385,13 +385,48 @@ class SummaryWriter(object):
|
||||
num (int): Number of values
|
||||
sum (float or int): Sum of all values
|
||||
sum_squares (float or int): Sum of squares for all values
|
||||
bucket_limits (torch.Tensor, numpy.array): Upper value per bucket
|
||||
bucket_limits (torch.Tensor, numpy.array): Upper value per bucket.
|
||||
The number of elements of it should be the same as `bucket_counts`.
|
||||
bucket_counts (torch.Tensor, numpy.array): Number of values per bucket
|
||||
global_step (int): Global step value to record
|
||||
walltime (float): Optional override default walltime (time.time())
|
||||
seconds after epoch of event
|
||||
see: https://github.com/tensorflow/tensorboard/blob/master/tensorboard/plugins/histogram/README.md
|
||||
|
||||
Examples::
|
||||
|
||||
from torch.utils.tensorboard import SummaryWriter
|
||||
import numpy as np
|
||||
writer = SummaryWriter()
|
||||
dummy_data = []
|
||||
for idx, value in enumerate(range(50)):
|
||||
dummy_data += [idx + 0.001] * value
|
||||
|
||||
bins = list(range(50+2))
|
||||
bins = np.array(bins)
|
||||
values = np.array(dummy_data).astype(float).reshape(-1)
|
||||
counts, limits = np.histogram(values, bins=bins)
|
||||
sum_sq = values.dot(values)
|
||||
writer.add_histogram_raw(
|
||||
tag='histogram_with_raw_data',
|
||||
min=values.min(),
|
||||
max=values.max(),
|
||||
num=len(values),
|
||||
sum=values.sum(),
|
||||
sum_squares=sum_sq,
|
||||
bucket_limits=limits[1:].tolist(),
|
||||
bucket_counts=counts.tolist(),
|
||||
global_step=0)
|
||||
writer.close()
|
||||
|
||||
Expected result:
|
||||
|
||||
.. image:: _static/img/tensorboard/add_histogram_raw.png
|
||||
:scale: 50 %
|
||||
|
||||
"""
|
||||
if len(bucket_limits) != len(bucket_counts):
|
||||
raise ValueError('len(bucket_limits) != len(bucket_counts), see the document.')
|
||||
self._get_file_writer().add_summary(
|
||||
histogram_raw(tag,
|
||||
min,
|
||||
|
Reference in New Issue
Block a user