mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Fix crash on unload torch cpu dll (#67632)
Trying to rebase https://github.com/pytorch/pytorch/pull/61290 into latest pytorch:master Pull Request resolved: https://github.com/pytorch/pytorch/pull/67632 Approved by: https://github.com/ezyang
This commit is contained in:
committed by
PyTorch MergeBot
parent
53f56894ae
commit
a54c9a419e
@ -96,6 +96,16 @@ void registerCustomClassMethod(std::unique_ptr<jit::Function> fn) {
|
||||
customClassMethods().emplace_back(std::move(fn));
|
||||
}
|
||||
|
||||
void deregisterCustomClassMethod(const c10::QualifiedName& name) {
|
||||
auto& methods = customClassMethods();
|
||||
methods.erase(
|
||||
std::remove_if(
|
||||
methods.begin(),
|
||||
methods.end(),
|
||||
[&name](const auto& method) { return method->qualname() == name; }),
|
||||
methods.end());
|
||||
}
|
||||
|
||||
std::vector<c10::FunctionSchema> customClassSchemasForBCCheck() {
|
||||
auto& methods = customClassMethods();
|
||||
return c10::fmap(methods, [](const std::unique_ptr<jit::Function>& fn) {
|
||||
|
||||
@ -466,6 +466,10 @@ class TORCH_API OpSchemaRegistry {
|
||||
static OpSchema&
|
||||
NewSchema(const string& key, const string& file, const int line);
|
||||
|
||||
static void RemoveSchema(const std::string& key) {
|
||||
map().erase(key);
|
||||
}
|
||||
|
||||
static const OpSchema* Schema(const string& key) {
|
||||
auto& m = map();
|
||||
auto it = m.find(key);
|
||||
@ -583,10 +587,22 @@ OpSchema::Cost PointwiseCostInference(
|
||||
|
||||
#ifndef CAFFE2_NO_OPERATOR_SCHEMA
|
||||
|
||||
#define OPERATOR_SCHEMA(name) \
|
||||
EXPORT_IF_NOT_MSVC void CAFFE2_PLEASE_ADD_OPERATOR_SCHEMA_FOR_##name(){}; \
|
||||
static OpSchema* C10_ANONYMOUS_VARIABLE(name) CAFFE2_UNUSED = \
|
||||
&OpSchemaRegistry::NewSchema(#name, __FILE__, __LINE__)
|
||||
#define OPERATOR_SCHEMA(name) \
|
||||
EXPORT_IF_NOT_MSVC void CAFFE2_PLEASE_ADD_OPERATOR_SCHEMA_FOR_##name(){}; \
|
||||
static OpSchema& RegisterOpSchema_##name() { \
|
||||
struct OpSchemaRegisterer_##name { \
|
||||
OpSchemaRegisterer_##name() \
|
||||
: schema(OpSchemaRegistry::NewSchema(#name, __FILE__, __LINE__)) {} \
|
||||
~OpSchemaRegisterer_##name() { \
|
||||
OpSchemaRegistry::RemoveSchema(#name); \
|
||||
} \
|
||||
OpSchema& schema; \
|
||||
}; \
|
||||
static OpSchemaRegisterer_##name op_schema_registerer_##name; \
|
||||
return op_schema_registerer_##name.schema; \
|
||||
} \
|
||||
static OpSchema* C10_ANONYMOUS_VARIABLE(name) CAFFE2_UNUSED = \
|
||||
&RegisterOpSchema_##name()
|
||||
|
||||
#else // CAFFE2_NO_OPERATOR_SCHEMA
|
||||
|
||||
|
||||
@ -17,7 +17,7 @@ namespace cuda {
|
||||
namespace {
|
||||
class RegisterInterface {
|
||||
public:
|
||||
RegisterInterface() {
|
||||
RegisterInterface() : canFuseNodeIndex_(RegisterProfilingNode(canFuseNode)) {
|
||||
auto ptr = getFuserInterface();
|
||||
ptr->fn_compile_n = &compileCudaFusionGroup;
|
||||
ptr->fn_run_n_s = &runCudaFusionGroup;
|
||||
@ -27,6 +27,13 @@ class RegisterInterface {
|
||||
ptr->fn_profile_n = &shouldProfileNode;
|
||||
ptr->fn_skip_n = &skipNodeKind;
|
||||
}
|
||||
|
||||
~RegisterInterface() {
|
||||
DeregisterProfilingNode(canFuseNodeIndex_);
|
||||
}
|
||||
|
||||
private:
|
||||
int canFuseNodeIndex_;
|
||||
};
|
||||
|
||||
static RegisterInterface register_interface_;
|
||||
|
||||
@ -48,6 +48,11 @@ void registerFusionBackend(
|
||||
getFusionBackends()[backend_type] = std::move(ctor);
|
||||
}
|
||||
|
||||
void deregisterFusionBackend(at::Device::Type backend_type) {
|
||||
std::lock_guard<std::mutex> guard(fusionBackendLock());
|
||||
getFusionBackends().erase(backend_type);
|
||||
}
|
||||
|
||||
bool hasFusionBackend(at::Device::Type backend_type) {
|
||||
std::lock_guard<std::mutex> guard(fusionBackendLock());
|
||||
return getFusionBackends().count(backend_type);
|
||||
|
||||
@ -46,13 +46,19 @@ using FusedKernelConstructor = std::function<std::shared_ptr<FusedKernel>(
|
||||
TORCH_API void registerFusionBackend(
|
||||
at::Device::Type backend_type,
|
||||
FusedKernelConstructor ctor);
|
||||
TORCH_API void deregisterFusionBackend(at::Device::Type backend_type);
|
||||
TORCH_API bool hasFusionBackend(at::Device::Type backend_type);
|
||||
struct TORCH_API RegisterFusionBackend {
|
||||
RegisterFusionBackend(
|
||||
at::Device::Type backend_type,
|
||||
FusedKernelConstructor ctor) {
|
||||
FusedKernelConstructor ctor)
|
||||
: backend_type(backend_type) {
|
||||
registerFusionBackend(backend_type, std::move(ctor));
|
||||
}
|
||||
~RegisterFusionBackend() {
|
||||
deregisterFusionBackend(backend_type);
|
||||
}
|
||||
at::Device::Type backend_type;
|
||||
};
|
||||
|
||||
} // namespace fuser
|
||||
|
||||
@ -22,10 +22,18 @@ struct TORCH_API RegisterOperators {
|
||||
explicit RegisterOperators(std::vector<c10::optional<Operator>> operators) {
|
||||
for (c10::optional<Operator>& o : operators) {
|
||||
if (o) {
|
||||
registered_schemas.push_back(o.value().schema());
|
||||
registerOperator(std::move(o.value()));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
~RegisterOperators() {
|
||||
for (const auto& s : registered_schemas)
|
||||
deregisterOperator(s);
|
||||
}
|
||||
|
||||
std::vector<c10::FunctionSchema> registered_schemas;
|
||||
};
|
||||
|
||||
} // namespace jit
|
||||
|
||||
@ -22,9 +22,15 @@ class ProfileRegistry {
|
||||
return &profile_registry_;
|
||||
}
|
||||
|
||||
void registerProfileNode(const std::function<bool(const Node*)>& func) {
|
||||
int registerProfileNode(const std::function<bool(const Node*)>& func) {
|
||||
std::lock_guard<std::mutex> guard(mutex_);
|
||||
registry_funcs_.push_back(func);
|
||||
registry_funcs_[registry_index_] = func;
|
||||
return registry_index_++;
|
||||
}
|
||||
|
||||
void deregisterProfileNode(int index) {
|
||||
std::lock_guard<std::mutex> guard(mutex_);
|
||||
registry_funcs_.erase(index);
|
||||
}
|
||||
|
||||
bool shouldProfileNode(const Node* node) {
|
||||
@ -35,7 +41,7 @@ class ProfileRegistry {
|
||||
return true;
|
||||
}
|
||||
for (const auto& func : registry_funcs_) {
|
||||
if (func(node)) {
|
||||
if (func.second(node)) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
@ -43,14 +49,19 @@ class ProfileRegistry {
|
||||
}
|
||||
|
||||
private:
|
||||
std::vector<std::function<bool(const Node*)>> registry_funcs_;
|
||||
int registry_index_{};
|
||||
std::unordered_map<int, std::function<bool(const Node*)>> registry_funcs_;
|
||||
std::mutex mutex_;
|
||||
};
|
||||
|
||||
} // namespace
|
||||
|
||||
void RegisterProfilingNode(const std::function<bool(const Node*)>& func) {
|
||||
ProfileRegistry::getRegistry()->registerProfileNode(func);
|
||||
int RegisterProfilingNode(const std::function<bool(const Node*)>& func) {
|
||||
return ProfileRegistry::getRegistry()->registerProfileNode(func);
|
||||
}
|
||||
|
||||
void DeregisterProfilingNode(int index) {
|
||||
ProfileRegistry::getRegistry()->deregisterProfileNode(index);
|
||||
}
|
||||
|
||||
bool ShapeSymbolTable::bindSymbolicShapes(
|
||||
|
||||
@ -82,7 +82,8 @@ namespace jit {
|
||||
using ::c10::TensorTypePtr;
|
||||
using Dimension = int64_t;
|
||||
|
||||
TORCH_API void RegisterProfilingNode(const std::function<bool(const Node*)>&);
|
||||
TORCH_API int RegisterProfilingNode(const std::function<bool(const Node*)>&);
|
||||
TORCH_API void DeregisterProfilingNode(int index);
|
||||
|
||||
struct ProfilingRecord;
|
||||
|
||||
|
||||
@ -49,6 +49,10 @@ void RegisterCodeGenList::AddStmtFactoryMethod(
|
||||
stmt_factory_methods_[name] = stmt_factory_method;
|
||||
}
|
||||
|
||||
void RegisterCodeGenList::RemoveStmtFactoryMethod(const std::string& name) {
|
||||
stmt_factory_methods_.erase(name);
|
||||
}
|
||||
|
||||
std::unique_ptr<CodeGen> CreateCodeGen(
|
||||
const std::string& name,
|
||||
StmtPtr stmt,
|
||||
|
||||
@ -214,6 +214,7 @@ class RegisterCodeGenList {
|
||||
TORCH_API void AddStmtFactoryMethod(
|
||||
const std::string& name,
|
||||
const StmtFactoryMethod& stmt_factory_method);
|
||||
TORCH_API void RemoveStmtFactoryMethod(const std::string& name);
|
||||
|
||||
std::unordered_map<std::string, StmtFactoryMethod> stmt_factory_methods_;
|
||||
};
|
||||
@ -221,7 +222,7 @@ class RegisterCodeGenList {
|
||||
template <class CodeGenType>
|
||||
class RegisterCodeGen {
|
||||
public:
|
||||
explicit RegisterCodeGen(const std::string& name) {
|
||||
explicit RegisterCodeGen(const std::string& name) : name_(name) {
|
||||
RegisterCodeGenList& codegen_list = RegisterCodeGenList::GetInstance();
|
||||
codegen_list.AddStmtFactoryMethod(
|
||||
name,
|
||||
@ -235,6 +236,13 @@ class RegisterCodeGen {
|
||||
return method;
|
||||
});
|
||||
}
|
||||
|
||||
~RegisterCodeGen() {
|
||||
RegisterCodeGenList::GetInstance().RemoveStmtFactoryMethod(name_);
|
||||
}
|
||||
|
||||
private:
|
||||
std::string name_;
|
||||
};
|
||||
|
||||
TORCH_API std::unique_ptr<CodeGen> CreateCodeGen(
|
||||
|
||||
@ -421,6 +421,23 @@ class class_ : public ::torch::detail::class_base {
|
||||
registerCustomClassMethod(std::move(method));
|
||||
return method_val;
|
||||
}
|
||||
|
||||
// Wrapper function to force method deregistration on shutdown.
|
||||
static void registerCustomClassMethod(std::unique_ptr<jit::Function> method) {
|
||||
c10::QualifiedName method_name = method->qualname();
|
||||
torch::registerCustomClassMethod(std::move(method));
|
||||
|
||||
struct CustomClassMethodDeregistrator {
|
||||
~CustomClassMethodDeregistrator() {
|
||||
for (const auto& name : names)
|
||||
torch::deregisterCustomClassMethod(name);
|
||||
}
|
||||
std::vector<c10::QualifiedName> names;
|
||||
};
|
||||
|
||||
static CustomClassMethodDeregistrator deregistator;
|
||||
deregistator.names.emplace_back(std::move(method_name));
|
||||
}
|
||||
};
|
||||
|
||||
/// make_custom_class() is a convenient way to create an instance of a
|
||||
|
||||
@ -217,6 +217,7 @@ class TORCH_API class_base {
|
||||
|
||||
TORCH_API void registerCustomClass(at::ClassTypePtr class_type);
|
||||
TORCH_API void registerCustomClassMethod(std::unique_ptr<jit::Function> method);
|
||||
TORCH_API void deregisterCustomClassMethod(const c10::QualifiedName& name);
|
||||
|
||||
// Given a qualified name (e.g. __torch__.torch.classes.Foo), return
|
||||
// the ClassType pointer to the Type that describes that custom class,
|
||||
|
||||
Reference in New Issue
Block a user