Back out "Make TorchScript Preserve Fully Qualified Class Name for Python Exceptions"

Summary: as title

Test Plan:
```
buck run mode/opt-split-dwarf -c=python.package_style=inplace //ai_infra/distributed_ai/pyper_test_framework/templates:pyper_release_v2 -- --model inline_cvr_post_imp_deterministic_shrunk_pyper_release_v2 --cluster TSCTestCluster --hpc_identity oncall_pyper_oncall --stage prod_offline_training --test_module training_platform
...
############## Start inline_cvr_post_imp_model Test Results Analysis ##############
I1226 22:03:56.789000 3346280 test_driver.py:139  UNKNOWN     ] Test finished in 808.2743511786684 seconds.
+-------------------------+---------+------------------------+-----------------+
| Test Case               | Status  | Message                | Model Entity ID |
+-------------------------+---------+------------------------+-----------------+
| SmallWorld_release_test | Success | finished successfully. | 987987491       |
+-------------------------+---------+------------------------+-----------------+
I1226 22:03:56.790000 3346280 test_driver.py:143  UNKNOWN     ] test_run_id: 3d085f61-28d1-411d-bd27-940ea2554b23 use this id to find your run in scuba pyper_test_framework
I1226 22:03:56.792000 3346280 test_driver.py:160  UNKNOWN     ] Calling cleanup
I1226 22:03:56.792000 3346280 training_platform_test_launcher.py:385  UNKNOWN     ] Stopping launched jobs 1
I1226 22:03:59.563122 3346280 ClientSingletonManager.cpp:100] Shutting down Manifold ClientSingletonManager
```

Reviewed By: seemethere

Differential Revision: D33325936

fbshipit-source-id: 64414bf7061ad77e8ac12eb8abafee4043e0fa1e
This commit is contained in:
Bo Wu
2021-12-27 09:10:24 -08:00
committed by Facebook GitHub Bot
parent 4ae71c8d34
commit bf610f08b0
14 changed files with 168 additions and 427 deletions

View File

