[reapply][JIT] Namespaces for TorchBind (#35254)

Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/35254

Reapply D20541090 with some BC fixes
ghstack-source-id: 100733987

Test Plan: buck test mode/dev-nosan //caffe2/torch/fb/predictor/model_repo/tests:ai_infra_representative_model_shard_6_test -- 'RepresentativeModelTest\/ShardedRepresentativeModelTest\.RunModel\/0'

Reviewed By: zdevito

Differential Revision: D20607111

fbshipit-source-id: 80f148d860571208c93e9308128cd480ff089f74
This commit is contained in:
James Reed
2020-03-24 00:34:43 -07:00
committed by Facebook GitHub Bot
parent 17068ba467
commit 618c6214aa
12 changed files with 116 additions and 78 deletions

View File

@ -20,6 +20,10 @@ void registerCustomClass(at::ClassTypePtr class_type) {
}
at::ClassTypePtr getCustomClass(const std::string& name) {
// BC hack so we can upgrade a binary internally
if (name == "__torch__.torch.classes.SentencePiece") {
return getCustomClass("__torch__.torch.classes.fb.SentencePiece");
}
return customClasses().count(name) ? customClasses()[name] : nullptr;
}

View File

@ -17,7 +17,7 @@ using internal::convolution2d::createConv2dClampPrePackOpContext;
namespace {
torch::jit::class_<LinearOpContext> register_packed_linear_op_context_class() {
static auto register_linear_op_context_class =
torch::jit::class_<LinearOpContext>("LinearOpContext")
torch::jit::class_<LinearOpContext>("xnnpack", "LinearOpContext")
.def_pickle(
[](const c10::intrusive_ptr<LinearOpContext>& op_context)
-> SerializationTypeLinearPrePack { // __getstate__
@ -36,7 +36,7 @@ torch::jit::class_<LinearOpContext> register_packed_linear_op_context_class() {
torch::jit::class_<Conv2dOpContext> register_packed_conv2d_op_context_class() {
static auto register_conv2d_op_context_class =
torch::jit::class_<Conv2dOpContext>("Conv2dOpContext")
torch::jit::class_<Conv2dOpContext>("xnnpack", "Conv2dOpContext")
.def_pickle(
[](const c10::intrusive_ptr<Conv2dOpContext>& op_context)
-> SerializationTypeConv2dPrePack { // __getstate__
@ -67,14 +67,14 @@ static auto registry =
torch::RegisterOperators()
.op("prepacked::linear_clamp_prepack(Tensor W, Tensor? B=None, "
"float? output_min=None, float? output_max=None) "
"-> __torch__.torch.classes.LinearOpContext",
"-> __torch__.torch.classes.xnnpack.LinearOpContext",
torch::RegisterOperators::options()
.aliasAnalysis(at::AliasAnalysisKind::PURE_FUNCTION)
.kernel<decltype(createLinearClampPrePackOpContext),
createLinearClampPrePackOpContext>(
DispatchKey::CPUTensorId))
.op("prepacked::linear_clamp_run(Tensor X,"
" __torch__.torch.classes.LinearOpContext W_prepack) -> Tensor Y",
" __torch__.torch.classes.xnnpack.LinearOpContext W_prepack) -> Tensor Y",
torch::RegisterOperators::options()
.aliasAnalysis(at::AliasAnalysisKind::PURE_FUNCTION)
.kernel<internal::linear::LinearClampRun>(
@ -82,14 +82,14 @@ static auto registry =
.op("prepacked::conv2d_clamp_prepack(Tensor W, Tensor? B, int[2] stride, "
"int[2] padding, int[2] dilation, int groups, "
"float? output_min=None, float? output_max=None) "
"-> __torch__.torch.classes.Conv2dOpContext",
"-> __torch__.torch.classes.xnnpack.Conv2dOpContext",
torch::RegisterOperators::options()
.aliasAnalysis(at::AliasAnalysisKind::PURE_FUNCTION)
.kernel<decltype(createConv2dClampPrePackOpContext),
createConv2dClampPrePackOpContext>(
DispatchKey::CPUTensorId))
.op("prepacked::conv2d_clamp_run(Tensor X, "
"__torch__.torch.classes.Conv2dOpContext W_prepack) -> Tensor Y",
"__torch__.torch.classes.xnnpack.Conv2dOpContext W_prepack) -> Tensor Y",
torch::RegisterOperators::options()
.aliasAnalysis(at::AliasAnalysisKind::PURE_FUNCTION)
.kernel<internal::convolution2d::Conv2dClampRun>(

View File

@ -138,7 +138,7 @@ void testClassDerive() {
static const auto torchbindSrc = R"JIT(
class FooBar1234(Module):
__parameters__ = []
f : __torch__.torch.classes._TorchScriptTesting_StackString
f : __torch__.torch.classes._TorchScriptTesting._StackString
training : bool
def forward(self: __torch__.FooBar1234) -> str:
return (self.f).top()

View File

@ -66,7 +66,7 @@ struct PickleTester : torch::CustomClassHolder {
std::vector<int64_t> vals;
};
static auto test = torch::class_<Foo>("_TorchScriptTesting_Foo")
static auto test = torch::class_<Foo>("_TorchScriptTesting", "_Foo")
.def(torch::init<int64_t, int64_t>())
// .def(torch::init<>())
.def("info", &Foo::info)
@ -75,7 +75,9 @@ static auto test = torch::class_<Foo>("_TorchScriptTesting_Foo")
.def("combine", &Foo::combine);
static auto testStack =
torch::class_<MyStackClass<std::string>>("_TorchScriptTesting_StackString")
torch::class_<MyStackClass<std::string>>(
"_TorchScriptTesting",
"_StackString")
.def(torch::init<std::vector<std::string>>())
.def("push", &MyStackClass<std::string>::push)
.def("pop", &MyStackClass<std::string>::pop)
@ -101,7 +103,7 @@ static auto testStack =
// clang-format on
static auto testPickle =
torch::class_<PickleTester>("_TorchScriptTesting_PickleTester")
torch::class_<PickleTester>("_TorchScriptTesting", "_PickleTester")
.def(torch::init<std::vector<int64_t>>())
.def_pickle(
[](c10::intrusive_ptr<PickleTester> self) { // __getstate__
@ -127,10 +129,10 @@ at::Tensor take_an_instance(const c10::intrusive_ptr<PickleTester>& instance) {
torch::RegisterOperators& register_take_instance() {
static auto instance_registry = torch::RegisterOperators().op(
torch::RegisterOperators::options()
.schema(
"_TorchScriptTesting::take_an_instance(__torch__.torch.classes._TorchScriptTesting_PickleTester x) -> Tensor Y")
.catchAllKernel<decltype(take_an_instance), &take_an_instance>());
torch::RegisterOperators::options()
.schema(
"_TorchScriptTesting::take_an_instance(__torch__.torch.classes._TorchScriptTesting._PickleTester x) -> Tensor Y")
.catchAllKernel<decltype(take_an_instance), &take_an_instance>());
return instance_registry;
}
@ -146,7 +148,7 @@ void testTorchbindIValueAPI() {
auto custom_class_obj = make_custom_class<MyStackClass<std::string>>(
std::vector<std::string>{"foo", "bar"});
m.define(R"(
def forward(self, s : __torch__.torch.classes._TorchScriptTesting_StackString):
def forward(self, s : __torch__.torch.classes._TorchScriptTesting._StackString):
return s.pop(), s
)");

View File

@ -343,7 +343,8 @@ void testLiteInterpreterBuiltinFunction() {
namespace {
static auto reg =
torch::jit::class_<TorchBindLiteInterpreterTestStruct>(
"_TorchScriptTesting_LiteInterpreterTest")
"_TorchScriptTesting",
"_LiteInterpreterTest")
.def("get", &TorchBindLiteInterpreterTestStruct::get)
.def_pickle(
// __getattr__

View File

@ -35,19 +35,19 @@ class TestCustomOperators(unittest.TestCase):
def test_no_return_class(self):
def f():
val = torch.classes._TorchScriptTesting_Foo(5, 3)
val = torch.classes._TorchScriptTesting._Foo(5, 3)
return val.info()
self.assertEqual(*test_equality(f, lambda x: x))
def test_constructor_with_args(self):
def f():
val = torch.classes._TorchScriptTesting_Foo(5, 3)
val = torch.classes._TorchScriptTesting._Foo(5, 3)
return val
self.assertEqual(*test_equality(f, lambda x: x.info()))
def test_function_call_with_args(self):
def f():
val = torch.classes._TorchScriptTesting_Foo(5, 3)
val = torch.classes._TorchScriptTesting._Foo(5, 3)
val.increment(1)
return val
@ -55,7 +55,7 @@ class TestCustomOperators(unittest.TestCase):
def test_function_method_wrong_type(self):
def f():
val = torch.classes._TorchScriptTesting_Foo(5, 3)
val = torch.classes._TorchScriptTesting._Foo(5, 3)
val.increment("asdf")
return val
@ -65,8 +65,8 @@ class TestCustomOperators(unittest.TestCase):
@unittest.skip("We currently don't support passing custom classes to custom methods.")
def test_input_class_type(self):
def f():
val = torch.classes._TorchScriptTesting_Foo(1, 2)
val2 = torch.classes._TorchScriptTesting_Foo(2, 3)
val = torch.classes._TorchScriptTesting._Foo(1, 2)
val2 = torch.classes._TorchScriptTesting._Foo(2, 3)
val.combine(val2)
return val
@ -74,14 +74,14 @@ class TestCustomOperators(unittest.TestCase):
def test_stack_string(self):
def f():
val = torch.classes._TorchScriptTesting_StackString(["asdf", "bruh"])
val = torch.classes._TorchScriptTesting._StackString(["asdf", "bruh"])
return val.pop()
self.assertEqual(*test_equality(f, lambda x: x))
def test_stack_push_pop(self):
def f():
val = torch.classes._TorchScriptTesting_StackString(["asdf", "bruh"])
val2 = torch.classes._TorchScriptTesting_StackString(["111", "222"])
val = torch.classes._TorchScriptTesting._StackString(["asdf", "bruh"])
val2 = torch.classes._TorchScriptTesting._StackString(["111", "222"])
val.push(val2.pop())
return val.pop() + val2.pop()
self.assertEqual(*test_equality(f, lambda x: x))

View File

@ -5551,23 +5551,23 @@ def foo(x):
return (cmp_key(obj1), cmp_key(obj2))
def f():
val = torch.classes._TorchScriptTesting_Foo(5, 3)
val = torch.classes._TorchScriptTesting._Foo(5, 3)
val.increment(1)
return val
test_equality(f, lambda x: x)
with self.assertRaisesRegex(RuntimeError, "Expected a value of type 'int'"):
val = torch.classes._TorchScriptTesting_Foo(5, 3)
val = torch.classes._TorchScriptTesting._Foo(5, 3)
val.increment('foo')
def f():
ss = torch.classes._TorchScriptTesting_StackString(["asdf", "bruh"])
ss = torch.classes._TorchScriptTesting._StackString(["asdf", "bruh"])
return ss.pop()
test_equality(f, lambda x: x)
def f():
ss1 = torch.classes._TorchScriptTesting_StackString(["asdf", "bruh"])
ss2 = torch.classes._TorchScriptTesting_StackString(["111", "222"])
ss1 = torch.classes._TorchScriptTesting._StackString(["asdf", "bruh"])
ss2 = torch.classes._TorchScriptTesting._StackString(["111", "222"])
ss1.push(ss2.pop())
return ss1.pop() + ss2.pop()
test_equality(f, lambda x: x)
@ -5575,14 +5575,14 @@ def foo(x):
@skipIfRocm
def test_torchbind_take_as_arg(self):
global StackString # see [local resolution in python]
StackString = torch.classes._TorchScriptTesting_StackString
StackString = torch.classes._TorchScriptTesting._StackString
def foo(stackstring):
# type: (StackString)
stackstring.push("lel")
return stackstring
script_input = torch.classes._TorchScriptTesting_StackString([])
script_input = torch.classes._TorchScriptTesting._StackString([])
scripted = torch.jit.script(foo)
script_output = scripted(script_input)
self.assertEqual(script_output.pop(), "lel")
@ -5590,7 +5590,7 @@ def foo(x):
@skipIfRocm
def test_torchbind_return_instance(self):
def foo():
ss = torch.classes._TorchScriptTesting_StackString(["hi", "mom"])
ss = torch.classes._TorchScriptTesting._StackString(["hi", "mom"])
return ss
scripted = torch.jit.script(foo)
@ -5606,7 +5606,7 @@ def foo(x):
@skipIfRocm
def test_torchbind_return_instance_from_method(self):
def foo():
ss = torch.classes._TorchScriptTesting_StackString(["hi", "mom"])
ss = torch.classes._TorchScriptTesting._StackString(["hi", "mom"])
clone = ss.clone()
ss.pop()
return ss, clone
@ -5620,8 +5620,8 @@ def foo(x):
@skipIfRocm
def test_torchbind_take_instance_as_method_arg(self):
def foo():
ss = torch.classes._TorchScriptTesting_StackString(["mom"])
ss2 = torch.classes._TorchScriptTesting_StackString(["hi"])
ss = torch.classes._TorchScriptTesting._StackString(["mom"])
ss2 = torch.classes._TorchScriptTesting._StackString(["hi"])
ss.merge(ss2)
return ss
@ -5633,7 +5633,7 @@ def foo(x):
@skipIfRocm
def test_torchbind_return_tuple(self):
def f():
val = torch.classes._TorchScriptTesting_StackString(["3", "5"])
val = torch.classes._TorchScriptTesting._StackString(["3", "5"])
return val.return_a_tuple()
scripted = torch.jit.script(f)
@ -5643,8 +5643,8 @@ def foo(x):
@skipIfRocm
def test_torchbind_save_load(self):
def foo():
ss = torch.classes._TorchScriptTesting_StackString(["mom"])
ss2 = torch.classes._TorchScriptTesting_StackString(["hi"])
ss = torch.classes._TorchScriptTesting._StackString(["mom"])
ss2 = torch.classes._TorchScriptTesting._StackString(["hi"])
ss.merge(ss2)
return ss
@ -5673,7 +5673,7 @@ def foo(x):
@skipIfRocm
def test_torchbind_lambda_method(self):
def foo():
ss = torch.classes._TorchScriptTesting_StackString(["mom"])
ss = torch.classes._TorchScriptTesting._StackString(["mom"])
return ss.top()
scripted = torch.jit.script(foo)
@ -5684,7 +5684,7 @@ def foo(x):
class FooBar1234(torch.nn.Module):
def __init__(self):
super(FooBar1234, self).__init__()
self.f = torch.classes._TorchScriptTesting_StackString(["3", "4"])
self.f = torch.classes._TorchScriptTesting._StackString(["3", "4"])
def forward(self):
return self.f.top()
@ -5701,7 +5701,7 @@ def foo(x):
class FooBar4321(torch.nn.Module):
def __init__(self):
super(FooBar4321, self).__init__()
self.f = torch.classes._TorchScriptTesting_PickleTester([3, 4])
self.f = torch.classes._TorchScriptTesting._PickleTester([3, 4])
def forward(self):
return self.f.top()
@ -5723,7 +5723,7 @@ def foo(x):
class TryTracing(torch.nn.Module):
def __init__(self):
super(TryTracing, self).__init__()
self.f = torch.classes._TorchScriptTesting_PickleTester([3, 4])
self.f = torch.classes._TorchScriptTesting._PickleTester([3, 4])
def forward(self):
return torch.ops._TorchScriptTesting.take_an_instance(self.f)
@ -5736,7 +5736,7 @@ def foo(x):
class TryTracingNest(torch.nn.Module):
def __init__(self):
super(TryTracingNest, self).__init__()
self.f = torch.classes._TorchScriptTesting_PickleTester([3, 4])
self.f = torch.classes._TorchScriptTesting._PickleTester([3, 4])
class TryTracing123(torch.nn.Module):
def __init__(self):
@ -5751,7 +5751,7 @@ def foo(x):
@skipIfRocm
def test_torchbind_pickle_serialization(self):
nt = torch.classes._TorchScriptTesting_PickleTester([3, 4])
nt = torch.classes._TorchScriptTesting._PickleTester([3, 4])
b = io.BytesIO()
torch.save(nt, b)
b.seek(0)
@ -5761,8 +5761,8 @@ def foo(x):
@skipIfRocm
def test_torchbind_instantiate_missing_class(self):
with self.assertRaisesRegex(RuntimeError, 'Tried to instantiate class IDontExist but it does not exist!'):
torch.classes.IDontExist(3, 4, 5)
with self.assertRaisesRegex(RuntimeError, 'Tried to instantiate class foo.IDontExist but it does not exist!'):
torch.classes.foo.IDontExist(3, 4, 5)
def test_jitter_bug(self):
@torch.jit.script

View File

@ -1,15 +1,25 @@
import types
import torch._C
class _ClassNamespace(types.ModuleType):
def __init__(self, name):
super(_ClassNamespace, self).__init__('torch.classes' + name)
self.name = name
def __getattr__(self, attr):
proxy = torch._C._get_custom_class_python_wrapper(self.name, attr)
if proxy is None:
raise RuntimeError('Class {}.{} not registered!'.format(self.name, attr))
return proxy
class _Classes(types.ModuleType):
def __init__(self):
super(_Classes, self).__init__('torch.classes')
def __getattr__(self, attr):
proxy = torch._C._get_custom_class_python_wrapper(attr)
if proxy is None:
raise RuntimeError('Class {} not registered!'.format(attr))
return proxy
def __getattr__(self, name):
namespace = _ClassNamespace(name)
setattr(self, name, namespace)
return namespace
@property
def loaded_libraries(self):

View File

@ -244,12 +244,16 @@ std::pair<TypePtr, c10::optional<AliasInfo>> SchemaTypeParser::parseType() {
<< "Expected classes namespace but got " << classes_tok.text();
}
L.expect('.');
auto ns_tok = L.expect(TK_IDENT);
L.expect('.');
auto class_tok = L.expect(TK_IDENT);
value = getCustomClass(
std::string("__torch__.torch.classes.") + class_tok.text());
std::string("__torch__.torch.classes.") + ns_tok.text() + "." +
class_tok.text());
if (!value) {
throw ErrorReport(class_tok.range)
<< "Unknown custom class type " << class_tok.text()
<< "Unknown custom class type "
<< ns_tok.text() + "." + class_tok.text()
<< ". Please ensure it is registered.";
}
} else {

View File

@ -28,20 +28,23 @@ void initPythonCustomClassBindings(PyObject* module) {
// code object in turn calls __init__. Rather than calling __init__
// directly, we need a wrapper that at least returns the instance
// rather than the None return value from __init__
m.def("_get_custom_class_python_wrapper", [](const std::string& qualname) {
std::string full_qualname = "__torch__.torch.classes." + qualname;
auto named_type = getCustomClass(full_qualname);
TORCH_CHECK(
named_type,
"Tried to instantiate class ",
qualname,
" but it"
" does not exist! Ensure that it is registered via torch::jit"
"::class_");
c10::ClassTypePtr class_type = named_type->cast<ClassType>();
return ScriptClass(c10::StrongTypePtr(
std::shared_ptr<CompilationUnit>(), std::move(class_type)));
});
m.def(
"_get_custom_class_python_wrapper",
[](const std::string& ns, const std::string& qualname) {
std::string full_qualname =
"__torch__.torch.classes." + ns + "." + qualname;
auto named_type = getCustomClass(full_qualname);
TORCH_CHECK(
named_type,
"Tried to instantiate class ",
ns + "." + qualname,
" but it"
" does not exist! Ensure that it is registered via torch::jit"
"::class_");
c10::ClassTypePtr class_type = named_type->cast<ClassType>();
return ScriptClass(c10::StrongTypePtr(
std::shared_ptr<CompilationUnit>(), std::move(class_type)));
});
}
} // namespace jit

View File

@ -34,7 +34,7 @@ detail::types<void, Types...> init() {
/// calls needed. For example, to register a class named Foo, you might
/// create a global variable like so:
///
/// static auto register_foo = torch::class_<Foo>("Foo")
/// static auto register_foo = torch::class_<Foo>("myclasses", "Foo")
/// .def("myMethod", &Foo::myMethod)
/// .def("lambdaMethod", [](const c10::intrusive_ptr<Foo>& self) {
/// // Do something with `self`
@ -51,12 +51,16 @@ class class_ {
public:
/// This constructor actually registers the class type.
/// String argument `className_` is the name you would like to
/// String argument `namespaceName` is an identifier for the
/// namespace you would like this class to appear in.
/// String argument `className` is the name you would like to
/// see this class exposed as in Python and TorchScript. For example, if
/// you pass in "MyStack" here, the class will appear as
/// `torch.classes.MyStack` in both Python and TorchScript.
explicit class_(const std::string& className) : className(std::move(className)) {
qualClassName = topModule + "." + parentModule + "." + className;
/// you pass `foo` as the namespace name and `Bar` as the className, the
/// class will appear as `torch.classes.foo.Bar` in Python and TorchScript
explicit class_(const std::string& namespaceName, const std::string& className) {
detail::checkValidIdent(namespaceName, "Namespace name");
detail::checkValidIdent(className, "Class name");
qualClassName = std::string("__torch__.torch.classes.") + namespaceName + "." + className;
classTypePtr = at::ClassType::create(
c10::QualifiedName(qualClassName),
@ -230,12 +234,8 @@ class class_ {
classTypePtr->addMethod(method.get());
}
std::string className;
std::string qualClassName;
at::ClassTypePtr classTypePtr;
const std::string parentModule = "classes";
const std::string topModule = "__torch__.torch";
};
/// make_custom_class() is a convenient way to create an instance of a registered

View File

@ -119,6 +119,20 @@ struct BoxedProxy<void, Func> {
}
};
inline bool validIdent(size_t i, char n) {
return isalpha(n) || n == '_' || (i > 0 && isdigit(n));
}
inline void checkValidIdent(const std::string& str, const char *type) {
for (size_t i = 0; i < str.size(); ++i) {
TORCH_CHECK(validIdent(i, str[i]),
type,
" must be a valid Python/C++ identifier."
" Character '", str[i], "' at index ",
i, " is illegal.");
}
}
} // namespace detail
TORCH_API void registerCustomClass(at::ClassTypePtr class_type);