mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/65345 FooType::get() can return a const reference. Inconveniently, converting shared_ptr<FooType> to shared_ptr<Type> requires a copy & refcount bump, so to properly take advantage of this in unshapedType() we need to take a const Type& in isSubtypeOf(), which is good practice anyway -- don't require a shared_ptr if you don't need to take ownership. ghstack-source-id: 140044165 Test Plan: CI perf says c10::unshapedType time decreased from 2.8% to 2.2% during static runtime startup, though I expect this to be generally beneficial. Reviewed By: hlu1 Differential Revision: D31027361 fbshipit-source-id: 676feb81db9f74ad7b8651d8774f4ecb4cfa6ab8
481 lines
16 KiB
C++
481 lines
16 KiB
C++
#include <torch/csrc/jit/codegen/cuda/shape_inference.h>
|
|
|
|
#include <c10/core/ScalarType.h>
|
|
#include <torch/csrc/jit/codegen/cuda/instrumentation.h>
|
|
#include <torch/csrc/jit/ir/constants.h>
|
|
#include <torch/csrc/jit/runtime/operator.h>
|
|
|
|
#include <ATen/ExpandUtils.h>
|
|
#include <ATen/core/jit_type.h>
|
|
|
|
namespace torch {
|
|
namespace jit {
|
|
namespace fuser {
|
|
namespace cuda {
|
|
|
|
namespace {
|
|
|
|
bool hasTypeAndDevice(const TensorTypePtr& op) {
|
|
return op->device().has_value() && op->scalarType().has_value();
|
|
}
|
|
|
|
/* NaiveTypePropagator
|
|
* Populate type/device tag on tensor, this is a transition module to
|
|
* cover the absence of type inference in codegen cuda fuser.
|
|
*
|
|
* We only cover operations supported in codegen. We focus on propagate concrete
|
|
* types.
|
|
* It does NOT handle aliases (not supported in codegen anyway); Type promotion
|
|
* is not guaranteed to be consistent with PyTorch (we need to serve the need of
|
|
* codegen instead).
|
|
*/
|
|
class NaiveTypePropagator {
|
|
public:
|
|
NaiveTypePropagator(std::shared_ptr<Graph> graph)
|
|
: graph_(std::move(graph)) {}
|
|
|
|
void PropagateOnBlock(Block* block) {
|
|
for (Node* node : block->nodes()) {
|
|
PropagateOnNode(node);
|
|
}
|
|
}
|
|
|
|
void PropagateOnNode(Node* node) {
|
|
switch (node->kind()) {
|
|
// Constant:
|
|
case prim::Constant: {
|
|
if (node->output()->type()->isSubtypeOf(*TensorType::get())) {
|
|
node->output()->inferTypeFrom(node->t(attr::value));
|
|
}
|
|
break;
|
|
}
|
|
// unary operations that forward meta info:
|
|
case aten::neg:
|
|
case aten::bitwise_not:
|
|
case aten::abs:
|
|
case aten::log:
|
|
case aten::log10:
|
|
case aten::log1p:
|
|
case aten::log2:
|
|
case aten::lgamma:
|
|
case aten::exp:
|
|
case aten::expm1:
|
|
case aten::erf:
|
|
case aten::erfc:
|
|
case aten::cos:
|
|
case aten::acos:
|
|
case aten::cosh:
|
|
case aten::sin:
|
|
case aten::asin:
|
|
case aten::sinh:
|
|
case aten::tan:
|
|
case aten::atan:
|
|
case aten::sqrt:
|
|
case aten::rsqrt:
|
|
case aten::ceil:
|
|
case aten::floor:
|
|
case aten::round:
|
|
case aten::trunc:
|
|
case aten::frac:
|
|
case aten::reciprocal:
|
|
case aten::relu:
|
|
case aten::sigmoid:
|
|
case aten::threshold:
|
|
case aten::softplus:
|
|
case aten::clamp:
|
|
case aten::gelu:
|
|
case aten::gelu_backward:
|
|
case aten::silu:
|
|
case aten::tanh: {
|
|
TORCH_CHECK(
|
|
hasTypeAndDevice(node->input(0)->type()->cast<TensorType>()),
|
|
"Type and device propagation has failed, or was not provided enough information.");
|
|
node->output()->setType(node->input(0)->type()->cast<TensorType>());
|
|
break;
|
|
}
|
|
// TODO: rand_like should support cast.
|
|
case aten::rand_like: {
|
|
TORCH_CHECK(
|
|
hasTypeAndDevice(node->input(0)->type()->cast<TensorType>()),
|
|
"Type and device propagation has failed, or was not provided enough information.");
|
|
node->output()->setType(node->input(0)->type()->cast<TensorType>());
|
|
break;
|
|
}
|
|
// binary operations that forward meta info and broadcast shape:
|
|
case aten::mul:
|
|
case aten::div:
|
|
case aten::atan2:
|
|
// TODO: double check type casting logic for min/max/pow
|
|
case aten::min:
|
|
case aten::max:
|
|
case aten::pow:
|
|
case aten::remainder:
|
|
case aten::fmod:
|
|
case aten::lerp:
|
|
// add/sub could be ternary op and the third argument does not contribute
|
|
// to neither type promoteion nor shape.
|
|
case aten::add:
|
|
case aten::sub: {
|
|
const auto promoted_type = binary_broadcast_type(
|
|
node->input(0)->type()->cast<TensorType>(),
|
|
node->input(1)->type()->cast<TensorType>());
|
|
node->output()->setType(promoted_type);
|
|
break;
|
|
}
|
|
// Type can be int or bool for "and" and "or", if both are bool should be
|
|
// bool, if both int should be int, otherwise would have errored
|
|
case aten::__and__:
|
|
case aten::__or__: {
|
|
const auto promoted_type = binary_broadcast_type(
|
|
node->input(0)->type()->cast<TensorType>(),
|
|
node->input(1)->type()->cast<TensorType>(),
|
|
node->input(0)->type()->cast<TensorType>()->scalarType() ==
|
|
at::ScalarType::Bool
|
|
? at::ScalarType::Bool
|
|
: at::ScalarType::Int);
|
|
break;
|
|
}
|
|
// Real int ops
|
|
case aten::__xor__:
|
|
case aten::__lshift__:
|
|
case aten::__rshift__: {
|
|
const auto promoted_type = binary_broadcast_type(
|
|
node->input(0)->type()->cast<TensorType>(),
|
|
node->input(1)->type()->cast<TensorType>(),
|
|
at::ScalarType::Int);
|
|
node->output()->setType(promoted_type);
|
|
break;
|
|
}
|
|
case aten::lt:
|
|
case aten::le:
|
|
case aten::gt:
|
|
case aten::ge:
|
|
case aten::ne:
|
|
case aten::eq: {
|
|
const auto promoted_type = binary_broadcast_type(
|
|
node->input(0)->type()->cast<TensorType>(),
|
|
node->input(1)->type()->cast<TensorType>(),
|
|
at::ScalarType::Bool);
|
|
node->output()->setType(promoted_type);
|
|
break;
|
|
}
|
|
case aten::where: {
|
|
const auto promoted_type = binary_broadcast_type(
|
|
node->input(1)->type()->cast<TensorType>(),
|
|
node->input(2)->type()->cast<TensorType>());
|
|
node->output()->setType(promoted_type);
|
|
break;
|
|
}
|
|
case aten::addcmul: {
|
|
auto promoted_type = binary_broadcast_type(
|
|
node->input(1)->type()->cast<TensorType>(),
|
|
node->input(2)->type()->cast<TensorType>());
|
|
promoted_type = binary_broadcast_type(
|
|
promoted_type, node->input(0)->type()->cast<TensorType>());
|
|
node->output()->setType(promoted_type);
|
|
break;
|
|
}
|
|
case aten::dropout: {
|
|
auto out_type = node->input(0)->type()->cast<TensorType>();
|
|
node->output()->setType(out_type);
|
|
break;
|
|
}
|
|
case aten::instance_norm:
|
|
case aten::batch_norm: {
|
|
auto out_type = node->input(0)->type()->cast<TensorType>();
|
|
node->output()->setType(out_type);
|
|
break;
|
|
}
|
|
case aten::_batch_norm_impl_index_backward: {
|
|
auto grad_input_type = node->input(1)->type()->cast<TensorType>();
|
|
TORCH_CHECK(
|
|
hasTypeAndDevice(grad_input_type),
|
|
"Type and device propagation has failed, or was not provided enough information.");
|
|
node->output(0)->setType(grad_input_type);
|
|
|
|
// TODO: double check with type promotion
|
|
auto mean_rstd_type = TensorType::create(
|
|
*grad_input_type->scalarType(),
|
|
*grad_input_type->device(),
|
|
c10::nullopt,
|
|
c10::nullopt);
|
|
|
|
node->output(1)->setType(mean_rstd_type);
|
|
node->output(2)->setType(mean_rstd_type);
|
|
|
|
break;
|
|
}
|
|
case aten::_batch_norm_impl_index: {
|
|
auto out_type = node->input(0)->type()->cast<TensorType>();
|
|
TORCH_CHECK(
|
|
hasTypeAndDevice(out_type),
|
|
"Type and device propagation has failed, or was not provided enough information.");
|
|
node->output(0)->setType(out_type);
|
|
|
|
auto mean_rstd_type = TensorType::create(
|
|
*out_type->scalarType(),
|
|
*out_type->device(),
|
|
c10::nullopt,
|
|
c10::nullopt);
|
|
|
|
node->output(1)->setType(mean_rstd_type);
|
|
node->output(2)->setType(mean_rstd_type);
|
|
// TODO: not that it matters, but mark the right type here;
|
|
// node->output(3)->setType(out_type->withScalarType());
|
|
node->output(3)->setType(out_type);
|
|
node->output(4)->setType(IntType::get());
|
|
|
|
break;
|
|
}
|
|
case aten::native_batch_norm: {
|
|
auto out_type = node->input(0)->type()->cast<TensorType>();
|
|
TORCH_CHECK(
|
|
hasTypeAndDevice(out_type),
|
|
"Type and device propagation has failed, or was not provided enough information.");
|
|
node->output(0)->setType(out_type);
|
|
|
|
auto mean_rstd_type = TensorType::create(
|
|
*out_type->scalarType(),
|
|
*out_type->device(),
|
|
c10::nullopt,
|
|
c10::nullopt);
|
|
|
|
node->output(1)->setType(mean_rstd_type);
|
|
node->output(2)->setType(mean_rstd_type);
|
|
|
|
break;
|
|
}
|
|
case aten::native_batch_norm_backward: {
|
|
// NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers)
|
|
auto out_mask_list = constant_as<c10::List<bool>>(node->input(9));
|
|
TORCH_INTERNAL_ASSERT(
|
|
out_mask_list.has_value(), "output mask for batch_norm_backward");
|
|
std::vector<int> output_mask;
|
|
for (const auto value : out_mask_list->vec()) {
|
|
output_mask.emplace_back(static_cast<int>(value));
|
|
}
|
|
|
|
if (output_mask[0]) {
|
|
auto in_type = node->input(1)->type()->cast<TensorType>();
|
|
node->output(0)->setType(in_type);
|
|
}
|
|
|
|
if (output_mask[1]) {
|
|
auto weight_type = node->input(2)->type()->cast<TensorType>();
|
|
node->output(1)->setType(weight_type);
|
|
}
|
|
|
|
if (output_mask[2]) {
|
|
auto weight_type = node->input(2)->type()->cast<TensorType>();
|
|
auto bias_type = TensorType::create(
|
|
*weight_type->scalarType(),
|
|
*weight_type->device(),
|
|
*weight_type->dim(),
|
|
output_mask[2]);
|
|
node->output(2)->setType(bias_type);
|
|
}
|
|
break;
|
|
}
|
|
case aten::layer_norm: {
|
|
auto out_type = node->input(0)->type()->cast<TensorType>();
|
|
node->output()->setType(out_type);
|
|
break;
|
|
}
|
|
case aten::native_layer_norm: {
|
|
auto out_type = node->input(0)->type()->cast<TensorType>();
|
|
TORCH_CHECK(
|
|
hasTypeAndDevice(out_type),
|
|
"Type and device propagation has failed, or was not provided enough information.");
|
|
node->output(0)->setType(out_type);
|
|
|
|
auto mean_rstd_type = TensorType::create(
|
|
*out_type->scalarType(), *out_type->device(), c10::nullopt, false);
|
|
|
|
node->output(1)->setType(mean_rstd_type);
|
|
node->output(2)->setType(mean_rstd_type);
|
|
|
|
break;
|
|
}
|
|
case aten::native_layer_norm_backward: {
|
|
// NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers)
|
|
auto out_mask_list = constant_as<c10::List<bool>>(node->input(7));
|
|
TORCH_INTERNAL_ASSERT(
|
|
out_mask_list.has_value(), "output mask for layer_norm_backward");
|
|
std::vector<int> output_mask;
|
|
for (const auto value : out_mask_list->vec()) {
|
|
output_mask.emplace_back(static_cast<int>(value));
|
|
}
|
|
|
|
if (output_mask[0]) {
|
|
auto out_type = node->input(0)->type()->cast<TensorType>();
|
|
node->output(0)->setType(out_type);
|
|
}
|
|
|
|
if (output_mask[1] &&
|
|
// NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers)
|
|
!node->input(5)->type()->isSubtypeOf(
|
|
static_cast<c10::TypePtr>(NoneType::get()))) {
|
|
// NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers)
|
|
auto weight_type = node->input(5)->type()->cast<TensorType>();
|
|
node->output(1)->setType(weight_type);
|
|
}
|
|
|
|
if (output_mask[2] &&
|
|
// NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers)
|
|
!node->input(6)->type()->isSubtypeOf(
|
|
static_cast<c10::TypePtr>(NoneType::get()))) {
|
|
// NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers)
|
|
auto bias_type = node->input(6)->type()->cast<TensorType>();
|
|
node->output(2)->setType(bias_type);
|
|
}
|
|
break;
|
|
}
|
|
case aten::softmax: {
|
|
auto out_type = node->input(0)->type()->cast<TensorType>();
|
|
|
|
// accept dtype input to `aten::softmax` node
|
|
if (!node->input(2)->type()->isSubtypeOf(
|
|
static_cast<c10::TypePtr>(NoneType::get()))) {
|
|
if (auto opt_ivalue = toIValue(node->input(2))) {
|
|
out_type = out_type->withScalarType(opt_ivalue->toScalarType());
|
|
}
|
|
}
|
|
node->output()->setType(out_type);
|
|
break;
|
|
}
|
|
case aten::_softmax_backward_data: {
|
|
auto out_type = node->input(0)->type()->cast<TensorType>();
|
|
node->output()->setType(out_type);
|
|
break;
|
|
}
|
|
case aten::mean:
|
|
case aten::sum: {
|
|
auto out_type = node->input(0)->type()->cast<TensorType>();
|
|
|
|
// accept dtype input to `aten::sum` node
|
|
if (!node->input(3)->type()->isSubtypeOf(
|
|
static_cast<c10::TypePtr>(NoneType::get()))) {
|
|
if (auto opt_ivalue = toIValue(node->input(3))) {
|
|
out_type = out_type->withScalarType(opt_ivalue->toScalarType());
|
|
}
|
|
}
|
|
const auto dims = constant_as<c10::List<int64_t>>(node->input(1));
|
|
const auto keepdim = constant_as<bool>(node->input(2));
|
|
TORCH_CHECK(
|
|
dims.has_value() && keepdim.has_value(),
|
|
"Shape inference cannot handle options.");
|
|
node->output()->setType(
|
|
unary_reduce_type(out_type, dims->vec(), keepdim.value()));
|
|
break;
|
|
}
|
|
case aten::sum_to_size:
|
|
case aten::_grad_sum_to_size: {
|
|
auto out_type = node->input(0)->type()->cast<TensorType>();
|
|
node->output()->setType(out_type->withDim(c10::nullopt));
|
|
break;
|
|
}
|
|
case aten::type_as: {
|
|
const auto type0 = node->input(0)->type()->cast<TensorType>();
|
|
const auto type1 = node->input(1)->type()->cast<TensorType>();
|
|
TORCH_CHECK(
|
|
type0 != nullptr && type1 != nullptr &&
|
|
type1->scalarType().has_value(),
|
|
"input to type_as needs to be a tensor");
|
|
node->output()->setType(type0->withScalarType(type1->scalarType()));
|
|
break;
|
|
}
|
|
case aten::to: {
|
|
const auto type0 = node->input(0)->type()->cast<TensorType>();
|
|
const auto out_dtype = toIValue(node->input(1));
|
|
TORCH_CHECK(out_dtype, "No output type specified");
|
|
node->output()->setType(
|
|
type0->withScalarType(out_dtype->toScalarType()));
|
|
break;
|
|
}
|
|
case prim::add_optional: {
|
|
const auto type0 = node->input(0)->type()->cast<TensorType>();
|
|
const auto type1 = node->input(1)->type()->cast<TensorType>();
|
|
TORCH_CHECK(type0 != nullptr);
|
|
if (type1 != nullptr) {
|
|
node->output()->setType(type0);
|
|
} else {
|
|
const auto promoted_type = binary_broadcast_type(type0, type1);
|
|
node->output()->setType(promoted_type);
|
|
}
|
|
break;
|
|
}
|
|
default:
|
|
TORCH_CHECK(
|
|
false,
|
|
"type inference failed, unrecognized operation encountered:",
|
|
node->kind().toDisplayString());
|
|
// TODO: generate a proper error log, as this probably means something
|
|
// went unexpected.
|
|
break;
|
|
}
|
|
}
|
|
|
|
void run() {
|
|
PropagateOnBlock(graph_->block());
|
|
}
|
|
|
|
protected:
|
|
TensorTypePtr unary_reduce_type(
|
|
const TensorTypePtr& op,
|
|
const std::vector<int64_t>& dims,
|
|
bool keepdim) {
|
|
TORCH_CHECK(
|
|
hasTypeAndDevice(op),
|
|
"Type and device propagation has failed, or was not provided enough information.");
|
|
return TensorType::create(
|
|
*op->scalarType(), *op->device(), c10::nullopt, c10::nullopt);
|
|
}
|
|
|
|
// TODO: we should comply to codegen type promotion.
|
|
TensorTypePtr binary_broadcast_type(
|
|
TensorTypePtr const& op0,
|
|
TensorTypePtr const& op1,
|
|
c10::optional<at::ScalarType> scalar_type = c10::nullopt) {
|
|
TORCH_CHECK(
|
|
op0 != nullptr || op1 != nullptr,
|
|
"Scalar operations on binary broadcast type, not supported yet.");
|
|
|
|
if (op0 != nullptr && op1 != nullptr) {
|
|
TORCH_CHECK(
|
|
hasTypeAndDevice(op0) && hasTypeAndDevice(op1),
|
|
"Type and device propagation has failed, or was not provided enough information.");
|
|
auto promoted_scalar_type = scalar_type.has_value()
|
|
? *scalar_type
|
|
: c10::promoteTypes(*op0->scalarType(), *op1->scalarType());
|
|
|
|
return TensorType::create(
|
|
promoted_scalar_type, *op0->device(), c10::nullopt, c10::nullopt);
|
|
} else {
|
|
auto ptr = (op0 != nullptr) ? op0 : op1;
|
|
TORCH_CHECK(
|
|
hasTypeAndDevice(ptr),
|
|
"Type and device propagation has failed, or was not provided enough information.");
|
|
return TensorType::create(
|
|
scalar_type.has_value() ? *scalar_type : *ptr->scalarType(),
|
|
*ptr->device(),
|
|
c10::nullopt,
|
|
c10::nullopt);
|
|
}
|
|
}
|
|
|
|
private:
|
|
std::shared_ptr<Graph> graph_;
|
|
};
|
|
|
|
} // namespace
|
|
|
|
void TypePropagate(std::shared_ptr<Graph>& graph) {
|
|
FUSER_PERF_SCOPE("TypePropagate");
|
|
NaiveTypePropagator(graph).run();
|
|
}
|
|
|
|
} // namespace cuda
|
|
} // namespace fuser
|
|
} // namespace jit
|
|
} // namespace torch
|