diff --git a/torch/csrc/jit/backends/coreml/objc/PTMCoreMLExecutor.mm b/torch/csrc/jit/backends/coreml/objc/PTMCoreMLExecutor.mm index 2df73527d58e..fbb7abe87b52 100644 --- a/torch/csrc/jit/backends/coreml/objc/PTMCoreMLExecutor.mm +++ b/torch/csrc/jit/backends/coreml/objc/PTMCoreMLExecutor.mm @@ -14,6 +14,24 @@ #include #include +// This is a utility macro that can be used to throw an exception when a CoreML +// API function produces a NSError. The exception will contain a message with +// useful info extracted from the NSError. +#define COREML_THROW_IF_ERROR(error, preamble) \ + do { \ + if C10_LIKELY(error) { \ + throw c10::Error( \ + {__func__, __FILE__, static_cast(__LINE__)}, \ + c10::str( \ + preamble, \ + " Error details: ", \ + " Localized_description: ", error.localizedDescription.UTF8String, \ + " Domain: ", error.domain.UTF8String, \ + " Code: ", error.code, \ + " User Info: ", error.userInfo.description.UTF8String)); \ + } \ + } while (false) + @implementation PTMCoreMLFeatureProvider { NSUInteger _coremlVersion; std::vector _specs; @@ -173,10 +191,8 @@ static NSString* gModelCacheDirectory = @""; // remove cached models if compalition failed. [self cleanup]; - TORCH_CHECK( - false, - "Error compiling the MLModel", - [error localizedDescription].UTF8String); + + COREML_THROW_IF_ERROR(error, "Error compiling the MLModel file!"); return NO; } if (@available(iOS 12.0, macOS 10.14, *)) { @@ -201,10 +217,7 @@ static NSString* gModelCacheDirectory = @""; observer->onExitCompileModel(instance_key, false, true); } - TORCH_CHECK( - false, - "Error loading the MLModel", - error.localizedDescription.UTF8String); + COREML_THROW_IF_ERROR(error, "Error loading the MLModel file!"); } if (observer) { @@ -240,12 +253,8 @@ static NSString* gModelCacheDirectory = @""; [_mlModel predictionFromFeatures:inputFeature options:options error:&error]; - if (error) { - TORCH_CHECK( - false, - "Error running the prediction", - error.localizedDescription.UTF8String); - } + + COREML_THROW_IF_ERROR(error, "Error running CoreML inference!"); ++_inferences; if (observer) {