mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-21 05:34:18 +08:00
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/18598 ghimport-source-id: c74597e5e7437e94a43c163cee0639b20d0d0c6a Stack from [ghstack](https://github.com/ezyang/ghstack): * **#18598 Turn on F401: Unused import warning.** This was requested by someone at Facebook; this lint is turned on for Facebook by default. "Sure, why not." I had to noqa a number of imports in __init__. Hypothetically we're supposed to use __all__ in this case, but I was too lazy to fix it. Left for future work. Be careful! flake8-2 and flake8-3 behave differently with respect to import resolution for # type: comments. flake8-3 will report an import unused; flake8-2 will not. For now, I just noqa'd all these sites. All the changes were done by hand. Signed-off-by: Edward Z. Yang <ezyang@fb.com> Differential Revision: D14687478 fbshipit-source-id: 30d532381e914091aadfa0d2a5a89404819663e3
48 lines
1.6 KiB
Python
48 lines
1.6 KiB
Python
import torch
|
|
import torch.jit
|
|
import numpy as np
|
|
import unittest
|
|
from caffe2.python import core
|
|
from common_utils import TestCase, run_tests
|
|
|
|
|
|
def canonical(graph):
|
|
return str(torch._C._jit_pass_canonicalize(graph))
|
|
|
|
|
|
@unittest.skipIf("Relu_ENGINE_DNNLOWP" not in core._REGISTERED_OPERATORS, "fbgemm-based Caffe2 ops are not linked")
|
|
class TestQuantized(TestCase):
|
|
def test_relu(self):
|
|
a = (torch.tensor([4, 6, 1, 10], dtype=torch.uint8), 0.01, 5)
|
|
r = torch.ops.c10.quantized_relu(a)
|
|
np.testing.assert_equal(r[0].numpy(), torch.tensor([5, 6, 5, 10], dtype=torch.uint8).numpy())
|
|
np.testing.assert_almost_equal(0.01, r[1])
|
|
self.assertEqual(5, r[2])
|
|
|
|
def test_quantize(self):
|
|
a = (torch.tensor([4, 6, 1, 10], dtype=torch.uint8), 0.01, 5)
|
|
r = torch.ops.c10.dequantize(a)
|
|
np.testing.assert_almost_equal(r.numpy(), [-0.01, 0.01, -0.04, 0.05])
|
|
# default args
|
|
q_def = torch.ops.c10.quantize(r)
|
|
# specified
|
|
q = torch.ops.c10.quantize(r, scale=0.01, zero_point=5)
|
|
np.testing.assert_equal(q[0].numpy(), a[0].numpy())
|
|
np.testing.assert_almost_equal(q[1], a[1])
|
|
self.assertEqual(q[2], a[2])
|
|
|
|
def test_script(self):
|
|
@torch.jit.script
|
|
def foo(x):
|
|
# type: (Tuple[Tensor, float, int]) -> Tuple[Tensor, float, int]
|
|
return torch.ops.c10.quantized_relu(x)
|
|
self.assertExpectedInline(canonical(foo.graph), '''\
|
|
graph(%x : (Tensor, float, int)):
|
|
%1 : (Tensor, float, int) = c10::quantized_relu(%x)
|
|
return (%1)
|
|
''')
|
|
|
|
|
|
if __name__ == '__main__':
|
|
run_tests()
|