Files
pytorch/torch/csrc/jit/codegen/cuda/ir_nodes.cpp
Michael Andreas Dagitses 67d313a032 turn on -Werror=unused-function in our Bazel CPU build
Summary:
We also fix any existing issues. Note that we only do this for the CPU
build because nvcc is considered a C++ toolchain but it does not have
the same flag support. Adding flags to the GPU build will cause nvcc
errors.

Test Plan: Built locally, rely on CI to confirm.

Reviewers: malfet

Subscribers:

Tasks:

Tags:

Pull Request resolved: https://github.com/pytorch/pytorch/pull/79154

Approved by: https://github.com/seemethere, https://github.com/osalpekar, https://github.com/albanD
2022-06-10 18:30:08 +00:00

1926 lines
58 KiB
C++

#include <torch/csrc/jit/codegen/cuda/arith.h>
#include <torch/csrc/jit/codegen/cuda/disjoint_set.h>
#include <torch/csrc/jit/codegen/cuda/ir_cloner.h>
#include <torch/csrc/jit/codegen/cuda/ir_interface_nodes.h>
#include <torch/csrc/jit/codegen/cuda/ir_iostream.h>
#include <torch/csrc/jit/codegen/cuda/ir_utils.h>
#include <torch/csrc/jit/codegen/cuda/kernel.h>
#include <torch/csrc/jit/codegen/cuda/kernel_ir.h>
#include <torch/csrc/jit/codegen/cuda/lower2device.h>
#include <torch/csrc/jit/codegen/cuda/root_domain_map.h>
#include <torch/csrc/jit/codegen/cuda/transform_iter.h>
#include <torch/csrc/jit/codegen/cuda/transform_rfactor.h>
#include <torch/csrc/jit/codegen/cuda/transform_view.h>
#include <c10/util/irange.h>
#include <sstream>
namespace torch {
namespace jit {
namespace fuser {
namespace cuda {
namespace {
class ScalarCheck : OptInConstDispatch {
public:
static bool sameAs(const Val* v1, const Val* v2) {
if (v1 == v2)
return true;
if (v1->getValType() != v2->getValType())
return false;
if (v1->getDataType() != v2->getDataType())
return false;
ScalarCheck sc(v1, v2);
return sc.same_;
}
private:
void handle(const Bool* b) final {
same_ = v1_->as<Bool>()->sameAs(v2_->as<Bool>());
}
void handle(const Double* d) final {
same_ = v1_->as<Double>()->sameAs(v2_->as<Double>());
}
void handle(const Int* i) final {
same_ = v1_->as<Int>()->sameAs(v2_->as<Int>());
}
void handle(const NamedScalar* ns) final {
same_ = v1_->as<NamedScalar>()->sameAs(v2_->as<NamedScalar>());
}
ScalarCheck(const Val* _v1, const Val* _v2) : v1_(_v1), v2_(_v2) {
OptInConstDispatch::handle(v1_);
}
private:
const Val* v1_ = nullptr;
const Val* v2_ = nullptr;
bool same_ = false;
};
} // namespace
bool areEqualScalars(Val* v1, Val* v2) {
return ScalarCheck::sameAs(v1, v2);
}
Bool::Bool(IrBuilderPasskey passkey)
: Val(passkey, ValType::Scalar, DataType::Bool),
maybe_value_{c10::nullopt} {}
Bool::Bool(IrBuilderPasskey passkey, bool value)
: Val(passkey, ValType::Scalar, DataType::Bool), maybe_value_{value} {}
Bool::Bool(IrBuilderPasskey passkey, c10::optional<bool> value)
: Val(passkey, ValType::Scalar, DataType::Bool), maybe_value_{value} {}
Bool::Bool(const Bool* src, IrCloner* ir_cloner)
: Val(src, ir_cloner), maybe_value_(src->maybe_value_) {}
bool Bool::sameAs(const Statement* other) const {
if (this == other) {
return true;
}
if (!other->isA<Bool>()) {
return false;
}
const auto other_bool = other->as<Bool>();
if (isConst() && other_bool->isConst()) {
return *value() == *(other_bool->value());
}
return false;
}
Double::Double(IrBuilderPasskey passkey)
: Val(passkey, ValType::Scalar, DataType::Double),
maybe_value_{c10::nullopt} {}
Double::Double(IrBuilderPasskey passkey, ScalarType value)
: Val(passkey, ValType::Scalar, DataType::Double), maybe_value_{value} {}
Double::Double(IrBuilderPasskey passkey, c10::optional<ScalarType> value)
: Val(passkey, ValType::Scalar, DataType::Double), maybe_value_{value} {}
Double::Double(const Double* src, IrCloner* ir_cloner)
: Val(src, ir_cloner), maybe_value_(src->maybe_value_) {}
bool Double::sameAs(const Statement* other) const {
if (this == other) {
return true;
}
if (!other->isA<Double>()) {
return false;
}
const auto other_double = other->as<Double>();
if (isConst() && other_double->isConst())
return *value() == *(other_double->value());
return false;
}
Int::Int(IrBuilderPasskey passkey)
: Val(passkey, ValType::Scalar, DataType::Int),
maybe_value_{c10::nullopt} {}
Int::Int(IrBuilderPasskey passkey, ScalarType value)
: Val(passkey, ValType::Scalar, DataType::Int), maybe_value_{value} {}
Int::Int(IrBuilderPasskey passkey, c10::optional<ScalarType> value)
: Val(passkey, ValType::Scalar, DataType::Int), maybe_value_{value} {}
Int::Int(const Int* src, IrCloner* ir_cloner)
: Val(src, ir_cloner), maybe_value_(src->maybe_value_) {}
bool Int::sameAs(const Statement* other) const {
if (this == other) {
return true;
}
if (!other->isA<Int>()) {
return false;
}
const auto other_int = other->as<Int>();
if (isConst() && other_int->isConst()) {
return *value() == *(other_int->value());
}
return false;
}
ComplexDouble::ComplexDouble(IrBuilderPasskey passkey)
: Val(passkey, ValType::Scalar, DataType::ComplexDouble),
maybe_value_{c10::nullopt} {}
ComplexDouble::ComplexDouble(IrBuilderPasskey passkey, ScalarType value)
: Val(passkey, ValType::Scalar, DataType::ComplexDouble),
maybe_value_{value} {}
ComplexDouble::ComplexDouble(
IrBuilderPasskey passkey,
c10::optional<ScalarType> value)
: Val(passkey, ValType::Scalar, DataType::ComplexDouble),
maybe_value_{value} {}
ComplexDouble::ComplexDouble(const ComplexDouble* src, IrCloner* ir_cloner)
: Val(src, ir_cloner), maybe_value_(src->maybe_value_) {}
bool ComplexDouble::sameAs(const Statement* other) const {
if (this == other) {
return true;
}
if (!other->isA<ComplexDouble>()) {
return false;
}
const auto other_complex = other->as<ComplexDouble>();
if (isConst() && other_complex->isConst())
return *value() == *(other_complex->value());
return false;
}
UnaryOp::UnaryOp(IrBuilderPasskey passkey, UnaryOpType type, Val* out, Val* in)
: Expr(passkey, ExprType::UnaryOp),
unary_op_type_{type},
out_{out},
in_{in} {
addOutput(out);
addInput(in);
}
UnaryOp::UnaryOp(const UnaryOp* src, IrCloner* ir_cloner)
: Expr(src, ir_cloner),
unary_op_type_(src->unary_op_type_),
out_(ir_cloner->clone(src->out_)),
in_(ir_cloner->clone(src->in_)) {}
bool UnaryOp::sameAs(const Statement* other) const {
if (this == other) {
return true;
}
if (!other->isA<UnaryOp>()) {
return false;
}
const auto other_op = other->as<UnaryOp>();
if (getUnaryOpType() != other_op->getUnaryOpType())
return false;
return Expr::sameAs(other);
}
BinaryOp::BinaryOp(
IrBuilderPasskey passkey,
BinaryOpType type,
Val* out,
Val* lhs,
Val* rhs)
: Expr(passkey, ExprType::BinaryOp),
binary_op_type_{type},
out_{out},
lhs_{lhs},
rhs_{rhs} {
addOutput(out);
addInput(lhs);
addInput(rhs);
}
BinaryOp::BinaryOp(const BinaryOp* src, IrCloner* ir_cloner)
: Expr(src, ir_cloner),
binary_op_type_(src->binary_op_type_),
out_(ir_cloner->clone(src->out_)),
lhs_(ir_cloner->clone(src->lhs_)),
rhs_(ir_cloner->clone(src->rhs_)) {}
bool BinaryOp::sameAs(const Statement* other) const {
if (this == other) {
return true;
}
if (!other->isA<BinaryOp>()) {
return false;
}
const auto other_op = other->as<BinaryOp>();
if (getBinaryOpType() != other_op->getBinaryOpType())
return false;
return Expr::sameAs(other);
}
TernaryOp::TernaryOp(
IrBuilderPasskey passkey,
TernaryOpType type,
Val* out,
Val* in1,
Val* in2,
Val* in3)
: Expr(passkey, ExprType::TernaryOp),
ternary_op_type_{type},
out_{out},
in1_{in1},
in2_{in2},
in3_{in3} {
addOutput(out);
addInput(in1);
addInput(in2);
addInput(in3);
}
TernaryOp::TernaryOp(const TernaryOp* src, IrCloner* ir_cloner)
: Expr(src, ir_cloner),
ternary_op_type_(src->ternary_op_type_),
out_(ir_cloner->clone(src->out_)),
in1_(ir_cloner->clone(src->in1_)),
in2_(ir_cloner->clone(src->in2_)),
in3_(ir_cloner->clone(src->in3_)) {}
bool TernaryOp::sameAs(const Statement* other) const {
if (this == other) {
return true;
}
if (!other->isA<TernaryOp>()) {
return false;
}
const auto other_op = other->as<TernaryOp>();
if (getTernaryOpType() != other_op->getTernaryOpType())
return false;
return Expr::sameAs(other);
}
BroadcastOp::BroadcastOp(
IrBuilderPasskey passkey,
Val* out,
Val* in,
std::vector<bool> is_broadcast_dims)
: Expr(passkey, ExprType::BroadcastOp),
out_(out),
in_(in),
is_broadcast_dims_(std::move(is_broadcast_dims)) {
// clang-tidy complains about out_ that it may be null.
TORCH_INTERNAL_ASSERT(out_ != nullptr);
TORCH_INTERNAL_ASSERT(in_ != nullptr);
auto out_type = out->getValType().value();
auto in_type = in->getValType().value();
TORCH_INTERNAL_ASSERT(
(out_type == ValType::TensorView && in_type == ValType::TensorView) ||
(out_type == ValType::TensorIndex && in_type == ValType::TensorIndex),
"Cannot braodcast a non-tensor object.");
addOutput(out);
addInput(in);
if (!out->isA<TensorView>() || !in->isA<TensorView>()) {
return;
}
passkey.ir_container_->registerExpr(exprPasskey(), this);
// This is a generic check that root dims of a consumer and producer match.
// Maybe we shouldn't relegate it to this constructor.
const auto c_tv = out_->as<TensorView>();
const auto p_tv = in_->as<TensorView>();
const auto& c_root = c_tv->getRootDomain();
const auto& p_root = p_tv->getMaybeRFactorDomain();
const auto root_p2c =
PairwiseRootDomainMap(p_tv, c_tv)
.mapProducerToConsumer(p_tv->domain(), c_tv->domain());
for (auto id : p_root) {
if (root_p2c.find(id) == root_p2c.end()) {
TORCH_INTERNAL_ASSERT(
id->isReduction() || id->isStride(),
"Invalid broadcast op: ",
id,
". Non-reduction input dim does't match to output.");
}
}
std::unordered_set<IterDomain*> c_mapped;
for (auto pair_entry : root_p2c) {
c_mapped.insert(pair_entry.second);
}
for (const auto i : c10::irange(c_root.size())) {
const auto c_id = c_root[i];
if (c_mapped.find(c_id) != c_mapped.end()) {
continue;
}
TORCH_INTERNAL_ASSERT(
c_id->isBroadcast() && is_broadcast_dims_[i],
"Invalid broadcast op: ",
c_id,
". Non-broadcasted output dim isn't matched from input.");
}
}
BroadcastOp::BroadcastOp(const BroadcastOp* src, IrCloner* ir_cloner)
: Expr(src, ir_cloner),
out_(ir_cloner->clone(src->out_)),
in_(ir_cloner->clone(src->in_)),
is_broadcast_dims_(src->is_broadcast_dims_) {}
bool BroadcastOp::sameAs(const Statement* other) const {
if (this == other) {
return true;
}
if (!other->isA<BroadcastOp>()) {
return false;
}
const auto other_op = other->as<BroadcastOp>();
if (getBroadcastDimFlags() != other_op->getBroadcastDimFlags()) {
return false;
}
return Expr::sameAs(other);
}
ReductionOp::ReductionOp(
IrBuilderPasskey passkey,
BinaryOpType reduction_op_type,
Val* init,
Val* out,
Val* in,
bool is_allreduce,
ExprType expr_type)
: Expr(passkey, expr_type),
reduction_op_type_(reduction_op_type),
init_(init),
out_(out),
in_(in),
is_allreduce_(is_allreduce) {
TORCH_CHECK(
out->getValType().value() == ValType::TensorView ||
out->getValType().value() == ValType::TensorIndex);
TORCH_INTERNAL_ASSERT(
(in->getValType() == ValType::TensorView &&
out->getValType() == ValType::TensorView) ||
(in->getValType() == ValType::TensorIndex &&
out->getValType() == ValType::TensorIndex),
"Reduction operation was created that does not have tensor inputs and outputs.");
if (in->isA<TensorView>()) {
TORCH_INTERNAL_ASSERT(
TensorDomain::noReductions(
in->as<TensorView>()->getMaybeRFactorDomain())
.size() == out->as<TensorView>()->getRootDomain().size(),
"Reduction operation created with mismatched domains.");
}
TORCH_INTERNAL_ASSERT(
init->isConstScalar(),
"Tried to create a reduction operation whith an initial value that isn't a constant.");
addOutput(out);
addInput(in);
}
GroupedReductionOp::GroupedReductionOp(
IrBuilderPasskey passkey,
std::vector<BinaryOpType> reduction_op_types,
std::vector<Val*> init_vals,
std::vector<Val*> outputs,
std::vector<Val*> inputs,
bool is_fused,
ExprType expr_type)
: Expr(passkey, expr_type),
reduction_op_types_(std::move(reduction_op_types)),
init_vals_(std::move(init_vals)),
is_allreduce_(is_fused) {
for (auto out : outputs) {
addOutput(out);
}
for (auto in : inputs) {
addInput(in);
}
}
GroupedReductionOp::GroupedReductionOp(
const GroupedReductionOp* src,
IrCloner* ir_cloner)
: Expr(src, ir_cloner),
reduction_op_types_(src->reduction_op_types_),
init_vals_(ir_cloner->clone(src->init_vals_)),
is_allreduce_(src->is_allreduce_) {}
bool GroupedReductionOp::sameAs(const Statement* other) const {
if (this == other) {
return true;
}
auto grouped_rop = dynamic_cast<const GroupedReductionOp*>(other);
if (grouped_rop == nullptr) {
return false;
}
if (!Expr::sameAs(other) ||
getReductionOpTypes() != grouped_rop->getReductionOpTypes()) {
return false;
}
for (const auto i : c10::irange(numReductions())) {
if (!initVal(i)->sameAs(grouped_rop->initVal(i))) {
return false;
}
}
return true;
}
WelfordOp::WelfordOp(
IrBuilderPasskey passkey,
Val* out_avg,
Val* out_var,
Val* out_N,
Val* init_avg,
Val* init_var,
Val* init_N,
Val* in_avg,
Val* in_var,
Val* in_N,
bool is_fused)
: Expr(passkey, ExprType::WelfordOp),
out_avg_(out_avg),
out_var_(out_var),
out_N_(out_N),
init_avg_(init_avg),
init_var_(init_var),
init_N_(init_N),
in_avg_(in_avg),
in_var_(in_var == nullptr ? in_avg->container()->zeroVal() : in_var),
in_N_(in_N),
is_allreduce_(is_fused) {
// Check output type
TORCH_INTERNAL_ASSERT(
out_avg->getValType().value() == ValType::TensorView ||
out_avg->getValType().value() == ValType::TensorIndex);
TORCH_INTERNAL_ASSERT(
out_var->getValType().value() == ValType::TensorView ||
out_var->getValType().value() == ValType::TensorIndex);
TORCH_INTERNAL_ASSERT(
out_N->getValType().value() == ValType::TensorView ||
out_N->getValType().value() == ValType::TensorIndex);
// check initial value
TORCH_INTERNAL_ASSERT(init_N->getValType().value() == ValType::Scalar);
if (!init_N->isZeroInt()) {
// when initial count is zero, no initial variance or average is needed
// initial value with a count of 1 is un-common enough that I'll push
// the responsibility of creating all-zero var tensors to the user
TORCH_INTERNAL_ASSERT(
init_avg &&
(init_avg->getValType().value() == ValType::TensorView ||
init_avg->getValType().value() == ValType::TensorIndex));
TORCH_INTERNAL_ASSERT(
init_var &&
(init_var->getValType().value() == ValType::TensorView ||
init_var->getValType().value() == ValType::TensorIndex));
}
TORCH_INTERNAL_ASSERT(
in_avg &&
(in_avg->getValType().value() == ValType::TensorView ||
in_avg->getValType().value() == ValType::TensorIndex),
in_avg->getValType().value());
// check input
TORCH_INTERNAL_ASSERT(
in_N->getValType().value() == ValType::Scalar ||
in_N->getValType().value() == ValType::TensorView ||
in_N->getValType().value() == ValType::TensorIndex);
if (!in_N->isOneInt()) {
// when input is only one value, only the value is required through avg
// input the var part is implicitly 0 and codegen will handle that.
TORCH_INTERNAL_ASSERT(
in_var &&
(in_var->getValType().value() == ValType::TensorView ||
in_var->getValType().value() == ValType::TensorIndex));
} else {
TORCH_INTERNAL_ASSERT(
in_var == nullptr || in_var->isZeroInt(),
"Invalid var input, which must be either nullptr or scalar zero when the N input is one.");
}
addOutput(out_avg_);
addOutput(out_var_);
addOutput(out_N_);
addInput(in_avg_);
// Previously in_var_ was allowed to be null
TORCH_INTERNAL_ASSERT(
in_var_ != nullptr, "Welford var input nullptr not allowed");
addInput(in_var_);
addInput(in_N_);
}
WelfordOp::WelfordOp(const WelfordOp* src, IrCloner* ir_cloner)
: Expr(src, ir_cloner),
out_avg_(ir_cloner->clone(src->out_avg_)),
out_var_(ir_cloner->clone(src->out_var_)),
out_N_(ir_cloner->clone(src->out_N_)),
init_avg_(src->init_avg_ ? ir_cloner->clone(src->init_avg_) : nullptr),
init_var_(src->init_var_ ? ir_cloner->clone(src->init_var_) : nullptr),
init_N_(ir_cloner->clone(src->init_N_)),
in_avg_(ir_cloner->clone(src->in_avg_)),
in_var_(src->in_var_ ? ir_cloner->clone(src->in_var_) : nullptr),
in_N_(ir_cloner->clone(src->in_N_)),
is_allreduce_(src->is_allreduce_) {}
namespace {
inline bool sameOptionalVal(Val* a, Val* b) {
return ((a == nullptr && b == nullptr)) || ((a && b) && (a->sameAs(b)));
}
} // namespace
bool WelfordOp::sameAs(const Statement* other) const {
if (this == other) {
return true;
}
if (auto other_wop = dynamic_cast<const WelfordOp*>(other)) {
return in_avg_->sameAs(other_wop->in_avg_) &&
sameOptionalVal(in_var_, other_wop->in_var_) &&
in_N_->sameAs(other_wop->in_N_) &&
sameOptionalVal(init_avg_, other_wop->init_avg_) &&
sameOptionalVal(init_var_, other_wop->init_var_) &&
init_N_->sameAs(other_wop->init_N_);
}
return false;
}
std::vector<Val*> WelfordOp::getInitVals() const {
std::vector<Val*> init_vals({init_avg_, init_var_, init_N_});
return init_vals;
}
MmaOp::MmaOp(
IrBuilderPasskey passkey,
Val* out,
Val* in_a,
Val* in_b,
Val* init)
: Expr(passkey, ExprType::MmaOp),
out_(out),
in_a_(in_a),
in_b_(in_b),
init_(init) {
// Check output type
TORCH_INTERNAL_ASSERT(
out->getValType().value() == ValType::TensorView ||
out->getValType().value() == ValType::TensorIndex);
TORCH_INTERNAL_ASSERT(
in_a->getValType().value() == ValType::TensorView ||
in_a->getValType().value() == ValType::TensorIndex,
in_a->getValType().value());
TORCH_INTERNAL_ASSERT(
in_b->getValType().value() == ValType::TensorView ||
in_b->getValType().value() == ValType::TensorIndex,
in_b->getValType().value());
addOutput(out);
addInput(in_a);
addInput(in_b);
}
MmaOp::MmaOp(
IrBuilderPasskey passkey,
Val* out,
Val* in_a,
Val* in_b,
Val* init,
MmaOptions options)
: MmaOp(passkey, out, in_a, in_b, init) {
options_ = options;
}
MmaOp::MmaOp(const MmaOp* src, IrCloner* ir_cloner)
: Expr(src, ir_cloner),
out_(ir_cloner->clone(src->out_)),
in_a_(ir_cloner->clone(src->in_a_)),
in_b_(ir_cloner->clone(src->in_b_)),
init_(ir_cloner->clone(src->init_)),
options_(src->options_) {}
bool MmaOp::sameAs(const Statement* other) const {
if (this == other) {
return true;
}
if (auto other_mma = dynamic_cast<const MmaOp*>(other)) {
return out_->sameAs(other_mma->out_) && in_a_->sameAs(other_mma->in_a_) &&
in_b_->sameAs(other_mma->in_b_) && init_->sameAs(other_mma->init_) &&
options_ == other_mma->options_;
}
return false;
}
ReductionOp::ReductionOp(const ReductionOp* src, IrCloner* ir_cloner)
: Expr(src, ir_cloner),
reduction_op_type_(src->reduction_op_type_),
init_(ir_cloner->clone(src->init_)),
out_(ir_cloner->clone(src->out_)),
in_(ir_cloner->clone(src->in_)),
is_allreduce_(src->is_allreduce_) {}
bool ReductionOp::sameAs(const Statement* other) const {
if (this == other) {
return true;
}
if (!other->isA<ReductionOp>()) {
return false;
}
const auto other_op = other->as<ReductionOp>();
// Note that init is not part of input vals, so it must be checked separately.
return (
Expr::sameAs(other) &&
getReductionOpType() == other_op->getReductionOpType() &&
init()->sameAs(other_op->init()));
}
TransposeOp::TransposeOp(
IrBuilderPasskey passkey,
TensorView* out,
TensorView* in,
std::vector<int> new2old)
: Expr(passkey, ExprType::TransposeOp),
out_(out),
in_(in),
new2old_(std::move(new2old)) {
// Sanity check of the input parameters. Maybe not necessary as they
// should be checked at function transpose.
TORCH_INTERNAL_ASSERT(
!in->hasRFactor(), "Transposing rFactor tensors is not supported.");
TORCH_INTERNAL_ASSERT(
TensorDomain::noReductions(in->getRootDomain()).size() ==
out->getRootDomain().size());
TORCH_INTERNAL_ASSERT(new2old_.size() == out->getRootDomain().size());
// Make sure the entries of new2old are unique and range from 0 to
// N-1, where N == new2old.size().
std::set<int> old_positions(new2old_.begin(), new2old_.end());
TORCH_INTERNAL_ASSERT(old_positions.size() == new2old_.size());
// old_positions is sorted, so the first entry must be 0.
TORCH_INTERNAL_ASSERT(
*(old_positions.begin()) == 0,
"Invalid new2old vector detected: ",
new2old_);
// The last entry must be N-1, since old_positions is sorted, starts
// with 0, and its length is N.
TORCH_INTERNAL_ASSERT(
*(old_positions.rbegin()) == (int)(new2old_.size() - 1),
"Invalid new2old vector detected: ",
new2old_);
addOutput(out);
addInput(in);
}
TransposeOp::TransposeOp(const TransposeOp* src, IrCloner* ir_cloner)
: Expr(src, ir_cloner),
out_(ir_cloner->clone(src->out_)),
in_(ir_cloner->clone(src->in_)),
new2old_(src->new2old_) {}
ShiftOp::ShiftOp(
IrBuilderPasskey passkey,
Val* out,
Val* in,
std::vector<int> offsets,
std::vector<int> pad_width)
: Expr(passkey, ExprType::ShiftOp),
out_(out),
in_(in),
offsets_(std::move(offsets)),
pad_width_(std::move(pad_width)) {
// clang-tidy complains about out_ that it may be null.
TORCH_INTERNAL_ASSERT(out_ != nullptr);
TORCH_INTERNAL_ASSERT(in_ != nullptr);
auto out_type = out->getValType().value();
auto in_type = in->getValType().value();
TORCH_INTERNAL_ASSERT(
out_type == ValType::TensorView && in_type == ValType::TensorView,
"Cannot shift a non-tensor object.");
TORCH_INTERNAL_ASSERT(
offsets_.size() ==
TensorDomain::noReductions(in_->as<TensorView>()->getRootDomain())
.size(),
"Invalid offset vector: ",
offsets_);
TORCH_INTERNAL_ASSERT(
pad_width_.size() ==
TensorDomain::noReductions(in_->as<TensorView>()->getRootDomain())
.size(),
"Invalid padding width vector: ",
pad_width_);
addOutput(out);
addInput(in);
}
ShiftOp::ShiftOp(const ShiftOp* src, IrCloner* ir_cloner)
: Expr(src, ir_cloner),
out_(ir_cloner->clone(src->out_)),
in_(ir_cloner->clone(src->in_)),
offsets_(src->offsets_),
pad_width_(src->pad_width_) {}
bool ShiftOp::sameAs(const Statement* other) const {
if (this == other) {
return true;
}
if (!other->isA<ShiftOp>()) {
return false;
}
const auto other_op = other->as<ShiftOp>();
if (offsets() != other_op->offsets()) {
return false;
}
return Expr::sameAs(other);
}
GatherOp::GatherOp(
IrBuilderPasskey passkey,
Val* out,
Val* in,
std::vector<int> window_shape,
std::vector<std::vector<int>> pad_width)
: Expr(passkey, ExprType::GatherOp),
out_(out),
in_(in),
window_shape_(std::move(window_shape)),
pad_width_(std::move(pad_width)) {
// clang-tidy complains about out_ that it may be null.
TORCH_INTERNAL_ASSERT(out_ != nullptr);
TORCH_INTERNAL_ASSERT(in_ != nullptr);
auto out_type = out->getValType().value();
auto in_type = in->getValType().value();
TORCH_INTERNAL_ASSERT(
out_type == ValType::TensorView && in_type == ValType::TensorView,
"Cannot shift a non-tensor object.");
const auto ndims =
TensorDomain::noReductions(in_->as<TensorView>()->getRootDomain()).size();
TORCH_INTERNAL_ASSERT(
window_shape_.size() == ndims,
"Invalid window_shape vector: ",
window_shape_);
TORCH_INTERNAL_ASSERT(
pad_width_.size() == ndims, "Invalid pad_width vector: ", pad_width_);
for (const auto& pad : pad_width_) {
TORCH_INTERNAL_ASSERT(
pad.size() == 2, "Padding size for each axis must have two Int vals.");
}
addOutput(out);
addInput(in);
}
GatherOp::GatherOp(const GatherOp* src, IrCloner* ir_cloner)
: Expr(src, ir_cloner),
out_(ir_cloner->clone(src->out_)),
in_(ir_cloner->clone(src->in_)),
window_shape_(src->window_shape_),
pad_width_(src->pad_width_) {}
bool GatherOp::sameAs(const Statement* other) const {
if (this == other) {
return true;
}
if (!other->isA<GatherOp>()) {
return false;
}
const auto other_op = other->as<GatherOp>();
if (windowShape() != other_op->windowShape() ||
padWidth() != other_op->padWidth()) {
return false;
}
return Expr::sameAs(other);
}
int GatherOp::gatherAxis(int axis) const {
if (axis < 0) {
axis += out()->as<TensorView>()->nDims();
}
TORCH_INTERNAL_ASSERT(
axis >= 0 && axis < (int)windowShape().size(), "Invalid axis: ", axis);
return int(windowShape().size()) + axis;
}
ViewAsScalar::ViewAsScalar(
IrBuilderPasskey passkey,
Val* out,
Val* in,
IterDomain* vector_id,
Val* index)
: Expr(passkey, ExprType::ViewAsScalar),
out_(out),
in_(in),
vector_id_(vector_id),
index_(index) {
addOutput(out);
addInput(in);
}
ViewAsScalar::ViewAsScalar(const ViewAsScalar* src, IrCloner* ir_cloner)
: Expr(src, ir_cloner),
out_(ir_cloner->clone(src->out_)),
in_(ir_cloner->clone(src->in_)),
vector_id_(ir_cloner->clone(src->vector_id_)),
index_(ir_cloner->clone(src->index_)) {}
ViewOp::ViewOp(IrBuilderPasskey passkey, TensorView* out, TensorView* in)
: Expr(passkey, ExprType::ViewOp), out_(out), in_(in) {
addOutput(out);
addInput(in);
}
ViewOp::ViewOp(const ViewOp* src, IrCloner* ir_cloner)
: Expr(src, ir_cloner),
out_(ir_cloner->clone(src->out_)),
in_(ir_cloner->clone(src->in_)) {}
LoadStoreOp::LoadStoreOp(
IrBuilderPasskey passkey,
LoadStoreOpType op_type,
Val* out,
Val* in)
: Expr(passkey, ExprType::LoadStoreOp),
load_store_type_(op_type),
out_(out),
in_(in) {
addOutput(out);
addInput(in);
}
LoadStoreOp::LoadStoreOp(const LoadStoreOp* src, IrCloner* ir_cloner)
: Expr(src, ir_cloner),
load_store_type_(src->load_store_type_),
out_(ir_cloner->clone(src->out_)),
in_(ir_cloner->clone(src->in_)) {}
IterDomain::IterDomain(
IrBuilderPasskey passkey,
Val* start,
Val* extent,
ParallelType parallel_type,
IterType iter_type,
bool is_rfactor_domain,
bool is_padded_dimension,
c10::optional<int64_t> padded_to_size,
bool is_mma_swizzled)
: IterDomain(
passkey,
start,
extent,
nullptr,
parallel_type,
iter_type,
is_rfactor_domain,
is_padded_dimension,
padded_to_size,
is_mma_swizzled) {}
IterDomain::IterDomain(
IrBuilderPasskey passkey,
Val* start,
Val* extent,
Val* stop_offset,
ParallelType parallel_type,
IterType iter_type,
bool is_rfactor_domain,
bool is_padded_dimension,
c10::optional<int64_t> padded_to_size,
bool is_mma_swizzled)
: Val(passkey, ValType::IterDomain, DataType::Int),
start_(start),
extent_(extent),
stop_offset_(
stop_offset == nullptr ? passkey.ir_container_->zeroVal()
: stop_offset),
parallel_type_(parallel_type),
iter_type_(iter_type),
is_rfactor_domain_(is_rfactor_domain),
is_padded_dimension_(is_padded_dimension),
padded_to_size_(padded_to_size),
is_mma_swizzled_(is_mma_swizzled) {
TORCH_CHECK(
!(isRFactorProduct() && isBroadcast()),
"IterDomain cannot be both a broadcast and rfactor domain.");
TORCH_INTERNAL_ASSERT(
extent->isAnInt(),
"Cannot create an iter domain over an extent that is not an int but received ",
extent,
" .");
TORCH_INTERNAL_ASSERT(
start->isAnInt(),
"Cannot create an iter domain with a start that is not an int but received ",
start,
" .");
}
IterDomain::IterDomain(const IterDomain* src, IrCloner* ir_cloner)
: Val(src, ir_cloner),
start_(ir_cloner->clone(src->start_)),
extent_(ir_cloner->clone(src->extent_)),
stop_offset_(ir_cloner->clone(src->stop_offset_)),
parallel_type_(src->parallel_type_),
iter_type_(src->iter_type_),
is_rfactor_domain_(src->is_rfactor_domain_),
is_padded_dimension_(src->is_padded_dimension_),
padded_to_size_(src->padded_to_size_),
is_mma_swizzled_(src->is_mma_swizzled_) {}
bool IterDomain::sameAs(const Statement* other) const {
if (other == this) {
return true;
}
if (!other->isA<IterDomain>()) {
return false;
}
const IterDomain* other_id = other->as<IterDomain>();
bool is_same = isReduction() == other_id->isReduction() &&
getParallelType() == other_id->getParallelType() &&
isVectorComponent() == other_id->isVectorComponent();
is_same = is_same && ScalarCheck::sameAs(extent(), other_id->extent());
is_same = is_same && ScalarCheck::sameAs(start(), other_id->start());
is_same =
is_same && ScalarCheck::sameAs(stopOffset(), other_id->stopOffset());
return is_same;
}
// Returns a new IterDomain matching properties of this except for
// is_rfactor_domain_
IterDomain* IterDomain::cloneWithoutRFactor() const {
auto cloned = IrBuilder::create<IterDomain>(
ir_container_,
start(),
extent(),
stopOffset(),
getParallelType(),
getIterType(),
false,
is_padded_dimension_,
padded_to_size_,
is_mma_swizzled_);
return cloned;
}
std::vector<IterDomain*> IterDomain::clone(
const std::vector<IterDomain*>& domains) {
std::vector<IterDomain*> cloned_domains;
std::transform(
domains.begin(),
domains.end(),
std::back_inserter(cloned_domains),
[](auto id) { return id->cloneWithoutRFactor(); });
return cloned_domains;
}
// Merging does not propagate the start and stop values of the input
// domains to the merged output domain. The actual range of the
// domains is enforced by predicates. Note that since only root
// domains have valid start and stop, it's not possible to contiguous
// predication.
IterDomain* IterDomain::merge(IterDomain* outer, IterDomain* inner) {
TORCH_CHECK(
!outer->extent()->isZeroInt() && !inner->extent()->isZeroInt(),
"Merging IterDomains with ending values that are 0 is not supported at this time.");
TORCH_CHECK(
outer->isReduction() == inner->isReduction() ||
(!outer->isReduction() && inner->extent()->isOneInt()) ||
(outer->extent()->isOneInt() && !inner->isReduction()),
"Merging IterDomains requires that their iteration types match.");
TORCH_CHECK(
(outer->isGather() && inner->isGather()) ||
(!outer->isGather() && !inner->isGather()),
"Merging gather and non-gather domains is not supported.");
Val* merged_id_size = mul(outer->extent(), inner->extent());
IterType itype = outer->getIterType();
if (outer->isBroadcast() && inner->isBroadcast()) {
if (outer->getIterType() == IterType::BroadcastWithStride ||
inner->getIterType() == IterType::BroadcastWithStride) {
itype = IterType::BroadcastWithStride;
} else {
itype = IterType::BroadcastWithoutStride;
}
} else if (outer->isBroadcast() || inner->isBroadcast()) {
itype = IterType::Iteration;
}
// Merging trivial reduction with iter domain, that's fine, just make it an
// iter domain.
if ((outer->isReduction() || inner->isReduction()) &&
(!outer->isReduction() || !inner->isReduction())) {
itype = IterType::Iteration;
}
IterDomain* merged_id = IrBuilder::create<IterDomain>(
outer->container(),
outer->container()->zeroVal(),
merged_id_size->as<Int>(),
outer->getParallelType(),
itype);
IrBuilder::create<Merge>(outer->container(), merged_id, outer, inner);
return merged_id;
}
// Both outer and inner domains do not inherit start and stop
// values as they can't be split. The access range is enforced by
// predicates.
std::pair<IterDomain*, IterDomain*> IterDomain::split(
IterDomain* in,
Val* factor,
bool inner_split,
Val* start_offset,
Val* stop_offset) {
TORCH_CHECK(
!in->extent()->isZeroInt(),
"Splitting IterDomains with ending values that are 0 is not supported at this time.");
TORCH_CHECK(factor->isAnInt(), "Cannot split by non-integer value ", factor);
if (factor->getValType() == ValType::Scalar) {
TORCH_CHECK(
factor->isConstScalar() ||
(FusionGuard::getCurFusion() == factor->fusion() &&
factor->isFusionInput()),
factor,
" is not a constant nor an input. It must be one or the other to be used in a split.",
" If you want a symbolic split based on a thread dimension please use IterDomain::split(IterDomain*, ParallelType);");
} else if (factor->getValType() == ValType::NamedScalar) {
TORCH_CHECK(
factor->as<NamedScalar>()->getParallelDim() != c10::nullopt,
"Splitting a dimension by a named scalar is only supported on block or grid dimensions but received ",
factor);
}
// outer loop size
Val* remainder =
ceilDiv(Split::extent(in->extent(), start_offset, stop_offset), factor);
if ((start_offset != nullptr && !start_offset->isZeroInt()) ||
(stop_offset != nullptr && !stop_offset->isZeroInt())) {
TORCH_INTERNAL_ASSERT(
in->definition() == nullptr,
"Partial split is only allowed with root domains");
}
// outer loop IterDomain
IterDomain* ido = IrBuilder::create<IterDomain>(
in->container(),
in->container()->zeroVal(),
inner_split ? remainder->as<Int>() : factor,
in->getParallelType(),
in->getIterType());
// inner loop IterDomain
IterDomain* idi = IrBuilder::create<IterDomain>(
in->container(),
in->container()->zeroVal(),
inner_split ? factor : remainder->as<Int>(),
in->getParallelType(),
in->getIterType());
IrBuilder::create<Split>(
in->container(),
ido,
idi,
in,
factor,
inner_split,
start_offset,
stop_offset);
return {ido, idi};
}
std::pair<IterDomain*, IterDomain*> IterDomain::split(
IterDomain* in,
Val* factor,
bool inner_split,
bool trim_out_of_bounds) {
auto start_offset = trim_out_of_bounds ? in->start() : nullptr;
auto stop_offset = trim_out_of_bounds ? in->stopOffset() : nullptr;
return IterDomain::split(in, factor, inner_split, start_offset, stop_offset);
}
std::pair<IterDomain*, IterDomain*> IterDomain::stridedSplit(int factor) {
// Use partial split so that only valid values are retained
auto split_out = IterDomain::split(
this, IrBuilder::create<Int>(container(), factor), true, true);
split_out.second->iter_type_ = IterType::Stride;
split_out.first->is_rfactor_domain_ = true;
split_out.second->is_rfactor_domain_ = true;
return split_out;
}
// TODO: We should change parallelize interface to be on tensorview or at least
// vectorize should be done on tensorview. This would let us check that we don't
// vectorize to the left of the computeAt domain, and could allow us to do some
// simple validation of vectorize as it's inputs are right most and contiguous.
void IterDomain::parallelize(ParallelType t) {
if (parallel_type_ == t) {
// No op, don't do any more checks, it was already set to this value.
return;
}
if (t == ParallelType::Unroll || isParallelTypeVectorize(t)) {
TORCH_CHECK(
start()->isZeroInt() && extent()->isConstScalar(),
"Vectorization, unrolling, and unswitching are only supported with start = 0 and extent as a const int, but got ",
"a start of ",
start(),
" and extent ",
extent(),
" .");
}
if (isMmaSwizzled()) {
// Mma swizzled axes represent data representation within a warp
// so only allow updates that keep the parallelization within
// a warp.
// Note && TODO: this check is actually used to allow indexing path
// to make copies of the iterdomains. We might eventually just want
// to lock these parallel types and not allowing any changes once
// they are swizzled.
TORCH_CHECK(
t == ParallelType::Vectorize || t == ParallelType::TIDx ||
t == ParallelType::Serial,
"Parallel type other than serial, tidx, vectorize not allowed for mma swizzled ids");
}
parallel_type_ = t;
}
bool IterDomain::maybePartial() const {
return !start()->isZeroInt() || !stopOffset()->isZeroInt();
}
Val* IterDomain::stopOffset() const {
return stop_offset_;
}
Val* IterDomain::stop() const {
if (stopOffset()->isZeroInt()) {
return extent();
}
return sub(extent(), stopOffset());
}
TensorDomain::TensorDomain(
IrBuilderPasskey passkey,
std::vector<IterDomain*> root_domain,
std::vector<bool> contiguity)
: Val(passkey, ValType::TensorDomain, DataType::Null),
root_domain_(std::move(root_domain)),
contiguity_(
contiguity.empty() ? std::vector<bool>(root_domain_.size(), false)
: std::move(contiguity)) {
TORCH_CHECK(
contiguity_.size() == getMaybeRFactorDomain().size(),
"Invalid contiguity information provided, incorrect size. Recieved vector of size ",
contiguity_.size(),
" but needed one of size ",
root_domain_.size());
// Just due to clang-tidy, correct value set in resetDomains
has_nontrivial_reduction_ = false;
domain_ = root_domain_;
resetDomains();
}
TensorDomain::TensorDomain(
IrBuilderPasskey passkey,
std::vector<IterDomain*> root_domain,
std::vector<IterDomain*> domain,
std::vector<bool> contiguity)
: Val(passkey, ValType::TensorDomain, DataType::Null),
root_domain_(std::move(root_domain)),
domain_(std::move(domain)),
contiguity_(
contiguity.empty() ? std::vector<bool>(root_domain_.size(), false)
: std::move(contiguity)) {
TORCH_CHECK(
contiguity_.size() == getMaybeRFactorDomain().size(),
"Invalid contiguity information provided, incorrect size. Recieved vector of size ",
contiguity_.size(),
" but needed one of size ",
root_domain_.size());
std::vector<Val*> domain_vals(domain_.begin(), domain_.end());
auto inps = IterVisitor::getInputsTo(domain_vals);
// Validate that the root domain consists of all inputs to domain
// Uncertain if this will hold for RFactor
std::unordered_set<Val*> root_vals(root_domain_.begin(), root_domain_.end());
std::for_each(inps.begin(), inps.end(), [root_vals](Val* inp) {
TORCH_INTERNAL_ASSERT(
root_vals.find(inp) != root_vals.end(),
"Invalid tensor domain, ",
inp,
" is an input of domain, but it is not found in the root domain.");
});
// Just due to clang-tidy, correct value set in resetDomains
has_nontrivial_reduction_ = false;
resetDomains();
}
TensorDomain::TensorDomain(
IrBuilderPasskey passkey,
std::vector<IterDomain*> root_domain,
std::vector<IterDomain*> rfactor_domain,
std::vector<IterDomain*> domain,
std::vector<bool> contiguity)
: Val(passkey, ValType::TensorDomain, DataType::Null),
root_domain_(std::move(root_domain)),
domain_(std::move(domain)),
rfactor_domain_(std::move(rfactor_domain)),
contiguity_(
contiguity.empty() ? std::vector<bool>(rfactor_domain_.size(), false)
: std::move(contiguity)) {
TORCH_CHECK(
contiguity_.size() == getMaybeRFactorDomain().size(),
"Invalid contiguity information provided, incorrect size. Recieved vector of size ",
contiguity_.size(),
" but needed one of size ",
getMaybeRFactorDomain().size());
auto inps = IterVisitor::getInputsTo(
std::vector<Val*>(domain_.begin(), domain_.end()));
// Validate that the root domain consists of all inputs to domain
// Uncertain if this will hold for RFactor
std::unordered_set<Val*> root_vals(root_domain_.begin(), root_domain_.end());
std::for_each(inps.begin(), inps.end(), [root_vals](Val* inp) {
TORCH_INTERNAL_ASSERT(
root_vals.find(inp) != root_vals.end(),
"Invalid tensor domain, ",
inp,
" is an input of domain, but it is not found in the root domain.");
});
inps = IterVisitor::getInputsTo(
std::vector<Val*>(rfactor_domain_.begin(), rfactor_domain_.end()));
std::for_each(inps.begin(), inps.end(), [root_vals](Val* inp) {
TORCH_INTERNAL_ASSERT(
root_vals.find(inp) != root_vals.end(),
"Invalid tensor domain, ",
inp,
" is an input of the rfactor domain, but it is not found in the root domain.");
});
// Just due to clang-tidy, correct value set in resetDomains
has_nontrivial_reduction_ = false;
resetDomains();
}
TensorDomain::TensorDomain(const TensorDomain* src, IrCloner* ir_cloner)
: Val(src, ir_cloner),
root_domain_(ir_cloner->clone(src->root_domain_)),
domain_(ir_cloner->clone(src->domain_)),
no_bcast_domain_(ir_cloner->clone(src->no_bcast_domain_)),
no_reduction_domain_(ir_cloner->clone(src->no_reduction_domain_)),
rfactor_domain_(ir_cloner->clone(src->rfactor_domain_)),
contiguity_(src->contiguity()),
has_nontrivial_reduction_(src->has_nontrivial_reduction_) {}
bool TensorDomain::hasBlockBroadcast() const {
return std::any_of(domain_.begin(), domain_.end(), [](IterDomain* id) {
return id->isBroadcast() && id->isThreadDim();
});
}
bool TensorDomain::hasGridBroadcast() const {
return std::any_of(domain_.begin(), domain_.end(), [](IterDomain* id) {
return id->isBroadcast() && id->isBlockDim();
});
}
bool TensorDomain::operator==(const TensorDomain& other) const {
// Checks equality of each class field. Should not be necessary to
// check no_bcast_domain_ and no_reduction_domain_ as they are just
// derived from domain_.
return root_domain_ == other.root_domain_ && domain_ == other.domain_ &&
rfactor_domain_ == other.rfactor_domain_ &&
contiguity_ == other.contiguity_;
}
bool TensorDomain::sameAs(const Statement* const other) const {
if (this == other) {
return true;
}
if (!other->isA<TensorDomain>()) {
return false;
}
const TensorDomain* other_td = other->as<TensorDomain>();
if (nDims() != other_td->nDims()) {
return false;
}
if (getRootDomain().size() != other_td->getRootDomain().size()) {
return false;
}
if (getRFactorDomain().size() != other_td->getRFactorDomain().size()) {
return false;
}
for (const auto i : c10::irange(nDims())) {
if (!(axis(i)->sameAs(other_td->axis(i)))) {
return false;
}
}
for (const auto i : c10::irange(getRootDomain().size())) {
if (!(getRootDomain()[i]->sameAs(other_td->getRootDomain()[i]))) {
return false;
}
}
for (const auto i : c10::irange(getRFactorDomain().size())) {
if (!(getRFactorDomain()[i]->sameAs(other_td->getRFactorDomain()[i]))) {
return false;
}
}
return true;
}
bool TensorDomain::sameAs(
const std::vector<IterDomain*>& lhs,
const std::vector<IterDomain*>& rhs) {
if (lhs.size() != rhs.size())
return false;
size_t i = 0;
for (auto td_lhs : lhs) {
if (!td_lhs->sameAs(rhs[i++]))
return false;
}
return true;
}
void TensorDomain::setContiguity(const std::vector<bool>& contig) {
TORCH_INTERNAL_ASSERT(
getMaybeRFactorDomain().size() == contig.size(),
"Invalid contiguity vector: ",
contig);
contiguity_ = contig;
}
bool TensorDomain::hasReduction() const {
return has_nontrivial_reduction_;
}
bool TensorDomain::hasBlockReduction() const {
return std::any_of(domain_.begin(), domain_.end(), [](IterDomain* id) {
return id->isReduction() && id->isThreadDim();
});
}
bool TensorDomain::hasGridReduction() const {
return std::any_of(domain_.begin(), domain_.end(), [](IterDomain* id) {
return id->isReduction() && id->isBlockDim();
});
}
bool TensorDomain::hasBroadcast() const {
return no_bcast_domain_.size() != domain_.size();
}
bool TensorDomain::hasRFactor() const {
return !rfactor_domain_.empty();
}
bool TensorDomain::hasViewLikeRFactor() const {
if (!hasRFactor()) {
// Can't have view like rfactor if there is no rfactor domain
return false;
}
// If there's an rfactor domain and no rfactor product is a reduction, this is
// a view like rfactor
return std::none_of(
getMaybeRFactorDomain().begin(),
getMaybeRFactorDomain().end(),
[](IterDomain* id) {
return id->isReduction() && id->isRFactorProduct();
});
}
bool TensorDomain::hasVectorize() const {
return std::any_of(domain_.begin(), domain_.end(), [](IterDomain* id) {
return id->getParallelType() == ParallelType::Vectorize ||
id->getParallelType() == ParallelType::MisalignedVectorize;
});
}
c10::optional<unsigned int> TensorDomain::getReductionAxis() const {
auto it = std::find_if(domain_.begin(), domain_.end(), [](const auto& id) {
return id->isReduction();
});
if (it == domain_.end()) {
return c10::optional<unsigned int>();
} else {
return c10::optional<unsigned int>(std::distance(domain_.begin(), it));
}
}
// i here is int, as we want to accept negative value and ::size_type can be a
// uint.
IterDomain* TensorDomain::axis(int i) const {
TORCH_INTERNAL_ASSERT(
nDims() > 0, "Tried to access an axis in a 0-dim domain");
if (i < 0)
i += nDims();
TORCH_CHECK(
i >= 0 && (unsigned int)i < nDims(),
"Tried to access axis ",
i,
" in domain ",
this);
return domain_[i];
}
size_t TensorDomain::posOf(IterDomain* id) const {
TORCH_INTERNAL_ASSERT(nDims() > 0, "Tried to find an axis in a 0-dim domain");
size_t i = 0;
while (i < domain_.size()) {
if (domain_[i] == id)
return i;
i++;
}
TORCH_CHECK(false, "Provided id is not part of this domain.");
}
size_t TensorDomain::rootPosOf(IterDomain* id) const {
TORCH_INTERNAL_ASSERT(
root_domain_.size() > 0, "Tried to find an axis in a 0-dim root domain");
auto it = std::find(root_domain_.begin(), root_domain_.end(), id);
TORCH_INTERNAL_ASSERT(
it != root_domain_.end(), "Provided id is not part of root domain.");
return std::distance(root_domain_.begin(), it);
}
void TensorDomain::split(
int axis_,
Val* factor,
bool inner_split,
bool trim_out_of_bounds) {
TORCH_INTERNAL_ASSERT(nDims() > 0, "Tried to do split on a 0-dim domain");
if (axis_ < 0)
axis_ += nDims();
TORCH_INTERNAL_ASSERT(
axis_ >= 0 && (unsigned int)axis_ < nDims(),
"Tried to split on axis outside TensorDomain's range.");
IterDomain* id = axis(axis_);
// partial split is only allowed with root domains
if (trim_out_of_bounds) {
TORCH_INTERNAL_ASSERT(
std::find(getRootDomain().begin(), getRootDomain().end(), id) !=
getRootDomain().end(),
"Partial split is only allowed with root domains");
}
TORCH_INTERNAL_ASSERT(
!id->isMmaSwizzled(),
"Further transformation on warp mapped id's not allowed.");
auto split_ids =
IterDomain::split(id, factor, inner_split, trim_out_of_bounds);
domain_.erase(domain_.begin() + axis_);
domain_.insert(domain_.begin() + axis_, split_ids.second);
domain_.insert(domain_.begin() + axis_, split_ids.first);
resetDomains();
}
// Merge "axis" and "axis+1" into 1 dimension
void TensorDomain::merge(int axis_o, int axis_i) {
TORCH_INTERNAL_ASSERT(nDims() > 0, "Tried to do merge on a 0-dim domain");
if (axis_o < 0)
axis_o += nDims();
if (axis_i < 0)
axis_i += nDims();
TORCH_CHECK(
axis_o >= 0 && (unsigned int)axis_o < nDims() && axis_i >= 0 &&
(unsigned int)axis_i < nDims(),
"Invalid merge detected, either one or both axes are outside of TensorView's range.");
TORCH_CHECK(
axis_o != axis_i,
"Invalid merge detected, axes provided are the same axis.");
if (axis_o > axis_i) {
auto tmp = axis_i;
axis_i = axis_o;
axis_o = tmp;
}
IterDomain* first = axis(axis_o);
IterDomain* second = axis(axis_i);
TORCH_INTERNAL_ASSERT(
!first->isMmaSwizzled() && !second->isMmaSwizzled(),
"Further transformation on warp mapped id's not allowed.");
IterDomain* merged_id = IterDomain::merge(first, second);
domain_.erase(domain_.begin() + axis_i);
domain_.erase(domain_.begin() + axis_o);
domain_.insert(domain_.begin() + axis_o, merged_id);
resetDomains();
}
// Reorder axes according to map[old_pos] = new_pos
void TensorDomain::reorder(const std::unordered_map<int, int>& old2new_) {
TORCH_INTERNAL_ASSERT(
!(nDims() == 0 && old2new_.size() > 0),
"Tried to reorder a 0-dim domain");
domain_ = orderedAs(domain_, old2new_);
resetDomains();
}
std::vector<IterDomain*> TensorDomain::orderedAs(
const std::vector<IterDomain*>& dom,
const std::unordered_map<int, int>& old2new_) {
TORCH_INTERNAL_ASSERT(
!(dom.size() == 0 && old2new_.size() > 0),
"Tried to reorder a 0-dim domain");
// Eventhough these checks are already in TensorView, we want to redo them as
// we can enter this function from other places, not through TensorView
auto new2old = ir_utils::normalizeOld2New(old2new_, dom.size());
std::vector<IterDomain*> reordered_domain;
std::transform(
new2old.begin(),
new2old.end(),
std::back_inserter(reordered_domain),
[dom](int i) -> IterDomain* { return dom[i]; });
return reordered_domain;
}
std::vector<IterDomain*> TensorDomain::noReductions(
const std::vector<IterDomain*>& td) {
size_t size_out = 0;
for (auto id : td) {
if (!id->isReduction() && !id->isStride()) {
size_out++;
}
}
std::vector<IterDomain*> noReductionDomain(size_out);
int it = 0;
for (auto id : td) {
if (!id->isReduction() && !id->isStride()) {
noReductionDomain[it++] = id;
}
}
return noReductionDomain;
}
std::vector<IterDomain*> TensorDomain::noBroadcasts(
const std::vector<IterDomain*>& td) {
size_t size_out = 0;
for (auto id : td)
if (!id->isBroadcast())
size_out++;
std::vector<IterDomain*> noBroadcastDomain(size_out);
int it = 0;
for (auto id : td)
if (!id->isBroadcast())
noBroadcastDomain[it++] = id;
return noBroadcastDomain;
}
bool TensorDomain::hasBroadcast(const std::vector<IterDomain*>& td) {
for (auto id : td)
if (id->isBroadcast())
return true;
return false;
}
bool TensorDomain::hasReduction(const std::vector<IterDomain*>& td) {
for (auto id : td)
if (id->isReduction())
return true;
return false;
}
bool TensorDomain::hasNontrivialReduction(const std::vector<IterDomain*>& td) {
for (auto id : td) {
if (id->isReduction() && !id->isTrivialReduction()) {
return true;
}
}
return false;
}
TensorDomain* TensorDomain::view(
const std::vector<std::shared_ptr<ViewTransform>>& transforms) {
TORCH_INTERNAL_ASSERT(nDims() > 0, "Tried to view transform a 0-dim domain");
return transformView(this, transforms);
}
TensorDomain* TensorDomain::flatten(int64_t start_dim, int64_t end_dim) {
if (start_dim < 0) {
start_dim += nDims();
}
if (end_dim < 0) {
end_dim += nDims();
}
std::vector<IterDomain*> new_root_domain;
auto inp_domain = noReductions(getMaybeRFactorDomain());
new_root_domain.reserve(inp_domain.size());
for (auto id : inp_domain) {
new_root_domain.push_back(id->cloneWithoutRFactor());
}
std::vector<IterDomain*> rfactor_domain;
rfactor_domain.reserve(new_root_domain.size() - (end_dim - start_dim));
for (auto i : c10::irange(start_dim)) {
rfactor_domain.push_back(new_root_domain[i]);
}
IterDomain* merged_id = new_root_domain[start_dim];
for (auto i : c10::irange(start_dim + 1, end_dim + 1)) {
IterDomain* new_merged_id = IrBuilder::create<IterDomain>(
merged_id->container(),
merged_id->container()->zeroVal(),
mul(merged_id->extent(), new_root_domain[i]->extent()),
ParallelType::Serial,
IterType::Iteration,
true);
IrBuilder::create<Merge>(new_merged_id, merged_id, new_root_domain[i]);
merged_id = new_merged_id;
}
rfactor_domain.push_back(merged_id);
for (auto i : c10::irange(end_dim + 1, nDims())) {
rfactor_domain.push_back(new_root_domain[i]);
}
return IrBuilder::create<TensorDomain>(
new_root_domain,
rfactor_domain,
rfactor_domain,
std::vector<bool>(rfactor_domain.size(), true));
}
// TODO: Rfactor a Welford
// pair is in order where second is the consumer of first
std::pair<TensorDomain*, TensorDomain*> TensorDomain::rFactor(
const std::vector<int>& axes_) {
return TransformRFactor::runReplay(this, axes_);
}
Split::Split(
IrBuilderPasskey passkey,
IterDomain* outer,
IterDomain* inner,
IterDomain* in,
Val* factor,
bool inner_split,
Val* start_offset,
Val* stop_offset)
: Expr(passkey, ExprType::Split),
outer_{outer},
inner_{inner},
in_{in},
factor_{factor},
inner_split_{inner_split},
start_offset_{
start_offset != nullptr ? start_offset
: passkey.ir_container_->zeroVal()},
stop_offset_{
stop_offset != nullptr ? stop_offset
: passkey.ir_container_->zeroVal()} {
TORCH_INTERNAL_ASSERT(
factor_->isAnInt(),
"Attempted to create a Split node with a non-integer factor.");
addOutput(outer);
addOutput(inner);
addInput(in);
// TODO add factor as an input, need to check Split::Split during validation
// and need to check BestEffortReplay::findFirstMismatchedID addInput(factor);
}
Split::Split(const Split* src, IrCloner* ir_cloner)
: Expr(src, ir_cloner),
outer_(ir_cloner->clone(src->outer_)),
inner_(ir_cloner->clone(src->inner_)),
in_(ir_cloner->clone(src->in_)),
factor_(ir_cloner->clone(src->factor_)),
inner_split_(src->inner_split_),
start_offset_(ir_cloner->clone(src->start_offset_)),
stop_offset_(ir_cloner->clone(src->stop_offset_)) {}
Val* Split::extent(Val* in_extent, Val* start_offset, Val* stop_offset) {
TORCH_INTERNAL_ASSERT(in_extent != nullptr);
if (start_offset != nullptr && !start_offset->isZeroInt()) {
in_extent = sub(in_extent, start_offset);
}
if (stop_offset != nullptr && !stop_offset->isZeroInt()) {
in_extent = sub(in_extent, stop_offset);
}
return in_extent;
}
bool Split::sameAs(const Statement* other) const {
if (this == other) {
return true;
}
if (!other->isA<Split>()) {
return false;
}
return Expr::sameAs(other) &&
factor()->sameAs(other->as<Split>()->factor()) &&
innerSplit() == other->as<Split>()->innerSplit() &&
startOffset()->sameAs(other->as<Split>()->startOffset()) &&
stopOffset()->sameAs(other->as<Split>()->stopOffset());
}
Merge::Merge(
IrBuilderPasskey passkey,
IterDomain* out,
IterDomain* outer,
IterDomain* inner)
: Expr(passkey, ExprType::Merge), out_{out}, outer_{outer}, inner_{inner} {
addOutput(out);
addInput(outer);
addInput(inner);
}
Merge::Merge(const Merge* src, IrCloner* ir_cloner)
: Expr(src, ir_cloner),
out_(ir_cloner->clone(src->out_)),
outer_(ir_cloner->clone(src->outer_)),
inner_(ir_cloner->clone(src->inner_)) {}
bool Merge::sameAs(const Statement* other) const {
if (this == other) {
return true;
}
if (!other->isA<Merge>()) {
return false;
}
return Expr::sameAs(other);
}
NamedScalar::NamedScalar(
IrBuilderPasskey passkey,
std::string name,
DataType dtype)
: Val(passkey, ValType::NamedScalar, dtype), name_(std::move(name)) {}
NamedScalar::NamedScalar(const NamedScalar* src, IrCloner* ir_cloner)
: Val(src, ir_cloner), name_(src->name_) {}
bool NamedScalar::sameAs(const Statement* other) const {
if (this == other) {
return true;
}
if (!other->isA<NamedScalar>()) {
return false;
}
return other->as<NamedScalar>()->name().compare(name()) == 0;
}
NamedScalar* NamedScalar::getParallelDim(ParallelType p_type) {
TORCH_INTERNAL_ASSERT(
isParallelTypeThread(p_type),
"Cannot get parallel dim of non thread type, received: ",
p_type);
TORCH_INTERNAL_ASSERT(FusionGuard::getCurFusion() != nullptr);
std::string parallel_dim = stringifyThreadSize(p_type);
return IrBuilder::create<NamedScalar>(parallel_dim, DataType::Int);
}
NamedScalar* NamedScalar::getParallelIndex(ParallelType p_type) {
TORCH_INTERNAL_ASSERT(FusionGuard::getCurFusion() != nullptr);
std::string parallel_ind = stringifyThread(p_type);
return IrBuilder::create<NamedScalar>(parallel_ind, DataType::Int);
}
c10::optional<ParallelType> NamedScalar::getParallelDim() const {
if (stringifyThreadSize(ParallelType::TIDx).compare(name()) == 0) {
return c10::optional<ParallelType>(ParallelType::TIDx);
} else if (stringifyThreadSize(ParallelType::TIDy).compare(name()) == 0) {
return c10::optional<ParallelType>(ParallelType::TIDy);
} else if (stringifyThreadSize(ParallelType::TIDz).compare(name()) == 0) {
return c10::optional<ParallelType>(ParallelType::TIDz);
} else if (stringifyThreadSize(ParallelType::BIDx).compare(name()) == 0) {
return c10::optional<ParallelType>(ParallelType::BIDx);
} else if (stringifyThreadSize(ParallelType::BIDy).compare(name()) == 0) {
return c10::optional<ParallelType>(ParallelType::BIDy);
} else if (stringifyThreadSize(ParallelType::BIDz).compare(name()) == 0) {
return c10::optional<ParallelType>(ParallelType::BIDz);
}
return c10::nullopt;
}
c10::optional<ParallelType> NamedScalar::getParallelIndex() const {
if (stringifyThread(ParallelType::TIDx).compare(name()) == 0) {
return c10::optional<ParallelType>(ParallelType::TIDx);
} else if (stringifyThread(ParallelType::TIDy).compare(name()) == 0) {
return c10::optional<ParallelType>(ParallelType::TIDy);
} else if (stringifyThread(ParallelType::TIDz).compare(name()) == 0) {
return c10::optional<ParallelType>(ParallelType::TIDz);
} else if (stringifyThread(ParallelType::BIDx).compare(name()) == 0) {
return c10::optional<ParallelType>(ParallelType::BIDx);
} else if (stringifyThread(ParallelType::BIDy).compare(name()) == 0) {
return c10::optional<ParallelType>(ParallelType::BIDy);
} else if (stringifyThread(ParallelType::BIDz).compare(name()) == 0) {
return c10::optional<ParallelType>(ParallelType::BIDz);
}
return c10::nullopt;
}
} // namespace cuda
} // namespace fuser
} // namespace jit
} // namespace torch