mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[jit] Reduce refcounting of Types (#65345)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/65345 FooType::get() can return a const reference. Inconveniently, converting shared_ptr<FooType> to shared_ptr<Type> requires a copy & refcount bump, so to properly take advantage of this in unshapedType() we need to take a const Type& in isSubtypeOf(), which is good practice anyway -- don't require a shared_ptr if you don't need to take ownership. ghstack-source-id: 140044165 Test Plan: CI perf says c10::unshapedType time decreased from 2.8% to 2.2% during static runtime startup, though I expect this to be generally beneficial. Reviewed By: hlu1 Differential Revision: D31027361 fbshipit-source-id: 676feb81db9f74ad7b8651d8774f4ecb4cfa6ab8
This commit is contained in:
committed by
Facebook GitHub Bot
parent
1ae468a484
commit
2d885ab73d
@ -88,7 +88,7 @@ std::shared_ptr<SugaredValue> SimpleValue::attr(
|
||||
Function& m,
|
||||
const std::string& field) {
|
||||
// Allow method-style casts on Tensor types. e.g. x.int()
|
||||
if (value_->type()->isSubtypeOf(TensorType::get())) {
|
||||
if (value_->type()->isSubtypeOf(*TensorType::get())) {
|
||||
if (builtin_cast_method_to_scalar_type().count(field)) {
|
||||
return std::make_shared<TensorCastValue>(
|
||||
builtin_cast_method_to_scalar_type().at(field),
|
||||
@ -202,7 +202,7 @@ std::shared_ptr<SugaredValue> SimpleValue::attr(
|
||||
}
|
||||
|
||||
// Handle calling tolist() on a Tensor.
|
||||
if (value_->type()->isSubtypeOf(TensorType::get()) && field == "tolist") {
|
||||
if (value_->type()->isSubtypeOf(*TensorType::get()) && field == "tolist") {
|
||||
return SpecialFormValue::create(prim::tolist);
|
||||
}
|
||||
|
||||
@ -252,7 +252,7 @@ std::vector<std::shared_ptr<SugaredValue>> SimpleValue::asTuple(
|
||||
}
|
||||
|
||||
static bool isRecursive(const TypePtr& classType, const TypePtr& attrType) {
|
||||
if (attrType->isSubtypeOf(classType)) {
|
||||
if (attrType->isSubtypeOf(*classType)) {
|
||||
return true;
|
||||
}
|
||||
|
||||
@ -334,7 +334,7 @@ void SimpleValue::setAttr(
|
||||
|
||||
// Check type correctness
|
||||
const auto newType = newValue->type();
|
||||
if (!newType->isSubtypeOf(expectedType)) {
|
||||
if (!newType->isSubtypeOf(*expectedType)) {
|
||||
throw ErrorReport(loc) << "Wrong type for attribute assignment. Expected "
|
||||
<< expectedType->repr_str() << " but got "
|
||||
<< newType->repr_str();
|
||||
@ -390,7 +390,7 @@ Value* SimpleValue::len(const SourceRange& loc, Function& m) {
|
||||
TypePtr val_type = val->type();
|
||||
Graph& g = *m.graph();
|
||||
if (val_type->cast<ListType>() || val_type->cast<StringType>() ||
|
||||
val_type->isSubtypeOf(TensorType::get())) {
|
||||
val_type->isSubtypeOf(*TensorType::get())) {
|
||||
return g.insert(aten::len, {val}, {}, loc);
|
||||
} else {
|
||||
throw ErrorReport(loc) << "'" << val_type->repr_str() << "'"
|
||||
@ -415,7 +415,7 @@ SugaredValuePtr SimpleValue::getitem(
|
||||
} else if (auto dict_type = val_type->cast<DictType>()) {
|
||||
return std::make_shared<SimpleValue>(
|
||||
g.insert(aten::__getitem__, {val, idx}, {}, loc));
|
||||
} else if (val_type->isSubtypeOf(TensorType::get())) {
|
||||
} else if (val_type->isSubtypeOf(*TensorType::get())) {
|
||||
return std::make_shared<SimpleValue>(
|
||||
g.insert(aten::select, {val, 0, idx}, {}, loc));
|
||||
} else if (auto class_type = val_type->cast<ClassType>()) {
|
||||
@ -702,7 +702,7 @@ std::shared_ptr<BuiltinFunction> BuiltinFunction::tryCreate(
|
||||
continue;
|
||||
}
|
||||
const auto concrete_type = tryEvalTypeVariables(formal_type, type_env);
|
||||
if (!concrete_type || !self->type()->isSubtypeOf(concrete_type)) {
|
||||
if (!concrete_type || !self->type()->isSubtypeOf(*concrete_type)) {
|
||||
continue;
|
||||
}
|
||||
return std::make_shared<BuiltinFunction>(symbol, self);
|
||||
|
||||
Reference in New Issue
Block a user