mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-21 05:34:18 +08:00
fix add_histogram_raw (#20688)
Summary: This is a porting of the fix from: https://github.com/lanpa/tensorboardX/issues/421 cc orionr Pull Request resolved: https://github.com/pytorch/pytorch/pull/20688 Reviewed By: NarineK Differential Revision: D15415093 Pulled By: orionr fbshipit-source-id: d32a6298218fbc6fe315aa0f18b57e0c8ef92627
This commit is contained in:
committed by
Facebook Github Bot
parent
fd2aa93b37
commit
cfc98ae714
BIN
docs/source/_static/img/tensorboard/add_histogram_raw.png
Normal file
BIN
docs/source/_static/img/tensorboard/add_histogram_raw.png
Normal file
Binary file not shown.
After Width: | Height: | Size: 27 KiB |
@ -86,6 +86,7 @@ Expected result:
|
|||||||
.. automethod:: add_scalar
|
.. automethod:: add_scalar
|
||||||
.. automethod:: add_scalars
|
.. automethod:: add_scalars
|
||||||
.. automethod:: add_histogram
|
.. automethod:: add_histogram
|
||||||
|
.. automethod:: add_histogram_raw
|
||||||
.. automethod:: add_image
|
.. automethod:: add_image
|
||||||
.. automethod:: add_images
|
.. automethod:: add_images
|
||||||
.. automethod:: add_figure
|
.. automethod:: add_figure
|
||||||
|
@ -105,7 +105,7 @@ if TEST_TENSORBOARD:
|
|||||||
num=num,
|
num=num,
|
||||||
sum=floats.sum().item(),
|
sum=floats.sum().item(),
|
||||||
sum_squares=sum_sq,
|
sum_squares=sum_sq,
|
||||||
bucket_limits=limits.tolist(),
|
bucket_limits=limits[1:].tolist(),
|
||||||
bucket_counts=counts.tolist())
|
bucket_counts=counts.tolist())
|
||||||
|
|
||||||
ints = make_np(torch.randint(0, 100, (num,)))
|
ints = make_np(torch.randint(0, 100, (num,)))
|
||||||
@ -118,7 +118,7 @@ if TEST_TENSORBOARD:
|
|||||||
num=num,
|
num=num,
|
||||||
sum=ints.sum().item(),
|
sum=ints.sum().item(),
|
||||||
sum_squares=sum_sq,
|
sum_squares=sum_sq,
|
||||||
bucket_limits=limits.tolist(),
|
bucket_limits=limits[1:].tolist(),
|
||||||
bucket_counts=counts.tolist())
|
bucket_counts=counts.tolist())
|
||||||
|
|
||||||
ints = torch.tensor(range(0, 100)).float()
|
ints = torch.tensor(range(0, 100)).float()
|
||||||
|
@ -385,13 +385,48 @@ class SummaryWriter(object):
|
|||||||
num (int): Number of values
|
num (int): Number of values
|
||||||
sum (float or int): Sum of all values
|
sum (float or int): Sum of all values
|
||||||
sum_squares (float or int): Sum of squares for all values
|
sum_squares (float or int): Sum of squares for all values
|
||||||
bucket_limits (torch.Tensor, numpy.array): Upper value per bucket
|
bucket_limits (torch.Tensor, numpy.array): Upper value per bucket.
|
||||||
|
The number of elements of it should be the same as `bucket_counts`.
|
||||||
bucket_counts (torch.Tensor, numpy.array): Number of values per bucket
|
bucket_counts (torch.Tensor, numpy.array): Number of values per bucket
|
||||||
global_step (int): Global step value to record
|
global_step (int): Global step value to record
|
||||||
walltime (float): Optional override default walltime (time.time())
|
walltime (float): Optional override default walltime (time.time())
|
||||||
seconds after epoch of event
|
seconds after epoch of event
|
||||||
see: https://github.com/tensorflow/tensorboard/blob/master/tensorboard/plugins/histogram/README.md
|
see: https://github.com/tensorflow/tensorboard/blob/master/tensorboard/plugins/histogram/README.md
|
||||||
|
|
||||||
|
Examples::
|
||||||
|
|
||||||
|
from torch.utils.tensorboard import SummaryWriter
|
||||||
|
import numpy as np
|
||||||
|
writer = SummaryWriter()
|
||||||
|
dummy_data = []
|
||||||
|
for idx, value in enumerate(range(50)):
|
||||||
|
dummy_data += [idx + 0.001] * value
|
||||||
|
|
||||||
|
bins = list(range(50+2))
|
||||||
|
bins = np.array(bins)
|
||||||
|
values = np.array(dummy_data).astype(float).reshape(-1)
|
||||||
|
counts, limits = np.histogram(values, bins=bins)
|
||||||
|
sum_sq = values.dot(values)
|
||||||
|
writer.add_histogram_raw(
|
||||||
|
tag='histogram_with_raw_data',
|
||||||
|
min=values.min(),
|
||||||
|
max=values.max(),
|
||||||
|
num=len(values),
|
||||||
|
sum=values.sum(),
|
||||||
|
sum_squares=sum_sq,
|
||||||
|
bucket_limits=limits[1:].tolist(),
|
||||||
|
bucket_counts=counts.tolist(),
|
||||||
|
global_step=0)
|
||||||
|
writer.close()
|
||||||
|
|
||||||
|
Expected result:
|
||||||
|
|
||||||
|
.. image:: _static/img/tensorboard/add_histogram_raw.png
|
||||||
|
:scale: 50 %
|
||||||
|
|
||||||
"""
|
"""
|
||||||
|
if len(bucket_limits) != len(bucket_counts):
|
||||||
|
raise ValueError('len(bucket_limits) != len(bucket_counts), see the document.')
|
||||||
self._get_file_writer().add_summary(
|
self._get_file_writer().add_summary(
|
||||||
histogram_raw(tag,
|
histogram_raw(tag,
|
||||||
min,
|
min,
|
||||||
|
Reference in New Issue
Block a user