Adding Custom Rules to Device Propagation (#66973)

Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/66973

Test Plan: Imported from OSS

Reviewed By: driazati

Differential Revision: D32175958

Pulled By: Gamrix

fbshipit-source-id: 26a9ef41e10a171be6a8779a4e6014e2e7e3f2c1
This commit is contained in:
John Clow
2021-11-04 18:57:19 -07:00
committed by Facebook GitHub Bot
parent d04389e6f0
commit 853298481b
3 changed files with 85 additions and 2 deletions

View File

@ -14,6 +14,7 @@
#include <algorithm>
#include <iostream>
#include <memory>
#include <set>
#include <sstream>
#include <string>
@ -2303,6 +2304,15 @@ OperatorSet::OperatorSet(std::initializer_list<const char*> sig_literals) {
}
}
std::vector<std::shared_ptr<Operator>> OperatorSet::getOps() const {
std::vector<std::shared_ptr<Operator>> result;
for (const auto& kv : ops) {
auto ops_for_symbol = kv.second;
result.insert(result.end(), ops_for_symbol.begin(), ops_for_symbol.end());
}
return result;
}
bool Node::isMemberOf(const OperatorSet& os) const {
auto it = os.ops.find(kind());
if (it == os.ops.end()) {

View File

@ -1581,6 +1581,7 @@ TORCH_API std::vector<Node*> findAllNodes(
struct OperatorSet {
OperatorSet(std::initializer_list<const char*> sig_literals);
std::vector<std::shared_ptr<Operator>> getOps() const;
private:
friend struct Node;
@ -1613,6 +1614,12 @@ struct OperatorMap {
std::make_pair(op, val));
}
void insert(const OperatorSet& op_set, T val) {
for (auto& op : op_set.getOps()) {
insert(op, val);
}
}
void insert(
std::initializer_list<std::pair<std::shared_ptr<Operator>, T>> v) {
for (auto& el : v) {

View File

@ -5,8 +5,10 @@
#include <c10/util/ArrayRef.h>
#include <c10/util/Optional.h>
#include <torch/csrc/jit/ir/alias_analysis.h>
#include <torch/csrc/jit/ir/ir.h>
#include <torch/csrc/jit/jit_log.h>
#include <torch/csrc/jit/passes/dtype_analysis.h>
#include <torch/csrc/jit/passes/utils/op_registry.h>
#include <torch/library.h>
#include <algorithm>
#include <memory>
@ -20,6 +22,10 @@ namespace {
using Tensor = at::Tensor;
using ScalarType = at::ScalarType;
// ----------------------------------------------------------------------------------
// Metatensor Inference for Dtype
// ----------------------------------------------------------------------------------
std::unique_ptr<Stack> MTensorArgumentCreator(Node* n) {
auto stack = std::make_unique<std::vector<IValue>>();
for (Value* inp : n->inputs()) {
@ -139,13 +145,57 @@ bool tryApplyDtypeMetaTensor(Node* n) {
return setDtype(n->output(), return_tensor->scalar_type());
}
// ----------------------------------------------------------------------------------
// Custom Rules for Dtype
// ----------------------------------------------------------------------------------
using DtypePropRule = std::function<bool(Node*)>;
// Function to propagate dtype information for a node
// Returns true if the dtype information was changed
bool setIfAllDtypeMatch(Node* n) {
// Sets all tensor outputs to the dtype of the first input
// only if all inputs are the same dtype, otherwise do nothing
TORCH_INTERNAL_ASSERT(n->inputs().size() >= 1);
auto first_arg = n->inputs().at(0);
auto tensor_type = first_arg->type()->cast<TensorType>();
TORCH_INTERNAL_ASSERT(tensor_type, "Expecting a tensor type");
auto scalar_type = tensor_type->scalarType();
if (!scalar_type.has_value()) {
return false;
}
for (auto arg : n->inputs()) {
tensor_type = arg->type()->cast<TensorType>();
if (!tensor_type) {
continue;
}
auto arg_scalar_type = tensor_type->scalarType();
if (!arg_scalar_type.has_value()) { // Allow None for optional args
continue;
}
if (arg_scalar_type != scalar_type) {
return false;
}
}
bool changed = false;
for (auto output : n->outputs()) {
if (output->type()->cast<TensorType>()) {
changed |= setDtype(output, scalar_type.value());
}
}
return changed;
}
// DtypePropagationPass is an analysis pass that walks through a graph in
// topological order and forward propagate Dtypes (ScalarTypes) from graph
// inputs (expressed in input_descriptors) to all output tensor nodes in the
// graph.
struct DtypePropagationPass {
DtypePropagationPass(std::shared_ptr<Graph> graph)
: graph_(std::move(graph)) {}
explicit DtypePropagationPass(std::shared_ptr<Graph> graph)
: graph_(std::move(graph)) {
buildDtypeRuleRegistry();
}
// returns true if at least one node has its scalar type set on a tensor node
bool run() {
@ -240,8 +290,24 @@ struct DtypePropagationPass {
bool processAtenOps(Node* n) {
GRAPH_DEBUG("processAtenOps");
GRAPH_DEBUG("case = ", n->kind(), " ", *n);
// Custom Rule Matching
if (auto prop_fn = dtype_prop_registry_->find(n->getOperator())) {
DtypePropRule rule = *prop_fn;
return rule(n);
}
return tryApplyDtypeMetaTensor(n);
}
void buildDtypeRuleRegistry() {
// building a registry for all of the custom dtype rules
dtype_prop_registry_ = std::make_unique<OperatorMap<DtypePropRule>>();
dtype_prop_registry_->insert(
*nn_ops_first_input_preserving(), setIfAllDtypeMatch);
dtype_prop_registry_->insert(
*ops_one_tensor_in_shape_transform(), setIfAllDtypeMatch);
}
std::unique_ptr<OperatorMap<DtypePropRule>> dtype_prop_registry_;
std::shared_ptr<Graph> graph_;
};