Commit Graph

14275 Commits

Author SHA1 Message Date
a6a3f2e06b [MPS] Fixes GELU, LeakyRELU and MISH on non-contiguous tensors (#123049)
Fixes GELU, LeakyRELU and MISH activation functions on non-contiguous tensors (for instance, when a transpose operation was applied on the tensors prior to the MPS operator), forward and backward passes.

I also extended tests on the 3 activation functions to check: full-precision and half-precision, contiguous and non-contiguous, and several dims of tensors: scalars, 1D, empty, 2D, > 3D.

I had issues with Mish and GELU activations when asserting the gradients vs. CPU with sum() on some cases, so I reverted to the previous setup by setting a gradient parameter on .backwards().
This PR also fixes an issue with LeakyRELU on empty tensors.

Fixes #98212 huggingface/transformers#22468 huggingface/transformers#19353
Pull Request resolved: https://github.com/pytorch/pytorch/pull/123049
Approved by: https://github.com/kulinseth
2024-04-21 00:12:32 +00:00
929242a15c Revert "torch.mtia module for MTIA device backend (#123612)"
This reverts commit d7e1bf9ff908d2a9c20d5354426d34c539fcb7a1.

Reverted https://github.com/pytorch/pytorch/pull/123612 on behalf of https://github.com/jeffdaily due to This broke ROCm. see test_overrides.py ([comment](https://github.com/pytorch/pytorch/pull/123611#issuecomment-2067363780))
2024-04-19 22:44:26 +00:00
d8a98ddd60 Prep PR for cutlass 3.5 update (#124412)
# Summary
These changes are needed for the upgrade to cutlass 3.5
#123458

Pull Request resolved: https://github.com/pytorch/pytorch/pull/124412
Approved by: https://github.com/Skylion007, https://github.com/nWEIdia, https://github.com/malfet
2024-04-19 22:10:37 +00:00
c74dfca5e7 Int4MM: Unswizzle for different dtypes (#124448)
If dtype is not the one this platform is optimized for, it might need different unswizzling pattenrs Implement ones for non-vectorized flavor of the kernel, so that int4mm can be used with float32 and float16 dtypes

Pull Request resolved: https://github.com/pytorch/pytorch/pull/124448
Approved by: https://github.com/jgong5, https://github.com/mikekgfb
2024-04-19 21:17:15 +00:00
e6a788ac26 Fix compilation on aarch64 with gcc (#124511)
Which is more stringent than clang when equivalently sized NEON registers are cast to each other. In particular, at one point `uint16x4_t` were cast to `int16x4_t`, which gcc does not allow. Added `vreinterpret_s16_u16` (which is a no-op) to solve this and tested in https://godbolt.org/z/sYb4ThM6M

Test plan: Build aarch64 wheels
Pull Request resolved: https://github.com/pytorch/pytorch/pull/124511
Approved by: https://github.com/mikekgfb
2024-04-19 19:53:19 +00:00
661fd23640 [AMD] TunableOp take priority over DISABLE_ADDMM_HIP_LT (#124161)
Summary: It seems super confusing that if we set DISABLE_ADDMM_HIP_LT + PYTORCH_TUNABLEOP_ENABLED, the former takes priority. This is because the former goes through the gemm_and_bias and tunable op is integrated with gemm path. Before we can integrate tunable op with gemm_and_bias, we'll probably just let tunable op takes priority

Test Plan: Run a simple linear program and verified.

Differential Revision: D56183954

Pull Request resolved: https://github.com/pytorch/pytorch/pull/124161
Approved by: https://github.com/jeffdaily, https://github.com/nmacchioni
2024-04-19 19:08:06 +00:00
8869b543e8 [AMD] Remove deprecated macro from COnvUtils (#124158)
Summary:
This is not great, but our ATen-cpu is not completely GPU agnostic. Previously we have worked on D54453492 (https://github.com/pytorch/pytorch/pull/121082) and D54528255, but there are a few things we haven't resolved, and it's exploding here. So we'll continue to fix them until all are gone.

This ROCm block is for 4.3 which is very old. I don't think it should be supported any more. So let's just kill this macro

Test Plan: CI

Differential Revision: D56172660

Pull Request resolved: https://github.com/pytorch/pytorch/pull/124158
Approved by: https://github.com/jeffdaily, https://github.com/nmacchioni
2024-04-19 19:00:31 +00:00
e62169a8fa Support torchbind op dispatch in python (#123367)
We override the `__call__` method and register fake, functional, proxy default dispatch mode implementation in its python_key_mode_table.

The idea is:
1. when inputs contains FakeScriptObject,  we dispatch it through _get_dispatch mechanism. We implement dispatch mode keys automatically in the operator's constructor.
2. when inputs are not fakified, we dispatch through the original c++ dispatcher.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/123367
Approved by: https://github.com/zou3519
2024-04-19 17:17:27 +00:00
c9db59e9e4 [sparse] Add fast semi-structured spasification kernels (#122350)
This PR adds in fast semi-structured sparsification kernels to PyTorch.

These kernels allow for accelerated semi-structured sparsification
kernels in PyTorch.

The kernels have been added as aten native functions

In particular, three new functions have been added:

* `torch._sparse_semi_structured_tile`

This function will return the packed representation and metadata for
both X and X', as well as the thread masks. Note that this applies 2:4
sparsity in a 4x4 tile instead of a 1x4 strip as usual.

* `torch._sparse_semi_structured_apply`

This function takes in an input tensor and thread masks from the above
function and returns a packed representation and metadata from applying
thread masks to the input tensor.

* `torch._sparse_semi_structured_apply_dense`

This function does the same thing as above but instead of returning the
tensor in the sparse representation it returns it in the dense
representation

The subclasses have also been updated to add a new
`prune_dense_static_sort`
classmethod to create sparse tensors with this format. I've added some
additional documentatino on how to calculate the compressed tensors
needed to create a SparseSemiStructuredTensor oneself.

To this end, there are two new helper functions added:
`sparse_semi_structured_tile`
`compute_compressed_swizzled_bitmask`

Differential Revision: [D56190801](https://our.internmc.facebook.com/intern/diff/D56190801)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/122350
Approved by: https://github.com/cpuhrsch
2024-04-19 13:31:58 +00:00
88fa843e58 Add vectorized norm fill for ppc64le (#113351)
This patch adds the vectorized norm fill for ppc64le.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/113351
Approved by: https://github.com/jgong5
2024-04-19 12:34:00 +00:00
b412b75b42 [optim] add fused_adam/adamw_kernel support for CPU device (#123074)
On par with `CUDA` implementation.

For `autocast` logic, same with `CUDA` + `Fused Adam`:
 - check inf in `gradscalar.step`
 - In fused kernel, if there is `inf`, do nothing. If not, unscale the grad ( also write back) and update the param.

**TestPlan**:
```
# extend CUDA only test for CPU fused adagrad
python test_optim.py -k test_fused_matches_forloop
python test_optim.py -k test_fused_large_tensor
python test_torch.py -k test_grad_scaling_autocast_fused

# extend fused test
python test_torch.py -k test_params_invalidated_with_grads_invalidated_between_unscale_and_step
python test_optim.py -k test_can_load_older_state_dict

# newly added test (follow 6b1f13ea2f/test/test_cuda.py (L1108))
python test_optim.py -k test_grad_scaling_autocast_fused_optimizers
```

**Benchmark**:
**5.1x** on 56 core SPR
**Parameter-size=1M**
**Nparams=10**
[test script](https://gist.github.com/zhuhaozhe/ef9a290ad3f8f4067b3373a3bdaa33e7)

```
numactl -C 0-55 -m 0 python bench_adam.py
non-fused 6.0174267292022705 s
fused 1.1787631511688232 s
```

**Note: Fused kernel accuracy**
The accuracy failure in CI shows a little higher than default tolerance
```
2024-04-02T06:09:16.2213887Z Mismatched elements: 21 / 64 (32.8%)
2024-04-02T06:09:16.2214339Z Greatest absolute difference: 1.5735626220703125e-05 at index (6, 6) (up to 1e-05 allowed)
2024-04-02T06:09:16.2214813Z Greatest relative difference: 1.0073336852656212e-05 at index (4, 1) (up to 1.3e-06 allowed)
```
I have debug it step by step and unfortunately we may not able to make the `fused kernel` exactly same with `non fused` one due to compiler optimizations.
For example, in non-fused impl
```
exp_avg_sq.mul_(beta2).addcmul_(grad, grad.conj(), value=1 - beta2)
```
and in fused impl
```
  exp_avg_sq_ptr[d] = scalar_t(beta2) * exp_avg_sq_ptr[d];
  //  std::cout << "exp_avg_sq " <<   exp_avg_sq_ptr[d] << std::endl;
  exp_avg_sq_ptr[d] = exp_avg_sq_ptr[d] +
      scalar_t(exp_avg_sq_grad_coefficient) * grad_val * grad_val;
```
If I keep `std::cout`, I can get exactly same results in UT
```
===============param
0.6796758770942688
0.6796758770942688
```
But when I comment out it, there will be a difference
```
===============param
0.6796758770942688
0.6796759366989136
```
So I will make the tolerance a little higher than default one.

Co-authored-by: Jane Xu <janeyx@meta.com>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/123074
Approved by: https://github.com/jgong5, https://github.com/janeyx99
2024-04-19 11:14:04 +00:00
b2f6cfd9c0 Fix AVX2 int4pack_mm_kernel crash if weighs are unaligned (#124433)
Followup after https://github.com/pytorch/pytorch/pull/124128
`s/_mm256_load_si128/_mm256_loadu_si128/`

Pull Request resolved: https://github.com/pytorch/pytorch/pull/124433
Approved by: https://github.com/desertfire
2024-04-19 05:17:38 +00:00
e0792cf3d6 Make copy_cast, softmax and cat_out unranked (#123191)
Fixes #ISSUE_NUMBER
This helps with the performance as it removes multiple copies of the graphs saved due to their shapes.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/123191
Approved by: https://github.com/DenisVieriu97
2024-04-18 23:14:55 +00:00
d7e1bf9ff9 torch.mtia module for MTIA device backend (#123612)
MTIA device has its own Module in PyTorch now.
torch.mtia has following APIs similar to other backends. The lazy_init is also supported.
```
__all__ = [
    "init",
    "is_available",
    "synchronize",
    "device_count",
    "current_device",
    "current_stream",
    "default_stream",
    "set_stream",
    "stream",
    "device",
]

```
------------
For device management. We expand AccleratorHooksInterface to support generic device management and it can be used in both C++ and PyThon.
```
def _accelerator_hooks_device_count() -> _int: ...
def _accelerator_hooks_set_current_device(device_index: _int) -> None: ...
def _accelerator_hooks_get_current_device() -> _int : ...
def _accelerator_hooks_exchange_device(device_index: _int) -> _int : ...
def _accelerator_hooks_maybe_exchange_device(device_index: _int) -> _int : ...
```

---------
Adding get_device_module API to retrieve device modules for different device types.
```
def get_device_module(device: Optional[Union[torch.device, str]] = None)
```
---------
@exported-using-ghexport

Differential Revision: [D52923602](https://our.internmc.facebook.com/intern/diff/D52923602/)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/123612
Approved by: https://github.com/albanD
ghstack dependencies: #123611
2024-04-18 17:38:06 +00:00
a8cf91c395 Fix predispatch tracing for aten::lift_fresh_copy (#124198)
Differential Revision: D56200666

Previously, when we hit the Functionalize kernel for lift_fresh_copy, we directly dispatch self.clone() to proxy dispatch. As a result, we end up receiving a functional tensor at proxy dispatch. As a work around, I unwrap self manually. Not sure, why it works ok in aot-dispatch tho

Pull Request resolved: https://github.com/pytorch/pytorch/pull/124198
Approved by: https://github.com/bdhirsh
2024-04-18 17:02:38 +00:00
415a8f6398 Fixed issue in affine_grid_backward when grad_grid is non-contiguous (#124370)
Description:
- replaced .view with .reshape to fix the problem when grad_grid is channels last 2d/3d
- added a consistency test

Fixes #124154

Pull Request resolved: https://github.com/pytorch/pytorch/pull/124370
Approved by: https://github.com/lezcano
2024-04-18 16:30:10 +00:00
5677128cb8 [MPS] Fix crash with binary_cross_entropy is invoked for half dtypes (#124258)
By creating constants using input tensors dtype

One line reproducer:
```
python -c "import torch; x=torch.arange(3, dtype=torch.float16,device='mps');print(torch.nn.functional.binary_cross_entropy(x, x))"
```

Before the change
```
loc("mps_subtract"("(mpsFileLoc): /AppleInternal/Library/BuildRoots/ce725a5f-c761-11ee-a4ec-b6ef2fd8d87b/Library/Caches/com.apple.xbs/Sources/MetalPerformanceShadersGraph/mpsgraph/MetalPerformanceShadersGraph/Core/Files/MPSGraphUtilities.mm":233:0)): error: input types 'tensor<f32>' and 'tensor<3xf16>' are not broadcast compatible
LLVM ERROR: Failed to infer result type(s).
```
After
```
tensor(-33.7812, device='mps:0', dtype=torch.float16)
```

Fixes https://github.com/pytorch/pytorch/issues/124252

Pull Request resolved: https://github.com/pytorch/pytorch/pull/124258
Approved by: https://github.com/kulinseth
2024-04-18 15:21:01 +00:00
1325fd94a4 Support xpu autocast policy (#124052)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/124052
Approved by: https://github.com/jgong5, https://github.com/EikanWang, https://github.com/gujinghui, https://github.com/albanD
2024-04-18 14:06:48 +00:00
a0466061e1 Support xpu host allocator (#123080)
# Motivation
This PR mainly covers caching host allocator supported on xpu backend.

# Solution
`XPUCachingHostAllocator` adopts the **same** caching mechanism as cuda via two abstract interfaces -`CachingHostAllocatorImpl` and `CachingHostAllocatorInterface`.

# Additional Context
Following CUDA, this PR adds a new API `getPinnedMemoryAllocator` to support the tensor's memory pinned.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/123080
Approved by: https://github.com/jgong5, https://github.com/gujinghui, https://github.com/EikanWang, https://github.com/albanD
2024-04-18 12:29:21 +00:00
6fcbeb3489 [ATen] Add CPU fp16 support for nll_loss and cross_entropy_loss (#123256)
Add CPU FP16 support for nll_loss and cross_entropy_loss.
Resolve issue #123328.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/123256
Approved by: https://github.com/jgong5, https://github.com/EikanWang, https://github.com/malfet
2024-04-18 11:44:38 +00:00
ec608a5d66 Refactor CUDA's amp cast policy to be generic (#124051)
# Motivation
This PR intends to create several op lists for different policies:
- `AT_FORALL_LOWER_PRECISION_FP` for policy `lower_precision_fp`
- `AT_FORALL_FP32` for policy `fp32`
- `AT_FORALL_FP32_SET_OPT_DTYPE` for policy `fp32_set_opt_dtype`
- `AT_FORALL_PROMOTE` for policy `promote`.

To make sure the other backend can reuse the policy op list.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/124051
Approved by: https://github.com/jgong5, https://github.com/EikanWang, https://github.com/gujinghui, https://github.com/albanD
ghstack dependencies: #124050
2024-04-18 04:35:25 +00:00
de1c0d2497 [cublas] Keep explicit workspace creation to avoid OOM (#124250)
Summary:
We explicitly set the cublas workspace even though CUDA 12.2+ fixed the issue where memory usage increased during graph capture. Original issue: https://github.com/pytorch/pytorch/pull/83461

This is because in CUDA 12.2+, the use of cudaMallocAsync in cublas will allocate memory dynamically (even if they're cheap) outside PyTorch's CUDA caching allocator. It's possible that CCA used up all the memory and cublas's cudaMallocAsync will return OOM

Test Plan: CI

Differential Revision: D56226746

Pull Request resolved: https://github.com/pytorch/pytorch/pull/124250
Approved by: https://github.com/houseroad, https://github.com/eqy
2024-04-18 04:17:38 +00:00
cc66c43d51 Make macro with AMP more generic (#124050)
# Motivation
According to [[RFC] Intel GPU Upstreaming](https://github.com/pytorch/pytorch/issues/114723), we would like to upstream amp autocast policy to facilitate the functionality and accuracy of `torch.compile` on e2e benchmarks.

# Solution
The first PR aims to make macro `KERNEL` to be generic. It accepts two types of inputs, like `(DISPATCH, OP, POLICY)` and `(DISPATCH, OP, OVERLOAD, POLICY)`.
The second PR intends to refactor CUDA's autocast policy to make it can be shared with `XPU` backend.
The final PR would like to support XPU autocast policy which shares the same recipe with `CUDA` backend.

# Additional Context
Another motivation is we would like to unify autocast API and provide the generic APIs, like:
- `torch.get_autocast_dtype(device_type)`
- `torch.set_autocast_dtype(device_type)`
- `torch.is_autocast_enabled(device_type)`
- `torch.set_autocast_enabled(device_type)`

Pull Request resolved: https://github.com/pytorch/pytorch/pull/124050
Approved by: https://github.com/jgong5, https://github.com/EikanWang, https://github.com/gujinghui, https://github.com/albanD
2024-04-18 01:15:03 +00:00
00372b1211 Extend int[48]mm ops to float32 input (#124287)
Just for completeness

Pull Request resolved: https://github.com/pytorch/pytorch/pull/124287
Approved by: https://github.com/mikekgfb
2024-04-17 23:10:49 +00:00
9875a834e4 [Intel GPU] oneDNN GPU GEMM support (#117202)
# Motivation

This PR is a part of RFC #114848, and it  is a successor PR of #116249 and #116019. This PR would depend on oneDNN compilation in #116249. Some runtime support is needed in #116019.

Aten operators like `addmm`, `baddmm` is defined in `Blas.cpp` in `aten/src/ATen/native/mkldnn/xpu/`.

Accompanied with these files provide core functionaliy, `BlasImpl.h`, `Utils.h` and other file provide basic utilities for them. For instance, `Utils.h` provide common memory descriptor query utils for `Matmul.h` and these utility function will also be used in other primitive, like `convolution`.  `BlasImpl.h` is a header file that provide helper for handling shape info processing in matmul related operators. It would not only help basic GEMM operator like `addmm, baddmm` but also help fusion operators used in `torch.compile` like `linear_pointwise` in #117824.

In next stage, we would continually complete the oneDNN support through enabling  `matmul fusion`  and `convolution` related code.

Co-authored-by: xiaolil1 <xiaoli.liu@intel.com>
Co-authored-by: lei,zhenyuan <zhenyuan.lei@intel.com>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/117202
Approved by: https://github.com/EikanWang, https://github.com/jgong5, https://github.com/malfet
ghstack dependencies: #117098, #117112
2024-04-17 23:06:38 +00:00
cc18afa25f Intel GPU oneDNN upstreaming for primitive integration (#117112)
# Motivation

As proposed in https://github.com/pytorch/pytorch/issues/114848 and https://github.com/pytorch/pytorch/issues/114723, oneDNN library is an important component for Intel GPU software ecosystem.

Current PR is based on #117098, where oneDNN library for Intel GPU should be ready.  This PR is the integration code from aten to oneDNN. GEMM integration code is the core part in this PR. Accompanied with GEMM, more basic support like runtime (device, stream), primitive attr is also included.

We put the oneDNN integration code in directory `aten/src/ATen/native/mkldnn/xpu/detail`. We add a namespace `at::native::xpu::onednn` for oneDNN integration.

The code in this PR would be used in following PRs, where aten operators would call the functions in these integration code.. We separate the prs due to onednn integration is logically separable with aten operator implementation. Also, this can ease the burden of reviewing by avoid too much codes in single PR.

Co-authored-by: xiaolil1 <xiaoli.liu@intel.com>
Co-authored-by: lei,zhenyuan <zhenyuan.lei@intel.com>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/117112
Approved by: https://github.com/EikanWang, https://github.com/jgong5, https://github.com/albanD
2024-04-17 22:49:56 +00:00
46324fe073 Speedup int4mm_kernel with NEON (#124257)
By unrolling middle loop by 16 elements and using neon to decode packed int4 to float32.
  Unrolling entire `n` loop actually makes it a tad slower, probably because ARM has smaller register file that x86
  Before/after performance running stories110M on M2Pro

 | eager (before) | eager (after) | compile(before) | compile (after) |
 | ---- | --- | -- | -- |
 | 28 | 57  | 31 | 104 |

Pull Request resolved: https://github.com/pytorch/pytorch/pull/124257
Approved by: https://github.com/mikekgfb
2024-04-17 16:04:25 +00:00
9b1d6c8d98 improve F.adaptive_avg_pool2d error messages on mps (#124143)
Gives better error messages on mps. Partially fixes #123725 in the case of `F.adaptive_avg_pool2d`. This also relates to #96056.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/124143
Approved by: https://github.com/albanD, https://github.com/malfet
2024-04-17 16:04:09 +00:00
b880a71010 [BE] Add missing std:: prefix to Unique.mm (#124232)
Follow up after https://github.com/pytorch/pytorch/pull/124117 fixes following warning
```
/Users/malfet/git/pytorch/pytorch/aten/src/ATen/native/mps/operations/Unique.mm:282:26: warning: use of function template name with no prior declaration in function call with explicit template arguments is a C++20 extension [-Wc++20-extensions]
  return std::make_tuple(get<0>(out).to("mps"), get<1>(out).to("mps"), get<2>(out).to("mps"));
                         ^
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/124232
Approved by: https://github.com/kit1980, https://github.com/Skylion007
2024-04-17 14:12:29 +00:00
47dbfecd37 Rename impl_abstract to register_fake, part 1/2 (#123937)
This PR:
- adds a new torch.library.register_fake and deprecates
  torch.library.impl_abstract. The motivation is that we have a lot of
  confusion around the naming so we are going to align the naming with
  the actual subsystem (FakeTensor).
- renames `m.impl_abstract_pystub("fbgemm_gpu.sparse_ops")` to
  `m.has_python_registration("fbgemm_gpu.sparse_ops")`. No deprecation
  here yet; I need to test how this works with static initialization.
- Renames a bunch of internals to match (e.g. abstractimplpystub ->
  pystub)

I'm scared to rename the Python-side internal APIs (e.g.
torch._library.abstract_impl) because of torch.package concerns. I'll do
that in its own isolated PR next just in case it causes problems.

DEPRECATION NOTE: torch.library.impl_abstract was renamed to to
torch.library.register_fake. Please use register_fake. We'll delete
impl_abstract in a future version of PyTorch.

Test Plan:
- existing tests
Pull Request resolved: https://github.com/pytorch/pytorch/pull/123937
Approved by: https://github.com/albanD
2024-04-17 12:46:01 +00:00
2dc15b6849 Revert "[sparse] Add fast semi-structured spasification kernels (#122350)"
This reverts commit 14b2273b0c58b4000e10b2e441341eeafb7dd2f6.

Reverted https://github.com/pytorch/pytorch/pull/122350 on behalf of https://github.com/DanilBaibak due to Broken trunk ([comment](https://github.com/pytorch/pytorch/pull/122350#issuecomment-2061070350))
2024-04-17 11:47:02 +00:00
acc466751b Add bfloat16 support to binary_cross_entropy for CPU (#123823)
Fixes #123715

As the title stated.

But, maybe we should pay attention to this https://github.com/pytorch/pytorch/pull/33206, which removed the half support for cpu about 4 years ago.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/123823
Approved by: https://github.com/Skylion007, https://github.com/malfet
2024-04-17 09:44:07 +00:00
ed22dde877 Pointer to the nonzero limit ticket (#124244)
For the nonzero impl limits we are still asking at runtime to fill a new ticket  but we had already more then one.
So I am pointing to the current open ticket.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/124244
Approved by: https://github.com/ezyang
2024-04-17 06:15:36 +00:00
14b2273b0c [sparse] Add fast semi-structured spasification kernels (#122350)
This PR adds in fast semi-structured sparsification kernels to PyTorch.

These kernels allow for accelerated semi-structured sparsification
kernels in PyTorch.

The kernels have been added as aten native functions

In particular, three new functions have been added:

* `torch._sparse_semi_structured_tile`

This function will return the packed representation and metadata for
both X and X', as well as the thread masks. Note that this applies 2:4
sparsity in a 4x4 tile instead of a 1x4 strip as usual.

* `torch._sparse_semi_structured_apply`

This function takes in an input tensor and thread masks from the above
function and returns a packed representation and metadata from applying
thread masks to the input tensor.

* `torch._sparse_semi_structured_apply_dense`

This function does the same thing as above but instead of returning the
tensor in the sparse representation it returns it in the dense
representation

The subclasses have also been updated to add a new
`prune_dense_static_sort`
classmethod to create sparse tensors with this format. I've added some
additional documentatino on how to calculate the compressed tensors
needed to create a SparseSemiStructuredTensor oneself.

To this end, there are two new helper functions added:
`sparse_semi_structured_tile`
`compute_compressed_swizzled_bitmask`

Differential Revision: [D56190801](https://our.internmc.facebook.com/intern/diff/D56190801)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/122350
Approved by: https://github.com/cpuhrsch
2024-04-16 20:31:52 +00:00
1f89bf4188 Revert "[reland] _foreach_copy with different src/dst dtypes (#123844)"
This reverts commit ff1e3ff5a503a520c1a310c8e72a383657f9a4bc.

Reverted https://github.com/pytorch/pytorch/pull/123844 on behalf of https://github.com/malfet due to Perhaps it enabled it for different dtype, but broke for the same ([comment](https://github.com/pytorch/pytorch/pull/123844#issuecomment-2059861767))
2024-04-16 20:23:14 +00:00
72271fb07e Add NEON ISA support on aarch64 (#123584)
Fixes #104729

This improves the compiled mode performance of Softmax (by 20%) and other operations (like batchnorm) that invoke the reduce_all function. Thereby also improves BERT inference by around 8%.

Tested on a graviton 3 instance (c7g.4xl). Tests were run in a single-threaded manner.

Script attached below.
Command: `OMP_NUM_THREADS=1 LRU_CACHE_CAPACITY=1024 DNNL_DEFAULT_FPMATH_MODE=BF16 python TestSoftmax.py`
[TestSoftmax.txt](https://github.com/pytorch/pytorch/files/14910754/TestSoftmax.txt)
```python
import torch
import torch.nn as nn
from torch.profiler import profile, record_function, ProfilerActivity

model = nn.Softmax().eval()
compiled_model = torch.compile(model)
inputs = torch.randn(1024, 1024)

with torch.set_grad_enabled(False):
    for _ in range(50):
        compiled_model(inputs) #Warmup
    print("Warmup over")
    with profile(activities=[ProfilerActivity.CPU]) as prof:
        with record_function("model_inference"):
            for _ in range(100):
                compiled_model(inputs)

print(prof.key_averages().table(sort_by="self_cpu_time_total"))
# Check if the compiled model inference and the eager model inference are similar using torch.allclose
print(torch.allclose(compiled_model(inputs), model(inputs)))
```

Co-authored-by: Nikita Shulga <2453524+malfet@users.noreply.github.com>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/123584
Approved by: https://github.com/jgong5, https://github.com/malfet
2024-04-16 18:49:52 +00:00
83ef3bb128 Fix AVX512 int4pack_mm_kernel crash if weighs are unaligned (#124128)
By replacing `_mm256_load_si256` with `_mm256_loadu_si256`, as there are no guarantees that tensor should be aligned

Fixes crash reported in https://github.com/pytorch/pytorch/issues/124034 though I'm unsure about perf implications if tensor are properly aligned

Pull Request resolved: https://github.com/pytorch/pytorch/pull/124128
Approved by: https://github.com/mikekgfb
2024-04-16 04:35:25 +00:00
ff1e3ff5a5 [reland] _foreach_copy with different src/dst dtypes (#123844)
Attempt to reland https://github.com/pytorch/pytorch/pull/121717.
The change is the array bounds check.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/123844
Approved by: https://github.com/janeyx99
2024-04-16 02:20:58 +00:00
a4c8002ee0 MPS FFT implementation bug (#123274)
Current implementation drops the negative frequency components even when the user doesn't ask for the one-sided transform. The tests for the negative frequency components seem to have worked by accident due to internal implementation details but the issue becomes evident in MacOs 14.4.
Co-authored-by: Nikita Shulga <2453524+malfet@users.noreply.github.com>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/123274
Approved by: https://github.com/malfet
2024-04-16 02:02:37 +00:00
eeb626b46a [BE] Do not use using namespace in mps headers (#124117)
- Remove `using namespace std` from `MPSDevice.h`
- Add `std::` prefix to 1st argument of `MPSProfiler::StartTrace`
- Do the same in front of `numeric_limits` template instantiation in `ReduceOps.mm`
Co-authored-by: Nikita Shulga <2453524+malfet@users.noreply.github.com>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/124117
Approved by: https://github.com/malfet
2024-04-16 01:39:42 +00:00
bd222473fc [EZ][BE] Fix unknown pragma warning (#124086)
By using `C10_DIAGNOSTIC_` macros instead of `#pragma clang diagnostic` that puts appropriate compiler supported pragmas. Fixes following warning during the bazel build
```
INFO: From Compiling aten/src/ATen/native/TensorFactories.cpp:
aten/src/ATen/native/TensorFactories.cpp:372: warning: ignoring #pragma clang diagnostic [-Wunknown-pragmas]
  372 | #pragma clang diagnostic push
      |
aten/src/ATen/native/TensorFactories.cpp:373: warning: ignoring #pragma clang diagnostic [-Wunknown-pragmas]
  373 | #pragma clang diagnostic ignored "-Wmissing-prototypes"
      |
aten/src/ATen/native/TensorFactories.cpp:375: warning: ignoring #pragma clang diagnostic [-Wunknown-pragmas]
  375 | #pragma clang diagnostic pop
      |
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/124086
Approved by: https://github.com/kit1980, https://github.com/seemethere, https://github.com/Skylion007
2024-04-15 21:44:31 +00:00
221c397e2e Use NEON to speedup int8pack_mm on aarch64 (#124023)
Just vectorizing innter loop as follows:
```cpp
float32x4_t c_val = vdupq_n_f32(0.0);
for (int k = 0; k < K; k += 8) {
  float16x8_t a_val = vld1q_f16(reinterpret_cast<const float16_t *>(A) + m * lda + k);
  int16x8_t b_val = vmovl_s8(vld1_s8(B + n * ldb + k));
  auto a_val_low = vcvt_f32_f16(vget_low_f16(a_val));
  auto a_val_high = vcvt_f32_f16(vget_high_f16(a_val));
  auto b_val_low = vcvtq_f32_s32(vmovl_s16(vget_low_s16(b_val)));
  auto b_val_high = vcvtq_f32_s32(vmovl_s16(vget_high_s16(b_val)));
  c_val = vaddq_f32(c_val, vmulq_f32(a_val_low, b_val_low));
  c_val = vaddq_f32(c_val, vmulq_f32(a_val_high, b_val_high));
}
float scale_val = static_cast<float>(scales[n]);
C[m * ldc + n] = reduce(c_val) * scale_val;
```

Which bumps perf from 35 to 58 tokens per second (65% perf gain).
Unrolling both inner and outer loops bumps perf to 64 tokens per sec
(i.e. another 10% gain)

Before/after performance running stories110M on M2Pro
| eager (before) | eager (after) | compile(before) | compile (after) |
| ---- | --- | -- | -- |
| 35 | 64  | 56 | 132 |

Pull Request resolved: https://github.com/pytorch/pytorch/pull/124023
Approved by: https://github.com/mikekgfb
ghstack dependencies: #124022
2024-04-15 18:57:59 +00:00
a096e99a5d Enable int8mm kernel for float16 (#124022)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/124022
Approved by: https://github.com/mikekgfb
2024-04-14 19:48:43 +00:00
f5331aade5 Simplify ATen sparse semi-structured operators based on CUTLASS (#123473)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/123473
Approved by: https://github.com/cpuhrsch
2024-04-14 06:57:41 +00:00
97261be0a8 Revert "Simplify ATen sparse semi-structured operators based on CUTLASS (#123473)"
This reverts commit b2a0b8c446234f0b35a66aff87501c4596ea5d51.

Reverted https://github.com/pytorch/pytorch/pull/123473 on behalf of https://github.com/DanilBaibak due to Break internal build ([comment](https://github.com/pytorch/pytorch/pull/123473#issuecomment-2053561077))
2024-04-13 07:47:32 +00:00
7e3f80f00f accelerate binary_cross_entropy_with_logits (#122789)
Following https://github.com/pytorch/pytorch/pull/115539

Same benchmark in #115539:
|avg time (ms)|with `pos_weight`|no `pos_weight`|
|-|-|-|
|before #115539 |2049|1736|
|after #115539    |1320|1049|
|this PR               |907  |801|

This PR is faster 24-31% than the version after #115539.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/122789
Approved by: https://github.com/peterbell10
2024-04-13 04:18:47 +00:00
e4c887fbf6 [AOTAutograd] Replay views on output using FunctionalTensor metas. (#121007)
Fix: #120336

This PR fixes an issue on AOTAutograd, specifically on backends that don't support views
by themselves (e.g. XLA). Previously, AOTAutograd tried to reconstruct output views by
calling `as_strided` on the concrete bases using sizes and strides of the outputs that
aliased them. Since backends such as XLA doesn't support tensor aliasing, the sizes and
strides would be that of a contiguous tensor (not a view tensor). Because of that, calling
`as_strided` would error, since the output tensor would be bigger than its base. Instead,
this PR applies the sequence of `ViewMeta` gathered for each output during the
functionalization phase.

**Note:** we intentionally don't support base tensors that went through metadata mutation,
i.e. in-place view operations.

In summary, this PR:

- Introduces one `FunctionalTensorWrapper` member function alongside its Python APIs
    - `apply_view_metas(base)`: applies the `ViewMeta` sequence of the given instance onto
      another base
- Introduces a `OutputAliasInfo.functional_tensor` field
    - Saves the `FunctionalTensorWrapper` instance collected by the functionalization phase
    - Wraps it with a new `FunctionalTensorMetadataEq` class for comparing only the
      metadata of the tensors
- Plumbs `OutputAliasInfo.functional_tensor` to `gen_alias_from_base` function
    - Applies the `ViewMeta` sequence of the saved `FunctionalTensor` onto `aliased_base_tensor`
- Propagates `OutputAliasInfo.functional_tensor` when updating `fw_metadata`

(this PR description was updated in order to reflect the most recent changes)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/121007
Approved by: https://github.com/bdhirsh
2024-04-12 16:54:13 +00:00
757daece95 [sparse] add meta support for add operation (and copy) (#123594)
This is a small step towards #117188
@pearu to review (this was split of #117907)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/123594
Approved by: https://github.com/pearu, https://github.com/peterbell10
2024-04-12 15:50:30 +00:00
2cb3301f80 [ROCm] Add cast to kFloat in amax calculation (#123872)
necessary cast to kFloat missed in previous amax PR

Pull Request resolved: https://github.com/pytorch/pytorch/pull/123872
Approved by: https://github.com/drisspg
2024-04-12 15:38:41 +00:00
3120dbbf81 Revert "[sparse] Add fast semi-structured spasification kernels (#122350)"
This reverts commit aaec97a40364bb6ccfd968f28d309cfff8748d20.

Reverted https://github.com/pytorch/pytorch/pull/122350 on behalf of https://github.com/DanilBaibak due to Break internal build ([comment](https://github.com/pytorch/pytorch/pull/122350#issuecomment-2051757450))
2024-04-12 13:26:10 +00:00