#include #include #include #include #include namespace torch { namespace nn { namespace detail { template DropoutImplBase::DropoutImplBase(DropoutOptions options_) : options(options_) { AT_CHECK(options.rate_ >= 0, "Dropout rate must not be less than zero"); AT_CHECK(options.rate_ <= 1, "Dropout rate must not be greater than one"); } template void DropoutImplBase::reset() {} template class DropoutImplBase; template class DropoutImplBase; } // namespace detail DropoutOptions::DropoutOptions(double rate) : rate_(rate) {} Tensor DropoutImpl::forward(Tensor input) { return torch::dropout(input, options.rate_, this->is_training()); } Tensor FeatureDropoutImpl::forward(Tensor input) { return torch::feature_dropout(input, options.rate_, this->is_training()); } } // namespace nn } // namespace torch