Compare commits

..

70 Commits

Author SHA1 Message Date
acf8783a5b Update
[ghstack-poisoned]
2025-11-03 16:15:50 +00:00
1946c2368f Update
[ghstack-poisoned]
2025-11-03 15:55:54 +00:00
341514bf45 Update (base update)
[ghstack-poisoned]
2025-11-03 15:55:54 +00:00
8120d6004e Update
[ghstack-poisoned]
2025-11-03 12:35:05 +00:00
d6222a3b82 Update (base update)
[ghstack-poisoned]
2025-11-03 12:35:05 +00:00
18af57b76b Update
[ghstack-poisoned]
2025-11-03 11:55:12 +00:00
379d5606b9 Update (base update)
[ghstack-poisoned]
2025-11-03 11:55:12 +00:00
70518444a8 Update
[ghstack-poisoned]
2025-10-31 10:18:29 +00:00
55a42d3a3e Update (base update)
[ghstack-poisoned]
2025-10-31 10:18:29 +00:00
793a5e2f86 Update
[ghstack-poisoned]
2025-10-30 15:49:36 +00:00
e7a108c32b Update (base update)
[ghstack-poisoned]
2025-10-30 15:49:36 +00:00
dd5a8d3fc8 Update
[ghstack-poisoned]
2025-10-30 15:44:01 +00:00
1d82e429d7 Update (base update)
[ghstack-poisoned]
2025-10-30 15:44:01 +00:00
03e312c85e Update
[ghstack-poisoned]
2025-10-30 14:29:34 +00:00
2540eaff4d Update (base update)
[ghstack-poisoned]
2025-10-30 14:29:34 +00:00
b1e44a9ff1 Update
[ghstack-poisoned]
2025-10-30 13:05:51 +00:00
7c585b11f9 Update (base update)
[ghstack-poisoned]
2025-10-30 13:05:51 +00:00
77330f39e4 Update
[ghstack-poisoned]
2025-10-29 14:33:47 +00:00
eafb84aebd Update (base update)
[ghstack-poisoned]
2025-10-29 14:33:47 +00:00
bf5e9cc835 Update
[ghstack-poisoned]
2025-10-29 13:16:06 +00:00
83370ee71f Update (base update)
[ghstack-poisoned]
2025-10-29 13:16:06 +00:00
8e2e74b12a Update
[ghstack-poisoned]
2025-10-29 10:50:26 +00:00
dbb5565e17 Update (base update)
[ghstack-poisoned]
2025-10-29 10:50:26 +00:00
3a41807da9 Update
[ghstack-poisoned]
2025-10-29 10:45:32 +00:00
670828a5bb Update (base update)
[ghstack-poisoned]
2025-10-29 10:45:32 +00:00
b0c4e5ce92 Update
[ghstack-poisoned]
2025-10-28 16:04:05 +00:00
23d74eb617 Update (base update)
[ghstack-poisoned]
2025-10-28 16:04:05 +00:00
3cc8b64300 Update
[ghstack-poisoned]
2025-10-28 15:45:49 +00:00
5282147127 Update (base update)
[ghstack-poisoned]
2025-10-28 15:45:49 +00:00
367f40a7e0 Update
[ghstack-poisoned]
2025-10-28 15:39:23 +00:00
c67c516653 Update (base update)
[ghstack-poisoned]
2025-10-28 15:39:23 +00:00
a451258d9c Update
[ghstack-poisoned]
2025-10-28 15:20:44 +00:00
ef90141cbf Update (base update)
[ghstack-poisoned]
2025-10-28 15:20:44 +00:00
5d84e13851 Update
[ghstack-poisoned]
2025-10-28 15:12:10 +00:00
ba63727c2e Update (base update)
[ghstack-poisoned]
2025-10-28 15:12:10 +00:00
02a255ae8e Update
[ghstack-poisoned]
2025-10-28 15:08:06 +00:00
329d47c055 Update (base update)
[ghstack-poisoned]
2025-10-28 15:08:06 +00:00
295a042e39 Update
[ghstack-poisoned]
2025-10-28 14:49:48 +00:00
82e7131068 Update (base update)
[ghstack-poisoned]
2025-10-28 14:49:48 +00:00
b61a11e8d9 Update
[ghstack-poisoned]
2025-10-28 14:07:59 +00:00
00a615d1e2 Update (base update)
[ghstack-poisoned]
2025-10-28 14:07:59 +00:00
9f582f55af Update
[ghstack-poisoned]
2025-10-28 13:58:07 +00:00
c4fc3c53e1 Update (base update)
[ghstack-poisoned]
2025-10-28 13:58:07 +00:00
cd82b0f7d9 Update
[ghstack-poisoned]
2025-10-28 13:44:50 +00:00
139222da06 Update (base update)
[ghstack-poisoned]
2025-10-28 13:44:50 +00:00
2450d02e97 Update
[ghstack-poisoned]
2025-10-28 12:02:22 +00:00
0820b97e78 Update (base update)
[ghstack-poisoned]
2025-10-28 12:02:22 +00:00
b421538f59 Update
[ghstack-poisoned]
2025-10-28 11:48:27 +00:00
3efbfb3f6f Update (base update)
[ghstack-poisoned]
2025-10-28 11:48:27 +00:00
2325197448 Update
[ghstack-poisoned]
2025-10-27 17:15:01 +00:00
a849ab3e44 Update (base update)
[ghstack-poisoned]
2025-10-27 17:15:01 +00:00
994fe49902 Update
[ghstack-poisoned]
2025-10-27 15:12:21 +00:00
bc85bf7ed1 Update (base update)
[ghstack-poisoned]
2025-10-27 15:12:21 +00:00
ba1fe373be Update
[ghstack-poisoned]
2025-10-27 15:00:12 +00:00
28a754f37b Update (base update)
[ghstack-poisoned]
2025-10-27 15:00:12 +00:00
cc7c1c81f6 Update
[ghstack-poisoned]
2025-10-27 12:38:28 +00:00
f78c4dee42 Update (base update)
[ghstack-poisoned]
2025-10-27 12:38:28 +00:00
2514f9d62f Update
[ghstack-poisoned]
2025-10-27 12:28:17 +00:00
51122c815f Update (base update)
[ghstack-poisoned]
2025-10-27 12:28:17 +00:00
d7fd08839f Update
[ghstack-poisoned]
2025-10-27 12:04:23 +00:00
8f73b7cb35 Update (base update)
[ghstack-poisoned]
2025-10-27 12:04:23 +00:00
c587a960fb Update
[ghstack-poisoned]
2025-10-27 11:48:10 +00:00
25d8411fb5 Update (base update)
[ghstack-poisoned]
2025-10-27 11:48:10 +00:00
eabecd05c5 Update
[ghstack-poisoned]
2025-10-27 11:41:21 +00:00
edca2c8698 Update (base update)
[ghstack-poisoned]
2025-10-27 11:28:48 +00:00
f2de9313f4 Update
[ghstack-poisoned]
2025-10-27 11:28:48 +00:00
3ac256e289 Update (base update)
[ghstack-poisoned]
2025-10-24 16:55:04 +00:00
94055d73d4 Update
[ghstack-poisoned]
2025-10-24 16:55:04 +00:00
dc09f97271 Update (base update)
[ghstack-poisoned]
2025-10-24 16:48:45 +00:00
841b8a27b8 Update
[ghstack-poisoned]
2025-10-24 16:48:45 +00:00
19 changed files with 158 additions and 681 deletions

