mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[PyTorch-TB] Write full tensor as tensor proto (#105186)
Write full tensor as tensor proto Pull Request resolved: https://github.com/pytorch/pytorch/pull/105186 Approved by: https://github.com/atalman
This commit is contained in:
committed by
PyTorch MergeBot
parent
233f917c83
commit
d855c6c7de
@ -79,6 +79,8 @@ if TEST_TENSORBOARD:
|
||||
from tensorboard.compat.proto.graph_pb2 import GraphDef
|
||||
from torch.utils.tensorboard import summary, SummaryWriter
|
||||
from torch.utils.tensorboard._utils import _prepare_video, convert_to_HWC
|
||||
from tensorboard.compat.proto.types_pb2 import DataType
|
||||
from torch.utils.tensorboard.summary import tensor_proto
|
||||
from torch.utils.tensorboard._convert_np import make_np
|
||||
from torch.utils.tensorboard._pytorch_graph import graph
|
||||
from google.protobuf import text_format
|
||||
@ -862,5 +864,43 @@ class TestTensorBoardNumpy(BaseTestCase):
|
||||
)
|
||||
compare_proto(graph, self)
|
||||
|
||||
class TestTensorProtoSummary(BaseTestCase):
|
||||
def test_float_tensor_proto(self):
|
||||
float_values = [1.0, 2.0, 3.0]
|
||||
actual_proto = (
|
||||
tensor_proto("dummy", torch.tensor(float_values)).value[0].tensor
|
||||
)
|
||||
self.assertEqual(actual_proto.float_val, float_values)
|
||||
self.assertTrue(actual_proto.dtype == DataType.DT_FLOAT)
|
||||
|
||||
def test_int_tensor_proto(self):
|
||||
int_values = [1, 2, 3]
|
||||
actual_proto = (
|
||||
tensor_proto("dummy", torch.tensor(int_values, dtype=torch.int32))
|
||||
.value[0]
|
||||
.tensor
|
||||
)
|
||||
self.assertEqual(actual_proto.int_val, int_values)
|
||||
self.assertTrue(actual_proto.dtype == DataType.DT_INT32)
|
||||
|
||||
def test_scalar_tensor_proto(self):
|
||||
scalar_value = 0.1
|
||||
actual_proto = (
|
||||
tensor_proto("dummy", torch.tensor(scalar_value)).value[0].tensor
|
||||
)
|
||||
self.assertAlmostEqual(actual_proto.float_val[0], scalar_value)
|
||||
|
||||
def test_complex_tensor_proto(self):
|
||||
real = torch.tensor([1.0, 2.0])
|
||||
imag = torch.tensor([3.0, 4.0])
|
||||
actual_proto = (
|
||||
tensor_proto("dummy", torch.complex(real, imag)).value[0].tensor
|
||||
)
|
||||
self.assertEqual(actual_proto.scomplex_val, [1.0, 3.0, 2.0, 4.0])
|
||||
|
||||
def test_empty_tensor_proto(self):
|
||||
actual_proto = tensor_proto("dummy", torch.empty(0)).value[0].tensor
|
||||
self.assertEqual(actual_proto.float_val, [])
|
||||
|
||||
if __name__ == '__main__':
|
||||
run_tests()
|
||||
|
Reference in New Issue
Block a user