@ -1,159 +0,0 @@
/*
* We have a python unit test for exceptions in test/jit/test_exception.py .
* Add a CPP version here to verify that excepted exception types thrown from
* C++. This is hard to test in python code since C++ exceptions will be
* translated to python exceptions.
*/
#include <gtest/gtest.h>
#include <pybind11/embed.h>
#include <torch/csrc/jit/frontend/parser.h>
#include <torch/csrc/jit/frontend/resolver.h>
#include <torch/csrc/jit/runtime/jit_exception.h>
#include <torch/jit.h>
#include <iostream>
#include <stdexcept>
namespace torch {
namespace jit {
namespace py = pybind11;
TEST(TestException, TestAssertion) {
std::string pythonCode = R"PY(
def foo():
raise AssertionError("An assertion failed")
)PY";
auto cu_ptr = torch::jit::compile(pythonCode);
torch::jit::GraphFunction* gf =
(torch::jit::GraphFunction*)&cu_ptr->get_function("foo");
std::cerr << "Graph is\n" << *gf->graph() << std::endl;
bool is_jit_exception = false;
std::string message;
c10::optional<std::string> exception_class;
try {
cu_ptr->run_method("foo");
} catch (JITException& e) {
is_jit_exception = true;
message = e.what();
exception_class = e.getPythonClassName();
}
EXPECT_TRUE(is_jit_exception);
EXPECT_FALSE(exception_class);
EXPECT_TRUE(
message.find("RuntimeError: AssertionError: An assertion failed") !=
std::string::npos);
}
struct MyPythonExceptionValue : public torch::jit::SugaredValue {
explicit MyPythonExceptionValue(const py::object& exception_class) {
qualified_name_ =
(py::str(py::getattr(exception_class, "__module__", py::str(""))) +
py::str(".") +
py::str(py::getattr(exception_class, "__name__", py::str(""))))
.cast<std::string>();
}
std::string kind() const override {
return "My Python exception";
}
// Simplified from PythonExceptionValue::call
std::shared_ptr<torch::jit::SugaredValue> call(
const torch::jit::SourceRange& loc,
torch::jit::GraphFunction& caller,
at::ArrayRef<torch::jit::NamedValue> args,
at::ArrayRef<torch::jit::NamedValue> kwargs,
size_t n_binders) override {
TORCH_CHECK(args.size() == 1);
Value* error_message = args.at(0).value(*caller.graph());
Value* qualified_class_name =
insertConstant(*caller.graph(), qualified_name_, loc);
return std::make_shared<ExceptionMessageValue>(
error_message, qualified_class_name);
}
private:
std::string qualified_name_;
};
class SimpleResolver : public torch::jit::Resolver {
public:
explicit SimpleResolver() {}
std::shared_ptr<torch::jit::SugaredValue> resolveValue(
const std::string& name,
torch::jit::GraphFunction& m,
const torch::jit::SourceRange& loc) override {
// follows toSugaredValue (toSugaredValue is defined in caffe2:_C which is
// a python extension. We can not add that as a cpp_binary's dep)
if (name == "SimpleValueError") {
py::object obj = py::globals()["SimpleValueError"];
return std::make_shared<MyPythonExceptionValue>(obj);
}
TORCH_CHECK(false, "resolveValue: can not resolve '", name, "{}'");
}
torch::jit::TypePtr resolveType(
const std::string& name,
const torch::jit::SourceRange& loc) override {
return nullptr;
}
};
/*
* - The python source code parsing for TorchScript here is learned from
* torch::jit::compile.
* - The code only parses one Def. If there are multiple in the code, those
* except the first one are skipped.
*/
TEST(TestException, TestCustomException) {
py::scoped_interpreter guard{};
py::exec(R"PY(
class SimpleValueError(ValueError):
def __init__(self, message):
super(SimpleValueError, self).__init__(message)
)PY");
std::string pythonCode = R"PY(
def foo():
raise SimpleValueError("An assertion failed")
)PY";
torch::jit::Parser p(
std::make_shared<torch::jit::Source>(pythonCode, "<string>", 1));
auto def = torch::jit::Def(p.parseFunction(/*is_method=*/false));
std::cerr << "Def is:\n" << def << std::endl;
auto cu = std::make_shared<torch::jit::CompilationUnit>();
(void)cu->define(
c10::nullopt,
{},
{},
{def},
// class PythonResolver is defined in
// torch/csrc/jit/python/script_init.cpp. It's not in a header file so I
// can not use it. Create a SimpleResolver insteand
{std::make_shared<SimpleResolver>()},
nullptr);
torch::jit::GraphFunction* gf =
(torch::jit::GraphFunction*)&cu->get_function("foo");
std::cerr << "Graph is\n" << *gf->graph() << std::endl;
bool is_jit_exception = false;
c10::optional<std::string> exception_class;
std::string message;
try {
cu->run_method("foo");
} catch (JITException& e) {
is_jit_exception = true;
exception_class = e.getPythonClassName();
message = e.what();
}
EXPECT_TRUE(is_jit_exception);
EXPECT_EQ("__main__.SimpleValueError", *exception_class);
EXPECT_TRUE(
message.find("__main__.SimpleValueError: An assertion failed") !=
std::string::npos);
}
} // namespace jit
} // namespace torch

View File

@ -1,8 +0,0 @@
r"""
Define exceptions used in test_exception.py. We define them in a
separate file on purpose to make sure the fully qualified exception class name
is captured correctly in suce cases.
"""
class MyKeyError(KeyError):
def __init__(self, msg):
super(KeyError, self).__init__(msg)

View File

