Files
pytorch/torch/csrc/autograd/variable_info.h
Joel Schlosser 19918a1863 Fix autograd.Function + NJT when an output grad is None (#136875)
For `autograd.Function`, the engine will try to allocate correctly-shaped zeros for `None` grads (i.e. in the case where the output isn't used downstream). It determines the shape of these zeros from the `VariableInfo` entry, which is derived from the forward output shape. For the NJT forward output case, the size info stored will contain a nested int, and calling `zeros()` with this size throws:
```
RuntimeError: .../build/aten/src/ATen/RegisterCPU.cpp:5260: SymIntArrayRef expected to contain only concrete integers
```

This PR fixes this by storing the full tensor in the `VariableInfo` for the nested case and calling `zeros_like()` to allocate correctly-shaped zeros. This is pretty inefficient; ideally we would want to save just the NJT shape and be able to construct zeros from it, but this requires factory function support for nested ints (WIP). So this is a short-term fix until we have that.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/136875
Approved by: https://github.com/soulitzer, https://github.com/huydhn
2024-10-14 19:31:50 +00:00

24 lines
608 B
C++

#pragma once
#include <torch/csrc/autograd/variable.h>
namespace torch::autograd {
struct TORCH_API VariableInfo {
explicit VariableInfo();
explicit VariableInfo(const Variable& var, bool use_zeros_like = false);
Variable zeros(at::OptionalDeviceGuard& device_guard) const;
at::Layout layout = at::Layout::Strided;
at::Device device = at::kCPU;
at::ScalarType scalar_type = at::kFloat;
std::vector<c10::SymInt> size;
bool requires_grad;
bool is_empty;
// needed for e.g. NJTs since they only support zeros_like()
std::optional<Variable> the_var;
};
} // namespace torch::autograd