mirror of
				https://github.com/pytorch/pytorch.git
				synced 2025-11-04 08:00:58 +08:00 
			
		
		
		
	Compare commits
	
		
			70 Commits
		
	
	
		
			viable/str
			...
			ciflow/tru
		
	
	| Author | SHA1 | Date | |
|---|---|---|---|
| acf8783a5b | |||
| 1946c2368f | |||
| 341514bf45 | |||
| 8120d6004e | |||
| d6222a3b82 | |||
| 18af57b76b | |||
| 379d5606b9 | |||
| 70518444a8 | |||
| 55a42d3a3e | |||
| 793a5e2f86 | |||
| e7a108c32b | |||
| dd5a8d3fc8 | |||
| 1d82e429d7 | |||
| 03e312c85e | |||
| 2540eaff4d | |||
| b1e44a9ff1 | |||
| 7c585b11f9 | |||
| 77330f39e4 | |||
| eafb84aebd | |||
| bf5e9cc835 | |||
| 83370ee71f | |||
| 8e2e74b12a | |||
| dbb5565e17 | |||
| 3a41807da9 | |||
| 670828a5bb | |||
| b0c4e5ce92 | |||
| 23d74eb617 | |||
| 3cc8b64300 | |||
| 5282147127 | |||
| 367f40a7e0 | |||
| c67c516653 | |||
| a451258d9c | |||
| ef90141cbf | |||
| 5d84e13851 | |||
| ba63727c2e | |||
| 02a255ae8e | |||
| 329d47c055 | |||
| 295a042e39 | |||
| 82e7131068 | |||
| b61a11e8d9 | |||
| 00a615d1e2 | |||
| 9f582f55af | |||
| c4fc3c53e1 | |||
| cd82b0f7d9 | |||
| 139222da06 | |||
| 2450d02e97 | |||
| 0820b97e78 | |||
| b421538f59 | |||
| 3efbfb3f6f | |||
| 2325197448 | |||
| a849ab3e44 | |||
| 994fe49902 | |||
| bc85bf7ed1 | |||
| ba1fe373be | |||
| 28a754f37b | |||
| cc7c1c81f6 | |||
| f78c4dee42 | |||
| 2514f9d62f | |||
| 51122c815f | |||
| d7fd08839f | |||
| 8f73b7cb35 | |||
| c587a960fb | |||
| 25d8411fb5 | |||
| eabecd05c5 | |||
| edca2c8698 | |||
| f2de9313f4 | |||
| 3ac256e289 | |||
| 94055d73d4 | |||
| dc09f97271 | |||
| 841b8a27b8 | 
							
								
								
									
										3
									
								
								.github/workflows/inductor-rocm.yml
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										3
									
								
								.github/workflows/inductor-rocm.yml
									
									
									
									
										vendored
									
									
								
							@ -1,10 +1,9 @@
 | 
			
		||||
name: inductor-rocm
 | 
			
		||||
 | 
			
		||||
on:
 | 
			
		||||
  schedule:
 | 
			
		||||
    - cron: 0 * * * *
 | 
			
		||||
  push:
 | 
			
		||||
    branches:
 | 
			
		||||
      - main
 | 
			
		||||
      - release/*
 | 
			
		||||
    tags:
 | 
			
		||||
      - ciflow/inductor-rocm/*
 | 
			
		||||
 | 
			
		||||
							
								
								
									
										2
									
								
								.github/workflows/nightly.yml
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										2
									
								
								.github/workflows/nightly.yml
									
									
									
									
										vendored
									
									
								
							@ -41,7 +41,7 @@ jobs:
 | 
			
		||||
    uses: ./.github/workflows/_linux-build.yml
 | 
			
		||||
    needs: get-label-type
 | 
			
		||||
    with:
 | 
			
		||||
      runner_prefix: "${{ needs.get-label-type.outputs.label-type }}"
 | 
			
		||||
      runner: "${{ needs.get-label-type.outputs.label-type }}linux.2xlarge"
 | 
			
		||||
      build-environment: linux-jammy-py3.10-gcc11
 | 
			
		||||
      docker-image-name: ci-image:pytorch-linux-jammy-py3.10-gcc11
 | 
			
		||||
    secrets: inherit
 | 
			
		||||
 | 
			
		||||
							
								
								
									
										8
									
								
								.github/workflows/pull.yml
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										8
									
								
								.github/workflows/pull.yml
									
									
									
									
										vendored
									
									
								
							@ -66,10 +66,10 @@ jobs:
 | 
			
		||||
          { config: "default", shard: 5, num_shards: 5, runner: "${{ needs.get-label-type.outputs.label-type }}linux.2xlarge" },
 | 
			
		||||
          { config: "docs_test", shard: 1, num_shards: 1,  runner: "${{ needs.get-label-type.outputs.label-type }}linux.2xlarge" },
 | 
			
		||||
          { config: "jit_legacy", shard: 1, num_shards: 1, runner: "${{ needs.get-label-type.outputs.label-type }}linux.2xlarge" },
 | 
			
		||||
          { config: "backwards_compat", shard: 1, num_shards: 1, runner: "${{ needs.get-label-type.outputs.label-type }}linux.c7i.2xlarge" },
 | 
			
		||||
          { config: "backwards_compat", shard: 1, num_shards: 1, runner: "${{ needs.get-label-type.outputs.label-type }}linux.2xlarge" },
 | 
			
		||||
          { config: "distributed", shard: 1, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}linux.2xlarge" },
 | 
			
		||||
          { config: "distributed", shard: 2, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}linux.2xlarge" },
 | 
			
		||||
          { config: "numpy_2_x", shard: 1, num_shards: 1, runner: "${{ needs.get-label-type.outputs.label-type }}linux.c7i.2xlarge" },
 | 
			
		||||
          { config: "numpy_2_x", shard: 1, num_shards: 1, runner: "${{ needs.get-label-type.outputs.label-type }}linux.2xlarge" },
 | 
			
		||||
        ]}
 | 
			
		||||
    secrets: inherit
 | 
			
		||||
 | 
			
		||||
@ -167,8 +167,8 @@ jobs:
 | 
			
		||||
      docker-image-name: ci-image:pytorch-linux-jammy-py3-clang12-onnx
 | 
			
		||||
      test-matrix: |
 | 
			
		||||
        { include: [
 | 
			
		||||
          { config: "default", shard: 1, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}linux.c7i.2xlarge" },
 | 
			
		||||
          { config: "default", shard: 2, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}linux.c7i.2xlarge" },
 | 
			
		||||
          { config: "default", shard: 1, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}linux.2xlarge" },
 | 
			
		||||
          { config: "default", shard: 2, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}linux.2xlarge" },
 | 
			
		||||
        ]}
 | 
			
		||||
    secrets: inherit
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
							
								
								
									
										2
									
								
								.github/workflows/rocm.yml
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										2
									
								
								.github/workflows/rocm.yml
									
									
									
									
										vendored
									
									
								
							@ -3,13 +3,13 @@ name: rocm
 | 
			
		||||
on:
 | 
			
		||||
  push:
 | 
			
		||||
    branches:
 | 
			
		||||
      - main
 | 
			
		||||
      - release/*
 | 
			
		||||
    tags:
 | 
			
		||||
      - ciflow/rocm/*
 | 
			
		||||
  workflow_dispatch:
 | 
			
		||||
  schedule:
 | 
			
		||||
    - cron: 29 8 * * *  # about 1:29am PDT
 | 
			
		||||
    - cron: 0 * * * *
 | 
			
		||||
 | 
			
		||||
concurrency:
 | 
			
		||||
  group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.ref_name }}-${{ github.ref_type == 'branch' && github.sha }}-${{ github.event_name == 'workflow_dispatch' }}-${{ github.event_name == 'schedule' }}
 | 
			
		||||
 | 
			
		||||
@ -1,6 +1,6 @@
 | 
			
		||||
#include <ATen/cuda/CUDAGreenContext.h>
 | 
			
		||||
 | 
			
		||||
#if defined(CUDA_VERSION) && (CUDA_VERSION >= 12030) && !defined(USE_ROCM) && defined(PYTORCH_C10_DRIVER_API_SUPPORTED)
 | 
			
		||||
#if defined(CUDA_VERSION) && !defined(USE_ROCM) && defined(PYTORCH_C10_DRIVER_API_SUPPORTED)
 | 
			
		||||
#include <c10/cuda/driver_api.h>
 | 
			
		||||
#include <stdexcept>
 | 
			
		||||
#include <vector>
 | 
			
		||||
 | 
			
		||||
@ -2917,7 +2917,9 @@ static Tensor& linalg_eig_make_complex_eigenvectors(Tensor& complex_vectors, con
 | 
			
		||||
DEFINE_DISPATCH(linalg_eig_stub);
 | 
			
		||||
 | 
			
		||||
static std::tuple<Tensor&, Tensor&> linalg_eig_out_info(const Tensor& input, Tensor& values, Tensor& vectors, Tensor& infos, bool compute_eigenvectors) {
 | 
			
		||||
  auto options = input.options();
 | 
			
		||||
  // MAGMA doesn't have GPU interface for GEEV routine, it requires inputs to be on CPU
 | 
			
		||||
  // therefore we create all intermediate tensors on CPU
 | 
			
		||||
  auto options = input.options().device(at::kCPU);
 | 
			
		||||
 | 
			
		||||
  // These internal asserts make explicit the assumptions in the implementation
 | 
			
		||||
  // Error check with the actual error messages are done on the higher level of the hierarchy of calls
 | 
			
		||||
@ -2926,13 +2928,16 @@ static std::tuple<Tensor&, Tensor&> linalg_eig_out_info(const Tensor& input, Ten
 | 
			
		||||
 | 
			
		||||
  // for real-valued 'input', eigenvalues can be real-valued or complex-valued
 | 
			
		||||
  TORCH_INTERNAL_ASSERT_DEBUG_ONLY((toComplexType(input.scalar_type()) == values.scalar_type()) || (input.scalar_type() == values.scalar_type()));
 | 
			
		||||
  TORCH_INTERNAL_ASSERT_DEBUG_ONLY(values.device() == at::kCPU);
 | 
			
		||||
 | 
			
		||||
  // for real-valued 'input', eigenvectors can be real-valued or complex-valued
 | 
			
		||||
  if (compute_eigenvectors) {
 | 
			
		||||
    TORCH_INTERNAL_ASSERT_DEBUG_ONLY((toComplexType(input.scalar_type()) == vectors.scalar_type()) || (input.scalar_type() == vectors.scalar_type()));
 | 
			
		||||
    TORCH_INTERNAL_ASSERT_DEBUG_ONLY(vectors.device() == at::kCPU);
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  TORCH_INTERNAL_ASSERT_DEBUG_ONLY(infos.scalar_type() == at::kInt);
 | 
			
		||||
  TORCH_INTERNAL_ASSERT_DEBUG_ONLY(infos.device() == at::kCPU);
 | 
			
		||||
  TORCH_INTERNAL_ASSERT_DEBUG_ONLY(infos.numel() == std::max<int64_t>(1, batchCount(input)));
 | 
			
		||||
  TORCH_INTERNAL_ASSERT_DEBUG_ONLY(infos.is_contiguous());
 | 
			
		||||
 | 
			
		||||
@ -2981,7 +2986,15 @@ static std::tuple<Tensor&, Tensor&> linalg_eig_out_info(const Tensor& input, Ten
 | 
			
		||||
    }
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  linalg_eig_stub(input.device().type(), real_imag_values, maybe_complex_vectors, infos, input, compute_eigenvectors);
 | 
			
		||||
  // MAGMA uses a hybrid CPU-GPU algorithm that performs well only for large matrices
 | 
			
		||||
  // See: https://github.com/pytorch/pytorch/pull/52491#issuecomment-795685687
 | 
			
		||||
  // Here we call CPU path for matrices smaller than 2048x2048
 | 
			
		||||
  // that should be in general significantly faster than calling MAGMA
 | 
			
		||||
  if (input.size(-1) <= 2048) {
 | 
			
		||||
    linalg_eig_stub(at::kCPU, real_imag_values, maybe_complex_vectors, infos, input.to(kCPU), compute_eigenvectors);
 | 
			
		||||
  } else {
 | 
			
		||||
    linalg_eig_stub(input.device().type(), real_imag_values, maybe_complex_vectors, infos, input, compute_eigenvectors);
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  // if input is not complex we need to do some post-processing
 | 
			
		||||
  if (!input.is_complex()) {
 | 
			
		||||
@ -3006,14 +3019,7 @@ static std::tuple<Tensor&, Tensor&> linalg_eig_out_info(const Tensor& input, Ten
 | 
			
		||||
    }
 | 
			
		||||
    if (compute_eigenvectors) {
 | 
			
		||||
      if (vectors.is_complex()) {
 | 
			
		||||
        // We move to the CPU because linalg_eig_make_complex_eigenvectors requires it.
 | 
			
		||||
        // Performance note: this function could be implemented via a TensorIterator,
 | 
			
		||||
        // which would avoid an explicit host-device synchronization.
 | 
			
		||||
        auto vectors_cpu = vectors.cpu();
 | 
			
		||||
        auto values_cpu  = values.cpu();
 | 
			
		||||
        auto maybe_complex_vectors_cpu = maybe_complex_vectors.cpu();
 | 
			
		||||
        vectors_cpu = linalg_eig_make_complex_eigenvectors(vectors_cpu, values_cpu, maybe_complex_vectors_cpu);
 | 
			
		||||
        vectors.copy_(vectors_cpu);
 | 
			
		||||
          vectors = linalg_eig_make_complex_eigenvectors(vectors, values, maybe_complex_vectors);
 | 
			
		||||
      } else {
 | 
			
		||||
        TORCH_CHECK(false, "torch.linalg.eig: imaginary part of eigenvectors is non-zero, can't safely cast eigenvectors to non-complex dtype.")
 | 
			
		||||
      }
 | 
			
		||||
@ -3033,7 +3039,8 @@ std::tuple<Tensor&, Tensor&> linalg_eig_out(const Tensor& input, Tensor& values,
 | 
			
		||||
  checkSameDevice("torch.linalg.eig", values, input, "eigenvalues");
 | 
			
		||||
  checkSameDevice("torch.linalg.eig", vectors, input, "eigenvectors");
 | 
			
		||||
 | 
			
		||||
  auto options = input.options();
 | 
			
		||||
  // MAGMA doesn't have GPU interface for GEEV routine, it requires inputs to be on CPU
 | 
			
		||||
  auto options = input.options().device(at::kCPU);
 | 
			
		||||
  auto infos = at::zeros({std::max<int64_t>(1, batchCount(input))}, options.dtype(kInt));
 | 
			
		||||
 | 
			
		||||
  // if result is not empty and not in batched column major format we have to allocate a temporary tensor
 | 
			
		||||
@ -3122,7 +3129,8 @@ Tensor& linalg_eigvals_out(const Tensor& input, Tensor& values) {
 | 
			
		||||
  checkLinalgCompatibleDtype("torch.linalg.eigvals", values.scalar_type(), toComplexType(input.scalar_type()), "eigenvalues");
 | 
			
		||||
  checkSameDevice("torch.linalg.eigvals", values, input, "eigenvalues");
 | 
			
		||||
 | 
			
		||||
  auto options = input.options();
 | 
			
		||||
  // MAGMA doesn't have GPU interface for GEEV routine, it requires inputs to be on CPU
 | 
			
		||||
  auto options = input.options().device(at::kCPU);
 | 
			
		||||
  auto infos = at::zeros({std::max<int64_t>(1, batchCount(input))}, options.dtype(kInt));
 | 
			
		||||
 | 
			
		||||
  bool values_expected_type = (values.scalar_type() == toComplexType(input.scalar_type()));
 | 
			
		||||
@ -3151,7 +3159,6 @@ Tensor& linalg_eigvals_out(const Tensor& input, Tensor& values) {
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  Tensor vectors;
 | 
			
		||||
  vectors = at::empty({0}, input.options());
 | 
			
		||||
  if (values_tmp_needed) {
 | 
			
		||||
    Tensor values_tmp = at::empty({0}, options.dtype(values_type));
 | 
			
		||||
    std::tie(values_tmp, std::ignore) = linalg_eig_out_info(input, values_tmp, vectors, infos, /*compute_eigenvectors=*/false);
 | 
			
		||||
 | 
			
		||||
