mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-21 05:34:18 +08:00
Summary: [Comment](https://github.com/pytorch/pytorch/pull/62445/files#r680132022) claims, it got added for consistency with top level CMakeLists.txt, but `-Wno-unused-variable` is not mentioned there. Modify violations in 50+ files that were added in the interim by either removing unused variables, or decorating the code with `C10_UNUSED` if local variable is likely used to extend object lifetime until the end of the block. Caused preventable revert in https://github.com/pytorch/pytorch/pull/72633#issuecomment-1092300787 Pull Request resolved: https://github.com/pytorch/pytorch/pull/75538 Reviewed By: anjali411 Differential Revision: D35747333 Pulled By: malfet fbshipit-source-id: 3fc5828e44a4c05ba0e89e92613e6ebbdb260626 (cherry picked from commit c179fba21cfa2a0093fad50ccad5a22dd7cff52c)
3143 lines
111 KiB
C++
3143 lines
111 KiB
C++
#include <torch/csrc/jit/codegen/cuda/index_compute.h>
|
|
|
|
#include <c10/util/Exception.h>
|
|
#include <c10/util/irange.h>
|
|
#include <torch/csrc/jit/codegen/cuda/arith.h>
|
|
#include <torch/csrc/jit/codegen/cuda/contiguity.h>
|
|
#include <torch/csrc/jit/codegen/cuda/expr_evaluator.h>
|
|
#include <torch/csrc/jit/codegen/cuda/index_reference_replay.h>
|
|
#include <torch/csrc/jit/codegen/cuda/instrumentation.h>
|
|
#include <torch/csrc/jit/codegen/cuda/ir_all_nodes.h>
|
|
#include <torch/csrc/jit/codegen/cuda/ir_iostream.h>
|
|
#include <torch/csrc/jit/codegen/cuda/ir_utils.h>
|
|
#include <torch/csrc/jit/codegen/cuda/kernel_expr_evaluator.h>
|
|
#include <torch/csrc/jit/codegen/cuda/lower2device.h>
|
|
#include <torch/csrc/jit/codegen/cuda/lower_double_buffer.h>
|
|
#include <torch/csrc/jit/codegen/cuda/lower_magic_zero.h>
|
|
#include <torch/csrc/jit/codegen/cuda/lower_shift.h>
|
|
#include <torch/csrc/jit/codegen/cuda/lower_unroll.h>
|
|
#include <torch/csrc/jit/codegen/cuda/lower_utils.h>
|
|
#include <torch/csrc/jit/codegen/cuda/root_domain_map.h>
|
|
#include <torch/csrc/jit/codegen/cuda/transform_iter.h>
|
|
#include <torch/csrc/jit/codegen/cuda/transform_replay.h>
|
|
|
|
namespace torch {
|
|
namespace jit {
|
|
namespace fuser {
|
|
namespace cuda {
|
|
|
|
namespace {
|
|
|
|
// Update the HaloInfo mappings for a reference tensor by propagating
|
|
// the halo information from the consumer tensor.
|
|
void updateHaloInfoForReference(
|
|
const ReferenceTensor& reference,
|
|
const TensorView* consumer_tv) {
|
|
const auto gpu_lower = GpuLower::current();
|
|
|
|
auto& halo_info = gpu_lower->haloInfo();
|
|
|
|
auto reference_domain = reference.domain;
|
|
|
|
// First, propagate the halo information of the consumer root domain
|
|
// to the reference root domain.
|
|
for (auto consumer_root_id : consumer_tv->getRootDomain()) {
|
|
auto consumer_index_concrete_id =
|
|
gpu_lower->caIndexMap().getConcreteMappedID(consumer_root_id);
|
|
auto reference_it =
|
|
reference.concrete_to_id.find(consumer_index_concrete_id);
|
|
if (reference_it == reference.concrete_to_id.end()) {
|
|
// This happens when consumer_root_id is a broadcast or an
|
|
// initialization of a reduction buffer. In those cases, since
|
|
// the domain is not going to be predicated, it's not necessary
|
|
// to propagate halo information to the reference tensor.
|
|
continue;
|
|
}
|
|
auto reference_id = reference_it->second;
|
|
halo_info.setRootAxisInfo(
|
|
reference_id, halo_info.getRootAxisInfo(consumer_root_id));
|
|
}
|
|
|
|
// Now that the reference root has halo information copied from
|
|
// the cosumer, propagate it down to non-root domains.
|
|
halo_info.build(reference_domain);
|
|
|
|
return;
|
|
}
|
|
|
|
// Get a map of IterDomains to halo-extended extents of corresponding
|
|
// reference IterDomains.
|
|
//
|
|
// ref_map: ref-to-consumer in consumer indexing; ref-to-producer in
|
|
// producer indexing
|
|
std::unordered_map<IterDomain*, Val*> getReferenceHaloExtentMap(
|
|
const ReferenceTensor& reference,
|
|
const std::unordered_map<IterDomain*, IterDomain*>& index_map_from_ref) {
|
|
const auto& halo_info = GpuLower::current()->haloInfo();
|
|
|
|
std::unordered_map<IterDomain*, Val*> reference_halo_extent_map;
|
|
|
|
// Propagate halo extents of the reference to the consumer or
|
|
// producer tensor
|
|
for (auto kv : index_map_from_ref) {
|
|
auto ref_id = kv.first;
|
|
auto producer_or_consumer_id = kv.second;
|
|
auto extent = halo_info.getExtent(ref_id);
|
|
if (extent != nullptr) {
|
|
reference_halo_extent_map[producer_or_consumer_id] = extent;
|
|
}
|
|
}
|
|
|
|
return reference_halo_extent_map;
|
|
}
|
|
|
|
//! Offset of an index of a producer axis with respect to its
|
|
//! corresponding consumer index
|
|
int getProducerHaloOffset(
|
|
const TensorView* producer_tv,
|
|
size_t producer_axis,
|
|
const TensorView* consumer_tv) {
|
|
auto p2c =
|
|
PairwiseRootDomainMap(producer_tv, consumer_tv)
|
|
.mapProducerToConsumer(producer_tv->domain(), consumer_tv->domain());
|
|
|
|
auto producer_id = producer_tv->getMaybeRFactorDomain()[producer_axis];
|
|
|
|
auto it = p2c.find(producer_id);
|
|
// p2c should always have a mapping for producer_id. The only case
|
|
// where no mapping exists for a producer axis is when it is a
|
|
// reduction axis. Since this function is only used for indexing
|
|
// producer tensors, where reduction axes are skipped, producer_id
|
|
// should never be a reduction axis.
|
|
TORCH_INTERNAL_ASSERT(it != p2c.end());
|
|
IterDomain* consumer_id = it->second;
|
|
|
|
const auto& halo_map = GpuLower::current()->haloInfo();
|
|
const auto p_pad = halo_map.getRootAxisInfo(producer_id).width(0);
|
|
const auto c_pad = halo_map.getRootAxisInfo(consumer_id).width(0);
|
|
|
|
auto offset = p_pad - c_pad;
|
|
|
|
// If the consumer is a result of shifting the producer, adjust the
|
|
// producer index per the offsets argument of the shift op.
|
|
if (auto shift_op = dynamic_cast<const ShiftOp*>(consumer_tv->definition())) {
|
|
offset -= shift_op->offset(producer_axis);
|
|
}
|
|
|
|
return offset;
|
|
}
|
|
|
|
//! Offset producer index when necessary
|
|
Val* getProducerIndexWithHalo(
|
|
const TensorView* producer_tv,
|
|
size_t producer_axis,
|
|
Val* producer_index,
|
|
const TensorView* consumer_tv) {
|
|
const auto offset =
|
|
getProducerHaloOffset(producer_tv, producer_axis, consumer_tv);
|
|
|
|
if (offset == 0) {
|
|
return producer_index;
|
|
}
|
|
|
|
producer_index = SimplifyingIrBuilder::addExpr(producer_index, offset);
|
|
|
|
return producer_index;
|
|
}
|
|
|
|
//! Create a producer offset based off a consumer index
|
|
//!
|
|
//! \param consumer_root_axis Position of corresponding consumer axis
|
|
//! \param consumer_tv Consumer TensorView
|
|
//! \param index_map Mappings from consumer or reference to indices
|
|
//! \param use_reference_map True when index_map maps reference domains
|
|
//! \param concrete_to_ref_map Mappings from concrete to reference domains
|
|
Val* getProducerOffsetWithGather(
|
|
size_t consumer_root_axis,
|
|
const TensorView* consumer_tv,
|
|
const std::unordered_map<IterDomain*, Val*>& index_map,
|
|
bool use_reference_map = false,
|
|
const std::unordered_map<IterDomain*, IterDomain*>& concrete_to_ref_map =
|
|
{}) {
|
|
const auto gpu_lower = GpuLower::current();
|
|
|
|
const auto gather_expr = dynamic_cast<GatherOp*>(consumer_tv->definition());
|
|
|
|
if (gather_expr == nullptr) {
|
|
return gpu_lower->kernel()->zeroVal();
|
|
}
|
|
|
|
// If the window extent is one, no specific offsetting
|
|
// is necessary
|
|
if (consumer_root_axis >= gather_expr->windowShape().size() ||
|
|
gather_expr->windowShape()[consumer_root_axis] == 1) {
|
|
return gpu_lower->kernel()->zeroVal();
|
|
}
|
|
|
|
// Basically, the goal is to build an expression of producer_index +
|
|
// window_index, so we first need to locate the index expression
|
|
// that corresponds to the window axis of this producer axis.
|
|
|
|
const auto window_axis = gather_expr->gatherAxis(consumer_root_axis);
|
|
auto window_id = consumer_tv->getRootDomain().at(window_axis);
|
|
|
|
// When index_map maps a reference tensor, find the corresponding
|
|
// reference ID of window_id.
|
|
if (use_reference_map) {
|
|
auto concrete_window_id =
|
|
gpu_lower->caIndexMap().getConcreteMappedID(window_id);
|
|
auto concrete_2_ref_it = concrete_to_ref_map.find(concrete_window_id);
|
|
TORCH_INTERNAL_ASSERT(concrete_2_ref_it != concrete_to_ref_map.end());
|
|
window_id = concrete_2_ref_it->second;
|
|
}
|
|
|
|
auto window_idx = index_map.at(window_id);
|
|
|
|
// Positive padding at offset zero means the indexing shifted to the
|
|
// negative direction.
|
|
auto pad_width = gather_expr->padWidth()[consumer_root_axis][0];
|
|
|
|
// producer offset: window_index - padding
|
|
auto producer_offset = SimplifyingIrBuilder::subExpr(
|
|
window_idx, SimplifyingIrBuilder::create<Int>(pad_width));
|
|
return producer_offset;
|
|
}
|
|
|
|
//! Offset a producer index of a gather expression
|
|
//!
|
|
//! Given an index of a producer root axis, build a new index
|
|
//! expression that accesses a window position that the current loop
|
|
//! structure refers to. Use getGatherProducerOffset to create an
|
|
//! offset Val.
|
|
Val* getProducerIndexWithGather(
|
|
Val* producer_index,
|
|
size_t producer_root_axis,
|
|
const TensorView* producer_tv,
|
|
const TensorView* consumer_tv,
|
|
const std::unordered_map<IterDomain*, IterDomain*>& concrete_to_ref_map,
|
|
const std::unordered_map<IterDomain*, Val*>& ref_index_map) {
|
|
auto gather_op = dynamic_cast<const GatherOp*>(consumer_tv->definition());
|
|
|
|
// Just return the producer index as is if this is not a gather
|
|
if (gather_op == nullptr) {
|
|
return producer_index;
|
|
}
|
|
|
|
// Consumer axis that corresponds to the producer axis
|
|
int consumer_axis = -1;
|
|
for (const auto i : c10::irange(producer_root_axis + 1)) {
|
|
if (producer_tv->getMaybeRFactorDomain()[i]->isReduction() ||
|
|
producer_tv->getMaybeRFactorDomain()[i]->isStride()) {
|
|
continue;
|
|
}
|
|
++consumer_axis;
|
|
}
|
|
|
|
TORCH_INTERNAL_ASSERT(
|
|
consumer_axis >= 0 &&
|
|
consumer_axis < (int)gather_op->windowShape().size(),
|
|
"Invalid consumer axis",
|
|
consumer_axis,
|
|
", producer_axis: ",
|
|
producer_root_axis);
|
|
|
|
auto offset = getProducerOffsetWithGather(
|
|
consumer_axis, consumer_tv, ref_index_map, true, concrete_to_ref_map);
|
|
return SimplifyingIrBuilder::addExpr(producer_index, offset);
|
|
}
|
|
|
|
// Adjusts a global consumer index when its root domain is partially
|
|
// split. Note that non-global consumer indices don't need any
|
|
// adjustment.
|
|
Val* getGlobalConsumerOffsetWithPartialSplit(IterDomain* root_id) {
|
|
auto offset = GpuLower::current()->partialSplitMap().getStartOffset(root_id);
|
|
if (offset == nullptr) {
|
|
return GpuLower::current()->kernel()->zeroVal();
|
|
} else {
|
|
return offset;
|
|
}
|
|
}
|
|
|
|
// Adjusts a global producer index when its root domain and
|
|
// corresponding consumer root domain have non-matching split
|
|
// offsets. Specifically, since producer_index is calcualted based on
|
|
// the consumer, if the consumer has a non-zero offset,
|
|
// it needs to be added to the index. Also, when the producer itself
|
|
// also has a non-zero split offset, that needs to be subtracted from
|
|
// the index.
|
|
Val* getProducerIndexWithPartialSplit(
|
|
Val* producer_index,
|
|
IterDomain* producer_root_id,
|
|
const TensorView* producer_tv,
|
|
const TensorView* consumer_tv) {
|
|
const auto gpu_lower = GpuLower::current();
|
|
|
|
auto p2c =
|
|
PairwiseRootDomainMap(producer_tv, consumer_tv)
|
|
.mapProducerToConsumer(producer_tv->domain(), consumer_tv->domain());
|
|
|
|
auto it = p2c.find(producer_root_id);
|
|
if (it == p2c.end()) {
|
|
return producer_index;
|
|
}
|
|
|
|
auto consumer_root_id = it->second;
|
|
|
|
auto consumer_offset =
|
|
gpu_lower->partialSplitMap().getStartOffset(consumer_root_id);
|
|
consumer_offset = consumer_offset == nullptr ? gpu_lower->kernel()->zeroVal()
|
|
: consumer_offset;
|
|
|
|
auto producer_offset =
|
|
gpu_lower->partialSplitMap().getStartOffset(producer_root_id);
|
|
producer_offset = producer_offset == nullptr ? gpu_lower->kernel()->zeroVal()
|
|
: producer_offset;
|
|
|
|
// If the producer is on global memory, it's always allocated
|
|
// without trimming the out-of-bounds region, so the consumer offset
|
|
// should be added to the index.
|
|
if (producer_tv->getMemoryType() == MemoryType::Global) {
|
|
if (consumer_offset->isZeroInt()) {
|
|
return producer_index;
|
|
} else {
|
|
return SimplifyingIrBuilder::addExpr(producer_index, consumer_offset);
|
|
}
|
|
}
|
|
|
|
// Non-global case. Difference of the split offsets must be
|
|
// accounted.
|
|
|
|
auto diff = SimplifyingIrBuilder::subExpr(consumer_offset, producer_offset);
|
|
kir::ExpressionEvaluator ee;
|
|
auto diff_eval = ee.evaluate(diff);
|
|
// We currently only allow constant offsetting
|
|
TORCH_INTERNAL_ASSERT(diff_eval.has_value(), "Invalid partial split");
|
|
|
|
if (diff_eval.value() == 0) {
|
|
return producer_index;
|
|
}
|
|
|
|
return SimplifyingIrBuilder::addExpr(
|
|
producer_index, SimplifyingIrBuilder::create<Int>(diff_eval.value()));
|
|
}
|
|
|
|
} // namespace
|
|
|
|
void IndexCompute::handle(Split* split) {
|
|
auto in_id = split->in()->as<IterDomain>();
|
|
auto outer_id = split->outer()->as<IterDomain>();
|
|
auto inner_id = split->inner()->as<IterDomain>();
|
|
|
|
auto outer_it = index_map_.find(outer_id);
|
|
auto inner_it = index_map_.find(inner_id);
|
|
if (outer_it == index_map_.end() || inner_it == index_map_.end()) {
|
|
return;
|
|
}
|
|
|
|
const auto outer_ind = outer_it->second;
|
|
const auto inner_ind = inner_it->second;
|
|
|
|
const bool outer_zero = isZero(outer_id);
|
|
const bool inner_zero = isZero(inner_id);
|
|
|
|
// We want to mark as zero merged in if we're working with shared or local
|
|
// memory, and the dimension we're working with is not part of the allocation,
|
|
// as we have special propagation rules for that scenario.
|
|
|
|
// Maybe clear in_id as it could have been mapped over from another
|
|
// IndexCompute. Uncertain if this is needed but seems to be safe.
|
|
bool zero_merged_in = hasZeroMerged(in_id) || hasZeroMerged(inner_id) ||
|
|
hasZeroMerged(outer_id);
|
|
|
|
// If both are zero, the split input is also zero
|
|
if (inner_zero && outer_zero) {
|
|
zero_domains_.emplace(in_id);
|
|
}
|
|
|
|
if (zero_merged_in) {
|
|
zero_merged_in_.emplace(in_id);
|
|
}
|
|
|
|
if (isZero(in_id)) {
|
|
index_map_[in_id] = GpuLower::current()->kernel()->zeroVal();
|
|
extent_map_[in_id] = GpuLower::current()->kernel()->zeroVal();
|
|
} else if (zero_merged_in && outer_zero) {
|
|
index_map_[in_id] = inner_ind;
|
|
extent_map_[in_id] = getExtent(inner_id);
|
|
} else if (zero_merged_in && inner_zero) {
|
|
index_map_[in_id] = outer_ind;
|
|
extent_map_[in_id] = getExtent(outer_id);
|
|
} else {
|
|
index_map_[in_id] = SimplifyingIrBuilder::addExpr(
|
|
SimplifyingIrBuilder::mulExpr(outer_ind, getExtent(inner_id)),
|
|
inner_ind);
|
|
// The extent should be updated only when its allocation is
|
|
// partial, i.e., zero_merged_in is true. See PR #1270.
|
|
if (zero_merged_in) {
|
|
extent_map_[in_id] = SimplifyingIrBuilder::mulExpr(
|
|
getExtent(outer_id), getExtent(inner_id));
|
|
}
|
|
}
|
|
}
|
|
|
|
void IndexCompute::handle(Merge* merge) {
|
|
auto out_id = merge->out();
|
|
auto outer_id = merge->outer();
|
|
auto inner_id = merge->inner();
|
|
|
|
auto out_it = index_map_.find(out_id);
|
|
if (out_it == index_map_.end()) {
|
|
return;
|
|
}
|
|
auto out_ind = out_it->second;
|
|
|
|
auto zero = GpuLower::current()->kernel()->zeroVal();
|
|
|
|
if (isZero(out_id)) {
|
|
index_map_[outer_id] = zero;
|
|
index_map_[inner_id] = zero;
|
|
extent_map_[outer_id] = zero;
|
|
extent_map_[inner_id] = zero;
|
|
zero_domains_.emplace(outer_id);
|
|
zero_domains_.emplace(inner_id);
|
|
return;
|
|
}
|
|
|
|
if (!hasZeroMerged(out_id) && contig_ids.find(out_id) != contig_ids.end()) {
|
|
// Contiguous indexing path
|
|
auto input_ids = ir_utils::iterDomainInputsOfOrderedAs(
|
|
{merge->out()}, td_->getMaybeRFactorDomain());
|
|
|
|
// Shouldn't hit this, but don't want to segfault if somehow we do.
|
|
TORCH_INTERNAL_ASSERT(!input_ids.empty());
|
|
|
|
for (auto root_id : input_ids) {
|
|
index_map_[root_id] = zero;
|
|
}
|
|
|
|
index_map_[*(input_ids.end() - 1)] = out_ind;
|
|
return;
|
|
}
|
|
|
|
Val* inner_extent = getExtent(inner_id);
|
|
|
|
// When the reference has halo extent for inner_id, that extent needs to
|
|
// be used to un-merge
|
|
if (reference_halo_extent_map_.find(inner_id) !=
|
|
reference_halo_extent_map_.end()) {
|
|
inner_extent = reference_halo_extent_map_[inner_id];
|
|
}
|
|
|
|
const auto outer_extent = getExtent(outer_id);
|
|
|
|
if (inner_id->isBroadcast() && inner_extent->isOneInt()) {
|
|
// Propagate away from broadcast dims
|
|
index_map_[outer_id] = out_ind;
|
|
index_map_[inner_id] = zero;
|
|
|
|
extent_map_[outer_id] = getExtent(out_id);
|
|
} else if (outer_id->isBroadcast() && outer_extent->isOneInt()) {
|
|
// Propagate away from broadcast dims
|
|
index_map_[outer_id] = zero;
|
|
index_map_[inner_id] = out_ind;
|
|
|
|
extent_map_[inner_id] = getExtent(out_id);
|
|
} else if (hasZeroMerged(out_id)) {
|
|
// Don't propagate to inner id if it's comprised of only broadcast root
|
|
// domains, unless outer is also all broadcast domains. Index shouldn't be
|
|
// anything but zero if both inner and outer are all broadcast domains, but
|
|
// didn't add a hard check for this. See FusionAdvancedIndexing5_CUDA
|
|
if (!inner_id->isBroadcast() && !outer_id->isBroadcast()) {
|
|
// If neither dimension is a broadcast (should be true for reference
|
|
// indexing) pick the preferred path or the inner path.
|
|
if (preferred_paths_.find(outer_id) != preferred_paths_.end() &&
|
|
preferred_paths_.find(inner_id) == preferred_paths_.end()) {
|
|
// Marked that we should prop through outer, not inner.
|
|
index_map_[outer_id] = out_ind;
|
|
extent_map_[outer_id] = getExtent(out_id);
|
|
index_map_[inner_id] = zero;
|
|
extent_map_[inner_id] = zero;
|
|
zero_domains_.emplace(inner_id);
|
|
} else {
|
|
// Prop through inner
|
|
index_map_[inner_id] = out_ind;
|
|
extent_map_[inner_id] = getExtent(out_id);
|
|
index_map_[outer_id] = zero;
|
|
extent_map_[outer_id] = zero;
|
|
zero_domains_.emplace(outer_id);
|
|
}
|
|
} else if (inner_id->isBroadcast() && !outer_id->isBroadcast()) {
|
|
// Inner is broadcast and outer isn't, prop through outer
|
|
index_map_[outer_id] = out_ind;
|
|
extent_map_[outer_id] = getExtent(out_id);
|
|
index_map_[inner_id] = zero;
|
|
extent_map_[inner_id] = zero;
|
|
zero_domains_.emplace(inner_id);
|
|
} else {
|
|
// Default to propagating through inner
|
|
index_map_[inner_id] = out_ind;
|
|
extent_map_[inner_id] = getExtent(out_id);
|
|
index_map_[outer_id] = zero;
|
|
extent_map_[outer_id] = zero;
|
|
zero_domains_.emplace(outer_id);
|
|
}
|
|
zero_merged_in_.emplace(inner_id);
|
|
zero_merged_in_.emplace(outer_id);
|
|
} else {
|
|
index_map_[outer_id] = SimplifyingIrBuilder::divExpr(out_ind, inner_extent);
|
|
index_map_[inner_id] = SimplifyingIrBuilder::modExpr(out_ind, inner_extent);
|
|
}
|
|
}
|
|
|
|
void IndexCompute::handle(Expr* e) {
|
|
switch (e->getExprType().value()) {
|
|
case (ExprType::Split):
|
|
case (ExprType::Merge):
|
|
break;
|
|
default:
|
|
TORCH_INTERNAL_ASSERT(
|
|
false, "Invalid expr type found in transform traversal.");
|
|
}
|
|
BackwardVisitor::handle(e);
|
|
}
|
|
|
|
// Otherwise warning on runBackward as it hides an overloaded virtual
|
|
// using TransformIter::runBackward;
|
|
IndexCompute::IndexCompute(
|
|
const TensorDomain* _td,
|
|
std::unordered_map<IterDomain*, Val*> initial_index_map,
|
|
std::unordered_map<IterDomain*, Val*> extent_map,
|
|
std::unordered_set<IterDomain*> zero_domains,
|
|
std::unordered_set<IterDomain*> zero_merged_in,
|
|
const std::vector<bool>& root_contiguity,
|
|
std::unordered_set<IterDomain*> preferred_paths,
|
|
std::unordered_map<IterDomain*, Val*> reference_halo_extent_map)
|
|
: td_(_td),
|
|
index_map_(std::move(initial_index_map)),
|
|
extent_map_(std::move(extent_map)),
|
|
zero_domains_(std::move(zero_domains)),
|
|
zero_merged_in_(std::move(zero_merged_in)),
|
|
preferred_paths_(std::move(preferred_paths)),
|
|
reference_halo_extent_map_(std::move(reference_halo_extent_map)) {
|
|
FUSER_PERF_SCOPE("GpuLower::Lower::IndexCompute::IndexCompute");
|
|
|
|
// Make sure we recompute any indices we can that map to a contiguous access
|
|
// in physical memory.
|
|
if (std::any_of(root_contiguity.begin(), root_contiguity.end(), [](bool b) {
|
|
return b;
|
|
})) {
|
|
ContigIDs contig_finder(
|
|
td_->domain(), td_->getMaybeRFactorDomain(), root_contiguity);
|
|
contig_ids = contig_finder.contigIDs();
|
|
root_to_contig_id_ = contig_finder.rootToIndexedID();
|
|
const auto& within_contig = contig_finder.withinContigIDs();
|
|
for (auto contig_id : contig_ids) {
|
|
if (index_map_.find(contig_id) != index_map_.end()) {
|
|
TORCH_INTERNAL_ASSERT(
|
|
within_contig.find(contig_id) != within_contig.end());
|
|
for (auto id : within_contig.at(contig_id)) {
|
|
index_map_.erase(id);
|
|
}
|
|
}
|
|
}
|
|
} else {
|
|
for (auto root_id : td_->getMaybeRFactorDomain()) {
|
|
root_to_contig_id_[root_id] = root_id;
|
|
}
|
|
}
|
|
}
|
|
|
|
void IndexCompute::run() {
|
|
const std::vector<Val*> domain_vals(
|
|
td_->domain().begin(), td_->domain().end());
|
|
|
|
traverseFrom(td_->fusion(), domain_vals, false);
|
|
}
|
|
|
|
Val* IndexCompute::getExtent(IterDomain* id) const {
|
|
// Pick from extent_map_ if available. Previously parallel
|
|
// dimensions were ued (e.g., blockDim.x), however, it would result
|
|
// in out-of-bounds errors when the extent of IterDomain is smaller
|
|
// than the threading dimension.
|
|
if (extent_map_.find(id) != extent_map_.end()) {
|
|
return extent_map_.at(id);
|
|
} else {
|
|
return id->extent();
|
|
}
|
|
}
|
|
|
|
bool IndexCompute::hasZeroMerged(IterDomain* id) const {
|
|
return zero_merged_in_.find(id) != zero_merged_in_.end() || isZero(id);
|
|
}
|
|
|
|
bool IndexCompute::isZero(IterDomain* id) const {
|
|
return zero_domains_.find(id) != zero_domains_.end();
|
|
}
|
|
|
|
IndexCompute IndexCompute::updateIndexCompute(
|
|
const TensorDomain* new_td,
|
|
const std::unordered_map<IterDomain*, IterDomain*>& id_map,
|
|
const std::vector<bool>& root_contiguity,
|
|
const std::unordered_map<IterDomain*, Val*>& reference_halo_extent_map)
|
|
const {
|
|
FUSER_PERF_SCOPE("GpuLower::Lower::updateIndexCompute");
|
|
|
|
std::unordered_map<IterDomain*, Val*> updated_index_map;
|
|
std::unordered_map<IterDomain*, Val*> updated_extent_map;
|
|
std::unordered_set<IterDomain*> updated_zero_domains;
|
|
std::unordered_set<IterDomain*> updated_zero_merged_in;
|
|
|
|
for (auto id_entry : id_map) {
|
|
IterDomain* prev_id = id_entry.first;
|
|
IterDomain* new_id = id_entry.second;
|
|
|
|
if (index_map_.find(prev_id) != index_map_.end()) {
|
|
updated_index_map[new_id] = index_map_.at(prev_id);
|
|
}
|
|
|
|
updated_extent_map[new_id] = getExtent(prev_id);
|
|
|
|
if (zero_domains_.find(prev_id) != zero_domains_.end()) {
|
|
updated_zero_domains.emplace(new_id);
|
|
}
|
|
|
|
if (zero_merged_in_.find(prev_id) != zero_merged_in_.end()) {
|
|
updated_zero_merged_in.emplace(new_id);
|
|
}
|
|
}
|
|
|
|
IndexCompute updated_index_compute(
|
|
new_td,
|
|
updated_index_map,
|
|
updated_extent_map,
|
|
updated_zero_domains,
|
|
updated_zero_merged_in,
|
|
root_contiguity,
|
|
{},
|
|
reference_halo_extent_map);
|
|
updated_index_compute.run();
|
|
|
|
return updated_index_compute;
|
|
}
|
|
|
|
namespace {
|
|
// Map indices down to the leaf domains for applying swizzle
|
|
class UpdateLeafIndices : public IterVisitor {
|
|
public:
|
|
UpdateLeafIndices(
|
|
const TensorDomain* td,
|
|
std::unordered_map<IterDomain*, Val*> initial_index_map,
|
|
std::unordered_map<IterDomain*, Val*> extent_map)
|
|
: td_(td),
|
|
index_map_(std::move(initial_index_map)),
|
|
extent_map_(std::move(extent_map)) {
|
|
const std::vector<Val*> domain_vals(
|
|
td_->domain().begin(), td_->domain().end());
|
|
|
|
traverseFrom(td_->fusion(), domain_vals, false);
|
|
}
|
|
|
|
const std::unordered_map<IterDomain*, Val*>& indexMap() const {
|
|
return index_map_;
|
|
}
|
|
|
|
const std::unordered_map<IterDomain*, Val*>& extentMap() const {
|
|
return extent_map_;
|
|
}
|
|
|
|
private:
|
|
using IterVisitor::handle;
|
|
|
|
void handle(Split* split) override {
|
|
auto in_id = split->in();
|
|
auto outer_id = split->outer();
|
|
auto inner_id = split->inner();
|
|
|
|
// Nothing need to be done when mappings for the output axes
|
|
// already exist.
|
|
if (index_map_.find(outer_id) != index_map_.end()) {
|
|
TORCH_INTERNAL_ASSERT(
|
|
index_map_.find(inner_id) != index_map_.end(),
|
|
"Outer exists but inner not found");
|
|
return;
|
|
}
|
|
|
|
auto factor = split->factor();
|
|
index_map_[inner_id] =
|
|
SimplifyingIrBuilder::modExpr(index_map_[in_id], factor);
|
|
extent_map_[inner_id] = factor;
|
|
index_map_[outer_id] =
|
|
SimplifyingIrBuilder::divExpr(index_map_[in_id], factor);
|
|
extent_map_[outer_id] =
|
|
SimplifyingIrBuilder::ceilDivExpr(getExtent(in_id), factor);
|
|
}
|
|
|
|
void handle(Merge* merge) override {
|
|
auto out_id = merge->out();
|
|
auto outer_id = merge->outer();
|
|
auto inner_id = merge->inner();
|
|
|
|
// Nothing need to be done when mappings for the output axes
|
|
// already exist.
|
|
if (index_map_.find(out_id) != index_map_.end()) {
|
|
return;
|
|
}
|
|
|
|
TORCH_INTERNAL_ASSERT(
|
|
index_map_.find(outer_id) != index_map_.end(), "Outer ID not found");
|
|
TORCH_INTERNAL_ASSERT(
|
|
index_map_.find(inner_id) != index_map_.end(), "Inner ID not found");
|
|
|
|
index_map_[out_id] = SimplifyingIrBuilder::mulExpr(
|
|
index_map_[inner_id],
|
|
SimplifyingIrBuilder::mulExpr(
|
|
index_map_[outer_id], getExtent(inner_id)));
|
|
|
|
extent_map_[out_id] =
|
|
SimplifyingIrBuilder::mulExpr(getExtent(outer_id), getExtent(inner_id));
|
|
}
|
|
|
|
// return extent_map_[id] if exists, else return id->extent()
|
|
Val* getExtent(IterDomain* id) {
|
|
if (extent_map_.find(id) != extent_map_.end()) {
|
|
return extent_map_.at(id);
|
|
} else {
|
|
return id->extent();
|
|
}
|
|
}
|
|
|
|
private:
|
|
const TensorDomain* td_;
|
|
std::unordered_map<IterDomain*, Val*> index_map_;
|
|
std::unordered_map<IterDomain*, Val*> extent_map_;
|
|
};
|
|
|
|
// Returns halo-extended extent if id has halo. Otherwise, just
|
|
// returns id->extent.
|
|
Val* getHaloExtentOfRootAxis(IterDomain* id, Val* normal_extent = nullptr) {
|
|
if (normal_extent == nullptr) {
|
|
normal_extent = id->extent();
|
|
}
|
|
|
|
const auto& halo = GpuLower::current()->haloInfo().getRootAxisInfo(id);
|
|
if (halo.hasHalo()) {
|
|
auto halo_extent = SimplifyingIrBuilder::addExpr(
|
|
normal_extent, SimplifyingIrBuilder::create<Int>(halo.width()));
|
|
return halo_extent;
|
|
} else {
|
|
return normal_extent;
|
|
}
|
|
}
|
|
|
|
} // namespace
|
|
|
|
IndexSwizzle::IndexSwizzle(
|
|
const TensorView* tv,
|
|
std::unordered_map<IterDomain*, Val*> initial_index_map,
|
|
std::unordered_map<IterDomain*, Val*> extent_map,
|
|
std::unordered_set<IterDomain*> zero_domains,
|
|
std::unordered_set<IterDomain*> zero_merged_in)
|
|
: IndexCompute(
|
|
tv->domain(),
|
|
std::move(initial_index_map),
|
|
std::move(extent_map),
|
|
std::move(zero_domains),
|
|
std::move(zero_merged_in),
|
|
std::vector<bool>(tv->getRootDomain().size(), false)),
|
|
tv_(tv),
|
|
swizzle_type_(tv->swizzleType()),
|
|
ids_to_swizzle_(tv->axesToSwizzle()) {}
|
|
|
|
void IndexSwizzle::run() {
|
|
TORCH_INTERNAL_ASSERT(
|
|
swizzle_type_ == SwizzleType::NoSwizzle ||
|
|
swizzle_type_ == SwizzleType::Transpose,
|
|
"Invalid swizzle type");
|
|
if (swizzle_type_ == SwizzleType::Transpose) {
|
|
// Shifts the second axis by the first axis as ((idx_1 + idx_2) %
|
|
// ext). Alternatively, ((idx_1 - idx_2) & (ext - 1)) would also
|
|
// work if ext is a power of two. Practically, ext should be 32 if
|
|
// the data type of the tensor is float, so the latter approach
|
|
// should also be fine.
|
|
TORCH_INTERNAL_ASSERT(tv_->getMemoryType() == MemoryType::Shared);
|
|
TORCH_INTERNAL_ASSERT(tv_->axesToSwizzle().size() == 2);
|
|
|
|
UpdateLeafIndices update_leaves(td_, indexMap(), extentMap());
|
|
index_map_ = update_leaves.indexMap();
|
|
extent_map_ = update_leaves.extentMap();
|
|
|
|
IterDomain* id_to_swizzle_i = ids_to_swizzle_.at(0);
|
|
IterDomain* id_to_swizzle_j = ids_to_swizzle_.at(1);
|
|
|
|
if (indexMap().find(id_to_swizzle_i) != indexMap().end() &&
|
|
indexMap().find(id_to_swizzle_j) != indexMap().end()) {
|
|
auto idx_to_swizzle_i = indexMap().at(id_to_swizzle_i);
|
|
auto idx_to_swizzle_j = indexMap().at(id_to_swizzle_j);
|
|
|
|
auto swizzled_idx = SimplifyingIrBuilder::modExpr(
|
|
SimplifyingIrBuilder::addExpr(idx_to_swizzle_i, idx_to_swizzle_j),
|
|
id_to_swizzle_j->extent());
|
|
index_map_[id_to_swizzle_j] = swizzled_idx;
|
|
swizzled_ids_.insert(id_to_swizzle_j);
|
|
IndexCompute::run();
|
|
}
|
|
}
|
|
}
|
|
|
|
void IndexSwizzle::handle(Expr* e) {
|
|
auto out_ids = ir_utils::filterByType<IterDomain>(e->outputs());
|
|
bool needs_update =
|
|
std::any_of(out_ids.begin(), out_ids.end(), [this](IterDomain* id) {
|
|
return swizzled_ids_.find(id) != swizzled_ids_.end();
|
|
});
|
|
if (!needs_update) {
|
|
return;
|
|
}
|
|
|
|
IndexCompute::handle(e);
|
|
for (auto input : ir_utils::filterByType<IterDomain>(e->inputs())) {
|
|
swizzled_ids_.insert(input);
|
|
}
|
|
}
|
|
|
|
namespace {
|
|
|
|
// Used for local and shared index mapping. Returns a map from loops
|
|
// to loop indices as well as a set of loops that do not contribute to
|
|
// indexing.
|
|
std::pair<
|
|
std::unordered_map<kir::ForLoop*, Val*>,
|
|
std::unordered_set<kir::ForLoop*>>
|
|
indexMapFromTV(
|
|
const TensorView* tv,
|
|
const std::vector<kir::ForLoop*>& loops,
|
|
kir::ForLoop* alloc_loop,
|
|
bool as_consumer,
|
|
kir::ForLoop* double_buffer_loop = nullptr) {
|
|
bool within_alloc = false;
|
|
if (alloc_loop == nullptr) {
|
|
within_alloc = true;
|
|
}
|
|
|
|
const bool is_global = tv->getMemoryType() == MemoryType::Global;
|
|
const bool is_shared = tv->getMemoryType() == MemoryType::Shared;
|
|
const bool is_local = tv->getMemoryType() == MemoryType::Local;
|
|
|
|
std::unordered_map<kir::ForLoop*, Val*> loop_to_ind_map;
|
|
|
|
// Check if the current op has an implicit loop implemented
|
|
// within an mma instruction.
|
|
bool within_mma_loops =
|
|
std::any_of(loops.begin(), loops.end(), [](kir::ForLoop* fl) {
|
|
return fl->iter_domain()->isMma();
|
|
});
|
|
|
|
// When indexed as a producer, the parallel types of the the
|
|
// producer domains may not be the same as those of the loops, but
|
|
// that's still valid parallelization. However, in that case, using
|
|
// the parallel types of the loops to decide replacement of indices
|
|
// with zero isn't valid. That's only valid when there's a matching
|
|
// IterDomain in the producer tensor that has the same parallel
|
|
// type.
|
|
auto find_matching_parallel_domain = [tv](IterDomain* id) -> bool {
|
|
const auto gpu_lower = GpuLower::current();
|
|
auto it = std::find_if(
|
|
tv->domain()->domain().begin(),
|
|
tv->domain()->domain().end(),
|
|
[&](IterDomain* tv_id) {
|
|
// Matching is done using the index and loop maps. See
|
|
// validateParallelize as well.
|
|
return gpu_lower->caIndexMap().areMapped(id, tv_id) ||
|
|
(gpu_lower->caLoopMap().areMapped(id, tv_id) &&
|
|
ir_utils::derivedFromRootCAAxes(tv, tv_id));
|
|
});
|
|
if (it == tv->domain()->domain().end()) {
|
|
return false;
|
|
}
|
|
|
|
auto corresponding_domain = *it;
|
|
return corresponding_domain->getParallelType() == id->getParallelType();
|
|
};
|
|
|
|
// Track domains that do not contibute to the resulting
|
|
// index. Previously, index->isZeroInt() was used to detect such
|
|
// domains, but that's not a reliable method as we may set an
|
|
// initial index to zero for unswitch.
|
|
std::unordered_set<kir::ForLoop*> zero_loops;
|
|
|
|
for (auto loop : loops) {
|
|
Val* idx = nullptr;
|
|
const auto same_parallel_type = as_consumer ||
|
|
find_matching_parallel_domain(loop->iter_domain()) ||
|
|
// Note && TODO:
|
|
// mma swizzled lane_id does not map naturally from producer
|
|
// to consumer but they should still be detected as same
|
|
// parallel type. In a follow up may want to extent
|
|
// find_matching_parallel_domain to cover this case.
|
|
(within_mma_loops &&
|
|
loop->iter_domain()->getParallelType() == ParallelType::TIDx);
|
|
// See also LoopNestGenerator::pushAlloc.
|
|
// NOLINTNEXTLINE(bugprone-branch-clone)
|
|
if (!within_alloc) {
|
|
if ((loop->iter_domain()->isThreadDim() && is_shared) ||
|
|
(loop->iter_domain()->isThread() && is_global)) {
|
|
idx = loop->index();
|
|
} else {
|
|
idx = GpuLower::current()->kernel()->zeroVal();
|
|
zero_loops.insert(loop);
|
|
}
|
|
} else if (
|
|
// For shared-memory tensors, when a domain is parallelized by
|
|
// BID, the index can be replaced with zero as long as the
|
|
// tensor has a matching domain that has the same parallel
|
|
// type. Matching can be omitted when indexed as a consumer
|
|
// since it is always the case. When indexed as a producer, to
|
|
// replace it with zero, the same parallel type of BID must be
|
|
// used by the producer tensor. Thus, since this is a shared
|
|
// memory tensor, when a producer domain is parallelized by
|
|
// BID, there must be a matching consumer domain with the same
|
|
// parallel type, which must be the IterDomain of the
|
|
// loop.
|
|
(loop->iter_domain()->isBlockDim() && is_shared &&
|
|
same_parallel_type) ||
|
|
// Similarly for local memory tensors, zero replacement can be
|
|
// only done when there's a matching domain with the same
|
|
// parallel type
|
|
(loop->iter_domain()->isThread() && is_local && same_parallel_type)) {
|
|
idx = GpuLower::current()->kernel()->zeroVal();
|
|
zero_loops.insert(loop);
|
|
} else {
|
|
idx = loop->index();
|
|
}
|
|
|
|
// If the loop is trivial, the loop index can only be the loop
|
|
// start value.
|
|
if (idx == loop->index() && loop->isTrivial()) {
|
|
idx = loop->start();
|
|
}
|
|
|
|
if (loop == double_buffer_loop) {
|
|
idx = SimplifyingIrBuilder::addExpr(
|
|
idx, GpuLower::current()->kernel()->oneVal());
|
|
}
|
|
|
|
loop_to_ind_map[loop] = idx;
|
|
|
|
if (!within_alloc && loop == alloc_loop) {
|
|
within_alloc = true;
|
|
}
|
|
}
|
|
// NOLINTNEXTLINE(clang-analyzer-cplusplus.NewDeleteLeaks)
|
|
return {loop_to_ind_map, zero_loops};
|
|
}
|
|
|
|
//! Set "pragma unroll" required for loops that indexing of Local
|
|
//! tensors depends on.
|
|
//!
|
|
//! \param tv Indexed tensor
|
|
//! \param alloc_loop Allocation loop of tv
|
|
//! \param loops The current loop structure
|
|
//! \param id_map Producer-to-consumer map in case of indexing as producer
|
|
void ensureStaticIndexing(
|
|
const TensorView* tv,
|
|
kir::ForLoop* alloc_loop,
|
|
const std::vector<kir::ForLoop*>& loops,
|
|
const std::unordered_map<IterDomain*, IterDomain*>& id_map = {}) {
|
|
if (tv->getMemoryType() != MemoryType::Local) {
|
|
return;
|
|
}
|
|
|
|
bool within_alloc = false;
|
|
if (alloc_loop == nullptr) {
|
|
within_alloc = true;
|
|
}
|
|
|
|
for (auto loop : loops) {
|
|
if (!within_alloc) {
|
|
if (loop == alloc_loop) {
|
|
within_alloc = true;
|
|
}
|
|
continue;
|
|
}
|
|
IterDomain* loop_id = loop->iter_domain();
|
|
if (loop->vectorize() || loop_id->isThread()) {
|
|
continue;
|
|
}
|
|
// Look for a domain that is mapped with the loop. If mapped in
|
|
// the loop map, the loop index should be used for indexing of the
|
|
// tensor, except for broadcast and reduction domains.
|
|
auto it = std::find_if(
|
|
tv->domain()->domain().begin(),
|
|
tv->domain()->domain().end(),
|
|
[loop_id, &id_map](IterDomain* id) {
|
|
if (id->isBroadcast() || id->isReduction() || id->isStride()) {
|
|
return false;
|
|
}
|
|
auto id_replacement = id_map.find(id);
|
|
if (id_replacement != id_map.end()) {
|
|
id = id_replacement->second;
|
|
}
|
|
return GpuLower::current()->caLoopMap().areMapped(loop_id, id);
|
|
});
|
|
if (it != tv->domain()->domain().end()) {
|
|
loop->requireUnroll();
|
|
}
|
|
}
|
|
}
|
|
|
|
// Map everything we can from reference to provided tv using the provided
|
|
// compute at map. If root_only is true, only root domains are included.
|
|
// We can't simply try to use the provided tv root domains and
|
|
// map those to the reference as the provided tv may have root domains that
|
|
// don't exist in reference. This can happen when the provided tv is from before
|
|
// a view, but all the loops are generated from TVs generated after the view
|
|
// operation.
|
|
std::unordered_map<IterDomain*, IterDomain*> indexMapReferenceTo(
|
|
const TensorView* tv,
|
|
const ComputeAtMap& ca_map,
|
|
const std::unordered_map<IterDomain*, IterDomain*>&
|
|
reference_concrete_to_id_map,
|
|
bool root_only = false) {
|
|
std::unordered_map<IterDomain*, IterDomain*> index_map_ref_to_producer;
|
|
|
|
auto gen_map = [&](const auto& pids) {
|
|
for (auto p_id : pids) {
|
|
auto concrete_id = ca_map.getConcreteMappedID(p_id);
|
|
auto ref_id_it = reference_concrete_to_id_map.find(concrete_id);
|
|
if (ref_id_it != reference_concrete_to_id_map.end()) {
|
|
index_map_ref_to_producer[ref_id_it->second] = p_id;
|
|
}
|
|
}
|
|
};
|
|
|
|
if (root_only) {
|
|
gen_map(tv->getRootDomain());
|
|
} else {
|
|
auto all_pid_vals = DependencyCheck::getAllValsBetween(
|
|
{tv->getRootDomain().begin(), tv->getRootDomain().end()},
|
|
{tv->domain()->domain().begin(), tv->domain()->domain().end()});
|
|
auto all_pids = ir_utils::filterByType<IterDomain>(all_pid_vals);
|
|
gen_map(all_pids);
|
|
}
|
|
|
|
return index_map_ref_to_producer;
|
|
}
|
|
|
|
Val* hoistConsumerIndex(
|
|
IterDomain* consumer_root_id,
|
|
const TensorView* consumer_tv,
|
|
const IndexCompute& consumer_indexing,
|
|
TensorDomain* ref_td,
|
|
const IndexCompute& ref_indexing,
|
|
const std::vector<kir::ForLoop*>& loops,
|
|
Val* index) {
|
|
// If index has no defining expression, there's nothing to hoist
|
|
if (disableIndexHoisting() || index->definition() == nullptr) {
|
|
return index;
|
|
}
|
|
|
|
// The old swizzle interface, which should be deprecated, is not
|
|
// supported.
|
|
if (consumer_tv->swizzleType() != SwizzleType::NoSwizzle) {
|
|
return index;
|
|
}
|
|
|
|
// Find the true indexed domain, which can be a merged contiguous domain.
|
|
auto indexed_consumer_id_it =
|
|
consumer_indexing.rootToContigID().find(consumer_root_id);
|
|
TORCH_INTERNAL_ASSERT(
|
|
indexed_consumer_id_it != consumer_indexing.rootToContigID().end(),
|
|
"Consumer indexed ID not found: ",
|
|
consumer_root_id->toString());
|
|
auto indexed_consumer_id = indexed_consumer_id_it->second;
|
|
|
|
// Insert the index into the common index map. A previously inserted
|
|
// val can be returned.
|
|
auto common_index = GpuLower::current()
|
|
->commonIndexMap()
|
|
.insert(
|
|
indexed_consumer_id,
|
|
consumer_tv->domain(),
|
|
ref_td,
|
|
ref_indexing.indexMap(),
|
|
loops,
|
|
index)
|
|
.first;
|
|
|
|
return common_index;
|
|
}
|
|
|
|
std::unordered_map<IterDomain*, IterDomain*> invertOneToOneMap(
|
|
const std::unordered_map<IterDomain*, IterDomain*>& map) {
|
|
std::unordered_map<IterDomain*, IterDomain*> inverted;
|
|
for (const auto& kv : map) {
|
|
bool inserted = inverted.emplace(kv.second, kv.first).second;
|
|
TORCH_INTERNAL_ASSERT(
|
|
inserted,
|
|
"Multiple mappings to the same value detected: ",
|
|
kv.second->toString());
|
|
}
|
|
return inverted;
|
|
}
|
|
|
|
Val* hoistProducerIndex(
|
|
IterDomain* producer_root_id,
|
|
const TensorView* producer_tv,
|
|
const IndexCompute& producer_indexing,
|
|
const TensorView* consumer_tv,
|
|
const std::unordered_map<IterDomain*, IterDomain*>& p2c_map,
|
|
TensorDomain* ref_td,
|
|
const IndexCompute& ref_indexing,
|
|
const std::vector<kir::ForLoop*>& loops,
|
|
Val* index) {
|
|
// If index has no defining expression, there's nothing to hoist
|
|
if (disableIndexHoisting() || index->definition() == nullptr) {
|
|
return index;
|
|
}
|
|
|
|
// The old swizzle interface, which should be deprecated, is not
|
|
// supported.
|
|
if (producer_tv->swizzleType() != SwizzleType::NoSwizzle) {
|
|
return index;
|
|
}
|
|
|
|
auto indexed_producer_id_it =
|
|
producer_indexing.rootToContigID().find(producer_root_id);
|
|
TORCH_INTERNAL_ASSERT(
|
|
indexed_producer_id_it != producer_indexing.rootToContigID().end(),
|
|
"Producer indexed ID not found: ",
|
|
producer_root_id->toString());
|
|
auto indexed_producer_id = indexed_producer_id_it->second;
|
|
|
|
// Use the corresponding consumer domain to find matching
|
|
// for-loops. Note that there's no CA mapping with the producer
|
|
// domains as the producer TensorDomain is a temporary replay
|
|
// domain.
|
|
|
|
auto indexed_consumer_id_it = p2c_map.find(indexed_producer_id);
|
|
|
|
// There can be no corresponding consumer ID. For example, consider:
|
|
// consumer: [b1, i2, i3]
|
|
// producer: [i2, i3].
|
|
// Suppose the consumer is transformed as:
|
|
// consumer: [(b1*i2)*i3]
|
|
// Then the producer would be transformed when indexed:
|
|
// producer: [i2*i3]
|
|
// Assuming i2 and i3 are contiguous, the producer indexing is done
|
|
// with the mreged i2*i3 domain, but there's no domain in the
|
|
// cosumer that maps with the producer indexed domain.
|
|
// It seems non-trivial to support patterns like this. Skip for now.
|
|
if (indexed_consumer_id_it == p2c_map.end()) {
|
|
return index;
|
|
}
|
|
|
|
IterDomain* indexed_consumer_id = indexed_consumer_id_it->second;
|
|
|
|
auto common_index = GpuLower::current()
|
|
->commonIndexMap()
|
|
.insert(
|
|
indexed_consumer_id,
|
|
consumer_tv->domain(),
|
|
ref_td,
|
|
ref_indexing.indexMap(),
|
|
loops,
|
|
index)
|
|
.first;
|
|
|
|
return common_index;
|
|
}
|
|
|
|
} // namespace
|
|
|
|
std::vector<Val*> Index::getGlobalProducerStridedIndices(
|
|
TensorView* producer_tv,
|
|
const TensorView* consumer_tv,
|
|
const std::vector<kir::ForLoop*>& loops) {
|
|
FUSER_PERF_SCOPE("GpuLower::Lower::getGlobalProducerIndex");
|
|
const auto gpu_lower = GpuLower::current();
|
|
|
|
// Get a reference tensor replayed as existing loop structure
|
|
auto reference = IndexReferenceReplay::getReference(loops);
|
|
auto reference_domain = reference.domain;
|
|
auto reference_id_map = reference.concrete_to_id;
|
|
|
|
// Replay producer to look like consumer so we can index on producer since
|
|
// our loop nests look like consumer
|
|
auto pairwise_map = PairwiseRootDomainMap(producer_tv, consumer_tv);
|
|
auto producerAsC =
|
|
TransformReplay::replayPasC(producer_tv, consumer_tv, -1, pairwise_map)
|
|
.first;
|
|
|
|
// Make the producer_tv look like consumer while performing indexing math
|
|
ir_utils::TVDomainGuard domain_guard(producer_tv, producerAsC);
|
|
|
|
// Map everything we can from reference to producer using compute at index
|
|
// map. Use consumer as a proxy between producer and the generated reference.
|
|
std::unordered_map<IterDomain*, IterDomain*> index_map_ref_to_producer;
|
|
|
|
// This replay has to be consistent with compute at index map.
|
|
BestEffortReplay replay_producer_as_consumer(
|
|
producer_tv->domain()->domain(),
|
|
consumer_tv->domain()->domain(),
|
|
pairwise_map.mapConsumerToProducer(
|
|
consumer_tv->domain(), producer_tv->domain()));
|
|
|
|
const auto& c2p_map = replay_producer_as_consumer.getReplay();
|
|
const auto p2c_map = invertOneToOneMap(c2p_map);
|
|
{
|
|
std::unordered_map<IterDomain*, IterDomain*> index_map_ref_to_consumer =
|
|
indexMapReferenceTo(
|
|
consumer_tv, gpu_lower->caIndexMap(), reference_id_map);
|
|
|
|
for (auto entry : index_map_ref_to_consumer) {
|
|
auto r_id = entry.first;
|
|
auto c_id = entry.second;
|
|
auto c2p_it = c2p_map.find(c_id);
|
|
if (c2p_it != c2p_map.end()) {
|
|
auto p_id = c2p_it->second;
|
|
index_map_ref_to_producer[r_id] = p_id;
|
|
}
|
|
}
|
|
}
|
|
|
|
kir::ForLoop* db_loop = gpu_lower->doubleBufferInfo().getDoubleBufferLoop(
|
|
consumer_tv, loops, true);
|
|
|
|
// Index into the reference tensor. Reference indexing will handle vectorized
|
|
// dims where index should be set to 0
|
|
auto ref_compute = getReferenceIndexing(loops, reference_domain, db_loop);
|
|
|
|
// Forward vectorized IDs to index into producer correctly
|
|
// We want p_id to be vectorized like consumer just for the indexing, then we
|
|
// need to switch it back later. Store previous state here when changing. We
|
|
// need to do this as replaying producer as consumer can use replay best
|
|
// effort which means some domains may be producer's original domains.
|
|
std::vector<std::pair<IterDomain*, ParallelType>> p_id_backup;
|
|
for (auto entry : index_map_ref_to_producer) {
|
|
auto ref_id = entry.first;
|
|
auto p_id = entry.second;
|
|
if (ref_id->getParallelType() == ParallelType::Vectorize) {
|
|
p_id_backup.emplace_back(std::make_pair(p_id, p_id->getParallelType()));
|
|
p_id->parallelize(ParallelType::Vectorize);
|
|
} else if (ref_id->getParallelType() == ParallelType::MisalignedVectorize) {
|
|
p_id->parallelize(ParallelType::MisalignedVectorize);
|
|
}
|
|
}
|
|
|
|
// Adds halo info mappings for the reference
|
|
updateHaloInfoForReference(reference, consumer_tv);
|
|
|
|
const auto reference_halo_extent_map =
|
|
getReferenceHaloExtentMap(reference, index_map_ref_to_producer);
|
|
|
|
// Index into producer using reference indexing
|
|
auto producer_indexing = ref_compute.updateIndexCompute(
|
|
producer_tv->domain(),
|
|
index_map_ref_to_producer,
|
|
producer_tv->domain()->contiguity(),
|
|
reference_halo_extent_map);
|
|
|
|
// Revert p_ids
|
|
for (auto entry : p_id_backup) {
|
|
entry.first->parallelize(entry.second);
|
|
}
|
|
|
|
// Indices should now be mapped onto IterDomains in producer, so just grab
|
|
// and use them.
|
|
auto root_dom = producer_tv->getMaybeRFactorDomain();
|
|
|
|
// TODO: Abstract stride logic to reuse with consumer indexing
|
|
std::vector<Val*> strides(root_dom.size(), nullptr);
|
|
{
|
|
int stride_i = 0;
|
|
for (const auto i : c10::irange(root_dom.size())) {
|
|
if (root_dom[i]->isReduction() ||
|
|
root_dom[i]->getIterType() == IterType::BroadcastWithoutStride) {
|
|
strides[i] = GpuLower::current()->kernel()->oneVal();
|
|
continue;
|
|
}
|
|
std::stringstream ss;
|
|
ss << "T" << producer_tv->name() << ".stride[" << stride_i++ << "]";
|
|
strides[i] =
|
|
SimplifyingIrBuilder::create<NamedScalar>(ss.str(), DataType::Int);
|
|
}
|
|
}
|
|
|
|
TORCH_INTERNAL_ASSERT(
|
|
root_dom.size() == producer_tv->domain()->contiguity().size());
|
|
Val* cur_contig_stride = GpuLower::current()->kernel()->oneVal();
|
|
for (const auto i : c10::irange(root_dom.size())) {
|
|
auto dim = root_dom.size() - i - 1;
|
|
if (root_dom[dim]->isReduction()) {
|
|
continue;
|
|
}
|
|
if (root_dom[dim]->getIterType() == IterType::BroadcastWithoutStride) {
|
|
continue;
|
|
}
|
|
|
|
Val* root_ind = nullptr;
|
|
if (producer_indexing.indexMap().find(root_dom[dim]) !=
|
|
producer_indexing.indexMap().end()) {
|
|
root_ind = producer_indexing.indexMap().at(root_dom[dim]);
|
|
} else if (root_dom[dim]->getIterType() == IterType::BroadcastWithStride) {
|
|
root_ind = GpuLower::current()->kernel()->zeroVal();
|
|
}
|
|
|
|
TORCH_INTERNAL_ASSERT(
|
|
root_ind != nullptr,
|
|
"Couldn't find root mapping for TV",
|
|
producer_tv->name(),
|
|
" dim: ",
|
|
i,
|
|
" id: ",
|
|
root_dom[dim]);
|
|
|
|
if (producer_tv->domain()->contiguity()[dim]) {
|
|
// If contig, used the stored stride which may be the previous
|
|
// dimensions stride * previous dimensions size
|
|
strides[dim] = cur_contig_stride;
|
|
// Prepare for the next dimension which may also be contiguous, multiply
|
|
// by extent of this dimension
|
|
auto root_dim_extent = getHaloExtentOfRootAxis(root_dom[dim]);
|
|
cur_contig_stride =
|
|
SimplifyingIrBuilder::mulExpr(cur_contig_stride, root_dim_extent);
|
|
} else {
|
|
// If non contiguous dimension, keep local stride information, set cur
|
|
// stride to local stride * local raw extent
|
|
auto root_dim_extent = getHaloExtentOfRootAxis(root_dom[dim]);
|
|
cur_contig_stride =
|
|
SimplifyingIrBuilder::mulExpr(strides[dim], root_dim_extent);
|
|
}
|
|
}
|
|
|
|
auto vectorize_shift =
|
|
loops.empty() ? nullptr : loops.back()->vectorize_shift();
|
|
|
|
// Global striding
|
|
std::vector<Val*> strided_inds(
|
|
root_dom.size(), GpuLower::current()->kernel()->zeroVal());
|
|
for (const auto i : c10::irange(root_dom.size())) {
|
|
// If the domain is derived from a trivial reduction, no indexing
|
|
// to create.
|
|
if (root_dom[i]->isReduction() ||
|
|
root_dom[i]->getIterType() == IterType::BroadcastWithoutStride ||
|
|
root_dom[i]->getIterType() == IterType::BroadcastWithStride ||
|
|
gpu_lower->trivialReductionInfo().isDerived(root_dom[i])) {
|
|
continue;
|
|
}
|
|
|
|
TORCH_INTERNAL_ASSERT(
|
|
producer_indexing.indexMap().find(root_dom[i]) !=
|
|
producer_indexing.indexMap().end(),
|
|
"Couldn't find root mapping for TV",
|
|
producer_tv->name(),
|
|
" dim: ",
|
|
i,
|
|
" id: ",
|
|
root_dom[i]->toString());
|
|
|
|
auto root_ind = producer_indexing.indexMap().at(root_dom[i]);
|
|
|
|
// index hoist must be done before the adjustments for halo
|
|
root_ind = hoistProducerIndex(
|
|
root_dom[i],
|
|
producer_tv,
|
|
producer_indexing,
|
|
consumer_tv,
|
|
p2c_map,
|
|
reference.domain,
|
|
ref_compute,
|
|
loops,
|
|
root_ind);
|
|
|
|
root_ind = getProducerIndexWithHalo(producer_tv, i, root_ind, consumer_tv);
|
|
|
|
root_ind = getProducerIndexWithGather(
|
|
root_ind,
|
|
i,
|
|
producer_tv,
|
|
consumer_tv,
|
|
reference_id_map,
|
|
ref_compute.indexMap());
|
|
|
|
root_ind = getProducerIndexWithPartialSplit(
|
|
root_ind, root_dom[i], producer_tv, consumer_tv);
|
|
|
|
if (root_ind->isZeroInt()) {
|
|
continue;
|
|
} else {
|
|
auto strided_ind = SimplifyingIrBuilder::mulExpr(root_ind, strides[i]);
|
|
if (i == root_dom.size() - 1 && vectorize_shift != nullptr) {
|
|
strided_inds[i] =
|
|
SimplifyingIrBuilder::addExpr(strided_ind, vectorize_shift);
|
|
} else {
|
|
strided_inds[i] = strided_ind;
|
|
}
|
|
}
|
|
}
|
|
|
|
return strided_inds;
|
|
}
|
|
|
|
// Producer index for either shared or local memory
|
|
std::vector<Val*> Index::getNonGlobalProducerStridedIndices(
|
|
TensorView* producer_tv,
|
|
const TensorView* consumer_tv,
|
|
const std::vector<kir::ForLoop*>& loops) {
|
|
const auto gpu_lower = GpuLower::current();
|
|
|
|
// Get a reference tensor replayed as existing loop structure
|
|
auto reference = IndexReferenceReplay::getReference(loops);
|
|
auto reference_domain = reference.domain;
|
|
auto reference_id_map = reference.concrete_to_id;
|
|
|
|
// Replay producer to look like consumer so we can index on producer since our
|
|
// loop nests look like consumer
|
|
auto pairwise_map = PairwiseRootDomainMap(producer_tv, consumer_tv);
|
|
auto producer_replayed_as_consumer =
|
|
TransformReplay::replayPasC(producer_tv, consumer_tv, -1, pairwise_map)
|
|
.first;
|
|
|
|
ir_utils::TVDomainGuard domain_guard(
|
|
producer_tv, producer_replayed_as_consumer);
|
|
|
|
// This map has forwarded broadcast axes, it should only be used to compute
|
|
// the allocation position of the producer, and to figure out which producer
|
|
// indices are mapped to consumer trivial reductions.
|
|
std::unordered_map<IterDomain*, IterDomain*> p2c_alloc_map;
|
|
|
|
// We want to play producer as consumer instead of the other way around
|
|
// since consumer may have some broadcasted axes producer doesn't have
|
|
// merged into loops producer may use. If we did consumer as producer we
|
|
// wouldn't have this information in the mapping.
|
|
auto replay_PasC =
|
|
BestEffortReplay::replayPasC(producer_tv, consumer_tv, -1, pairwise_map);
|
|
|
|
const auto& c2p_map = replay_PasC.getReplay();
|
|
const auto p2c_map = invertOneToOneMap(c2p_map);
|
|
|
|
// Grab consumer domain entries and reverse replay map. TODO: Maybe
|
|
// TransformReplay::replayPasC could return this map
|
|
for (auto id : consumer_tv->domain()->domain()) {
|
|
auto c2p_it = c2p_map.find(id);
|
|
if (c2p_it != c2p_map.end()) {
|
|
auto c_id = c2p_it->first;
|
|
auto p_id = c2p_it->second;
|
|
p2c_alloc_map[p_id] = c_id;
|
|
}
|
|
}
|
|
|
|
kir::ForLoop* consumer_db_loop =
|
|
gpu_lower->doubleBufferInfo().getDoubleBufferLoop(
|
|
consumer_tv, loops, true);
|
|
|
|
// Find allocation point of producer relative to loop nests. P2C map is
|
|
// required because producer was replayed as consumer, so we can't use the
|
|
// regular compute at maps to line up its iter domains with the for loops.
|
|
auto alloc_info =
|
|
loop_utils::getAllocInformation(producer_tv, loops, p2c_alloc_map, true);
|
|
std::unordered_map<kir::ForLoop*, Val*> loop_to_ind_map;
|
|
std::unordered_set<kir::ForLoop*> zero_loops;
|
|
std::tie(loop_to_ind_map, zero_loops) = indexMapFromTV(
|
|
producer_tv, loops, alloc_info.init_for_loop, false, consumer_db_loop);
|
|
|
|
ensureStaticIndexing(
|
|
producer_tv, alloc_info.init_for_loop, loops, p2c_alloc_map);
|
|
|
|
// Map loop nests to indicies, zeroing out those not used due to locality of
|
|
// memory
|
|
std::unordered_map<IterDomain*, Val*> ref_id_to_ind_map;
|
|
// Track which domains are not used
|
|
std::unordered_set<IterDomain*> ref_zero_domains;
|
|
|
|
// Due to rfactor/initialization reference_domain may be bigger than loop nest
|
|
// structure, ignore IterDomains that aren't present in the loop nest when
|
|
// indexing reference.
|
|
TORCH_INTERNAL_ASSERT(loops.size() <= reference_domain->nDims());
|
|
for (const auto loop_i : c10::irange(loops.size())) {
|
|
auto ref_axis = reference_domain->axis(loop_i);
|
|
ref_id_to_ind_map[ref_axis] = loop_to_ind_map[loops[loop_i]];
|
|
if (zero_loops.count(loops[loop_i]) > 0) {
|
|
ref_zero_domains.insert(ref_axis);
|
|
}
|
|
}
|
|
|
|
// Map everything we can from reference to producer using compute at index
|
|
// map. All producer id's don't exist in the compute at map. The rfactor axes
|
|
// all may be, but since I haven't proven that to be the case, going to do a
|
|
// more conservative approach, which is to use the consumer as a proxy between
|
|
// producer to reference.
|
|
std::unordered_map<IterDomain*, IterDomain*> index_map_ref_to_producer;
|
|
{
|
|
// This replay has to be consistent with compute at index map.
|
|
BestEffortReplay replay_producer_as_consumer(
|
|
producer_tv->domain()->domain(),
|
|
consumer_tv->domain()->domain(),
|
|
pairwise_map.mapConsumerToProducer(
|
|
consumer_tv->domain(), producer_tv->domain()));
|
|
|
|
const auto& c2p_map = replay_producer_as_consumer.getReplay();
|
|
|
|
std::unordered_map<IterDomain*, IterDomain*> index_map_ref_to_consumer =
|
|
indexMapReferenceTo(
|
|
consumer_tv, gpu_lower->caIndexMap(), reference_id_map);
|
|
|
|
for (auto entry : index_map_ref_to_consumer) {
|
|
auto r_id = entry.first;
|
|
auto c_id = entry.second;
|
|
auto c2p_it = c2p_map.find(c_id);
|
|
if (c2p_it != c2p_map.end()) {
|
|
auto p_id = c2p_it->second;
|
|
index_map_ref_to_producer[r_id] = p_id;
|
|
}
|
|
}
|
|
}
|
|
|
|
// Grab roots that map into producer and save them into the preferred roots
|
|
// set for references indexing
|
|
std::unordered_set<IterDomain*> preferred_roots;
|
|
for (auto entry : index_map_ref_to_producer) {
|
|
if (entry.second->isBroadcast() || entry.second->isReduction() ||
|
|
entry.second->isStride()) {
|
|
continue;
|
|
}
|
|
preferred_roots.emplace(entry.first);
|
|
}
|
|
|
|
// Make sure propagation of indexing while mixing with 0 indicies we propagate
|
|
// in a way that the producer will be able to see what's going on (propagating
|
|
// into common roots of reference and producer).
|
|
auto preferred_paths = buildPreferredPaths(reference_domain, preferred_roots);
|
|
|
|
// Index into the reference tensor
|
|
auto ref_compute = getReferenceIndexing(
|
|
loops,
|
|
reference_domain,
|
|
ref_id_to_ind_map,
|
|
ref_zero_domains,
|
|
preferred_paths);
|
|
|
|
// Forward vectorized IDs to index into producer correctly
|
|
// We want p_id to be vectorized like consumer just for the indexing, then we
|
|
// need to switch it back later. Store previous state here when changing. We
|
|
// need to do this as replaying producer as consumer can use replay best
|
|
// effort which means some domains may be the originals.
|
|
std::vector<std::pair<IterDomain*, ParallelType>> p_id_backup;
|
|
for (auto entry : index_map_ref_to_producer) {
|
|
auto ref_id = entry.first;
|
|
auto p_id = entry.second;
|
|
if (ref_id->getParallelType() == ParallelType::Vectorize) {
|
|
p_id_backup.emplace_back(std::make_pair(p_id, p_id->getParallelType()));
|
|
p_id->parallelize(ParallelType::Vectorize);
|
|
} else if (ref_id->getParallelType() == ParallelType::MisalignedVectorize) {
|
|
p_id->parallelize(ParallelType::MisalignedVectorize);
|
|
}
|
|
}
|
|
|
|
// Index into producer using reference indexing
|
|
|
|
// Adds halo info mappings for the reference
|
|
updateHaloInfoForReference(reference, consumer_tv);
|
|
|
|
const auto reference_halo_extent_map =
|
|
getReferenceHaloExtentMap(reference, index_map_ref_to_producer);
|
|
|
|
auto producer_indexing = ref_compute.updateIndexCompute(
|
|
producer_tv->domain(),
|
|
index_map_ref_to_producer,
|
|
producer_tv->domain()->contiguity(),
|
|
reference_halo_extent_map);
|
|
|
|
// Revert p_ids
|
|
for (auto entry : p_id_backup) {
|
|
entry.first->parallelize(entry.second);
|
|
}
|
|
|
|
IndexSwizzle index_swizzle(
|
|
producer_tv,
|
|
producer_indexing.indexMap(),
|
|
producer_indexing.extentMap(),
|
|
producer_indexing.zeroDomains(),
|
|
producer_indexing.zeroMergedIn());
|
|
|
|
index_swizzle.run();
|
|
|
|
const auto& index_map = index_swizzle.indexMap();
|
|
const auto& extent_map = producer_indexing.extentMap();
|
|
const auto& zero_domain_map = producer_indexing.zeroDomains();
|
|
// Indices should now be mapped onto IterDomains in producer, so just grab
|
|
// and use them.
|
|
auto root_dom = producer_tv->getMaybeRFactorDomain();
|
|
|
|
// Figure out which root axes we don't need to index
|
|
std::unordered_set<IterDomain*> skip_indexing;
|
|
|
|
for (auto root_id : root_dom) {
|
|
// Already taken care of because we can detect no indexing required
|
|
if (root_id->isBroadcast() || root_id->isReduction() ||
|
|
gpu_lower->trivialReductionInfo().isDerived(root_id) ||
|
|
root_id->isStride()) {
|
|
skip_indexing.insert(root_id);
|
|
continue;
|
|
}
|
|
|
|
// Already an entry for this root domain, continue
|
|
if (index_map.find(root_id) != index_map.end()) {
|
|
continue;
|
|
}
|
|
|
|
// Maps to consumers trivial reduction, don't index
|
|
if (p2c_alloc_map.find(root_id) != p2c_alloc_map.end() &&
|
|
gpu_lower->trivialReductionInfo().isDerived(
|
|
p2c_alloc_map.at(root_id))) {
|
|
skip_indexing.emplace(root_id);
|
|
}
|
|
}
|
|
|
|
std::vector<Val*> strided_inds(
|
|
root_dom.size(), GpuLower::current()->kernel()->zeroVal());
|
|
for (const auto i : c10::irange(root_dom.size())) {
|
|
if (skip_indexing.count(root_dom[i])) {
|
|
continue;
|
|
}
|
|
|
|
TORCH_INTERNAL_ASSERT(
|
|
index_map.find(root_dom[i]) != index_map.end(),
|
|
"Couldn't find root mapping for TV",
|
|
producer_tv->name(),
|
|
" dim: ",
|
|
i,
|
|
" id: ",
|
|
root_dom[i]->toString());
|
|
|
|
auto root_ind_i = index_map.at(root_dom[i]);
|
|
|
|
// index hoist must be done before the adjustments for halo
|
|
root_ind_i = hoistProducerIndex(
|
|
root_dom[i],
|
|
producer_tv,
|
|
producer_indexing,
|
|
consumer_tv,
|
|
c2p_map,
|
|
reference.domain,
|
|
ref_compute,
|
|
loops,
|
|
root_ind_i);
|
|
|
|
root_ind_i =
|
|
getProducerIndexWithHalo(producer_tv, i, root_ind_i, consumer_tv);
|
|
|
|
root_ind_i = getProducerIndexWithGather(
|
|
root_ind_i,
|
|
i,
|
|
producer_tv,
|
|
consumer_tv,
|
|
reference_id_map,
|
|
ref_compute.indexMap());
|
|
|
|
root_ind_i = getProducerIndexWithPartialSplit(
|
|
root_ind_i, root_dom[i], producer_tv, consumer_tv);
|
|
|
|
if (root_ind_i->isZeroInt()) {
|
|
continue;
|
|
}
|
|
|
|
// Compute striding for this index.
|
|
Val* stride = nullptr;
|
|
for (const auto j : c10::irange(i + 1, root_dom.size())) {
|
|
if (skip_indexing.count(root_dom[j])) {
|
|
continue;
|
|
}
|
|
|
|
TORCH_INTERNAL_ASSERT(
|
|
index_map.find(root_dom[j]) != index_map.end(),
|
|
"Couldn't find root mapping for TV",
|
|
consumer_tv->name(),
|
|
" dim: ",
|
|
i,
|
|
" id: ",
|
|
root_dom[i]);
|
|
|
|
auto root_ext_j = extent_map.find(root_dom[j]) == extent_map.end()
|
|
? root_dom[j]->extent()
|
|
: extent_map.at(root_dom[j]);
|
|
|
|
root_ext_j = getHaloExtentOfRootAxis(root_dom[j], root_ext_j);
|
|
|
|
if (zero_domain_map.count(root_dom[j]) == 0) {
|
|
if (stride == nullptr) {
|
|
stride = root_ext_j;
|
|
} else {
|
|
stride = SimplifyingIrBuilder::mulExpr(stride, root_ext_j);
|
|
}
|
|
}
|
|
}
|
|
|
|
if (stride != nullptr) {
|
|
strided_inds[i] = SimplifyingIrBuilder::mulExpr(root_ind_i, stride);
|
|
} else {
|
|
strided_inds[i] = root_ind_i;
|
|
}
|
|
}
|
|
|
|
if (producer_tv->isDoubleBuffered()) {
|
|
auto db_loop = gpu_lower->doubleBufferInfo().getDoubleBufferLoop(
|
|
producer_tv, loops, true);
|
|
if (db_loop != nullptr) {
|
|
auto loop_index =
|
|
db_loop->isTrivial() ? db_loop->start() : db_loop->index();
|
|
auto db_switch_index = SimplifyingIrBuilder::modExpr(
|
|
loop_index, SimplifyingIrBuilder::create<Int>(2));
|
|
auto original_alloc_size =
|
|
gpu_lower->doubleBufferInfo().getOriginalAllocSize(producer_tv);
|
|
auto db_strided_index =
|
|
SimplifyingIrBuilder::mulExpr(db_switch_index, original_alloc_size);
|
|
strided_inds.push_back(db_strided_index);
|
|
}
|
|
}
|
|
return strided_inds;
|
|
}
|
|
|
|
std::vector<Val*> Index::getGlobalConsumerStridedIndices(
|
|
const TensorView* consumer_tv,
|
|
const std::vector<kir::ForLoop*>& loops) {
|
|
FUSER_PERF_SCOPE("GpuLower::Lower::getGlobalConsumerIndex");
|
|
const auto gpu_lower = GpuLower::current();
|
|
|
|
// Get a reference tensor replayed as existing loop structure
|
|
auto reference = IndexReferenceReplay::getReference(loops);
|
|
auto reference_domain = reference.domain;
|
|
auto reference_id_map = reference.concrete_to_id;
|
|
|
|
// Map everything we can from reference to consumer using compute at index
|
|
// map.
|
|
std::unordered_map<IterDomain*, IterDomain*> index_map_ref_to_consumer =
|
|
indexMapReferenceTo(
|
|
consumer_tv, gpu_lower->caIndexMap(), reference_id_map);
|
|
|
|
// Index into the reference tensor. Reference indexing will handle vectorized
|
|
// dims where index should be set to 0
|
|
auto ref_compute = getReferenceIndexing(loops, reference_domain);
|
|
|
|
// Index into consumer using reference indexing
|
|
|
|
// Adds halo info mappings for the reference
|
|
updateHaloInfoForReference(reference, consumer_tv);
|
|
|
|
const auto reference_halo_extent_map =
|
|
getReferenceHaloExtentMap(reference, index_map_ref_to_consumer);
|
|
|
|
auto consumer_indexing = ref_compute.updateIndexCompute(
|
|
consumer_tv->domain(),
|
|
index_map_ref_to_consumer,
|
|
consumer_tv->domain()->contiguity(),
|
|
reference_halo_extent_map);
|
|
|
|
// Indices should now be mapped onto IterDomains in consumer, so just grab
|
|
// and use them.
|
|
auto root_dom = consumer_tv->getMaybeRFactorDomain();
|
|
|
|
// TODO: Abstract stride logic to reuse with producer indexing
|
|
std::vector<Val*> strides(
|
|
root_dom.size(), GpuLower::current()->kernel()->oneVal());
|
|
{
|
|
int stride_i = 0;
|
|
for (const auto i : c10::irange(root_dom.size())) {
|
|
if (root_dom[i]->isReduction() ||
|
|
root_dom[i]->getIterType() == IterType::BroadcastWithoutStride ||
|
|
root_dom[i]->isStride()) {
|
|
strides[i] = GpuLower::current()->kernel()->oneVal();
|
|
continue;
|
|
}
|
|
std::stringstream ss;
|
|
ss << "T" << consumer_tv->name() << ".stride[" << stride_i++ << "]";
|
|
strides[i] =
|
|
SimplifyingIrBuilder::create<NamedScalar>(ss.str(), DataType::Int);
|
|
}
|
|
}
|
|
|
|
TORCH_INTERNAL_ASSERT(
|
|
root_dom.size() == consumer_tv->domain()->contiguity().size());
|
|
Val* cur_contig_stride = GpuLower::current()->kernel()->oneVal();
|
|
for (const auto i : c10::irange(root_dom.size())) {
|
|
auto dim = root_dom.size() - i - 1;
|
|
if (root_dom[dim]->isReduction() || root_dom[dim]->isStride()) {
|
|
continue;
|
|
}
|
|
if (root_dom[dim]->getIterType() == IterType::BroadcastWithoutStride) {
|
|
continue;
|
|
}
|
|
|
|
Val* root_ind = nullptr;
|
|
if (consumer_indexing.indexMap().find(root_dom[dim]) !=
|
|
consumer_indexing.indexMap().end()) {
|
|
root_ind = consumer_indexing.indexMap().at(root_dom[dim]);
|
|
} else if (root_dom[dim]->getIterType() == IterType::BroadcastWithStride) {
|
|
root_ind = GpuLower::current()->kernel()->zeroVal();
|
|
}
|
|
|
|
TORCH_INTERNAL_ASSERT(
|
|
root_ind != nullptr,
|
|
"Couldn't find root mapping for TV",
|
|
consumer_tv->name(),
|
|
" dim: ",
|
|
i,
|
|
" id: ",
|
|
root_dom[dim]);
|
|
|
|
if (consumer_tv->domain()->contiguity()[dim]) {
|
|
// If contig, used the stored stride which may be the previous
|
|
// dimensions stride * previous dimensions size
|
|
strides[dim] = cur_contig_stride;
|
|
// Prepare for the next dimension which may also be contiguous, multiply
|
|
// by extent of this dimension
|
|
auto root_dim_extent = getHaloExtentOfRootAxis(root_dom[dim]);
|
|
cur_contig_stride =
|
|
SimplifyingIrBuilder::mulExpr(cur_contig_stride, root_dim_extent);
|
|
} else {
|
|
// If non contiguous dimension, keep local stride information, set cur
|
|
// stride to local stride * local raw extent
|
|
cur_contig_stride = SimplifyingIrBuilder::mulExpr(
|
|
strides[dim], getHaloExtentOfRootAxis(root_dom[dim]));
|
|
}
|
|
}
|
|
|
|
auto vectorize_shift =
|
|
loops.empty() ? nullptr : loops.back()->vectorize_shift();
|
|
|
|
// Global striding
|
|
std::vector<Val*> strided_inds(
|
|
root_dom.size(), GpuLower::current()->kernel()->zeroVal());
|
|
for (const auto i : c10::irange(root_dom.size())) {
|
|
// See a comment in indexing to root domains in getGlobalProducerIndex.
|
|
if (root_dom[i]->isReduction() ||
|
|
root_dom[i]->getIterType() == IterType::BroadcastWithoutStride ||
|
|
root_dom[i]->getIterType() == IterType::BroadcastWithStride ||
|
|
gpu_lower->trivialReductionInfo().isDerived(root_dom[i]) ||
|
|
root_dom[i]->isStride()) {
|
|
continue;
|
|
}
|
|
|
|
TORCH_INTERNAL_ASSERT(
|
|
consumer_indexing.indexMap().find(root_dom[i]) !=
|
|
consumer_indexing.indexMap().end(),
|
|
"Couldn't find root mapping for TV",
|
|
consumer_tv->name(),
|
|
" dim: ",
|
|
i,
|
|
" id: ",
|
|
root_dom[i]->toString());
|
|
|
|
auto root_ind = consumer_indexing.indexMap().at(root_dom[i]);
|
|
|
|
// index hoist must be done before the adjustments for halo
|
|
root_ind = hoistConsumerIndex(
|
|
root_dom[i],
|
|
consumer_tv,
|
|
consumer_indexing,
|
|
reference.domain,
|
|
ref_compute,
|
|
loops,
|
|
root_ind);
|
|
|
|
root_ind = SimplifyingIrBuilder::addExpr(
|
|
root_ind, getGlobalConsumerOffsetWithPartialSplit(root_dom[i]));
|
|
|
|
if (root_ind->isZeroInt()) {
|
|
continue;
|
|
} else {
|
|
auto strided_ind = SimplifyingIrBuilder::mulExpr(root_ind, strides[i]);
|
|
if (i == root_dom.size() - 1 && vectorize_shift != nullptr) {
|
|
strided_inds[i] =
|
|
SimplifyingIrBuilder::addExpr(strided_ind, vectorize_shift);
|
|
} else {
|
|
strided_inds[i] = strided_ind;
|
|
}
|
|
}
|
|
}
|
|
|
|
TORCH_INTERNAL_ASSERT(
|
|
strided_inds.size() == consumer_tv->getMaybeRFactorDomain().size());
|
|
|
|
return strided_inds;
|
|
}
|
|
|
|
// Consumer index for either shared or local memory
|
|
std::vector<Val*> Index::getNonGlobalConsumerStridedIndices(
|
|
const TensorView* consumer_tv,
|
|
const std::vector<kir::ForLoop*>& loops) {
|
|
const auto gpu_lower = GpuLower::current();
|
|
|
|
// Get a reference tensor replayed as existing loop structure
|
|
auto reference = IndexReferenceReplay::getReference(loops);
|
|
auto reference_domain = reference.domain;
|
|
auto reference_id_map = reference.concrete_to_id;
|
|
|
|
auto alloc_info = loop_utils::getAllocInformation(consumer_tv, loops);
|
|
std::unordered_map<kir::ForLoop*, Val*> loop_to_ind_map;
|
|
std::unordered_set<kir::ForLoop*> zero_loops;
|
|
std::tie(loop_to_ind_map, zero_loops) =
|
|
indexMapFromTV(consumer_tv, loops, alloc_info.init_for_loop, true);
|
|
|
|
ensureStaticIndexing(consumer_tv, alloc_info.init_for_loop, loops);
|
|
|
|
// Map loop nests to indicies, zeroing out those not used due to locality of
|
|
// memory
|
|
std::unordered_map<IterDomain*, Val*> ref_id_to_ind_map;
|
|
std::unordered_set<IterDomain*> ref_zero_domains;
|
|
|
|
// Due to rfactor/initialization reference_domain may be bigger than loop nest
|
|
// structure, ignore IterDomains that aren't present in the loop nest when
|
|
// indexing reference.
|
|
TORCH_INTERNAL_ASSERT(loops.size() <= reference_domain->nDims());
|
|
for (const auto loop_i : c10::irange(loops.size())) {
|
|
auto ref_axis = reference_domain->axis(loop_i);
|
|
ref_id_to_ind_map[ref_axis] = loop_to_ind_map[loops[loop_i]];
|
|
if (zero_loops.count(loops[loop_i]) > 0) {
|
|
ref_zero_domains.insert(ref_axis);
|
|
}
|
|
}
|
|
|
|
// Map everything we can from reference to consumer using compute at index
|
|
// map.
|
|
std::unordered_map<IterDomain*, IterDomain*> index_map_ref_to_consumer =
|
|
indexMapReferenceTo(
|
|
consumer_tv, gpu_lower->caIndexMap(), reference_id_map);
|
|
|
|
// Grab roots that map into consumer and save them into the preferred roots
|
|
// set for references indexing
|
|
std::unordered_set<IterDomain*> preferred_roots;
|
|
for (auto entry : index_map_ref_to_consumer) {
|
|
if (entry.second->isBroadcast() || entry.second->isReduction() ||
|
|
entry.second->isStride()) {
|
|
continue;
|
|
}
|
|
preferred_roots.emplace(entry.first);
|
|
}
|
|
|
|
// Make sure propagation of indexing while mixing with 0 indicies we propagate
|
|
// in a way that consumer will be able to see what's going on.
|
|
auto preferred_paths = buildPreferredPaths(reference_domain, preferred_roots);
|
|
|
|
// Index into the reference tensor
|
|
auto ref_compute = getReferenceIndexing(
|
|
loops,
|
|
reference_domain,
|
|
ref_id_to_ind_map,
|
|
ref_zero_domains,
|
|
preferred_paths);
|
|
|
|
// Adds halo info mappings for the reference
|
|
updateHaloInfoForReference(reference, consumer_tv);
|
|
|
|
const auto reference_halo_extent_map =
|
|
getReferenceHaloExtentMap(reference, index_map_ref_to_consumer);
|
|
|
|
// Index into consumer using reference indexing
|
|
auto consumer_indexing = ref_compute.updateIndexCompute(
|
|
consumer_tv->domain(),
|
|
index_map_ref_to_consumer,
|
|
consumer_tv->domain()->contiguity(),
|
|
reference_halo_extent_map);
|
|
|
|
IndexSwizzle index_swizzle(
|
|
consumer_tv,
|
|
consumer_indexing.indexMap(),
|
|
consumer_indexing.extentMap(),
|
|
consumer_indexing.zeroDomains(),
|
|
consumer_indexing.zeroMergedIn());
|
|
|
|
index_swizzle.run();
|
|
|
|
const auto& index_map = index_swizzle.indexMap();
|
|
const auto& extent_map = consumer_indexing.extentMap();
|
|
const auto& zero_domain_map = consumer_indexing.zeroDomains();
|
|
|
|
// Indices should now be mapped onto IterDomains in consumer, so just grab
|
|
// and use them.
|
|
auto root_dom = consumer_tv->getMaybeRFactorDomain();
|
|
std::vector<Val*> strided_inds(
|
|
root_dom.size(), GpuLower::current()->kernel()->zeroVal());
|
|
for (const auto i : c10::irange(root_dom.size())) {
|
|
if (root_dom[i]->isReduction() || root_dom[i]->isBroadcast() ||
|
|
gpu_lower->trivialReductionInfo().isDerived(root_dom[i]) ||
|
|
root_dom[i]->isStride()) {
|
|
continue;
|
|
}
|
|
|
|
TORCH_INTERNAL_ASSERT(
|
|
index_map.find(root_dom[i]) != index_map.end(),
|
|
"Couldn't find root mapping for TV",
|
|
consumer_tv->name(),
|
|
" dim: ",
|
|
i,
|
|
" id: ",
|
|
root_dom[i]->toString());
|
|
|
|
auto root_ind_i = index_map.at(root_dom[i]);
|
|
if (root_ind_i->isZeroInt()) {
|
|
continue;
|
|
}
|
|
|
|
// index hoist must be done before the adjustments for halo
|
|
root_ind_i = hoistConsumerIndex(
|
|
root_dom[i],
|
|
consumer_tv,
|
|
consumer_indexing,
|
|
reference.domain,
|
|
ref_compute,
|
|
loops,
|
|
root_ind_i);
|
|
|
|
// Compute striding for this index.
|
|
Val* stride = nullptr;
|
|
for (const auto j : c10::irange(i + 1, root_dom.size())) {
|
|
if (root_dom[j]->isBroadcast() || root_dom[j]->isReduction() ||
|
|
gpu_lower->trivialReductionInfo().isDerived(root_dom[j]) ||
|
|
root_dom[j]->isStride()) {
|
|
continue;
|
|
}
|
|
|
|
TORCH_INTERNAL_ASSERT(
|
|
index_map.find(root_dom[j]) != index_map.end(),
|
|
"Couldn't find root mapping for TV",
|
|
consumer_tv->name(),
|
|
" dim: ",
|
|
i,
|
|
" id: ",
|
|
root_dom[i]);
|
|
|
|
auto root_ext_j = extent_map.find(root_dom[j]) == extent_map.end()
|
|
? root_dom[j]->extent()
|
|
: extent_map.at(root_dom[j]);
|
|
|
|
root_ext_j = getHaloExtentOfRootAxis(root_dom[j], root_ext_j);
|
|
|
|
if (zero_domain_map.count(root_dom[j]) == 0) {
|
|
if (stride == nullptr) {
|
|
stride = root_ext_j;
|
|
} else {
|
|
stride = SimplifyingIrBuilder::mulExpr(stride, root_ext_j);
|
|
}
|
|
}
|
|
}
|
|
|
|
if (stride != nullptr) {
|
|
strided_inds[i] = SimplifyingIrBuilder::mulExpr(root_ind_i, stride);
|
|
} else {
|
|
strided_inds[i] = root_ind_i;
|
|
}
|
|
}
|
|
|
|
// This check was originally done in getConsumerStridedIndices, but
|
|
// the number of strided index values depends on the loop where the
|
|
// consumer tensor is located. If it's double buffered and not in
|
|
// the prologue loop, strided_inds ends up having one more
|
|
// index, so it's just much simpler to check here before adding the
|
|
// additional index for double buffering.
|
|
TORCH_INTERNAL_ASSERT(
|
|
strided_inds.size() == consumer_tv->getMaybeRFactorDomain().size());
|
|
|
|
if (consumer_tv->isDoubleBuffered()) {
|
|
auto db_loop = gpu_lower->doubleBufferInfo().getDoubleBufferLoop(
|
|
consumer_tv, loops, true);
|
|
if (db_loop != nullptr) {
|
|
auto db_switch_index = SimplifyingIrBuilder::subExpr(
|
|
gpu_lower->kernel()->oneVal(),
|
|
SimplifyingIrBuilder::modExpr(
|
|
db_loop->index(), SimplifyingIrBuilder::create<Int>(2)));
|
|
auto original_alloc_size =
|
|
gpu_lower->doubleBufferInfo().getOriginalAllocSize(consumer_tv);
|
|
auto db_strided_index =
|
|
SimplifyingIrBuilder::mulExpr(db_switch_index, original_alloc_size);
|
|
strided_inds.push_back(db_strided_index);
|
|
}
|
|
}
|
|
|
|
return strided_inds;
|
|
}
|
|
|
|
std::vector<Val*> Index::getProducerStridedIndices(
|
|
TensorView* producer,
|
|
const TensorView* consumer,
|
|
const std::vector<kir::ForLoop*>& loops) {
|
|
FUSER_PERF_SCOPE("GpuLower::Lower::Index::getProducerStridedIndices");
|
|
if (producer->domain()->noReductions().size() == 0) {
|
|
return std::vector<Val*>(
|
|
producer->getMaybeRFactorDomain().size(),
|
|
GpuLower::current()->kernel()->zeroVal());
|
|
}
|
|
|
|
std::vector<Val*> strided_indices;
|
|
if (producer->getMemoryType() == MemoryType::Global) {
|
|
strided_indices =
|
|
getGlobalProducerStridedIndices(producer, consumer, loops);
|
|
} else {
|
|
strided_indices =
|
|
getNonGlobalProducerStridedIndices(producer, consumer, loops);
|
|
}
|
|
|
|
TORCH_INTERNAL_ASSERT(
|
|
strided_indices.size() ==
|
|
producer->getMaybeRFactorDomain().size() +
|
|
(producer->isDoubleBuffered() ? 1 : 0));
|
|
|
|
return strided_indices;
|
|
}
|
|
|
|
// Producer is the inputs of an expression
|
|
kir::TensorIndex* Index::getProducerIndex(
|
|
TensorView* producer,
|
|
const TensorView* consumer,
|
|
const std::vector<kir::ForLoop*>& loops) {
|
|
auto strided_indices = getProducerStridedIndices(producer, consumer, loops);
|
|
return SimplifyingIrBuilder::create<kir::TensorIndex>(
|
|
producer, strided_indices);
|
|
}
|
|
|
|
std::vector<Val*> Index::getConsumerStridedIndices(
|
|
const TensorView* consumer,
|
|
const std::vector<kir::ForLoop*>& loops) {
|
|
FUSER_PERF_SCOPE("GpuLower::Lower::Index::getConsumerStridedIndices");
|
|
if (consumer->domain()->noReductions().size() == 0) {
|
|
return std::vector<Val*>(
|
|
consumer->getMaybeRFactorDomain().size(),
|
|
GpuLower::current()->kernel()->zeroVal());
|
|
}
|
|
|
|
std::vector<Val*> strided_indices;
|
|
if (consumer->getMemoryType() == MemoryType::Global) {
|
|
strided_indices = getGlobalConsumerStridedIndices(consumer, loops);
|
|
} else {
|
|
strided_indices = getNonGlobalConsumerStridedIndices(consumer, loops);
|
|
}
|
|
|
|
return strided_indices;
|
|
}
|
|
|
|
// Consumer is the output of an expression
|
|
kir::TensorIndex* Index::getConsumerIndex(
|
|
const TensorView* consumer,
|
|
const std::vector<kir::ForLoop*>& loops) {
|
|
auto strided_indices = getConsumerStridedIndices(consumer, loops);
|
|
return SimplifyingIrBuilder::create<kir::TensorIndex>(
|
|
consumer, strided_indices);
|
|
}
|
|
|
|
namespace {
|
|
|
|
struct PredicateDomainInfo {
|
|
public:
|
|
// Iteration domain to predicate
|
|
IterDomain* id = nullptr;
|
|
// The set of iteration domains that make up the id. If this is for
|
|
// a non-divisible split, the set only contains the id itself. This
|
|
// set is used to remove redundant predicates when gathering
|
|
// unswitch predicates.
|
|
std::unordered_set<IterDomain*> covered_ids;
|
|
// True if this predicate is for a non-divisible split
|
|
bool is_non_divisible_split = false;
|
|
};
|
|
|
|
// Find iteration domains in the history of a consumer to predicate comprised
|
|
// only of merge operations. Only return iteration domains that are subsequently
|
|
// fed into a split, or are in the provided domain. In other words, we don't
|
|
// want to return every IterDomain that's contiguous, just the one closest to
|
|
// the leaves. Predicates are not associated with physical memory so we can
|
|
// treat all of them as contiguous merges.
|
|
std::vector<PredicateDomainInfo> getPredicateContigIds(
|
|
TensorView* consumer_tv) {
|
|
const auto gpu_lower = GpuLower::current();
|
|
|
|
const auto& consumer_root_domain = consumer_tv->getRootDomain();
|
|
|
|
std::vector<IterDomain*> contiguous_ids = consumer_root_domain;
|
|
|
|
if (contiguous_ids.empty()) {
|
|
return std::vector<PredicateDomainInfo>();
|
|
}
|
|
|
|
// If root IDs are partial, i.e., start is non-zero and stop is not
|
|
// equal to extent, predication can't be done with merged domains as
|
|
// start and stop information is only available with root
|
|
// domains. Similarly, merged domains don't have enough information
|
|
// about halo to do correct predication, so they must be excluded.
|
|
std::unordered_set<IterDomain*> excluded_ids;
|
|
|
|
for (auto consumer_root_id : consumer_root_domain) {
|
|
if (gpu_lower->haloInfo().getRootAxisInfo(consumer_root_id).hasHalo()) {
|
|
excluded_ids.insert(consumer_root_id);
|
|
continue;
|
|
}
|
|
if (consumer_root_id->maybePartial()) {
|
|
excluded_ids.insert(consumer_root_id);
|
|
continue;
|
|
}
|
|
// When consumer_root_id is a broadcast domain, do not allow contig
|
|
// predication as the merged output is not mapped with the
|
|
// reference unless the concrete domain is also a broadcast
|
|
// domain.
|
|
if (consumer_root_id->isBroadcast() &&
|
|
!gpu_lower->caLoopMap()
|
|
.getConcreteMappedID(consumer_root_id)
|
|
->isBroadcast()) {
|
|
excluded_ids.insert(consumer_root_id);
|
|
continue;
|
|
}
|
|
// Shifted or gathered axes need to be predicated at the root domain
|
|
auto shift_expr = dynamic_cast<ShiftOp*>(consumer_tv->definition());
|
|
auto gather_expr = dynamic_cast<GatherOp*>(consumer_tv->definition());
|
|
if (shift_expr == nullptr && gather_expr == nullptr) {
|
|
continue;
|
|
}
|
|
auto consumer_root_pos = consumer_tv->domain()->rootPosOf(consumer_root_id);
|
|
if ((shift_expr && shift_expr->offset(consumer_root_pos) != 0) ||
|
|
(gather_expr && consumer_root_pos < gather_expr->windowShape().size() &&
|
|
gather_expr->windowShape().at(consumer_root_pos) != 1)) {
|
|
excluded_ids.insert(consumer_root_id);
|
|
}
|
|
}
|
|
|
|
// Run through iteration domain history
|
|
auto exprs = StmtSort::getExprs(
|
|
consumer_tv->fusion(),
|
|
{consumer_tv->domain()->domain().begin(),
|
|
consumer_tv->domain()->domain().end()});
|
|
|
|
for (auto expr : exprs) {
|
|
// If not a merge, output is not contiguous
|
|
if (expr->isA<Merge>()) {
|
|
auto merge = expr->as<Merge>();
|
|
auto inner_contig_it = std::find(
|
|
contiguous_ids.begin(), contiguous_ids.end(), merge->inner());
|
|
auto outer_contig_it = std::find(
|
|
contiguous_ids.begin(), contiguous_ids.end(), merge->outer());
|
|
|
|
if (excluded_ids.count(merge->inner()) > 0 ||
|
|
excluded_ids.count(merge->outer()) > 0) {
|
|
continue;
|
|
}
|
|
|
|
if (inner_contig_it != contiguous_ids.end() &&
|
|
outer_contig_it != contiguous_ids.end()) {
|
|
// If inner and outer are contiguous, out must be contiguous. Remove
|
|
// inner and outer, and add out.
|
|
contiguous_ids.erase(outer_contig_it);
|
|
contiguous_ids.erase(std::find(
|
|
contiguous_ids.begin(), contiguous_ids.end(), merge->inner()));
|
|
contiguous_ids.emplace_back(merge->out());
|
|
}
|
|
}
|
|
}
|
|
|
|
std::vector<PredicateDomainInfo> contig_id_infos;
|
|
|
|
// Create entries and return them
|
|
for (auto contig_id : contiguous_ids) {
|
|
// Pick inputs from the starting domains, i.e.,
|
|
// reference_predicated_root_domain.
|
|
auto contig_root_vals = IterVisitor::getInputsTo(
|
|
{contig_id},
|
|
{consumer_root_domain.begin(), consumer_root_domain.end()});
|
|
auto contig_root_ids = ir_utils::filterByType<IterDomain>(contig_root_vals);
|
|
PredicateDomainInfo contig_id_info;
|
|
contig_id_info.id = contig_id;
|
|
contig_id_info.covered_ids = std::unordered_set<IterDomain*>(
|
|
contig_root_ids.begin(), contig_root_ids.end());
|
|
contig_id_infos.push_back(contig_id_info);
|
|
}
|
|
return contig_id_infos;
|
|
}
|
|
|
|
IterDomain* getMappedReferenceDomain(
|
|
IterDomain* id,
|
|
const ReferenceTensor& reference) {
|
|
// Partially overlaps with getPredicateContigIds()
|
|
auto concrete_id = GpuLower::current()->caIndexMap().getConcreteMappedID(id);
|
|
auto it = reference.concrete_to_id.find(concrete_id);
|
|
if (it == reference.concrete_to_id.end()) {
|
|
return nullptr;
|
|
}
|
|
return it->second;
|
|
}
|
|
|
|
std::vector<PredicateDomainInfo> getNonDivisibleConsumerDomainsToPredicate(
|
|
TensorView* consumer_tv) {
|
|
const auto& non_divisible_split_info =
|
|
GpuLower::current()->nonDivisibleSplitInfo();
|
|
|
|
std::vector<PredicateDomainInfo> pred_info_vec;
|
|
|
|
auto it = non_divisible_split_info.splitsToPredicate().find(consumer_tv);
|
|
if (it == non_divisible_split_info.splitsToPredicate().end()) {
|
|
return {};
|
|
}
|
|
|
|
const auto& splits_to_predicate = it->second;
|
|
|
|
for (auto split : splits_to_predicate) {
|
|
PredicateDomainInfo info{split->in(), {split->in()}, true};
|
|
pred_info_vec.emplace_back(info);
|
|
}
|
|
|
|
return pred_info_vec;
|
|
}
|
|
|
|
bool needsPadding(TensorView* tv) {
|
|
auto shift_expr = dynamic_cast<ShiftOp*>(tv->definition());
|
|
auto gather_expr = dynamic_cast<GatherOp*>(tv->definition());
|
|
|
|
return (shift_expr != nullptr && shift_expr->hasPadding()) ||
|
|
(gather_expr != nullptr && gather_expr->hasPadding());
|
|
}
|
|
|
|
// Get an additional offset of a stop index when building a predicate
|
|
// for unswitch. Initial stop indices generated at getPredicateReferenceIndexing
|
|
// do not take halo into account, and the adjustment for halo is done as an
|
|
// additional offset to the final index value so that unswitch predicates can be
|
|
// compared with each other by just looking at the additional offsets.
|
|
//
|
|
// consumer_root_id: the domain for which a stop predicate is being built.
|
|
int getUnswitchStopOffset(
|
|
IterDomain* consumer_root_id,
|
|
TensorView* consumer_tv) {
|
|
const auto gpu_lower = GpuLower::current();
|
|
|
|
AxisHaloInfo halo_info =
|
|
gpu_lower->haloInfo().getRootAxisInfo(consumer_root_id);
|
|
|
|
// If the consumer root domain to predicate does not have halo, no
|
|
// adjustment is required.
|
|
if (!halo_info.hasHalo()) {
|
|
return 0;
|
|
}
|
|
|
|
// Find if this contig_id is used in the unswitched domains
|
|
auto unswitch_it = std::find_if(
|
|
consumer_tv->domain()->domain().begin(),
|
|
consumer_tv->domain()->domain().end(),
|
|
[](IterDomain* id) {
|
|
return id->getParallelType() == ParallelType::Unswitch ||
|
|
id->getParallelType() == ParallelType::Unroll ||
|
|
id->getParallelType() == ParallelType::Vectorize;
|
|
});
|
|
|
|
// If any of the unswitched leaf domains inherits the halo from the
|
|
// root domain, the halo width needs to be added to the stop offset
|
|
if (std::any_of(
|
|
unswitch_it,
|
|
consumer_tv->domain()->domain().end(),
|
|
[&gpu_lower, &consumer_root_id](auto leaf_id) {
|
|
return gpu_lower->haloInfo().isHaloInherited(
|
|
consumer_root_id, leaf_id);
|
|
})) {
|
|
return halo_info.width();
|
|
} else {
|
|
return 0;
|
|
}
|
|
}
|
|
|
|
std::pair<Val*, Val*> getStartAndStopOffsetsForShift(
|
|
TensorView* consumer_tv,
|
|
IterDomain* consumer_id,
|
|
bool padding_predicate) {
|
|
TORCH_INTERNAL_ASSERT(consumer_id != nullptr);
|
|
|
|
auto shift_expr = dynamic_cast<ShiftOp*>(consumer_tv->definition());
|
|
|
|
// Adjustment is not necessary if not shift.
|
|
// Even so, padding predicate does not need any adjustment.
|
|
if (shift_expr == nullptr || padding_predicate) {
|
|
return {
|
|
GpuLower::current()->kernel()->zeroVal(),
|
|
GpuLower::current()->kernel()->zeroVal()};
|
|
}
|
|
|
|
const auto root_axis_pos = consumer_tv->domain()->rootPosOf(consumer_id);
|
|
|
|
// The first or last N elements, where N is the padding width,
|
|
// correspond to the padding predicate.
|
|
|
|
const auto shift_offset = shift_expr->offset(root_axis_pos);
|
|
const auto pad_width = shift_expr->padWidth().at(root_axis_pos);
|
|
|
|
int start_offset = 0;
|
|
int stop_offset = 0;
|
|
|
|
if (shift_offset > 0) {
|
|
start_offset = -pad_width;
|
|
} else if (shift_offset < 0) {
|
|
stop_offset = pad_width;
|
|
}
|
|
|
|
return {
|
|
SimplifyingIrBuilder::create<Int>(start_offset),
|
|
SimplifyingIrBuilder::create<Int>(stop_offset)};
|
|
}
|
|
|
|
std::pair<Val*, Val*> getStartAndStopOffsetsForGather(
|
|
TensorView* consumer_tv,
|
|
IterDomain* consumer_id,
|
|
const std::unordered_map<IterDomain*, Val*>& ref_start_index_map,
|
|
const std::unordered_map<IterDomain*, Val*>& ref_stop_index_map,
|
|
bool padding_predicate) {
|
|
TORCH_INTERNAL_ASSERT(consumer_id != nullptr);
|
|
|
|
// Adjustment is not necessary if not gather. Even so, padding
|
|
// predicate does not need any adjustment.
|
|
if (!consumer_tv->definition()->isA<GatherOp>() || padding_predicate) {
|
|
return {
|
|
GpuLower::current()->kernel()->zeroVal(),
|
|
GpuLower::current()->kernel()->zeroVal()};
|
|
}
|
|
|
|
const auto root_axis_pos = consumer_tv->domain()->rootPosOf(consumer_id);
|
|
|
|
auto producer_start_offset = getProducerOffsetWithGather(
|
|
root_axis_pos, consumer_tv, ref_start_index_map);
|
|
|
|
auto producer_stop_offset = getProducerOffsetWithGather(
|
|
root_axis_pos, consumer_tv, ref_stop_index_map);
|
|
|
|
auto consumer_start_offset = GpuLower::current()->kernel()->zeroVal();
|
|
auto consumer_stop_offset = GpuLower::current()->kernel()->zeroVal();
|
|
|
|
if (producer_start_offset->isZeroInt() && producer_stop_offset->isZeroInt()) {
|
|
return {consumer_start_offset, consumer_stop_offset};
|
|
}
|
|
|
|
Val* start_offset = nullptr;
|
|
Val* stop_offset = nullptr;
|
|
|
|
// In the normal case, take the minimum of the start and the
|
|
// maximum of the stop offsets. If there's no padding, the producer
|
|
// offset must be always larger than the consumer
|
|
// offset. So, the consumer and produce offsets can be always used
|
|
// for the start and stop offsets, respectively.
|
|
const auto pad_left =
|
|
consumer_tv->definition()->as<GatherOp>()->padWidth()[root_axis_pos][0];
|
|
const auto pad_right =
|
|
consumer_tv->definition()->as<GatherOp>()->padWidth()[root_axis_pos][1];
|
|
const auto window_size =
|
|
consumer_tv->definition()->as<GatherOp>()->windowShape()[root_axis_pos];
|
|
|
|
// consumer index: index
|
|
// producer index: index + window_index - pad_left
|
|
//
|
|
// consumer extent: ext
|
|
// producer extent: ext + window_size - 1 - pad_left - pad_right
|
|
//
|
|
// consumer stop pred: index < ext
|
|
// producer stop pred: index + window_index - pad_left < ext + window_size - 1
|
|
// - pad_left - pad_right
|
|
// -> index + window_index - pad_left - (window_size - 1 -
|
|
// pad_left - pad_right) < ext
|
|
// -> index + window_index - (window_size - 1 - pad_right) <
|
|
// ext
|
|
//
|
|
// consumer start pred: index >= 0
|
|
// producer start pred: index + window_index - pad_left >= 0
|
|
|
|
const auto producer_ext_adj = window_size - 1 - pad_left - pad_right;
|
|
producer_stop_offset = SimplifyingIrBuilder::subExpr(
|
|
producer_stop_offset,
|
|
SimplifyingIrBuilder::create<Int>(producer_ext_adj));
|
|
|
|
// As commented above, when pad_left is zero, the consumer predicate
|
|
// is always more restrictive than the producer predicate.
|
|
if (pad_left == 0) {
|
|
start_offset = consumer_start_offset;
|
|
} else {
|
|
start_offset = SimplifyingIrBuilder::minExpr(
|
|
consumer_start_offset, producer_start_offset);
|
|
}
|
|
|
|
// As commented above, when pad_right is zero, the consumer
|
|
// predicate is always more restrictive than the producer
|
|
// predicate.
|
|
if (pad_right == 0) {
|
|
stop_offset = consumer_stop_offset;
|
|
} else {
|
|
stop_offset = SimplifyingIrBuilder::maxExpr(
|
|
consumer_stop_offset, producer_stop_offset);
|
|
}
|
|
|
|
TORCH_INTERNAL_ASSERT(start_offset != nullptr);
|
|
TORCH_INTERNAL_ASSERT(stop_offset != nullptr);
|
|
|
|
return {start_offset, stop_offset};
|
|
}
|
|
|
|
// Get the start and stop limit offsets that define the valid range to
|
|
// compute. In the simplest case, they are just 0 and
|
|
// IterDomain::extent. However, IterDomain may have non-zero start and
|
|
// stop that's different from extent. Also, when IterDomain has halo,
|
|
// the actual offsets of the logical start and stop positions are
|
|
// shifted.
|
|
std::pair<Val*, Val*> getStartAndStopLimitOffsets(
|
|
IterDomain* consumer_id,
|
|
bool padding_predicate,
|
|
bool non_divisible_pred) {
|
|
const auto gpu_lower = GpuLower::current();
|
|
|
|
TORCH_INTERNAL_ASSERT(consumer_id != nullptr);
|
|
|
|
Val* start_limit = consumer_id->start();
|
|
Val* stop_limit = SimplifyingIrBuilder::negExpr(consumer_id->stopOffset());
|
|
|
|
if (!non_divisible_pred) {
|
|
AxisHaloInfo halo_info = gpu_lower->haloInfo().getRootAxisInfo(consumer_id);
|
|
|
|
// Below, "left" and "right" halo mean halo at offset zero and
|
|
// axis extent, respectively.
|
|
//
|
|
// The consumer axis looks like this:
|
|
//
|
|
// [0, left halo)[start_limit, stop_limit)[0, right halo)
|
|
//
|
|
if (!padding_predicate) {
|
|
start_limit =
|
|
SimplifyingIrBuilder::addExpr(start_limit, halo_info.width(0));
|
|
stop_limit =
|
|
SimplifyingIrBuilder::addExpr(stop_limit, halo_info.width(0));
|
|
} else {
|
|
// In case of the padding predicate, the whole range, including both left
|
|
// and right halo regions, is computed.
|
|
stop_limit = SimplifyingIrBuilder::addExpr(stop_limit, halo_info.width());
|
|
}
|
|
} else {
|
|
// For non-divisible predicates, the index must be predicated such
|
|
// that it is less than the extent of the predicated ID +
|
|
// halo. Note that getRootAxisInfo doesn't work since consumer_id
|
|
// isn't a root domain.
|
|
if (gpu_lower->haloInfo().hasHaloWidth(consumer_id)) {
|
|
auto halo = gpu_lower->haloInfo().getHaloWidth(consumer_id);
|
|
stop_limit = SimplifyingIrBuilder::addExpr(stop_limit, halo);
|
|
}
|
|
}
|
|
|
|
return {start_limit, stop_limit};
|
|
}
|
|
|
|
// Return an IndexCompute for a predicate reference tensor. Two different
|
|
// maps are used when generating predicates for unswitched expressions
|
|
// as start and stop conditions need to use different loop-to-index
|
|
// mappings.
|
|
auto getPredicateReferenceIndexing(
|
|
const std::vector<kir::ForLoop*>& loops,
|
|
const ReferenceTensor& reference,
|
|
kir::ForLoop* unswitch_or_vec_loop,
|
|
IterDomain* double_buffer_axis,
|
|
bool start) {
|
|
auto reference_domain = reference.domain;
|
|
|
|
std::unordered_map<kir::ForLoop*, Val*> loop_to_ind_map;
|
|
|
|
std::transform(
|
|
loops.begin(),
|
|
loops.end(),
|
|
std::inserter(loop_to_ind_map, loop_to_ind_map.begin()),
|
|
[](kir::ForLoop* fl) { return std::make_pair(fl, fl->index()); });
|
|
|
|
// If unswitch don't directly use indices from for loop, use zero
|
|
// and for loop extent minus 1
|
|
if (unswitch_or_vec_loop != nullptr) {
|
|
// Vectorized predicates are different from unswitch. Unswitch predicates
|
|
// all loops within the unswitch (the outer most unswitch) are generated
|
|
// with loop->extent-1 as the index. With vectorized predicates, only the
|
|
// vectorized loop should be like this.
|
|
|
|
bool vectorized_pred =
|
|
unswitch_or_vec_loop->iter_domain()->getParallelType() ==
|
|
ParallelType::Vectorize;
|
|
|
|
TORCH_INTERNAL_ASSERT(
|
|
loops.size() <= reference_domain->nDims(),
|
|
"Invalid reference generated.");
|
|
|
|
bool within_unswitch = false;
|
|
|
|
for (const auto loop_i : c10::irange(loops.size())) {
|
|
auto loop = loops[loop_i];
|
|
auto loop_id = loop->iter_domain();
|
|
auto loop_pt = loop_id->getParallelType();
|
|
auto ref_id = reference_domain->axis(loop_i);
|
|
|
|
if (loop == unswitch_or_vec_loop) {
|
|
within_unswitch = true;
|
|
}
|
|
|
|
if (within_unswitch) {
|
|
// Rely on the reference to check broadcasting. The for loop could be
|
|
// broadcasted on a constant value from an unroll split. Since reference
|
|
// may convert this to an iter domain, that for loop could be valid to
|
|
// generate predication from.
|
|
|
|
// Note that loop->stop() is not used below. Instead,
|
|
// loop->iter_domain()->extent() is used, which is uniform
|
|
// across the mapped domains irrespective of halo. Predicates are
|
|
// compared with each to pick the most restrictive ones. The
|
|
// comparison is done by only using the offset, which is the
|
|
// term added to the index. So, the index term must be the
|
|
// same among all predicates, otherwise the comparison would
|
|
// be invalid. The effect by halo is added to the offset
|
|
// term. See getUnswitchStopOffset.
|
|
|
|
if (ref_id->isBroadcast()) {
|
|
// Ignore indexing into broadcasted dimensions.
|
|
continue;
|
|
} else if (loop_id->isThread()) {
|
|
// When parallelized, if the loop stop is the same as the
|
|
// extent of the associated IterDomain, i.e., no extra
|
|
// iterations for halo, predicating with the threading index
|
|
// is sufficient for both the start and stop
|
|
// predicates. That isn't the case if the loop has halo, and
|
|
// in the case either the minimum and maximum values of the
|
|
// iteration domain needs to be used.
|
|
//
|
|
// Note: Better performance was obtained if using
|
|
// threadIdx in unswitch predicates was avoided. More
|
|
// specifically, in the Hdiff stencil example, instead of
|
|
// predicating with threadIdx.x for both the start and stop
|
|
// predicates, using zero and (blockDim.x - 1) for the start
|
|
// and stop predicates, respectively, resulted in less
|
|
// register pressure. The alternative codegen can be done by
|
|
// adding this to the first if condition:
|
|
// loop_id->isBlockDim(). This would not be a concern if the
|
|
// else part could be omitted, so canOmitElseClause should
|
|
// be used as well.
|
|
if (loop->stop() == loop_id->extent()) {
|
|
loop_to_ind_map[loop] = loop->start();
|
|
} else if (start) {
|
|
loop_to_ind_map[loop] = GpuLower::current()->kernel()->zeroVal();
|
|
} else {
|
|
// Note that the parallel dimension is used rather than
|
|
// loop-stop(). See the above comment.
|
|
loop_to_ind_map[loop] = SimplifyingIrBuilder::subExpr(
|
|
GpuLower::current()->parallelDimensionMap().get(loop_pt),
|
|
GpuLower::current()->kernel()->zeroVal());
|
|
}
|
|
} else if (start) {
|
|
loop_to_ind_map[loop] = GpuLower::current()->kernel()->zeroVal();
|
|
} else {
|
|
// Similar to the above, loop_id()->extent() is
|
|
// used here instead of loop->stop(). See the above comment.
|
|
loop_to_ind_map[loop] = SimplifyingIrBuilder::subExpr(
|
|
loop_id->extent(), GpuLower::current()->kernel()->oneVal());
|
|
}
|
|
}
|
|
|
|
// If a vectorized predicate, bail after the vectorized loop was found.
|
|
// Don't continue unswitching loops.
|
|
if (vectorized_pred && within_unswitch) {
|
|
break;
|
|
}
|
|
}
|
|
}
|
|
|
|
for (const auto loop : loops) {
|
|
auto& idx = loop_to_ind_map.at(loop);
|
|
// If the loop is trivial, the loop index can only be the loop
|
|
// start value.
|
|
if (idx == loop->index() && loop->isTrivial()) {
|
|
idx = loop->start();
|
|
}
|
|
}
|
|
|
|
if (double_buffer_axis != nullptr) {
|
|
auto db_loop = GpuLower::current()->doubleBufferInfo().getDoubleBufferLoop(
|
|
double_buffer_axis, loops, true);
|
|
if (db_loop != nullptr) {
|
|
auto loop_to_ind_map_it = loop_to_ind_map.find(db_loop);
|
|
TORCH_INTERNAL_ASSERT(loop_to_ind_map_it != loop_to_ind_map.end());
|
|
auto cur_index = loop_to_ind_map_it->second;
|
|
// if cur_index is not the same as the index of db_loop, it must
|
|
// be true that that index has been modified to support
|
|
// unswitch. In that case, it is not necessary to move ahead the
|
|
// index for double buffering.
|
|
if (cur_index == db_loop->index()) {
|
|
loop_to_ind_map[db_loop] = SimplifyingIrBuilder::addExpr(
|
|
cur_index, GpuLower::current()->kernel()->oneVal());
|
|
}
|
|
}
|
|
}
|
|
|
|
// Add magic zero to a loop pretty far inside in indexing
|
|
IterDomain* magic_zero_loop = nullptr;
|
|
std::unordered_map<IterDomain*, Val*> ref_id_to_ind_map;
|
|
// Due to rfactor/initialization reference_domain may be bigger than loop nest
|
|
// structure
|
|
TORCH_INTERNAL_ASSERT(loops.size() <= reference_domain->nDims());
|
|
for (const auto loop_i : c10::irange(loops.size())) {
|
|
auto loop = loops[loop_i];
|
|
auto ind = loop_to_ind_map[loops[loop_i]];
|
|
auto ref_axis = reference_domain->axis(loop_i);
|
|
|
|
if (Index::protectWithMagicZero(loop, ref_axis, ind)) {
|
|
magic_zero_loop = ref_axis;
|
|
}
|
|
|
|
ref_id_to_ind_map[ref_axis] = loop_to_ind_map[loop];
|
|
}
|
|
|
|
if (ref_id_to_ind_map.count(magic_zero_loop)) {
|
|
auto& ind = ref_id_to_ind_map[magic_zero_loop];
|
|
if (!ind->isConstScalar()) {
|
|
ind = SimplifyingIrBuilder::addExpr(
|
|
ind, GpuLower::current()->kernel()->magicZeroVal());
|
|
}
|
|
}
|
|
|
|
std::unordered_map<IterDomain*, IterDomain*> ref_self_map;
|
|
auto all_vals = DependencyCheck::getAllValsBetween(
|
|
{reference_domain->getRootDomain().begin(),
|
|
reference_domain->getRootDomain().end()},
|
|
{reference_domain->domain().begin(), reference_domain->domain().end()});
|
|
auto all_ids = ir_utils::filterByType<IterDomain>(all_vals);
|
|
std::for_each(all_ids.begin(), all_ids.end(), [&ref_self_map](auto id) {
|
|
ref_self_map.insert({id, id});
|
|
});
|
|
|
|
std::unordered_map<IterDomain*, Val*> reference_halo_extent_map =
|
|
getReferenceHaloExtentMap(reference, ref_self_map);
|
|
|
|
// Index into the reference tensor
|
|
auto index_compute = getReferenceIndexing(
|
|
loops,
|
|
reference_domain,
|
|
ref_id_to_ind_map,
|
|
{},
|
|
{},
|
|
reference_halo_extent_map);
|
|
|
|
return index_compute;
|
|
}
|
|
|
|
// Get the offsets for the start and stop predicates. The offsets
|
|
// are to be added to the index.
|
|
std::pair<Val*, Val*> getStartAndStopOffsets(
|
|
IterDomain* consumer_id,
|
|
TensorView* consumer_tv,
|
|
const ReferenceTensor& reference,
|
|
const std::unordered_map<IterDomain*, Val*>& consumer_start_index_map,
|
|
const std::unordered_map<IterDomain*, Val*>& consumer_stop_index_map,
|
|
bool padding_predicate,
|
|
bool unswitch,
|
|
bool non_divisible_pred) {
|
|
// By default, the offsets for the start and stop predicates are
|
|
// just zero. All halo-related adjustments are done at root domains,
|
|
// so consumer_id is not a root domain, no adjustment is required.
|
|
if (consumer_id->definition() != nullptr && !non_divisible_pred) {
|
|
return {
|
|
GpuLower::current()->kernel()->zeroVal(),
|
|
GpuLower::current()->kernel()->zeroVal()};
|
|
}
|
|
|
|
auto consumer_def = consumer_tv->definition();
|
|
|
|
Val* start_offset = GpuLower::current()->kernel()->zeroVal();
|
|
Val* stop_offset = GpuLower::current()->kernel()->zeroVal();
|
|
|
|
// These adjustments are not required when predicating non-divisible splits
|
|
if (!non_divisible_pred) {
|
|
if (consumer_def->isA<ShiftOp>()) {
|
|
std::tie(start_offset, stop_offset) = getStartAndStopOffsetsForShift(
|
|
consumer_tv, consumer_id, padding_predicate);
|
|
} else if (consumer_def->isA<GatherOp>()) {
|
|
std::tie(start_offset, stop_offset) = getStartAndStopOffsetsForGather(
|
|
consumer_tv,
|
|
consumer_id,
|
|
consumer_start_index_map,
|
|
consumer_stop_index_map,
|
|
padding_predicate);
|
|
}
|
|
|
|
// Adjustment for partial split
|
|
auto partial_split_offset =
|
|
getGlobalConsumerOffsetWithPartialSplit(consumer_id);
|
|
start_offset =
|
|
SimplifyingIrBuilder::addExpr(start_offset, partial_split_offset);
|
|
stop_offset =
|
|
SimplifyingIrBuilder::addExpr(stop_offset, partial_split_offset);
|
|
|
|
// If generating a predicate for unswitch, adjust the stop offset to
|
|
// accommodate the addition of halo to the loop stop. See the
|
|
// comment in getPredicateReferenceIndexing as well.
|
|
if (unswitch) {
|
|
TORCH_INTERNAL_ASSERT(
|
|
!padding_predicate, "Unswitch should not use the padding predicate");
|
|
auto stop_unswitch_offset =
|
|
getUnswitchStopOffset(consumer_id, consumer_tv);
|
|
stop_offset =
|
|
SimplifyingIrBuilder::addExpr(stop_offset, stop_unswitch_offset);
|
|
}
|
|
}
|
|
|
|
// Get the boundaries of two ends
|
|
auto limits = getStartAndStopLimitOffsets(
|
|
consumer_id, padding_predicate, non_divisible_pred);
|
|
|
|
// At this point, we have everything to create both start and stop
|
|
// predicates as:
|
|
//
|
|
// index + start_offset >= start_limit
|
|
// index + stop_offset < extent + stop_limit
|
|
//
|
|
// In order to enable consolidating unswitch predicates, organize
|
|
// the predicates as:
|
|
//
|
|
// index + (start_offset - start_limit) >= 0
|
|
// index + (stop_offset - stop_limit) < extent
|
|
|
|
start_offset = SimplifyingIrBuilder::subExpr(start_offset, limits.first);
|
|
stop_offset = SimplifyingIrBuilder::subExpr(stop_offset, limits.second);
|
|
|
|
return {start_offset, stop_offset};
|
|
}
|
|
|
|
// A partial value of a start offset is returned if determined to be
|
|
// safe. Nullptr is returned if it can be omitted completely.
|
|
Val* simplifyStartOffset(Val* start_offset) {
|
|
// Start predicate can be omitted when start_offset >= 0.
|
|
auto offset_val = start_offset->as<Int>()->value();
|
|
if (offset_val.has_value() && offset_val.value() >= 0) {
|
|
return nullptr;
|
|
}
|
|
|
|
// start_offset may look like min(0, window_index - pad). Then, can
|
|
// remove min and leave the rhs only.
|
|
auto def = dynamic_cast<BinaryOp*>(start_offset->definition());
|
|
if (def != nullptr && def->getBinaryOpType() == BinaryOpType::Min &&
|
|
def->lhs()->isZeroInt()) {
|
|
return def->rhs();
|
|
}
|
|
|
|
return start_offset;
|
|
}
|
|
|
|
bool canOmitStopPredicate(
|
|
Val* stop_index,
|
|
Val* stop_offset,
|
|
IterDomain* contig_id) {
|
|
bool index_simple = stop_index->definition() == nullptr;
|
|
// The definition may be just adding the magic zero, which can be
|
|
// effectively considered "simple"
|
|
if (!index_simple && isProtectedWithMagicZero(stop_index)) {
|
|
// Make sure the lhs of stop_index is simple.
|
|
auto lhs = stop_index->definition()->as<BinaryOp>()->lhs();
|
|
if (lhs->definition() == nullptr) {
|
|
index_simple = true;
|
|
}
|
|
}
|
|
|
|
if (!index_simple) {
|
|
return false;
|
|
}
|
|
|
|
const auto gpu_lower = GpuLower::current();
|
|
|
|
// Stop predicate: stop_index + stop_offset < extent, where
|
|
// stop_index ranges from 0 to (extent + halo), so this can be
|
|
// omitted if extent + halo + stop_offset < extent, i.e., halo +
|
|
// stop_offset <= 0.
|
|
|
|
auto stop_offset_val = stop_offset->as<Int>()->value();
|
|
|
|
// If they are not compile-time constant, can't prove the
|
|
// condition.
|
|
if (!stop_offset_val.has_value()) {
|
|
return false;
|
|
}
|
|
|
|
// Note that when a root domain is halo extended, it is the domain
|
|
// to be predicated, not its merged contig id even if it exists. So,
|
|
// if contig_id does not have root axis info, contig_id is
|
|
// guaranteed to have no halo.
|
|
auto halo_ext = gpu_lower->haloInfo().hasRootAxisInfo(contig_id)
|
|
? gpu_lower->haloInfo().getRootAxisInfo(contig_id).width()
|
|
: 0;
|
|
|
|
if (halo_ext + stop_offset_val.value() > 0) {
|
|
return false;
|
|
}
|
|
|
|
// When the domain is parallelized, the parallel dimension must be
|
|
// exact. Otherwise, there would be extra threads/blocks that need
|
|
// to be predicated out.
|
|
if (isParallelTypeThread(contig_id->getParallelType())) {
|
|
if (!gpu_lower->parallelDimensionMap().isExact(
|
|
contig_id->getParallelType())) {
|
|
return false;
|
|
}
|
|
// If the domain has halo, the loop is expanded by the halo
|
|
// extent, so we can't prove the loop extent is the same as the
|
|
// parallel dimension.
|
|
if (halo_ext != 0) {
|
|
return false;
|
|
}
|
|
}
|
|
|
|
return true;
|
|
}
|
|
|
|
std::pair<Val*, Val*> hoistPredicates(
|
|
Val* start_index,
|
|
Val* stop_index,
|
|
const std::vector<kir::ForLoop*>& loops,
|
|
kir::ForLoop* unswitch_or_vec_loop,
|
|
IterDomain* predicated_consumer_id,
|
|
TensorView* predicated_consumer_tv,
|
|
TensorDomain* ref_td,
|
|
const std::unordered_map<IterDomain*, Val*>& ref_start_index_map,
|
|
const std::unordered_map<IterDomain*, Val*>& ref_stop_index_map) {
|
|
const std::pair<Val*, Val*> same_indices{start_index, stop_index};
|
|
|
|
if (disableIndexHoisting()) {
|
|
return same_indices;
|
|
}
|
|
|
|
const auto start_is_same_as_stop = stop_index == start_index;
|
|
|
|
Val* hoisted_stop_index = nullptr;
|
|
|
|
if (stop_index->definition() == nullptr) {
|
|
// If the index doens't have an expression, nothing to hoist
|
|
hoisted_stop_index = stop_index;
|
|
} else {
|
|
bool inserted = false;
|
|
std::tie(hoisted_stop_index, inserted) =
|
|
GpuLower::current()->commonIndexMap().insert(
|
|
predicated_consumer_id,
|
|
predicated_consumer_tv->domain(),
|
|
ref_td,
|
|
ref_stop_index_map,
|
|
loops,
|
|
stop_index);
|
|
}
|
|
|
|
Val* hoisted_start_index = nullptr;
|
|
if (start_is_same_as_stop) {
|
|
hoisted_start_index = hoisted_stop_index;
|
|
} else if (start_index->definition() == nullptr) {
|
|
hoisted_start_index = start_index;
|
|
} else {
|
|
bool inserted = false;
|
|
std::tie(hoisted_start_index, inserted) =
|
|
GpuLower::current()->commonIndexMap().insert(
|
|
predicated_consumer_id,
|
|
predicated_consumer_tv->domain(),
|
|
ref_td,
|
|
ref_start_index_map,
|
|
loops,
|
|
start_index);
|
|
}
|
|
|
|
return {hoisted_start_index, hoisted_stop_index};
|
|
}
|
|
|
|
} // namespace
|
|
|
|
// Returns predicates and the concrete (by loop map) root domains they cover
|
|
std::pair<std::vector<RootPredicateInfo>, ReferenceTensor> Index::
|
|
getReferenceRootPredicates(
|
|
TensorView* consumer_tv,
|
|
const std::vector<kir::ForLoop*>& loops,
|
|
kir::ForLoop* unswitch_or_vec_loop,
|
|
bool shift_padding) {
|
|
FUSER_PERF_SCOPE("GpuLower::Lower::Index::getReferenceRootPredicates");
|
|
|
|
const auto gpu_lower = GpuLower::current();
|
|
|
|
const bool is_unswitch = unswitch_or_vec_loop != nullptr;
|
|
|
|
// Nothing needs to be done when padding is not required.
|
|
if (shift_padding && !needsPadding(consumer_tv)) {
|
|
return {{RootPredicateInfo::getFalseInfo()}, ReferenceTensor{}};
|
|
}
|
|
|
|
// Get a reference tensor replayed as existing loop structure
|
|
ReferenceTensor reference = IndexReferenceReplay::getReference(loops);
|
|
|
|
// Generate halo information for reference.
|
|
updateHaloInfoForReference(reference, consumer_tv);
|
|
|
|
const auto ref_2_consumer = indexMapReferenceTo(
|
|
consumer_tv, gpu_lower->caIndexMap(), reference.concrete_to_id);
|
|
|
|
const auto reference_halo_extent_map =
|
|
getReferenceHaloExtentMap(reference, ref_2_consumer);
|
|
|
|
auto db_axis = gpu_lower->doubleBufferInfo().getDoubleBufferAxis(consumer_tv);
|
|
|
|
// Both start and stop positions may need to be predicated. Indexing
|
|
// differs when generating predicates for unswitch.
|
|
// NOTE: If we could find-and-replace KIR nodes, we could just
|
|
// generate one index map, clone it and replace the loop-to-index
|
|
// mappings of unswitched loops for the start predicate.
|
|
auto ref_stop_indexing = getPredicateReferenceIndexing(
|
|
loops, reference, unswitch_or_vec_loop, db_axis, false);
|
|
const auto consumer_stop_indexing = ref_stop_indexing.updateIndexCompute(
|
|
consumer_tv->domain(),
|
|
ref_2_consumer,
|
|
std::vector<bool>(consumer_tv->getMaybeRFactorDomain().size(), false),
|
|
reference_halo_extent_map);
|
|
const auto& consumer_stop_index_map = consumer_stop_indexing.indexMap();
|
|
|
|
// If not unswitch, share the same indexing map as the stop index
|
|
// map
|
|
const auto& ref_start_indexing = is_unswitch
|
|
? getPredicateReferenceIndexing(
|
|
loops, reference, unswitch_or_vec_loop, db_axis, true)
|
|
: ref_stop_indexing;
|
|
|
|
std::unordered_map<IterDomain*, Val*> consumer_start_index_map;
|
|
if (is_unswitch) {
|
|
const auto consumer_start_indexing = ref_start_indexing.updateIndexCompute(
|
|
consumer_tv->domain(),
|
|
ref_2_consumer,
|
|
std::vector<bool>(consumer_tv->getMaybeRFactorDomain().size(), false),
|
|
reference_halo_extent_map);
|
|
consumer_start_index_map = consumer_start_indexing.indexMap();
|
|
} else {
|
|
consumer_start_index_map = consumer_stop_index_map;
|
|
}
|
|
|
|
// Get the contiguous ids we need to generate predicates for
|
|
auto contig_id_infos = getPredicateContigIds(consumer_tv);
|
|
|
|
auto non_divisible_splits =
|
|
getNonDivisibleConsumerDomainsToPredicate(consumer_tv);
|
|
contig_id_infos.insert(
|
|
contig_id_infos.end(),
|
|
non_divisible_splits.begin(),
|
|
non_divisible_splits.end());
|
|
|
|
std::vector<RootPredicateInfo> pred_info_vec;
|
|
|
|
for (auto contig_id_entry : contig_id_infos) {
|
|
auto contig_id = contig_id_entry.id;
|
|
// No predicates needed for braodcasted indices.
|
|
if (contig_id->isBroadcast() ||
|
|
gpu_lower->trivialReductionInfo().isDerived(contig_id)) {
|
|
continue;
|
|
}
|
|
|
|
auto root_ids = contig_id_entry.covered_ids;
|
|
|
|
const auto consumer_stop_indexing_it =
|
|
consumer_stop_index_map.find(contig_id);
|
|
|
|
// First condition below happens with Misaligned predicates, where
|
|
// inner-most vectorized loops are not included in the loops
|
|
// parameter. Predicates involving vectorized loops are separately
|
|
// generated in lower_misaligned_vectorization.
|
|
//
|
|
// Second condition is simply to avoid predication on broadcasting axes as
|
|
// it's not required.
|
|
if (consumer_stop_indexing_it == consumer_stop_index_map.end() ||
|
|
consumer_stop_indexing_it->second->isZeroInt()) {
|
|
continue;
|
|
}
|
|
|
|
RootPredicateInfo info;
|
|
|
|
// Compute offsets for start and stop predicate. For non-shift,
|
|
// non-gather ops, there's only stop predicate as indices never be
|
|
// negative. However, for shift and gather, the index may need to
|
|
// be predicated so that it is >= zero.
|
|
//
|
|
// Furthermore, in case of gather, both producer and consumer
|
|
// positions may need to be predicated, so there can be multiple
|
|
// offset values.
|
|
//
|
|
// The final predicates will look like:
|
|
// (index + start_offset) >= 0 && (index + stop_offset) < extent.
|
|
|
|
std::tie(info.start_offset_, info.stop_offset_) = getStartAndStopOffsets(
|
|
contig_id,
|
|
consumer_tv,
|
|
reference,
|
|
consumer_start_index_map,
|
|
consumer_stop_index_map,
|
|
shift_padding,
|
|
unswitch_or_vec_loop != nullptr,
|
|
contig_id_entry.is_non_divisible_split);
|
|
|
|
auto stop_index = consumer_stop_indexing_it->second;
|
|
auto start_index = consumer_start_index_map.at(contig_id);
|
|
|
|
std::tie(start_index, stop_index) = hoistPredicates(
|
|
start_index,
|
|
stop_index,
|
|
loops,
|
|
unswitch_or_vec_loop,
|
|
contig_id,
|
|
consumer_tv,
|
|
reference.domain,
|
|
ref_start_indexing.indexMap(),
|
|
ref_stop_indexing.indexMap());
|
|
|
|
// Build predicates for start positions as:
|
|
// start_index + start_offset >= 0
|
|
auto start_offset = simplifyStartOffset(info.start_offset_);
|
|
if (start_offset == nullptr) {
|
|
info.start_predicate_ = GpuLower::current()->kernel()->trueVal();
|
|
} else {
|
|
auto offsetted_start_index =
|
|
SimplifyingIrBuilder::addExpr(start_index, start_offset);
|
|
auto start_pred =
|
|
SimplifyingIrBuilder::geExpr(
|
|
offsetted_start_index, GpuLower::current()->kernel()->zeroVal())
|
|
->as<Bool>();
|
|
info.start_predicate_ = start_pred;
|
|
}
|
|
|
|
// Build predicates for stop positions as:
|
|
// stop_index + stop_offset < IterDomain::extent
|
|
auto stop_offset = info.stop_offset_;
|
|
if (canOmitStopPredicate(stop_index, stop_offset, contig_id)) {
|
|
info.stop_predicate_ = GpuLower::current()->kernel()->trueVal();
|
|
} else {
|
|
auto offsetted_stop_index =
|
|
SimplifyingIrBuilder::addExpr(stop_index, stop_offset);
|
|
auto stop_pred = SimplifyingIrBuilder::ltExpr(
|
|
offsetted_stop_index, contig_id->extent())
|
|
->as<Bool>();
|
|
info.stop_predicate_ = stop_pred;
|
|
}
|
|
|
|
for (auto consumer_id : contig_id_entry.covered_ids) {
|
|
info.root_ids_.insert(consumer_id);
|
|
}
|
|
pred_info_vec.emplace_back(info);
|
|
}
|
|
|
|
return {pred_info_vec, reference};
|
|
}
|
|
|
|
bool Index::protectWithMagicZero(
|
|
kir::ForLoop* loop,
|
|
IterDomain* reference_domain,
|
|
Val* ind) {
|
|
bool ref_dom_simple =
|
|
(reference_domain == nullptr ? true
|
|
: reference_domain->definition() != nullptr);
|
|
bool ind_simple =
|
|
(ind == nullptr ? true
|
|
: ind->definition() != nullptr && !ind->isZeroInt());
|
|
return loop->isUnrolled() && (!ref_dom_simple || !ind_simple);
|
|
}
|
|
|
|
RootPredicateInfo RootPredicateInfo::getFalseInfo() {
|
|
RootPredicateInfo info;
|
|
info.start_predicate_ = GpuLower::current()->kernel()->falseVal();
|
|
info.stop_predicate_ = GpuLower::current()->kernel()->falseVal();
|
|
|
|
return info;
|
|
}
|
|
|
|
} // namespace cuda
|
|
} // namespace fuser
|
|
} // namespace jit
|
|
} // namespace torch
|