Files
pytorch/torch/csrc/jit/tensorexpr/cuda_codegen.cpp
Yuanyuan Chen 9fff8155c3 [2/N] Fix clang-tidy readability checks (#164652)
This PR applies clang-tidy readability checks to jit sources and all headers in the code base.
`readability-redundant-inline-specifier` is suppressed because it incurs too many changes. `readability-redundant-inline-specifier` is used to detect redundant inline specifiers on function and variable declarations. There are many in-class method definitions that are marked inline.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/164652
Approved by: https://github.com/Skylion007
2025-10-06 01:06:01 +00:00

1354 lines
42 KiB
C++

#include <torch/csrc/jit/tensorexpr/cuda_codegen.h>
#include <torch/csrc/jit/tensorexpr/half_support.h>
#include <ATen/cuda/CUDAContext.h>
#include <ATen/cuda/CUDAGeneratorImpl.h>
#include <ATen/native/cuda/jit_utils.h>
#include <c10/cuda/CUDAFunctions.h>
#include <c10/util/irange.h>
#include <torch/csrc/jit/codegen/fuser/cuda/fused_kernel.h>
#include <torch/csrc/jit/codegen/fuser/cuda/resource_strings.h>
#include <torch/csrc/jit/jit_log.h>
#include <torch/csrc/jit/tensorexpr/analysis.h>
#include <torch/csrc/jit/tensorexpr/cuda_random.h>
#include <torch/csrc/jit/tensorexpr/eval.h>
#include <torch/csrc/jit/tensorexpr/exceptions.h>
#include <torch/csrc/jit/tensorexpr/ir_simplifier.h>
#include <torch/csrc/jit/tensorexpr/registerizer.h>
#ifndef AT_PER_OPERATOR_HEADERS
#include <ATen/NativeFunctions.h>
#else
#include <ATen/ops/empty_strided_native.h>
#endif
#include <unordered_map>
#include <utility>
namespace torch::jit::tensorexpr {
// A RAII wrapper to manage a variable and name pair in the look-up table.
// TODO: move this to a more shared place.
class ScopedVarName {
public:
ScopedVarName(VarNameMap* mapping, const VarPtr& var, const std::string& name)
: mapping_(mapping), var_(var) {
auto iter = mapping->find(var);
if (iter != mapping->end()) {
throw std::runtime_error("Duplicate var entry: " + var->name_hint());
}
mapping->insert(std::make_pair(var, name));
}
ScopedVarName(
UniqueNameManager* manager,
const VarPtr& var,
const std::string& name)
: ScopedVarName(&manager->unique_name_mapping_, var, name) {}
ScopedVarName(const ScopedVarName&) = delete;
ScopedVarName& operator=(const ScopedVarName&) = delete;
~ScopedVarName() noexcept(false) {
mapping_->erase(var_);
}
private:
VarNameMap* mapping_ = nullptr;
VarPtr var_ = nullptr;
};
static bool is_zero(const ExprPtr& expr) {
auto v = intValue(expr);
return v && *v == 0;
}
static const at::cuda::NVRTC& nvrtc() {
return at::globalContext().getNVRTC();
}
std::string CudaPrinter::dtypeToCppString(const Dtype& dtype) {
switch (dtype.scalar_type()) {
case ScalarType::Bool:
return "bool";
case ScalarType::Half:
return "half";
case ScalarType::BFloat16:
return fuser::cuda::bfloat16_type_string;
case ScalarType::Char:
return "char";
case ScalarType::Byte:
return "unsigned char";
case ScalarType::Short:
return "short";
case ScalarType::Long:
return "long long";
default:
return dtype.ToCppString();
}
}
void CudaAnalysis::visit(const FreePtr& v) {
if (thread_local_bufs_.count(v->buffer_var()) == 0 &&
cross_block_bufs_.count(v->buffer_var()) == 0) {
throw std::runtime_error("Global free not supported yet");
}
}
void CudaAnalysis::visit(const AllocatePtr& v) {
StmtPtr p = v->get_parent();
while (p) {
ForPtr for_v = to<For>(p);
if (for_v) {
if (for_v->loop_options().is_gpu_block_index()) {
// TODO: This isn't right if there's a thread index at a higher level
// than this.
cross_block_bufs_.insert(v->buffer_var());
return;
} else if (for_v->loop_options().is_gpu_thread_index()) {
thread_local_bufs_.insert(v->buffer_var());
return;
}
}
p = p->get_parent();
}
throw std::runtime_error("Global alloc not supported yet");
}
void CudaAnalysis::visit(const PlacementAllocatePtr& v) {
throw std::runtime_error("Memory reuse not supported yet");
}
void CudaAnalysis::visit(const ForPtr& v) {
// Recurse first.
v->body()->accept(this);
const LoopOptions& loop_options = v->loop_options();
if (loop_options.is_gpu_block_index()) {
int gpu_block_index = loop_options.gpu_block_index();
if (gpu_block_index >= 3) {
throw std::runtime_error("support only 3D gpu_block_index");
}
ExprPtr prev = nullptr;
if (gpu_block_extents_.size() <= static_cast<size_t>(gpu_block_index)) {
gpu_block_extents_.resize(gpu_block_index + 1);
} else {
prev = gpu_block_extents_[gpu_block_index];
}
if (!is_zero(v->start())) {
throw std::runtime_error(
"start must be zero for gpu_block_index: " +
std::to_string(v->start()));
}
// NOLINTNEXTLINE(bugprone-branch-clone)
if (prev == nullptr) {
gpu_block_extents_[gpu_block_index] = v->stop();
} else if (prev->isConstant() && immediateEquals(prev, 1)) {
// extents must be positive so if the current extent is 1 then even if the
// stop is symbolic it's the max.
gpu_block_extents_[gpu_block_index] = v->stop();
} else {
gpu_block_extents_[gpu_block_index] =
IRSimplifier::simplify(alloc<Max>(prev, v->stop(), true));
}
} else if (loop_options.is_gpu_thread_index()) {
int gpu_thread_index = loop_options.gpu_thread_index();
if (gpu_thread_index >= 3) {
throw std::runtime_error("support only 3D gpu_thread_index");
}
ExprPtr prev = nullptr;
if (gpu_thread_extents_.size() <= static_cast<size_t>(gpu_thread_index)) {
gpu_thread_extents_.resize(gpu_thread_index + 1);
} else {
prev = gpu_thread_extents_[gpu_thread_index];
}
if (!is_zero(v->start())) {
throw std::runtime_error(
"start must be zero for gpu_thread_index: " +
std::to_string(v->start()));
}
// NOLINTNEXTLINE(bugprone-branch-clone)
if (prev == nullptr) {
gpu_thread_extents_[gpu_thread_index] = v->stop();
} else if (prev->isConstant() && immediateEquals(prev, 1)) {
// extents must be positive so if the current extent is 1 then even if the
// stop is symbolic it's the max.
gpu_thread_extents_[gpu_thread_index] = v->stop();
} else {
gpu_thread_extents_[gpu_thread_index] =
IRSimplifier::simplify(alloc<Max>(prev, v->stop(), true));
}
}
}
void CudaPrinter::print_flat_alloc(const AllocatePtr& alloc) {
std::vector<ExprPtr> dims = alloc->dims();
// TODO: this should be merged with the storage flattener.
int64_t flat_size = 1;
for (const auto& dim : dims) {
auto dim_i = intValue(dim);
if (dim_i) {
flat_size *= *dim_i;
} else {
throw std::runtime_error("Only integer dimensions are supported for now");
}
}
os() << dtypeToCppString(alloc->dtype()) << " " << (*alloc->buffer_var())
<< "[" << flat_size << "];" << '\n';
}
void CudaPrinter::visit(const AllocatePtr& v) {
// TODO: handle dynamic shapes here.
if (cuda_analysis_->cross_block_bufs().count(v->buffer_var()) != 0) {
emitIndent();
os() << "__shared__ ";
print_flat_alloc(v);
return;
}
if (cuda_analysis_->thread_local_bufs().count(v->buffer_var()) != 0) {
emitIndent();
print_flat_alloc(v);
return;
}
throw std::runtime_error("Encountered Alloc not local to block or thread");
}
void CudaPrinter::visit(const FreePtr& v) {
// do nothing
}
void CudaPrinter::visit(const ForPtr& v) {
IRPrinter::visit(v);
}
void CudaPrinter::visit(const CastPtr& v) {
std::string castFn = v->dtype().scalar_type() == ScalarType::Half
? "__float2half"
: v->dtype().scalar_type() == ScalarType::BFloat16 ? "__float2bfloat16"
: v->src_value()->dtype().scalar_type() == ScalarType::Half
? "__half2float"
: v->src_value()->dtype().scalar_type() == ScalarType::BFloat16
? "__bfloat162float"
: ("(" + dtypeToCppString(v->dtype()) + ")");
os() << castFn << "(";
v->src_value()->accept(this);
os() << ")";
}
void CudaPrinter::visit(const IntrinsicsPtr& v) {
if (v->op_type() == IntrinsicsOp::kRand) {
os() << "Uint32ToFloat(" << *rand_func_ << "())";
return;
}
std::string func_name = v->func_name();
// get type of resulting expression.
ScalarType returnType = v->param(0)->dtype().scalar_type();
for (size_t i = 1; i < v->nparams(); ++i) {
returnType = promoteTypes(returnType, v->param(i)->dtype().scalar_type());
}
if (returnType == ScalarType::Half || returnType == ScalarType::Float) {
func_name = func_name + "f";
}
if (v->op_type() == IntrinsicsOp::kAbs &&
!c10::isIntegralType(returnType, true)) {
// since kAbs's func_name is `abs`, prefix `f` for floating point
func_name = "f" + func_name;
}
if (v->op_type() == IntrinsicsOp::kIsNan) {
func_name = "isnan";
}
os() << func_name << "(";
for (const auto i : c10::irange(v->nparams())) {
if (i > 0) {
os() << ", ";
}
os() << *v->param(i);
}
os() << ")";
}
void CudaPrinter::visit(const ExternalCallPtr& v) {
throw unimplemented_lowering(v);
}
void CudaPrinter::visit(const LoadPtr& v) {
// TODO: find a better metric in using ldg or not. Support different dtypes.
// Detects whether the load target is also a store target.
// TODO: this is currently too wide. It detects whether a store-target
// exists within the program. In fact, this check is only necessary within a
// kernel.
if (v->indices().empty()) {
os() << *v->base_handle();
return;
}
if (v->dtype().scalar_type() == ScalarType::Bool ||
v->dtype().scalar_type() == ScalarType::Half ||
v->dtype().scalar_type() == ScalarType::BFloat16) {
// There's no __ldg overload for bool or half.
os() << *v->base_handle() << "[" << *v->flat_index() << "]";
return;
}
if (cuda_analysis_->is_buf_store_target(v->buf())) {
// Cuda __ldg can only be applied on read-only buffers.
os() << *v->base_handle() << "[" << *v->flat_index() << "]";
return;
}
os() << "__ldg(" << *v->base_handle() << " + " << *v->flat_index() << ")";
}
// TODO: maybe this should be a more shared location?
// TODO: investigate how "ExprPtr" can be implicitly converted to "ExprHandle"
// as a bool.
static bool CheckEqual(const ExprPtr& lhs, const ExprPtr& rhs) {
// The fast path. Checks if the pointers are the same.
if (lhs == rhs) {
return true;
}
ExprHandle diff = Sub::make(ExprHandle(lhs), ExprHandle(rhs));
ExprHandle diff_s = IRSimplifier::simplify(diff);
return immediateEquals(diff_s.node(), 0);
}
class AtomicAddFuser : public IRMutator {
public:
AtomicAddFuser(
const std::unordered_set<VarPtr>& thread_local_bufs,
const GPUMetaVarRewriter& metavars)
: thread_local_bufs_(thread_local_bufs) {
const std::vector<ExprPtr>& block_extents = metavars.gpu_block_extents();
const std::vector<VarPtr>& block_vars = metavars.gpu_block_vars();
for (size_t i = 0; i < block_extents.size(); ++i) {
MetaVarExtent extent{block_extents[i], false};
if (extent.expr->isConstant() && immediateEquals(extent.expr, 1)) {
extent.trivial = true;
} else {
nontrivial_metavars_.insert(block_vars[i]);
}
metavars_[block_vars[i]] = extent;
}
const std::vector<ExprPtr>& thread_extents = metavars.gpu_thread_extents();
const std::vector<VarPtr>& thread_vars = metavars.gpu_thread_vars();
for (size_t i = 0; i < thread_extents.size(); ++i) {
MetaVarExtent extent{thread_extents[i], false};
if (extent.expr->isConstant() && immediateEquals(extent.expr, 1)) {
extent.trivial = true;
} else {
nontrivial_metavars_.insert(thread_vars[i]);
}
metavars_[thread_vars[i]] = extent;
}
}
StmtPtr mutate(const StorePtr& v) override {
BufPtr buf = v->buf();
// Thread locals never need to be atomic.
if (thread_local_bufs_.count(buf->base_handle()) != 0) {
return v;
}
ScalarType dtype = v->value()->dtype().scalar_type();
if (dtype != ScalarType::Float && dtype != ScalarType::Double) {
return v;
}
AddPtr add_v = to<Add>(v->value());
if (!add_v) {
return v;
}
LoadPtr load_v = to<Load>(add_v->lhs());
if (!load_v) {
return v;
}
if (v->base_handle() != load_v->base_handle()) {
return v;
}
if (v->indices().empty() && load_v->indices().empty()) {
return v;
}
bool index_equal = CheckEqual(v->flat_index(), load_v->flat_index());
if (!index_equal) {
return v;
}
// TODO: this checks that the metavars occur directly as an index, but this
// is pessimistic, blockIdx.x + 1 is fine too if there is no overlapping.
std::unordered_set<VarPtr> vars_to_find = nontrivial_metavars_;
for (const ExprPtr& e : v->indices()) {
if (VarPtr v = to<Var>(e)) {
vars_to_find.erase(v);
}
}
if (vars_to_find.empty()) {
// All metavars accounted for.
return v;
}
return alloc<AtomicAdd>(buf, v->indices(), add_v->rhs());
}
private:
// NOLINTNEXTLINE(cppcoreguidelines-avoid-const-or-ref-data-members)
const std::unordered_set<VarPtr>& thread_local_bufs_;
struct MetaVarExtent {
ExprPtr expr{nullptr};
bool trivial{false};
};
std::unordered_map<VarPtr, MetaVarExtent> metavars_;
std::unordered_set<VarPtr> nontrivial_metavars_;
};
void CudaPrinter::visit(const StorePtr& v) {
emitIndent();
if (v->indices().empty()) {
os() << *v->base_handle() << " = ";
} else {
os() << *v->base_handle() << "[" << *v->flat_index() << "] = ";
}
os() << *v->value() << ";";
os() << '\n';
}
void CudaPrinter::visit(const AtomicAddPtr& v) {
emitIndent();
if (cuda_analysis_->thread_local_bufs().count(v->base_handle()) > 0) {
// atomicAdd only works on global and shared memory
os() << *v->base_handle() << "[" << *v->flat_index()
<< "] += " << *v->value() << ";";
} else {
os() << "atomicAdd(&" << *v->base_handle() << "[" << *v->flat_index() << "]"
<< ", " << *v->value() << ");";
}
os() << '\n';
}
void CudaPrinter::visit(const MaxPtr& v) {
if (v->dtype().is_integral()) {
os() << "max(";
} else {
os() << "maximum(";
}
v->lhs()->accept(this);
os() << ",";
v->rhs()->accept(this);
os() << ")";
}
void CudaPrinter::visit(const MinPtr& v) {
if (v->dtype().is_integral()) {
os() << "min(";
} else {
os() << "minimum(";
}
v->lhs()->accept(this);
os() << ",";
v->rhs()->accept(this);
os() << ")";
}
void CudaPrinter::visit(const IfThenElsePtr& v) {
os() << "((";
v->condition()->accept(this);
os() << ") ? ";
v->true_value()->accept(this);
os() << " : ";
v->false_value()->accept(this);
os() << ")";
}
void CudaPrinter::visit(const BlockPtr& v) {
os() << "{" << '\n';
indent_++;
for (const StmtPtr& s : v->stmts()) {
s->accept(this);
}
indent_--;
emitIndent();
os() << "}";
}
void CudaPrinter::visit(const LetPtr& v) {
emitIndent();
os() << dtypeToCppString(v->var()->dtype());
os() << " " << *v->var() << " = ";
v->value()->accept(this);
os() << ";" << '\n';
}
class PrioritizeLoad : public IRMutator {
public:
ExprPtr mutate(const LoadPtr& v) override {
// Look at the declaration of this variable for more details.
if (nested_if_then_else_ > 0) {
return IRMutator::mutate(v);
}
if (nested_let_) {
return IRMutator::mutate(v);
}
if (thread_local_bufs_.count(v->base_handle()) > 0) {
return IRMutator::mutate(v);
}
if (v->indices().empty()) {
return IRMutator::mutate(v);
}
if (nested_store_) {
if (v->base_handle() == nested_store_->buf()->base_handle() &&
v->indices().size() == nested_store_->indices().size()) {
// also check indices
bool same = true;
for (const auto i : c10::irange(v->indices().size())) {
if (!exprEquals(v->indices()[i], nested_store_->indices()[i])) {
same = false;
break;
}
}
if (same) {
return IRMutator::mutate(v);
}
} else if (nested_store_->indices().empty()) {
return IRMutator::mutate(v);
}
}
MemLoadList& load_list = load_stack_.back();
VarPtr load_new_var = alloc<Var>("v", v->dtype());
ExprPtr new_value = IRMutator::mutate(v);
load_list.emplace_back(load_new_var, new_value);
return load_new_var;
}
ExprPtr mutate(const CastPtr& v) override {
LoadPtr src_load = to<Load>(v->src_value());
ExprPtr new_src = v->src_value()->accept_mutator(this);
VarPtr new_var = to<Var>(new_src);
if (!src_load || !new_var) {
return alloc<Cast>(v->dtype(), new_src);
}
// We just did the prioritize load, let's fold in the Cast.
MemLoadList& load_list = load_stack_.back();
assert(!load_list.empty());
auto pair = load_list.back();
assert(pair.first == new_var);
load_list.pop_back();
new_var = alloc<Var>("v", v->dtype());
ExprPtr new_value = alloc<Cast>(v->dtype(), pair.second);
load_list.emplace_back(new_var, new_value);
return new_var;
}
StmtPtr mutate(const StorePtr& v) override {
StorePtr last = nested_store_;
nested_store_ = v;
StmtPtr s = IRMutator::mutate(v);
nested_store_ = last;
return s;
}
StmtPtr mutate(const LetPtr& v) override {
nested_let_ = true;
StmtPtr s = IRMutator::mutate(v);
nested_let_ = false;
return s;
}
StmtPtr mutate(const BlockPtr& v) override {
std::list<StmtPtr> stmts = v->stmts();
for (const StmtPtr& stmt : stmts) {
PushList();
StmtPtr stmt_new = stmt->accept_mutator(this);
AddMemLoadsFromList(v, stmt);
PopList();
if (stmt_new == stmt) {
continue;
}
v->replace_stmt(stmt, stmt_new);
}
return v;
}
ExprPtr mutate(const IfThenElsePtr& v) override {
nested_if_then_else_++;
ExprPtr new_v = IRMutator::mutate(v);
nested_if_then_else_--;
return new_v;
}
private:
using MemLoadEntry = std::pair<VarPtr, ExprPtr>;
using MemLoadList = std::vector<MemLoadEntry>;
using MemoryLoadStack = std::vector<MemLoadList>;
void PushList() {
load_stack_.emplace_back();
}
void PopList() {
load_stack_.pop_back();
}
void AddMemLoadsFromList(const BlockPtr& block, const StmtPtr& last) {
MemLoadList& load_list = load_stack_.back();
if (load_list.empty()) {
return;
}
for (auto& pair : load_list) {
StmtPtr news = alloc<Let>(pair.first, pair.second);
block->insert_stmt_before(news, last);
}
}
MemoryLoadStack load_stack_;
// TODO: For now, we are not moving the loads with the IfThenElse.
// Eventually, we should switch to a more generic structure like:
// int v2 = IfThenElse(cond, true_v, false_v) + 2 ->
//
// int v;
// if (cond) {
// v = true_v;
// } else {
// v = false_v;
// }
// int v2 = v + 2;
int nested_if_then_else_{0};
StorePtr nested_store_{nullptr};
bool nested_let_{false};
std::unordered_set<VarPtr> thread_local_bufs_;
};
std::string CudaCodeGen::GetUniqueFuncName(const std::string& func_prefix) {
int64_t counter = 0;
std::string name = func_prefix;
while (taken_func_names.count(name)) {
name = func_prefix + "_" + std::to_string(counter++);
}
taken_func_names.insert(name);
return name;
}
bool GPUMetaVarRewriter::isFullExtent() {
{
auto& extents = cuda_analysis_->gpu_block_extents();
for (int i = 0; i < 3; ++i) {
if (!exprEquals(current_block_reach_[i], extents[i])) {
return false;
}
}
}
{
auto& extents = cuda_analysis_->gpu_thread_extents();
for (int i = 0; i < 3; ++i) {
if (!exprEquals(current_thread_reach_[i], extents[i])) {
return false;
}
}
}
return true;
}
StmtPtr GPUMetaVarRewriter::mutate(const ForPtr& v) {
StmtPtr body = v->body();
ExprPtr old_reach = nullptr;
const LoopOptions& loop_options = v->loop_options();
if (loop_options.is_gpu_block_index()) {
int gpu_block_index = loop_options.gpu_block_index();
if (gpu_block_index >= 3) {
throw std::runtime_error("support only 3D gpu_block_index");
}
old_reach = current_block_reach_[gpu_block_index];
// Extents must be positive, assume >= 1.
if (old_reach->isConstant() && immediateEquals(old_reach, 1)) {
current_block_reach_[gpu_block_index] = v->stop();
} else {
current_block_reach_[gpu_block_index] =
IRSimplifier::simplify(alloc<Max>(old_reach, v->stop(), true));
}
VarPtr metaVar = gpu_block_vars_[gpu_block_index];
body = Substitute(Stmt::clone(body), {{v->var(), metaVar}});
} else if (loop_options.is_gpu_thread_index()) {
int gpu_thread_index = loop_options.gpu_thread_index();
if (gpu_thread_index >= 3) {
throw std::runtime_error("support only 3D gpu_thread_index");
}
old_reach = current_thread_reach_[gpu_thread_index];
// Extents must be positive, assume >= 1.
if (old_reach->isConstant() && immediateEquals(old_reach, 1)) {
current_thread_reach_[gpu_thread_index] = v->stop();
} else {
current_thread_reach_[gpu_thread_index] =
IRSimplifier::simplify(alloc<Max>(old_reach, v->stop(), true));
}
VarPtr metaVar = gpu_thread_vars_[gpu_thread_index];
body = Substitute(Stmt::clone(body), {{v->var(), metaVar}});
}
// Recurse into body block.
body = Stmt::clone(body->accept_mutator(this));
// pop the internal reach off the stack.
if (loop_options.is_gpu_block_index()) {
current_block_reach_[loop_options.gpu_block_index()] = old_reach;
return body;
} else if (loop_options.is_gpu_thread_index()) {
current_thread_reach_[loop_options.gpu_thread_index()] = old_reach;
return body;
}
return v->cloneWithNewBody(body);
}
StmtPtr GPUMetaVarRewriter::mutate(const BlockPtr& v) {
std::vector<Segment> innerSegments;
Segment current;
auto pushAndReset = [&](bool mask) {
if (!current.empty()) {
innerSegments.push_back(current);
}
current.reset(mask);
};
// Here's we're slicing the Block's contents into segments that should have
// the same launch reach. Segments are comprised of all statements that aren't
// loops - which are their own segments. Some operations, such as threading
// and memory ops should never be masked and so also get their own segment.
for (const StmtPtr& stmt : *v) {
StmtPtr stmt_new = stmt->accept_mutator(this);
if (stmt == stmt_new) {
stmt_new = Stmt::clone(stmt_new);
}
// Likewise, Allocate and Free should never be masked.
if (to<Allocate>(stmt) || to<Free>(stmt)) {
pushAndReset(false);
}
// If the current stmt *was* a loop, it's a segment boundary.
if (ForPtr f = to<For>(stmt)) {
pushAndReset(false);
}
current.stmts().push_back(stmt_new);
// if the current segment should not be masked, it's a segment boundary on
// the far side as well.
if (!current.mask()) {
pushAndReset(true);
}
}
if (!current.empty()) {
innerSegments.push_back(current);
}
// We are max extent in all dimensions, so need no masks at this level.
if (isFullExtent()) {
// flatten inner segments.
std::vector<StmtPtr> stmts;
for (auto& v : innerSegments) {
for (const auto& s : v.stmts()) {
stmts.push_back(s);
}
}
return alloc<Block>(stmts);
}
std::vector<StmtPtr> stmts;
for (auto& segment : innerSegments) {
bool need_sync = false;
// We never mask loops, they'll mask their contents.
if (!segment.mask()) {
TORCH_INTERNAL_ASSERT(segment.stmts().size() == 1, buildErrorMessage());
stmts.push_back(segment.stmts()[0]);
continue;
}
// If we get here, we must mask since we're not full reach and our direct
// child isn't a For.
StmtPtr inner = alloc<Block>(segment.stmts());
// threads inside blocks.
auto& thread_extents = cuda_analysis_->gpu_thread_extents();
for (size_t i = 0; i < gpu_thread_vars_.size(); ++i) {
if (!exprEquals(current_thread_reach_[i], thread_extents[i])) {
need_sync = true;
// Mask it against the current dimensions.
inner = alloc<Cond>(
alloc<CompareSelect>(
gpu_thread_vars_[i],
current_thread_reach_[i],
CompareSelectOperation::kLT),
inner,
nullptr);
}
}
auto& block_extents = cuda_analysis_->gpu_block_extents();
for (size_t i = 0; i < gpu_block_vars_.size(); ++i) {
if (!exprEquals(current_block_reach_[i], block_extents[i])) {
// Mask it against the current dimensions.
inner = alloc<Cond>(
alloc<CompareSelect>(
gpu_block_vars_[i],
current_block_reach_[i],
CompareSelectOperation::kLT),
inner,
nullptr);
}
}
if (need_sync) {
stmts.push_back(alloc<SyncThreads>());
}
stmts.push_back(inner);
if (need_sync) {
stmts.push_back(alloc<SyncThreads>());
}
}
return alloc<Block>(stmts);
}
static std::ostream& operator<<(
std::ostream& out,
const std::vector<ExprPtr>& exprs) {
size_t i = 0;
for (const auto& expr : exprs) {
if (i++ > 0) {
out << ", ";
}
out << *expr;
}
return out;
}
static const char* device_resource_string = R"(
#define NAN __int_as_float(0x7fffffff)
#define POS_INFINITY __int_as_float(0x7f800000)
#define NEG_INFINITY __int_as_float(0xff800000)
)";
static const char* shared_resource_string = R"(
template<typename T>
__device__ T maximum(T a, T b) {
return isnan(a) ? a : (a > b ? a : b);
}
template<typename T>
__device__ T minimum(T a, T b) {
return isnan(a) ? a : (a < b ? a : b);
}
)";
void CudaCodeGen::Initialize() {
// TODO: handle multiple kernels.
// TODO: handle dynamic dimension.
// TODO: call nvrtc.
// TODO: merge HasRand with CudaAnalysis.
GenericIntrinsicsExpander intrinsics_expander;
apply_mutator(&intrinsics_expander);
HasRand has_rand_func(stmt());
has_random_ = has_rand_func.has_rand();
cuda_analysis_ = std::make_unique<CudaAnalysis>();
printer_ =
std::make_unique<CudaPrinter>(&oss_, cuda_analysis_.get(), has_random_);
metavar_rewriter_ =
std::make_unique<GPUMetaVarRewriter>(cuda_analysis_.get());
// Check whether the statement uses the Half type, if so add the
// half_support_literal.
StmtPtr stmt_v = stmt();
HalfChecker halfChecker(buffer_args());
stmt_v->accept(&halfChecker);
os() << device_resource_string << shared_resource_string;
if (has_random_) {
os() << philox_random_string << '\n';
}
if (halfChecker.hasHalf()) {
os() << fuser::cuda::half_support_literal << '\n';
}
if (halfChecker.hasBFloat16()) {
os() << fuser::cuda::bfloat16_support_literal << '\n';
}
std::string func_name = GetUniqueFuncName(kernel_func_name());
os() << "extern \"C\" __global__" << '\n';
#if defined(USE_ROCM)
// CUDA has a default limit of threads per block (=flat work group size)
// of 1024, but ROCm uses 256 by default. At the time of writing
// (#45506), I am unaware of a stricter limit that TensorExpr imposes
// (maybe for perf),so I use 1024 as maximum flat work group size.
// We put a minimum value of 1, this is also used by hip (ROCm 3.8) in
// the __launch_bound__ implementation. The arguments for the attribute
// are (min, max), for details see the documentation at
// https://clang.llvm.org/docs/AttributeReference.html#amdgpu-flat-work-group-size
os() << "__attribute__((amdgpu_flat_work_group_size(1, 1024)))" << std::endl;
#endif
os() << "void " << func_name << "(";
const std::vector<BufferArg> buffer_args = this->buffer_args();
for (size_t i = 0; i < buffer_args.size(); i++) {
if (i > 0) {
os() << ", ";
}
const BufferArg& buffer_arg = buffer_args[i];
VarPtr var = buffer_arg.var();
Dtype dtype = buffer_arg.dtype();
os() << printer_->dtypeToCppString(dtype)
<< (buffer_arg.isVar() ? " " : "* ")
<< name_manager()->get_unique_name(var);
}
VarPtr rand_seed;
VarPtr rand_offset;
if (has_random_) {
// TODO: switch to kUint64 when it is available.
rand_seed = alloc<Var>("rand_seed", kInt);
rand_offset = alloc<Var>("rand_offset", kInt);
std::string uint64_str = "unsigned long long";
os() << ", " << uint64_str << " " << *rand_seed << ", " << uint64_str << " "
<< *rand_offset;
}
os() << ") {";
os() << '\n';
if (has_random_) {
VarPtr idx = alloc<Var>("idx", kInt);
os() << "int " << *idx << " = blockIdx.x*blockDim.x + threadIdx.x;" << '\n';
VarPtr rand_func = printer_->rand_func();
os() << "Philox " << *rand_func << "(" << *rand_seed << ", " << *idx << ", "
<< *rand_offset << ");" << '\n';
os() << '\n';
}
stmt_v->accept(cuda_analysis_.get());
stmt_v = stmt_v->accept_mutator(metavar_rewriter_.get());
AtomicAddFuser atomic_add_fuser(
cuda_analysis_->thread_local_bufs(), *metavar_rewriter_);
stmt_v = stmt_v->accept_mutator(&atomic_add_fuser);
stmt_v = registerize(stmt_v);
PrioritizeLoad prioritize_load;
stmt_v = stmt_v->accept_mutator(&prioritize_load);
// The registerizer might insert half-type scalars, we don't want this.
HalfRewriter hsFix;
stmt_v = stmt_v->accept_mutator(&hsFix);
stmt_v = IRSimplifier::simplify(stmt_v);
set_stmt(stmt_v);
stmt_v->accept(printer_.get());
os() << '\n';
os() << "}";
// Check that all block extents had been set.
const std::vector<ExprPtr>& gpu_block_extents =
metavar_rewriter_->gpu_block_extents();
for (size_t i = 0; i < gpu_block_extents.size(); i++) {
if (!gpu_block_extents[i]) {
throw std::runtime_error("Missing gpu_block_index: " + std::to_string(i));
}
}
// Precompute block and thread extents for call_with_numel(). If
// precomputation can't be done (block/thread extents aren't
// constant), then disallow call_with_numel.
auto block_extents = metavar_rewriter_->gpu_block_extents();
auto thread_extents = metavar_rewriter_->gpu_thread_extents();
bool canCallWithNumel =
!has_random_ && !block_extents.empty() && !thread_extents.empty();
for (size_t i = 1; i < block_extents.size() && canCallWithNumel; i++) {
canCallWithNumel = canCallWithNumel && block_extents[i]->isConstant() &&
immediateAs<int>(block_extents[i]) == 1;
}
for (size_t i = 1; i < thread_extents.size() && canCallWithNumel; i++) {
canCallWithNumel = canCallWithNumel && thread_extents[i]->isConstant() &&
immediateAs<int>(thread_extents[i]) == 1;
}
if (canCallWithNumel && thread_extents[0]->isConstant()) {
// We assume block_extents[0] is output.numel()/thread_block_size_.
thread_block_size_ = immediateAs<int>(thread_extents[0]);
} else {
// Disable call_with_numel.
thread_block_size_ = -1;
}
// Build an LLVM based eval expression for the extents
block_extents_eval_.reserve(block_extents.size());
std::vector<BufferArg> extents_buffer_args;
// We need to extract the args that are used in the thread and block extents
// from bufferArgs and only use those for the `ExprEval` below. Without this,
// bufferArgs might contain arbitrary types that are not handled by LLVM and
// hence would result in an error.
std::unordered_set<VarPtr> vars_in_extents;
for (const auto& be : block_extents) {
auto v = VarFinder::find(be);
vars_in_extents.insert(v.begin(), v.end());
}
for (const auto& te : thread_extents) {
auto v = VarFinder::find(te);
vars_in_extents.insert(v.begin(), v.end());
}
for (const size_t i : c10::irange(buffer_args.size())) {
if (vars_in_extents.count(buffer_args[i].var())) {
extents_buffer_args.push_back(buffer_args[i]);
arg_pos_in_extents_.push_back(true);
} else {
arg_pos_in_extents_.push_back(false);
}
}
for (const auto& be : block_extents) {
#ifdef TORCH_ENABLE_LLVM
block_extents_eval_.emplace_back(
ExprEval<LLVMCodeGen>(ExprHandle(be), extents_buffer_args));
#else
block_extents_eval_.emplace_back(ExprHandle(be), extents_buffer_args);
#endif
}
thread_extents_eval_.reserve(thread_extents.size());
for (const auto& te : thread_extents) {
#ifdef TORCH_ENABLE_LLVM
thread_extents_eval_.emplace_back(
ExprEval<LLVMCodeGen>(ExprHandle(te), extents_buffer_args));
#else
thread_extents_eval_.emplace_back(ExprHandle(te), extents_buffer_args);
#endif
}
GRAPH_DEBUG(
"Fused TE CUDA kernel:\n",
oss_.str(),
"\n",
"gpu_block_extents: (",
metavar_rewriter_->gpu_block_extents(),
")\n",
"gpu_thread_extents: (",
metavar_rewriter_->gpu_thread_extents(),
")");
CompileToNVRTC(oss_.str(), func_name);
}
void CudaCodeGen::call_with_numel(void** args, int64_t numel) {
if (C10_UNLIKELY(numel == 0)) {
return;
}
if (C10_UNLIKELY(thread_block_size_ <= 0)) {
TORCH_INTERNAL_ASSERT(
thread_block_size_ >= 0,
"call_with_numel() requires a precomputed thread block size");
}
auto const& buffer_args = this->buffer_args();
size_t gpu_block_extents =
(numel + thread_block_size_ - 1) / thread_block_size_;
size_t gpu_thread_extents = thread_block_size_;
// In CUDA we need to pass pointers to pointers for buffers, thus we need to
// go over args and add an extra indirection for such non-scalar
// arguments.
// Why? See some details here:
// https://stackoverflow.com/questions/34388712/cannot-understand-how-jcuda-culaunchkernel-work
std::vector<void*> ptr_to_args(buffer_args.size());
for (size_t i = 0; i < buffer_args.size(); i++) {
ptr_to_args[i] = buffer_args[i].isVar() ? args[i] : (&args[i]);
}
const auto device = this->device().index();
const auto prior_device = at::cuda::current_device();
if (prior_device != device) {
at::cuda::set_device(device);
}
auto stream = at::cuda::getCurrentCUDAStream();
at::cuda::jit::initializeCudaContext();
AT_CUDA_DRIVER_CHECK(nvrtc().cuLaunchKernel(
function_,
gpu_block_extents,
1,
1,
gpu_thread_extents,
1,
1,
0,
stream,
ptr_to_args.data(),
nullptr));
if (prior_device != device) {
at::cuda::set_device(prior_device);
}
}
void CudaCodeGen::call_raw(const std::vector<void*>& raw_args) {
auto const& buffer_args = this->buffer_args();
// TODO: move as much of this into the constructors.
const std::vector<ExprPtr>& gpu_block_extents =
metavar_rewriter_->gpu_block_extents();
const std::vector<ExprPtr>& gpu_thread_extents =
metavar_rewriter_->gpu_thread_extents();
if (gpu_block_extents.size() > 3 || gpu_thread_extents.size() > 3) {
throw malformed_input(
"cuda_codegen: block or thread extent greater than 3D");
}
std::vector<int64_t> gpu_block_extents_v(3, 1);
std::vector<int64_t> gpu_thread_extents_v(3, 1);
// evaluate all the block/thread extents into values
// TODO: eventually, codegen these calculations and make them part of the
// module.
std::vector<void*> extent_args;
size_t raw_args_size = raw_args.size();
extent_args.reserve(raw_args_size);
for (size_t i = 0; i < raw_args_size; ++i) {
if (arg_pos_in_extents_[i]) {
extent_args.push_back(raw_args[i]);
}
}
for (size_t i = 0; i < gpu_block_extents.size(); i++) {
if (gpu_block_extents[i]->isConstant()) {
gpu_block_extents_v[i] = immediateAs<int64_t>(gpu_block_extents[i]);
continue;
}
{
// invocation of block_extents_eval_ isn't thread safe and this function
// may be invoked by multiple threads
std::lock_guard<std::mutex> guard(eval_lock_);
gpu_block_extents_v[i] =
block_extents_eval_[i].value<int64_t>(extent_args);
}
}
for (size_t i = 0; i < gpu_thread_extents.size(); i++) {
if (gpu_thread_extents[i]->isConstant()) {
gpu_thread_extents_v[i] = immediateAs<int64_t>(gpu_thread_extents[i]);
continue;
}
{
std::lock_guard<std::mutex> guard(eval_lock_);
gpu_thread_extents_v[i] =
thread_extents_eval_[i].value<int64_t>(extent_args);
}
}
// Skip launching the kernel if there are no elements to process.
for (auto extent : gpu_block_extents_v) {
if (extent == 0) {
return;
}
}
auto ptr_count = buffer_args.size();
// If the kernel has a rand call in it, add two extra arguments for random
// seed and offset.
if (has_random_) {
ptr_count += 2;
}
std::vector<void*> ptr_to_args(ptr_count);
// In CUDA we need to pass pointers to pointers for buffers, thus we need to
// go over raw_args and add an extra indirection for such non-scalar
// arguments.
// Why? See some details here:
// https://stackoverflow.com/questions/34388712/cannot-understand-how-jcuda-culaunchkernel-work
for (size_t i = 0; i < buffer_args.size(); i++) {
ptr_to_args[i] =
buffer_args[i].isVar() ? raw_args[i] : const_cast<void**>(&raw_args[i]);
}
if (has_random_) {
uint64_t rand_seed = uint64_t(-1);
uint64_t rand_offset = uint64_t(-1);
auto gen = at::cuda::detail::getDefaultCUDAGenerator();
// TODO: total hack. Switch to numel when it is available.
int64_t total_elements_per_thread = (1LL << 28);
{
std::lock_guard<std::mutex> lock(gen.mutex());
auto philox_engine_inputs =
at::check_generator<at::CUDAGeneratorImpl>(gen)->philox_engine_inputs(
total_elements_per_thread);
rand_seed = philox_engine_inputs.first;
rand_offset = philox_engine_inputs.second;
}
ptr_to_args[buffer_args.size()] = &rand_seed;
ptr_to_args[buffer_args.size() + 1] = &rand_offset;
}
auto prior_device = at::cuda::current_device();
if (prior_device != this->device().index()) {
at::cuda::set_device(this->device().index());
}
// Launch the kernels
auto stream = at::cuda::getCurrentCUDAStream();
at::cuda::jit::initializeCudaContext();
AT_CUDA_DRIVER_CHECK(nvrtc().cuLaunchKernel(
function_,
gpu_block_extents_v[0],
gpu_block_extents_v[1],
gpu_block_extents_v[2],
gpu_thread_extents_v[0],
gpu_thread_extents_v[1],
gpu_thread_extents_v[2],
0,
stream,
ptr_to_args.data(),
nullptr));
if (prior_device != this->device().index()) {
at::cuda::set_device(prior_device);
}
}
void CudaCodeGen::call(const std::vector<CallArg>& args) {
if (args.size() != buffer_args().size()) {
throw malformed_input("cuda_codegen: wrong number of args in call");
}
auto const& buffer_args = this->buffer_args();
std::vector<void*> raw_args(buffer_args.size());
for (size_t i = 0; i < buffer_args.size(); i++) {
auto const& bufferArg = buffer_args[i];
auto const& callArg = args[i];
raw_args[i] = argToPtr(bufferArg, callArg);
}
call_raw(raw_args);
}
at::Tensor CudaCodeGen::empty_strided(
c10::IntArrayRef size,
c10::IntArrayRef stride,
std::optional<c10::ScalarType> dtype_opt,
std::optional<c10::Layout> layout_opt,
std::optional<c10::Device> device_opt,
std::optional<bool> pin_memory_opt) {
c10::DeviceGuard device_guard(device_opt.value());
return at::native::empty_strided_cuda(
size, stride, dtype_opt, layout_opt, device_opt, pin_memory_opt);
}
void CudaCodeGen::CompileToNVRTC(
const std::string& code,
const std::string& func_name) {
at::cuda::jit::initializeCudaContext();
// Note: hacked at::DeviceGuard since at::DeviceGuard was failing to work
// properly in some scenarios
auto prior_device = at::cuda::current_device();
if (prior_device != this->device().index()) {
at::cuda::set_device(this->device().index());
}
// Acquires device and NVRTC properties (for compile arch and occupancy
// calculations)
cudaDeviceProp* prop = at::cuda::getCurrentDeviceProperties();
int major = 0, minor = 0;
bool compile_to_sass = false;
fuser::cuda::codegenOutputQuery(prop, major, minor, compile_to_sass);
// Creates the NVRTC program
nvrtcProgram program{nullptr};
AT_CUDA_NVRTC_CHECK(nvrtc().nvrtcCreateProgram(
&program, code.c_str(), nullptr, 0, nullptr, nullptr));
#if defined(USE_ROCM)
std::vector<const char*> args = {"--std=c++17"};
args.push_back("-hip-pch");
#else
const std::string compute = std::string("--gpu-architecture=") +
#if !defined(USE_ROCM)
// CUDA 11.1 allows going directly to SASS (sm_) instead of PTX (compute_)
// which gives better backwards compatibility to work on older driver,
// (since older driver doesn't necessarily recognize PTX emitted by new
// toolkit);
// Meanwhile, for forward compatibility (future device with
// `compile_to_sass==false`), since SASS are not necessarily compatible,
// we fallback to PTX instead.
(compile_to_sass ? "sm_" : "compute_") +
#else
"compute_" +
#endif
std::to_string(major) + std::to_string(minor);
const std::vector<const char*> args = {
"--std=c++17", compute.c_str(), "-default-device"};
#endif
auto result = nvrtc().nvrtcCompileProgram(
program, static_cast<int>(args.size()), args.data());
if (result != NVRTC_SUCCESS) {
size_t logsize = 0;
AT_CUDA_NVRTC_CHECK(nvrtc().nvrtcGetProgramLogSize(program, &logsize));
std::vector<char> log(logsize);
AT_CUDA_NVRTC_CHECK(nvrtc().nvrtcGetProgramLog(program, log.data()));
std::stringstream cu;
cu << log.data() << '\n';
cu << "nvrtc compilation failed: " << '\n';
cu << code << '\n';
throw std::runtime_error(cu.str());
}
ResourceGuard holdProgram(
[&] { AT_CUDA_NVRTC_CHECK(nvrtc().nvrtcDestroyProgram(&program)); });
AT_CUDA_NVRTC_CHECK(result);
size_t ptx_size = 0;
std::vector<char> ptx;
#if !defined(USE_ROCM)
// compile_to_sass determines whether we are generating SASS or PTX, hence
// the different API.
auto getSize = compile_to_sass
? at::globalContext().getNVRTC().nvrtcGetCUBINSize
: at::globalContext().getNVRTC().nvrtcGetPTXSize;
auto getFunc = compile_to_sass ? at::globalContext().getNVRTC().nvrtcGetCUBIN
: at::globalContext().getNVRTC().nvrtcGetPTX;
#else
auto getSize = at::globalContext().getNVRTC().nvrtcGetPTXSize;
auto getFunc = at::globalContext().getNVRTC().nvrtcGetPTX;
#endif
AT_CUDA_NVRTC_CHECK(getSize(program, &ptx_size));
ptx.resize(ptx_size);
AT_CUDA_NVRTC_CHECK(getFunc(program, ptx.data()));
CUmodule module{nullptr};
AT_CUDA_DRIVER_CHECK(nvrtc().cuModuleLoadData(&module, ptx.data()));
AT_CUDA_DRIVER_CHECK(
nvrtc().cuModuleGetFunction(&function_, module, func_name.c_str()));
if (prior_device != this->device().index()) {
at::cuda::set_device(prior_device);
}
}
CudaCodeGen::~CudaCodeGen() = default;
RegisterCodeGen<CudaCodeGen> cuda_codegen_reg("cuda_codegen");
} // namespace torch::jit::tensorexpr