@ -1881,8 +1881,6 @@ void geqrf_kernel(const Tensor& input, const Tensor& tau) {
 | 
			
		||||
 | 
			
		||||
REGISTER_CUDA_DISPATCH(geqrf_stub, &geqrf_kernel)
 | 
			
		||||
 | 
			
		||||
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ linalg_eigh ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
 | 
			
		||||
 | 
			
		||||
template <typename scalar_t>
 | 
			
		||||
static void apply_magma_eigh(const Tensor& values, const Tensor& vectors, const Tensor& infos, bool upper, bool compute_eigenvectors) {
 | 
			
		||||
#if !AT_MAGMA_ENABLED()
 | 
			
		||||
@ -1957,6 +1955,8 @@ static void apply_magma_eigh(const Tensor& values, const Tensor& vectors, const
 | 
			
		||||
#endif
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ linalg_eigh ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
 | 
			
		||||
 | 
			
		||||
// This is a type dispatch function for 'apply_magma_eigh'
 | 
			
		||||
// For small inputs result is computed on CPU
 | 
			
		||||
void linalg_eigh_magma(const Tensor& eigenvalues, const Tensor& eigenvectors, const Tensor& infos, bool upper, bool compute_eigenvectors) {
 | 
			
		||||
@ -2019,10 +2019,10 @@ This is an in-place routine, content of 'input', 'values', 'vectors' is overwrit
 | 
			
		||||
For more information see MAGMA's documentation for GEEV routine.
 | 
			
		||||
*/
 | 
			
		||||
template <typename scalar_t>
 | 
			
		||||
void apply_magma_eig(Tensor& values, Tensor& vectors, Tensor& input, Tensor& infos, bool compute_eigenvectors) {
 | 
			
		||||
void apply_linalg_eig(Tensor& values, Tensor& vectors, Tensor& input, Tensor& infos, bool compute_eigenvectors) {
 | 
			
		||||
#if !AT_MAGMA_ENABLED()
 | 
			
		||||
TORCH_CHECK(false, "Calling torch.linalg.eig with MAGMA requires compiling PyTorch with MAGMA. "
 | 
			
		||||
                   "Either transfer the tensor to the CPU before calling torch.linalg.eig or use cuSolver.");
 | 
			
		||||
TORCH_CHECK(false, "Calling torch.linalg.eig on a CUDA tensor requires compiling PyTorch with MAGMA. "
 | 
			
		||||
                   "Either transfer the tensor to the CPU before calling torch.linalg.eig or recompile with MAGMA.");
 | 
			
		||||
#else
 | 
			
		||||
  TORCH_INTERNAL_ASSERT_DEBUG_ONLY(input.device() == at::kCPU);
 | 
			
		||||
  TORCH_INTERNAL_ASSERT_DEBUG_ONLY(values.device() == at::kCPU);
 | 
			
		||||
@ -2076,44 +2076,22 @@ TORCH_CHECK(false, "Calling torch.linalg.eig with MAGMA requires compiling PyTor
 | 
			
		||||
#endif
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// MAGMA wrapper: transfers tensors to CPU, calls apply_magma_eig, then copies results back.
 | 
			
		||||
void linalg_eig_magma(Tensor& eigenvalues, Tensor& eigenvectors, Tensor& infos, const Tensor& input, bool compute_eigenvectors){
 | 
			
		||||
  // MAGMA doesn't have GPU interface for the eigendecomposition, and it forces us to transfer to CPU
 | 
			
		||||
  auto eigenvalues_cpu = eigenvalues.cpu();
 | 
			
		||||
  auto eigenvectors_cpu = eigenvectors.cpu();
 | 
			
		||||
  auto infos_cpu = infos.cpu();
 | 
			
		||||
 | 
			
		||||
  Tensor input_cpu = at::empty(input.sizes(), input.options().device(kCPU));
 | 
			
		||||
  input_cpu.transpose_(-2, -1);
 | 
			
		||||
  input_cpu.copy_(input);
 | 
			
		||||
 | 
			
		||||
  AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES(input.scalar_type(), "linalg_eig_out_cuda", [&]{
 | 
			
		||||
    apply_magma_eig<scalar_t>(eigenvalues_cpu, eigenvectors_cpu, input_cpu, infos_cpu, compute_eigenvectors);
 | 
			
		||||
  });
 | 
			
		||||
 | 
			
		||||
  eigenvalues.copy_(eigenvalues_cpu);
 | 
			
		||||
  eigenvectors.copy_(eigenvectors_cpu);
 | 
			
		||||
  infos.copy_(infos_cpu);
 | 
			
		||||
}
 | 
			
		||||
// This is a type dispatching helper function for 'apply_linalg_eig'
 | 
			
		||||
void linalg_eig_kernel(Tensor& eigenvalues, Tensor& eigenvectors, Tensor& infos, const Tensor& input, bool compute_eigenvectors) {
 | 
			
		||||
  // This function calculates the non-symmetric eigendecomposition in-place
 | 
			
		||||
  // tensors should be in batched column major memory format
 | 
			
		||||
  // the content of eigenvalues, eigenvectors and infos is overwritten by 'linalg_eig_magma' or
 | 
			
		||||
  // 'linalg_eig_cusolver_xgeev' both geev routines modify the provided input matrix in-place, therefore we need a copy
 | 
			
		||||
  // the content of eigenvalues, eigenvectors and infos is overwritten by 'apply_linalg_eig'
 | 
			
		||||
 | 
			
		||||
  // apply_linalg_eig modifies the provided input matrix in-place, therefore we need a copy
 | 
			
		||||
  // MAGMA doesn't have GPU interface for the eigendecomposition and it forces us to transfer 'input' to CPU
 | 
			
		||||
  TORCH_INTERNAL_ASSERT_DEBUG_ONLY(input.is_cuda());
 | 
			
		||||
#if defined(CUSOLVER_VERSION) && (CUSOLVER_VERSION >= 11702)
 | 
			
		||||
  auto preferred_backend = at::globalContext().linalgPreferredBackend();
 | 
			
		||||
  switch (preferred_backend) {
 | 
			
		||||
    case at::LinalgBackend::Cusolver:
 | 
			
		||||
    default:
 | 
			
		||||
      linalg_eig_cusolver_xgeev(eigenvalues, eigenvectors, input, infos, compute_eigenvectors);
 | 
			
		||||
      return;
 | 
			
		||||
    case at::LinalgBackend::Magma:
 | 
			
		||||
      break; // MAGMA path handled below
 | 
			
		||||
  }
 | 
			
		||||
#endif
 | 
			
		||||
  linalg_eig_magma(eigenvalues, eigenvectors, infos, input, compute_eigenvectors);
 | 
			
		||||
  Tensor input_working_copy = at::empty(input.sizes(), input.options().device(kCPU));
 | 
			
		||||
  input_working_copy.transpose_(-2, -1);  // make input_working_copy to have Fortran contiguous memory layout
 | 
			
		||||
  input_working_copy.copy_(input);
 | 
			
		||||
 | 
			
		||||
  AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES(input.scalar_type(), "linalg_eig_out_cuda", [&]{
 | 
			
		||||
    apply_linalg_eig<scalar_t>(eigenvalues, eigenvectors, input_working_copy, infos, compute_eigenvectors);
 | 
			
		||||
  });
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
REGISTER_CUDA_DISPATCH(linalg_eig_stub, &linalg_eig_kernel)
 | 
			
		||||
 | 
			
		||||
@ -1625,126 +1625,6 @@ void linalg_eigh_cusolver(const Tensor& eigenvalues, const Tensor& eigenvectors,
 | 
			
		||||
#endif
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// cuSOLVER Xgeev (requires cuSOLVER >= 11.7.2, i.e. CUDA 12.8+)
 | 
			
		||||
#if defined(CUSOLVER_VERSION) && (CUSOLVER_VERSION >= 11702)
 | 
			
		||||
 | 
			
		||||
template <typename scalar_t>
 | 
			
		||||
void apply_xgeev(const Tensor& values, const Tensor& vectors, const Tensor& input, const Tensor& infos, bool compute_eigenvectors) {
 | 
			
		||||
  TORCH_INTERNAL_ASSERT_DEBUG_ONLY(values.is_cuda());
 | 
			
		||||
  TORCH_INTERNAL_ASSERT_DEBUG_ONLY(vectors.is_cuda());
 | 
			
		||||
  TORCH_INTERNAL_ASSERT_DEBUG_ONLY(input.is_cuda());
 | 
			
		||||
  TORCH_INTERNAL_ASSERT_DEBUG_ONLY(infos.is_cuda());
 | 
			
		||||
 | 
			
		||||
  int   n   = cuda_int_cast(input.size(-1), "n");
 | 
			
		||||
  int   lda = std::max<int>(1, n);
 | 
			
		||||
  auto  batch_size = batchCount(input);
 | 
			
		||||
 | 
			
		||||
  if (n == 0 || batch_size == 0) {
 | 
			
		||||
    // XGeev crashes on empty input, explicitly handle empty input
 | 
			
		||||
    auto values_shape = IntArrayRef(input.sizes().data(), input.dim() - 1);
 | 
			
		||||
    values.resize_(values_shape, MemoryFormat::Contiguous);
 | 
			
		||||
    values.zero_();
 | 
			
		||||
 | 
			
		||||
    if (compute_eigenvectors) {
 | 
			
		||||
      vectors.resize_(input.sizes(), MemoryFormat::Contiguous);
 | 
			
		||||
      vectors.zero_();
 | 
			
		||||
    } else {
 | 
			
		||||
      vectors.resize_({0});
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    infos.resize_({std::max<int64_t>(1, batch_size)}, MemoryFormat::Contiguous);
 | 
			
		||||
    infos.zero_();
 | 
			
		||||
    return;
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  int64_t vectors_stride = 0;
 | 
			
		||||
  if (compute_eigenvectors){
 | 
			
		||||
    vectors_stride = matrixStride(vectors);
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  auto values_stride = values.size(-1);
 | 
			
		||||
  auto vectors_data = vectors.data_ptr<scalar_t>();
 | 
			
		||||
  auto values_data = values.data_ptr<scalar_t>();
 | 
			
		||||
  auto infos_data = infos.data_ptr<int>();
 | 
			
		||||
 | 
			
		||||
  cusolverDnParams_t params = nullptr;
 | 
			
		||||
  TORCH_CUSOLVER_CHECK(cusolverDnCreateParams(¶ms));
 | 
			
		||||
 | 
			
		||||
  Tensor A_fortran = input.mT().contiguous();
 | 
			
		||||
  auto* A_data = A_fortran.data_ptr<scalar_t>();
 | 
			
		||||
  const auto A_stride = matrixStride(A_fortran);
 | 
			
		||||
  auto handle = at::cuda::getCurrentCUDASolverDnHandle();
 | 
			
		||||
 | 
			
		||||
  const int ldvl = 1; // ldvl >= 1 if jobvl = CUSOLVER_EIG_MODE_NOVECTOR
 | 
			
		||||
  cusolverEigMode_t jobvl = CUSOLVER_EIG_MODE_NOVECTOR;
 | 
			
		||||
 | 
			
		||||
  cusolverEigMode_t jobvr;
 | 
			
		||||
  int ldvr;
 | 
			
		||||
  if (compute_eigenvectors) {
 | 
			
		||||
    ldvr = n; // ldvr >= n if jobvr = CUSOLVER_EIG_MODE_VECTOR
 | 
			
		||||
    jobvr = CUSOLVER_EIG_MODE_VECTOR;
 | 
			
		||||
  }
 | 
			
		||||
  else {
 | 
			
		||||
    ldvr = 1; // ldvr >= 1 if jobvr = CUSOLVER_EIG_MODE_NOVECTOR
 | 
			
		||||
    jobvr = CUSOLVER_EIG_MODE_NOVECTOR;
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  scalar_t*   W   = values.data_ptr<scalar_t>();
 | 
			
		||||
  scalar_t*   VL  = nullptr;
 | 
			
		||||
  scalar_t*   VR  = vectors.data_ptr<scalar_t>();
 | 
			
		||||
 | 
			
		||||
  const scalar_t*   A_const = A_data;
 | 
			
		||||
  const scalar_t*   W_const = W;
 | 
			
		||||
  const scalar_t*   VL_const = VL;
 | 
			
		||||
  const scalar_t*   VR_const = VR;
 | 
			
		||||
 | 
			
		||||
  size_t ws_dev = 0, ws_host = 0;
 | 
			
		||||
  at::cuda::solver::xgeev_bufferSize<scalar_t>(
 | 
			
		||||
    handle, params,
 | 
			
		||||
    jobvl, jobvr,
 | 
			
		||||
    n,
 | 
			
		||||
    A_const, lda,
 | 
			
		||||
    W_const,
 | 
			
		||||
    VL_const, ldvl,
 | 
			
		||||
    VR_const, ldvr,
 | 
			
		||||
    &ws_dev, &ws_host);
 | 
			
		||||
 | 
			
		||||
  auto& device_allocator  = *at::cuda::getCUDADeviceAllocator();
 | 
			
		||||
  auto  work_device_data  = device_allocator.allocate(ws_dev);
 | 
			
		||||
  // use pinned memory for best performance.
 | 
			
		||||
  auto& host_allocator    = *at::cuda::getPinnedMemoryAllocator();
 | 
			
		||||
  auto  work_host_data    = host_allocator.allocate(ws_host);
 | 
			
		||||
 | 
			
		||||
  for (decltype(batch_size) i = 0; i < batch_size; ++i) {
 | 
			
		||||
    scalar_t* Ai   = A_data      + i * A_stride;
 | 
			
		||||
    scalar_t* Wi   = values_data + i * values_stride;
 | 
			
		||||
    scalar_t* VLi  = nullptr; // xgeev does not support computing left evs
 | 
			
		||||
    scalar_t* VRi  = compute_eigenvectors ? (vectors_data + i * vectors_stride) : nullptr;
 | 
			
		||||
    int*      info = infos_data + i;
 | 
			
		||||
 | 
			
		||||
    at::cuda::solver::xgeev<scalar_t>(
 | 
			
		||||
      handle, params,
 | 
			
		||||
      jobvl, jobvr,
 | 
			
		||||
      n,
 | 
			
		||||
      Ai, lda,
 | 
			
		||||
      Wi,
 | 
			
		||||
      VLi, ldvl,
 | 
			
		||||
      VRi, ldvr,
 | 
			
		||||
      static_cast<scalar_t*>(work_device_data.get()), ws_dev,
 | 
			
		||||
      static_cast<scalar_t*>(work_host_data.get()),  ws_host,
 | 
			
		||||
      info);
 | 
			
		||||
  }
 | 
			
		||||
  TORCH_CUSOLVER_CHECK(cusolverDnDestroyParams(params));
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
void linalg_eig_cusolver_xgeev(const Tensor& eigenvalues, const Tensor& eigenvectors, const Tensor& input, const Tensor& infos, bool compute_eigenvectors) {
 | 
			
		||||
  AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES(eigenvectors.scalar_type(), "linalg_eig_cuda", [&] {
 | 
			
		||||
    apply_xgeev<scalar_t>(eigenvalues, eigenvectors, input, infos, compute_eigenvectors);
 | 
			
		||||
  });
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
#endif // defined(CUSOLVER_VERSION) && (CUSOLVER_VERSION >= 11702)
 | 
			
		||||
 | 
			
		||||
// The 'apply_' word is used for templated by dtype functions that call an API routine
 | 
			
		||||
// underneath. Since the cusolver API has a slightly different structure we do not prepend
 | 
			
		||||
// apply_ to this function.
 | 
			
		||||
 | 
			
		||||
@ -73,11 +73,6 @@ void ormqr_cusolver(const Tensor& input, const Tensor& tau, const Tensor& other,
 | 
			
		||||
Tensor& orgqr_helper_cusolver(Tensor& result, const Tensor& tau);
 | 
			
		||||
 | 
			
		||||
void linalg_eigh_cusolver(const Tensor& eigenvalues, const Tensor& eigenvectors, const Tensor& infos, bool upper, bool compute_eigenvectors);
 | 
			
		||||
 | 
			
		||||
void linalg_eig_cusolver_xgeev(const Tensor& eigenvalues, const Tensor& eigenvectors, const Tensor& input, const Tensor& infos, bool compute_eigenvectors);
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
void lu_solve_looped_cusolver(const Tensor& LU, const Tensor& pivots, const Tensor& B, TransposeType transpose);
 | 
			
		||||
 | 
			
		||||
void lu_factor_looped_cusolver(const Tensor& self, const Tensor& pivots, const Tensor& infos, bool get_pivots);
 | 
			
		||||
 | 
			
		||||
@ -1954,336 +1954,6 @@ void xsyevd<c10::complex<double>, double>(
 | 
			
		||||
      workspaceInBytesOnHost,
 | 
			
		||||
      info));
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// cuSOLVER Xgeev bindings (requires cuSOLVER >= 11.7.2, i.e. CUDA 12.8+)
 | 
			
		||||
#if defined(CUSOLVER_VERSION) && (CUSOLVER_VERSION >= 11702)
 | 
			
		||||
 | 
			
		||||
template <>
 | 
			
		||||
void xgeev_bufferSize<float>(
 | 
			
		||||
    cusolverDnHandle_t handle,
 | 
			
		||||
    cusolverDnParams_t params,
 | 
			
		||||
    cusolverEigMode_t jobvl,
 | 
			
		||||
    cusolverEigMode_t jobvr,
 | 
			
		||||
    int64_t n,
 | 
			
		||||
    const float* A,
 | 
			
		||||
    int64_t lda,
 | 
			
		||||
    const float* W,
 | 
			
		||||
    const float* VL,
 | 
			
		||||
    int64_t ldvl,
 | 
			
		||||
    const float* VR,
 | 
			
		||||
    int64_t ldvr,
 | 
			
		||||
    size_t* workspaceInBytesOnDevice,
 | 
			
		||||
    size_t* workspaceInBytesOnHost) {
 | 
			
		||||
  TORCH_CUSOLVER_CHECK(cusolverDnXgeev_bufferSize(
 | 
			
		||||
      handle, params, jobvl, jobvr, n,
 | 
			
		||||
      CUDA_R_32F,
 | 
			
		||||
      reinterpret_cast<const void*>(A),
 | 
			
		||||
      lda,
 | 
			
		||||
      CUDA_R_32F,
 | 
			
		||||
      reinterpret_cast<const void*>(W),
 | 
			
		||||
      CUDA_R_32F,
 | 
			
		||||
      reinterpret_cast<const void*>(VL),
 | 
			
		||||
      ldvl,
 | 
			
		||||
      CUDA_R_32F,
 | 
			
		||||
      reinterpret_cast<const void*>(VR),
 | 
			
		||||
      ldvr,
 | 
			
		||||
      CUDA_R_32F,
 | 
			
		||||
      workspaceInBytesOnDevice,
 | 
			
		||||
      workspaceInBytesOnHost));
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
template <>
 | 
			
		||||
void xgeev_bufferSize<double>(
 | 
			
		||||
    cusolverDnHandle_t handle,
 | 
			
		||||
    cusolverDnParams_t params,
 | 
			
		||||
    cusolverEigMode_t jobvl,
 | 
			
		||||
    cusolverEigMode_t jobvr,
 | 
			
		||||
    int64_t n,
 | 
			
		||||
    const double* A,
 | 
			
		||||
    int64_t lda,
 | 
			
		||||
    const double* W,
 | 
			
		||||
    const double* VL,
 | 
			
		||||
    int64_t ldvl,
 | 
			
		||||
    const double* VR,
 | 
			
		||||
    int64_t ldvr,
 | 
			
		||||
    size_t* workspaceInBytesOnDevice,
 | 
			
		||||
    size_t* workspaceInBytesOnHost) {
 | 
			
		||||
  TORCH_CUSOLVER_CHECK(cusolverDnXgeev_bufferSize(
 | 
			
		||||
      handle, params, jobvl, jobvr, n,
 | 
			
		||||
      CUDA_R_64F,
 | 
			
		||||
      reinterpret_cast<const void*>(A),
 | 
			
		||||
      lda,
 | 
			
		||||
      CUDA_R_64F,
 | 
			
		||||
      reinterpret_cast<const void*>(W),
 | 
			
		||||
      CUDA_R_64F,
 | 
			
		||||
      reinterpret_cast<const void*>(VL),
 | 
			
		||||
      ldvl,
 | 
			
		||||
      CUDA_R_64F,
 | 
			
		||||
      reinterpret_cast<const void*>(VR),
 | 
			
		||||
      ldvr,
 | 
			
		||||
      CUDA_R_64F,
 | 
			
		||||
      workspaceInBytesOnDevice,
 | 
			
		||||
      workspaceInBytesOnHost));
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
template <>
 | 
			
		||||
void xgeev_bufferSize<c10::complex<float>>(
 | 
			
		||||
    cusolverDnHandle_t handle,
 | 
			
		||||
    cusolverDnParams_t params,
 | 
			
		||||
    cusolverEigMode_t jobvl,
 | 
			
		||||
    cusolverEigMode_t jobvr,
 | 
			
		||||
    int64_t n,
 | 
			
		||||
    const c10::complex<float>* A,
 | 
			
		||||
    int64_t lda,
 | 
			
		||||
    const c10::complex<float>* W,
 | 
			
		||||
    const c10::complex<float>* VL,
 | 
			
		||||
    int64_t ldvl,
 | 
			
		||||
    const c10::complex<float>* VR,
 | 
			
		||||
    int64_t ldvr,
 | 
			
		||||
    size_t* workspaceInBytesOnDevice,
 | 
			
		||||
    size_t* workspaceInBytesOnHost) {
 | 
			
		||||
  TORCH_CUSOLVER_CHECK(cusolverDnXgeev_bufferSize(
 | 
			
		||||
      handle, params, jobvl, jobvr, n,
 | 
			
		||||
      CUDA_C_32F,
 | 
			
		||||
      reinterpret_cast<const void*>(A),
 | 
			
		||||
      lda,
 | 
			
		||||
      CUDA_C_32F,
 | 
			
		||||
      reinterpret_cast<const void*>(W),
 | 
			
		||||
      CUDA_C_32F,
 | 
			
		||||
      reinterpret_cast<const void*>(VL),
 | 
			
		||||
      ldvl,
 | 
			
		||||
      CUDA_C_32F,
 | 
			
		||||
      reinterpret_cast<const void*>(VR),
 | 
			
		||||
      ldvr,
 | 
			
		||||
      CUDA_C_32F,
 | 
			
		||||
      workspaceInBytesOnDevice,
 | 
			
		||||
      workspaceInBytesOnHost));
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
template <>
 | 
			
		||||
void xgeev_bufferSize<c10::complex<double>>(
 | 
			
		||||
    cusolverDnHandle_t handle,
 | 
			
		||||
    cusolverDnParams_t params,
 | 
			
		||||
    cusolverEigMode_t jobvl,
 | 
			
		||||
    cusolverEigMode_t jobvr,
 | 
			
		||||
    int64_t n,
 | 
			
		||||
    const c10::complex<double>* A,
 | 
			
		||||
    int64_t lda,
 | 
			
		||||
    const c10::complex<double>* W,
 | 
			
		||||
    const c10::complex<double>* VL,
 | 
			
		||||
    int64_t ldvl,
 | 
			
		||||
    const c10::complex<double>* VR,
 | 
			
		||||
    int64_t ldvr,
 | 
			
		||||
    size_t* workspaceInBytesOnDevice,
 | 
			
		||||
    size_t* workspaceInBytesOnHost) {
 | 
			
		||||
  TORCH_CUSOLVER_CHECK(cusolverDnXgeev_bufferSize(
 | 
			
		||||
      handle, params, jobvl, jobvr, n,
 | 
			
		||||
      CUDA_C_64F,
 | 
			
		||||
      reinterpret_cast<const void*>(A),
 | 
			
		||||
      lda,
 | 
			
		||||
      CUDA_C_64F,
 | 
			
		||||
      reinterpret_cast<const void*>(W),
 | 
			
		||||
      CUDA_C_64F,
 | 
			
		||||
      reinterpret_cast<const void*>(VL),
 | 
			
		||||
      ldvl,
 | 
			
		||||
      CUDA_C_64F,
 | 
			
		||||
      reinterpret_cast<const void*>(VR),
 | 
			
		||||
      ldvr,
 | 
			
		||||
      CUDA_C_64F,
 | 
			
		||||
      workspaceInBytesOnDevice,
 | 
			
		||||
      workspaceInBytesOnHost));
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
template <>
 | 
			
		||||
void xgeev<float>(
 | 
			
		||||
    cusolverDnHandle_t handle,
 | 
			
		||||
    cusolverDnParams_t params,
 | 
			
		||||
    cusolverEigMode_t jobvl,
 | 
			
		||||
    cusolverEigMode_t jobvr,
 | 
			
		||||
    int64_t n,
 | 
			
		||||
    float* A,
 | 
			
		||||
    int64_t lda,
 | 
			
		||||
    float* W,
 | 
			
		||||
    float* VL,
 | 
			
		||||
    int64_t ldvl,
 | 
			
		||||
    float* VR,
 | 
			
		||||
    int64_t ldvr,
 | 
			
		||||
    float* bufferOnDevice,
 | 
			
		||||
    size_t workspaceInBytesOnDevice,
 | 
			
		||||
    float* bufferOnHost,
 | 
			
		||||
    size_t workspaceInBytesOnHost,
 | 
			
		||||
    int* info) {
 | 
			
		||||
 | 
			
		||||
  TORCH_CUSOLVER_CHECK(cusolverDnXgeev(
 | 
			
		||||
      handle,
 | 
			
		||||
      params,
 | 
			
		||||
      jobvl,
 | 
			
		||||
      jobvr,
 | 
			
		||||
      n,
 | 
			
		||||
      CUDA_R_32F,
 | 
			
		||||
      reinterpret_cast<void*>(A),
 | 
			
		||||
      lda,
 | 
			
		||||
      CUDA_R_32F,
 | 
			
		||||
      reinterpret_cast<void*>(W),
 | 
			
		||||
      CUDA_R_32F,
 | 
			
		||||
      reinterpret_cast<void*>(VL),
 | 
			
		||||
      ldvl,
 | 
			
		||||
      CUDA_R_32F,
 | 
			
		||||
      reinterpret_cast<void*>(VR),
 | 
			
		||||
      ldvr,
 | 
			
		||||
      CUDA_R_32F,
 | 
			
		||||
      reinterpret_cast<void*>(bufferOnDevice),
 | 
			
		||||
      workspaceInBytesOnDevice,
 | 
			
		||||
      reinterpret_cast<void*>(bufferOnHost),
 | 
			
		||||
      workspaceInBytesOnHost,
 | 
			
		||||
      info));
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
template <>
 | 
			
		||||
void xgeev<double>(
 | 
			
		||||
    cusolverDnHandle_t handle,
 | 
			
		||||
    cusolverDnParams_t params,
 | 
			
		||||
    cusolverEigMode_t jobvl,
 | 
			
		||||
    cusolverEigMode_t jobvr,
 | 
			
		||||
    int64_t n,
 | 
			
		||||
    double* A,
 | 
			
		||||
    int64_t lda,
 | 
			
		||||
    double* W,
 | 
			
		||||
    double* VL,
 | 
			
		||||
    int64_t ldvl,
 | 
			
		||||
    double* VR,
 | 
			
		||||
    int64_t ldvr,
 | 
			
		||||
    double* bufferOnDevice,
 | 
			
		||||
    size_t workspaceInBytesOnDevice,
 | 
			
		||||
    double* bufferOnHost,
 | 
			
		||||
    size_t workspaceInBytesOnHost,
 | 
			
		||||
    int* info) {
 | 
			
		||||
 | 
			
		||||
  TORCH_CUSOLVER_CHECK(cusolverDnXgeev(
 | 
			
		||||
      handle,
 | 
			
		||||
      params,
 | 
			
		||||
      jobvl,
 | 
			
		||||
      jobvr,
 | 
			
		||||
      n,
 | 
			
		||||
      CUDA_R_64F,
 | 
			
		||||
      reinterpret_cast<void*>(A),
 | 
			
		||||
      lda,
 | 
			
		||||
      CUDA_R_64F,
 | 
			
		||||
      reinterpret_cast<void*>(W),
 | 
			
		||||
      CUDA_R_64F,
 | 
			
		||||
      reinterpret_cast<void*>(VL),
 | 
			
		||||
      ldvl,
 | 
			
		||||
      CUDA_R_64F,
 | 
			
		||||
      reinterpret_cast<void*>(VR),
 | 
			
		||||
      ldvr,
 | 
			
		||||
      CUDA_R_64F,
 | 
			
		||||
      reinterpret_cast<void*>(bufferOnDevice),
 | 
			
		||||
      workspaceInBytesOnDevice,
 | 
			
		||||
      reinterpret_cast<void*>(bufferOnHost),
 | 
			
		||||
      workspaceInBytesOnHost,
 | 
			
		||||
      info));
 | 
			
		||||
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
template <>
 | 
			
		||||
void xgeev<c10::complex<float>>(
 | 
			
		||||
    cusolverDnHandle_t handle,
 | 
			
		||||
    cusolverDnParams_t params,
 | 
			
		||||
    cusolverEigMode_t jobvl,
 | 
			
		||||
    cusolverEigMode_t jobvr,
 | 
			
		||||
    int64_t n,
 | 
			
		||||
    c10::complex<float>* A,
 | 
			
		||||
    int64_t lda,
 | 
			
		||||
    c10::complex<float>* W,
 | 
			
		||||
    c10::complex<float>* VL,
 | 
			
		||||
    int64_t ldvl,
 | 
			
		||||
    c10::complex<float>* VR,
 | 
			
		||||
    int64_t ldvr,
 | 
			
		||||
    c10::complex<float>* bufferOnDevice,
 | 
			
		||||
    size_t workspaceInBytesOnDevice,
 | 
			
		||||
    c10::complex<float>* bufferOnHost,
 | 
			
		||||
    size_t workspaceInBytesOnHost,
 | 
			
		||||
    int* info) {
 | 
			
		||||
 | 
			
		||||
  TORCH_CUSOLVER_CHECK(cusolverDnXgeev(
 | 
			
		||||
      handle,
 | 
			
		||||
      params,
 | 
			
		||||
      jobvl,
 | 
			
		||||
      jobvr,
 | 
			
		||||
      n,
 | 
			
		||||
      CUDA_C_32F,
 | 
			
		||||
      reinterpret_cast<void*>(A),
 | 
			
		||||
      lda,
 | 
			
		||||
      CUDA_C_32F,
 | 
			
		||||
      reinterpret_cast<void*>(W),
 | 
			
		||||
      CUDA_C_32F,
 | 
			
		||||
      reinterpret_cast<void*>(VL),
 | 
			
		||||
      ldvl,
 | 
			
		||||
      CUDA_C_32F,
 | 
			
		||||
      reinterpret_cast<void*>(VR),
 | 
			
		||||
      ldvr,
 | 
			
		||||
      CUDA_C_32F,
 | 
			
		||||
      reinterpret_cast<void*>(bufferOnDevice),
 | 
			
		||||
      workspaceInBytesOnDevice,
 | 
			
		||||
      reinterpret_cast<void*>(bufferOnHost),
 | 
			
		||||
      workspaceInBytesOnHost,
 | 
			
		||||
      info));
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
template <>
 | 
			
		||||
void xgeev<c10::complex<double>>(
 | 
			
		||||
    cusolverDnHandle_t handle,
 | 
			
		||||
    cusolverDnParams_t params,
 | 
			
		||||
    cusolverEigMode_t jobvl,
 | 
			
		||||
    cusolverEigMode_t jobvr,
 | 
			
		||||
    int64_t n,
 | 
			
		||||
    c10::complex<double>* A,
 | 
			
		||||
    int64_t lda,
 | 
			
		||||
    c10::complex<double>* W,
 | 
			
		||||
    c10::complex<double>* VL,
 | 
			
		||||
    int64_t ldvl,
 | 
			
		||||
    c10::complex<double>* VR,
 | 
			
		||||
    int64_t ldvr,
 | 
			
		||||
    c10::complex<double>* bufferOnDevice,
 | 
			
		||||
    size_t workspaceInBytesOnDevice,
 | 
			
		||||
    c10::complex<double>* bufferOnHost,
 | 
			
		||||
    size_t workspaceInBytesOnHost,
 | 
			
		||||
    int* info) {
 | 
			
		||||
 | 
			
		||||
  TORCH_CUSOLVER_CHECK(cusolverDnXgeev(
 | 
			
		||||
      handle,
 | 
			
		||||
      params,
 | 
			
		||||
      jobvl,
 | 
			
		||||
      jobvr,
 | 
			
		||||
      n,
 | 
			
		||||
      CUDA_C_64F,
 | 
			
		||||
      reinterpret_cast<void*>(A),
 | 
			
		||||
      lda,
 | 
			
		||||
      CUDA_C_64F,
 | 
			
		||||
      reinterpret_cast<void*>(W),
 | 
			
		||||
      CUDA_C_64F,
 | 
			
		||||
      reinterpret_cast<void*>(VL),
 | 
			
		||||
      ldvl,
 | 
			
		||||
      CUDA_C_64F,
 | 
			
		||||
      reinterpret_cast<void*>(VR),
 | 
			
		||||
      ldvr,
 | 
			
		||||
      CUDA_C_64F,
 | 
			
		||||
      reinterpret_cast<void*>(bufferOnDevice),
 | 
			
		||||
      workspaceInBytesOnDevice,
 | 
			
		||||
      reinterpret_cast<void*>(bufferOnHost),
 | 
			
		||||
      workspaceInBytesOnHost,
 | 
			
		||||
      info));
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
#endif // defined(CUSOLVER_VERSION) && (CUSOLVER_VERSION >= 11702)
 | 
			
		||||
 | 
			
		||||
#endif // USE_CUSOLVER_64_BIT
 | 
			
		||||
 | 
			
		||||
#ifdef USE_CUSOLVER_64_BIT_XSYEV_BATCHED
 | 
			
		||||
 | 
			
		||||
@ -674,66 +674,6 @@ template <>
 | 
			
		||||
void xsyevd<c10::complex<double>, double>(
 | 
			
		||||
    CUDASOLVER_XSYEVD_ARGTYPES(c10::complex<double>, double));
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
// cuSOLVER Xgeev (non-Hermitian eigen decomposition, CUDA >= 12.8)
 | 
			
		||||
#if defined(CUSOLVER_VERSION) && (CUSOLVER_VERSION >= 11702)
 | 
			
		||||
 | 
			
		||||
#define CUDASOLVER_XGEEV_BUFFERSIZE_ARGTYPES(scalar_t)                        \
 | 
			
		||||
cusolverDnHandle_t handle, cusolverDnParams_t params,                         \
 | 
			
		||||
cusolverEigMode_t jobvl, cusolverEigMode_t jobvr, int64_t n,                  \
 | 
			
		||||
const scalar_t* A, int64_t lda, const scalar_t* W,                            \
 | 
			
		||||
const scalar_t* VL, int64_t ldvl, const scalar_t* VR, int64_t ldvr,           \
 | 
			
		||||
size_t* workspaceInBytesOnDevice, size_t* workspaceInBytesOnHost
 | 
			
		||||
 | 
			
		||||
template <class scalar_t>
 | 
			
		||||
void xgeev_bufferSize(
 | 
			
		||||
    CUDASOLVER_XGEEV_BUFFERSIZE_ARGTYPES(scalar_t)) {
 | 
			
		||||
  static_assert(false&&sizeof(scalar_t),
 | 
			
		||||
      "at::cuda::solver::xgeev_bufferSize: not implemented");
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
template <>
 | 
			
		||||
void xgeev_bufferSize<float>(CUDASOLVER_XGEEV_BUFFERSIZE_ARGTYPES(float));
 | 
			
		||||
 | 
			
		||||
template <>
 | 
			
		||||
void xgeev_bufferSize<double>(CUDASOLVER_XGEEV_BUFFERSIZE_ARGTYPES(double));
 | 
			
		||||
 | 
			
		||||
template <>
 | 
			
		||||
void xgeev_bufferSize<c10::complex<float>>(
 | 
			
		||||
    CUDASOLVER_XGEEV_BUFFERSIZE_ARGTYPES(c10::complex<float>));
 | 
			
		||||
 | 
			
		||||
template <>
 | 
			
		||||
void xgeev_bufferSize<c10::complex<double>>(
 | 
			
		||||
    CUDASOLVER_XGEEV_BUFFERSIZE_ARGTYPES(c10::complex<double>));
 | 
			
		||||
 | 
			
		||||
#define CUDASOLVER_XGEEV_ARGTYPES(scalar_t)                                    \
 | 
			
		||||
cusolverDnHandle_t handle, cusolverDnParams_t params,                          \
 | 
			
		||||
cusolverEigMode_t jobvl, cusolverEigMode_t jobvr, int64_t n, scalar_t *A,      \
 | 
			
		||||
int64_t lda, scalar_t *W, scalar_t *VL, int64_t ldvl, scalar_t *VR, int64_t ldvr,\
 | 
			
		||||
scalar_t *bufferOnDevice, size_t workspaceInBytesOnDevice, scalar_t *bufferOnHost,\
 | 
			
		||||
size_t workspaceInBytesOnHost, int *info
 | 
			
		||||
 | 
			
		||||
template <class scalar_t>
 | 
			
		||||
void xgeev(CUDASOLVER_XGEEV_ARGTYPES(scalar_t)) {
 | 
			
		||||
  static_assert(false&&sizeof(scalar_t),
 | 
			
		||||
      "at::cuda::solver::xgeev: not implemented");
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
template <>
 | 
			
		||||
void xgeev<float>(CUDASOLVER_XGEEV_ARGTYPES(float));
 | 
			
		||||
 | 
			
		||||
template <>
 | 
			
		||||
void xgeev<double>(CUDASOLVER_XGEEV_ARGTYPES(double));
 | 
			
		||||
 | 
			
		||||
template <>
 | 
			
		||||
void xgeev<c10::complex<float>>(CUDASOLVER_XGEEV_ARGTYPES(c10::complex<float>));
 | 
			
		||||
 | 
			
		||||
template <>
 | 
			
		||||
void xgeev<c10::complex<double>>(CUDASOLVER_XGEEV_ARGTYPES(c10::complex<double>));
 | 
			
		||||
 | 
			
		||||
#endif // defined(CUSOLVER_VERSION) && (CUSOLVER_VERSION >= 11702)
 | 
			
		||||
 | 
			
		||||
#endif // USE_CUSOLVER_64_BIT
 | 
			
		||||
 | 
			
		||||
#ifdef USE_CUSOLVER_64_BIT_XSYEV_BATCHED
 | 
			
		||||
 | 
			
		||||
@ -794,16 +794,14 @@ class TestFP8Lowering(TestCase):
 | 
			
		||||
        _get_torch_cuda_version() < (12, 9),
 | 
			
		||||
        "cuBLAS blockwise scaling added in CUDA 12.9",
 | 
			
		||||
    )
 | 
			
		||||
    @parametrize("shape", ((16, 256, 256), (1024, 512, 1024)))
 | 
			
		||||
    @parametrize("use_fast_accum", (False, True))
 | 
			
		||||
    @parametrize(
 | 
			
		||||
        "scaling_block_sizes", ((1, 128, 128, 128), (1, 128, 1, 128))
 | 
			
		||||
    )  # (BlockWise1x128, BlockWise128x128), (BlockWise1x128, BlockWise1x128)
 | 
			
		||||
    def test_main_loop_scaling(
 | 
			
		||||
        "shape", ((16, 256, 256), (1024, 512, 1024))
 | 
			
		||||
    )  # TODO (jananisriram): add scaling recipe overrides for shapes like (16, 256, 64) and (256, 16, 64)
 | 
			
		||||
    @parametrize("use_fast_accum", (False, True))
 | 
			
		||||
    def test_blockwise1x128_blockwise128x128_scaling(
 | 
			
		||||
        self,
 | 
			
		||||
        shape: tuple[int, int, int],
 | 
			
		||||
        use_fast_accum: bool,
 | 
			
		||||
        scaling_block_sizes: tuple[int, int, int, int],
 | 
			
		||||
    ):
 | 
			
		||||
        # Only bf16 output type is supported for non-tensorwise scaling, not fp32
 | 
			
		||||
        dtype: torch.dtype = torch.bfloat16
 | 
			
		||||
@ -816,28 +814,20 @@ class TestFP8Lowering(TestCase):
 | 
			
		||||
        w = torch.randn(N, K, dtype=dtype, device=device)
 | 
			
		||||
        bias = None
 | 
			
		||||
 | 
			
		||||
        am, ak, bn, bk = scaling_block_sizes
 | 
			
		||||
 | 
			
		||||
        # quantize weight (prior to inference)
 | 
			
		||||
        w_fp8, w_inverse_scale = _quantize_blockwise(
 | 
			
		||||
            w, dtype_float8, block_outer=bn, block_inner=bk
 | 
			
		||||
            w, dtype_float8, block_outer=128, block_inner=128
 | 
			
		||||
        )
 | 
			
		||||
        w_t_fp8 = w_fp8.t()
 | 
			
		||||
        if (bn, bk) == (1, 128):
 | 
			
		||||
            w_inverse_scale = (
 | 
			
		||||
                w_inverse_scale.t().contiguous().t().t()
 | 
			
		||||
            )  # 1x128 blocks need scales to be outer-dim-major
 | 
			
		||||
        else:
 | 
			
		||||
            w_inverse_scale = w_inverse_scale.t()  # scale_b should be (1, N)
 | 
			
		||||
        w_inverse_scale = w_inverse_scale.t()  # scale_b should be (1, N)
 | 
			
		||||
 | 
			
		||||
        # quantize input x
 | 
			
		||||
        x_fp8, x_inverse_scale = _quantize_blockwise(
 | 
			
		||||
            x, dtype_float8, block_outer=am, block_inner=ak
 | 
			
		||||
            x, dtype_float8, block_outer=1, block_inner=128
 | 
			
		||||
        )
 | 
			
		||||
        if (am, ak) == (1, 128):
 | 
			
		||||
            x_inverse_scale = (
 | 
			
		||||
                x_inverse_scale.t().contiguous().t()
 | 
			
		||||
            )  # 1x128 blocks need scales to be outer-dim-major
 | 
			
		||||
        x_inverse_scale = (
 | 
			
		||||
            x_inverse_scale.t().contiguous().t()
 | 
			
		||||
        )  # 1x128 blocks need scales to be outer-dim-major
 | 
			
		||||
 | 
			
		||||
        def linear(x_fp8, x_inverse_scale, w_t_fp8, w_inverse_scale, bias):
 | 
			
		||||
            y = torch._scaled_mm(
 | 
			
		||||
@ -882,15 +872,9 @@ class TestFP8Lowering(TestCase):
 | 
			
		||||
        FileCheck().check(
 | 
			
		||||
            f"SCALE_RECIPE_A : tl.constexpr = {ScalingType.BlockWise1x128.value}"
 | 
			
		||||
        ).run(code[0])
 | 
			
		||||
 | 
			
		||||
        if (bn, bk) == (1, 128):
 | 
			
		||||
            check_scale_recipe_b = ScalingType.BlockWise1x128.value
 | 
			
		||||
        else:
 | 
			
		||||
            check_scale_recipe_b = ScalingType.BlockWise128x128.value
 | 
			
		||||
        FileCheck().check(
 | 
			
		||||
            f"SCALE_RECIPE_B : tl.constexpr = {check_scale_recipe_b}"
 | 
			
		||||
            f"SCALE_RECIPE_B : tl.constexpr = {ScalingType.BlockWise128x128.value}"
 | 
			
		||||
        ).run(code[0])
 | 
			
		||||
 | 
			
		||||
        self.assertEqual(y_eager.dtype, dtype)
 | 
			
		||||
        self.assertEqual(y_compiled.dtype, dtype)
 | 
			
		||||
        torch.testing.assert_close(y_eager, y_compiled, rtol=1e-2, atol=0.05)
 | 
			
		||||
 | 
			
		||||
@ -553,7 +553,7 @@ class TestPatternMatcher(TestCase):
 | 
			
		||||
                torch.randn(16, 16, device=GPU_TYPE),
 | 
			
		||||
                torch.randn(16, 16, device=GPU_TYPE),
 | 
			
		||||
                torch.randn(16, 16, device=GPU_TYPE),
 | 
			
		||||
                True,
 | 
			
		||||
                False,
 | 
			
		||||
            ),
 | 
			
		||||
            (
 | 
			
		||||
                torch.randn(8, device=GPU_TYPE),
 | 
			
		||||
@ -687,17 +687,20 @@ class TestPatternMatcher(TestCase):
 | 
			
		||||
        FileCheck().check("call").check_not(".run").run(code[0])
 | 
			
		||||
 | 
			
		||||
    def test_cat_addmm(self):
 | 
			
		||||
        def fn(a, b, c):
 | 
			
		||||
        def fn(b1, b2, b3, mat1, mat2, mat3):
 | 
			
		||||
            return torch.cat(
 | 
			
		||||
                [
 | 
			
		||||
                    torch.addmm(a, b, c),
 | 
			
		||||
                    torch.addmm(b, c, a),
 | 
			
		||||
                    torch.addmm(c, a, b),
 | 
			
		||||
                    torch.addmm(b1, mat1, mat2),
 | 
			
		||||
                    torch.addmm(b2, mat1, mat3),
 | 
			
		||||
                    torch.addmm(b3, mat2, mat3),
 | 
			
		||||
                ],
 | 
			
		||||
                1,
 | 
			
		||||
            )
 | 
			
		||||
 | 
			
		||||
        args = [
 | 
			
		||||
            torch.randn(16, device=GPU_TYPE),
 | 
			
		||||
            torch.randn(16, device=GPU_TYPE),
 | 
			
		||||
            torch.randn(16, device=GPU_TYPE),
 | 
			
		||||
            torch.randn(16, 16, device=GPU_TYPE),
 | 
			
		||||
            torch.randn(16, 16, device=GPU_TYPE),
 | 
			
		||||
            torch.randn(16, 16, device=GPU_TYPE),
 | 
			
		||||
 | 
			
		||||
@ -1693,7 +1693,7 @@ class AssociativeScanHigherOrderVariable(TorchHigherOrderOperatorVariable):
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
        from torch._higher_order_ops.utils import _maybe_fake_tracing
 | 
			
		||||
        from torch._inductor.utils import is_pointwise_use
 | 
			
		||||
        from torch._inductor.utils import has_only_pointwise_uses
 | 
			
		||||
 | 
			
		||||
        with tx.fake_mode:
 | 
			
		||||
            sub_args_fake = [
 | 
			
		||||
@ -1712,9 +1712,7 @@ class AssociativeScanHigherOrderVariable(TorchHigherOrderOperatorVariable):
 | 
			
		||||
 | 
			
		||||
            for node in fx.graph.nodes:
 | 
			
		||||
                # Check that the combine_fn is pointwise, if combine_mode='pointwise'
 | 
			
		||||
                if not all(
 | 
			
		||||
                    is_pointwise_use(use) or use.op == "output" for use in node.users
 | 
			
		||||
                ):
 | 
			
		||||
                if not has_only_pointwise_uses(node, select_output=True):
 | 
			
		||||
                    raise RuntimeError(
 | 
			
		||||
                        "For combine_mode='pointwise', the combine_fn needs to be pointwise"
 | 
			
		||||
                    )
 | 
			
		||||
 | 
			
		||||
@ -1505,15 +1505,45 @@ def view_to_reshape(gm):
 | 
			
		||||
        nd.target = torch.ops.aten.reshape.default
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
# Relevant for addmm and (add + mm)/(mm + add)
 | 
			
		||||
# Follows the dispatch logic for cuBLASLt at
 | 
			
		||||
# aten/src/ATen/native/cuda/Blas.cpp::isInputCompliesAddmmCudaLt
 | 
			
		||||
def _cublaslt_can_fuse_bias_epilogue(inp, mat1, mat2):
 | 
			
		||||
    if config.max_autotune_gemm:
 | 
			
		||||
        return False
 | 
			
		||||
 | 
			
		||||
    # match the dispatch logic for cuBLASLT at aten/src/ATen/native/cuda/Blas.cpp
 | 
			
		||||
    if not (
 | 
			
		||||
        inp.is_cuda
 | 
			
		||||
        and (inp.dim() == 1 or inp.squeeze().dim == 1)
 | 
			
		||||
        and inp.is_contiguous()
 | 
			
		||||
    ):
 | 
			
		||||
        return False
 | 
			
		||||
 | 
			
		||||
    if not (mat1.dim() == 2 and mat2.dim() == 2):
 | 
			
		||||
        return False
 | 
			
		||||
 | 
			
		||||
    if inp.size(0) != mat2.size(1):
 | 
			
		||||
        return False
 | 
			
		||||
 | 
			
		||||
    if inp.dtype != mat1.dtype or inp.dtype != mat2.dtype:
 | 
			
		||||
        return False
 | 
			
		||||
 | 
			
		||||
    return True
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def should_prefer_unfused_addmm(match):
 | 
			
		||||
    inp = match.kwargs["inp"]
 | 
			
		||||
    if not is_gpu(inp.meta["val"].device.type):
 | 
			
		||||
        return False
 | 
			
		||||
 | 
			
		||||
    return has_uses_tagged_as(
 | 
			
		||||
        match.output_node(),
 | 
			
		||||
        (torch.Tag.pointwise, torch.Tag.reduction),
 | 
			
		||||
    )
 | 
			
		||||
    if has_uses_tagged_as(
 | 
			
		||||
        match.output_node(), (torch.Tag.pointwise, torch.Tag.reduction)
 | 
			
		||||
    ):
 | 
			
		||||
        return True
 | 
			
		||||
    else:
 | 
			
		||||
        args_val = (arg.meta["val"] for arg in (inp, *match.args))
 | 
			
		||||
        return not _cublaslt_can_fuse_bias_epilogue(*args_val)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@register_graph_pattern(
 | 
			
		||||
 | 
			
		||||
@ -1545,7 +1545,6 @@ scaling_pairs = [
 | 
			
		||||
    (ScalingType.TensorWise, ScalingType.TensorWise),
 | 
			
		||||
    (ScalingType.RowWise, ScalingType.RowWise),
 | 
			
		||||
    (ScalingType.BlockWise1x128, ScalingType.BlockWise128x128),
 | 
			
		||||
    (ScalingType.BlockWise1x128, ScalingType.BlockWise1x128),
 | 
			
		||||
]
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -1564,15 +1563,11 @@ def _is_rowwise_scaling(sz: Any, transpose: bool) -> bool:
 | 
			
		||||
    return V.graph.sizevars.statically_known_equals(sz[idx], 1)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def _is_blockwise1xTILESIZE_scaling(
 | 
			
		||||
    sz: Any, tensor_sz: Any, tile_size: int, transpose: bool
 | 
			
		||||
) -> bool:
 | 
			
		||||
    lhs = 1 if transpose else 0
 | 
			
		||||
    rhs = 0 if transpose else 1
 | 
			
		||||
def _is_blockwise1xTILESIZE_scaling(sz: Any, tensor_sz: Any, tile_size: int) -> bool:
 | 
			
		||||
    return V.graph.sizevars.statically_known_equals(
 | 
			
		||||
        sz[lhs], tensor_sz[lhs]
 | 
			
		||||
        sz[0], tensor_sz[0]
 | 
			
		||||
    ) and V.graph.sizevars.statically_known_equals(
 | 
			
		||||
        sz[rhs], ceildiv(tensor_sz[rhs], tile_size)
 | 
			
		||||
        sz[1], ceildiv(tensor_sz[1], tile_size)
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -1594,9 +1589,7 @@ def is_desired_scaling(
 | 
			
		||||
        case ScalingType.RowWise:
 | 
			
		||||
            return _is_rowwise_scaling(scale_size, transpose)
 | 
			
		||||
        case ScalingType.BlockWise1x128:
 | 
			
		||||
            return _is_blockwise1xTILESIZE_scaling(
 | 
			
		||||
                scale_size, t.get_size(), 128, transpose
 | 
			
		||||
            )
 | 
			
		||||
            return _is_blockwise1xTILESIZE_scaling(scale_size, t.get_size(), 128)
 | 
			
		||||
        case ScalingType.BlockWise128x128:
 | 
			
		||||
            return _is_blockwise128x128_scaling(scale_size, t.get_size())
 | 
			
		||||
        case _:
 | 
			
		||||
 | 
			
		||||
@ -80,9 +80,9 @@ from .ir import (
 | 
			
		||||
from .utils import (
 | 
			
		||||
    ceildiv,
 | 
			
		||||
    decode_device,
 | 
			
		||||
    has_only_pointwise_uses,
 | 
			
		||||
    is_dynamic,
 | 
			
		||||
    is_gpu,
 | 
			
		||||
    is_pointwise_use,
 | 
			
		||||
    is_view,
 | 
			
		||||
    needs_fallback_due_to_atomic_add_limitations,
 | 
			
		||||
    pad_listlike,
 | 
			
		||||
@ -1850,10 +1850,7 @@ def cat(inputs, dim=0):
 | 
			
		||||
        (len(inputs) <= config.max_pointwise_cat_inputs)
 | 
			
		||||
        and all(op_count(t) <= MAX_SIMPLE_OP_COUNT for t in inputs)
 | 
			
		||||
    ):
 | 
			
		||||
        pointwise_uses = all(
 | 
			
		||||
            is_pointwise_use(use, additional_pointwise_ops)
 | 
			
		||||
            for use in V.current_node.users
 | 
			
		||||
        )
 | 
			
		||||
        pointwise_uses = has_only_pointwise_uses(V.current_node)
 | 
			
		||||
        # fuse in case we will be used in a pointwise node, and there are any inputs we
 | 
			
		||||
        # we can prevent materialization of.
 | 
			
		||||
        fuse_pointwise_use = (
 | 
			
		||||
 | 
			
		||||
@ -529,30 +529,6 @@ def is_view(op: torch._ops.OpOverload) -> bool:
 | 
			
		||||
    return any(a.alias_info is not None for a in op._schema.arguments)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def is_pointwise_use(
 | 
			
		||||
    use: Node,
 | 
			
		||||
    is_pointwise_fn: Callable[[torch._ops.OpOverload], bool] = lambda _: False,
 | 
			
		||||
) -> bool:
 | 
			
		||||
    """
 | 
			
		||||
    Do all uses of this op have torch.Tag.pointwise or return True for optional `is_pointwise_fn`
 | 
			
		||||
 | 
			
		||||
    Uses in views ops will follow the views uses
 | 
			
		||||
    """
 | 
			
		||||
 | 
			
		||||
    if use.op != "call_function":
 | 
			
		||||
        return False
 | 
			
		||||
    if not (
 | 
			
		||||
        isinstance(use.target, torch._ops.OpOverload) or use.target is operator.getitem
 | 
			
		||||
    ):
 | 
			
		||||
        return False
 | 
			
		||||
 | 
			
		||||
    target = cast(torch._ops.OpOverload, use.target)
 | 
			
		||||
    if target is operator.getitem or is_view(target):
 | 
			
		||||
        return all(is_pointwise_use(u, is_pointwise_fn) for u in use.users)
 | 
			
		||||
 | 
			
		||||
    return torch.Tag.pointwise in target.tags or is_pointwise_fn(target)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class LogicalConnective(enum.Enum):
 | 
			
		||||
    OR = enum.auto()
 | 
			
		||||
    AND = enum.auto()
 | 
			
		||||
@ -562,6 +538,8 @@ def has_uses(
 | 
			
		||||
    target: Node,
 | 
			
		||||
    use_selector_fn: Callable[[torch._ops.OpOverload], bool] = lambda _: False,
 | 
			
		||||
    use_aggregate_type: LogicalConnective = LogicalConnective.OR,
 | 
			
		||||
    *,
 | 
			
		||||
    select_output: bool = False,
 | 
			
		||||
) -> bool:
 | 
			
		||||
    """
 | 
			
		||||
    Given a target, explore the uses of `target` by applying `use_selector_fn`
 | 
			
		||||
@ -585,6 +563,8 @@ def has_uses(
 | 
			
		||||
    use_aggregate_fn = get_use_aggregate_fn(use_aggregate_type)
 | 
			
		||||
 | 
			
		||||
    def has_uses_impl(use: Node) -> bool:
 | 
			
		||||
        if select_output and use.op == "output":
 | 
			
		||||
            return True
 | 
			
		||||
        if use.op != "call_function":
 | 
			
		||||
            return False
 | 
			
		||||
        if not (
 | 
			
		||||
@ -603,17 +583,52 @@ def has_uses(
 | 
			
		||||
    return use_aggregate_fn(has_uses_impl(user) for user in target.users)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def has_only_uses(
 | 
			
		||||
    target: Node,
 | 
			
		||||
    use_selector_fn: Callable[[torch._ops.OpOverload], bool] = lambda _: False,
 | 
			
		||||
    *,
 | 
			
		||||
    select_output: bool = False,
 | 
			
		||||
) -> bool:
 | 
			
		||||
    return has_uses(
 | 
			
		||||
        target, use_selector_fn, LogicalConnective.AND, select_output=select_output
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def has_uses_tagged_as(
 | 
			
		||||
    target: Node,
 | 
			
		||||
    use_tags: Collection[torch.Tag],
 | 
			
		||||
    use_aggregate_type: LogicalConnective = LogicalConnective.OR,
 | 
			
		||||
    *,
 | 
			
		||||
    select_output: bool = False,
 | 
			
		||||
) -> bool:
 | 
			
		||||
    """
 | 
			
		||||
    Is there a use with given tags?
 | 
			
		||||
    """
 | 
			
		||||
 | 
			
		||||
    return has_uses(
 | 
			
		||||
        target, lambda use: any(tag in use_tags for tag in use.tags), use_aggregate_type
 | 
			
		||||
        target,
 | 
			
		||||
        lambda use: any(tag in use_tags for tag in use.tags),
 | 
			
		||||
        use_aggregate_type,
 | 
			
		||||
        select_output=select_output,
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def has_only_pointwise_uses(
 | 
			
		||||
    target: Node,
 | 
			
		||||
    *,
 | 
			
		||||
    select_output: bool = False,
 | 
			
		||||
) -> bool:
 | 
			
		||||
    """
 | 
			
		||||
    Do all uses of target have torch.Tag.pointwise?
 | 
			
		||||
 | 
			
		||||
    Uses in views ops will follow the views uses
 | 
			
		||||
    """
 | 
			
		||||
 | 
			
		||||
    return has_uses_tagged_as(
 | 
			
		||||
        target,
 | 
			
		||||
        use_tags=(torch.Tag.pointwise,),
 | 
			
		||||
        use_aggregate_type=LogicalConnective.AND,
 | 
			
		||||
        select_output=select_output,
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -6365,18 +6365,12 @@ def meta_scaled_mm(
 | 
			
		||||
        n = mat2.size(1)
 | 
			
		||||
 | 
			
		||||
        is_blockwise_scaling = (
 | 
			
		||||
            (
 | 
			
		||||
                scale_a.dtype == torch.float8_e8m0fnu
 | 
			
		||||
                and scale_b.dtype == torch.float8_e8m0fnu
 | 
			
		||||
            )
 | 
			
		||||
            or (
 | 
			
		||||
                scale_a.dtype == torch.float8_e4m3fn
 | 
			
		||||
                and scale_b.dtype == torch.float8_e4m3fn
 | 
			
		||||
            )
 | 
			
		||||
        )  # note: this applies to blockwise scaling for non-FP8 types (FP8 accepts FP32 scales)
 | 
			
		||||
 | 
			
		||||
        def ceil_div(a, b):
 | 
			
		||||
            return (a + b - 1) // b
 | 
			
		||||
            scale_a.dtype == torch.float8_e8m0fnu
 | 
			
		||||
            and scale_b.dtype == torch.float8_e8m0fnu
 | 
			
		||||
        ) or (
 | 
			
		||||
            scale_a.dtype == torch.float8_e4m3fn
 | 
			
		||||
            and scale_b.dtype == torch.float8_e4m3fn
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
        if scale_a.numel() == 1 and scale_b.numel() == 1:
 | 
			
		||||
            # tensorwise scaling
 | 
			
		||||
@ -6398,6 +6392,9 @@ def meta_scaled_mm(
 | 
			
		||||
 | 
			
		||||
            block_size_mn = 128
 | 
			
		||||
 | 
			
		||||
            def ceil_div(a, b):
 | 
			
		||||
                return (a + b - 1) // b
 | 
			
		||||
 | 
			
		||||
            num_k_blocks = ceil_div(_k, block_size_k)
 | 
			
		||||
            padded_num_k_blocks = ceil_div(num_k_blocks, 4) * 4
 | 
			
		||||
 | 
			
		||||
@ -6453,18 +6450,11 @@ def meta_scaled_mm(
 | 
			
		||||
                )
 | 
			
		||||
            elif (
 | 
			
		||||
                scale_a.size(0) == m
 | 
			
		||||
                and scale_a.size(1) == scale_b.size(0) == ceil_div(_k, 128)
 | 
			
		||||
                and scale_b.size(1) == ceil_div(n, 128)
 | 
			
		||||
                and scale_a.size(1) == scale_b.size(0) == (_k + 128 - 1) // 128
 | 
			
		||||
                and scale_b.size(1) == (n + 128 - 1) // 128
 | 
			
		||||
            ):
 | 
			
		||||
                # (BlockWise1x128, BlockWise128x128)
 | 
			
		||||
                pass  # do nothing, but do not error
 | 
			
		||||
            elif (
 | 
			
		||||
                scale_a.size(0) == m
 | 
			
		||||
                and scale_a.size(1) == scale_b.size(0) == ceil_div(_k, 128)
 | 
			
		||||
                and scale_b.size(1) == n
 | 
			
		||||
            ):
 | 
			
		||||
                # (BlockWise1x128, BlockWise1x128)
 | 
			
		||||
                pass  # do nothing, but do not error
 | 
			
		||||
            else:
 | 
			
		||||
                # does not match any valid scaling type
 | 
			
		||||
                torch._check(
 | 
			
		||||
@ -6473,10 +6463,8 @@ def meta_scaled_mm(
 | 
			
		||||
                        "Invalid scaling configuration. "
 | 
			
		||||
                        "For tensorwise scaling, both scales should be scalar. "
 | 
			
		||||
                        f"For rowwise scaling, scale_a should be ({m}, 1), scale_b should be (1, {n}). "
 | 
			
		||||
                        f"For (BlockWise1x128, BlockWise128x128), scale_a should be ({m}, {ceil_div(_k, 128)}), "
 | 
			
		||||
                        + f"scale_b should be ({ceil_div(_k, 128)}, {ceil_div(n, 128)}). "
 | 
			
		||||
                        f"For (BlockWise1x128, BlockWise1x128), scale_a should be ({m}, {ceil_div(_k, 128)}), "
 | 
			
		||||
                        + f"scale_b should be ({ceil_div(_k, 128)}, {n}). "
 | 
			
		||||
                        f"For (BlockWise1x128, BlockWise128x128), scale_a should be ({m}, {(_k + 128 - 1) // 128}), "
 | 
			
		||||
                        + f"scale_b should be ({(_k + 128 - 1) // 128}, {(n + 128 - 1) // 128}). "
 | 
			
		||||
                        f"Got scale_a.size()=({scale_a.size(0)}, {scale_a.size(1)}) "
 | 
			
		||||
                        f"and scale_b.size()=({scale_b.size(0)}, {scale_b.size(1)})"
 | 
			
		||||
                    ),
 | 
			
		||||
 | 
			
		||||
		Reference in New Issue
	
	Block a user