## Summary
As part of #125683, this PR modifies existing CPU GEMM cpp template & micro-kernel template to enable int8 WoQ GEMM auto-tuning with AVX2, AVX512 & AMX ISAs (the latter is only available on Xeon 4th generation & beyond).
WoQ GEMM takes FP16/BF16 activations, int8 weights, and scale of the same dtype as activations.
The operation is equivalent to `torch.nn.functional.linear(x, w.to(x.dtype)) * scale`, which is essentially what the ATen op `torch.ops.aten._weight_int8pack_mm` currently does (except that weights are not cached by it). Weights will be considered constant & cached, so this implementation is suitable for inference, and not QAT. `scale` is supported as a `mul` epilogue.
Only BF16 activations have been supported in this PR because for FP16 & FP32, weight is dequantized during constant-folding pass of freezing, and then after auto-tuning, performance with a large `M` dimension may be better than either torch.ops.aten._weight_int8pack_mm, or the WoQ micro-kernel support introduced in this PR, which dequantizes `w` within the micro-kernel.
While even BF16 activations with a large `M` dimension may benefit from dequantizing `w` beforehand, for now, they would use WoQ support in GEMM templates for auto-tuning, and then a subsequent PR would add logic for deciding whether or not to dequantize weights beforehand.
### Performance
#### AMX
Op-level speedup due to AMX micro-kernel (selected during auto-tuning) on 32 physical cores of Intel(R) Xeon(R) Platinum 8468H (of Xeon 4th generation series, codenamed Sapphire Rapids) vs. ATen kernel `torch.ops.aten._weight_int8pack_mm`. Intel OpenMP & tcmalloc were preloaded.
In a few cases with an odd `K`, the implementation being added in this PR may not perform as well as the ATen kernel, which is unrelated to this PR, though, since `test_linear_amx` also exhibits similar datapoints. In those cases, the AMX micro-kernel might be slower than AVX512 micro-kernel, so if such sets of shapes are used for auto-tuning, either the AVX512 micro-kernel implementation, or the ATen kernel would be chosen instead.
Benchmarked with unit-tests.
Tabular data at https://gist.github.com/sanchitintel/294811a86c8ff6b867c668ae2107c405?permalink_comment_id=5142442#gistcomment-5142442
The AVX512 micro-kernel was disabled to collect data for AMX micro-kernel.
#### AVX2/AVX512 micro-kernels
Tabular data at at https://gist.github.com/sanchitintel/52b5fa9c66f791be19e48e2aa6423dc4?permalink_comment_id=5142437#gistcomment-5142437
### Follow-up
1. int4 WoQ GEMM micro-kernel will also be added in a separate PR.
2. A subsequent PR would add logic for deciding whether or not to dequantize weights beforehand.
E2E perf measurement should be done with #131310.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/131887
Approved by: https://github.com/jgong5, https://github.com/leslie-fang-intel, https://github.com/jansel
Currently if storage_offset is unbacked symbol and is_align can not be computed compiletime - it hard fails.
Doing the best we can: adding guard_size_oblivious and fallback on False if can not be evaluated compiletime
Pull Request resolved: https://github.com/pytorch/pytorch/pull/132423
Approved by: https://github.com/ezyang
Summary:
A couple of improvements to the generated comments in inductor kernels:
1. Makes the nodes in the comment topologically sorted, I think having them
alphabetically sorted is a gotcha. I was always confused on why the
sorting in the comments did not match the code.
2. Adds a printout of the aten graph fragment corresponding to the
current inductor kernel, to make it easier to map from aten
code to inductor code
Example float8-overhead-related inductor kernel comment after this PR:
```
# kernel path: /tmp/torchinductor_vasiliy/27/c27ts3rdw56ns7od5j6ovdnhxphished2lcu3adclzzixoo7khg5.py
# Source Nodes: [weight_fp8], Original ATen: [aten.mul, aten.clamp, aten._to_copy]
# Source node to ATen node mapping:
# weight_fp8 => clamp_max_1, clamp_min_3, convert_element_type_10, convert_element_type_11, convert_element_type_9, mul_3
# Graph fragment:
# %mul_3 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%primals_2, %convert_element_type_8), kwargs = {})
# %convert_element_type_9 : [num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%mul_3, torch.float32), kwargs = {})
# %clamp_min_3 : [num_users=1] = call_function[target=torch.ops.aten.clamp_min.default](args = (%convert_element_type_9, -448.0), kwargs = {})
# %clamp_max_1 : [num_users=1] = call_function[target=torch.ops.aten.clamp_max.default](args = (%clamp_min_3, 448.0), kwargs = {})
# %convert_element_type_10 : [num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%clamp_max_1, torch.bfloat16), kwargs = {})
# %convert_element_type_11 : [num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%convert_element_type_10, torch.float8_e4m3fn), kwargs = {})
triton_poi_fused__to_copy_clamp_mul_5 = async_compile.triton('triton_', '''
```
Test Plan:
Reviewers:
Subscribers:
Tasks:
Tags:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/126698
Approved by: https://github.com/ezyang
ghstack dependencies: #126573
Add functional support for torch.addmm with CK backend. See also #125453
# Implementation details
1. It turns out we can use the same template between addmm and matmul; essentially, matmul is addmm with empty bias
2. The Python generator in CK was updated to generate the shared cpp template. The pip package can be installed from `pip install git+https://github.com/rocm/composable_kernel@add-addmm` and will be merged into `develop` branch after this PR lands to avoid breaking the current matmul
# Testing
`pytest test/inductor/test_ck_backend.py -k addmm`
Pull Request resolved: https://github.com/pytorch/pytorch/pull/130576
Approved by: https://github.com/chenyang78
This PR mostly refactors by putting code into utils files so that they can be shared between codecache.py and compile_fx.py. Afterwards, it then changes compile_fx so that:
- When saving to FXGraphCache, we save onto the CompiledFXGraph all the necessary metadata for running post compile steps (realigning inputs, cudagraphification).
- When loading from FXGraphCache, we use the saved information directly, instead of calculating them from scratch.
What this does is make it so that `FXGraphCache.load()` is a perfect cache on compile_fx_inner, in that it **returns exactly what compile_fx_inner returns**. This also makes it possible for AOTAutogradCache, given a key to the fx graph cache and example inputs, to get back the full return value of compile_fx_inner.
## What's a post compile step?
We define a **post-compile** to be the set of actions that need to run after FXGraphCache either loads from the cache or misses and runs compilation. These steps include:
- Setting the tracing context's output strides
- Running cudagraphs if enabled
- Maybe realign inputs if cudagraphs didn't run
To run these steps, we save all the necessary metadata in CompiledFxGraph, and use them on a cache hit to reconstruct the object.
## Splitting cudagraphs work into pre/post compile
Cudagraphs does a lot of work on the input graph module to determine if cudagraphs can be enabled. This is the code that involves cudagraph_tests and stack traces. This will work in a world where we have access to the input graph module, but with AOTAutograd warm start, we won't have access to that information anymore. Therefore we can split cudagraphs work into two parts: on a cache miss (and therefore a full compile), we do the cudagraphs testing work, and save cudagraph_fail_reasons into the cache. Then on a cache hit, we know whether or not we can run cudagraphs, and if we can't, we can emit the correct error messages.
Implementation notes:
- We save `fx_kwargs` directly onto the CompiledFXGraph. `fx_kwargs` is already, by definition, part of the cache key, so this is safe to do when it comes to cache correctness.
- ^ Why do we do above even though FXGraphCache.load takes fx_kwargs as an argument? Because AOTAutogradCache **doesn't** have access to fx_kwargs: they're annoyingly encoded in the functools.partial() of the fw_compiler, so *only* inductor knows about these options. They're fully captured by the AOTAutogradCache key (since every key to fx_kwargs is either a global config, or a field that's deterministic based on an input graph module), but their values are still needed to run cudagraphs/postprocessing. Therefore, it's easier/safer to store it on the cached result.
- Willing to hear other approaches here if we think saving these extra fields is not reasonable, though I can't think of another way to do this that's less complicated to explain.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/130572
Approved by: https://github.com/eellison
Add the Inductor lowering for `torch._scaled_mm`, whose API was last updated in https://github.com/pytorch/pytorch/pull/128683.
The lowering does:
- for tensor-wise scaling, auto-tune between the default ATen kernel (cuBLAS) and Triton kernel configurations.
- for row-wise scaling, auto-tune between the default ATen kernel (CUTLASS kernel added in https://github.com/pytorch/pytorch/pull/125204) and Triton kernel configurations.
The Triton kernel template is based on 3ad9031d02 (D56337896) by @choutim, without using SPLIT_K, and that of mm `torch/_inductor/kernel/mm.py`
## Testing:
- Logging shows max-autotune tuning (`AUTOTUNE scaled_mm`) for both tensor-wise and row-wise scaling when called with the two scaling types.
- Row-wise scaling allows operator fusion between preceding pointwise/reduction op and amax/cast:
- output code Evaluating m=256, n=256, k=256, fusion_case='pointwise', scaling_mode='row'
- P1477224245 - 2 kernels
- output code Evaluating m=2048, n=256, k=2048, fusion_case='reduction', scaling_mode='row'
- P1477227340 - 2 kernels
- UT `python test/inductor/test_fp8.py -- TestFP8Lowering`
## Benchmarking
Eager/compiled tensor-wise/row-wise scaling for various shapes:
https://docs.google.com/spreadsheets/d/1VfWEVuyrwoWysfbS0_u2VHJ-PsdWkF1qIsiD60AzTes/edit?gid=2113587669#gid=2113587669
- Some of the “compiled” cases are slightly slower than “eager”. It’s because max-autotune selected the ATen kernel in the compiled case, and I think the discrepancy is variance.
Eager/compiled tensor-wise/row-wise scaling with pointwise/reduction preceding op for various shapes:
https://docs.google.com/spreadsheets/d/1Nv07NrdffQIoDeMjo9E0V-E-EYrEN0WysO_bn1bc6ns/edit?gid=1715488446#gid=1715488446
## Questions for reviewers:
- Should the type of the accumulator `ACC_TYPE` always be in float32? If not, where is this type set (output layout?)?
## Todo:
- Make the Triton template use the improved persistent kernel version (https://github.com/pytorch/FBGEMM/pull/2735 by @htyu)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/130422
Approved by: https://github.com/ipiszy
```
# Mode to emulate pytorch eager numerics for lower precision (fp16, bf16)
# Pytorch eager computes bf16/fp16 by upcasting inputs to fp32 and downcasting after
# For multiple, fused pointwise nodes, inductor will elide the intermediary upcasts and downcasts
# Typically this should be closer to fp64 ref numerics. However, it can be useful for debugging
# to emulate the eager numerics.
```
We add extra upcasts and downcasts for pointwise nodes that correspond to casts that existed in the original user program (excluding pointwise nodes that are emitted during decomposition). Since this is mostly for debugging, I added this information in the `meta` so that this mode does not have unintended side effects like changing pattern matching.
in theory there could also be some other casts with fused reduction -> reduction, although i havent seen this in practice as much. could be done as follow up. note: only works with cuda backend right now.
This mode was sufficient to eliminate compile differences from https://fb.workplace.com/groups/385893200869952/posts/464263173032954/?comment_id=465199259606012&reply_comment_id=465676792891592.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/131595
Approved by: https://github.com/shunting314, https://github.com/bdhirsh, https://github.com/jansel
This PR enables the Inductor compute/comm reordering passes to Traceable FSDP2 to achieve overlap. Note that the overlap is not maximally optimized yet and the follow-up work will be done in subsequent PRs.
Test commands:
- `pytest -rA test/distributed/test_compute_comm_reordering.py::TestComputeCommReorderingMultiProc`
- `pytest -rA test/distributed/_composable/fsdp/test_fully_shard_compile.py::TestFullyShardCompile::test_transformer_backend_inductor`
- `pytest -rA test/distributed/_composable/fsdp/test_fully_shard_compile.py::TestFullyShardCompile::test_nested_fully_shard_backend_inductor`
Pull Request resolved: https://github.com/pytorch/pytorch/pull/131614
Approved by: https://github.com/yifuwang
ghstack dependencies: #131510
This PR creates these `GroupedSchedulerNode`s:
- One for each all-gather code block (cast + copy-in + all-gather)
- One for each all-gather-wait code block (all-gather-wait + copy-out)
- One for each reduce-scatter code block (copy-in + reduce-scatter)
- One for each reduce-scatter-wait code block (reduce-scatter-wait)
This serves two goals:
- Prevent outside ops from being fused into these op groups, in order to have more predicable memory usage.
- Make it easier to specify the dependency e.g. from `i+1` all-gather group node to the `i` all-gather-wait group node, to enforce FSDP2 comm ordering (i.e. "serialization of comms").
The actual "reorder-for-FSDP-compute-comm-overlap" PR will come next.
Test commands:
- `pytest -rA test/distributed/test_compute_comm_reordering.py::TestComputeCommReorderingMultiProc`
- `pytest -rA test/distributed/_composable/fsdp/test_fully_shard_compile.py::TestFullyShardCompile::test_transformer_backend_inductor`
- `pytest -rA test/distributed/_composable/fsdp/test_fully_shard_compile.py::TestFullyShardCompile::test_nested_fully_shard_backend_inductor`
Pull Request resolved: https://github.com/pytorch/pytorch/pull/131510
Approved by: https://github.com/yifuwang
Currently we require `n % register_block_n == 0` which typically bring good perf when `n` is a multiply of 8, 16, 32 etc. while will fall back to the reference micro gemm otherwise (where `register_block_n == 1`). This PR optimizes this by padding `n` to the multiple of `register_block_n` which is 8, 16, 32 etc. for packed weight. Therefore, the micro-gemm can work as is on the padded `n`. When the weight is padded, we will use the local accumulation buffer to get the result from micro-gemm and then unpadded (sliced) before storing back to the output buffer.
Performance numbers measured on "Intel (R) Xeon (R) CPU Max 9480", single core, bf16.
Before
AUTOTUNE linear_unary(512x768, 3073x768, 3073)
_linear_pointwise 2.3563 ms 100.0%
cpp_packed_gemm_0 710.5902 ms 0.3%
After
AUTOTUNE linear_unary(512x768, 3073x768, 3073)
cpp_packed_gemm_0 1.8909 ms 100.0%
_linear_pointwise 2.1016 ms 90.0%
Pull Request resolved: https://github.com/pytorch/pytorch/pull/130690
Approved by: https://github.com/leslie-fang-intel, https://github.com/jansel
ghstack dependencies: #130675
- More conservative estimation of plannable inputs
- Consider constant_pad_nd as pointwise node in concat lowering
- Use aten.cat instead of constant pad ndwhen padding just a single dimension because it can be memory-planned away
Pull Request resolved: https://github.com/pytorch/pytorch/pull/128909
Approved by: https://github.com/Chillee
FSDP2 eager pre-allocates the output buffer for AllGather and the AllGather just writes into that buffer. However, under compile, by default we use out-of-place AllGather, which means in Traceable FSDP2 case we will be unnecessarily using more memory than eager. We want to re-inplace that AllGather instead.
This PR adds a post_grad pass to re-inplace all_gather_into_tensor (i.e. changing it from `all_gather_into_tensor.default` out-of-place op to `all_gather_into_tensor_out.default` out-variant op).
One thing to note is that since with this pass we are introducing a mutable op into the post_grad FX graph, we must do this pass after `reinplace_inplaceable_ops` (at which point we are okay again with having mutable ops in the graph). To facilitate this, this PR adds a `post_grad_custom_post_reinplace_pass` extension point to allow user-defined post-reinplace FX passes.
---
Test commands:
- `pytest -rA test/distributed/_composable/fsdp/test_fully_shard_compile.py::TestFullyShardCompile::test_transformer_fullgraph_backend_inductor`
---
Pull Request resolved: https://github.com/pytorch/pytorch/pull/129773
Approved by: https://github.com/eellison
In this PR, we abstracted the different types of aten operation parameters as `ParameterMetadata`. This structure intends to be used to represent and store the metadata of each aten operation parameter. Currently, it only supports `Tensor`, `TensorList`, and `Scalar`.
```C++
using ParameterMetadataValue = std::variant<TensorMetadata, std::vector<TensorMetadata>, c10::Scalar>;
```
With this PR, we can extend other parameter-type support in a more modularize way, like `string`, `int`, `double`.
Differential Revision: [D59399546](https://our.internmc.facebook.com/intern/diff/D59399546)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/125308
Approved by: https://github.com/jgong5, https://github.com/jansel, https://github.com/atalman
There is one huge problem this fixes: today, sympify(symint)
produces a float(!!) because Sympy attempts to see if you can
coerce the symint to float in sympify and of course this works on
SymInt.
However, this also has another nontrivial effect: anywhere in Inductor
where sympy expressions are passed around, it is also valid to pass
around a SymInt now. I'm ambivalent about this: it's currently a
mistake to be passing around a SymInt when a sympy expression is
expected. But maybe this is fine?
Signed-off-by: Edward Z. Yang <ezyang@meta.com>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/130166
Approved by: https://github.com/yf225
**Summary**
This PR mainly refactor 2 things:
1. Passing in weight's data type explicitly in `create_micro_gemm` as `input2.dtype`. When registering `CppMicroGemmConfig`, we will reuse `input.dtype` if `input2.dtype` is not explicitly registered.
2. Add an util function to get the output data type and compute data type from input data type.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/129221
Approved by: https://github.com/jgong5, https://github.com/jansel
ghstack dependencies: #128825, #129048, #129049, #129103, #129220
This PR adds an alternative backend for Inductor, adding Composable Kernel Universal GEMM instances to the autotune instance selection.
The implementation is heavily influenced by the series of PRs which adds CUTLASS backend (https://github.com/pytorch/pytorch/issues/106991). The main differences are
(1) customizing compiler for the ROCm platform
(2) customizing template code generation for Composable Kernel Universal GEMM instances.
We provide config tuning knobs for balancing between instance sources compilation time and finding the best instance.
### Testing
Install the ck library
```
pip install git+https://github.com/rocm/composable_kernel@develop
```
Run the test
```
TORCH_LOGS=+torch._inductor \
pytest --capture=tee-sys test/inductor/test_ck_backend.py
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/125453
Approved by: https://github.com/eellison, https://github.com/jansel
Summary: Found during testing with remote caching: Use the same output logger object between graph.py and codecache.py since it's patched in `run_and_get_cpp_code`. That allows us to capture any logging produced from the codecache path when using `run_and_get_cpp_code`. I'm also fixing a few tests that were passing mistakenly because logging was missing.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/128794
Approved by: https://github.com/oulgen, https://github.com/leslie-fang-intel
In this PR, we abstracted the different types of aten operation parameters as `ParameterMetadata`. This structure intends to be used to represent and store the metadata of each aten operation parameter. Currently, it only supports `Tensor`, `TensorList`, and `Scalar`.
```C++
using ParameterMetadataValue = std::variant<TensorMetadata, std::vector<TensorMetadata>, c10::Scalar>;
```
With this PR, we can extend other parameter-type support in a more modularize way, like `string`, `int`, `double`, and other different types to be summarized as the following list. The list is collected from all aten operations and ordered by the number of being used.
- `Tensor`
- `bool`
- `int64_t`
- `TensorList`
- `Scalar`
- `c10::SymIntArrayRef`
- `::std::optional<Tensor>`
- `IntArrayRef`
- `double`
- `c10::SymInt`
- `::std::optional<ScalarType>`
- `::std::optional<double>`
- `::std::optional<bool>`
- `::std::optional<Layout>`
- `::std::optional<Device>`
- `::std::optional<int64_t>`
- `Dimname`
- `::std::optional<Generator>`
- `c10::string_view`
- `::std::optional<c10::string_view>`
- `OptionalIntArrayRef`
- `::std::optional<Scalar>`
- `OptionalSymIntArrayRef`
- `::std::optional<MemoryFormat>`
- `::std::optional<c10::SymInt>`
- `ScalarType`
- `ArrayRef<Scalar>`
- `DimnameList`
- `::std::optional<ArrayRef<double>>`
- `::std::array<bool,3>`
- `::std::optional<DimnameList>`
- `c10::List<::std::optional<Tensor>>`
- `::std::array<bool,2>`
- `Storage`
- `::std::array<bool,4>`
- `Device`
- `DeviceIndex`
- `ITensorListRef`
- `Stream`
- `Layout`
- `MemoryFormat`
Pull Request resolved: https://github.com/pytorch/pytorch/pull/125308
Approved by: https://github.com/jgong5, https://github.com/jansel
This PR implements "V0" of AOTAutogradCache. Given an input to AOTAutograd, we calculate a cache key, then save an AOTAutogradCacheEntry.
Each AOTAutogradCacheEntry has:
- A CompiledForward and optionally a CompiledBackward
- A bunch of metadata.
CompiledForward and CompiledBackward each save the *key* to the FXGraphCache associated with the compiled object. FXGraphCache populates this key field as long as it's able to return a compiled graph given a set of inputs. We then load the same object from the FXGraphCache on an AOTAutogradCache hit.
On cache miss:
- Run AOTAutograd, up to AOTAutogradDispatch.post_compile.
- Save an AOTAutogradCacheEntry to the cache after compiling the necessary portions and receiving a cache key from FXGraphCache. In this we *always* compile the backwards ahead of time. The PR above this one implements backward lazy caching, so that we only save to the cache after compiling the backward in a lazy backward scenario.
- Return the resulting object
On cache hit:
- Run AOTAutogradCacheEntry.post_compile() on the cache key.
- This attempts to load the forward and backward graphs from FXGraphCache
- As long as we successfully load from FXGraphCache, it's a hit. We then rewrap the callable with post compile wrappers using our saved metadata.
For now, we ignore the fakified out and debug wrappers. We only save to the cache if Fakified out is turned off.
V0 Guards behavior:
FXGraphCache serializes guards that are needed in the shape_env based on the symint inputs to the graph. The invariant that AOTAutograd uses here is that the sources for symints given to it by dynamo are exactly the same as the ones it passes to inductor, for both the forward and backward passes. (This does *not* mean that the tensor values passed in are the same: only that their symints are). That is, AOTAutograd and Inductor never create new guards based on symints with *different sources* than those passed to it by inductor.
We don't currently store any AOTAutograd specific guards: my hypothesis is that FXGraphCache already stores these, as any guards generated by AOTAutograd should already be in the shape_env before calling into inductor, and we don't generate new guards post inductor. If this is needed, I'll add it in another diff.
Testing:
We'll start with some basic unit tests, but I'll be adding more and more complicated testing as the next step.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/126791
Approved by: https://github.com/bdhirsh