mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-21 05:34:18 +08:00
Tensor reinitialization codemod - 3/5 (#15912)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/15912 Codemod generated with clangr shard mode, 25 files per diff, To eliminiate partially initialized Tensor, we split the initialization of local Tensor variables into two steps, first declare un uninitialized Tensor, and call `ReinitializeTensor` to initialize it. motivation: https://github.com/pytorch/pytorch/pull/12407 Reviewed By: dzhulgakov Differential Revision: D13586734 fbshipit-source-id: 8485d2c51225343961351c7a2e8f95055534f9a9
This commit is contained in:
committed by
Facebook Github Bot
parent
57d29ffa9c
commit
d277f77da2
@ -316,7 +316,7 @@ bool CosineSimilarityOp<float, CUDAContext>::RunOnDevice() {
|
|||||||
const float* X_data = X.data<float>();
|
const float* X_data = X.data<float>();
|
||||||
const float* Y_data = Y.data<float>();
|
const float* Y_data = Y.data<float>();
|
||||||
// Auxiliary arrays, one allocation of memory
|
// Auxiliary arrays, one allocation of memory
|
||||||
aux_.Resize(2 * N);
|
ReinitializeTensor(&aux_, {2 * N}, at::dtype<float>().device(CUDA));
|
||||||
float* aux_data = aux_.mutable_data<float>();
|
float* aux_data = aux_.mutable_data<float>();
|
||||||
float* x2 = aux_data;
|
float* x2 = aux_data;
|
||||||
float* y2 = aux_data + N;
|
float* y2 = aux_data + N;
|
||||||
@ -371,7 +371,7 @@ bool CosineSimilarityGradientOp<float, CUDAContext>::RunOnDevice() {
|
|||||||
auto* dY_data = dY->template mutable_data<float>();
|
auto* dY_data = dY->template mutable_data<float>();
|
||||||
|
|
||||||
// one memory allocation, a few arrays
|
// one memory allocation, a few arrays
|
||||||
aux_.Resize(6 * N);
|
ReinitializeTensor(&aux_, {6 * N}, at::dtype<float>().device(CUDA));
|
||||||
float* aux_data = aux_.mutable_data<float>();
|
float* aux_data = aux_.mutable_data<float>();
|
||||||
float* xn = aux_data;
|
float* xn = aux_data;
|
||||||
float* yn = aux_data + N;
|
float* yn = aux_data + N;
|
||||||
|
@ -156,7 +156,7 @@ class CosineSimilarityOp : public Operator<Context> {
|
|||||||
OUTPUT_TAGS(COS_OUT);
|
OUTPUT_TAGS(COS_OUT);
|
||||||
|
|
||||||
private:
|
private:
|
||||||
Tensor aux_{Context::GetDeviceType()};
|
Tensor aux_;
|
||||||
};
|
};
|
||||||
|
|
||||||
template <typename T, class Context>
|
template <typename T, class Context>
|
||||||
@ -173,7 +173,7 @@ class CosineSimilarityGradientOp final : public Operator<Context> {
|
|||||||
OUTPUT_TAGS(DER_X_OUT, DER_Y_OUT);
|
OUTPUT_TAGS(DER_X_OUT, DER_Y_OUT);
|
||||||
|
|
||||||
private:
|
private:
|
||||||
Tensor aux_{Context::GetDeviceType()};
|
Tensor aux_;
|
||||||
};
|
};
|
||||||
|
|
||||||
template <typename T, class Context>
|
template <typename T, class Context>
|
||||||
|
@ -98,15 +98,12 @@ void fc_op_cpu_impl(
|
|||||||
math_type);
|
math_type);
|
||||||
// Add bias term
|
// Add bias term
|
||||||
Tensor bias_multiplier(cache->bias_multiplier_);
|
Tensor bias_multiplier(cache->bias_multiplier_);
|
||||||
if (bias_multiplier.numel() != M) {
|
ReinitializeTensor(&bias_multiplier, {M}, at::dtype<DataType>().device(CPU));
|
||||||
// If the helper bias multiplier is not M, reshape and fill it with one.
|
caffe2::math::Set<DataType, Context>(
|
||||||
bias_multiplier.Resize(M);
|
M,
|
||||||
caffe2::math::Set<DataType, Context>(
|
caffe2::convert::To<float, DataType>(1),
|
||||||
M,
|
bias_multiplier.template mutable_data<DataType>(),
|
||||||
caffe2::convert::To<float, DataType>(1),
|
static_cast<Context*>(&context));
|
||||||
bias_multiplier.template mutable_data<DataType>(),
|
|
||||||
static_cast<Context*>(&context));
|
|
||||||
}
|
|
||||||
caffe2::math::Gemm<DataType, Context, caffe2::DefaultEngine>(
|
caffe2::math::Gemm<DataType, Context, caffe2::DefaultEngine>(
|
||||||
CblasNoTrans,
|
CblasNoTrans,
|
||||||
CblasNoTrans,
|
CblasNoTrans,
|
||||||
|
@ -12,7 +12,7 @@ struct FullyConnected final {
|
|||||||
|
|
||||||
struct Cache final {
|
struct Cache final {
|
||||||
vector<int64_t> Y_shape_cache_;
|
vector<int64_t> Y_shape_cache_;
|
||||||
C10Tensor bias_multiplier_ = C10Tensor(Tensor{CPU});
|
C10Tensor bias_multiplier_ = C10Tensor(Tensor());
|
||||||
};
|
};
|
||||||
|
|
||||||
using Signature = void(
|
using Signature = void(
|
||||||
|
@ -1,6 +1,7 @@
|
|||||||
#ifndef CAFFE2_OPERATORS_FULLY_CONNECTED_OP_H_
|
#ifndef CAFFE2_OPERATORS_FULLY_CONNECTED_OP_H_
|
||||||
#define CAFFE2_OPERATORS_FULLY_CONNECTED_OP_H_
|
#define CAFFE2_OPERATORS_FULLY_CONNECTED_OP_H_
|
||||||
|
|
||||||
|
#include <c10/util/Optional.h>
|
||||||
#include "caffe2/core/context.h"
|
#include "caffe2/core/context.h"
|
||||||
#include "caffe2/core/operator.h"
|
#include "caffe2/core/operator.h"
|
||||||
#include "caffe2/utils/conversions.h"
|
#include "caffe2/utils/conversions.h"
|
||||||
@ -104,15 +105,22 @@ class FullyConnectedOp final : public Operator<Context> {
|
|||||||
&context_,
|
&context_,
|
||||||
math_type);
|
math_type);
|
||||||
// Add bias term
|
// Add bias term
|
||||||
if (bias_multiplier_.numel() != M) {
|
if (!bias_multiplier_.has_value()) {
|
||||||
// If the helper bias multiplier is not M, reshape and fill it with one.
|
bias_multiplier_ = caffe2::empty({M}, at::dtype<T_B>().device(Context::GetDeviceType()));
|
||||||
bias_multiplier_.Resize(M);
|
|
||||||
math::Set<T_B, Context>(
|
math::Set<T_B, Context>(
|
||||||
M,
|
M,
|
||||||
convert::To<float, T_B>(1),
|
convert::To<float, T_B>(1),
|
||||||
bias_multiplier_.template mutable_data<T_B>(),
|
bias_multiplier_->template mutable_data<T_B>(),
|
||||||
|
&context_);
|
||||||
|
} else if (bias_multiplier_->numel() != M) {
|
||||||
|
bias_multiplier_->Resize(M);
|
||||||
|
math::Set<T_B, Context>(
|
||||||
|
M,
|
||||||
|
convert::To<float, T_B>(1),
|
||||||
|
bias_multiplier_->template mutable_data<T_B>(),
|
||||||
&context_);
|
&context_);
|
||||||
}
|
}
|
||||||
|
|
||||||
math::Gemm<T_B, Context, Engine>(
|
math::Gemm<T_B, Context, Engine>(
|
||||||
CblasNoTrans,
|
CblasNoTrans,
|
||||||
CblasNoTrans,
|
CblasNoTrans,
|
||||||
@ -120,7 +128,7 @@ class FullyConnectedOp final : public Operator<Context> {
|
|||||||
N,
|
N,
|
||||||
1,
|
1,
|
||||||
1,
|
1,
|
||||||
bias_multiplier_.template data<T_B>(),
|
bias_multiplier_->template data<T_B>(),
|
||||||
b.template data<T_B>(),
|
b.template data<T_B>(),
|
||||||
1,
|
1,
|
||||||
Y->template mutable_data<T_Y>(),
|
Y->template mutable_data<T_Y>(),
|
||||||
@ -144,7 +152,7 @@ class FullyConnectedOp final : public Operator<Context> {
|
|||||||
// A local vector to cache the output shape so we don't need to recreate
|
// A local vector to cache the output shape so we don't need to recreate
|
||||||
// a vector object every time we run Run().
|
// a vector object every time we run Run().
|
||||||
vector<int64_t> Y_shape_cache_;
|
vector<int64_t> Y_shape_cache_;
|
||||||
Tensor bias_multiplier_{Context::GetDeviceType()};
|
c10::optional<Tensor> bias_multiplier_;
|
||||||
|
|
||||||
bool float16_compute_;
|
bool float16_compute_;
|
||||||
};
|
};
|
||||||
@ -250,14 +258,19 @@ class FullyConnectedGradientOp : public Operator<Context> {
|
|||||||
dW->template mutable_data<T_DW>(),
|
dW->template mutable_data<T_DW>(),
|
||||||
&context_,
|
&context_,
|
||||||
math_type);
|
math_type);
|
||||||
if (bias_multiplier_.numel() != M) {
|
if (!bias_multiplier_.has_value()) {
|
||||||
// If the helper bias multiplier is not M, reshape and fill it
|
bias_multiplier_ = caffe2::empty({M}, at::dtype<T_B>().device(Context::GetDeviceType()));
|
||||||
// with one.
|
|
||||||
bias_multiplier_.Resize(M);
|
|
||||||
math::Set<T_B, Context>(
|
math::Set<T_B, Context>(
|
||||||
M,
|
M,
|
||||||
convert::To<float, T_B>(1),
|
convert::To<float, T_B>(1),
|
||||||
bias_multiplier_.template mutable_data<T_B>(),
|
bias_multiplier_->template mutable_data<T_B>(),
|
||||||
|
&context_);
|
||||||
|
} else if (bias_multiplier_->numel() != M) {
|
||||||
|
bias_multiplier_->Resize(M);
|
||||||
|
math::Set<T_B, Context>(
|
||||||
|
M,
|
||||||
|
convert::To<float, T_B>(1),
|
||||||
|
bias_multiplier_->template mutable_data<T_B>(),
|
||||||
&context_);
|
&context_);
|
||||||
}
|
}
|
||||||
// Compute dB
|
// Compute dB
|
||||||
@ -267,7 +280,7 @@ class FullyConnectedGradientOp : public Operator<Context> {
|
|||||||
N,
|
N,
|
||||||
1,
|
1,
|
||||||
dY.template data<T_DY>(),
|
dY.template data<T_DY>(),
|
||||||
bias_multiplier_.template data<T_B>(),
|
bias_multiplier_->template data<T_B>(),
|
||||||
0,
|
0,
|
||||||
db->template mutable_data<T_DB>(),
|
db->template mutable_data<T_DB>(),
|
||||||
&context_);
|
&context_);
|
||||||
@ -307,7 +320,7 @@ class FullyConnectedGradientOp : public Operator<Context> {
|
|||||||
protected:
|
protected:
|
||||||
size_t axis_{1};
|
size_t axis_{1};
|
||||||
size_t axis_w_{1};
|
size_t axis_w_{1};
|
||||||
Tensor bias_multiplier_{Context::GetDeviceType()};
|
c10::optional<Tensor> bias_multiplier_;
|
||||||
bool float16_compute_;
|
bool float16_compute_;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
@ -55,13 +55,13 @@ class GivenTensorByteStringToUInt8FillOp final : public FillerOp<Context> {
|
|||||||
<< " given size: " << source_values.size();
|
<< " given size: " << source_values.size();
|
||||||
|
|
||||||
auto str = source_values[0];
|
auto str = source_values[0];
|
||||||
values_.Resize(str.size());
|
ReinitializeTensor(&values_, {static_cast<int64_t>(str.size())}, at::dtype<uint8_t>().device(CPU));
|
||||||
uint8_t* values_data = values_.template mutable_data<uint8_t>();
|
uint8_t* values_data = values_.template mutable_data<uint8_t>();
|
||||||
for (int i = 0; i < str.size(); i++) {
|
for (int i = 0; i < str.size(); i++) {
|
||||||
values_data[i] = static_cast<uint8_t>(str[i]);
|
values_data[i] = static_cast<uint8_t>(str[i]);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
Tensor values_{CPU};
|
Tensor values_;
|
||||||
};
|
};
|
||||||
} // namespace caffe2
|
} // namespace caffe2
|
||||||
|
@ -60,7 +60,7 @@ class GivenTensorFillOp final : public FillerOp<Context> {
|
|||||||
void ExtractValues() {
|
void ExtractValues() {
|
||||||
auto source_values =
|
auto source_values =
|
||||||
this->template GetRepeatedArgument<Type>("values");
|
this->template GetRepeatedArgument<Type>("values");
|
||||||
values_.Resize(source_values.size());
|
ReinitializeTensor(&values_, {static_cast<int64_t>(source_values.size())}, at::dtype<Type>().device(CPU));
|
||||||
Type* values_data = values_.template mutable_data<Type>();
|
Type* values_data = values_.template mutable_data<Type>();
|
||||||
for (int i = 0; i < source_values.size(); i++) {
|
for (int i = 0; i < source_values.size(); i++) {
|
||||||
values_data[i] = static_cast<Type>(source_values[i]);
|
values_data[i] = static_cast<Type>(source_values[i]);
|
||||||
@ -83,6 +83,6 @@ class GivenTensorFillOp final : public FillerOp<Context> {
|
|||||||
}
|
}
|
||||||
|
|
||||||
bool (GivenTensorFillOp::*body_)(Tensor* output);
|
bool (GivenTensorFillOp::*body_)(Tensor* output);
|
||||||
Tensor values_{CPU};
|
Tensor values_;
|
||||||
};
|
};
|
||||||
} // namespace caffe2
|
} // namespace caffe2
|
||||||
|
@ -104,8 +104,10 @@ bool GroupNormGradientOp<T, Context>::RunOnDeviceImpl(
|
|||||||
// dL/ds = Sum(dL/dY * gamma * X)
|
// dL/ds = Sum(dL/dY * gamma * X)
|
||||||
// dL/db = Sum(dL/dY * gamma)
|
// dL/db = Sum(dL/dY * gamma)
|
||||||
const int C = G * D;
|
const int C = G * D;
|
||||||
ds_.Resize(N, G);
|
ReinitializeTensor(
|
||||||
db_.Resize(N, G);
|
&ds_, {N, G}, at::dtype<T>().device(Context::GetDeviceType()));
|
||||||
|
ReinitializeTensor(
|
||||||
|
&db_, {N, G}, at::dtype<T>().device(Context::GetDeviceType()));
|
||||||
T* ds_data = ds_.template mutable_data<T>();
|
T* ds_data = ds_.template mutable_data<T>();
|
||||||
T* db_data = db_.template mutable_data<T>();
|
T* db_data = db_.template mutable_data<T>();
|
||||||
math::Set<T, Context>(N * G, T(0), ds_data, &context_);
|
math::Set<T, Context>(N * G, T(0), ds_data, &context_);
|
||||||
|
@ -326,8 +326,10 @@ bool GroupNormGradientOp<float, CUDAContext>::RunOnDeviceImpl(
|
|||||||
float* dbeta_data) {
|
float* dbeta_data) {
|
||||||
const int size = N * G * D * HxW;
|
const int size = N * G * D * HxW;
|
||||||
const int C = G * D;
|
const int C = G * D;
|
||||||
ds_.Resize(N, G);
|
ReinitializeTensor(
|
||||||
db_.Resize(N, G);
|
&ds_, {N, G}, at::dtype<float>().device(CUDA));
|
||||||
|
ReinitializeTensor(
|
||||||
|
&db_, {N, G}, at::dtype<float>().device(CUDA));
|
||||||
float* ds_data = ds_.mutable_data<float>();
|
float* ds_data = ds_.mutable_data<float>();
|
||||||
float* db_data = db_.mutable_data<float>();
|
float* db_data = db_.mutable_data<float>();
|
||||||
if (order_ == StorageOrder::NCHW) {
|
if (order_ == StorageOrder::NCHW) {
|
||||||
|
@ -57,8 +57,8 @@ class GroupNormOp final : public Operator<Context> {
|
|||||||
mu_data = mu->template mutable_data<T>();
|
mu_data = mu->template mutable_data<T>();
|
||||||
rsig_data = rsig->template mutable_data<T>();
|
rsig_data = rsig->template mutable_data<T>();
|
||||||
} else {
|
} else {
|
||||||
mu_.Resize(N, G);
|
ReinitializeTensor(&mu_, {N, G}, at::dtype<T>().device(Context::GetDeviceType()));
|
||||||
rsig_.Resize(N, G);
|
ReinitializeTensor(&rsig_, {N, G}, at::dtype<T>().device(Context::GetDeviceType()));
|
||||||
mu_data = mu_.template mutable_data<T>();
|
mu_data = mu_.template mutable_data<T>();
|
||||||
rsig_data = rsig_.template mutable_data<T>();
|
rsig_data = rsig_.template mutable_data<T>();
|
||||||
}
|
}
|
||||||
@ -88,8 +88,8 @@ class GroupNormOp final : public Operator<Context> {
|
|||||||
T* mu,
|
T* mu,
|
||||||
T* rsig) {
|
T* rsig) {
|
||||||
const int C = G * D;
|
const int C = G * D;
|
||||||
scale_.Resize(N, C);
|
ReinitializeTensor(&scale_, {N, C}, at::dtype<T>().device(Context::GetDeviceType()));
|
||||||
bias_.Resize(N, C);
|
ReinitializeTensor(&bias_, {N, C}, at::dtype<T>().device(Context::GetDeviceType()));
|
||||||
T* scale_data = scale_.template mutable_data<T>();
|
T* scale_data = scale_.template mutable_data<T>();
|
||||||
T* bias_data = bias_.template mutable_data<T>();
|
T* bias_data = bias_.template mutable_data<T>();
|
||||||
if (order_ == StorageOrder::NCHW) {
|
if (order_ == StorageOrder::NCHW) {
|
||||||
@ -175,10 +175,10 @@ class GroupNormOp final : public Operator<Context> {
|
|||||||
const StorageOrder order_;
|
const StorageOrder order_;
|
||||||
const bool is_test_;
|
const bool is_test_;
|
||||||
|
|
||||||
Tensor mu_{Context::GetDeviceType()};
|
Tensor mu_;
|
||||||
Tensor rsig_{Context::GetDeviceType()};
|
Tensor rsig_;
|
||||||
Tensor scale_{Context::GetDeviceType()};
|
Tensor scale_;
|
||||||
Tensor bias_{Context::GetDeviceType()};
|
Tensor bias_;
|
||||||
|
|
||||||
// Input: X, gamma, beta
|
// Input: X, gamma, beta
|
||||||
// Output: Y, mu, inv_sig
|
// Output: Y, mu, inv_sig
|
||||||
@ -255,8 +255,8 @@ class GroupNormGradientOp final : public Operator<Context> {
|
|||||||
const int group_;
|
const int group_;
|
||||||
const StorageOrder order_;
|
const StorageOrder order_;
|
||||||
|
|
||||||
Tensor ds_{Context::GetDeviceType()};
|
Tensor ds_;
|
||||||
Tensor db_{Context::GetDeviceType()};
|
Tensor db_;
|
||||||
|
|
||||||
// Input: dY, X, gamma, beta, mu, inv_sig
|
// Input: dY, X, gamma, beta, mu, inv_sig
|
||||||
// Output: dX, dgamma, dbeta
|
// Output: dX, dgamma, dbeta
|
||||||
|
@ -24,34 +24,39 @@ float HSoftmaxOp<float, CPUContext>::RunForwardSingle(const float* X,
|
|||||||
//Softmax
|
//Softmax
|
||||||
float* softmax_output_data = int_output + int_output_offset;
|
float* softmax_output_data = int_output + int_output_offset;
|
||||||
|
|
||||||
if (scale_.numel() != 1) {
|
if (!scale_.has_value()) {
|
||||||
scale_.Resize(1);
|
scale_ = caffe2::empty({1}, at::dtype<float>().device(CPU));
|
||||||
}
|
}
|
||||||
if (sum_multiplier_.numel() != dim_out) {
|
|
||||||
sum_multiplier_.Resize(dim_out);
|
if (!sum_multiplier_.has_value()) {
|
||||||
|
sum_multiplier_ = caffe2::empty({dim_out}, at::dtype<float>().device(CPU));
|
||||||
math::Set<float, CPUContext>(dim_out, 1.f,
|
math::Set<float, CPUContext>(dim_out, 1.f,
|
||||||
sum_multiplier_.mutable_data<float>(), &context_);
|
sum_multiplier_->mutable_data<float>(), &context_);
|
||||||
|
} else if (sum_multiplier_->numel() != dim_out) {
|
||||||
|
sum_multiplier_->Resize(dim_out);
|
||||||
|
math::Set<float, CPUContext>(dim_out, 1.f,
|
||||||
|
sum_multiplier_->mutable_data<float>(), &context_);
|
||||||
}
|
}
|
||||||
math::RowwiseMax<float, CPUContext>(1, dim_out, fc_output_data,
|
math::RowwiseMax<float, CPUContext>(1, dim_out, fc_output_data,
|
||||||
scale_.mutable_data<float>(), &context_);
|
scale_->mutable_data<float>(), &context_);
|
||||||
|
|
||||||
// Put the intermediate result X - max(X) into Y
|
// Put the intermediate result X - max(X) into Y
|
||||||
context_.template CopyFromCPU<float>(
|
context_.template CopyFromCPU<float>(
|
||||||
dim_out, fc_output_data, softmax_output_data);
|
dim_out, fc_output_data, softmax_output_data);
|
||||||
// Subtract the scale
|
// Subtract the scale
|
||||||
math::Gemv<float, CPUContext>(CblasNoTrans, dim_out, 1, -1,
|
math::Gemv<float, CPUContext>(CblasNoTrans, dim_out, 1, -1,
|
||||||
sum_multiplier_.data<float>(), scale_.data<float>(), 1, softmax_output_data,
|
sum_multiplier_->data<float>(), scale_->data<float>(), 1, softmax_output_data,
|
||||||
&context_);
|
&context_);
|
||||||
|
|
||||||
// Exponentiation
|
// Exponentiation
|
||||||
math::Exp<float, CPUContext>(dim_out, softmax_output_data,
|
math::Exp<float, CPUContext>(dim_out, softmax_output_data,
|
||||||
softmax_output_data, &context_);
|
softmax_output_data, &context_);
|
||||||
math::Gemv<float, CPUContext>(CblasNoTrans, 1, dim_out, 1,
|
math::Gemv<float, CPUContext>(CblasNoTrans, 1, dim_out, 1,
|
||||||
softmax_output_data, sum_multiplier_.data<float>(), 0,
|
softmax_output_data, sum_multiplier_->data<float>(), 0,
|
||||||
scale_.mutable_data<float>(), &context_);
|
scale_->mutable_data<float>(), &context_);
|
||||||
|
|
||||||
// Do division
|
// Do division
|
||||||
const float scale = *scale_.data<float>();
|
const float scale = *(scale_->data<float>());
|
||||||
for (int j = 0; j < dim_out; ++j) {
|
for (int j = 0; j < dim_out; ++j) {
|
||||||
softmax_output_data[j] /= scale;
|
softmax_output_data[j] /= scale;
|
||||||
}
|
}
|
||||||
@ -94,10 +99,14 @@ bool HSoftmaxOp<float, CPUContext>::RunOnDevice() {
|
|||||||
float* int_output_data = intermediate_output->template mutable_data<float>();
|
float* int_output_data = intermediate_output->template mutable_data<float>();
|
||||||
int int_output_offset = 0;
|
int int_output_offset = 0;
|
||||||
|
|
||||||
if (bias_multiplier_.numel() != M) {
|
if (!bias_multiplier_.has_value()) {
|
||||||
bias_multiplier_.Resize(M);
|
bias_multiplier_ = caffe2::empty({M}, at::dtype<float>().device(CPU));
|
||||||
math::Set<float, CPUContext>(M, static_cast<float>(1),
|
math::Set<float, CPUContext>(M, static_cast<float>(1),
|
||||||
bias_multiplier_.mutable_data<float>(), &context_);
|
bias_multiplier_->mutable_data<float>(), &context_);
|
||||||
|
} else if (bias_multiplier_->numel() != M) {
|
||||||
|
bias_multiplier_->Resize(M);
|
||||||
|
math::Set<float, CPUContext>(M, static_cast<float>(1),
|
||||||
|
bias_multiplier_->mutable_data<float>(), &context_);
|
||||||
}
|
}
|
||||||
|
|
||||||
for (int sample = 0; sample < M; ++sample) {
|
for (int sample = 0; sample < M; ++sample) {
|
||||||
@ -112,7 +121,7 @@ bool HSoftmaxOp<float, CPUContext>::RunOnDevice() {
|
|||||||
//Adding log probabilities
|
//Adding log probabilities
|
||||||
Ydata[sample] += RunForwardSingle(X.data<float>() + sample*K,
|
Ydata[sample] += RunForwardSingle(X.data<float>() + sample*K,
|
||||||
W.data<float>() + w_offset*K, b.data<float>() + w_offset, target,
|
W.data<float>() + w_offset*K, b.data<float>() + w_offset, target,
|
||||||
int_output_data, bias_multiplier_.data<float>()+sample, w_length, K,
|
int_output_data, bias_multiplier_->data<float>()+sample, w_length, K,
|
||||||
int_output_offset);
|
int_output_offset);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -137,15 +146,19 @@ void HSoftmaxGradientOp<float, CPUContext>::RunBackwardSingle(const float* X,
|
|||||||
int_output_offset -= dim_out;
|
int_output_offset -= dim_out;
|
||||||
|
|
||||||
//Softmax
|
//Softmax
|
||||||
if (scale_.numel() != 1) {
|
if (!scale_.has_value()) {
|
||||||
scale_.Resize(1);
|
scale_ = caffe2::empty({1}, at::dtype<float>().device(CPU));
|
||||||
}
|
}
|
||||||
float* scaledata = scale_.mutable_data<float>();
|
float* scaledata = scale_->mutable_data<float>();
|
||||||
|
|
||||||
if (sum_multiplier_.numel() != dim_out) {
|
if (!sum_multiplier_.has_value()) {
|
||||||
sum_multiplier_.Resize(dim_out);
|
sum_multiplier_ = caffe2::empty({dim_out}, at::dtype<float>().device(CPU));
|
||||||
math::Set<float, CPUContext>(dim_out, 1.f,
|
math::Set<float, CPUContext>(dim_out, 1.f,
|
||||||
sum_multiplier_.mutable_data<float>(), &context_);
|
sum_multiplier_->mutable_data<float>(), &context_);
|
||||||
|
} else if (sum_multiplier_->numel() != dim_out) {
|
||||||
|
sum_multiplier_->Resize(dim_out);
|
||||||
|
math::Set<float, CPUContext>(dim_out, 1.f,
|
||||||
|
sum_multiplier_->mutable_data<float>(), &context_);
|
||||||
}
|
}
|
||||||
|
|
||||||
float* dX_softmax = dint_output + int_output_offset - dim_out;
|
float* dX_softmax = dint_output + int_output_offset - dim_out;
|
||||||
@ -154,19 +167,19 @@ void HSoftmaxGradientOp<float, CPUContext>::RunBackwardSingle(const float* X,
|
|||||||
math::Dot<float, CPUContext>(dim_out, X_entropy, dX_entropy, scaledata,
|
math::Dot<float, CPUContext>(dim_out, X_entropy, dX_entropy, scaledata,
|
||||||
&context_);
|
&context_);
|
||||||
math::Gemv<float, CPUContext>(CblasTrans, 1, dim_out, -1,
|
math::Gemv<float, CPUContext>(CblasTrans, 1, dim_out, -1,
|
||||||
sum_multiplier_.data<float>(), scaledata , 1, dX_softmax, &context_);
|
sum_multiplier_->data<float>(), scaledata , 1, dX_softmax, &context_);
|
||||||
math::Mul<float, CPUContext>(dim_out, dX_softmax, X_entropy, dX_softmax,
|
math::Mul<float, CPUContext>(dim_out, dX_softmax, X_entropy, dX_softmax,
|
||||||
&context_);
|
&context_);
|
||||||
|
|
||||||
int_output_offset -= dim_out;
|
int_output_offset -= dim_out;
|
||||||
|
|
||||||
//FC
|
//FC
|
||||||
if (bias_multiplier_.numel() != 1) {
|
if (!bias_multiplier_.has_value()) {
|
||||||
// If the helper bias multiplier has not been created, reshape and fill
|
// If the helper bias multiplier has not been created, reshape and fill
|
||||||
// it with 1
|
// it with 1
|
||||||
bias_multiplier_.Resize(1);
|
bias_multiplier_ = caffe2::empty({1}, at::dtype<float>().device(CPU));
|
||||||
math::Set<float, CPUContext>(1, static_cast<float>(1),
|
math::Set<float, CPUContext>(1, static_cast<float>(1),
|
||||||
bias_multiplier_.template mutable_data<float>(), &context_);
|
bias_multiplier_->template mutable_data<float>(), &context_);
|
||||||
}
|
}
|
||||||
|
|
||||||
// Compute dW and add incrementally
|
// Compute dW and add incrementally
|
||||||
@ -177,7 +190,7 @@ void HSoftmaxGradientOp<float, CPUContext>::RunBackwardSingle(const float* X,
|
|||||||
// Compute dB and add incrementally
|
// Compute dB and add incrementally
|
||||||
// db = db + dX_softmax*bias_multiplier_
|
// db = db + dX_softmax*bias_multiplier_
|
||||||
math::Gemv<float, CPUContext>(CblasTrans, 1, dim_out, 1, dX_softmax,
|
math::Gemv<float, CPUContext>(CblasTrans, 1, dim_out, 1, dX_softmax,
|
||||||
bias_multiplier_.template data<float>(), 1, db, &context_);
|
bias_multiplier_->template data<float>(), 1, db, &context_);
|
||||||
|
|
||||||
// Compute dX and add incrementally
|
// Compute dX and add incrementally
|
||||||
// dX = dX + W'dX_softmax
|
// dX = dX + W'dX_softmax
|
||||||
@ -265,7 +278,7 @@ bool HSoftmaxSearchOp<float, CPUContext>::pruning(
|
|||||||
b + w_offset,
|
b + w_offset,
|
||||||
-1,
|
-1,
|
||||||
int_output_data,
|
int_output_data,
|
||||||
bias_multiplier_.template data<float>() + sample,
|
bias_multiplier_->template data<float>() + sample,
|
||||||
w_length,
|
w_length,
|
||||||
K,
|
K,
|
||||||
int_output_offset);
|
int_output_offset);
|
||||||
@ -351,13 +364,14 @@ bool HSoftmaxSearchOp<float, CPUContext>::RunOnDevice() {
|
|||||||
auto* Y_names = Output(0, {M, top_n_}, at::dtype<string>());
|
auto* Y_names = Output(0, {M, top_n_}, at::dtype<string>());
|
||||||
auto* Y_scores = Output(1, {M, top_n_}, at::dtype<float>());
|
auto* Y_scores = Output(1, {M, top_n_}, at::dtype<float>());
|
||||||
|
|
||||||
if (bias_multiplier_.numel() != M) {
|
if (!bias_multiplier_.has_value()) {
|
||||||
bias_multiplier_.Resize(M);
|
bias_multiplier_ = caffe2::empty({M}, at::dtype<float>().device(CPU));
|
||||||
math::Set<float, CPUContext>(
|
math::Set<float, CPUContext>(M, static_cast<float>(1),
|
||||||
M,
|
bias_multiplier_->mutable_data<float>(), &context_);
|
||||||
static_cast<float>(1),
|
} else if (bias_multiplier_->numel() != M) {
|
||||||
bias_multiplier_.mutable_data<float>(),
|
bias_multiplier_->Resize(M);
|
||||||
&context_);
|
math::Set<float, CPUContext>(M, static_cast<float>(1),
|
||||||
|
bias_multiplier_->mutable_data<float>(), &context_);
|
||||||
}
|
}
|
||||||
|
|
||||||
for (int sample = 0; sample < M; ++sample) {
|
for (int sample = 0; sample < M; ++sample) {
|
||||||
|
@ -1,6 +1,7 @@
|
|||||||
#ifndef CAFFE2_OPERATORS_H_SOFTMAX_OP_H_
|
#ifndef CAFFE2_OPERATORS_H_SOFTMAX_OP_H_
|
||||||
#define CAFFE2_OPERATORS_H_SOFTMAX_OP_H_
|
#define CAFFE2_OPERATORS_H_SOFTMAX_OP_H_
|
||||||
|
|
||||||
|
#include <c10/util/Optional.h>
|
||||||
#include "caffe2/core/context.h"
|
#include "caffe2/core/context.h"
|
||||||
#include "caffe2/core/logging.h"
|
#include "caffe2/core/logging.h"
|
||||||
#include "caffe2/core/operator.h"
|
#include "caffe2/core/operator.h"
|
||||||
@ -25,9 +26,9 @@ class HSoftmaxOpBase : public Operator<Context> {
|
|||||||
|
|
||||||
protected:
|
protected:
|
||||||
std::unordered_map<int, PathProto> hierarchy_all_map_;
|
std::unordered_map<int, PathProto> hierarchy_all_map_;
|
||||||
Tensor scale_{Context::GetDeviceType()};
|
c10::optional<Tensor> scale_;
|
||||||
Tensor sum_multiplier_{Context::GetDeviceType()};
|
c10::optional<Tensor> sum_multiplier_;
|
||||||
Tensor bias_multiplier_{Context::GetDeviceType()};
|
c10::optional<Tensor> bias_multiplier_;
|
||||||
static constexpr T kLOG_THRESHOLD() {
|
static constexpr T kLOG_THRESHOLD() {
|
||||||
return 1e-20f;
|
return 1e-20f;
|
||||||
}
|
}
|
||||||
|
@ -39,10 +39,12 @@ bool InstanceNormGradientOp<T, Context>::RunOnDeviceWithOrderNHWC() {
|
|||||||
|
|
||||||
// Resize before we get into the per-instance loop
|
// Resize before we get into the per-instance loop
|
||||||
if (InputSize() < 5) {
|
if (InputSize() < 5) {
|
||||||
mean_.Resize(N, C);
|
ReinitializeTensor(
|
||||||
|
&mean_, {N, C}, at::dtype<T>().device(Context::GetDeviceType()));
|
||||||
}
|
}
|
||||||
if (InputSize() < 6) {
|
if (InputSize() < 6) {
|
||||||
inv_stdev_.Resize(N, C);
|
ReinitializeTensor(
|
||||||
|
&inv_stdev_, {N, C}, at::dtype<T>().device(Context::GetDeviceType()));
|
||||||
}
|
}
|
||||||
|
|
||||||
// looping over per-instance and using Eigen blocks to extract out
|
// looping over per-instance and using Eigen blocks to extract out
|
||||||
@ -174,7 +176,8 @@ bool InstanceNormGradientOp<T, Context>::RunOnDeviceWithOrderNCHW() {
|
|||||||
|
|
||||||
// Compute mean if it wasn't passed in
|
// Compute mean if it wasn't passed in
|
||||||
if (InputSize() < 5) {
|
if (InputSize() < 5) {
|
||||||
mean_.Resize(N, C);
|
ReinitializeTensor(
|
||||||
|
&mean_, {N, C}, at::dtype<T>().device(Context::GetDeviceType()));
|
||||||
EigenVectorArrayMap<T> mean_mutable_arr(
|
EigenVectorArrayMap<T> mean_mutable_arr(
|
||||||
mean_.template mutable_data<T>(), N * C);
|
mean_.template mutable_data<T>(), N * C);
|
||||||
mean_mutable_arr = input_mat.colwise().mean();
|
mean_mutable_arr = input_mat.colwise().mean();
|
||||||
@ -189,7 +192,8 @@ bool InstanceNormGradientOp<T, Context>::RunOnDeviceWithOrderNCHW() {
|
|||||||
|
|
||||||
// compute 1 / stdev if not passed in
|
// compute 1 / stdev if not passed in
|
||||||
if (InputSize() < 6) {
|
if (InputSize() < 6) {
|
||||||
inv_stdev_.Resize(N, C);
|
ReinitializeTensor(
|
||||||
|
&inv_stdev_, {N, C}, at::dtype<T>().device(Context::GetDeviceType()));
|
||||||
EigenVectorArrayMap<T> inv_stdev_mutable_arr(
|
EigenVectorArrayMap<T> inv_stdev_mutable_arr(
|
||||||
inv_stdev_.template mutable_data<T>(), N * C);
|
inv_stdev_.template mutable_data<T>(), N * C);
|
||||||
|
|
||||||
|
@ -378,7 +378,7 @@ bool InstanceNormGradientOp<float, CUDAContext>::RunOnDeviceWithOrderNHWC() {
|
|||||||
const auto dim_stride = C;
|
const auto dim_stride = C;
|
||||||
|
|
||||||
if (InputSize() < 5) {
|
if (InputSize() < 5) {
|
||||||
mean_.Resize(N, C);
|
ReinitializeTensor(&mean_, {N, C}, at::dtype<float>().device(CUDA));
|
||||||
auto mean_mutable_data = mean_.mutable_data<float>();
|
auto mean_mutable_data = mean_.mutable_data<float>();
|
||||||
InstanceNormMeanKernel<<<
|
InstanceNormMeanKernel<<<
|
||||||
CAFFE_GET_BLOCKS(N * C),
|
CAFFE_GET_BLOCKS(N * C),
|
||||||
@ -401,7 +401,7 @@ bool InstanceNormGradientOp<float, CUDAContext>::RunOnDeviceWithOrderNHWC() {
|
|||||||
const auto mean_data = mean.data<float>();
|
const auto mean_data = mean.data<float>();
|
||||||
|
|
||||||
if (InputSize() < 6) {
|
if (InputSize() < 6) {
|
||||||
inv_stdev_.Resize(N, C);
|
ReinitializeTensor(&inv_stdev_, {N, C}, at::dtype<float>().device(CUDA));
|
||||||
auto inv_stdev_mutable_data = inv_stdev_.mutable_data<float>();
|
auto inv_stdev_mutable_data = inv_stdev_.mutable_data<float>();
|
||||||
InstanceNormInvStdevKernel<<<
|
InstanceNormInvStdevKernel<<<
|
||||||
CAFFE_GET_BLOCKS(N * C),
|
CAFFE_GET_BLOCKS(N * C),
|
||||||
|
@ -81,8 +81,8 @@ class InstanceNormGradientOp : public Operator<Context> {
|
|||||||
|
|
||||||
// temp results that could get passed through to this gradient, but if not,
|
// temp results that could get passed through to this gradient, but if not,
|
||||||
// are stored here
|
// are stored here
|
||||||
Tensor mean_{Context::GetDeviceType()};
|
Tensor mean_;
|
||||||
Tensor inv_stdev_{Context::GetDeviceType()};
|
Tensor inv_stdev_;
|
||||||
|
|
||||||
INPUT_TAGS(INPUT, SCALE, BIAS, OUTPUT_GRAD, MEAN, INV_STDEV);
|
INPUT_TAGS(INPUT, SCALE, BIAS, OUTPUT_GRAD, MEAN, INV_STDEV);
|
||||||
OUTPUT_TAGS(INPUT_GRAD, SCALE_GRAD, BIAS_GRAD);
|
OUTPUT_TAGS(INPUT_GRAD, SCALE_GRAD, BIAS_GRAD);
|
||||||
|
@ -175,7 +175,7 @@ bool IntegralImageGradientOp<float, CUDAContext>::RunOnDevice() {
|
|||||||
// Col pass reduces shape to (N, C, H, W)
|
// Col pass reduces shape to (N, C, H, W)
|
||||||
vector<int64_t> row_pass_shape(dY.sizes().vec());
|
vector<int64_t> row_pass_shape(dY.sizes().vec());
|
||||||
row_pass_shape[3] -= 1;
|
row_pass_shape[3] -= 1;
|
||||||
row_pass_buffer_.Resize(row_pass_shape);
|
ReinitializeTensor(&row_pass_buffer_, row_pass_shape, at::dtype<float>().device(CUDA));
|
||||||
const int chans = row_pass_buffer_.dim32(1);
|
const int chans = row_pass_buffer_.dim32(1);
|
||||||
const int rows_out = row_pass_buffer_.dim32(2);
|
const int rows_out = row_pass_buffer_.dim32(2);
|
||||||
const int cols_out = row_pass_buffer_.dim32(3);
|
const int cols_out = row_pass_buffer_.dim32(3);
|
||||||
|
@ -28,7 +28,7 @@ class IntegralImageGradientOp final : public Operator<Context> {
|
|||||||
bool RunOnDevice() override;
|
bool RunOnDevice() override;
|
||||||
|
|
||||||
protected:
|
protected:
|
||||||
Tensor row_pass_buffer_{Context::GetDeviceType()};
|
Tensor row_pass_buffer_;
|
||||||
};
|
};
|
||||||
|
|
||||||
} // namespace caffe2
|
} // namespace caffe2
|
||||||
|
@ -132,11 +132,11 @@ class LayerNormGradientOp final : public Operator<Context> {
|
|||||||
const int N = X.size_from_dim(canonical_axis);
|
const int N = X.size_from_dim(canonical_axis);
|
||||||
|
|
||||||
auto* dX = Output(0, X.sizes(), at::dtype<T>());
|
auto* dX = Output(0, X.sizes(), at::dtype<T>());
|
||||||
ds_.Resize(M);
|
ReinitializeTensor(&ds_, {M}, at::dtype<T>().device(Context::GetDeviceType()));
|
||||||
db_.Resize(M);
|
ReinitializeTensor(&db_, {M}, at::dtype<T>().device(Context::GetDeviceType()));
|
||||||
dY_scale_.Resize(M);
|
ReinitializeTensor(&dY_scale_, {M}, at::dtype<T>().device(Context::GetDeviceType()));
|
||||||
X_scale_.Resize(M);
|
ReinitializeTensor(&X_scale_, {M}, at::dtype<T>().device(Context::GetDeviceType()));
|
||||||
bias_.Resize(M);
|
ReinitializeTensor(&bias_, {M}, at::dtype<T>().device(Context::GetDeviceType()));
|
||||||
const T* dY_data = dY.template data<T>();
|
const T* dY_data = dY.template data<T>();
|
||||||
const T* X_data = X.template data<T>();
|
const T* X_data = X.template data<T>();
|
||||||
const T* mean_data = mean.template data<T>();
|
const T* mean_data = mean.template data<T>();
|
||||||
@ -200,11 +200,11 @@ class LayerNormGradientOp final : public Operator<Context> {
|
|||||||
|
|
||||||
const int axis_;
|
const int axis_;
|
||||||
|
|
||||||
Tensor ds_{Context::GetDeviceType()};
|
Tensor ds_;
|
||||||
Tensor db_{Context::GetDeviceType()};
|
Tensor db_;
|
||||||
Tensor dY_scale_{Context::GetDeviceType()};
|
Tensor dY_scale_;
|
||||||
Tensor X_scale_{Context::GetDeviceType()};
|
Tensor X_scale_;
|
||||||
Tensor bias_{Context::GetDeviceType()};
|
Tensor bias_;
|
||||||
};
|
};
|
||||||
|
|
||||||
} // namespace caffe2
|
} // namespace caffe2
|
||||||
|
@ -22,7 +22,7 @@ template <>
|
|||||||
bool LengthsTileOp<CUDAContext>::RunOnDevice() {
|
bool LengthsTileOp<CUDAContext>::RunOnDevice() {
|
||||||
auto& data = Input(DATA);
|
auto& data = Input(DATA);
|
||||||
auto& lengths = Input(LENGTHS);
|
auto& lengths = Input(LENGTHS);
|
||||||
|
|
||||||
|
|
||||||
CAFFE_ENFORCE_EQ(lengths.ndim(), 1, "LENGTHS must be 1-D");
|
CAFFE_ENFORCE_EQ(lengths.ndim(), 1, "LENGTHS must be 1-D");
|
||||||
CAFFE_ENFORCE_GE(data.ndim(), 1, "DATA should be at least 1-D");
|
CAFFE_ENFORCE_GE(data.ndim(), 1, "DATA should be at least 1-D");
|
||||||
@ -45,8 +45,8 @@ bool LengthsTileOp<CUDAContext>::RunOnDevice() {
|
|||||||
auto numElements = total_length * numElementsPerRow;
|
auto numElements = total_length * numElementsPerRow;
|
||||||
auto numBlocks = CAFFE_GET_BLOCKS(numElements);
|
auto numBlocks = CAFFE_GET_BLOCKS(numElements);
|
||||||
|
|
||||||
rowMappingHost_.Resize(total_length);
|
ReinitializeTensor(&rowMappingHost_, {total_length}, at::dtype<int32_t>().device(CPU));
|
||||||
rowMappingDevice_.Resize(total_length);
|
ReinitializeTensor(&rowMappingDevice_, {total_length}, at::dtype<int32_t>().device(CPU));
|
||||||
auto* rowOffsets = rowMappingHost_.mutable_data<int32_t>();
|
auto* rowOffsets = rowMappingHost_.mutable_data<int32_t>();
|
||||||
int32_t outputRow = 0;
|
int32_t outputRow = 0;
|
||||||
for (int64_t i = 0; i < lengths_size; i++) {
|
for (int64_t i = 0; i < lengths_size; i++) {
|
||||||
|
@ -20,8 +20,8 @@ class LengthsTileOp : public Operator<Context> {
|
|||||||
|
|
||||||
private:
|
private:
|
||||||
Tensor lengths_host_{CPU};
|
Tensor lengths_host_{CPU};
|
||||||
Tensor rowMappingHost_{CPU};
|
Tensor rowMappingHost_;
|
||||||
Tensor rowMappingDevice_{Context::GetDeviceType()};
|
Tensor rowMappingDevice_;
|
||||||
};
|
};
|
||||||
|
|
||||||
} // namespace caffe2
|
} // namespace caffe2
|
||||||
|
@ -52,7 +52,10 @@ void LambdaRankNdcgOp<float, CPUContext>::ResizeInvLogITensor(int size) {
|
|||||||
new_size <<= 1;
|
new_size <<= 1;
|
||||||
}
|
}
|
||||||
if (new_size != old_size) {
|
if (new_size != old_size) {
|
||||||
inv_log_i_.Resize(new_size);
|
ReinitializeTensor(
|
||||||
|
&inv_log_i_,
|
||||||
|
{new_size},
|
||||||
|
at::dtype<float>().device(CPU));
|
||||||
auto* data = inv_log_i_.template mutable_data<float>();
|
auto* data = inv_log_i_.template mutable_data<float>();
|
||||||
EigenVectorArrayMap<float> vec(data, inv_log_i_.numel());
|
EigenVectorArrayMap<float> vec(data, inv_log_i_.numel());
|
||||||
const float log2f_ = std::log(2.f);
|
const float log2f_ = std::log(2.f);
|
||||||
@ -64,7 +67,8 @@ void LambdaRankNdcgOp<float, CPUContext>::ResizeInvLogITensor(int size) {
|
|||||||
|
|
||||||
template <>
|
template <>
|
||||||
void LambdaRankNdcgOp<float, CPUContext>::ComputeDiscounts(int* idx, int N) {
|
void LambdaRankNdcgOp<float, CPUContext>::ComputeDiscounts(int* idx, int N) {
|
||||||
discount_.Resize(N);
|
ReinitializeTensor(
|
||||||
|
&discount_, {N}, at::dtype<float>().device(CPU));
|
||||||
auto* discount_data = discount_.template mutable_data<float>();
|
auto* discount_data = discount_.template mutable_data<float>();
|
||||||
auto* inv_log_i_data = inv_log_i_.template mutable_data<float>();
|
auto* inv_log_i_data = inv_log_i_.template mutable_data<float>();
|
||||||
for (int i = 0; i < N; i++) {
|
for (int i = 0; i < N; i++) {
|
||||||
@ -94,8 +98,10 @@ float LambdaRankNdcgOp<float, CPUContext>::LambdaRankNdcgSession(
|
|||||||
return 0;
|
return 0;
|
||||||
}
|
}
|
||||||
|
|
||||||
ideal_idx_.Resize(N);
|
ReinitializeTensor(
|
||||||
rank_idx_.Resize(N);
|
&ideal_idx_, {N}, at::dtype<int>().device(CPU));
|
||||||
|
ReinitializeTensor(
|
||||||
|
&rank_idx_, {N}, at::dtype<int>().device(CPU));
|
||||||
auto* rank_idx_data = rank_idx_.template mutable_data<int>();
|
auto* rank_idx_data = rank_idx_.template mutable_data<int>();
|
||||||
auto* ideal_idx_data = ideal_idx_.template mutable_data<int>();
|
auto* ideal_idx_data = ideal_idx_.template mutable_data<int>();
|
||||||
|
|
||||||
@ -114,7 +120,8 @@ float LambdaRankNdcgOp<float, CPUContext>::LambdaRankNdcgSession(
|
|||||||
}
|
}
|
||||||
|
|
||||||
const double log2f_ = std::log(2.f);
|
const double log2f_ = std::log(2.f);
|
||||||
gain_.Resize(N);
|
ReinitializeTensor(
|
||||||
|
&gain_, {N}, at::dtype<float>().device(CPU));
|
||||||
auto* gain_data = gain_.template mutable_data<float>();
|
auto* gain_data = gain_.template mutable_data<float>();
|
||||||
EigenVectorArrayMap<float> gain_vec(gain_data, gain_.numel());
|
EigenVectorArrayMap<float> gain_vec(gain_data, gain_.numel());
|
||||||
|
|
||||||
@ -141,7 +148,8 @@ float LambdaRankNdcgOp<float, CPUContext>::LambdaRankNdcgSession(
|
|||||||
// similar to ideal but replace with actual discounts
|
// similar to ideal but replace with actual discounts
|
||||||
double dcg = (gain_vec * discount_vec).sum();
|
double dcg = (gain_vec * discount_vec).sum();
|
||||||
|
|
||||||
lambda_.Resize(N * N);
|
ReinitializeTensor(
|
||||||
|
&lambda_, {N * N}, at::dtype<float>().device(CPU));
|
||||||
auto* lambda_data = lambda_.template mutable_data<float>();
|
auto* lambda_data = lambda_.template mutable_data<float>();
|
||||||
EigenArrayMap<float> lambda_mat(lambda_data, N, N);
|
EigenArrayMap<float> lambda_mat(lambda_data, N, N);
|
||||||
// computes lambda weight (i, j) = abs(gain_dff * discount_diff)
|
// computes lambda weight (i, j) = abs(gain_dff * discount_diff)
|
||||||
|
@ -35,12 +35,12 @@ class LambdaRankNdcgOp final : public Operator<Context> {
|
|||||||
Tensor** dy);
|
Tensor** dy);
|
||||||
bool use_ndcg_as_loss_;
|
bool use_ndcg_as_loss_;
|
||||||
bool use_exp_gain_;
|
bool use_exp_gain_;
|
||||||
Tensor gain_{Context::GetDeviceType()};
|
Tensor gain_;
|
||||||
Tensor discount_{Context::GetDeviceType()};
|
Tensor discount_;
|
||||||
Tensor rank_idx_{Context::GetDeviceType()};
|
Tensor rank_idx_;
|
||||||
Tensor ideal_idx_{Context::GetDeviceType()};
|
Tensor ideal_idx_;
|
||||||
Tensor lambda_{Context::GetDeviceType()};
|
Tensor lambda_;
|
||||||
Tensor inv_log_i_{Context::GetDeviceType()};
|
Tensor inv_log_i_;
|
||||||
};
|
};
|
||||||
|
|
||||||
template <typename T, class Context>
|
template <typename T, class Context>
|
||||||
|
@ -25,8 +25,14 @@ bool PercentileOp<CPUContext>::RunOnDevice() {
|
|||||||
num_values,
|
num_values,
|
||||||
"Sum of lengths should be equal to the total number of samples");
|
"Sum of lengths should be equal to the total number of samples");
|
||||||
|
|
||||||
values_tensor.Resize(num_values);
|
ReinitializeTensor(
|
||||||
percentiles_tensor.Resize(num_values);
|
&values_tensor,
|
||||||
|
{num_values},
|
||||||
|
at::dtype<float>().device(CPU));
|
||||||
|
ReinitializeTensor(
|
||||||
|
&percentiles_tensor,
|
||||||
|
{num_values},
|
||||||
|
at::dtype<float>().device(CPU));
|
||||||
float* values_tensor_data = values_tensor.template mutable_data<float>();
|
float* values_tensor_data = values_tensor.template mutable_data<float>();
|
||||||
float* percentiles_tensor_data =
|
float* percentiles_tensor_data =
|
||||||
percentiles_tensor.template mutable_data<float>();
|
percentiles_tensor.template mutable_data<float>();
|
||||||
|
@ -25,8 +25,8 @@ class PercentileOp final : public Operator<Context> {
|
|||||||
protected:
|
protected:
|
||||||
INPUT_TAGS(X, VAL_PCT_PAIRS, LENS);
|
INPUT_TAGS(X, VAL_PCT_PAIRS, LENS);
|
||||||
OUTPUT_TAGS(PCT);
|
OUTPUT_TAGS(PCT);
|
||||||
Tensor values_tensor{Context::GetDeviceType()};
|
Tensor values_tensor;
|
||||||
Tensor percentiles_tensor{Context::GetDeviceType()};
|
Tensor percentiles_tensor;
|
||||||
};
|
};
|
||||||
|
|
||||||
} // namespace caffe2
|
} // namespace caffe2
|
||||||
|
@ -55,7 +55,7 @@ class Int8ConcatOp final : public Operator<CPUContext> {
|
|||||||
}
|
}
|
||||||
Y_dims[axis_] += Xi.t.size(axis_);
|
Y_dims[axis_] += Xi.t.size(axis_);
|
||||||
}
|
}
|
||||||
Y->t.Resize(Y_dims);
|
ReinitializeTensor(&Y->t, Y_dims, at::dtype<uint8_t>().device(CPU));
|
||||||
int before = X0.t.size_to_dim(axis_);
|
int before = X0.t.size_to_dim(axis_);
|
||||||
int after = X0.t.size_from_dim(axis_ + 1);
|
int after = X0.t.size_from_dim(axis_ + 1);
|
||||||
const auto C_total = Y_dims[axis_];
|
const auto C_total = Y_dims[axis_];
|
||||||
|
@ -43,7 +43,7 @@ class Int8FCOp final : public Operator<CPUContext> {
|
|||||||
CHECK_EQ(K, W.t.size(1));
|
CHECK_EQ(K, W.t.size(1));
|
||||||
CHECK_EQ(N, B.t.numel());
|
CHECK_EQ(N, B.t.numel());
|
||||||
const auto M = X.t.numel() / K;
|
const auto M = X.t.numel() / K;
|
||||||
Y->t.Resize(M, N);
|
ReinitializeTensor(&Y->t, {M, N}, at::dtype<uint8_t>().device(CPU));
|
||||||
|
|
||||||
runWithSharedBuffer<CPUContext>(ws_, [&](Tensor* buffer) {
|
runWithSharedBuffer<CPUContext>(ws_, [&](Tensor* buffer) {
|
||||||
initQNNPACK();
|
initQNNPACK();
|
||||||
|
Reference in New Issue
Block a user