mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
This PR is the first step towards refactors the build for nvfuser in order to have the coegen being a standalone library.
Contents inside this PR:
1. nvfuser code base has been moved to `./nvfuser`, from `./torch/csrc/jit/codegen/cuda/`, except for registration code for integration (interface.h/interface.cpp)
2. splits the build system so nvfuser is generating its own `.so` files. Currently there are:
- `libnvfuser_codegen.so`, which contains the integration, codegen and runtime system of nvfuser
- `nvfuser.so`, which is nvfuser's python API via pybind. Python frontend is now exposed via `nvfuser._C.XXX` instead of `torch._C._nvfuser`
3. nvfuser cpp tests is currently being compiled into `nvfuser_tests`
4. cmake is refactored so that:
- nvfuser now has its own `CMakeLists.txt`, which is under `torch/csrc/jit/codegen/cuda/`.
- nvfuser backend code is not compiled inside `libtorch_cuda_xxx` any more
- nvfuser is added as a subdirectory under `./CMakeLists.txt` at the very end after torch is built.
- since nvfuser has dependency on torch, the registration of nvfuser at runtime is done via dlopen (`at::DynamicLibrary`). This avoids circular dependency in cmake, which will be a nightmare to handle. For details, look at `torch/csrc/jit/codegen/cuda/interface.cpp::LoadingNvfuserLibrary`
Future work that's scoped in following PR:
- Currently since nvfuser codegen has dependency on torch, we need to refactor that out so we can move nvfuser into a submodule and not rely on dlopen to load the library. @malfet
- Since we moved nvfuser into a cmake build, we effectively disabled bazel build for nvfuser. This could impact internal workload at Meta, so we need to put support back. cc'ing @vors
Pull Request resolved: https://github.com/pytorch/pytorch/pull/89621
Approved by: https://github.com/davidberard98
2708 lines
94 KiB
C++
2708 lines
94 KiB
C++
#include <codegen.h>
|
|
#include <expr_evaluator.h>
|
|
#include <instrumentation.h>
|
|
#include <kernel_expr_evaluator.h>
|
|
#include <kernel_ir.h>
|
|
#include <kernel_ir_dispatch.h>
|
|
#include <scheduler/mma_utils.h>
|
|
#include <type.h>
|
|
#include <utils.h>
|
|
|
|
#include <array>
|
|
#include <cmath>
|
|
#include <sstream>
|
|
#include <vector>
|
|
|
|
namespace torch {
|
|
namespace jit {
|
|
namespace fuser {
|
|
namespace cuda {
|
|
namespace codegen {
|
|
|
|
namespace {
|
|
|
|
std::string ptrType(DataType dt) {
|
|
std::stringstream ss;
|
|
ss << dt << "*";
|
|
return ss.str();
|
|
}
|
|
|
|
//! Utility class to build an argument list
|
|
class ArgumentBuilder {
|
|
public:
|
|
//! Build an argument list where each argument is separated with a comma
|
|
ArgumentBuilder() = default;
|
|
|
|
//! Build an argument list where each argument has its own line
|
|
ArgumentBuilder(int indent_level, const char* tab) {
|
|
std::stringstream ss;
|
|
for (const auto i : c10::irange(indent_level)) {
|
|
(void)i; // Suppress unused variable warning
|
|
ss << tab;
|
|
}
|
|
sep_ = ",\n" + ss.str();
|
|
}
|
|
|
|
//! Add a new argument
|
|
template <typename T>
|
|
ArgumentBuilder& arg(const T& x) {
|
|
addSeparator();
|
|
return append(x);
|
|
}
|
|
|
|
//! Append to the last argument
|
|
template <typename T>
|
|
ArgumentBuilder& append(const T& arg) {
|
|
ss_ << arg;
|
|
return *this;
|
|
}
|
|
|
|
//! Get a string of the argument list
|
|
std::string str() const {
|
|
return ss_.str();
|
|
}
|
|
|
|
friend std::ostream& operator<<(std::ostream& os, const ArgumentBuilder& ab) {
|
|
return os << ab.str();
|
|
}
|
|
|
|
private:
|
|
void addSeparator() {
|
|
if (ss_.tellp() != 0) {
|
|
ss_ << sep_;
|
|
}
|
|
}
|
|
|
|
private:
|
|
std::string sep_ = ", ";
|
|
std::stringstream ss_;
|
|
};
|
|
|
|
//! Append to the last argument
|
|
template <>
|
|
ArgumentBuilder& ArgumentBuilder::append<bool>(const bool& arg) {
|
|
ss_ << (arg ? "true" : "false");
|
|
return *this;
|
|
}
|
|
|
|
//! Returns "template_name<template_arg>"
|
|
template <typename TemplateNameT, typename TemplateArgT>
|
|
std::string genTemplate(
|
|
const TemplateNameT& template_name,
|
|
const TemplateArgT& template_arg) {
|
|
std::stringstream ss;
|
|
ss << template_name << "<" << template_arg << ">";
|
|
return ss.str();
|
|
}
|
|
|
|
//! Returns "func_name(func_arg)"
|
|
template <typename FuncNameT, typename FuncArgT>
|
|
std::string genCall(const FuncNameT& func_name, const FuncArgT& func_arg) {
|
|
std::stringstream ss;
|
|
ss << func_name << "(" << func_arg << ")";
|
|
return ss.str();
|
|
}
|
|
|
|
//! Returns "func_name<template_arg>(func_arg)"
|
|
template <typename FuncNameT, typename TemplateArgT, typename FuncArgT>
|
|
std::string genCall(
|
|
const FuncNameT& func_name,
|
|
const TemplateArgT& template_arg,
|
|
const FuncArgT& func_arg) {
|
|
std::stringstream ss;
|
|
ss << func_name << "<" << template_arg << ">(" << func_arg << ")";
|
|
return ss.str();
|
|
}
|
|
|
|
//! A utility class to check if an expression of a particular type exists
|
|
class ExprFinder : kir::ConstIrVisitor {
|
|
public:
|
|
//! True if expr or any of its nested expressions is included in
|
|
//! expr_types
|
|
static bool exists(
|
|
const Expr* expr,
|
|
const std::unordered_set<ExprType>& expr_types) {
|
|
ExprFinder finder(expr_types);
|
|
finder.handle(std::vector<const Expr*>{expr});
|
|
return finder.is_found_;
|
|
}
|
|
|
|
private:
|
|
ExprFinder(const std::unordered_set<ExprType>& expr_types)
|
|
: expr_types_(expr_types) {}
|
|
|
|
using kir::ConstIrVisitor::handle;
|
|
|
|
void handle(const Expr* expr) final {
|
|
if (expr_types_.find(expr->etype()) != expr_types_.end()) {
|
|
is_found_ = true;
|
|
return;
|
|
}
|
|
kir::ConstIrVisitor::handle(expr);
|
|
}
|
|
|
|
private:
|
|
const std::unordered_set<ExprType>& expr_types_;
|
|
bool is_found_ = false;
|
|
};
|
|
|
|
class CudaKernelGenerator : private OptOutConstDispatch {
|
|
static constexpr const char* kTab = " ";
|
|
|
|
public:
|
|
static std::string generateKernelDefinition(
|
|
const kir::Kernel* kernel,
|
|
const std::string& kernel_name) {
|
|
CudaKernelGenerator codegen(kernel);
|
|
codegen.genDeclaration(kernel_name);
|
|
codegen.startBlock();
|
|
codegen.genPrologue();
|
|
codegen.genBody();
|
|
codegen.endBlock();
|
|
TORCH_CHECK(codegen.block_nest_level_ == 0);
|
|
return codegen.code_.str();
|
|
}
|
|
|
|
private:
|
|
explicit CudaKernelGenerator(const kir::Kernel* kernel) : kernel_(kernel) {
|
|
initStringStreamFormat(code_);
|
|
}
|
|
|
|
void initStringStreamFormat(std::stringstream& ss) {
|
|
const int digits = std::numeric_limits<Double::ScalarType>::max_digits10;
|
|
ss.imbue(std::locale("C"));
|
|
ss << std::scientific << std::setprecision(digits);
|
|
}
|
|
|
|
// Generates the kernel function declaration
|
|
void genDeclaration(const std::string& kernel_name) {
|
|
const auto& kernel_summary = kernel_->summary();
|
|
|
|
code_ << "__global__ void " << kernel_name << "(";
|
|
|
|
std::unordered_set<Val*> unique_args;
|
|
|
|
std::vector<Val*> params;
|
|
|
|
// Inputs & Outputs
|
|
for (auto val : kernel_->inputs()) {
|
|
params.push_back(val);
|
|
}
|
|
for (auto val : kernel_->outputs()) {
|
|
TORCH_INTERNAL_ASSERT(
|
|
!val->isScalar(), "No scalar output is allowed: ", val->toString());
|
|
params.push_back(val);
|
|
}
|
|
|
|
// Generate parameter declarations
|
|
unsigned int duplicate_counter = 0;
|
|
for (auto i : c10::irange(params.size())) {
|
|
std::stringstream var_name_ss;
|
|
if (params[i]->isA<TensorView>()) {
|
|
var_name_ss << varName(params[i]->as<TensorView>());
|
|
} else {
|
|
var_name_ss << gen(params[i]);
|
|
}
|
|
|
|
// If value is duplicate in arguments change the name to avoid name
|
|
// conflicts in args.
|
|
if (!unique_args.emplace(params[i]).second) {
|
|
var_name_ss << "_duplicate_" << duplicate_counter++;
|
|
}
|
|
|
|
if (const auto tv = dynamic_cast<TensorView*>(params[i])) {
|
|
if (tv->isCpuScalar()) {
|
|
code_ << " CpuScalarTensor<" << params[i]->dtype() << "> "
|
|
<< var_name_ss.str();
|
|
} else {
|
|
code_
|
|
<< "Tensor<" << params[i]->dtype() << ", "
|
|
<< TensorDomain::noReductions(tv->getMaybeRFactorDomain()).size()
|
|
<< "> " << var_name_ss.str();
|
|
}
|
|
} else {
|
|
TORCH_INTERNAL_ASSERT(params[i]->isScalar()); // NOLINT (LLVM bug 48525)
|
|
TORCH_INTERNAL_ASSERT(params[i]->definition() == nullptr);
|
|
code_ << params[i]->dtype() << " " << var_name_ss.str();
|
|
}
|
|
|
|
if (i + 1 != params.size()) {
|
|
code_ << ", ";
|
|
}
|
|
}
|
|
|
|
// Global buffers
|
|
for (auto allocate : kernel_summary.global_allocations) {
|
|
TORCH_INTERNAL_ASSERT(allocate->buffer()->isA<TensorView>());
|
|
const auto tv = allocate->buffer()->as<TensorView>();
|
|
const auto& maybe_rfactor_domain = tv->domain()->hasRFactor()
|
|
? tv->domain()->getRFactorDomain()
|
|
: tv->domain()->getRootDomain();
|
|
const auto nDims = std::count_if(
|
|
maybe_rfactor_domain.begin(),
|
|
maybe_rfactor_domain.end(),
|
|
[](const IterDomain* id) { return !id->isReduction(); });
|
|
code_ << ", Tensor<" << tv->dtype() << ", " << nDims << "> "
|
|
<< varName(tv);
|
|
}
|
|
|
|
// Kernels generating random numbers take extra (seed, offset) arguments
|
|
if (kernel_summary.max_rng_offsets >= 0) {
|
|
code_ << ", at::PhiloxCudaState philox_args";
|
|
}
|
|
|
|
code_ << ") ";
|
|
}
|
|
|
|
// Generates setup code which is executed before the kernel body
|
|
void genPrologue() {
|
|
const auto& kernel_summary = kernel_->summary();
|
|
|
|
// Random number generator (optional)
|
|
if (kernel_summary.max_rng_offsets >= 0) {
|
|
indent() << "auto philox_offset = philox_args.captured_ ?\n";
|
|
indent()
|
|
<< " static_cast<uint64_t>(*(philox_args.offset_.ptr) + philox_args.offset_intragraph_) :\n";
|
|
indent() << " philox_args.offset_.val;\n";
|
|
indent() << "uint4 rng_result;\n";
|
|
indent() << "nvfuser_index_t rng_subseq = -1;\n";
|
|
indent() << "nvfuser_index_t rng_offset = -1;\n";
|
|
}
|
|
|
|
// Do we have any dynamic shared memory buffers?
|
|
const bool has_dynamic_smem =
|
|
!kernel_summary.dynamic_smem_allocations.empty();
|
|
|
|
// Do we have any reductions?
|
|
const bool has_reductions = kernel_summary.has_block_reductions ||
|
|
kernel_summary.has_grid_reductions;
|
|
const bool has_parallel_welford =
|
|
kernel_summary.has_block_welford || kernel_summary.has_grid_welford;
|
|
|
|
// Shared memory
|
|
if (has_dynamic_smem || has_reductions || has_parallel_welford) {
|
|
indent() << "alignas("
|
|
#ifndef USE_ROCM
|
|
<< 16 // always align to 16B for any shared mem allocation
|
|
#else
|
|
<< 8 // for HIP, we want 8-aligned even for smaller datatypes
|
|
#endif
|
|
<< ") extern __shared__ char array[];\n";
|
|
|
|
if (has_dynamic_smem) {
|
|
indent() << "unsigned smem_offset = 0;\n";
|
|
}
|
|
|
|
if (has_reductions || has_parallel_welford) {
|
|
indent() << "void* shared_mem = array;\n";
|
|
if (has_dynamic_smem) {
|
|
if (has_parallel_welford) {
|
|
indent() << "smem_offset += "
|
|
<< "((blockDim.x * blockDim.y * blockDim.z) * 3 * sizeof("
|
|
<< kernel_summary.largest_smem_data_type << "));\n";
|
|
} else {
|
|
indent() << "smem_offset += "
|
|
<< "((blockDim.x * blockDim.y * blockDim.z) * sizeof("
|
|
<< kernel_summary.largest_smem_data_type << "));\n";
|
|
}
|
|
}
|
|
|
|
if (has_parallel_welford) {
|
|
// Unpack shared mem pointer
|
|
auto space_type = kernel_summary.largest_smem_data_type;
|
|
indent()
|
|
<< "nvfuser_index_t block_size = blockDim.x*blockDim.y*blockDim.z;\n";
|
|
indent() << space_type << " *shared_mem_var = "
|
|
<< "static_cast<" << space_type << "*>("
|
|
<< "shared_mem);\n";
|
|
indent() << space_type
|
|
<< " *shared_mem_avg = shared_mem_var + block_size;\n";
|
|
indent() << space_type
|
|
<< " *shared_mem_n = shared_mem_avg + block_size;\n";
|
|
}
|
|
}
|
|
}
|
|
|
|
// Call the initialization function if using a custom block sync
|
|
if (std::getenv("PYTORCH_NVFUSER_USE_BLOCK_SYNC_ATOMIC")) {
|
|
indent() << "block_sync::init();\n";
|
|
}
|
|
}
|
|
|
|
void genBody() {
|
|
for (auto expr : kernel_->topLevelExprs()) {
|
|
OptOutConstDispatch::handle(expr);
|
|
}
|
|
}
|
|
|
|
void startBlock(bool continuation = false) {
|
|
if (continuation) {
|
|
code_ << "{\n";
|
|
} else {
|
|
indent() << "{\n";
|
|
}
|
|
++block_nest_level_;
|
|
}
|
|
|
|
void endBlock(const char* sep = "\n") {
|
|
--block_nest_level_;
|
|
TORCH_CHECK(block_nest_level_ >= 0);
|
|
indent() << "}" << sep;
|
|
}
|
|
|
|
std::ostream& indent() {
|
|
for (const auto i : c10::irange(block_nest_level_)) {
|
|
(void)i; // Suppress unused variable warning
|
|
code_ << kTab;
|
|
}
|
|
return code_;
|
|
}
|
|
|
|
std::string gen(const Statement* stmt) {
|
|
std::stringstream tmp_code;
|
|
initStringStreamFormat(tmp_code);
|
|
std::swap(tmp_code, code_);
|
|
OptOutConstDispatch::handle(stmt);
|
|
std::swap(tmp_code, code_);
|
|
return tmp_code.str();
|
|
}
|
|
|
|
std::string varName(const Val* val) {
|
|
std::stringstream name;
|
|
if (val->isA<TensorView>()) {
|
|
name << "T";
|
|
} else if (val->isA<kir::IntPair>()) {
|
|
name << "ip";
|
|
} else {
|
|
name << typePrefix(val->dtype());
|
|
}
|
|
name << val->name();
|
|
return name.str();
|
|
}
|
|
|
|
std::string genInline(const Statement* stmt) {
|
|
const bool saved_inline = print_inline_;
|
|
print_inline_ = true;
|
|
auto result = gen(stmt);
|
|
print_inline_ = saved_inline;
|
|
// NOLINTNEXTLINE(performance-no-automatic-move)
|
|
return result;
|
|
}
|
|
|
|
void handle(const kir::Predicate* pred) final {
|
|
TORCH_INTERNAL_ASSERT(pred->hasValue());
|
|
code_ << gen(pred->value());
|
|
}
|
|
|
|
void handle(const Bool* pred) final {
|
|
const auto def = pred->definition();
|
|
const bool has_alloc = alloc_map_.find(pred) != alloc_map_.end();
|
|
if (def != nullptr && !has_alloc) {
|
|
code_ << "(" << gen(def) << ")";
|
|
} else if (pred->isConst()) {
|
|
code_ << (*pred->value() ? "true" : "false");
|
|
} else {
|
|
code_ << varName(pred);
|
|
}
|
|
}
|
|
|
|
void handle(const Double* d) final {
|
|
const auto def = d->definition();
|
|
const bool has_alloc = alloc_map_.find(d) != alloc_map_.end();
|
|
if (def != nullptr && !has_alloc) {
|
|
code_ << "(" << gen(def) << ")";
|
|
} else if (d->isConst()) {
|
|
auto val = *d->value();
|
|
// note: default inf/nan doesn't work and should be replaced with macros
|
|
// `NAN`, `POS_INFINITY` and `NEG_INFINITY` instead.
|
|
if (std::isinf(val)) {
|
|
if (val > 0) {
|
|
code_ << "POS_INFINITY";
|
|
} else {
|
|
code_ << "NEG_INFINITY";
|
|
}
|
|
} else if (std::isnan(val)) {
|
|
code_ << "NAN";
|
|
} else {
|
|
code_ << val;
|
|
}
|
|
} else {
|
|
code_ << varName(d);
|
|
}
|
|
}
|
|
|
|
void handle(const Int* i) final {
|
|
// Check the replacement map first. If there's an entry for i, use
|
|
// the corresponding replacement.
|
|
auto replace_it = index_replacement_map_.find(i);
|
|
if (replace_it != index_replacement_map_.end()) {
|
|
code_ << replace_it->second;
|
|
return;
|
|
}
|
|
|
|
const auto def = i->definition();
|
|
const bool has_alloc = alloc_map_.find(i) != alloc_map_.end();
|
|
if (def != nullptr && !has_alloc) {
|
|
code_ << "(" << genInline(def) << ")";
|
|
} else if (i->isConst()) {
|
|
code_ << *i->value();
|
|
} else {
|
|
code_ << varName(i);
|
|
}
|
|
}
|
|
|
|
void handle(const ComplexDouble* c) final {
|
|
const auto def = c->definition();
|
|
const bool has_alloc = alloc_map_.find(c) != alloc_map_.end();
|
|
if (def != nullptr && !has_alloc) {
|
|
code_ << "(" << gen(def) << ")";
|
|
} else if (c->isConst()) {
|
|
code_ << "std::complex<double>" << *c->value();
|
|
} else {
|
|
code_ << varName(c);
|
|
}
|
|
}
|
|
|
|
void handle(const NamedScalar* ns) final {
|
|
// dim3 components are unsigned int. Cast to signed integer to
|
|
// support negative indexing
|
|
if (ns->getParallelIndex().has_value() ||
|
|
ns->getParallelDim().has_value()) {
|
|
code_ << "((nvfuser_index_t)" << ns->name() << ")";
|
|
} else {
|
|
code_ << ns->name();
|
|
}
|
|
}
|
|
|
|
//! Returns the sum of all indices in a TensorIndex,
|
|
//! or 0 if the indices vector is empty.
|
|
//! Used lowering generic tensor index and lowering
|
|
//! mma fragment indices.
|
|
std::string genTensorIndex(const kir::TensorIndex* ti) {
|
|
bool first = true;
|
|
std::stringstream index;
|
|
for (auto* ind : ti->indices()) {
|
|
if (!ind->isZeroInt()) {
|
|
if (!first) {
|
|
index << " + ";
|
|
}
|
|
index << genInline(ind);
|
|
first = false;
|
|
}
|
|
}
|
|
|
|
if (first) {
|
|
index << "0";
|
|
}
|
|
|
|
return index.str();
|
|
}
|
|
|
|
void handle(const kir::TensorIndex* ti) final {
|
|
bool is_volatile = ti->view()->getMemoryType() == MemoryType::Global &&
|
|
kernel_->summary().sync_map.needsRawSync(ti->view()).hasBID();
|
|
if (is_volatile) {
|
|
code_ << "*(volatile " << ti->getDataType().value() << "*)&";
|
|
}
|
|
code_ << varName(ti->view()) << "[" << genTensorIndex(ti) << "]";
|
|
}
|
|
|
|
void handle(const ViewAsScalar* sv) final {
|
|
indent() << gen(sv->output(0)) << " = " << gen(sv->input(0)) << "["
|
|
<< gen(sv->index()) << "];\n";
|
|
}
|
|
|
|
void handle(const IterDomain*) final {
|
|
TORCH_INTERNAL_ASSERT(false, "Unreachable");
|
|
}
|
|
|
|
void handle(const TensorDomain*) final {
|
|
TORCH_INTERNAL_ASSERT(false, "Unreachable");
|
|
}
|
|
|
|
void handle(const TensorView*) final {
|
|
TORCH_INTERNAL_ASSERT(false, "Unreachable");
|
|
}
|
|
|
|
//! Utility for generating vectorized pointer access in ldsm and
|
|
//! cpasync.
|
|
//! TODO: this access pattern as is could be merged with exisiting
|
|
//! vectorization handling logic but this path will be updated in
|
|
//! follow ups to optimize the generated assembly so keeping them
|
|
//! separate path for now.
|
|
std::string genVectorPointer(Val* val, DataType dtype, int vec_size) {
|
|
std::stringstream ss;
|
|
|
|
ss << "reinterpret_cast<Array<" << dtype << "," << vec_size << ","
|
|
<< vec_size << ">*>(&" << gen(val) << ")";
|
|
|
|
return ss.str();
|
|
}
|
|
|
|
// Utility function to emit a cp.async intrinsic
|
|
void genCpAsync(const LoadStoreOp* ldst, int vec_size) {
|
|
auto dtype = ldst->in()->getDataType().value();
|
|
|
|
if (ldst->predicate() == nullptr) {
|
|
// Out of line predicate variant
|
|
indent() << "Ampere::cpAsync("
|
|
<< genVectorPointer(ldst->out(), dtype, vec_size) << ","
|
|
<< genVectorPointer(ldst->in(), dtype, vec_size) << ");\n";
|
|
} else {
|
|
// Inline predicate variant
|
|
indent() << "Ampere::cpAsync("
|
|
<< genVectorPointer(ldst->out(), dtype, vec_size) << ","
|
|
<< genVectorPointer(ldst->in(), dtype, vec_size) << ","
|
|
<< genInline(ldst->predicate()) << ");\n";
|
|
}
|
|
}
|
|
|
|
void genLdMatrix(const LoadStoreOp* ldst, int vector_word_size) {
|
|
auto dtype = ldst->in()->getDataType().value();
|
|
indent() << "Turing::ldMatrix";
|
|
if (ldst->opType() == LoadStoreOpType::LdMatrixTranspose) {
|
|
code_ << "T";
|
|
}
|
|
code_ << " (";
|
|
code_ << "*" << genVectorPointer(ldst->out(), dtype, vector_word_size)
|
|
<< ","
|
|
<< "&" << gen(ldst->in()) << ");\n";
|
|
}
|
|
|
|
void handle(const FullOp* fop) final {
|
|
indent() << gen(fop->output(0)) << " = (" << fop->dtype() << ")"
|
|
<< gen(fop->getFillValue()) << ";\n";
|
|
}
|
|
|
|
void handle(const ARangeOp* aop) final {
|
|
auto index =
|
|
genTensorIndex(aop->getLinearLogicalIndex()->as<kir::TensorIndex>());
|
|
indent() << gen(aop->output(0)) << " = arange<" << aop->dtype() << ">";
|
|
code_ << "(" << index << ", " << gen(aop->start()) << ", "
|
|
<< gen(aop->step()) << ");\n";
|
|
}
|
|
|
|
void handle(const EyeOp* aop) final {
|
|
auto index1 = gen(aop->getIndex1());
|
|
auto index2 = gen(aop->getIndex2());
|
|
indent() << gen(aop->output(0)) << " = (" << aop->dtype() << ")";
|
|
code_ << "(" << index1 << " == " << index2 << ");\n";
|
|
}
|
|
|
|
void handle(const UnaryOp* uop) final {
|
|
bool is_vector_op = false;
|
|
size_t vector_word_size = 1;
|
|
|
|
if (uop->out()->isA<kir::TensorIndex>()) {
|
|
auto out_tv = uop->out()->as<kir::TensorIndex>()->view();
|
|
if (std::any_of(
|
|
out_tv->domain()->domain().begin(),
|
|
out_tv->domain()->domain().end(),
|
|
[&](IterDomain* id) { return id->isMma(); })) {
|
|
auto mma = dynamic_cast<MmaOp*>(
|
|
uop->out()->as<kir::TensorIndex>()->view()->definition());
|
|
TORCH_INTERNAL_ASSERT(
|
|
mma != nullptr, "CodeGen: mma op not in mma loop");
|
|
genMmaInitialization(mma, uop);
|
|
return;
|
|
}
|
|
}
|
|
|
|
if (vectorize_scope_ && uop->out()->isA<kir::TensorIndex>()) {
|
|
auto ti = uop->out()->as<kir::TensorIndex>();
|
|
|
|
bool vectorize_op = false;
|
|
bool misaligned_op = false;
|
|
|
|
for (auto id : ti->view()->domain()->domain()) {
|
|
if (!isParallelTypeVectorize(id->getParallelType())) {
|
|
continue;
|
|
}
|
|
|
|
ExpressionEvaluator expr_eval(id->fusion());
|
|
auto vector_size_optional = expr_eval.evaluate(id->extent());
|
|
|
|
TORCH_INTERNAL_ASSERT(
|
|
vector_size_optional.has_value(),
|
|
"Could not evaluate constant value bound to vectorized dim.");
|
|
|
|
vector_word_size = vector_size_optional->as<int64_t>();
|
|
|
|
vectorize_op = id->getParallelType() == ParallelType::Vectorize;
|
|
misaligned_op =
|
|
id->getParallelType() == ParallelType::MisalignedVectorize;
|
|
break;
|
|
}
|
|
|
|
if (vectorize_op) {
|
|
TORCH_INTERNAL_ASSERT(
|
|
uop->getUnaryOpType() == UnaryOpType::Set,
|
|
"Cannot vectorize operations that are not sets. ",
|
|
"Use cacheBefore and cacheAfter to store/load with vectorized reads into buffers.");
|
|
is_vector_op = true;
|
|
}
|
|
|
|
if (misaligned_op) {
|
|
is_vector_op = (uop->getUnaryOpType() == UnaryOpType::Set);
|
|
}
|
|
|
|
if (is_vector_op && !uop->in()->isScalar()) {
|
|
TORCH_INTERNAL_ASSERT(
|
|
uop->out()->dtype() == uop->in()->dtype(),
|
|
"Vectorized store/load requires input and output datatypes match.");
|
|
}
|
|
|
|
if (is_vector_op) {
|
|
auto out_tv = uop->out()->as<kir::TensorIndex>()->view();
|
|
if (uop->in()->isScalar()) {
|
|
// Note:
|
|
// Double buffered local tensors need indexed initialization,
|
|
// so will need to use `arraySet` option.
|
|
if (out_tv->getMemoryType() == MemoryType::Local &&
|
|
!(out_tv->isDoubleBuffered() || out_tv->isCircularBuffered())) {
|
|
// Vectorized initialization
|
|
indent() << varName(out_tv) << ".set(" << gen(uop->in()) << ");\n";
|
|
} else {
|
|
// Note: currently arraySet option is not vectorized, so it will
|
|
// rely on auto vectorization pass of cuda compiler.
|
|
indent() << "arraySet<" << out_tv->getDataType().value() << ", "
|
|
<< vector_word_size << ">(&" << gen(uop->out()) << ", "
|
|
<< "(" << out_tv->getDataType().value() << ")"
|
|
<< gen(uop->in()) << ");\n";
|
|
}
|
|
} else {
|
|
// Vectorized load
|
|
TORCH_INTERNAL_ASSERT(
|
|
uop->in()->isA<kir::TensorIndex>(),
|
|
"Invalid input to unary op with tensor output, found: ",
|
|
uop->in()->toString());
|
|
|
|
auto in_tv = uop->in()->as<kir::TensorIndex>()->view();
|
|
bool localToGlobal = out_tv->getMemoryType() == MemoryType::Global &&
|
|
in_tv->getMemoryType() == MemoryType::Local;
|
|
|
|
bool globalToLocal = out_tv->getMemoryType() == MemoryType::Local &&
|
|
in_tv->getMemoryType() == MemoryType::Global;
|
|
|
|
bool globalToGlobal = out_tv->getMemoryType() == MemoryType::Global &&
|
|
in_tv->getMemoryType() == MemoryType::Global;
|
|
|
|
bool is_volatile_to = out_tv->getMemoryType() == MemoryType::Global &&
|
|
kernel_->summary().sync_map.needsRawSync(out_tv).hasBID();
|
|
|
|
bool is_volatile_from =
|
|
in_tv->getMemoryType() == MemoryType::Global &&
|
|
kernel_->summary().sync_map.needsRawSync(in_tv).hasBID();
|
|
|
|
if (localToGlobal) {
|
|
indent() << "loadLocalToGlobal<" << uop->out()->dtype() << ", "
|
|
<< vector_word_size << ", "
|
|
<< (is_volatile_to ? "true" : "false") << ">(";
|
|
code_ << " &" << gen(uop->out()) << ", &" << gen(uop->in())
|
|
<< ");\n";
|
|
} else if (globalToLocal) {
|
|
indent() << "loadGlobalToLocal<" << uop->out()->dtype() << ", "
|
|
<< vector_word_size << ", "
|
|
<< (is_volatile_from ? "true" : "false") << ">(&"
|
|
<< gen(uop->out()) << ", ";
|
|
code_ << " &" << gen(uop->in()) << ");\n";
|
|
} else if (globalToGlobal) {
|
|
indent() << "loadGlobalToGlobal<" << uop->out()->dtype() << ", "
|
|
<< vector_word_size << ", "
|
|
<< (is_volatile_to ? "true" : "false") << ", "
|
|
<< (is_volatile_from ? "true" : "false") << ">(";
|
|
code_ << " &" << gen(uop->out()) << ", ";
|
|
code_ << " &" << gen(uop->in()) << ");\n";
|
|
} else {
|
|
indent() << "loadGeneric<" << uop->out()->dtype() << ", "
|
|
<< vector_word_size << ">(";
|
|
code_ << " &" << gen(uop->out()) << ", ";
|
|
code_ << " &" << gen(uop->in()) << ");\n";
|
|
}
|
|
}
|
|
return;
|
|
}
|
|
}
|
|
|
|
const auto op_type = uop->getUnaryOpType();
|
|
|
|
if (uop->out()->isA<NamedScalar>()) {
|
|
if (auto op = inline_op_str(op_type)) {
|
|
indent() << gen(uop->out()) << " = " << *op << genInline(uop->in())
|
|
<< ";\n";
|
|
}
|
|
return;
|
|
}
|
|
|
|
if (!print_inline_) {
|
|
indent() << gen(uop->out());
|
|
if (!uop->out()->isScalar() && !uop->in()->isScalar()) {
|
|
code_ << "\n";
|
|
indent() << kTab;
|
|
}
|
|
code_ << " = ";
|
|
}
|
|
|
|
if (auto op = inline_op_str(op_type)) {
|
|
if (alsoBooleanOperator(op_type) &&
|
|
uop->out()->dtype() == DataType::Bool) {
|
|
code_ << stringifyBooleanOp(op_type) << gen(uop->in());
|
|
} else {
|
|
code_ << *op << gen(uop->in());
|
|
}
|
|
} else {
|
|
if (op_type == UnaryOpType::Cast) {
|
|
const auto cast_str =
|
|
cast_func_str({uop->in()->dtype(), uop->out()->dtype()});
|
|
TORCH_INTERNAL_ASSERT(
|
|
cast_str.has_value(),
|
|
"Invalid cast. Input type: ",
|
|
uop->in()->dtype(),
|
|
", output type: ",
|
|
uop->out()->dtype());
|
|
code_ << cast_str.value();
|
|
} else {
|
|
code_ << op_type;
|
|
if (needFloatSuffix(op_type) &&
|
|
uop->out()->dtype() == DataType::Float) {
|
|
code_ << "f";
|
|
}
|
|
}
|
|
|
|
code_ << "(" << gen(uop->in()) << ")";
|
|
}
|
|
|
|
if (!print_inline_) {
|
|
code_ << ";\n";
|
|
}
|
|
}
|
|
|
|
void handle(const RNGOp* rop) final {
|
|
// TODO: TORCH_INTERNAL_ASSERT that the scheduler correctly creates an
|
|
// innermost ID of size 4 (float) or size 2 (double)?
|
|
auto index = genTensorIndex(rop->getPhiloxIndex()->as<kir::TensorIndex>());
|
|
int multiple = rop->dtype() == DataType::Double ? 2 : 4;
|
|
indent() << "nvfuser_index_t linear_index" << rop->name() << " = " << index
|
|
<< ";\n";
|
|
indent() << "nvfuser_index_t rng_subseq" << rop->name() << " = linear_index"
|
|
<< rop->name() << " / " << multiple << ";\n";
|
|
indent() << "nvfuser_index_t rng_component" << rop->name()
|
|
<< " = linear_index" << rop->name() << " % " << multiple << ";\n";
|
|
indent() << "nvfuser_index_t rng_offset" << rop->name() << " = "
|
|
<< rop->getRNGOffset() << ";\n";
|
|
indent() << "if (rng_subseq != rng_subseq" << rop->name()
|
|
<< " || rng_offset != rng_offset" << rop->name() << ") {\n";
|
|
indent() << " auto seed = philox_args.captured_ ?\n"
|
|
<< " static_cast<uint64_t>(*(philox_args.seed_.ptr)) : \n"
|
|
<< " philox_args.seed_.val;\n";
|
|
indent() << " rng_result = philox(seed, rng_subseq" << rop->name()
|
|
<< ", philox_offset / 4 + rng_offset" << rop->name() << ");\n";
|
|
indent() << " rng_subseq = rng_subseq" << rop->name() << ";\n";
|
|
indent() << " rng_offset = rng_offset" << rop->name() << ";\n";
|
|
indent() << "}\n";
|
|
auto op_type = rop->getRNGOpType();
|
|
indent() << gen(rop->output(0)) << " = " << op_type;
|
|
if (needFloatSuffix(op_type) && rop->dtype() == DataType::Float) {
|
|
code_ << "f";
|
|
}
|
|
code_ << "(rng_result, rng_component" << rop->name();
|
|
switch (op_type) {
|
|
case RNGOpType::UniformRange: {
|
|
auto parameters = rop->getParameters();
|
|
TORCH_INTERNAL_ASSERT(parameters.size() == 2);
|
|
code_ << ", " << gen(parameters[0]) << ", " << gen(parameters[1]);
|
|
break;
|
|
}
|
|
default:;
|
|
}
|
|
code_ << ");\n";
|
|
}
|
|
|
|
std::string genBinaryOp(
|
|
BinaryOpType op_type,
|
|
DataType data_type,
|
|
const std::string& lhs,
|
|
const std::string& rhs) {
|
|
std::stringstream expr;
|
|
if (auto op = inline_op_str(op_type)) {
|
|
expr << lhs << " ";
|
|
if (alsoBooleanOperator(op_type) && data_type == DataType::Bool) {
|
|
expr << stringifyBooleanOp(op_type);
|
|
} else {
|
|
expr << *op;
|
|
}
|
|
expr << " " << rhs;
|
|
} else {
|
|
if (integer_op_str(op_type) && isIntegralType(data_type)) {
|
|
auto int_op = integer_op_str(op_type);
|
|
expr << *int_op;
|
|
} else if (bool_op_str(op_type) && isBooleanType(data_type)) {
|
|
auto bool_op = bool_op_str(op_type);
|
|
expr << *bool_op;
|
|
} else {
|
|
expr << op_type;
|
|
if (needFloatSuffix(op_type) && data_type == DataType::Float) {
|
|
expr << "f";
|
|
}
|
|
}
|
|
expr << "(" << lhs << ", " << rhs << ")";
|
|
}
|
|
return expr.str();
|
|
}
|
|
|
|
// If one argument is a tensorview and the other is a scalar, make sure we
|
|
// cast the scalar to the tensorview type
|
|
std::string scalarCast(Val* lhs, Val* rhs) {
|
|
// If neither are scalars return
|
|
if (!((lhs->isScalar() || rhs->isScalar()) &&
|
|
(lhs->isA<kir::TensorIndex>() || rhs->isA<kir::TensorIndex>()))) {
|
|
return "";
|
|
}
|
|
|
|
// Looking for mixed tensorview scalar options where types don't match
|
|
// but are either both floating or both int types. We should cast
|
|
// scalar to tensorview type in these instances.
|
|
auto lhs_t = lhs->dtype();
|
|
auto rhs_t = rhs->dtype();
|
|
|
|
// If same type, don't cast anything
|
|
if (lhs_t == rhs_t) {
|
|
return "";
|
|
}
|
|
|
|
// Don't do anything when dealing with bools
|
|
if (lhs_t == DataType::Bool || rhs_t == DataType::Bool) {
|
|
return "";
|
|
}
|
|
|
|
// Mixing floating and int combination
|
|
if ((isFloatingPointType(lhs_t) != isFloatingPointType(rhs_t)) ||
|
|
(isIntegralType(lhs_t) != isIntegralType(rhs_t))) {
|
|
return "";
|
|
}
|
|
|
|
std::stringstream cast;
|
|
cast << "(" << (lhs->isA<kir::TensorIndex>() ? lhs_t : rhs_t) << ") ";
|
|
return cast.str();
|
|
}
|
|
|
|
// If possible, replace pow with mul. Return true when successful.
|
|
bool genPowerWithMul(const BinaryOp* bop) {
|
|
if (bop->getBinaryOpType() != BinaryOpType::Pow) {
|
|
return false;
|
|
}
|
|
|
|
auto rhs = bop->rhs();
|
|
c10::optional<double> exponent;
|
|
if (auto val_int = dynamic_cast<Int*>(rhs)) {
|
|
if (val_int->isConst()) {
|
|
exponent = val_int->value().value();
|
|
}
|
|
} else if (auto val_float = dynamic_cast<Double*>(rhs)) {
|
|
if (val_float->isConst()) {
|
|
auto fp_exp = val_float->value().value();
|
|
double int_exp = 0;
|
|
if (std::modf(fp_exp, &int_exp) == 0) {
|
|
exponent = int_exp;
|
|
}
|
|
}
|
|
}
|
|
|
|
if (!exponent.has_value()) {
|
|
return false;
|
|
}
|
|
|
|
// Only **2 and **3 are considered
|
|
if (!(exponent.value() == 2 || exponent.value() == 3)) {
|
|
return false;
|
|
}
|
|
|
|
auto lhs = gen(bop->lhs());
|
|
|
|
if (print_inline_) {
|
|
code_ << lhs << " * " << lhs;
|
|
if (exponent.value() == 3) {
|
|
code_ << " * " << lhs;
|
|
}
|
|
} else {
|
|
indent() << gen(bop->out());
|
|
if (bop->out()->isScalar()) {
|
|
code_ << " = " << lhs << " * " << lhs;
|
|
if (exponent.value() == 3) {
|
|
code_ << " * " << lhs;
|
|
}
|
|
} else {
|
|
code_ << "\n";
|
|
indent() << kTab << "= " << lhs << "\n";
|
|
indent() << kTab << "* " << lhs;
|
|
if (exponent.value() == 3) {
|
|
code_ << "\n";
|
|
indent() << kTab << "* " << lhs;
|
|
}
|
|
}
|
|
}
|
|
|
|
code_ << ";\n";
|
|
return true;
|
|
}
|
|
|
|
void handle(const BinaryOp* bop) final {
|
|
// Try replacing pow with mul
|
|
if (genPowerWithMul(bop)) {
|
|
return;
|
|
}
|
|
|
|
const auto op_type = bop->getBinaryOpType();
|
|
if (print_inline_) {
|
|
// Inline expression: `lhs op rhs`
|
|
code_ << genBinaryOp(
|
|
op_type, bop->out()->dtype(), gen(bop->lhs()), gen(bop->rhs()));
|
|
} else {
|
|
indent() << gen(bop->out());
|
|
if (bop->out()->isScalar()) {
|
|
// Single line: `out = lhs op rhs;`
|
|
code_ << " = "
|
|
<< genBinaryOp(
|
|
op_type,
|
|
bop->out()->dtype(),
|
|
gen(bop->lhs()),
|
|
gen(bop->rhs()));
|
|
} else {
|
|
// Split TensorView expressions across multiple lines:
|
|
//
|
|
// out
|
|
// = lhs
|
|
// op rhs;
|
|
//
|
|
|
|
auto cast = scalarCast(bop->lhs(), bop->rhs());
|
|
if (auto op = inline_op_str(op_type)) {
|
|
code_ << "\n";
|
|
indent() << kTab << "= " << (bop->lhs()->isScalar() ? cast : "")
|
|
<< gen(bop->lhs()) << "\n";
|
|
indent() << kTab;
|
|
if (alsoBooleanOperator(op_type) &&
|
|
bop->out()->dtype() == DataType::Bool) {
|
|
code_ << stringifyBooleanOp(op_type);
|
|
} else {
|
|
code_ << *op;
|
|
}
|
|
code_ << " " << (bop->rhs()->isScalar() ? cast : "")
|
|
<< gen(bop->rhs());
|
|
} else {
|
|
if (integer_op_str(op_type) && isIntegralType(bop->out()->dtype())) {
|
|
auto int_op = integer_op_str(op_type);
|
|
code_ << " = " << *int_op << "(\n";
|
|
} else if (
|
|
bool_op_str(op_type) && isBooleanType(bop->out()->dtype())) {
|
|
auto bool_op = bool_op_str(op_type);
|
|
code_ << " = " << *bool_op << "(\n";
|
|
} else {
|
|
std::stringstream op_str;
|
|
op_str << op_type;
|
|
if (needFloatSuffix(op_type) &&
|
|
bop->out()->dtype() == DataType::Float) {
|
|
op_str << "f";
|
|
}
|
|
code_ << " = " << op_str.str() << "(\n";
|
|
}
|
|
indent() << kTab << (bop->lhs()->isScalar() ? cast : "")
|
|
<< gen(bop->lhs()) << ",\n";
|
|
indent() << kTab << (bop->rhs()->isScalar() ? cast : "")
|
|
<< gen(bop->rhs()) << ")";
|
|
}
|
|
}
|
|
code_ << ";\n";
|
|
}
|
|
}
|
|
|
|
void handle(const TernaryOp* top) final {
|
|
if (!print_inline_) {
|
|
indent() << gen(top->out());
|
|
if (!top->out()->isScalar()) {
|
|
code_ << "\n";
|
|
indent() << kTab;
|
|
}
|
|
code_ << " = ";
|
|
}
|
|
|
|
code_ << top->getTernaryOpType() << "(" << gen(top->in1()) << ", ";
|
|
|
|
// Make sure the two operands of where has the same
|
|
// type. Note that compiling "where(0.0f, 0.0)" fails because of
|
|
// the overloading ambiguity.
|
|
if (top->getTernaryOpType() == TernaryOpType::Where) {
|
|
auto cast = scalarCast(top->in2(), top->in3());
|
|
code_ << (top->in2()->isScalar() ? cast : "") << gen(top->in2()) << ", "
|
|
<< (top->in3()->isScalar() ? cast : "") << gen(top->in3()) << ")";
|
|
} else {
|
|
code_ << gen(top->in2()) << ", " << gen(top->in3()) << ")";
|
|
}
|
|
|
|
if (!print_inline_) {
|
|
code_ << ";\n";
|
|
}
|
|
}
|
|
|
|
std::string genArchString(MmaOptions::MacroType macro) {
|
|
std::stringstream ss;
|
|
if (isVolta(macro)) {
|
|
ss << "Volta";
|
|
} else if (isTuring(macro)) {
|
|
ss << "Turing";
|
|
} else if (isAmpere(macro)) {
|
|
ss << "Ampere";
|
|
} else {
|
|
TORCH_INTERNAL_ASSERT(false, "mma macro unknown arch");
|
|
}
|
|
return ss.str();
|
|
}
|
|
|
|
std::string genMmaOp(const MmaOp* mma, bool init = false) {
|
|
std::stringstream ss;
|
|
auto options = mma->options();
|
|
ss << genArchString(options.macro) << "::";
|
|
if (init) {
|
|
ss << "init";
|
|
}
|
|
ss << toString(options.macro);
|
|
|
|
if (isVolta(options.macro)) {
|
|
ss << toString(options.operand_layout);
|
|
} else if (isTuring(options.macro) || isAmpere(options.macro)) {
|
|
// mma's in turing and ampere TN only, transpose is handled either
|
|
// via ldmatrix for fp16 or explicitly for other types.
|
|
ss << "TN";
|
|
}
|
|
// TODO: additional parameter could be removed by swizzling iterdomain
|
|
auto acc_stride = mma->accStride();
|
|
TORCH_INTERNAL_ASSERT(acc_stride > 0);
|
|
ss << "<" << acc_stride << ">";
|
|
return ss.str();
|
|
}
|
|
|
|
void genMmaOperands(const MmaOp* mma) {
|
|
std::stringstream ss;
|
|
auto options = mma->options();
|
|
auto in_a = mma->inA()->as<kir::TensorIndex>()->view();
|
|
auto dtype = in_a->getDataType().value();
|
|
indent() << kTab << "&(reinterpret_cast<Array<" << dtype << ","
|
|
<< getInputARegisterSize(options.macro) << ","
|
|
<< getInputARegisterSize(options.macro) << ">*>(&"
|
|
<< varName(mma->inA()->as<kir::TensorIndex>()->view()) << ")["
|
|
<< genTensorIndex(mma->inA()->as<kir::TensorIndex>()) << "])"
|
|
<< ",\n";
|
|
indent() << kTab << "&(reinterpret_cast<Array<" << dtype << ","
|
|
<< getInputBRegisterSize(options.macro) << ","
|
|
<< getInputBRegisterSize(options.macro) << ">*>(&"
|
|
<< varName(mma->inB()->as<kir::TensorIndex>()->view()) << ")["
|
|
<< genTensorIndex(mma->inB()->as<kir::TensorIndex>()) << "])";
|
|
}
|
|
|
|
void genMmaInitialization(const MmaOp* mma, const UnaryOp* uop) {
|
|
auto options = mma->options();
|
|
|
|
indent() << genMmaOp(mma, true) << "(reinterpret_cast<Array<"
|
|
<< mma->out()->getDataType().value() << ","
|
|
<< getOutputRegisterSize(options.macro) << ","
|
|
<< getOutputRegisterSize(options.macro) << ">*>"
|
|
<< "(&" << gen(uop->out()) << "));\n";
|
|
}
|
|
|
|
void handle(const MmaOp* mma) final {
|
|
auto options = mma->options();
|
|
auto out = mma->out()->as<kir::TensorIndex>();
|
|
indent() << genMmaOp(mma) << "(\n";
|
|
indent() << kTab << "reinterpret_cast<Array<"
|
|
<< out->view()->getDataType().value() << ","
|
|
<< getOutputRegisterSize(options.macro) << ","
|
|
<< getOutputRegisterSize(options.macro) << ">*>(&"
|
|
<< gen(mma->out()) << "),\n";
|
|
genMmaOperands(mma);
|
|
code_ << ");\n";
|
|
}
|
|
|
|
std::string genReductionOp(BinaryOpType op_type, DataType data_type) {
|
|
std::stringstream lambda;
|
|
lambda << "[](" << data_type << " &a, " << data_type << " b) "
|
|
<< "{ a = " << genBinaryOp(op_type, data_type, "a", "b") << "; }";
|
|
return lambda.str();
|
|
}
|
|
|
|
void handle(const BroadcastOp* stmt) final {
|
|
TORCH_INTERNAL_ASSERT(stmt->out()->isA<kir::TensorIndex>());
|
|
|
|
const ParallelTypeBitmap parallel_types =
|
|
kernel_->summary().broadcast_parallel_types.at(stmt);
|
|
|
|
if (parallel_types.none()) {
|
|
// Not parallelized
|
|
indent() << gen(stmt->out()) << "\n";
|
|
indent() << kTab << " = " << gen(stmt->in()) << ";\n";
|
|
return;
|
|
}
|
|
|
|
TORCH_INTERNAL_ASSERT(
|
|
!parallel_types.hasBID(),
|
|
"Parallel broadcast across blocks should have been translated to a GridBroadcast IR node");
|
|
|
|
std::stringstream flags_str;
|
|
for (const ParallelType pt : kParallelTypeTIDs) {
|
|
const bool parallel_bcast = parallel_types.get(pt);
|
|
if (pt != kParallelTypeTIDs[0]) {
|
|
flags_str << ", ";
|
|
}
|
|
flags_str << (parallel_bcast ? "true" : "false");
|
|
}
|
|
|
|
const auto data_type = stmt->out()->dtype();
|
|
indent() << "broadcast::blockBroadcast<" << flags_str.str() << ">(\n";
|
|
indent() << kTab << gen(stmt->out()) << ",\n";
|
|
indent() << kTab << gen(stmt->in()) << ",\n";
|
|
indent() << kTab << "static_cast<" << data_type << "*>(shared_mem),\n";
|
|
TORCH_INTERNAL_ASSERT(
|
|
stmt->predicate() != nullptr && stmt->predicate()->hasValue());
|
|
indent() << kTab << genInline(stmt->predicate()) << ");\n";
|
|
}
|
|
|
|
void genSerialReduction(
|
|
const kir::TensorIndex* output,
|
|
const Val* input,
|
|
BinaryOpType reduction_op_type) {
|
|
const auto gen_out = gen(output);
|
|
indent() << gen_out << " = "
|
|
<< genBinaryOp(
|
|
reduction_op_type, output->dtype(), gen_out, gen(input))
|
|
<< ";\n";
|
|
return;
|
|
}
|
|
|
|
void genWarpReduction(
|
|
const kir::TensorIndex* output,
|
|
const kir::TensorIndex* input,
|
|
const Val* init,
|
|
BinaryOpType reduction_op_type,
|
|
kir::Predicate* read_pred) {
|
|
bool is_single_warp =
|
|
kernel_->getWarpPaddedParallelInfo().is_tidx_single_warp;
|
|
|
|
indent() << "warp::warpReduceTIDX";
|
|
if (is_single_warp) {
|
|
code_ << "<true>(\n";
|
|
} else {
|
|
code_ << "<false>(\n";
|
|
}
|
|
indent() << kTab << gen(output) << ",\n";
|
|
indent() << kTab << gen(input) << ",\n";
|
|
indent() << kTab << genReductionOp(reduction_op_type, output->dtype())
|
|
<< ",\n";
|
|
indent() << kTab << "threadIdx,\n";
|
|
indent() << kTab << "blockDim,\n";
|
|
indent() << kTab << "static_cast<" << output->dtype()
|
|
<< "*>(shared_mem),\n";
|
|
TORCH_INTERNAL_ASSERT(read_pred != nullptr && read_pred->hasValue());
|
|
indent() << kTab << genInline(read_pred) << ",\n";
|
|
indent() << kTab << output->dtype() << "(" << genInline(init) << "));\n";
|
|
}
|
|
|
|
void genBlockReduction(
|
|
const kir::TensorIndex* output,
|
|
const kir::TensorIndex* input,
|
|
const Val* init,
|
|
BinaryOpType reduction_op_type,
|
|
kir::Predicate* read_pred,
|
|
kir::Predicate* write_pred) {
|
|
const auto par_domains = ir_utils::getParallelDomains(output);
|
|
// Get parallel reduction domains
|
|
const bool tidx =
|
|
par_domains.find(ParallelType::TIDx) != par_domains.end() &&
|
|
par_domains.at(ParallelType::TIDx)->isReduction();
|
|
const bool tidy =
|
|
par_domains.find(ParallelType::TIDy) != par_domains.end() &&
|
|
par_domains.at(ParallelType::TIDy)->isReduction();
|
|
const bool tidz =
|
|
par_domains.find(ParallelType::TIDz) != par_domains.end() &&
|
|
par_domains.at(ParallelType::TIDz)->isReduction();
|
|
|
|
const auto data_type = output->dtype();
|
|
|
|
indent() << "blockReduce<" << (tidx ? "true" : "false") << ", "
|
|
<< (tidy ? "true" : "false") << ", " << (tidz ? "true" : "false")
|
|
<< ">(\n";
|
|
indent() << kTab << gen(output) << ",\n";
|
|
indent() << kTab << gen(input) << ",\n";
|
|
indent() << kTab << genReductionOp(reduction_op_type, output->dtype())
|
|
<< ",\n";
|
|
indent() << kTab << "threadIdx,\n";
|
|
indent() << kTab << "blockDim,\n";
|
|
indent() << kTab << "static_cast<" << data_type << "*>(shared_mem),\n";
|
|
TORCH_INTERNAL_ASSERT(read_pred != nullptr && read_pred->hasValue());
|
|
indent() << kTab << genInline(read_pred) << ",\n";
|
|
// Pass the write predicate if available and different from the
|
|
// default predicate. The blockReduce runtime function uses the
|
|
// default predicate for both read and write when only the
|
|
// default one is given.
|
|
if (write_pred != nullptr) {
|
|
TORCH_INTERNAL_ASSERT(write_pred->hasValue());
|
|
indent() << kTab << genInline(write_pred) << ",\n";
|
|
}
|
|
indent() << kTab << data_type << "(" << genInline(init) << "));\n";
|
|
}
|
|
|
|
void handle(const ReductionOp* rop) final {
|
|
TORCH_INTERNAL_ASSERT(rop->out()->isA<kir::TensorIndex>());
|
|
|
|
const auto output = rop->out()->as<kir::TensorIndex>();
|
|
const auto input = rop->in()->as<kir::TensorIndex>();
|
|
const auto domain = output->view()->domain();
|
|
const auto op_type = rop->getReductionOpType();
|
|
|
|
const bool has_block_reduce = domain->hasBlockReduction();
|
|
const bool has_grid_reduce = domain->hasGridReduction();
|
|
|
|
TORCH_INTERNAL_ASSERT(
|
|
!has_grid_reduce,
|
|
"ReductionOp does not support block parallelization. GridReductionOp must be used. ",
|
|
rop->toString());
|
|
|
|
if (!has_block_reduce) {
|
|
genSerialReduction(output, input, op_type);
|
|
} else if (
|
|
auto reduction_id = ir_utils::getMaybeWarpReductionDim(output, input)) {
|
|
genWarpReduction(output, input, rop->init(), op_type, rop->predicate());
|
|
} else {
|
|
genBlockReduction(
|
|
output,
|
|
input,
|
|
rop->init(),
|
|
op_type,
|
|
rop->predicate(),
|
|
rop->writePredicate());
|
|
}
|
|
}
|
|
|
|
void handle(const LoadStoreOp* ldst) final {
|
|
// TODO:
|
|
// Need to gradually merge the code path of this
|
|
// with UnaryOp::Set for vectorization.
|
|
// There is quite a bit of possible clean up.
|
|
bool vectorize_op = false;
|
|
size_t vector_word_size = 1;
|
|
auto ti = ldst->out()->as<kir::TensorIndex>();
|
|
|
|
// Check vectorization and set vector word size
|
|
for (auto id : ti->view()->domain()->domain()) {
|
|
if (!isParallelTypeVectorize(id->getParallelType())) {
|
|
continue;
|
|
}
|
|
|
|
ExpressionEvaluator expr_eval(id->fusion());
|
|
auto vector_size_optional = expr_eval.evaluate(id->extent());
|
|
|
|
TORCH_INTERNAL_ASSERT(
|
|
vector_size_optional.has_value(),
|
|
"Could not evaluate constant value bound to vectorized dim.");
|
|
|
|
TORCH_INTERNAL_ASSERT(
|
|
id->getParallelType() != ParallelType::MisalignedVectorize,
|
|
"LoadStoreOp: no support yet for mis-aligned vectorization");
|
|
vector_word_size = vector_size_optional->as<int64_t>();
|
|
vectorize_op = true;
|
|
break;
|
|
}
|
|
|
|
// Dispatch instruction generation:
|
|
switch (ldst->opType()) {
|
|
case LoadStoreOpType::LdMatrix:
|
|
case LoadStoreOpType::LdMatrixTranspose:
|
|
TORCH_INTERNAL_ASSERT(
|
|
vectorize_op, "LdMatrix: Vectorization required: ", ldst);
|
|
genLdMatrix(ldst, vector_word_size);
|
|
break;
|
|
case LoadStoreOpType::CpAsync:
|
|
genCpAsync(ldst, vector_word_size);
|
|
break;
|
|
default:
|
|
TORCH_INTERNAL_ASSERT(false, "LoadStoreOp: Unknown op type");
|
|
}
|
|
}
|
|
|
|
void handle(const WelfordOp* wop) final {
|
|
TORCH_INTERNAL_ASSERT(wop->out()->isA<kir::TensorIndex>());
|
|
|
|
const auto out = wop->out()->as<kir::TensorIndex>();
|
|
const auto domain = out->view()->domain();
|
|
|
|
const auto out_var = wop->outVar();
|
|
const auto out_avg = wop->outAvg();
|
|
const auto out_N = wop->outN();
|
|
|
|
const auto in_var = wop->inVar();
|
|
const auto in_avg = wop->inAvg();
|
|
const auto in_N = wop->inN();
|
|
|
|
// inVar was allowed to be nullptr. Make sure it isn't.
|
|
TORCH_INTERNAL_ASSERT(
|
|
in_var != nullptr, "Welford var input nullptr not allowed");
|
|
|
|
const bool has_block_reduce = domain->hasBlockReduction();
|
|
const bool has_grid_reduce = domain->hasGridReduction();
|
|
|
|
// Serial WelfordOp generation
|
|
if (!has_block_reduce && !has_grid_reduce) {
|
|
indent() << "welfordCombine ("
|
|
<< "\n";
|
|
indent() << kTab << gen(out_avg) << ",\n";
|
|
indent() << kTab << gen(out_var) << ",\n";
|
|
indent() << kTab << gen(out_N) << ",\n";
|
|
indent() << kTab << gen(in_avg) << ",\n";
|
|
indent() << kTab << "(" << out_avg->dtype() << ")" << gen(in_var)
|
|
<< ",\n";
|
|
indent() << kTab << "(" << out_N->dtype() << ")" << gen(in_N) << ");\n";
|
|
return;
|
|
}
|
|
|
|
const auto par_domains = ir_utils::getParallelDomains(wop->out());
|
|
// Get parallel reduction domains
|
|
const bool tidx =
|
|
par_domains.find(ParallelType::TIDx) != par_domains.end() &&
|
|
par_domains.at(ParallelType::TIDx)->isReduction();
|
|
const bool tidy =
|
|
par_domains.find(ParallelType::TIDy) != par_domains.end() &&
|
|
par_domains.at(ParallelType::TIDy)->isReduction();
|
|
const bool tidz =
|
|
par_domains.find(ParallelType::TIDz) != par_domains.end() &&
|
|
par_domains.at(ParallelType::TIDz)->isReduction();
|
|
|
|
const auto data_type = wop->out()->dtype();
|
|
|
|
if (has_block_reduce) {
|
|
if (has_grid_reduce) {
|
|
// allocate block result
|
|
indent() << data_type << " "
|
|
<< "block_result_avg_" << block_reduce_name_ << " = "
|
|
<< gen(wop->initAvg()) << ";\n";
|
|
indent() << data_type << " "
|
|
<< "block_result_var_" << block_reduce_name_ << " = "
|
|
<< gen(wop->initVar()) << ";\n";
|
|
indent() << out_N->dtype() << " "
|
|
<< "block_result_n_" << block_reduce_name_ << " = "
|
|
<< gen(wop->initN()) << ";\n";
|
|
}
|
|
indent() << "blockWelford<" << (tidx ? "true" : "false") << ", "
|
|
<< (tidy ? "true" : "false") << ", " << (tidz ? "true" : "false")
|
|
<< ">(\n";
|
|
if (has_grid_reduce) {
|
|
indent() << kTab << "block_result_avg_" << block_reduce_name_ << ",\n";
|
|
indent() << kTab << "block_result_var_" << block_reduce_name_ << ",\n";
|
|
indent() << kTab << "block_result_n_" << block_reduce_name_ << ",\n";
|
|
} else {
|
|
indent() << kTab << gen(wop->outAvg()) << ",\n";
|
|
indent() << kTab << gen(wop->outVar()) << ",\n";
|
|
indent() << kTab << gen(wop->outN()) << ",\n";
|
|
}
|
|
indent() << kTab << gen(in_avg) << ",\n";
|
|
indent() << kTab << out_avg->dtype() << "(" << gen(in_var) << "),\n";
|
|
indent() << kTab << out_N->dtype() << "(" << gen(in_N) << "),\n";
|
|
indent() << kTab << "threadIdx,\n";
|
|
indent() << kTab << "blockDim,\n";
|
|
indent() << kTab << "reinterpret_cast<" << data_type
|
|
<< "*>(shared_mem_avg),\n";
|
|
indent() << kTab << "reinterpret_cast<" << data_type
|
|
<< "*>(shared_mem_var),\n";
|
|
indent() << kTab << "reinterpret_cast<" << out_N->dtype()
|
|
<< "*>(shared_mem_n),\n";
|
|
TORCH_INTERNAL_ASSERT(wop->predicate() != nullptr);
|
|
TORCH_INTERNAL_ASSERT(
|
|
wop->predicate() != nullptr && wop->predicate()->hasValue());
|
|
auto read_pred = genInline(wop->predicate());
|
|
indent() << kTab << read_pred << ",\n";
|
|
if (wop->writePredicate() != nullptr) {
|
|
TORCH_INTERNAL_ASSERT(wop->writePredicate()->hasValue());
|
|
auto write_pred = genInline(wop->writePredicate());
|
|
indent() << kTab << write_pred << ",\n";
|
|
}
|
|
indent() << kTab << data_type << "(0));\n";
|
|
}
|
|
}
|
|
|
|
// Support ReductionOp and WelfordOp
|
|
template <typename REDUCTION_OP>
|
|
std::string generateGridReduceTemplateFlags(
|
|
const REDUCTION_OP* rop,
|
|
const ParallelTypeBitmap& thread_pred) {
|
|
TORCH_INTERNAL_ASSERT(
|
|
!rop->isAllreduce(),
|
|
"This is not for the allreduce reduction kernel\n");
|
|
|
|
const auto par_domains = ir_utils::getParallelDomains(rop->outputs()[0]);
|
|
ArgumentBuilder flags;
|
|
for (const ParallelType pt : kParallelTypeThreads) {
|
|
const bool parallel_reduction =
|
|
par_domains.find(pt) != par_domains.end() &&
|
|
par_domains.at(pt)->isReduction();
|
|
const bool pred = thread_pred.get(pt);
|
|
TORCH_INTERNAL_ASSERT(
|
|
!(parallel_reduction && pred), "Cannot reduce predicated axis: ", pt);
|
|
bool flag = false;
|
|
// Currently assumed that no dimensions parallelized with blocks
|
|
// are predicated. This assumption may be lifted, but
|
|
// gridReduction would need some changes.
|
|
if (isParallelTypeBlockDim(pt)) {
|
|
TORCH_INTERNAL_ASSERT(
|
|
!pred, "Predication on block dimensions not allowed: ", pt);
|
|
flag = parallel_reduction;
|
|
} else {
|
|
flag = !pred && !parallel_reduction;
|
|
}
|
|
flags.arg(flag);
|
|
}
|
|
return flags.str();
|
|
}
|
|
|
|
// TODO: This should replace generateGridReduceTemplateFlags once
|
|
// GridWelford is refactored as GridReduction.
|
|
template <typename REDUCTION_OP>
|
|
std::string generateGridReduceTemplateFlags2(
|
|
const REDUCTION_OP* rop,
|
|
const ParallelTypeBitmap& thread_pred) {
|
|
TORCH_INTERNAL_ASSERT(
|
|
!rop->isAllreduce(),
|
|
"This is not for the allreduce reduction kernel\n");
|
|
|
|
const auto par_domains =
|
|
ir_utils::getParallelDomains(ir_utils::getTvOutput(rop));
|
|
ArgumentBuilder flags;
|
|
for (const ParallelType pt : kParallelTypeThreads) {
|
|
const bool parallel_reduction =
|
|
par_domains.find(pt) != par_domains.end() &&
|
|
par_domains.at(pt)->isReduction();
|
|
const bool pred = thread_pred.get(pt);
|
|
TORCH_INTERNAL_ASSERT(
|
|
!(parallel_reduction && pred), "Cannot reduce predicated axis: ", pt);
|
|
// Currently assumed that no dimensions parallelized with blocks
|
|
// are predicated. This assumption may be lifted, but
|
|
// gridReduction would need some changes.
|
|
if (isParallelTypeBlockDim(pt)) {
|
|
TORCH_INTERNAL_ASSERT(
|
|
!pred, "Predication on block dimensions not allowed: ", pt);
|
|
}
|
|
flags.arg(parallel_reduction);
|
|
}
|
|
return flags.str();
|
|
}
|
|
|
|
void addProfileArguments(ArgumentBuilder& func_args, const Expr* expr) {
|
|
if (isOptionEnabled(EnableOption::KernelProfile) &&
|
|
kernel_->profile().isProfiled(expr)) {
|
|
const auto& buffer_indices =
|
|
kernel_->profile().getIndicesInProfileBuffer(expr);
|
|
auto buffer = kernel_->profile().getBuffer();
|
|
TORCH_INTERNAL_ASSERT(buffer != nullptr);
|
|
for (const auto& index : buffer_indices) {
|
|
func_args.arg(varName(buffer)).append("[").append(index).append("]");
|
|
}
|
|
}
|
|
}
|
|
|
|
void handle(const kir::GridReduction* grop) final {
|
|
TORCH_INTERNAL_ASSERT(grop->out()->isA<kir::TensorIndex>());
|
|
|
|
const auto out = grop->out()->as<kir::TensorIndex>();
|
|
const auto domain = out->view()->domain();
|
|
TORCH_INTERNAL_ASSERT(domain->hasGridReduction());
|
|
|
|
const auto data_type = grop->out()->dtype();
|
|
const auto op_type = grop->getReductionOpType();
|
|
|
|
TORCH_INTERNAL_ASSERT(
|
|
grop->reduction_buffer()->buffer()->isA<TensorView>());
|
|
TORCH_INTERNAL_ASSERT(grop->sync_buffer()->buffer()->isA<TensorView>());
|
|
const auto work_buffer =
|
|
grop->reduction_buffer()->buffer()->as<TensorView>();
|
|
const auto sync_buffer = grop->sync_buffer()->buffer()->as<TensorView>();
|
|
|
|
if (grop->isAllreduce()) {
|
|
generateGridAllreduce(grop);
|
|
return;
|
|
}
|
|
|
|
const std::string flags_str =
|
|
generateGridReduceTemplateFlags2(grop, grop->threadPredicate());
|
|
|
|
const bool persistent_sync =
|
|
kernel_->summary().has_cooperative_grid_reduction;
|
|
|
|
// Since block-level reduction is already done, those dimensions
|
|
// with tidx/y/z being true do not participate in the grid
|
|
// reduction.
|
|
ArgumentBuilder template_args;
|
|
template_args.arg(flags_str).arg(persistent_sync);
|
|
|
|
ArgumentBuilder func_args(block_nest_level_ + 1, kTab);
|
|
func_args.arg(gen(grop->out()));
|
|
func_args.arg(gen(grop->in()));
|
|
func_args.arg(genReductionOp(op_type, out->dtype()));
|
|
func_args.arg("&").append(varName(work_buffer)).append("[0]");
|
|
func_args.arg("&").append(varName(sync_buffer)).append("[0]");
|
|
func_args.arg(genCall("static_cast", ptrType(data_type), "shared_mem"));
|
|
// read and write predicates
|
|
TORCH_INTERNAL_ASSERT(
|
|
grop->predicate() != nullptr && grop->predicate()->hasValue());
|
|
const auto read_pred = genInline(grop->predicate());
|
|
func_args.arg(read_pred);
|
|
if (grop->writePredicate() != nullptr) {
|
|
TORCH_INTERNAL_ASSERT(grop->writePredicate()->hasValue());
|
|
func_args.arg(genInline(grop->writePredicate()));
|
|
} else {
|
|
func_args.arg(read_pred);
|
|
}
|
|
// Init val
|
|
func_args.arg(genCall(data_type, genInline(grop->init())));
|
|
func_args.arg(genInline(grop->entrance_index()));
|
|
func_args.arg(genInline(grop->entrances()));
|
|
|
|
addProfileArguments(func_args, grop);
|
|
|
|
indent() << "reduction::gridReduce<" << template_args << ">(\n";
|
|
indent() << kTab << func_args << ");\n";
|
|
}
|
|
|
|
std::string genFusedReductionName(const TensorView* reduction_out) {
|
|
return varName(reduction_out) + "_reduction";
|
|
}
|
|
|
|
void generateGridAllreduce(const kir::GridReduction* grop) {
|
|
TORCH_INTERNAL_ASSERT(grop->isAllreduce());
|
|
|
|
const auto out = grop->out()->as<kir::TensorIndex>();
|
|
|
|
const auto data_type = grop->out()->dtype();
|
|
const auto op_type = grop->getReductionOpType();
|
|
|
|
const auto work_buffer =
|
|
grop->reduction_buffer()->buffer()->as<TensorView>();
|
|
const auto sync_buffer = grop->sync_buffer()->buffer()->as<TensorView>();
|
|
|
|
const auto reduction_name = genFusedReductionName(out->view());
|
|
|
|
// template <typename Func, typename... Types>
|
|
// __device__ __inline__ void reduce(
|
|
// RefTuple<Types...> out,
|
|
// const LocalTuple<Types...>& inp,
|
|
// VolatilePtrTuple<Types...> global_work_buffer,
|
|
// int64_t* global_sync_buffer, // Allocated as product of all
|
|
// // non-participating Grid dimension
|
|
// PtrTuple<Types...> shared_buf,
|
|
// bool read_pred, // Prevent reading from out of bounds memory
|
|
// bool write_pred, // Prevent from writing out of bounds
|
|
// const LocalTuple<Types...>& init_val,
|
|
// Func reduction_op);
|
|
|
|
indent() << reduction_name << ".reduce(\n";
|
|
|
|
ArgumentBuilder func_args(block_nest_level_ + 1, kTab);
|
|
// out
|
|
func_args.arg(genCall("RefTuple", data_type, gen(grop->out())));
|
|
// inp
|
|
func_args.arg(genCall("ConstRefTuple", data_type, gen(grop->in())));
|
|
// global_work_buffer
|
|
func_args.arg(genCall(
|
|
"VolatilePtrTuple", data_type, "&" + varName(work_buffer) + "[0]"));
|
|
// global_sync_buffer
|
|
func_args.arg("&").append(varName(sync_buffer)).append("[0]");
|
|
// shared_buf
|
|
func_args.arg(genCall(
|
|
"PtrTuple",
|
|
data_type,
|
|
genCall("static_cast", ptrType(data_type), "shared_mem")));
|
|
// read and write predicates
|
|
TORCH_INTERNAL_ASSERT(
|
|
grop->predicate() != nullptr && grop->predicate()->hasValue());
|
|
const auto read_pred = genInline(grop->predicate());
|
|
auto write_pred = read_pred;
|
|
if (grop->writePredicate() != nullptr) {
|
|
TORCH_INTERNAL_ASSERT(grop->writePredicate()->hasValue());
|
|
write_pred = genInline(grop->writePredicate());
|
|
}
|
|
func_args.arg(read_pred).arg(write_pred);
|
|
// init_val
|
|
func_args.arg(genCall("LocalTuple", data_type, genInline(grop->init())));
|
|
// reduction_op
|
|
func_args.arg(genReductionOp(op_type, out->dtype()));
|
|
|
|
addProfileArguments(func_args, grop);
|
|
|
|
indent() << kTab << func_args << ");\n";
|
|
}
|
|
|
|
void handle(const kir::GroupedGridReduction* grouped_grop) final {
|
|
const auto out = ir_utils::getTvOutput(grouped_grop);
|
|
const auto domain = out->domain();
|
|
TORCH_INTERNAL_ASSERT(domain->hasGridReduction());
|
|
|
|
TORCH_INTERNAL_ASSERT(
|
|
grouped_grop->sync_buffer()->buffer()->isA<TensorView>());
|
|
const auto sync_buffer =
|
|
grouped_grop->sync_buffer()->buffer()->as<TensorView>();
|
|
|
|
if (grouped_grop->isAllreduce()) {
|
|
generateGroupedGridAllreduce(grouped_grop);
|
|
return;
|
|
}
|
|
|
|
TORCH_INTERNAL_ASSERT(
|
|
grouped_grop->numExprs() == 2,
|
|
"Only grouping of 2 reductions is supported. ",
|
|
grouped_grop->toString());
|
|
|
|
const std::string flags_str = generateGridReduceTemplateFlags2(
|
|
grouped_grop, grouped_grop->threadPredicate());
|
|
|
|
const bool persistent_sync =
|
|
kernel_->summary().has_cooperative_grid_reduction;
|
|
|
|
// Since block-level reduction is already done, those dimensions
|
|
// with tidx/y/z being true do not participate in the grid
|
|
// reduction.
|
|
ArgumentBuilder template_args;
|
|
template_args.arg(flags_str).arg(persistent_sync);
|
|
|
|
ArgumentBuilder func_args(block_nest_level_ + 1, kTab);
|
|
|
|
// Append arguments for each reduction
|
|
for (const auto i : c10::irange(grouped_grop->numExprs())) {
|
|
TORCH_INTERNAL_ASSERT(
|
|
grouped_grop->reduction_buffers().at(i)->buffer()->isA<TensorView>());
|
|
const auto work_buffer =
|
|
grouped_grop->reduction_buffers().at(i)->buffer()->as<TensorView>();
|
|
|
|
func_args.arg(gen(grouped_grop->output(i)));
|
|
func_args.arg(gen(grouped_grop->input(i)));
|
|
func_args.arg(genCall(
|
|
grouped_grop->output(i)->dtype(),
|
|
genInline(grouped_grop->initVal(i))));
|
|
func_args.arg(genReductionOp(
|
|
grouped_grop->getReductionOpType(i),
|
|
grouped_grop->output(i)->dtype()));
|
|
func_args.arg("&").append(varName(work_buffer)).append("[0]");
|
|
}
|
|
|
|
// The rest of the arguments are common between the reductions
|
|
func_args.arg("&").append(varName(sync_buffer)).append("[0]");
|
|
func_args.arg("shared_mem");
|
|
// read and write predicates
|
|
TORCH_INTERNAL_ASSERT(
|
|
grouped_grop->predicate() != nullptr &&
|
|
grouped_grop->predicate()->hasValue());
|
|
const auto read_pred = genInline(grouped_grop->predicate());
|
|
func_args.arg(read_pred);
|
|
if (grouped_grop->writePredicate() != nullptr) {
|
|
TORCH_INTERNAL_ASSERT(grouped_grop->writePredicate()->hasValue());
|
|
func_args.arg(genInline(grouped_grop->writePredicate()));
|
|
} else {
|
|
func_args.arg(read_pred);
|
|
}
|
|
|
|
func_args.arg(genInline(grouped_grop->entrance_index()));
|
|
func_args.arg(genInline(grouped_grop->entrances()));
|
|
|
|
addProfileArguments(func_args, grouped_grop);
|
|
|
|
indent() << "reduction::gridReduceGroup<" << template_args << ">(\n";
|
|
indent() << kTab << func_args << ");\n";
|
|
}
|
|
|
|
void handle(const kir::GroupedGridWelford* grouped_gwop) final {
|
|
if (grouped_gwop->isAllreduce()) {
|
|
generateGroupedGridAllreduceWelford(grouped_gwop);
|
|
return;
|
|
} else {
|
|
TORCH_INTERNAL_ASSERT(
|
|
false, "Non-allreduce grouped grid welford is not yet supported");
|
|
}
|
|
}
|
|
|
|
// Enumerates all combinations of index values of grouped
|
|
// loops. Each combination is a vector of loop index values. The
|
|
// length of the vector is the number of grouped loops.
|
|
//
|
|
// Example 1: only one domain of extent 2 is grouped: {{0}, {1}}.
|
|
// Example 2: two domains of extents 2 and 3 are grouped: {{0, 0},
|
|
// {0, 1}, {0, 2}, {1, 0}, {1, 1}, {1, 2}}
|
|
std::vector<std::vector<int64_t>> getGroupedLoopIndexConcreteIntSets() {
|
|
std::vector<std::vector<int64_t>> index_combinationsatoins;
|
|
|
|
// Initialize with an empty vector
|
|
index_combinationsatoins.push_back(std::vector<int64_t>());
|
|
|
|
// Incrementally build a combinatorial set
|
|
for (const auto loop : grouped_loops_) {
|
|
const auto iter_count = loop->stop()->evaluateInt();
|
|
std::vector<std::vector<int64_t>> new_combinations;
|
|
// Append integers from 0 to iter_count to all the vectors built
|
|
// so far
|
|
for (const auto& index_vec : index_combinationsatoins) {
|
|
for (int64_t i = 0; i < iter_count; ++i) {
|
|
auto index_vec_appended = index_vec;
|
|
index_vec_appended.push_back(i);
|
|
new_combinations.push_back(index_vec_appended);
|
|
}
|
|
}
|
|
index_combinationsatoins = std::move(new_combinations);
|
|
}
|
|
|
|
return index_combinationsatoins;
|
|
}
|
|
|
|
//! Returns all combinations of maps from index Vals of grouped loops to their
|
|
//! conrete integers.
|
|
std::vector<std::unordered_map<const Int*, int64_t>>
|
|
getLoopIndexReplacementMaps() {
|
|
std::vector<std::unordered_map<const Int*, int64_t>> maps;
|
|
|
|
if (grouped_loops_.empty()) {
|
|
std::unordered_map<const Int*, int64_t> empty_map;
|
|
return {empty_map};
|
|
}
|
|
|
|
// Vector of indices of grouped loops
|
|
std::vector<Int*> loop_indices;
|
|
std::transform(
|
|
grouped_loops_.begin(),
|
|
grouped_loops_.end(),
|
|
std::back_inserter(loop_indices),
|
|
[](const kir::ForLoop* loop) { return loop->index()->as<Int>(); });
|
|
|
|
// All combinations of loop index integer values
|
|
const auto index_val_sets = getGroupedLoopIndexConcreteIntSets();
|
|
|
|
// Create maps from loop index Vals to integers
|
|
for (const auto& index_values : index_val_sets) {
|
|
TORCH_INTERNAL_ASSERT(loop_indices.size() == index_values.size());
|
|
std::unordered_map<const Int*, int64_t> index_val_map;
|
|
for (const auto i : c10::irange(loop_indices.size())) {
|
|
auto loop_index = loop_indices.at(i);
|
|
auto index_val = index_values.at(i);
|
|
index_val_map.emplace(loop_index, index_val);
|
|
}
|
|
maps.emplace_back(std::move(index_val_map));
|
|
}
|
|
|
|
return maps;
|
|
}
|
|
|
|
void generateGroupedGridAllreduce(
|
|
const kir::GroupedGridReduction* grouped_grop) {
|
|
TORCH_INTERNAL_ASSERT(grouped_grop->isAllreduce());
|
|
|
|
// There are two dimensions of grouping: horizontal grouping and
|
|
// iteration grouping. The total number of individual reductions
|
|
// is the number of horizontal reductions * the extent of grouped
|
|
// iterations. All of them are packed into a single grid reduction
|
|
// call. The number of reductions is limited, and currently it is
|
|
// simply an error if exceeded. This could be avoided by
|
|
// decomposing grouped_grop into smaller groups within the
|
|
// limit. TODO: Support a larger number of reductions.
|
|
|
|
// First, enumerate all combinations of loop index values of
|
|
// grouped IterDomains. If only a single domain is grouped, this
|
|
// is simply just a 1D vector of integer from 0 to extent-1. If
|
|
// two domains are grouped, combinations of two integer vectors
|
|
// are returned. These loop index value vectors are returned as a
|
|
// map from loop index Vals to concrete int values.
|
|
const auto index_replacement_maps = getLoopIndexReplacementMaps();
|
|
const auto num_grouped_iterations = index_replacement_maps.size();
|
|
|
|
// This is also checked at the lowering validaiton time, so it
|
|
// isn't strictly necessary.
|
|
TORCH_INTERNAL_ASSERT(
|
|
num_grouped_iterations * grouped_grop->numExprs() <=
|
|
kMaxNumGroupedReductions,
|
|
"Too many grouped reductions: ",
|
|
grouped_grop->toString(),
|
|
". Up to ",
|
|
kMaxNumGroupedReductions,
|
|
" reductions are allowed.");
|
|
|
|
ArgumentBuilder types;
|
|
ArgumentBuilder outputs;
|
|
ArgumentBuilder inputs;
|
|
ArgumentBuilder work_bufs;
|
|
ArgumentBuilder init_vals;
|
|
ArgumentBuilder reduction_ops;
|
|
|
|
ArgumentBuilder bool_types;
|
|
ArgumentBuilder read_preds;
|
|
ArgumentBuilder write_preds;
|
|
|
|
for (const auto expr_index : c10::irange(grouped_grop->numExprs())) {
|
|
const auto data_type = grouped_grop->outputs().at(expr_index)->dtype();
|
|
TORCH_INTERNAL_ASSERT(grouped_grop->reduction_buffers()
|
|
.at(expr_index)
|
|
->buffer()
|
|
->isA<TensorView>());
|
|
|
|
for (const auto& group_index :
|
|
c10::irange(index_replacement_maps.size())) {
|
|
// Set the index replacement map with the concrete values of
|
|
// indices of grouped loops.
|
|
index_replacement_map_ = index_replacement_maps.at(group_index);
|
|
|
|
types.arg(data_type);
|
|
|
|
// out
|
|
outputs.arg(gen(grouped_grop->outputs().at(expr_index)));
|
|
|
|
// inp
|
|
inputs.arg(gen(grouped_grop->inputs().at(expr_index)));
|
|
|
|
// global_work_buffer
|
|
const auto work_buffer = grouped_grop->reduction_buffers()
|
|
.at(expr_index)
|
|
->buffer()
|
|
->as<TensorView>();
|
|
// Separate Work buffer is used for each reduction.
|
|
auto work_buffer_offset = group_index == 0
|
|
? "0"
|
|
: (genInline(grouped_grop->buffer_stride()) + " * " +
|
|
std::to_string(group_index));
|
|
work_bufs.arg("&")
|
|
.append(varName(work_buffer))
|
|
.append("[")
|
|
.append(work_buffer_offset)
|
|
.append("]");
|
|
init_vals.arg(genInline(grouped_grop->initVal(expr_index)));
|
|
|
|
reduction_ops.arg(genReductionOp(
|
|
grouped_grop->getReductionOpType(expr_index),
|
|
grouped_grop->output(expr_index)->dtype()));
|
|
|
|
// read and write predicates
|
|
bool_types.arg("bool");
|
|
// Same argument for all inputs. Different predicates would be
|
|
// used when grouping is done across iterations
|
|
TORCH_INTERNAL_ASSERT(
|
|
grouped_grop->predicate() != nullptr &&
|
|
grouped_grop->predicate()->hasValue());
|
|
const auto read_pred = genInline(grouped_grop->predicate());
|
|
read_preds.arg(read_pred);
|
|
if (grouped_grop->writePredicate() != nullptr) {
|
|
TORCH_INTERNAL_ASSERT(grouped_grop->writePredicate()->hasValue());
|
|
write_preds.arg(genInline(grouped_grop->writePredicate()));
|
|
} else {
|
|
write_preds.arg(read_pred);
|
|
}
|
|
|
|
index_replacement_map_.clear();
|
|
}
|
|
}
|
|
|
|
ArgumentBuilder func_args(block_nest_level_ + 1, kTab);
|
|
func_args.arg(genCall("RefTuple", types, outputs));
|
|
func_args.arg(genCall("ConstRefTuple", types, inputs));
|
|
func_args.arg(genCall("VolatilePtrTuple", types, work_bufs));
|
|
func_args.arg(genCall("LocalTuple", types, init_vals));
|
|
|
|
// global_sync_buffer
|
|
const auto sync_buffer =
|
|
grouped_grop->sync_buffer()->buffer()->as<TensorView>();
|
|
func_args.arg("&").append(varName(sync_buffer)).append("[0]");
|
|
|
|
// shared_buf
|
|
func_args.arg("shared_mem");
|
|
|
|
func_args.arg(genCall("LocalTuple", bool_types, read_preds));
|
|
func_args.arg(genCall("LocalTuple", bool_types, write_preds));
|
|
|
|
addProfileArguments(func_args, grouped_grop);
|
|
|
|
func_args.arg(reduction_ops);
|
|
|
|
indent() << genFusedReductionName(ir_utils::getTvOutput(grouped_grop))
|
|
<< ".reduceGroup(\n";
|
|
indent() << kTab << func_args << ");\n";
|
|
}
|
|
|
|
// Mostly the same as the grouped grid redution version
|
|
void generateGroupedGridAllreduceWelford(
|
|
const kir::GroupedGridWelford* grouped_gwop) {
|
|
TORCH_INTERNAL_ASSERT(grouped_gwop->isAllreduce());
|
|
|
|
const auto index_replacement_maps = getLoopIndexReplacementMaps();
|
|
const auto num_grouped_iterations = index_replacement_maps.size();
|
|
|
|
// This is also checked at the lowering validaiton time, so it
|
|
// isn't strictly necessary.
|
|
TORCH_INTERNAL_ASSERT(
|
|
num_grouped_iterations * grouped_gwop->numExprs() <=
|
|
kMaxNumGroupedReductions,
|
|
"Too many grouped reductions: ",
|
|
grouped_gwop->toString(),
|
|
". Up to ",
|
|
kMaxNumGroupedReductions,
|
|
" reductions are allowed.");
|
|
|
|
ArgumentBuilder data_types;
|
|
ArgumentBuilder index_types;
|
|
|
|
// Note that the data type of var and avg and that of N are the
|
|
// same with all the welford ops since we only support
|
|
// grouping of iterations.
|
|
const auto data_type = grouped_gwop->outputVals().at(0).avg()->dtype();
|
|
const auto index_type = grouped_gwop->outputVals().at(0).N()->dtype();
|
|
|
|
std::array<ArgumentBuilder, 3> out_args;
|
|
std::array<ArgumentBuilder, 3> in_args;
|
|
std::array<ArgumentBuilder, 3> init_args;
|
|
std::array<ArgumentBuilder, 3> work_bufs;
|
|
|
|
ArgumentBuilder bool_types;
|
|
ArgumentBuilder read_preds;
|
|
ArgumentBuilder write_preds;
|
|
|
|
for (const auto expr_index : c10::irange(grouped_gwop->numExprs())) {
|
|
const auto& output = grouped_gwop->outputVals().at(expr_index);
|
|
const auto& input = grouped_gwop->inputVals().at(expr_index);
|
|
const auto& init = grouped_gwop->initVals().at(expr_index);
|
|
|
|
for (const auto& group_index :
|
|
c10::irange(index_replacement_maps.size())) {
|
|
// Set the index replacement map with the concrete values of
|
|
// indices of grouped loops.
|
|
index_replacement_map_ = index_replacement_maps.at(group_index);
|
|
|
|
data_types.arg(data_type);
|
|
index_types.arg(index_type);
|
|
|
|
auto work_buffer_offset = group_index == 0
|
|
? "0"
|
|
: (genInline(grouped_gwop->buffer_stride()) + " * " +
|
|
std::to_string(group_index));
|
|
|
|
// Setup arguments for avg, var, and N
|
|
for (const auto i : c10::irange(3)) {
|
|
out_args[i].arg(gen(output.get(i)));
|
|
in_args[i].arg(gen(input.get(i)));
|
|
init_args[i].arg(gen(init.get(i)));
|
|
const auto work_buffer = grouped_gwop->reduction_buffers()[i]
|
|
.at(expr_index)
|
|
->buffer()
|
|
->as<TensorView>();
|
|
work_bufs[i]
|
|
.arg("&")
|
|
.append(varName(work_buffer))
|
|
.append("[")
|
|
.append(work_buffer_offset)
|
|
.append("]");
|
|
}
|
|
|
|
// read and write predicates
|
|
bool_types.arg("bool");
|
|
// Same argument for all inputs. Different predicates would be
|
|
// used when grouping is done across iterations
|
|
TORCH_INTERNAL_ASSERT(grouped_gwop->predicate() != nullptr);
|
|
TORCH_INTERNAL_ASSERT(
|
|
grouped_gwop->predicate() != nullptr &&
|
|
grouped_gwop->predicate()->hasValue());
|
|
const auto read_pred = genInline(grouped_gwop->predicate());
|
|
read_preds.arg(read_pred);
|
|
if (grouped_gwop->writePredicate() != nullptr) {
|
|
TORCH_INTERNAL_ASSERT(grouped_gwop->writePredicate()->hasValue());
|
|
write_preds.arg(genInline(grouped_gwop->writePredicate()));
|
|
} else {
|
|
write_preds.arg(read_pred);
|
|
}
|
|
|
|
index_replacement_map_.clear();
|
|
}
|
|
}
|
|
|
|
ArgumentBuilder func_args(block_nest_level_ + 1, kTab);
|
|
// output
|
|
func_args.arg(genCall("RefTuple", data_types, out_args[0]));
|
|
func_args.arg(genCall("RefTuple", data_types, out_args[1]));
|
|
func_args.arg(genCall("RefTuple", index_types, out_args[2]));
|
|
// input
|
|
func_args.arg(genCall("ConstRefTuple", data_types, in_args[0]));
|
|
func_args.arg(genCall("ConstRefTuple", data_types, in_args[1]));
|
|
func_args.arg(genCall("ConstRefTuple", index_types, in_args[2]));
|
|
// init
|
|
func_args.arg(genCall("LocalTuple", data_types, init_args[0]));
|
|
func_args.arg(genCall("LocalTuple", data_types, init_args[1]));
|
|
func_args.arg(genCall("LocalTuple", index_types, init_args[2]));
|
|
// work buffer
|
|
func_args.arg(genCall("VolatilePtrTuple", data_types, work_bufs[0]));
|
|
func_args.arg(genCall("VolatilePtrTuple", data_types, work_bufs[1]));
|
|
func_args.arg(genCall("VolatilePtrTuple", index_types, work_bufs[2]));
|
|
// global_sync_buffer
|
|
const auto sync_buffer =
|
|
grouped_gwop->sync_buffer()->buffer()->as<TensorView>();
|
|
func_args.arg("&").append(varName(sync_buffer)).append("[0]");
|
|
|
|
// shared_buf
|
|
ArgumentBuilder smem_buffer_args;
|
|
smem_buffer_args.arg(
|
|
genCall("reinterpret_cast", ptrType(data_type), "shared_mem_avg"));
|
|
smem_buffer_args.arg(
|
|
genCall("reinterpret_cast", ptrType(data_type), "shared_mem_var"));
|
|
smem_buffer_args.arg(
|
|
genCall("reinterpret_cast", ptrType(index_type), "shared_mem_n"));
|
|
func_args.arg(genCall(
|
|
"PtrTuple",
|
|
ArgumentBuilder().arg(data_type).arg(data_type).arg(index_type),
|
|
smem_buffer_args));
|
|
|
|
func_args.arg(genCall("LocalTuple", bool_types, read_preds));
|
|
func_args.arg(genCall("LocalTuple", bool_types, write_preds));
|
|
|
|
addProfileArguments(func_args, grouped_gwop);
|
|
|
|
ArgumentBuilder func_template_args;
|
|
func_template_args.arg(
|
|
grouped_gwop->numExprs() * index_replacement_maps.size());
|
|
func_template_args.arg(data_type);
|
|
func_template_args.arg(index_type);
|
|
|
|
indent() << genCall(
|
|
genFusedReductionName(ir_utils::getTvOutput(grouped_gwop)) +
|
|
".welfordGroup",
|
|
func_template_args,
|
|
func_args)
|
|
<< ";\n";
|
|
}
|
|
|
|
void handle(const kir::GridBroadcast* grop) final {
|
|
const auto bop = grop->broadcast_op();
|
|
TORCH_INTERNAL_ASSERT(bop->out()->isA<kir::TensorIndex>());
|
|
|
|
const ParallelTypeBitmap parallel_types =
|
|
kernel_->summary().broadcast_parallel_types.at(bop);
|
|
|
|
TORCH_INTERNAL_ASSERT(
|
|
parallel_types.hasBID(),
|
|
"GridBroadcast needs to be used with a broadcast op that is parallelized with the BID parallel types");
|
|
|
|
TORCH_INTERNAL_ASSERT(
|
|
grop->broadcast_buffer()->buffer()->isA<TensorView>());
|
|
TORCH_INTERNAL_ASSERT(grop->sync_buffer()->buffer()->isA<TensorView>());
|
|
const auto work_buffer =
|
|
grop->broadcast_buffer()->buffer()->as<TensorView>();
|
|
const auto sync_buffer = grop->sync_buffer()->buffer()->as<TensorView>();
|
|
|
|
std::stringstream flags_str;
|
|
for (const ParallelType pt : kParallelTypeThreads) {
|
|
const bool parallel_bcast = parallel_types.get(pt);
|
|
if (pt != kParallelTypeThreads[0]) {
|
|
flags_str << ", ";
|
|
}
|
|
flags_str << (parallel_bcast ? "true" : "false");
|
|
}
|
|
|
|
// Since block-level broadcast has not necessarily been performed before
|
|
// this function call, so grid broadcast may be broadcasting across both
|
|
// the grid and the block level.
|
|
indent() << "grid_broadcast::broadcast<" << flags_str.str() << ">(\n";
|
|
indent() << kTab << gen(bop->out()) << ",\n";
|
|
indent() << kTab << gen(bop->in()) << ",\n";
|
|
indent() << kTab << "&" << varName(work_buffer) << "[0],\n";
|
|
indent() << kTab << varName(sync_buffer) << ",\n";
|
|
TORCH_INTERNAL_ASSERT(
|
|
grop->predicate() != nullptr && grop->predicate()->hasValue());
|
|
indent() << kTab << genInline(grop->predicate()) << ");\n";
|
|
}
|
|
|
|
void handle(const kir::GridWelford* gwop) final {
|
|
const auto wop = gwop->welford_op();
|
|
TORCH_INTERNAL_ASSERT(wop->outAvg()->isA<kir::TensorIndex>());
|
|
|
|
const auto out = wop->out()->as<kir::TensorIndex>();
|
|
const auto domain = out->view()->domain();
|
|
TORCH_INTERNAL_ASSERT(domain->hasGridReduction());
|
|
|
|
const auto data_type = out->dtype();
|
|
|
|
TORCH_INTERNAL_ASSERT(gwop->var_buffer()->buffer()->isA<TensorView>());
|
|
TORCH_INTERNAL_ASSERT(gwop->sync_buffer()->buffer()->isA<TensorView>());
|
|
|
|
const auto avg_buffer = gwop->avg_buffer()->buffer()->as<TensorView>();
|
|
const auto var_buffer = gwop->var_buffer()->buffer()->as<TensorView>();
|
|
const auto n_buffer = gwop->N_buffer()->buffer()->as<TensorView>();
|
|
const auto sync_buffer = gwop->sync_buffer()->buffer()->as<TensorView>();
|
|
|
|
if (wop->isAllreduce()) {
|
|
generateGridAllreduce(gwop);
|
|
return;
|
|
}
|
|
|
|
const bool persistent_sync =
|
|
kernel_->summary().has_cooperative_grid_reduction;
|
|
|
|
const std::string flags_str =
|
|
generateGridReduceTemplateFlags(wop, gwop->threadPredicate());
|
|
|
|
// Since block-level reduction is already done, those dimensions
|
|
// with tidx/y/z being true do not participate in the grid reduction.
|
|
indent() << "welford::gridWelford<" << flags_str << ", "
|
|
<< (persistent_sync ? "true" : "false") << ">(\n";
|
|
indent() << kTab << gen(wop->outAvg()) << ",\n";
|
|
indent() << kTab << gen(wop->outVar()) << ",\n";
|
|
indent() << kTab << gen(wop->outN()) << ",\n";
|
|
if (domain->hasBlockReduction()) {
|
|
indent() << kTab << "block_result_avg_" << block_reduce_name_ << ",\n";
|
|
indent() << kTab << "block_result_var_" << block_reduce_name_ << ",\n";
|
|
indent() << kTab << "block_result_n_" << block_reduce_name_ << ",\n";
|
|
block_reduce_name_++;
|
|
} else {
|
|
indent() << kTab << gen(wop->inAvg()) << ",\n";
|
|
TORCH_INTERNAL_ASSERT(
|
|
wop->inVar() != nullptr, "Welford var input nullptr not allowed");
|
|
indent() << kTab << "(" << wop->outVar()->dtype() << ")"
|
|
<< gen(wop->inVar()) << ",\n";
|
|
indent() << kTab << "(" << wop->outN()->dtype() << ")" << gen(wop->inN())
|
|
<< ",\n";
|
|
}
|
|
indent() << kTab << "&" << varName(avg_buffer) << "[0],\n";
|
|
indent() << kTab << "&" << varName(var_buffer) << "[0],\n";
|
|
indent() << kTab << "&" << varName(n_buffer) << "[0],\n";
|
|
indent() << kTab << varName(sync_buffer) << ",\n";
|
|
indent() << kTab << "reinterpret_cast<" << data_type
|
|
<< "*>(shared_mem_avg),\n";
|
|
indent() << kTab << "reinterpret_cast<" << data_type
|
|
<< "*>(shared_mem_var),\n";
|
|
indent() << kTab << "reinterpret_cast<" << wop->outN()->dtype()
|
|
<< "*>(shared_mem_n),\n";
|
|
TORCH_INTERNAL_ASSERT(
|
|
gwop->predicate() != nullptr && gwop->predicate()->hasValue());
|
|
auto read_pred = genInline(gwop->predicate());
|
|
indent() << kTab << read_pred << ",\n";
|
|
if (gwop->writePredicate() != nullptr) {
|
|
TORCH_INTERNAL_ASSERT(gwop->writePredicate()->hasValue());
|
|
auto write_pred = genInline(gwop->writePredicate());
|
|
indent() << kTab << write_pred << ",\n";
|
|
} else {
|
|
indent() << kTab << read_pred << ",\n";
|
|
}
|
|
// TODO : init value support or remove.
|
|
indent() << kTab << data_type << "(0),\n";
|
|
indent() << kTab << genInline(gwop->entrance_index()) << ",\n";
|
|
indent() << kTab << genInline(gwop->entrances());
|
|
code_ << ");\n";
|
|
}
|
|
|
|
void generateGridAllreduce(const kir::GridWelford* gwop) {
|
|
const auto wop = gwop->welford_op();
|
|
TORCH_INTERNAL_ASSERT(wop->isAllreduce());
|
|
|
|
const auto out = wop->out()->as<kir::TensorIndex>();
|
|
|
|
const auto data_type = wop->outAvg()->dtype();
|
|
const auto index_type = wop->outN()->dtype();
|
|
TORCH_INTERNAL_ASSERT(wop->outAvg()->dtype() == wop->outVar()->dtype());
|
|
|
|
ArgumentBuilder data_type_args;
|
|
data_type_args.arg(data_type).arg(data_type).arg(index_type);
|
|
|
|
const auto sync_buffer = gwop->sync_buffer()->buffer()->as<TensorView>();
|
|
|
|
const auto reduction_name = genFusedReductionName(out->view());
|
|
|
|
// template <typename Func, typename... Types>
|
|
// __device__ __inline__ void reduce(
|
|
// RefTuple<Types...> out,
|
|
// const LocalTuple<Types...>& inp,
|
|
// VolatilePtrTuple<Types...> global_work_buffer,
|
|
// int64_t* global_sync_buffer, // Allocated as product of all
|
|
// // non-participating Grid dimension
|
|
// PtrTuple<Types...> shared_buf,
|
|
// bool read_pred, // Prevent reading from out of bounds memory
|
|
// bool write_pred, // Prevent from writing out of bounds
|
|
// const LocalTuple<Types...>& init_val,
|
|
// Func reduction_op);
|
|
|
|
ArgumentBuilder out_args;
|
|
out_args.arg(gen(wop->outAvg()));
|
|
out_args.arg(gen(wop->outVar()));
|
|
out_args.arg(gen(wop->outN()));
|
|
|
|
ArgumentBuilder in_args;
|
|
in_args.arg(gen(wop->inAvg()));
|
|
if (wop->inVar() != nullptr) {
|
|
in_args.arg(gen(wop->inVar()));
|
|
} else {
|
|
in_args.arg("(").append(data_type).append(")0");
|
|
}
|
|
in_args.arg(gen(wop->inN()));
|
|
|
|
ArgumentBuilder init_args;
|
|
init_args.arg(gen(wop->initAvg()));
|
|
init_args.arg(gen(wop->initVar()));
|
|
init_args.arg(gen(wop->initN()));
|
|
|
|
ArgumentBuilder work_buffer_args;
|
|
work_buffer_args.arg("&")
|
|
.append(varName(gwop->avg_buffer()->buffer()->as<TensorView>()))
|
|
.append("[0]");
|
|
work_buffer_args.arg("&")
|
|
.append(varName(gwop->var_buffer()->buffer()->as<TensorView>()))
|
|
.append("[0]");
|
|
work_buffer_args.arg("&")
|
|
.append(varName(gwop->N_buffer()->buffer()->as<TensorView>()))
|
|
.append("[0]");
|
|
|
|
ArgumentBuilder smem_buffer_args;
|
|
smem_buffer_args.arg(
|
|
genCall("reinterpret_cast", ptrType(data_type), "shared_mem_avg"));
|
|
smem_buffer_args.arg(
|
|
genCall("reinterpret_cast", ptrType(data_type), "shared_mem_var"));
|
|
smem_buffer_args.arg(
|
|
genCall("reinterpret_cast", ptrType(index_type), "shared_mem_n"));
|
|
|
|
ArgumentBuilder func_args(block_nest_level_ + 1, kTab);
|
|
// out
|
|
func_args.arg(genCall("RefTuple", data_type_args, out_args));
|
|
// inp
|
|
func_args.arg(genCall("ConstRefTuple", data_type_args, in_args));
|
|
// global_work_buffer
|
|
func_args.arg(
|
|
genCall("VolatilePtrTuple", data_type_args, work_buffer_args));
|
|
// global_sync_buffer
|
|
func_args.arg("&").append(varName(sync_buffer)).append("[0]");
|
|
// shared_buf
|
|
func_args.arg(genCall("PtrTuple", data_type_args, smem_buffer_args));
|
|
// read and write predicates
|
|
TORCH_INTERNAL_ASSERT(
|
|
gwop->predicate() != nullptr && gwop->predicate()->hasValue());
|
|
const auto read_pred = genInline(gwop->predicate());
|
|
auto write_pred = read_pred;
|
|
if (gwop->writePredicate() != nullptr) {
|
|
TORCH_INTERNAL_ASSERT(gwop->writePredicate()->hasValue());
|
|
write_pred = genInline(gwop->writePredicate());
|
|
}
|
|
func_args.arg(read_pred).arg(write_pred);
|
|
// init_val
|
|
func_args.arg(genCall("LocalTuple", data_type_args, init_args));
|
|
// reduction_op
|
|
func_args.arg(genTemplate(
|
|
"welfordCombine", ArgumentBuilder().arg(data_type).arg(index_type)));
|
|
|
|
indent() << reduction_name << ".reduce(\n";
|
|
indent() << kTab << func_args << ");\n";
|
|
}
|
|
|
|
void handle(const kir::AllocateFusedReduction* alloc_fused_reduction) final {
|
|
// See the runtime file of the fused reduction
|
|
enum class ReductionParallelTypeState { Reduce, Iter, Pred, Inactive };
|
|
|
|
using ReductionParallelTypeStateArray =
|
|
ParallelTypeMap<ReductionParallelTypeState>;
|
|
|
|
ReductionParallelTypeStateArray states(
|
|
ReductionParallelTypeState::Inactive);
|
|
|
|
for (const ParallelType pt : kParallelTypeThreads) {
|
|
// It may be better to predicate grid reductions on dimensions they don't
|
|
// actively use, however since that should generally be discouraged (they
|
|
// should be part of the iter portion of the operation, or they should be
|
|
// predciated out) we're just going to assume they're part of the iter
|
|
// dimension. This would cause more communication than strictly necessary
|
|
// but should not be a common use case.
|
|
auto pt_dim = kernel_->summary().parallel_dimension_map_.get(pt);
|
|
if (pt_dim == nullptr || pt_dim->isOneInt()) {
|
|
continue;
|
|
}
|
|
// Initialize pt_dim if used to an iter dimension. It may change to a
|
|
// reduction or predicated dimension later.
|
|
states[pt] = ReductionParallelTypeState::Iter;
|
|
}
|
|
|
|
for (auto id : alloc_fused_reduction->out()->view()->domain()->domain()) {
|
|
auto pt = id->getParallelType();
|
|
if (isParallelTypeThread(pt)) {
|
|
auto state = id->isReduction() ? ReductionParallelTypeState::Reduce
|
|
: ReductionParallelTypeState::Iter;
|
|
states[pt] = state;
|
|
}
|
|
}
|
|
|
|
for (const auto predicated_pt : alloc_fused_reduction->threadPredicate()) {
|
|
auto& state = states[predicated_pt];
|
|
TORCH_INTERNAL_ASSERT(
|
|
state != ReductionParallelTypeState::Reduce,
|
|
"Invalid thread predication: ",
|
|
predicated_pt);
|
|
state = ReductionParallelTypeState::Pred;
|
|
}
|
|
|
|
ArgumentBuilder flags;
|
|
for (auto pt : kParallelTypeThreads) {
|
|
flags.arg(static_cast<int>(states[pt]));
|
|
}
|
|
|
|
// Persistent
|
|
flags.arg(true);
|
|
|
|
// Broadcast is fused
|
|
flags.arg(true);
|
|
|
|
const auto reduction_name =
|
|
genFusedReductionName(alloc_fused_reduction->out()->view());
|
|
|
|
indent() << genTemplate("fused_reduction::ParallelReduce", flags) << " "
|
|
<< reduction_name << ";\n";
|
|
}
|
|
|
|
void handleScope(const kir::Scope& scope) {
|
|
for (auto expr : scope.exprs()) {
|
|
OptOutConstDispatch::handle(expr);
|
|
}
|
|
}
|
|
|
|
void handleTrivialLoop(const kir::ForLoop* loop) {
|
|
if (loop->vectorize()) {
|
|
vectorize_scope_ = true;
|
|
}
|
|
handleScope(loop->body());
|
|
if (loop->vectorize()) {
|
|
vectorize_scope_ = false;
|
|
}
|
|
}
|
|
|
|
void handle(const GroupedReductionOp* grouped_rop) final {
|
|
for (const auto i : c10::irange(grouped_rop->numExprs())) {
|
|
TORCH_INTERNAL_ASSERT(grouped_rop->output(i)->isA<kir::TensorIndex>());
|
|
|
|
const auto output = grouped_rop->output(i)->as<kir::TensorIndex>();
|
|
const auto input = grouped_rop->input(i)->as<kir::TensorIndex>();
|
|
const auto domain = output->view()->domain();
|
|
const auto op_type = grouped_rop->getReductionOpType(i);
|
|
|
|
const bool has_block_reduce = domain->hasBlockReduction();
|
|
const bool has_grid_reduce = domain->hasGridReduction();
|
|
|
|
TORCH_INTERNAL_ASSERT(
|
|
!has_grid_reduce,
|
|
"GroupedReductionOp does not support block parallelization. GroupedGridReduction must be used. ",
|
|
grouped_rop->toString());
|
|
|
|
if (!has_block_reduce) {
|
|
genSerialReduction(output, input, op_type);
|
|
} else if (
|
|
auto reduction_id =
|
|
ir_utils::getMaybeWarpReductionDim(output, input)) {
|
|
genWarpReduction(
|
|
output,
|
|
input,
|
|
grouped_rop->initVal(i),
|
|
op_type,
|
|
grouped_rop->predicate());
|
|
} else {
|
|
genBlockReduction(
|
|
output,
|
|
input,
|
|
grouped_rop->initVal(i),
|
|
op_type,
|
|
grouped_rop->predicate(),
|
|
grouped_rop->writePredicate());
|
|
}
|
|
}
|
|
}
|
|
|
|
void handle(const GroupedWelfordOp* grouped_wop) final {
|
|
TORCH_INTERNAL_ASSERT(
|
|
false,
|
|
"Should not reach here as grouped welford is only enabled for grid welford,",
|
|
" which is handled by its own handler");
|
|
}
|
|
|
|
//! True if loop is grouped. The IterDomain of the loop must have
|
|
//! ParallelType::Group, but it isn't sufficient as the loop may be
|
|
//! for an initialization expression, for which the loop shold not
|
|
//! be grouped. Make sure a GroupedGridReduction is found.
|
|
bool isGroupedLoop(const kir::ForLoop* loop) {
|
|
if (loop->iter_domain()->getParallelType() != ParallelType::Group) {
|
|
return false;
|
|
}
|
|
return ExprFinder::exists(
|
|
loop, {ExprType::GroupedGridReduction, ExprType::GroupedGridWelford});
|
|
}
|
|
|
|
void handle(const kir::ForLoop* loop) final {
|
|
if (loop->isTrivial()) {
|
|
handleTrivialLoop(loop);
|
|
return;
|
|
}
|
|
|
|
// If a loop is grouped, no loop is created, but it isn't
|
|
// considered trivial as the loop trip count is not one.
|
|
if (isGroupedLoop(loop)) {
|
|
grouped_loops_.push_back(loop);
|
|
handleScope(loop->body());
|
|
grouped_loops_.pop_back();
|
|
return;
|
|
}
|
|
|
|
const auto gen_index = gen(loop->index());
|
|
const auto gen_start = genInline(loop->start());
|
|
const auto gen_stop = genInline(loop->stop());
|
|
const auto gen_step = genInline(loop->step());
|
|
|
|
std::stringstream step_code;
|
|
if (loop->step()->isOneInt()) {
|
|
step_code << "++" << gen_index;
|
|
} else {
|
|
step_code << gen_index << " += " << gen_step;
|
|
}
|
|
if (loop->isUnrolled()) {
|
|
indent() << "#pragma unroll\n";
|
|
} else {
|
|
indent() << "#pragma unroll 1\n";
|
|
}
|
|
|
|
indent() << "for(nvfuser_index_t " << gen_index;
|
|
if (loop->iter_domain()->isParallelized()) {
|
|
code_ << " = " << gen_start << "; ";
|
|
} else {
|
|
// Do not start at the start of the ID when not parallelized. Instead,
|
|
// start at 0. Predicates will protect buffers between 0 and ID->start(),
|
|
// however if we started at ID->start and extent == ID->start, we could
|
|
// have a "degenerate" loop (loop with no iterations). It may not be an
|
|
// issue to have a 0-sized loop, but all potential consequences haven't
|
|
// been covered. One example is WAR analysis which could incorrectly think
|
|
// a barrier inside a 0-sized loop actually provides protection.
|
|
code_ << " = 0; ";
|
|
}
|
|
code_ << gen_index << " < " << gen_stop << "; " << step_code.str() << ") ";
|
|
startBlock(true);
|
|
handleScope(loop->body());
|
|
endBlock();
|
|
}
|
|
|
|
void handle(const kir::IfThenElse* ite) final {
|
|
auto conditional = ite->predicate()->value();
|
|
if (conditional->isConst()) {
|
|
// If the conditional is a constant, then the IfThenElse is not required
|
|
if (conditional->value().value()) {
|
|
handleScope(ite->thenBody());
|
|
} else {
|
|
handleScope(ite->elseBody());
|
|
}
|
|
return;
|
|
}
|
|
|
|
indent() << "if (" << genInline(conditional) << ") ";
|
|
|
|
// "then" block
|
|
startBlock(true);
|
|
handleScope(ite->thenBody());
|
|
|
|
// "else" block (optional)
|
|
if (ite->hasElse()) {
|
|
endBlock(" else ");
|
|
startBlock(true);
|
|
handleScope(ite->elseBody());
|
|
}
|
|
|
|
endBlock();
|
|
}
|
|
|
|
void handle(const kir::Allocate* alloc) final {
|
|
const auto buffer_dtype = alloc->buffer()->dtype();
|
|
|
|
TORCH_INTERNAL_ASSERT(alloc->buffer() != nullptr);
|
|
alloc_map_.emplace(alloc->buffer(), alloc);
|
|
|
|
if (!alloc->buffer()->isA<TensorView>()) {
|
|
indent() << buffer_dtype << " " << gen(alloc->buffer()) << ";\n";
|
|
return;
|
|
}
|
|
|
|
const auto tv = alloc->buffer()->as<TensorView>();
|
|
|
|
const auto size = alloc->size();
|
|
TORCH_INTERNAL_ASSERT(size != nullptr);
|
|
|
|
if (alloc->alias() != nullptr) {
|
|
// Allocate alias another Allocate stmt
|
|
const auto alias_tv = alloc->alias()->buffer()->as<TensorView>();
|
|
indent() << "// Alias Allocation - " << alloc->memoryType() << "\n";
|
|
indent() << "auto& " << varName(tv) << " = " << varName(alias_tv)
|
|
<< ";\n";
|
|
|
|
} else {
|
|
// Standard Memory Allocation
|
|
switch (tv->getMemoryType()) {
|
|
case MemoryType::Global:
|
|
indent() << "// Allocate global tensor " << varName(tv) << "\n";
|
|
break;
|
|
case MemoryType::Shared:
|
|
// Align Offset Position
|
|
indent() << "smem_offset = alignBufferSize(smem_offset, "
|
|
// Always align to 128b / 16B
|
|
<< 16 << ");\n";
|
|
// Shared Memory Pointer
|
|
indent() << buffer_dtype << "* " << varName(tv)
|
|
<< " = reinterpret_cast<" << buffer_dtype << "*>"
|
|
<< "(array + smem_offset);\n";
|
|
// Increment Offset Position
|
|
indent() << "smem_offset += (" << genInline(size) << " * sizeof("
|
|
<< buffer_dtype << "));\n";
|
|
break;
|
|
case MemoryType::Local: {
|
|
auto va = kernel_->summary().vectorized_accesses;
|
|
if (va.find(tv) != va.end()) {
|
|
indent() << "Array<" << buffer_dtype << ", " << genInline(size)
|
|
<< ", " << va.at(tv) << "> " << varName(tv) << ";\n";
|
|
} else {
|
|
indent() << buffer_dtype << " " << varName(tv) << "["
|
|
<< genInline(size) << "];\n";
|
|
}
|
|
} break;
|
|
default:
|
|
TORCH_INTERNAL_ASSERT(false, "Unexpected memory type");
|
|
}
|
|
}
|
|
}
|
|
|
|
void handle(const kir::BlockSync* sync) final {
|
|
// Use a custom synchronization method if enabled
|
|
if (std::getenv("PYTORCH_NVFUSER_USE_BLOCK_SYNC_ATOMIC")) {
|
|
indent() << "block_sync::sync();\n";
|
|
} else {
|
|
indent() << "__barrier_sync(0);\n";
|
|
}
|
|
}
|
|
|
|
void handle(const kir::CpAsyncWait* cpasync_wait) final {
|
|
if (cpasync_wait->keepStages() > 0) {
|
|
// Perform partial sync, see comment on kir::CpAsyncWait.
|
|
indent() << "Ampere::cpAsyncPartialBarrier<" << cpasync_wait->keepStages()
|
|
<< ">();\n";
|
|
} else {
|
|
// Perform sync all, see comment on kir::CpAsyncWait.
|
|
indent() << "Ampere::cpAsyncBarrier();\n";
|
|
}
|
|
}
|
|
|
|
void handle(const kir::CpAsyncCommit* cpasync_wait) final {
|
|
// Commit inflight cp.async transfers. See comment on kir::CpAsyncCommit.
|
|
indent() << "Ampere::cpAsyncCommit();\n";
|
|
}
|
|
|
|
void handle(const kir::GridSync* sync) final {
|
|
// Use a custom synchronization method if enabled
|
|
bool bidx = sync->syncDims().get(ParallelType::BIDx);
|
|
bool bidy = sync->syncDims().get(ParallelType::BIDy);
|
|
bool bidz = sync->syncDims().get(ParallelType::BIDz);
|
|
|
|
ArgumentBuilder sync_call_template_parms;
|
|
sync_call_template_parms.arg(bidx).arg(bidy).arg(bidz).arg(true);
|
|
|
|
auto sync_idx = genCall(
|
|
"index_utils::maskedOffset",
|
|
ArgumentBuilder().arg(!bidx).arg(!bidy).arg(!bidz),
|
|
ArgumentBuilder().arg("blockIdx").arg("gridDim"));
|
|
|
|
auto sync_segment_size = genCall(
|
|
"index_utils::maskedSize",
|
|
ArgumentBuilder().arg(bidx).arg(bidy).arg(bidz),
|
|
ArgumentBuilder().arg("gridDim"));
|
|
|
|
ArgumentBuilder sync_call_args;
|
|
sync_call_args.arg(varName(sync->syncBuffer()))
|
|
.append("[")
|
|
.append(sync_idx)
|
|
.append("]");
|
|
sync_call_args.arg(sync_segment_size);
|
|
|
|
auto sync_call =
|
|
genCall("grid_sync::sync", sync_call_template_parms, sync_call_args);
|
|
|
|
indent() << sync_call << ";\n";
|
|
}
|
|
|
|
void handle(const kir::InitMagicZero*) final {
|
|
indent() << "NVFUSER_DEFINE_MAGIC_ZERO\n";
|
|
}
|
|
|
|
void handle(const kir::UpdateMagicZero*) final {
|
|
indent() << "NVFUSER_UPDATE_MAGIC_ZERO\n";
|
|
}
|
|
|
|
void handle(const kir::Swizzle2DInt* swizzle_2d) final {
|
|
TORCH_INTERNAL_ASSERT(print_inline_);
|
|
TORCH_INTERNAL_ASSERT(
|
|
swizzle_2d->swizzleType() != Swizzle2DType::NoSwizzle,
|
|
"Swizzle type undefined.");
|
|
if (print_inline_) {
|
|
code_ << swizzle_2d->swizzleType() << "({" << gen(swizzle_2d->inX())
|
|
<< "," << gen(swizzle_2d->inY()) << "} , "
|
|
<< "{" << gen(swizzle_2d->extentX()) << ","
|
|
<< gen(swizzle_2d->extentY()) << "})";
|
|
}
|
|
}
|
|
|
|
void handle(const kir::IntPair* int_pair) final {
|
|
const auto def = int_pair->definition();
|
|
TORCH_INTERNAL_ASSERT(
|
|
def != nullptr, "no support for un-inlined int pair yet.");
|
|
code_ << gen(def);
|
|
}
|
|
|
|
void handle(const kir::PairSelect* pair_select) final {
|
|
if (print_inline_) {
|
|
code_ << gen(pair_select->in());
|
|
} else {
|
|
indent() << gen(pair_select->out()) << " = " << gen(pair_select->in());
|
|
}
|
|
|
|
switch (pair_select->selection()) {
|
|
case kir::PairSelect::Selection::X:
|
|
code_ << ".x";
|
|
break;
|
|
case kir::PairSelect::Selection::Y:
|
|
code_ << ".y";
|
|
break;
|
|
default:
|
|
TORCH_INTERNAL_ASSERT(false, "unknown select")
|
|
break;
|
|
}
|
|
|
|
if (!print_inline_) {
|
|
code_ << ";\n";
|
|
}
|
|
}
|
|
|
|
private:
|
|
std::stringstream code_;
|
|
const kir::Kernel* kernel_;
|
|
int block_nest_level_ = 0;
|
|
int block_reduce_name_ = 0;
|
|
bool print_inline_ = false;
|
|
|
|
// Mark when we are inside of a vectorized for-loop
|
|
bool vectorize_scope_ = false;
|
|
//! Keep track of Allocate node for Val. Used to determine if Val
|
|
//! should be inlined.
|
|
std::unordered_map<const Val*, const kir::Allocate*> alloc_map_;
|
|
//! Keep track of grouped loops
|
|
std::deque<const kir::ForLoop*> grouped_loops_;
|
|
//! Used to replace symbolic indices with concrete values
|
|
std::unordered_map<const Int*, int64_t> index_replacement_map_;
|
|
};
|
|
|
|
} // namespace
|
|
|
|
std::string generateCudaKernel(
|
|
const kir::Kernel* kernel,
|
|
const std::string& kernel_name) {
|
|
FUSER_PERF_SCOPE("generateCudaKernel");
|
|
return CudaKernelGenerator::generateKernelDefinition(kernel, kernel_name);
|
|
}
|
|
|
|
} // namespace codegen
|
|
} // namespace cuda
|
|
} // namespace fuser
|
|
} // namespace jit
|
|
} // namespace torch
|