Fixes GELU, LeakyRELU and MISH activation functions on non-contiguous tensors (for instance, when a transpose operation was applied on the tensors prior to the MPS operator), forward and backward passes.
I also extended tests on the 3 activation functions to check: full-precision and half-precision, contiguous and non-contiguous, and several dims of tensors: scalars, 1D, empty, 2D, > 3D.
I had issues with Mish and GELU activations when asserting the gradients vs. CPU with sum() on some cases, so I reverted to the previous setup by setting a gradient parameter on .backwards().
This PR also fixes an issue with LeakyRELU on empty tensors.
Fixes#98212huggingface/transformers#22468huggingface/transformers#19353
Pull Request resolved: https://github.com/pytorch/pytorch/pull/123049
Approved by: https://github.com/kulinseth
Which is more stringent than clang when equivalently sized NEON registers are cast to each other. In particular, at one point `uint16x4_t` were cast to `int16x4_t`, which gcc does not allow. Added `vreinterpret_s16_u16` (which is a no-op) to solve this and tested in https://godbolt.org/z/sYb4ThM6M
Test plan: Build aarch64 wheels
Pull Request resolved: https://github.com/pytorch/pytorch/pull/124511
Approved by: https://github.com/mikekgfb
Summary: It seems super confusing that if we set DISABLE_ADDMM_HIP_LT + PYTORCH_TUNABLEOP_ENABLED, the former takes priority. This is because the former goes through the gemm_and_bias and tunable op is integrated with gemm path. Before we can integrate tunable op with gemm_and_bias, we'll probably just let tunable op takes priority
Test Plan: Run a simple linear program and verified.
Differential Revision: D56183954
Pull Request resolved: https://github.com/pytorch/pytorch/pull/124161
Approved by: https://github.com/jeffdaily, https://github.com/nmacchioni
Summary:
This is not great, but our ATen-cpu is not completely GPU agnostic. Previously we have worked on D54453492 (https://github.com/pytorch/pytorch/pull/121082) and D54528255, but there are a few things we haven't resolved, and it's exploding here. So we'll continue to fix them until all are gone.
This ROCm block is for 4.3 which is very old. I don't think it should be supported any more. So let's just kill this macro
Test Plan: CI
Differential Revision: D56172660
Pull Request resolved: https://github.com/pytorch/pytorch/pull/124158
Approved by: https://github.com/jeffdaily, https://github.com/nmacchioni
We override the `__call__` method and register fake, functional, proxy default dispatch mode implementation in its python_key_mode_table.
The idea is:
1. when inputs contains FakeScriptObject, we dispatch it through _get_dispatch mechanism. We implement dispatch mode keys automatically in the operator's constructor.
2. when inputs are not fakified, we dispatch through the original c++ dispatcher.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/123367
Approved by: https://github.com/zou3519
This PR adds in fast semi-structured sparsification kernels to PyTorch.
These kernels allow for accelerated semi-structured sparsification
kernels in PyTorch.
The kernels have been added as aten native functions
In particular, three new functions have been added:
* `torch._sparse_semi_structured_tile`
This function will return the packed representation and metadata for
both X and X', as well as the thread masks. Note that this applies 2:4
sparsity in a 4x4 tile instead of a 1x4 strip as usual.
* `torch._sparse_semi_structured_apply`
This function takes in an input tensor and thread masks from the above
function and returns a packed representation and metadata from applying
thread masks to the input tensor.
* `torch._sparse_semi_structured_apply_dense`
This function does the same thing as above but instead of returning the
tensor in the sparse representation it returns it in the dense
representation
The subclasses have also been updated to add a new
`prune_dense_static_sort`
classmethod to create sparse tensors with this format. I've added some
additional documentatino on how to calculate the compressed tensors
needed to create a SparseSemiStructuredTensor oneself.
To this end, there are two new helper functions added:
`sparse_semi_structured_tile`
`compute_compressed_swizzled_bitmask`
Differential Revision: [D56190801](https://our.internmc.facebook.com/intern/diff/D56190801)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/122350
Approved by: https://github.com/cpuhrsch
On par with `CUDA` implementation.
For `autocast` logic, same with `CUDA` + `Fused Adam`:
- check inf in `gradscalar.step`
- In fused kernel, if there is `inf`, do nothing. If not, unscale the grad ( also write back) and update the param.
**TestPlan**:
```
# extend CUDA only test for CPU fused adagrad
python test_optim.py -k test_fused_matches_forloop
python test_optim.py -k test_fused_large_tensor
python test_torch.py -k test_grad_scaling_autocast_fused
# extend fused test
python test_torch.py -k test_params_invalidated_with_grads_invalidated_between_unscale_and_step
python test_optim.py -k test_can_load_older_state_dict
# newly added test (follow 6b1f13ea2f/test/test_cuda.py (L1108))
python test_optim.py -k test_grad_scaling_autocast_fused_optimizers
```
**Benchmark**:
**5.1x** on 56 core SPR
**Parameter-size=1M**
**Nparams=10**
[test script](https://gist.github.com/zhuhaozhe/ef9a290ad3f8f4067b3373a3bdaa33e7)
```
numactl -C 0-55 -m 0 python bench_adam.py
non-fused 6.0174267292022705 s
fused 1.1787631511688232 s
```
**Note: Fused kernel accuracy**
The accuracy failure in CI shows a little higher than default tolerance
```
2024-04-02T06:09:16.2213887Z Mismatched elements: 21 / 64 (32.8%)
2024-04-02T06:09:16.2214339Z Greatest absolute difference: 1.5735626220703125e-05 at index (6, 6) (up to 1e-05 allowed)
2024-04-02T06:09:16.2214813Z Greatest relative difference: 1.0073336852656212e-05 at index (4, 1) (up to 1.3e-06 allowed)
```
I have debug it step by step and unfortunately we may not able to make the `fused kernel` exactly same with `non fused` one due to compiler optimizations.
For example, in non-fused impl
```
exp_avg_sq.mul_(beta2).addcmul_(grad, grad.conj(), value=1 - beta2)
```
and in fused impl
```
exp_avg_sq_ptr[d] = scalar_t(beta2) * exp_avg_sq_ptr[d];
// std::cout << "exp_avg_sq " << exp_avg_sq_ptr[d] << std::endl;
exp_avg_sq_ptr[d] = exp_avg_sq_ptr[d] +
scalar_t(exp_avg_sq_grad_coefficient) * grad_val * grad_val;
```
If I keep `std::cout`, I can get exactly same results in UT
```
===============param
0.6796758770942688
0.6796758770942688
```
But when I comment out it, there will be a difference
```
===============param
0.6796758770942688
0.6796759366989136
```
So I will make the tolerance a little higher than default one.
Co-authored-by: Jane Xu <janeyx@meta.com>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/123074
Approved by: https://github.com/jgong5, https://github.com/janeyx99
MTIA device has its own Module in PyTorch now.
torch.mtia has following APIs similar to other backends. The lazy_init is also supported.
```
__all__ = [
"init",
"is_available",
"synchronize",
"device_count",
"current_device",
"current_stream",
"default_stream",
"set_stream",
"stream",
"device",
]
```
------------
For device management. We expand AccleratorHooksInterface to support generic device management and it can be used in both C++ and PyThon.
```
def _accelerator_hooks_device_count() -> _int: ...
def _accelerator_hooks_set_current_device(device_index: _int) -> None: ...
def _accelerator_hooks_get_current_device() -> _int : ...
def _accelerator_hooks_exchange_device(device_index: _int) -> _int : ...
def _accelerator_hooks_maybe_exchange_device(device_index: _int) -> _int : ...
```
---------
Adding get_device_module API to retrieve device modules for different device types.
```
def get_device_module(device: Optional[Union[torch.device, str]] = None)
```
---------
@exported-using-ghexport
Differential Revision: [D52923602](https://our.internmc.facebook.com/intern/diff/D52923602/)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/123612
Approved by: https://github.com/albanD
ghstack dependencies: #123611
Differential Revision: D56200666
Previously, when we hit the Functionalize kernel for lift_fresh_copy, we directly dispatch self.clone() to proxy dispatch. As a result, we end up receiving a functional tensor at proxy dispatch. As a work around, I unwrap self manually. Not sure, why it works ok in aot-dispatch tho
Pull Request resolved: https://github.com/pytorch/pytorch/pull/124198
Approved by: https://github.com/bdhirsh
By creating constants using input tensors dtype
One line reproducer:
```
python -c "import torch; x=torch.arange(3, dtype=torch.float16,device='mps');print(torch.nn.functional.binary_cross_entropy(x, x))"
```
Before the change
```
loc("mps_subtract"("(mpsFileLoc): /AppleInternal/Library/BuildRoots/ce725a5f-c761-11ee-a4ec-b6ef2fd8d87b/Library/Caches/com.apple.xbs/Sources/MetalPerformanceShadersGraph/mpsgraph/MetalPerformanceShadersGraph/Core/Files/MPSGraphUtilities.mm":233:0)): error: input types 'tensor<f32>' and 'tensor<3xf16>' are not broadcast compatible
LLVM ERROR: Failed to infer result type(s).
```
After
```
tensor(-33.7812, device='mps:0', dtype=torch.float16)
```
Fixes https://github.com/pytorch/pytorch/issues/124252
Pull Request resolved: https://github.com/pytorch/pytorch/pull/124258
Approved by: https://github.com/kulinseth
Summary:
We explicitly set the cublas workspace even though CUDA 12.2+ fixed the issue where memory usage increased during graph capture. Original issue: https://github.com/pytorch/pytorch/pull/83461
This is because in CUDA 12.2+, the use of cudaMallocAsync in cublas will allocate memory dynamically (even if they're cheap) outside PyTorch's CUDA caching allocator. It's possible that CCA used up all the memory and cublas's cudaMallocAsync will return OOM
Test Plan: CI
Differential Revision: D56226746
Pull Request resolved: https://github.com/pytorch/pytorch/pull/124250
Approved by: https://github.com/houseroad, https://github.com/eqy
# Motivation
According to [[RFC] Intel GPU Upstreaming](https://github.com/pytorch/pytorch/issues/114723), we would like to upstream amp autocast policy to facilitate the functionality and accuracy of `torch.compile` on e2e benchmarks.
# Solution
The first PR aims to make macro `KERNEL` to be generic. It accepts two types of inputs, like `(DISPATCH, OP, POLICY)` and `(DISPATCH, OP, OVERLOAD, POLICY)`.
The second PR intends to refactor CUDA's autocast policy to make it can be shared with `XPU` backend.
The final PR would like to support XPU autocast policy which shares the same recipe with `CUDA` backend.
# Additional Context
Another motivation is we would like to unify autocast API and provide the generic APIs, like:
- `torch.get_autocast_dtype(device_type)`
- `torch.set_autocast_dtype(device_type)`
- `torch.is_autocast_enabled(device_type)`
- `torch.set_autocast_enabled(device_type)`
Pull Request resolved: https://github.com/pytorch/pytorch/pull/124050
Approved by: https://github.com/jgong5, https://github.com/EikanWang, https://github.com/gujinghui, https://github.com/albanD
# Motivation
This PR is a part of RFC #114848, and it is a successor PR of #116249 and #116019. This PR would depend on oneDNN compilation in #116249. Some runtime support is needed in #116019.
Aten operators like `addmm`, `baddmm` is defined in `Blas.cpp` in `aten/src/ATen/native/mkldnn/xpu/`.
Accompanied with these files provide core functionaliy, `BlasImpl.h`, `Utils.h` and other file provide basic utilities for them. For instance, `Utils.h` provide common memory descriptor query utils for `Matmul.h` and these utility function will also be used in other primitive, like `convolution`. `BlasImpl.h` is a header file that provide helper for handling shape info processing in matmul related operators. It would not only help basic GEMM operator like `addmm, baddmm` but also help fusion operators used in `torch.compile` like `linear_pointwise` in #117824.
In next stage, we would continually complete the oneDNN support through enabling `matmul fusion` and `convolution` related code.
Co-authored-by: xiaolil1 <xiaoli.liu@intel.com>
Co-authored-by: lei,zhenyuan <zhenyuan.lei@intel.com>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/117202
Approved by: https://github.com/EikanWang, https://github.com/jgong5, https://github.com/malfet
ghstack dependencies: #117098, #117112
# Motivation
As proposed in https://github.com/pytorch/pytorch/issues/114848 and https://github.com/pytorch/pytorch/issues/114723, oneDNN library is an important component for Intel GPU software ecosystem.
Current PR is based on #117098, where oneDNN library for Intel GPU should be ready. This PR is the integration code from aten to oneDNN. GEMM integration code is the core part in this PR. Accompanied with GEMM, more basic support like runtime (device, stream), primitive attr is also included.
We put the oneDNN integration code in directory `aten/src/ATen/native/mkldnn/xpu/detail`. We add a namespace `at::native::xpu::onednn` for oneDNN integration.
The code in this PR would be used in following PRs, where aten operators would call the functions in these integration code.. We separate the prs due to onednn integration is logically separable with aten operator implementation. Also, this can ease the burden of reviewing by avoid too much codes in single PR.
Co-authored-by: xiaolil1 <xiaoli.liu@intel.com>
Co-authored-by: lei,zhenyuan <zhenyuan.lei@intel.com>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/117112
Approved by: https://github.com/EikanWang, https://github.com/jgong5, https://github.com/albanD
By unrolling middle loop by 16 elements and using neon to decode packed int4 to float32.
Unrolling entire `n` loop actually makes it a tad slower, probably because ARM has smaller register file that x86
Before/after performance running stories110M on M2Pro
| eager (before) | eager (after) | compile(before) | compile (after) |
| ---- | --- | -- | -- |
| 28 | 57 | 31 | 104 |
Pull Request resolved: https://github.com/pytorch/pytorch/pull/124257
Approved by: https://github.com/mikekgfb
This PR:
- adds a new torch.library.register_fake and deprecates
torch.library.impl_abstract. The motivation is that we have a lot of
confusion around the naming so we are going to align the naming with
the actual subsystem (FakeTensor).
- renames `m.impl_abstract_pystub("fbgemm_gpu.sparse_ops")` to
`m.has_python_registration("fbgemm_gpu.sparse_ops")`. No deprecation
here yet; I need to test how this works with static initialization.
- Renames a bunch of internals to match (e.g. abstractimplpystub ->
pystub)
I'm scared to rename the Python-side internal APIs (e.g.
torch._library.abstract_impl) because of torch.package concerns. I'll do
that in its own isolated PR next just in case it causes problems.
DEPRECATION NOTE: torch.library.impl_abstract was renamed to to
torch.library.register_fake. Please use register_fake. We'll delete
impl_abstract in a future version of PyTorch.
Test Plan:
- existing tests
Pull Request resolved: https://github.com/pytorch/pytorch/pull/123937
Approved by: https://github.com/albanD
This PR adds in fast semi-structured sparsification kernels to PyTorch.
These kernels allow for accelerated semi-structured sparsification
kernels in PyTorch.
The kernels have been added as aten native functions
In particular, three new functions have been added:
* `torch._sparse_semi_structured_tile`
This function will return the packed representation and metadata for
both X and X', as well as the thread masks. Note that this applies 2:4
sparsity in a 4x4 tile instead of a 1x4 strip as usual.
* `torch._sparse_semi_structured_apply`
This function takes in an input tensor and thread masks from the above
function and returns a packed representation and metadata from applying
thread masks to the input tensor.
* `torch._sparse_semi_structured_apply_dense`
This function does the same thing as above but instead of returning the
tensor in the sparse representation it returns it in the dense
representation
The subclasses have also been updated to add a new
`prune_dense_static_sort`
classmethod to create sparse tensors with this format. I've added some
additional documentatino on how to calculate the compressed tensors
needed to create a SparseSemiStructuredTensor oneself.
To this end, there are two new helper functions added:
`sparse_semi_structured_tile`
`compute_compressed_swizzled_bitmask`
Differential Revision: [D56190801](https://our.internmc.facebook.com/intern/diff/D56190801)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/122350
Approved by: https://github.com/cpuhrsch
Fixes#104729
This improves the compiled mode performance of Softmax (by 20%) and other operations (like batchnorm) that invoke the reduce_all function. Thereby also improves BERT inference by around 8%.
Tested on a graviton 3 instance (c7g.4xl). Tests were run in a single-threaded manner.
Script attached below.
Command: `OMP_NUM_THREADS=1 LRU_CACHE_CAPACITY=1024 DNNL_DEFAULT_FPMATH_MODE=BF16 python TestSoftmax.py`
[TestSoftmax.txt](https://github.com/pytorch/pytorch/files/14910754/TestSoftmax.txt)
```python
import torch
import torch.nn as nn
from torch.profiler import profile, record_function, ProfilerActivity
model = nn.Softmax().eval()
compiled_model = torch.compile(model)
inputs = torch.randn(1024, 1024)
with torch.set_grad_enabled(False):
for _ in range(50):
compiled_model(inputs) #Warmup
print("Warmup over")
with profile(activities=[ProfilerActivity.CPU]) as prof:
with record_function("model_inference"):
for _ in range(100):
compiled_model(inputs)
print(prof.key_averages().table(sort_by="self_cpu_time_total"))
# Check if the compiled model inference and the eager model inference are similar using torch.allclose
print(torch.allclose(compiled_model(inputs), model(inputs)))
```
Co-authored-by: Nikita Shulga <2453524+malfet@users.noreply.github.com>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/123584
Approved by: https://github.com/jgong5, https://github.com/malfet
Current implementation drops the negative frequency components even when the user doesn't ask for the one-sided transform. The tests for the negative frequency components seem to have worked by accident due to internal implementation details but the issue becomes evident in MacOs 14.4.
Co-authored-by: Nikita Shulga <2453524+malfet@users.noreply.github.com>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/123274
Approved by: https://github.com/malfet
Fix: #120336
This PR fixes an issue on AOTAutograd, specifically on backends that don't support views
by themselves (e.g. XLA). Previously, AOTAutograd tried to reconstruct output views by
calling `as_strided` on the concrete bases using sizes and strides of the outputs that
aliased them. Since backends such as XLA doesn't support tensor aliasing, the sizes and
strides would be that of a contiguous tensor (not a view tensor). Because of that, calling
`as_strided` would error, since the output tensor would be bigger than its base. Instead,
this PR applies the sequence of `ViewMeta` gathered for each output during the
functionalization phase.
**Note:** we intentionally don't support base tensors that went through metadata mutation,
i.e. in-place view operations.
In summary, this PR:
- Introduces one `FunctionalTensorWrapper` member function alongside its Python APIs
- `apply_view_metas(base)`: applies the `ViewMeta` sequence of the given instance onto
another base
- Introduces a `OutputAliasInfo.functional_tensor` field
- Saves the `FunctionalTensorWrapper` instance collected by the functionalization phase
- Wraps it with a new `FunctionalTensorMetadataEq` class for comparing only the
metadata of the tensors
- Plumbs `OutputAliasInfo.functional_tensor` to `gen_alias_from_base` function
- Applies the `ViewMeta` sequence of the saved `FunctionalTensor` onto `aliased_base_tensor`
- Propagates `OutputAliasInfo.functional_tensor` when updating `fw_metadata`
(this PR description was updated in order to reflect the most recent changes)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/121007
Approved by: https://github.com/bdhirsh