mirror of
https://github.com/pytorch/pytorch.git
synced 2025-11-11 22:34:53 +08:00
Summary: Should be non-semantic. Uses https://en.wikipedia.org/wiki/Wikipedia:Lists_of_common_misspellings/For_machines to find likely typos, with https://github.com/bwignall/typochecker to help automate the checking. Uses an updated version of the tool used in https://github.com/pytorch/pytorch/pull/30606 . Pull Request resolved: https://github.com/pytorch/pytorch/pull/31523 Differential Revision: D19216749 Pulled By: mrshenli fbshipit-source-id: 7fd489cb9a77cd7e4950c1046f925d57524960ea
141 lines
4.0 KiB
C++
141 lines
4.0 KiB
C++
#ifndef CAFFE2_OPERATORS_ACTIVATION_OPS_CUDNN_H_
|
|
#define CAFFE2_OPERATORS_ACTIVATION_OPS_CUDNN_H_
|
|
|
|
#include "caffe2/core/context_gpu.h"
|
|
#include "caffe2/core/cudnn_wrappers.h"
|
|
#include "caffe2/core/operator.h"
|
|
#include "caffe2/core/tensor.h"
|
|
#include "caffe2/core/types.h"
|
|
|
|
namespace caffe2 {
|
|
|
|
class CuDNNActivationOpBase : public Operator<CUDAContext> {
|
|
public:
|
|
USE_OPERATOR_FUNCTIONS(CUDAContext);
|
|
|
|
template <class... Args>
|
|
explicit CuDNNActivationOpBase(Args&&... args)
|
|
: Operator<CUDAContext>(std::forward<Args>(args)...),
|
|
cudnn_wrapper_(&context_) {
|
|
CUDNN_ENFORCE(cudnnCreateTensorDescriptor(&data_desc_));
|
|
CUDNN_ENFORCE(cudnnCreateActivationDescriptor(&act_desc_));
|
|
}
|
|
|
|
virtual ~CuDNNActivationOpBase() {
|
|
CUDNN_ENFORCE(cudnnDestroyTensorDescriptor(data_desc_));
|
|
CUDNN_ENFORCE(cudnnDestroyActivationDescriptor(act_desc_));
|
|
}
|
|
|
|
protected:
|
|
void SetTensorDescriptor(
|
|
const cudnnDataType_t data_type,
|
|
const int data_size) {
|
|
if (data_size != input_size_) {
|
|
// Since the best performance is obtained when the tensor is HW-packed, we
|
|
// put X.size() to W.
|
|
input_size_ = data_size;
|
|
CUDNN_ENFORCE(cudnnSetTensor4dDescriptor(
|
|
data_desc_,
|
|
GetCudnnTensorFormat(StorageOrder::NCHW),
|
|
data_type,
|
|
1,
|
|
1,
|
|
1,
|
|
input_size_));
|
|
}
|
|
}
|
|
|
|
CuDNNWrapper cudnn_wrapper_;
|
|
cudnnTensorDescriptor_t data_desc_;
|
|
cudnnActivationDescriptor_t act_desc_;
|
|
|
|
int input_size_ = 0;
|
|
};
|
|
|
|
template <cudnnActivationMode_t kCuDNNActivationMode>
|
|
class CuDNNActivationOp final : public CuDNNActivationOpBase {
|
|
public:
|
|
USE_OPERATOR_FUNCTIONS(CUDAContext);
|
|
|
|
template <class... Args>
|
|
explicit CuDNNActivationOp(Args&&... args)
|
|
: CuDNNActivationOpBase(std::forward<Args>(args)...) {
|
|
CUDNN_ENFORCE(cudnnSetActivationDescriptor(
|
|
act_desc_, kCuDNNActivationMode, CUDNN_PROPAGATE_NAN, 0.0));
|
|
}
|
|
|
|
bool RunOnDevice() override {
|
|
return DispatchHelper<TensorTypes<float, at::Half>>::call(this, Input(0));
|
|
}
|
|
|
|
template <typename T>
|
|
bool DoRunWithType() {
|
|
const auto& X = Input(0);
|
|
|
|
auto* Y = Output(0, X.sizes(), at::dtype<T>());
|
|
if (X.numel() == 0) {
|
|
Y->template mutable_data<T>();
|
|
return true;
|
|
}
|
|
this->SetTensorDescriptor(cudnnTypeWrapper<T>::type, X.numel());
|
|
CUDNN_ENFORCE(cudnnActivationForward(
|
|
this->cudnn_wrapper_.inline_cudnn_handle(),
|
|
this->act_desc_,
|
|
cudnnTypeWrapper<T>::kOne(),
|
|
this->data_desc_,
|
|
X.template data<T>(),
|
|
cudnnTypeWrapper<T>::kZero(),
|
|
this->data_desc_,
|
|
Y->template mutable_data<T>()));
|
|
return true;
|
|
}
|
|
};
|
|
|
|
template <cudnnActivationMode_t kCuDNNActivationMode>
|
|
class CuDNNActivationGradientOp final : public CuDNNActivationOpBase {
|
|
public:
|
|
USE_OPERATOR_FUNCTIONS(CUDAContext);
|
|
|
|
template <class... Args>
|
|
explicit CuDNNActivationGradientOp(Args&&... args)
|
|
: CuDNNActivationOpBase(std::forward<Args>(args)...) {
|
|
CUDNN_ENFORCE(cudnnSetActivationDescriptor(
|
|
act_desc_, kCuDNNActivationMode, CUDNN_PROPAGATE_NAN, 0.0));
|
|
}
|
|
|
|
bool RunOnDevice() override {
|
|
return DispatchHelper<TensorTypes<float, at::Half>>::call(this, Input(0));
|
|
}
|
|
|
|
template <typename T>
|
|
bool DoRunWithType() {
|
|
const auto& Y = Input(0);
|
|
const auto& dY = Input(1);
|
|
|
|
auto* dX = Output(0, Y.sizes(), at::dtype<T>());
|
|
if (Y.numel() == 0) {
|
|
dX->template mutable_data<T>();
|
|
return true;
|
|
}
|
|
this->SetTensorDescriptor(cudnnTypeWrapper<T>::type, Y.numel());
|
|
CUDNN_ENFORCE(cudnnActivationBackward(
|
|
this->cudnn_wrapper_.inline_cudnn_handle(),
|
|
this->act_desc_,
|
|
cudnnTypeWrapper<T>::kOne(),
|
|
this->data_desc_,
|
|
Y.template data<T>(),
|
|
this->data_desc_,
|
|
dY.template data<T>(),
|
|
this->data_desc_,
|
|
Y.template data<T>(), // Use Y_data as placeholder here.
|
|
cudnnTypeWrapper<T>::kZero(),
|
|
this->data_desc_,
|
|
dX->template mutable_data<T>()));
|
|
return true;
|
|
}
|
|
};
|
|
|
|
} // namespace caffe2
|
|
|
|
#endif // CAFFE2_OPERATORS_ACTIVATION_OPS_CUDNN_H_
|