Files
pytorch/torch/csrc/jit/passes/quantization/insert_observers.cpp
Aaron Gokaslan 3916d7a575 Apply modernize-use-emplace to aten, c10, torch (#91077)
Apply clang-tidy check modernize-use-emplace. This is slightly more efficient by using an inplace constructor and is the recommended style in parts of the codebase covered by clang-tidy. This just manually applies the check to rest of the codebase. Pinging @ezyang as this is related to my other PRs he reviewed like #89000

Pull Request resolved: https://github.com/pytorch/pytorch/pull/91077
Approved by: https://github.com/ezyang
2022-12-19 07:49:56 +00:00

1722 lines
64 KiB
C++

#include <c10/util/irange.h>
#include <torch/csrc/jit/passes/quantization/insert_observers.h>
#include <torch/csrc/jit/frontend/schema_matching.h>
#include <torch/csrc/jit/ir/subgraph_matcher.h>
#include <torch/csrc/jit/jit_log.h>
#include <torch/csrc/jit/passes/constant_pooling.h>
#include <torch/csrc/jit/passes/constant_propagation.h>
#include <torch/csrc/jit/passes/fuse_linear.h>
#include <torch/csrc/jit/passes/graph_rewrite_helper.h>
#include <torch/csrc/jit/passes/inline_fork_wait.h>
#include <torch/csrc/jit/passes/quantization/helper.h>
#include <torch/csrc/jit/passes/remove_mutation.h>
#include <memory>
#include <regex>
#include <stack>
#include <string>
namespace torch {
namespace jit {
using ModuleQConfigMap = std::unordered_map<ModulePtr, c10::optional<QConfig>>;
namespace {
struct OptionalQConfigHash {
inline size_t operator()(const c10::optional<QConfig>& qconfig_opt) const {
if (qconfig_opt.has_value()) {
const auto& m1 = std::get<0>(*qconfig_opt);
const auto& m2 = std::get<1>(*qconfig_opt);
constexpr int CONST = 7;
return std::hash<Module>()(m1) + CONST * std::hash<Module>()(m2);
}
return 0;
}
};
using QConfigTypePtrMap =
std::unordered_map<c10::optional<QConfig>, TypePtr, OptionalQConfigHash>;
using NameModuleVector = std::vector<std::pair<std::string, Module>>;
using OptionalModuleVector = std::vector<c10::optional<Module>>;
using ModuleMethodVector = std::vector<std::pair<Module, std::string>>;
using graph_rewrite_helper::PatternInfo;
using graph_rewrite_helper::replaceConvolutionWithAtenConv;
// helper functions
void fillQConfigMap(
const Module& module,
const QConfigDict& qconfig_dict,
ModuleQConfigMap& map,
const std::string& key = "",
const c10::optional<QConfig>& parent_qconfig = c10::nullopt) {
c10::optional<QConfig> qconfig;
if (qconfig_dict.find(key) != qconfig_dict.end()) {
GRAPH_DEBUG("Got module config for key:", key);
qconfig = qconfig_dict.at(key);
} else {
GRAPH_DEBUG("Inheriting qconfig from parent module:", key);
qconfig = parent_qconfig;
}
map[module._ivalue()] = qconfig;
for (const NameModule& s : module.named_children()) {
std::string child_key;
if (key == "") {
child_key = s.name;
} else {
child_key = key + "." + s.name;
}
fillQConfigMap(s.value._ivalue(), qconfig_dict, map, child_key, qconfig);
}
}
Module getObserverModuleFor(Value* v, const QConfig& qconfig) {
return isWeight(v) ? std::get<1>(qconfig) : std::get<0>(qconfig);
}
// helper classes
class ModuleCloneHelper {
public:
/** Clone according to module qconfig map, this is for handling the case
* where we have two module instances sharing the same ClassType
* but configured with different QConfig
* code is copied and modified from
* https://github.com/pytorch/pytorch/blob/master/torch/csrc/jit/api/module.cpp
* inplace option means if the copy of the Tensor is deepcopy or not
* if inplace is true, the cloned module will share the tensors with
* original model instead of deepcopy them
*/
Module clone(
const Module& module,
const ModuleQConfigMap& module_qconfig_map,
bool inplace = false) {
std::unordered_map<TypePtr, QConfigTypePtrMap> type_remap;
IValue::HashAliasedIValueMap memo;
return clone_impl(module, module_qconfig_map, type_remap, inplace, memo);
}
private:
Module clone_impl(
const Module& module,
const ModuleQConfigMap& module_qconfig_map,
std::unordered_map<TypePtr, QConfigTypePtrMap>& type_remap,
bool inplace,
IValue::HashAliasedIValueMap memo) {
auto qconfig = module_qconfig_map.at(module._ivalue());
auto type = module.type();
// Create a new _ivalue in the same compilation unit.
// Since now we have shared ClassType, we need to preserve the shared
// ClassType during cloning, so we first use type and qconfig to check if
// the type is already cloned, if so, we'll create a new module with the
// cloned ClassType, if not, we'll create a new module and a new ClassType.
bool type_already_cloned = type_remap.find(type) != type_remap.end() &&
type_remap.at(type).find(qconfig) != type_remap.at(type).end();
Module r;
if (type_already_cloned) {
// if we cloned the class type before, we'll reuse it
Module new_module(
module._ivalue()->compilation_unit(),
type_remap.at(type).at(qconfig)->cast<ClassType>());
r = new_module;
} else {
Module new_module(
*type->name(), module._ivalue()->compilation_unit(), true);
r = new_module;
type_remap[type][module_qconfig_map.at(module._ivalue())] = r.type();
}
// Copy slots. If a slot is a module - recursively clone it.
size_t N = type->numAttributes();
for (const auto i : c10::irange(N)) {
IValue s = module._ivalue()->getSlot(i);
std::string attr_name = type->getAttributeName(i);
TypePtr attr_type = type->getAttribute(i);
if (attr_type->is_module()) {
const Module& orig = Module(s.toObject());
Module cloned =
clone_impl(orig, module_qconfig_map, type_remap, inplace, memo);
// NOTE: why do we need to manually setattr on object instead of using
// register_module here? because the attr can be a module interface
// type and hold a Module object still. register_module will not let us
// correctly set up the type for this attr, so we had to do this
// manually. In the case it's an interface type, the type will be shared
// by the new cloned instance in the same compilation unit bc it only
// contains a list of functionSchema
r.type()->addOrCheckAttribute(
attr_name,
attr_type->cast<ClassType>() ? cloned.type() : attr_type);
r._ivalue()->setAttr(attr_name, cloned._ivalue());
} else {
// we'll deepcopy the IValue in non inplace option
r.register_attribute(
type->getAttributeName(i),
type->getAttribute(i),
inplace ? s : s.deepcopy(memo),
type->is_parameter(i),
type->is_buffer(i));
}
}
// only clone the methods and constants if the ClassType is not cloned
// before
if (!type_already_cloned) {
for (size_t i = 0; i < type->numConstants(); ++i) {
r.type()->addConstant(type->getConstantName(i), type->getConstant(i));
}
// Clone methods remapping the types to the cloned ones.
for (auto& fn : type->methods()) {
clone_method(module, r, *fn, module_qconfig_map, type_remap);
}
// Execute __setstate__(__getstate__()) to initialize custom class
// members.
if (auto setstate_method = r.find_method("__setstate__")) {
auto getstate_method = r.find_method("__getstate__");
TORCH_INTERNAL_ASSERT(getstate_method, "expect __getstate__");
auto state = (*getstate_method)(Stack{});
(*setstate_method)(Stack{state});
}
}
return r;
}
void remapTypes(
Block* block,
Value* self,
const Module& source,
Module& target,
const ModuleQConfigMap& module_qconfig_map,
const std::function<TypePtr(TypePtr, c10::optional<QConfig>)>&
type_remap_fn) {
// remap of %self will be done outside of the function
// and we don't support the case when people pass in
// module as argument of the method because in that case
// we need to do more comprehensive analysis to decide the
// QConfig for the module
for (size_t i = 1; i < block->inputs().size(); ++i) {
TORCH_CHECK(
!block->inputs()[i]->type()->cast<ClassType>(),
"We don't support quantizing methods that has Object as arguments");
}
for (Node* node : block->nodes()) {
// remapping type for module instance
if (node->kind() == prim::CallMethod || node->kind() == prim::GetAttr) {
Value* instance = node->inputs()[0];
auto child_opt = getInvokedModuleOpt(source, node, self);
if (child_opt.has_value()) {
auto qconfig = module_qconfig_map.at(child_opt->_ivalue());
instance->setType(type_remap_fn(instance->type(), qconfig));
}
}
// We don't remap output and the remapping of module type
// will be done in CallMethod, we don't support type remapping
// for modules returned from methods or functions
for (Block* sub_block : node->blocks()) {
remapTypes(
sub_block, self, source, target, module_qconfig_map, type_remap_fn);
}
for (Symbol name : node->attributeNames()) {
if (node->kindOf(name) == AttributeKind::g) {
remapTypes(
node->g(name).get(),
source,
target,
module_qconfig_map,
type_remap_fn);
} else if (node->kindOf(name) == AttributeKind::gs) {
for (const auto& g : node->gs(name)) {
remapTypes(
g.get(), source, target, module_qconfig_map, type_remap_fn);
}
}
}
}
}
void remapTypes(
Graph* graph,
const Module& source,
Module& target,
const ModuleQConfigMap& module_qconfig_map,
const std::function<TypePtr(TypePtr, c10::optional<QConfig>)>&
type_remap_fn) {
remapTypes(
graph->block(),
graph->inputs()[0],
source,
target,
module_qconfig_map,
type_remap_fn);
}
void clone_method(
const Module& source,
Module& target,
const Function& method,
const ModuleQConfigMap& module_qconfig_map,
const std::unordered_map<TypePtr, QConfigTypePtrMap>& type_remap) {
auto type_remap_fn = [&](TypePtr type_ptr,
const c10::optional<QConfig>& qconfig) {
if (type_remap.find(type_ptr) != type_remap.end()) {
const auto& qconfig_map = type_remap.at(type_ptr);
if (qconfig_map.find(qconfig) != qconfig_map.end()) {
return qconfig_map.at(qconfig);
}
}
return type_ptr;
};
auto graph = toGraphFunction(method).graph()->copy();
remapTypes(graph.get(), source, target, module_qconfig_map, type_remap_fn);
// remap self
graph->inputs()[0]->setType(target.type());
// we only support %self being Module in the arguments of function
auto schema_type_remap_fn = [&](TypePtr type_ptr) {
return type_remap_fn(type_ptr, module_qconfig_map.at(source._ivalue()));
};
auto schema =
method.getSchema().cloneWithRemappedTypes(schema_type_remap_fn);
const auto this_method_name =
c10::QualifiedName(*target.type()->name(), method.name());
auto copied = target._ivalue()->compilation_unit()->create_function(
this_method_name, graph);
target.type()->addMethod(copied);
copied->setSchema(std::move(schema));
}
};
class InsertObserversHelper {
public:
explicit InsertObserversHelper(
const ModuleQConfigMap& map,
QuantType quant_type)
: module_qconfig_map_(map), quant_type_(quant_type) {}
// TODO: replace (module, method_name) with graph?
// preprocess to clean up the graph from tracing
void preprocess(Module& module, const std::string& method_name);
// Fill the map between the caller input/output to input/output
// of called graph, this is used to navigate through the graph
// to find the observer for a given value
void fillBoundaryValueMap(Module& module, const std::string& method_name);
// analyze the graph and record necessary information that can
// be used in insert observers
void analyze(Module& module, const std::string& method_name);
void removeActivationObservers();
/**
* Recursively insert observers for the method, also we'll process
* the nodes in the graph in the order of execution of these nodes
* since we need the context information to decide whether we want to
* observe/quantize a value a not, we don't want to observe a value multiple
* times.
*
* arguemnt: is_entry_point means whether the current method is the forward
* method of the top level module.
*
* Since we want to insert observers in the call site instead of in the called
* graph, we'll postpone inserting observer to caller as much as possible, if
* we know the current method is the outer most method, then
* we will insert all observers in the graph instead of postpone this to the
* parent, note that this assumes we don't have recursive method
* calls
*
* returns a tuple of vectors of observer modules for input and output, these
* are used for inserting observers for the input/output values
* since we need to insert these values at call site.
* And a vector of indexes of outputs that indicates whether the output value
* is already observed or not, this is used for propagating the observed
* property of a value through CallMethods, because we should skip inserting
* observers for ops that don't require observation
*/
std::tuple<OptionalModuleVector, OptionalModuleVector, std::vector<size_t>>
insertObservers(
Module& module,
const std::string& method_name,
bool is_entry_point = false,
std::unordered_set<Value*> graph_observed_values =
std::unordered_set<Value*>());
void setInsertResetObserverMethod(
bool insert_reset_observer_method,
const std::string& method_name) {
insert_reset_observer_method_ = insert_reset_observer_method;
reset_observer_method_name_ = "reset_observers_" + method_name;
}
private:
std::tuple<OptionalModuleVector, OptionalModuleVector, std::vector<size_t>>
insertObserversFor(
Block* block,
script::Module& module,
// this is a reference because when we insert observer for a value
// in one block it is also observed in another block, we don't want to
// insert multiple observers for the same value
std::unordered_set<Value*>& block_observed_values,
bool is_entry_point = false,
bool is_user_defined_function = false);
// Record v as "ready for observation" by storing it in values_to_observe.
// If v is a part of a delayed observation pattern, record v's descendant
// (per delay rules) instead. The observers are inserted at a later stage
// by reading the state created by this function.
void recordObserved(
Value* v,
const Module& observer_module,
std::unordered_map<Value*, Module>& values_to_observe,
std::unordered_set<Value*>& block_observed_values);
ModuleMethodVector getInvokedMethods(
Module& module,
const std::string& method_name);
bool valueNeedsToBeQuantized(Value* v, const QConfig& qconfig);
bool isObserved(
Value* v,
const std::unordered_set<Value*>& block_observed_values) {
return block_observed_values.count(v) || observed_values_.count(v);
}
// Fill the map from value to the corresponding observer module
// this map is used in insertObservers to actually insert
// observers to the module
void fillValueObserverMap(Module& module, const std::string& method_name);
// Clone observer module and add it to the original module,
// and insert a call to observer forward function
void insertObserverFor(
Value* v,
Module& module,
const Module& observer_module,
NameModuleVector& observer_name_and_modules);
void insertObserverResetMinMax(
Module& module,
const NameModuleVector& observer_name_and_modules);
// Uses the state created by fillBoundaryValueMap and fillValueObserverMap
// to return an observer configured for a value, if it is needed.
c10::optional<Module> getObserverFor(Value* v);
// Uses the state created by fillPassThroughValueMap to propage observed
// property which should pass through from inputs to outputs.
void propagateObservedProperty(
Value* output,
std::unordered_set<Value*>& block_observed_values);
// for cat/add/mul we will only observe their output if their input
// are observed
bool shouldObserve(
Node* n,
const std::unordered_set<Value*>& block_observed_values,
QuantType quant_type) {
// Check whether node output uses can be quantized, eg cat followed by
// linear op
for (Value* v : n->outputs()) {
for (const auto& use : v->uses()) {
if (useQuantizable(use, quant_type)) {
return true;
}
}
}
if (isPropagateQuantSingleInputOp(n)) {
return isObserved(n->input(0), block_observed_values);
} else if (isPropagateQuantBinaryOp(n)) {
// This checks both of the input should be tensor and observed.
// There is one check that we didn't do here, which is
// !isScalar(isObserved(n->input(1), block_observed_values)
// to make sure input 1 is not a scalar, because scalar tensor input
// for add/mul won't be observed with current rule, we can omit
// this check here
return isObserved(n->input(0), block_observed_values) &&
isObserved(n->input(1), block_observed_values);
}
return true;
}
void delayObservingValuesInPattern(Graph& graph, const PatternInfo& pattern);
// Find and mark known patterns such as conv-relu (and others) where
// we should not insert observers in the middle of the pattern.
void addValuesToDelayObservation(
const Module& module,
const std::string& method_name);
// Fill the map from values to the list of values that can pass the observed
// property to it
void fillPassThroughValueMap(const std::shared_ptr<Graph>& graph);
bool insertResetObserverMethod() {
return insert_reset_observer_method_;
}
const ModuleQConfigMap& module_qconfig_map_;
// Values we want to delay observation, used to delay the observation for
// values in the middle of the ops that are supposed to be fused, e.g.
// the output value of conv in the conv - relu pattern
// the key is the intermediate output, e.g. output of conv
// the value is the value we want to observe, e.g. output of relu
//
// example, assuming we want to delay conv-relu:
// %x1 = conv(%x0)
// %x2 = relu(%x1)
//
// delay_observation_map_ = {
// %x1: %x2,
// }
std::unordered_map<Value*, Value*> delay_observation_map_;
std::unordered_set<Graph*> visited_graph_of_observer_map_;
// Map of value to observer module configured for that value.
std::unordered_map<Value*, Module> observer_for_value_;
// Map from values from callsite into the values in the CallMethod graph
// key of the map is the value from caller graph, and the value of the map
// is the list of values in the callee graph (the graph
// corresponding to the called method),
// the reason it is a set is that a value in the caller graph
// can both correspond to the output of one callee graph and input of another
// callee graph.
//
// example:
// // top level module
// %x1 = conv(%x0)
// %x2 = prim::CallFunction(%foo, %x1)
//
// // graph of %foo
// %y2 = conv(%y1)
// return %y2
//
// boundary_value_map = {
// // current module's output values to corresponding return values from
// subgraph %x2: %y2,
// // current module's input values to corresponding input value to subgraph
// %x1: %y1,
// }
std::unordered_map<Value*, std::unordered_set<Value*>> boundary_value_map_;
std::unordered_set<Value*> observed_values_;
// This is used for the observed values to pass through the ops like flatten,
// so that output value of flatten does not need to be observed
// key is the output of the op, value is a vector of values that need
// to be observed in order to pass the observed property to the output
//
// example:
// %x1 = flatten(%x0) // pass_through
// %x2 = conv(%x1) // not pass_through
//
// pass_through_value_map_ = {
// %x1: [%x0],
// }
std::unordered_map<Value*, std::vector<Value*>> pass_through_value_map_;
// Unique id generator for observer module, used for generating
// unique observer names when we insert observer module, we
// record the current unique id used to avoid incrementing from 0
// every time to find a unique id.
int uid_ = 0;
// Set of observer forward call nodes
std::unordered_set<Node*> observer_nodes_;
// Map from block to a vector of observer name and observer modules we
// want to add to the module instance that has the block
std::unordered_map<Block*, NameModuleVector> block_observer_map_;
// Type of quantization for this pass.
QuantType quant_type_ = QuantType::STATIC;
// These are the IR patterns we match to skip inserting observers.
// They are compiled once on construction and used repeatedly within
// the pass.
// nn.Linear + nn.ReLU
const PatternInfo nn_linear_nn_relu = PatternInfo::parse_from_str(
R"(
graph(%input, %linear, %relu):
%first_output = prim::CallMethod[name="forward"](%linear, %input)
%second_output = prim::CallMethod[name="forward\\d*"](%relu, %first_output)
return (%second_output) )",
{is_linear_module, is_relu_module});
// nn.Linear + F.relu
const PatternInfo nn_linear_f_relu = PatternInfo::parse_from_str(
R"(
graph(%input, %linear, %relu, %inplace):
%first_output = prim::CallMethod[name="forward"](%linear, %input)
%second_output = prim::CallFunction(%relu, %first_output, %inplace)
return (%second_output) )",
{is_linear_module, is_functional_relu});
// nn.Linear + aten::relu
const PatternInfo nn_linear_aten_relu = PatternInfo::parse_from_str(
R"(
graph(%input, %linear, %relu):
%first_output = prim::CallMethod[name="forward"](%linear, %input)
%second_output = aten::relu(%first_output)
return (%second_output) )",
{is_linear_module});
// nn.Linear + aten::relu_
const PatternInfo nn_linear_aten_relu_ = PatternInfo::parse_from_str(
R"(
graph(%input, %linear, %relu):
%first_output = prim::CallMethod[name="forward"](%linear, %input)
%second_output = aten::relu_(%first_output)
return (%second_output) )",
{is_linear_module});
// aten::linear + nn.ReLU
const PatternInfo aten_linear_nn_relu = PatternInfo::parse_from_str(
R"(
graph(%input, %weight, %bias, %relu):
%first_output = aten::linear(%input, %weight, %bias)
%second_output = prim::CallMethod[name="forward\\d*"](%relu, %first_output)
return (%second_output) )",
{is_relu_module});
// aten::linear + F.relu
const PatternInfo aten_linear_f_relu = PatternInfo::parse_from_str(
R"(
graph(%input, %weight, %bias, %relu, %inplace):
%first_output = aten::linear(%input, %weight, %bias)
%second_output = prim::CallFunction(%relu, %first_output, %inplace)
return (%second_output) )",
{is_functional_relu});
// aten::linear + aten::relu
const PatternInfo aten_linear_aten_relu = PatternInfo::parse_from_str(
R"(
graph(%input, %weight, %bias):
%first_output = aten::linear(%input, %weight, %bias)
%second_output = aten::relu(%first_output)
return (%second_output) )");
// aten::linear + aten::relu_
const PatternInfo aten_linear_aten_relu_ = PatternInfo::parse_from_str(
R"(
graph(%input, %weight, %bias):
%first_output = aten::linear(%input, %weight, %bias)
%second_output = aten::relu_(%first_output)
return (%second_output) )");
const PatternInfo nn_conv1d_f_relu = PatternInfo::parse_from_str(
R"(
graph(%self, %input, %conv, %relu, %inplace):
%first_output = prim::CallMethod[name="forward"](%conv, %input)
%second_output = prim::CallFunction(%relu, %first_output, %inplace)
return (%second_output) )",
{is_conv1d_module, is_functional_relu});
const PatternInfo nn_conv1d_nn_relu = PatternInfo::parse_from_str(
R"(
graph(%self, %input, %conv, %relu):
%first_output = prim::CallMethod[name="forward"](%conv, %input)
%second_output = prim::CallMethod[name="forward\\d*"](%relu, %first_output)
return (%second_output) )",
{is_conv1d_module, is_relu_module});
const PatternInfo nn_conv1d_aten_relu = PatternInfo::parse_from_str(
R"(
graph(%self, %input, %conv):
%first_output = prim::CallMethod[name="forward"](%conv, %input)
%second_output = aten::relu(%first_output)
return (%second_output) )",
{is_conv1d_module});
const PatternInfo nn_conv1d_aten_relu_ = PatternInfo::parse_from_str(
R"(
graph(%self, %input, %conv):
%first_output = prim::CallMethod[name="forward"](%conv, %input)
%second_output = aten::relu_(%first_output)
return (%second_output) )",
{is_conv1d_module});
const PatternInfo nn_conv2d_f_relu = PatternInfo::parse_from_str(
R"(
graph(%self, %input, %conv, %relu, %inplace):
%first_output = prim::CallMethod[name="forward"](%conv, %input)
%second_output = prim::CallFunction(%relu, %first_output, %inplace)
return (%second_output) )",
{is_conv2d_module, is_functional_relu});
const PatternInfo nn_conv2d_nn_relu = PatternInfo::parse_from_str(
R"(
graph(%self, %input, %conv, %relu):
%first_output = prim::CallMethod[name="forward"](%conv, %input)
%second_output = prim::CallMethod[name="forward\\d*"](%relu, %first_output)
return (%second_output) )",
{is_conv2d_module, is_relu_module});
const PatternInfo nn_conv2d_aten_relu = PatternInfo::parse_from_str(
R"(
graph(%self, %input, %conv):
%first_output = prim::CallMethod[name="forward"](%conv, %input)
%second_output = aten::relu(%first_output)
return (%second_output) )",
{is_conv2d_module});
const PatternInfo nn_conv2d_aten_relu_ = PatternInfo::parse_from_str(
R"(
graph(%self, %input, %conv):
%first_output = prim::CallMethod[name="forward"](%conv, %input)
%second_output = aten::relu_(%first_output)
return (%second_output) )",
{is_conv2d_module});
const PatternInfo nn_conv3d_f_relu = PatternInfo::parse_from_str(
R"(
graph(%self, %input, %conv, %relu, %inplace):
%first_output = prim::CallMethod[name="forward"](%conv, %input)
%second_output = prim::CallFunction(%relu, %first_output, %inplace)
return (%second_output) )",
{is_conv3d_module, is_functional_relu});
const PatternInfo nn_conv3d_nn_relu = PatternInfo::parse_from_str(
R"(
graph(%self, %input, %conv, %relu):
%first_output = prim::CallMethod[name="forward"](%conv, %input)
%second_output = prim::CallMethod[name="forward\\d*"](%relu, %first_output)
return (%second_output) )",
{is_conv3d_module, is_relu_module});
const PatternInfo nn_conv3d_aten_relu = PatternInfo::parse_from_str(
R"(
graph(%self, %conv, %input):
%first_output = prim::CallMethod[name="forward"](%conv, %input)
%second_output = aten::relu(%first_output)
return (%second_output) )",
{is_conv3d_module});
const PatternInfo nn_conv3d_aten_relu_ = PatternInfo::parse_from_str(
R"(
graph(%self, %input, %conv):
%first_output = prim::CallMethod[name="forward"](%conv, %input)
%second_output = aten::relu_(%first_output)
return (%second_output) )",
{is_conv3d_module});
const PatternInfo add_nn_relu = PatternInfo::parse_from_str(
R"(
graph(%self, %a, %b, %alpha, %relu):
%first_output = aten::add(%a, %b, %alpha)
%second_output = prim::CallMethod[name="forward\\d*"](%relu, %first_output)
return (%second_output) )",
{aten_add_alpha_is_one, is_relu_module});
const PatternInfo add_f_relu = PatternInfo::parse_from_str(
R"(
graph(%self, %a, %b, %alpha, %relu, %inplace):
%first_output = aten::add(%a, %b, %alpha)
%second_output = prim::CallFunction(%relu, %first_output, %inplace)
return (%second_output) )",
{aten_add_alpha_is_one, is_functional_relu});
const PatternInfo inplace_add_nn_relu = PatternInfo::parse_from_str(
R"(
graph(%self, %a, %b, %alpha, %relu):
%first_output = aten::add_(%a, %b, %alpha)
%second_output = prim::CallMethod[name="forward\\d*"](%relu, %first_output)
return (%second_output) )",
{aten_add_alpha_is_one, is_relu_module});
const PatternInfo inplace_add_f_relu = PatternInfo::parse_from_str(
R"(
graph(%self, %a, %b, %alpha, %relu, %inplace):
%first_output = aten::add_(%a, %b, %alpha)
%second_output = prim::CallFunction(%relu, %first_output, %inplace)
return (%second_output) )",
{aten_add_alpha_is_one, is_functional_relu});
const PatternInfo add_aten_relu = PatternInfo::parse_from_str(R"(
graph(%self, %a, %b, %alpha):
%first_output = aten::add(%a, %b, %alpha)
%second_output = aten::relu(%first_output)
return (%second_output) )");
const PatternInfo add_aten_relu_ = PatternInfo::parse_from_str(R"(
graph(%self, %a, %b, %alpha):
%first_output = aten::add(%a, %b, %alpha)
%second_output = aten::relu_(%first_output)
return (%second_output) )");
const PatternInfo inplace_add_aten_relu = PatternInfo::parse_from_str(R"(
graph(%self, %a, %b, %alpha):
%first_output = aten::add_(%a, %b, %alpha)
%second_output = aten::relu(%first_output)
return (%second_output) )");
const PatternInfo inplace_add_aten_relu_ = PatternInfo::parse_from_str(R"(
graph(%self, %a, %b, %alpha):
%first_output = aten::add_(%a, %b, %alpha)
%second_output = aten::relu_(%first_output)
return (%second_output) )");
const PatternInfo nn_bn2d_nn_relu = PatternInfo::parse_from_str(
R"(
graph(%self, %input, %batchnorm, %relu):
%first_output = prim::CallMethod[name="forward"](%batchnorm, %input)
%second_output = prim::CallMethod[name="forward\\d*"](%relu, %first_output)
return (%second_output) )",
{is_batchnorm2d_module, is_relu_module});
const PatternInfo nn_bn2d_f_relu = PatternInfo::parse_from_str(
R"(
graph(%self, %input, %batchnorm, %relu, %inplace):
%first_output = prim::CallMethod[name="forward"](%batchnorm, %input)
%second_output = prim::CallFunction(%relu, %first_output, %inplace)
return (%second_output) )",
{is_batchnorm2d_module, is_functional_relu});
const PatternInfo nn_bn2d_aten_relu = PatternInfo::parse_from_str(
R"(
graph(%self, %input, %batchnorm):
%first_output = prim::CallMethod[name="forward"](%batchnorm, %input)
%second_output = aten::relu(%first_output)
return (%second_output) )",
{is_batchnorm2d_module});
const PatternInfo nn_bn2d_aten_relu_ = PatternInfo::parse_from_str(
R"(
graph(%self, %input, %batchnorm):
%first_output = prim::CallMethod[name="forward"](%batchnorm, %input)
%second_output = aten::relu_(%first_output)
return (%second_output) )",
{is_batchnorm2d_module});
const PatternInfo nn_bn3d_nn_relu = PatternInfo::parse_from_str(
R"(
graph(%self, %input, %batchnorm, %relu):
%first_output = prim::CallMethod[name="forward"](%batchnorm, %input)
%second_output = prim::CallMethod[name="forward\\d*"](%relu, %first_output)
return (%second_output) )",
{is_batchnorm3d_module, is_relu_module});
const PatternInfo nn_bn3d_f_relu = PatternInfo::parse_from_str(
R"(
graph(%self, %input, %batchnorm, %relu, %inplace):
%first_output = prim::CallMethod[name="forward"](%batchnorm, %input)
%second_output = prim::CallFunction(%relu, %first_output, %inplace)
return (%second_output) )",
{is_batchnorm3d_module, is_functional_relu});
const PatternInfo nn_bn3d_aten_relu = PatternInfo::parse_from_str(
R"(
graph(%self, %input, %batchnorm):
%first_output = prim::CallMethod[name="forward"](%batchnorm, %input)
%second_output = aten::relu(%first_output)
return (%second_output) )",
{is_batchnorm3d_module});
const PatternInfo nn_bn3d_aten_relu_ = PatternInfo::parse_from_str(
R"(
graph(%self, %input, %batchnorm):
%first_output = prim::CallMethod[name="forward"](%batchnorm, %input)
%second_output = aten::relu_(%first_output)
return (%second_output) )",
{is_batchnorm3d_module});
const PatternInfo mul_nn_relu = PatternInfo::parse_from_str(
R"(
graph(%self, %a, %b, %relu):
%first_output = aten::mul(%a, %b)
%second_output = prim::CallMethod[name="forward"](%relu, %first_output)
return (%second_output) )",
{is_relu_module});
const PatternInfo mul_f_relu = PatternInfo::parse_from_str(
R"(
graph(%self, %a, %b, %relu, %inplace):
%first_output = aten::mul(%a, %b)
%second_output = prim::CallFunction(%relu, %first_output, %inplace)
return (%second_output) )",
{is_functional_relu});
const PatternInfo inplace_mul_nn_relu = PatternInfo::parse_from_str(
R"(
graph(%self, %a, %b, %relu):
%first_output = aten::mul_(%a, %b)
%second_output = prim::CallMethod[name="forward"](%relu, %first_output)
return (%second_output) )",
{is_relu_module});
const PatternInfo inplace_mul_f_relu = PatternInfo::parse_from_str(
R"(
graph(%self, %a, %b, %relu, %inplace):
%first_output = aten::mul_(%a, %b)
%second_output = prim::CallFunction(%relu, %first_output, %inplace)
return (%second_output) )",
{is_functional_relu});
const PatternInfo mul_aten_relu = PatternInfo::parse_from_str(R"(
graph(%self, %a, %b):
%first_output = aten::mul(%a, %b)
%second_output = aten::relu(%first_output)
return (%second_output) )");
const PatternInfo mul_aten_relu_ = PatternInfo::parse_from_str(R"(
graph(%self, %a, %b):
%first_output = aten::mul(%a, %b)
%second_output = aten::relu_(%first_output)
return (%second_output) )");
const PatternInfo inplace_mul_aten_relu = PatternInfo::parse_from_str(R"(
graph(%self, %a, %b):
%first_output = aten::mul_(%a, %b)
%second_output = aten::relu(%first_output)
return (%second_output) )");
const PatternInfo inplace_mul_aten_relu_ = PatternInfo::parse_from_str(R"(
graph(%self, %a, %b):
%first_output = aten::mul_(%a, %b)
%second_output = aten::relu_(%first_output)
return (%second_output) )");
const std::vector<std::reference_wrapper<const PatternInfo>> delay_patterns =
{
nn_linear_f_relu, nn_linear_nn_relu,
nn_linear_aten_relu, nn_linear_aten_relu_,
aten_linear_f_relu, aten_linear_nn_relu,
aten_linear_aten_relu, aten_linear_aten_relu_,
nn_conv1d_f_relu, nn_conv1d_nn_relu,
nn_conv1d_aten_relu, nn_conv1d_aten_relu_,
nn_conv2d_f_relu, nn_conv2d_nn_relu,
nn_conv2d_aten_relu, nn_conv2d_aten_relu_,
nn_conv3d_f_relu, nn_conv3d_nn_relu,
nn_conv3d_aten_relu, nn_conv3d_aten_relu_,
add_nn_relu, add_f_relu,
inplace_add_nn_relu, inplace_add_f_relu,
add_aten_relu, add_aten_relu_,
inplace_add_aten_relu, inplace_add_aten_relu_,
nn_bn2d_nn_relu, nn_bn2d_f_relu,
nn_bn2d_aten_relu, nn_bn2d_aten_relu_,
nn_bn3d_nn_relu, nn_bn3d_f_relu,
nn_bn3d_aten_relu, nn_bn3d_aten_relu_,
mul_nn_relu, mul_f_relu,
inplace_mul_nn_relu, inplace_mul_f_relu,
mul_aten_relu, mul_aten_relu_,
inplace_mul_aten_relu, inplace_mul_aten_relu_,
};
bool insert_reset_observer_method_{false};
std::string reset_observer_method_name_;
};
ModuleMethodVector InsertObserversHelper::getInvokedMethods(
Module& module,
const std::string& method_name) {
ModuleMethodVector invoked_methods;
Method method = module.get_method(method_name);
auto graph = method.graph();
std::stack<Block*> blocks_to_visit;
blocks_to_visit.push(graph->block());
while (!blocks_to_visit.empty()) {
Block* b = blocks_to_visit.top();
blocks_to_visit.pop();
for (Node* n : b->nodes()) {
// Skip observer nodes
if (observer_nodes_.count(n)) {
continue;
}
if (n->kind() == prim::CallMethod) {
auto m_opt = getInvokedModuleOpt(module, n, graph->inputs()[0]);
if (m_opt.has_value()) {
invoked_methods.emplace_back(*m_opt, n->s(attr::name));
}
}
for (Block* subblock : n->blocks()) {
blocks_to_visit.push(subblock);
}
}
}
return invoked_methods;
}
void InsertObserversHelper::insertObserverFor(
Value* v,
Module& module,
const Module& observer_module,
NameModuleVector& observer_name_and_modules) {
if (observed_values_.count(v)) {
return;
}
GRAPH_DEBUG("Inserting observer for:", v->debugName());
Module observer = observer_module.deepcopy();
std::string observer_name = "_observer_" + c10::to_string(uid_++);
while (module.hasattr(observer_name)) {
observer_name = "_observer_" + c10::to_string(uid_++);
}
module.register_module(observer_name, observer);
observer_name_and_modules.emplace_back(observer_name, observer);
auto* g = v->owningGraph();
// Get handle of observer module
Node* observer_instance =
g->createGetAttr(g->inputs()[0], observer_name)->insertAfter(v->node());
observer_instance->output()->setDebugName(observer_name);
{
WithInsertPoint guard(observer_instance->next());
// Match arguments to types of observer's arguments
MatchedSchema forward_matched_schema = matchSchema(
observer.get_method("forward").function().getSchema(),
v->node()->sourceRange(),
*g,
{observer_instance->output(), v},
{});
// Insert call to observer's forward
Node* call = g->insertMethodCall("forward", forward_matched_schema)->node();
call->output()->copyMetadata(v);
// Replace v with the output of observer
v->replaceAllUsesWith(call->output());
// The above also replaced the input to `call`, so switch it back to
// the correct value
call->replaceInput(1, v);
observer_nodes_.emplace(call);
observed_values_.insert(call->output());
}
}
void InsertObserversHelper::insertObserverResetMinMax(
Module& module,
const NameModuleVector& observer_name_and_modules) {
if (observer_name_and_modules.empty()) {
return;
}
auto reset_min_max_opt = module.find_method(reset_observer_method_name_);
if (!reset_min_max_opt.has_value()) {
std::shared_ptr<Graph> reset_observer_graph = std::make_shared<Graph>();
Value* module_value = reset_observer_graph->addInput("self");
Node* output_node = reset_observer_graph->createNone();
reset_observer_graph->insertNode(output_node);
reset_observer_graph->registerOutput(output_node->output());
module_value->setType(module._ivalue()->type());
const auto method_name = c10::QualifiedName(
*(module.type()->name()), reset_observer_method_name_);
auto reset_observer_fn =
module._ivalue()->compilation_unit()->create_function(
method_name, reset_observer_graph);
auto self_arg = c10::Argument("self", module.type());
auto output_arg = c10::Argument("none", output_node->output()->type());
auto schema = c10::FunctionSchema(
reset_observer_method_name_, "", {self_arg}, {output_arg});
reset_observer_fn->setSchema(std::move(schema));
module.type()->addMethod(reset_observer_fn);
}
auto reset_min_max_graph =
module.get_method(reset_observer_method_name_).graph();
Value* self = reset_min_max_graph->inputs()[0];
for (const auto& pair : observer_name_and_modules) {
const auto& observer_name = pair.first;
const auto& observer = pair.second;
Value* observer_value =
reset_min_max_graph->insertGetAttr(self, observer_name);
MatchedSchema reset_minmax_schema = matchSchema(
observer.get_method("reset_min_max_vals").function().getSchema(),
observer_value->node()->sourceRange(),
*reset_min_max_graph,
{observer_value},
{});
reset_min_max_graph->insertMethodCall(
"reset_min_max_vals", reset_minmax_schema);
}
}
void InsertObserversHelper::delayObservingValuesInPattern(
Graph& graph,
const PatternInfo& pattern) {
const Graph& pattern_graph = *pattern.pattern_graph;
const std::unordered_map<std::string, Value*>& vmap = pattern.vmap;
const auto& matches = findPatternMatches(pattern_graph, graph);
for (const auto& match : matches) {
if (!std::all_of(
pattern.filters.begin(),
pattern.filters.end(),
[&](const MatchFilter& f) { return f(match, vmap); })) {
continue;
}
auto first_output = match.values_map.at(vmap.at("first_output"));
auto second_output = match.values_map.at(vmap.at("second_output"));
GRAPH_DEBUG(
"Delay observation for value in function pattern:",
first_output->debugName(),
" to ",
second_output->debugName());
delay_observation_map_[first_output] = second_output;
}
}
void InsertObserversHelper::addValuesToDelayObservation(
const Module& module,
const std::string& method_name) {
Method method = module.get_method(method_name);
auto graph = method.graph();
for (const auto& pattern : delay_patterns) {
delayObservingValuesInPattern(*graph, pattern);
}
}
void InsertObserversHelper::fillPassThroughValueMap(
const std::shared_ptr<Graph>& graph) {
std::stack<Block*> blocks_to_visit;
blocks_to_visit.push(graph->block());
while (!blocks_to_visit.empty()) {
Block* b = blocks_to_visit.top();
blocks_to_visit.pop();
for (Node* n : b->nodes()) {
if (userDefinedCallFunction(n)) {
auto g = getCallFunctionGraph(n);
blocks_to_visit.push(g->block());
}
for (auto* output : n->outputs()) {
for (auto* input : getPassThroughInputs(output)) {
pass_through_value_map_[output].push_back(input);
}
}
for (Block* subblock : n->blocks()) {
blocks_to_visit.push(subblock);
}
}
}
}
void InsertObserversHelper::fillBoundaryValueMap(
Module& module,
const std::string& method_name) {
for (auto& invoked_method : getInvokedMethods(module, method_name)) {
auto& invoked_module = std::get<0>(invoked_method);
const auto& invoked_method_name = std::get<1>(invoked_method);
fillBoundaryValueMap(invoked_module, invoked_method_name);
}
auto graph = module.get_method(method_name).graph();
std::stack<Block*> blocks_to_visit;
blocks_to_visit.push(graph->block());
auto* self = graph->inputs()[0];
while (!blocks_to_visit.empty()) {
Block* b = blocks_to_visit.top();
blocks_to_visit.pop();
for (Node* n : b->nodes()) {
if (n->kind() == prim::CallMethod || userDefinedCallFunction(n)) {
std::shared_ptr<Graph> g;
// offset of input for the caller node, since the first
// input of CallFunction is the function node and the graph
// for CallFunction start with actual input
// NOLINTNEXTLINE(cppcoreguidelines-init-variables)
size_t input_offset;
if (n->kind() == prim::CallMethod) {
auto m_opt = getInvokedModuleOpt(module, n, self);
if (!m_opt.has_value()) {
continue;
}
auto m = *m_opt;
g = m.get_method(n->s(attr::name)).graph();
input_offset = 0;
} else {
g = getCallFunctionGraph(n);
input_offset = 1;
}
// add mapping from callsite value to value in called graph
for (auto i = 0U; i < g->outputs().size(); ++i) {
auto* return_val = g->outputs()[i];
GRAPH_DEBUG(
"Boundary Map[return]:",
n->output(i)->debugName(),
" -> ",
return_val->debugName());
boundary_value_map_[n->output(i)].insert(return_val);
}
for (auto i = 0U; i < g->inputs().size(); ++i) {
auto caller_input_index = i + input_offset;
auto* caller_input = n->input(caller_input_index);
auto* input_val = g->inputs()[i];
GRAPH_DEBUG(
"Boundary Map[input]:",
caller_input->debugName(),
" -> ",
input_val->debugName());
boundary_value_map_[caller_input].insert(input_val);
}
} else if (n->kind() == prim::If) {
for (Block* subblock : n->blocks()) {
blocks_to_visit.push(subblock);
for (Value* v : n->outputs()) {
Value* subblock_output = subblock->outputs()[v->offset()];
GRAPH_DEBUG(
"Boundary Map[if_output]:",
v->debugName(),
" -> ",
subblock_output->debugName());
boundary_value_map_[v].insert(subblock_output);
}
}
} else {
for (Block* subblock : n->blocks()) {
blocks_to_visit.push(subblock);
}
}
}
}
}
void InsertObserversHelper::preprocess(
Module& module,
const std::string& method_name) {
// run preprocess for child module before parent, since preprocess
// mutates the graph and it might affect passes like fillBoundaryValueMap
for (auto& invoked_method : getInvokedMethods(module, method_name)) {
auto& invoked_module = std::get<0>(invoked_method);
const auto& invoked_method_name = std::get<1>(invoked_method);
preprocess(invoked_module, invoked_method_name);
}
Method method = module.get_method(method_name);
auto graph = method.graph();
// Inline fork-wait calls
InlineForkWait(graph);
// fuse decomposed linear into aten::linear
FuseLinear(graph);
replaceConvolutionWithAtenConv(graph);
RemoveListMutation(graph);
}
void InsertObserversHelper::analyze(
Module& module,
const std::string& method_name) {
for (auto& invoked_method : getInvokedMethods(module, method_name)) {
auto& invoked_module = std::get<0>(invoked_method);
const auto& invoked_method_name = std::get<1>(invoked_method);
analyze(invoked_module, invoked_method_name);
}
// fill out various internal state which will be later used in
// insertObservers to insert the correct observer
addValuesToDelayObservation(module, method_name);
fillValueObserverMap(module, method_name);
Method method = module.get_method(method_name);
auto graph = method.graph();
fillPassThroughValueMap(graph);
}
bool InsertObserversHelper::valueNeedsToBeQuantized(
Value* v,
const QConfig& qconfig) {
if (isBiasOfConvOrLinear(v) ||
!(v->type()->isSubtypeOf(*TensorType::get()) ||
v->type()->isSubtypeOf(*ListType::ofTensors())) ||
isEmbeddingBagNonInput(v)) {
return false;
}
// For dynamic quantization we only insert observers at the input
// of the quantizable function.
if (quant_type_ == QuantType::STATIC) {
// Check whether producer is quantizable
if (!isWeightOnlyStaticQuantOp(v->node()) &&
(nodeQuantizable(v->node()) || isPropagateQuantOp(v->node()))) {
return true;
}
}
if (quant_type_ == QuantType::DYNAMIC) {
// Check the dtype of the observer module.
Module observer_module = getObserverModuleFor(v, qconfig);
auto scalar_type = observer_module.attr("dtype");
// For inputs with Fp16 type that are not-weights we don't observer them for
// dynamic quantization.
if (scalar_type == at::ScalarType::Half && !isWeight(v)) {
return false;
}
}
// Check whether node input value is quantizable
for (const auto& use : v->uses()) {
if (useQuantizable(use, quant_type_)) {
return true;
}
}
return false;
}
void InsertObserversHelper::removeActivationObservers() {
std::vector<std::unordered_map<Value*, Module>::iterator>
values_to_be_removed;
for (auto it = observer_for_value_.begin(); it != observer_for_value_.end();
it++) {
if (!isWeight(it->first)) {
values_to_be_removed.push_back(it);
}
}
for (auto it : values_to_be_removed) {
observer_for_value_.erase(it);
}
}
void InsertObserversHelper::fillValueObserverMap(
Module& module,
const std::string& method_name) {
Method method = module.get_method(method_name);
auto graph = method.graph();
if (visited_graph_of_observer_map_.count(graph.get())) {
return;
}
visited_graph_of_observer_map_.insert(graph.get());
std::stack<Block*> blocks_to_visit;
auto qconfig_opt = module_qconfig_map_.at(module._ivalue());
if (!qconfig_opt) {
return;
}
auto qconfig = *qconfig_opt;
for (auto* v : graph->inputs()) {
if (valueNeedsToBeQuantized(v, qconfig)) {
GRAPH_DEBUG("Recording observer for ", v->debugName());
GRAPH_DUMP("In graph:", v->owningGraph());
observer_for_value_[v] = getObserverModuleFor(v, qconfig);
}
}
blocks_to_visit.push(graph->block());
while (!blocks_to_visit.empty()) {
Block* b = blocks_to_visit.top();
blocks_to_visit.pop();
for (Node* n : b->nodes()) {
for (Value* v : n->outputs()) {
if (valueNeedsToBeQuantized(v, qconfig)) {
GRAPH_DEBUG("Recording observer for ", v->debugName());
GRAPH_DUMP("In graph:", v->owningGraph());
observer_for_value_[v] = getObserverModuleFor(v, qconfig);
}
}
for (Block* subblock : n->blocks()) {
blocks_to_visit.push(subblock);
}
}
}
}
c10::optional<Module> InsertObserversHelper::getObserverFor(Value* v) {
if (observer_for_value_.count(v)) {
auto observer = observer_for_value_.at(v);
GRAPH_DEBUG("Got observer module config for:", v->debugName());
return observer;
}
c10::optional<Module> result;
if (boundary_value_map_.count(v)) {
for (Value* next : boundary_value_map_.at(v)) {
GRAPH_DEBUG(
"Going through boundary map:",
v->debugName(),
" --> ",
next->debugName());
GRAPH_DUMP("From graph:", v->owningGraph());
GRAPH_DUMP("To graph:", next->owningGraph());
auto observer_opt = getObserverFor(next);
if (observer_opt) {
// Need to make sure all values are
// configured with same observer
if (result) {
TORCH_CHECK(
*observer_opt == *result,
"Expecting all values in the graph only configured with one observer");
} else {
result = observer_opt;
}
}
}
}
GRAPH_DEBUG(
"Observer module config for ", v->debugName(), ":", result.has_value());
return result;
}
std::tuple<OptionalModuleVector, OptionalModuleVector, std::vector<size_t>>
InsertObserversHelper::insertObservers(
Module& module,
const std::string& method_name,
bool is_entry_point,
std::unordered_set<Value*> graph_observed_values) {
auto graph = module.get_method(method_name).graph();
return insertObserversFor(
graph->block(), module, graph_observed_values, is_entry_point);
}
void InsertObserversHelper::recordObserved(
Value* v,
const Module& observer_module,
std::unordered_map<Value*, Module>& values_to_observe,
std::unordered_set<Value*>& block_observed_values) {
Value* to_observe = v;
if (delay_observation_map_.count(v)) {
to_observe = delay_observation_map_.at(v);
}
values_to_observe[to_observe] = observer_module;
block_observed_values.insert(to_observe);
}
std::tuple<OptionalModuleVector, OptionalModuleVector, std::vector<size_t>>
InsertObserversHelper::insertObserversFor(
Block* block,
script::Module& module,
std::unordered_set<Value*>& block_observed_values,
bool is_entry_point,
bool is_user_defined_function) {
// input/output values, used to skip inserting observers
// for input and output of the block and the owning graph,
// we have to insert the observers at call site because
// the graph itself can be shared
std::unordered_set<Value*> inputs_outputs;
// list of observer modules for input values
std::vector<c10::optional<Module>> block_input_observers;
// list of observer modules for output values
std::vector<c10::optional<Module>> block_output_observers;
// if the current block is the block for entry point graph(the forward graph
// of the top level module), we can insert observers in the block directly
if (!is_entry_point) {
auto* graph = block->owningGraph();
// graph inputs/outputs
for (auto list : {graph->inputs(), graph->outputs()}) {
for (auto* v : list) {
inputs_outputs.insert(v);
}
}
// block outputs
for (auto* v : block->outputs()) {
inputs_outputs.insert(v);
}
for (auto* v : block->inputs()) {
block_input_observers.emplace_back(getObserverFor(v));
}
for (auto* v : block->outputs()) {
// we need explictly skip the values that are already observed
// this might happen in subblocks for `if` since
// these subblock has access to all values before the `if` node
if (!isObserved(v, block_observed_values)) {
block_output_observers.emplace_back(getObserverFor(v));
} else {
block_output_observers.emplace_back(c10::nullopt);
}
}
}
// This means the block is been processed before, we just
// need to attach observer modules and construct the information
// needed by call site here
bool visited = block_observer_map_.count(block);
if (visited) {
// instance clone of observer module and setAttr
for (const auto& observer_attrs : block_observer_map_.at(block)) {
const auto& name = std::get<0>(observer_attrs);
const auto& observer = std::get<1>(observer_attrs);
module._ivalue()->setAttr(name, observer.deepcopy()._ivalue());
}
}
// NB: Why do we need to process the graph even if it's visited?
// Reason is `block_observed_values` can
// change depending on where the method is called, and
// outputs that's been observed(third item of the returned result)
// can change depending on that, so for each graph we'll need to go through
// the whole process of inserting observers, the observers inserted in this
// block won't change, but the information we return to the caller will change
// based on `block_observed_values`
std::stack<Block*> blocks_to_visit;
blocks_to_visit.push(block);
auto* self = block->owningGraph()->inputs()[0];
// We first construct a map from value to the module, then
// insert observers for them later, this is to avoid interference
// of the inserted observers with the analysis to decide where
// to insert observers, also we only insert observers for
// "intermediate values" that is not the input/output of the
// graph
std::unordered_map<Value*, Module> values_to_observe;
for (auto* v : block->inputs()) {
if (!inputs_outputs.count(v) && !values_to_observe.count(v)) {
if (auto observer_opt = getObserverFor(v)) {
recordObserved(
v, *observer_opt, values_to_observe, block_observed_values);
}
}
}
while (!blocks_to_visit.empty()) {
Block* b = blocks_to_visit.top();
blocks_to_visit.pop();
for (Node* n : b->nodes()) {
if (observer_nodes_.count(n)) {
continue;
}
if (n->kind() == prim::CallMethod || userDefinedCallFunction(n)) {
script::Module m;
std::shared_ptr<Graph> g;
// NOLINTNEXTLINE(cppcoreguidelines-init-variables)
size_t input_offset;
bool is_udf_for_subblock = is_user_defined_function;
if (n->kind() == prim::CallMethod) {
auto m_opt = getInvokedModuleOpt(module, n, self);
if (!m_opt.has_value()) {
continue;
}
m = *m_opt;
g = m.get_method(n->s(attr::name)).graph();
input_offset = 0;
} else { // CallFunction
m = module;
g = getCallFunctionGraph(n);
input_offset = 1;
is_udf_for_subblock = true;
}
std::unordered_set<Value*> callee_observed_inputs;
for (auto i = 0U; i < g->inputs().size(); ++i) {
auto* node_input = n->input(i + input_offset);
if (isObserved(node_input, block_observed_values)) {
callee_observed_inputs.insert(g->inputs()[i]);
}
}
auto* subblock = g->block();
auto info_from_callee = insertObserversFor(
subblock, m, callee_observed_inputs, false, is_udf_for_subblock);
auto input_observers = std::get<0>(info_from_callee);
auto output_observers = std::get<1>(info_from_callee);
auto callee_observed_outputs = std::get<2>(info_from_callee);
for (auto idx : callee_observed_outputs) {
block_observed_values.insert(n->outputs()[idx]);
}
for (auto i = 0U; i < g->inputs().size(); ++i) {
auto* node_input = n->input(i + input_offset);
if (input_observers[i] && !inputs_outputs.count(node_input) &&
!isObserved(node_input, block_observed_values)) {
recordObserved(
node_input,
*input_observers[i],
values_to_observe,
block_observed_values);
}
}
for (auto i = 0U; i < n->outputs().size(); ++i) {
if (output_observers[i] && !inputs_outputs.count(n->output(i)) &&
!isObserved(n->output(i), block_observed_values)) {
recordObserved(
n->output(i),
*output_observers[i],
values_to_observe,
block_observed_values);
}
}
} else if (n->kind() == prim::If) {
// a vector recoding whether each output is observed or not
std::vector<bool> aggregated_output_observe_state;
for (Block* subblock : n->blocks()) {
if (alwaysRaisesException(subblock)) {
continue;
}
// subblock has access to all the values in the scope of prim::If,
// so subblock_observed_values == block_observed_values
auto info_from_subblock =
insertObserversFor(subblock, module, block_observed_values);
// subblock for prim::If doesn't have inputs
auto output_observers = std::get<1>(info_from_subblock);
auto subblock_observed_outputs = std::get<2>(info_from_subblock);
// We'll insert output observer for each subblock, and in the end
// we will check if output of subblocks are quantized consistently
for (size_t i = 0; i < subblock->outputs().size(); ++i) {
Value* output = subblock->outputs()[i];
if (output_observers[i] && !inputs_outputs.count(output) &&
!isObserved(output, block_observed_values)) {
recordObserved(
output,
*output_observers[i],
values_to_observe,
block_observed_values);
}
}
for (auto idx : subblock_observed_outputs) {
block_observed_values.insert(subblock->outputs()[idx]);
}
std::vector<bool> subblock_output_observe_state;
for (size_t i = 0; i < subblock->outputs().size(); ++i) {
Value* output = subblock->outputs()[i];
subblock_output_observe_state.push_back(
isObserved(output, block_observed_values));
}
if (aggregated_output_observe_state.size() > 0) {
TORCH_CHECK(
aggregated_output_observe_state ==
subblock_output_observe_state,
"branches for `if` should return values that are observed "
"consistently, if node:",
*n);
} else {
aggregated_output_observe_state = subblock_output_observe_state;
}
}
// mark the output of if as observed
for (size_t i = 0; i < n->outputs().size(); ++i) {
if (aggregated_output_observe_state[i]) {
block_observed_values.insert(n->output(i));
}
}
} else if (n->kind() == prim::Loop) {
TORCH_WARN_ONCE(
"prim::Loop is not yet supported in quantization, "
"please make sure nothing needs to be quantized in the "
"loop");
}
for (Value* v : n->outputs()) {
propagateObservedProperty(v, block_observed_values);
if (!inputs_outputs.count(v) && !isObserved(v, block_observed_values)) {
auto observer_opt = getObserverFor(v);
// If the node is one of the propagate quant node, e.g.
// aten::cat, we should observe its output only
// if the input of the node is observed
if (observer_opt &&
shouldObserve(n, block_observed_values, quant_type_)) {
recordObserved(
v, *observer_opt, values_to_observe, block_observed_values);
}
}
}
}
}
std::vector<size_t> output_idxs;
for (auto i = 0U; i < block->outputs().size(); ++i) {
if (isObserved(block->outputs()[i], block_observed_values)) {
output_idxs.push_back(i);
}
}
if (!visited) {
NameModuleVector observer_name_and_modules;
for (const auto& item : values_to_observe) {
auto* v = item.first;
auto observer = item.second;
TORCH_CHECK(
!is_user_defined_function,
"Inserting observers for user defined functions is not "
"supported right now");
insertObserverFor(v, module, observer, observer_name_and_modules);
}
if (insertResetObserverMethod()) {
insertObserverResetMinMax(module, observer_name_and_modules);
}
block_observer_map_[block] = observer_name_and_modules;
}
return std::make_tuple(
block_input_observers, block_output_observers, output_idxs);
}
void InsertObserversHelper::propagateObservedProperty(
Value* output,
std::unordered_set<Value*>& block_observed_values) {
if (pass_through_value_map_.count(output)) {
// since the vector is always non-empty, we will
// not return the initial value
bool all_observed = true;
for (Value* v : pass_through_value_map_.at(output)) {
all_observed &=
observed_values_.count(v) || block_observed_values.count(v);
}
if (all_observed) {
GRAPH_DEBUG("Pass through observed property in node:", *output->node());
// This is to propagate observed property through
// all ops that doesn't require observation
block_observed_values.insert(output);
}
}
}
} // namespace
Module InsertObservers(
Module& input_module,
const std::string& method_name,
const QConfigDict& qconfig_dict,
bool inplace,
QuantType quant_type) {
ModuleQConfigMap map_before_clone;
fillQConfigMap(input_module, qconfig_dict, map_before_clone);
ModuleCloneHelper mh;
Module module = mh.clone(input_module, map_before_clone, inplace);
SwapFunctionalLinear(module);
ModuleQConfigMap module_qconfig_map;
// Since the types are changed after clone, we need to fill
// the qconfig map again
fillQConfigMap(module, qconfig_dict, module_qconfig_map);
GRAPH_DEBUG("Quant type:", quant_type);
InsertObserversHelper helper(module_qconfig_map, quant_type);
helper.preprocess(module, method_name);
helper.fillBoundaryValueMap(module, method_name);
// analyze needs to run after fillBoundaryValueMap
// since we need to know the boundary value mapping to trace
// through the calls
helper.analyze(module, method_name);
helper.insertObservers(module, method_name, /* is_entry_point */ true);
return module;
}
Module InsertObserversForOnDevicePTQ(
Module& input_module,
const std::string& method_name,
const QConfigDict& qconfig_dict,
bool inplace,
QuantType quant_type) {
ModuleQConfigMap map_before_clone;
fillQConfigMap(input_module, qconfig_dict, map_before_clone);
ModuleCloneHelper mh;
Module cloned_module = mh.clone(input_module, map_before_clone, inplace);
std::shared_ptr<Graph> g = cloned_module.get_method(method_name).graph();
SwapFunctionalLinear(g);
std::string observer_method_name = "observe_" + method_name;
cloneMethod(cloned_module, method_name, observer_method_name);
ModuleQConfigMap module_qconfig_map;
// Since the types are changed after clone, we need to fill
// the qconfig map again
fillQConfigMap(cloned_module, qconfig_dict, module_qconfig_map);
GRAPH_DEBUG("Quant type:", quant_type);
InsertObserversHelper helper(module_qconfig_map, quant_type);
// Removes list mutation part is not clear. Is it needed
helper.preprocess(cloned_module, observer_method_name);
// Since we expect the graph to be inlined this should not have any use
// However, this function does handle if blocks
// Although as far as I understood If blocks are not really handled
// in JIT quantization. Should we just protect against this. That is if we
// find observable value inside If block? Also side effect of inlining is that
// you will have multiple getattrs for the same attribute and thus potentially
// multiple observers observing the same value. This will also lead to
// increased size of the packed param struct. I dont expect this to be a
// commong pattern but something to be aware fo Note that current quant
// workflow does not prevent this anyway since during inset quant dequant
// things are inlined anyway
helper.fillBoundaryValueMap(cloned_module, observer_method_name);
// analyze needs to run after fillBoundaryValueMap
// since we need to know the boundary value mapping to trace
// through the calls
helper.analyze(cloned_module, observer_method_name);
// Remove activation observer if quant_type is dynamic
if (quant_type == QuantType::DYNAMIC) {
helper.removeActivationObservers();
}
helper.setInsertResetObserverMethod(true, method_name);
helper.insertObservers(
cloned_module, observer_method_name, /* is_entry_point */ true);
return cloned_module;
}
} // namespace jit
} // namespace torch