Fix C++20 build (#112333)

Currently C++20 fails because of incorrect template initialization order. This PR adjusted the order of theses classes and a constructor to address the issue.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/112333
Approved by: https://github.com/albanD
This commit is contained in:
cyy
2024-02-13 05:10:15 +00:00
committed by PyTorch MergeBot
parent 2bda6b4cb8
commit 47a2e6b6b8
4 changed files with 107 additions and 107 deletions

View File

@ -36,7 +36,7 @@ TEST(MemoryDAGTest, Basic) {
t->makePointerTo(e, a);
t->makePointerTo(e, f);
auto dag = std::make_unique<MemoryDAG>(std::move(t));
auto dag = std::move(*t).createMemoryDAG();
/**
* Test mayAlias()
@ -69,7 +69,7 @@ TEST(MemoryDAGTest, Basic) {
auto c = t->makeFreshValue(cValue);
t->addToContainedElements(a, c);
auto dag = std::make_unique<MemoryDAG>(std::move(t));
auto dag = std::move(*t).createMemoryDAG();
EXPECT_TRUE(dag->mayContainAlias(a, b));
EXPECT_TRUE(dag->mayContainAlias(b, a));
@ -99,7 +99,7 @@ TEST(MemoryDAGTest, Basic) {
auto d = t->makeFreshValue(dValue);
t->addToContainedElements(b, d);
auto dag = std::make_unique<MemoryDAG>(std::move(t));
auto dag = std::move(*t).createMemoryDAG();
EXPECT_TRUE(dag->mayContainAlias(b, d));
EXPECT_TRUE(dag->mayContainAlias(d, b));
@ -126,7 +126,7 @@ TEST(MemoryDAGTest, Basic) {
t->addToContainedElements(f, e);
auto dag = std::make_unique<MemoryDAG>(std::move(t));
auto dag = std::move(*t).createMemoryDAG();
for (auto elem : {a, b, c, d}) {
EXPECT_FALSE(dag->mayContainAlias(f, elem));
EXPECT_FALSE(dag->mayContainAlias(e, elem));

View File

@ -228,7 +228,7 @@ AliasDb::AliasDb(
writeRegistry_(std::make_unique<AliasDb::WriteRegistry>()) {
analyze(graph_);
memoryDAG_ = std::make_unique<MemoryDAG>(std::move(memoryDAGBuilder_));
memoryDAG_ = std::move(*memoryDAGBuilder_).createMemoryDAG();
memoryDAGBuilder_ = nullptr; // to make further access a hard error
memoryDAG_->setWildcards(

View File

@ -19,40 +19,52 @@ typedef c10::SparseBitVector<256> MemoryLocations;
namespace torch {
namespace jit {
struct Element;
struct Value;
class MemoryDAG;
using AliasTypeSet = std::vector<TypePtr>;
/**
* Helper to build up the points-to graph.
*
* We separate the "building" into a different class because it allows us to
* cache internally to MemoryDAG without worrying about how the DAG structure
* is mutated.
*/
class TORCH_API MemoryDAGBuilder {
public:
MemoryDAGBuilder() = default;
MemoryDAGBuilder(const MemoryDAGBuilder&) = delete;
MemoryDAGBuilder& operator=(const MemoryDAGBuilder&) = delete;
// `Element` represents a vertex in the points-to graph. It represents
// anything that could have an aliasing relationship--mostly IR
// `Value`s, but also wildcards or the type inside a container (e.g. `T`
// in `List[T]`)
struct Element {
Element(const Value* value_, unsigned index_);
// wildcard constructor
explicit Element(unsigned index_);
// Index into the owning DAG's bit vector that represents this element.
// NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes)
unsigned index;
// All elements that this element *may* point to. It's possible to have
// multiple elements that you might point to due to control flow/complex ops
// NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes)
MemoryLocations pointsTo;
// Backreference for points-to.
// NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes)
MemoryLocations pointedFrom;
// Elements can contain other elements (e.g. List[Tensor])
// NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes)
MemoryLocations containedElements;
// The values that this element corresponds to. May be empty if this element
// doesn't represent a first-class value.
// This is for debug information only.
// NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes)
std::unordered_set<const Value*> values;
private:
// Make `from` point at `to`.
void makePointerTo(Element* from, Element* to);
void addToContainedElements(Element* contained, Element* container);
friend class MemoryDAG;
// We memoize the results of `getMemoryLocations` to speed up queries.
// A nullopt means that this cache is not yet populated. Since `MemoryDAG` is
// immutable, this cache should never need to be invalidated.
mutable c10::optional<MemoryLocations> cachedMemoryLocations_;
// Make a fresh Element (i.e. an Element that doesn't point to anything) and
// return it.
Element* makeFreshValue(const Value* v);
friend MemoryDAG;
private:
// `MemoryDAGBuilder` builds up `indexToElementMap_`, then uses
// the map to construct the `MemoryDAG`
std::vector<std::unique_ptr<Element>> indexToElementMap_;
mutable c10::optional<MemoryLocations> cachedAllContainedMemoryLocations_;
};
// class MemoryDAG
@ -72,8 +84,8 @@ class TORCH_API MemoryDAGBuilder {
// which memory locations an element may point to.
class TORCH_API MemoryDAG {
public:
explicit MemoryDAG(std::unique_ptr<MemoryDAGBuilder> builder)
: indexToElementMap_(std::move(builder->indexToElementMap_)) {}
explicit MemoryDAG(std::vector<std::unique_ptr<Element>> indexToElementMap)
: indexToElementMap_(std::move(indexToElementMap)) {}
// explicitly delete copy constructor because otherwise windows build is
// confused for an exported class see
// https://stackoverflow.com/a/51033485/105137
@ -127,49 +139,38 @@ class TORCH_API MemoryDAG {
std::vector<std::unique_ptr<Element>> indexToElementMap_;
};
// `Element` represents a vertex in the points-to graph. It represents
// anything that could have an aliasing relationship--mostly IR
// `Value`s, but also wildcards or the type inside a container (e.g. `T`
// in `List[T]`)
struct Element {
Element(const Value* value_, unsigned index_);
// wildcard constructor
explicit Element(unsigned index_);
/**
* Helper to build up the points-to graph.
*
* We separate the "building" into a different class because it allows us to
* cache internally to MemoryDAG without worrying about how the DAG structure
* is mutated.
*/
class TORCH_API MemoryDAGBuilder {
public:
MemoryDAGBuilder() = default;
MemoryDAGBuilder(const MemoryDAGBuilder&) = delete;
MemoryDAGBuilder& operator=(const MemoryDAGBuilder&) = delete;
// Index into the owning DAG's bit vector that represents this element.
// NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes)
unsigned index;
// All elements that this element *may* point to. It's possible to have
// multiple elements that you might point to due to control flow/complex ops
// NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes)
MemoryLocations pointsTo;
// Backreference for points-to.
// NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes)
MemoryLocations pointedFrom;
// Elements can contain other elements (e.g. List[Tensor])
// NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes)
MemoryLocations containedElements;
// The values that this element corresponds to. May be empty if this element
// doesn't represent a first-class value.
// This is for debug information only.
// NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes)
std::unordered_set<const Value*> values;
private:
// Make `from` point at `to`.
void makePointerTo(Element* from, Element* to);
friend class MemoryDAG;
// We memoize the results of `getMemoryLocations` to speed up queries.
// A nullopt means that this cache is not yet populated. Since `MemoryDAG` is
// immutable, this cache should never need to be invalidated.
mutable c10::optional<MemoryLocations> cachedMemoryLocations_;
void addToContainedElements(Element* contained, Element* container);
mutable c10::optional<MemoryLocations> cachedAllContainedMemoryLocations_;
std::unique_ptr<MemoryDAG> createMemoryDAG() && {
return std::make_unique<MemoryDAG>(std::move(indexToElementMap_));
}
// Make a fresh Element (i.e. an Element that doesn't point to anything) and
// return it.
Element* makeFreshValue(const Value* v);
friend MemoryDAG;
private:
// `MemoryDAGBuilder` builds up `indexToElementMap_`, then uses
// the map to construct the `MemoryDAG`
std::vector<std::unique_ptr<Element>> indexToElementMap_;
};
} // namespace jit
} // namespace torch

View File

@ -240,7 +240,6 @@ class TORCH_API StaticRuntimeMetadata : public torch::CustomClassHolder {
///
class MemoryPlanner;
class StaticNodeInfo;
class ProcessedFunction;
class ProcessedNode;
class StaticRuntime;
@ -259,6 +258,42 @@ struct TORCH_API SROperatorObserver {
};
#endif
class TORCH_API ProcessedFunction {
public:
ProcessedFunction(
Node* node,
bool enable_out_variant,
bool check_memory_overlap);
enum class Kind : uint8_t {
kOutVariant,
kNativeFunction,
kInterpreterFallback,
};
void run(ProcessedNode* pnode) const {
return f_(pnode);
}
Kind kind() const {
return kind_;
}
bool checkMemoryOverlap() const {
return check_memory_overlap_;
}
size_t num_outputs() const {
return num_outputs_;
}
private:
SROperator f_;
Kind kind_{ProcessedFunction::Kind::kOutVariant};
bool check_memory_overlap_{false};
size_t num_outputs_{0};
};
// A `BlockInfo` instance stores all of the shared state that each
// `BlockRunner` will need to access. Most of this information is
// read-only and shared between threads.
@ -778,42 +813,6 @@ class TORCH_API BlockRunner {
std::vector<ProcessedNode> nodes_;
};
class TORCH_API ProcessedFunction {
public:
ProcessedFunction(
Node* node,
bool enable_out_variant,
bool check_memory_overlap);
enum class Kind : uint8_t {
kOutVariant,
kNativeFunction,
kInterpreterFallback,
};
void run(ProcessedNode* pnode) const {
return f_(pnode);
}
Kind kind() const {
return kind_;
}
bool checkMemoryOverlap() const {
return check_memory_overlap_;
}
size_t num_outputs() const {
return num_outputs_;
}
private:
SROperator f_;
Kind kind_{ProcessedFunction::Kind::kOutVariant};
bool check_memory_overlap_{false};
size_t num_outputs_{0};
};
// NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init)
class TORCH_API StaticNodeInfo {
public: