Clarify API and add examples for all methods (#20008)

Summary:
As a part of supporting writing data into TensorBoard readable format, we show more example on how to use the function in addition to the API docs.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/20008

Reviewed By: natalialunova

Differential Revision: D15261502

Pulled By: orionr

fbshipit-source-id: 16611695a27e74bfcdf311e7cad40196e0947038
This commit is contained in:
Tzu-Wei Huang
2019-05-08 14:03:04 -07:00
committed by Facebook Github Bot
parent 4a086f700f
commit 7edf9a25e8
9 changed files with 211 additions and 46 deletions

Binary file not shown.

After

Width:  |  Height:  |  Size: 47 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 46 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 75 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 45 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 97 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 157 KiB

View File

@ -48,13 +48,46 @@ and runnable with::
pip install tb-nightly # Until 1.14 moves to the release channel
tensorboard --logdir=runs
Lots of information can be logged for one experiment. To avoid cluttering
the UI and have better result clustering, we can group plots by naming them
hierarchically. For example, "Loss/train" and "Loss/test" will be grouped
together, while "Accuracy/train" and "Accuracy/test" will be grouped separately
in the TensorBoard interface.
.. code:: python
from torch.utils.tensorboard import SummaryWriter
import numpy as np
writer = SummaryWriter()
for n_iter in range(100):
writer.add_scalar('Loss/train', np.random.random(), n_iter)
writer.add_scalar('Loss/test', np.random.random(), n_iter)
writer.add_scalar('Accuracy/train', np.random.random(), n_iter)
writer.add_scalar('Accuracy/test', np.random.random(), n_iter)
Expected result:
.. image:: _static/img/tensorboard/hier_tags.png
:scale: 75 %
|
|
.. currentmodule:: torch.utils.tensorboard.writer
.. autoclass:: SummaryWriter
.. automethod:: __init__
.. automethod:: add_scalar
.. automethod:: add_scalars
.. automethod:: add_histogram
.. automethod:: add_image
.. automethod:: add_images
.. automethod:: add_figure
.. automethod:: add_video
.. automethod:: add_audio

View File

@ -170,6 +170,12 @@ class GraphPy(object):
self.nodes_io[key].uniqueName = self.unique_name_to_scoped_name[node.uniqueName]
def to_proto(self):
"""
Converts graph representation of GraphPy object to TensorBoard
required format.
"""
# TODO: compute correct memory usage and CPU time once
# PyTorch supports it
nodes = []
node_stats = []
for v in self.nodes_io.values():

View File

@ -35,11 +35,7 @@ class FileWriter(object):
training.
"""
def __init__(self,
logdir,
max_queue=10,
flush_secs=120,
filename_suffix=''):
def __init__(self, logdir, max_queue=10, flush_secs=120, filename_suffix=''):
"""Creates a `FileWriter` and an event file.
On construction the writer creates a new event file in `logdir`.
The other arguments to the constructor control the asynchronous writes to
@ -49,10 +45,11 @@ class FileWriter(object):
logdir: A string. Directory where event file will be written.
max_queue: Integer. Size of the queue for pending events and
summaries before one of the 'add' calls forces a flush to disk.
Default is ten items.
flush_secs: Number. How often, in seconds, to flush the
pending events and summaries to disk.
filename_suffix: A string. Suffix added to all event filenames.
More details on event filename construction in
pending events and summaries to disk. Default is every two minutes.
filename_suffix: A string. Suffix added to all event filenames
in the logdir directory. More details on filename construction in
tensorboard.summary.writer.event_file_writer.EventFileWriter.
"""
# Sometimes PosixPath is passed in and we need to coerce it to
@ -150,7 +147,7 @@ class FileWriter(object):
class SummaryWriter(object):
"""Writes entries directly to event files in the log_dir to be
"""Writes entries directly to event files in the logdir to be
consumed by TensorBoard.
The `SummaryWriter` class provides a high-level API to create an event file
@ -160,37 +157,61 @@ class SummaryWriter(object):
training.
"""
def __init__(self, log_dir=None, comment='', **kwargs):
def __init__(self, logdir=None, comment='', purge_step=None, max_queue=10,
flush_secs=120, filename_suffix=''):
"""Creates a `SummaryWriter` that will write out events and summaries
to the event file.
Args:
log_dir (string): save location, default is: runs/**CURRENT_DATETIME_HOSTNAME**, which changes after each
run. Use hierarchical folder structure to compare between runs easily. e.g. pass in
'runs/exp1', 'runs/exp2', etc. for each new experiment to compare across. Defaults
to ``./runs/``.
comment (string): comment that appends to the default ``log_dir``. If ``log_dir`` is assigned,
this argument will no effect.
logdir (string): Save directory location. Default is
runs/**CURRENT_DATETIME_HOSTNAME**, which changes after each run.
Use hierarchical folder structure to compare
between runs easily. e.g. pass in 'runs/exp1', 'runs/exp2', etc.
for each new experiment to compare across them.
comment (string): Comment logdir suffix appended to the default
``logdir``. If ``logdir`` is assigned, this argument has no effect.
purge_step (int):
When logging crashes at step :math:`T+X` and restarts at step :math:`T`, any events
whose global_step larger or equal to :math:`T` will be purged and hidden from TensorBoard.
Note that the resumed experiment and crashed experiment should have the same ``log_dir``.
filename_suffix (string):
Every event file's name is suffixed with suffix. Example: ``SummaryWriter(filename_suffix='.123')``
More details on event filename construction in
When logging crashes at step :math:`T+X` and restarts at step :math:`T`,
any events whose global_step larger or equal to :math:`T` will be
purged and hidden from TensorBoard.
Note that crashed and resumed experiments should have the same ``logdir``.
max_queue (int): Size of the queue for pending events and
summaries before one of the 'add' calls forces a flush to disk.
Default is ten items.
flush_secs (int): How often, in seconds, to flush the
pending events and summaries to disk. Default is every two minutes.
filename_suffix (string): Suffix added to all event filenames in
the logdir directory. More details on filename construction in
tensorboard.summary.writer.event_file_writer.EventFileWriter.
kwargs: extra keyword arguments for FileWriter (e.g. 'flush_secs'
controls how often to flush pending events). For more arguments
please refer to docs for 'tf.summary.FileWriter'.
Examples::
from torch.utils.tensorboard import SummaryWriter
# create a summary writer with automatically generated folder name.
writer = SummaryWriter()
# folder location: runs/May04_22-14-54_s-MacBook-Pro.local/
# create a summary writer using the specified folder name.
writer = SummaryWriter("my_experiment")
# folder location: my_experiment
# create a summary writer with comment appended.
writer = SummaryWriter(comment="LR_0.1_BATCH_16")
# folder location: runs/May04_22-14-54_s-MacBook-Pro.localLR_0.1_BATCH_16/
"""
if not log_dir:
if not logdir:
import socket
from datetime import datetime
current_time = datetime.now().strftime('%b%d_%H-%M-%S')
log_dir = os.path.join(
logdir = os.path.join(
'runs', current_time + '_' + socket.gethostname() + comment)
self.log_dir = log_dir
self.kwargs = kwargs
self.logdir = logdir
self.purge_step = purge_step
self.max_queue = max_queue
self.flush_secs = flush_secs
self.filename_suffix = filename_suffix
# Initialize the file writers, but they can be cleared out on close
# and recreated later as needed.
@ -225,16 +246,16 @@ class SummaryWriter(object):
def _get_file_writer(self):
"""Returns the default FileWriter instance. Recreates it if closed."""
if self.all_writers is None or self.file_writer is None:
if 'purge_step' in self.kwargs.keys():
most_recent_step = self.kwargs.pop('purge_step')
self.file_writer = FileWriter(logdir=self.log_dir, **self.kwargs)
self.file_writer = FileWriter(self.logdir, self.max_queue,
self.flush_secs, self.filename_suffix)
self.all_writers = {self.file_writer.get_logdir(): self.file_writer}
if self.purge_step is not None:
most_recent_step = self.purge_step
self.file_writer.add_event(
Event(step=most_recent_step, file_version='brain.Event:2'))
self.file_writer.add_event(
Event(step=most_recent_step, session_log=SessionLog(status=SessionLog.START)))
else:
self.file_writer = FileWriter(logdir=self.log_dir, **self.kwargs)
self.all_writers = {self.file_writer.get_logdir(): self.file_writer}
self.purge_step = None
return self.file_writer
def add_scalar(self, tag, scalar_value, global_step=None, walltime=None):
@ -246,6 +267,21 @@ class SummaryWriter(object):
global_step (int): Global step value to record
walltime (float): Optional override default walltime (time.time())
with seconds after epoch of event
Examples::
from torch.utils.tensorboard import SummaryWriter
writer = SummaryWriter()
x = range(100)
for i in x:
writer.add_scalar('y=2x', i * 2, i)
writer.close()
Expected result:
.. image:: _static/img/tensorboard/add_scalar.png
:scale: 50 %
"""
if self._check_caffe2_blob(scalar_value):
scalar_value = workspace.FetchBlob(scalar_value)
@ -266,11 +302,22 @@ class SummaryWriter(object):
Examples::
writer.add_scalars('run_14h', {'xsinx':i*np.sin(i/r),
'xcosx':i*np.cos(i/r),
'arctanx': numsteps*np.arctan(i/r)}, i)
from torch.utils.tensorboard import SummaryWriter
writer = SummaryWriter()
r = 5
for i in range(100):
writer.add_scalars('run_14h', {'xsinx':i*np.sin(i/r),
'xcosx':i*np.cos(i/r),
'tanx': np.tan(i/r)}, i)
writer.close()
# This call adds three values to the same scalar plot with the tag
# 'run_14h' in TensorBoard's scalar section.
Expected result:
.. image:: _static/img/tensorboard/add_scalars.png
:scale: 50 %
"""
walltime = time.time() if walltime is None else walltime
fw_logdir = self._get_file_writer().get_logdir()
@ -279,7 +326,8 @@ class SummaryWriter(object):
if fw_tag in self.all_writers.keys():
fw = self.all_writers[fw_tag]
else:
fw = FileWriter(logdir=fw_tag)
fw = FileWriter(fw_tag, self.max_queue, self.flush_secs,
self.filename_suffix)
self.all_writers[fw_tag] = fw
if self._check_caffe2_blob(scalar_value):
scalar_value = workspace.FetchBlob(scalar_value)
@ -293,10 +341,26 @@ class SummaryWriter(object):
tag (string): Data identifier
values (torch.Tensor, numpy.array, or string/blobname): Values to build histogram
global_step (int): Global step value to record
bins (string): one of {'tensorflow','auto', 'fd', ...}, this determines how the bins are made. You can find
bins (string): One of {'tensorflow','auto', 'fd', ...}. This determines how the bins are made. You can find
other options in: https://docs.scipy.org/doc/numpy/reference/generated/numpy.histogram.html
walltime (float): Optional override default walltime (time.time())
seconds after epoch of event
Examples::
from torch.utils.tensorboard import SummaryWriter
import numpy as np
writer = SummaryWriter()
for i in range(10):
x = np.random.random(1000)
writer.add_histogram('distribution centers', x + i, i)
writer.close()
Expected result:
.. image:: _static/img/tensorboard/add_histogram.png
:scale: 50 %
"""
if self._check_caffe2_blob(values):
values = workspace.FetchBlob(values)
@ -352,6 +416,31 @@ class SummaryWriter(object):
convert a batch of tensor into 3xHxW format or call ``add_images`` and let us do the job.
Tensor with :math:`(1, H, W)`, :math:`(H, W)`, :math:`(H, W, 3)` is also suitible as long as
corresponding ``dataformats`` argument is passed. e.g. CHW, HWC, HW.
Examples::
from torch.utils.tensorboard import SummaryWriter
import numpy as np
img = np.zeros((3, 100, 100))
img[0] = np.arange(0, 10000).reshape(100, 100) / 10000
img[1] = 1 - np.arange(0, 10000).reshape(100, 100) / 10000
img_HWC = np.zeros((100, 100, 3))
img_HWC[:, :, 0] = np.arange(0, 10000).reshape(100, 100) / 10000
img_HWC[:, :, 1] = 1 - np.arange(0, 10000).reshape(100, 100) / 10000
writer = SummaryWriter()
writer.add_image('my_image', img, 0)
# If you have non-default dimension setting, set the dataformats argument.
writer.add_image('my_image_HWC', img_HWC, 0, dataformats='HWC')
writer.close()
Expected result:
.. image:: _static/img/tensorboard/add_image.png
:scale: 50 %
"""
if self._check_caffe2_blob(img_tensor):
img_tensor = workspace.FetchBlob(img_tensor)
@ -372,6 +461,26 @@ class SummaryWriter(object):
Shape:
img_tensor: Default is :math:`(N, 3, H, W)`. If ``dataformats`` is specified, other shape will be
accepted. e.g. NCHW or NHWC.
Examples::
from torch.utils.tensorboard import SummaryWriter
import numpy as np
img_batch = np.zeros((16, 3, 100, 100))
for i in range(16):
img_batch[i, 0] = np.arange(0, 10000).reshape(100, 100) / 10000 / 16 * i
img_batch[i, 1] = (1 - np.arange(0, 10000).reshape(100, 100) / 10000) / 16 * i
writer = SummaryWriter()
writer.add_images('my_image_batch', img_batch, 0)
writer.close()
Expected result:
.. image:: _static/img/tensorboard/add_images.png
:scale: 30 %
"""
if self._check_caffe2_blob(img_tensor):
img_tensor = workspace.FetchBlob(img_tensor)
@ -410,7 +519,7 @@ class SummaryWriter(object):
Args:
tag (string): Data identifier
figure (matplotlib.pyplot.figure) or list of figures: figure or a list of figures
figure (matplotlib.pyplot.figure) or list of figures: Figure or a list of figures
global_step (int): Global step value to record
close (bool): Flag to automatically close the figure
walltime (float): Optional override default walltime (time.time())
@ -483,14 +592,14 @@ class SummaryWriter(object):
"""Add graph data to summary.
Args:
model (torch.nn.Module): model to draw.
input_to_model (torch.Tensor or list of torch.Tensor): a variable or a tuple of
model (torch.nn.Module): Model to draw.
input_to_model (torch.Tensor or list of torch.Tensor): A variable or a tuple of
variables to be fed.
verbose (bool): Whether to print graph structure in console.
omit_useless_nodes (bool): Default to ``true``, which eliminates unused nodes.
operator_export_type (string): One of: ``"ONNX"``, ``"RAW"``. This determines
the optimization level of the graph. If error happens during exporting
the graph, use ``"RAW"`` may help.
the graph, using ``"RAW"`` might help.
"""
if hasattr(model, 'forward'):
@ -601,17 +710,34 @@ class SummaryWriter(object):
def add_pr_curve(self, tag, labels, predictions, global_step=None,
num_thresholds=127, weights=None, walltime=None):
"""Adds precision recall curve.
Plotting a precision-recall curve lets you understand your model's
performance under different threshold settings. With this function,
you provide the ground truth labeling (T/F) and prediction confidence
(usually the output of your model) for each target. The TensorBoard UI
will let you choose the threshold interactively.
Args:
tag (string): Data identifier
labels (torch.Tensor, numpy.array, or string/blobname): Ground truth data. Binary label for each element.
labels (torch.Tensor, numpy.array, or string/blobname):
Ground truth data. Binary label for each element.
predictions (torch.Tensor, numpy.array, or string/blobname):
The probability that an element be classified as true. Value should in [0, 1]
The probability that an element be classified as true.
Value should in [0, 1]
global_step (int): Global step value to record
num_thresholds (int): Number of thresholds used to draw the curve.
walltime (float): Optional override default walltime (time.time())
seconds after epoch of event
Examples::
from torch.utils.tensorboard import SummaryWriter
import numpy as np
labels = np.random.randint(2, size=100) # binary label
predictions = np.random.rand(100)
writer = SummaryWriter()
writer.add_pr_curve('pr_curve', labels, predictions, 0)
writer.close()
"""
labels, predictions = make_np(labels), make_np(predictions)
self._get_file_writer().add_summary(