mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-21 05:34:18 +08:00
Summary: `IValue::toString()` creates a `new c10::intrusive_ptr` (like `std::shared_ptr`) and `->string()` immediately accesses it, creating an atomic reference increment/decrement. We can skip both of these operations by calling `IValue::toStringRef()`. Test Plan: CI Reviewed By: jaybean-dev Differential Revision: D39605242 Pull Request resolved: https://github.com/pytorch/pytorch/pull/85437 Approved by: https://github.com/jfix71
218 lines
6.8 KiB
C++
218 lines
6.8 KiB
C++
#include <torch/csrc/jit/ir/constants.h>
|
|
|
|
#include <ATen/core/functional.h>
|
|
#include <torch/csrc/autograd/variable.h>
|
|
#include <torch/csrc/jit/ir/ir.h>
|
|
#include <torch/csrc/jit/runtime/custom_operator.h>
|
|
#include <torch/csrc/jit/runtime/operator.h>
|
|
|
|
namespace torch {
|
|
namespace jit {
|
|
|
|
bool insertableTensor(const at::Tensor& ten) {
|
|
// bail if tensor has no storage i.e. opaque tensor used in MKLdnn.
|
|
// or gradients because we have no way of serializing them & are mutable
|
|
return !ten.requires_grad() && ten.has_storage() && !ten.is_nested();
|
|
}
|
|
|
|
bool insertableIValue(const IValue& ivalue) {
|
|
if (ivalue.isInt() || ivalue.isNone() || ivalue.isBool() ||
|
|
ivalue.isDouble() || ivalue.isComplexDouble() || ivalue.isString() ||
|
|
ivalue.isDevice() || ivalue.isEnum()) {
|
|
return true;
|
|
}
|
|
if (ivalue.isTensor()) {
|
|
return insertableTensor(ivalue.toTensor());
|
|
}
|
|
if (ivalue.isList() || ivalue.isTuple()) {
|
|
c10::ArrayRef<IValue> elems;
|
|
if (ivalue.isTuple()) {
|
|
elems = ivalue.toTupleRef().elements();
|
|
} else {
|
|
elems = ivalue.toListRef();
|
|
}
|
|
return std::all_of(elems.begin(), elems.end(), [](const IValue& tup_elem) {
|
|
return insertableIValue(tup_elem);
|
|
});
|
|
}
|
|
if (ivalue.isGenericDict()) {
|
|
const auto& dict = ivalue.toGenericDict();
|
|
return std::all_of(dict.begin(), dict.end(), [](const auto& entry) {
|
|
return insertableIValue(entry.key()) && insertableIValue(entry.value());
|
|
});
|
|
}
|
|
|
|
return false;
|
|
}
|
|
|
|
Value* insertConstant(
|
|
Graph& g,
|
|
const IValue& val,
|
|
c10::optional<SourceRange> loc,
|
|
c10::optional<ScopePtr> scope) {
|
|
auto value = tryInsertConstant(g, val, std::move(loc), std::move(scope));
|
|
if (value) {
|
|
return *value;
|
|
}
|
|
throw constant_not_supported_error(
|
|
"Unsupported value kind: " + val.tagKind());
|
|
}
|
|
|
|
// IValue -> Constant node
|
|
c10::optional<Value*> tryInsertConstant(
|
|
Graph& g,
|
|
const IValue& val,
|
|
c10::optional<SourceRange> loc,
|
|
c10::optional<ScopePtr> scope) {
|
|
Node* n = g.create(prim::Constant);
|
|
if (val.isTensor()) {
|
|
at::Tensor ref = val.toTensor();
|
|
if (!insertableTensor(val.toTensor())) {
|
|
n->destroy();
|
|
return c10::nullopt;
|
|
}
|
|
if (!ref.defined()) {
|
|
n->destroy();
|
|
return g.insertNode(g.createNone())->output();
|
|
}
|
|
TORCH_INTERNAL_ASSERT(!ref.requires_grad());
|
|
n->output()->inferTypeFrom(
|
|
ref); // note: before t_ because of std::move(ref)
|
|
n->t_(attr::value, std::move(ref));
|
|
} else if (val.isInt()) {
|
|
n->i_(attr::value, val.toInt());
|
|
n->output()->setType(IntType::get());
|
|
} else if (val.isDouble()) {
|
|
n->f_(attr::value, val.toDouble());
|
|
n->output()->setType(FloatType::get());
|
|
} else if (val.isComplexDouble()) {
|
|
n->c_(attr::value, val.toComplexDouble());
|
|
n->output()->setType(ComplexType::get());
|
|
} else if (val.isBool()) {
|
|
n->i_(attr::value, val.toBool());
|
|
n->output()->setType(BoolType::get());
|
|
} else if (val.isList()) {
|
|
bool fast_path_list =
|
|
val.isBoolList() || val.isIntList() || val.isDoubleList();
|
|
if (fast_path_list || insertableIValue(val)) {
|
|
n->ival_(attr::value, val);
|
|
n->output()->setType(val.type());
|
|
} else {
|
|
n->destroy();
|
|
return c10::nullopt;
|
|
}
|
|
} else if (val.isString()) {
|
|
n->s_(attr::value, val.toStringRef());
|
|
n->output()->setType(StringType::get());
|
|
} else if (val.isDevice()) {
|
|
std::stringstream ss;
|
|
ss << val.toDevice();
|
|
n->s_(attr::value, ss.str());
|
|
n->output()->setType(DeviceObjType::get());
|
|
} else if (val.isStream()) {
|
|
auto stream = val.toStream();
|
|
n->i_(attr::value, stream.pack());
|
|
n->output()->setType(StreamObjType::get());
|
|
} else if (val.isNone()) {
|
|
n->output()->setType(NoneType::get());
|
|
} else if (val.isTuple()) {
|
|
if (insertableIValue(val)) {
|
|
n->ival_(attr::value, val);
|
|
n->output()->setType(val.type());
|
|
} else {
|
|
n->destroy();
|
|
return c10::nullopt;
|
|
};
|
|
} else if (val.isObject()) {
|
|
const auto& ref = val.toObjectRef();
|
|
// see: [Constant Object Weak CompilationUnit Reference]
|
|
if (!ref.type()->is_module() &&
|
|
(ref.is_weak_compilation_ref() ||
|
|
ref.is_empty_strong_compilation_ref())) {
|
|
n->ival_(attr::value, val);
|
|
n->output()->setType(val.type());
|
|
} else {
|
|
n->destroy();
|
|
return c10::nullopt;
|
|
}
|
|
} else if ((val.isGenericDict() && insertableIValue(val)) || (val.isEnum())) {
|
|
n->ival_(attr::value, val);
|
|
n->output()->setType(val.type());
|
|
} else {
|
|
n->destroy();
|
|
return c10::nullopt;
|
|
}
|
|
if (loc)
|
|
n->setSourceRange(*loc);
|
|
if (scope)
|
|
n->setScope(*scope);
|
|
return g.insertNode(n)->output();
|
|
}
|
|
|
|
c10::optional<IValue> toIValue(const Value* v) {
|
|
if (v->node()->kind() != prim::Constant || v->type()->cast<FunctionType>()) {
|
|
return c10::nullopt;
|
|
}
|
|
const Node* node = v->node();
|
|
const TypePtr& type = v->type();
|
|
if (type->isSubtypeOf(*TensorType::get())) {
|
|
return node->t(attr::value);
|
|
} else if (type->isSubtypeOf(*BoolType::get())) {
|
|
return (bool)node->i(attr::value);
|
|
} else if (
|
|
type->isSubtypeOf(*NumberType::get()) &&
|
|
node->kindOf(attr::value) == AttributeKind::i) {
|
|
return node->i(attr::value);
|
|
} else if (
|
|
type->isSubtypeOf(*NumberType::get()) &&
|
|
node->kindOf(attr::value) == AttributeKind::f) {
|
|
return node->f(attr::value);
|
|
} else if (
|
|
type->isSubtypeOf(*NumberType::get()) &&
|
|
node->kindOf(attr::value) == AttributeKind::c) {
|
|
return node->c(attr::value);
|
|
} else if (
|
|
type->cast<ListType>() &&
|
|
node->kindOf(attr::value) == AttributeKind::ival) {
|
|
const auto& list = node->ival(attr::value);
|
|
TORCH_INTERNAL_ASSERT(list.isList());
|
|
return list;
|
|
} else if (
|
|
type->cast<DictType>() &&
|
|
node->kindOf(attr::value) == AttributeKind::ival) {
|
|
const auto& dict = node->ival(attr::value);
|
|
TORCH_INTERNAL_ASSERT(dict.isGenericDict());
|
|
return dict;
|
|
} else if (
|
|
type->cast<TupleType>() &&
|
|
node->kindOf(attr::value) == AttributeKind::ival) {
|
|
const auto& tup = node->ival(attr::value);
|
|
TORCH_INTERNAL_ASSERT(tup.isTuple());
|
|
return tup;
|
|
} else if (type == StringType::get()) {
|
|
const auto& s = node->s(attr::value);
|
|
return s;
|
|
} else if (type == DeviceObjType::get()) {
|
|
auto d = c10::Device(node->s(attr::value));
|
|
return d;
|
|
} else if (type == StreamObjType::get()) {
|
|
auto s = c10::Stream::unpack(node->i(attr::value));
|
|
return s;
|
|
} else if (node->mustBeNone()) {
|
|
return IValue();
|
|
} else if (type->cast<EnumType>()) {
|
|
const auto& enum_val = node->ival(attr::value);
|
|
return enum_val;
|
|
} else if (type->cast<ClassType>() && !type->is_module()) {
|
|
const auto& class_val = node->ival(attr::value);
|
|
return class_val;
|
|
} else {
|
|
std::stringstream ss;
|
|
ss << "constant literal not supported for: " << type->str();
|
|
throw std::runtime_error(ss.str());
|
|
}
|
|
}
|
|
|
|
} // namespace jit
|
|
} // namespace torch
|