mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
This reverts commit 4146be192ead477360a2763c5005e46a9485c3bf. Reverted https://github.com/pytorch/pytorch/pull/108480 on behalf of https://github.com/huydhn due to Sorry for reverting this, but this is needed to keep trunk green after https://github.com/pytorch/pytorch/pull/108479 was reverted. Both will need to be relanded ([comment](https://github.com/pytorch/pytorch/pull/108480#issuecomment-1707067595))
788 lines
27 KiB
C++
788 lines
27 KiB
C++
#include <torch/csrc/jit/frontend/sugared_value.h>
|
|
|
|
#include <c10/util/irange.h>
|
|
#include <torch/csrc/jit/frontend/schema_matching.h>
|
|
#include <torch/csrc/jit/frontend/tree_views.h>
|
|
#include <torch/csrc/jit/ir/ir.h>
|
|
#include <torch/csrc/jit/passes/constant_propagation.h>
|
|
|
|
namespace torch::jit {
|
|
|
|
struct NoneValue : SugaredValue {
|
|
NoneValue() = default;
|
|
std::string kind() const override {
|
|
return "None";
|
|
}
|
|
};
|
|
|
|
std::shared_ptr<SugaredValue> PrintValue::call(
|
|
const SourceRange& loc,
|
|
GraphFunction& m,
|
|
at::ArrayRef<NamedValue> args,
|
|
at::ArrayRef<NamedValue> kwargs,
|
|
size_t n_binders) {
|
|
auto& g = *m.graph();
|
|
if (!kwargs.empty())
|
|
throw ErrorReport(loc) << "print doesn't accept any keyword arguments";
|
|
|
|
std::vector<Value*> lowered_inputs = toValues(*m.graph(), args);
|
|
g.insertNode(g.create(prim::Print, lowered_inputs, 0)->setSourceRange(loc));
|
|
return std::make_shared<NoneValue>();
|
|
}
|
|
|
|
static const std::unordered_map<std::string, at::ScalarType>&
|
|
builtin_cast_method_to_scalar_type() {
|
|
static std::unordered_map<std::string, at::ScalarType> mapping = {
|
|
{"byte", at::kByte},
|
|
{"char", at::kChar},
|
|
{"double", at::kDouble},
|
|
{"float", at::kFloat},
|
|
{"cfloat", at::kComplexFloat},
|
|
{"cdouble", at::kComplexDouble},
|
|
{"int", at::kInt},
|
|
{"long", at::kLong},
|
|
{"short", at::kShort},
|
|
{"half", at::kHalf}};
|
|
return mapping;
|
|
}
|
|
|
|
std::shared_ptr<SugaredValue> BuiltinFunction::call(
|
|
const SourceRange& loc,
|
|
GraphFunction& m,
|
|
at::ArrayRef<NamedValue> args,
|
|
at::ArrayRef<NamedValue> kwargs,
|
|
size_t n_binders) {
|
|
return std::make_shared<SimpleValue>(
|
|
emitBuiltinCall(loc, *m.graph(), symbol, args, kwargs, self));
|
|
}
|
|
|
|
// older versions of gcc/clang have a bug where enums can't be used as keys
|
|
// in a map by default
|
|
// https://stackoverflow.com/questions/18837857/cant-use-enum-class-as-unordered-map-key
|
|
struct EnumClassHash {
|
|
template <typename T>
|
|
std::size_t operator()(T t) const {
|
|
return static_cast<std::size_t>(t);
|
|
}
|
|
};
|
|
|
|
bool SimpleValue::hasAttr(
|
|
const SourceRange& loc,
|
|
GraphFunction& m,
|
|
const std::string& field) {
|
|
auto class_type = value_->type()->cast<ClassType>();
|
|
if (!class_type) {
|
|
throw ErrorReport(loc) << "hasattr's first argument must be an object, got "
|
|
<< value_->type()->repr_str() << " instead";
|
|
}
|
|
|
|
return class_type->hasMethod(field) || class_type->hasAttribute(field) ||
|
|
class_type->hasConstant(field);
|
|
}
|
|
|
|
// support syntax sugar for x.foo(y, z) by allowing x.foo to return a
|
|
// callable value that will resolve to foo(x, y, z) when called.
|
|
std::shared_ptr<SugaredValue> SimpleValue::attr(
|
|
const SourceRange& loc,
|
|
GraphFunction& m,
|
|
const std::string& field) {
|
|
// Allow method-style casts on Tensor types. e.g. x.int()
|
|
if (value_->type()->isSubtypeOf(*TensorType::get())) {
|
|
if (builtin_cast_method_to_scalar_type().count(field)) {
|
|
return std::make_shared<TensorCastValue>(
|
|
builtin_cast_method_to_scalar_type().at(field),
|
|
NamedValue(loc, "self", value_));
|
|
}
|
|
}
|
|
// accessing properties of Tensor and Device that are implemented as
|
|
// prim:: or aten:: operators
|
|
using PropertiesLookup = std::unordered_map<
|
|
TypeKind,
|
|
std::unordered_map<std::string, std::string>,
|
|
EnumClassHash>;
|
|
static const PropertiesLookup builtin_properties = {
|
|
{TypeKind::OptionalType,
|
|
{
|
|
{"unchecked_unwrap_optional", "prim"},
|
|
}},
|
|
{TypeKind::TensorType,
|
|
{
|
|
{"dtype", "prim"},
|
|
{"device", "prim"},
|
|
{"grad", "prim"},
|
|
{"data", "prim"},
|
|
{"shape", "prim"},
|
|
{"is_cuda", "prim"},
|
|
{"is_cpu", "prim"},
|
|
{"is_xla", "prim"},
|
|
{"is_xpu", "prim"},
|
|
{"is_sparse", "prim"},
|
|
{"is_sparse_csr", "prim"},
|
|
{"is_mkldnn", "prim"},
|
|
{"is_mps", "prim"},
|
|
{"is_mtia", "prim"},
|
|
{"is_quantized", "prim"},
|
|
{"is_vulkan", "prim"},
|
|
{"is_ipu", "prim"},
|
|
{"is_meta", "prim"},
|
|
{"is_leaf", "aten"},
|
|
{"is_nested", "prim"},
|
|
{"requires_grad", "prim"},
|
|
{"layout", "prim"},
|
|
{"T", "prim"},
|
|
{"H", "prim"},
|
|
{"mT", "aten"},
|
|
{"mH", "aten"},
|
|
{"is_ort", "prim"},
|
|
{"itemsize", "prim"},
|
|
{"nbytes", "prim"},
|
|
{"ndim", "prim"},
|
|
{"name", "prim"},
|
|
{"real", "aten"},
|
|
{"imag", "aten"},
|
|
{"retains_grad", "aten"},
|
|
}},
|
|
{TypeKind::DeviceObjType, {{"type", "prim"}, {"index", "prim"}}}};
|
|
auto kind = value_->type()->kind();
|
|
auto types_for_builtin = builtin_properties.find(kind);
|
|
if (types_for_builtin != builtin_properties.end()) {
|
|
auto builtin_entry = types_for_builtin->second.find(field);
|
|
if (builtin_entry != types_for_builtin->second.end()) {
|
|
// A builtin was found, add it to the graph
|
|
auto the_namespace = builtin_entry->second;
|
|
auto r = m.graph()->insert(
|
|
Symbol::fromQualString(the_namespace + "::" + field), {value_});
|
|
return std::make_shared<SimpleValue>(r);
|
|
}
|
|
}
|
|
|
|
// accessing fields of named tuples
|
|
if (auto tuple_type = value_->type()->cast<TupleType>()) {
|
|
if (tuple_type->schema()) {
|
|
auto attrs = tuple_type->schema()->arguments();
|
|
for (const auto i : c10::irange(attrs.size())) {
|
|
if (attrs[i].name() == field) {
|
|
auto idx = m.graph()->insertConstant(IValue(static_cast<int64_t>(i)));
|
|
auto out_type = tuple_type->elements().at(i);
|
|
auto r = m.graph()
|
|
->insertNode(
|
|
m.graph()->createTupleIndex(value_, idx, out_type))
|
|
->output();
|
|
return std::make_shared<SimpleValue>(r);
|
|
}
|
|
}
|
|
}
|
|
} else if (auto awaitType = value_->type()->cast<AwaitType>()) {
|
|
auto elType = awaitType->getElementType();
|
|
auto& g = *m.graph();
|
|
auto v = g.insert(prim::awaitable_wait, {value_}, {}, loc);
|
|
auto sv = std::make_shared<SimpleValue>(v);
|
|
return sv->attr(loc, m, field);
|
|
} else if (auto classType = value_->type()->cast<ClassType>()) {
|
|
// This is a class, emit the proper attribute lookup
|
|
if (classType->findMethod(field)) {
|
|
return std::make_shared<MethodValue>(getValue(), field);
|
|
}
|
|
if (classType->hasAttribute(field)) {
|
|
auto& g = *m.graph();
|
|
auto n = g.insertNode(g.createGetAttr(value_, field));
|
|
return std::make_shared<SimpleValue>(n->output());
|
|
}
|
|
// Check and see if it's a getter attribute.
|
|
auto prop = classType->getProperty(field);
|
|
if (prop) {
|
|
return MethodValue(value_, prop->getter->name())
|
|
.call(loc, m, {}, {}, /*n_binders=*/1);
|
|
}
|
|
} else if (auto iface = value_->type()->cast<InterfaceType>()) {
|
|
// accessing methods of interfaces
|
|
if (iface->getMethod(field)) {
|
|
return std::make_shared<MethodValue>(getValue(), field);
|
|
}
|
|
} else if (auto enum_type = value_->type()->cast<EnumType>()) {
|
|
// Handle access to Enum's `name` and `value` attribute.
|
|
auto& g = *m.graph();
|
|
|
|
if (field == "name") {
|
|
auto n = g.insertNode(g.createEnumName(value_));
|
|
return std::make_shared<SimpleValue>(n->output());
|
|
}
|
|
|
|
if (field == "value") {
|
|
auto n = g.insertNode(g.createEnumValue(value_));
|
|
return std::make_shared<SimpleValue>(n->output());
|
|
}
|
|
}
|
|
|
|
// none of the more-specific cases worked, so see if this is a builtin method
|
|
// If field is a type, then call the aten::to op
|
|
if (field == "type") {
|
|
if (auto builtin = BuiltinFunction::tryCreate(
|
|
Symbol::aten("to"), NamedValue(loc, "self", value_))) {
|
|
return builtin;
|
|
}
|
|
}
|
|
|
|
if (auto builtin = BuiltinFunction::tryCreate(
|
|
Symbol::aten(field), NamedValue(loc, "self", value_))) {
|
|
return builtin;
|
|
}
|
|
|
|
// Handle calling tolist() on a Tensor.
|
|
if (value_->type()->isSubtypeOf(*TensorType::get()) && field == "tolist") {
|
|
return SpecialFormValue::create(prim::tolist);
|
|
}
|
|
|
|
// Handle calling __getitem__() directly on a Tensor, it needs special
|
|
// handling because desired method name (`__getitem__`) doesn't match `aten`
|
|
// operator name of `aten::index`.
|
|
if (value_->type()->isSubtypeOf(*TensorType::get()) &&
|
|
field == "__getitem__") {
|
|
return SpecialFormValue::create(aten::index);
|
|
}
|
|
|
|
ErrorReport report(loc);
|
|
report << "'" << value_->type()->repr_str()
|
|
<< "' object has no attribute or method '" << field << "'.";
|
|
if (auto classType = value_->type()->cast<ClassType>()) {
|
|
if (classType->isUnresolvedClassAttribute(field)) {
|
|
report
|
|
<< " '" << field
|
|
<< "' is defined as a class attribute which currently is not"
|
|
" supported. Consider converting this to an instance attribute.";
|
|
} else {
|
|
report << " Did you forget to initialize an attribute in __init__()?";
|
|
}
|
|
}
|
|
throw report;
|
|
}
|
|
|
|
std::vector<std::shared_ptr<SugaredValue>> SimpleValue::asTuple(
|
|
const SourceRange& loc,
|
|
GraphFunction& m,
|
|
const c10::optional<size_t>& size_hint) {
|
|
static const auto make_simple_value =
|
|
[](Value* v) -> std::shared_ptr<SugaredValue> {
|
|
return std::make_shared<SimpleValue>(v);
|
|
};
|
|
if (value_->type()->kind() == TypeKind::TupleType) {
|
|
auto outputs = createTupleUnpack(value_);
|
|
return fmap(outputs, make_simple_value);
|
|
} else if (value_->type()->kind() == TypeKind::ListType) {
|
|
if (!size_hint) {
|
|
throw ErrorReport(loc)
|
|
<< "cannot statically infer the expected size of a "
|
|
<< "list in this context";
|
|
}
|
|
auto graph = value_->owningGraph();
|
|
Node* unpack =
|
|
graph->insertNode(graph->createListUnpack(value_, *size_hint));
|
|
return fmap(unpack->outputs(), make_simple_value);
|
|
} else if (value_->type()->kind() == TypeKind::AnyTupleType) {
|
|
throw ErrorReport(loc)
|
|
<< "Provided tuple is not fully defined/refined including its element types, please provide a value of type like Tuple[int, int]";
|
|
}
|
|
throw ErrorReport(loc) << value_->type()->repr_str()
|
|
<< " cannot be used as a tuple";
|
|
}
|
|
|
|
static bool isRecursive(const TypePtr& classType, const TypePtr& attrType) {
|
|
if (attrType->isSubtypeOf(*classType)) {
|
|
return true;
|
|
}
|
|
|
|
// Recursively check contained types. We need to do this because a user may do
|
|
// A -> B -> A.
|
|
for (const auto& type : attrType->containedTypes()) {
|
|
if (isRecursive(classType, type)) {
|
|
return true;
|
|
}
|
|
}
|
|
return false;
|
|
}
|
|
|
|
void SimpleValue::setAttr(
|
|
const SourceRange& loc,
|
|
GraphFunction& m,
|
|
const std::string& field,
|
|
Value* newValue) {
|
|
const auto classType = value_->type()->cast<ClassType>();
|
|
if (!classType) {
|
|
throw ErrorReport(loc) << "Tried to set an attribute: " << field
|
|
<< " on a non-class: " << value_->type()->repr_str();
|
|
}
|
|
auto expectedType = classType->findAttribute(field);
|
|
if (!expectedType) {
|
|
// If we are still compiling the __init__ method for this class, then
|
|
// setting an unknown attribute adds it to the class's definition.
|
|
|
|
// We are initializing if:
|
|
const auto isInitializing =
|
|
// 1. The method we're currently inserting into is an init method
|
|
// TODO this can be a qualified name check
|
|
m.name() == "__init__" &&
|
|
// 2. The `self` arg matches this value's type (i.e. we are in the init
|
|
// method for this class, not some other class)
|
|
!m.graph()->inputs().empty() &&
|
|
m.graph()->inputs().at(0)->type() == classType;
|
|
|
|
if (isInitializing) {
|
|
if (isRecursive(classType, newValue->type())) {
|
|
throw ErrorReport(loc)
|
|
<< "Assignment to attribute '" << field
|
|
<< "' cannot be of a type that contains class "
|
|
<< "'" << classType->repr_str() << "'.\n"
|
|
<< "Classes that recursively contain instances of themselves"
|
|
<< " are not yet supported";
|
|
}
|
|
|
|
classType->addAttribute(field, newValue->type());
|
|
expectedType = newValue->type();
|
|
|
|
const auto insertPoint = m.graph()->insertPoint();
|
|
const auto topLevelBlock = m.graph()->block();
|
|
if (insertPoint->owningBlock() != topLevelBlock) {
|
|
throw ErrorReport(loc)
|
|
<< "First assignment cannot be in a control-flow block. "
|
|
<< "Initialize the field at the top level first";
|
|
}
|
|
} else {
|
|
// Check and see if it's a setter attribute.
|
|
auto prop = classType->getProperty(field);
|
|
if (prop && prop->setter) {
|
|
MethodValue(value_, prop->setter->name())
|
|
.call(loc, m, {newValue}, {}, /*n_binders=*/1);
|
|
return;
|
|
}
|
|
|
|
if (prop && !prop->setter) {
|
|
throw ErrorReport(loc) << "Tried to set read-only attribute: " << field;
|
|
}
|
|
|
|
throw ErrorReport(loc)
|
|
<< "Tried to set nonexistent attribute: " << field
|
|
<< ". Did you forget to initialize it in __init__()?";
|
|
}
|
|
}
|
|
|
|
AT_ASSERT(expectedType);
|
|
|
|
// Check type correctness
|
|
const auto newType = newValue->type();
|
|
if (!newType->isSubtypeOf(*expectedType)) {
|
|
throw ErrorReport(loc) << "Wrong type for attribute assignment. Expected "
|
|
<< expectedType->repr_str() << " but got "
|
|
<< newType->repr_str();
|
|
}
|
|
|
|
auto& g = *m.graph();
|
|
g.insertNode(g.createSetAttr(value_, field, newValue));
|
|
}
|
|
|
|
std::shared_ptr<SugaredValue> SimpleValue::call(
|
|
const SourceRange& loc,
|
|
GraphFunction& m,
|
|
at::ArrayRef<NamedValue> args,
|
|
at::ArrayRef<NamedValue> kwargs,
|
|
size_t n_binders) {
|
|
// allow our 'fake' closures to be called, used for fork serialization
|
|
// at the moment, but can be expanded later
|
|
Node* self = getValue()->node();
|
|
if (self->kind() == prim::TupleConstruct && self->inputs().size() == 2 &&
|
|
self->inputs().at(0)->node()->kind() == prim::Closure) {
|
|
std::shared_ptr<Graph> graph =
|
|
self->inputs().at(0)->node()->g(attr::Subgraph);
|
|
Value* context = self->inputs().at(1);
|
|
AT_ASSERT(context->node()->kind() == prim::TupleConstruct);
|
|
|
|
// fork nodes are emitted in their own block but we do not simplify
|
|
// tuple construction across blocks. To ensure we clean up the tuple
|
|
// construct create another copy of the tuple construct in the fork block
|
|
Value* close_context =
|
|
m.graph()
|
|
->insertNode(m.graph()->createTuple(context->node()->inputs()))
|
|
->output();
|
|
// TODO this needs to go in `m`s compilation unit
|
|
auto cu = std::make_shared<CompilationUnit>();
|
|
auto fn = cu->create_function(QualifiedName("anon"), graph);
|
|
auto ret = StrongFunctionPtr(std::move(cu), fn);
|
|
|
|
std::vector<NamedValue> ctx_inputs = {close_context};
|
|
ctx_inputs.insert(ctx_inputs.end(), args.begin(), args.end());
|
|
return FunctionValue(ret).call(loc, m, ctx_inputs, kwargs, n_binders);
|
|
}
|
|
|
|
if (auto class_type = getValue()->type()->cast<ClassType>()) {
|
|
return attr(loc, m, "__call__")->call(loc, m, args, kwargs, n_binders);
|
|
}
|
|
|
|
return SugaredValue::call(loc, m, args, kwargs, n_binders);
|
|
}
|
|
|
|
Value* SimpleValue::len(const SourceRange& loc, GraphFunction& m) {
|
|
// List, Tuple, Tensor, fill in missing information desugaring
|
|
Value* val = getValue();
|
|
TypePtr val_type = val->type();
|
|
Graph& g = *m.graph();
|
|
if (val_type->cast<ListType>() || val_type->cast<StringType>() ||
|
|
val_type->isSubtypeOf(*TensorType::get())) {
|
|
return g.insert(aten::len, {val}, {}, loc);
|
|
} else {
|
|
throw ErrorReport(loc) << "'" << val_type->repr_str() << "'"
|
|
<< " object is not iterable";
|
|
}
|
|
}
|
|
|
|
SugaredValuePtr SimpleValue::getitem(
|
|
const SourceRange& loc,
|
|
GraphFunction& m,
|
|
Value* idx,
|
|
TypePtr type_hint) {
|
|
Value* val = getValue();
|
|
TypePtr val_type = val->type();
|
|
Graph& g = *m.graph();
|
|
|
|
// if it's a List/String/Dict, emit a regular __getitem__ op
|
|
// NOLINTNEXTLINE(bugprone-branch-clone)
|
|
if (val_type->cast<ListType>() || val_type->cast<StringType>()) {
|
|
return std::make_shared<SimpleValue>(
|
|
g.insert(aten::__getitem__, {val, idx}, {}, loc));
|
|
} else if (auto dict_type = val_type->cast<DictType>()) {
|
|
return std::make_shared<SimpleValue>(
|
|
g.insert(aten::__getitem__, {val, idx}, {}, loc));
|
|
} else if (val_type->isSubtypeOf(*TensorType::get())) {
|
|
return std::make_shared<SimpleValue>(
|
|
g.insert(aten::select, {val, 0, idx}, {}, loc));
|
|
} else if (auto class_type = val_type->cast<ClassType>()) {
|
|
// Check if this is an indexing operation enabled by a type hint.
|
|
// The ModuleDict has already been checked during IR generation to make
|
|
// sure its contents implement the module interface referred to by
|
|
// type_hint.
|
|
if (class_type->is_module() && type_hint) {
|
|
auto res = g.insert(prim::ModuleContainerIndex, {val, idx}, {}, loc);
|
|
res->setType(type_hint);
|
|
return std::make_shared<SimpleValue>(res);
|
|
}
|
|
|
|
// Defer to the __getitem__ attr on the class.
|
|
return attr(loc, m, "__getitem__")->call(loc, m, {idx}, {}, 1);
|
|
} else {
|
|
throw ErrorReport(loc) << "'" << val_type->repr_str() << "'"
|
|
<< " object is not subscriptable";
|
|
}
|
|
}
|
|
|
|
SugaredValuePtr SimpleValue::iter(const SourceRange& loc, GraphFunction& m) {
|
|
auto value = getValue();
|
|
auto type = value->type();
|
|
// built-in iterable types
|
|
if (type->cast<ListType>() || type->cast<StringType>() ||
|
|
type->cast<TensorType>()) {
|
|
return std::make_shared<SimpleValue>(value);
|
|
}
|
|
// dicts iterate over keys
|
|
if (type->cast<DictType>()) {
|
|
return std::make_shared<SimpleValue>(
|
|
m.graph()->insert(aten::keys, {value}, {}, loc));
|
|
}
|
|
if (auto tup = type->cast<TupleType>()) {
|
|
auto tup_values = createTupleUnpack(value);
|
|
std::vector<SugaredValuePtr> tup_sugared;
|
|
for (Value* v : tup_values) {
|
|
tup_sugared.push_back(std::make_shared<SimpleValue>(v));
|
|
}
|
|
return std::make_shared<SugaredTupleValue>(tup_sugared);
|
|
} else {
|
|
throw ErrorReport(loc) << "'" << type->repr_str() << "'"
|
|
<< " object is not iterable";
|
|
}
|
|
}
|
|
|
|
RangeValue::RangeValue(
|
|
const SourceRange& loc,
|
|
GraphFunction& m,
|
|
std::vector<Value*> inputs,
|
|
c10::optional<int64_t> static_len) {
|
|
for (const auto i : c10::irange(inputs.size())) {
|
|
auto typ = inputs[i]->type();
|
|
if (!typ->cast<IntType>()) {
|
|
throw ErrorReport(loc)
|
|
<< "all inputs of range must be ints, found " << typ->repr_str()
|
|
<< " in argument " << c10::guts::to_string(i);
|
|
}
|
|
}
|
|
|
|
Graph& g = *m.graph();
|
|
if (inputs.empty()) {
|
|
throw ErrorReport(loc) << "range expected at least 1 arguments, got 0";
|
|
} else if (inputs.size() == 1) {
|
|
end_ = inputs[0];
|
|
start_ = g.insertConstant(0, loc);
|
|
step_ = g.insertConstant(1, loc);
|
|
// range() call only contains end, easier to calculate len() and getitem()
|
|
has_only_end_ = true;
|
|
} else if (inputs.size() <= 3) {
|
|
start_ = inputs[0];
|
|
end_ = inputs[1];
|
|
if (inputs.size() == 3) {
|
|
step_ = inputs[2];
|
|
} else {
|
|
step_ = g.insertConstant(1, loc);
|
|
}
|
|
has_only_end_ = false;
|
|
} else {
|
|
throw ErrorReport(loc) << "range expected at most 3 arguments, got "
|
|
<< inputs.size();
|
|
}
|
|
|
|
static_len_ = static_len;
|
|
}
|
|
|
|
SugaredValuePtr RangeValue::iter(const SourceRange& loc, GraphFunction& m) {
|
|
return shared_from_this();
|
|
};
|
|
|
|
Value* RangeValue::len(const SourceRange& loc, GraphFunction& m) {
|
|
if (static_len_) {
|
|
return insertConstant(*m.graph(), *static_len_, loc);
|
|
}
|
|
if (has_only_end_) {
|
|
return end_;
|
|
} else {
|
|
Graph& g = *m.graph();
|
|
return g.insert(aten::__range_length, {start_, end_, step_}, {}, loc);
|
|
}
|
|
}
|
|
|
|
SugaredValuePtr RangeValue::getitem(
|
|
const SourceRange& loc,
|
|
GraphFunction& m,
|
|
Value* idx,
|
|
TypePtr type_hint) {
|
|
if (has_only_end_) {
|
|
return std::make_shared<SimpleValue>(idx);
|
|
} else {
|
|
auto& g = *m.graph();
|
|
return std::make_shared<SimpleValue>(
|
|
g.insert(aten::__derive_index, {idx, start_, step_}, {}, loc));
|
|
}
|
|
}
|
|
|
|
std::vector<SugaredValuePtr> IterableTree::get_base_iterables() {
|
|
std::vector<SugaredValuePtr> base_iters{};
|
|
|
|
for (SugaredValuePtr& sv : children_) {
|
|
if (auto iv = std::dynamic_pointer_cast<IterableTree>(sv)) {
|
|
std::vector<SugaredValuePtr> child_iters = iv->get_base_iterables();
|
|
// merge child iters with the base_iters
|
|
base_iters.insert(
|
|
base_iters.end(),
|
|
std::make_move_iterator(child_iters.begin()),
|
|
std::make_move_iterator(child_iters.end()));
|
|
|
|
} else {
|
|
// IterableTree leaves, either SimpleValue or RangeValue
|
|
base_iters.emplace_back(sv);
|
|
}
|
|
}
|
|
return base_iters;
|
|
}
|
|
|
|
Value* IterableTree::len(const SourceRange& loc, GraphFunction& m) {
|
|
// if it's a iterable tree, we get the base iterables that consists of
|
|
// SimpleValue or RangeValue, and then calculate the minimum length of all the
|
|
// base iterables to be max_trip_count_val
|
|
TORCH_INTERNAL_ASSERT(!unroll_length_);
|
|
Graph& g = *m.graph();
|
|
std::vector<SugaredValuePtr> base_iters = get_base_iterables();
|
|
std::vector<Value*> lengths;
|
|
lengths.reserve(base_iters.size());
|
|
|
|
for (const SugaredValuePtr& base_iter : base_iters) {
|
|
lengths.emplace_back(base_iter->len(loc, m));
|
|
}
|
|
Node* list_node = g.insertNode(g.createList(IntType::get(), lengths));
|
|
return g.insert(prim::min, {list_node->output()}, {}, loc);
|
|
}
|
|
|
|
SugaredValuePtr IterableTree::getitem(
|
|
const SourceRange& loc,
|
|
GraphFunction& m,
|
|
Value* idx,
|
|
TypePtr type_hint) {
|
|
std::vector<SugaredValuePtr> child_items;
|
|
child_items.reserve(children_.size());
|
|
for (const SugaredValuePtr& child : children_) {
|
|
child_items.emplace_back(child->getitem(loc, m, idx));
|
|
}
|
|
return std::make_shared<SugaredTupleValue>(child_items);
|
|
}
|
|
|
|
void IterableTree::addChild(
|
|
const SourceRange& range,
|
|
GraphFunction& m,
|
|
const SugaredValuePtr& iter_value) {
|
|
c10::optional<int64_t> child_len = iter_value->staticLen();
|
|
if (children_.empty()) {
|
|
unroll_length_ = child_len;
|
|
} else {
|
|
if ((unroll_length_ && !child_len) || (child_len && !unroll_length_)) {
|
|
throw ErrorReport(range)
|
|
<< "Can not iterate over a module list or tuple with a value "
|
|
"that does not have a statically determinable length\n";
|
|
}
|
|
if (unroll_length_ && child_len) {
|
|
// iterables run for the minimum length of all its leaves
|
|
unroll_length_ = std::min(*child_len, *unroll_length_);
|
|
} else {
|
|
unroll_length_ = c10::nullopt;
|
|
}
|
|
}
|
|
children_.push_back(iter_value);
|
|
}
|
|
|
|
std::shared_ptr<SugaredValue> MagicMethod::call(
|
|
const SourceRange& loc,
|
|
GraphFunction& m,
|
|
at::ArrayRef<NamedValue> args,
|
|
at::ArrayRef<NamedValue> kwargs,
|
|
size_t n_binders) {
|
|
if (!args.empty()) {
|
|
Value* self = args[0].value(*m.graph());
|
|
if (auto class_ptr = self->type()->cast<ClassType>()) {
|
|
return SimpleValue(self)
|
|
.attr(loc, m, desugared_name_)
|
|
->call(loc, m, args.slice(1), kwargs, n_binders);
|
|
}
|
|
}
|
|
TORCH_INTERNAL_ASSERT(base_value_);
|
|
return base_value_->call(loc, m, args, kwargs, n_binders);
|
|
}
|
|
|
|
std::shared_ptr<SugaredValue> ClassValue::call(
|
|
const SourceRange& loc,
|
|
GraphFunction& m,
|
|
// note: names for args will be 'argument 0', 'argument 1', etc..
|
|
at::ArrayRef<NamedValue> args,
|
|
at::ArrayRef<NamedValue> kwargs,
|
|
size_t n_binders) {
|
|
AT_ASSERT(n_binders <= 1);
|
|
|
|
// Generate a new object of the right type, then call `__init__` on it
|
|
auto& g = *m.graph();
|
|
auto self = g.insertNode(g.createObject(type_))->output();
|
|
self->node()->setSourceRange(loc);
|
|
if (!type_->findMethod("__init__")) {
|
|
throw ErrorReport(loc) << "Class " << type_->name()->name()
|
|
<< " does not have an __init__ function defined";
|
|
}
|
|
|
|
// Call the init function
|
|
MethodValue(self, "__init__").call(loc, m, args, kwargs, n_binders);
|
|
|
|
return std::make_shared<SimpleValue>(self);
|
|
}
|
|
|
|
std::shared_ptr<SugaredValue> ClassValue::attr(
|
|
const SourceRange& loc,
|
|
GraphFunction& m,
|
|
const std::string& field) {
|
|
// Allow import_source.cpp to resolve calls to a submodule's
|
|
// hooks. Edge case because normally you wouldn't allow a module to
|
|
// call functions of a submodule
|
|
if (Function* hook = type_->findHook(field)) {
|
|
return std::make_shared<FunctionValue>(hook);
|
|
}
|
|
|
|
if (field != "__new__") {
|
|
throw ErrorReport(loc) << "Tried to lookup unknown attribute on class "
|
|
<< type_->annotation_str();
|
|
}
|
|
return SpecialFormValue::create(prim::CreateObject);
|
|
}
|
|
|
|
std::shared_ptr<SugaredValue> NamedTupleConstructor::call(
|
|
const SourceRange& loc,
|
|
GraphFunction& m,
|
|
at::ArrayRef<NamedValue> args,
|
|
at::ArrayRef<NamedValue> kwargs,
|
|
size_t n_binders) {
|
|
auto& g = *m.graph();
|
|
|
|
auto schema = type_->schema();
|
|
TORCH_INTERNAL_ASSERT(schema);
|
|
auto qualname = type_->name();
|
|
auto matched_schema = matchSchema(*schema, loc, g, args, kwargs);
|
|
|
|
auto self =
|
|
g.insertNode(
|
|
g.createTuple(matched_schema.inputs, type_)->setSourceRange(loc))
|
|
->output();
|
|
self->setType(type_);
|
|
|
|
return std::make_shared<SimpleValue>(self);
|
|
}
|
|
|
|
std::shared_ptr<BuiltinFunction> BuiltinFunction::tryCreate(
|
|
Symbol symbol,
|
|
c10::optional<NamedValue> self) {
|
|
for (const std::shared_ptr<Operator>& op : getAllOperatorsFor(symbol)) {
|
|
if (!self) {
|
|
return std::make_shared<BuiltinFunction>(symbol, nullptr);
|
|
}
|
|
if (auto index = op->schema().argumentIndexWithName("self")) {
|
|
std::unordered_map<std::string, TypePtr> type_env;
|
|
TypePtr formal_type = op->schema().arguments().at(*index).type();
|
|
const MatchTypeReturn matched =
|
|
matchTypeVariables(formal_type, self->type(), type_env);
|
|
if (!matched.success()) {
|
|
continue;
|
|
}
|
|
const auto concrete_type = tryEvalTypeVariables(formal_type, type_env);
|
|
if (!concrete_type || !self->type()->isSubtypeOf(*concrete_type)) {
|
|
continue;
|
|
}
|
|
return std::make_shared<BuiltinFunction>(symbol, self);
|
|
}
|
|
}
|
|
return nullptr;
|
|
}
|
|
|
|
std::shared_ptr<SugaredValue> SugaredEnumClass::attr(
|
|
const SourceRange& loc,
|
|
GraphFunction& m,
|
|
const std::string& field) {
|
|
const auto& names_values = enum_type_->enumNamesValues();
|
|
auto it = std::find_if(
|
|
names_values.begin(),
|
|
names_values.end(),
|
|
[&field](const at::EnumNameValue& nv) { return nv.first == field; });
|
|
if (it == names_values.end()) {
|
|
throw ErrorReport(loc) << enum_type_->repr_str() << "'"
|
|
<< " has no attribute '" << field << "'";
|
|
}
|
|
auto enum_holder = c10::make_intrusive<at::ivalue::EnumHolder>(
|
|
enum_type_, it->first, it->second);
|
|
return std::make_shared<SimpleValue>(
|
|
m.graph()->insertConstant(IValue(enum_holder), loc));
|
|
}
|
|
|
|
SugaredValuePtr SugaredEnumClass::iter(
|
|
const SourceRange& loc,
|
|
GraphFunction& m) {
|
|
const auto& names_values = enum_type_->enumNamesValues();
|
|
auto enum_value_ivalues = c10::impl::GenericList(enum_type_);
|
|
enum_value_ivalues.reserve(names_values.size());
|
|
for (const auto& name_value : names_values) {
|
|
auto enum_holder = c10::make_intrusive<at::ivalue::EnumHolder>(
|
|
enum_type_, name_value.first, name_value.second);
|
|
enum_value_ivalues.emplace_back(enum_holder);
|
|
}
|
|
|
|
auto enum_values_list_constant = std::make_shared<SimpleValue>(
|
|
m.graph()->insertConstant(enum_value_ivalues, loc));
|
|
return enum_values_list_constant;
|
|
}
|
|
|
|
} // namespace torch::jit
|