mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 12:54:11 +08:00
main
17271 Commits
Author | SHA1 | Message | Date | |
---|---|---|---|---|
602ace5eb4 |
Revert "[ATen] Fix CUDA reduction warp shuffle order (#164790)"
This reverts commit 36371b8ec7a1baed255c18451b2c716386a54c95. Reverted https://github.com/pytorch/pytorch/pull/164790 on behalf of https://github.com/clee2000 due to was reverted due to failing internal tests after merge D84992607 ([comment](https://github.com/pytorch/pytorch/pull/164790#issuecomment-3420373755)) |
|||
8a8329b51f |
[ATen] Switch order of blocked reduce when vectorize loads (#165178)
Performance benchmarking, perf neutral: ``` ================================================================================================================================================================================================================================================ Tensor Shape Operation Full reduce (ms) Non-Contig dim (ms) Contig dim (ms) Full reduce (ms) Non-Contig dim (ms) Contig dim (ms) Full diff % Non-Contig diff % Contig diff % ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ (256, 256) mean 0.015684 0.017056 0.008287 0.016015 0.016929 0.008170 -2.07% +0.75% +1.43% (256, 256) sum 0.015774 0.016638 0.007926 0.015811 0.016935 0.008330 -0.23% -1.75% -4.85% ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ (512, 512) mean 0.013385 0.025742 0.008629 0.013046 0.026005 0.008924 +2.60% -1.01% -3.31% (512, 512) sum 0.013390 0.026059 0.009116 0.013054 0.025696 0.008952 +2.57% +1.41% +1.83% ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ (1024, 1024) mean 0.014213 0.015467 0.010334 0.013862 0.015082 0.010318 +2.53% +2.55% +0.16% (1024, 1024) sum 0.014179 0.015446 0.010774 0.014132 0.015073 0.010350 +0.33% +2.47% +4.10% ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ (2048, 2048) mean 0.018234 0.019487 0.014812 0.018482 0.019397 0.014802 -1.34% +0.46% +0.07% (2048, 2048) sum 0.018202 0.019529 0.015195 0.018122 0.019485 0.015129 +0.44% +0.23% +0.44% ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ (4096, 4096) mean 0.033582 0.039378 0.030751 0.033810 0.039673 0.031019 -0.67% -0.74% -0.86% (4096, 4096) sum 0.033604 0.039777 0.030809 0.033530 0.039386 0.031113 +0.22% +0.99% -0.98% ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ (8192, 8192) mean 0.085824 0.091133 0.084200 0.085431 0.091364 0.084303 +0.46% -0.25% -0.12% (8192, 8192) sum 0.085763 0.091442 0.084180 0.085508 0.091419 0.084595 +0.30% +0.03% -0.49% ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ (8192, 16384) mean 0.146480 0.147666 0.138807 0.146515 0.147987 0.138930 -0.02% -0.22% -0.09% (8192, 16384) sum 0.146446 0.147593 0.138559 0.146151 0.147982 0.139120 +0.20% -0.26% -0.40% ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ (8192, 32768) mean 0.266047 0.265386 0.253837 0.265648 0.265885 0.253652 +0.15% -0.19% +0.07% (8192, 32768) sum 0.266093 0.265421 0.253890 0.265458 0.265591 0.253567 +0.24% -0.06% +0.13% ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ (8192, 65536) mean 0.498632 0.508976 0.481865 0.498237 0.508777 0.481476 +0.08% +0.04% +0.08% (8192, 65536) sum 0.498917 0.508202 0.481883 0.498104 0.508016 0.481972 +0.16% +0.04% -0.02% ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ (8192, 131072) mean 0.957633 0.968519 0.938172 0.956766 0.968267 0.938196 +0.09% +0.03% -0.00% (8192, 131072) sum 0.956972 0.968140 0.937741 0.957365 0.968404 0.938056 -0.04% -0.03% -0.03% ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ (8192, 262144) mean 1.906661 1.928377 1.861846 1.907327 1.928811 1.862083 -0.03% -0.02% -0.01% (8192, 262144) sum 1.905976 1.928362 1.862399 1.907098 1.928844 1.861782 -0.06% -0.02% +0.03% ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ (4096, 262144) mean 0.956852 0.970101 0.936524 0.957263 0.969809 0.936965 -0.04% +0.03% -0.05% (4096, 262144) sum 0.957117 0.969933 0.936247 0.956675 0.969451 0.936395 +0.05% +0.05% -0.02% ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ (2048, 262144) mean 0.498813 0.511299 0.483415 0.498567 0.511482 0.483376 +0.05% -0.04% +0.01% (2048, 262144) sum 0.498813 0.510834 0.483641 0.498875 0.511036 0.483338 -0.01% -0.04% +0.06% ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ (1024, 262144) mean 0.266157 0.276751 0.255192 0.265966 0.276808 0.255544 +0.07% -0.02% -0.14% (1024, 262144) sum 0.266133 0.276709 0.255528 0.265658 0.276685 0.255287 +0.18% +0.01% +0.09% ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ (512, 131072) mean 0.085941 0.081184 0.087931 0.085591 0.080832 0.088008 +0.41% +0.44% -0.09% (512, 131072) sum 0.085962 0.081107 0.088045 0.085882 0.081160 0.088024 +0.09% -0.07% +0.02% ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ (1000, 1000) mean 0.014203 0.045859 0.010310 0.013885 0.046132 0.010621 +2.29% -0.59% -2.93% (1000, 1000) sum 0.014180 0.046165 0.010756 0.013893 0.046109 0.010338 +2.07% +0.12% +4.04% ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ (1024, 129) mean 0.012953 0.016751 0.008536 0.012977 0.016714 0.008916 -0.18% +0.22% -4.26% (1024, 129) sum 0.013356 0.016806 0.008722 0.013003 0.017071 0.008611 +2.71% -1.55% +1.29% ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ (1024, 257) mean 0.013075 0.016787 0.009102 0.013116 0.016769 0.008679 -0.31% +0.11% +4.87% (1024, 257) sum 0.013092 0.016842 0.008786 0.013126 0.017128 0.008771 -0.26% -1.67% +0.17% ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ (1024, 587) mean 0.013662 0.017412 0.010055 0.013659 0.017019 0.010033 +0.02% +2.31% +0.22% (1024, 587) sum 0.013636 0.017473 0.010163 0.013642 0.017363 0.010101 -0.04% +0.63% +0.61% ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ (2048, 977) mean 0.015276 0.027873 0.012531 0.015241 0.027783 0.012467 +0.23% +0.32% +0.51% (2048, 977) sum 0.015345 0.027949 0.012192 0.015255 0.027839 0.012485 +0.59% +0.40% -2.35% ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ (1024, 128) mean 0.012806 0.014020 0.008291 0.013137 0.014309 0.007908 -2.52% -2.02% +4.84% (1024, 128) sum 0.012769 0.014308 0.007924 0.012788 0.014236 0.008038 -0.15% +0.51% -1.42% ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ (8192, 128) mean 0.014145 0.023049 0.009143 0.014104 0.023298 0.009501 +0.29% -1.07% -3.77% (8192, 128) sum 0.014132 0.023082 0.009638 0.014107 0.023331 0.009244 +0.18% -1.07% +4.26% ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ (1024, 130) mean 0.013420 0.025834 0.008949 0.013368 0.025724 0.008918 +0.39% +0.43% +0.35% (1024, 130) sum 0.013300 0.025940 0.009113 0.013266 0.025419 0.008922 +0.26% +2.05% +2.14% ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ (8192, 130) mean 0.013993 0.017883 0.009661 0.014275 0.018220 0.009596 -1.98% -1.85% +0.68% (8192, 130) sum 0.014026 0.018297 0.010066 0.014326 0.018257 0.009659 -2.09% +0.22% +4.21% ================================================================================================================================================================================================================================================ ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/165178 Approved by: https://github.com/ngimel ghstack dependencies: #165494, #164790, #165055 |
|||
a1114beed2 |
Deprecate overlapped functions in CUDAAllocatorConfig (#165289)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/165289 Approved by: https://github.com/albanD ghstack dependencies: #165288 |
|||
35e51893bd |
Remove CUDA 11 workarounds for CUB_SUPPORTS_SCAN_BY_KEY and CUB_SUPPORTS_UNIQUE_BY_KEY (#164637)
`CUB_SUPPORTS_SCAN_BY_KEY` and `CUB_SUPPORTS_UNIQUE_BY_KEY` are true since CUDA 12. This PR removes the old branches and source files. Pull Request resolved: https://github.com/pytorch/pytorch/pull/164637 Approved by: https://github.com/ezyang |
|||
d14cbb4476 |
Add NVFP4 two-level scaling to scaled_mm (#165774)
Summary: * Add second-level scaling dispatch to scaled_mm, tying into optional `alpha` passing * Add two-level tests Test Plan: ``` pytest -svv -k "nvfp4_global_scale" test/test_scaled_matmul_cuda.py ``` Reviewers: Subscribers: Tasks: Tags: Signed-off-by: Simon Layton <simonlayton@meta.com> Pull Request resolved: https://github.com/pytorch/pytorch/pull/165774 Approved by: https://github.com/drisspg |
|||
ad67170c8b |
[MPS] sparse matmuls (#165232)
Implements matmuls for sparse tensors. With this commit most of the core sparse operations should be implemented. Fixes: https://github.com/pytorch/pytorch/issues/156540 https://github.com/pytorch/pytorch/issues/129842 Should be merged after: https://github.com/pytorch/pytorch/pull/165102 To compare MPS and CPU, you can use this script: ```python import torch import time import matplotlib.pyplot as plt B, I, J, K = 8, 20000, 20000, 20000 num_iterations = 500 nnz_values = [10, 50, 100, 200, 500, 1000, 2000, 5000, 10000, 20000, 100000] speedups = [] for nnz in nnz_values: indices = torch.stack([ torch.randint(0, B, (nnz,)), torch.randint(0, I, (nnz,)), torch.randint(0, J, (nnz,)), ]) values = torch.rand(nnz) sparse = torch.sparse_coo_tensor(indices, values, size=(B, I, J), device="mps").coalesce() dense = torch.randn(B, J, 200, device="mps") t1 = time.time() for _ in range(num_iterations): result = torch.bmm(sparse, dense) torch.mps.synchronize() t2 = time.time() mps_time = (t2 - t1) / num_iterations sparse_cpu = sparse.cpu() dense_cpu = dense.cpu() t1 = time.time() for _ in range(num_iterations): result_cpu = torch.bmm(sparse_cpu, dense_cpu) t2 = time.time() cpu_time = (t2 - t1) / num_iterations speedup = cpu_time / mps_time speedups.append(speedup) print(f"nnz={nnz}: MPS={mps_time:.6f}s, CPU={cpu_time:.6f}s, Speedup={speedup:.2f}x") plt.figure(figsize=(10, 6)) plt.plot(nnz_values, speedups, marker='o', linewidth=2, markersize=8) plt.xlabel('Number of Non-Zero Elements (nnz)', fontsize=12) plt.ylabel('Speedup (CPU time / MPS time)', fontsize=12) plt.title('MPS vs CPU Speedup for Sparse-Dense BMM', fontsize=14) plt.grid(True, alpha=0.3) plt.axhline(y=1, color='r', linestyle='--', alpha=0.5) plt.xscale('log') plt.tight_layout() plt.show() ``` ## Tested on M1 Pro <img width="1000" height="600" alt="Figure_1" src="https://github.com/user-attachments/assets/4a2402ec-3dc4-402d-8196-a0426906ca3d" /> Pull Request resolved: https://github.com/pytorch/pytorch/pull/165232 Approved by: https://github.com/malfet |
|||
23417ae50f |
[Submodule] Bump FBGEMM to latest (#165544)
Summary: * FBGEMM submodule updated to main * CMake updated to reflect necessary changes * Notably pulls in NVFP4 grouped gemm kernels Test Plan: Reviewers: Subscribers: Tasks: Tags: Signed-off-by: Simon Layton <simonlayton@meta.com> Pull Request resolved: https://github.com/pytorch/pytorch/pull/165544 Approved by: https://github.com/cyyever, https://github.com/jeffdaily |
|||
de09bab4b6 |
[BE]: Update cudnn frontend submodule to 1.15.0 (#165776)
Update cudnn frontend submodule to 1.15.0 Pull Request resolved: https://github.com/pytorch/pytorch/pull/165776 Approved by: https://github.com/eqy |
|||
29b029648e |
Fixed issue with GradTrackingTensor not properly propagating sparse layout (#165765)
Fixes #164286 Fixed issue with GradTrackingTensor not properly propagating sparse layout. @ezyang @jcaip Pull Request resolved: https://github.com/pytorch/pytorch/pull/165765 Approved by: https://github.com/ezyang |
|||
9c12651417 |
Improve error message for non-positive groups in convolution (#165669)
Prevents from segmentation fault for invalid groups value in convolution. Fixes #142835 Pull Request resolved: https://github.com/pytorch/pytorch/pull/165669 Approved by: https://github.com/mikaylagawarecki |
|||
faff826a46 |
Revert "[ROCm] new implementation of upsample_bilinear2d_backward (#164572)"
This reverts commit 53f9ae0e50d4dcc47f2ca4bf854803f9d4f875ae. Reverted https://github.com/pytorch/pytorch/pull/164572 on behalf of https://github.com/seemethere due to Looks like this is failing in our internal builds, will post a suggestion for a fix but want you to double verify that this behavior is correct ([comment](https://github.com/pytorch/pytorch/pull/164572#issuecomment-3416262676)) |
|||
935ccdbe75 |
[MPS] Fix internal assertion in torch.linalg.solve for singular matrices (#165254)
Fixes #163962 by special casing MPS in the negative status code branch in `_linalg_check_errors`. Checks if info is [`MPSMatrixDecompositionStatus.singular`](https://developer.apple.com/documentation/metalperformanceshaders/mpsmatrixdecompositionstatus/singular) (which has a raw value of -2). I didn't find an official Apple source with this raw value (besides printing the enum value), so I'm not sure if we can (or should) depend on it? Is there a way to directly get the Objective-C enum value in C++? Pull Request resolved: https://github.com/pytorch/pytorch/pull/165254 Approved by: https://github.com/malfet |
|||
ce29d0d796 |
[ATen] Vectorize 8 elements on 16 bit data types for sum/mean (#165055)
Benchmarks for a full reduction + reduction on the contiguous dimension. Vectorized loads do not occur on the non contiguous dimension. Benchmarking done for FP16/BF16, ~6% improvement on average across shapes, up to ~24% for single reduction on contiguous dimension and 46% for full reduce: **BF16** ``` Tensor Shape Operation Full reduce (ms) Contiguous dim (ms) Full reduce (ms) Contiguous dim (ms) Full reduce diff % Contiguous diff % ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ (256, 256) mean 0.022686 0.008263 0.015498 0.008117 +46.38% +1.80% (256, 256) sum 0.022769 0.008269 0.015628 0.008185 +45.69% +1.03% ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ (512, 512) mean 0.014116 0.009545 0.012892 0.008839 +9.49% +7.99% (512, 512) sum 0.014110 0.009892 0.012891 0.008878 +9.46% +11.42% ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ (1024, 1024) mean 0.014727 0.012642 0.014061 0.010519 +4.74% +20.18% (1024, 1024) sum 0.014376 0.012636 0.014069 0.010595 +2.18% +19.26% ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ (2048, 2048) mean 0.018663 0.018294 0.018171 0.014678 +2.71% +24.64% (2048, 2048) sum 0.018638 0.017931 0.018142 0.014713 +2.73% +21.87% ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ (4096, 4096) mean 0.034216 0.036953 0.033520 0.030585 +2.08% +20.82% (4096, 4096) sum 0.034196 0.036942 0.033518 0.030676 +2.02% +20.43% ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ (8192, 8192) mean 0.087763 0.095201 0.085439 0.084960 +2.72% +12.05% (8192, 8192) sum 0.088079 0.095592 0.085353 0.084632 +3.19% +12.95% ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ (8192, 16384) mean 0.148174 0.149705 0.146274 0.138865 +1.30% +7.81% (8192, 16384) sum 0.147820 0.149371 0.146419 0.138752 +0.96% +7.65% ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ (8192, 32768) mean 0.266144 0.260807 0.265953 0.253330 +0.07% +2.95% (8192, 32768) sum 0.266572 0.261163 0.265729 0.253294 +0.32% +3.11% ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ (8192, 65536) mean 0.502034 0.486312 0.498417 0.481246 +0.73% +1.05% (8192, 65536) sum 0.501597 0.486351 0.497735 0.481579 +0.78% +0.99% ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ (8192, 131072) mean 0.971178 0.942988 0.957164 0.938316 +1.46% +0.50% (8192, 131072) sum 0.971189 0.943232 0.956814 0.937816 +1.50% +0.58% ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ (8192, 262144) mean 1.953728 1.877648 1.904937 1.861692 +2.56% +0.86% (8192, 262144) sum 1.953969 1.877538 1.905990 1.862547 +2.52% +0.80% ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ (4096, 262144) mean 0.970408 0.940965 0.957871 0.936732 +1.31% +0.45% (4096, 262144) sum 0.970919 0.941652 0.957765 0.936676 +1.37% +0.53% ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ (2048, 262144) mean 0.501477 0.486976 0.497964 0.483570 +0.71% +0.70% (2048, 262144) sum 0.501955 0.487213 0.498210 0.483218 +0.75% +0.83% ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ (1024, 262144) mean 0.266536 0.257111 0.265642 0.255439 +0.34% +0.65% (1024, 262144) sum 0.266613 0.257096 0.265427 0.255472 +0.45% +0.64% ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ (512, 131072) mean 0.087805 0.091200 0.085818 0.087851 +2.32% +3.81% (512, 131072) sum 0.087788 0.091249 0.085373 0.087944 +2.83% +3.76% ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ (1000, 1000) mean 0.014503 0.012328 0.013663 0.010190 +6.15% +20.98% (1000, 1000) sum 0.014545 0.012378 0.013662 0.010579 +6.46% +17.01% ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ (1024, 129) mean 0.014163 0.008371 0.012893 0.008828 +9.85% -5.18% (1024, 129) sum 0.014132 0.008751 0.013234 0.008868 +6.79% -1.32% ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ (1024, 257) mean 0.014296 0.009101 0.013334 0.008563 +7.21% +6.28% (1024, 257) sum 0.014302 0.009058 0.013020 0.008672 +9.85% +4.45% ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ (1024, 587) mean 0.014127 0.010997 0.013443 0.009944 +5.09% +10.59% (1024, 587) sum 0.014471 0.011373 0.013123 0.010354 +10.27% +9.84% ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ (2048, 977) mean 0.015607 0.013566 0.015089 0.012152 +3.43% +11.64% (2048, 977) sum 0.015953 0.013580 0.015039 0.011861 +6.08% +14.49% ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ (1024, 128) mean 0.013982 0.008058 0.012747 0.008139 +9.69% -1.00% (1024, 128) sum 0.013967 0.008071 0.012726 0.007859 +9.75% +2.70% ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ (8192, 128) mean 0.014378 0.009627 0.013712 0.009395 +4.86% +2.47% (8192, 128) sum 0.014389 0.009965 0.013718 0.009521 +4.89% +4.66% ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ (1024, 130) mean 0.014156 0.008267 0.012895 0.008833 +9.78% -6.41% (1024, 130) sum 0.013797 0.008277 0.012903 0.008512 +6.93% -2.76% ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ (8192, 130) mean 0.014977 0.010026 0.013911 0.009876 +7.66% +1.52% (8192, 130) sum 0.014994 0.010043 0.014235 0.009604 +5.33% +4.57% ==================================================================================================================================================================================== ``` **FP16** ``` Tensor Shape Operation Full reduce (ms) Contiguous dim (ms) Full reduce (ms) Contiguous dim (ms) Full reduce diff % Contiguous diff % ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ (256, 256) mean 0.022804 0.008298 0.015888 0.007848 +43.53% +5.73% (256, 256) sum 0.023215 0.008328 0.015677 0.007850 +48.08% +6.09% ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ (512, 512) mean 0.013777 0.009988 0.012884 0.008512 +6.93% +17.34% (512, 512) sum 0.013775 0.009622 0.012870 0.009028 +7.03% +6.58% ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ (1024, 1024) mean 0.014740 0.012322 0.013708 0.010239 +7.53% +20.34% (1024, 1024) sum 0.014762 0.012756 0.013722 0.010307 +7.58% +23.76% ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ (2048, 2048) mean 0.018700 0.018364 0.018135 0.015078 +3.12% +21.79% (2048, 2048) sum 0.018276 0.018415 0.018471 0.015127 -1.06% +21.74% ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ (4096, 4096) mean 0.034518 0.037000 0.033838 0.030617 +2.01% +20.85% (4096, 4096) sum 0.034569 0.037448 0.033842 0.031100 +2.15% +20.41% ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ (8192, 8192) mean 0.087675 0.095176 0.085328 0.084105 +2.75% +13.16% (8192, 8192) sum 0.088102 0.095211 0.085707 0.084090 +2.79% +13.23% ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ (8192, 16384) mean 0.147800 0.149263 0.146388 0.138390 +0.96% +7.86% (8192, 16384) sum 0.148147 0.148957 0.146439 0.138801 +1.17% +7.32% ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ (8192, 32768) mean 0.266316 0.260294 0.265829 0.253411 +0.18% +2.72% (8192, 32768) sum 0.266562 0.260717 0.265744 0.253308 +0.31% +2.92% ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ (8192, 65536) mean 0.502035 0.486077 0.498139 0.481374 +0.78% +0.98% (8192, 65536) sum 0.501571 0.485733 0.498353 0.481350 +0.65% +0.91% ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ (8192, 131072) mean 0.971343 0.943016 0.956600 0.938622 +1.54% +0.47% (8192, 131072) sum 0.971463 0.942991 0.957352 0.938334 +1.47% +0.50% ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ (8192, 262144) mean 1.952722 1.877165 1.906406 1.861455 +2.43% +0.84% (8192, 262144) sum 1.952634 1.876388 1.904677 1.861282 +2.52% +0.81% ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ (4096, 262144) mean 0.970697 0.941298 0.956964 0.936160 +1.44% +0.55% (4096, 262144) sum 0.969981 0.941078 0.957016 0.936260 +1.35% +0.51% ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ (2048, 262144) mean 0.501577 0.487208 0.498422 0.483493 +0.63% +0.77% (2048, 262144) sum 0.502029 0.487124 0.497854 0.483643 +0.84% +0.72% ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ (1024, 262144) mean 0.266416 0.257383 0.265928 0.255140 +0.18% +0.88% (1024, 262144) sum 0.266434 0.257081 0.265817 0.255143 +0.23% +0.76% ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ (512, 131072) mean 0.087858 0.091296 0.085816 0.087745 +2.38% +4.05% (512, 131072) sum 0.088144 0.091314 0.085664 0.087864 +2.90% +3.93% ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ (1000, 1000) mean 0.014977 0.012393 0.014141 0.010614 +5.91% +16.76% (1000, 1000) sum 0.014589 0.012804 0.014118 0.010320 +3.34% +24.07% ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ (1024, 129) mean 0.014208 0.008383 0.013273 0.008440 +7.04% -0.68% (1024, 129) sum 0.013804 0.008863 0.013265 0.009003 +4.06% -1.56% ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ (1024, 257) mean 0.014378 0.009109 0.013037 0.009038 +10.29% +0.79% (1024, 257) sum 0.014387 0.009113 0.013396 0.008698 +7.40% +4.77% ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ (1024, 587) mean 0.014207 0.011037 0.013182 0.010391 +7.78% +6.22% (1024, 587) sum 0.014588 0.011453 0.013539 0.010049 +7.75% +13.97% ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ (2048, 977) mean 0.016024 0.013614 0.015448 0.011845 +3.73% +14.93% (2048, 977) sum 0.015990 0.014033 0.015406 0.012278 +3.79% +14.29% ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ (1024, 128) mean 0.014037 0.007804 0.013143 0.008242 +6.80% -5.31% (1024, 128) sum 0.014041 0.007847 0.012759 0.007850 +10.05% -0.04% ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ (8192, 128) mean 0.014361 0.009644 0.014075 0.009061 +2.03% +6.43% (8192, 128) sum 0.014366 0.010032 0.013702 0.009181 +4.85% +9.27% ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ (1024, 130) mean 0.014226 0.008696 0.012894 0.008835 +10.33% -1.57% (1024, 130) sum 0.013830 0.008740 0.013288 0.008989 +4.08% -2.77% ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ (8192, 130) mean 0.015036 0.010019 0.013917 0.009538 +8.04% +5.04% (8192, 130) sum 0.014652 0.010403 0.013900 0.009565 +5.41% +8.76% ==================================================================================================================================================================================== ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/165055 Approved by: https://github.com/ngimel ghstack dependencies: #165494, #164790 |
|||
7231118db3 |
Turn some const variables into constexpr in C++ code (#165401)
This PR checks the C++ code and turns some const variables into constexpr. Pull Request resolved: https://github.com/pytorch/pytorch/pull/165401 Approved by: https://github.com/Skylion007 |
|||
4a22139eea |
[MPS][BE] Fix unused variable warning (#165726)
Namely this one ``` /Users/malfet/git/pytorch/pytorch/aten/src/ATen/native/mps/kernels/Shape.metal:19:18: warning: unused variable 'output_sizes' [-Wunused-variable] constant auto& output_sizes = shared_params.output_sizes; ^ /Users/malfet/git/pytorch/pytorch/aten/src/ATen/native/mps/kernels/Shape.metal:85:1: note: in instantiation of function template specialization 'cat<long, float, float>' requested here REGISTER_CAT_FOR_INDEX_TYPE(int64_t); ^ /Users/malfet/git/pytorch/pytorch/aten/src/ATen/native/mps/kernels/Shape.metal:69:3: note: expanded from macro 'REGISTER_CAT_FOR_INDEX_TYPE' REGISTER_CAT_OP_ALL_INPUT_TYPES(I, float); \ ^ /Users/malfet/git/pytorch/pytorch/aten/src/ATen/native/mps/kernels/Shape.metal:55:3: note: expanded from macro 'REGISTER_CAT_OP_ALL_INPUT_TYPES' REGISTER_CAT_OP(I, float, T_out); \ ^ /Users/malfet/git/pytorch/pytorch/aten/src/ATen/native/mps/kernels/Shape.metal:47:15: note: expanded from macro 'REGISTER_CAT_OP' kernel void cat<I, T_in, T_out>( \ ``` Repeated about 20-30 times Pull Request resolved: https://github.com/pytorch/pytorch/pull/165726 Approved by: https://github.com/Skylion007 |
|||
cb6e4d7d82 |
User-passed alpha to scaled_gemm (#165563)
Summary: Add optional user-passed `alpha` argument to `at::cuda::blas::scaled_gemm`, necessary for two-level-scaled NVFP4 gemm calls (where the global de-scales are folded into the `alpha` argument. Global de-scales are naturally device tensors, but using cublas' device-pointer mode for `alpha`/`beta` has an interesting lifetime implication - the `alpha` tensor must be valid & correct until the end of the matmul call, *not* just the launch (as for host values). To enable this, I added device-constant memory for `one` and `zero`, along with a statically-held single-fp32-value tensor, which is valid from the first passed-`alpha` invocation of `scaled_gemm` to the end of the program. User-passed values are copied into this perpetual buffer to ensure lifetime requirements are met. Test Plan: Reviewers: Subscribers: Tasks: Tags: Signed-off-by: Simon Layton <simonlayton@meta.com> Pull Request resolved: https://github.com/pytorch/pytorch/pull/165563 Approved by: https://github.com/drisspg, https://github.com/eqy |
|||
202f83dc4e |
[ROCm][layer_norm] Use __builtin_amdgcn_rcpf(x) instead of 1.f/x (#165589)
Replace (more) exact calculation with hardware approximation. Benefits: Reduced code size. Improved performance for certain scenarios. Experiments show low reduction in precision. Experiments show no significant performance regressions. bfloat16 as well as float16 related calculations may benefit largely from this change. Co-author: @mhalk @amd-hhashemi Pull Request resolved: https://github.com/pytorch/pytorch/pull/165589 Approved by: https://github.com/jeffdaily |
|||
9e94ec76b8 |
Revert "Turn some const variables into constexpr in C++ code (#165401)"
This reverts commit 5b2afe4c5dc87786ca65bf22ca9a78f7c21a33a4.
Reverted https://github.com/pytorch/pytorch/pull/165401 on behalf of https://github.com/seemethere due to This is breaking test/distributions/test_distributions.py::TestDistributions::test_binomial_sample on HUD, see
|
|||
364624e209 |
[codemod][lowrisk] Remove unused exception parameter from some files (#165700)
Summary: `-Wunused-exception-parameter` has identified an unused exception parameter. This diff removes it. This: ``` try { ... } catch (exception& e) { // no use of e } ``` should instead be written as ``` } catch (exception&) { ``` If the code compiles, this is safe to land. Test Plan: Sandcastle Differential Revision: D84868162 Pull Request resolved: https://github.com/pytorch/pytorch/pull/165700 Approved by: https://github.com/Skylion007 |
|||
5b2afe4c5d |
Turn some const variables into constexpr in C++ code (#165401)
This PR checks the C++ code and turns some const variables into constexpr. Pull Request resolved: https://github.com/pytorch/pytorch/pull/165401 Approved by: https://github.com/Skylion007 |
|||
e0fe37fa68 |
[MPS] Move torch.cat impl to Metal (#165373)
After this change, all of the cases tested in [this performance measurement script](
|
|||
d2c82bafb7 |
Revert "158232 Fix autocast cache incorrectly retaining no_grad state (#165068)"
This reverts commit 5daef30b26b794d237fbbc399c1d47ec0380200a.
Reverted https://github.com/pytorch/pytorch/pull/165068 on behalf of https://github.com/jeffdaily due to This broke ROCm CI. test/test_transformers.py::TestTransformersCUDA::test_transformerencoder_fastpath_use_torchscript_False_enable_nested_tensor_True_use_autocast_True_d_model_256_cuda [GH job link](https://github.com/pytorch/pytorch/actions/runs/18572589089/job/52952074008) [HUD commit link](
|
|||
cbc08c8993 |
Add NEON acceleration for Vectorized<int[8|16|32|64> (#165273)
Summary: Adding NEON specializations of Vectorized<T> for int8, int16, int32 and int64. Correcness has been checked using test_ops.py and the comprehensive torch test operator_benchmark_test.py has been enhanced by adding cases of bitwise operations, boolean ops and integer ops. The benchmark, which uses the PyTorch API, shows significant enhancements in a wide variety of operations: Before: bitwise xor: 779.882us boolean any: 636.209us boolean all: 538.621us integer mul: 304.457us integer asr: 447.997us After: bitwise xor: 680.221us ---> 15% higher throughput boolean any: 391.468us ---> 63% higher throughput boolean all: 390.189us ---> 38% higher throughput integer mul: 193.532us ---> 57% higher throughput integer asr: 179.929us---> 149% higher throughput Test Plan: Correctness: buck2 test @mode/opt //caffe2/test:test_ops buck2 test @mode/opt //caffe2/test:torch buck2 test @mode/opt //caffe2/test/distributed/launcher/fb:fb_run_test Performance: buck2 run mode/opt //caffe2/benchmarks/operator_benchmark/fb:operator_benchmark_test Differential Revision: D84424638 Pull Request resolved: https://github.com/pytorch/pytorch/pull/165273 Approved by: https://github.com/malfet |
|||
aba8c43594 |
Register var for MTIA (#165382)
Summary: Registers variance kernel Reviewed By: srsuryadev Differential Revision: D84546250 Pull Request resolved: https://github.com/pytorch/pytorch/pull/165382 Approved by: https://github.com/malfet |
|||
5daef30b26 |
158232 Fix autocast cache incorrectly retaining no_grad state (#165068)
Fixes #158232 The autocast caching heuristic in `aten/src/ATen/autocast_mode.cpp:139` did not account for gradient mode state when deciding whether to cache. FSDP2 is not directly related. ~~This PR adds `GradMode::is_enabled()` check to caching condition. Caching is now disabled in `no_grad()` contexts to prevent storing tensors with incorrect gradient state. Ensures correctness at the cost of using cache.~~ This PR proposes separate caches for gradient-enabled and gradient-disabled modes. Adds tests. Pull Request resolved: https://github.com/pytorch/pytorch/pull/165068 Approved by: https://github.com/ngimel, https://github.com/janeyx99 |
|||
7669ac9402 |
[ROCm] Add scaled_mm v2 support. (#165528)
Add mx fp4 support in Blas.cpp. Updated the scale_kernel_dispatch array and ScaledGemmImplementation enum to include MXFP4 support. Modify the tests under test_scaled_matmul_cuda accordingly. PYTORCH_TEST_WITH_ROCM=1 python test/test_scaled_matmul_cuda.py -v -k test_blockwise 115 test passed. Pull Request resolved: https://github.com/pytorch/pytorch/pull/165528 Approved by: https://github.com/jeffdaily |
|||
fe5ccb1a74 |
bf16 support for per tensor backward (#165362)
Adding bf16 for the backward pass of `torch._fake_quantize_learnable_per_tensor_affine()`. Note that for testing, we modified the seed to avoid increasing tolerance due to cases where difference in Python vs CPP downcasting causes tensor mismatches. (e.g. 27.87704 vs 27.8408 before downcasting, 27.7500 vs 27.8750 after downcasting for Python vs CPP op) Pull Request resolved: https://github.com/pytorch/pytorch/pull/165362 Approved by: https://github.com/andrewor14 |
|||
1a5b7eca7b |
[BE] Fold cond into TORCH_CHECK(false,...) (#165593)
Replace `if (!foo) { TORCH_CHECK(false, "bar");}` with `TORCH_CHECK(foo,"bar");` Pull Request resolved: https://github.com/pytorch/pytorch/pull/165593 Approved by: https://github.com/albanD ghstack dependencies: #165594 |
|||
8573574b32 |
[MPS] sparse mask implementation (#165102)
sparse mask implementation Pull Request resolved: https://github.com/pytorch/pytorch/pull/165102 Approved by: https://github.com/malfet |
|||
e6033f6efb |
[MPS] Improve index_fill_ error handling (#165594)
It shoudl not throw "Cannot convert a float64 Tensor to MPS", but rather a sensible "Converting complex Scalar to non-complex type is not supported". Add TODO about the complex support, probably good reason to rip out MPSGraph from index_fill as well Pull Request resolved: https://github.com/pytorch/pytorch/pull/165594 Approved by: https://github.com/dcci, https://github.com/kulinseth |
|||
d73c283c3a |
[CUDA] Large tensor maxpool crash fix (#165374)
Fixes #165297 Pull Request resolved: https://github.com/pytorch/pytorch/pull/165374 Approved by: https://github.com/eqy, https://github.com/malfet |
|||
eaeaa08e3a |
[PowerPC] Disable MKLDNN TF32 on PowerPC to fix build failure (#163454)
The commits f4d8bc46c7706f872abcb4ec41f0b32207d5d826 added TF32 support for x86 CPUs, which causes build failures on PowerPC systems with mkldnn. This patch disables TF32 paths on PowerPC while keeping x86 TF32 support intact, allowing PyTorch to build successfully on PowerPC. I have run the mkldnn test case on PowerPC, and it passed successfully. `pytest test/test_mkldnn.py 87 passed, 2 skipped in 1709.02s (0:28:29` Pull Request resolved: https://github.com/pytorch/pytorch/pull/163454 Approved by: https://github.com/jgong5, https://github.com/malfet |
|||
26f3803433 |
Remove workaround to old CUDA bug (#164354)
As in the title. A check for https://github.com/pytorch/pytorch/issues/164348 to see if the workaround can be removed. Pull Request resolved: https://github.com/pytorch/pytorch/pull/164354 Approved by: https://github.com/janeyx99, https://github.com/ngimel, https://github.com/malfet, https://github.com/jeffdaily ghstack dependencies: #164350 |
|||
36371b8ec7 |
[ATen] Fix CUDA reduction warp shuffle order (#164790)
Typical warp shuffle reduction has the following pattern: <img width="1138" height="501" alt="image" src="https://github.com/user-attachments/assets/3bd176dc-0ad2-4df6-90c7-06e467337166" /> which is exhibited in Triton generated by torch.compile: <img width="663" height="403" alt="image" src="https://github.com/user-attachments/assets/7f9f36cd-b9eb-44c1-879e-b469668a2ea8" /> Switch the warp shuffle order to make bitwise equivalence between the 2 easier. PTX difference between old and new, we see a few extra instructions: https://www.diffchecker.com/h6ly3INC/ Comparing the performance on different reduction operations, we see minimal differences. New represents the changes in this PR, old represents the past warp shuffle order: ``` Tensor Shape Operation New all dims (ms) New dim=0 (ms) New dim=1 (ms) Old all dims (ms) Old dim=0 (ms) Old dim=1 (ms) ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ (1024, 1024) mean 0.015817 0.016259 0.013642 0.015990 0.016258 0.013631 (1024, 1024) sum 0.015917 0.015906 0.013359 0.015707 0.016266 0.013226 (1024, 1024) min 0.016021 0.024625 0.015631 0.015761 0.024485 0.015317 (1024, 1024) max 0.016349 0.024971 0.015972 0.015771 0.025001 0.015314 (1024, 1024) argmin 0.018070 0.024448 0.015578 0.018135 0.025370 0.015322 (1024, 1024) argmax 0.018427 0.024859 0.015932 0.018164 0.024452 0.015639 (1024, 1024) var 0.020078 0.026413 0.020295 0.020199 0.026381 0.020214 ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ (2048, 2048) mean 0.023826 0.023726 0.022273 0.023236 0.023776 0.022248 (2048, 2048) sum 0.023840 0.023355 0.021974 0.023294 0.023354 0.021884 (2048, 2048) min 0.024519 0.041263 0.024620 0.023292 0.041491 0.024358 (2048, 2048) max 0.024509 0.041670 0.024277 0.023334 0.041231 0.024395 (2048, 2048) argmin 0.026125 0.041282 0.024567 0.026772 0.041773 0.024296 (2048, 2048) argmax 0.026117 0.041487 0.024572 0.026412 0.041477 0.024273 (2048, 2048) var 0.026603 0.048581 0.031308 0.027587 0.048603 0.030860 ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ (4096, 4096) mean 0.053927 0.057070 0.054073 0.053028 0.057544 0.053935 (4096, 4096) sum 0.053604 0.057410 0.054451 0.053076 0.057033 0.054266 (4096, 4096) min 0.054293 0.109122 0.058363 0.053821 0.108689 0.058382 (4096, 4096) max 0.054258 0.108035 0.058703 0.053492 0.110552 0.058376 (4096, 4096) argmin 0.056805 0.111167 0.058301 0.056836 0.112325 0.058292 (4096, 4096) argmax 0.056488 0.110958 0.058636 0.056844 0.111000 0.057928 (4096, 4096) var 0.058936 0.141755 0.068693 0.059735 0.141284 0.068500 ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ (8192, 8192) mean 0.145552 0.148082 0.138647 0.145364 0.147818 0.138207 (8192, 8192) sum 0.145985 0.147900 0.138714 0.145755 0.148031 0.138616 (8192, 8192) min 0.146566 0.205359 0.192739 0.145611 0.205237 0.182335 (8192, 8192) max 0.146526 0.204844 0.193050 0.146073 0.205457 0.182697 (8192, 8192) argmin 0.150190 0.206605 0.192543 0.150654 0.206847 0.182007 (8192, 8192) argmax 0.150481 0.206368 0.192535 0.150845 0.206430 0.182022 (8192, 8192) var 0.150884 0.184546 0.203900 0.151594 0.184172 0.197983 ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ (1, 1024, 128) mean 0.014293 0.008119 0.014533 0.013861 0.008022 0.014449 (1, 1024, 128) sum 0.014039 0.007877 0.014111 0.014219 0.008227 0.014045 (1, 1024, 128) min 0.014159 0.011354 0.023493 0.014271 0.010862 0.023644 (1, 1024, 128) max 0.014154 0.011027 0.023368 0.014259 0.011234 0.023692 (1, 1024, 128) argmin 0.016403 0.005677 0.023328 0.016273 0.005683 0.024073 (1, 1024, 128) argmax 0.016734 0.005675 0.023437 0.016580 0.005318 0.023331 (1, 1024, 128) var 0.018338 0.009549 0.025538 0.018528 0.009391 0.024777 ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ (5, 1024, 128) mean 0.014873 0.010131 0.015546 0.015123 0.010131 0.015481 (5, 1024, 128) sum 0.015334 0.009673 0.015824 0.014736 0.009671 0.015438 (5, 1024, 128) min 0.015047 0.013252 0.024573 0.014803 0.013163 0.024551 (5, 1024, 128) max 0.015050 0.013339 0.024197 0.014810 0.013525 0.024230 (5, 1024, 128) argmin 0.017341 0.012737 0.024306 0.017471 0.012379 0.024991 (5, 1024, 128) argmax 0.017345 0.012411 0.024421 0.017422 0.012471 0.024237 (5, 1024, 128) var 0.019973 0.011453 0.026188 0.020050 0.011438 0.026282 ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ (10, 1024, 128) mean 0.016976 0.011575 0.016831 0.016722 0.011927 0.017173 (10, 1024, 128) sum 0.017039 0.011841 0.017159 0.016385 0.011860 0.016753 (10, 1024, 128) min 0.017036 0.015331 0.026770 0.016944 0.015205 0.027166 (10, 1024, 128) max 0.017369 0.015348 0.027077 0.016531 0.015716 0.026819 (10, 1024, 128) argmin 0.019203 0.014447 0.026813 0.018994 0.014497 0.027313 (10, 1024, 128) argmax 0.019563 0.014795 0.027140 0.019460 0.014912 0.026733 (10, 1024, 128) var 0.020529 0.014316 0.030405 0.020719 0.013960 0.029964 ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ (100, 1024, 128) mean 0.045046 0.039168 0.046082 0.044839 0.039217 0.045782 (100, 1024, 128) sum 0.045094 0.039150 0.045777 0.044496 0.039542 0.046083 (100, 1024, 128) min 0.045768 0.054466 0.076244 0.044915 0.053943 0.076599 (100, 1024, 128) max 0.045748 0.054459 0.076188 0.044931 0.053949 0.076856 (100, 1024, 128) argmin 0.048275 0.054046 0.076647 0.048694 0.054105 0.077004 (100, 1024, 128) argmax 0.048267 0.054395 0.077401 0.048691 0.054131 0.076751 (100, 1024, 128) var 0.049710 0.043254 0.083077 0.050971 0.043251 0.082378 ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ (1000, 1000, 100) mean 0.202312 0.196723 0.197765 0.201774 0.196641 0.197459 (1000, 1000, 100) sum 0.202651 0.196682 0.197736 0.202175 0.196313 0.197523 (1000, 1000, 100) min 0.203022 0.264762 0.269200 0.202729 0.264129 0.268694 (1000, 1000, 100) max 0.202864 0.264396 0.269388 0.202486 0.263896 0.268720 (1000, 1000, 100) argmin 0.226727 0.263781 0.268651 0.226597 0.264676 0.268983 (1000, 1000, 100) argmax 0.226412 0.264469 0.269090 0.226570 0.264595 0.269178 (1000, 1000, 100) var 0.243223 0.204079 0.216096 0.241942 0.204079 0.215925 ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ (10000, 100) mean 0.016193 0.020277 0.014316 0.016152 0.020324 0.013712 (10000, 100) sum 0.016289 0.020237 0.014034 0.016168 0.020265 0.013708 (10000, 100) min 0.016046 0.030872 0.019609 0.016208 0.030867 0.018627 (10000, 100) max 0.016369 0.030835 0.019257 0.016218 0.030861 0.018209 (10000, 100) argmin 0.017957 0.031171 0.019517 0.018050 0.031556 0.018077 (10000, 100) argmax 0.017961 0.031658 0.019521 0.018060 0.031564 0.018087 (10000, 100) var 0.020393 0.035652 0.019339 0.020144 0.035987 0.019171 ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ (100000, 10) mean 0.015718 0.016576 0.016555 0.015999 0.016246 0.014869 (100000, 10) sum 0.015833 0.016247 0.016572 0.016007 0.016627 0.014872 (100000, 10) min 0.015888 0.020510 0.023920 0.015671 0.020821 0.021417 (100000, 10) max 0.015889 0.020479 0.023918 0.016077 0.020386 0.021421 (100000, 10) argmin 0.018233 0.020863 0.023647 0.017574 0.020864 0.021103 (100000, 10) argmax 0.017896 0.020527 0.023296 0.017569 0.020447 0.021098 (100000, 10) var 0.020005 0.024198 0.024372 0.020075 0.024167 0.022415 ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ (1023, 1023, 1023) mean 1.874816 1.963506 1.903909 1.873279 1.963859 1.903230 (1023, 1023, 1023) sum 1.875030 1.965716 1.902458 1.873566 1.960730 1.901642 (1023, 1023, 1023) min 1.878563 2.473455 2.179092 1.875174 2.482086 2.183027 (1023, 1023, 1023) max 1.879128 2.474803 2.178895 1.874831 2.482253 2.183884 (1023, 1023, 1023) argmin 1.921800 2.476629 2.174831 1.923987 2.472641 2.170453 (1023, 1023, 1023) argmax 1.922605 2.476688 2.177927 1.923366 2.472808 2.172979 (1023, 1023, 1023) var 1.972606 3.088695 2.758797 1.978679 3.095658 2.762243 ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ (1023, 1023, 255) mean 0.489984 0.500954 0.492957 0.489891 0.500654 0.491971 (1023, 1023, 255) sum 0.490228 0.500764 0.492289 0.489624 0.501089 0.492824 (1023, 1023, 255) min 0.491457 0.563560 0.553334 0.490355 0.564709 0.554754 (1023, 1023, 255) max 0.491396 0.563628 0.553345 0.490017 0.565004 0.554947 (1023, 1023, 255) argmin 0.503666 0.561512 0.551831 0.503845 0.560972 0.551017 (1023, 1023, 255) argmax 0.503602 0.561185 0.551407 0.504328 0.561267 0.551448 (1023, 1023, 255) var 0.510844 0.709452 0.701630 0.512693 0.710365 0.701965 ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ (1023, 1023, 377) mean 0.707439 0.727646 0.712019 0.706769 0.727101 0.711632 (1023, 1023, 377) sum 0.707780 0.727453 0.711554 0.706807 0.726656 0.711729 (1023, 1023, 377) min 0.709423 0.819809 0.794379 0.707847 0.822086 0.796664 (1023, 1023, 377) max 0.709297 0.819780 0.794308 0.707566 0.821913 0.796690 (1023, 1023, 377) argmin 0.725028 0.817088 0.791695 0.726039 0.816445 0.790828 (1023, 1023, 377) argmax 0.725301 0.817011 0.791420 0.726040 0.816917 0.791143 (1023, 1023, 377) var 0.740859 1.034165 1.006712 0.743413 1.035506 1.007638 ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/164790 Approved by: https://github.com/ngimel, https://github.com/eqy ghstack dependencies: #165494 |
|||
53f9ae0e50 |
[ROCm] new implementation of upsample_bilinear2d_backward (#164572)
Changed the implementation from an output-based approach to an input-based one to remove `atomicAdd` operations, and it appears to deliver at least a 20× speedup. The changes are from Yu-Yun <YuYun.Chang@amd.com>. # Summary: Refactor of the implementation of the `upsample_bilinear2d_backward` opertion on MI300X/MI325X - The original "scatter-add" approach - Each thread, representing an output pixel, scattered gradient contributions to four input pixels, using costly atomic operations on MI300X/MI325X GPUs. - The new "gather-sum" approach - Each thread is responsible for a single input pixel and gathers all relevant gradient contributions from a small, calculated region of the output tensor (done by the `compute_output_range` device function). # Breakdown of the code changes - Inversion of the parallelization strategy of the kernel function `upsample_bilinear2d_backward_out_frame` - Originally, the main kernel loop was parallelized over the number of elements in the output gradient tensor (`const size_t o_numel = nc * width2 * height2;`). - Each thread processed one output pixel. - The new loop is parallelized over the number of elements in the input gradient tensor (`const size_t i_numel = nc * height1 * width1;`). - Each thread is responsible for calculating the final gradient for a single input pixel. - The kernel launch changes accordingly in the function `upsample_bilinear2d_backward_out_cuda_template`. - Added a device function for calculating the range of output pixels that could have possibly used that the input pixel (`input_pos`) during the forward pass interpolation - This is essentially the mathematical inverse of the forward pass. - This function tries to prune a thread's search space so that it only needs to inspect a small, local window of the output tensor. - Gradient calculation approach switching from "scatter-add" to "gather-sum" - Scatter-add - For each output pixel, the thread calculated 4 gradient contributions and use `fastAtomicAdd` 4 times to add these values to 4 different (and potentially highly contended) memory locations in the input gradient tensor. - Gather-sum - A thread responsible for one input pixel calls `compute_output_range` to determine the small rectangular region of output pixels that influence the input's final gradient value. - The thread iterates through this region, and for each output pixel in the regionre, it re-calculates the interpolation weights to determine the exact contribution to its specific input pixel. - All these contributions are accumulated into a private, per-thread register variable (`accscalar_t grad_sum = 0;`). - W/o any gloabl memory access, this accumulation is extremely fast. - When the loops are done, the thread performs a single, direct write (non-atomic) of the final summed gradient to its designated location in global memory (`idata[index] = static_cast<scalar_t>(grad_sum);`). # Why performance gets boosted - Analysis of the root cause of performance drop - Ref. (internal only) - https://amd.atlassian.net/wiki/spaces/~glencao2/pages/1140493327/PyTorch__upsample_bilinear2d_backward - First and foremost, elimination of the contention of atomic operations - Many parallel threads called `atomicAdd` frequently attempting to update the exact same memory location in the input gradient tensor at the same time. - The GPU's memory controler has to serialize these operations, effectively nullifying the benefit of parallel capability at those contention points. - MI300X/MI325X chiplet-based CDNA 3 architeture amplified the issue. - When contending threads reside on different XCDs, resolving the atomic operation requires high-latency coherence traffic across the Infinity Fabric interconnect. - The implementation change eliminates hardware-level serialization and cross-chiplet coherence traffic caused by many `atomicAdd`. - Improved memory access pattern and locality - Write coalescing - The regular sum writes `idata[index] = static_cast<scalar_t>(grad_sum);` can be perfectly coalesced by GPUs. - Read locality - Even though there are many (potentially repeated) reads from the output tensor (`static_cast<accscalar_t>(odata[output_idx])`), these are highly cache-friendly, meaning the data for one thread is likely to be in the L1 or L2 cache already due to an access from a neighboring thread. - Trade-off: computation for memory synchronization - The recalculation of interpolation weights fits well on high-computational-throughput modern GPUs like MI300X/MI325X. - Removal of atomic operations avoids expensive memory synchronization. --- Optimizations of `grid_sampler_2d_backward` will be addressed in a separate PR. Doc for reference: (internal only) https://amd.atlassian.net/wiki/spaces/~glencao2/pages/1162750701/PyTorch__grid_sampler_2d_backward Pull Request resolved: https://github.com/pytorch/pytorch/pull/164572 Approved by: https://github.com/jeffdaily |
|||
66ea76ec44 |
[ROCm][tunableop] Improvements to tunableop Numerical Check (#163079)
Modified the flag PYTORCH_TUNABLEOP_NUMERICAL_CHECK, so that it accepts the numerical tolerances in the format atol_rtol as compared to the previous 0 and 1. Retains previous functionality with default values as well. Pull Request resolved: https://github.com/pytorch/pytorch/pull/163079 Approved by: https://github.com/naromero77amd, https://github.com/jeffdaily |
|||
7f9b745494 |
[ROCm][tunableop] Modified Online Tuning Mode to add Instant Logging (#163965)
- Added instant logging in online tuning mode, so that each tuned GEMM is instantly written - Allows us to have saved tuning configs, in cases of crashes. Pull Request resolved: https://github.com/pytorch/pytorch/pull/163965 Approved by: https://github.com/naromero77amd, https://github.com/jeffdaily |
|||
066f818eea |
Refactor and unify v1/v2 _scaled_mm codes (#165436)
Summary: * Refactor out some core routines (scaled_gemm, auto-tuned scaled_gemm) * Unify v1/v2 dispatch calls where possible * Simplify call pattern w.r.t. CUDA/ROCM for easier readability. Test Plan: ``` pytest -svv test/test_scaled_matmul_cuda.py ``` Reviewers: Subscribers: Tasks: Tags: Signed-off-by: Simon Layton <simonlayton@meta.com> Pull Request resolved: https://github.com/pytorch/pytorch/pull/165436 Approved by: https://github.com/drisspg |
|||
7c6c5d04fe |
Add scaled_grouped_mm_v2 and python API (#165154)
Summary: * Add `torch._scaled_grouped_mm_v2` with more functionality and extensibility for future formats * Add `torch.nn.functional.scaled_grouped_mm` as public entrypoint * Test both original and v2 functionality Test Plan: ``` pytest -svv -k grouped test/test_scaled_matmul_cuda.py ``` Reviewers: Subscribers: Tasks: Tags: Signed-off-by: Simon Layton <simonlayton@meta.com> Pull Request resolved: https://github.com/pytorch/pytorch/pull/165154 Approved by: https://github.com/drisspg, https://github.com/danielvegamyhre |
|||
7719cb75bf |
[ATen][CMake] Fix duplicated CUTLASS path (#165424)
Fixes #165110 The `PUBLIC` scope causes CUTLASS of the FBGEMM being included in for all PyTorch targets, including special matmuls (RowwiseScaledMM, ScaledGroupMM and GroupMM). Due to version mismatch between FBGEMM/CUTLASS and PyTorch/CUTLASS it is unacceptable to use FBGEMM/CUTLASS in PyTorch targets. This PR limits the scope of FBGEMM/CUTLASS to `fbgemm_genai` target only. Pull Request resolved: https://github.com/pytorch/pytorch/pull/165424 Approved by: https://github.com/cthi, https://github.com/eqy, https://github.com/danielvegamyhre |
|||
712f54d453 |
[ATen] Remove explicit casting of complex nansum during accumulation (#165494)
https://github.com/pytorch/pytorch/pull/164790 modifies aten to perform a different reduction order intra warp. However, this change exposed a large difference in a sum for complex32. Namely the case: ``` import torch a = torch.tensor([[ 4.82031250+7.34765625j, -3.37109375-1.9501953125j], [ 3.7832031250-2.43359375j, -6.07812500+5.32812500j]], dtype=torch.complex32, device='cuda:0') sum_out = torch.sum(a) nansum_out = torch.nansum(a) torch.testing.assert_close( sum_out, nansum_out, rtol=0, atol=0, ) ``` Here, the result of `sum` and `nansum` differed significantly by 1e-2. Further investigation showed that the explicit casting of b back to `arg_t` from `scalar_t` was the root cause. `arg_t` is the dtype of the accumulator, ComplexFloat, and `scalar_t` of the input dtype, ComplexHalf. When we cast in the reduction to the accumulator order, that means the input is still of ComplexHalf, which loses precision as it can store intermediate values. Pull Request resolved: https://github.com/pytorch/pytorch/pull/165494 Approved by: https://github.com/ngimel |
|||
8e510e1095 |
[MPS] fix empty dot op crash (#165237)
reproducer ``` import torch # does not crash a = torch.rand((0), device="cpu") b = torch.rand((0), device="cpu") a.dot(b) # crashes due to internal assert a = torch.rand((0), device="mps") b = torch.rand((0), device="mps") a.dot(b) ``` Discovered when implementing an op for SparseMPS backend Pull Request resolved: https://github.com/pytorch/pytorch/pull/165237 Approved by: https://github.com/malfet |
|||
b4fd47179e |
feat(dynamo): IS#160752 make F.one_hot work with jacfwd + torch.compile(dynamic=True) (#160837)
Fixes #160752 # Background: `torch.func.jacfwd` is implemented as vmap over forward-mode JVP. With torch.compile(dynamic=True), FakeTensor + SymInt shape reasoning is used while tracing through the transform. The old vmap rule for one_hot decomposed into “zeros_symint + scatter,” which interacted poorly with the transform stack and dynamic shapes, leading to failures mid-trace. Using a functional equality construction makes one_hot composable with vmap/JVP and friendly to dynamic shape tracing. # Changes: - functorch vmap batching rule for `aten::one_hot` now uses a purely functional formulation: - Replace “zeros + scatter” with eq(self.unsqueeze(-1), arange(num_classes)).to(kLong) under FuncTorchBatched. - one_hot native path remains unchanged for regular eager; vmap transform no longer relies on scatter, which was fragile under dynamic shape tracing. The minimal repro from the issue is now fixed: ```python import torch import torch.nn.functional as F MAX, BATCH = 3, 37 def func(x, idxs): return x.square() * F.one_hot(idxs, MAX) def jacfunc(x, idxs): return torch.func.jacfwd(func, argnums=0)(x, idxs) idxs = torch.randint(MAX, (BATCH,), dtype=torch.int64) x = torch.rand((BATCH, MAX), dtype=torch.float64) # eager out_eager = jacfunc(x, idxs) # compiled dynamic jacfunc_c = torch.compile(jacfunc, dynamic=True) out_comp = jacfunc_c(x, idxs) torch.testing.assert_close(out_eager, out_comp) ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/160837 Approved by: https://github.com/guilhermeleobas, https://github.com/zou3519 |
|||
4f400ab520 |
Fix: nDims is mutated inside the loop in Shape.cu (#165446)
Summary: The `nDims` variable is mutated inside the loop but never restored to its original value. This affects subsequent iterations of the outer loop. Each batch iteration may get incorrect `nDims` after the first batch. Test Plan: CI Reviewed By: ngimel Differential Revision: D84612194 Pull Request resolved: https://github.com/pytorch/pytorch/pull/165446 Approved by: https://github.com/ngimel |
|||
a20afb6100 |
Allow at::native::offset_t to be offset using operator+= (#164570)
This will be required by CCCL 3.1. Pull Request resolved: https://github.com/pytorch/pytorch/pull/164570 Approved by: https://github.com/Skylion007, https://github.com/eqy |
|||
382d04a51e |
[Inductor][ATen][FP8] Add note for supported blockwise scaling strategy pairs (#165450)
Summary: Add note mentioning which scaling type pairs are supported in Inductor ATen, since this was a source of confusion and also informs which scaling strategies we choose to support for other backends, like Triton. Test Plan: n/a Reviewed By: lw Differential Revision: D84522373 Pull Request resolved: https://github.com/pytorch/pytorch/pull/165450 Approved by: https://github.com/NikhilAPatel |
|||
7fee6bbf34 |
[Fix] Completely remove stride normalization on DLPack Tensor (#164161)
A followup on PR #163282 Fixes #163274 Pull Request resolved: https://github.com/pytorch/pytorch/pull/164161 Approved by: https://github.com/ngimel, https://github.com/eqy |
|||
c733072874 |
Fix IValue from SymBool on big-endian system (#163647)
Skip test_compiled_autograd_attribution on s390x It fails both on s390x and x86_64 at least under some circumstances. Disable it for now until on s390x until it works reliably. Pull Request resolved: https://github.com/pytorch/pytorch/pull/163647 Approved by: https://github.com/malfet |
|||
56d6229ff9 |
[MPS] fix comment for normcdf (#165233)
Just a small comment fix for normcdf Pull Request resolved: https://github.com/pytorch/pytorch/pull/165233 Approved by: https://github.com/malfet |
|||
a856a17799 |
bf16 support for per_channel bwd (#165325)
Follow up to #165098 - adding bf16 support for the backward pass. To avoid BC breaking changes/losing precision, we upcast the parameters to fp32 after the op gets called, and downcast the gradients to bf16 before returning. For testing, we upcast to fp32 before calling the reference function. We increase the tolerance to 1e-2 for bf16 inputs because of a difference in casting calculations between python's `x.to(torch.bfloat16)` and cpp's `x.to(at::kBFloat16)` (after comparing intermediate tensors, we found that the numerics diverge after the final casting). We don't explicitly cast in the CPP op but rather let autograd/optimizer handle it. Pull Request resolved: https://github.com/pytorch/pytorch/pull/165325 Approved by: https://github.com/andrewor14 |