Fix crash on unload torch cpu dll (#67632)

Trying to rebase https://github.com/pytorch/pytorch/pull/61290 into latest pytorch:master
Pull Request resolved: https://github.com/pytorch/pytorch/pull/67632
Approved by: https://github.com/ezyang
This commit is contained in:
David Braun
2022-07-31 21:37:56 +00:00
committed by PyTorch MergeBot
parent 53f56894ae
commit a54c9a419e
12 changed files with 108 additions and 14 deletions

View File

@ -96,6 +96,16 @@ void registerCustomClassMethod(std::unique_ptr<jit::Function> fn) {
customClassMethods().emplace_back(std::move(fn));
}
void deregisterCustomClassMethod(const c10::QualifiedName& name) {
auto& methods = customClassMethods();
methods.erase(
std::remove_if(
methods.begin(),
methods.end(),
[&name](const auto& method) { return method->qualname() == name; }),
methods.end());
}
std::vector<c10::FunctionSchema> customClassSchemasForBCCheck() {
auto& methods = customClassMethods();
return c10::fmap(methods, [](const std::unique_ptr<jit::Function>& fn) {

View File

@ -466,6 +466,10 @@ class TORCH_API OpSchemaRegistry {
static OpSchema&
NewSchema(const string& key, const string& file, const int line);
static void RemoveSchema(const std::string& key) {
map().erase(key);
}
static const OpSchema* Schema(const string& key) {
auto& m = map();
auto it = m.find(key);
@ -583,10 +587,22 @@ OpSchema::Cost PointwiseCostInference(
#ifndef CAFFE2_NO_OPERATOR_SCHEMA
#define OPERATOR_SCHEMA(name) \
EXPORT_IF_NOT_MSVC void CAFFE2_PLEASE_ADD_OPERATOR_SCHEMA_FOR_##name(){}; \
static OpSchema* C10_ANONYMOUS_VARIABLE(name) CAFFE2_UNUSED = \
&OpSchemaRegistry::NewSchema(#name, __FILE__, __LINE__)
#define OPERATOR_SCHEMA(name) \
EXPORT_IF_NOT_MSVC void CAFFE2_PLEASE_ADD_OPERATOR_SCHEMA_FOR_##name(){}; \
static OpSchema& RegisterOpSchema_##name() { \
struct OpSchemaRegisterer_##name { \
OpSchemaRegisterer_##name() \
: schema(OpSchemaRegistry::NewSchema(#name, __FILE__, __LINE__)) {} \
~OpSchemaRegisterer_##name() { \
OpSchemaRegistry::RemoveSchema(#name); \
} \
OpSchema& schema; \
}; \
static OpSchemaRegisterer_##name op_schema_registerer_##name; \
return op_schema_registerer_##name.schema; \
} \
static OpSchema* C10_ANONYMOUS_VARIABLE(name) CAFFE2_UNUSED = \
&RegisterOpSchema_##name()
#else // CAFFE2_NO_OPERATOR_SCHEMA

View File

@ -17,7 +17,7 @@ namespace cuda {
namespace {
class RegisterInterface {
public:
RegisterInterface() {
RegisterInterface() : canFuseNodeIndex_(RegisterProfilingNode(canFuseNode)) {
auto ptr = getFuserInterface();
ptr->fn_compile_n = &compileCudaFusionGroup;
ptr->fn_run_n_s = &runCudaFusionGroup;
@ -27,6 +27,13 @@ class RegisterInterface {
ptr->fn_profile_n = &shouldProfileNode;
ptr->fn_skip_n = &skipNodeKind;
}
~RegisterInterface() {
DeregisterProfilingNode(canFuseNodeIndex_);
}
private:
int canFuseNodeIndex_;
};
static RegisterInterface register_interface_;

View File

@ -48,6 +48,11 @@ void registerFusionBackend(
getFusionBackends()[backend_type] = std::move(ctor);
}
void deregisterFusionBackend(at::Device::Type backend_type) {
std::lock_guard<std::mutex> guard(fusionBackendLock());
getFusionBackends().erase(backend_type);
}
bool hasFusionBackend(at::Device::Type backend_type) {
std::lock_guard<std::mutex> guard(fusionBackendLock());
return getFusionBackends().count(backend_type);

View File

@ -46,13 +46,19 @@ using FusedKernelConstructor = std::function<std::shared_ptr<FusedKernel>(
TORCH_API void registerFusionBackend(
at::Device::Type backend_type,
FusedKernelConstructor ctor);
TORCH_API void deregisterFusionBackend(at::Device::Type backend_type);
TORCH_API bool hasFusionBackend(at::Device::Type backend_type);
struct TORCH_API RegisterFusionBackend {
RegisterFusionBackend(
at::Device::Type backend_type,
FusedKernelConstructor ctor) {
FusedKernelConstructor ctor)
: backend_type(backend_type) {
registerFusionBackend(backend_type, std::move(ctor));
}
~RegisterFusionBackend() {
deregisterFusionBackend(backend_type);
}
at::Device::Type backend_type;
};
} // namespace fuser

View File

@ -22,10 +22,18 @@ struct TORCH_API RegisterOperators {
explicit RegisterOperators(std::vector<c10::optional<Operator>> operators) {
for (c10::optional<Operator>& o : operators) {
if (o) {
registered_schemas.push_back(o.value().schema());
registerOperator(std::move(o.value()));
}
}
}
~RegisterOperators() {
for (const auto& s : registered_schemas)
deregisterOperator(s);
}
std::vector<c10::FunctionSchema> registered_schemas;
};
} // namespace jit

View File

@ -22,9 +22,15 @@ class ProfileRegistry {
return &profile_registry_;
}
void registerProfileNode(const std::function<bool(const Node*)>& func) {
int registerProfileNode(const std::function<bool(const Node*)>& func) {
std::lock_guard<std::mutex> guard(mutex_);
registry_funcs_.push_back(func);
registry_funcs_[registry_index_] = func;
return registry_index_++;
}
void deregisterProfileNode(int index) {
std::lock_guard<std::mutex> guard(mutex_);
registry_funcs_.erase(index);
}
bool shouldProfileNode(const Node* node) {
@ -35,7 +41,7 @@ class ProfileRegistry {
return true;
}
for (const auto& func : registry_funcs_) {
if (func(node)) {
if (func.second(node)) {
return true;
}
}
@ -43,14 +49,19 @@ class ProfileRegistry {
}
private:
std::vector<std::function<bool(const Node*)>> registry_funcs_;
int registry_index_{};
std::unordered_map<int, std::function<bool(const Node*)>> registry_funcs_;
std::mutex mutex_;
};
} // namespace
void RegisterProfilingNode(const std::function<bool(const Node*)>& func) {
ProfileRegistry::getRegistry()->registerProfileNode(func);
int RegisterProfilingNode(const std::function<bool(const Node*)>& func) {
return ProfileRegistry::getRegistry()->registerProfileNode(func);
}
void DeregisterProfilingNode(int index) {
ProfileRegistry::getRegistry()->deregisterProfileNode(index);
}
bool ShapeSymbolTable::bindSymbolicShapes(

View File

@ -82,7 +82,8 @@ namespace jit {
using ::c10::TensorTypePtr;
using Dimension = int64_t;
TORCH_API void RegisterProfilingNode(const std::function<bool(const Node*)>&);
TORCH_API int RegisterProfilingNode(const std::function<bool(const Node*)>&);
TORCH_API void DeregisterProfilingNode(int index);
struct ProfilingRecord;

View File

@ -49,6 +49,10 @@ void RegisterCodeGenList::AddStmtFactoryMethod(
stmt_factory_methods_[name] = stmt_factory_method;
}
void RegisterCodeGenList::RemoveStmtFactoryMethod(const std::string& name) {
stmt_factory_methods_.erase(name);
}
std::unique_ptr<CodeGen> CreateCodeGen(
const std::string& name,
StmtPtr stmt,

View File

@ -214,6 +214,7 @@ class RegisterCodeGenList {
TORCH_API void AddStmtFactoryMethod(
const std::string& name,
const StmtFactoryMethod& stmt_factory_method);
TORCH_API void RemoveStmtFactoryMethod(const std::string& name);
std::unordered_map<std::string, StmtFactoryMethod> stmt_factory_methods_;
};
@ -221,7 +222,7 @@ class RegisterCodeGenList {
template <class CodeGenType>
class RegisterCodeGen {
public:
explicit RegisterCodeGen(const std::string& name) {
explicit RegisterCodeGen(const std::string& name) : name_(name) {
RegisterCodeGenList& codegen_list = RegisterCodeGenList::GetInstance();
codegen_list.AddStmtFactoryMethod(
name,
@ -235,6 +236,13 @@ class RegisterCodeGen {
return method;
});
}
~RegisterCodeGen() {
RegisterCodeGenList::GetInstance().RemoveStmtFactoryMethod(name_);
}
private:
std::string name_;
};
TORCH_API std::unique_ptr<CodeGen> CreateCodeGen(

View File

@ -421,6 +421,23 @@ class class_ : public ::torch::detail::class_base {
registerCustomClassMethod(std::move(method));
return method_val;
}
// Wrapper function to force method deregistration on shutdown.
static void registerCustomClassMethod(std::unique_ptr<jit::Function> method) {
c10::QualifiedName method_name = method->qualname();
torch::registerCustomClassMethod(std::move(method));
struct CustomClassMethodDeregistrator {
~CustomClassMethodDeregistrator() {
for (const auto& name : names)
torch::deregisterCustomClassMethod(name);
}
std::vector<c10::QualifiedName> names;
};
static CustomClassMethodDeregistrator deregistator;
deregistator.names.emplace_back(std::move(method_name));
}
};
/// make_custom_class() is a convenient way to create an instance of a

View File

@ -217,6 +217,7 @@ class TORCH_API class_base {
TORCH_API void registerCustomClass(at::ClassTypePtr class_type);
TORCH_API void registerCustomClassMethod(std::unique_ptr<jit::Function> method);
TORCH_API void deregisterCustomClassMethod(const c10::QualifiedName& name);
// Given a qualified name (e.g. __torch__.torch.classes.Foo), return
// the ClassType pointer to the Type that describes that custom class,