Adding Custom Rules to Device Propagation (#66973)

Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/66973

Test Plan: Imported from OSS

Reviewed By: albanD

Differential Revision: D32497549

Pulled By: Gamrix

fbshipit-source-id: 5732682c0b39709f76cf218490e5f5136c0d83f8
This commit is contained in:
John Clow
2021-11-18 16:25:21 -08:00
committed by Facebook GitHub Bot
parent 77db720c65
commit 71a031e954
3 changed files with 85 additions and 2 deletions

View File

@ -14,6 +14,7 @@
#include <algorithm>
#include <iostream>
#include <memory>
#include <set>
#include <sstream>
#include <string>
@ -2302,6 +2303,15 @@ OperatorSet::OperatorSet(std::initializer_list<const char*> sig_literals) {
}
}
std::vector<std::shared_ptr<Operator>> OperatorSet::getOps() const {
std::vector<std::shared_ptr<Operator>> result;
for (const auto& kv : ops) {
auto ops_for_symbol = kv.second;
result.insert(result.end(), ops_for_symbol.begin(), ops_for_symbol.end());
}
return result;
}
bool Node::isMemberOf(const OperatorSet& os) const {
auto it = os.ops.find(kind());
if (it == os.ops.end()) {