[AsyncMM] re-enable and adapt to cutlass 3.6.0 (#144011)

[D68734067](https://our.internmc.facebook.com/intern/diff/D68734067)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/144011
Approved by: https://github.com/Skylion007, https://github.com/drisspg
This commit is contained in:
Yifu Wang
2025-01-24 17:18:36 -08:00
committed by PyTorch MergeBot
parent 1e3d1738a4
commit c70362fac8
2 changed files with 260 additions and 117 deletions

View File

@ -5,44 +5,17 @@
#include <ATen/cuda/nvrtc_stub/ATenNVRTC.h>
#include <c10/cuda/CUDAGuard.h>
#if false && !defined(USE_ROCM) && !defined(_WIN32) && defined(CUDA_VERSION) && \
// Two warninngs in Cutlass included header files
C10_DIAGNOSTIC_PUSH_AND_IGNORED_IF_DEFINED("-Wset-but-not-used")
C10_DIAGNOSTIC_PUSH_AND_IGNORED_IF_DEFINED("-Wunused-but-set-parameter")
#if !defined(USE_ROCM) && !defined(_WIN32) && defined(CUDA_VERSION) && \
CUDA_VERSION >= 12000
#define BUILD_ASYNC_MM_KERNEL
#endif
#if defined(BUILD_ASYNC_MM_KERNEL)
// We are going to override the cuTensorMapEncodeTiled driver api with our lazy
// loader
static CUresult CUDAAPI nvrtc_cuTensorMapEncodeTiled(
CUtensorMap* tensorMap,
CUtensorMapDataType tensorDataType,
cuuint32_t tensorRank,
void* globalAddress,
const cuuint64_t* globalDim,
const cuuint64_t* globalStrides,
const cuuint32_t* boxDim,
const cuuint32_t* elementStrides,
CUtensorMapInterleave interleave,
CUtensorMapSwizzle swizzle,
CUtensorMapL2promotion l2Promotion,
CUtensorMapFloatOOBfill oobFill) {
return at::globalContext().getNVRTC().cuTensorMapEncodeTiled(
tensorMap,
tensorDataType,
tensorRank,
globalAddress,
globalDim,
globalStrides,
boxDim,
elementStrides,
interleave,
swizzle,
l2Promotion,
oobFill);
}
// clang-format off
#include <cutlass/core_io.h>
#include <cutlass/cutlass.h>
#include <cutlass/gemm/device/gemm.h>
@ -50,13 +23,9 @@ static CUresult CUDAAPI nvrtc_cuTensorMapEncodeTiled(
#include <cutlass/numeric_types.h>
#include <cutlass/trace.h>
#include <cutlass/util/host_tensor.h>
// Rename the global function symbol
#define cuTensorMapEncodeTiled nvrtc_cuTensorMapEncodeTiled
#include <cute/tensor.hpp>
#undef cuTensorMapEncodeTiled
// Set everything back to normal
#include <cutlass/version.h>
#include <cutlass/gemm/collective/collective_builder.hpp>
#include <cutlass/gemm/device/gemm_universal_adapter.h>
#include <cutlass/epilogue/collective/collective_builder.hpp>
@ -65,10 +34,12 @@ static CUresult CUDAAPI nvrtc_cuTensorMapEncodeTiled(
#include <cutlass/gemm/dispatch_policy.hpp>
#include <cutlass/gemm/kernel/gemm_universal.hpp>
#include <cutlass/util/packed_stride.hpp>
// clang-format on
#include <torch/csrc/distributed/c10d/cuda/cutlass/gemm/kernel/persistent_async_input_scheduler.cuh>
C10_DIAGNOSTIC_POP()
C10_DIAGNOSTIC_POP()
namespace {
using namespace cute;
@ -107,7 +78,7 @@ at::Tensor async_input_mm_impl(
cutlass::epilogue::collective::EpilogueTileAuto,
ElementAccumulator,
ElementAccumulator,
void,
ElementC,
LayoutC,
AlignmentC,
ElementC,
@ -133,7 +104,7 @@ at::Tensor async_input_mm_impl(
KernelSchedule>::CollectiveOp;
using GemmKernel = cutlass::gemm::kernel::GemmUniversal<
Shape<int, int, int, int>,
Shape<int, int, int>,
CollectiveMainloop,
CollectiveEpilogue,
cutlass::gemm::PersistentAsyncInputScheduler<KernelSchedule>>;
@ -171,7 +142,7 @@ at::Tensor async_input_mm_impl(
typename Gemm::Arguments arguments{
cutlass::gemm::GemmUniversalMode::kGemm,
{M, N, K, 1},
{M, N, K},
{
reinterpret_cast<ElementA*>(a.data_ptr<at::BFloat16>()),
stride_A,
@ -179,7 +150,7 @@ at::Tensor async_input_mm_impl(
stride_B,
},
{{1, 1},
nullptr,
reinterpret_cast<ElementC*>(out.data_ptr<at::BFloat16>()),
stride_C,
reinterpret_cast<ElementC*>(out.data_ptr<at::BFloat16>()),
stride_C},

View File

@ -1,6 +1,6 @@
/**
* This file contains PersistentAsyncInputScheduler, a forked version of PersistentScheduler that
* supports consuming asynchronous input. This tile scheduler introduces the following arguments:
* This file contains PersistentTileSchedulerSm90, a forked version of PersistentTileSchedulerSm90
* that supports consuming asynchronous input. This tile scheduler introduces the following arguments:
*
* - tiles_per_chunk_m Specifies the size of an M chunk. Chunks are the granularity at which the
* asynchronous input becomes ready. It must be an interger multiple of the size of an M tile.
@ -22,6 +22,12 @@
* CollectiveMainloop,
* CollectiveEpilogue,
* cutlass::gemm::PersistentAsyncInputScheduler<KernelSchedule>>;
*
* Unfortunately, the CRTP base class for tile schedulers (StaticPersistentTileScheduler) doesn't
* provide enough flexibility for the required customization. We had to create a new tile scheduler
* by copying PersistentTileSchedulerSm90 and StaticPersistentTileScheduler then customize on top of
* it. In PersistentTileSchedulerSm90AsyncInput, we marked the customizations with "CUSTOM LOGIC BEGIN"
* and "CUSTOM LOGIC END" comment blocks.
*/
#pragma once
@ -68,6 +74,7 @@ namespace cutlass::gemm::kernel::detail {
////////////////////////////////////////////////////////////////////////////////
class PersistentTileSchedulerSm90AsyncInputParams :
public PersistentTileSchedulerSm90Params {
public:
@ -80,34 +87,92 @@ class PersistentTileSchedulerSm90AsyncInput {
private:
uint64_t current_work_linear_idx_;
uint64_t total_grid_size_;
// ==============================
// CUSTOM LOGIC BEGIN
// ==============================
bool is_mainloop_producer_;
// ==============================
// CUSTOM LOGIC END
// ==============================
public:
using WorkTileInfo = PersistentTileSchedulerSm90::WorkTileInfo;
struct WorkTileInfo {
int32_t M_idx = 0;
int32_t N_idx = 0;
int32_t L_idx = 0;
bool is_valid_tile = false;
CUTLASS_HOST_DEVICE
bool
is_valid() const {
return is_valid_tile;
}
CUTLASS_HOST_DEVICE
static WorkTileInfo
invalid_work_tile() {
return {-1, -1, -1, false};
}
CUTLASS_HOST_DEVICE
bool
is_final_split(uint32_t k_tiles_per_output_tile) const {
return true;
}
CUTLASS_HOST_DEVICE
int32_t
reduction_subtile_idx() const {
return -1;
}
};
// ==============================
// CUSTOM LOGIC BEGIN
// ==============================
using Params = PersistentTileSchedulerSm90AsyncInputParams;
// ==============================
// CUSTOM LOGIC END
// ==============================
using RasterOrder = typename Params::RasterOrder;
using RasterOrderOptions = typename Params::RasterOrderOptions;
static constexpr bool IsDynamicPersistent = false;
public:
// ==============================
// CUSTOM LOGIC BEGIN
// ==============================
struct Arguments {
int max_swizzle_size = 1;
RasterOrderOptions raster_order = RasterOrderOptions::Heuristic;
int max_swizzle_size;
RasterOrderOptions raster_order;
// Async input specific
int tile_idx_pivot_m = 0;
int tiles_per_chunk_m = 0;
uint32_t* chunk_signals = nullptr;
int tile_idx_pivot_m;
int tiles_per_chunk_m;
uint32_t* chunk_signals;
Arguments():
max_swizzle_size(1),
raster_order(RasterOrderOptions::Heuristic),
tile_idx_pivot_m(0),
tiles_per_chunk_m(0),
chunk_signals(nullptr) {}
// ==============================
// CUSTOM LOGIC END
// ==============================
};
template <class ProblemShapeMNKL, class TileShape, class ClusterShape>
static Params
to_underlying_arguments(
ProblemShapeMNKL problem_shape_mnkl,
TileShape tile_shape,
ClusterShape cluster_shape,
[[maybe_unused]] KernelHardwareInfo const& hw_info,
Arguments const& arguments,
[[maybe_unused]] void* workspace=nullptr,
[[maybe_unused]] const uint32_t epilogue_subtile = 1) {
ProblemShapeMNKL problem_shape_mnkl,
TileShape tile_shape,
ClusterShape cluster_shape,
[[maybe_unused]] KernelHardwareInfo const& hw_info,
Arguments const& arguments,
[[maybe_unused]] void* workspace=nullptr,
[[maybe_unused]] const uint32_t epilogue_subtile = 1,
[[maybe_unused]] uint32_t ktile_start_alignment_count = 1u) {
// We only need the tile and cluster shape during scheduler setup, so let FTAD do the magic
static_assert(cute::is_static<TileShape>::value);
@ -123,9 +188,16 @@ public:
arguments.max_swizzle_size,
arguments.raster_order
);
// ==============================
// CUSTOM LOGIC BEGIN
// ==============================
params.tile_idx_pivot_m = arguments.tile_idx_pivot_m;
params.tiles_per_chunk_m = arguments.tiles_per_chunk_m;
params.chunk_signals = arguments.chunk_signals;
// ==============================
// CUSTOM LOGIC END
// ==============================
return params;
}
@ -133,13 +205,13 @@ public:
CUTLASS_HOST_DEVICE
static bool
can_implement(Arguments const& args) {
return args.raster_order == RasterOrderOptions::AlongN;
return args.max_swizzle_size >= 1;
}
CUTLASS_HOST_DEVICE
PersistentTileSchedulerSm90AsyncInput() { }
CUTLASS_DEVICE explicit PersistentTileSchedulerSm90AsyncInput(Params const& params_) : params(params_) {
CUTLASS_DEVICE explicit PersistentTileSchedulerSm90AsyncInput(Params const& params_) : scheduler_params(params_) {
// MSVC requires protecting use of CUDA-specific nonstandard syntax,
// like blockIdx and gridDim, with __CUDA_ARCH__.
#if defined(__CUDA_ARCH__)
@ -150,11 +222,16 @@ public:
current_work_linear_idx_ = uint64_t(blockIdx.x) * uint64_t(gridDim.y) + uint64_t(blockIdx.y);
}
total_grid_size_ = uint64_t(gridDim.x) * uint64_t(gridDim.y) * uint64_t(gridDim.z);
// ==============================
// CUSTOM LOGIC BEGIN
// ==============================
int warp_group_role = canonical_warp_group_idx();
int producer_warp_group_role = canonical_warp_idx_sync() % NumWarpsPerWarpGroup;
is_mainloop_producer_ = warp_group_role == 0 && producer_warp_group_role == 0;
total_grid_size_ = uint64_t(gridDim.x) * uint64_t(gridDim.y) * uint64_t(gridDim.z);
// ==============================
// CUSTOM LOGIC END
// ==============================
#else
CUTLASS_ASSERT(false && "This line should never be reached");
#endif
@ -177,21 +254,24 @@ public:
CUTLASS_DEVICE
WorkTileInfo
get_current_work_for_linear_idx(uint64_t linear_idx) const {
if (linear_idx >= params.blocks_per_problem_) {
if (linear_idx >= scheduler_params.blocks_per_problem_) {
return WorkTileInfo::invalid_work_tile();
}
// Map worker's linear index into the CTA tiled problem shape to the corresponding MNL indices
uint64_t work_idx_l, remainder;
params.divmod_batch_(work_idx_l, remainder, linear_idx);
scheduler_params.divmod_batch_(work_idx_l, remainder, linear_idx);
uint64_t blk_per_grid_dim = params.divmod_cluster_shape_minor_.divide(remainder);
uint64_t blk_per_grid_dim = scheduler_params.divmod_cluster_shape_minor_.divide(remainder);
// ==============================
// CUSTOM LOGIC BEGIN
// ==============================
uint64_t cluster_id, cluster_major_offset = 0, cluster_minor_offset = 0;
params.divmod_cluster_shape_major_(cluster_id, cluster_major_offset, blk_per_grid_dim);
scheduler_params.divmod_cluster_shape_major_(cluster_id, cluster_major_offset, blk_per_grid_dim);
auto [cta_m_in_cluster, cta_n_in_cluster, _] = cute::block_id_in_cluster();
if (params.raster_order_ == RasterOrder::AlongN) {
if (scheduler_params.raster_order_ == RasterOrder::AlongN) {
cluster_minor_offset = cta_m_in_cluster;
}
else {
@ -202,20 +282,20 @@ public:
uint64_t cluster_idx_minor_div_swizzle, extra, offset;
offset = cluster_id & ((1 << params.log_swizzle_size_) - 1);
extra = cluster_id >> params.log_swizzle_size_;
offset = cluster_id & ((1 << scheduler_params.log_swizzle_size_) - 1);
extra = cluster_id >> scheduler_params.log_swizzle_size_;
params.divmod_cluster_blk_major_(cluster_idx_minor_div_swizzle, cluster_idx_major, extra);
scheduler_params.divmod_cluster_blk_major_(cluster_idx_minor_div_swizzle, cluster_idx_major, extra);
cluster_idx_minor = cluster_idx_minor_div_swizzle * (1 << params.log_swizzle_size_) + offset;
cluster_idx_minor = cluster_idx_minor_div_swizzle * (1 << scheduler_params.log_swizzle_size_) + offset;
auto minor_work_idx = static_cast<int32_t>(cluster_idx_minor * params.divmod_cluster_shape_minor_.divisor +
auto minor_work_idx = static_cast<int32_t>(cluster_idx_minor * scheduler_params.divmod_cluster_shape_minor_.divisor +
cluster_minor_offset);
auto major_work_idx = static_cast<int32_t>(cluster_idx_major * params.divmod_cluster_shape_major_.divisor +
auto major_work_idx = static_cast<int32_t>(cluster_idx_major * scheduler_params.divmod_cluster_shape_major_.divisor +
cluster_major_offset);
int m, n;
if (params.raster_order_ == RasterOrder::AlongN) {
if (scheduler_params.raster_order_ == RasterOrder::AlongN) {
m = minor_work_idx;
n = major_work_idx;
} else {
@ -224,13 +304,13 @@ public:
}
// Pivot after swizzling
auto tiles_m = params.problem_tiles_m_ * params.cluster_shape_m_;
m = (m + params.tile_idx_pivot_m) % tiles_m;
auto tiles_m = scheduler_params.problem_tiles_m_ * scheduler_params.cluster_shape_m_;
m = (m + scheduler_params.tile_idx_pivot_m) % tiles_m;
if (is_mainloop_producer_) {
if (threadIdx.x == 0) {
size_t chunk_idx = m / params.tiles_per_chunk_m;
wait_signal(params.chunk_signals + chunk_idx);
size_t chunk_idx = m / scheduler_params.tiles_per_chunk_m;
wait_signal(scheduler_params.chunk_signals + chunk_idx);
}
// An arbirary, non-default id
@ -240,6 +320,9 @@ public:
}
return {m, n, static_cast<int32_t>(work_idx_l), true};
// ==============================
// CUSTOM LOGIC END
// ==============================
}
CUTLASS_DEVICE
@ -248,6 +331,56 @@ public:
current_work_linear_idx_ += total_grid_size_ * uint64_t(advance_count);
}
CUTLASS_DEVICE
bool is_last_tile(WorkTileInfo& work_tile_info, uint32_t advance_count = 1) const {
if (continue_current_work(work_tile_info)) {
return false;
}
return not get_current_work_for_linear_idx(
current_work_linear_idx_ + (total_grid_size_ * uint64_t(advance_count))
).is_valid();
}
// Computes the linear index within a batch given M and N tile offsets within the batch.
// This essentially inverts the mapping performed in get_work_idx_m_and_n
static CUTLASS_DEVICE
uint64_t
get_linear_idx_from_m_and_n(
int32_t tile_m,
int32_t tile_n,
FastDivmodU64Pow2 const& divmod_cluster_shape_major,
FastDivmodU64Pow2 const& divmod_cluster_shape_minor,
FastDivmodU64 const& divmod_cluster_blk_major,
int32_t log_swizzle_size,
RasterOrder raster_order) {
uint64_t minor_work_idx, major_work_idx, cluster_minor_offset;
if (raster_order == RasterOrder::AlongN) {
minor_work_idx = static_cast<uint64_t>(tile_m);
major_work_idx = static_cast<uint64_t>(tile_n);
uint64_t cluster_m = divmod_cluster_shape_minor.divide(tile_m) * divmod_cluster_shape_minor.divisor;
cluster_minor_offset = tile_m - cluster_m;
}
else {
major_work_idx = static_cast<uint64_t>(tile_m);
minor_work_idx = static_cast<uint64_t>(tile_n);
uint64_t cluster_n = divmod_cluster_shape_minor.divide(tile_n) * divmod_cluster_shape_minor.divisor;
cluster_minor_offset = tile_n - cluster_n;
}
uint64_t cluster_idx_minor, cluster_idx_major, cluster_major_offset;
cluster_idx_minor = divmod_cluster_shape_minor.divide(minor_work_idx - cluster_minor_offset);
divmod_cluster_shape_major(cluster_idx_major, cluster_major_offset, major_work_idx);
uint64_t cluster_idx_minor_div_swizzle = cluster_idx_minor >> log_swizzle_size;
uint64_t offset = cluster_idx_minor & ((1 << log_swizzle_size) - 1);
uint64_t extra = cluster_idx_minor_div_swizzle * divmod_cluster_blk_major.divisor + cluster_idx_major;
uint64_t cluster_id = (extra << log_swizzle_size) | offset;
return (cluster_id * divmod_cluster_shape_major.divisor + cluster_major_offset) * divmod_cluster_shape_minor.divisor + cluster_minor_offset;
}
// Given the inputs, computes the total number of output blocks over which this problem will compute.
// Note that this is only the logical size of our grid, not the physical grid we will actually launch.
template<class ProblemShapeMNKL, class BlockShape, class ClusterShape>
@ -263,20 +396,38 @@ public:
cta_m, cta_n
);
}
// Kernel helper function to get next work ID
template <class WorkIdPipeline, class WorkIdPipelineState>
// Reloaded interface that receives WorkTileInfo to deduce next work.
// Kernel helper function to get next work tile
CUTLASS_DEVICE
auto
fetch_next_work(
WorkTileInfo work_tile_info,
WorkIdPipeline& work_id_pipeline,
WorkIdPipelineState work_id_pipe_consumer_state) {
WorkTileInfo new_work_tile_info;
advance_to_next_work();
new_work_tile_info = get_current_work();
fetch_next_work(WorkTileInfo work_tile_info) {
if (continue_current_work(work_tile_info)) {
return cute::make_tuple(work_tile_info, true);
}
// Return true to indicate that the WorkID pipeline state should be advanced
return cute::make_tuple(new_work_tile_info, true);
advance_to_next_work();
return cute::make_tuple(get_current_work(), true);
}
// Given the inputs, computes the total number of output blocks over which this problem will compute.
// Note that this is only the logical size of our grid, not the physical grid we will actually launch.
template<class ProblemShapeMNKL, class TileShape, class AtomThrShape, class ClusterShape>
CUTLASS_HOST_DEVICE static
dim3
get_tiled_cta_shape_mnl(ProblemShapeMNKL problem_shape_mnkl,
TileShape tile_shape_mnk,
AtomThrShape atom_thr_shape_mnk,
ClusterShape cluster_shape_mnk) {
auto [tiles_m, tiles_n, tiles_l] = product_each(ceil_div(select<0,1,3>(problem_shape_mnkl), take<0,2>(tile_shape_mnk)));
auto cta_m = round_nearest(tiles_m * size<0>(atom_thr_shape_mnk), size<0>(cluster_shape_mnk));
auto cta_n = round_nearest(tiles_n * size<1>(atom_thr_shape_mnk), size<1>(cluster_shape_mnk));
return Params::get_tiled_cta_shape_mnl(
to_gemm_coord(problem_shape_mnkl),
to_gemm_coord(cluster_shape_mnk),
cta_m, cta_n
);
}
CUTLASS_DEVICE
@ -292,17 +443,31 @@ public:
);
}
CUTLASS_DEVICE
static auto
work_tile_to_cta_coord(WorkTileInfo work_tile_info, dim3 block_id_in_cluster) {
// Get every cta coord in three dimensions of the cluster
auto [cta_m_in_cluster, cta_n_in_cluster, cta_l_in_cluster] = block_id_in_cluster;
return make_coord(
work_tile_info.M_idx + static_cast<int32_t>(cta_m_in_cluster),
work_tile_info.N_idx + static_cast<int32_t>(cta_n_in_cluster),
_,
work_tile_info.L_idx + static_cast<int32_t>(cta_l_in_cluster)
);
}
// Given the inputs, computes the physical grid we should launch.
template<class ProblemShapeMNKL, class BlockShape, class ClusterShape>
CUTLASS_HOST_DEVICE static
dim3
get_grid_shape(
ProblemShapeMNKL problem_shape_mnk,
BlockShape cta_shape,
ClusterShape cluster_shape,
KernelHardwareInfo hw_info,
Arguments arguments,
bool truncate_by_problem_size=true) {
[[maybe_unused]] Params const& params,
ProblemShapeMNKL problem_shape_mnk,
BlockShape cta_shape,
ClusterShape cluster_shape,
KernelHardwareInfo hw_info,
Arguments arguments = Arguments{},
bool truncate_by_problem_size=true) {
auto problem_shape_mnkl = cute::append<4>(problem_shape_mnk, cute::Int<1>{});
dim3 problem_blocks = get_tiled_cta_shape_mnl(problem_shape_mnkl, cta_shape, cluster_shape);
@ -318,19 +483,17 @@ public:
}
// Given the inputs, computes the physical grid we should launch.
template<class ProblemShapeMNKL, class BlockShape, class ClusterShape>
CUTLASS_HOST_DEVICE static
dim3
template<class ProblemShapeMNKL, class TileShape, class AtomThrShape, class ClusterShape>
static dim3
get_grid_shape(
Params const& params,
ProblemShapeMNKL problem_shape_mnk,
BlockShape cta_shape,
ClusterShape cluster_shape,
KernelHardwareInfo hw_info) {
auto problem_shape_mnkl = cute::append<4>(problem_shape_mnk, cute::Int<1>{});
dim3 problem_blocks = get_tiled_cta_shape_mnl(problem_shape_mnkl, cta_shape, cluster_shape);
Params const& params,
ProblemShapeMNKL problem_shape_mnkl,
TileShape tile_shape_mnk,
AtomThrShape atom_thr_shape_mnk,
ClusterShape cluster_shape_mnk,
KernelHardwareInfo hw_info) {
dim3 problem_blocks = get_tiled_cta_shape_mnl(problem_shape_mnkl, tile_shape_mnk, atom_thr_shape_mnk, cluster_shape_mnk);
Arguments args{};
if constexpr (!std::is_const_v<decltype(args.max_swizzle_size)>) {
args.max_swizzle_size = 1 << params.log_swizzle_size_;
@ -339,7 +502,7 @@ public:
return Params::get_grid_shape(
problem_blocks,
to_gemm_coord(cluster_shape),
to_gemm_coord(cluster_shape_mnk),
hw_info,
args.max_swizzle_size,
args.raster_order,
@ -349,15 +512,15 @@ public:
// Convert CTA-level work tile info to cluster-level tile coord
CUTLASS_DEVICE
cute::Coord<int,int,int,int>
tile_info_to_coord_mnkl(WorkTileInfo work_tile_info) const {
auto
work_tile_to_cluster_coord_mnkl(WorkTileInfo work_tile_info) const {
// TileScheduler works at CTA-level, kernel works at cluster-level
int m_coord = idx2crd(work_tile_info.M_idx / params.cluster_shape_m_,
params.problem_tiles_m_);
int n_coord = idx2crd(work_tile_info.N_idx / params.cluster_shape_n_,
params.problem_tiles_n_);
int m_coord = idx2crd(work_tile_info.M_idx / scheduler_params.cluster_shape_m_,
scheduler_params.problem_tiles_m_);
int n_coord = idx2crd(work_tile_info.N_idx / scheduler_params.cluster_shape_n_,
scheduler_params.problem_tiles_n_);
int l_coord = idx2crd(work_tile_info.L_idx,
params.problem_tiles_l_);
scheduler_params.problem_tiles_l_);
return make_coord(m_coord, n_coord, _, l_coord);
}
@ -398,6 +561,14 @@ public:
return false;
}
template <class ProblemShapeMNKL, class TileShape, class Shape>
CUTLASS_DEVICE
auto
get_k_tile_iterator(WorkTileInfo const& work_tile_info, ProblemShapeMNKL problem_shape_MNKL, TileShape tile_shape, Shape) {
auto k_tiles = cute::ceil_div(cute::get<2>(problem_shape_MNKL), cute::get<2>(tile_shape));
return cute::make_coord_iterator(k_tiles);
}
template <class ProblemShape, class TileShape>
CUTLASS_HOST_DEVICE
static int
@ -463,20 +634,21 @@ public:
// The basic tile scheduler does not require any additional workspace
template <class ProblemShape, class ElementAccumulator>
static int
get_workspace_size(Arguments const&, ProblemShape, KernelHardwareInfo const&, uint32_t, const uint32_t = 1) {
static size_t
get_workspace_size(Arguments const&, ProblemShape, KernelHardwareInfo const&, uint32_t, const uint32_t = 1, uint32_t = 1) {
return 0;
}
template <class ProblemShape, class ElementAccumulator>
static cutlass::Status
initialize_workspace(Arguments const&, void*, cudaStream_t, ProblemShape, KernelHardwareInfo const&,
uint32_t, const uint32_t = 1) {
uint32_t, const uint32_t = 1, uint32_t = 1, CudaHostAdapter* cuda_adapter = nullptr) {
return Status::kSuccess;
}
public:
// Sink scheduler params as a member
Params params;
Params scheduler_params;
};
// Selector