mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 12:54:11 +08:00
[PyTorch Mobile]Move train related files to their own folder (#58205)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/58205 It's worthing moving train related files into their own folder since we are adding more code under the mobile directory. This diff does that. Test Plan: run unit tests and ci Reviewed By: iseeyuan Differential Revision: D28402432 fbshipit-source-id: cd76a1c4f8ff06508cdc3aad8a169fbf34bb4995
This commit is contained in:
committed by
Facebook GitHub Bot
parent
49a8942a77
commit
73d51406fa
133
torch/csrc/jit/mobile/train/optim/sgd.cpp
Normal file
133
torch/csrc/jit/mobile/train/optim/sgd.cpp
Normal file
@ -0,0 +1,133 @@
|
||||
#include <torch/csrc/jit/mobile/train/optim/sgd.h>
|
||||
|
||||
#include <torch/types.h>
|
||||
#include <torch/utils.h>
|
||||
|
||||
#include <ATen/ATen.h>
|
||||
|
||||
#include <functional>
|
||||
|
||||
namespace torch {
|
||||
namespace jit {
|
||||
namespace mobile {
|
||||
|
||||
bool SGDParamGroup::has_options() const {
|
||||
return options_ != nullptr;
|
||||
}
|
||||
|
||||
SGDOptions& SGDParamGroup::options() {
|
||||
TORCH_CHECK(has_options());
|
||||
return *options_.get();
|
||||
}
|
||||
|
||||
const SGDOptions& SGDParamGroup::options() const {
|
||||
TORCH_CHECK(has_options());
|
||||
return *options_.get();
|
||||
}
|
||||
|
||||
void SGDParamGroup::set_options(std::unique_ptr<SGDOptions> options) {
|
||||
options_ = std::move(options);
|
||||
}
|
||||
|
||||
std::vector<Tensor>& SGDParamGroup::params() {
|
||||
return params_;
|
||||
}
|
||||
|
||||
const std::vector<Tensor>& SGDParamGroup::params() const {
|
||||
return params_;
|
||||
}
|
||||
|
||||
SGDOptions::SGDOptions(double lr) : lr_(lr) {}
|
||||
|
||||
bool operator==(const SGDOptions& lhs, const SGDOptions& rhs) {
|
||||
return (lhs.lr() == rhs.lr()) && (lhs.momentum() == rhs.momentum()) &&
|
||||
(lhs.dampening() == rhs.dampening()) &&
|
||||
(lhs.weight_decay() == rhs.weight_decay()) &&
|
||||
(lhs.nesterov() == rhs.nesterov());
|
||||
}
|
||||
|
||||
bool operator==(const SGDParamState& lhs, const SGDParamState& rhs) {
|
||||
return torch::equal(lhs.momentum_buffer(), rhs.momentum_buffer());
|
||||
}
|
||||
|
||||
void SGD::add_param_group(const SGDParamGroup& param_group) {
|
||||
for (const auto& param : param_group.params()) {
|
||||
TORCH_CHECK(param.is_leaf(), "can't optimize a non-leaf Tensor");
|
||||
}
|
||||
TORCH_INTERNAL_ASSERT(defaults_ != nullptr);
|
||||
SGDParamGroup param_group_(param_group.params());
|
||||
if (!param_group.has_options()) {
|
||||
param_group_.set_options(defaults_->clone());
|
||||
} else {
|
||||
param_group_.set_options(param_group.options().clone());
|
||||
}
|
||||
for (const auto& p : param_group_.params()) {
|
||||
TORCH_CHECK(
|
||||
state_.count(c10::guts::to_string(p.unsafeGetTensorImpl())) == 0,
|
||||
"some parameters appear in more than one parameter group");
|
||||
}
|
||||
param_groups_.emplace_back(std::move(param_group_));
|
||||
}
|
||||
|
||||
void SGD::zero_grad() {
|
||||
for (auto& group : param_groups_) {
|
||||
for (auto& p : group.params()) {
|
||||
if (p.grad().defined()) {
|
||||
p.grad().detach_();
|
||||
p.grad().zero_();
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Tensor SGD::step(const LossClosure& closure) {
|
||||
NoGradGuard no_grad;
|
||||
Tensor loss = {};
|
||||
if (closure != nullptr) {
|
||||
at::AutoGradMode enable_grad(true);
|
||||
loss = closure();
|
||||
}
|
||||
for (auto& group : param_groups_) {
|
||||
auto& options = static_cast<SGDOptions&>(group.options());
|
||||
auto weight_decay = options.weight_decay();
|
||||
auto momentum = options.momentum();
|
||||
auto dampening = options.dampening();
|
||||
auto nesterov = options.nesterov();
|
||||
|
||||
for (auto& p : group.params()) {
|
||||
if (!p.grad().defined()) {
|
||||
continue;
|
||||
}
|
||||
auto d_p = p.grad().data();
|
||||
if (weight_decay != 0) {
|
||||
d_p = d_p.add(p.data(), weight_decay);
|
||||
}
|
||||
if (momentum != 0) {
|
||||
Tensor buf;
|
||||
auto param_state =
|
||||
state_.find(c10::guts::to_string(p.unsafeGetTensorImpl()));
|
||||
if (param_state == state_.end()) {
|
||||
buf = torch::clone(d_p).detach();
|
||||
auto state = std::make_unique<SGDParamState>();
|
||||
state->momentum_buffer(buf);
|
||||
state_[c10::guts::to_string(p.unsafeGetTensorImpl())] =
|
||||
std::move(state);
|
||||
} else {
|
||||
buf = static_cast<SGDParamState&>(*param_state->second)
|
||||
.momentum_buffer();
|
||||
buf.mul_(momentum).add_(d_p, 1 - dampening);
|
||||
}
|
||||
if (nesterov) {
|
||||
d_p = d_p.add(buf, momentum);
|
||||
} else {
|
||||
d_p = buf;
|
||||
}
|
||||
}
|
||||
p.data().add_(d_p, -1 * options.lr());
|
||||
}
|
||||
}
|
||||
return loss;
|
||||
}
|
||||
} // namespace mobile
|
||||
} // namespace jit
|
||||
} // namespace torch
|
Reference in New Issue
Block a user