Remove unnecessary printing from tests

Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/19606

Differential Revision: D15046583

Pulled By: ezyang

fbshipit-source-id: ea9bb691d23855e7eddbabe68bf112a726641ba4
This commit is contained in:
Tongzhou Wang
2019-04-23 09:16:05 -07:00
committed by Facebook Github Bot
parent 36084908e4
commit 3b4d4ef503
2 changed files with 18 additions and 9 deletions

View File

@ -2691,8 +2691,6 @@ class _TestTorchMixin(object):
zero_point = 2
qr = r.quantize_linear(scale, zero_point)
rqr = qr.dequantize()
print(r.numpy())
print(rqr.numpy())
self.assertTrue(np.allclose(r.numpy(), rqr.numpy(), atol=2 / scale))
def test_qtensor_creation(self):
@ -2701,12 +2699,7 @@ class _TestTorchMixin(object):
val = 100
numel = 10
q = torch._empty_affine_quantized(numel, dtype=torch.qint8, scale=scale, zero_point=zero_point)
# for i in range(numel):
# # wait for th_fill
# q[i] = val
# r = q.dequantize()
# for i in range(numel):
# self.assertEqual(r[i], (val - zero_point) * scale)
# TODO: check dequantized values?
@unittest.skipIf(torch.cuda.device_count() < 2, 'fewer than 2 GPUs detected')
def test_device_guard(self):

View File

@ -6,6 +6,7 @@ import shutil
import random
import tempfile
import unittest
import contextlib
import torch
import torch.nn as nn
import torch.utils.data
@ -492,11 +493,26 @@ class TestONNXUtils(TestCase):
try_check_onnx_broadcast(dims1, dims2, True, False)
# Errors will still be raised and reported
@contextlib.contextmanager
def suppress_stderr():
original = sys.stderr
sys.stderr = open(os.devnull, 'w')
yield
sys.stderr = original
class TestHub(TestCase):
@classmethod
@skipIfNoTorchVision
def setUpClass(cls):
cls.resnet18_pretrained = models.__dict__['resnet18'](pretrained=True).state_dict()
# The current torchvision code does not provide a way to disable tqdm
# progress bar, leading this download printing a huge number of lines
# in CI.
# TODO: remove this context manager when torchvision provides a way.
# See pytorch/torchvision#862
with suppress_stderr():
cls.resnet18_pretrained = models.__dict__['resnet18'](pretrained=True).state_dict()
@skipIfNoTorchVision
def test_load_from_github(self):