mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Adding Custom Rules to Device Propagation (#66973)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/66973 Test Plan: Imported from OSS Reviewed By: driazati Differential Revision: D32175958 Pulled By: Gamrix fbshipit-source-id: 26a9ef41e10a171be6a8779a4e6014e2e7e3f2c1
This commit is contained in:
committed by
Facebook GitHub Bot
parent
d04389e6f0
commit
853298481b
@ -14,6 +14,7 @@
|
||||
|
||||
#include <algorithm>
|
||||
#include <iostream>
|
||||
#include <memory>
|
||||
#include <set>
|
||||
#include <sstream>
|
||||
#include <string>
|
||||
@ -2303,6 +2304,15 @@ OperatorSet::OperatorSet(std::initializer_list<const char*> sig_literals) {
|
||||
}
|
||||
}
|
||||
|
||||
std::vector<std::shared_ptr<Operator>> OperatorSet::getOps() const {
|
||||
std::vector<std::shared_ptr<Operator>> result;
|
||||
for (const auto& kv : ops) {
|
||||
auto ops_for_symbol = kv.second;
|
||||
result.insert(result.end(), ops_for_symbol.begin(), ops_for_symbol.end());
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
bool Node::isMemberOf(const OperatorSet& os) const {
|
||||
auto it = os.ops.find(kind());
|
||||
if (it == os.ops.end()) {
|
||||
|
@ -1581,6 +1581,7 @@ TORCH_API std::vector<Node*> findAllNodes(
|
||||
|
||||
struct OperatorSet {
|
||||
OperatorSet(std::initializer_list<const char*> sig_literals);
|
||||
std::vector<std::shared_ptr<Operator>> getOps() const;
|
||||
|
||||
private:
|
||||
friend struct Node;
|
||||
@ -1613,6 +1614,12 @@ struct OperatorMap {
|
||||
std::make_pair(op, val));
|
||||
}
|
||||
|
||||
void insert(const OperatorSet& op_set, T val) {
|
||||
for (auto& op : op_set.getOps()) {
|
||||
insert(op, val);
|
||||
}
|
||||
}
|
||||
|
||||
void insert(
|
||||
std::initializer_list<std::pair<std::shared_ptr<Operator>, T>> v) {
|
||||
for (auto& el : v) {
|
||||
|
@ -5,8 +5,10 @@
|
||||
#include <c10/util/ArrayRef.h>
|
||||
#include <c10/util/Optional.h>
|
||||
#include <torch/csrc/jit/ir/alias_analysis.h>
|
||||
#include <torch/csrc/jit/ir/ir.h>
|
||||
#include <torch/csrc/jit/jit_log.h>
|
||||
#include <torch/csrc/jit/passes/dtype_analysis.h>
|
||||
#include <torch/csrc/jit/passes/utils/op_registry.h>
|
||||
#include <torch/library.h>
|
||||
#include <algorithm>
|
||||
#include <memory>
|
||||
@ -20,6 +22,10 @@ namespace {
|
||||
using Tensor = at::Tensor;
|
||||
using ScalarType = at::ScalarType;
|
||||
|
||||
// ----------------------------------------------------------------------------------
|
||||
// Metatensor Inference for Dtype
|
||||
// ----------------------------------------------------------------------------------
|
||||
|
||||
std::unique_ptr<Stack> MTensorArgumentCreator(Node* n) {
|
||||
auto stack = std::make_unique<std::vector<IValue>>();
|
||||
for (Value* inp : n->inputs()) {
|
||||
@ -139,13 +145,57 @@ bool tryApplyDtypeMetaTensor(Node* n) {
|
||||
return setDtype(n->output(), return_tensor->scalar_type());
|
||||
}
|
||||
|
||||
// ----------------------------------------------------------------------------------
|
||||
// Custom Rules for Dtype
|
||||
// ----------------------------------------------------------------------------------
|
||||
using DtypePropRule = std::function<bool(Node*)>;
|
||||
// Function to propagate dtype information for a node
|
||||
// Returns true if the dtype information was changed
|
||||
|
||||
bool setIfAllDtypeMatch(Node* n) {
|
||||
// Sets all tensor outputs to the dtype of the first input
|
||||
// only if all inputs are the same dtype, otherwise do nothing
|
||||
TORCH_INTERNAL_ASSERT(n->inputs().size() >= 1);
|
||||
auto first_arg = n->inputs().at(0);
|
||||
auto tensor_type = first_arg->type()->cast<TensorType>();
|
||||
TORCH_INTERNAL_ASSERT(tensor_type, "Expecting a tensor type");
|
||||
auto scalar_type = tensor_type->scalarType();
|
||||
if (!scalar_type.has_value()) {
|
||||
return false;
|
||||
}
|
||||
for (auto arg : n->inputs()) {
|
||||
tensor_type = arg->type()->cast<TensorType>();
|
||||
if (!tensor_type) {
|
||||
continue;
|
||||
}
|
||||
auto arg_scalar_type = tensor_type->scalarType();
|
||||
|
||||
if (!arg_scalar_type.has_value()) { // Allow None for optional args
|
||||
continue;
|
||||
}
|
||||
if (arg_scalar_type != scalar_type) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
bool changed = false;
|
||||
for (auto output : n->outputs()) {
|
||||
if (output->type()->cast<TensorType>()) {
|
||||
changed |= setDtype(output, scalar_type.value());
|
||||
}
|
||||
}
|
||||
return changed;
|
||||
}
|
||||
|
||||
// DtypePropagationPass is an analysis pass that walks through a graph in
|
||||
// topological order and forward propagate Dtypes (ScalarTypes) from graph
|
||||
// inputs (expressed in input_descriptors) to all output tensor nodes in the
|
||||
// graph.
|
||||
struct DtypePropagationPass {
|
||||
DtypePropagationPass(std::shared_ptr<Graph> graph)
|
||||
: graph_(std::move(graph)) {}
|
||||
explicit DtypePropagationPass(std::shared_ptr<Graph> graph)
|
||||
: graph_(std::move(graph)) {
|
||||
buildDtypeRuleRegistry();
|
||||
}
|
||||
|
||||
// returns true if at least one node has its scalar type set on a tensor node
|
||||
bool run() {
|
||||
@ -240,8 +290,24 @@ struct DtypePropagationPass {
|
||||
bool processAtenOps(Node* n) {
|
||||
GRAPH_DEBUG("processAtenOps");
|
||||
GRAPH_DEBUG("case = ", n->kind(), " ", *n);
|
||||
// Custom Rule Matching
|
||||
if (auto prop_fn = dtype_prop_registry_->find(n->getOperator())) {
|
||||
DtypePropRule rule = *prop_fn;
|
||||
return rule(n);
|
||||
}
|
||||
return tryApplyDtypeMetaTensor(n);
|
||||
}
|
||||
|
||||
void buildDtypeRuleRegistry() {
|
||||
// building a registry for all of the custom dtype rules
|
||||
dtype_prop_registry_ = std::make_unique<OperatorMap<DtypePropRule>>();
|
||||
|
||||
dtype_prop_registry_->insert(
|
||||
*nn_ops_first_input_preserving(), setIfAllDtypeMatch);
|
||||
dtype_prop_registry_->insert(
|
||||
*ops_one_tensor_in_shape_transform(), setIfAllDtypeMatch);
|
||||
}
|
||||
std::unique_ptr<OperatorMap<DtypePropRule>> dtype_prop_registry_;
|
||||
std::shared_ptr<Graph> graph_;
|
||||
};
|
||||
|
||||
|
Reference in New Issue
Block a user