mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[AsyncMM] re-enable and adapt to cutlass 3.6.0 (#144011)
[D68734067](https://our.internmc.facebook.com/intern/diff/D68734067) Pull Request resolved: https://github.com/pytorch/pytorch/pull/144011 Approved by: https://github.com/Skylion007, https://github.com/drisspg
This commit is contained in:
committed by
PyTorch MergeBot
parent
1e3d1738a4
commit
c70362fac8
@ -5,44 +5,17 @@
|
||||
#include <ATen/cuda/nvrtc_stub/ATenNVRTC.h>
|
||||
#include <c10/cuda/CUDAGuard.h>
|
||||
|
||||
#if false && !defined(USE_ROCM) && !defined(_WIN32) && defined(CUDA_VERSION) && \
|
||||
// Two warninngs in Cutlass included header files
|
||||
C10_DIAGNOSTIC_PUSH_AND_IGNORED_IF_DEFINED("-Wset-but-not-used")
|
||||
C10_DIAGNOSTIC_PUSH_AND_IGNORED_IF_DEFINED("-Wunused-but-set-parameter")
|
||||
|
||||
#if !defined(USE_ROCM) && !defined(_WIN32) && defined(CUDA_VERSION) && \
|
||||
CUDA_VERSION >= 12000
|
||||
#define BUILD_ASYNC_MM_KERNEL
|
||||
#endif
|
||||
|
||||
#if defined(BUILD_ASYNC_MM_KERNEL)
|
||||
|
||||
// We are going to override the cuTensorMapEncodeTiled driver api with our lazy
|
||||
// loader
|
||||
static CUresult CUDAAPI nvrtc_cuTensorMapEncodeTiled(
|
||||
CUtensorMap* tensorMap,
|
||||
CUtensorMapDataType tensorDataType,
|
||||
cuuint32_t tensorRank,
|
||||
void* globalAddress,
|
||||
const cuuint64_t* globalDim,
|
||||
const cuuint64_t* globalStrides,
|
||||
const cuuint32_t* boxDim,
|
||||
const cuuint32_t* elementStrides,
|
||||
CUtensorMapInterleave interleave,
|
||||
CUtensorMapSwizzle swizzle,
|
||||
CUtensorMapL2promotion l2Promotion,
|
||||
CUtensorMapFloatOOBfill oobFill) {
|
||||
return at::globalContext().getNVRTC().cuTensorMapEncodeTiled(
|
||||
tensorMap,
|
||||
tensorDataType,
|
||||
tensorRank,
|
||||
globalAddress,
|
||||
globalDim,
|
||||
globalStrides,
|
||||
boxDim,
|
||||
elementStrides,
|
||||
interleave,
|
||||
swizzle,
|
||||
l2Promotion,
|
||||
oobFill);
|
||||
}
|
||||
|
||||
// clang-format off
|
||||
#include <cutlass/core_io.h>
|
||||
#include <cutlass/cutlass.h>
|
||||
#include <cutlass/gemm/device/gemm.h>
|
||||
@ -50,13 +23,9 @@ static CUresult CUDAAPI nvrtc_cuTensorMapEncodeTiled(
|
||||
#include <cutlass/numeric_types.h>
|
||||
#include <cutlass/trace.h>
|
||||
#include <cutlass/util/host_tensor.h>
|
||||
|
||||
// Rename the global function symbol
|
||||
#define cuTensorMapEncodeTiled nvrtc_cuTensorMapEncodeTiled
|
||||
#include <cute/tensor.hpp>
|
||||
#undef cuTensorMapEncodeTiled
|
||||
// Set everything back to normal
|
||||
|
||||
#include <cutlass/version.h>
|
||||
#include <cutlass/gemm/collective/collective_builder.hpp>
|
||||
#include <cutlass/gemm/device/gemm_universal_adapter.h>
|
||||
#include <cutlass/epilogue/collective/collective_builder.hpp>
|
||||
@ -65,10 +34,12 @@ static CUresult CUDAAPI nvrtc_cuTensorMapEncodeTiled(
|
||||
#include <cutlass/gemm/dispatch_policy.hpp>
|
||||
#include <cutlass/gemm/kernel/gemm_universal.hpp>
|
||||
#include <cutlass/util/packed_stride.hpp>
|
||||
// clang-format on
|
||||
|
||||
#include <torch/csrc/distributed/c10d/cuda/cutlass/gemm/kernel/persistent_async_input_scheduler.cuh>
|
||||
|
||||
C10_DIAGNOSTIC_POP()
|
||||
C10_DIAGNOSTIC_POP()
|
||||
|
||||
namespace {
|
||||
|
||||
using namespace cute;
|
||||
@ -107,7 +78,7 @@ at::Tensor async_input_mm_impl(
|
||||
cutlass::epilogue::collective::EpilogueTileAuto,
|
||||
ElementAccumulator,
|
||||
ElementAccumulator,
|
||||
void,
|
||||
ElementC,
|
||||
LayoutC,
|
||||
AlignmentC,
|
||||
ElementC,
|
||||
@ -133,7 +104,7 @@ at::Tensor async_input_mm_impl(
|
||||
KernelSchedule>::CollectiveOp;
|
||||
|
||||
using GemmKernel = cutlass::gemm::kernel::GemmUniversal<
|
||||
Shape<int, int, int, int>,
|
||||
Shape<int, int, int>,
|
||||
CollectiveMainloop,
|
||||
CollectiveEpilogue,
|
||||
cutlass::gemm::PersistentAsyncInputScheduler<KernelSchedule>>;
|
||||
@ -171,7 +142,7 @@ at::Tensor async_input_mm_impl(
|
||||
|
||||
typename Gemm::Arguments arguments{
|
||||
cutlass::gemm::GemmUniversalMode::kGemm,
|
||||
{M, N, K, 1},
|
||||
{M, N, K},
|
||||
{
|
||||
reinterpret_cast<ElementA*>(a.data_ptr<at::BFloat16>()),
|
||||
stride_A,
|
||||
@ -179,7 +150,7 @@ at::Tensor async_input_mm_impl(
|
||||
stride_B,
|
||||
},
|
||||
{{1, 1},
|
||||
nullptr,
|
||||
reinterpret_cast<ElementC*>(out.data_ptr<at::BFloat16>()),
|
||||
stride_C,
|
||||
reinterpret_cast<ElementC*>(out.data_ptr<at::BFloat16>()),
|
||||
stride_C},
|
||||
|
@ -1,6 +1,6 @@
|
||||
/**
|
||||
* This file contains PersistentAsyncInputScheduler, a forked version of PersistentScheduler that
|
||||
* supports consuming asynchronous input. This tile scheduler introduces the following arguments:
|
||||
* This file contains PersistentTileSchedulerSm90, a forked version of PersistentTileSchedulerSm90
|
||||
* that supports consuming asynchronous input. This tile scheduler introduces the following arguments:
|
||||
*
|
||||
* - tiles_per_chunk_m – Specifies the size of an M chunk. Chunks are the granularity at which the
|
||||
* asynchronous input becomes ready. It must be an interger multiple of the size of an M tile.
|
||||
@ -22,6 +22,12 @@
|
||||
* CollectiveMainloop,
|
||||
* CollectiveEpilogue,
|
||||
* cutlass::gemm::PersistentAsyncInputScheduler<KernelSchedule>>;
|
||||
*
|
||||
* Unfortunately, the CRTP base class for tile schedulers (StaticPersistentTileScheduler) doesn't
|
||||
* provide enough flexibility for the required customization. We had to create a new tile scheduler
|
||||
* by copying PersistentTileSchedulerSm90 and StaticPersistentTileScheduler then customize on top of
|
||||
* it. In PersistentTileSchedulerSm90AsyncInput, we marked the customizations with "CUSTOM LOGIC BEGIN"
|
||||
* and "CUSTOM LOGIC END" comment blocks.
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
@ -68,6 +74,7 @@ namespace cutlass::gemm::kernel::detail {
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
|
||||
class PersistentTileSchedulerSm90AsyncInputParams :
|
||||
public PersistentTileSchedulerSm90Params {
|
||||
public:
|
||||
@ -80,34 +87,92 @@ class PersistentTileSchedulerSm90AsyncInput {
|
||||
private:
|
||||
uint64_t current_work_linear_idx_;
|
||||
uint64_t total_grid_size_;
|
||||
// ==============================
|
||||
// CUSTOM LOGIC BEGIN
|
||||
// ==============================
|
||||
bool is_mainloop_producer_;
|
||||
// ==============================
|
||||
// CUSTOM LOGIC END
|
||||
// ==============================
|
||||
|
||||
public:
|
||||
using WorkTileInfo = PersistentTileSchedulerSm90::WorkTileInfo;
|
||||
struct WorkTileInfo {
|
||||
int32_t M_idx = 0;
|
||||
int32_t N_idx = 0;
|
||||
int32_t L_idx = 0;
|
||||
bool is_valid_tile = false;
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
bool
|
||||
is_valid() const {
|
||||
return is_valid_tile;
|
||||
}
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
static WorkTileInfo
|
||||
invalid_work_tile() {
|
||||
return {-1, -1, -1, false};
|
||||
}
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
bool
|
||||
is_final_split(uint32_t k_tiles_per_output_tile) const {
|
||||
return true;
|
||||
}
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
int32_t
|
||||
reduction_subtile_idx() const {
|
||||
return -1;
|
||||
}
|
||||
};
|
||||
|
||||
// ==============================
|
||||
// CUSTOM LOGIC BEGIN
|
||||
// ==============================
|
||||
using Params = PersistentTileSchedulerSm90AsyncInputParams;
|
||||
// ==============================
|
||||
// CUSTOM LOGIC END
|
||||
// ==============================
|
||||
using RasterOrder = typename Params::RasterOrder;
|
||||
using RasterOrderOptions = typename Params::RasterOrderOptions;
|
||||
static constexpr bool IsDynamicPersistent = false;
|
||||
|
||||
public:
|
||||
// ==============================
|
||||
// CUSTOM LOGIC BEGIN
|
||||
// ==============================
|
||||
struct Arguments {
|
||||
int max_swizzle_size = 1;
|
||||
RasterOrderOptions raster_order = RasterOrderOptions::Heuristic;
|
||||
int max_swizzle_size;
|
||||
RasterOrderOptions raster_order;
|
||||
|
||||
// Async input specific
|
||||
int tile_idx_pivot_m = 0;
|
||||
int tiles_per_chunk_m = 0;
|
||||
uint32_t* chunk_signals = nullptr;
|
||||
int tile_idx_pivot_m;
|
||||
int tiles_per_chunk_m;
|
||||
uint32_t* chunk_signals;
|
||||
|
||||
Arguments():
|
||||
max_swizzle_size(1),
|
||||
raster_order(RasterOrderOptions::Heuristic),
|
||||
tile_idx_pivot_m(0),
|
||||
tiles_per_chunk_m(0),
|
||||
chunk_signals(nullptr) {}
|
||||
// ==============================
|
||||
// CUSTOM LOGIC END
|
||||
// ==============================
|
||||
};
|
||||
|
||||
template <class ProblemShapeMNKL, class TileShape, class ClusterShape>
|
||||
static Params
|
||||
to_underlying_arguments(
|
||||
ProblemShapeMNKL problem_shape_mnkl,
|
||||
TileShape tile_shape,
|
||||
ClusterShape cluster_shape,
|
||||
[[maybe_unused]] KernelHardwareInfo const& hw_info,
|
||||
Arguments const& arguments,
|
||||
[[maybe_unused]] void* workspace=nullptr,
|
||||
[[maybe_unused]] const uint32_t epilogue_subtile = 1) {
|
||||
ProblemShapeMNKL problem_shape_mnkl,
|
||||
TileShape tile_shape,
|
||||
ClusterShape cluster_shape,
|
||||
[[maybe_unused]] KernelHardwareInfo const& hw_info,
|
||||
Arguments const& arguments,
|
||||
[[maybe_unused]] void* workspace=nullptr,
|
||||
[[maybe_unused]] const uint32_t epilogue_subtile = 1,
|
||||
[[maybe_unused]] uint32_t ktile_start_alignment_count = 1u) {
|
||||
|
||||
// We only need the tile and cluster shape during scheduler setup, so let FTAD do the magic
|
||||
static_assert(cute::is_static<TileShape>::value);
|
||||
@ -123,9 +188,16 @@ public:
|
||||
arguments.max_swizzle_size,
|
||||
arguments.raster_order
|
||||
);
|
||||
|
||||
// ==============================
|
||||
// CUSTOM LOGIC BEGIN
|
||||
// ==============================
|
||||
params.tile_idx_pivot_m = arguments.tile_idx_pivot_m;
|
||||
params.tiles_per_chunk_m = arguments.tiles_per_chunk_m;
|
||||
params.chunk_signals = arguments.chunk_signals;
|
||||
// ==============================
|
||||
// CUSTOM LOGIC END
|
||||
// ==============================
|
||||
|
||||
return params;
|
||||
}
|
||||
@ -133,13 +205,13 @@ public:
|
||||
CUTLASS_HOST_DEVICE
|
||||
static bool
|
||||
can_implement(Arguments const& args) {
|
||||
return args.raster_order == RasterOrderOptions::AlongN;
|
||||
return args.max_swizzle_size >= 1;
|
||||
}
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
PersistentTileSchedulerSm90AsyncInput() { }
|
||||
|
||||
CUTLASS_DEVICE explicit PersistentTileSchedulerSm90AsyncInput(Params const& params_) : params(params_) {
|
||||
CUTLASS_DEVICE explicit PersistentTileSchedulerSm90AsyncInput(Params const& params_) : scheduler_params(params_) {
|
||||
// MSVC requires protecting use of CUDA-specific nonstandard syntax,
|
||||
// like blockIdx and gridDim, with __CUDA_ARCH__.
|
||||
#if defined(__CUDA_ARCH__)
|
||||
@ -150,11 +222,16 @@ public:
|
||||
current_work_linear_idx_ = uint64_t(blockIdx.x) * uint64_t(gridDim.y) + uint64_t(blockIdx.y);
|
||||
}
|
||||
|
||||
total_grid_size_ = uint64_t(gridDim.x) * uint64_t(gridDim.y) * uint64_t(gridDim.z);
|
||||
|
||||
// ==============================
|
||||
// CUSTOM LOGIC BEGIN
|
||||
// ==============================
|
||||
int warp_group_role = canonical_warp_group_idx();
|
||||
int producer_warp_group_role = canonical_warp_idx_sync() % NumWarpsPerWarpGroup;
|
||||
is_mainloop_producer_ = warp_group_role == 0 && producer_warp_group_role == 0;
|
||||
total_grid_size_ = uint64_t(gridDim.x) * uint64_t(gridDim.y) * uint64_t(gridDim.z);
|
||||
// ==============================
|
||||
// CUSTOM LOGIC END
|
||||
// ==============================
|
||||
#else
|
||||
CUTLASS_ASSERT(false && "This line should never be reached");
|
||||
#endif
|
||||
@ -177,21 +254,24 @@ public:
|
||||
CUTLASS_DEVICE
|
||||
WorkTileInfo
|
||||
get_current_work_for_linear_idx(uint64_t linear_idx) const {
|
||||
if (linear_idx >= params.blocks_per_problem_) {
|
||||
if (linear_idx >= scheduler_params.blocks_per_problem_) {
|
||||
return WorkTileInfo::invalid_work_tile();
|
||||
}
|
||||
|
||||
// Map worker's linear index into the CTA tiled problem shape to the corresponding MNL indices
|
||||
uint64_t work_idx_l, remainder;
|
||||
params.divmod_batch_(work_idx_l, remainder, linear_idx);
|
||||
scheduler_params.divmod_batch_(work_idx_l, remainder, linear_idx);
|
||||
|
||||
uint64_t blk_per_grid_dim = params.divmod_cluster_shape_minor_.divide(remainder);
|
||||
uint64_t blk_per_grid_dim = scheduler_params.divmod_cluster_shape_minor_.divide(remainder);
|
||||
|
||||
// ==============================
|
||||
// CUSTOM LOGIC BEGIN
|
||||
// ==============================
|
||||
uint64_t cluster_id, cluster_major_offset = 0, cluster_minor_offset = 0;
|
||||
params.divmod_cluster_shape_major_(cluster_id, cluster_major_offset, blk_per_grid_dim);
|
||||
scheduler_params.divmod_cluster_shape_major_(cluster_id, cluster_major_offset, blk_per_grid_dim);
|
||||
|
||||
auto [cta_m_in_cluster, cta_n_in_cluster, _] = cute::block_id_in_cluster();
|
||||
if (params.raster_order_ == RasterOrder::AlongN) {
|
||||
if (scheduler_params.raster_order_ == RasterOrder::AlongN) {
|
||||
cluster_minor_offset = cta_m_in_cluster;
|
||||
}
|
||||
else {
|
||||
@ -202,20 +282,20 @@ public:
|
||||
|
||||
uint64_t cluster_idx_minor_div_swizzle, extra, offset;
|
||||
|
||||
offset = cluster_id & ((1 << params.log_swizzle_size_) - 1);
|
||||
extra = cluster_id >> params.log_swizzle_size_;
|
||||
offset = cluster_id & ((1 << scheduler_params.log_swizzle_size_) - 1);
|
||||
extra = cluster_id >> scheduler_params.log_swizzle_size_;
|
||||
|
||||
params.divmod_cluster_blk_major_(cluster_idx_minor_div_swizzle, cluster_idx_major, extra);
|
||||
scheduler_params.divmod_cluster_blk_major_(cluster_idx_minor_div_swizzle, cluster_idx_major, extra);
|
||||
|
||||
cluster_idx_minor = cluster_idx_minor_div_swizzle * (1 << params.log_swizzle_size_) + offset;
|
||||
cluster_idx_minor = cluster_idx_minor_div_swizzle * (1 << scheduler_params.log_swizzle_size_) + offset;
|
||||
|
||||
auto minor_work_idx = static_cast<int32_t>(cluster_idx_minor * params.divmod_cluster_shape_minor_.divisor +
|
||||
auto minor_work_idx = static_cast<int32_t>(cluster_idx_minor * scheduler_params.divmod_cluster_shape_minor_.divisor +
|
||||
cluster_minor_offset);
|
||||
auto major_work_idx = static_cast<int32_t>(cluster_idx_major * params.divmod_cluster_shape_major_.divisor +
|
||||
auto major_work_idx = static_cast<int32_t>(cluster_idx_major * scheduler_params.divmod_cluster_shape_major_.divisor +
|
||||
cluster_major_offset);
|
||||
|
||||
int m, n;
|
||||
if (params.raster_order_ == RasterOrder::AlongN) {
|
||||
if (scheduler_params.raster_order_ == RasterOrder::AlongN) {
|
||||
m = minor_work_idx;
|
||||
n = major_work_idx;
|
||||
} else {
|
||||
@ -224,13 +304,13 @@ public:
|
||||
}
|
||||
|
||||
// Pivot after swizzling
|
||||
auto tiles_m = params.problem_tiles_m_ * params.cluster_shape_m_;
|
||||
m = (m + params.tile_idx_pivot_m) % tiles_m;
|
||||
auto tiles_m = scheduler_params.problem_tiles_m_ * scheduler_params.cluster_shape_m_;
|
||||
m = (m + scheduler_params.tile_idx_pivot_m) % tiles_m;
|
||||
|
||||
if (is_mainloop_producer_) {
|
||||
if (threadIdx.x == 0) {
|
||||
size_t chunk_idx = m / params.tiles_per_chunk_m;
|
||||
wait_signal(params.chunk_signals + chunk_idx);
|
||||
size_t chunk_idx = m / scheduler_params.tiles_per_chunk_m;
|
||||
wait_signal(scheduler_params.chunk_signals + chunk_idx);
|
||||
}
|
||||
|
||||
// An arbirary, non-default id
|
||||
@ -240,6 +320,9 @@ public:
|
||||
}
|
||||
|
||||
return {m, n, static_cast<int32_t>(work_idx_l), true};
|
||||
// ==============================
|
||||
// CUSTOM LOGIC END
|
||||
// ==============================
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE
|
||||
@ -248,6 +331,56 @@ public:
|
||||
current_work_linear_idx_ += total_grid_size_ * uint64_t(advance_count);
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE
|
||||
bool is_last_tile(WorkTileInfo& work_tile_info, uint32_t advance_count = 1) const {
|
||||
if (continue_current_work(work_tile_info)) {
|
||||
return false;
|
||||
}
|
||||
return not get_current_work_for_linear_idx(
|
||||
current_work_linear_idx_ + (total_grid_size_ * uint64_t(advance_count))
|
||||
).is_valid();
|
||||
}
|
||||
|
||||
// Computes the linear index within a batch given M and N tile offsets within the batch.
|
||||
// This essentially inverts the mapping performed in get_work_idx_m_and_n
|
||||
static CUTLASS_DEVICE
|
||||
uint64_t
|
||||
get_linear_idx_from_m_and_n(
|
||||
int32_t tile_m,
|
||||
int32_t tile_n,
|
||||
FastDivmodU64Pow2 const& divmod_cluster_shape_major,
|
||||
FastDivmodU64Pow2 const& divmod_cluster_shape_minor,
|
||||
FastDivmodU64 const& divmod_cluster_blk_major,
|
||||
int32_t log_swizzle_size,
|
||||
RasterOrder raster_order) {
|
||||
|
||||
uint64_t minor_work_idx, major_work_idx, cluster_minor_offset;
|
||||
if (raster_order == RasterOrder::AlongN) {
|
||||
minor_work_idx = static_cast<uint64_t>(tile_m);
|
||||
major_work_idx = static_cast<uint64_t>(tile_n);
|
||||
uint64_t cluster_m = divmod_cluster_shape_minor.divide(tile_m) * divmod_cluster_shape_minor.divisor;
|
||||
cluster_minor_offset = tile_m - cluster_m;
|
||||
}
|
||||
else {
|
||||
major_work_idx = static_cast<uint64_t>(tile_m);
|
||||
minor_work_idx = static_cast<uint64_t>(tile_n);
|
||||
uint64_t cluster_n = divmod_cluster_shape_minor.divide(tile_n) * divmod_cluster_shape_minor.divisor;
|
||||
cluster_minor_offset = tile_n - cluster_n;
|
||||
}
|
||||
|
||||
uint64_t cluster_idx_minor, cluster_idx_major, cluster_major_offset;
|
||||
cluster_idx_minor = divmod_cluster_shape_minor.divide(minor_work_idx - cluster_minor_offset);
|
||||
divmod_cluster_shape_major(cluster_idx_major, cluster_major_offset, major_work_idx);
|
||||
|
||||
uint64_t cluster_idx_minor_div_swizzle = cluster_idx_minor >> log_swizzle_size;
|
||||
uint64_t offset = cluster_idx_minor & ((1 << log_swizzle_size) - 1);
|
||||
|
||||
uint64_t extra = cluster_idx_minor_div_swizzle * divmod_cluster_blk_major.divisor + cluster_idx_major;
|
||||
|
||||
uint64_t cluster_id = (extra << log_swizzle_size) | offset;
|
||||
return (cluster_id * divmod_cluster_shape_major.divisor + cluster_major_offset) * divmod_cluster_shape_minor.divisor + cluster_minor_offset;
|
||||
}
|
||||
|
||||
// Given the inputs, computes the total number of output blocks over which this problem will compute.
|
||||
// Note that this is only the logical size of our grid, not the physical grid we will actually launch.
|
||||
template<class ProblemShapeMNKL, class BlockShape, class ClusterShape>
|
||||
@ -263,20 +396,38 @@ public:
|
||||
cta_m, cta_n
|
||||
);
|
||||
}
|
||||
// Kernel helper function to get next work ID
|
||||
template <class WorkIdPipeline, class WorkIdPipelineState>
|
||||
|
||||
// Reloaded interface that receives WorkTileInfo to deduce next work.
|
||||
// Kernel helper function to get next work tile
|
||||
CUTLASS_DEVICE
|
||||
auto
|
||||
fetch_next_work(
|
||||
WorkTileInfo work_tile_info,
|
||||
WorkIdPipeline& work_id_pipeline,
|
||||
WorkIdPipelineState work_id_pipe_consumer_state) {
|
||||
WorkTileInfo new_work_tile_info;
|
||||
advance_to_next_work();
|
||||
new_work_tile_info = get_current_work();
|
||||
fetch_next_work(WorkTileInfo work_tile_info) {
|
||||
if (continue_current_work(work_tile_info)) {
|
||||
return cute::make_tuple(work_tile_info, true);
|
||||
}
|
||||
|
||||
// Return true to indicate that the WorkID pipeline state should be advanced
|
||||
return cute::make_tuple(new_work_tile_info, true);
|
||||
advance_to_next_work();
|
||||
return cute::make_tuple(get_current_work(), true);
|
||||
}
|
||||
|
||||
// Given the inputs, computes the total number of output blocks over which this problem will compute.
|
||||
// Note that this is only the logical size of our grid, not the physical grid we will actually launch.
|
||||
template<class ProblemShapeMNKL, class TileShape, class AtomThrShape, class ClusterShape>
|
||||
CUTLASS_HOST_DEVICE static
|
||||
dim3
|
||||
get_tiled_cta_shape_mnl(ProblemShapeMNKL problem_shape_mnkl,
|
||||
TileShape tile_shape_mnk,
|
||||
AtomThrShape atom_thr_shape_mnk,
|
||||
ClusterShape cluster_shape_mnk) {
|
||||
auto [tiles_m, tiles_n, tiles_l] = product_each(ceil_div(select<0,1,3>(problem_shape_mnkl), take<0,2>(tile_shape_mnk)));
|
||||
auto cta_m = round_nearest(tiles_m * size<0>(atom_thr_shape_mnk), size<0>(cluster_shape_mnk));
|
||||
auto cta_n = round_nearest(tiles_n * size<1>(atom_thr_shape_mnk), size<1>(cluster_shape_mnk));
|
||||
|
||||
return Params::get_tiled_cta_shape_mnl(
|
||||
to_gemm_coord(problem_shape_mnkl),
|
||||
to_gemm_coord(cluster_shape_mnk),
|
||||
cta_m, cta_n
|
||||
);
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE
|
||||
@ -292,17 +443,31 @@ public:
|
||||
);
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE
|
||||
static auto
|
||||
work_tile_to_cta_coord(WorkTileInfo work_tile_info, dim3 block_id_in_cluster) {
|
||||
// Get every cta coord in three dimensions of the cluster
|
||||
auto [cta_m_in_cluster, cta_n_in_cluster, cta_l_in_cluster] = block_id_in_cluster;
|
||||
return make_coord(
|
||||
work_tile_info.M_idx + static_cast<int32_t>(cta_m_in_cluster),
|
||||
work_tile_info.N_idx + static_cast<int32_t>(cta_n_in_cluster),
|
||||
_,
|
||||
work_tile_info.L_idx + static_cast<int32_t>(cta_l_in_cluster)
|
||||
);
|
||||
}
|
||||
|
||||
// Given the inputs, computes the physical grid we should launch.
|
||||
template<class ProblemShapeMNKL, class BlockShape, class ClusterShape>
|
||||
CUTLASS_HOST_DEVICE static
|
||||
dim3
|
||||
get_grid_shape(
|
||||
ProblemShapeMNKL problem_shape_mnk,
|
||||
BlockShape cta_shape,
|
||||
ClusterShape cluster_shape,
|
||||
KernelHardwareInfo hw_info,
|
||||
Arguments arguments,
|
||||
bool truncate_by_problem_size=true) {
|
||||
[[maybe_unused]] Params const& params,
|
||||
ProblemShapeMNKL problem_shape_mnk,
|
||||
BlockShape cta_shape,
|
||||
ClusterShape cluster_shape,
|
||||
KernelHardwareInfo hw_info,
|
||||
Arguments arguments = Arguments{},
|
||||
bool truncate_by_problem_size=true) {
|
||||
|
||||
auto problem_shape_mnkl = cute::append<4>(problem_shape_mnk, cute::Int<1>{});
|
||||
dim3 problem_blocks = get_tiled_cta_shape_mnl(problem_shape_mnkl, cta_shape, cluster_shape);
|
||||
@ -318,19 +483,17 @@ public:
|
||||
}
|
||||
|
||||
// Given the inputs, computes the physical grid we should launch.
|
||||
template<class ProblemShapeMNKL, class BlockShape, class ClusterShape>
|
||||
CUTLASS_HOST_DEVICE static
|
||||
dim3
|
||||
template<class ProblemShapeMNKL, class TileShape, class AtomThrShape, class ClusterShape>
|
||||
static dim3
|
||||
get_grid_shape(
|
||||
Params const& params,
|
||||
ProblemShapeMNKL problem_shape_mnk,
|
||||
BlockShape cta_shape,
|
||||
ClusterShape cluster_shape,
|
||||
KernelHardwareInfo hw_info) {
|
||||
|
||||
auto problem_shape_mnkl = cute::append<4>(problem_shape_mnk, cute::Int<1>{});
|
||||
dim3 problem_blocks = get_tiled_cta_shape_mnl(problem_shape_mnkl, cta_shape, cluster_shape);
|
||||
Params const& params,
|
||||
ProblemShapeMNKL problem_shape_mnkl,
|
||||
TileShape tile_shape_mnk,
|
||||
AtomThrShape atom_thr_shape_mnk,
|
||||
ClusterShape cluster_shape_mnk,
|
||||
KernelHardwareInfo hw_info) {
|
||||
|
||||
dim3 problem_blocks = get_tiled_cta_shape_mnl(problem_shape_mnkl, tile_shape_mnk, atom_thr_shape_mnk, cluster_shape_mnk);
|
||||
Arguments args{};
|
||||
if constexpr (!std::is_const_v<decltype(args.max_swizzle_size)>) {
|
||||
args.max_swizzle_size = 1 << params.log_swizzle_size_;
|
||||
@ -339,7 +502,7 @@ public:
|
||||
|
||||
return Params::get_grid_shape(
|
||||
problem_blocks,
|
||||
to_gemm_coord(cluster_shape),
|
||||
to_gemm_coord(cluster_shape_mnk),
|
||||
hw_info,
|
||||
args.max_swizzle_size,
|
||||
args.raster_order,
|
||||
@ -349,15 +512,15 @@ public:
|
||||
|
||||
// Convert CTA-level work tile info to cluster-level tile coord
|
||||
CUTLASS_DEVICE
|
||||
cute::Coord<int,int,int,int>
|
||||
tile_info_to_coord_mnkl(WorkTileInfo work_tile_info) const {
|
||||
auto
|
||||
work_tile_to_cluster_coord_mnkl(WorkTileInfo work_tile_info) const {
|
||||
// TileScheduler works at CTA-level, kernel works at cluster-level
|
||||
int m_coord = idx2crd(work_tile_info.M_idx / params.cluster_shape_m_,
|
||||
params.problem_tiles_m_);
|
||||
int n_coord = idx2crd(work_tile_info.N_idx / params.cluster_shape_n_,
|
||||
params.problem_tiles_n_);
|
||||
int m_coord = idx2crd(work_tile_info.M_idx / scheduler_params.cluster_shape_m_,
|
||||
scheduler_params.problem_tiles_m_);
|
||||
int n_coord = idx2crd(work_tile_info.N_idx / scheduler_params.cluster_shape_n_,
|
||||
scheduler_params.problem_tiles_n_);
|
||||
int l_coord = idx2crd(work_tile_info.L_idx,
|
||||
params.problem_tiles_l_);
|
||||
scheduler_params.problem_tiles_l_);
|
||||
return make_coord(m_coord, n_coord, _, l_coord);
|
||||
}
|
||||
|
||||
@ -398,6 +561,14 @@ public:
|
||||
return false;
|
||||
}
|
||||
|
||||
template <class ProblemShapeMNKL, class TileShape, class Shape>
|
||||
CUTLASS_DEVICE
|
||||
auto
|
||||
get_k_tile_iterator(WorkTileInfo const& work_tile_info, ProblemShapeMNKL problem_shape_MNKL, TileShape tile_shape, Shape) {
|
||||
auto k_tiles = cute::ceil_div(cute::get<2>(problem_shape_MNKL), cute::get<2>(tile_shape));
|
||||
return cute::make_coord_iterator(k_tiles);
|
||||
}
|
||||
|
||||
template <class ProblemShape, class TileShape>
|
||||
CUTLASS_HOST_DEVICE
|
||||
static int
|
||||
@ -463,20 +634,21 @@ public:
|
||||
|
||||
// The basic tile scheduler does not require any additional workspace
|
||||
template <class ProblemShape, class ElementAccumulator>
|
||||
static int
|
||||
get_workspace_size(Arguments const&, ProblemShape, KernelHardwareInfo const&, uint32_t, const uint32_t = 1) {
|
||||
static size_t
|
||||
get_workspace_size(Arguments const&, ProblemShape, KernelHardwareInfo const&, uint32_t, const uint32_t = 1, uint32_t = 1) {
|
||||
return 0;
|
||||
}
|
||||
|
||||
template <class ProblemShape, class ElementAccumulator>
|
||||
static cutlass::Status
|
||||
initialize_workspace(Arguments const&, void*, cudaStream_t, ProblemShape, KernelHardwareInfo const&,
|
||||
uint32_t, const uint32_t = 1) {
|
||||
uint32_t, const uint32_t = 1, uint32_t = 1, CudaHostAdapter* cuda_adapter = nullptr) {
|
||||
return Status::kSuccess;
|
||||
}
|
||||
|
||||
public:
|
||||
// Sink scheduler params as a member
|
||||
Params params;
|
||||
Params scheduler_params;
|
||||
};
|
||||
|
||||
// Selector
|
||||
|
Reference in New Issue
Block a user