[TB] Add support for hparam domain_discrete (#40720)

Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/40720

Add support for populating domain_discrete field in TensorBoard add_hparams API

Test Plan: Unit test test_hparams_domain_discrete

Reviewed By: edward-io

Differential Revision: D22291347

fbshipit-source-id: 78db9f62661c9fe36cd08d563db0e7021c01428d
This commit is contained in:
Siqi Yan
2020-06-29 19:31:59 -07:00
committed by Facebook GitHub Bot
parent 53af9df557
commit 01e2099bb8
3 changed files with 95 additions and 6 deletions

View File

@ -465,6 +465,22 @@ class TestTensorBoardSummary(BaseTestCase):
mt = {'accuracy': 0.1}
self.assertTrue(compare_proto(summary.hparams(hp, mt), self))
def test_hparams_domain_discrete(self):
hp = {"lr": 0.1, "bool_var": True, "string_var": "hi"}
mt = {"accuracy": 0.1}
hp_domain = {"lr": [0.1], "bool_var": [True], "string_var": ["hi"]}
# hparam_domain_discrete keys needs to be subset of hparam_dict keys
with self.assertRaises(TypeError):
summary.hparams(hp, mt, hparam_domain_discrete={"wrong_key": []})
# hparam_domain_discrete values needs to be same type as hparam_dict values
with self.assertRaises(TypeError):
summary.hparams(hp, mt, hparam_domain_discrete={"lr": [True]})
# only smoke test. Because protobuf map serialization is nondeterministic.
summary.hparams(hp, mt, hparam_domain_discrete=hp_domain)
def test_mesh(self):
v = np.array([[[1, 1, 1], [-1, -1, 1], [1, -1, -1], [-1, 1, -1]]], dtype=float)
c = np.array([[[255, 0, 0], [0, 255, 0], [0, 0, 255], [255, 0, 255]]], dtype=int)

View File

@ -10,6 +10,7 @@ import os
# pylint: disable=unused-import
from six.moves import range
from google.protobuf import struct_pb2
from tensorboard.compat.proto.summary_pb2 import Summary
from tensorboard.compat.proto.summary_pb2 import HistogramProto
from tensorboard.compat.proto.summary_pb2 import SummaryMetadata
@ -50,7 +51,7 @@ def _draw_single_box(image, xmin, ymin, xmax, ymax, display_str, color='black',
return image
def hparams(hparam_dict=None, metric_dict=None):
def hparams(hparam_dict=None, metric_dict=None, hparam_domain_discrete=None):
"""Outputs three `Summary` protocol buffers needed by hparams plugin.
`Experiment` keeps the metadata of an experiment, such as the name of the
hyperparameters and the name of the metrics.
@ -62,6 +63,8 @@ def hparams(hparam_dict=None, metric_dict=None):
and their values.
metric_dict: A dictionary that contains names of the metrics
and their values.
hparam_domain_discrete: (Optional[Dict[str, List[Any]]]) A dictionary that
contains names of the hyperparameters and all discrete values they can hold
Returns:
The `Summary` protobufs for Experiment, SessionStartInfo and
@ -99,6 +102,21 @@ def hparams(hparam_dict=None, metric_dict=None):
logging.warning('parameter: metric_dict should be a dictionary, nothing logged.')
raise TypeError('parameter: metric_dict should be a dictionary, nothing logged.')
hparam_domain_discrete = hparam_domain_discrete or {}
if not isinstance(hparam_domain_discrete, dict):
raise TypeError(
"parameter: hparam_domain_discrete should be a dictionary, nothing logged."
)
for k, v in hparam_domain_discrete.items():
if (
k not in hparam_dict
or not isinstance(v, list)
or not all(isinstance(d, type(hparam_dict[k])) for d in v)
):
raise TypeError(
"parameter: hparam_domain_discrete[{}] should be a list of same type as "
"hparam_dict[{}].".format(k, k)
)
hps = []
@ -108,17 +126,68 @@ def hparams(hparam_dict=None, metric_dict=None):
continue
if isinstance(v, int) or isinstance(v, float):
ssi.hparams[k].number_value = v
hps.append(HParamInfo(name=k, type=DataType.Value("DATA_TYPE_FLOAT64")))
if k in hparam_domain_discrete:
domain_discrete = struct_pb2.ListValue(
values=[
struct_pb2.Value(number_value=d)
for d in hparam_domain_discrete[k]
]
)
else:
domain_discrete = None
hps.append(
HParamInfo(
name=k,
type=DataType.Value("DATA_TYPE_FLOAT64"),
domain_discrete=domain_discrete,
)
)
continue
if isinstance(v, string_types):
ssi.hparams[k].string_value = v
hps.append(HParamInfo(name=k, type=DataType.Value("DATA_TYPE_STRING")))
if k in hparam_domain_discrete:
domain_discrete = struct_pb2.ListValue(
values=[
struct_pb2.Value(string_value=d)
for d in hparam_domain_discrete[k]
]
)
else:
domain_discrete = None
hps.append(
HParamInfo(
name=k,
type=DataType.Value("DATA_TYPE_STRING"),
domain_discrete=domain_discrete,
)
)
continue
if isinstance(v, bool):
ssi.hparams[k].bool_value = v
hps.append(HParamInfo(name=k, type=DataType.Value("DATA_TYPE_BOOL")))
if k in hparam_domain_discrete:
domain_discrete = struct_pb2.ListValue(
values=[
struct_pb2.Value(bool_value=d)
for d in hparam_domain_discrete[k]
]
)
else:
domain_discrete = None
hps.append(
HParamInfo(
name=k,
type=DataType.Value("DATA_TYPE_BOOL"),
domain_discrete=domain_discrete,
)
)
continue
if isinstance(v, torch.Tensor):

View File

@ -268,7 +268,9 @@ class SummaryWriter(object):
"""Returns the directory where event files will be written."""
return self.log_dir
def add_hparams(self, hparam_dict, metric_dict, run_name=None):
def add_hparams(
self, hparam_dict, metric_dict, hparam_domain_discrete=None, run_name=None
):
"""Add a set of hyperparameters to be compared in TensorBoard.
Args:
@ -281,6 +283,8 @@ class SummaryWriter(object):
here should be unique in the tensorboard record. Otherwise the value
you added by ``add_scalar`` will be displayed in hparam plugin. In most
cases, this is unwanted.
hparam_domain_discrete: (Optional[Dict[str, List[Any]]]) A dictionary that
contains names of the hyperparameters and all discrete values they can hold
run_name (str): Name of the run, to be included as part of the logdir.
If unspecified, will use current timestamp.
@ -301,7 +305,7 @@ class SummaryWriter(object):
torch._C._log_api_usage_once("tensorboard.logging.add_hparams")
if type(hparam_dict) is not dict or type(metric_dict) is not dict:
raise TypeError('hparam_dict and metric_dict should be dictionary.')
exp, ssi, sei = hparams(hparam_dict, metric_dict)
exp, ssi, sei = hparams(hparam_dict, metric_dict, hparam_domain_discrete)
if not run_name:
run_name = str(time.time())