mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Revert "[CUDA][CUTLASS][submodule] Fixes for CUTLASS upgrade (#131493)"
This reverts commit 4aa66f68a803927ddd127ceaaa1521b8d6e90e5f. Reverted https://github.com/pytorch/pytorch/pull/131493 on behalf of https://github.com/izaitsevfb due to breaks internal builds with identifier "std::numeric_limits< ::cutlass::half_t> ::infinity" is undefined in device code ([comment](https://github.com/pytorch/pytorch/pull/131493#issuecomment-2293939390))
This commit is contained in:
@ -6,8 +6,6 @@
|
||||
// Doesn't work on ROCm or Windows yet
|
||||
// TODO: Add compiler warning? Add PyTorch config flag?
|
||||
#else
|
||||
#include <cuda_fp16.h>
|
||||
|
||||
#include <cuda_runtime.h>
|
||||
#include <cutlass/cutlass.h>
|
||||
#include <cutlass/tensor_ref.h>
|
||||
|
@ -141,13 +141,13 @@ void f8f8bf16_rowwise_impl(
|
||||
cute::Stride<cute::Int<1>, cute::Int<0>, cute::Int<0>>>;
|
||||
|
||||
using WScale = cutlass::epilogue::fusion::Sm90RowBroadcast<
|
||||
0,
|
||||
PONG ? 2 : 1,
|
||||
TileShape,
|
||||
ElementComputeEpilogue,
|
||||
cute::Stride<cute::Int<0>, cute::Int<1>, cute::Int<0>>>;
|
||||
|
||||
using Bias = cutlass::epilogue::fusion::Sm90RowBroadcast<
|
||||
0,
|
||||
PONG ? 2 : 1,
|
||||
TileShape,
|
||||
ElementBias,
|
||||
cute::Stride<cute::Int<0>, cute::Int<1>, cute::Int<0>>>;
|
||||
|
@ -9,6 +9,16 @@
|
||||
// sparsification, as a bitmask.
|
||||
// NOTE: Algorithms might select LESS than 8 values in total in some cases.
|
||||
|
||||
namespace platform {
|
||||
template <>
|
||||
struct numeric_limits<cutlass::bfloat16_t> {
|
||||
CUTLASS_HOST_DEVICE
|
||||
static cutlass::bfloat16_t infinity() {
|
||||
return cutlass::bfloat16_t::bitcast(0x7f80);
|
||||
}
|
||||
};
|
||||
} // namespace platform
|
||||
|
||||
namespace at::native{
|
||||
|
||||
template <typename Element, typename Pointwise>
|
||||
@ -58,7 +68,7 @@ template <typename Op = IdentityOp>
|
||||
struct LargestValuesGreedy {
|
||||
template <typename T>
|
||||
static CUTLASS_DEVICE T outOfBoundsFillValue() {
|
||||
return -std::numeric_limits<T>::infinity();
|
||||
return -platform::numeric_limits<T>::infinity();
|
||||
}
|
||||
|
||||
template <typename Tile4x4Accessor>
|
||||
@ -118,7 +128,7 @@ template <typename Op = IdentityOp>
|
||||
struct Causal1122 {
|
||||
template <typename T>
|
||||
static CUTLASS_DEVICE T outOfBoundsFillValue() {
|
||||
return -std::numeric_limits<T>::infinity();
|
||||
return -platform::numeric_limits<T>::infinity();
|
||||
}
|
||||
|
||||
template <typename Tile4x4Accessor>
|
||||
|
@ -44,7 +44,6 @@ template <
|
||||
typename ThreadblockShape,
|
||||
typename WarpShape,
|
||||
typename InstructionShape,
|
||||
typename Operator,
|
||||
typename LayoutInputA,
|
||||
typename LayoutInputB,
|
||||
bool use_bias,
|
||||
@ -63,6 +62,7 @@ Tensor two_four_sgemm(
|
||||
using SmArch = cutlass::arch::Sm80; // Only CC 8.x devices are supported at the moment.
|
||||
using SwizzleThreadBlock = cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>; // This choice provides good performance across wide range of operand sizes.
|
||||
constexpr int NumStages = 3; // This choice provides good performance across wide range of operand sizes.
|
||||
using Operator = cutlass::arch::OpMultiplyAdd;
|
||||
constexpr int NumEVTEpilogueStages = 1;
|
||||
|
||||
constexpr int AlignmentInputA = 128 / cutlass::sizeof_bits<ElementInputA>::value;
|
||||
@ -317,7 +317,6 @@ template <
|
||||
typename ThreadblockShape,
|
||||
typename WarpShape,
|
||||
typename InstructionShape,
|
||||
typename Operator,
|
||||
bool EnableRowMajorRowMajorLayouts,
|
||||
bool EnableRowMajorColumnMajorLayouts,
|
||||
bool EnableColumnMajorRowMajorLayouts,
|
||||
@ -346,7 +345,6 @@ Tensor two_four_sgemm_dispatch_layouts(
|
||||
ThreadblockShape,
|
||||
WarpShape,
|
||||
InstructionShape,
|
||||
Operator,
|
||||
cutlass::layout::RowMajor,
|
||||
cutlass::layout::RowMajor,
|
||||
use_bias,
|
||||
@ -369,7 +367,6 @@ Tensor two_four_sgemm_dispatch_layouts(
|
||||
ThreadblockShape,
|
||||
WarpShape,
|
||||
InstructionShape,
|
||||
Operator,
|
||||
cutlass::layout::RowMajor,
|
||||
cutlass::layout::ColumnMajor,
|
||||
use_bias,
|
||||
@ -392,7 +389,6 @@ Tensor two_four_sgemm_dispatch_layouts(
|
||||
ThreadblockShape,
|
||||
WarpShape,
|
||||
InstructionShape,
|
||||
Operator,
|
||||
cutlass::layout::ColumnMajor,
|
||||
cutlass::layout::RowMajor,
|
||||
use_bias,
|
||||
@ -415,7 +411,6 @@ Tensor two_four_sgemm_dispatch_layouts(
|
||||
ThreadblockShape,
|
||||
WarpShape,
|
||||
InstructionShape,
|
||||
Operator,
|
||||
cutlass::layout::ColumnMajor,
|
||||
cutlass::layout::ColumnMajor,
|
||||
use_bias,
|
||||
@ -445,7 +440,6 @@ template <
|
||||
typename ThreadblockShape,
|
||||
typename WarpShape,
|
||||
typename InstructionShape,
|
||||
typename Operator,
|
||||
bool EnableRowMajorRowMajorLayouts,
|
||||
bool EnableRowMajorColumnMajorLayouts,
|
||||
bool EnableColumnMajorRowMajorLayouts,
|
||||
@ -463,7 +457,6 @@ Tensor two_four_sgemm_dispatch_layouts_bias(
|
||||
ThreadblockShape,
|
||||
WarpShape,
|
||||
InstructionShape,
|
||||
Operator,
|
||||
EnableRowMajorRowMajorLayouts,
|
||||
EnableRowMajorColumnMajorLayouts,
|
||||
EnableColumnMajorRowMajorLayouts,
|
||||
@ -483,7 +476,6 @@ Tensor two_four_sgemm_dispatch_layouts_bias(
|
||||
ThreadblockShape,
|
||||
WarpShape,
|
||||
InstructionShape,
|
||||
Operator,
|
||||
EnableRowMajorRowMajorLayouts,
|
||||
EnableRowMajorColumnMajorLayouts,
|
||||
EnableColumnMajorRowMajorLayouts,
|
||||
@ -506,7 +498,6 @@ template <
|
||||
typename ThreadblockShape,
|
||||
typename WarpShape,
|
||||
typename InstructionShape,
|
||||
typename Operator,
|
||||
bool EnableRowMajorRowMajorLayouts,
|
||||
bool EnableRowMajorColumnMajorLayouts,
|
||||
bool EnableColumnMajorRowMajorLayouts,
|
||||
@ -528,7 +519,6 @@ Tensor two_four_sgemm_dispatch_layouts_bias_activation(
|
||||
ThreadblockShape,
|
||||
WarpShape,
|
||||
InstructionShape,
|
||||
Operator,
|
||||
EnableRowMajorRowMajorLayouts,
|
||||
EnableRowMajorColumnMajorLayouts,
|
||||
EnableColumnMajorRowMajorLayouts,
|
||||
@ -550,7 +540,6 @@ Tensor two_four_sgemm_dispatch_layouts_bias_activation(
|
||||
ThreadblockShape,
|
||||
WarpShape,
|
||||
InstructionShape,
|
||||
Operator,
|
||||
EnableRowMajorRowMajorLayouts,
|
||||
EnableRowMajorColumnMajorLayouts,
|
||||
EnableColumnMajorRowMajorLayouts,
|
||||
@ -572,7 +561,6 @@ Tensor two_four_sgemm_dispatch_layouts_bias_activation(
|
||||
ThreadblockShape,
|
||||
WarpShape,
|
||||
InstructionShape,
|
||||
Operator,
|
||||
EnableRowMajorRowMajorLayouts,
|
||||
EnableRowMajorColumnMajorLayouts,
|
||||
EnableColumnMajorRowMajorLayouts,
|
||||
@ -729,7 +717,6 @@ Tensor _sparse_semi_structured_linear(
|
||||
cutlass::gemm::GemmShape<128, 128, 128>;
|
||||
using WarpShape = cutlass::gemm::GemmShape<64, 64, 128>;
|
||||
using InstructionShape = cutlass::gemm::GemmShape<16, 8, 64>;
|
||||
using Operator = cutlass::arch::OpMultiplyAddSaturate;
|
||||
const auto EnableRowMajorRowMajorLayouts = false;
|
||||
const auto EnableRowMajorColumnMajorLayouts = true;
|
||||
const auto EnableColumnMajorRowMajorLayouts = false;
|
||||
@ -747,7 +734,6 @@ Tensor _sparse_semi_structured_linear(
|
||||
ThreadblockShape,
|
||||
WarpShape,
|
||||
InstructionShape,
|
||||
Operator,
|
||||
EnableRowMajorRowMajorLayouts,
|
||||
EnableRowMajorColumnMajorLayouts,
|
||||
EnableColumnMajorRowMajorLayouts,
|
||||
@ -770,7 +756,6 @@ Tensor _sparse_semi_structured_linear(
|
||||
ThreadblockShape,
|
||||
WarpShape,
|
||||
InstructionShape,
|
||||
Operator,
|
||||
EnableRowMajorRowMajorLayouts,
|
||||
EnableRowMajorColumnMajorLayouts,
|
||||
EnableColumnMajorRowMajorLayouts,
|
||||
@ -796,7 +781,6 @@ Tensor _sparse_semi_structured_linear(
|
||||
using ThreadblockShape = cutlass::gemm::GemmShape<128, 128, 64>;
|
||||
using WarpShape = cutlass::gemm::GemmShape<64, 64, 64>;
|
||||
using InstructionShape = cutlass::gemm::GemmShape<16, 8, 32>;
|
||||
using Operator = cutlass::arch::OpMultiplyAdd;
|
||||
const auto EnableRowMajorRowMajorLayouts = true;
|
||||
const auto EnableRowMajorColumnMajorLayouts = true;
|
||||
const auto EnableColumnMajorRowMajorLayouts = true;
|
||||
@ -812,7 +796,6 @@ Tensor _sparse_semi_structured_linear(
|
||||
ThreadblockShape,
|
||||
WarpShape,
|
||||
InstructionShape,
|
||||
Operator,
|
||||
EnableRowMajorRowMajorLayouts,
|
||||
EnableRowMajorColumnMajorLayouts,
|
||||
EnableColumnMajorRowMajorLayouts,
|
||||
@ -837,7 +820,6 @@ Tensor _sparse_semi_structured_linear(
|
||||
using ThreadblockShape = cutlass::gemm::GemmShape<128, 128, 64>;
|
||||
using WarpShape = cutlass::gemm::GemmShape<64, 64, 64>;
|
||||
using InstructionShape = cutlass::gemm::GemmShape<16, 8, 32>;
|
||||
using Operator = cutlass::arch::OpMultiplyAdd;
|
||||
const auto EnableRowMajorRowMajorLayouts = true;
|
||||
const auto EnableRowMajorColumnMajorLayouts = true;
|
||||
const auto EnableColumnMajorRowMajorLayouts = true;
|
||||
@ -853,7 +835,6 @@ Tensor _sparse_semi_structured_linear(
|
||||
ThreadblockShape,
|
||||
WarpShape,
|
||||
InstructionShape,
|
||||
Operator,
|
||||
EnableRowMajorRowMajorLayouts,
|
||||
EnableRowMajorColumnMajorLayouts,
|
||||
EnableColumnMajorRowMajorLayouts,
|
||||
@ -878,7 +859,6 @@ Tensor _sparse_semi_structured_linear(
|
||||
using ThreadblockShape = cutlass::gemm::GemmShape<128, 64, 32>;
|
||||
using WarpShape = cutlass::gemm::GemmShape<64, 32, 32>;
|
||||
using InstructionShape = cutlass::gemm::GemmShape<16, 8, 16>;
|
||||
using Operator = cutlass::arch::OpMultiplyAdd;
|
||||
const auto EnableRowMajorRowMajorLayouts = true;
|
||||
const auto EnableRowMajorColumnMajorLayouts = true;
|
||||
const auto EnableColumnMajorRowMajorLayouts = true;
|
||||
@ -894,7 +874,6 @@ Tensor _sparse_semi_structured_linear(
|
||||
ThreadblockShape,
|
||||
WarpShape,
|
||||
InstructionShape,
|
||||
Operator,
|
||||
EnableRowMajorRowMajorLayouts,
|
||||
EnableRowMajorColumnMajorLayouts,
|
||||
EnableColumnMajorRowMajorLayouts,
|
||||
|
@ -41,7 +41,6 @@ template <
|
||||
typename ThreadblockShape,
|
||||
typename WarpShape,
|
||||
typename InstructionShape,
|
||||
typename Operator,
|
||||
typename LayoutInputA,
|
||||
typename LayoutInputB,
|
||||
bool use_tensor_c>
|
||||
@ -58,6 +57,7 @@ void spgemm_cutlass(
|
||||
using SmArch = cutlass::arch::Sm80; // Only CC 8.x devices are supported at the moment.
|
||||
using SwizzleThreadBlock = cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>; // This choice provides good performance across wide range of operand sizes.
|
||||
constexpr int NumStages = 3; // This choice provides good performance across wide range of operand sizes.
|
||||
using Operator = cutlass::arch::OpMultiplyAdd;
|
||||
constexpr int NumEVTEpilogueStages = 1;
|
||||
|
||||
constexpr int AlignmentInputA = 128 / cutlass::sizeof_bits<ElementInputA>::value;
|
||||
@ -305,7 +305,6 @@ template <
|
||||
typename ThreadblockShape,
|
||||
typename WarpShape,
|
||||
typename InstructionShape,
|
||||
typename Operator,
|
||||
bool EnableRowMajorRowMajorLayouts,
|
||||
bool EnableRowMajorColumnMajorLayouts,
|
||||
bool EnableColumnMajorRowMajorLayouts,
|
||||
@ -334,7 +333,6 @@ void spgemm_cutlass_dispatch_layouts(
|
||||
ThreadblockShape,
|
||||
WarpShape,
|
||||
InstructionShape,
|
||||
Operator,
|
||||
cutlass::layout::RowMajor,
|
||||
cutlass::layout::RowMajor,
|
||||
use_tensor_c>(
|
||||
@ -360,7 +358,6 @@ void spgemm_cutlass_dispatch_layouts(
|
||||
ThreadblockShape,
|
||||
WarpShape,
|
||||
InstructionShape,
|
||||
Operator,
|
||||
cutlass::layout::RowMajor,
|
||||
cutlass::layout::ColumnMajor,
|
||||
use_tensor_c>(
|
||||
@ -386,7 +383,6 @@ void spgemm_cutlass_dispatch_layouts(
|
||||
ThreadblockShape,
|
||||
WarpShape,
|
||||
InstructionShape,
|
||||
Operator,
|
||||
cutlass::layout::ColumnMajor,
|
||||
cutlass::layout::RowMajor,
|
||||
use_tensor_c>(
|
||||
@ -412,7 +408,6 @@ void spgemm_cutlass_dispatch_layouts(
|
||||
ThreadblockShape,
|
||||
WarpShape,
|
||||
InstructionShape,
|
||||
Operator,
|
||||
cutlass::layout::ColumnMajor,
|
||||
cutlass::layout::ColumnMajor,
|
||||
use_tensor_c>(
|
||||
@ -444,7 +439,6 @@ template <
|
||||
typename ThreadblockShape,
|
||||
typename WarpShape,
|
||||
typename InstructionShape,
|
||||
typename Operator,
|
||||
bool EnableRowMajorRowMajorLayouts,
|
||||
bool EnableRowMajorColumnMajorLayouts,
|
||||
bool EnableColumnMajorRowMajorLayouts,
|
||||
@ -462,7 +456,6 @@ void spgemm_cutlass_dispatch_layouts_tensor_c(
|
||||
ThreadblockShape,
|
||||
WarpShape,
|
||||
InstructionShape,
|
||||
Operator,
|
||||
EnableRowMajorRowMajorLayouts,
|
||||
EnableRowMajorColumnMajorLayouts,
|
||||
EnableColumnMajorRowMajorLayouts,
|
||||
@ -484,7 +477,6 @@ void spgemm_cutlass_dispatch_layouts_tensor_c(
|
||||
ThreadblockShape,
|
||||
WarpShape,
|
||||
InstructionShape,
|
||||
Operator,
|
||||
EnableRowMajorRowMajorLayouts,
|
||||
EnableRowMajorColumnMajorLayouts,
|
||||
EnableColumnMajorRowMajorLayouts,
|
||||
@ -637,7 +629,6 @@ Tensor sparse_semi_structured_mad_op(
|
||||
cutlass::gemm::GemmShape<128, 128, 128>;
|
||||
using WarpShape = cutlass::gemm::GemmShape<64, 64, 128>;
|
||||
using InstructionShape = cutlass::gemm::GemmShape<16, 8, 64>;
|
||||
using Operator = cutlass::arch::OpMultiplyAddSaturate;
|
||||
const auto EnableRowMajorRowMajorLayouts = false;
|
||||
const auto EnableRowMajorColumnMajorLayouts = true;
|
||||
const auto EnableColumnMajorRowMajorLayouts = false;
|
||||
@ -652,7 +643,6 @@ Tensor sparse_semi_structured_mad_op(
|
||||
ThreadblockShape,
|
||||
WarpShape,
|
||||
InstructionShape,
|
||||
Operator,
|
||||
EnableRowMajorRowMajorLayouts,
|
||||
EnableRowMajorColumnMajorLayouts,
|
||||
EnableColumnMajorRowMajorLayouts,
|
||||
@ -674,7 +664,6 @@ Tensor sparse_semi_structured_mad_op(
|
||||
ThreadblockShape,
|
||||
WarpShape,
|
||||
InstructionShape,
|
||||
Operator,
|
||||
EnableRowMajorRowMajorLayouts,
|
||||
EnableRowMajorColumnMajorLayouts,
|
||||
EnableColumnMajorRowMajorLayouts,
|
||||
@ -698,7 +687,6 @@ Tensor sparse_semi_structured_mad_op(
|
||||
using ThreadblockShape = cutlass::gemm::GemmShape<128, 128, 64>;
|
||||
using WarpShape = cutlass::gemm::GemmShape<64, 64, 64>;
|
||||
using InstructionShape = cutlass::gemm::GemmShape<16, 8, 32>;
|
||||
using Operator = cutlass::arch::OpMultiplyAdd;
|
||||
const auto EnableRowMajorRowMajorLayouts = true;
|
||||
const auto EnableRowMajorColumnMajorLayouts = true;
|
||||
const auto EnableColumnMajorRowMajorLayouts = true;
|
||||
@ -711,7 +699,6 @@ Tensor sparse_semi_structured_mad_op(
|
||||
ThreadblockShape,
|
||||
WarpShape,
|
||||
InstructionShape,
|
||||
Operator,
|
||||
EnableRowMajorRowMajorLayouts,
|
||||
EnableRowMajorColumnMajorLayouts,
|
||||
EnableColumnMajorRowMajorLayouts,
|
||||
@ -734,7 +721,6 @@ Tensor sparse_semi_structured_mad_op(
|
||||
using ThreadblockShape = cutlass::gemm::GemmShape<128, 128, 64>;
|
||||
using WarpShape = cutlass::gemm::GemmShape<64, 64, 64>;
|
||||
using InstructionShape = cutlass::gemm::GemmShape<16, 8, 32>;
|
||||
using Operator = cutlass::arch::OpMultiplyAdd;
|
||||
const auto EnableRowMajorRowMajorLayouts = true;
|
||||
const auto EnableRowMajorColumnMajorLayouts = true;
|
||||
const auto EnableColumnMajorRowMajorLayouts = true;
|
||||
@ -747,7 +733,6 @@ Tensor sparse_semi_structured_mad_op(
|
||||
ThreadblockShape,
|
||||
WarpShape,
|
||||
InstructionShape,
|
||||
Operator,
|
||||
EnableRowMajorRowMajorLayouts,
|
||||
EnableRowMajorColumnMajorLayouts,
|
||||
EnableColumnMajorRowMajorLayouts,
|
||||
@ -770,7 +755,6 @@ Tensor sparse_semi_structured_mad_op(
|
||||
using ThreadblockShape = cutlass::gemm::GemmShape<128, 64, 32>;
|
||||
using WarpShape = cutlass::gemm::GemmShape<64, 32, 32>;
|
||||
using InstructionShape = cutlass::gemm::GemmShape<16, 8, 16>;
|
||||
using Operator = cutlass::arch::OpMultiplyAdd;
|
||||
const auto EnableRowMajorRowMajorLayouts = true;
|
||||
const auto EnableRowMajorColumnMajorLayouts = true;
|
||||
const auto EnableColumnMajorRowMajorLayouts = true;
|
||||
@ -783,7 +767,6 @@ Tensor sparse_semi_structured_mad_op(
|
||||
ThreadblockShape,
|
||||
WarpShape,
|
||||
InstructionShape,
|
||||
Operator,
|
||||
EnableRowMajorRowMajorLayouts,
|
||||
EnableRowMajorColumnMajorLayouts,
|
||||
EnableColumnMajorRowMajorLayouts,
|
||||
|
@ -4,7 +4,7 @@
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <cute/tensor.hpp>
|
||||
#include <cute/algorithm/copy.hpp>
|
||||
|
||||
#include <cutlass/cutlass.h>
|
||||
#include <cutlass/array.h>
|
||||
|
@ -4,7 +4,7 @@
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <cute/tensor.hpp>
|
||||
#include <cute/algorithm/copy.hpp>
|
||||
|
||||
#include <cutlass/cutlass.h>
|
||||
#include <cutlass/array.h>
|
||||
|
2
third_party/cutlass
vendored
2
third_party/cutlass
vendored
Submodule third_party/cutlass updated: fb170439e8...bbe579a9e3
@ -125,7 +125,6 @@ class CUTLASSArgs:
|
||||
generator_target = ""
|
||||
kernels = "all"
|
||||
ignore_kernels = ""
|
||||
exclude_kernels = ""
|
||||
# TODO: these three look dead?
|
||||
kernel_filter_file: None = None
|
||||
selected_kernel_list: None = None
|
||||
|
Reference in New Issue
Block a user