Make TorchScript Preserve Fully Qualified Class Name for Python Exceptions (#70339)

Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/70339

When a python program is translated to TorchScript, the python exception type is dropped. This makes users's life hard when they need to categorize errors based more than only exception message.

Here we make the change so when we raise a python exception, we record the fully qualified class name for the exception. Later on when the TorchScript is interpreted, a special exception CustomJITException is thrown. User can get the python class name from CustomJITException::getPythonClassName .

Note that, this diff does not customize the mapping from C++ exception to Python exception. It's left to the users to do whatever mapping they want.

Code under scripts/shunting are just my own experimental code. I can split them out if requested.
ghstack-source-id: 146221879

Test Plan: buck test mode/opt //caffe2/test:jit

Reviewed By: gmagogsfm

Differential Revision: D33282878

fbshipit-source-id: 910f67a764519f1053a48589d1a34df69001525d
This commit is contained in:
Shunting Zhang
2021-12-24 00:24:10 -08:00
committed by Facebook GitHub Bot
parent ab4f9862a3
commit 911d527b87
14 changed files with 427 additions and 168 deletions

View File

@ -914,8 +914,11 @@ std::shared_ptr<SugaredValue> PythonExceptionValue::call(
->insertNode(caller.graph()->createTuple(message_values))
->output();
}
Value* qualified_class_name =
insertConstant(*caller.graph(), exception_class_qualified_name_, loc);
return std::make_shared<ExceptionMessageValue>(error_message);
return std::make_shared<ExceptionMessageValue>(
error_message, qualified_class_name);
}
bool isNamedTupleClass(const py::object& obj) {