mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Back out "Make TorchScript Preserve Fully Qualified Class Name for Python Exceptions"
Summary: as title Test Plan: ``` buck run mode/opt-split-dwarf -c=python.package_style=inplace //ai_infra/distributed_ai/pyper_test_framework/templates:pyper_release_v2 -- --model inline_cvr_post_imp_deterministic_shrunk_pyper_release_v2 --cluster TSCTestCluster --hpc_identity oncall_pyper_oncall --stage prod_offline_training --test_module training_platform ... ############## Start inline_cvr_post_imp_model Test Results Analysis ############## I1226 22:03:56.789000 3346280 test_driver.py:139 UNKNOWN ] Test finished in 808.2743511786684 seconds. +-------------------------+---------+------------------------+-----------------+ | Test Case | Status | Message | Model Entity ID | +-------------------------+---------+------------------------+-----------------+ | SmallWorld_release_test | Success | finished successfully. | 987987491 | +-------------------------+---------+------------------------+-----------------+ I1226 22:03:56.790000 3346280 test_driver.py:143 UNKNOWN ] test_run_id: 3d085f61-28d1-411d-bd27-940ea2554b23 use this id to find your run in scuba pyper_test_framework I1226 22:03:56.792000 3346280 test_driver.py:160 UNKNOWN ] Calling cleanup I1226 22:03:56.792000 3346280 training_platform_test_launcher.py:385 UNKNOWN ] Stopping launched jobs 1 I1226 22:03:59.563122 3346280 ClientSingletonManager.cpp:100] Shutting down Manifold ClientSingletonManager ``` Reviewed By: seemethere Differential Revision: D33325936 fbshipit-source-id: 64414bf7061ad77e8ac12eb8abafee4043e0fa1e
This commit is contained in:
committed by
Facebook GitHub Bot
parent
4ae71c8d34
commit
bf610f08b0
@ -1,159 +0,0 @@
|
||||
/*
|
||||
* We have a python unit test for exceptions in test/jit/test_exception.py .
|
||||
* Add a CPP version here to verify that excepted exception types thrown from
|
||||
* C++. This is hard to test in python code since C++ exceptions will be
|
||||
* translated to python exceptions.
|
||||
*/
|
||||
#include <gtest/gtest.h>
|
||||
#include <pybind11/embed.h>
|
||||
#include <torch/csrc/jit/frontend/parser.h>
|
||||
#include <torch/csrc/jit/frontend/resolver.h>
|
||||
#include <torch/csrc/jit/runtime/jit_exception.h>
|
||||
#include <torch/jit.h>
|
||||
#include <iostream>
|
||||
#include <stdexcept>
|
||||
|
||||
namespace torch {
|
||||
namespace jit {
|
||||
|
||||
namespace py = pybind11;
|
||||
|
||||
TEST(TestException, TestAssertion) {
|
||||
std::string pythonCode = R"PY(
|
||||
def foo():
|
||||
raise AssertionError("An assertion failed")
|
||||
)PY";
|
||||
auto cu_ptr = torch::jit::compile(pythonCode);
|
||||
torch::jit::GraphFunction* gf =
|
||||
(torch::jit::GraphFunction*)&cu_ptr->get_function("foo");
|
||||
std::cerr << "Graph is\n" << *gf->graph() << std::endl;
|
||||
|
||||
bool is_jit_exception = false;
|
||||
std::string message;
|
||||
c10::optional<std::string> exception_class;
|
||||
try {
|
||||
cu_ptr->run_method("foo");
|
||||
} catch (JITException& e) {
|
||||
is_jit_exception = true;
|
||||
message = e.what();
|
||||
exception_class = e.getPythonClassName();
|
||||
}
|
||||
EXPECT_TRUE(is_jit_exception);
|
||||
EXPECT_FALSE(exception_class);
|
||||
EXPECT_TRUE(
|
||||
message.find("RuntimeError: AssertionError: An assertion failed") !=
|
||||
std::string::npos);
|
||||
}
|
||||
|
||||
struct MyPythonExceptionValue : public torch::jit::SugaredValue {
|
||||
explicit MyPythonExceptionValue(const py::object& exception_class) {
|
||||
qualified_name_ =
|
||||
(py::str(py::getattr(exception_class, "__module__", py::str(""))) +
|
||||
py::str(".") +
|
||||
py::str(py::getattr(exception_class, "__name__", py::str(""))))
|
||||
.cast<std::string>();
|
||||
}
|
||||
|
||||
std::string kind() const override {
|
||||
return "My Python exception";
|
||||
}
|
||||
|
||||
// Simplified from PythonExceptionValue::call
|
||||
std::shared_ptr<torch::jit::SugaredValue> call(
|
||||
const torch::jit::SourceRange& loc,
|
||||
torch::jit::GraphFunction& caller,
|
||||
at::ArrayRef<torch::jit::NamedValue> args,
|
||||
at::ArrayRef<torch::jit::NamedValue> kwargs,
|
||||
size_t n_binders) override {
|
||||
TORCH_CHECK(args.size() == 1);
|
||||
Value* error_message = args.at(0).value(*caller.graph());
|
||||
Value* qualified_class_name =
|
||||
insertConstant(*caller.graph(), qualified_name_, loc);
|
||||
return std::make_shared<ExceptionMessageValue>(
|
||||
error_message, qualified_class_name);
|
||||
}
|
||||
|
||||
private:
|
||||
std::string qualified_name_;
|
||||
};
|
||||
|
||||
class SimpleResolver : public torch::jit::Resolver {
|
||||
public:
|
||||
explicit SimpleResolver() {}
|
||||
|
||||
std::shared_ptr<torch::jit::SugaredValue> resolveValue(
|
||||
const std::string& name,
|
||||
torch::jit::GraphFunction& m,
|
||||
const torch::jit::SourceRange& loc) override {
|
||||
// follows toSugaredValue (toSugaredValue is defined in caffe2:_C which is
|
||||
// a python extension. We can not add that as a cpp_binary's dep)
|
||||
if (name == "SimpleValueError") {
|
||||
py::object obj = py::globals()["SimpleValueError"];
|
||||
return std::make_shared<MyPythonExceptionValue>(obj);
|
||||
}
|
||||
TORCH_CHECK(false, "resolveValue: can not resolve '", name, "{}'");
|
||||
}
|
||||
|
||||
torch::jit::TypePtr resolveType(
|
||||
const std::string& name,
|
||||
const torch::jit::SourceRange& loc) override {
|
||||
return nullptr;
|
||||
}
|
||||
};
|
||||
|
||||
/*
|
||||
* - The python source code parsing for TorchScript here is learned from
|
||||
* torch::jit::compile.
|
||||
* - The code only parses one Def. If there are multiple in the code, those
|
||||
* except the first one are skipped.
|
||||
*/
|
||||
TEST(TestException, TestCustomException) {
|
||||
py::scoped_interpreter guard{};
|
||||
py::exec(R"PY(
|
||||
class SimpleValueError(ValueError):
|
||||
def __init__(self, message):
|
||||
super(SimpleValueError, self).__init__(message)
|
||||
)PY");
|
||||
|
||||
std::string pythonCode = R"PY(
|
||||
def foo():
|
||||
raise SimpleValueError("An assertion failed")
|
||||
)PY";
|
||||
|
||||
torch::jit::Parser p(
|
||||
std::make_shared<torch::jit::Source>(pythonCode, "<string>", 1));
|
||||
auto def = torch::jit::Def(p.parseFunction(/*is_method=*/false));
|
||||
std::cerr << "Def is:\n" << def << std::endl;
|
||||
auto cu = std::make_shared<torch::jit::CompilationUnit>();
|
||||
(void)cu->define(
|
||||
c10::nullopt,
|
||||
{},
|
||||
{},
|
||||
{def},
|
||||
// class PythonResolver is defined in
|
||||
// torch/csrc/jit/python/script_init.cpp. It's not in a header file so I
|
||||
// can not use it. Create a SimpleResolver insteand
|
||||
{std::make_shared<SimpleResolver>()},
|
||||
nullptr);
|
||||
torch::jit::GraphFunction* gf =
|
||||
(torch::jit::GraphFunction*)&cu->get_function("foo");
|
||||
std::cerr << "Graph is\n" << *gf->graph() << std::endl;
|
||||
bool is_jit_exception = false;
|
||||
c10::optional<std::string> exception_class;
|
||||
std::string message;
|
||||
try {
|
||||
cu->run_method("foo");
|
||||
} catch (JITException& e) {
|
||||
is_jit_exception = true;
|
||||
exception_class = e.getPythonClassName();
|
||||
message = e.what();
|
||||
}
|
||||
EXPECT_TRUE(is_jit_exception);
|
||||
EXPECT_EQ("__main__.SimpleValueError", *exception_class);
|
||||
EXPECT_TRUE(
|
||||
message.find("__main__.SimpleValueError: An assertion failed") !=
|
||||
std::string::npos);
|
||||
}
|
||||
|
||||
} // namespace jit
|
||||
} // namespace torch
|
||||
@ -1,8 +0,0 @@
|
||||
r"""
|
||||
Define exceptions used in test_exception.py. We define them in a
|
||||
separate file on purpose to make sure the fully qualified exception class name
|
||||
is captured correctly in suce cases.
|
||||
"""
|
||||
class MyKeyError(KeyError):
|
||||
def __init__(self, msg):
|
||||
super(KeyError, self).__init__(msg)
|
||||
@ -1,175 +0,0 @@
|
||||
from torch.testing._internal.common_utils import TestCase
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
r"""
|
||||
Test TorchScript exception handling.
|
||||
"""
|
||||
class TestException(TestCase):
|
||||
def test_assertions(self):
|
||||
cu = torch.jit.CompilationUnit('''
|
||||
def foo(cond):
|
||||
assert bool(cond), "hi"
|
||||
return 0
|
||||
''')
|
||||
|
||||
cu.foo(torch.tensor(1))
|
||||
with self.assertRaisesRegex(torch.jit.Error, "AssertionError: hi"):
|
||||
cu.foo(torch.tensor(0))
|
||||
|
||||
@torch.jit.script
|
||||
def foo(cond):
|
||||
assert bool(cond), "hi"
|
||||
|
||||
foo(torch.tensor(1))
|
||||
# we don't currently validate the name of the exception
|
||||
with self.assertRaisesRegex(torch.jit.Error, "AssertionError: hi"):
|
||||
foo(torch.tensor(0))
|
||||
|
||||
def test_pyop_exception_message(self):
|
||||
class Foo(torch.jit.ScriptModule):
|
||||
def __init__(self):
|
||||
super(Foo, self).__init__()
|
||||
self.conv = nn.Conv2d(1, 10, kernel_size=5)
|
||||
|
||||
@torch.jit.script_method
|
||||
def forward(self, x):
|
||||
return self.conv(x)
|
||||
foo = Foo()
|
||||
# testing that the correct error message propagates
|
||||
with self.assertRaisesRegex(RuntimeError, "Expected 4-dimensional input for 4-dimensional weight"):
|
||||
foo(torch.ones([123])) # wrong size
|
||||
|
||||
def test_builtin_error_messsage(self):
|
||||
with self.assertRaisesRegex(RuntimeError, "Arguments for call are not valid"):
|
||||
@torch.jit.script
|
||||
def close_match(x):
|
||||
return x.masked_fill(True)
|
||||
|
||||
with self.assertRaisesRegex(RuntimeError, "This op may not exist or may not be currently "
|
||||
"supported in TorchScript"):
|
||||
@torch.jit.script
|
||||
def unknown_op(x):
|
||||
torch.set_anomaly_enabled(True)
|
||||
return x
|
||||
|
||||
def test_exceptions(self):
|
||||
cu = torch.jit.CompilationUnit('''
|
||||
def foo(cond):
|
||||
if bool(cond):
|
||||
raise ValueError(3)
|
||||
return 1
|
||||
''')
|
||||
|
||||
cu.foo(torch.tensor(0))
|
||||
with self.assertRaisesRegex(torch.jit.Error, "3"):
|
||||
cu.foo(torch.tensor(1))
|
||||
|
||||
def foo(cond):
|
||||
a = 3
|
||||
if bool(cond):
|
||||
raise ArbitraryError(a, "hi")
|
||||
if 1 == 2:
|
||||
raise ArbitraryError
|
||||
return a
|
||||
|
||||
with self.assertRaisesRegex(RuntimeError, "undefined value ArbitraryError"):
|
||||
torch.jit.script(foo)
|
||||
|
||||
def exception_as_value():
|
||||
a = Exception()
|
||||
print(a)
|
||||
|
||||
with self.assertRaisesRegex(RuntimeError, "cannot be used as a value"):
|
||||
torch.jit.script(exception_as_value)
|
||||
|
||||
@torch.jit.script
|
||||
def foo_no_decl_always_throws():
|
||||
raise RuntimeError("Hi")
|
||||
|
||||
# function that has no declared type but always throws set to None
|
||||
output_type = next(foo_no_decl_always_throws.graph.outputs()).type()
|
||||
self.assertTrue(str(output_type) == "NoneType")
|
||||
|
||||
@torch.jit.script
|
||||
def foo_decl_always_throws():
|
||||
# type: () -> Tensor
|
||||
raise Exception("Hi")
|
||||
|
||||
output_type = next(foo_decl_always_throws.graph.outputs()).type()
|
||||
self.assertTrue(str(output_type) == "Tensor")
|
||||
|
||||
def foo():
|
||||
raise 3 + 4
|
||||
|
||||
with self.assertRaisesRegex(RuntimeError, "must derive from BaseException"):
|
||||
torch.jit.script(foo)
|
||||
|
||||
# a escapes scope
|
||||
@torch.jit.script
|
||||
def foo():
|
||||
if 1 == 1:
|
||||
a = 1
|
||||
else:
|
||||
if 1 == 1:
|
||||
raise Exception("Hi")
|
||||
else:
|
||||
raise Exception("Hi")
|
||||
return a
|
||||
self.assertEqual(foo(), 1)
|
||||
|
||||
@torch.jit.script
|
||||
def tuple_fn():
|
||||
raise RuntimeError("hello", "goodbye")
|
||||
|
||||
with self.assertRaisesRegex(torch.jit.Error, "hello, goodbye"):
|
||||
tuple_fn()
|
||||
|
||||
@torch.jit.script
|
||||
def no_message():
|
||||
raise RuntimeError
|
||||
|
||||
with self.assertRaisesRegex(torch.jit.Error, "RuntimeError"):
|
||||
no_message()
|
||||
|
||||
def test_python_op_exception(self):
|
||||
@torch.jit.ignore
|
||||
def python_op(x):
|
||||
raise Exception("bad!")
|
||||
|
||||
@torch.jit.script
|
||||
def fn(x):
|
||||
return python_op(x)
|
||||
|
||||
with self.assertRaisesRegex(RuntimeError, "operation failed in the TorchScript interpreter"):
|
||||
fn(torch.tensor(4))
|
||||
|
||||
def test_dict_expansion_raises_error(self):
|
||||
def fn(self):
|
||||
d = {"foo": 1, "bar": 2, "baz": 3}
|
||||
return {**d}
|
||||
|
||||
with self.assertRaisesRegex(torch.jit.frontend.NotSupportedError,
|
||||
"Dict expansion "):
|
||||
torch.jit.script(fn)
|
||||
|
||||
def test_custom_python_exception(self):
|
||||
class MyValueError(ValueError):
|
||||
def __init__(self, msg):
|
||||
super(MyValueError, self).__init__(msg)
|
||||
|
||||
@torch.jit.script
|
||||
def fn():
|
||||
raise MyValueError("test custom exception")
|
||||
|
||||
with self.assertRaisesRegex(torch.jit.Error, "jit.test_exception.MyValueError: test custom exception"):
|
||||
fn()
|
||||
|
||||
def test_custom_python_exception_defined_elsewhere(self):
|
||||
from jit.myexception import MyKeyError
|
||||
|
||||
@torch.jit.script
|
||||
def fn():
|
||||
raise MyKeyError("This is a user defined key error")
|
||||
with self.assertRaisesRegex(torch.jit.Error, "jit.myexception.MyKeyError: This is a user defined key error"):
|
||||
fn()
|
||||
148
test/test_jit.py
148
test/test_jit.py
@ -73,7 +73,6 @@ from jit.test_batch_mm import TestBatchMM # noqa: F401
|
||||
from jit.test_dtype_analysis import TestDtypeAnalysis, TestDtypeCustomRulesCPU # noqa: F401
|
||||
from jit.test_dce import TestDCE # noqa: F401
|
||||
from jit.test_sparse import TestSparse # noqa: F401
|
||||
from jit.test_exception import TestException # noqa: F401
|
||||
|
||||
# Torch
|
||||
from torch import Tensor
|
||||
@ -12994,6 +12993,153 @@ dedent """
|
||||
|
||||
self.checkScript(dedent(code), (101,))
|
||||
|
||||
def test_pyop_exception_message(self):
|
||||
class Foo(torch.jit.ScriptModule):
|
||||
def __init__(self):
|
||||
super(Foo, self).__init__()
|
||||
self.conv = nn.Conv2d(1, 10, kernel_size=5)
|
||||
|
||||
@torch.jit.script_method
|
||||
def forward(self, x):
|
||||
return self.conv(x)
|
||||
foo = Foo()
|
||||
# testing that the correct error message propagates
|
||||
with self.assertRaisesRegex(RuntimeError, "Expected 4-dimensional input for 4-dimensional weight"):
|
||||
foo(torch.ones([123])) # wrong size
|
||||
|
||||
def test_builtin_error_messsage(self):
|
||||
with self.assertRaisesRegex(RuntimeError, "Arguments for call are not valid"):
|
||||
@torch.jit.script
|
||||
def close_match(x):
|
||||
return x.masked_fill(True)
|
||||
|
||||
with self.assertRaisesRegex(RuntimeError, "This op may not exist or may not be currently "
|
||||
"supported in TorchScript"):
|
||||
@torch.jit.script
|
||||
def unknown_op(x):
|
||||
torch.set_anomaly_enabled(True)
|
||||
return x
|
||||
|
||||
def test_exceptions(self):
|
||||
cu = torch.jit.CompilationUnit('''
|
||||
def foo(cond):
|
||||
if bool(cond):
|
||||
raise ValueError(3)
|
||||
return 1
|
||||
''')
|
||||
|
||||
cu.foo(torch.tensor(0))
|
||||
with self.assertRaisesRegex(torch.jit.Error, "3"):
|
||||
cu.foo(torch.tensor(1))
|
||||
|
||||
def foo(cond):
|
||||
a = 3
|
||||
if bool(cond):
|
||||
raise ArbitraryError(a, "hi")
|
||||
if 1 == 2:
|
||||
raise ArbitraryError
|
||||
return a
|
||||
|
||||
with self.assertRaisesRegex(RuntimeError, "undefined value ArbitraryError"):
|
||||
torch.jit.script(foo)
|
||||
|
||||
def exception_as_value():
|
||||
a = Exception()
|
||||
print(a)
|
||||
|
||||
with self.assertRaisesRegex(RuntimeError, "cannot be used as a value"):
|
||||
torch.jit.script(exception_as_value)
|
||||
|
||||
@torch.jit.script
|
||||
def foo_no_decl_always_throws():
|
||||
raise RuntimeError("Hi")
|
||||
|
||||
# function that has no declared type but always throws set to None
|
||||
output_type = next(foo_no_decl_always_throws.graph.outputs()).type()
|
||||
self.assertTrue(str(output_type) == "NoneType")
|
||||
|
||||
@torch.jit.script
|
||||
def foo_decl_always_throws():
|
||||
# type: () -> Tensor
|
||||
raise Exception("Hi")
|
||||
|
||||
output_type = next(foo_decl_always_throws.graph.outputs()).type()
|
||||
self.assertTrue(str(output_type) == "Tensor")
|
||||
|
||||
def foo():
|
||||
raise 3 + 4
|
||||
|
||||
with self.assertRaisesRegex(RuntimeError, "must derive from BaseException"):
|
||||
torch.jit.script(foo)
|
||||
|
||||
# a escapes scope
|
||||
@torch.jit.script
|
||||
def foo():
|
||||
if 1 == 1:
|
||||
a = 1
|
||||
else:
|
||||
if 1 == 1:
|
||||
raise Exception("Hi")
|
||||
else:
|
||||
raise Exception("Hi")
|
||||
return a
|
||||
self.assertEqual(foo(), 1)
|
||||
|
||||
@torch.jit.script
|
||||
def tuple_fn():
|
||||
raise RuntimeError("hello", "goodbye")
|
||||
|
||||
with self.assertRaisesRegex(torch.jit.Error, "hello, goodbye"):
|
||||
tuple_fn()
|
||||
|
||||
@torch.jit.script
|
||||
def no_message():
|
||||
raise RuntimeError
|
||||
|
||||
with self.assertRaisesRegex(torch.jit.Error, "RuntimeError"):
|
||||
no_message()
|
||||
|
||||
def test_assertions(self):
|
||||
cu = torch.jit.CompilationUnit('''
|
||||
def foo(cond):
|
||||
assert bool(cond), "hi"
|
||||
return 0
|
||||
''')
|
||||
|
||||
cu.foo(torch.tensor(1))
|
||||
with self.assertRaisesRegex(torch.jit.Error, "AssertionError: hi"):
|
||||
cu.foo(torch.tensor(0))
|
||||
|
||||
@torch.jit.script
|
||||
def foo(cond):
|
||||
assert bool(cond), "hi"
|
||||
|
||||
foo(torch.tensor(1))
|
||||
# we don't currently validate the name of the exception
|
||||
with self.assertRaisesRegex(torch.jit.Error, "AssertionError: hi"):
|
||||
foo(torch.tensor(0))
|
||||
|
||||
def test_python_op_exception(self):
|
||||
@torch.jit.ignore
|
||||
def python_op(x):
|
||||
raise Exception("bad!")
|
||||
|
||||
@torch.jit.script
|
||||
def fn(x):
|
||||
return python_op(x)
|
||||
|
||||
with self.assertRaisesRegex(RuntimeError, "operation failed in the TorchScript interpreter"):
|
||||
fn(torch.tensor(4))
|
||||
|
||||
def test_dict_expansion_raises_error(self):
|
||||
def fn(self):
|
||||
d = {"foo": 1, "bar": 2, "baz": 3}
|
||||
return {**d}
|
||||
|
||||
with self.assertRaisesRegex(torch.jit.frontend.NotSupportedError,
|
||||
"Dict expansion "):
|
||||
torch.jit.script(fn)
|
||||
|
||||
def test_module_parameters_and_buffers(self):
|
||||
weights = torch.randn(10, 10)
|
||||
bias = torch.randn(10)
|
||||
|
||||
@ -977,7 +977,7 @@ def is_scripting() -> bool:
|
||||
|
||||
|
||||
# Retrieves a fully-qualified name (module hierarchy + classname) for a given obj.
|
||||
def _qualified_name(obj, mangle_name=True) -> str:
|
||||
def _qualified_name(obj) -> str:
|
||||
# This special case allows us to override the qualified name on a type.
|
||||
# It's currently used in conjunction with tracing, where we create a
|
||||
# fake module to filter only supported attributes. However, since this
|
||||
@ -1026,16 +1026,13 @@ def _qualified_name(obj, mangle_name=True) -> str:
|
||||
module_name = module_name.replace("<", "_")
|
||||
module_name = module_name.replace(">", "_")
|
||||
|
||||
# The PythonExceptionValue C++ class in torch/csrc/jit/python/python_sugared_value.h
|
||||
# does not need mangle the python class name.
|
||||
if mangle_name:
|
||||
# __main__ is a builtin module, so rewrite it to "__torch__".
|
||||
if module_name == "__main__":
|
||||
module_name = "__torch__"
|
||||
else:
|
||||
# Everything else gets a "__torch__" prefix to avoid name collisions
|
||||
# with the names of user values.
|
||||
module_name = "__torch__." + module_name
|
||||
# __main__ is a builtin module, so rewrite it to "__torch__".
|
||||
if module_name == "__main__":
|
||||
module_name = "__torch__"
|
||||
else:
|
||||
# Everything else gets a "__torch__" prefix to avoid name collisions
|
||||
# with the names of user values.
|
||||
module_name = "__torch__." + module_name
|
||||
|
||||
if "." in name:
|
||||
raise RuntimeError(f"Could not get qualified name for class '{name}': "
|
||||
|
||||
@ -2469,14 +2469,12 @@ struct to_ir {
|
||||
void emitRaise(const Raise& raise) {
|
||||
auto sv = emitSugaredExpr(raise.expr(), 1);
|
||||
Value* error_message = nullptr;
|
||||
Value* qualified_class_name = nullptr;
|
||||
|
||||
if (auto exception_instance =
|
||||
std::dynamic_pointer_cast<ExceptionMessageValue>(sv)) {
|
||||
// The typical case, an instance of the exception class was thrown:
|
||||
// raise RuntimeError("error")
|
||||
error_message = exception_instance->getValue();
|
||||
qualified_class_name = exception_instance->getQualifiedClassName();
|
||||
} else if (
|
||||
auto exception_class = std::dynamic_pointer_cast<ExceptionValue>(sv)) {
|
||||
// A bare exception was thrown so add an empty message. e.g.
|
||||
@ -2493,11 +2491,7 @@ struct to_ir {
|
||||
error_message = graph->insert(aten::str, {error_message});
|
||||
}
|
||||
|
||||
graph->insert(
|
||||
prim::RaiseException,
|
||||
{error_message, qualified_class_name},
|
||||
{},
|
||||
raise.range());
|
||||
graph->insert(prim::RaiseException, {error_message}, {}, raise.range());
|
||||
exit_blocks.insert(environment_stack->block());
|
||||
}
|
||||
|
||||
|
||||
@ -741,10 +741,7 @@ struct SimpleSelf : public Self {
|
||||
// This is not a SimpleValue so it can not pass through the code paths that
|
||||
// expect a SimpleValue as a sugared value.
|
||||
struct TORCH_API ExceptionMessageValue : public SugaredValue {
|
||||
explicit ExceptionMessageValue(
|
||||
Value* value,
|
||||
Value* qualified_class_name = nullptr)
|
||||
: value_(value), qualified_class_name_(qualified_class_name) {}
|
||||
explicit ExceptionMessageValue(Value* value) : value_(value) {}
|
||||
|
||||
std::string kind() const override {
|
||||
return "exception message";
|
||||
@ -754,14 +751,7 @@ struct TORCH_API ExceptionMessageValue : public SugaredValue {
|
||||
return value_;
|
||||
}
|
||||
|
||||
// qualified python class name
|
||||
Value* getQualifiedClassName() {
|
||||
return qualified_class_name_;
|
||||
}
|
||||
|
||||
private:
|
||||
Value* value_;
|
||||
Value* qualified_class_name_;
|
||||
};
|
||||
|
||||
struct TORCH_API ExceptionValue : public SugaredValue {
|
||||
|
||||
@ -14,12 +14,7 @@ void tupleIndex(Stack& stack) {
|
||||
}
|
||||
|
||||
void raiseException(Stack& stack) {
|
||||
c10::optional<std::string> qualified_class_name =
|
||||
pop(stack).toOptional<std::string>();
|
||||
std::string message;
|
||||
pop(stack, message);
|
||||
|
||||
throw JITException(message, qualified_class_name);
|
||||
throw JITException(pop(stack).toStringRef());
|
||||
}
|
||||
|
||||
void is(Stack& stack) {
|
||||
|
||||
@ -914,11 +914,8 @@ std::shared_ptr<SugaredValue> PythonExceptionValue::call(
|
||||
->insertNode(caller.graph()->createTuple(message_values))
|
||||
->output();
|
||||
}
|
||||
Value* qualified_class_name =
|
||||
insertConstant(*caller.graph(), exception_class_qualified_name_, loc);
|
||||
|
||||
return std::make_shared<ExceptionMessageValue>(
|
||||
error_message, qualified_class_name);
|
||||
return std::make_shared<ExceptionMessageValue>(error_message);
|
||||
}
|
||||
|
||||
bool isNamedTupleClass(const py::object& obj) {
|
||||
|
||||
@ -328,12 +328,7 @@ struct VISIBILITY_HIDDEN PythonClassValue : public ClassValue {
|
||||
struct VISIBILITY_HIDDEN PythonExceptionValue : public ExceptionValue {
|
||||
explicit PythonExceptionValue(const py::object& exception_class)
|
||||
: ExceptionValue(
|
||||
py::str(py::getattr(exception_class, "__name__", py::str("")))),
|
||||
exception_class_qualified_name_(
|
||||
py::str(py::module::import("torch._jit_internal")
|
||||
.attr("_qualified_name")(
|
||||
exception_class,
|
||||
/*mangle_name=*/false))) {}
|
||||
py::str(py::getattr(exception_class, "__name__", py::str("")))) {}
|
||||
|
||||
std::string kind() const override {
|
||||
return "Python exception";
|
||||
@ -345,9 +340,6 @@ struct VISIBILITY_HIDDEN PythonExceptionValue : public ExceptionValue {
|
||||
at::ArrayRef<NamedValue> args,
|
||||
at::ArrayRef<NamedValue> kwargs,
|
||||
size_t n_binders) override;
|
||||
|
||||
private:
|
||||
std::string exception_class_qualified_name_;
|
||||
};
|
||||
|
||||
// Python Slice class.
|
||||
|
||||
@ -714,19 +714,10 @@ struct InterpreterStateImpl : c10::intrusive_ptr_target {
|
||||
}
|
||||
throw;
|
||||
}
|
||||
auto* jit_exception = dynamic_cast<JITException*>(&e);
|
||||
bool is_jit_exception = dynamic_cast<JITException*>(&e);
|
||||
// Janky af. See https://github.com/pytorch/pytorch/issues/54612
|
||||
auto* not_implemented_error = dynamic_cast<c10::NotImplementedError*>(&e);
|
||||
|
||||
c10::optional<std::string> python_class_name;
|
||||
if (jit_exception) {
|
||||
python_class_name = jit_exception->getPythonClassName();
|
||||
}
|
||||
handleError(
|
||||
ExceptionMessage(e),
|
||||
(bool)jit_exception,
|
||||
not_implemented_error,
|
||||
python_class_name);
|
||||
handleError(ExceptionMessage(e), is_jit_exception, not_implemented_error);
|
||||
return false;
|
||||
}
|
||||
}
|
||||
@ -745,18 +736,15 @@ struct InterpreterStateImpl : c10::intrusive_ptr_target {
|
||||
void handleError(
|
||||
const ExceptionMessage& msg,
|
||||
bool is_jit_exception,
|
||||
c10::NotImplementedError* not_implemented_error,
|
||||
c10::optional<std::string> python_class_name) {
|
||||
c10::NotImplementedError* not_implemented_error) {
|
||||
std::ostringstream ss;
|
||||
std::string class_name =
|
||||
python_class_name ? *python_class_name : "RuntimeError";
|
||||
ss << "The following operation failed in the TorchScript interpreter.\n";
|
||||
formatStackTrace(ss);
|
||||
ss << class_name << ": " << msg << "\n";
|
||||
ss << "RuntimeError: " << msg << "\n";
|
||||
if (future_) {
|
||||
future_->setError(std::make_exception_ptr(Future::FutureError(ss.str())));
|
||||
} else if (is_jit_exception) {
|
||||
throw JITException(ss.str(), python_class_name);
|
||||
throw JITException(ss.str());
|
||||
} else if (not_implemented_error) {
|
||||
throw c10::NotImplementedError(
|
||||
ss.str(),
|
||||
|
||||
@ -3,11 +3,7 @@
|
||||
namespace torch {
|
||||
namespace jit {
|
||||
|
||||
JITException::JITException(
|
||||
const std::string& msg,
|
||||
c10::optional<std::string> python_class_name)
|
||||
: std::runtime_error(msg),
|
||||
python_class_name_(std::move(python_class_name)) {}
|
||||
JITException::JITException(const std::string& msg) : std::runtime_error(msg) {}
|
||||
|
||||
} // namespace jit
|
||||
} // namespace torch
|
||||
|
||||
@ -2,24 +2,13 @@
|
||||
|
||||
#include <stdexcept>
|
||||
|
||||
#include <c10/util/Optional.h>
|
||||
#include <torch/csrc/Export.h>
|
||||
#include <string>
|
||||
|
||||
namespace torch {
|
||||
namespace jit {
|
||||
|
||||
struct TORCH_API JITException : public std::runtime_error {
|
||||
explicit JITException(
|
||||
const std::string& msg,
|
||||
c10::optional<std::string> python_class_name = c10::nullopt);
|
||||
|
||||
c10::optional<std::string> getPythonClassName() const {
|
||||
return python_class_name_;
|
||||
}
|
||||
|
||||
private:
|
||||
c10::optional<std::string> python_class_name_;
|
||||
explicit JITException(const std::string& msg);
|
||||
};
|
||||
|
||||
} // namespace jit
|
||||
|
||||
@ -406,8 +406,7 @@ static const OperatorGeneratorArgs opGenArgs[] = {
|
||||
numToTensorScalar,
|
||||
aliasAnalysisFromSchema()),
|
||||
OperatorGeneratorArgs(
|
||||
TORCH_SELECTIVE_SCHEMA(
|
||||
"prim::RaiseException(str msg, str? cls=None) -> ()"),
|
||||
TORCH_SELECTIVE_SCHEMA("prim::RaiseException(str msg) -> ()"),
|
||||
raiseException,
|
||||
aliasAnalysisFromSchema()),
|
||||
OperatorGeneratorArgs(
|
||||
|
||||
Reference in New Issue
Block a user