@ -1,175 +0,0 @@
from torch.testing._internal.common_utils import TestCase
import torch
from torch import nn
r"""
Test TorchScript exception handling.
"""
class TestException(TestCase):
def test_assertions(self):
cu = torch.jit.CompilationUnit('''
def foo(cond):
assert bool(cond), "hi"
return 0
''')
cu.foo(torch.tensor(1))
with self.assertRaisesRegex(torch.jit.Error, "AssertionError: hi"):
cu.foo(torch.tensor(0))
@torch.jit.script
def foo(cond):
assert bool(cond), "hi"
foo(torch.tensor(1))
# we don't currently validate the name of the exception
with self.assertRaisesRegex(torch.jit.Error, "AssertionError: hi"):
foo(torch.tensor(0))
def test_pyop_exception_message(self):
class Foo(torch.jit.ScriptModule):
def __init__(self):
super(Foo, self).__init__()
self.conv = nn.Conv2d(1, 10, kernel_size=5)
@torch.jit.script_method
def forward(self, x):
return self.conv(x)
foo = Foo()
# testing that the correct error message propagates
with self.assertRaisesRegex(RuntimeError, "Expected 4-dimensional input for 4-dimensional weight"):
foo(torch.ones([123])) # wrong size
def test_builtin_error_messsage(self):
with self.assertRaisesRegex(RuntimeError, "Arguments for call are not valid"):
@torch.jit.script
def close_match(x):
return x.masked_fill(True)
with self.assertRaisesRegex(RuntimeError, "This op may not exist or may not be currently "
"supported in TorchScript"):
@torch.jit.script
def unknown_op(x):
torch.set_anomaly_enabled(True)
return x
def test_exceptions(self):
cu = torch.jit.CompilationUnit('''
def foo(cond):
if bool(cond):
raise ValueError(3)
return 1
''')
cu.foo(torch.tensor(0))
with self.assertRaisesRegex(torch.jit.Error, "3"):
cu.foo(torch.tensor(1))
def foo(cond):
a = 3
if bool(cond):
raise ArbitraryError(a, "hi")
if 1 == 2:
raise ArbitraryError
return a
with self.assertRaisesRegex(RuntimeError, "undefined value ArbitraryError"):
torch.jit.script(foo)
def exception_as_value():
a = Exception()
print(a)
with self.assertRaisesRegex(RuntimeError, "cannot be used as a value"):
torch.jit.script(exception_as_value)
@torch.jit.script
def foo_no_decl_always_throws():
raise RuntimeError("Hi")
# function that has no declared type but always throws set to None
output_type = next(foo_no_decl_always_throws.graph.outputs()).type()
self.assertTrue(str(output_type) == "NoneType")
@torch.jit.script
def foo_decl_always_throws():
# type: () -> Tensor
raise Exception("Hi")
output_type = next(foo_decl_always_throws.graph.outputs()).type()
self.assertTrue(str(output_type) == "Tensor")
def foo():
raise 3 + 4
with self.assertRaisesRegex(RuntimeError, "must derive from BaseException"):
torch.jit.script(foo)
# a escapes scope
@torch.jit.script
def foo():
if 1 == 1:
a = 1
else:
if 1 == 1:
raise Exception("Hi")
else:
raise Exception("Hi")
return a
self.assertEqual(foo(), 1)
@torch.jit.script
def tuple_fn():
raise RuntimeError("hello", "goodbye")
with self.assertRaisesRegex(torch.jit.Error, "hello, goodbye"):
tuple_fn()
@torch.jit.script
def no_message():
raise RuntimeError
with self.assertRaisesRegex(torch.jit.Error, "RuntimeError"):
no_message()
def test_python_op_exception(self):
@torch.jit.ignore
def python_op(x):
raise Exception("bad!")
@torch.jit.script
def fn(x):
return python_op(x)
with self.assertRaisesRegex(RuntimeError, "operation failed in the TorchScript interpreter"):
fn(torch.tensor(4))
def test_dict_expansion_raises_error(self):
def fn(self):
d = {"foo": 1, "bar": 2, "baz": 3}
return {**d}
with self.assertRaisesRegex(torch.jit.frontend.NotSupportedError,
"Dict expansion "):
torch.jit.script(fn)
def test_custom_python_exception(self):
class MyValueError(ValueError):
def __init__(self, msg):
super(MyValueError, self).__init__(msg)
@torch.jit.script
def fn():
raise MyValueError("test custom exception")
with self.assertRaisesRegex(torch.jit.Error, "jit.test_exception.MyValueError: test custom exception"):
fn()
def test_custom_python_exception_defined_elsewhere(self):
from jit.myexception import MyKeyError
@torch.jit.script
def fn():
raise MyKeyError("This is a user defined key error")
with self.assertRaisesRegex(torch.jit.Error, "jit.myexception.MyKeyError: This is a user defined key error"):
fn()

View File

