Revert "Eliminate c10::guts::to_string (#108480)"

This reverts commit 4146be192ead477360a2763c5005e46a9485c3bf.

Reverted https://github.com/pytorch/pytorch/pull/108480 on behalf of https://github.com/huydhn due to Sorry for reverting this, but this is needed to keep trunk green after https://github.com/pytorch/pytorch/pull/108479 was reverted.  Both will need to be relanded ([comment](https://github.com/pytorch/pytorch/pull/108480#issuecomment-1707067595))
This commit is contained in:
PyTorch MergeBot
2023-09-05 18:04:53 +00:00
parent 5b31a41841
commit 8da04e023e
44 changed files with 225 additions and 164 deletions

View File

@ -63,7 +63,7 @@ void SGD::add_param_group(const SGDParamGroup& param_group) {
}
for (const auto& p : param_group_.params()) {
TORCH_CHECK(
state_.count(p.unsafeGetTensorImpl()) == 0,
state_.count(c10::guts::to_string(p.unsafeGetTensorImpl())) == 0,
"some parameters appear in more than one parameter group");
}
param_groups_.emplace_back(std::move(param_group_));
@ -104,12 +104,14 @@ Tensor SGD::step(const LossClosure& closure) {
}
if (momentum != 0) {
Tensor buf;
auto param_state = state_.find(p.unsafeGetTensorImpl());
auto param_state =
state_.find(c10::guts::to_string(p.unsafeGetTensorImpl()));
if (param_state == state_.end()) {
buf = torch::clone(d_p).detach();
auto state = std::make_unique<SGDParamState>();
state->momentum_buffer(buf);
state_[p.unsafeGetTensorImpl()] = std::move(state);
state_[c10::guts::to_string(p.unsafeGetTensorImpl())] =
std::move(state);
} else {
buf = static_cast<SGDParamState&>(*param_state->second)
.momentum_buffer();