Revert a4d5a2c3fe325459fc175a1ffe9cf7552f22be55...b80aa519609fc9330c8ad6038e0807f7d53d0477 on PyTorch dispatcher walkthrough

Brian Hirsh
2024-05-22 09:16:59 -04:00
parent aa5164f73c
commit ed778ba10d

@ -1,503 +1,339 @@
[[Page Maintainers|Where or how should I add documentation?]]:
[@bdhirsh](mailto:briandhirsh@gmail.com)
[[Page Maintainers|Where or how should I add documentation?]]: @bdhirsh
# Codegen + Structured Kernels Overview
Adapted from: http://blog.ezyang.com/2020/09/lets-talk-about-the-pytorch-dispatcher/
Ed has a really great overview of code-generation and why we have it in PyTorch:
check out his podcast:
https://pytorch-dev-podcast.simplecast.com/episodes/code-generation.
## Why care about a dispatcher?
This document will go over our codegen subsystem + structured kernels in more
detail, and involve you using gdb to jump through the different code-generated
files that are part of a call into torch.add().
* PyTorch has a lot of systems: autograd, tracing, vmap
* and a lot of backend devices: XLA, CUDA, CPU, ...
* We could write a single at::add function that handles all of the above
* It would probably have a big fat switch statement with a lot of code in it.
* Think about packing the VariableType (autogenerated) code, the CUDA code, and the CPU code all into one function!
### What it is
We have a code-generation pipeline that runs as part of the PyTorch build - it
reads in some yaml files, and spits out a bunch of C++ files.
What is the dispatcher?
### Why we have it
* Each operator has a *dispatch table*, a table of function pointers for each key.
* The dispatch keys [are sorted by *priority*](https://github.com/pytorch/pytorch/blob/f588ad6a35c3f52da8e8180c7b51de954fce5fd1/c10/core/DispatchKey.h#L21)
* When you call an operator, the dispatcher looks at the current set of DispatchKeys to figure out which function pointer to call.
So, why do we have codegen? One big motivating factor is to reduce boilerplate.
PyTorch has a lot of operators, and theres a lot of stuff that should “just
work” for every operator. We dont want to make someone hand-write all of that
functionality whenever a new operator is added. Instead, we code-generate it.
A (non-exhaustive) list of functionality (we need all of this for every
operator, so multiply by ~2000):
Each Tensor has a DispatchKeySet
- bindings to python
- The frontend C++ API
- autograd support
- registering kernels to the dispatcher
- other stuff
- special logic for factory functions
- torch.jit.trace functionality
* To figure out which function pointer to call, we:
* Union all the dispatch keys in the Tensor
* Union some global dispatch keys (just BackendSelect*)
* Union a set of “Local Include” keys. These are usually set in a thread-local way
* Remove a set of “Local Exclude keys”. This is usually set in a thread-local way
* The code for that lives [here](https://github.com/pytorch/pytorch/blob/65f33ec85c2a7d8fb9bf582017d3170bf89e6c12/aten/src/ATen/core/dispatch/DispatchKeyExtractor.h#L23).
* Once we have our final key set, we pick the first dispatch key.
### Inputs
## Lets go through an example of what happens when we add two Tensors together.
We have a yaml file, native_functions.yaml, which describes metadata about each
operator that gets consumed by the codegen:
https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/native_functions.yaml
```py
x = torch.randn(3, device='cuda')
y = torch.randn(1, device='cuda')
torch.add(x, y)
Were going to focus on the operator torch.add(a, b, out=c), which corresponds
to the yaml entry add.out:
# Tensor dispatch keys:
# x has the AutogradCUDA and CUDA dispatch key
# y has the AutogradCUDA and CUDA dispatch key.
```yaml
- func: add.out(Tensor self, Tensor other, *, Scalar alpha=1, Tensor(a!) out) -> Tensor(a!)
device_check: NoCheck # TensorIterator
structured: True
structured_inherits: TensorIteratorBase
ufunc_inner_loop:
Generic: add (AllAndComplex, BFloat16, Half, ComplexHalf)
ScalarOnly: add (Bool)
dispatch:
SparseCPU: add_out_sparse_cpu
SparseCUDA: add_out_sparse_cuda
SparseCsrCPU: add_out_sparse_compressed_cpu
SparseCsrCUDA: add_out_sparse_compressed_cuda
MkldnnCPU: mkldnn_add_out
MPS: add_out_mps
tags: pointwise
# Global dispatch keys:
# [BackendSelect]
# Local include: []
# Local exclude: []
# The final set ordered by priority, is [AutogradCUDA, BackendSelect, CUDA]
```
Theres public documentation on each of the different pieces of yaml here:
https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/README.md
Ultimately, what happens is that we will make it to `native::add`. native::add is [registered for at::add on the CPU and CUDA keys](https://github.com/pytorch/pytorch/blob/f588ad6a35c3f52da8e8180c7b51de954fce5fd1/aten/src/ATen/native/native_functions.yaml#L372)
The codegen is written in a functional style using python dataclasses to
represent the different inputs/intermediates/outputs. For example, each entry in
`native_functions.yaml` is represented in the codegen as a NativeFunction
object:
https://github.com/pytorch/pytorch/blob/6596a3f23dfe1ea4175637fa979bcbfbff397737/torchgen/model.py#L427
Finally, one of the main entry points to the codegen is in
`tools/codegen/gen.py` (theres also a separate entry point for the autograd
codegen pipeline). You can see the part of the file where we generate the C++
API for example, `Functions.h`:
(https://github.com/pytorch/pytorch/blob/f8e14f3b46e68a5271a8c57ce749ad8057d77ddd/torchgen/gen.py#L1781)
It reads in a template file, `aten/src/ATen/templates/Functions.h`
(https://github.com/pytorch/pytorch/blob/f8e14f3b46e68a5271a8c57ce749ad8057d77ddd/aten/src/ATen/templates/Functions.h),
and generates the file build/aten/src/ATen/Functions.h:
# Exercise 0: Full Stack Trace of torch.add
For this exercise youll need to have pytorch built with debug symbols. I
usually do that with `USE_CUDA=0 DEBUG=1 python setup.py develop` (The
`USE_CUDA=0` is because we dont need it, and building with cuda takes a long
time).
Were going to run a small python program using gdb to view the full stack
trace. Create a python script, `tmp.py`, with the following:
```python
import torch
a = torch.tensor([1, 1])
b = torch.tensor([1, 1])
c = torch.add(a, b)
```
Run `gdb python`(or `lldb python -- tmp.py`) to start up `gdb`. Were going to set
a breakpoint in the `add` kernel - to do that, in the `gdb` prompt, type `break
structured_ufunc_add_CPU::impl`(or `b structured_ufunc_add_CPU::impl` in lldb).
Then run your script inside of `gdb` with `run tmp.py`(or `r` in lldb).
The debugger should pause inside of the add kernel. Type `bt` to view the
current stack trace.
Ignoring the first ~10 function calls through the python interpreter, you should
see a stack trace that looks something like the following:
```
* thread #1, name = 'python', stop reason = breakpoint 1.1
* frame #0: 0x00007fffd38ed42a libtorch_cpu.so`at::native::structured_ufunc_add_CPU::impl(this=0x00007fffffffae60, self=0x00007fffffffbbe0, other=0x00007fffffffbbd8, alpha=0x00007fffffffbbb0, out=0x00007fffffffb190) at UfuncCPU_add.cpp:30:11
frame #1: 0x00007fffd2ae81aa libtorch_cpu.so`at::(anonymous namespace)::wrapper_CPU_add_Tensor(self=0x00007fffffffbbe0, other=0x00007fffffffbbd8, alpha=0x00007fffffffbbb0) at RegisterCPU.cpp:1576:8
frame #2: 0x00007fffd2c9079d libtorch_cpu.so`c10::impl::wrap_kernel_functor_unboxed_<c10::impl::detail::WrapFunctionIntoFunctor_<c10::CompileTimeFunctionPointer<at::Tensor(const at::Tensor&, const at::Tensor&, const c10::Scalar&), at::(anonymous namespace)::wrapper_CPU_add_Tensor>, at::Tensor, c10::guts::typelist::typelist<const at::Tensor&, const at::Tensor&, const c10::Scalar&> >, at::Tensor(const at::Tensor&, const at::Tensor&, const c10::Scalar&)>::call(c10::OperatorKernel *, c10::DispatchKeySet, const at::Tensor &, const at::Tensor &, const c10::Scalar &) [inlined] operator(args#2=0x00007fffffffbbb0, args#1=0x00007fffffffbbd8, args#0=0x00007fffffffbbe0, this=0x0000555556633ac0) at WrapFunctionIntoFunctor.h:13:72
frame #3: 0x00007fffd2c90759 libtorch_cpu.so`c10::impl::wrap_kernel_functor_unboxed_<c10::impl::detail::WrapFunctionIntoFunctor_<c10::CompileTimeFunctionPointer<at::Tensor(const at::Tensor&, const at::Tensor&, const c10::Scalar&), at::(anonymous namespace)::wrapper_CPU_add_Tensor>, at::Tensor, c10::guts::typelist::typelist<const at::Tensor&, const at::Tensor&, const c10::Scalar&> >, at::Tensor(const at::Tensor&, const at::Tensor&, const c10::Scalar&)>::call(functor=0x0000555556633ac0, (null)=(repr_ = 32769), args#0=0x00007fffffffbbe0, args#1=0x00007fffffffbbd8, args#2=0x00007fffffffbbb0) at make_boxed_from_unboxed_functor.h:468:63
frame #4: 0x00007fffd1f5dec7 libtorch_cpu.so`at::Tensor c10::callUnboxedKernelFunction<at::Tensor, at::Tensor const&, at::Tensor const&, c10::Scalar const&>(unboxed_kernel_func=0x00007fffd2c906ee, functor=0x0000555556633ac0, dispatchKeySet=(repr_ = 32769), (null)=0x00007fffffffbbe0, (null)=0x00007fffffffbbd8, (null)=0x00007fffffffbbb0) at KernelFunction_impl.h:52:72
frame #5: 0x00007fffd1e1ea24 libtorch_cpu.so`at::Tensor c10::Dispatcher::redispatch<at::Tensor, at::Tensor const&, at::Tensor const&, c10::Scalar const&>(c10::TypedOperatorHandle<at::Tensor (at::Tensor const&, at::Tensor const&, c10::Scalar const&)> const&, c10::DispatchKeySet, at::Tensor const&, at::Tensor const&, c10::Scalar const&) const at KernelFunction_impl.h:104:87
frame #6: 0x00007fffd1e1e9aa libtorch_cpu.so`at::Tensor c10::Dispatcher::redispatch<at::Tensor, at::Tensor const&, at::Tensor const&, c10::Scalar const&>(this=0x00007fffe61a1de0, op=0x00007fffe61c7db0, currentDispatchKeySet=(repr_ = 32769), (null)=0x00007fffffffbbe0, (null)=0x00007fffffffbbd8, (null)=0x00007fffffffbbb0) const at Dispatcher.h:712:102
frame #7: 0x00007fffd2332b7c libtorch_cpu.so`at::_ops::add_Tensor::redispatch(c10::DispatchKeySet, at::Tensor const&, at::Tensor const&, c10::Scalar const&) [inlined] c10::TypedOperatorHandle<at::Tensor (at::Tensor const&, at::Tensor const&, c10::Scalar const&)>::redispatch(args#2=0x00007fffffffbbb0, args#1=0x00007fffffffbbd8, args#0=0x00007fffffffbbe0, currentDispatchKeySet=(repr_ = 32769), this=<unavailable>) const at Dispatcher.h:532:126
frame #8: 0x00007fffd2332acd libtorch_cpu.so`at::_ops::add_Tensor::redispatch(dispatchKeySet=(repr_ = 32769), self=0x00007fffffffbbe0, other=0x00007fffffffbbd8, alpha=0x00007fffffffbbb0) at Operators_2.cpp:1049:60
frame #9: 0x00007fffd502cbf2 libtorch_cpu.so`at::redispatch::add(dispatchKeySet=(repr_ = 32769), self=0x00007fffffffbbe0, other=0x00007fffffffbbd8, alpha=0x00007fffffffbbb0) at RedispatchFunctions.h:607:83
frame #10: 0x00007fffd4ef7650 libtorch_cpu.so`operator(__closure=0x00007fffffffb5c0) at VariableType_2.cpp:5969:85
frame #11: 0x00007fffd4ef7b7c libtorch_cpu.so`torch::autograd::VariableType::(anonymous namespace)::add_Tensor(ks=(repr_ = 274877939713), self=0x00007fffffffbbe0, other=0x00007fffffffbbd8, alpha=0x00007fffffffbbb0) at VariableType_2.cpp:5970:6
frame #12: 0x00007fffd4ff0ad9 libtorch_cpu.so`c10::impl::wrap_kernel_functor_unboxed_<c10::impl::detail::WrapFunctionIntoFunctor_<c10::CompileTimeFunctionPointer<at::Tensor(c10::DispatchKeySet, const at::Tensor&, const at::Tensor&, const c10::Scalar&), torch::autograd::VariableType::(anonymous namespace)::add_Tensor>, at::Tensor, c10::guts::typelist::typelist<c10::DispatchKeySet, const at::Tensor&, const at::Tensor&, const c10::Scalar&> >, at::Tensor(c10::DispatchKeySet, const at::Tensor&, const at::Tensor&, const c10::Scalar&)>::call(c10::OperatorKernel *, c10::DispatchKeySet, const at::Tensor &, const at::Tensor &, const c10::Scalar &) [inlined] operator(args#3=0x00007fffffffbbb0, args#2=0x00007fffffffbbd8, args#1=0x00007fffffffbbe0, args#0=(repr_ = 274877939713), this=0x0000555557b02710) at WrapFunctionIntoFunctor.h:13:72
frame #13: 0x00007fffd4ff0a80 libtorch_cpu.so`c10::impl::wrap_kernel_functor_unboxed_<c10::impl::detail::WrapFunctionIntoFunctor_<c10::CompileTimeFunctionPointer<at::Tensor(c10::DispatchKeySet, const at::Tensor&, const at::Tensor&, const c10::Scalar&), torch::autograd::VariableType::(anonymous namespace)::add_Tensor>, at::Tensor, c10::guts::typelist::typelist<c10::DispatchKeySet, const at::Tensor&, const at::Tensor&, const c10::Scalar&> >, at::Tensor(c10::DispatchKeySet, const at::Tensor&, const at::Tensor&, const c10::Scalar&)>::call(functor=0x0000555557b02710, dispatchKeySet=(repr_ = 274877939713), args#0=0x00007fffffffbbe0, args#1=0x00007fffffffbbd8, args#2=0x00007fffffffbbb0) at make_boxed_from_unboxed_functor.h:485:79
frame #14: 0x00007fffd1f5dec7 libtorch_cpu.so`at::Tensor c10::callUnboxedKernelFunction<at::Tensor, at::Tensor const&, at::Tensor const&, c10::Scalar const&>(unboxed_kernel_func=0x00007fffd4ff0a0b, functor=0x0000555557b02710, dispatchKeySet=(repr_ = 274877939713), (null)=0x00007fffffffbbe0, (null)=0x00007fffffffbbd8, (null)=0x00007fffffffbbb0) at KernelFunction_impl.h:52:72
frame #15: 0x00007fffd233293b libtorch_cpu.so`at::_ops::add_Tensor::call(at::Tensor const&, at::Tensor const&, c10::Scalar const&) at KernelFunction_impl.h:104:87
frame #16: 0x00007fffd23328ac libtorch_cpu.so`at::_ops::add_Tensor::call(at::Tensor const&, at::Tensor const&, c10::Scalar const&) at Dispatcher.h:694:97
frame #17: 0x00007fffd2332695 libtorch_cpu.so`at::_ops::add_Tensor::call(at::Tensor const&, at::Tensor const&, c10::Scalar const&) [inlined] c10::TypedOperatorHandle<at::Tensor (at::Tensor const&, at::Tensor const&, c10::Scalar const&)>::call(args#2=0x00007fffffffbbb0, args#1=0x00007fffffffbbd8, args#0=0x00007fffffffbbe0, this=<unavailable>) const at Dispatcher.h:527:97
frame #18: 0x00007fffd233257b libtorch_cpu.so`at::_ops::add_Tensor::call(self=0x00007fffffffbbe0, other=0x00007fffffffbbd8, alpha=0x00007fffffffbbb0) at Operators_2.cpp:1042:38
frame #19: 0x00007fffe74b677c libtorch_python.so`at::Tensor::add(this=0x00007fffffffbbe0, other=0x00007fffffffbbd8, alpha=0x00007fffffffbbb0) const at TensorBody.h:1664:79
frame #20: 0x00007fffe75f1164 libtorch_python.so`operator(__closure=0x00007fffffffbaad, self=0x00007fffffffbbe0, other=0x00007fffffffbbd8, alpha=0x00007fffffffbbb0) at python_torch_functions_2.cpp:1400:39
frame #21: 0x00007fffe75f1777 libtorch_python.so`torch::autograd::THPVariable_add(self_=0x0000000000000000, args=0x00007ffd68f74b80, kwargs=0x0000000000000000) at python_torch_functions_2.cpp:1402:33
```
Thats a lot of function calls! Were going walk through the main pieces that
are relevant to codegen and where they live. For each piece, I listed the
relevant numbers in the gdb stack trace.
> Tip: In `lldb`, if you are curious about the abstract path of the source files
> listed in the frame backtrace, for example the 12th frame , you may first
> switch to that frame `f 12` and then show its source info `so i`.
### (1) Python Bindings
> #21: torch::autograd::THPVariable_add
This is the first stop that we hit after going through the python interpreter:
python bindings. This is the code that interfaces directly with cpython to bind
our C++ functions to python.
You can see a snippet of the function below: Its job is basically to take all of
the PyObjects that it was handed from CPython, parse them into actual C++ types
(like `at::Tensor`), and call into the C++ API. It does that below by calling into
the Tensor add method: self.add(other, alpha).
So heres what happens:
* at::add(x, y) invokes the dispatcher, which combines the dispatch keys into a DispatchKeySet as described above. In this scenario, the highest priority key is the `AutogradCUDA` key.
* The file that this function lives in is actually codegend, so if you want to view it in source youll need to build pytorch. Then you can view it at `build/aten/src/ATen/Functions.h`.
* at::add(x, y) dispatches to the Autograd implementation of add. Thats the below function.
* Dont worry too much about the `Autograd` vs. `AutogradCUDA` distinction; in 99% of cases you can treat them as identical. If youre curious though, `Autograd` is an [alias dispatch key](https://github.com/pytorch/pytorch/blob/65f33ec85c2a7d8fb9bf582017d3170bf89e6c12/c10/core/DispatchKeySet.h#L211).
* This function is also codegend - after building, you can view it at `torch/csrc/autograd/generated/VariableTypeEverything.cpp`.
```cpp
static PyObject * THPVariable_add(PyObject* self_, PyObject* args, PyObject* kwargs)
{
HANDLE_TH_ERRORS
static PythonArgParser parser({
"add(Tensor input, Scalar alpha, Tensor other, *, Tensor out=None)|deprecated",
"add(Tensor input, Tensor other, *, Scalar alpha=1, Tensor out=None)",
}, /*traceable=*/true);
// In VariableTypeEverything.cpp
ParsedArgs<4> parsed_args;
auto _r = parser.parse(nullptr, args, kwargs, parsed_args);
...
auto dispatch_add = [](const at::Tensor & self, const at::Tensor & other, const at::Scalar & alpha) -> at::Tensor
pybind11::gil_scoped_release no_gil;
return self.add(other, alpha); return self.add(other, alpha);
};
return wrap(dispatch_add(_r.tensor(0), _r.tensor(1), _r.scalar(2)));
...
Py_RETURN_NONE;
END_HANDLE_TH_ERRORS
}
```
These are all codegend and live in
`torch/csrc/autograd/generated/python_torch_functions_2.cpp`.
### (2) C++ API
> #18: at::\_ops::add_Tensor::call
> #19: at::Tensor::add
The next stop is the C++ method API, which is one of the top-level APIs for
calling into the dispatcher. The dispatcher then looks at all of the arguments +
any thread-local state to figure out which kernel to dispatch to.
https://github.com/pytorch/pytorch/wiki/PyTorch-dispatcher-walkthrough has some
more details about the dispatcher key-calculation process.
In `build/aten/src/ATen/core/TensorBody.h`:
```cpp
//namespace at
inline at::Tensor Tensor::add(const at::Tensor & other, const at::Scalar & alpha) const {
return at::_ops::add_Tensor::call(const_cast<Tensor&>(*this), other, alpha);
}
```
In `build/aten/src/ATen/Operators_2.cpp`:
```cpp
static C10_NOINLINE c10::TypedOperatorHandle<add_Tensor::schema> create_add_Tensor_typed_handle() {
return c10::Dispatcher::singleton()
.findSchemaOrThrow(add_Tensor::name, add_Tensor::overload_name)
.typed<add_Tensor::schema>();
}
at::Tensor add_Tensor::call(const at::Tensor & self, const at::Tensor & other, const at::Scalar & alpha) {
static auto op = create_add_Tensor_typed_handle();
return op.call(self, other, alpha);
}
```
### (3) Autograd kernel
> #11: torch::autograd::VariableType::(anonymous namespace)::add_Tensor
After a bunch of dispatcher-related functions, the dispatcher eventually takes
us to the autograd add kernel. The autograd kernel:
- saves some metadata for autograd
- re-invokes the dispatcher by calling `at::redispatch::add(ks &
c10::after_autograd_keyset, self_, other_, alpha);`
In `torch/csrc/autograd/generated/VariableType_2.cpp`:
```cpp
// namespace at::VariableType
at::Tensor add_Tensor(c10::DispatchKeySet ks, const at::Tensor & self, const at::Tensor & other, const at::Scalar & alpha) {
...
}
// Register `add_Tensor` so that it can be found
TORCH_LIBRARY_IMPL(aten, Autograd, m) {
...
m.impl("add.Tensor",TORCH_FN(VariableType::add_Tensor));
...
}
```
The autograd kernel ends up calling back into the C++ API (by calling
`at::redispatch::add`), which then calls back into the dispatcher and calculates
the next kernel to dispatch to.
### (4) CPU kernel
> #0: at::native::structured_ufunc_add_CPU::impl
> #1: at::(anonymous namespace)::wrapper_CPU_add_Tensor
After a few more function hops through the dispatcher, we eventually dispatch to
the CPU add kernel, which has to actually carry out the computation. The code
for the cpu kernel (and the code that registers the kernel to the dispatcher)
looks like this:
In `build/aten/src/ATen/RegisterCPU.cpp`:
```cpp
at::Tensor wrapper_CPU_add_Tensor(const at::Tensor & self, const at::Tensor & other, const at::Scalar & alpha) {
structured_ufunc_add_CPU_functional op;
op.meta(self, other, alpha);
op.impl(self, other, alpha, op.outputs_[0]);
return std::move(op.outputs_[0]);
}
TORCH_LIBRARY_IMPL(aten, CPU, m) {
m.impl("add.Tensor", TORCH_FN(wrapper_CPU_add_Tensor));
}
```
This code looks a little funky; it calls into a `meta()` and `impl()` function
that are defined elsewhere. This is because add is implemented as a structured
kernel - a new way of implementing operators in pytorch.
That code is some scaffolding that contains a call to the hand-written “cpu add”
kernel. The call to `op.impl()` corresponds directly to the add kernel written
in `build/aten/src/ATen/UfuncCPU_add.cpp`
```cpp
TORCH_IMPL_FUNC(ufunc_add_CPU)(const at::Tensor & self, const at::Tensor & other, const at::Scalar & alpha, const at::Tensor & out) {
add_stub(device_type(), *this, alpha);
}
```
(Note: theres a bit more indirection inside of the handwritten kernel before
reaching main part of the add kernel, which lives in
`build/aten/src/ATen/UfuncCPUKernel_add.cpp`
```cpp
void add_kernel(TensorIteratorBase& iter, const at::Scalar & alpha) {
AT_DISPATCH_SWITCH(iter.common_dtype(), "add_stub",
...
AT_DISPATCH_CASE(at::ScalarType::Long,
[&]() {
auto _s_alpha = alpha.to<scalar_t>();
auto _v_alpha = at::vec::Vectorized<scalar_t>(_s_alpha);
cpu_kernel_vec(iter,
[=](scalar_t self, scalar_t other) { return ufunc::add(self, other, _s_alpha); },
[=](at::vec::Vectorized<scalar_t> self, at::vec::Vectorized<scalar_t> other) { return ufunc::add(self, other, _v_alpha)};
);
Tensor add_Tensor(const Tensor & self, const Tensor & other, Scalar alpha) {
auto& self_ = unpack(self, "self", 0);
auto& other_ = unpack(other, "other", 1);
std::shared_ptr<AddBackward0> grad_fn;
if (compute_requires_grad( self, other )) {
grad_fn = std::shared_ptr<AddBackward0>(new AddBackward0(), deleteNode);
grad_fn->set_next_edges(collect_next_edges( self, other ));
grad_fn->alpha = alpha;
}
)
...
)
}
#ifndef NDEBUG
c10::optional<Storage> self__storage_saved =
self_.has_storage() ? c10::optional<Storage>(self_.storage()) : c10::nullopt;
c10::intrusive_ptr<TensorImpl> self__impl_saved;
if (self_.defined()) self__impl_saved = self_.getIntrusivePtr();
c10::optional<Storage> other__storage_saved =
other_.has_storage() ? c10::optional<Storage>(other_.storage()) : c10::nullopt;
c10::intrusive_ptr<TensorImpl> other__impl_saved;
if (other_.defined()) other__impl_saved = other_.getIntrusivePtr();
#endif
auto tmp = ([&]() {
at::AutoNonVariableTypeMode non_var_type_mode(true);
return at::add(self_, other_, alpha);
})();
auto result = std::move(tmp);
#ifndef NDEBUG
if (self__storage_saved.has_value())
AT_ASSERT(self__storage_saved.value().is_alias_of(self_.storage()));
if (self__impl_saved) AT_ASSERT(self__impl_saved == self_.getIntrusivePtr());
if (other__storage_saved.has_value())
AT_ASSERT(other__storage_saved.value().is_alias_of(other_.storage()));
if (other__impl_saved) AT_ASSERT(other__impl_saved == other_.getIntrusivePtr());
#endif
if (grad_fn) {
set_history(flatten_tensor_args( result ), grad_fn);
}
return result;
}
```
* None of the tensors require grad, so none of the autograd specific logic actually happens
* fun fact: were planning on changing this behavior in the near-to-mid future! Eventually, wed like it if the autograd kernel doesnt ever get called unless the input tensors actually require gradients (specifying `requires_grad=True`)
```cpp
auto tmp = ([&]() {
at::AutoNonVariableTypeMode non_var_type_mode(true);
return at::add(self_, other_, alpha);
})();
```
The code-generated code above lives in the code-generated file
`build/aten/src/ATen/UfuncCPUKernel_add.cpp`
So, the code above calls into our hand-written CPU add kernel, returns a new
output tensor containing the result, and were done!
## Takeaway
The main takeaway from the exercise above is that:
- A lot of stuff happens when you call an operator
- ...most of which is code-generated! A lot of this logic is _really_ similar
across PyTorchs ~2000 operators, and ripe for abstracting over (through
something like code generation).
Sometimes when youre working on / debugging a feature, it can be useful to know
which bits of logic are codegend, and where that logic lives.
# Structured Kernels
Another big part of the codegen is “structured kernels” - a new way of writing
kernels in PyTorch, which uses some clever factoring + a bunch of codegen to
reduce the amount of boilerplate required when writing kernels. torch.add is
implemented as a “structured kernel”, so were going to walk through the bits of
it related to structured kernels.
The process of implementing an operator as a structured kernel involves writing
two functions:
- A “meta” function, which asserts that the inputs have the correct shape/dtype
and figures out what size the output tensor should be.
- An “impl” function, which does the actual computation. There will be a
separate impl() function for every backend (cpu, cuda, xla, etc).
The codegen is responsible for taking these two functions, and plugging them
together in the right way to create all 3 variants of the operator for you:
- at::add() (functional version)
- at::add\_() (inplace version)
- at::add_out() (out= version)
Helpful reading: this presentation on structured kernels, including a diagram on
the class hierarchy (which will be useful in the exercise further down).
https://drive.google.com/file/d/16qPvpCF4Jbh7ss2lCQMk5hmcyzJvUyQj/view?usp=sharing
See also: the structured kernels RFC contains a more detailed overview of what
they are and what the codegen creates:
https://github.com/pytorch/rfcs/blob/rfc-0005/RFC-0005-structured-kernel-definitions.md
## Structured Kernel codegen output example: torch.add
The CPU kernel for the torch.add operator lives in
https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/BinaryOps.cpp#L16,
and has two components:
### Meta function:
```cpp
// expands to structured_add_Tensor::meta() { ... }
TORCH_META_FUNC2(add, Tensor) (
const Tensor& self, const Tensor& other, const Scalar& alpha
) {
build_borrowing_binary_op(maybe_get_output(), self, other);
native::alpha_check(dtype(), alpha);
}
* Now, the above code: creates a [Local Exclude set of [Autograd]](https://github.com/pytorch/pytorch/blob/f588ad6a35c3f52da8e8180c7b51de954fce5fd1/aten/src/ATen/core/LegacyTypeDispatch.h#L50)
* Inside at::add(self_, other_, alpha);, the computation happens again:
```
Impl function:
```cpp
// expands to structured_add_out::impl() { ... }
TORCH_IMPL_FUNC(add_out) (
const Tensor& self, const Tensor& other, const Scalar& alpha, const Tensor& result
) {
add_stub(device_type(), *this, alpha);
TORCH_INTERNAL_ASSERT(result.scalar_type() == output().dtype());
}
# Tensor dispatch keys:
# x has the AutogradCUDA and CUDA dispatch key
# y has the AutogradCUDA and CUDA dispatch key.
# Global include dispatch keys:
# BackendSelect
# Local include: []
# Local exclude: [AutogradCPU, AutogradCUDA, AutogradXLA]
# The final set is [BackendSelect, CUDA]
```
So, the code above implements the two functions structured_add_Tensor::meta()
and structured_add_out::impl(), but where are they declared? The codegen creates
declarations for them.
In NativeMetaFunctions.h:
* Cool! So now the dispatcher looks up BackendSelects implementation for add.
* Checking `build/aten/src/ATen/BackendSelectRegister.cpp`, there is no BackendSelect implementation for add. vBackendSelects add implementation is a special [*fallback kernel*](https://github.com/pytorch/pytorch/blob/f588ad6a35c3f52da8e8180c7b51de954fce5fd1/aten/src/ATen/core/BackendSelectFallbackKernel.cpp#L3-L5)that says “there is nothing here, instead, hop over to the next dispatch key”. More on fallback kernels later.
* In fact, the Dispatcher actually has an optimization to avoid calling the fallthrough kernel at all. It figures out which kernels are “fallthrough” kernels at static initialization time, and adds them to a bitset mask to skip them entirely. If youre curious, the logic for that lives [here](https://github.com/pytorch/pytorch/blob/a029422cae70b019222c00558da5437020550173/aten/src/ATen/core/dispatch/DispatchKeyExtractor.h#L145).
```cpp
// namespace at::meta
struct TORCH_API structured_add_Tensor : public TensorIteratorBase {
void meta(const at::Tensor & self, const at::Tensor & other, const at::Scalar & alpha);
};
```
In NativeFunctions.h:
```cpp
// namespace at::native
struct TORCH_API structured_add_out : public at::meta::structured_add_Tensor {
void impl(const at::Tensor & self, const at::Tensor & other, const at::Scalar & alpha, const at::Tensor & out);
};
```
You can see that the codegen generated declarations for the two functions, and
we hand-implemented them ourselves in BinaryOps.cpp. But how does the codegen
use them?
The code-generated logic that stitches them together lives in the code-generated
file `RegisterCPU.cpp`, and looks like this:
```cpp
// functional version
at::Tensor wrapper_CPU_add_Tensor(const at::Tensor & self, const at::Tensor & other, const at::Scalar & alpha) {
structured_ufunc_add_CPU_functional op;
op.meta(self, other, alpha);
op.impl(self, other, alpha, op.outputs_[0]);
return std::move(op.outputs_[0]);
}
// inplace version
at::Tensor & wrapper_CPU_add__Tensor(at::Tensor & self, const at::Tensor & other, const at::Scalar & alpha) {
structured_ufunc_add_CPU_inplace op(self);
op.meta(self, other, alpha);
op.impl(self, other, alpha, op.outputs_[0]);
if (op.proxy_outputs_[0].has_value()) op.outputs_[0].get().copy_(*op.proxy_outputs_[0]);
return self;
}
// out= version
at::Tensor & wrapper_CPU_add_out_out(const at::Tensor & self, const at::Tensor & other, const at::Scalar & alpha, at::Tensor & out) {
structured_ufunc_add_CPU_out op(out);
op.meta(self, other, alpha);
op.impl(self, other, alpha, op.maybe_get_output(0));
if (op.proxy_outputs_[0].has_value()) op.outputs_[0].get().copy_(*op.proxy_outputs_[0]);
return out;
}
// registering the 3 kernels above to the dispatcher, under the CPU Dispatch Key.
TORCH_LIBRARY_IMPL(aten, CPU, m) {
...
m.impl("add.Tensor", TORCH_FN(wrapper_CPU_add_Tensor));
m.impl("add.out", TORCH_FN(wrapper_CPU_add_out_out));
m.impl("add_.Tensor", TORCH_FN(wrapper_CPU_add__Tensor));
}
```
This is the "final" output - the 3 operators that we needed. The codegen created
3 new kernels, each of which call into our `meta()` and `impl()` functions. The
only difference between the 3 is that they use different classes, each of which
has a different implementation of `set_output()`. You can also find the
definition of all 3 of these classes in `RegisterCPU.cpp`, but below is the
example for `structured_ufunc_add_CPU_functional`:
```cpp
struct structured_ufunc_add_CPU_functional final : public at::native::structured_ufunc_add_CPU {
void set_output_strided(
int64_t output_idx, IntArrayRef sizes, IntArrayRef strides,
TensorOptions options, DimnameList names
) override {
outputs_[output_idx] = create_out(sizes, strides, options);
if (!names.empty()) {
namedinference::propagate_names(outputs_[output_idx], names);
TORCH_LIBRARY_IMPL(_, BackendSelect, m) {
m.fallback(torch::CppFunction::makeFallthrough());
}
// super must happen after, so that downstream can use maybe_get_output
// to retrieve the output
at::native::structured_ufunc_add_CPU::set_output_raw_strided(output_idx, sizes, strides, options, names);
}
void set_output_raw_strided(
int64_t output_idx, IntArrayRef sizes, IntArrayRef strides,
TensorOptions options, DimnameList names
) override {
outputs_[output_idx] = create_out(sizes, strides, options);
if (!names.empty()) {
namedinference::propagate_names(outputs_[output_idx], names);
}
// super must happen after, so that downstream can use maybe_get_output
// to retrieve the output
at::native::structured_ufunc_add_CPU::set_output_raw_strided(output_idx, sizes, strides, options, names);
}
const Tensor& maybe_get_output(int64_t output_idx) override {
return outputs_[output_idx];
}
std::array<Tensor, 1> outputs_;
};
```
You can see that it has its own definition of `set_output_strided()` and
`set_output_raw_strided()` - in this case,
its implementing the functional `at::Tensor::add` kernel, so it needs to
allocate a new tensor as the output (it does that using `at::create_out()`).
* So the dispatcher goes and picks the next key down the list, which is CUDA. We now invoke at::native::add. Weve reached the end!
## Example number 2: factory function
```py
x = torch.randn(3, 3, device='cuda')
# Upon calling at::randn, the dispatch keys are:
# Global include set: [BackendSelect]
# Local include set: []
# Local exclude set: []
# So we select the BackendSelect version of randn.
```
* Factory functions like randn are treated specially, and do not get Fallthrough kernels registered to the dispatcher. Again, you can see the kernel for randn thats registered to the `BackendSelect` key in `build/aten/src/ATen/BackendSelectRegister.cpp`.
* BackendSelect version of randn:
```cpp
C10_ALWAYS_INLINE
at::Tensor randn(at::IntArrayRef size, c10::optional<at::ScalarType> dtype, c10::optional<at::Layout> layout, c10::optional<a
static auto op = c10::Dispatcher::singleton()
.findSchemaOrThrow("aten::randn", "")
.typed<at::Tensor (at::IntArrayRef, c10::optional<at::ScalarType>, c10::optional<at::Layout>, c10::optional<at::Device>,
DispatchKeySet _dk = c10::DispatchKeySet(c10::computeDispatchKey(dtype, layout, device));
return op.redispatch(_dk, size, dtype, layout, device, pin_memory);
}
```
* It computes a dispatch key based on the dtype, layout, and device. In our case, the computed dispatch key is CUDA, so it straight up calls native::randn: https://github.com/pytorch/pytorch/blob/2c554266108f1b556dd49f7c3c06c08f2bbd3cbe/aten/src/ATen/native/TensorFactories.cpp#L616-L623
```cpp
Tensor randn(IntArrayRef size, c10::optional<Generator> generator, const TensorOptions& options) {
auto result = at::empty(size, options);
return result.normal_(0, 1, generator);
}
```
* But were not done! at::empty got invoked. [at::empty also invokes](https://github.com/pytorch/pytorch/blob/2c554266108f1b556dd49f7c3c06c08f2bbd3cbe/aten/src/ATen/native/native_functions.yaml#L1620)the [CUDA version](https://github.com/pytorch/pytorch/blob/2c554266108f1b556dd49f7c3c06c08f2bbd3cbe/aten/src/ATen/native/cuda/TensorFactories.cu#L44)
```
- func: empty.memory_format(int[] size, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None, MemoryFormat? memory_format=None) -> Tensor
#use_c10_dispatcher: full
dispatch:
CPU: empty_cpu
CUDA: empty_cuda
```
* Great. `result` is a Tensor with dispatch keys [AutogradCUDA, CUDA].
* `result.normal_(...)` dispatches to the Autograd implementation of normal_
```cpp
// See VariableTypeEverything.cpp
Tensor & normal_(Tensor & self, double mean, double std, c10::optional<Generator> generator) {
auto& self_ = unpack(self, "self", 0);
check_inplace(self);
std::shared_ptr<NormalBackward0> grad_fn;
if (compute_requires_grad( self )) {
grad_fn = std::shared_ptr<NormalBackward0>(new NormalBackward0(), deleteNode);
grad_fn->set_next_edges(collect_next_edges( self ));
}
#ifndef NDEBUG
c10::optional<Storage> self__storage_saved =
self_.has_storage() ? c10::optional<Storage>(self_.storage()) : c10::nullopt;
c10::intrusive_ptr<TensorImpl> self__impl_saved;
if (self_.defined()) self__impl_saved = self_.getIntrusivePtr();
#endif
{
at::AutoNonVariableTypeMode non_var_type_mode(true);
self_.normal_(mean, std, generator);
}
#ifndef NDEBUG
if (self__storage_saved.has_value())
AT_ASSERT(self__storage_saved.value().is_alias_of(self_.storage()));
if (self__impl_saved) AT_ASSERT(self__impl_saved == self_.getIntrusivePtr());
#endif
increment_version(self);
if (grad_fn) {
rebase_history(flatten_tensor_args( self ), grad_fn);
}
return self;
}
```
* Which then goes to
```
{
at::AutoNonVariableTypeMode non_var_type_mode(true);
self_.normal_(mean, std, generator);
}
```
* `self_` is a Tensor with dispatch keys [AutogradCUDA, CUDA]. The AutoNonVariableTypeMode adds a Local Exclude of [AutogradCUDA]. So the final set is [BackendSelect, CUDA] and the dispatcher selects the CUDA implementation of normal_
* Which straight up goes to[native::normal_](https://github.com/pytorch/pytorch/blob/2c554266108f1b556dd49f7c3c06c08f2bbd3cbe/aten/src/ATen/native/native_functions.yaml#L6793)
```
- func: normal_(Tensor(a!) self, float mean=0, float std=1, *, Generator? generator=None) -> Tensor(a!)
variants: method
dispatch:
CPU, CUDA: normal_
```
## How do we populate the dispatch table?
The dispatcher has a registration API
* see “Operator Registration” in http://blog.ezyang.com/2020/09/lets-talk-about-the-pytorch-dispatcher/
* Our codegen pipeline takes care of the work of calling the registration API, and registering all of our different kernels to most of the important dispatch keys:
* CPU
* CUDA
* Autograd
* BackendSelect
* The API includes the ability to define a Fallback kernel (that does nothing), which you saw used by BackendSelect
## Boxing vs unboxing
Helpful resources:
* See “Unboxing” in http://blog.ezyang.com/2020/09/lets-talk-about-the-pytorch-dispatcher/
* See this wiki page, which has a really useful diagram: https://github.com/pytorch/pytorch/wiki/Boxing-and-Unboxing-in-the-PyTorch-Operator-Library
**Understanding boxed vs. unboxed representations**
Unboxed representation
* Objects have a different layout depending on the data in question
* What you expect from C++: each struct is a different size depending on its type
* This is great for efficiency - your data is packed together tightly, only takes up as much space as it needs!
An unboxed data representation has a downside though: you cant write a single function that works over all of your different objects!
Well, you sort of can with templates in C++:
```cpp
template<typename T>
void foo(T obj) {...}
```
In the above, `void foo(T)` is a function template that you call with different types. But if I call `foo("a"); foo(123);`, the compiler generates and stamps out two completely different implementations of foo() - one that looks like `void foo(const char*)`, and another that looks like `void foo(int)`! Templates are handy for avoiding code duplication, but we still end up producing a new specialized function for every different type thats passed into the template.
Contrast that to a boxed representation:
Boxed representation
* Objects have a unified layout.
* In general: Different programming languages may choose to use a boxed layout by default for all of their types, e.g. Java.
* In PyTorch: We have our own boxed layout implemented in C++: Some of our APIs shove values into these things [called IValues](https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/core/ivalue.h)and toss them onto a stack.
* IValue is a union between Tensor, int64, float64, etc...
Having a boxed representation for types lets us write boxed functions in PyTorch:
* Boxed functions can be written once, and (if implemented correctly) work for all operators
* Boxed functions in PyTorch have a very specific schema:
* `void my_boxed_f(const OperatorHandle& op, std::vector<c10::IValue> stack)`
* Defined here: https://github.com/pytorch/pytorch/blob/8216da1f23b893c074e76e8a9aa7127efbda4287/aten/src/ATen/core/boxing/KernelFunction.h#L104
* `OperatorHandle` is a class that represents an operator (e.g. `torch.add`, `torch.mul`), and the `std::vector<c10::IValue>` is a stack of IValues that are the inputs to that operator
* The general idea when writing a boxed function: Pop the inputs off of the stack, compute some output(s), and push them back onto the stack.
**Example where boxing is used: Batched Fallback kernel.**
* https://github.com/pytorch/pytorch/blob/2c554266108f1b556dd49f7c3c06c08f2bbd3cbe/aten/src/ATen/BatchedFallback.cpp#L238
* `void`` ``batchedTensorForLoopFallback``(``const`` c10::OperatorHandle& op, torch::jit::Stack* stack) {...}`
* Essentially, all of the arguments are IValues on the stack, and we have a handle to an operator in the DispatchTable. A fallback tells us what to do to handle this
* What BatchedFallback does is the following:
* For each sample in the batch, call op(sample)
* For example:
```py
x = torch.randn(N, 3)
y = torch.randn(N, 3)
torch.add(x, y)
```
* There can be multiple tensors present in `stack`.
* BatchedFallback takes all the IValues, converts some to Tensor, slices them in the batch dimension as necessary, and calls `op` multiple times.
* It ends up doing: `torch.stack([x[0] + y[0], x[1] + y[1], x[2] + y[2], ...)`
**Benefits of boxing in PyTorch:**
There are some benefits to writing batching logic as a boxed fallback like the above:
* Complexity decrease. This is kind of subjective, but arguably a single boxed kernel like the above is easier to maintain compared to the alternatives. If you want to write some functionality that works for every operator in PyTorch, some alternatives to writing a boxed fallback are:
* Manually write 1000+ versions of your code, one for each operator (ouch)
* Write fancy template metaprogramming logic to templatize over the operator and argument types (ouch)
* Write codegen that generates all of the different kernels for you.
* We actually do this in some cases: for autograd, and (currently) for tracing. This code is faster than a boxed fallback, but also requires work and careful design to make it maintainable.
* Binary size: We only have one function, instead of having separate specialized functions for every operator (and there are 1000+ operators).
* This is especially important for the mobile use case: mobile cares a lot about having a small binary size!
* Mobile internally also uses the Lite Interpreter, which executes ops in a boxed format. Im not an expert on what this looks like though.
**How is this related to the Dispatcher?**
So, how is this whole notion of boxed vs. unboxed kernels relevant to the dispatcher?
Well, suppose we have a boxed kernel for Batching like we described above, and with batching turned on, I call `torch.sin(x)` on a cpu tensor. We expect batching logic to run before we eventually hit the sin() kernel.
The dispatcher is responsible for going from
* `at::sin()` (normal, unboxed frontend C++ entry point)
* to `void batchedTensorForLoopFallback(const c10::OperatorHandle&, torch::jit::Stack*)` (BOXED kernel that performs batching. Somehow all of the arguments to sin() need to be wrapped into a Stack!)
* to `at::native::sin()` (normal, UNBOXED cpu kernel for sin(). Somehow, the arguments need to be unboxed again so we can call this unboxed function!)
Where does all of this boxing and unboxing conversion logic happen? Since the dispatcher provides APIs for registering both unboxed and boxed kernels, then it also needs to know how to convert arguments and operators between the unboxed and boxed world between invocations.
* If youre curious, some of the template magic that does that lives around here: https://github.com/pytorch/pytorch/blob/8216da1f23b893c074e76e8a9aa7127efbda4287/aten/src/ATen/core/boxing/impl/make_boxed_from_unboxed_functor.h#L521
That class corresponds to one of the leaves of the class hierarchy - a picture
of the full class hierarchy can be found in the linked presentation
(https://drive.google.com/file/d/16qPvpCF4Jbh7ss2lCQMk5hmcyzJvUyQj/view?usp=sharing)