[Pytorch Edge] Extend Tracer to Custom Classes (#67004)

Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/67004

New version because the other one was impossible to rebase

Trace custom classes

Test Plan: CI.

Reviewed By: dhruvbird

Differential Revision: D31818978

fbshipit-source-id: daa22ccb153e32685bcca43a303ba9e21042d052
This commit is contained in:
Jacob Szwejbka
2021-10-26 11:36:14 -07:00
committed by Facebook GitHub Bot
parent 34ee5b11ff
commit 6c22b96082
11 changed files with 248 additions and 19 deletions

View File

@ -1,5 +1,5 @@
#include <torch/custom_class.h>
#include <ATen/record_function.h>
#include <ATen/core/jit_type.h>
#include <ATen/core/function_schema.h>
#include <ATen/core/functional.h>
@ -9,6 +9,14 @@
namespace torch {
namespace detail {
void record_custom_class(std::string name) {
RECORD_FUNCTION_WITH_SCOPE(at::RecordScope::CUSTOM_CLASS, name, {});
}
} // namespace detail
std::unordered_map<std::string, at::ClassTypePtr>& customClasses() {
static std::unordered_map<std::string, at::ClassTypePtr> customClasses;
return customClasses;
@ -25,8 +33,12 @@ void registerCustomClass(at::ClassTypePtr class_type) {
customClasses()[name] = std::move(class_type);
}
at::ClassTypePtr getCustomClass(const std::string& name) {
return customClasses().count(name) ? customClasses()[name] : nullptr;
at::ClassTypePtr getCustomClass(const std::string& class_name) {
auto ret = customClasses().count(class_name) ? customClasses()[class_name] : nullptr;
if (ret) {
RECORD_CUSTOM_CLASS(class_name);
}
return ret;
}
const std::unordered_set<std::string> getAllCustomClassesNames() {

View File

@ -91,6 +91,21 @@ constexpr bool schema_allowlist_check(string_view schema) {
#endif
}
// Returns true iff the given custom class name is on the allowlist
// and should be registered
constexpr bool custom_class_allowlist_check(string_view custom_class_name) {
#if !defined(TORCH_CUSTOM_CLASS_ALLOWLIST)
// If the TORCH_CUSTOM_CLASS_ALLOWLIST parameter is not defined,
// all custom classes are to be registered
(void)custom_class_name;
return true;
#else
return op_allowlist_contains(
C10_STRINGIZE(TORCH_CUSTOM_CLASS_ALLOWLIST),
custom_class_name);
#endif
}
// schema_allowlist_check() implicitly depends on a macro, TORCH_OPERATOR_WHITELIST.
// Add this API to pass arbitrary allowlist.
constexpr bool op_allowlist_contains_name_in_schema(string_view allowlist, string_view schema) {

View File

@ -27,6 +27,8 @@ enum class C10_API_ENUM RecordScope : uint8_t {
TORCHSCRIPT_FUNCTION,
// Kernel Function dtype Tag
KERNEL_FUNCTION_DTYPE,
// Torchbind custom class,
CUSTOM_CLASS,
// Kernel Function dtype Tag
LITE_INTERPRETER,
// User defined scope (e.g. with record_function())

View File

@ -36,11 +36,16 @@ class SelectiveBuilder:
# of the kernel function implementation itself.
kernel_metadata: Dict[str, List[str]]
# A set of all the custom torch bind classes used by the selected models
# Stored as a set internally to remove duplicates proactively, but written
# as a list to yamls
custom_classes: Set[str]
# If true, then fragments for all dtypes for all kernel functions
# are included. This is typically set when any one of the
# are included as well as all custom classes. This is typically set when any one of the
# operator lists is generated from a mechanism other than
# tracing based selective build.
include_all_kernel_dtypes: bool
include_all_non_op_selectives: bool
@staticmethod
def get_nop_selector() -> 'SelectiveBuilder':
@ -49,11 +54,12 @@ class SelectiveBuilder:
@staticmethod
def from_yaml_dict(data: Dict[str, object]) -> 'SelectiveBuilder':
valid_top_level_keys = {
'include_all_kernel_dtypes',
'include_all_non_op_selectives',
'include_all_operators',
'debug_info',
'operators',
'kernel_metadata',
'custom_classes',
}
top_level_keys = set(data.keys())
if len(top_level_keys - valid_top_level_keys) > 0:
@ -84,15 +90,19 @@ class SelectiveBuilder:
for (k, v) in kernel_metadata_dict.items():
kernel_metadata[str(k)] = list(map(lambda dtype: str(dtype), v))
include_all_kernel_dtypes = data.get('include_all_kernel_dtypes', False)
assert isinstance(include_all_kernel_dtypes, bool)
custom_classes = data.get('custom_classes', [])
custom_classes = set(custom_classes) # type: ignore[arg-type]
include_all_non_op_selectives = data.get('include_all_non_op_selectives', False)
assert isinstance(include_all_non_op_selectives, bool)
return SelectiveBuilder(
include_all_operators,
debug_info,
operators,
kernel_metadata,
include_all_kernel_dtypes,
custom_classes, # type: ignore[arg-type]
include_all_non_op_selectives,
)
@staticmethod
@ -121,7 +131,7 @@ class SelectiveBuilder:
}
return SelectiveBuilder.from_yaml_dict({
'operators': operators,
'include_all_kernel_dtypes': True,
'include_all_non_op_selectives': True,
})
def is_operator_selected(self, name: str) -> bool:
@ -184,14 +194,14 @@ class SelectiveBuilder:
return base_op.include_all_overloads and base_op.is_root_operator
def is_kernel_dtype_selected(self, kernel_tag: str, dtype: str) -> bool:
if self.include_all_operators or self.include_all_kernel_dtypes:
if self.include_all_operators or self.include_all_non_op_selectives:
return True
return kernel_tag in self.kernel_metadata and dtype in self.kernel_metadata[kernel_tag]
def to_dict(self) -> Dict[str, object]:
ret: Dict[str, object] = {
'include_all_kernel_dtypes': self.include_all_kernel_dtypes,
'include_all_non_op_selectives': self.include_all_non_op_selectives,
'include_all_operators': self.include_all_operators,
}
operators = {}
@ -204,6 +214,8 @@ class SelectiveBuilder:
ret['kernel_metadata'] = {k: sorted(list(v)) for (k, v) in self.kernel_metadata.items()}
ret['custom_classes'] = sorted(self.custom_classes)
return ret
@ -226,13 +238,15 @@ def combine_selective_builders(lhs: SelectiveBuilder, rhs: SelectiveBuilder) ->
debug_info = merge_debug_info(lhs._debug_info, rhs._debug_info)
operators = merge_operator_dicts(lhs.operators, rhs.operators)
kernel_metadata = merge_kernel_metadata(lhs.kernel_metadata, rhs.kernel_metadata)
include_all_kernel_dtypes = lhs.include_all_kernel_dtypes or rhs.include_all_kernel_dtypes
include_all_non_op_selectives = lhs.include_all_non_op_selectives or rhs.include_all_non_op_selectives
custom_classes = lhs.custom_classes.union(rhs.custom_classes)
return SelectiveBuilder(
include_all_operators,
debug_info,
operators,
kernel_metadata,
include_all_kernel_dtypes,
custom_classes,
include_all_non_op_selectives,
)

View File

@ -52,7 +52,7 @@ def get_selected_kernel_dtypes_code(
# dtypes are selected (i.e. both cases).
#
body = "return true;"
if selective_builder.include_all_operators is False and selective_builder.include_all_kernel_dtypes is False:
if selective_builder.include_all_operators is False and selective_builder.include_all_non_op_selectives is False:
body_parts = []
for kernel_tag, dtypes in selective_builder.kernel_metadata.items():
conditions = list(map(lambda x: 'scalar_type == at::ScalarType::' + x, dtypes))
@ -76,10 +76,16 @@ def write_selected_mobile_ops(
selective_builder: SelectiveBuilder,
) -> None:
root_ops = extract_root_operators(selective_builder)
custom_classes = selective_builder.custom_classes
with open(output_file_path, "wb") as out_file:
body_parts = [selected_mobile_ops_preamble]
# This condition checks if we are in selective build.
# if these lists are not defined the corresponding selective build macros trivially return the item in question was selected
if not selective_builder.include_all_operators:
body_parts.append("#define TORCH_OPERATOR_WHITELIST " + (";".join(sorted(root_ops))) + ";\n\n")
# This condition checks if we are in tracing based selective build
if selective_builder.include_all_non_op_selectives is False:
body_parts.append("#define TORCH_CUSTOM_CLASS_ALLOWLIST " + (";".join(sorted(custom_classes))) + ";\n\n")
body_parts.append(get_selected_kernel_dtypes_code(selective_builder))
header_contents = "".join(body_parts)

View File

@ -0,0 +1,13 @@
#include <torch/csrc/jit/mobile/model_tracer/CustomClassTracer.h>
namespace torch {
namespace jit {
namespace mobile {
CustomClassTracer::custom_classes_type& CustomClassTracer::getLoadedClasses() {
static custom_classes_type loaded_classes;
return loaded_classes;
}
} // namespace mobile
} // namespace jit
} // namespace torch

View File

@ -0,0 +1,52 @@
#pragma once
#include <ATen/record_function.h>
#include <map>
#include <set>
#include <string>
namespace torch {
namespace jit {
namespace mobile {
/* The CustomClassTracer class handles the attachment and removal of a recording
* callback that traces the invocation of code that handles loading custom
* classes on mobile.
*
* You can get the set of used custom classes using
* getLoadedClasses().
*
* Note: This class is not thread safe or re-entrant, and should not be used
* across multiple threads of execution.
*
*/
struct CustomClassTracer final {
at::CallbackHandle handle_;
/* These are the custom class names (constant
* character string) which shows up in code.
*/
typedef std::set<std::string> custom_classes_type;
CustomClassTracer() {
auto recorder_cb = [](const at::RecordFunction& fn)
-> std::unique_ptr<at::ObserverContext> {
std::string name = fn.name().str();
getLoadedClasses().insert(name);
return nullptr;
};
handle_ =
at::addGlobalCallback(at::RecordFunctionCallback(recorder_cb)
.scopes({at::RecordScope::CUSTOM_CLASS}));
}
static custom_classes_type& getLoadedClasses();
~CustomClassTracer() {
at::removeCallback(handle_);
}
};
} // namespace mobile
} // namespace jit
} // namespace torch

View File

@ -1,4 +1,5 @@
#include <ATen/Functions.h>
#include <ATen/core/dispatch/Dispatcher.h>
#include <ATen/core/dispatch/ObservedOperators.h>
#include <c10/core/ScalarType.h>
#include <c10/util/Exception.h>
@ -9,6 +10,8 @@
#include <torch/csrc/jit/mobile/model_tracer/TensorUtils.h>
#include <torch/csrc/jit/mobile/model_tracer/TracerRunner.h>
#include <torch/csrc/jit/mobile/parse_operators.h>
#include <torch/csrc/jit/mobile/runtime_compatibility.h>
#include <torch/csrc/jit/runtime/operator.h>
#include <torch/script.h>
namespace torch {
@ -92,6 +95,84 @@ void consume_tensor(const at::Tensor& t) {
c.copy_(t.cpu());
}
std::unordered_map<std::string, c10::FunctionSchema>
_get_runtime_ops_and_schema() {
std::unordered_map<std::string, c10::FunctionSchema> result;
// Grab the jit operators
auto nonDispatcherOperators = torch::jit::getAllOperators();
for (const auto& full_op : nonDispatcherOperators) {
auto op = full_op->schema();
auto op_name = op.name();
if (!op.overload_name().empty()) {
op_name += ("." + op.overload_name());
}
result.emplace(op_name, op);
}
// Grab the dispatcher operators
auto dispatcherOperators = c10::Dispatcher::singleton().getAllOpNames();
for (auto& op : dispatcherOperators) {
// grab schema
const auto op_handle = c10::Dispatcher::singleton().findOp(op);
if (op_handle->hasSchema()) {
auto op_name = op.name;
if (!op.overload_name.empty()) {
op_name += ("." + op.overload_name);
}
result.emplace(op_name, op_handle->schema());
}
}
return result;
}
/**
* For the vast majority of usecases the instrumentation in getCustomClass will
* catch any custom classes referenced by a model. There are however, niche
* situations that avoid the getCustomClass instrumentation due to some nuances
* of mobile model deserialization. To get around that we can search through all
* the used ops, and inspect their schemas to search for any referenced classes.
* Example schema: prepacked::linear_clamp_prepack(Tensor W, Tensor? B=None,
* Scalar? output_min=None, Scalar? output_max=None) ->
* __torch__.torch.classes.xnnpack.LinearOpContext"
*/
void recordCustomClassesFromOpSchemas(
std::set<std::string>& root_ops,
std::set<std::string>& traced_ops,
std::set<std::string>& loaded_classes) {
std::set<std::string> ops;
ops.insert(root_ops.begin(), root_ops.end());
ops.insert(traced_ops.begin(), traced_ops.end());
auto ops_and_schemas = _get_runtime_ops_and_schema();
auto record_if_class = [&](std::string type_name) {
// All custom class types start with __torch__ not sure if this is by
// chance or guaranteed
if (type_name.find("__torch__") != std::string::npos) {
// The name of a customClassType here is its fully qualified name, but
// in registration only the class name is used so only record that
loaded_classes.insert(type_name.substr(type_name.find_last_of('.') + 1));
}
};
for (auto& op_name : ops) {
// This check is only necessary because of GPU models.
// Certain models can only run on a specific backend say metal.
// Those ops will be present in the models root ops, but likely
// not the tracer on linux
if (ops_and_schemas.find(op_name) != ops_and_schemas.end()) {
auto& schema = ops_and_schemas.at(op_name);
for (auto& arg : schema.arguments()) {
record_if_class(arg.type()->annotation_str());
}
for (auto& ret : schema.returns()) {
record_if_class(ret.type()->annotation_str());
}
}
}
}
void run_model(
const std::string& input_module_path,
std::set<std::string>& root_ops,
@ -101,7 +182,6 @@ void run_model(
// This is needed so that we can load any TorchBind objects (custom classes)
// that this model refers to so that any operators being called from those
// TorchBind objects can be traced by the model tracer.
//
torch::jit::mobile::MobileModelRunner module_runner(input_module_path, 0);
root_ops = module_runner.get_root_operators();
std::cout << "Got " << root_ops.size() << " Root Operators." << std::endl;
@ -178,10 +258,12 @@ TracerResult trace_run(const std::string& input_module_path) {
torch::jit::mobile::OperatorCallTracer op_tracer;
torch::jit::mobile::KernelDTypeTracer kdtype_tracer;
torch::jit::mobile::CustomClassTracer custom_class_tracer;
call_setup_methods();
std::set<std::string> root_ops, traced_operators, enabled_backends;
std::set<std::string> root_ops, traced_operators, enabled_backends,
loaded_classes;
torch::jit::mobile::KernelDTypeTracer::kernel_tags_type called_kernel_tags;
using torch::jit::MobileModuleLoadOptions;
@ -192,13 +274,21 @@ TracerResult trace_run(const std::string& input_module_path) {
run_model(input_module_path, root_ops, enabled_backends, called_kernel_tags);
traced_operators = op_tracer.getCalledOperators();
recordCustomClassesFromOpSchemas(root_ops, traced_operators, loaded_classes);
called_kernel_tags.insert(
kdtype_tracer.getCalledKernelTags().begin(),
kdtype_tracer.getCalledKernelTags().end());
traced_operators.insert(
always_included_traced_ops.begin(), always_included_traced_ops.end());
loaded_classes.insert(
custom_class_tracer.getLoadedClasses().begin(),
custom_class_tracer.getLoadedClasses().end());
TracerResult tracer_result = {
root_ops, traced_operators, called_kernel_tags, enabled_backends};
root_ops,
traced_operators,
called_kernel_tags,
loaded_classes,
enabled_backends};
return tracer_result;
}

View File

@ -1,6 +1,7 @@
#pragma once
#include <ATen/core/ivalue.h>
#include <torch/csrc/jit/mobile/model_tracer/CustomClassTracer.h>
#include <torch/csrc/jit/mobile/model_tracer/KernelDTypeTracer.h>
namespace torch {
@ -17,6 +18,7 @@ struct TracerResult {
std::set<std::string> root_ops;
std::set<std::string> traced_operators;
KernelDTypeTracer::kernel_tags_type called_kernel_tags;
CustomClassTracer::custom_classes_type loaded_classes;
std::set<std::string> enabled_backends;
};

View File

@ -135,7 +135,7 @@ int main(int argc, char* argv[]) {
}
}
yaml_out << "include_all_kernel_dtypes: true" << std::endl;
yaml_out << "include_all_non_op_selectives: true" << std::endl;
yaml_out << "operators:" << std::endl;
printOpsYAML(
yaml_out,

View File

@ -8,6 +8,29 @@
namespace torch {
namespace detail {
/**
* In the Facebook internal build (using BUCK), this macro is enabled by
* passing in -c pt.enable_record_kernel_dtype=1 when building the tracer
* binary.
*/
#if defined ENABLE_RECORD_KERNEL_FUNCTION_DTYPE
TORCH_API void record_custom_class(std::string name);
/**
* Record an instance of a custom class being loaded
* grab portion of string after final '.' from qualified name
* as this seemingly aligns with how users name their custom classes
* example: __torch__.torch.classes.xnnpack.Conv2dOpContext
*/
#define RECORD_CUSTOM_CLASS(NAME) \
auto name = std::string(NAME); \
detail::record_custom_class(name.substr(name.find_last_of(".") + 1));
#else
#define RECORD_CUSTOM_CLASS(NAME)
#endif
} // namespace detail
/// This struct is used to represent default values for arguments
/// when registering methods for custom classes.
/// static auto register_foo = torch::class_<Foo>("myclasses", "Foo")