mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Revert "[PyTorch TB] Write raw tensor as tensor_proto (#104908)"
This reverts commit dceae41c29782399c84304812696a8382e9b4292. Reverted https://github.com/pytorch/pytorch/pull/104908 on behalf of https://github.com/facebook-github-bot due to Diff reverted internally ([comment](https://github.com/pytorch/pytorch/pull/104908#issuecomment-1634532376))
This commit is contained in:
@ -77,7 +77,6 @@ class BaseTestCase(TestCase):
|
||||
|
||||
if TEST_TENSORBOARD:
|
||||
from tensorboard.compat.proto.graph_pb2 import GraphDef
|
||||
from tensorboard.compat.proto.types_pb2 import DataType
|
||||
from torch.utils.tensorboard import summary, SummaryWriter
|
||||
from torch.utils.tensorboard._utils import _prepare_video, convert_to_HWC
|
||||
from torch.utils.tensorboard._convert_np import make_np
|
||||
@ -558,34 +557,6 @@ def write_proto(str_to_compare, function_ptr):
|
||||
with open(expected_file, 'w') as f:
|
||||
f.write(str(str_to_compare))
|
||||
|
||||
class TestTensorProtoSummary(TestCase):
|
||||
def test_float_tensor_proto(self):
|
||||
float_values = [1., 2., 3.]
|
||||
actual_proto = summary.tensor_proto('dummy', torch.tensor(float_values)).value[0].tensor
|
||||
self.assertEqual(actual_proto.float_val, float_values)
|
||||
self.assertTrue(actual_proto.dtype == DataType.DT_FLOAT)
|
||||
|
||||
def test_int_tensor_proto(self):
|
||||
int_values = [1, 2, 3]
|
||||
actual_proto = summary.tensor_proto('dummy', torch.tensor(int_values, dtype=torch.int32)).value[0].tensor
|
||||
self.assertEqual(actual_proto.int_val, int_values)
|
||||
self.assertTrue(actual_proto.dtype == DataType.DT_INT32)
|
||||
|
||||
def test_scalar_tensor_proto(self):
|
||||
scalar_value = 0.1
|
||||
actual_proto = summary.tensor_proto('dummy', torch.tensor(scalar_value)).value[0].tensor
|
||||
self.assertAlmostEqual(actual_proto.float_val[0], scalar_value)
|
||||
|
||||
def test_complex_tensor_proto(self):
|
||||
real = torch.tensor([1., 2.])
|
||||
imag = torch.tensor([3., 4.])
|
||||
actual_proto = summary.tensor_proto('dummy', torch.complex(real, imag)).value[0].tensor
|
||||
self.assertEqual(actual_proto.scomplex_val, [1., 3., 2., 4.])
|
||||
|
||||
def test_empty_tensor_proto(self):
|
||||
actual_proto = summary.tensor_proto('dummy', torch.empty(0)).value[0].tensor
|
||||
self.assertEqual(actual_proto.float_val, [])
|
||||
|
||||
class TestTensorBoardPytorchGraph(BaseTestCase):
|
||||
def test_pytorch_graph(self):
|
||||
dummy_input = (torch.zeros(1, 3),)
|
||||
|
Reference in New Issue
Block a user