mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
change arg order of Copy/Memcpy to follow inputs-then-outputs convention
instead of C memcpy order -- from (dst, src, n) to (n, src, dst)
This commit is contained in:
@ -119,8 +119,8 @@ class Tensor {
|
||||
Tensor(const Tensor<dtype, SrcContext>& src, ContextForCopy* context)
|
||||
: data_(nullptr) {
|
||||
Reshape(src.dims());
|
||||
context->template Copy<dtype, Context, SrcContext>(
|
||||
mutable_data(), src.data(), src.size());
|
||||
context->template Copy<dtype, SrcContext, Context>(
|
||||
src.size(), src.data(), mutable_data());
|
||||
}
|
||||
|
||||
// Creates a tensor, and fills its contents with the given values. We need to
|
||||
@ -129,16 +129,16 @@ class Tensor {
|
||||
: data_(nullptr) {
|
||||
Reshape(dims);
|
||||
CHECK_EQ(values.size(), size_);
|
||||
context->template Copy<dtype, Context, CPUContext>(
|
||||
mutable_data(), values.data(), values.size());
|
||||
context->template Copy<dtype, CPUContext, Context>(
|
||||
values.size(), values.data(), mutable_data());
|
||||
}
|
||||
|
||||
// Special case of above: create a tensor of shape 1, and the given value.
|
||||
Tensor(const dtype& value, Context* context)
|
||||
: data_(nullptr) {
|
||||
Reshape(std::vector<int>());
|
||||
context->template Copy<dtype, Context, CPUContext>(
|
||||
mutable_data(), &value, 1);
|
||||
context->template Copy<dtype, CPUContext, Context>(
|
||||
1, &value, mutable_data());
|
||||
}
|
||||
|
||||
virtual ~Tensor() {}
|
||||
|
@ -23,9 +23,8 @@ class TensorSerializerFloat : public BlobSerializerBase {
|
||||
for (int i = 0; i < input.size(); ++i) {
|
||||
proto.add_float_data(0);
|
||||
}
|
||||
this->device_context_.template Copy<float, CPUContext, DeviceContext>(
|
||||
proto.mutable_float_data()->mutable_data(),
|
||||
input.data(), input.size());
|
||||
this->device_context_.template Copy<float, DeviceContext, CPUContext>(
|
||||
input.size(), input.data(), proto.mutable_float_data()->mutable_data());
|
||||
return proto.SerializeAsString();
|
||||
}
|
||||
|
||||
@ -53,9 +52,8 @@ class TensorSerializerInt32 : public BlobSerializerBase {
|
||||
for (int i = 0; i < input.size(); ++i) {
|
||||
proto.add_int32_data(0);
|
||||
}
|
||||
this->device_context_.template Copy<int, CPUContext, DeviceContext>(
|
||||
proto.mutable_int32_data()->mutable_data(),
|
||||
input.data(), input.size());
|
||||
this->device_context_.template Copy<int, DeviceContext, CPUContext>(
|
||||
input.size(), input.data(), proto.mutable_int32_data()->mutable_data());
|
||||
return proto.SerializeAsString();
|
||||
}
|
||||
|
||||
@ -80,8 +78,8 @@ class TensorSerializerBytes : public BlobSerializerBase {
|
||||
proto.add_dims(dim);
|
||||
}
|
||||
std::unique_ptr<char[]> buffer(new char[input.size()]);
|
||||
this->device_context_.template Copy<char, CPUContext, DeviceContext>(
|
||||
buffer.get(), input.data(), input.size());
|
||||
this->device_context_.template Copy<char, DeviceContext, CPUContext>(
|
||||
input.size(), input.data(), buffer.get());
|
||||
proto.set_byte_data(buffer, input.size());
|
||||
return proto.SerializeAsString();
|
||||
}
|
||||
|
@ -29,13 +29,13 @@ class CPUContext {
|
||||
static void Delete(void* data) { delete[] static_cast<char*>(data); }
|
||||
|
||||
// Two copy functions that deals with cross-device copies.
|
||||
template <class DstContext, class SrcContext>
|
||||
inline void Memcpy(void* dst, const void* src, size_t nbytes);
|
||||
template <typename T, class DstContext, class SrcContext>
|
||||
inline void Copy(T* dst, const T* src, int n) {
|
||||
Memcpy<DstContext, SrcContext>(static_cast<void*>(dst),
|
||||
template <class SrcContext, class DstContext>
|
||||
inline void Memcpy(size_t nbytes, const void* src, void* dst);
|
||||
template <typename T, class SrcContext, class DstContext>
|
||||
inline void Copy(int n, const T* src, T* dst) {
|
||||
Memcpy<SrcContext, DstContext>(n * sizeof(T),
|
||||
static_cast<const void*>(src),
|
||||
n * sizeof(T));
|
||||
static_cast<void*>(dst));
|
||||
}
|
||||
|
||||
protected:
|
||||
@ -44,7 +44,7 @@ class CPUContext {
|
||||
|
||||
template<>
|
||||
inline void CPUContext::Memcpy<CPUContext, CPUContext>(
|
||||
void* dst, const void* src, size_t nbytes) {
|
||||
size_t nbytes, const void* src, void* dst) {
|
||||
memcpy(dst, src, nbytes);
|
||||
}
|
||||
|
||||
|
@ -106,19 +106,19 @@ class CUDAContext {
|
||||
}
|
||||
}
|
||||
|
||||
template <class DstContext, class SrcContext>
|
||||
inline void Copy(void* dst, const void* src, size_t nbytes) {
|
||||
template <class SrcContext, class DstContext>
|
||||
inline void Copy(size_t nbytes, const void* src, void* dst) {
|
||||
CUDA_CHECK(cudaMemcpyAsync(
|
||||
dst, src, nbytes, cudaMemcpyDefault, cuda_stream_));
|
||||
// TODO(Yangqing): do we want to synchronize inside copy?
|
||||
CUDA_CHECK(cudaStreamSynchronize(cuda_stream_));
|
||||
}
|
||||
|
||||
template <typename T, class DstContext, class SrcContext>
|
||||
inline void Copy(T* dst, const T* src, int n) {
|
||||
Copy<DstContext, SrcContext>(static_cast<void*>(dst),
|
||||
template <typename T, class SrcContext, class DstContext>
|
||||
inline void Copy(int n, const T* src, T* dst) {
|
||||
Copy<SrcContext, DstContext>(n * sizeof(T),
|
||||
static_cast<const void*>(src),
|
||||
n * sizeof(T));
|
||||
static_cast<void*>(dst));
|
||||
}
|
||||
|
||||
protected:
|
||||
@ -132,10 +132,10 @@ class CUDAContext {
|
||||
// For the CPU context, we also allow a (probably expensive) function
|
||||
// to copy the data from a cuda context.
|
||||
template<>
|
||||
inline void CPUContext::Memcpy<CPUContext, CUDAContext>(
|
||||
void* dst, const void* src, size_t nbytes) {
|
||||
inline void CPUContext::Memcpy<CUDAContext, CPUContext>(
|
||||
size_t nbytes, const void* src, void* dst) {
|
||||
CUDAContext context;
|
||||
context.Copy<CPUContext, CUDAContext>(dst, src, nbytes);
|
||||
context.Copy<CUDAContext, CPUContext>(nbytes, src, dst);
|
||||
}
|
||||
|
||||
} // namespace caffe2
|
||||
|
@ -34,7 +34,7 @@ TEST(CPUContextTest, TestAllocDealloc) {
|
||||
}
|
||||
DeviceOption option;
|
||||
CPUContext context(option);
|
||||
context.Copy<float, CPUContext, CPUContext>(dst_data, data, 10);
|
||||
context.Copy<float, CPUContext, CPUContext>(10, data, dst_data);
|
||||
for (int i = 0; i < 10; ++i) {
|
||||
EXPECT_FLOAT_EQ(dst_data[i], i);
|
||||
}
|
||||
|
@ -187,15 +187,15 @@ bool ImageInputOp<DeviceContext>::CopyPrefetched() {
|
||||
// The first output is the image data.
|
||||
auto* image_output = OperatorBase::Output<Tensor<float, DeviceContext> >(0);
|
||||
image_output->ReshapeLike(prefetched_image_);
|
||||
this->device_context_.template Copy<float, DeviceContext, CPUContext>(
|
||||
image_output->mutable_data(), prefetched_image_.data(),
|
||||
prefetched_image_.size());
|
||||
this->device_context_.template Copy<float, CPUContext, DeviceContext>(
|
||||
prefetched_image_.size(), prefetched_image_.data(),
|
||||
image_output->mutable_data());
|
||||
// The second output is the label.
|
||||
auto* label_output = OperatorBase::Output<Tensor<int, DeviceContext> >(1);
|
||||
label_output->ReshapeLike(prefetched_label_);
|
||||
this->device_context_.template Copy<int, DeviceContext, CPUContext>(
|
||||
label_output->mutable_data(), prefetched_label_.data(),
|
||||
prefetched_label_.size());
|
||||
this->device_context_.template Copy<int, CPUContext, DeviceContext>(
|
||||
prefetched_label_.size(), prefetched_label_.data(),
|
||||
label_output->mutable_data());
|
||||
return true;
|
||||
}
|
||||
|
||||
|
@ -111,8 +111,8 @@ class GivenTensorFillOp final : public FillerOp<dtype, DeviceContext> {
|
||||
DCHECK_EQ(output->size(), values_.size())
|
||||
<< "output size: " << output->size() << " given size: "
|
||||
<< values_.size();
|
||||
device_context_.template Copy<dtype, DeviceContext, CPUContext>(
|
||||
output->mutable_data(), values_.data(), output->size());
|
||||
device_context_.template Copy<dtype, CPUContext, DeviceContext>(
|
||||
output->size(), values_.data(), output->mutable_data());
|
||||
return true;
|
||||
}
|
||||
|
||||
|
@ -56,9 +56,8 @@ class LoadTensorOp final : public Operator<float, DeviceContext> {
|
||||
output->Reshape(
|
||||
vector<int>(proto.dims().begin(), proto.dims().end()));
|
||||
CHECK_EQ(output->size(), proto.float_data_size());
|
||||
this->device_context_.template Copy<float, DeviceContext, CPUContext>(
|
||||
output->mutable_data(), proto.float_data().data(),
|
||||
output->size());
|
||||
this->device_context_.template Copy<float, CPUContext, DeviceContext>(
|
||||
output->size(), proto.float_data().data(), output->mutable_data());
|
||||
VLOG(1) << "Loaded float tensor " << key << ".";
|
||||
break;
|
||||
}
|
||||
@ -71,9 +70,8 @@ class LoadTensorOp final : public Operator<float, DeviceContext> {
|
||||
output->Reshape(
|
||||
vector<int>(proto.dims().begin(), proto.dims().end()));
|
||||
CHECK_EQ(output->size(), proto.int32_data_size());
|
||||
this->device_context_.template Copy<int, DeviceContext, CPUContext>(
|
||||
output->mutable_data(), proto.int32_data().data(),
|
||||
output->size());
|
||||
this->device_context_.template Copy<int, CPUContext, DeviceContext>(
|
||||
output->size(), proto.int32_data().data(), output->mutable_data());
|
||||
VLOG(1) << "Loaded int32 tensor " << key << ".";
|
||||
break;
|
||||
}
|
||||
|
@ -42,7 +42,7 @@ bool LRNOp<float, CPUContext>::RunOnDeviceWithOrderNCHW() {
|
||||
float* this_scale_slice = scale_data + n * image_size + c * H * W;
|
||||
// copy previous scale
|
||||
device_context_.Copy<float, CPUContext, CPUContext>(
|
||||
this_scale_slice, this_scale_slice - H * W, H * W);
|
||||
H * W, this_scale_slice - H * W, this_scale_slice);
|
||||
// add head
|
||||
math::Axpy<float, CPUContext>(
|
||||
H * W, &alpha_over_size, padded_square_data + (c + size_ - 1) * H * W,
|
||||
|
@ -24,7 +24,7 @@ bool SoftmaxOp<float, CPUContext>::RunOnDevice() {
|
||||
&device_context_);
|
||||
// Put the intermediate result X - max(X) into Y
|
||||
device_context_.template Copy<float, CPUContext, CPUContext>(
|
||||
Y->mutable_data(), X.data(), X.size());
|
||||
X.size(), X.data(), Y->mutable_data());
|
||||
// Subtract the scale
|
||||
static const float kMinusOne = -1.;
|
||||
static const float kOne = 1.;
|
||||
@ -74,7 +74,7 @@ bool SoftmaxGradientOp<float, CPUContext>::RunOnDevice() {
|
||||
const float* Ydata = Y.data();
|
||||
const float* dYdata = dY.data();
|
||||
float* dXdata = dX->mutable_data();
|
||||
device_context_.Copy<float, CPUContext, CPUContext>(dXdata, dYdata, Y.size());
|
||||
device_context_.Copy<float, CPUContext, CPUContext>(Y.size(), dYdata, dXdata);
|
||||
float* scaledata = scale_.mutable_data();
|
||||
for (int i = 0; i < N; ++i) {
|
||||
math::Dot<float, CPUContext>(D, Ydata + i * D, dYdata + i * D,
|
||||
|
@ -100,8 +100,8 @@ bool SummarizeOp<float, CUDAContext>::RunOnDevice() {
|
||||
Y->Reshape(std::vector<int>{4});
|
||||
float output_buffer[NUM_STATS] = {result.min, result.max, result.mean,
|
||||
standard_deviation};
|
||||
device_context_.Copy<float, CUDAContext, CPUContext>(
|
||||
Y->mutable_data(), output_buffer, NUM_STATS);
|
||||
device_context_.Copy<float, CPUContext, CUDAContext>(
|
||||
NUM_STATS, output_buffer, Y->mutable_data());
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
@ -166,8 +166,8 @@ bool TensorProtosDBInput<DeviceContext>::CopyPrefetched() {
|
||||
auto& input =
|
||||
prefetched_blobs_[i]->template Get<Tensor<float, CPUContext> >();
|
||||
output->ReshapeLike(input);
|
||||
this->device_context_.template Copy<float, DeviceContext, CPUContext>(
|
||||
output->mutable_data(), input.data(), input.size());
|
||||
this->device_context_.template Copy<float, CPUContext, DeviceContext>(
|
||||
input.size(), input.data(), output->mutable_data());
|
||||
break;
|
||||
}
|
||||
case TensorProto::INT32:
|
||||
@ -176,8 +176,8 @@ bool TensorProtosDBInput<DeviceContext>::CopyPrefetched() {
|
||||
auto& input =
|
||||
prefetched_blobs_[i]->template Get<Tensor<int, CPUContext> >();
|
||||
output->ReshapeLike(input);
|
||||
this->device_context_.template Copy<int, DeviceContext, CPUContext>(
|
||||
output->mutable_data(), input.data(), input.size());
|
||||
this->device_context_.template Copy<int, CPUContext, DeviceContext>(
|
||||
input.size(), input.data(), output->mutable_data());
|
||||
break;
|
||||
}
|
||||
case TensorProto::STRING:
|
||||
|
@ -77,8 +77,8 @@ class PrintOp final : public Operator<dtype, DeviceContext> {
|
||||
auto& input = Input(input_id);
|
||||
DCHECK_GT(input.size(), 0);
|
||||
temp_tensor.ReshapeLike(input);
|
||||
device_context_.template Copy<dtype, CPUContext, DeviceContext>(
|
||||
temp_tensor.mutable_data(), input.data(), input.size());
|
||||
device_context_.template Copy<dtype, DeviceContext, CPUContext>(
|
||||
input.size(), input.data(), temp_tensor.mutable_data());
|
||||
}
|
||||
std::stringstream values_stream;
|
||||
int total_count = std::min(temp_tensor.size(), limit_);
|
||||
@ -206,7 +206,7 @@ class SumOp : public Operator<dtype, DeviceContext> {
|
||||
auto* output = Output(0);
|
||||
output->ReshapeLike(input);
|
||||
device_context_.template Copy<dtype, DeviceContext, DeviceContext>(
|
||||
output->mutable_data(), input.data(), input.size());
|
||||
input.size(), input.data(), output->mutable_data());
|
||||
for (int i = 1; i < InputSize(); ++i) {
|
||||
math::Add(output->size(), output->data(), Input(i).data(),
|
||||
output->mutable_data(), &device_context_);
|
||||
@ -276,8 +276,8 @@ class CopyOp : public Operator<dtype, DeviceContext> {
|
||||
auto& input = OperatorBase::Input<Tensor<dtype, SrcContext> >(0);
|
||||
auto* output = OperatorBase::Output<Tensor<dtype, DstContext> >(0);
|
||||
output->ReshapeLike(input);
|
||||
this->device_context_.template Copy<dtype, DstContext, SrcContext>(
|
||||
output->mutable_data(), input.data(), input.size());
|
||||
this->device_context_.template Copy<dtype, SrcContext, DstContext>(
|
||||
input.size(), input.data(), output->mutable_data());
|
||||
return true;
|
||||
}
|
||||
|
||||
|
@ -50,9 +50,9 @@ class LearningRateOp final : public Operator<dtype, DeviceContext> {
|
||||
dtype learning_rate = base_lr_ * (*functor_)(iter);
|
||||
// Write to output.
|
||||
auto* output = Output(0);
|
||||
output->Reshape(std::vector<int>{1});
|
||||
device_context_.template Copy<dtype, DeviceContext, CPUContext>(
|
||||
Output(0)->mutable_data(), &learning_rate, 1);
|
||||
output->Reshape(std::vector<int>());
|
||||
device_context_.template Copy<dtype, CPUContext, DeviceContext>(
|
||||
1, &learning_rate, Output(0)->mutable_data());
|
||||
return true;
|
||||
}
|
||||
|
||||
|
@ -83,9 +83,9 @@ PyObject* FetchTensor(const Blob& blob) {
|
||||
// Now, copy the data to the tensor.
|
||||
// TODO(Yangqing): Is there an easier way to convert PyObject to
|
||||
// PyArrayObject?
|
||||
context.template Copy<T, CPUContext, DeviceContext>(
|
||||
static_cast<T*>(PyArray_DATA(reinterpret_cast<PyArrayObject*>(array))),
|
||||
tensor.data(), tensor.size());
|
||||
context.template Copy<T, DeviceContext, CPUContext>(
|
||||
tensor.size(), tensor.data(),
|
||||
static_cast<T*>(PyArray_DATA(reinterpret_cast<PyArrayObject*>(array))));
|
||||
return array;
|
||||
}
|
||||
|
||||
@ -105,10 +105,9 @@ PyObject* FeedTensor(const DeviceOption& option, PyArrayObject* original_array,
|
||||
}
|
||||
tensor->Reshape(dims);
|
||||
// Now, copy the data to the tensor.
|
||||
context.template Copy<T, DeviceContext, CPUContext>(
|
||||
tensor->mutable_data(),
|
||||
static_cast<T*>(PyArray_DATA(array)),
|
||||
tensor->size());
|
||||
context.template Copy<T, CPUContext, DeviceContext>(
|
||||
tensor->size(), static_cast<T*>(PyArray_DATA(array)),
|
||||
tensor->mutable_data());
|
||||
Py_XDECREF(array);
|
||||
Py_RETURN_TRUE;
|
||||
}
|
||||
|
Reference in New Issue
Block a user