mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-21 05:34:18 +08:00
Move most methods off Variable into torch::autograd::impl functions. (#29665)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/29665 Our intention is to merge the static distinction between Tensor and Variable. Ordinarily, this would entail merging the methods of Tensor and Variable. But there are a lot of "private"-ish methods on Variable that we don't actually want to dump onto the Tensor class. So, as prep work, we move all of those methods off of Variable and into the torch::autograd::impl namespace (impl as in, please don't use this end users). This ends up being a fairly large patch because all of the call sites have to play ball too. While I was on the topic, I also moved any of the touched functions into the C++ file, so that modifying them would not trigger a recompilation of all of torch. Signed-off-by: Edward Z. Yang <ezyang@fb.com> Test Plan: Imported from OSS Differential Revision: D18496169 Pulled By: ezyang fbshipit-source-id: afb203252620ec274be596b3e7b1d84d321bad3a
This commit is contained in:
committed by
Facebook Github Bot
parent
38340f59fd
commit
1ab2f043ba
@ -142,7 +142,7 @@ PyObject *THPEngine_run_backward(THPEngine *self, PyObject *args, PyObject *kwar
|
||||
THPUtils_assert(THPVariable_Check(_tensor), "element %d of tensors "
|
||||
"tuple is not a Tensor", i);
|
||||
auto& variable = ((THPVariable*)_tensor)->cdata;
|
||||
auto gradient_edge = variable.gradient_edge();
|
||||
auto gradient_edge = torch::autograd::impl::gradient_edge(variable);
|
||||
THPUtils_assert(gradient_edge.function,
|
||||
"element %d of tensors does not require grad and does not have a grad_fn", i);
|
||||
roots.push_back(std::move(gradient_edge));
|
||||
@ -180,7 +180,7 @@ PyObject *THPEngine_run_backward(THPEngine *self, PyObject *args, PyObject *kwar
|
||||
const auto output_nr = input_var->cdata.output_nr();
|
||||
auto grad_fn = input_var->cdata.grad_fn();
|
||||
if (!grad_fn) {
|
||||
grad_fn = input_var->cdata.try_get_grad_accumulator();
|
||||
grad_fn = torch::autograd::impl::try_get_grad_accumulator(input_var->cdata);
|
||||
}
|
||||
THPUtils_assert(input_var->cdata.requires_grad(),
|
||||
"One of the differentiated Tensors does not require grad");
|
||||
|
Reference in New Issue
Block a user