mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Clarify API and add examples for all methods (#20008)
Summary: As a part of supporting writing data into TensorBoard readable format, we show more example on how to use the function in addition to the API docs. Pull Request resolved: https://github.com/pytorch/pytorch/pull/20008 Reviewed By: natalialunova Differential Revision: D15261502 Pulled By: orionr fbshipit-source-id: 16611695a27e74bfcdf311e7cad40196e0947038
This commit is contained in:
committed by
Facebook Github Bot
parent
4a086f700f
commit
7edf9a25e8
BIN
docs/source/_static/img/tensorboard/add_histogram.png
Normal file
BIN
docs/source/_static/img/tensorboard/add_histogram.png
Normal file
Binary file not shown.
After Width: | Height: | Size: 47 KiB |
BIN
docs/source/_static/img/tensorboard/add_image.png
Normal file
BIN
docs/source/_static/img/tensorboard/add_image.png
Normal file
Binary file not shown.
After Width: | Height: | Size: 46 KiB |
BIN
docs/source/_static/img/tensorboard/add_images.png
Normal file
BIN
docs/source/_static/img/tensorboard/add_images.png
Normal file
Binary file not shown.
After Width: | Height: | Size: 75 KiB |
BIN
docs/source/_static/img/tensorboard/add_scalar.png
Normal file
BIN
docs/source/_static/img/tensorboard/add_scalar.png
Normal file
Binary file not shown.
After Width: | Height: | Size: 45 KiB |
BIN
docs/source/_static/img/tensorboard/add_scalars.png
Normal file
BIN
docs/source/_static/img/tensorboard/add_scalars.png
Normal file
Binary file not shown.
After Width: | Height: | Size: 97 KiB |
BIN
docs/source/_static/img/tensorboard/hier_tags.png
Normal file
BIN
docs/source/_static/img/tensorboard/hier_tags.png
Normal file
Binary file not shown.
After Width: | Height: | Size: 157 KiB |
@ -48,13 +48,46 @@ and runnable with::
|
||||
pip install tb-nightly # Until 1.14 moves to the release channel
|
||||
tensorboard --logdir=runs
|
||||
|
||||
|
||||
Lots of information can be logged for one experiment. To avoid cluttering
|
||||
the UI and have better result clustering, we can group plots by naming them
|
||||
hierarchically. For example, "Loss/train" and "Loss/test" will be grouped
|
||||
together, while "Accuracy/train" and "Accuracy/test" will be grouped separately
|
||||
in the TensorBoard interface.
|
||||
|
||||
.. code:: python
|
||||
|
||||
|
||||
from torch.utils.tensorboard import SummaryWriter
|
||||
import numpy as np
|
||||
|
||||
writer = SummaryWriter()
|
||||
|
||||
for n_iter in range(100):
|
||||
writer.add_scalar('Loss/train', np.random.random(), n_iter)
|
||||
writer.add_scalar('Loss/test', np.random.random(), n_iter)
|
||||
writer.add_scalar('Accuracy/train', np.random.random(), n_iter)
|
||||
writer.add_scalar('Accuracy/test', np.random.random(), n_iter)
|
||||
|
||||
|
||||
Expected result:
|
||||
|
||||
.. image:: _static/img/tensorboard/hier_tags.png
|
||||
:scale: 75 %
|
||||
|
||||
|
|
||||
|
|
||||
|
||||
.. currentmodule:: torch.utils.tensorboard.writer
|
||||
|
||||
.. autoclass:: SummaryWriter
|
||||
|
||||
.. automethod:: __init__
|
||||
.. automethod:: add_scalar
|
||||
.. automethod:: add_scalars
|
||||
.. automethod:: add_histogram
|
||||
.. automethod:: add_image
|
||||
.. automethod:: add_images
|
||||
.. automethod:: add_figure
|
||||
.. automethod:: add_video
|
||||
.. automethod:: add_audio
|
||||
|
@ -170,6 +170,12 @@ class GraphPy(object):
|
||||
self.nodes_io[key].uniqueName = self.unique_name_to_scoped_name[node.uniqueName]
|
||||
|
||||
def to_proto(self):
|
||||
"""
|
||||
Converts graph representation of GraphPy object to TensorBoard
|
||||
required format.
|
||||
"""
|
||||
# TODO: compute correct memory usage and CPU time once
|
||||
# PyTorch supports it
|
||||
nodes = []
|
||||
node_stats = []
|
||||
for v in self.nodes_io.values():
|
||||
|
@ -35,11 +35,7 @@ class FileWriter(object):
|
||||
training.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
logdir,
|
||||
max_queue=10,
|
||||
flush_secs=120,
|
||||
filename_suffix=''):
|
||||
def __init__(self, logdir, max_queue=10, flush_secs=120, filename_suffix=''):
|
||||
"""Creates a `FileWriter` and an event file.
|
||||
On construction the writer creates a new event file in `logdir`.
|
||||
The other arguments to the constructor control the asynchronous writes to
|
||||
@ -49,10 +45,11 @@ class FileWriter(object):
|
||||
logdir: A string. Directory where event file will be written.
|
||||
max_queue: Integer. Size of the queue for pending events and
|
||||
summaries before one of the 'add' calls forces a flush to disk.
|
||||
Default is ten items.
|
||||
flush_secs: Number. How often, in seconds, to flush the
|
||||
pending events and summaries to disk.
|
||||
filename_suffix: A string. Suffix added to all event filenames.
|
||||
More details on event filename construction in
|
||||
pending events and summaries to disk. Default is every two minutes.
|
||||
filename_suffix: A string. Suffix added to all event filenames
|
||||
in the logdir directory. More details on filename construction in
|
||||
tensorboard.summary.writer.event_file_writer.EventFileWriter.
|
||||
"""
|
||||
# Sometimes PosixPath is passed in and we need to coerce it to
|
||||
@ -150,7 +147,7 @@ class FileWriter(object):
|
||||
|
||||
|
||||
class SummaryWriter(object):
|
||||
"""Writes entries directly to event files in the log_dir to be
|
||||
"""Writes entries directly to event files in the logdir to be
|
||||
consumed by TensorBoard.
|
||||
|
||||
The `SummaryWriter` class provides a high-level API to create an event file
|
||||
@ -160,37 +157,61 @@ class SummaryWriter(object):
|
||||
training.
|
||||
"""
|
||||
|
||||
def __init__(self, log_dir=None, comment='', **kwargs):
|
||||
def __init__(self, logdir=None, comment='', purge_step=None, max_queue=10,
|
||||
flush_secs=120, filename_suffix=''):
|
||||
"""Creates a `SummaryWriter` that will write out events and summaries
|
||||
to the event file.
|
||||
|
||||
Args:
|
||||
log_dir (string): save location, default is: runs/**CURRENT_DATETIME_HOSTNAME**, which changes after each
|
||||
run. Use hierarchical folder structure to compare between runs easily. e.g. pass in
|
||||
'runs/exp1', 'runs/exp2', etc. for each new experiment to compare across. Defaults
|
||||
to ``./runs/``.
|
||||
comment (string): comment that appends to the default ``log_dir``. If ``log_dir`` is assigned,
|
||||
this argument will no effect.
|
||||
logdir (string): Save directory location. Default is
|
||||
runs/**CURRENT_DATETIME_HOSTNAME**, which changes after each run.
|
||||
Use hierarchical folder structure to compare
|
||||
between runs easily. e.g. pass in 'runs/exp1', 'runs/exp2', etc.
|
||||
for each new experiment to compare across them.
|
||||
comment (string): Comment logdir suffix appended to the default
|
||||
``logdir``. If ``logdir`` is assigned, this argument has no effect.
|
||||
purge_step (int):
|
||||
When logging crashes at step :math:`T+X` and restarts at step :math:`T`, any events
|
||||
whose global_step larger or equal to :math:`T` will be purged and hidden from TensorBoard.
|
||||
Note that the resumed experiment and crashed experiment should have the same ``log_dir``.
|
||||
filename_suffix (string):
|
||||
Every event file's name is suffixed with suffix. Example: ``SummaryWriter(filename_suffix='.123')``
|
||||
More details on event filename construction in
|
||||
When logging crashes at step :math:`T+X` and restarts at step :math:`T`,
|
||||
any events whose global_step larger or equal to :math:`T` will be
|
||||
purged and hidden from TensorBoard.
|
||||
Note that crashed and resumed experiments should have the same ``logdir``.
|
||||
max_queue (int): Size of the queue for pending events and
|
||||
summaries before one of the 'add' calls forces a flush to disk.
|
||||
Default is ten items.
|
||||
flush_secs (int): How often, in seconds, to flush the
|
||||
pending events and summaries to disk. Default is every two minutes.
|
||||
filename_suffix (string): Suffix added to all event filenames in
|
||||
the logdir directory. More details on filename construction in
|
||||
tensorboard.summary.writer.event_file_writer.EventFileWriter.
|
||||
kwargs: extra keyword arguments for FileWriter (e.g. 'flush_secs'
|
||||
controls how often to flush pending events). For more arguments
|
||||
please refer to docs for 'tf.summary.FileWriter'.
|
||||
|
||||
Examples::
|
||||
|
||||
from torch.utils.tensorboard import SummaryWriter
|
||||
|
||||
# create a summary writer with automatically generated folder name.
|
||||
writer = SummaryWriter()
|
||||
# folder location: runs/May04_22-14-54_s-MacBook-Pro.local/
|
||||
|
||||
# create a summary writer using the specified folder name.
|
||||
writer = SummaryWriter("my_experiment")
|
||||
# folder location: my_experiment
|
||||
|
||||
# create a summary writer with comment appended.
|
||||
writer = SummaryWriter(comment="LR_0.1_BATCH_16")
|
||||
# folder location: runs/May04_22-14-54_s-MacBook-Pro.localLR_0.1_BATCH_16/
|
||||
|
||||
"""
|
||||
if not log_dir:
|
||||
if not logdir:
|
||||
import socket
|
||||
from datetime import datetime
|
||||
current_time = datetime.now().strftime('%b%d_%H-%M-%S')
|
||||
log_dir = os.path.join(
|
||||
logdir = os.path.join(
|
||||
'runs', current_time + '_' + socket.gethostname() + comment)
|
||||
self.log_dir = log_dir
|
||||
self.kwargs = kwargs
|
||||
self.logdir = logdir
|
||||
self.purge_step = purge_step
|
||||
self.max_queue = max_queue
|
||||
self.flush_secs = flush_secs
|
||||
self.filename_suffix = filename_suffix
|
||||
|
||||
# Initialize the file writers, but they can be cleared out on close
|
||||
# and recreated later as needed.
|
||||
@ -225,16 +246,16 @@ class SummaryWriter(object):
|
||||
def _get_file_writer(self):
|
||||
"""Returns the default FileWriter instance. Recreates it if closed."""
|
||||
if self.all_writers is None or self.file_writer is None:
|
||||
if 'purge_step' in self.kwargs.keys():
|
||||
most_recent_step = self.kwargs.pop('purge_step')
|
||||
self.file_writer = FileWriter(logdir=self.log_dir, **self.kwargs)
|
||||
self.file_writer = FileWriter(self.logdir, self.max_queue,
|
||||
self.flush_secs, self.filename_suffix)
|
||||
self.all_writers = {self.file_writer.get_logdir(): self.file_writer}
|
||||
if self.purge_step is not None:
|
||||
most_recent_step = self.purge_step
|
||||
self.file_writer.add_event(
|
||||
Event(step=most_recent_step, file_version='brain.Event:2'))
|
||||
self.file_writer.add_event(
|
||||
Event(step=most_recent_step, session_log=SessionLog(status=SessionLog.START)))
|
||||
else:
|
||||
self.file_writer = FileWriter(logdir=self.log_dir, **self.kwargs)
|
||||
self.all_writers = {self.file_writer.get_logdir(): self.file_writer}
|
||||
self.purge_step = None
|
||||
return self.file_writer
|
||||
|
||||
def add_scalar(self, tag, scalar_value, global_step=None, walltime=None):
|
||||
@ -246,6 +267,21 @@ class SummaryWriter(object):
|
||||
global_step (int): Global step value to record
|
||||
walltime (float): Optional override default walltime (time.time())
|
||||
with seconds after epoch of event
|
||||
|
||||
Examples::
|
||||
|
||||
from torch.utils.tensorboard import SummaryWriter
|
||||
writer = SummaryWriter()
|
||||
x = range(100)
|
||||
for i in x:
|
||||
writer.add_scalar('y=2x', i * 2, i)
|
||||
writer.close()
|
||||
|
||||
Expected result:
|
||||
|
||||
.. image:: _static/img/tensorboard/add_scalar.png
|
||||
:scale: 50 %
|
||||
|
||||
"""
|
||||
if self._check_caffe2_blob(scalar_value):
|
||||
scalar_value = workspace.FetchBlob(scalar_value)
|
||||
@ -266,11 +302,22 @@ class SummaryWriter(object):
|
||||
|
||||
Examples::
|
||||
|
||||
writer.add_scalars('run_14h', {'xsinx':i*np.sin(i/r),
|
||||
'xcosx':i*np.cos(i/r),
|
||||
'arctanx': numsteps*np.arctan(i/r)}, i)
|
||||
from torch.utils.tensorboard import SummaryWriter
|
||||
writer = SummaryWriter()
|
||||
r = 5
|
||||
for i in range(100):
|
||||
writer.add_scalars('run_14h', {'xsinx':i*np.sin(i/r),
|
||||
'xcosx':i*np.cos(i/r),
|
||||
'tanx': np.tan(i/r)}, i)
|
||||
writer.close()
|
||||
# This call adds three values to the same scalar plot with the tag
|
||||
# 'run_14h' in TensorBoard's scalar section.
|
||||
|
||||
Expected result:
|
||||
|
||||
.. image:: _static/img/tensorboard/add_scalars.png
|
||||
:scale: 50 %
|
||||
|
||||
"""
|
||||
walltime = time.time() if walltime is None else walltime
|
||||
fw_logdir = self._get_file_writer().get_logdir()
|
||||
@ -279,7 +326,8 @@ class SummaryWriter(object):
|
||||
if fw_tag in self.all_writers.keys():
|
||||
fw = self.all_writers[fw_tag]
|
||||
else:
|
||||
fw = FileWriter(logdir=fw_tag)
|
||||
fw = FileWriter(fw_tag, self.max_queue, self.flush_secs,
|
||||
self.filename_suffix)
|
||||
self.all_writers[fw_tag] = fw
|
||||
if self._check_caffe2_blob(scalar_value):
|
||||
scalar_value = workspace.FetchBlob(scalar_value)
|
||||
@ -293,10 +341,26 @@ class SummaryWriter(object):
|
||||
tag (string): Data identifier
|
||||
values (torch.Tensor, numpy.array, or string/blobname): Values to build histogram
|
||||
global_step (int): Global step value to record
|
||||
bins (string): one of {'tensorflow','auto', 'fd', ...}, this determines how the bins are made. You can find
|
||||
bins (string): One of {'tensorflow','auto', 'fd', ...}. This determines how the bins are made. You can find
|
||||
other options in: https://docs.scipy.org/doc/numpy/reference/generated/numpy.histogram.html
|
||||
walltime (float): Optional override default walltime (time.time())
|
||||
seconds after epoch of event
|
||||
|
||||
Examples::
|
||||
|
||||
from torch.utils.tensorboard import SummaryWriter
|
||||
import numpy as np
|
||||
writer = SummaryWriter()
|
||||
for i in range(10):
|
||||
x = np.random.random(1000)
|
||||
writer.add_histogram('distribution centers', x + i, i)
|
||||
writer.close()
|
||||
|
||||
Expected result:
|
||||
|
||||
.. image:: _static/img/tensorboard/add_histogram.png
|
||||
:scale: 50 %
|
||||
|
||||
"""
|
||||
if self._check_caffe2_blob(values):
|
||||
values = workspace.FetchBlob(values)
|
||||
@ -352,6 +416,31 @@ class SummaryWriter(object):
|
||||
convert a batch of tensor into 3xHxW format or call ``add_images`` and let us do the job.
|
||||
Tensor with :math:`(1, H, W)`, :math:`(H, W)`, :math:`(H, W, 3)` is also suitible as long as
|
||||
corresponding ``dataformats`` argument is passed. e.g. CHW, HWC, HW.
|
||||
|
||||
Examples::
|
||||
|
||||
from torch.utils.tensorboard import SummaryWriter
|
||||
import numpy as np
|
||||
img = np.zeros((3, 100, 100))
|
||||
img[0] = np.arange(0, 10000).reshape(100, 100) / 10000
|
||||
img[1] = 1 - np.arange(0, 10000).reshape(100, 100) / 10000
|
||||
|
||||
img_HWC = np.zeros((100, 100, 3))
|
||||
img_HWC[:, :, 0] = np.arange(0, 10000).reshape(100, 100) / 10000
|
||||
img_HWC[:, :, 1] = 1 - np.arange(0, 10000).reshape(100, 100) / 10000
|
||||
|
||||
writer = SummaryWriter()
|
||||
writer.add_image('my_image', img, 0)
|
||||
|
||||
# If you have non-default dimension setting, set the dataformats argument.
|
||||
writer.add_image('my_image_HWC', img_HWC, 0, dataformats='HWC')
|
||||
writer.close()
|
||||
|
||||
Expected result:
|
||||
|
||||
.. image:: _static/img/tensorboard/add_image.png
|
||||
:scale: 50 %
|
||||
|
||||
"""
|
||||
if self._check_caffe2_blob(img_tensor):
|
||||
img_tensor = workspace.FetchBlob(img_tensor)
|
||||
@ -372,6 +461,26 @@ class SummaryWriter(object):
|
||||
Shape:
|
||||
img_tensor: Default is :math:`(N, 3, H, W)`. If ``dataformats`` is specified, other shape will be
|
||||
accepted. e.g. NCHW or NHWC.
|
||||
|
||||
Examples::
|
||||
|
||||
from torch.utils.tensorboard import SummaryWriter
|
||||
import numpy as np
|
||||
|
||||
img_batch = np.zeros((16, 3, 100, 100))
|
||||
for i in range(16):
|
||||
img_batch[i, 0] = np.arange(0, 10000).reshape(100, 100) / 10000 / 16 * i
|
||||
img_batch[i, 1] = (1 - np.arange(0, 10000).reshape(100, 100) / 10000) / 16 * i
|
||||
|
||||
writer = SummaryWriter()
|
||||
writer.add_images('my_image_batch', img_batch, 0)
|
||||
writer.close()
|
||||
|
||||
Expected result:
|
||||
|
||||
.. image:: _static/img/tensorboard/add_images.png
|
||||
:scale: 30 %
|
||||
|
||||
"""
|
||||
if self._check_caffe2_blob(img_tensor):
|
||||
img_tensor = workspace.FetchBlob(img_tensor)
|
||||
@ -410,7 +519,7 @@ class SummaryWriter(object):
|
||||
|
||||
Args:
|
||||
tag (string): Data identifier
|
||||
figure (matplotlib.pyplot.figure) or list of figures: figure or a list of figures
|
||||
figure (matplotlib.pyplot.figure) or list of figures: Figure or a list of figures
|
||||
global_step (int): Global step value to record
|
||||
close (bool): Flag to automatically close the figure
|
||||
walltime (float): Optional override default walltime (time.time())
|
||||
@ -483,14 +592,14 @@ class SummaryWriter(object):
|
||||
"""Add graph data to summary.
|
||||
|
||||
Args:
|
||||
model (torch.nn.Module): model to draw.
|
||||
input_to_model (torch.Tensor or list of torch.Tensor): a variable or a tuple of
|
||||
model (torch.nn.Module): Model to draw.
|
||||
input_to_model (torch.Tensor or list of torch.Tensor): A variable or a tuple of
|
||||
variables to be fed.
|
||||
verbose (bool): Whether to print graph structure in console.
|
||||
omit_useless_nodes (bool): Default to ``true``, which eliminates unused nodes.
|
||||
operator_export_type (string): One of: ``"ONNX"``, ``"RAW"``. This determines
|
||||
the optimization level of the graph. If error happens during exporting
|
||||
the graph, use ``"RAW"`` may help.
|
||||
the graph, using ``"RAW"`` might help.
|
||||
|
||||
"""
|
||||
if hasattr(model, 'forward'):
|
||||
@ -601,17 +710,34 @@ class SummaryWriter(object):
|
||||
def add_pr_curve(self, tag, labels, predictions, global_step=None,
|
||||
num_thresholds=127, weights=None, walltime=None):
|
||||
"""Adds precision recall curve.
|
||||
Plotting a precision-recall curve lets you understand your model's
|
||||
performance under different threshold settings. With this function,
|
||||
you provide the ground truth labeling (T/F) and prediction confidence
|
||||
(usually the output of your model) for each target. The TensorBoard UI
|
||||
will let you choose the threshold interactively.
|
||||
|
||||
Args:
|
||||
tag (string): Data identifier
|
||||
labels (torch.Tensor, numpy.array, or string/blobname): Ground truth data. Binary label for each element.
|
||||
labels (torch.Tensor, numpy.array, or string/blobname):
|
||||
Ground truth data. Binary label for each element.
|
||||
predictions (torch.Tensor, numpy.array, or string/blobname):
|
||||
The probability that an element be classified as true. Value should in [0, 1]
|
||||
The probability that an element be classified as true.
|
||||
Value should in [0, 1]
|
||||
global_step (int): Global step value to record
|
||||
num_thresholds (int): Number of thresholds used to draw the curve.
|
||||
walltime (float): Optional override default walltime (time.time())
|
||||
seconds after epoch of event
|
||||
|
||||
Examples::
|
||||
|
||||
from torch.utils.tensorboard import SummaryWriter
|
||||
import numpy as np
|
||||
labels = np.random.randint(2, size=100) # binary label
|
||||
predictions = np.random.rand(100)
|
||||
writer = SummaryWriter()
|
||||
writer.add_pr_curve('pr_curve', labels, predictions, 0)
|
||||
writer.close()
|
||||
|
||||
"""
|
||||
labels, predictions = make_np(labels), make_np(predictions)
|
||||
self._get_file_writer().add_summary(
|
||||
|
Reference in New Issue
Block a user