@ -73,7 +73,6 @@ from jit.test_batch_mm import TestBatchMM # noqa: F401
from jit.test_dtype_analysis import TestDtypeAnalysis, TestDtypeCustomRulesCPU # noqa: F401
from jit.test_dce import TestDCE # noqa: F401
from jit.test_sparse import TestSparse # noqa: F401
from jit.test_exception import TestException # noqa: F401
# Torch
from torch import Tensor
@ -12994,6 +12993,153 @@ dedent """
self.checkScript(dedent(code), (101,))
def test_pyop_exception_message(self):
class Foo(torch.jit.ScriptModule):
def __init__(self):
super(Foo, self).__init__()
self.conv = nn.Conv2d(1, 10, kernel_size=5)
@torch.jit.script_method
def forward(self, x):
return self.conv(x)
foo = Foo()
# testing that the correct error message propagates
with self.assertRaisesRegex(RuntimeError, "Expected 4-dimensional input for 4-dimensional weight"):
foo(torch.ones([123])) # wrong size
def test_builtin_error_messsage(self):
with self.assertRaisesRegex(RuntimeError, "Arguments for call are not valid"):
@torch.jit.script
def close_match(x):
return x.masked_fill(True)
with self.assertRaisesRegex(RuntimeError, "This op may not exist or may not be currently "
"supported in TorchScript"):
@torch.jit.script
def unknown_op(x):
torch.set_anomaly_enabled(True)
return x
def test_exceptions(self):
cu = torch.jit.CompilationUnit('''
def foo(cond):
if bool(cond):
raise ValueError(3)
return 1
''')
cu.foo(torch.tensor(0))
with self.assertRaisesRegex(torch.jit.Error, "3"):
cu.foo(torch.tensor(1))
def foo(cond):
a = 3
if bool(cond):
raise ArbitraryError(a, "hi")
if 1 == 2:
raise ArbitraryError
return a
with self.assertRaisesRegex(RuntimeError, "undefined value ArbitraryError"):
torch.jit.script(foo)
def exception_as_value():
a = Exception()
print(a)
with self.assertRaisesRegex(RuntimeError, "cannot be used as a value"):
torch.jit.script(exception_as_value)
@torch.jit.script
def foo_no_decl_always_throws():
raise RuntimeError("Hi")
# function that has no declared type but always throws set to None
output_type = next(foo_no_decl_always_throws.graph.outputs()).type()
self.assertTrue(str(output_type) == "NoneType")
@torch.jit.script
def foo_decl_always_throws():
# type: () -> Tensor
raise Exception("Hi")
output_type = next(foo_decl_always_throws.graph.outputs()).type()
self.assertTrue(str(output_type) == "Tensor")
def foo():
raise 3 + 4
with self.assertRaisesRegex(RuntimeError, "must derive from BaseException"):
torch.jit.script(foo)
# a escapes scope
@torch.jit.script
def foo():
if 1 == 1:
a = 1
else:
if 1 == 1:
raise Exception("Hi")
else:
raise Exception("Hi")
return a
self.assertEqual(foo(), 1)
@torch.jit.script
def tuple_fn():
raise RuntimeError("hello", "goodbye")
with self.assertRaisesRegex(torch.jit.Error, "hello, goodbye"):
tuple_fn()
@torch.jit.script
def no_message():
raise RuntimeError
with self.assertRaisesRegex(torch.jit.Error, "RuntimeError"):
no_message()
def test_assertions(self):
cu = torch.jit.CompilationUnit('''
def foo(cond):
assert bool(cond), "hi"
return 0
''')
cu.foo(torch.tensor(1))
with self.assertRaisesRegex(torch.jit.Error, "AssertionError: hi"):
cu.foo(torch.tensor(0))
@torch.jit.script
def foo(cond):
assert bool(cond), "hi"
foo(torch.tensor(1))
# we don't currently validate the name of the exception
with self.assertRaisesRegex(torch.jit.Error, "AssertionError: hi"):
foo(torch.tensor(0))
def test_python_op_exception(self):
@torch.jit.ignore
def python_op(x):
raise Exception("bad!")
@torch.jit.script
def fn(x):
return python_op(x)
with self.assertRaisesRegex(RuntimeError, "operation failed in the TorchScript interpreter"):
fn(torch.tensor(4))
def test_dict_expansion_raises_error(self):
def fn(self):
d = {"foo": 1, "bar": 2, "baz": 3}
return {**d}
with self.assertRaisesRegex(torch.jit.frontend.NotSupportedError,
"Dict expansion "):
torch.jit.script(fn)
def test_module_parameters_and_buffers(self):
weights = torch.randn(10, 10)
bias = torch.randn(10)

View File

@ -977,7 +977,7 @@ def is_scripting() -> bool:
# Retrieves a fully-qualified name (module hierarchy + classname) for a given obj.
def _qualified_name(obj, mangle_name=True) -> str:
def _qualified_name(obj) -> str:
# This special case allows us to override the qualified name on a type.
# It's currently used in conjunction with tracing, where we create a
# fake module to filter only supported attributes. However, since this
@ -1026,16 +1026,13 @@ def _qualified_name(obj, mangle_name=True) -> str:
module_name = module_name.replace("<", "_")
module_name = module_name.replace(">", "_")
# The PythonExceptionValue C++ class in torch/csrc/jit/python/python_sugared_value.h
# does not need mangle the python class name.
if mangle_name:
# __main__ is a builtin module, so rewrite it to "__torch__".
if module_name == "__main__":
module_name = "__torch__"
else:
# Everything else gets a "__torch__" prefix to avoid name collisions
# with the names of user values.
module_name = "__torch__." + module_name
# __main__ is a builtin module, so rewrite it to "__torch__".
if module_name == "__main__":
module_name = "__torch__"
else:
# Everything else gets a "__torch__" prefix to avoid name collisions
# with the names of user values.
module_name = "__torch__." + module_name
if "." in name:
raise RuntimeError(f"Could not get qualified name for class '{name}': "

View File

@ -2469,14 +2469,12 @@ struct to_ir {
void emitRaise(const Raise& raise) {
auto sv = emitSugaredExpr(raise.expr(), 1);
Value* error_message = nullptr;
Value* qualified_class_name = nullptr;
if (auto exception_instance =
std::dynamic_pointer_cast<ExceptionMessageValue>(sv)) {
// The typical case, an instance of the exception class was thrown:
// raise RuntimeError("error")
error_message = exception_instance->getValue();
qualified_class_name = exception_instance->getQualifiedClassName();
} else if (
auto exception_class = std::dynamic_pointer_cast<ExceptionValue>(sv)) {
// A bare exception was thrown so add an empty message. e.g.
@ -2493,11 +2491,7 @@ struct to_ir {
error_message = graph->insert(aten::str, {error_message});
}
graph->insert(
prim::RaiseException,
{error_message, qualified_class_name},
{},
raise.range());
graph->insert(prim::RaiseException, {error_message}, {}, raise.range());
exit_blocks.insert(environment_stack->block());
}

View File

@ -741,10 +741,7 @@ struct SimpleSelf : public Self {
// This is not a SimpleValue so it can not pass through the code paths that
// expect a SimpleValue as a sugared value.
struct TORCH_API ExceptionMessageValue : public SugaredValue {
explicit ExceptionMessageValue(
Value* value,
Value* qualified_class_name = nullptr)
: value_(value), qualified_class_name_(qualified_class_name) {}
explicit ExceptionMessageValue(Value* value) : value_(value) {}
std::string kind() const override {
return "exception message";
@ -754,14 +751,7 @@ struct TORCH_API ExceptionMessageValue : public SugaredValue {
return value_;
}
// qualified python class name
Value* getQualifiedClassName() {
return qualified_class_name_;
}
private:
Value* value_;
Value* qualified_class_name_;
};
struct TORCH_API ExceptionValue : public SugaredValue {

View File

@ -14,12 +14,7 @@ void tupleIndex(Stack& stack) {
}
void raiseException(Stack& stack) {
c10::optional<std::string> qualified_class_name =
pop(stack).toOptional<std::string>();
std::string message;
pop(stack, message);
throw JITException(message, qualified_class_name);
throw JITException(pop(stack).toStringRef());
}
void is(Stack& stack) {

View File

@ -914,11 +914,8 @@ std::shared_ptr<SugaredValue> PythonExceptionValue::call(
->insertNode(caller.graph()->createTuple(message_values))
->output();
}
Value* qualified_class_name =
insertConstant(*caller.graph(), exception_class_qualified_name_, loc);
return std::make_shared<ExceptionMessageValue>(
error_message, qualified_class_name);
return std::make_shared<ExceptionMessageValue>(error_message);
}
bool isNamedTupleClass(const py::object& obj) {

View File

@ -328,12 +328,7 @@ struct VISIBILITY_HIDDEN PythonClassValue : public ClassValue {
struct VISIBILITY_HIDDEN PythonExceptionValue : public ExceptionValue {
explicit PythonExceptionValue(const py::object& exception_class)
: ExceptionValue(
py::str(py::getattr(exception_class, "__name__", py::str("")))),
exception_class_qualified_name_(
py::str(py::module::import("torch._jit_internal")
.attr("_qualified_name")(
exception_class,
/*mangle_name=*/false))) {}
py::str(py::getattr(exception_class, "__name__", py::str("")))) {}
std::string kind() const override {
return "Python exception";
@ -345,9 +340,6 @@ struct VISIBILITY_HIDDEN PythonExceptionValue : public ExceptionValue {
at::ArrayRef<NamedValue> args,
at::ArrayRef<NamedValue> kwargs,
size_t n_binders) override;
private:
std::string exception_class_qualified_name_;
};
// Python Slice class.

View File

@ -714,19 +714,10 @@ struct InterpreterStateImpl : c10::intrusive_ptr_target {
}
throw;
}
auto* jit_exception = dynamic_cast<JITException*>(&e);
bool is_jit_exception = dynamic_cast<JITException*>(&e);
// Janky af. See https://github.com/pytorch/pytorch/issues/54612
auto* not_implemented_error = dynamic_cast<c10::NotImplementedError*>(&e);
c10::optional<std::string> python_class_name;
if (jit_exception) {
python_class_name = jit_exception->getPythonClassName();
}
handleError(
ExceptionMessage(e),
(bool)jit_exception,
not_implemented_error,
python_class_name);
handleError(ExceptionMessage(e), is_jit_exception, not_implemented_error);
return false;
}
}
@ -745,18 +736,15 @@ struct InterpreterStateImpl : c10::intrusive_ptr_target {
void handleError(
const ExceptionMessage& msg,
bool is_jit_exception,
c10::NotImplementedError* not_implemented_error,
c10::optional<std::string> python_class_name) {
c10::NotImplementedError* not_implemented_error) {
std::ostringstream ss;
std::string class_name =
python_class_name ? *python_class_name : "RuntimeError";
ss << "The following operation failed in the TorchScript interpreter.\n";
formatStackTrace(ss);
ss << class_name << ": " << msg << "\n";
ss << "RuntimeError: " << msg << "\n";
if (future_) {
future_->setError(std::make_exception_ptr(Future::FutureError(ss.str())));
} else if (is_jit_exception) {
throw JITException(ss.str(), python_class_name);
throw JITException(ss.str());
} else if (not_implemented_error) {
throw c10::NotImplementedError(
ss.str(),

View File

@ -3,11 +3,7 @@
namespace torch {
namespace jit {
JITException::JITException(
const std::string& msg,
c10::optional<std::string> python_class_name)
: std::runtime_error(msg),
python_class_name_(std::move(python_class_name)) {}
JITException::JITException(const std::string& msg) : std::runtime_error(msg) {}
} // namespace jit
} // namespace torch

View File

@ -2,24 +2,13 @@
#include <stdexcept>
#include <c10/util/Optional.h>
#include <torch/csrc/Export.h>
#include <string>
namespace torch {
namespace jit {
struct TORCH_API JITException : public std::runtime_error {
explicit JITException(
const std::string& msg,
c10::optional<std::string> python_class_name = c10::nullopt);
c10::optional<std::string> getPythonClassName() const {
return python_class_name_;
}
private:
c10::optional<std::string> python_class_name_;
explicit JITException(const std::string& msg);
};
} // namespace jit

View File

@ -406,8 +406,7 @@ static const OperatorGeneratorArgs opGenArgs[] = {
numToTensorScalar,
aliasAnalysisFromSchema()),
OperatorGeneratorArgs(
TORCH_SELECTIVE_SCHEMA(
"prim::RaiseException(str msg, str? cls=None) -> ()"),
TORCH_SELECTIVE_SCHEMA("prim::RaiseException(str msg) -> ()"),
raiseException,
aliasAnalysisFromSchema()),
OperatorGeneratorArgs(