View File

@ -1,10 +1,9 @@
name: inductor-rocm
on:
schedule:
- cron: 0 * * * *
push:
branches:
- main
- release/*
tags:
- ciflow/inductor-rocm/*

View File

@ -41,7 +41,7 @@ jobs:
uses: ./.github/workflows/_linux-build.yml
needs: get-label-type
with:
runner_prefix: "${{ needs.get-label-type.outputs.label-type }}"
runner: "${{ needs.get-label-type.outputs.label-type }}linux.2xlarge"
build-environment: linux-jammy-py3.10-gcc11
docker-image-name: ci-image:pytorch-linux-jammy-py3.10-gcc11
secrets: inherit

View File

@ -66,10 +66,10 @@ jobs:
{ config: "default", shard: 5, num_shards: 5, runner: "${{ needs.get-label-type.outputs.label-type }}linux.2xlarge" },
{ config: "docs_test", shard: 1, num_shards: 1, runner: "${{ needs.get-label-type.outputs.label-type }}linux.2xlarge" },
{ config: "jit_legacy", shard: 1, num_shards: 1, runner: "${{ needs.get-label-type.outputs.label-type }}linux.2xlarge" },
{ config: "backwards_compat", shard: 1, num_shards: 1, runner: "${{ needs.get-label-type.outputs.label-type }}linux.c7i.2xlarge" },
{ config: "backwards_compat", shard: 1, num_shards: 1, runner: "${{ needs.get-label-type.outputs.label-type }}linux.2xlarge" },
{ config: "distributed", shard: 1, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}linux.2xlarge" },
{ config: "distributed", shard: 2, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}linux.2xlarge" },
{ config: "numpy_2_x", shard: 1, num_shards: 1, runner: "${{ needs.get-label-type.outputs.label-type }}linux.c7i.2xlarge" },
{ config: "numpy_2_x", shard: 1, num_shards: 1, runner: "${{ needs.get-label-type.outputs.label-type }}linux.2xlarge" },
]}
secrets: inherit
@ -167,8 +167,8 @@ jobs:
docker-image-name: ci-image:pytorch-linux-jammy-py3-clang12-onnx
test-matrix: |
{ include: [
{ config: "default", shard: 1, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}linux.c7i.2xlarge" },
{ config: "default", shard: 2, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}linux.c7i.2xlarge" },
{ config: "default", shard: 1, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}linux.2xlarge" },
{ config: "default", shard: 2, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}linux.2xlarge" },
]}
secrets: inherit

View File

@ -3,13 +3,13 @@ name: rocm
on:
push:
branches:
- main
- release/*
tags:
- ciflow/rocm/*
workflow_dispatch:
schedule:
- cron: 29 8 * * * # about 1:29am PDT
- cron: 0 * * * *
concurrency:
group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.ref_name }}-${{ github.ref_type == 'branch' && github.sha }}-${{ github.event_name == 'workflow_dispatch' }}-${{ github.event_name == 'schedule' }}

View File

@ -1,6 +1,6 @@
#include <ATen/cuda/CUDAGreenContext.h>
#if defined(CUDA_VERSION) && (CUDA_VERSION >= 12030) && !defined(USE_ROCM) && defined(PYTORCH_C10_DRIVER_API_SUPPORTED)
#if defined(CUDA_VERSION) && !defined(USE_ROCM) && defined(PYTORCH_C10_DRIVER_API_SUPPORTED)
#include <c10/cuda/driver_api.h>
#include <stdexcept>
#include <vector>

View File

@ -2917,7 +2917,9 @@ static Tensor& linalg_eig_make_complex_eigenvectors(Tensor& complex_vectors, con
DEFINE_DISPATCH(linalg_eig_stub);
static std::tuple<Tensor&, Tensor&> linalg_eig_out_info(const Tensor& input, Tensor& values, Tensor& vectors, Tensor& infos, bool compute_eigenvectors) {
auto options = input.options();
// MAGMA doesn't have GPU interface for GEEV routine, it requires inputs to be on CPU
// therefore we create all intermediate tensors on CPU
auto options = input.options().device(at::kCPU);
// These internal asserts make explicit the assumptions in the implementation
// Error check with the actual error messages are done on the higher level of the hierarchy of calls
@ -2926,13 +2928,16 @@ static std::tuple<Tensor&, Tensor&> linalg_eig_out_info(const Tensor& input, Ten
// for real-valued 'input', eigenvalues can be real-valued or complex-valued
TORCH_INTERNAL_ASSERT_DEBUG_ONLY((toComplexType(input.scalar_type()) == values.scalar_type()) || (input.scalar_type() == values.scalar_type()));
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(values.device() == at::kCPU);
// for real-valued 'input', eigenvectors can be real-valued or complex-valued
if (compute_eigenvectors) {
TORCH_INTERNAL_ASSERT_DEBUG_ONLY((toComplexType(input.scalar_type()) == vectors.scalar_type()) || (input.scalar_type() == vectors.scalar_type()));
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(vectors.device() == at::kCPU);
}
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(infos.scalar_type() == at::kInt);
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(infos.device() == at::kCPU);
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(infos.numel() == std::max<int64_t>(1, batchCount(input)));
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(infos.is_contiguous());
@ -2981,7 +2986,15 @@ static std::tuple<Tensor&, Tensor&> linalg_eig_out_info(const Tensor& input, Ten
}
}
linalg_eig_stub(input.device().type(), real_imag_values, maybe_complex_vectors, infos, input, compute_eigenvectors);
// MAGMA uses a hybrid CPU-GPU algorithm that performs well only for large matrices
// See: https://github.com/pytorch/pytorch/pull/52491#issuecomment-795685687
// Here we call CPU path for matrices smaller than 2048x2048
// that should be in general significantly faster than calling MAGMA
if (input.size(-1) <= 2048) {
linalg_eig_stub(at::kCPU, real_imag_values, maybe_complex_vectors, infos, input.to(kCPU), compute_eigenvectors);
} else {
linalg_eig_stub(input.device().type(), real_imag_values, maybe_complex_vectors, infos, input, compute_eigenvectors);
}
// if input is not complex we need to do some post-processing
if (!input.is_complex()) {
@ -3006,14 +3019,7 @@ static std::tuple<Tensor&, Tensor&> linalg_eig_out_info(const Tensor& input, Ten
}
if (compute_eigenvectors) {
if (vectors.is_complex()) {
// We move to the CPU because linalg_eig_make_complex_eigenvectors requires it.
// Performance note: this function could be implemented via a TensorIterator,
// which would avoid an explicit host-device synchronization.
auto vectors_cpu = vectors.cpu();
auto values_cpu = values.cpu();
auto maybe_complex_vectors_cpu = maybe_complex_vectors.cpu();
vectors_cpu = linalg_eig_make_complex_eigenvectors(vectors_cpu, values_cpu, maybe_complex_vectors_cpu);
vectors.copy_(vectors_cpu);
vectors = linalg_eig_make_complex_eigenvectors(vectors, values, maybe_complex_vectors);
} else {
TORCH_CHECK(false, "torch.linalg.eig: imaginary part of eigenvectors is non-zero, can't safely cast eigenvectors to non-complex dtype.")
}
@ -3033,7 +3039,8 @@ std::tuple<Tensor&, Tensor&> linalg_eig_out(const Tensor& input, Tensor& values,
checkSameDevice("torch.linalg.eig", values, input, "eigenvalues");
checkSameDevice("torch.linalg.eig", vectors, input, "eigenvectors");
auto options = input.options();
// MAGMA doesn't have GPU interface for GEEV routine, it requires inputs to be on CPU
auto options = input.options().device(at::kCPU);
auto infos = at::zeros({std::max<int64_t>(1, batchCount(input))}, options.dtype(kInt));
// if result is not empty and not in batched column major format we have to allocate a temporary tensor
@ -3122,7 +3129,8 @@ Tensor& linalg_eigvals_out(const Tensor& input, Tensor& values) {
checkLinalgCompatibleDtype("torch.linalg.eigvals", values.scalar_type(), toComplexType(input.scalar_type()), "eigenvalues");
checkSameDevice("torch.linalg.eigvals", values, input, "eigenvalues");
auto options = input.options();
// MAGMA doesn't have GPU interface for GEEV routine, it requires inputs to be on CPU
auto options = input.options().device(at::kCPU);
auto infos = at::zeros({std::max<int64_t>(1, batchCount(input))}, options.dtype(kInt));
bool values_expected_type = (values.scalar_type() == toComplexType(input.scalar_type()));
@ -3151,7 +3159,6 @@ Tensor& linalg_eigvals_out(const Tensor& input, Tensor& values) {
}
Tensor vectors;
vectors = at::empty({0}, input.options());
if (values_tmp_needed) {
Tensor values_tmp = at::empty({0}, options.dtype(values_type));
std::tie(values_tmp, std::ignore) = linalg_eig_out_info(input, values_tmp, vectors, infos, /*compute_eigenvectors=*/false);

