Revert "[CUDA][CUTLASS][submodule] Fixes for CUTLASS upgrade (#131493)"

This reverts commit 4aa66f68a803927ddd127ceaaa1521b8d6e90e5f.

Reverted https://github.com/pytorch/pytorch/pull/131493 on behalf of https://github.com/izaitsevfb due to breaks internal builds with identifier "std::numeric_limits< ::cutlass::half_t> ::infinity" is undefined in device code ([comment](https://github.com/pytorch/pytorch/pull/131493#issuecomment-2293939390))
This commit is contained in:
PyTorch MergeBot
2024-08-16 18:09:33 +00:00
parent 4ee65c7e4e
commit b833990a8f
9 changed files with 19 additions and 50 deletions

View File

@ -6,8 +6,6 @@
// Doesn't work on ROCm or Windows yet
// TODO: Add compiler warning? Add PyTorch config flag?
#else
#include <cuda_fp16.h>
#include <cuda_runtime.h>
#include <cutlass/cutlass.h>
#include <cutlass/tensor_ref.h>

View File

@ -141,13 +141,13 @@ void f8f8bf16_rowwise_impl(
cute::Stride<cute::Int<1>, cute::Int<0>, cute::Int<0>>>;
using WScale = cutlass::epilogue::fusion::Sm90RowBroadcast<
0,
PONG ? 2 : 1,
TileShape,
ElementComputeEpilogue,
cute::Stride<cute::Int<0>, cute::Int<1>, cute::Int<0>>>;
using Bias = cutlass::epilogue::fusion::Sm90RowBroadcast<
0,
PONG ? 2 : 1,
TileShape,
ElementBias,
cute::Stride<cute::Int<0>, cute::Int<1>, cute::Int<0>>>;

View File

@ -9,6 +9,16 @@
// sparsification, as a bitmask.
// NOTE: Algorithms might select LESS than 8 values in total in some cases.
namespace platform {
template <>
struct numeric_limits<cutlass::bfloat16_t> {
CUTLASS_HOST_DEVICE
static cutlass::bfloat16_t infinity() {
return cutlass::bfloat16_t::bitcast(0x7f80);
}
};
} // namespace platform
namespace at::native{
template <typename Element, typename Pointwise>
@ -58,7 +68,7 @@ template <typename Op = IdentityOp>
struct LargestValuesGreedy {
template <typename T>
static CUTLASS_DEVICE T outOfBoundsFillValue() {
return -std::numeric_limits<T>::infinity();
return -platform::numeric_limits<T>::infinity();
}
template <typename Tile4x4Accessor>
@ -118,7 +128,7 @@ template <typename Op = IdentityOp>
struct Causal1122 {
template <typename T>
static CUTLASS_DEVICE T outOfBoundsFillValue() {
return -std::numeric_limits<T>::infinity();
return -platform::numeric_limits<T>::infinity();
}
template <typename Tile4x4Accessor>

View File

@ -44,7 +44,6 @@ template <
typename ThreadblockShape,
typename WarpShape,
typename InstructionShape,
typename Operator,
typename LayoutInputA,
typename LayoutInputB,
bool use_bias,
@ -63,6 +62,7 @@ Tensor two_four_sgemm(
using SmArch = cutlass::arch::Sm80; // Only CC 8.x devices are supported at the moment.
using SwizzleThreadBlock = cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>; // This choice provides good performance across wide range of operand sizes.
constexpr int NumStages = 3; // This choice provides good performance across wide range of operand sizes.
using Operator = cutlass::arch::OpMultiplyAdd;
constexpr int NumEVTEpilogueStages = 1;
constexpr int AlignmentInputA = 128 / cutlass::sizeof_bits<ElementInputA>::value;
@ -317,7 +317,6 @@ template <
typename ThreadblockShape,
typename WarpShape,
typename InstructionShape,
typename Operator,
bool EnableRowMajorRowMajorLayouts,
bool EnableRowMajorColumnMajorLayouts,
bool EnableColumnMajorRowMajorLayouts,
@ -346,7 +345,6 @@ Tensor two_four_sgemm_dispatch_layouts(
ThreadblockShape,
WarpShape,
InstructionShape,
Operator,
cutlass::layout::RowMajor,
cutlass::layout::RowMajor,
use_bias,
@ -369,7 +367,6 @@ Tensor two_four_sgemm_dispatch_layouts(
ThreadblockShape,
WarpShape,
InstructionShape,
Operator,
cutlass::layout::RowMajor,
cutlass::layout::ColumnMajor,
use_bias,
@ -392,7 +389,6 @@ Tensor two_four_sgemm_dispatch_layouts(
ThreadblockShape,
WarpShape,
InstructionShape,
Operator,
cutlass::layout::ColumnMajor,
cutlass::layout::RowMajor,
use_bias,
@ -415,7 +411,6 @@ Tensor two_four_sgemm_dispatch_layouts(
ThreadblockShape,
WarpShape,
InstructionShape,
Operator,
cutlass::layout::ColumnMajor,
cutlass::layout::ColumnMajor,
use_bias,
@ -445,7 +440,6 @@ template <
typename ThreadblockShape,
typename WarpShape,
typename InstructionShape,
typename Operator,
bool EnableRowMajorRowMajorLayouts,
bool EnableRowMajorColumnMajorLayouts,
bool EnableColumnMajorRowMajorLayouts,
@ -463,7 +457,6 @@ Tensor two_four_sgemm_dispatch_layouts_bias(
ThreadblockShape,
WarpShape,
InstructionShape,
Operator,
EnableRowMajorRowMajorLayouts,
EnableRowMajorColumnMajorLayouts,
EnableColumnMajorRowMajorLayouts,
@ -483,7 +476,6 @@ Tensor two_four_sgemm_dispatch_layouts_bias(
ThreadblockShape,
WarpShape,
InstructionShape,
Operator,
EnableRowMajorRowMajorLayouts,
EnableRowMajorColumnMajorLayouts,
EnableColumnMajorRowMajorLayouts,
@ -506,7 +498,6 @@ template <
typename ThreadblockShape,
typename WarpShape,
typename InstructionShape,
typename Operator,
bool EnableRowMajorRowMajorLayouts,
bool EnableRowMajorColumnMajorLayouts,
bool EnableColumnMajorRowMajorLayouts,
@ -528,7 +519,6 @@ Tensor two_four_sgemm_dispatch_layouts_bias_activation(
ThreadblockShape,
WarpShape,
InstructionShape,
Operator,
EnableRowMajorRowMajorLayouts,
EnableRowMajorColumnMajorLayouts,
EnableColumnMajorRowMajorLayouts,
@ -550,7 +540,6 @@ Tensor two_four_sgemm_dispatch_layouts_bias_activation(
ThreadblockShape,
WarpShape,
InstructionShape,
Operator,
EnableRowMajorRowMajorLayouts,
EnableRowMajorColumnMajorLayouts,
EnableColumnMajorRowMajorLayouts,
@ -572,7 +561,6 @@ Tensor two_four_sgemm_dispatch_layouts_bias_activation(
ThreadblockShape,
WarpShape,
InstructionShape,
Operator,
EnableRowMajorRowMajorLayouts,
EnableRowMajorColumnMajorLayouts,
EnableColumnMajorRowMajorLayouts,
@ -729,7 +717,6 @@ Tensor _sparse_semi_structured_linear(
cutlass::gemm::GemmShape<128, 128, 128>;
using WarpShape = cutlass::gemm::GemmShape<64, 64, 128>;
using InstructionShape = cutlass::gemm::GemmShape<16, 8, 64>;
using Operator = cutlass::arch::OpMultiplyAddSaturate;
const auto EnableRowMajorRowMajorLayouts = false;
const auto EnableRowMajorColumnMajorLayouts = true;
const auto EnableColumnMajorRowMajorLayouts = false;
@ -747,7 +734,6 @@ Tensor _sparse_semi_structured_linear(
ThreadblockShape,
WarpShape,
InstructionShape,
Operator,
EnableRowMajorRowMajorLayouts,
EnableRowMajorColumnMajorLayouts,
EnableColumnMajorRowMajorLayouts,
@ -770,7 +756,6 @@ Tensor _sparse_semi_structured_linear(
ThreadblockShape,
WarpShape,
InstructionShape,
Operator,
EnableRowMajorRowMajorLayouts,
EnableRowMajorColumnMajorLayouts,
EnableColumnMajorRowMajorLayouts,
@ -796,7 +781,6 @@ Tensor _sparse_semi_structured_linear(
using ThreadblockShape = cutlass::gemm::GemmShape<128, 128, 64>;
using WarpShape = cutlass::gemm::GemmShape<64, 64, 64>;
using InstructionShape = cutlass::gemm::GemmShape<16, 8, 32>;
using Operator = cutlass::arch::OpMultiplyAdd;
const auto EnableRowMajorRowMajorLayouts = true;
const auto EnableRowMajorColumnMajorLayouts = true;
const auto EnableColumnMajorRowMajorLayouts = true;
@ -812,7 +796,6 @@ Tensor _sparse_semi_structured_linear(
ThreadblockShape,
WarpShape,
InstructionShape,
Operator,
EnableRowMajorRowMajorLayouts,
EnableRowMajorColumnMajorLayouts,
EnableColumnMajorRowMajorLayouts,
@ -837,7 +820,6 @@ Tensor _sparse_semi_structured_linear(
using ThreadblockShape = cutlass::gemm::GemmShape<128, 128, 64>;
using WarpShape = cutlass::gemm::GemmShape<64, 64, 64>;
using InstructionShape = cutlass::gemm::GemmShape<16, 8, 32>;
using Operator = cutlass::arch::OpMultiplyAdd;
const auto EnableRowMajorRowMajorLayouts = true;
const auto EnableRowMajorColumnMajorLayouts = true;
const auto EnableColumnMajorRowMajorLayouts = true;
@ -853,7 +835,6 @@ Tensor _sparse_semi_structured_linear(
ThreadblockShape,
WarpShape,
InstructionShape,
Operator,
EnableRowMajorRowMajorLayouts,
EnableRowMajorColumnMajorLayouts,
EnableColumnMajorRowMajorLayouts,
@ -878,7 +859,6 @@ Tensor _sparse_semi_structured_linear(
using ThreadblockShape = cutlass::gemm::GemmShape<128, 64, 32>;
using WarpShape = cutlass::gemm::GemmShape<64, 32, 32>;
using InstructionShape = cutlass::gemm::GemmShape<16, 8, 16>;
using Operator = cutlass::arch::OpMultiplyAdd;
const auto EnableRowMajorRowMajorLayouts = true;
const auto EnableRowMajorColumnMajorLayouts = true;
const auto EnableColumnMajorRowMajorLayouts = true;
@ -894,7 +874,6 @@ Tensor _sparse_semi_structured_linear(
ThreadblockShape,
WarpShape,
InstructionShape,
Operator,
EnableRowMajorRowMajorLayouts,
EnableRowMajorColumnMajorLayouts,
EnableColumnMajorRowMajorLayouts,

View File

@ -41,7 +41,6 @@ template <
typename ThreadblockShape,
typename WarpShape,
typename InstructionShape,
typename Operator,
typename LayoutInputA,
typename LayoutInputB,
bool use_tensor_c>
@ -58,6 +57,7 @@ void spgemm_cutlass(
using SmArch = cutlass::arch::Sm80; // Only CC 8.x devices are supported at the moment.
using SwizzleThreadBlock = cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>; // This choice provides good performance across wide range of operand sizes.
constexpr int NumStages = 3; // This choice provides good performance across wide range of operand sizes.
using Operator = cutlass::arch::OpMultiplyAdd;
constexpr int NumEVTEpilogueStages = 1;
constexpr int AlignmentInputA = 128 / cutlass::sizeof_bits<ElementInputA>::value;
@ -305,7 +305,6 @@ template <
typename ThreadblockShape,
typename WarpShape,
typename InstructionShape,
typename Operator,
bool EnableRowMajorRowMajorLayouts,
bool EnableRowMajorColumnMajorLayouts,
bool EnableColumnMajorRowMajorLayouts,
@ -334,7 +333,6 @@ void spgemm_cutlass_dispatch_layouts(
ThreadblockShape,
WarpShape,
InstructionShape,
Operator,
cutlass::layout::RowMajor,
cutlass::layout::RowMajor,
use_tensor_c>(
@ -360,7 +358,6 @@ void spgemm_cutlass_dispatch_layouts(
ThreadblockShape,
WarpShape,
InstructionShape,
Operator,
cutlass::layout::RowMajor,
cutlass::layout::ColumnMajor,
use_tensor_c>(
@ -386,7 +383,6 @@ void spgemm_cutlass_dispatch_layouts(
ThreadblockShape,
WarpShape,
InstructionShape,
Operator,
cutlass::layout::ColumnMajor,
cutlass::layout::RowMajor,
use_tensor_c>(
@ -412,7 +408,6 @@ void spgemm_cutlass_dispatch_layouts(
ThreadblockShape,
WarpShape,
InstructionShape,
Operator,
cutlass::layout::ColumnMajor,
cutlass::layout::ColumnMajor,
use_tensor_c>(
@ -444,7 +439,6 @@ template <
typename ThreadblockShape,
typename WarpShape,
typename InstructionShape,
typename Operator,
bool EnableRowMajorRowMajorLayouts,
bool EnableRowMajorColumnMajorLayouts,
bool EnableColumnMajorRowMajorLayouts,
@ -462,7 +456,6 @@ void spgemm_cutlass_dispatch_layouts_tensor_c(
ThreadblockShape,
WarpShape,
InstructionShape,
Operator,
EnableRowMajorRowMajorLayouts,
EnableRowMajorColumnMajorLayouts,
EnableColumnMajorRowMajorLayouts,
@ -484,7 +477,6 @@ void spgemm_cutlass_dispatch_layouts_tensor_c(
ThreadblockShape,
WarpShape,
InstructionShape,
Operator,
EnableRowMajorRowMajorLayouts,
EnableRowMajorColumnMajorLayouts,
EnableColumnMajorRowMajorLayouts,
@ -637,7 +629,6 @@ Tensor sparse_semi_structured_mad_op(
cutlass::gemm::GemmShape<128, 128, 128>;
using WarpShape = cutlass::gemm::GemmShape<64, 64, 128>;
using InstructionShape = cutlass::gemm::GemmShape<16, 8, 64>;
using Operator = cutlass::arch::OpMultiplyAddSaturate;
const auto EnableRowMajorRowMajorLayouts = false;
const auto EnableRowMajorColumnMajorLayouts = true;
const auto EnableColumnMajorRowMajorLayouts = false;
@ -652,7 +643,6 @@ Tensor sparse_semi_structured_mad_op(
ThreadblockShape,
WarpShape,
InstructionShape,
Operator,
EnableRowMajorRowMajorLayouts,
EnableRowMajorColumnMajorLayouts,
EnableColumnMajorRowMajorLayouts,
@ -674,7 +664,6 @@ Tensor sparse_semi_structured_mad_op(
ThreadblockShape,
WarpShape,
InstructionShape,
Operator,
EnableRowMajorRowMajorLayouts,
EnableRowMajorColumnMajorLayouts,
EnableColumnMajorRowMajorLayouts,
@ -698,7 +687,6 @@ Tensor sparse_semi_structured_mad_op(
using ThreadblockShape = cutlass::gemm::GemmShape<128, 128, 64>;
using WarpShape = cutlass::gemm::GemmShape<64, 64, 64>;
using InstructionShape = cutlass::gemm::GemmShape<16, 8, 32>;
using Operator = cutlass::arch::OpMultiplyAdd;
const auto EnableRowMajorRowMajorLayouts = true;
const auto EnableRowMajorColumnMajorLayouts = true;
const auto EnableColumnMajorRowMajorLayouts = true;
@ -711,7 +699,6 @@ Tensor sparse_semi_structured_mad_op(
ThreadblockShape,
WarpShape,
InstructionShape,
Operator,
EnableRowMajorRowMajorLayouts,
EnableRowMajorColumnMajorLayouts,
EnableColumnMajorRowMajorLayouts,
@ -734,7 +721,6 @@ Tensor sparse_semi_structured_mad_op(
using ThreadblockShape = cutlass::gemm::GemmShape<128, 128, 64>;
using WarpShape = cutlass::gemm::GemmShape<64, 64, 64>;
using InstructionShape = cutlass::gemm::GemmShape<16, 8, 32>;
using Operator = cutlass::arch::OpMultiplyAdd;
const auto EnableRowMajorRowMajorLayouts = true;
const auto EnableRowMajorColumnMajorLayouts = true;
const auto EnableColumnMajorRowMajorLayouts = true;
@ -747,7 +733,6 @@ Tensor sparse_semi_structured_mad_op(
ThreadblockShape,
WarpShape,
InstructionShape,
Operator,
EnableRowMajorRowMajorLayouts,
EnableRowMajorColumnMajorLayouts,
EnableColumnMajorRowMajorLayouts,
@ -770,7 +755,6 @@ Tensor sparse_semi_structured_mad_op(
using ThreadblockShape = cutlass::gemm::GemmShape<128, 64, 32>;
using WarpShape = cutlass::gemm::GemmShape<64, 32, 32>;
using InstructionShape = cutlass::gemm::GemmShape<16, 8, 16>;
using Operator = cutlass::arch::OpMultiplyAdd;
const auto EnableRowMajorRowMajorLayouts = true;
const auto EnableRowMajorColumnMajorLayouts = true;
const auto EnableColumnMajorRowMajorLayouts = true;
@ -783,7 +767,6 @@ Tensor sparse_semi_structured_mad_op(
ThreadblockShape,
WarpShape,
InstructionShape,
Operator,
EnableRowMajorRowMajorLayouts,
EnableRowMajorColumnMajorLayouts,
EnableColumnMajorRowMajorLayouts,

View File

@ -4,7 +4,7 @@
#pragma once
#include <cute/tensor.hpp>
#include <cute/algorithm/copy.hpp>
#include <cutlass/cutlass.h>
#include <cutlass/array.h>

View File

@ -4,7 +4,7 @@
#pragma once
#include <cute/tensor.hpp>
#include <cute/algorithm/copy.hpp>
#include <cutlass/cutlass.h>
#include <cutlass/array.h>

View File

@ -125,7 +125,6 @@ class CUTLASSArgs:
generator_target = ""
kernels = "all"
ignore_kernels = ""
exclude_kernels = ""
# TODO: these three look dead?
kernel_filter_file: None = None
selected_kernel_list: None = None