mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[reapply][JIT] Namespaces for TorchBind (#35254)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/35254 Reapply D20541090 with some BC fixes ghstack-source-id: 100733987 Test Plan: buck test mode/dev-nosan //caffe2/torch/fb/predictor/model_repo/tests:ai_infra_representative_model_shard_6_test -- 'RepresentativeModelTest\/ShardedRepresentativeModelTest\.RunModel\/0' Reviewed By: zdevito Differential Revision: D20607111 fbshipit-source-id: 80f148d860571208c93e9308128cd480ff089f74
This commit is contained in:
committed by
Facebook GitHub Bot
parent
17068ba467
commit
618c6214aa
@ -20,6 +20,10 @@ void registerCustomClass(at::ClassTypePtr class_type) {
|
||||
}
|
||||
|
||||
at::ClassTypePtr getCustomClass(const std::string& name) {
|
||||
// BC hack so we can upgrade a binary internally
|
||||
if (name == "__torch__.torch.classes.SentencePiece") {
|
||||
return getCustomClass("__torch__.torch.classes.fb.SentencePiece");
|
||||
}
|
||||
return customClasses().count(name) ? customClasses()[name] : nullptr;
|
||||
}
|
||||
|
||||
|
@ -17,7 +17,7 @@ using internal::convolution2d::createConv2dClampPrePackOpContext;
|
||||
namespace {
|
||||
torch::jit::class_<LinearOpContext> register_packed_linear_op_context_class() {
|
||||
static auto register_linear_op_context_class =
|
||||
torch::jit::class_<LinearOpContext>("LinearOpContext")
|
||||
torch::jit::class_<LinearOpContext>("xnnpack", "LinearOpContext")
|
||||
.def_pickle(
|
||||
[](const c10::intrusive_ptr<LinearOpContext>& op_context)
|
||||
-> SerializationTypeLinearPrePack { // __getstate__
|
||||
@ -36,7 +36,7 @@ torch::jit::class_<LinearOpContext> register_packed_linear_op_context_class() {
|
||||
|
||||
torch::jit::class_<Conv2dOpContext> register_packed_conv2d_op_context_class() {
|
||||
static auto register_conv2d_op_context_class =
|
||||
torch::jit::class_<Conv2dOpContext>("Conv2dOpContext")
|
||||
torch::jit::class_<Conv2dOpContext>("xnnpack", "Conv2dOpContext")
|
||||
.def_pickle(
|
||||
[](const c10::intrusive_ptr<Conv2dOpContext>& op_context)
|
||||
-> SerializationTypeConv2dPrePack { // __getstate__
|
||||
@ -67,14 +67,14 @@ static auto registry =
|
||||
torch::RegisterOperators()
|
||||
.op("prepacked::linear_clamp_prepack(Tensor W, Tensor? B=None, "
|
||||
"float? output_min=None, float? output_max=None) "
|
||||
"-> __torch__.torch.classes.LinearOpContext",
|
||||
"-> __torch__.torch.classes.xnnpack.LinearOpContext",
|
||||
torch::RegisterOperators::options()
|
||||
.aliasAnalysis(at::AliasAnalysisKind::PURE_FUNCTION)
|
||||
.kernel<decltype(createLinearClampPrePackOpContext),
|
||||
createLinearClampPrePackOpContext>(
|
||||
DispatchKey::CPUTensorId))
|
||||
.op("prepacked::linear_clamp_run(Tensor X,"
|
||||
" __torch__.torch.classes.LinearOpContext W_prepack) -> Tensor Y",
|
||||
" __torch__.torch.classes.xnnpack.LinearOpContext W_prepack) -> Tensor Y",
|
||||
torch::RegisterOperators::options()
|
||||
.aliasAnalysis(at::AliasAnalysisKind::PURE_FUNCTION)
|
||||
.kernel<internal::linear::LinearClampRun>(
|
||||
@ -82,14 +82,14 @@ static auto registry =
|
||||
.op("prepacked::conv2d_clamp_prepack(Tensor W, Tensor? B, int[2] stride, "
|
||||
"int[2] padding, int[2] dilation, int groups, "
|
||||
"float? output_min=None, float? output_max=None) "
|
||||
"-> __torch__.torch.classes.Conv2dOpContext",
|
||||
"-> __torch__.torch.classes.xnnpack.Conv2dOpContext",
|
||||
torch::RegisterOperators::options()
|
||||
.aliasAnalysis(at::AliasAnalysisKind::PURE_FUNCTION)
|
||||
.kernel<decltype(createConv2dClampPrePackOpContext),
|
||||
createConv2dClampPrePackOpContext>(
|
||||
DispatchKey::CPUTensorId))
|
||||
.op("prepacked::conv2d_clamp_run(Tensor X, "
|
||||
"__torch__.torch.classes.Conv2dOpContext W_prepack) -> Tensor Y",
|
||||
"__torch__.torch.classes.xnnpack.Conv2dOpContext W_prepack) -> Tensor Y",
|
||||
torch::RegisterOperators::options()
|
||||
.aliasAnalysis(at::AliasAnalysisKind::PURE_FUNCTION)
|
||||
.kernel<internal::convolution2d::Conv2dClampRun>(
|
||||
|
@ -138,7 +138,7 @@ void testClassDerive() {
|
||||
static const auto torchbindSrc = R"JIT(
|
||||
class FooBar1234(Module):
|
||||
__parameters__ = []
|
||||
f : __torch__.torch.classes._TorchScriptTesting_StackString
|
||||
f : __torch__.torch.classes._TorchScriptTesting._StackString
|
||||
training : bool
|
||||
def forward(self: __torch__.FooBar1234) -> str:
|
||||
return (self.f).top()
|
||||
|
@ -66,7 +66,7 @@ struct PickleTester : torch::CustomClassHolder {
|
||||
std::vector<int64_t> vals;
|
||||
};
|
||||
|
||||
static auto test = torch::class_<Foo>("_TorchScriptTesting_Foo")
|
||||
static auto test = torch::class_<Foo>("_TorchScriptTesting", "_Foo")
|
||||
.def(torch::init<int64_t, int64_t>())
|
||||
// .def(torch::init<>())
|
||||
.def("info", &Foo::info)
|
||||
@ -75,7 +75,9 @@ static auto test = torch::class_<Foo>("_TorchScriptTesting_Foo")
|
||||
.def("combine", &Foo::combine);
|
||||
|
||||
static auto testStack =
|
||||
torch::class_<MyStackClass<std::string>>("_TorchScriptTesting_StackString")
|
||||
torch::class_<MyStackClass<std::string>>(
|
||||
"_TorchScriptTesting",
|
||||
"_StackString")
|
||||
.def(torch::init<std::vector<std::string>>())
|
||||
.def("push", &MyStackClass<std::string>::push)
|
||||
.def("pop", &MyStackClass<std::string>::pop)
|
||||
@ -101,7 +103,7 @@ static auto testStack =
|
||||
// clang-format on
|
||||
|
||||
static auto testPickle =
|
||||
torch::class_<PickleTester>("_TorchScriptTesting_PickleTester")
|
||||
torch::class_<PickleTester>("_TorchScriptTesting", "_PickleTester")
|
||||
.def(torch::init<std::vector<int64_t>>())
|
||||
.def_pickle(
|
||||
[](c10::intrusive_ptr<PickleTester> self) { // __getstate__
|
||||
@ -127,10 +129,10 @@ at::Tensor take_an_instance(const c10::intrusive_ptr<PickleTester>& instance) {
|
||||
|
||||
torch::RegisterOperators& register_take_instance() {
|
||||
static auto instance_registry = torch::RegisterOperators().op(
|
||||
torch::RegisterOperators::options()
|
||||
.schema(
|
||||
"_TorchScriptTesting::take_an_instance(__torch__.torch.classes._TorchScriptTesting_PickleTester x) -> Tensor Y")
|
||||
.catchAllKernel<decltype(take_an_instance), &take_an_instance>());
|
||||
torch::RegisterOperators::options()
|
||||
.schema(
|
||||
"_TorchScriptTesting::take_an_instance(__torch__.torch.classes._TorchScriptTesting._PickleTester x) -> Tensor Y")
|
||||
.catchAllKernel<decltype(take_an_instance), &take_an_instance>());
|
||||
return instance_registry;
|
||||
}
|
||||
|
||||
@ -146,7 +148,7 @@ void testTorchbindIValueAPI() {
|
||||
auto custom_class_obj = make_custom_class<MyStackClass<std::string>>(
|
||||
std::vector<std::string>{"foo", "bar"});
|
||||
m.define(R"(
|
||||
def forward(self, s : __torch__.torch.classes._TorchScriptTesting_StackString):
|
||||
def forward(self, s : __torch__.torch.classes._TorchScriptTesting._StackString):
|
||||
return s.pop(), s
|
||||
)");
|
||||
|
||||
|
@ -343,7 +343,8 @@ void testLiteInterpreterBuiltinFunction() {
|
||||
namespace {
|
||||
static auto reg =
|
||||
torch::jit::class_<TorchBindLiteInterpreterTestStruct>(
|
||||
"_TorchScriptTesting_LiteInterpreterTest")
|
||||
"_TorchScriptTesting",
|
||||
"_LiteInterpreterTest")
|
||||
.def("get", &TorchBindLiteInterpreterTestStruct::get)
|
||||
.def_pickle(
|
||||
// __getattr__
|
||||
|
@ -35,19 +35,19 @@ class TestCustomOperators(unittest.TestCase):
|
||||
|
||||
def test_no_return_class(self):
|
||||
def f():
|
||||
val = torch.classes._TorchScriptTesting_Foo(5, 3)
|
||||
val = torch.classes._TorchScriptTesting._Foo(5, 3)
|
||||
return val.info()
|
||||
self.assertEqual(*test_equality(f, lambda x: x))
|
||||
|
||||
def test_constructor_with_args(self):
|
||||
def f():
|
||||
val = torch.classes._TorchScriptTesting_Foo(5, 3)
|
||||
val = torch.classes._TorchScriptTesting._Foo(5, 3)
|
||||
return val
|
||||
self.assertEqual(*test_equality(f, lambda x: x.info()))
|
||||
|
||||
def test_function_call_with_args(self):
|
||||
def f():
|
||||
val = torch.classes._TorchScriptTesting_Foo(5, 3)
|
||||
val = torch.classes._TorchScriptTesting._Foo(5, 3)
|
||||
val.increment(1)
|
||||
return val
|
||||
|
||||
@ -55,7 +55,7 @@ class TestCustomOperators(unittest.TestCase):
|
||||
|
||||
def test_function_method_wrong_type(self):
|
||||
def f():
|
||||
val = torch.classes._TorchScriptTesting_Foo(5, 3)
|
||||
val = torch.classes._TorchScriptTesting._Foo(5, 3)
|
||||
val.increment("asdf")
|
||||
return val
|
||||
|
||||
@ -65,8 +65,8 @@ class TestCustomOperators(unittest.TestCase):
|
||||
@unittest.skip("We currently don't support passing custom classes to custom methods.")
|
||||
def test_input_class_type(self):
|
||||
def f():
|
||||
val = torch.classes._TorchScriptTesting_Foo(1, 2)
|
||||
val2 = torch.classes._TorchScriptTesting_Foo(2, 3)
|
||||
val = torch.classes._TorchScriptTesting._Foo(1, 2)
|
||||
val2 = torch.classes._TorchScriptTesting._Foo(2, 3)
|
||||
val.combine(val2)
|
||||
return val
|
||||
|
||||
@ -74,14 +74,14 @@ class TestCustomOperators(unittest.TestCase):
|
||||
|
||||
def test_stack_string(self):
|
||||
def f():
|
||||
val = torch.classes._TorchScriptTesting_StackString(["asdf", "bruh"])
|
||||
val = torch.classes._TorchScriptTesting._StackString(["asdf", "bruh"])
|
||||
return val.pop()
|
||||
self.assertEqual(*test_equality(f, lambda x: x))
|
||||
|
||||
def test_stack_push_pop(self):
|
||||
def f():
|
||||
val = torch.classes._TorchScriptTesting_StackString(["asdf", "bruh"])
|
||||
val2 = torch.classes._TorchScriptTesting_StackString(["111", "222"])
|
||||
val = torch.classes._TorchScriptTesting._StackString(["asdf", "bruh"])
|
||||
val2 = torch.classes._TorchScriptTesting._StackString(["111", "222"])
|
||||
val.push(val2.pop())
|
||||
return val.pop() + val2.pop()
|
||||
self.assertEqual(*test_equality(f, lambda x: x))
|
||||
|
@ -5551,23 +5551,23 @@ def foo(x):
|
||||
return (cmp_key(obj1), cmp_key(obj2))
|
||||
|
||||
def f():
|
||||
val = torch.classes._TorchScriptTesting_Foo(5, 3)
|
||||
val = torch.classes._TorchScriptTesting._Foo(5, 3)
|
||||
val.increment(1)
|
||||
return val
|
||||
test_equality(f, lambda x: x)
|
||||
|
||||
with self.assertRaisesRegex(RuntimeError, "Expected a value of type 'int'"):
|
||||
val = torch.classes._TorchScriptTesting_Foo(5, 3)
|
||||
val = torch.classes._TorchScriptTesting._Foo(5, 3)
|
||||
val.increment('foo')
|
||||
|
||||
def f():
|
||||
ss = torch.classes._TorchScriptTesting_StackString(["asdf", "bruh"])
|
||||
ss = torch.classes._TorchScriptTesting._StackString(["asdf", "bruh"])
|
||||
return ss.pop()
|
||||
test_equality(f, lambda x: x)
|
||||
|
||||
def f():
|
||||
ss1 = torch.classes._TorchScriptTesting_StackString(["asdf", "bruh"])
|
||||
ss2 = torch.classes._TorchScriptTesting_StackString(["111", "222"])
|
||||
ss1 = torch.classes._TorchScriptTesting._StackString(["asdf", "bruh"])
|
||||
ss2 = torch.classes._TorchScriptTesting._StackString(["111", "222"])
|
||||
ss1.push(ss2.pop())
|
||||
return ss1.pop() + ss2.pop()
|
||||
test_equality(f, lambda x: x)
|
||||
@ -5575,14 +5575,14 @@ def foo(x):
|
||||
@skipIfRocm
|
||||
def test_torchbind_take_as_arg(self):
|
||||
global StackString # see [local resolution in python]
|
||||
StackString = torch.classes._TorchScriptTesting_StackString
|
||||
StackString = torch.classes._TorchScriptTesting._StackString
|
||||
|
||||
def foo(stackstring):
|
||||
# type: (StackString)
|
||||
stackstring.push("lel")
|
||||
return stackstring
|
||||
|
||||
script_input = torch.classes._TorchScriptTesting_StackString([])
|
||||
script_input = torch.classes._TorchScriptTesting._StackString([])
|
||||
scripted = torch.jit.script(foo)
|
||||
script_output = scripted(script_input)
|
||||
self.assertEqual(script_output.pop(), "lel")
|
||||
@ -5590,7 +5590,7 @@ def foo(x):
|
||||
@skipIfRocm
|
||||
def test_torchbind_return_instance(self):
|
||||
def foo():
|
||||
ss = torch.classes._TorchScriptTesting_StackString(["hi", "mom"])
|
||||
ss = torch.classes._TorchScriptTesting._StackString(["hi", "mom"])
|
||||
return ss
|
||||
|
||||
scripted = torch.jit.script(foo)
|
||||
@ -5606,7 +5606,7 @@ def foo(x):
|
||||
@skipIfRocm
|
||||
def test_torchbind_return_instance_from_method(self):
|
||||
def foo():
|
||||
ss = torch.classes._TorchScriptTesting_StackString(["hi", "mom"])
|
||||
ss = torch.classes._TorchScriptTesting._StackString(["hi", "mom"])
|
||||
clone = ss.clone()
|
||||
ss.pop()
|
||||
return ss, clone
|
||||
@ -5620,8 +5620,8 @@ def foo(x):
|
||||
@skipIfRocm
|
||||
def test_torchbind_take_instance_as_method_arg(self):
|
||||
def foo():
|
||||
ss = torch.classes._TorchScriptTesting_StackString(["mom"])
|
||||
ss2 = torch.classes._TorchScriptTesting_StackString(["hi"])
|
||||
ss = torch.classes._TorchScriptTesting._StackString(["mom"])
|
||||
ss2 = torch.classes._TorchScriptTesting._StackString(["hi"])
|
||||
ss.merge(ss2)
|
||||
return ss
|
||||
|
||||
@ -5633,7 +5633,7 @@ def foo(x):
|
||||
@skipIfRocm
|
||||
def test_torchbind_return_tuple(self):
|
||||
def f():
|
||||
val = torch.classes._TorchScriptTesting_StackString(["3", "5"])
|
||||
val = torch.classes._TorchScriptTesting._StackString(["3", "5"])
|
||||
return val.return_a_tuple()
|
||||
|
||||
scripted = torch.jit.script(f)
|
||||
@ -5643,8 +5643,8 @@ def foo(x):
|
||||
@skipIfRocm
|
||||
def test_torchbind_save_load(self):
|
||||
def foo():
|
||||
ss = torch.classes._TorchScriptTesting_StackString(["mom"])
|
||||
ss2 = torch.classes._TorchScriptTesting_StackString(["hi"])
|
||||
ss = torch.classes._TorchScriptTesting._StackString(["mom"])
|
||||
ss2 = torch.classes._TorchScriptTesting._StackString(["hi"])
|
||||
ss.merge(ss2)
|
||||
return ss
|
||||
|
||||
@ -5673,7 +5673,7 @@ def foo(x):
|
||||
@skipIfRocm
|
||||
def test_torchbind_lambda_method(self):
|
||||
def foo():
|
||||
ss = torch.classes._TorchScriptTesting_StackString(["mom"])
|
||||
ss = torch.classes._TorchScriptTesting._StackString(["mom"])
|
||||
return ss.top()
|
||||
|
||||
scripted = torch.jit.script(foo)
|
||||
@ -5684,7 +5684,7 @@ def foo(x):
|
||||
class FooBar1234(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super(FooBar1234, self).__init__()
|
||||
self.f = torch.classes._TorchScriptTesting_StackString(["3", "4"])
|
||||
self.f = torch.classes._TorchScriptTesting._StackString(["3", "4"])
|
||||
|
||||
def forward(self):
|
||||
return self.f.top()
|
||||
@ -5701,7 +5701,7 @@ def foo(x):
|
||||
class FooBar4321(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super(FooBar4321, self).__init__()
|
||||
self.f = torch.classes._TorchScriptTesting_PickleTester([3, 4])
|
||||
self.f = torch.classes._TorchScriptTesting._PickleTester([3, 4])
|
||||
|
||||
def forward(self):
|
||||
return self.f.top()
|
||||
@ -5723,7 +5723,7 @@ def foo(x):
|
||||
class TryTracing(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super(TryTracing, self).__init__()
|
||||
self.f = torch.classes._TorchScriptTesting_PickleTester([3, 4])
|
||||
self.f = torch.classes._TorchScriptTesting._PickleTester([3, 4])
|
||||
|
||||
def forward(self):
|
||||
return torch.ops._TorchScriptTesting.take_an_instance(self.f)
|
||||
@ -5736,7 +5736,7 @@ def foo(x):
|
||||
class TryTracingNest(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super(TryTracingNest, self).__init__()
|
||||
self.f = torch.classes._TorchScriptTesting_PickleTester([3, 4])
|
||||
self.f = torch.classes._TorchScriptTesting._PickleTester([3, 4])
|
||||
|
||||
class TryTracing123(torch.nn.Module):
|
||||
def __init__(self):
|
||||
@ -5751,7 +5751,7 @@ def foo(x):
|
||||
|
||||
@skipIfRocm
|
||||
def test_torchbind_pickle_serialization(self):
|
||||
nt = torch.classes._TorchScriptTesting_PickleTester([3, 4])
|
||||
nt = torch.classes._TorchScriptTesting._PickleTester([3, 4])
|
||||
b = io.BytesIO()
|
||||
torch.save(nt, b)
|
||||
b.seek(0)
|
||||
@ -5761,8 +5761,8 @@ def foo(x):
|
||||
|
||||
@skipIfRocm
|
||||
def test_torchbind_instantiate_missing_class(self):
|
||||
with self.assertRaisesRegex(RuntimeError, 'Tried to instantiate class IDontExist but it does not exist!'):
|
||||
torch.classes.IDontExist(3, 4, 5)
|
||||
with self.assertRaisesRegex(RuntimeError, 'Tried to instantiate class foo.IDontExist but it does not exist!'):
|
||||
torch.classes.foo.IDontExist(3, 4, 5)
|
||||
|
||||
def test_jitter_bug(self):
|
||||
@torch.jit.script
|
||||
|
@ -1,15 +1,25 @@
|
||||
import types
|
||||
import torch._C
|
||||
|
||||
class _ClassNamespace(types.ModuleType):
|
||||
def __init__(self, name):
|
||||
super(_ClassNamespace, self).__init__('torch.classes' + name)
|
||||
self.name = name
|
||||
|
||||
def __getattr__(self, attr):
|
||||
proxy = torch._C._get_custom_class_python_wrapper(self.name, attr)
|
||||
if proxy is None:
|
||||
raise RuntimeError('Class {}.{} not registered!'.format(self.name, attr))
|
||||
return proxy
|
||||
|
||||
class _Classes(types.ModuleType):
|
||||
def __init__(self):
|
||||
super(_Classes, self).__init__('torch.classes')
|
||||
|
||||
def __getattr__(self, attr):
|
||||
proxy = torch._C._get_custom_class_python_wrapper(attr)
|
||||
if proxy is None:
|
||||
raise RuntimeError('Class {} not registered!'.format(attr))
|
||||
return proxy
|
||||
def __getattr__(self, name):
|
||||
namespace = _ClassNamespace(name)
|
||||
setattr(self, name, namespace)
|
||||
return namespace
|
||||
|
||||
@property
|
||||
def loaded_libraries(self):
|
||||
|
@ -244,12 +244,16 @@ std::pair<TypePtr, c10::optional<AliasInfo>> SchemaTypeParser::parseType() {
|
||||
<< "Expected classes namespace but got " << classes_tok.text();
|
||||
}
|
||||
L.expect('.');
|
||||
auto ns_tok = L.expect(TK_IDENT);
|
||||
L.expect('.');
|
||||
auto class_tok = L.expect(TK_IDENT);
|
||||
value = getCustomClass(
|
||||
std::string("__torch__.torch.classes.") + class_tok.text());
|
||||
std::string("__torch__.torch.classes.") + ns_tok.text() + "." +
|
||||
class_tok.text());
|
||||
if (!value) {
|
||||
throw ErrorReport(class_tok.range)
|
||||
<< "Unknown custom class type " << class_tok.text()
|
||||
<< "Unknown custom class type "
|
||||
<< ns_tok.text() + "." + class_tok.text()
|
||||
<< ". Please ensure it is registered.";
|
||||
}
|
||||
} else {
|
||||
|
@ -28,20 +28,23 @@ void initPythonCustomClassBindings(PyObject* module) {
|
||||
// code object in turn calls __init__. Rather than calling __init__
|
||||
// directly, we need a wrapper that at least returns the instance
|
||||
// rather than the None return value from __init__
|
||||
m.def("_get_custom_class_python_wrapper", [](const std::string& qualname) {
|
||||
std::string full_qualname = "__torch__.torch.classes." + qualname;
|
||||
auto named_type = getCustomClass(full_qualname);
|
||||
TORCH_CHECK(
|
||||
named_type,
|
||||
"Tried to instantiate class ",
|
||||
qualname,
|
||||
" but it"
|
||||
" does not exist! Ensure that it is registered via torch::jit"
|
||||
"::class_");
|
||||
c10::ClassTypePtr class_type = named_type->cast<ClassType>();
|
||||
return ScriptClass(c10::StrongTypePtr(
|
||||
std::shared_ptr<CompilationUnit>(), std::move(class_type)));
|
||||
});
|
||||
m.def(
|
||||
"_get_custom_class_python_wrapper",
|
||||
[](const std::string& ns, const std::string& qualname) {
|
||||
std::string full_qualname =
|
||||
"__torch__.torch.classes." + ns + "." + qualname;
|
||||
auto named_type = getCustomClass(full_qualname);
|
||||
TORCH_CHECK(
|
||||
named_type,
|
||||
"Tried to instantiate class ",
|
||||
ns + "." + qualname,
|
||||
" but it"
|
||||
" does not exist! Ensure that it is registered via torch::jit"
|
||||
"::class_");
|
||||
c10::ClassTypePtr class_type = named_type->cast<ClassType>();
|
||||
return ScriptClass(c10::StrongTypePtr(
|
||||
std::shared_ptr<CompilationUnit>(), std::move(class_type)));
|
||||
});
|
||||
}
|
||||
|
||||
} // namespace jit
|
||||
|
@ -34,7 +34,7 @@ detail::types<void, Types...> init() {
|
||||
/// calls needed. For example, to register a class named Foo, you might
|
||||
/// create a global variable like so:
|
||||
///
|
||||
/// static auto register_foo = torch::class_<Foo>("Foo")
|
||||
/// static auto register_foo = torch::class_<Foo>("myclasses", "Foo")
|
||||
/// .def("myMethod", &Foo::myMethod)
|
||||
/// .def("lambdaMethod", [](const c10::intrusive_ptr<Foo>& self) {
|
||||
/// // Do something with `self`
|
||||
@ -51,12 +51,16 @@ class class_ {
|
||||
|
||||
public:
|
||||
/// This constructor actually registers the class type.
|
||||
/// String argument `className_` is the name you would like to
|
||||
/// String argument `namespaceName` is an identifier for the
|
||||
/// namespace you would like this class to appear in.
|
||||
/// String argument `className` is the name you would like to
|
||||
/// see this class exposed as in Python and TorchScript. For example, if
|
||||
/// you pass in "MyStack" here, the class will appear as
|
||||
/// `torch.classes.MyStack` in both Python and TorchScript.
|
||||
explicit class_(const std::string& className) : className(std::move(className)) {
|
||||
qualClassName = topModule + "." + parentModule + "." + className;
|
||||
/// you pass `foo` as the namespace name and `Bar` as the className, the
|
||||
/// class will appear as `torch.classes.foo.Bar` in Python and TorchScript
|
||||
explicit class_(const std::string& namespaceName, const std::string& className) {
|
||||
detail::checkValidIdent(namespaceName, "Namespace name");
|
||||
detail::checkValidIdent(className, "Class name");
|
||||
qualClassName = std::string("__torch__.torch.classes.") + namespaceName + "." + className;
|
||||
|
||||
classTypePtr = at::ClassType::create(
|
||||
c10::QualifiedName(qualClassName),
|
||||
@ -230,12 +234,8 @@ class class_ {
|
||||
classTypePtr->addMethod(method.get());
|
||||
}
|
||||
|
||||
std::string className;
|
||||
std::string qualClassName;
|
||||
at::ClassTypePtr classTypePtr;
|
||||
|
||||
const std::string parentModule = "classes";
|
||||
const std::string topModule = "__torch__.torch";
|
||||
};
|
||||
|
||||
/// make_custom_class() is a convenient way to create an instance of a registered
|
||||
|
@ -119,6 +119,20 @@ struct BoxedProxy<void, Func> {
|
||||
}
|
||||
};
|
||||
|
||||
inline bool validIdent(size_t i, char n) {
|
||||
return isalpha(n) || n == '_' || (i > 0 && isdigit(n));
|
||||
}
|
||||
|
||||
inline void checkValidIdent(const std::string& str, const char *type) {
|
||||
for (size_t i = 0; i < str.size(); ++i) {
|
||||
TORCH_CHECK(validIdent(i, str[i]),
|
||||
type,
|
||||
" must be a valid Python/C++ identifier."
|
||||
" Character '", str[i], "' at index ",
|
||||
i, " is illegal.");
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace detail
|
||||
|
||||
TORCH_API void registerCustomClass(at::ClassTypePtr class_type);
|
||||
|
Reference in New Issue
Block a user