View File

@ -1881,8 +1881,6 @@ void geqrf_kernel(const Tensor& input, const Tensor& tau) {
REGISTER_CUDA_DISPATCH(geqrf_stub, &geqrf_kernel)
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ linalg_eigh ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
template <typename scalar_t>
static void apply_magma_eigh(const Tensor& values, const Tensor& vectors, const Tensor& infos, bool upper, bool compute_eigenvectors) {
#if !AT_MAGMA_ENABLED()
@ -1957,6 +1955,8 @@ static void apply_magma_eigh(const Tensor& values, const Tensor& vectors, const
#endif
}
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ linalg_eigh ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
// This is a type dispatch function for 'apply_magma_eigh'
// For small inputs result is computed on CPU
void linalg_eigh_magma(const Tensor& eigenvalues, const Tensor& eigenvectors, const Tensor& infos, bool upper, bool compute_eigenvectors) {
@ -2019,10 +2019,10 @@ This is an in-place routine, content of 'input', 'values', 'vectors' is overwrit
For more information see MAGMA's documentation for GEEV routine.
*/
template <typename scalar_t>
void apply_magma_eig(Tensor& values, Tensor& vectors, Tensor& input, Tensor& infos, bool compute_eigenvectors) {
void apply_linalg_eig(Tensor& values, Tensor& vectors, Tensor& input, Tensor& infos, bool compute_eigenvectors) {
#if !AT_MAGMA_ENABLED()
TORCH_CHECK(false, "Calling torch.linalg.eig with MAGMA requires compiling PyTorch with MAGMA. "
"Either transfer the tensor to the CPU before calling torch.linalg.eig or use cuSolver.");
TORCH_CHECK(false, "Calling torch.linalg.eig on a CUDA tensor requires compiling PyTorch with MAGMA. "
"Either transfer the tensor to the CPU before calling torch.linalg.eig or recompile with MAGMA.");
#else
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(input.device() == at::kCPU);
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(values.device() == at::kCPU);
@ -2076,44 +2076,22 @@ TORCH_CHECK(false, "Calling torch.linalg.eig with MAGMA requires compiling PyTor
#endif
}
// MAGMA wrapper: transfers tensors to CPU, calls apply_magma_eig, then copies results back.
void linalg_eig_magma(Tensor& eigenvalues, Tensor& eigenvectors, Tensor& infos, const Tensor& input, bool compute_eigenvectors){
// MAGMA doesn't have GPU interface for the eigendecomposition, and it forces us to transfer to CPU
auto eigenvalues_cpu = eigenvalues.cpu();
auto eigenvectors_cpu = eigenvectors.cpu();
auto infos_cpu = infos.cpu();
Tensor input_cpu = at::empty(input.sizes(), input.options().device(kCPU));
input_cpu.transpose_(-2, -1);
input_cpu.copy_(input);
AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES(input.scalar_type(), "linalg_eig_out_cuda", [&]{
apply_magma_eig<scalar_t>(eigenvalues_cpu, eigenvectors_cpu, input_cpu, infos_cpu, compute_eigenvectors);
});
eigenvalues.copy_(eigenvalues_cpu);
eigenvectors.copy_(eigenvectors_cpu);
infos.copy_(infos_cpu);
}
// This is a type dispatching helper function for 'apply_linalg_eig'
void linalg_eig_kernel(Tensor& eigenvalues, Tensor& eigenvectors, Tensor& infos, const Tensor& input, bool compute_eigenvectors) {
// This function calculates the non-symmetric eigendecomposition in-place
// tensors should be in batched column major memory format
// the content of eigenvalues, eigenvectors and infos is overwritten by 'linalg_eig_magma' or
// 'linalg_eig_cusolver_xgeev' both geev routines modify the provided input matrix in-place, therefore we need a copy
// the content of eigenvalues, eigenvectors and infos is overwritten by 'apply_linalg_eig'
// apply_linalg_eig modifies the provided input matrix in-place, therefore we need a copy
// MAGMA doesn't have GPU interface for the eigendecomposition and it forces us to transfer 'input' to CPU
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(input.is_cuda());
#if defined(CUSOLVER_VERSION) && (CUSOLVER_VERSION >= 11702)
auto preferred_backend = at::globalContext().linalgPreferredBackend();
switch (preferred_backend) {
case at::LinalgBackend::Cusolver:
default:
linalg_eig_cusolver_xgeev(eigenvalues, eigenvectors, input, infos, compute_eigenvectors);
return;
case at::LinalgBackend::Magma:
break; // MAGMA path handled below
}
#endif
linalg_eig_magma(eigenvalues, eigenvectors, infos, input, compute_eigenvectors);
Tensor input_working_copy = at::empty(input.sizes(), input.options().device(kCPU));
input_working_copy.transpose_(-2, -1); // make input_working_copy to have Fortran contiguous memory layout
input_working_copy.copy_(input);
AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES(input.scalar_type(), "linalg_eig_out_cuda", [&]{
apply_linalg_eig<scalar_t>(eigenvalues, eigenvectors, input_working_copy, infos, compute_eigenvectors);
});
}
REGISTER_CUDA_DISPATCH(linalg_eig_stub, &linalg_eig_kernel)

View File

@ -1625,126 +1625,6 @@ void linalg_eigh_cusolver(const Tensor& eigenvalues, const Tensor& eigenvectors,
#endif
}
// cuSOLVER Xgeev (requires cuSOLVER >= 11.7.2, i.e. CUDA 12.8+)
#if defined(CUSOLVER_VERSION) && (CUSOLVER_VERSION >= 11702)
template <typename scalar_t>
void apply_xgeev(const Tensor& values, const Tensor& vectors, const Tensor& input, const Tensor& infos, bool compute_eigenvectors) {
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(values.is_cuda());
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(vectors.is_cuda());
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(input.is_cuda());
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(infos.is_cuda());
int n = cuda_int_cast(input.size(-1), "n");
int lda = std::max<int>(1, n);
auto batch_size = batchCount(input);
if (n == 0 || batch_size == 0) {
// XGeev crashes on empty input, explicitly handle empty input
auto values_shape = IntArrayRef(input.sizes().data(), input.dim() - 1);
values.resize_(values_shape, MemoryFormat::Contiguous);
values.zero_();
if (compute_eigenvectors) {
vectors.resize_(input.sizes(), MemoryFormat::Contiguous);
vectors.zero_();
} else {
vectors.resize_({0});
}
infos.resize_({std::max<int64_t>(1, batch_size)}, MemoryFormat::Contiguous);
infos.zero_();
return;
}
int64_t vectors_stride = 0;
if (compute_eigenvectors){
vectors_stride = matrixStride(vectors);
}
auto values_stride = values.size(-1);
auto vectors_data = vectors.data_ptr<scalar_t>();
auto values_data = values.data_ptr<scalar_t>();
auto infos_data = infos.data_ptr<int>();
cusolverDnParams_t params = nullptr;
TORCH_CUSOLVER_CHECK(cusolverDnCreateParams(&params));
Tensor A_fortran = input.mT().contiguous();
auto* A_data = A_fortran.data_ptr<scalar_t>();
const auto A_stride = matrixStride(A_fortran);
auto handle = at::cuda::getCurrentCUDASolverDnHandle();
const int ldvl = 1; // ldvl >= 1 if jobvl = CUSOLVER_EIG_MODE_NOVECTOR
cusolverEigMode_t jobvl = CUSOLVER_EIG_MODE_NOVECTOR;
cusolverEigMode_t jobvr;
int ldvr;
if (compute_eigenvectors) {
ldvr = n; // ldvr >= n if jobvr = CUSOLVER_EIG_MODE_VECTOR
jobvr = CUSOLVER_EIG_MODE_VECTOR;
}
else {
ldvr = 1; // ldvr >= 1 if jobvr = CUSOLVER_EIG_MODE_NOVECTOR
jobvr = CUSOLVER_EIG_MODE_NOVECTOR;
}
scalar_t* W = values.data_ptr<scalar_t>();
scalar_t* VL = nullptr;
scalar_t* VR = vectors.data_ptr<scalar_t>();
const scalar_t* A_const = A_data;
const scalar_t* W_const = W;
const scalar_t* VL_const = VL;
const scalar_t* VR_const = VR;
size_t ws_dev = 0, ws_host = 0;
at::cuda::solver::xgeev_bufferSize<scalar_t>(
handle, params,
jobvl, jobvr,
n,
A_const, lda,
W_const,
VL_const, ldvl,
VR_const, ldvr,
&ws_dev, &ws_host);
auto& device_allocator = *at::cuda::getCUDADeviceAllocator();
auto work_device_data = device_allocator.allocate(ws_dev);
// use pinned memory for best performance.
auto& host_allocator = *at::cuda::getPinnedMemoryAllocator();
auto work_host_data = host_allocator.allocate(ws_host);
for (decltype(batch_size) i = 0; i < batch_size; ++i) {
scalar_t* Ai = A_data + i * A_stride;
scalar_t* Wi = values_data + i * values_stride;
scalar_t* VLi = nullptr; // xgeev does not support computing left evs
scalar_t* VRi = compute_eigenvectors ? (vectors_data + i * vectors_stride) : nullptr;
int* info = infos_data + i;
at::cuda::solver::xgeev<scalar_t>(
handle, params,
jobvl, jobvr,
n,
Ai, lda,
Wi,
VLi, ldvl,
VRi, ldvr,
static_cast<scalar_t*>(work_device_data.get()), ws_dev,
static_cast<scalar_t*>(work_host_data.get()), ws_host,
info);
}
TORCH_CUSOLVER_CHECK(cusolverDnDestroyParams(params));
}
void linalg_eig_cusolver_xgeev(const Tensor& eigenvalues, const Tensor& eigenvectors, const Tensor& input, const Tensor& infos, bool compute_eigenvectors) {
AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES(eigenvectors.scalar_type(), "linalg_eig_cuda", [&] {
apply_xgeev<scalar_t>(eigenvalues, eigenvectors, input, infos, compute_eigenvectors);
});
}
#endif // defined(CUSOLVER_VERSION) && (CUSOLVER_VERSION >= 11702)
// The 'apply_' word is used for templated by dtype functions that call an API routine
// underneath. Since the cusolver API has a slightly different structure we do not prepend
// apply_ to this function.

View File

@ -73,11 +73,6 @@ void ormqr_cusolver(const Tensor& input, const Tensor& tau, const Tensor& other,
Tensor& orgqr_helper_cusolver(Tensor& result, const Tensor& tau);
void linalg_eigh_cusolver(const Tensor& eigenvalues, const Tensor& eigenvectors, const Tensor& infos, bool upper, bool compute_eigenvectors);
void linalg_eig_cusolver_xgeev(const Tensor& eigenvalues, const Tensor& eigenvectors, const Tensor& input, const Tensor& infos, bool compute_eigenvectors);
void lu_solve_looped_cusolver(const Tensor& LU, const Tensor& pivots, const Tensor& B, TransposeType transpose);
void lu_factor_looped_cusolver(const Tensor& self, const Tensor& pivots, const Tensor& infos, bool get_pivots);

View File

@ -1954,336 +1954,6 @@ void xsyevd<c10::complex<double>, double>(
workspaceInBytesOnHost,
info));
}
// cuSOLVER Xgeev bindings (requires cuSOLVER >= 11.7.2, i.e. CUDA 12.8+)
#if defined(CUSOLVER_VERSION) && (CUSOLVER_VERSION >= 11702)
template <>
void xgeev_bufferSize<float>(
cusolverDnHandle_t handle,
cusolverDnParams_t params,
cusolverEigMode_t jobvl,
cusolverEigMode_t jobvr,
int64_t n,
const float* A,
int64_t lda,
const float* W,
const float* VL,
int64_t ldvl,
const float* VR,
int64_t ldvr,
size_t* workspaceInBytesOnDevice,
size_t* workspaceInBytesOnHost) {
TORCH_CUSOLVER_CHECK(cusolverDnXgeev_bufferSize(
handle, params, jobvl, jobvr, n,
CUDA_R_32F,
reinterpret_cast<const void*>(A),
lda,
CUDA_R_32F,
reinterpret_cast<const void*>(W),
CUDA_R_32F,
reinterpret_cast<const void*>(VL),
ldvl,
CUDA_R_32F,
reinterpret_cast<const void*>(VR),
ldvr,
CUDA_R_32F,
workspaceInBytesOnDevice,
workspaceInBytesOnHost));
}
template <>
void xgeev_bufferSize<double>(
cusolverDnHandle_t handle,
cusolverDnParams_t params,
cusolverEigMode_t jobvl,
cusolverEigMode_t jobvr,
int64_t n,
const double* A,
int64_t lda,
const double* W,
const double* VL,
int64_t ldvl,
const double* VR,
int64_t ldvr,
size_t* workspaceInBytesOnDevice,
size_t* workspaceInBytesOnHost) {
TORCH_CUSOLVER_CHECK(cusolverDnXgeev_bufferSize(
handle, params, jobvl, jobvr, n,
CUDA_R_64F,
reinterpret_cast<const void*>(A),
lda,
CUDA_R_64F,
reinterpret_cast<const void*>(W),
CUDA_R_64F,
reinterpret_cast<const void*>(VL),
ldvl,
CUDA_R_64F,
reinterpret_cast<const void*>(VR),
ldvr,
CUDA_R_64F,
workspaceInBytesOnDevice,
workspaceInBytesOnHost));
}
template <>
void xgeev_bufferSize<c10::complex<float>>(
cusolverDnHandle_t handle,
cusolverDnParams_t params,
cusolverEigMode_t jobvl,
cusolverEigMode_t jobvr,
int64_t n,
const c10::complex<float>* A,
int64_t lda,
const c10::complex<float>* W,
const c10::complex<float>* VL,
int64_t ldvl,
const c10::complex<float>* VR,
int64_t ldvr,
size_t* workspaceInBytesOnDevice,
size_t* workspaceInBytesOnHost) {
TORCH_CUSOLVER_CHECK(cusolverDnXgeev_bufferSize(
handle, params, jobvl, jobvr, n,
CUDA_C_32F,
reinterpret_cast<const void*>(A),
lda,
CUDA_C_32F,
reinterpret_cast<const void*>(W),
CUDA_C_32F,
reinterpret_cast<const void*>(VL),
ldvl,
CUDA_C_32F,
reinterpret_cast<const void*>(VR),
ldvr,
CUDA_C_32F,
workspaceInBytesOnDevice,
workspaceInBytesOnHost));
}
template <>
void xgeev_bufferSize<c10::complex<double>>(
cusolverDnHandle_t handle,
cusolverDnParams_t params,
cusolverEigMode_t jobvl,
cusolverEigMode_t jobvr,
int64_t n,
const c10::complex<double>* A,
int64_t lda,
const c10::complex<double>* W,
const c10::complex<double>* VL,
int64_t ldvl,
const c10::complex<double>* VR,
int64_t ldvr,
size_t* workspaceInBytesOnDevice,
size_t* workspaceInBytesOnHost) {
TORCH_CUSOLVER_CHECK(cusolverDnXgeev_bufferSize(
handle, params, jobvl, jobvr, n,
CUDA_C_64F,
reinterpret_cast<const void*>(A),
lda,
CUDA_C_64F,
reinterpret_cast<const void*>(W),
CUDA_C_64F,
reinterpret_cast<const void*>(VL),
ldvl,
CUDA_C_64F,
reinterpret_cast<const void*>(VR),
ldvr,
CUDA_C_64F,
workspaceInBytesOnDevice,
workspaceInBytesOnHost));
}
template <>
void xgeev<float>(
cusolverDnHandle_t handle,
cusolverDnParams_t params,
cusolverEigMode_t jobvl,
cusolverEigMode_t jobvr,
int64_t n,
float* A,
int64_t lda,
float* W,
float* VL,
int64_t ldvl,
float* VR,
int64_t ldvr,
float* bufferOnDevice,
size_t workspaceInBytesOnDevice,
float* bufferOnHost,
size_t workspaceInBytesOnHost,
int* info) {
TORCH_CUSOLVER_CHECK(cusolverDnXgeev(
handle,
params,
jobvl,
jobvr,
n,
CUDA_R_32F,
reinterpret_cast<void*>(A),
lda,
CUDA_R_32F,
reinterpret_cast<void*>(W),
CUDA_R_32F,
reinterpret_cast<void*>(VL),
ldvl,
CUDA_R_32F,
reinterpret_cast<void*>(VR),
ldvr,
CUDA_R_32F,
reinterpret_cast<void*>(bufferOnDevice),
workspaceInBytesOnDevice,
reinterpret_cast<void*>(bufferOnHost),
workspaceInBytesOnHost,
info));
}
template <>
void xgeev<double>(
cusolverDnHandle_t handle,
cusolverDnParams_t params,
cusolverEigMode_t jobvl,
cusolverEigMode_t jobvr,
int64_t n,
double* A,
int64_t lda,
double* W,
double* VL,
int64_t ldvl,
double* VR,
int64_t ldvr,
double* bufferOnDevice,
size_t workspaceInBytesOnDevice,
double* bufferOnHost,
size_t workspaceInBytesOnHost,
int* info) {
TORCH_CUSOLVER_CHECK(cusolverDnXgeev(
handle,
params,
jobvl,
jobvr,
n,
CUDA_R_64F,
reinterpret_cast<void*>(A),
lda,
CUDA_R_64F,
reinterpret_cast<void*>(W),
CUDA_R_64F,
reinterpret_cast<void*>(VL),
ldvl,
CUDA_R_64F,
reinterpret_cast<void*>(VR),
ldvr,
CUDA_R_64F,
reinterpret_cast<void*>(bufferOnDevice),
workspaceInBytesOnDevice,
reinterpret_cast<void*>(bufferOnHost),
workspaceInBytesOnHost,
info));
}
template <>
void xgeev<c10::complex<float>>(
cusolverDnHandle_t handle,
cusolverDnParams_t params,
cusolverEigMode_t jobvl,
cusolverEigMode_t jobvr,
int64_t n,
c10::complex<float>* A,
int64_t lda,
c10::complex<float>* W,
c10::complex<float>* VL,
int64_t ldvl,
c10::complex<float>* VR,
int64_t ldvr,
c10::complex<float>* bufferOnDevice,
size_t workspaceInBytesOnDevice,
c10::complex<float>* bufferOnHost,
size_t workspaceInBytesOnHost,
int* info) {
TORCH_CUSOLVER_CHECK(cusolverDnXgeev(
handle,
params,
jobvl,
jobvr,
n,
CUDA_C_32F,
reinterpret_cast<void*>(A),
lda,
CUDA_C_32F,
reinterpret_cast<void*>(W),
CUDA_C_32F,
reinterpret_cast<void*>(VL),
ldvl,
CUDA_C_32F,
reinterpret_cast<void*>(VR),
ldvr,
CUDA_C_32F,
reinterpret_cast<void*>(bufferOnDevice),
workspaceInBytesOnDevice,
reinterpret_cast<void*>(bufferOnHost),
workspaceInBytesOnHost,
info));
}
template <>
void xgeev<c10::complex<double>>(
cusolverDnHandle_t handle,
cusolverDnParams_t params,
cusolverEigMode_t jobvl,
cusolverEigMode_t jobvr,
int64_t n,
c10::complex<double>* A,
int64_t lda,
c10::complex<double>* W,
c10::complex<double>* VL,
int64_t ldvl,
c10::complex<double>* VR,
int64_t ldvr,
c10::complex<double>* bufferOnDevice,
size_t workspaceInBytesOnDevice,
c10::complex<double>* bufferOnHost,
size_t workspaceInBytesOnHost,
int* info) {
TORCH_CUSOLVER_CHECK(cusolverDnXgeev(
handle,
params,
jobvl,
jobvr,
n,
CUDA_C_64F,
reinterpret_cast<void*>(A),
lda,
CUDA_C_64F,
reinterpret_cast<void*>(W),
CUDA_C_64F,
reinterpret_cast<void*>(VL),
ldvl,
CUDA_C_64F,
reinterpret_cast<void*>(VR),
ldvr,
CUDA_C_64F,
reinterpret_cast<void*>(bufferOnDevice),
workspaceInBytesOnDevice,
reinterpret_cast<void*>(bufferOnHost),
workspaceInBytesOnHost,
info));
}
#endif // defined(CUSOLVER_VERSION) && (CUSOLVER_VERSION >= 11702)
#endif // USE_CUSOLVER_64_BIT
#ifdef USE_CUSOLVER_64_BIT_XSYEV_BATCHED

View File

@ -674,66 +674,6 @@ template <>
void xsyevd<c10::complex<double>, double>(
CUDASOLVER_XSYEVD_ARGTYPES(c10::complex<double>, double));
// cuSOLVER Xgeev (non-Hermitian eigen decomposition, CUDA >= 12.8)
#if defined(CUSOLVER_VERSION) && (CUSOLVER_VERSION >= 11702)
#define CUDASOLVER_XGEEV_BUFFERSIZE_ARGTYPES(scalar_t) \
cusolverDnHandle_t handle, cusolverDnParams_t params, \
cusolverEigMode_t jobvl, cusolverEigMode_t jobvr, int64_t n, \
const scalar_t* A, int64_t lda, const scalar_t* W, \
const scalar_t* VL, int64_t ldvl, const scalar_t* VR, int64_t ldvr, \
size_t* workspaceInBytesOnDevice, size_t* workspaceInBytesOnHost
template <class scalar_t>
void xgeev_bufferSize(
CUDASOLVER_XGEEV_BUFFERSIZE_ARGTYPES(scalar_t)) {
static_assert(false&&sizeof(scalar_t),
"at::cuda::solver::xgeev_bufferSize: not implemented");
}
template <>
void xgeev_bufferSize<float>(CUDASOLVER_XGEEV_BUFFERSIZE_ARGTYPES(float));
template <>
void xgeev_bufferSize<double>(CUDASOLVER_XGEEV_BUFFERSIZE_ARGTYPES(double));
template <>
void xgeev_bufferSize<c10::complex<float>>(
CUDASOLVER_XGEEV_BUFFERSIZE_ARGTYPES(c10::complex<float>));
template <>
void xgeev_bufferSize<c10::complex<double>>(
CUDASOLVER_XGEEV_BUFFERSIZE_ARGTYPES(c10::complex<double>));
#define CUDASOLVER_XGEEV_ARGTYPES(scalar_t) \
cusolverDnHandle_t handle, cusolverDnParams_t params, \
cusolverEigMode_t jobvl, cusolverEigMode_t jobvr, int64_t n, scalar_t *A, \
int64_t lda, scalar_t *W, scalar_t *VL, int64_t ldvl, scalar_t *VR, int64_t ldvr,\
scalar_t *bufferOnDevice, size_t workspaceInBytesOnDevice, scalar_t *bufferOnHost,\
size_t workspaceInBytesOnHost, int *info
template <class scalar_t>
void xgeev(CUDASOLVER_XGEEV_ARGTYPES(scalar_t)) {
static_assert(false&&sizeof(scalar_t),
"at::cuda::solver::xgeev: not implemented");
}
template <>
void xgeev<float>(CUDASOLVER_XGEEV_ARGTYPES(float));
template <>
void xgeev<double>(CUDASOLVER_XGEEV_ARGTYPES(double));
template <>
void xgeev<c10::complex<float>>(CUDASOLVER_XGEEV_ARGTYPES(c10::complex<float>));
template <>
void xgeev<c10::complex<double>>(CUDASOLVER_XGEEV_ARGTYPES(c10::complex<double>));
#endif // defined(CUSOLVER_VERSION) && (CUSOLVER_VERSION >= 11702)
#endif // USE_CUSOLVER_64_BIT
#ifdef USE_CUSOLVER_64_BIT_XSYEV_BATCHED

View File

@ -794,16 +794,14 @@ class TestFP8Lowering(TestCase):
_get_torch_cuda_version() < (12, 9),
"cuBLAS blockwise scaling added in CUDA 12.9",
)
@parametrize("shape", ((16, 256, 256), (1024, 512, 1024)))
@parametrize("use_fast_accum", (False, True))
@parametrize(
"scaling_block_sizes", ((1, 128, 128, 128), (1, 128, 1, 128))
) # (BlockWise1x128, BlockWise128x128), (BlockWise1x128, BlockWise1x128)
def test_main_loop_scaling(
"shape", ((16, 256, 256), (1024, 512, 1024))
) # TODO (jananisriram): add scaling recipe overrides for shapes like (16, 256, 64) and (256, 16, 64)
@parametrize("use_fast_accum", (False, True))
def test_blockwise1x128_blockwise128x128_scaling(
self,
shape: tuple[int, int, int],
use_fast_accum: bool,
scaling_block_sizes: tuple[int, int, int, int],
):
# Only bf16 output type is supported for non-tensorwise scaling, not fp32
dtype: torch.dtype = torch.bfloat16
@ -816,28 +814,20 @@ class TestFP8Lowering(TestCase):
w = torch.randn(N, K, dtype=dtype, device=device)
bias = None
am, ak, bn, bk = scaling_block_sizes
# quantize weight (prior to inference)
w_fp8, w_inverse_scale = _quantize_blockwise(
w, dtype_float8, block_outer=bn, block_inner=bk
w, dtype_float8, block_outer=128, block_inner=128
)
w_t_fp8 = w_fp8.t()
if (bn, bk) == (1, 128):
w_inverse_scale = (
w_inverse_scale.t().contiguous().t().t()
) # 1x128 blocks need scales to be outer-dim-major
else:
w_inverse_scale = w_inverse_scale.t() # scale_b should be (1, N)
w_inverse_scale = w_inverse_scale.t() # scale_b should be (1, N)
# quantize input x
x_fp8, x_inverse_scale = _quantize_blockwise(
x, dtype_float8, block_outer=am, block_inner=ak
x, dtype_float8, block_outer=1, block_inner=128
)
if (am, ak) == (1, 128):
x_inverse_scale = (
x_inverse_scale.t().contiguous().t()
) # 1x128 blocks need scales to be outer-dim-major
x_inverse_scale = (
x_inverse_scale.t().contiguous().t()
) # 1x128 blocks need scales to be outer-dim-major
def linear(x_fp8, x_inverse_scale, w_t_fp8, w_inverse_scale, bias):
y = torch._scaled_mm(
@ -882,15 +872,9 @@ class TestFP8Lowering(TestCase):
FileCheck().check(
f"SCALE_RECIPE_A : tl.constexpr = {ScalingType.BlockWise1x128.value}"
).run(code[0])
if (bn, bk) == (1, 128):
check_scale_recipe_b = ScalingType.BlockWise1x128.value
else:
check_scale_recipe_b = ScalingType.BlockWise128x128.value
FileCheck().check(
f"SCALE_RECIPE_B : tl.constexpr = {check_scale_recipe_b}"
f"SCALE_RECIPE_B : tl.constexpr = {ScalingType.BlockWise128x128.value}"
).run(code[0])
self.assertEqual(y_eager.dtype, dtype)
self.assertEqual(y_compiled.dtype, dtype)
torch.testing.assert_close(y_eager, y_compiled, rtol=1e-2, atol=0.05)

View File

@ -553,7 +553,7 @@ class TestPatternMatcher(TestCase):
torch.randn(16, 16, device=GPU_TYPE),
torch.randn(16, 16, device=GPU_TYPE),
torch.randn(16, 16, device=GPU_TYPE),
True,
False,
),
(
torch.randn(8, device=GPU_TYPE),
@ -687,17 +687,20 @@ class TestPatternMatcher(TestCase):
FileCheck().check("call").check_not(".run").run(code[0])
def test_cat_addmm(self):
def fn(a, b, c):
def fn(b1, b2, b3, mat1, mat2, mat3):
return torch.cat(
[
torch.addmm(a, b, c),
torch.addmm(b, c, a),
torch.addmm(c, a, b),
torch.addmm(b1, mat1, mat2),
torch.addmm(b2, mat1, mat3),
torch.addmm(b3, mat2, mat3),
],
1,
)
args = [
torch.randn(16, device=GPU_TYPE),
torch.randn(16, device=GPU_TYPE),
torch.randn(16, device=GPU_TYPE),
torch.randn(16, 16, device=GPU_TYPE),
torch.randn(16, 16, device=GPU_TYPE),
torch.randn(16, 16, device=GPU_TYPE),

View File

@ -1693,7 +1693,7 @@ class AssociativeScanHigherOrderVariable(TorchHigherOrderOperatorVariable):
)
from torch._higher_order_ops.utils import _maybe_fake_tracing
from torch._inductor.utils import is_pointwise_use
from torch._inductor.utils import has_only_pointwise_uses
with tx.fake_mode:
sub_args_fake = [
@ -1712,9 +1712,7 @@ class AssociativeScanHigherOrderVariable(TorchHigherOrderOperatorVariable):
for node in fx.graph.nodes:
# Check that the combine_fn is pointwise, if combine_mode='pointwise'
if not all(
is_pointwise_use(use) or use.op == "output" for use in node.users
):
if not has_only_pointwise_uses(node, select_output=True):
raise RuntimeError(
"For combine_mode='pointwise', the combine_fn needs to be pointwise"
)

View File

@ -1505,15 +1505,45 @@ def view_to_reshape(gm):
nd.target = torch.ops.aten.reshape.default
# Relevant for addmm and (add + mm)/(mm + add)
# Follows the dispatch logic for cuBLASLt at
# aten/src/ATen/native/cuda/Blas.cpp::isInputCompliesAddmmCudaLt
def _cublaslt_can_fuse_bias_epilogue(inp, mat1, mat2):
if config.max_autotune_gemm:
return False
# match the dispatch logic for cuBLASLT at aten/src/ATen/native/cuda/Blas.cpp
if not (
inp.is_cuda
and (inp.dim() == 1 or inp.squeeze().dim == 1)
and inp.is_contiguous()
):
return False
if not (mat1.dim() == 2 and mat2.dim() == 2):
return False
if inp.size(0) != mat2.size(1):
return False
if inp.dtype != mat1.dtype or inp.dtype != mat2.dtype:
return False
return True
def should_prefer_unfused_addmm(match):
inp = match.kwargs["inp"]
if not is_gpu(inp.meta["val"].device.type):
return False
return has_uses_tagged_as(
match.output_node(),
(torch.Tag.pointwise, torch.Tag.reduction),
)
if has_uses_tagged_as(
match.output_node(), (torch.Tag.pointwise, torch.Tag.reduction)
):
return True
else:
args_val = (arg.meta["val"] for arg in (inp, *match.args))
return not _cublaslt_can_fuse_bias_epilogue(*args_val)
@register_graph_pattern(

View File

@ -1545,7 +1545,6 @@ scaling_pairs = [
(ScalingType.TensorWise, ScalingType.TensorWise),
(ScalingType.RowWise, ScalingType.RowWise),
(ScalingType.BlockWise1x128, ScalingType.BlockWise128x128),
(ScalingType.BlockWise1x128, ScalingType.BlockWise1x128),
]
@ -1564,15 +1563,11 @@ def _is_rowwise_scaling(sz: Any, transpose: bool) -> bool:
return V.graph.sizevars.statically_known_equals(sz[idx], 1)
def _is_blockwise1xTILESIZE_scaling(
sz: Any, tensor_sz: Any, tile_size: int, transpose: bool
) -> bool:
lhs = 1 if transpose else 0
rhs = 0 if transpose else 1
def _is_blockwise1xTILESIZE_scaling(sz: Any, tensor_sz: Any, tile_size: int) -> bool:
return V.graph.sizevars.statically_known_equals(
sz[lhs], tensor_sz[lhs]
sz[0], tensor_sz[0]
) and V.graph.sizevars.statically_known_equals(
sz[rhs], ceildiv(tensor_sz[rhs], tile_size)
sz[1], ceildiv(tensor_sz[1], tile_size)
)
@ -1594,9 +1589,7 @@ def is_desired_scaling(
case ScalingType.RowWise:
return _is_rowwise_scaling(scale_size, transpose)
case ScalingType.BlockWise1x128:
return _is_blockwise1xTILESIZE_scaling(
scale_size, t.get_size(), 128, transpose
)
return _is_blockwise1xTILESIZE_scaling(scale_size, t.get_size(), 128)
case ScalingType.BlockWise128x128:
return _is_blockwise128x128_scaling(scale_size, t.get_size())
case _:

View File

@ -80,9 +80,9 @@ from .ir import (
from .utils import (
ceildiv,
decode_device,
has_only_pointwise_uses,
is_dynamic,
is_gpu,
is_pointwise_use,
is_view,
needs_fallback_due_to_atomic_add_limitations,
pad_listlike,
@ -1850,10 +1850,7 @@ def cat(inputs, dim=0):
(len(inputs) <= config.max_pointwise_cat_inputs)
and all(op_count(t) <= MAX_SIMPLE_OP_COUNT for t in inputs)
):
pointwise_uses = all(
is_pointwise_use(use, additional_pointwise_ops)
for use in V.current_node.users
)
pointwise_uses = has_only_pointwise_uses(V.current_node)
# fuse in case we will be used in a pointwise node, and there are any inputs we
# we can prevent materialization of.
fuse_pointwise_use = (

View File

@ -529,30 +529,6 @@ def is_view(op: torch._ops.OpOverload) -> bool:
return any(a.alias_info is not None for a in op._schema.arguments)
def is_pointwise_use(
use: Node,
is_pointwise_fn: Callable[[torch._ops.OpOverload], bool] = lambda _: False,
) -> bool:
"""
Do all uses of this op have torch.Tag.pointwise or return True for optional `is_pointwise_fn`
Uses in views ops will follow the views uses
"""
if use.op != "call_function":
return False
if not (
isinstance(use.target, torch._ops.OpOverload) or use.target is operator.getitem
):
return False
target = cast(torch._ops.OpOverload, use.target)
if target is operator.getitem or is_view(target):
return all(is_pointwise_use(u, is_pointwise_fn) for u in use.users)
return torch.Tag.pointwise in target.tags or is_pointwise_fn(target)
class LogicalConnective(enum.Enum):
OR = enum.auto()
AND = enum.auto()
@ -562,6 +538,8 @@ def has_uses(
target: Node,
use_selector_fn: Callable[[torch._ops.OpOverload], bool] = lambda _: False,
use_aggregate_type: LogicalConnective = LogicalConnective.OR,
*,
select_output: bool = False,
) -> bool:
"""
Given a target, explore the uses of `target` by applying `use_selector_fn`
@ -585,6 +563,8 @@ def has_uses(
use_aggregate_fn = get_use_aggregate_fn(use_aggregate_type)
def has_uses_impl(use: Node) -> bool:
if select_output and use.op == "output":
return True
if use.op != "call_function":
return False
if not (
@ -603,17 +583,52 @@ def has_uses(
return use_aggregate_fn(has_uses_impl(user) for user in target.users)
def has_only_uses(
target: Node,
use_selector_fn: Callable[[torch._ops.OpOverload], bool] = lambda _: False,
*,
select_output: bool = False,
) -> bool:
return has_uses(
target, use_selector_fn, LogicalConnective.AND, select_output=select_output
)
def has_uses_tagged_as(
target: Node,
use_tags: Collection[torch.Tag],
use_aggregate_type: LogicalConnective = LogicalConnective.OR,
*,
select_output: bool = False,
) -> bool:
"""
Is there a use with given tags?
"""
return has_uses(
target, lambda use: any(tag in use_tags for tag in use.tags), use_aggregate_type
target,
lambda use: any(tag in use_tags for tag in use.tags),
use_aggregate_type,
select_output=select_output,
)
def has_only_pointwise_uses(
target: Node,
*,
select_output: bool = False,
) -> bool:
"""
Do all uses of target have torch.Tag.pointwise?
Uses in views ops will follow the views uses
"""
return has_uses_tagged_as(
target,
use_tags=(torch.Tag.pointwise,),
use_aggregate_type=LogicalConnective.AND,
select_output=select_output,
)

View File

@ -6365,18 +6365,12 @@ def meta_scaled_mm(
n = mat2.size(1)
is_blockwise_scaling = (
(
scale_a.dtype == torch.float8_e8m0fnu
and scale_b.dtype == torch.float8_e8m0fnu
)
or (
scale_a.dtype == torch.float8_e4m3fn
and scale_b.dtype == torch.float8_e4m3fn
)
) # note: this applies to blockwise scaling for non-FP8 types (FP8 accepts FP32 scales)
def ceil_div(a, b):
return (a + b - 1) // b
scale_a.dtype == torch.float8_e8m0fnu
and scale_b.dtype == torch.float8_e8m0fnu
) or (
scale_a.dtype == torch.float8_e4m3fn
and scale_b.dtype == torch.float8_e4m3fn
)
if scale_a.numel() == 1 and scale_b.numel() == 1:
# tensorwise scaling
@ -6398,6 +6392,9 @@ def meta_scaled_mm(
block_size_mn = 128
def ceil_div(a, b):
return (a + b - 1) // b
num_k_blocks = ceil_div(_k, block_size_k)
padded_num_k_blocks = ceil_div(num_k_blocks, 4) * 4
@ -6453,18 +6450,11 @@ def meta_scaled_mm(
)
elif (
scale_a.size(0) == m
and scale_a.size(1) == scale_b.size(0) == ceil_div(_k, 128)
and scale_b.size(1) == ceil_div(n, 128)
and scale_a.size(1) == scale_b.size(0) == (_k + 128 - 1) // 128
and scale_b.size(1) == (n + 128 - 1) // 128
):
# (BlockWise1x128, BlockWise128x128)
pass # do nothing, but do not error
elif (
scale_a.size(0) == m
and scale_a.size(1) == scale_b.size(0) == ceil_div(_k, 128)
and scale_b.size(1) == n
):
# (BlockWise1x128, BlockWise1x128)
pass # do nothing, but do not error
else:
# does not match any valid scaling type
torch._check(
@ -6473,10 +6463,8 @@ def meta_scaled_mm(
"Invalid scaling configuration. "
"For tensorwise scaling, both scales should be scalar. "
f"For rowwise scaling, scale_a should be ({m}, 1), scale_b should be (1, {n}). "
f"For (BlockWise1x128, BlockWise128x128), scale_a should be ({m}, {ceil_div(_k, 128)}), "
+ f"scale_b should be ({ceil_div(_k, 128)}, {ceil_div(n, 128)}). "
f"For (BlockWise1x128, BlockWise1x128), scale_a should be ({m}, {ceil_div(_k, 128)}), "
+ f"scale_b should be ({ceil_div(_k, 128)}, {n}). "
f"For (BlockWise1x128, BlockWise128x128), scale_a should be ({m}, {(_k + 128 - 1) // 128}), "
+ f"scale_b should be ({(_k + 128 - 1) // 128}, {(n + 128 - 1) // 128}). "
f"Got scale_a.size()=({scale_a.size(0)}, {scale_a.size(1)}) "
f"and scale_b.size()=({scale_b.size(0)}, {scale_b.size(1)})"
),