mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[Pytorch Edge] Extend Tracer to Custom Classes (#67004)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/67004 New version because the other one was impossible to rebase Trace custom classes Test Plan: CI. Reviewed By: dhruvbird Differential Revision: D31818978 fbshipit-source-id: daa22ccb153e32685bcca43a303ba9e21042d052
This commit is contained in:
committed by
Facebook GitHub Bot
parent
34ee5b11ff
commit
6c22b96082
@ -1,5 +1,5 @@
|
||||
#include <torch/custom_class.h>
|
||||
|
||||
#include <ATen/record_function.h>
|
||||
#include <ATen/core/jit_type.h>
|
||||
#include <ATen/core/function_schema.h>
|
||||
#include <ATen/core/functional.h>
|
||||
@ -9,6 +9,14 @@
|
||||
|
||||
namespace torch {
|
||||
|
||||
namespace detail {
|
||||
|
||||
void record_custom_class(std::string name) {
|
||||
RECORD_FUNCTION_WITH_SCOPE(at::RecordScope::CUSTOM_CLASS, name, {});
|
||||
}
|
||||
|
||||
} // namespace detail
|
||||
|
||||
std::unordered_map<std::string, at::ClassTypePtr>& customClasses() {
|
||||
static std::unordered_map<std::string, at::ClassTypePtr> customClasses;
|
||||
return customClasses;
|
||||
@ -25,8 +33,12 @@ void registerCustomClass(at::ClassTypePtr class_type) {
|
||||
customClasses()[name] = std::move(class_type);
|
||||
}
|
||||
|
||||
at::ClassTypePtr getCustomClass(const std::string& name) {
|
||||
return customClasses().count(name) ? customClasses()[name] : nullptr;
|
||||
at::ClassTypePtr getCustomClass(const std::string& class_name) {
|
||||
auto ret = customClasses().count(class_name) ? customClasses()[class_name] : nullptr;
|
||||
if (ret) {
|
||||
RECORD_CUSTOM_CLASS(class_name);
|
||||
}
|
||||
return ret;
|
||||
}
|
||||
|
||||
const std::unordered_set<std::string> getAllCustomClassesNames() {
|
||||
|
@ -91,6 +91,21 @@ constexpr bool schema_allowlist_check(string_view schema) {
|
||||
#endif
|
||||
}
|
||||
|
||||
// Returns true iff the given custom class name is on the allowlist
|
||||
// and should be registered
|
||||
constexpr bool custom_class_allowlist_check(string_view custom_class_name) {
|
||||
#if !defined(TORCH_CUSTOM_CLASS_ALLOWLIST)
|
||||
// If the TORCH_CUSTOM_CLASS_ALLOWLIST parameter is not defined,
|
||||
// all custom classes are to be registered
|
||||
(void)custom_class_name;
|
||||
return true;
|
||||
#else
|
||||
return op_allowlist_contains(
|
||||
C10_STRINGIZE(TORCH_CUSTOM_CLASS_ALLOWLIST),
|
||||
custom_class_name);
|
||||
#endif
|
||||
}
|
||||
|
||||
// schema_allowlist_check() implicitly depends on a macro, TORCH_OPERATOR_WHITELIST.
|
||||
// Add this API to pass arbitrary allowlist.
|
||||
constexpr bool op_allowlist_contains_name_in_schema(string_view allowlist, string_view schema) {
|
||||
|
@ -27,6 +27,8 @@ enum class C10_API_ENUM RecordScope : uint8_t {
|
||||
TORCHSCRIPT_FUNCTION,
|
||||
// Kernel Function dtype Tag
|
||||
KERNEL_FUNCTION_DTYPE,
|
||||
// Torchbind custom class,
|
||||
CUSTOM_CLASS,
|
||||
// Kernel Function dtype Tag
|
||||
LITE_INTERPRETER,
|
||||
// User defined scope (e.g. with record_function())
|
||||
|
@ -36,11 +36,16 @@ class SelectiveBuilder:
|
||||
# of the kernel function implementation itself.
|
||||
kernel_metadata: Dict[str, List[str]]
|
||||
|
||||
# A set of all the custom torch bind classes used by the selected models
|
||||
# Stored as a set internally to remove duplicates proactively, but written
|
||||
# as a list to yamls
|
||||
custom_classes: Set[str]
|
||||
|
||||
# If true, then fragments for all dtypes for all kernel functions
|
||||
# are included. This is typically set when any one of the
|
||||
# are included as well as all custom classes. This is typically set when any one of the
|
||||
# operator lists is generated from a mechanism other than
|
||||
# tracing based selective build.
|
||||
include_all_kernel_dtypes: bool
|
||||
include_all_non_op_selectives: bool
|
||||
|
||||
@staticmethod
|
||||
def get_nop_selector() -> 'SelectiveBuilder':
|
||||
@ -49,11 +54,12 @@ class SelectiveBuilder:
|
||||
@staticmethod
|
||||
def from_yaml_dict(data: Dict[str, object]) -> 'SelectiveBuilder':
|
||||
valid_top_level_keys = {
|
||||
'include_all_kernel_dtypes',
|
||||
'include_all_non_op_selectives',
|
||||
'include_all_operators',
|
||||
'debug_info',
|
||||
'operators',
|
||||
'kernel_metadata',
|
||||
'custom_classes',
|
||||
}
|
||||
top_level_keys = set(data.keys())
|
||||
if len(top_level_keys - valid_top_level_keys) > 0:
|
||||
@ -84,15 +90,19 @@ class SelectiveBuilder:
|
||||
for (k, v) in kernel_metadata_dict.items():
|
||||
kernel_metadata[str(k)] = list(map(lambda dtype: str(dtype), v))
|
||||
|
||||
include_all_kernel_dtypes = data.get('include_all_kernel_dtypes', False)
|
||||
assert isinstance(include_all_kernel_dtypes, bool)
|
||||
custom_classes = data.get('custom_classes', [])
|
||||
custom_classes = set(custom_classes) # type: ignore[arg-type]
|
||||
|
||||
include_all_non_op_selectives = data.get('include_all_non_op_selectives', False)
|
||||
assert isinstance(include_all_non_op_selectives, bool)
|
||||
|
||||
return SelectiveBuilder(
|
||||
include_all_operators,
|
||||
debug_info,
|
||||
operators,
|
||||
kernel_metadata,
|
||||
include_all_kernel_dtypes,
|
||||
custom_classes, # type: ignore[arg-type]
|
||||
include_all_non_op_selectives,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
@ -121,7 +131,7 @@ class SelectiveBuilder:
|
||||
}
|
||||
return SelectiveBuilder.from_yaml_dict({
|
||||
'operators': operators,
|
||||
'include_all_kernel_dtypes': True,
|
||||
'include_all_non_op_selectives': True,
|
||||
})
|
||||
|
||||
def is_operator_selected(self, name: str) -> bool:
|
||||
@ -184,14 +194,14 @@ class SelectiveBuilder:
|
||||
return base_op.include_all_overloads and base_op.is_root_operator
|
||||
|
||||
def is_kernel_dtype_selected(self, kernel_tag: str, dtype: str) -> bool:
|
||||
if self.include_all_operators or self.include_all_kernel_dtypes:
|
||||
if self.include_all_operators or self.include_all_non_op_selectives:
|
||||
return True
|
||||
|
||||
return kernel_tag in self.kernel_metadata and dtype in self.kernel_metadata[kernel_tag]
|
||||
|
||||
def to_dict(self) -> Dict[str, object]:
|
||||
ret: Dict[str, object] = {
|
||||
'include_all_kernel_dtypes': self.include_all_kernel_dtypes,
|
||||
'include_all_non_op_selectives': self.include_all_non_op_selectives,
|
||||
'include_all_operators': self.include_all_operators,
|
||||
}
|
||||
operators = {}
|
||||
@ -204,6 +214,8 @@ class SelectiveBuilder:
|
||||
|
||||
ret['kernel_metadata'] = {k: sorted(list(v)) for (k, v) in self.kernel_metadata.items()}
|
||||
|
||||
ret['custom_classes'] = sorted(self.custom_classes)
|
||||
|
||||
return ret
|
||||
|
||||
|
||||
@ -226,13 +238,15 @@ def combine_selective_builders(lhs: SelectiveBuilder, rhs: SelectiveBuilder) ->
|
||||
debug_info = merge_debug_info(lhs._debug_info, rhs._debug_info)
|
||||
operators = merge_operator_dicts(lhs.operators, rhs.operators)
|
||||
kernel_metadata = merge_kernel_metadata(lhs.kernel_metadata, rhs.kernel_metadata)
|
||||
include_all_kernel_dtypes = lhs.include_all_kernel_dtypes or rhs.include_all_kernel_dtypes
|
||||
include_all_non_op_selectives = lhs.include_all_non_op_selectives or rhs.include_all_non_op_selectives
|
||||
custom_classes = lhs.custom_classes.union(rhs.custom_classes)
|
||||
return SelectiveBuilder(
|
||||
include_all_operators,
|
||||
debug_info,
|
||||
operators,
|
||||
kernel_metadata,
|
||||
include_all_kernel_dtypes,
|
||||
custom_classes,
|
||||
include_all_non_op_selectives,
|
||||
)
|
||||
|
||||
|
||||
|
@ -52,7 +52,7 @@ def get_selected_kernel_dtypes_code(
|
||||
# dtypes are selected (i.e. both cases).
|
||||
#
|
||||
body = "return true;"
|
||||
if selective_builder.include_all_operators is False and selective_builder.include_all_kernel_dtypes is False:
|
||||
if selective_builder.include_all_operators is False and selective_builder.include_all_non_op_selectives is False:
|
||||
body_parts = []
|
||||
for kernel_tag, dtypes in selective_builder.kernel_metadata.items():
|
||||
conditions = list(map(lambda x: 'scalar_type == at::ScalarType::' + x, dtypes))
|
||||
@ -76,10 +76,16 @@ def write_selected_mobile_ops(
|
||||
selective_builder: SelectiveBuilder,
|
||||
) -> None:
|
||||
root_ops = extract_root_operators(selective_builder)
|
||||
custom_classes = selective_builder.custom_classes
|
||||
with open(output_file_path, "wb") as out_file:
|
||||
body_parts = [selected_mobile_ops_preamble]
|
||||
# This condition checks if we are in selective build.
|
||||
# if these lists are not defined the corresponding selective build macros trivially return the item in question was selected
|
||||
if not selective_builder.include_all_operators:
|
||||
body_parts.append("#define TORCH_OPERATOR_WHITELIST " + (";".join(sorted(root_ops))) + ";\n\n")
|
||||
# This condition checks if we are in tracing based selective build
|
||||
if selective_builder.include_all_non_op_selectives is False:
|
||||
body_parts.append("#define TORCH_CUSTOM_CLASS_ALLOWLIST " + (";".join(sorted(custom_classes))) + ";\n\n")
|
||||
|
||||
body_parts.append(get_selected_kernel_dtypes_code(selective_builder))
|
||||
header_contents = "".join(body_parts)
|
||||
|
13
torch/csrc/jit/mobile/model_tracer/CustomClassTracer.cpp
Normal file
13
torch/csrc/jit/mobile/model_tracer/CustomClassTracer.cpp
Normal file
@ -0,0 +1,13 @@
|
||||
#include <torch/csrc/jit/mobile/model_tracer/CustomClassTracer.h>
|
||||
|
||||
namespace torch {
|
||||
namespace jit {
|
||||
namespace mobile {
|
||||
CustomClassTracer::custom_classes_type& CustomClassTracer::getLoadedClasses() {
|
||||
static custom_classes_type loaded_classes;
|
||||
return loaded_classes;
|
||||
}
|
||||
|
||||
} // namespace mobile
|
||||
} // namespace jit
|
||||
} // namespace torch
|
52
torch/csrc/jit/mobile/model_tracer/CustomClassTracer.h
Normal file
52
torch/csrc/jit/mobile/model_tracer/CustomClassTracer.h
Normal file
@ -0,0 +1,52 @@
|
||||
#pragma once
|
||||
|
||||
#include <ATen/record_function.h>
|
||||
#include <map>
|
||||
#include <set>
|
||||
#include <string>
|
||||
|
||||
namespace torch {
|
||||
namespace jit {
|
||||
namespace mobile {
|
||||
|
||||
/* The CustomClassTracer class handles the attachment and removal of a recording
|
||||
* callback that traces the invocation of code that handles loading custom
|
||||
* classes on mobile.
|
||||
*
|
||||
* You can get the set of used custom classes using
|
||||
* getLoadedClasses().
|
||||
*
|
||||
* Note: This class is not thread safe or re-entrant, and should not be used
|
||||
* across multiple threads of execution.
|
||||
*
|
||||
*/
|
||||
struct CustomClassTracer final {
|
||||
at::CallbackHandle handle_;
|
||||
/* These are the custom class names (constant
|
||||
* character string) which shows up in code.
|
||||
*/
|
||||
typedef std::set<std::string> custom_classes_type;
|
||||
|
||||
CustomClassTracer() {
|
||||
auto recorder_cb = [](const at::RecordFunction& fn)
|
||||
-> std::unique_ptr<at::ObserverContext> {
|
||||
std::string name = fn.name().str();
|
||||
getLoadedClasses().insert(name);
|
||||
return nullptr;
|
||||
};
|
||||
|
||||
handle_ =
|
||||
at::addGlobalCallback(at::RecordFunctionCallback(recorder_cb)
|
||||
.scopes({at::RecordScope::CUSTOM_CLASS}));
|
||||
}
|
||||
|
||||
static custom_classes_type& getLoadedClasses();
|
||||
|
||||
~CustomClassTracer() {
|
||||
at::removeCallback(handle_);
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace mobile
|
||||
} // namespace jit
|
||||
} // namespace torch
|
@ -1,4 +1,5 @@
|
||||
#include <ATen/Functions.h>
|
||||
#include <ATen/core/dispatch/Dispatcher.h>
|
||||
#include <ATen/core/dispatch/ObservedOperators.h>
|
||||
#include <c10/core/ScalarType.h>
|
||||
#include <c10/util/Exception.h>
|
||||
@ -9,6 +10,8 @@
|
||||
#include <torch/csrc/jit/mobile/model_tracer/TensorUtils.h>
|
||||
#include <torch/csrc/jit/mobile/model_tracer/TracerRunner.h>
|
||||
#include <torch/csrc/jit/mobile/parse_operators.h>
|
||||
#include <torch/csrc/jit/mobile/runtime_compatibility.h>
|
||||
#include <torch/csrc/jit/runtime/operator.h>
|
||||
#include <torch/script.h>
|
||||
|
||||
namespace torch {
|
||||
@ -92,6 +95,84 @@ void consume_tensor(const at::Tensor& t) {
|
||||
c.copy_(t.cpu());
|
||||
}
|
||||
|
||||
std::unordered_map<std::string, c10::FunctionSchema>
|
||||
_get_runtime_ops_and_schema() {
|
||||
std::unordered_map<std::string, c10::FunctionSchema> result;
|
||||
|
||||
// Grab the jit operators
|
||||
auto nonDispatcherOperators = torch::jit::getAllOperators();
|
||||
for (const auto& full_op : nonDispatcherOperators) {
|
||||
auto op = full_op->schema();
|
||||
auto op_name = op.name();
|
||||
if (!op.overload_name().empty()) {
|
||||
op_name += ("." + op.overload_name());
|
||||
}
|
||||
result.emplace(op_name, op);
|
||||
}
|
||||
|
||||
// Grab the dispatcher operators
|
||||
auto dispatcherOperators = c10::Dispatcher::singleton().getAllOpNames();
|
||||
for (auto& op : dispatcherOperators) {
|
||||
// grab schema
|
||||
const auto op_handle = c10::Dispatcher::singleton().findOp(op);
|
||||
if (op_handle->hasSchema()) {
|
||||
auto op_name = op.name;
|
||||
if (!op.overload_name.empty()) {
|
||||
op_name += ("." + op.overload_name);
|
||||
}
|
||||
result.emplace(op_name, op_handle->schema());
|
||||
}
|
||||
}
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
/**
|
||||
* For the vast majority of usecases the instrumentation in getCustomClass will
|
||||
* catch any custom classes referenced by a model. There are however, niche
|
||||
* situations that avoid the getCustomClass instrumentation due to some nuances
|
||||
* of mobile model deserialization. To get around that we can search through all
|
||||
* the used ops, and inspect their schemas to search for any referenced classes.
|
||||
* Example schema: prepacked::linear_clamp_prepack(Tensor W, Tensor? B=None,
|
||||
* Scalar? output_min=None, Scalar? output_max=None) ->
|
||||
* __torch__.torch.classes.xnnpack.LinearOpContext"
|
||||
*/
|
||||
void recordCustomClassesFromOpSchemas(
|
||||
std::set<std::string>& root_ops,
|
||||
std::set<std::string>& traced_ops,
|
||||
std::set<std::string>& loaded_classes) {
|
||||
std::set<std::string> ops;
|
||||
ops.insert(root_ops.begin(), root_ops.end());
|
||||
ops.insert(traced_ops.begin(), traced_ops.end());
|
||||
auto ops_and_schemas = _get_runtime_ops_and_schema();
|
||||
|
||||
auto record_if_class = [&](std::string type_name) {
|
||||
// All custom class types start with __torch__ not sure if this is by
|
||||
// chance or guaranteed
|
||||
if (type_name.find("__torch__") != std::string::npos) {
|
||||
// The name of a customClassType here is its fully qualified name, but
|
||||
// in registration only the class name is used so only record that
|
||||
loaded_classes.insert(type_name.substr(type_name.find_last_of('.') + 1));
|
||||
}
|
||||
};
|
||||
|
||||
for (auto& op_name : ops) {
|
||||
// This check is only necessary because of GPU models.
|
||||
// Certain models can only run on a specific backend say metal.
|
||||
// Those ops will be present in the models root ops, but likely
|
||||
// not the tracer on linux
|
||||
if (ops_and_schemas.find(op_name) != ops_and_schemas.end()) {
|
||||
auto& schema = ops_and_schemas.at(op_name);
|
||||
for (auto& arg : schema.arguments()) {
|
||||
record_if_class(arg.type()->annotation_str());
|
||||
}
|
||||
for (auto& ret : schema.returns()) {
|
||||
record_if_class(ret.type()->annotation_str());
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void run_model(
|
||||
const std::string& input_module_path,
|
||||
std::set<std::string>& root_ops,
|
||||
@ -101,7 +182,6 @@ void run_model(
|
||||
// This is needed so that we can load any TorchBind objects (custom classes)
|
||||
// that this model refers to so that any operators being called from those
|
||||
// TorchBind objects can be traced by the model tracer.
|
||||
//
|
||||
torch::jit::mobile::MobileModelRunner module_runner(input_module_path, 0);
|
||||
root_ops = module_runner.get_root_operators();
|
||||
std::cout << "Got " << root_ops.size() << " Root Operators." << std::endl;
|
||||
@ -178,10 +258,12 @@ TracerResult trace_run(const std::string& input_module_path) {
|
||||
|
||||
torch::jit::mobile::OperatorCallTracer op_tracer;
|
||||
torch::jit::mobile::KernelDTypeTracer kdtype_tracer;
|
||||
torch::jit::mobile::CustomClassTracer custom_class_tracer;
|
||||
|
||||
call_setup_methods();
|
||||
|
||||
std::set<std::string> root_ops, traced_operators, enabled_backends;
|
||||
std::set<std::string> root_ops, traced_operators, enabled_backends,
|
||||
loaded_classes;
|
||||
torch::jit::mobile::KernelDTypeTracer::kernel_tags_type called_kernel_tags;
|
||||
|
||||
using torch::jit::MobileModuleLoadOptions;
|
||||
@ -192,13 +274,21 @@ TracerResult trace_run(const std::string& input_module_path) {
|
||||
run_model(input_module_path, root_ops, enabled_backends, called_kernel_tags);
|
||||
|
||||
traced_operators = op_tracer.getCalledOperators();
|
||||
recordCustomClassesFromOpSchemas(root_ops, traced_operators, loaded_classes);
|
||||
called_kernel_tags.insert(
|
||||
kdtype_tracer.getCalledKernelTags().begin(),
|
||||
kdtype_tracer.getCalledKernelTags().end());
|
||||
traced_operators.insert(
|
||||
always_included_traced_ops.begin(), always_included_traced_ops.end());
|
||||
loaded_classes.insert(
|
||||
custom_class_tracer.getLoadedClasses().begin(),
|
||||
custom_class_tracer.getLoadedClasses().end());
|
||||
TracerResult tracer_result = {
|
||||
root_ops, traced_operators, called_kernel_tags, enabled_backends};
|
||||
root_ops,
|
||||
traced_operators,
|
||||
called_kernel_tags,
|
||||
loaded_classes,
|
||||
enabled_backends};
|
||||
|
||||
return tracer_result;
|
||||
}
|
||||
|
@ -1,6 +1,7 @@
|
||||
#pragma once
|
||||
|
||||
#include <ATen/core/ivalue.h>
|
||||
#include <torch/csrc/jit/mobile/model_tracer/CustomClassTracer.h>
|
||||
#include <torch/csrc/jit/mobile/model_tracer/KernelDTypeTracer.h>
|
||||
|
||||
namespace torch {
|
||||
@ -17,6 +18,7 @@ struct TracerResult {
|
||||
std::set<std::string> root_ops;
|
||||
std::set<std::string> traced_operators;
|
||||
KernelDTypeTracer::kernel_tags_type called_kernel_tags;
|
||||
CustomClassTracer::custom_classes_type loaded_classes;
|
||||
std::set<std::string> enabled_backends;
|
||||
};
|
||||
|
||||
|
@ -135,7 +135,7 @@ int main(int argc, char* argv[]) {
|
||||
}
|
||||
}
|
||||
|
||||
yaml_out << "include_all_kernel_dtypes: true" << std::endl;
|
||||
yaml_out << "include_all_non_op_selectives: true" << std::endl;
|
||||
yaml_out << "operators:" << std::endl;
|
||||
printOpsYAML(
|
||||
yaml_out,
|
||||
|
@ -8,6 +8,29 @@
|
||||
|
||||
namespace torch {
|
||||
|
||||
namespace detail {
|
||||
/**
|
||||
* In the Facebook internal build (using BUCK), this macro is enabled by
|
||||
* passing in -c pt.enable_record_kernel_dtype=1 when building the tracer
|
||||
* binary.
|
||||
*/
|
||||
#if defined ENABLE_RECORD_KERNEL_FUNCTION_DTYPE
|
||||
TORCH_API void record_custom_class(std::string name);
|
||||
|
||||
/**
|
||||
* Record an instance of a custom class being loaded
|
||||
* grab portion of string after final '.' from qualified name
|
||||
* as this seemingly aligns with how users name their custom classes
|
||||
* example: __torch__.torch.classes.xnnpack.Conv2dOpContext
|
||||
*/
|
||||
#define RECORD_CUSTOM_CLASS(NAME) \
|
||||
auto name = std::string(NAME); \
|
||||
detail::record_custom_class(name.substr(name.find_last_of(".") + 1));
|
||||
#else
|
||||
#define RECORD_CUSTOM_CLASS(NAME)
|
||||
#endif
|
||||
} // namespace detail
|
||||
|
||||
/// This struct is used to represent default values for arguments
|
||||
/// when registering methods for custom classes.
|
||||
/// static auto register_foo = torch::class_<Foo>("myclasses", "Foo")
|
||||
|
Reference in New Issue
Block a user