mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Move TensorImpl::CopyFrom to caffe2::Tensor (2/2) (#14858)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/14858 This diff doesn't change logic but just takes the existing code and moves it to caffe2::Tensor Reviewed By: ezyang Differential Revision: D13365817 fbshipit-source-id: bc73b27a793602cb14200dcdf357aa63233da43c
This commit is contained in:
committed by
Facebook Github Bot
parent
070f33f154
commit
bb8ee2de0f
@ -808,80 +808,6 @@ struct C10_API TensorImpl : public c10::intrusive_ptr_target {
|
||||
return storage_.device();
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Copies the data from a source tensor, with a contex provided to
|
||||
* carry out the underlying memcpy operation. This method respects
|
||||
* caffe2_keep_on_shrink.
|
||||
*
|
||||
* After CopyFrom, this function guarantees that the destination tensor will
|
||||
* have the same initialization state and dtype as src. This function
|
||||
* preserves the DeviceType of the source tensor (so, e.g., if you allocate
|
||||
* a tensor on CPU and then CopyFrom a CUDA tensor, that will to a
|
||||
* CUDA-to-CPU transfer).
|
||||
*
|
||||
* 'async' parameter triggers async copy for CUDA tensors
|
||||
*/
|
||||
void CopyFrom(const TensorImpl& src, bool async = false) {
|
||||
AT_ASSERT(!is_variable());
|
||||
AT_ASSERTM(
|
||||
src.is_contiguous(),
|
||||
"Right now only copy of contiguous source Tensor is supported.");
|
||||
AT_ASSERTM(
|
||||
src.storage_initialized(),
|
||||
"Cannot copy from an uninitialized Tensor");
|
||||
|
||||
if (&src == this) {
|
||||
return;
|
||||
}
|
||||
|
||||
// Test if we need to allocate a new storage
|
||||
// Uninitialized storages are guaranteed to be uniquely owned,
|
||||
// so we don't need to swap in this case.
|
||||
// If the dtype changed, we need to reallocate storage.
|
||||
if (dtype() != src.dtype()) {
|
||||
// NB: copy preserves device_type
|
||||
// This storage will get initialized by the mutable_data call below.
|
||||
set_storage(at::Storage(device_type(), src.dtype()));
|
||||
}
|
||||
Resize(src.sizes());
|
||||
|
||||
if (numel() > 0) {
|
||||
if (dtype().copy()) {
|
||||
AT_ASSERTM(
|
||||
device_type() == DeviceType::CPU,
|
||||
"In CopyFrom source and dest tensors must both be CPU for "
|
||||
"non-POD copy, but dest tensor was ",
|
||||
device_type());
|
||||
AT_ASSERTM(
|
||||
src.device_type() == DeviceType::CPU,
|
||||
"In CopyFrom source and dest tensors must both be CPU for "
|
||||
"non-POD copy, but src tensor was ",
|
||||
src.device_type());
|
||||
dtype().copy()(src.data(), raw_mutable_data(data_type_), numel());
|
||||
} else {
|
||||
// The following copy uses the current (thread local) stream for copying
|
||||
// and also takes the GPU id from the device() field passed in.
|
||||
//
|
||||
// TODO: Potentially more enforcements are necessary to avoid accidental
|
||||
// switch to sync copy if the currently set device is wrong.
|
||||
//
|
||||
// Specifically, we might need to switch to a different context device
|
||||
// here explicitly to avoid relying on user synchronizing things
|
||||
// properly.
|
||||
//
|
||||
// note: raw_mutable_data initializes device here
|
||||
void* new_data = raw_mutable_data(dtype());
|
||||
CopyBytes(
|
||||
numel() * itemsize(),
|
||||
src.data(),
|
||||
src.device(),
|
||||
new_data,
|
||||
device(),
|
||||
async);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Extends the outer-most dimension of this tensor by num elements,
|
||||
* preserving the existing data.
|
||||
|
@ -104,8 +104,78 @@ class CAFFE2_API Tensor final {
|
||||
return impl_.get()->GetDevice();
|
||||
}
|
||||
|
||||
void CopyFrom(const Tensor& src, bool async = false) const {
|
||||
impl_.get()->CopyFrom(*src.impl_, async);
|
||||
/**
|
||||
* @brief Copies the data from a source tensor, with a contex provided to
|
||||
* carry out the underlying memcpy operation. This method respects
|
||||
* caffe2_keep_on_shrink.
|
||||
*
|
||||
* After CopyFrom, this function guarantees that the destination tensor will
|
||||
* have the same initialization state and dtype as src. This function
|
||||
* preserves the DeviceType of the source tensor (so, e.g., if you allocate
|
||||
* a tensor on CPU and then CopyFrom a CUDA tensor, that will to a
|
||||
* CUDA-to-CPU transfer).
|
||||
*
|
||||
* 'async' parameter triggers async copy for CUDA tensors
|
||||
*/
|
||||
void CopyFrom(const Tensor& src, bool async = false) {
|
||||
AT_ASSERT(!impl_->is_variable());
|
||||
AT_ASSERTM(
|
||||
src.impl_->is_contiguous(),
|
||||
"Right now only copy of contiguous source Tensor is supported.");
|
||||
AT_ASSERTM(
|
||||
src.impl_->storage_initialized(),
|
||||
"Cannot copy from an uninitialized Tensor");
|
||||
|
||||
if (src.impl_.get() == impl_.get()) {
|
||||
return;
|
||||
}
|
||||
|
||||
// Test if we need to allocate a new storage
|
||||
// Uninitialized storages are guaranteed to be uniquely owned,
|
||||
// so we don't need to swap in dst case.
|
||||
// If the dtype changed, we need to reallocate storage.
|
||||
if (impl_->dtype() != src.impl_->dtype()) {
|
||||
// NB: copy preserves device_type
|
||||
// This storage will get initialized by the mutable_data call below.
|
||||
impl_->set_storage(at::Storage(impl_->device_type(), src.impl_->dtype()));
|
||||
}
|
||||
impl_->Resize(src.impl_->sizes());
|
||||
|
||||
if (impl_->numel() > 0) {
|
||||
if (impl_->dtype().copy()) {
|
||||
AT_ASSERTM(
|
||||
impl_->device_type() == ::at::DeviceType::CPU,
|
||||
"In CopyFrom source and dest tensors must both be CPU for "
|
||||
"non-POD copy, but dest tensor was ",
|
||||
impl_->device_type());
|
||||
AT_ASSERTM(
|
||||
src.impl_->device_type() == ::at::DeviceType::CPU,
|
||||
"In CopyFrom source and dest tensors must both be CPU for "
|
||||
"non-POD copy, but src tensor was ",
|
||||
src.impl_->device_type());
|
||||
impl_->dtype().copy()(src.impl_->data(), impl_->raw_mutable_data(impl_->dtype()), impl_->numel());
|
||||
} else {
|
||||
// The following copy uses the current (thread local) stream for copying
|
||||
// and also takes the GPU id from the device() field passed in.
|
||||
//
|
||||
// TODO: Potentially more enforcements are necessary to avoid accidental
|
||||
// switch to sync copy if the currently set device is wrong.
|
||||
//
|
||||
// Specifically, we might need to switch to a different context device
|
||||
// here explicitly to avoid relying on user synchronizing things
|
||||
// properly.
|
||||
//
|
||||
// note: raw_mutable_data initializes device here
|
||||
void* new_data = impl_->raw_mutable_data(impl_->dtype());
|
||||
at::CopyBytes(
|
||||
impl_->numel() * impl_->itemsize(),
|
||||
src.impl_->data(),
|
||||
src.impl_->device(),
|
||||
new_data,
|
||||
impl_->device(),
|
||||
async);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
|
Reference in New Issue
Block a user