This PR is the first step towards refactors the build for nvfuser in order to have the coegen being a standalone library.
Contents inside this PR:
1. nvfuser code base has been moved to `./nvfuser`, from `./torch/csrc/jit/codegen/cuda/`, except for registration code for integration (interface.h/interface.cpp)
2. splits the build system so nvfuser is generating its own `.so` files. Currently there are:
- `libnvfuser_codegen.so`, which contains the integration, codegen and runtime system of nvfuser
- `nvfuser.so`, which is nvfuser's python API via pybind. Python frontend is now exposed via `nvfuser._C.XXX` instead of `torch._C._nvfuser`
3. nvfuser cpp tests is currently being compiled into `nvfuser_tests`
4. cmake is refactored so that:
- nvfuser now has its own `CMakeLists.txt`, which is under `torch/csrc/jit/codegen/cuda/`.
- nvfuser backend code is not compiled inside `libtorch_cuda_xxx` any more
- nvfuser is added as a subdirectory under `./CMakeLists.txt` at the very end after torch is built.
- since nvfuser has dependency on torch, the registration of nvfuser at runtime is done via dlopen (`at::DynamicLibrary`). This avoids circular dependency in cmake, which will be a nightmare to handle. For details, look at `torch/csrc/jit/codegen/cuda/interface.cpp::LoadingNvfuserLibrary`
Future work that's scoped in following PR:
- Currently since nvfuser codegen has dependency on torch, we need to refactor that out so we can move nvfuser into a submodule and not rely on dlopen to load the library. @malfet
- Since we moved nvfuser into a cmake build, we effectively disabled bazel build for nvfuser. This could impact internal workload at Meta, so we need to put support back. cc'ing @vors
Pull Request resolved: https://github.com/pytorch/pytorch/pull/89621
Approved by: https://github.com/davidberard98
This PR allows transposes to be fused with other operations. If a fusion group is formed only from operations that just manipulate metadata in PyTorch (transpose, view, etc.) then this group is not sent to nvFuser.
On top of that if we have converted to `nvprims` but then decided to not form a fusion group we modify the graph use `prim.impl_aten` attribute instead of calling `prim(*args, **kwargs)` that has a higher overhead.
cc @kevinstephano @jjsjann123
Pull Request resolved: https://github.com/pytorch/pytorch/pull/86967
Approved by: https://github.com/jjsjann123, https://github.com/SherlockNoMad
I'm seeing issue that we lower `_to_copy` into `nvprims.convert_element_type`. In cases where we are casting to a dtype that's not supported by nvfuser, this raise runtime error.
I added a quick check in the lowering part where each op can peek at fx.node and make a runtime decision on whether the given op should be lowered to nvprim.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/85566
Approved by: https://github.com/IvanYashchuk, https://github.com/ngimel
This PR adds `executor_parameters` keyword argument to `torch._prims.executor.execute`.
For now there are two knobs:
* `use_python_fusion_cache: bool = True` whether to use lru_cache when constructing fusion object or not.
* `allow_single_op_fusion: bool = True` whether to allow fusions with single callable
Behavior can be controlled by passing dict with custom specified values as `executor_parameters` argument.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/84482
Approved by: https://github.com/jjsjann123, https://github.com/ngimel
This PR adds nvfuser-specific primitive - `var_mean`.
Interpretation `torch.var_mean` -> `torch.ops.nvprims.var_mean` is handled by `TorchRefsNvfuserCapabilityMode` context manager.
I moved some helper code from `_prims/__init__.py` to `_prims_common`. Correctness is tested with OpInfo tests (see `PythonRefInfo("ops.nvprims.var_mean"`).
Layer norm reference now uses `torch.var_mean` instead of `torch._refs.var_mean` to allow interception. Here's a simple comparison of performance with this PR and master (on 3080ti):
```py
import torch
from torch._prims.context import TorchRefsNvfuserCapabilityMode
from torch.fx.experimental.proxy_tensor import make_fx
from torch._prims.executor import execute
def func(a):
return torch.native_layer_norm(a, (1024,), None, None, 1e-6)
a = torch.randn(10, 512, 1024, dtype=torch.float16, device="cuda")
with TorchRefsNvfuserCapabilityMode():
gm = make_fx(func)(a)
for _ in range(10):
execute(gm, a, executor="strictly_nvfuser");
```
run with `PYTORCH_NVFUSER_DUMP=dump_eff_bandwidth python script.py`
```py
# WITH THIS PR
# kernel1 run in 0.032768 ms, achieved: 641.25 GB/s
# kernel1 run in 0.033792 ms, achieved: 621.818 GB/s
# kernel1 run in 0.032768 ms, achieved: 641.25 GB/s
# kernel1 run in 0.032608 ms, achieved: 644.396 GB/s
# kernel1 run in 0.031744 ms, achieved: 661.935 GB/s
# kernel1 run in 0.031744 ms, achieved: 661.935 GB/s
# kernel1 run in 0.032768 ms, achieved: 641.25 GB/s
# kernel1 run in 0.03072 ms, achieved: 684 GB/s
# kernel1 run in 0.031744 ms, achieved: 661.935 GB/s
# kernel1 run in 0.031744 ms, achieved: 661.935 GB/s
# ON MASTER
# kernel1 run in 0.05632 ms, achieved: 373.091 GB/s
# kernel1 run in 0.044032 ms, achieved: 477.209 GB/s
# kernel1 run in 0.044032 ms, achieved: 477.209 GB/s
# kernel1 run in 0.044032 ms, achieved: 477.209 GB/s
# kernel1 run in 0.043808 ms, achieved: 479.649 GB/s
# kernel1 run in 0.043008 ms, achieved: 488.571 GB/s
# kernel1 run in 0.044032 ms, achieved: 477.209 GB/s
# kernel1 run in 0.043008 ms, achieved: 488.571 GB/s
# kernel1 run in 0.043008 ms, achieved: 488.571 GB/s
# kernel1 run in 0.043008 ms, achieved: 488.571 GB/s
```
So this PR gives about 35% improvement in performance using nvfuser executor with this specific normalized shape.
Also this PR fixes https://github.com/pytorch/pytorch/issues/83506 (see the change in `torch/csrc/jit/python/pybind_utils.cpp`).
Ref. https://github.com/pytorch/pytorch/issues/80187
Pull Request resolved: https://github.com/pytorch/pytorch/pull/83508
Approved by: https://github.com/ngimel
This PR adds nvfuser-specific primitive - `var_mean`.
Interpretation `torch.var_mean` -> `torch.ops.nvprims.var_mean` is handled by `TorchRefsNvfuserCapabilityMode` context manager.
I moved some helper code from `_prims/__init__.py` to `_prims_common`. Correctness is tested with OpInfo tests (see `PythonRefInfo("ops.nvprims.var_mean"`).
Layer norm reference now uses `torch.var_mean` instead of `torch._refs.var_mean` to allow interception. Here's a simple comparison of performance with this PR and master (on 3080ti):
```py
import torch
from torch._prims.context import TorchRefsNvfuserCapabilityMode
from torch.fx.experimental.proxy_tensor import make_fx
from torch._prims.executor import execute
def func(a):
return torch.native_layer_norm(a, (1024,), None, None, 1e-6)
a = torch.randn(10, 512, 1024, dtype=torch.float16, device="cuda")
with TorchRefsNvfuserCapabilityMode():
gm = make_fx(func)(a)
for _ in range(10):
execute(gm, a, executor="strictly_nvfuser");
```
run with `PYTORCH_NVFUSER_DUMP=dump_eff_bandwidth python script.py`
```py
# WITH THIS PR
# kernel1 run in 0.032768 ms, achieved: 641.25 GB/s
# kernel1 run in 0.033792 ms, achieved: 621.818 GB/s
# kernel1 run in 0.032768 ms, achieved: 641.25 GB/s
# kernel1 run in 0.032608 ms, achieved: 644.396 GB/s
# kernel1 run in 0.031744 ms, achieved: 661.935 GB/s
# kernel1 run in 0.031744 ms, achieved: 661.935 GB/s
# kernel1 run in 0.032768 ms, achieved: 641.25 GB/s
# kernel1 run in 0.03072 ms, achieved: 684 GB/s
# kernel1 run in 0.031744 ms, achieved: 661.935 GB/s
# kernel1 run in 0.031744 ms, achieved: 661.935 GB/s
# ON MASTER
# kernel1 run in 0.05632 ms, achieved: 373.091 GB/s
# kernel1 run in 0.044032 ms, achieved: 477.209 GB/s
# kernel1 run in 0.044032 ms, achieved: 477.209 GB/s
# kernel1 run in 0.044032 ms, achieved: 477.209 GB/s
# kernel1 run in 0.043808 ms, achieved: 479.649 GB/s
# kernel1 run in 0.043008 ms, achieved: 488.571 GB/s
# kernel1 run in 0.044032 ms, achieved: 477.209 GB/s
# kernel1 run in 0.043008 ms, achieved: 488.571 GB/s
# kernel1 run in 0.043008 ms, achieved: 488.571 GB/s
# kernel1 run in 0.043008 ms, achieved: 488.571 GB/s
```
So this PR gives about 35% improvement in performance using nvfuser executor with this specific normalized shape.
Also this PR fixes https://github.com/pytorch/pytorch/issues/83506 (see the change in `torch/csrc/jit/python/pybind_utils.cpp`).
Ref. https://github.com/pytorch/pytorch/issues/80187
Pull Request resolved: https://github.com/pytorch/pytorch/pull/83508
Approved by: https://github.com/ngimel
Using fx.Interpreter is a nice way of modifying the calls inside of FX graphs, but it introduces unnecessary overhead in this case.
Example:
```py
import torch
from torch.fx.experimental.proxy_tensor import make_fx
from torch._prims.context import TorchRefsNvfuserCapabilityMode
from torch._prims.executor import execute
a = torch.randn(3, 2, dtype=torch.float16, device="cuda")
s = torch.sigmoid
d = torch.digamma # digamma is not supported in nvfuser and aten eager execution is used
def func(a):
return s(d(s(d(s(d(s(a)))))))
with TorchRefsNvfuserCapabilityMode():
gm = make_fx(func)(a)
%%timeit
execute(gm, a, executor="nvfuser"); torch.cuda.synchronize();
# On master: 350 µs
# With this PR: 130 µs
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/83607
Approved by: https://github.com/ezyang
This PR does not include an NVFuser frontend cache but it decouples the backed Fusion IR exposure and instead builds it as needed, if there was a cache, by recording the requested definition for replay to start the process of building a Fusion if it doesn't already exist. Another PR will be put up to include the actual caching.
The main change in the Python Frontend is that the NVFuser Fusion IR is not directly defined by the interface. Currently, there is direct connection between the Python API and the creation of the Fusion IR and Object. This means the user defines TensorViews, Scalars, and calls Arith Functions (IR Expressions) on those IR Values. The goal is to disconnect the Python API from directly specifying the Fusion IR and enable caching of the IR so a Fusion Object is not necessarily built every time a Fusion Definition is seen.
The FusionDefinition in Python will mostly look the same except the Definition is now being recorded in a light weight representation called a "Recording" of Records. If the Description is not already cached, the Records are executed to build the Fusion IR. Initially, there is no caching because I am trying to bring up the representation first and get it correctly working.
This is what the Records look like. The records are functors that are called if it is necessary to build the Fusion IR
torch/csrc/jit/codegen/cuda/python_frontend/fusion_record.h
**Tensor Definition Record**
_Note: The Tensor Definition will change for runtime contiguity caching, I am just matching what is already there for now._
```
InputTensorRecord(
std::vector<size_t> _outputs,
std::vector<int64_t> _symbolic_sizes,
std::vector<bool> _contiguous_info,
NvfDataType _dtype)
: RecordFunctor({}, std::move(_outputs)),
symbolic_sizes(std::move(_symbolic_sizes)),
contiguous_info(std::move(_contiguous_info)),
dtype(_dtype) {}
void operator()(FusionDefinition& fd) final {
auto tv = TensorViewBuilder()
.ndims(symbolic_sizes.size())
.contiguity(contiguous_info)
.shape(symbolic_sizes)
.dtype(dtype)
.build();
fd.fusion_state.at(outputs.at(0)) = tv;
fd.addInput(tv);
}
std::vector<int64_t> symbolic_sizes;
std::vector<bool> contiguous_info;
NvfDataType dtype;
};
```
**Generic Templatized Op Record Definition**
Op Records are notable because they record Fusion IR arith functions as the `fusion_op_`.
```
template <class OutType, class... ArgTypes>
struct OpRecord : RecordFunctor {
OpRecord(
std::vector<size_t> _args,
std::vector<size_t> _outputs,
std::function<OutType(ArgTypes...)> fusion_op)
: RecordFunctor(std::move(_args), std::move(_outputs)),
fusion_op_(fusion_op) {}
template <class TupleType, std::size_t... Is>
OutType opFunc(
FusionDefinition& fd,
TupleType& tp,
std::index_sequence<Is...>) {
return fusion_op_(
dynamic_cast<typename std::tuple_element<Is, TupleType>::type>(
fd.fusion_state.at(args.at(Is)))...);
}
void operator()(FusionDefinition& fd) final {
using arg_tuple_t = std::tuple<ArgTypes...>;
auto indices =
std::make_index_sequence<std::tuple_size<arg_tuple_t>::value>();
arg_tuple_t inputs;
auto output = opFunc(fd, inputs, indices);
fd.fusion_state.at(outputs.at(0)) = output;
}
private:
std::function<OutType(ArgTypes...)> fusion_op_;
};
```
Perhaps the most confusing aspect of the Python Frontend is the `FusionDefinition`. The C++ Class that is bound to is very light weight, purposely. In an attempt to make sure users don't have to touch more than one file when adding new ops, assuming an appropriate Record has already been defined, the Python bindings effectively create functions that act on the FusionDefinition and appear as part of the class in Python but are not part of the class in C++.
Here is an example of a Unary Op Macro. It is creating the binding to a lambda function that effectively appears as a FusionDefinition operation in Python. The other way to do this would have been to create a class method directly in the `FusionDefinition` C++ and have a separate binding to that method.
```
#define NVFUSER_PYTHON_BINDING_UNARY_OP(op_str, op_name) \
nvf_ops.def( \
op_str, \
[](nvfuser::FusionDefinition::Operators& self, \
nvfuser::Tensor* input) -> nvfuser::Tensor* { \
nvfuser::Tensor* output = new nvfuser::Tensor( \
self.fusion_definition->recording_state.size()); \
self.fusion_definition->recording_state.emplace_back(output); \
self.fusion_definition->recording.emplace_back( \
new nvfuser::OpRecord<NvfTensorView*, NvfTensorView*>( \
{input->index}, \
{output->index}, \
static_cast<NvfTensorView* (*)(NvfTensorView*)>( \
torch::jit::fuser::cuda::op_name))); \
return output; \
}, \
py::return_value_policy::reference); \
```
Here is the `FusionDefinition` class edited for brevity. The playing of the records will be found under the `exit()` method where exit refers to exiting of the Python Context Manager. A `FusionDefinition` is captured through a context manager like the following:
```
fusion = Fusion()
with FusionDefinition(fusion) as fd :
t0 = fd.define_tensor(sizes=[5], strides=[1])
t1 = fd.ops.abs(t0)
fd.add_output(t1)
```
```
class FusionDefinition {
public:
FusionDefinition(FusionOwner* fusion_owner)
: fusion_owner_(fusion_owner),
prev_fusion_(nullptr),
recording(),
recording_state(),
fusion_state(),
ops(this) {}
// Context Manager Methods
FusionDefinition* enter() {
prev_fusion_ = FusionGuard::getCurFusion();
FusionGuard::setCurFusion(fusionPtr());
return this;
}
void exit() {
// Found in the Python Bindings, currently.
//for (auto& record : recording) {
// auto functor = record.get();
// (*functor)(self);
//}
FusionGuard::setCurFusion(prev_fusion_);
prev_fusion_ = nullptr;
}
void addInput(torch::jit::fuser::cuda::Val* input) {
fusionPtr()->addInput(input);
}
void addOutput(torch::jit::fuser::cuda::Val* output) {
fusionPtr()->addOutput(output);
}
Fusion* fusionPtr() {
return fusion_owner_->fusionPtr();
}
private:
FusionOwner* fusion_owner_;
Fusion* prev_fusion_;
public:
std::vector<std::unique_ptr<RecordFunctor>> recording;
std::vector<std::unique_ptr<State>> recording_state;
std::vector<NvfVal*> fusion_state;
struct Operators {
Operators(FusionDefinition* fd) : fusion_definition(fd) {}
// Python operations are effectively bound here.
FusionDefinition* fusion_definition;
};
Operators ops;
};
```
The Fusion IR doesn’t have `define_tensor` or `define_scalar` functions. I made them up and the name for the Python `FusionDefinition` as a more understandable/convenient way to define input tensors and scalars. `TensorView` objects and Fusion IR `Val` objects are not typically defined outside of a Fusion IR `Expr` output (typically arith function outputs) except for inputs to a graph. Mechanically speaking, there are two things you need to do to define the input in the Fusion IR. You need to define the IR `TensorView`/`Val` object and then record that the IR `TensorView`/`Val` object is an input in the `Fusion` Object that encapsulates the Fusion IR. Since the `FusionDefinition` does not correspond one-to-one with the Fusion IR and `define_tensor` and `define_scalar` are made up functions, I decided to combine the `Val` Object creation and recording of the input in the `Fusion` object in one step to reduce the amount of syntax required to define a Fusion in the python interface.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/81578
Approved by: https://github.com/jjsjann123, https://github.com/IvanYashchuk, https://github.com/SherlockNoMad
This PR introduces a new nvFuser executor for FX graphs containing different kinds of nodes, not just `torch.ops.prims` supported by nvFuser. The FX graph is partitioned based on whether nodes are supported or not by nvFuser and supported nodes are fused into subgraphs, that's all using Sherlock's work on the partitioner.
This new partitions-based executor with fallbacks to ATen is used by default with `executor="nvfuser"`. And the previous executor can be used with `executor="strictly_nvfuser"`, naming suggestions are welcome!
Pull Request resolved: https://github.com/pytorch/pytorch/pull/81043
Approved by: https://github.com/jjsjann123, https://github.com/SherlockNoMad
default 128 cache size has been causing no cache hit on some benchmark results with more than 128 partition. Bumping up to a more reasonable cache size.
Note that the simple LRU_CACHE doesn't give us any reuse of repetitive pattern, but that shouldn't be of much issue in our next iteration of nvfuser python API.
script for running benchmarks vvv
https://github.com/SherlockNoMad/NvFuserSample
Pull Request resolved: https://github.com/pytorch/pytorch/pull/81461
Approved by: https://github.com/SherlockNoMad
In the current setup for each call of the `execute` function, a `Fusion` object was constructed using `GraphModule` and args, that's expensive.
This PR makes use of `functools.lru_cache` to pay the `Fusion` creation cost once per `GraphModule` and set of args. Currently, the shape, strides, and dtype of tensors are static it can be changed later to make better use of the nvFuser's internal caching mechanism (by specifying only ndim, contiguity, dtype).
On master:
```py
In [2]: a = torch.randn(3, 3, device='cuda')
In [3]: with TorchRefsMode.push():
...: gm = make_fx(lambda x: torch.sigmoid(x))(a)
...:
In [4]: %%timeit
...: execute(gm, a, executor="nvfuser")
...: torch.cuda.synchronize()
175 ms ± 1.18 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)
```
This PR:
```py
In [2]: a = torch.randn(3, 3, device='cuda')
In [3]: with TorchRefsMode.push():
...: gm = make_fx(lambda x: torch.sigmoid(x))(a)
...:
In [4]: %%timeit
...: execute(gm, a, executor="nvfuser")
...: torch.cuda.synchronize()
62.6 µs ± 9.99 µs per loop (mean ± std. dev. of 7 runs, 10,000 loops each)
```
In addition, this PR adds support for pytree inputs and extends the test for this.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/80525
Approved by: https://github.com/kevinstephano, https://github.com/jjsjann123, https://github.com/SherlockNoMad