Fix type annotations for a number of torch.utils submodules (#42711)

Summary:
Related issue on `torch.utils` type annotation hiccups: gh-41794

Pull Request resolved: https://github.com/pytorch/pytorch/pull/42711

Reviewed By: mrshenli

Differential Revision: D23005434

Pulled By: malfet

fbshipit-source-id: 151554b1e7582743f032476aeccdfdad7a252095
This commit is contained in:
Ralf Gommers
2020-08-14 18:10:48 -07:00
committed by Facebook GitHub Bot
parent bcf54f9438
commit c84f78470b
9 changed files with 31 additions and 26 deletions

View File

@ -344,6 +344,7 @@ class SummaryWriter(object):
"""
torch._C._log_api_usage_once("tensorboard.logging.add_scalar")
if self._check_caffe2_blob(scalar_value):
from caffe2.python import workspace
scalar_value = workspace.FetchBlob(scalar_value)
self._get_file_writer().add_summary(
scalar(tag, scalar_value), global_step, walltime)
@ -382,6 +383,7 @@ class SummaryWriter(object):
fw_logdir = self._get_file_writer().get_logdir()
for tag, scalar_value in tag_scalar_dict.items():
fw_tag = fw_logdir + "/" + main_tag.replace("/", "_") + "_" + tag
assert self.all_writers is not None
if fw_tag in self.all_writers.keys():
fw = self.all_writers[fw_tag]
else:
@ -389,6 +391,7 @@ class SummaryWriter(object):
self.filename_suffix)
self.all_writers[fw_tag] = fw
if self._check_caffe2_blob(scalar_value):
from caffe2.python import workspace
scalar_value = workspace.FetchBlob(scalar_value)
fw.add_summary(scalar(main_tag, scalar_value),
global_step, walltime)
@ -423,6 +426,7 @@ class SummaryWriter(object):
"""
torch._C._log_api_usage_once("tensorboard.logging.add_histogram")
if self._check_caffe2_blob(values):
from caffe2.python import workspace
values = workspace.FetchBlob(values)
if isinstance(bins, six.string_types) and bins == 'tensorflow':
bins = self.default_bins
@ -540,6 +544,7 @@ class SummaryWriter(object):
"""
torch._C._log_api_usage_once("tensorboard.logging.add_image")
if self._check_caffe2_blob(img_tensor):
from caffe2.python import workspace
img_tensor = workspace.FetchBlob(img_tensor)
self._get_file_writer().add_summary(
image(tag, img_tensor, dataformats=dataformats), global_step, walltime)
@ -583,6 +588,7 @@ class SummaryWriter(object):
"""
torch._C._log_api_usage_once("tensorboard.logging.add_images")
if self._check_caffe2_blob(img_tensor):
from caffe2.python import workspace
img_tensor = workspace.FetchBlob(img_tensor)
self._get_file_writer().add_summary(
image(tag, img_tensor, dataformats=dataformats), global_step, walltime)
@ -612,8 +618,10 @@ class SummaryWriter(object):
"""
torch._C._log_api_usage_once("tensorboard.logging.add_image_with_boxes")
if self._check_caffe2_blob(img_tensor):
from caffe2.python import workspace
img_tensor = workspace.FetchBlob(img_tensor)
if self._check_caffe2_blob(box_tensor):
from caffe2.python import workspace
box_tensor = workspace.FetchBlob(box_tensor)
if labels is not None:
if isinstance(labels, str):
@ -676,6 +684,7 @@ class SummaryWriter(object):
"""
torch._C._log_api_usage_once("tensorboard.logging.add_audio")
if self._check_caffe2_blob(snd_tensor):
from caffe2.python import workspace
snd_tensor = workspace.FetchBlob(snd_tensor)
self._get_file_writer().add_summary(
audio(tag, snd_tensor, sample_rate=sample_rate), global_step, walltime)