dd2c22f4bb
bsr_dense_bmm(): enable more precise float32 support with float64 accumulators ( #100882 )
...
Float64 is there in Triton! This PR increases precision for float32 inputs with float64 accumulation dtype.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/100882
Approved by: https://github.com/cpuhrsch
2023-05-11 11:22:55 +00:00
0141a242fd
bsr_dense_bmm(): remove sparse_rowspace kernel and some dead code ( #100876 )
...
Pull Request resolved: https://github.com/pytorch/pytorch/pull/100876
Approved by: https://github.com/cpuhrsch , https://github.com/Skylion007
2023-05-09 16:12:11 +00:00
c4bc259f00
bsr_dense_mm(): better test coverage ( #100543 )
...
This PR improves test coverage for `bsr_dense_mm` by:
- ~~enabling correctness tests for `float32`~~.
- extending and testing input correctness checks.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/100543
Approved by: https://github.com/cpuhrsch , https://github.com/malfet
2023-05-09 09:26:02 +00:00
cd8b82e5c6
bsr_dense_mm(): code refactoring ( #100634 )
...
Code unification/refactoring for better re-use. Intended for easier `sampled_addmm` implementation.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/100634
Approved by: https://github.com/cpuhrsch
2023-05-08 13:27:39 +00:00
05dda7ff65
bsr_dense_mm Triton kernel: fix out kwarg ( #96648 )
...
As per title. The kernel did not handle `out=` correctly and returned a different tensor which only shared storage with `out`.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/96648
Approved by: https://github.com/cpuhrsch
2023-03-14 18:01:22 +00:00
76cac70939
new triton main pin ( #95896 )
...
Fixes #ISSUE_NUMBER
Pull Request resolved: https://github.com/pytorch/pytorch/pull/95896
Approved by: https://github.com/jansel , https://github.com/malfet
2023-03-10 06:30:41 +00:00
d0731271cd
Revert "new triton main pin ( #95896 )"
...
This reverts commit 6e0359dd4233b0cec51521bec8869f0a46ebd98b.
Reverted https://github.com/pytorch/pytorch/pull/95896 on behalf of https://github.com/huydhn due to I am not quite sure what this is about yet, but testing 3.8 wheel starts to fail 6e0359dd42
2023-03-10 05:41:45 +00:00
6e0359dd42
new triton main pin ( #95896 )
...
Fixes #ISSUE_NUMBER
Pull Request resolved: https://github.com/pytorch/pytorch/pull/95896
Approved by: https://github.com/jansel
2023-03-10 03:40:37 +00:00
d809020fc8
Triton kernel for bsr @ dense ( #94823 )
...
Pull Request resolved: https://github.com/pytorch/pytorch/pull/94823
Approved by: https://github.com/cpuhrsch , https://github.com/malfet
2023-03-03 15:11:28 +00:00
7012d985fa
Revert "Improve bsr @ strided
performance in baddmm
for bfloat16/half
with Triton kernels. ( #88078 )"
...
This reverts commit 46f16b93636615a81242b0d5cded84c5a57fd2e2.
Reverted https://github.com/pytorch/pytorch/pull/88078 on behalf of https://github.com/ZainRizvi due to Causing a test to fail consistently: test_decomp.py::HasDecompTest::test_has_decomposition
2023-01-26 16:22:29 +00:00
46f16b9363
Improve bsr @ strided
performance in baddmm
for bfloat16/half
with Triton kernels. ( #88078 )
...
As per title.
Additionally we also introduce support for:
- Rectangular block sizes which are powers of 2 and at least 16 (triton's `dot` limitation).
- Batch support with broadcasting for either of the arguments.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/88078
Approved by: https://github.com/cpuhrsch
2023-01-26 07:58:27 +00:00
60bf851931
Revert "Improve bsr @ strided
performance in baddmm
for bfloat16/half
with Triton kernels. ( #88078 )"
...
This reverts commit 8383b5c488399f2ae295c7c0f993bdd353dfd75c.
Reverted https://github.com/pytorch/pytorch/pull/88078 on behalf of https://github.com/malfet due to This seems to have broke sm_86 testing, see https://hud.pytorch.org/hud/pytorch/pytorch/master/1?per_page=50&name_filter=sm86%20%2F%20test%20 (default%2C%203
2023-01-19 23:37:59 +00:00
8383b5c488
Improve bsr @ strided
performance in baddmm
for bfloat16/half
with Triton kernels. ( #88078 )
...
As per title.
Additionally we also introduce support for:
- Rectangular block sizes which are powers of 2 and at least 16 (triton's `dot` limitation).
- Batch support with broadcasting for either of the arguments.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/88078
Approved by: https://github.com/cpuhrsch
2023-01-19 03:14:54 +00:00
89f1ad08b4
Revert "Improve bsr @ strided
performance in baddmm
for bfloat16/half
with Triton kernels. ( #88078 )"
...
This reverts commit 7f256fff77c49729131aa6d092e60e891d0c4948.
Reverted https://github.com/pytorch/pytorch/pull/88078 on behalf of https://github.com/huydhn due to This breaks lint 7f256fff77
2023-01-17 22:14:37 +00:00
7f256fff77
Improve bsr @ strided
performance in baddmm
for bfloat16/half
with Triton kernels. ( #88078 )
...
As per title.
Additionally we also introduce support for:
- Rectangular block sizes which are powers of 2 and at least 16 (triton's `dot` limitation).
- Batch support with broadcasting for either of the arguments.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/88078
Approved by: https://github.com/cpuhrsch
2023-01-17 21:43:20 +00:00