mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[PyTorch] Make TypePrinter take const Type& (#69412)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/69412 TypePrinter does not need to take ownership of the Type. This helps unblock the following diff to stop refcounting Type singletons. ghstack-source-id: 145671619 Test Plan: CI Reviewed By: suo Differential Revision: D32858525 fbshipit-source-id: df58676938fd20c7bae4a366d70b2067a852282d
This commit is contained in:
committed by
Facebook GitHub Bot
parent
7a12b5063e
commit
3b7fc0243c
@ -63,7 +63,7 @@ using ConstTypePtr = std::shared_ptr<const Type>;
|
||||
// c10::nullopt is returned, `annotation_str()` falls through to its default
|
||||
// implementation.
|
||||
using TypePrinter =
|
||||
std::function<c10::optional<std::string>(const ConstTypePtr&)>;
|
||||
std::function<c10::optional<std::string>(const Type&)>;
|
||||
|
||||
struct TORCH_API Type : std::enable_shared_from_this<Type> {
|
||||
private:
|
||||
@ -118,7 +118,7 @@ struct TORCH_API Type : std::enable_shared_from_this<Type> {
|
||||
std::string annotation_str(TypePrinter printer) const {
|
||||
if (printer) {
|
||||
// the printer can return nullopt to fall through to the default impl
|
||||
if (auto renamed = printer(shared_from_this())) {
|
||||
if (auto renamed = printer(*this)) {
|
||||
return *renamed;
|
||||
}
|
||||
}
|
||||
|
@ -9,8 +9,8 @@ namespace c10 {
|
||||
|
||||
TEST(TypeCustomPrinter, Basic) {
|
||||
TypePrinter printer =
|
||||
[](const ConstTypePtr& t) -> c10::optional<std::string> {
|
||||
if (auto tensorType = t->cast<TensorType>()) {
|
||||
[](const Type& t) -> c10::optional<std::string> {
|
||||
if (auto tensorType = t.cast<TensorType>()) {
|
||||
return "CustomTensor";
|
||||
}
|
||||
return c10::nullopt;
|
||||
@ -29,8 +29,8 @@ TEST(TypeCustomPrinter, Basic) {
|
||||
|
||||
TEST(TypeCustomPrinter, ContainedTypes) {
|
||||
TypePrinter printer =
|
||||
[](const ConstTypePtr& t) -> c10::optional<std::string> {
|
||||
if (auto tensorType = t->cast<TensorType>()) {
|
||||
[](const Type& t) -> c10::optional<std::string> {
|
||||
if (auto tensorType = t.cast<TensorType>()) {
|
||||
return "CustomTensor";
|
||||
}
|
||||
return c10::nullopt;
|
||||
@ -53,8 +53,8 @@ TEST(TypeCustomPrinter, ContainedTypes) {
|
||||
|
||||
TEST(TypeCustomPrinter, NamedTuples) {
|
||||
TypePrinter printer =
|
||||
[](const ConstTypePtr& t) -> c10::optional<std::string> {
|
||||
if (auto tupleType = t->cast<TupleType>()) {
|
||||
[](const Type& t) -> c10::optional<std::string> {
|
||||
if (auto tupleType = t.cast<TupleType>()) {
|
||||
// Rewrite only NamedTuples
|
||||
if (tupleType->name()) {
|
||||
return "Rewritten";
|
||||
|
@ -188,9 +188,8 @@ std::pair<IValue, IValue> getFunctionTuple(
|
||||
|
||||
// schema
|
||||
const auto& schema = func.getSchema();
|
||||
auto type_printer =
|
||||
[&](const c10::ConstTypePtr& t) -> c10::optional<std::string> {
|
||||
auto namedType = t->cast<c10::NamedType>();
|
||||
auto type_printer = [&](const c10::Type& t) -> c10::optional<std::string> {
|
||||
auto namedType = t.cast<c10::NamedType>();
|
||||
if (namedType && namedType->name()) {
|
||||
return type_name_uniquer_.getUniqueName(namedType).qualifiedName();
|
||||
}
|
||||
@ -709,9 +708,8 @@ void ScriptModuleSerializer::convertNamedType(
|
||||
std::string qualifier = qualname.prefix();
|
||||
PythonPrint* pp = file_streams_.find(qualifier);
|
||||
|
||||
auto type_printer =
|
||||
[&](const c10::ConstTypePtr& t) -> c10::optional<std::string> {
|
||||
auto namedType = t->cast<c10::NamedType>();
|
||||
auto type_printer = [&](const c10::Type& t) -> c10::optional<std::string> {
|
||||
auto namedType = t.cast<c10::NamedType>();
|
||||
if (namedType && namedType->name()) {
|
||||
return type_name_uniquer_.getUniqueName(namedType).qualifiedName();
|
||||
}
|
||||
|
Reference in New Issue
Block a user