Disallow changing the device of a tensor via set_. (#18832)

Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/18832
ghimport-source-id: fde4ad90541ba52dfa02bdd83466f17e6541e535

Stack from [ghstack](https://github.com/ezyang/ghstack):
* #18833 [STACK] Cache device on TensorImpl; clean up TensorImpl constructors.
* **#18832 [STACK] Disallow changing the device of a tensor via set_.**
* #18831 [STACK] Stop swapping in Storages of the wrong device for Tensors.

This is necessary to cache the device on a TensorImpl.

Differential Revision: D14766231

fbshipit-source-id: bba61634b2d6252ac0697b96033c9eea680956e8
This commit is contained in:
Gregory Chanan
2019-04-04 11:12:13 -07:00
committed by Facebook Github Bot
parent 15b318de84
commit 8732a1b42e
3 changed files with 50 additions and 3 deletions

View File

@ -56,9 +56,10 @@ void THTensor_setStorageNd(THTensor *self, THStorage *storage, ptrdiff_t storage
}
/* storageOffset */
if(storageOffset < 0)
if(storageOffset < 0) {
THError("Tensor: invalid storage offset");
self->set_storage_offset(storageOffset);
}
self->set_storage_offset(storageOffset);
/* size and stride */
THTensor_resizeNd(self, nDimension, size, stride);
@ -160,5 +161,15 @@ void THTensor_stealAndSetStoragePtr(THTensor* tensor, THStorage* storage) {
// Caffe2 might have tensors whose storages are null, but we
// don't allow it in PyTorch.
AT_ASSERT(storage);
// Caffe2 also has uninitialized dtype states, which we disallow here
AT_ASSERT(tensor->storage().dtype() == storage->dtype());
// We used to allow this, but this breaks device caching,
// see Note [We regret making Variable hold a Tensor]
// Let's put an actual error message for this one.
AT_CHECK(tensor->storage().device() == storage->device(),
"Attempted to set the storage of a tensor on device ", tensor->storage().device(),
" to a storage on different device ", storage->device(),
". This is no longer allowed; the devices must match.");
tensor->set_storage(at::Storage(c10::intrusive_ptr<THStorage>::reclaim(storage)));
}

View File

@ -8352,6 +8352,42 @@ class _TestTorchMixin(object):
t1.set_(t2)
self.assertEqual(t1.storage()._cdata, t2.storage()._cdata)
def test_tensor_set_errors(self):
f_cpu = torch.randn((2, 3), dtype=torch.float32)
d_cpu = torch.randn((2, 3), dtype=torch.float64)
# change dtype
self.assertRaises(RuntimeError, lambda: f_cpu.set_(d_cpu.storage()))
self.assertRaises(RuntimeError,
lambda: f_cpu.set_(d_cpu.storage(), 0, d_cpu.size(), d_cpu.stride()))
self.assertRaises(RuntimeError, lambda: f_cpu.set_(d_cpu))
# change device
if torch.cuda.is_available():
f_cuda = torch.randn((2, 3), dtype=torch.float32, device='cuda')
# cpu -> cuda
self.assertRaises(RuntimeError, lambda: f_cpu.set_(f_cuda.storage()))
self.assertRaises(RuntimeError,
lambda: f_cpu.set_(f_cuda.storage(), 0, f_cuda.size(), f_cuda.stride()))
self.assertRaises(RuntimeError, lambda: f_cpu.set_(f_cuda))
# cuda -> cpu
self.assertRaises(RuntimeError, lambda: f_cuda.set_(f_cpu.storage()))
self.assertRaises(RuntimeError,
lambda: f_cuda.set_(f_cpu.storage(), 0, f_cpu.size(), f_cpu.stride()))
self.assertRaises(RuntimeError, lambda: f_cuda.set_(f_cpu))
@unittest.skipIf(torch.cuda.device_count() < 2, 'less than 2 GPUs detected')
def test_tensor_set_errors_multigpu(self):
f_cuda0 = torch.randn((2, 3), dtype=torch.float32, device='cuda:0')
f_cuda1 = torch.randn((2, 3), dtype=torch.float32, device='cuda:1')
self.assertRaises(RuntimeError, lambda: f_cuda0.set_(f_cuda1.storage()))
self.assertRaises(RuntimeError,
lambda: f_cuda0.set_(f_cuda1.storage(), 0, f_cuda1.size(), f_cuda1.stride()))
self.assertRaises(RuntimeError, lambda: f_cuda0.set_(f_cuda1))
def test_equal(self):
# Contiguous, 1D
t1 = torch.Tensor((3, 4, 9, 10))

View File

@ -229,7 +229,7 @@ at::Tensor ScriptModuleDeserializer::loadTensor(
.set_(storage_it->second, tensor_proto.offset(), dims, strides);
} else if (device.type() == at::DeviceType::CUDA) {
result =
at::empty({0}, at::CUDA(type).options())
at::empty({0}, c10::TensorOptions(type).device(storage_it->second.device()))
.set_(storage_it->second, tensor_proto.offset(), dims, strides);
}
AT_ASSERT(result.defined());