Closes #120988
Currently operators that hit the autograd fallback call `check_inplace`
on all mutated inputs, including out arguments. This leads to a slightly
confusing error message:
```
RuntimeError: a leaf Variable that requires grad is being used in an in-place operation.
```
Compared to functions that don't fallback, which raise
```
RuntimeError: add(): functions with out=... arguments don't support automatic differentiation, but one of the arguments requires grad.
```
This changes the error message to make clear the issue is with the out argument,
but does not tighten the check to outright ban out arguments that require grad.
Instead, I use the same checks from `check_inplace` which allows non-leaf tensors
that require grad to pass without error.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/121089
Approved by: https://github.com/lezcano, https://github.com/soulitzer
ghstack dependencies: #121142
Replaces `view_func()` closures with a reified `ViewFunc` data structure. Codegen generates a `ViewFunc` subclass for each view op (e.g. `NarrowViewFunc`) containing state needed to reconstruct the view. The `ViewFunc` API allows for querying and hot-swapping any `SymInt`s or `Tensors` in the state through `get_symints()` / `get_tensors()` / `clone_and_set()`, which will be essential for fake-ification later on.
```cpp
/// Base class for view functions, providing reapplication of a view on a new base.
/// Each view op should get a codegenerated subclass of this class containing
/// any state needed to reconstruct the view. The class also provides convenience
/// accessors for saved SymInts / tensor state. This is useful for e.g. fake-ification,
/// where we want to use symbolic values or fake tensors instead.
struct TORCH_API ViewFunc {
virtual ~ViewFunc() {}
/// Returns any SymInts in the saved state.
virtual std::vector<c10::SymInt> get_symints() const { return {}; }
/// Returns the number of SymInts in the saved state.
virtual size_t num_symints() const { return 0; }
/// Returns any tensors in the saved state.
virtual std::vector<at::Tensor> get_tensors() const { return {}; }
/// Returns the number of tensors in the saved state.
virtual size_t num_tensors() const { return 0; }
/// Reapplies the view on the given base using the saved state.
virtual at::Tensor operator()(const at::Tensor&) const = 0;
/// Returns a clone of this ViewFunc, optionally with the specified saved state.
virtual std::unique_ptr<ViewFunc> clone_and_set(
std::optional<std::vector<c10::SymInt>> = c10::nullopt,
std::optional<std::vector<at::Tensor>> = c10::nullopt) const = 0;
protected:
/// Sets the values of any SymInts in the saved state. The input vector size must
/// match the number of SymInts in the saved state (i.e. the size of the list
/// returned by get_symints()).
virtual void set_symints(std::vector<c10::SymInt>) {}
/// Sets the values of any Tensors in the saved state. The input vector size must
/// match the number of Tensors in the saved state (i.e. the size of the list
/// returned by get_tensors()).
virtual void set_tensors(std::vector<at::Tensor>) {}
};
```
New codegen files:
* `torch/csrc/autograd/generated/ViewFunc.h`
* `torch/csrc/autograd/generated/ViewFuncs.cpp`
The templates for these also contains impls for `ChainedViewFunc` and `ErroringViewFunc` which are used in a few places within autograd.
Example codegen for `slice.Tensor`:
```cpp
// torch/csrc/autograd/generated/ViewFuncs.h
#define SLICE_TENSOR_VIEW_FUNC_AVAILABLE
struct SliceTensorViewFunc : public torch::autograd::ViewFunc {
SliceTensorViewFunc(int64_t dim, c10::optional<c10::SymInt> start, c10::optional<c10::SymInt> end, c10::SymInt step) : dim(dim), start(start), end(end), step(step)
{};
virtual ~SliceTensorViewFunc() override {};
virtual std::vector<c10::SymInt> get_symints() const override;
virtual size_t num_symints() const override;
virtual std::vector<at::Tensor> get_tensors() const override;
virtual size_t num_tensors() const override;
virtual at::Tensor operator()(const at::Tensor&) const override;
virtual std::unique_ptr<ViewFunc> clone_and_set(
std::optional<std::vector<c10::SymInt>> = c10::nullopt,
std::optional<std::vector<at::Tensor>> = c10::nullopt) const override;
protected:
virtual void set_symints(std::vector<c10::SymInt>) override;
virtual void set_tensors(std::vector<at::Tensor>) override;
private:
int64_t dim;
c10::optional<c10::SymInt> start;
c10::optional<c10::SymInt> end;
c10::SymInt step;
};
...
// torch/csrc/autograd/generated/ViewFuncs.cpp
std::vector<c10::SymInt> SliceTensorViewFunc::get_symints() const {
::std::vector<c10::SymInt> symints;
symints.reserve((start.has_value() ? 1 : 0) + (end.has_value() ? 1 : 0) + 1);
if(start.has_value()) symints.insert(symints.end(), *(start));
if(end.has_value()) symints.insert(symints.end(), *(end));
symints.push_back(step);
return symints;
}
size_t SliceTensorViewFunc::num_symints() const {
return static_cast<size_t>((start.has_value() ? 1 : 0) + (end.has_value() ? 1 : 0) + 1);
}
void SliceTensorViewFunc::set_symints(std::vector<c10::SymInt> symints) {
TORCH_INTERNAL_ASSERT(symints.size() == num_symints());
auto i = 0;
if(start.has_value()) start = symints[i];
i += (start.has_value() ? 1 : 0);
if(end.has_value()) end = symints[i];
i += (end.has_value() ? 1 : 0);
step = symints[i];
}
std::vector<at::Tensor> SliceTensorViewFunc::get_tensors() const {
::std::vector<at::Tensor> tensors;
return tensors;
}
size_t SliceTensorViewFunc::num_tensors() const {
return static_cast<size_t>(0);
}
void SliceTensorViewFunc::set_tensors(std::vector<at::Tensor> tensors) {
TORCH_INTERNAL_ASSERT(tensors.size() == num_tensors());
}
at::Tensor SliceTensorViewFunc::operator()(const at::Tensor& input_base) const {
return at::_ops::slice_Tensor::call(input_base, dim, start, end, step);
}
std::unique_ptr<ViewFunc> SliceTensorViewFunc::clone_and_set(
std::optional<std::vector<c10::SymInt>> symints,
std::optional<std::vector<at::Tensor>> tensors) const {
auto output = std::make_unique<SliceTensorViewFunc>(dim, start, end, step);
if (symints.has_value()) {
output->set_symints(std::move(*(symints)));
}
if (tensors.has_value()) {
output->set_tensors(std::move(*(tensors)));
}
return output;
}
```
The `_view_func()` / `_view_func_unsafe()` methods now accept two additional (optional) args for `symint_visitor_fn` / `tensor_visitor_fn`. If these are defined, they are expected to be python callables that operate on a single SymInt / tensor and return a new one. This allows for the hot-swapping needed during fake-ification.
For testing, there are extensive pre-existing tests, and I added a test to ensure that hot-swapping functions correctly.
```sh
python test/test_autograd.py -k test_view_func_replay
python test/test_ops.py -k test_view_replay
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/118404
Approved by: https://github.com/ezyang
Replaces `view_func()` closures with a reified `ViewFunc` data structure. Codegen generates a `ViewFunc` subclass for each view op (e.g. `NarrowViewFunc`) containing state needed to reconstruct the view. The `ViewFunc` API allows for querying and hot-swapping any `SymInt`s or `Tensors` in the state through `get_symints()` / `get_tensors()` / `clone_and_set()`, which will be essential for fake-ification later on.
```cpp
/// Base class for view functions, providing reapplication of a view on a new base.
/// Each view op should get a codegenerated subclass of this class containing
/// any state needed to reconstruct the view. The class also provides convenience
/// accessors for saved SymInts / tensor state. This is useful for e.g. fake-ification,
/// where we want to use symbolic values or fake tensors instead.
struct TORCH_API ViewFunc {
virtual ~ViewFunc() {}
/// Returns any SymInts in the saved state.
virtual std::vector<c10::SymInt> get_symints() const { return {}; }
/// Returns the number of SymInts in the saved state.
virtual size_t num_symints() const { return 0; }
/// Returns any tensors in the saved state.
virtual std::vector<at::Tensor> get_tensors() const { return {}; }
/// Returns the number of tensors in the saved state.
virtual size_t num_tensors() const { return 0; }
/// Reapplies the view on the given base using the saved state.
virtual at::Tensor operator()(const at::Tensor&) const = 0;
/// Returns a clone of this ViewFunc, optionally with the specified saved state.
virtual std::unique_ptr<ViewFunc> clone_and_set(
std::optional<std::vector<c10::SymInt>> = c10::nullopt,
std::optional<std::vector<at::Tensor>> = c10::nullopt) const = 0;
protected:
/// Sets the values of any SymInts in the saved state. The input vector size must
/// match the number of SymInts in the saved state (i.e. the size of the list
/// returned by get_symints()).
virtual void set_symints(std::vector<c10::SymInt>) {}
/// Sets the values of any Tensors in the saved state. The input vector size must
/// match the number of Tensors in the saved state (i.e. the size of the list
/// returned by get_tensors()).
virtual void set_tensors(std::vector<at::Tensor>) {}
};
```
New codegen files:
* `torch/csrc/autograd/generated/ViewFunc.h`
* `torch/csrc/autograd/generated/ViewFuncs.cpp`
The templates for these also contains impls for `ChainedViewFunc` and `ErroringViewFunc` which are used in a few places within autograd.
Example codegen for `slice.Tensor`:
```cpp
// torch/csrc/autograd/generated/ViewFuncs.h
#define SLICE_TENSOR_VIEW_FUNC_AVAILABLE
struct SliceTensorViewFunc : public torch::autograd::ViewFunc {
SliceTensorViewFunc(int64_t dim, c10::optional<c10::SymInt> start, c10::optional<c10::SymInt> end, c10::SymInt step) : dim(dim), start(start), end(end), step(step)
{};
virtual ~SliceTensorViewFunc() override {};
virtual std::vector<c10::SymInt> get_symints() const override;
virtual size_t num_symints() const override;
virtual std::vector<at::Tensor> get_tensors() const override;
virtual size_t num_tensors() const override;
virtual at::Tensor operator()(const at::Tensor&) const override;
virtual std::unique_ptr<ViewFunc> clone_and_set(
std::optional<std::vector<c10::SymInt>> = c10::nullopt,
std::optional<std::vector<at::Tensor>> = c10::nullopt) const override;
protected:
virtual void set_symints(std::vector<c10::SymInt>) override;
virtual void set_tensors(std::vector<at::Tensor>) override;
private:
int64_t dim;
c10::optional<c10::SymInt> start;
c10::optional<c10::SymInt> end;
c10::SymInt step;
};
...
// torch/csrc/autograd/generated/ViewFuncs.cpp
std::vector<c10::SymInt> SliceTensorViewFunc::get_symints() const {
::std::vector<c10::SymInt> symints;
symints.reserve((start.has_value() ? 1 : 0) + (end.has_value() ? 1 : 0) + 1);
if(start.has_value()) symints.insert(symints.end(), *(start));
if(end.has_value()) symints.insert(symints.end(), *(end));
symints.push_back(step);
return symints;
}
size_t SliceTensorViewFunc::num_symints() const {
return static_cast<size_t>((start.has_value() ? 1 : 0) + (end.has_value() ? 1 : 0) + 1);
}
void SliceTensorViewFunc::set_symints(std::vector<c10::SymInt> symints) {
TORCH_INTERNAL_ASSERT(symints.size() == num_symints());
auto i = 0;
if(start.has_value()) start = symints[i];
i += (start.has_value() ? 1 : 0);
if(end.has_value()) end = symints[i];
i += (end.has_value() ? 1 : 0);
step = symints[i];
}
std::vector<at::Tensor> SliceTensorViewFunc::get_tensors() const {
::std::vector<at::Tensor> tensors;
return tensors;
}
size_t SliceTensorViewFunc::num_tensors() const {
return static_cast<size_t>(0);
}
void SliceTensorViewFunc::set_tensors(std::vector<at::Tensor> tensors) {
TORCH_INTERNAL_ASSERT(tensors.size() == num_tensors());
}
at::Tensor SliceTensorViewFunc::operator()(const at::Tensor& input_base) const {
return at::_ops::slice_Tensor::call(input_base, dim, start, end, step);
}
std::unique_ptr<ViewFunc> SliceTensorViewFunc::clone_and_set(
std::optional<std::vector<c10::SymInt>> symints,
std::optional<std::vector<at::Tensor>> tensors) const {
auto output = std::make_unique<SliceTensorViewFunc>(dim, start, end, step);
if (symints.has_value()) {
output->set_symints(std::move(*(symints)));
}
if (tensors.has_value()) {
output->set_tensors(std::move(*(tensors)));
}
return output;
}
```
The `_view_func()` / `_view_func_unsafe()` methods now accept two additional (optional) args for `symint_visitor_fn` / `tensor_visitor_fn`. If these are defined, they are expected to be python callables that operate on a single SymInt / tensor and return a new one. This allows for the hot-swapping needed during fake-ification.
For testing, there are extensive pre-existing tests, and I added a test to ensure that hot-swapping functions correctly.
```sh
python test/test_autograd.py -k test_view_func_replay
python test/test_ops.py -k test_view_replay
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/118404
Approved by: https://github.com/ezyang
The flag is not correctly set when PyTorch is compiled with GPU support resulting in failures in
`test_ops.py::test_python_ref_meta__refs_linalg_svd_cpu_complex`.
Use a similar approach to test_meta and skip the check for this function.
Workaround for #105068
Pull Request resolved: https://github.com/pytorch/pytorch/pull/117972
Approved by: https://github.com/lezcano
Part 2 of implementation for general [subclass view fake-ification](https://docs.google.com/document/d/1C5taWiplmX7nKiURXDOAZG2W5VNJ2iV0fQFq92H0Cxw).
Details:
* Codegen `rev_view_func()` alongside `view_func()`
* Reverse view_func gives you a "base" from a "view": `rev_view_func(new_view) -> new_base` AKA it plays the original view backwards
* Utilizes the functional inverses defined in `FunctionalInverses.cpp`, passing `InverseReturnMode::AlwaysView`
* Manually implements functional inverses for `narrow()` and `chunk()`
* **NB: Multi-output views now set view_func() / rev_view_func() for each of the output views!**
* Due to this, the `as_view()` overload that operates on a list of views is scrapped in favor of iteration via codegen
Example codegen in `ADInplaceOrViewTypeN.cpp`:
```cpp
at::Tensor narrow(c10::DispatchKeySet ks, const at::Tensor & self, int64_t dim, c10::SymInt start, c10::SymInt length) {
auto _tmp = ([&]() {
at::AutoDispatchBelowADInplaceOrView guard;
return at::_ops::narrow::redispatch(ks & c10::after_ADInplaceOrView_keyset, self, dim, start, length);
})();
std::function<at::Tensor(const at::Tensor&)> func=nullptr;
std::function<at::Tensor(const at::Tensor&)> rev_func=nullptr;
if (false || !self.unsafeGetTensorImpl()->support_as_strided() ||
c10::AutogradState::get_tls_state().get_view_replay_enabled()) {
func = [=](const at::Tensor& input_base) {
return at::_ops::narrow::call(input_base, dim, start, length);
};
rev_func = [=](const at::Tensor& input_view) {
// NB: args from narrow() signature are passed along to the inverse
return at::functionalization::FunctionalInverses::narrow_copy_inverse(self, input_view, at::functionalization::InverseReturnMode::AlwaysView, dim, start, length);
};
}
auto result = as_view(/* base */ self, /* output */ _tmp, /* is_bw_differentiable */ true, /* is_fw_differentiable */ true, /* view_func */ func, /* rev_view_func */ rev_func, /* creation_meta */ InferenceMode::is_enabled() ? CreationMeta::INFERENCE_MODE : (at::GradMode::is_enabled() ? CreationMeta::DEFAULT : CreationMeta::NO_GRAD_MODE));
return result;
}
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/115894
Approved by: https://github.com/soulitzer
Summary:
This is important for writing aten IR based graph transformation.
```
In [4]: [x.name for x in torch.ops.aten.reshape.default._schema.arguments]
Out[4]: ['self', 'shape']
In [8]: torch.ops.aten.reshape.default(torch.rand(1,2), shape=[2])
Out[8]: tensor([0.7584, 0.4834])
# === CANNOT CALL `self` BY KWARGS ===
In [7]: torch.ops.aten.reshape.default(self=torch.rand(1,2), shape=[2])
---------------------------------------------------------------------------
TypeError Traceback (most recent call last)
Cell In[7], line 1
----> 1 torch.ops.aten.reshape.default(self=torch.rand(1,2), shape=[2])
TypeError: OpOverload.__call__() got multiple values for argument 'self'
```
# Where's the problem?
1. the aten ops first arg is usually named `self` (aten/src/ATen/native/native_functions.yaml)
2. Unfortunately, in `torch._ops.{OpOverload, OpOverloadPacket}.__call__()`, the first arg is (by python convention) named `self` too.
So when call `self` by kwargs, `OpOverloadPacket.__call__` received:
```
OpOverloadPacket.__call__(self, {"self": ...})
```
It is Python that does not allow some argument named "arg" to appear twice. and hence
> TypeError: OpOverload.__call__() got multiple values for argument 'self'
# How to fix?
**Note that**, in above, `self` is an instance of `OpOverloadPacket`, and the "self" kwarg is the input tensor to the aten op. To fix, we only need to differentiate the two `self`s.
In Python, first arg of a method does not need to be named `self`. So we change the `__call__` definition to:
```
def __call__(_self, ...):
```
Now the call becomes:
```
OpOverloadPacket.__call__(_self, {"self": ...})
```
where:
* `_self` is the instance to the `OpOverloadPacket`
* `"self"` is the input tensor to the aten op.
Test Plan:
```
In [4]: [x.name for x in torch.ops.aten.reshape.default._schema.arguments]
Out[4]: ['self', 'shape']
In [3]: torch.ops.aten.reshape.default(self=torch.rand(1,2), shape=[2])
Out[3]: tensor([0.5127, 0.3051])
```
Differential Revision: D51731996
Pull Request resolved: https://github.com/pytorch/pytorch/pull/114920
Approved by: https://github.com/houseroad
Constant time access of first value in collection. This is a constant time operation instead of converting the item to a list to get the first item which is linear. The rule is turned on which automatically autofixes and enforces this.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/115507
Approved by: https://github.com/malfet
Fixes#111222
This pull request adds tests for factory functions that create tensors with a strided layout. The tests are added to the `test_ops.py` file and check the behavior of the `empty`, `zeros`, `ones`, and `rand` factory functions when used with the `layout=torch.strided` argument.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/111463
Approved by: https://github.com/lezcano
## Context
Introduce a core decomposition for `aten.floor_divide` into other `aten` ops, and add it to the core ATen decomposition table.
This replaces the decomposition of `floor_divide` that was used by Inductor. I noticed there was a note on that decomposition
```
# TorchInductor-only decomposition. It should not be taken to core.
# See https://github.com/pytorch/torchdynamo/pull/1120
```
but couldn't discern the reason why this is the case. cc: @lezcano
Pull Request resolved: https://github.com/pytorch/pytorch/pull/110046
Approved by: https://github.com/peterbell10
Python decompositions wrapped by `out_wrapper` need to be unwrapped before compiling with TorchScript since:
- `out_wrapper` extends the decompositions signature with an out parameter, however this `out` parameter is not present in the source code of the original decomposition so the resulting `ScriptFunction` will not have an `out` parameter
- `out_wrapper` is in the `torch._prims_common.wrappers` module so its `globals()` are different to the globals of the decomposition to be wrapped. This may cause symbol resolution to fail with the TorchScript compiler since it is compiling the unwrapped decomps source code rather than the wrapper
The python decomposition for `aten.trace` is wrapped as an example, other decompositions are to be fixed in https://github.com/pytorch/pytorch/pull/107707
Pull Request resolved: https://github.com/pytorch/pytorch/pull/109367
Approved by: https://github.com/lezcano
This fixes numerous tests which were xfailing. For instance, the
`_segment_reduce.lengths` OpInfo test, which was previously relying on
the fallback kernel to determine the shape of the meta tensor. The
fallback kernel would fail with
segment_reduce(): Expected all rows of lengths along axis to sum to data.size(lengths.dim()-1) when !unsafe.
as it was trying to read the values of a meta tensor.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/109359
Approved by: https://github.com/ezyang
Fixes#68972
Relands #107246
To avoid causing Meta-internal CI failures, this PR avoids always asserting that the default dtype is float in the `TestCase.setUp/tearDown` methods. Instead, the assert is only done if `TestCase._default_dtype_check_enabled == True`. `_default_dtype_check_enabled` is set to True in the `if __name__ == "__main__":` blocks of all the relevant test files that have required changes for this issue
Pull Request resolved: https://github.com/pytorch/pytorch/pull/108088
Approved by: https://github.com/ezyang