mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-30 19:54:53 +08:00
retain undefined tensors in backward pass (#41490)
Summary: Leave undefined tensors / None returned from custom backward functions as undefined/None instead of creating a tensor full of zeros. This change improves performance in some cases. **This is BC-Breaking:** Custom backward functions that return None will now see it potentially being propagated all the way up to AccumulateGrad nodes. Potential impact is that .grad field of leaf tensors as well as the result of autograd.grad may be undefined/None where it used to be a tensor full of zeros. Also, autograd.grad may raise an error, if so, consider using allow_unused=True ([see doc](https://pytorch.org/docs/stable/autograd.html?highlight=autograd%20grad#torch.autograd.grad)) if it applies to your case. Pull Request resolved: https://github.com/pytorch/pytorch/pull/41490 Reviewed By: albanD Differential Revision: D22578241 Pulled By: heitorschueroff fbshipit-source-id: f4966f4cb520069294f8c5c1691eeea799cc0abe
This commit is contained in:
committed by
Facebook GitHub Bot
parent
a874c1e584
commit
cf811d2fb3
@ -171,12 +171,7 @@ auto PyNode::apply(variable_list&& inputs) -> variable_list {
|
||||
continue;
|
||||
}
|
||||
if (output == Py_None) {
|
||||
auto& info = input_info[results.size()];
|
||||
if (info.requires_grad) {
|
||||
results.emplace_back(info.zeros(_device_guard));
|
||||
} else {
|
||||
results.emplace_back();
|
||||
}
|
||||
results.emplace_back();
|
||||
} else {
|
||||
if (!THPVariable_Check(output)) {
|
||||
std::string msg("expected Variable or None (got ");
|
||||
@ -779,8 +774,6 @@ PyObject * THPFunction_do_backward(THPFunction *self, PyObject *args)
|
||||
"gradient tensors (expected %d, but got %d)", THPUtils_typename(self),
|
||||
num_outputs, num_grads);
|
||||
|
||||
// If any of the remaining grad_inputs are None, zero them.
|
||||
_prepare_grads(self, grad_input, false);
|
||||
return grad_input.release();
|
||||
|
||||
} catch (python_error& e) {
|
||||
|
||||
Reference in New Issue
Block a user