[PyTorch-TB] Write full tensor as tensor proto (#105186)

Write full tensor as tensor proto
Pull Request resolved: https://github.com/pytorch/pytorch/pull/105186
Approved by: https://github.com/atalman
This commit is contained in:
Seth Sangsu Lee
2023-07-14 18:04:09 +00:00
committed by PyTorch MergeBot
parent 233f917c83
commit d855c6c7de
3 changed files with 162 additions and 0 deletions

View File

@ -79,6 +79,8 @@ if TEST_TENSORBOARD:
from tensorboard.compat.proto.graph_pb2 import GraphDef
from torch.utils.tensorboard import summary, SummaryWriter
from torch.utils.tensorboard._utils import _prepare_video, convert_to_HWC
from tensorboard.compat.proto.types_pb2 import DataType
from torch.utils.tensorboard.summary import tensor_proto
from torch.utils.tensorboard._convert_np import make_np
from torch.utils.tensorboard._pytorch_graph import graph
from google.protobuf import text_format
@ -862,5 +864,43 @@ class TestTensorBoardNumpy(BaseTestCase):
)
compare_proto(graph, self)
class TestTensorProtoSummary(BaseTestCase):
def test_float_tensor_proto(self):
float_values = [1.0, 2.0, 3.0]
actual_proto = (
tensor_proto("dummy", torch.tensor(float_values)).value[0].tensor
)
self.assertEqual(actual_proto.float_val, float_values)
self.assertTrue(actual_proto.dtype == DataType.DT_FLOAT)
def test_int_tensor_proto(self):
int_values = [1, 2, 3]
actual_proto = (
tensor_proto("dummy", torch.tensor(int_values, dtype=torch.int32))
.value[0]
.tensor
)
self.assertEqual(actual_proto.int_val, int_values)
self.assertTrue(actual_proto.dtype == DataType.DT_INT32)
def test_scalar_tensor_proto(self):
scalar_value = 0.1
actual_proto = (
tensor_proto("dummy", torch.tensor(scalar_value)).value[0].tensor
)
self.assertAlmostEqual(actual_proto.float_val[0], scalar_value)
def test_complex_tensor_proto(self):
real = torch.tensor([1.0, 2.0])
imag = torch.tensor([3.0, 4.0])
actual_proto = (
tensor_proto("dummy", torch.complex(real, imag)).value[0].tensor
)
self.assertEqual(actual_proto.scomplex_val, [1.0, 3.0, 2.0, 4.0])
def test_empty_tensor_proto(self):
actual_proto = tensor_proto("dummy", torch.empty(0)).value[0].tensor
self.assertEqual(actual_proto.float_val, [])
if __name__ == '__main__':
run_tests()