mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 12:54:11 +08:00
Revert "Eliminate c10::guts::to_string (#108480)"
This reverts commit 4146be192ead477360a2763c5005e46a9485c3bf. Reverted https://github.com/pytorch/pytorch/pull/108480 on behalf of https://github.com/huydhn due to Sorry for reverting this, but this is needed to keep trunk green after https://github.com/pytorch/pytorch/pull/108479 was reverted. Both will need to be relanded ([comment](https://github.com/pytorch/pytorch/pull/108480#issuecomment-1707067595))
This commit is contained in:
@ -63,7 +63,7 @@ void SGD::add_param_group(const SGDParamGroup& param_group) {
|
||||
}
|
||||
for (const auto& p : param_group_.params()) {
|
||||
TORCH_CHECK(
|
||||
state_.count(p.unsafeGetTensorImpl()) == 0,
|
||||
state_.count(c10::guts::to_string(p.unsafeGetTensorImpl())) == 0,
|
||||
"some parameters appear in more than one parameter group");
|
||||
}
|
||||
param_groups_.emplace_back(std::move(param_group_));
|
||||
@ -104,12 +104,14 @@ Tensor SGD::step(const LossClosure& closure) {
|
||||
}
|
||||
if (momentum != 0) {
|
||||
Tensor buf;
|
||||
auto param_state = state_.find(p.unsafeGetTensorImpl());
|
||||
auto param_state =
|
||||
state_.find(c10::guts::to_string(p.unsafeGetTensorImpl()));
|
||||
if (param_state == state_.end()) {
|
||||
buf = torch::clone(d_p).detach();
|
||||
auto state = std::make_unique<SGDParamState>();
|
||||
state->momentum_buffer(buf);
|
||||
state_[p.unsafeGetTensorImpl()] = std::move(state);
|
||||
state_[c10::guts::to_string(p.unsafeGetTensorImpl())] =
|
||||
std::move(state);
|
||||
} else {
|
||||
buf = static_cast<SGDParamState&>(*param_state->second)
|
||||
.momentum_buffer();
|
||||
|
Reference in New Issue
Block a user