mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Adding Custom Rules to Device Propagation (#66973)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/66973 Test Plan: Imported from OSS Reviewed By: albanD Differential Revision: D32497549 Pulled By: Gamrix fbshipit-source-id: 5732682c0b39709f76cf218490e5f5136c0d83f8
This commit is contained in:
committed by
Facebook GitHub Bot
parent
77db720c65
commit
71a031e954
@ -14,6 +14,7 @@
|
||||
|
||||
#include <algorithm>
|
||||
#include <iostream>
|
||||
#include <memory>
|
||||
#include <set>
|
||||
#include <sstream>
|
||||
#include <string>
|
||||
@ -2302,6 +2303,15 @@ OperatorSet::OperatorSet(std::initializer_list<const char*> sig_literals) {
|
||||
}
|
||||
}
|
||||
|
||||
std::vector<std::shared_ptr<Operator>> OperatorSet::getOps() const {
|
||||
std::vector<std::shared_ptr<Operator>> result;
|
||||
for (const auto& kv : ops) {
|
||||
auto ops_for_symbol = kv.second;
|
||||
result.insert(result.end(), ops_for_symbol.begin(), ops_for_symbol.end());
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
bool Node::isMemberOf(const OperatorSet& os) const {
|
||||
auto it = os.ops.find(kind());
|
||||
if (it == os.ops.end()) {
|
||||
|
Reference in New Issue
Block a user