Files
pytorch/test/test_mkldnn.py
Will Feng 8cde4c4d22 Remove Variable::Impl and DifferentiableViewImpl (#17072)
Summary:
As part of the Variable/Tensor merge work: https://github.com/pytorch/pytorch/issues/13638, we make the following changes in this PR:
1. Remove the `Variable::Impl` class and the `DifferentiableViewImpl` class
2. Change all `Variable.data()` call sites to either use `Variable` directly, or use `Variable.tensor_data()`
3. Remove `Variable.data()` API
3. Add `Variable.variable_data()` that matches `tensor.data` in Python API, which creates a new `Variable` that shares the same storage and tensor metadata with the original `Variable`, but with a completely new autograd history.

After this PR, Variable doesn't wrap a Tensor internally anymore, and both Variable and Tensor use the same TensorImpl class as its `impl_`. The only difference is that Variable always has AutogradMeta in its TensorImpl, but Tensor doesn't.

**Note that this PR is BC-breaking in the following use cases:**

**Use Case 1:**
Previously, `x.data = y` works even if `x` and `y` are of different TensorImpl type (e.g. `x` is a CPU dense tensor whose impl is of type TensorImpl, while `y` is a CPU sparse tensor whose impl is of type SparseTensorImpl). However, after this PR, `x.data = y` doesn't work anymore if `x` and `y` are of different TensorImpl type, because the underlying implementation `variable.set_data(tensor)` no longer works if `variable` and `tensor` have different TensorImpl type.

**Use Case 2:**
If a tensor `x`'s `grad` is sparse, accumulating dense gradients to `x` will change the tensor that `x.grad` is pointing to. This is better illustrated with the following example:
```python
params = torch.tensor([1.5, 1.5]).requires_grad_()
with torch.no_grad():
    # Change gradient to a sparse tensor
    params.grad = torch.sparse_coo_tensor(torch.tensor([[1, 1]]).long(), torch.tensor([1., 1.]))

grad_saved = params.grad
params.backward(torch.tensor([1.5, 1.5]))
assert id(grad_saved) == id(params.grad)  # This will fail after this PR
```
The assertion in the last line will fail after this PR, because adding dense gradients to sparse gradients will change the `params.grad` tensor reference.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/17072

Differential Revision: D14075257

Pulled By: yf225

fbshipit-source-id: 0e681df641270dea586042dd26db59f2e76b5957
2019-05-23 21:09:04 -07:00

288 lines
11 KiB
Python

from __future__ import absolute_import, division, print_function, unicode_literals
import copy
import unittest
import torch
import torch.jit
from torch.utils import mkldnn as mkldnn_utils
from common_utils import TestCase, run_tests, TemporaryFileName
from torch.autograd.gradcheck import gradgradcheck, gradcheck
# Comment the line below to find out the CI machines having MKL-DNN build disabled
@unittest.skipIf(not torch._C.has_mkldnn, "MKL-DNN build is disabled")
class TestMkldnn(TestCase):
def test_conversion(self):
for cpu_tensor in [torch.randn((1, 2, 3, 4),
dtype=torch.float, device=torch.device('cpu')),
torch.randn((1, 2, 3, 4, 5),
dtype=torch.float, device=torch.device('cpu'))[:, :, :, :, 1]]:
cpu_tensor.requires_grad_()
mkldnn_tensor = cpu_tensor.to_mkldnn()
cpu_tensor_1 = mkldnn_tensor.to_dense()
self.assertEqual(cpu_tensor, cpu_tensor_1)
self.assertEqual(mkldnn_tensor.dtype, torch.float)
self.assertEqual(mkldnn_tensor.device, torch.device('cpu'))
self.assertEqual(mkldnn_tensor.size(), torch.Size([1, 2, 3, 4]))
self.assertEqual(mkldnn_tensor.numel(), cpu_tensor.numel())
self.assertEqual(mkldnn_tensor.element_size(), cpu_tensor.element_size())
self.assertRaisesRegex(RuntimeError,
"Cannot access data pointer of Tensor that doesn't have storage",
lambda: mkldnn_tensor.data_ptr() != 0)
def test_unsupported(self):
# unsupported types and unsupported types with gpu
for dtype in [torch.double, torch.half, torch.uint8, torch.int8,
torch.short, torch.int, torch.long]:
with self.assertRaises(RuntimeError) as context:
torch.randn(1, 2, 3, 4, dtype=dtype, device=torch.device('cpu')).to_mkldnn()
if torch.cuda.is_available():
with self.assertRaises(RuntimeError) as context:
torch.randn(1, 2, 3, 4, dtype=dtype, device=torch.device('cuda')).to_mkldnn()
# supported type with gpu
if torch.cuda.is_available():
with self.assertRaises(RuntimeError) as context:
torch.randn(1, 2, 3, 4, dtype=torch.float, device=torch.device('cuda')).to_mkldnn()
# some factory functions
for creator in [torch.empty, torch.ones, torch.zeros, torch.randn, torch.rand]:
with self.assertRaises(RuntimeError) as context:
creator(1, 2, 3, 4, dtype=torch.float, device=torch.device('cpu'), layout=torch._mkldnn)
def test_autograd_to_mkldnn(self):
# MKLDNN only supports float32
root = torch.randn(4, 5, dtype=torch.float32, requires_grad=True)
def func(root):
return root.to_mkldnn().to_dense()
# because MKLDNN only supports float32, we need to lessen the precision.
# these numbers are just empirical results that seem to work.
self.assertWarnsRegex(lambda: gradcheck(func, [root], atol=4e-2, rtol=1e-2),
'double precision floating point')
self.assertWarnsRegex(lambda: gradgradcheck(func, [root], atol=4e-2, rtol=1e-2),
'double precision floating point')
def test_autograd_from_mkldnn(self):
# MKLDNN only supports float32
root = torch.randn(4, 5, dtype=torch.float32).to_mkldnn().requires_grad_()
def func(root):
return root.to_dense()
# because MKLDNN only supports float32, we need to lessen the precision.
# these numbers are just empirical results that seem to work.
self.assertWarnsRegex(lambda: gradcheck(func, [root], atol=4e-2, rtol=1e-2),
'double precision floating point')
def test_detach(self):
root = torch.randn(4, 5, dtype=torch.float32).to_mkldnn().requires_grad_()
detach = root.detach()
self.assertEqual((4, 5), detach.size())
self.assertFalse(detach.requires_grad)
self.assertTrue(root.requires_grad)
detach_ = root.detach_()
self.assertEqual((4, 5), detach_.size())
self.assertFalse(detach_.requires_grad)
self.assertFalse(root.requires_grad)
def test_repr(self):
self.assertTrue("layout=torch._mkldnn" in str(torch.randn((1, 2, 3, 4),
dtype=torch.float, device=torch.device('cpu')).to_mkldnn()))
def test_conv2d(self):
for groups in [1, 4]:
N = torch.randint(3, 10, (1,)).item()
C = torch.randint(1, 3, (1,)).item() * groups
M = torch.randint(1, 3, (1,)).item() * groups
x = torch.randn(N, C, 224, 224, dtype=torch.float32) * 100
for bias in [True, False]:
conv2d = torch.nn.Conv2d(in_channels=C,
out_channels=M,
kernel_size=3,
stride=2,
padding=1,
bias=bias,
groups=groups).float()
mkldnn_conv2d = mkldnn_utils.to_mkldnn(copy.deepcopy(conv2d))
self.assertEqual(
conv2d(x),
mkldnn_conv2d(x.to_mkldnn()).to_dense())
self._test_serialization(mkldnn_conv2d, (x.to_mkldnn(),))
self._test_tracing(mkldnn_conv2d, (x.to_mkldnn(),))
def test_relu(self):
x = torch.randn((4, 5), dtype=torch.float32) * 10
self.assertEqual(torch.relu(x), torch.relu(x.to_mkldnn()).to_dense())
def test_relu_(self):
x1 = torch.randn((4, 5), dtype=torch.float32) * 10
x2 = x1.clone().to_mkldnn()
self.assertEqual(torch.relu_(x1), torch.relu_(x2).to_dense())
def test_max_pool2d(self):
N = torch.randint(3, 10, (1,)).item()
C = torch.randint(3, 10, (1,)).item()
x = torch.randn(N, C, 64, 64, dtype=torch.float32) * 10
max_pool2d = torch.nn.MaxPool2d(
kernel_size=3,
stride=2,
padding=1)
self.assertEqual(
max_pool2d(x),
max_pool2d(x.to_mkldnn()).to_dense())
def test_avg_pool2d(self):
N = torch.randint(3, 10, (1,)).item()
C = torch.randint(3, 10, (1,)).item()
x = torch.randn(N, C, 64, 64, dtype=torch.float32) * 10
for count_include_pad in [True, False]:
avg_pool2d = torch.nn.AvgPool2d(
kernel_size=3,
stride=2,
padding=1,
count_include_pad=count_include_pad)
self.assertEqual(
avg_pool2d(x),
avg_pool2d(x.to_mkldnn()).to_dense())
def test_adaptive_avg_pool2d(self):
N = torch.randint(3, 10, (1,)).item()
C = torch.randint(3, 10, (1,)).item()
x = torch.randn(N, C, 224, 224, dtype=torch.float32) * 100
adaptive_avg_pool2d = torch.nn.AdaptiveAvgPool2d(7)
self.assertEqual(
adaptive_avg_pool2d(x),
adaptive_avg_pool2d(x.to_mkldnn()).to_dense())
def test_batch_norm2d(self):
N = torch.randint(3, 10, (1,)).item()
C = torch.randint(3, 100, (1,)).item()
x = torch.randn(N, C, 35, 45, dtype=torch.float32) * 10
# TODO: support training
for train in [False]:
bn = torch.nn.BatchNorm2d(C).float().train(train)
mkldnn_bn = mkldnn_utils.to_mkldnn(copy.deepcopy(bn))
self.assertEqual(
bn(x),
mkldnn_bn(x.to_mkldnn()).to_dense())
self._test_serialization(mkldnn_bn, (x.to_mkldnn(),))
self._test_tracing(mkldnn_bn, (x.to_mkldnn(),))
def test_add(self):
N = torch.randint(3, 10, (1,)).item()
C = torch.randint(3, 100, (1,)).item()
alpha = torch.randn(1, dtype=torch.float32).item()
x = torch.randn(N, C, 35, 45, dtype=torch.float32) * 10
y = torch.randn(N, C, 35, 45, dtype=torch.float32) * 10
mx = x.to_mkldnn()
my = y.to_mkldnn()
# add
self.assertEqual(
x + y,
(mx + my).to_dense())
self.assertEqual(
torch.add(x, y, alpha=alpha),
torch.add(mx, my, alpha=alpha).to_dense())
# add_
x += y
mx += my
self.assertEqual(x, mx.to_dense())
# add_out
out = x.clone()
mkldnn_out = out.to_mkldnn()
torch.add(x, y, alpha=alpha, out=out)
torch.add(mx, my, alpha=alpha, out=mkldnn_out)
self.assertEqual(out, mkldnn_out.to_dense())
def test_view(self):
x = torch.randn(3, 4, 5, dtype=torch.float32).to_mkldnn()
self.assertRaisesRegex(RuntimeError,
"Change to use reshape",
lambda: x.view(x.size(0), -1))
def test_reshape(self):
x = torch.randn(3, 4, 5, dtype=torch.float32) * 10
size = (x.size(0), -1)
self.assertEqual(
x.reshape(size),
x.to_mkldnn().reshape(size).to_dense(),
)
def test_clone(self):
x = torch.randn(4, 5, dtype=torch.float32) * 10
self.assertEqual(
x.clone(),
x.to_mkldnn().clone().to_dense(),
)
def test_linear(self):
in_features = torch.randint(3, 10, (1,)).item()
out_features = torch.randint(3, 100, (1,)).item()
x = torch.randn(3, in_features, dtype=torch.float32) * 10
for bias in [True, False]:
linear = torch.nn.Linear(in_features, out_features, bias=bias).float()
mkldnn_linear = mkldnn_utils.to_mkldnn(copy.deepcopy(linear))
self.assertEqual(
linear(x),
mkldnn_linear(x.to_mkldnn()).to_dense())
self._test_serialization(mkldnn_linear, (x.to_mkldnn(),))
self._test_tracing(mkldnn_linear, (x.to_mkldnn(),))
def test_sigmoid(self):
x = torch.randn(4, 5, dtype=torch.float32) * 10
mkldnn_x = x.to_mkldnn()
self.assertEqual(
torch.sigmoid(x),
torch.sigmoid(mkldnn_x).to_dense(),
)
# inplace
torch.sigmoid_(x)
torch.sigmoid_(mkldnn_x)
self.assertEqual(x, mkldnn_x.to_dense())
def _test_serialization(self, module, inputs):
with TemporaryFileName() as fname:
torch.jit.save(module, fname)
loaded = torch.jit.load(fname)
self.assertEqual(
module(*inputs).to_dense(),
loaded(*inputs).to_dense())
def _test_tracing(self, module, inputs):
traced = torch.jit.trace(module, inputs, check_trace=False)
self.assertEqual(
module(*inputs).to_dense(),
traced(*inputs).to_dense())
def test_set_data_tensorimpl_type(self):
# Dense tensor has impl of type `TensorImpl`, while MKL-DNN tensor has impl
# of type `OpaqueTensorImpl<IDeepTensorWrapperPtr>`.
x = torch.randn((1, 2), dtype=torch.float, device=torch.device('cpu'))
x_mkldnn = x.to_mkldnn()
with self.assertRaisesRegex(RuntimeError, 'different types of TensorImpl'):
x.data = x_mkldnn
if __name__ == '__main__':
run_tests()