mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Disallow changing the device of a tensor via set_. (#18832)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/18832 ghimport-source-id: fde4ad90541ba52dfa02bdd83466f17e6541e535 Stack from [ghstack](https://github.com/ezyang/ghstack): * #18833 [STACK] Cache device on TensorImpl; clean up TensorImpl constructors. * **#18832 [STACK] Disallow changing the device of a tensor via set_.** * #18831 [STACK] Stop swapping in Storages of the wrong device for Tensors. This is necessary to cache the device on a TensorImpl. Differential Revision: D14766231 fbshipit-source-id: bba61634b2d6252ac0697b96033c9eea680956e8
This commit is contained in:
committed by
Facebook Github Bot
parent
15b318de84
commit
8732a1b42e
@ -56,9 +56,10 @@ void THTensor_setStorageNd(THTensor *self, THStorage *storage, ptrdiff_t storage
|
||||
}
|
||||
|
||||
/* storageOffset */
|
||||
if(storageOffset < 0)
|
||||
if(storageOffset < 0) {
|
||||
THError("Tensor: invalid storage offset");
|
||||
self->set_storage_offset(storageOffset);
|
||||
}
|
||||
self->set_storage_offset(storageOffset);
|
||||
|
||||
/* size and stride */
|
||||
THTensor_resizeNd(self, nDimension, size, stride);
|
||||
@ -160,5 +161,15 @@ void THTensor_stealAndSetStoragePtr(THTensor* tensor, THStorage* storage) {
|
||||
// Caffe2 might have tensors whose storages are null, but we
|
||||
// don't allow it in PyTorch.
|
||||
AT_ASSERT(storage);
|
||||
// Caffe2 also has uninitialized dtype states, which we disallow here
|
||||
AT_ASSERT(tensor->storage().dtype() == storage->dtype());
|
||||
|
||||
// We used to allow this, but this breaks device caching,
|
||||
// see Note [We regret making Variable hold a Tensor]
|
||||
// Let's put an actual error message for this one.
|
||||
AT_CHECK(tensor->storage().device() == storage->device(),
|
||||
"Attempted to set the storage of a tensor on device ", tensor->storage().device(),
|
||||
" to a storage on different device ", storage->device(),
|
||||
". This is no longer allowed; the devices must match.");
|
||||
tensor->set_storage(at::Storage(c10::intrusive_ptr<THStorage>::reclaim(storage)));
|
||||
}
|
||||
|
@ -8352,6 +8352,42 @@ class _TestTorchMixin(object):
|
||||
t1.set_(t2)
|
||||
self.assertEqual(t1.storage()._cdata, t2.storage()._cdata)
|
||||
|
||||
def test_tensor_set_errors(self):
|
||||
f_cpu = torch.randn((2, 3), dtype=torch.float32)
|
||||
d_cpu = torch.randn((2, 3), dtype=torch.float64)
|
||||
|
||||
# change dtype
|
||||
self.assertRaises(RuntimeError, lambda: f_cpu.set_(d_cpu.storage()))
|
||||
self.assertRaises(RuntimeError,
|
||||
lambda: f_cpu.set_(d_cpu.storage(), 0, d_cpu.size(), d_cpu.stride()))
|
||||
self.assertRaises(RuntimeError, lambda: f_cpu.set_(d_cpu))
|
||||
|
||||
# change device
|
||||
if torch.cuda.is_available():
|
||||
f_cuda = torch.randn((2, 3), dtype=torch.float32, device='cuda')
|
||||
|
||||
# cpu -> cuda
|
||||
self.assertRaises(RuntimeError, lambda: f_cpu.set_(f_cuda.storage()))
|
||||
self.assertRaises(RuntimeError,
|
||||
lambda: f_cpu.set_(f_cuda.storage(), 0, f_cuda.size(), f_cuda.stride()))
|
||||
self.assertRaises(RuntimeError, lambda: f_cpu.set_(f_cuda))
|
||||
|
||||
# cuda -> cpu
|
||||
self.assertRaises(RuntimeError, lambda: f_cuda.set_(f_cpu.storage()))
|
||||
self.assertRaises(RuntimeError,
|
||||
lambda: f_cuda.set_(f_cpu.storage(), 0, f_cpu.size(), f_cpu.stride()))
|
||||
self.assertRaises(RuntimeError, lambda: f_cuda.set_(f_cpu))
|
||||
|
||||
@unittest.skipIf(torch.cuda.device_count() < 2, 'less than 2 GPUs detected')
|
||||
def test_tensor_set_errors_multigpu(self):
|
||||
f_cuda0 = torch.randn((2, 3), dtype=torch.float32, device='cuda:0')
|
||||
f_cuda1 = torch.randn((2, 3), dtype=torch.float32, device='cuda:1')
|
||||
|
||||
self.assertRaises(RuntimeError, lambda: f_cuda0.set_(f_cuda1.storage()))
|
||||
self.assertRaises(RuntimeError,
|
||||
lambda: f_cuda0.set_(f_cuda1.storage(), 0, f_cuda1.size(), f_cuda1.stride()))
|
||||
self.assertRaises(RuntimeError, lambda: f_cuda0.set_(f_cuda1))
|
||||
|
||||
def test_equal(self):
|
||||
# Contiguous, 1D
|
||||
t1 = torch.Tensor((3, 4, 9, 10))
|
||||
|
@ -229,7 +229,7 @@ at::Tensor ScriptModuleDeserializer::loadTensor(
|
||||
.set_(storage_it->second, tensor_proto.offset(), dims, strides);
|
||||
} else if (device.type() == at::DeviceType::CUDA) {
|
||||
result =
|
||||
at::empty({0}, at::CUDA(type).options())
|
||||
at::empty({0}, c10::TensorOptions(type).device(storage_it->second.device()))
|
||||
.set_(storage_it->second, tensor_proto.offset(), dims, strides);
|
||||
}
|
||||
AT_ASSERT(result.defined());
|
||||
|
Reference in New Issue
Block a user