[PyTorch] Make TypePrinter take const Type& (#69412)

Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/69412

TypePrinter does not need to take ownership of the Type.

This helps unblock the following diff to stop refcounting Type singletons.
ghstack-source-id: 145671619

Test Plan: CI

Reviewed By: suo

Differential Revision: D32858525

fbshipit-source-id: df58676938fd20c7bae4a366d70b2067a852282d
This commit is contained in:
Scott Wolchok
2021-12-14 23:11:29 -08:00
committed by Facebook GitHub Bot
parent 7a12b5063e
commit 3b7fc0243c
3 changed files with 12 additions and 14 deletions

View File

@ -63,7 +63,7 @@ using ConstTypePtr = std::shared_ptr<const Type>;
// c10::nullopt is returned, `annotation_str()` falls through to its default
// implementation.
using TypePrinter =
std::function<c10::optional<std::string>(const ConstTypePtr&)>;
std::function<c10::optional<std::string>(const Type&)>;
struct TORCH_API Type : std::enable_shared_from_this<Type> {
private:
@ -118,7 +118,7 @@ struct TORCH_API Type : std::enable_shared_from_this<Type> {
std::string annotation_str(TypePrinter printer) const {
if (printer) {
// the printer can return nullopt to fall through to the default impl
if (auto renamed = printer(shared_from_this())) {
if (auto renamed = printer(*this)) {
return *renamed;
}
}

View File

@ -9,8 +9,8 @@ namespace c10 {
TEST(TypeCustomPrinter, Basic) {
TypePrinter printer =
[](const ConstTypePtr& t) -> c10::optional<std::string> {
if (auto tensorType = t->cast<TensorType>()) {
[](const Type& t) -> c10::optional<std::string> {
if (auto tensorType = t.cast<TensorType>()) {
return "CustomTensor";
}
return c10::nullopt;
@ -29,8 +29,8 @@ TEST(TypeCustomPrinter, Basic) {
TEST(TypeCustomPrinter, ContainedTypes) {
TypePrinter printer =
[](const ConstTypePtr& t) -> c10::optional<std::string> {
if (auto tensorType = t->cast<TensorType>()) {
[](const Type& t) -> c10::optional<std::string> {
if (auto tensorType = t.cast<TensorType>()) {
return "CustomTensor";
}
return c10::nullopt;
@ -53,8 +53,8 @@ TEST(TypeCustomPrinter, ContainedTypes) {
TEST(TypeCustomPrinter, NamedTuples) {
TypePrinter printer =
[](const ConstTypePtr& t) -> c10::optional<std::string> {
if (auto tupleType = t->cast<TupleType>()) {
[](const Type& t) -> c10::optional<std::string> {
if (auto tupleType = t.cast<TupleType>()) {
// Rewrite only NamedTuples
if (tupleType->name()) {
return "Rewritten";

View File

@ -188,9 +188,8 @@ std::pair<IValue, IValue> getFunctionTuple(
// schema
const auto& schema = func.getSchema();
auto type_printer =
[&](const c10::ConstTypePtr& t) -> c10::optional<std::string> {
auto namedType = t->cast<c10::NamedType>();
auto type_printer = [&](const c10::Type& t) -> c10::optional<std::string> {
auto namedType = t.cast<c10::NamedType>();
if (namedType && namedType->name()) {
return type_name_uniquer_.getUniqueName(namedType).qualifiedName();
}
@ -709,9 +708,8 @@ void ScriptModuleSerializer::convertNamedType(
std::string qualifier = qualname.prefix();
PythonPrint* pp = file_streams_.find(qualifier);
auto type_printer =
[&](const c10::ConstTypePtr& t) -> c10::optional<std::string> {
auto namedType = t->cast<c10::NamedType>();
auto type_printer = [&](const c10::Type& t) -> c10::optional<std::string> {
auto namedType = t.cast<c10::NamedType>();
if (namedType && namedType->name()) {
return type_name_uniquer_.getUniqueName(namedType).qualifiedName();
}