mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-21 05:34:18 +08:00
Summary: In TorchScript and C++ extensions we currently advocate a mix of `torch::` and `at::` namespace usage. In the C++ frontend I had instead exported all symbols from `at::` and some from `c10::` into the `torch::` namespace. This is far, far easier for users to understand, and also avoid bugs around creating tensors vs. variables. The same should from now on be true for the TorchScript C++ API (for running and loading models) and all C++ extensions. Note that since we're just talking about typedefs, this change does not break any existing code. Once this lands I will update stuff in `pytorch/tutorials` too. zdevito ezyang gchanan Pull Request resolved: https://github.com/pytorch/pytorch/pull/13523 Differential Revision: D12942787 Pulled By: goldsborough fbshipit-source-id: 76058936bd8707b33d9e5bbc2d0705fc3d820763
38 lines
1000 B
C++
38 lines
1000 B
C++
#include <torch/nn/modules/dropout.h>
|
|
|
|
#include <torch/types.h>
|
|
|
|
#include <c10/util/Exception.h>
|
|
|
|
#include <cstddef>
|
|
#include <vector>
|
|
|
|
namespace torch {
|
|
namespace nn {
|
|
namespace detail {
|
|
template <typename Derived>
|
|
DropoutImplBase<Derived>::DropoutImplBase(DropoutOptions options_)
|
|
: options(options_) {
|
|
AT_CHECK(options.rate_ >= 0, "Dropout rate must not be less than zero");
|
|
AT_CHECK(options.rate_ <= 1, "Dropout rate must not be greater than one");
|
|
}
|
|
|
|
template <typename Derived>
|
|
void DropoutImplBase<Derived>::reset() {}
|
|
|
|
template class DropoutImplBase<DropoutImpl>;
|
|
template class DropoutImplBase<FeatureDropoutImpl>;
|
|
} // namespace detail
|
|
|
|
DropoutOptions::DropoutOptions(double rate) : rate_(rate) {}
|
|
|
|
Tensor DropoutImpl::forward(Tensor input) {
|
|
return torch::dropout(input, options.rate_, this->is_training());
|
|
}
|
|
|
|
Tensor FeatureDropoutImpl::forward(Tensor input) {
|
|
return torch::feature_dropout(input, options.rate_, this->is_training());
|
|
}
|
|
} // namespace nn
|
|
} // namespace torch
|