Replace std::runtime_error with TORCH_CHECK (#159344)

Fixes part of #148114

Pull Request resolved: https://github.com/pytorch/pytorch/pull/159344
Approved by: https://github.com/cyyever, https://github.com/albanD
This commit is contained in:
zeshengzong
2025-09-16 09:00:02 +00:00
committed by PyTorch MergeBot
parent 9aca0ba027
commit e3783a9575
9 changed files with 99 additions and 117 deletions

View File

@ -62,7 +62,7 @@ static void setSignalHandler(
std::ostringstream oss; std::ostringstream oss;
oss << "An error occurred while setting handler for " << strsignal(signal) oss << "An error occurred while setting handler for " << strsignal(signal)
<< "."; << ".";
throw std::runtime_error(oss.str()); TORCH_CHECK(false, oss.str());
} }
} }
@ -141,29 +141,32 @@ static PyObject* THPModule_errorIfAnyWorkerFails(
continue; continue;
if (infop.si_code == CLD_EXITED && if (infop.si_code == CLD_EXITED &&
infop.si_status != EXIT_SUCCESS) { // exit with error infop.si_status != EXIT_SUCCESS) { // exit with error
std::ostringstream oss; auto error_msg = fmt::format(
oss << "DataLoader worker (pid " << worker_pid << ") exited " "DataLoader worker (pid {}) exited unexpectedly with exit code {}. "
<< "unexpectedly with exit code " << infop.si_status << ". " "Details are lost due to multiprocessing. Rerunning with "
<< "Details are lost due to multiprocessing. Rerunning with " "num_workers=0 may give better error trace.",
<< "num_workers=0 may give better error trace."; worker_pid,
infop.si_status);
// This is necessary. Otherwise, the runtime error will kill the other // This is necessary. Otherwise, the runtime error will kill the other
// workers, and trigger this again. // workers, and trigger this again.
pid_set.clear(); pid_set.clear();
throw std::runtime_error(oss.str()); TORCH_CHECK(false, error_msg);
} else if ( } else if (
infop.si_code == CLD_KILLED || infop.si_code == CLD_KILLED ||
infop.si_code == CLD_DUMPED) { // killed by signal infop.si_code == CLD_DUMPED) { // killed by signal
std::ostringstream oss; auto error_msg = fmt::format(
oss << "DataLoader worker (pid " << worker_pid << ") is killed " "DataLoader worker (pid {}) is killed by signal: {}. ",
<< "by signal: " << strsignal(infop.si_status) << ". "; worker_pid,
strsignal(infop.si_status));
if (infop.si_status == SIGBUS) { if (infop.si_status == SIGBUS) {
oss << "It is possible that dataloader's workers are out of shared memory. " error_msg +=
<< "Please try to raise your shared memory limit."; "It is possible that dataloader's workers are out of shared memory. "
"Please try to raise your shared memory limit.";
} }
// This is necessary. Otherwise, the runtime error will kill the other // This is necessary. Otherwise, the runtime error will kill the other
// workers, and trigger this again. // workers, and trigger this again.
pid_set.clear(); pid_set.clear();
throw std::runtime_error(oss.str()); TORCH_CHECK(false, error_msg);
} }
} }
} }

View File

@ -67,7 +67,8 @@ static PyObject* THPDevice_pynew(
auto as_device = r.device(0); // this works, because device can take strings auto as_device = r.device(0); // this works, because device can take strings
if (as_device.has_index()) { if (as_device.has_index()) {
auto device_type = r.string(0); auto device_type = r.string(0);
throw std::runtime_error( TORCH_CHECK(
false,
"type (string) must not include an index because index " "type (string) must not include an index because index "
"was passed explicitly: " + "was passed explicitly: " +
device_type); device_type);

View File

@ -10,6 +10,7 @@
#include <unordered_map> #include <unordered_map>
#include <variant> #include <variant>
#include <vector> #include <vector>
#include <c10/util/Exception.h>
#include <nlohmann/json.hpp> #include <nlohmann/json.hpp>
@ -190,7 +191,7 @@ inline std::string_view printEnum(const ArgumentKind& e) {
case ArgumentKind::POSITIONAL: return "POSITIONAL"; case ArgumentKind::POSITIONAL: return "POSITIONAL";
case ArgumentKind::KEYWORD: return "KEYWORD"; case ArgumentKind::KEYWORD: return "KEYWORD";
default: default:
throw std::runtime_error("Unknown enum value"); TORCH_CHECK(false, "Unknown enum value");
} }
} }
@ -198,7 +199,7 @@ inline void parseEnum(std::string_view s, ArgumentKind& t) {
if (s == "UNKNOWN") { t = ArgumentKind::UNKNOWN; return; } if (s == "UNKNOWN") { t = ArgumentKind::UNKNOWN; return; }
if (s == "POSITIONAL") { t = ArgumentKind::POSITIONAL; return; } if (s == "POSITIONAL") { t = ArgumentKind::POSITIONAL; return; }
if (s == "KEYWORD") { t = ArgumentKind::KEYWORD; return; } if (s == "KEYWORD") { t = ArgumentKind::KEYWORD; return; }
throw std::runtime_error("Unknown enum value: " + std::string{s}); TORCH_CHECK(false, "Unknown enum value: " + std::string{s});
} }
enum class Layout { enum class Layout {
@ -223,7 +224,7 @@ inline std::string_view printEnum(const Layout& e) {
case Layout::_mkldnn: return "_mkldnn"; case Layout::_mkldnn: return "_mkldnn";
case Layout::Strided: return "Strided"; case Layout::Strided: return "Strided";
default: default:
throw std::runtime_error("Unknown enum value"); TORCH_CHECK(false, "Unknown enum value");
} }
} }
@ -236,7 +237,7 @@ inline void parseEnum(std::string_view s, Layout& t) {
if (s == "SparseBsc") { t = Layout::SparseBsc; return; } if (s == "SparseBsc") { t = Layout::SparseBsc; return; }
if (s == "_mkldnn") { t = Layout::_mkldnn; return; } if (s == "_mkldnn") { t = Layout::_mkldnn; return; }
if (s == "Strided") { t = Layout::Strided; return; } if (s == "Strided") { t = Layout::Strided; return; }
throw std::runtime_error("Unknown enum value: " + std::string{s}); TORCH_CHECK(false, "Unknown enum value: " + std::string{s});
} }
enum class MemoryFormat { enum class MemoryFormat {
@ -255,7 +256,7 @@ inline std::string_view printEnum(const MemoryFormat& e) {
case MemoryFormat::ChannelsLast3d: return "ChannelsLast3d"; case MemoryFormat::ChannelsLast3d: return "ChannelsLast3d";
case MemoryFormat::PreserveFormat: return "PreserveFormat"; case MemoryFormat::PreserveFormat: return "PreserveFormat";
default: default:
throw std::runtime_error("Unknown enum value"); TORCH_CHECK(false, "Unknown enum value");
} }
} }
@ -265,7 +266,7 @@ inline void parseEnum(std::string_view s, MemoryFormat& t) {
if (s == "ChannelsLast") { t = MemoryFormat::ChannelsLast; return; } if (s == "ChannelsLast") { t = MemoryFormat::ChannelsLast; return; }
if (s == "ChannelsLast3d") { t = MemoryFormat::ChannelsLast3d; return; } if (s == "ChannelsLast3d") { t = MemoryFormat::ChannelsLast3d; return; }
if (s == "PreserveFormat") { t = MemoryFormat::PreserveFormat; return; } if (s == "PreserveFormat") { t = MemoryFormat::PreserveFormat; return; }
throw std::runtime_error("Unknown enum value: " + std::string{s}); TORCH_CHECK(false, "Unknown enum value: " + std::string{s});
} }
enum class ScalarType { enum class ScalarType {
@ -312,7 +313,7 @@ inline std::string_view printEnum(const ScalarType& e) {
case ScalarType::FLOAT8E4M3FNUZ: return "FLOAT8E4M3FNUZ"; case ScalarType::FLOAT8E4M3FNUZ: return "FLOAT8E4M3FNUZ";
case ScalarType::FLOAT8E5M2FNUZ: return "FLOAT8E5M2FNUZ"; case ScalarType::FLOAT8E5M2FNUZ: return "FLOAT8E5M2FNUZ";
default: default:
throw std::runtime_error("Unknown enum value"); TORCH_CHECK(false, "Unknown enum value");
} }
} }
@ -336,7 +337,7 @@ inline void parseEnum(std::string_view s, ScalarType& t) {
if (s == "FLOAT8E5M2") { t = ScalarType::FLOAT8E5M2; return; } if (s == "FLOAT8E5M2") { t = ScalarType::FLOAT8E5M2; return; }
if (s == "FLOAT8E4M3FNUZ") { t = ScalarType::FLOAT8E4M3FNUZ; return; } if (s == "FLOAT8E4M3FNUZ") { t = ScalarType::FLOAT8E4M3FNUZ; return; }
if (s == "FLOAT8E5M2FNUZ") { t = ScalarType::FLOAT8E5M2FNUZ; return; } if (s == "FLOAT8E5M2FNUZ") { t = ScalarType::FLOAT8E5M2FNUZ; return; }
throw std::runtime_error("Unknown enum value: " + std::string{s}); TORCH_CHECK(false, "Unknown enum value: " + std::string{s});
} }
@ -453,7 +454,7 @@ inline std::string_view printEnum(const SymExprHint::Tag& e) {
case SymExprHint::Tag::AS_BOOL: return "AS_BOOL"; case SymExprHint::Tag::AS_BOOL: return "AS_BOOL";
case SymExprHint::Tag::AS_FLOAT: return "AS_FLOAT"; case SymExprHint::Tag::AS_FLOAT: return "AS_FLOAT";
default: default:
throw std::runtime_error("Unknown enum value"); TORCH_CHECK(false, "Unknown enum value");
} }
} }
@ -461,7 +462,7 @@ inline void parseEnum(std::string_view s, SymExprHint::Tag& t) {
if (s == "AS_INT") { t = SymExprHint::Tag::AS_INT; return; } if (s == "AS_INT") { t = SymExprHint::Tag::AS_INT; return; }
if (s == "AS_BOOL") { t = SymExprHint::Tag::AS_BOOL; return; } if (s == "AS_BOOL") { t = SymExprHint::Tag::AS_BOOL; return; }
if (s == "AS_FLOAT") { t = SymExprHint::Tag::AS_FLOAT; return; } if (s == "AS_FLOAT") { t = SymExprHint::Tag::AS_FLOAT; return; }
throw std::runtime_error("Unknown enum value: " + std::string{s}); TORCH_CHECK(false, "Unknown enum value: " + std::string{s});
} }
@ -559,14 +560,14 @@ inline std::string_view printEnum(const SymInt::Tag& e) {
case SymInt::Tag::AS_EXPR: return "AS_EXPR"; case SymInt::Tag::AS_EXPR: return "AS_EXPR";
case SymInt::Tag::AS_INT: return "AS_INT"; case SymInt::Tag::AS_INT: return "AS_INT";
default: default:
throw std::runtime_error("Unknown enum value"); TORCH_CHECK(false, "Unknown enum value");
} }
} }
inline void parseEnum(std::string_view s, SymInt::Tag& t) { inline void parseEnum(std::string_view s, SymInt::Tag& t) {
if (s == "AS_EXPR") { t = SymInt::Tag::AS_EXPR; return; } if (s == "AS_EXPR") { t = SymInt::Tag::AS_EXPR; return; }
if (s == "AS_INT") { t = SymInt::Tag::AS_INT; return; } if (s == "AS_INT") { t = SymInt::Tag::AS_INT; return; }
throw std::runtime_error("Unknown enum value: " + std::string{s}); TORCH_CHECK(false, "Unknown enum value: " + std::string{s});
} }
@ -637,14 +638,14 @@ inline std::string_view printEnum(const SymFloat::Tag& e) {
case SymFloat::Tag::AS_EXPR: return "AS_EXPR"; case SymFloat::Tag::AS_EXPR: return "AS_EXPR";
case SymFloat::Tag::AS_FLOAT: return "AS_FLOAT"; case SymFloat::Tag::AS_FLOAT: return "AS_FLOAT";
default: default:
throw std::runtime_error("Unknown enum value"); TORCH_CHECK(false, "Unknown enum value");
} }
} }
inline void parseEnum(std::string_view s, SymFloat::Tag& t) { inline void parseEnum(std::string_view s, SymFloat::Tag& t) {
if (s == "AS_EXPR") { t = SymFloat::Tag::AS_EXPR; return; } if (s == "AS_EXPR") { t = SymFloat::Tag::AS_EXPR; return; }
if (s == "AS_FLOAT") { t = SymFloat::Tag::AS_FLOAT; return; } if (s == "AS_FLOAT") { t = SymFloat::Tag::AS_FLOAT; return; }
throw std::runtime_error("Unknown enum value: " + std::string{s}); TORCH_CHECK(false, "Unknown enum value: " + std::string{s});
} }
@ -715,14 +716,14 @@ inline std::string_view printEnum(const SymBool::Tag& e) {
case SymBool::Tag::AS_EXPR: return "AS_EXPR"; case SymBool::Tag::AS_EXPR: return "AS_EXPR";
case SymBool::Tag::AS_BOOL: return "AS_BOOL"; case SymBool::Tag::AS_BOOL: return "AS_BOOL";
default: default:
throw std::runtime_error("Unknown enum value"); TORCH_CHECK(false, "Unknown enum value");
} }
} }
inline void parseEnum(std::string_view s, SymBool::Tag& t) { inline void parseEnum(std::string_view s, SymBool::Tag& t) {
if (s == "AS_EXPR") { t = SymBool::Tag::AS_EXPR; return; } if (s == "AS_EXPR") { t = SymBool::Tag::AS_EXPR; return; }
if (s == "AS_BOOL") { t = SymBool::Tag::AS_BOOL; return; } if (s == "AS_BOOL") { t = SymBool::Tag::AS_BOOL; return; }
throw std::runtime_error("Unknown enum value: " + std::string{s}); TORCH_CHECK(false, "Unknown enum value: " + std::string{s});
} }
@ -865,14 +866,14 @@ inline std::string_view printEnum(const SymIntArgument::Tag& e) {
case SymIntArgument::Tag::AS_NAME: return "AS_NAME"; case SymIntArgument::Tag::AS_NAME: return "AS_NAME";
case SymIntArgument::Tag::AS_INT: return "AS_INT"; case SymIntArgument::Tag::AS_INT: return "AS_INT";
default: default:
throw std::runtime_error("Unknown enum value"); TORCH_CHECK(false, "Unknown enum value");
} }
} }
inline void parseEnum(std::string_view s, SymIntArgument::Tag& t) { inline void parseEnum(std::string_view s, SymIntArgument::Tag& t) {
if (s == "AS_NAME") { t = SymIntArgument::Tag::AS_NAME; return; } if (s == "AS_NAME") { t = SymIntArgument::Tag::AS_NAME; return; }
if (s == "AS_INT") { t = SymIntArgument::Tag::AS_INT; return; } if (s == "AS_INT") { t = SymIntArgument::Tag::AS_INT; return; }
throw std::runtime_error("Unknown enum value: " + std::string{s}); TORCH_CHECK(false, "Unknown enum value: " + std::string{s});
} }
@ -943,14 +944,14 @@ inline std::string_view printEnum(const SymFloatArgument::Tag& e) {
case SymFloatArgument::Tag::AS_NAME: return "AS_NAME"; case SymFloatArgument::Tag::AS_NAME: return "AS_NAME";
case SymFloatArgument::Tag::AS_FLOAT: return "AS_FLOAT"; case SymFloatArgument::Tag::AS_FLOAT: return "AS_FLOAT";
default: default:
throw std::runtime_error("Unknown enum value"); TORCH_CHECK(false, "Unknown enum value");
} }
} }
inline void parseEnum(std::string_view s, SymFloatArgument::Tag& t) { inline void parseEnum(std::string_view s, SymFloatArgument::Tag& t) {
if (s == "AS_NAME") { t = SymFloatArgument::Tag::AS_NAME; return; } if (s == "AS_NAME") { t = SymFloatArgument::Tag::AS_NAME; return; }
if (s == "AS_FLOAT") { t = SymFloatArgument::Tag::AS_FLOAT; return; } if (s == "AS_FLOAT") { t = SymFloatArgument::Tag::AS_FLOAT; return; }
throw std::runtime_error("Unknown enum value: " + std::string{s}); TORCH_CHECK(false, "Unknown enum value: " + std::string{s});
} }
@ -1021,14 +1022,14 @@ inline std::string_view printEnum(const SymBoolArgument::Tag& e) {
case SymBoolArgument::Tag::AS_NAME: return "AS_NAME"; case SymBoolArgument::Tag::AS_NAME: return "AS_NAME";
case SymBoolArgument::Tag::AS_BOOL: return "AS_BOOL"; case SymBoolArgument::Tag::AS_BOOL: return "AS_BOOL";
default: default:
throw std::runtime_error("Unknown enum value"); TORCH_CHECK(false, "Unknown enum value");
} }
} }
inline void parseEnum(std::string_view s, SymBoolArgument::Tag& t) { inline void parseEnum(std::string_view s, SymBoolArgument::Tag& t) {
if (s == "AS_NAME") { t = SymBoolArgument::Tag::AS_NAME; return; } if (s == "AS_NAME") { t = SymBoolArgument::Tag::AS_NAME; return; }
if (s == "AS_BOOL") { t = SymBoolArgument::Tag::AS_BOOL; return; } if (s == "AS_BOOL") { t = SymBoolArgument::Tag::AS_BOOL; return; }
throw std::runtime_error("Unknown enum value: " + std::string{s}); TORCH_CHECK(false, "Unknown enum value: " + std::string{s});
} }
@ -1135,14 +1136,14 @@ inline std::string_view printEnum(const OptionalTensorArgument::Tag& e) {
case OptionalTensorArgument::Tag::AS_TENSOR: return "AS_TENSOR"; case OptionalTensorArgument::Tag::AS_TENSOR: return "AS_TENSOR";
case OptionalTensorArgument::Tag::AS_NONE: return "AS_NONE"; case OptionalTensorArgument::Tag::AS_NONE: return "AS_NONE";
default: default:
throw std::runtime_error("Unknown enum value"); TORCH_CHECK(false, "Unknown enum value");
} }
} }
inline void parseEnum(std::string_view s, OptionalTensorArgument::Tag& t) { inline void parseEnum(std::string_view s, OptionalTensorArgument::Tag& t) {
if (s == "AS_TENSOR") { t = OptionalTensorArgument::Tag::AS_TENSOR; return; } if (s == "AS_TENSOR") { t = OptionalTensorArgument::Tag::AS_TENSOR; return; }
if (s == "AS_NONE") { t = OptionalTensorArgument::Tag::AS_NONE; return; } if (s == "AS_NONE") { t = OptionalTensorArgument::Tag::AS_NONE; return; }
throw std::runtime_error("Unknown enum value: " + std::string{s}); TORCH_CHECK(false, "Unknown enum value: " + std::string{s});
} }
@ -1769,7 +1770,7 @@ inline std::string_view printEnum(const Argument::Tag& e) {
case Argument::Tag::AS_OPTIONAL_TENSOR: return "AS_OPTIONAL_TENSOR"; case Argument::Tag::AS_OPTIONAL_TENSOR: return "AS_OPTIONAL_TENSOR";
case Argument::Tag::AS_COMPLEX: return "AS_COMPLEX"; case Argument::Tag::AS_COMPLEX: return "AS_COMPLEX";
default: default:
throw std::runtime_error("Unknown enum value"); TORCH_CHECK(false, "Unknown enum value");
} }
} }
@ -1801,7 +1802,7 @@ inline void parseEnum(std::string_view s, Argument::Tag& t) {
if (s == "AS_SYM_FLOATS") { t = Argument::Tag::AS_SYM_FLOATS; return; } if (s == "AS_SYM_FLOATS") { t = Argument::Tag::AS_SYM_FLOATS; return; }
if (s == "AS_OPTIONAL_TENSOR") { t = Argument::Tag::AS_OPTIONAL_TENSOR; return; } if (s == "AS_OPTIONAL_TENSOR") { t = Argument::Tag::AS_OPTIONAL_TENSOR; return; }
if (s == "AS_COMPLEX") { t = Argument::Tag::AS_COMPLEX; return; } if (s == "AS_COMPLEX") { t = Argument::Tag::AS_COMPLEX; return; }
throw std::runtime_error("Unknown enum value: " + std::string{s}); TORCH_CHECK(false, "Unknown enum value: " + std::string{s});
} }
@ -2127,7 +2128,7 @@ inline std::string_view printEnum(const ConstantValue::Tag& e) {
case ConstantValue::Tag::AS_STRING: return "AS_STRING"; case ConstantValue::Tag::AS_STRING: return "AS_STRING";
case ConstantValue::Tag::AS_BOOL: return "AS_BOOL"; case ConstantValue::Tag::AS_BOOL: return "AS_BOOL";
default: default:
throw std::runtime_error("Unknown enum value"); TORCH_CHECK(false, "Unknown enum value");
} }
} }
@ -2137,7 +2138,7 @@ inline void parseEnum(std::string_view s, ConstantValue::Tag& t) {
if (s == "AS_FLOAT") { t = ConstantValue::Tag::AS_FLOAT; return; } if (s == "AS_FLOAT") { t = ConstantValue::Tag::AS_FLOAT; return; }
if (s == "AS_STRING") { t = ConstantValue::Tag::AS_STRING; return; } if (s == "AS_STRING") { t = ConstantValue::Tag::AS_STRING; return; }
if (s == "AS_BOOL") { t = ConstantValue::Tag::AS_BOOL; return; } if (s == "AS_BOOL") { t = ConstantValue::Tag::AS_BOOL; return; }
throw std::runtime_error("Unknown enum value: " + std::string{s}); TORCH_CHECK(false, "Unknown enum value: " + std::string{s});
} }
@ -2465,7 +2466,7 @@ inline std::string_view printEnum(const InputSpec::Tag& e) {
case InputSpec::Tag::TOKEN: return "TOKEN"; case InputSpec::Tag::TOKEN: return "TOKEN";
case InputSpec::Tag::CONSTANT_INPUT: return "CONSTANT_INPUT"; case InputSpec::Tag::CONSTANT_INPUT: return "CONSTANT_INPUT";
default: default:
throw std::runtime_error("Unknown enum value"); TORCH_CHECK(false, "Unknown enum value");
} }
} }
@ -2477,7 +2478,7 @@ inline void parseEnum(std::string_view s, InputSpec::Tag& t) {
if (s == "CUSTOM_OBJ") { t = InputSpec::Tag::CUSTOM_OBJ; return; } if (s == "CUSTOM_OBJ") { t = InputSpec::Tag::CUSTOM_OBJ; return; }
if (s == "TOKEN") { t = InputSpec::Tag::TOKEN; return; } if (s == "TOKEN") { t = InputSpec::Tag::TOKEN; return; }
if (s == "CONSTANT_INPUT") { t = InputSpec::Tag::CONSTANT_INPUT; return; } if (s == "CONSTANT_INPUT") { t = InputSpec::Tag::CONSTANT_INPUT; return; }
throw std::runtime_error("Unknown enum value: " + std::string{s}); TORCH_CHECK(false, "Unknown enum value: " + std::string{s});
} }
@ -2851,7 +2852,7 @@ inline std::string_view printEnum(const OutputSpec::Tag& e) {
case OutputSpec::Tag::TOKEN: return "TOKEN"; case OutputSpec::Tag::TOKEN: return "TOKEN";
case OutputSpec::Tag::PARAMETER_MUTATION: return "PARAMETER_MUTATION"; case OutputSpec::Tag::PARAMETER_MUTATION: return "PARAMETER_MUTATION";
default: default:
throw std::runtime_error("Unknown enum value"); TORCH_CHECK(false, "Unknown enum value");
} }
} }
@ -2864,7 +2865,7 @@ inline void parseEnum(std::string_view s, OutputSpec::Tag& t) {
if (s == "USER_INPUT_MUTATION") { t = OutputSpec::Tag::USER_INPUT_MUTATION; return; } if (s == "USER_INPUT_MUTATION") { t = OutputSpec::Tag::USER_INPUT_MUTATION; return; }
if (s == "TOKEN") { t = OutputSpec::Tag::TOKEN; return; } if (s == "TOKEN") { t = OutputSpec::Tag::TOKEN; return; }
if (s == "PARAMETER_MUTATION") { t = OutputSpec::Tag::PARAMETER_MUTATION; return; } if (s == "PARAMETER_MUTATION") { t = OutputSpec::Tag::PARAMETER_MUTATION; return; }
throw std::runtime_error("Unknown enum value: " + std::string{s}); TORCH_CHECK(false, "Unknown enum value: " + std::string{s});
} }

View File

@ -131,9 +131,8 @@ FunctionParameter::FunctionParameter(const std::string& fmt, bool keyword_only)
size(0), size(0),
default_scalar(0) { default_scalar(0) {
auto space = fmt.find(' '); auto space = fmt.find(' ');
if (space == std::string::npos) { TORCH_CHECK(
throw std::runtime_error("FunctionParameter(): missing type: " + fmt); space != std::string::npos, "FunctionParameter(): missing type: " + fmt);
}
auto type_str = fmt.substr(0, space); auto type_str = fmt.substr(0, space);
@ -154,10 +153,9 @@ FunctionParameter::FunctionParameter(const std::string& fmt, bool keyword_only)
auto name_str = fmt.substr(space + 1); auto name_str = fmt.substr(space + 1);
auto it = type_map.find(type_str); auto it = type_map.find(type_str);
if (it == type_map.end()) { TORCH_CHECK(
throw std::runtime_error( it != type_map.end(),
"FunctionParameter(): invalid type string: " + type_str); "FunctionParameter(): invalid type string: " + type_str);
}
type_ = it->second; type_ = it->second;
auto eq = name_str.find('='); auto eq = name_str.find('=');
@ -1145,7 +1143,7 @@ auto FunctionParameter::_check(
case ParameterType::DISPATCH_KEY_SET: case ParameterType::DISPATCH_KEY_SET:
return py::isinstance<c10::DispatchKeySet>(py::handle(obj)); return py::isinstance<c10::DispatchKeySet>(py::handle(obj));
default: default:
throw std::runtime_error("unknown parameter type"); TORCH_CHECK(false, "unknown parameter type");
} }
} }
@ -1202,7 +1200,7 @@ std::string FunctionParameter::type_name() const {
case ParameterType::DISPATCH_KEY_SET: case ParameterType::DISPATCH_KEY_SET:
return "DispatchKeySet"; return "DispatchKeySet";
default: default:
throw std::runtime_error("unknown parameter type"); TORCH_CHECK(false, "unknown parameter type");
} }
} }
@ -1324,10 +1322,8 @@ void FunctionParameter::set_default_str(const std::string& str) {
} }
if (type_ == ParameterType::TENSOR || if (type_ == ParameterType::TENSOR ||
type_ == ParameterType::DISPATCH_KEY_SET) { type_ == ParameterType::DISPATCH_KEY_SET) {
if (str != "None") { TORCH_CHECK(
throw std::runtime_error( str == "None", "default value for Tensor must be none, got: " + str);
"default value for Tensor must be none, got: " + str);
}
} else if (type_ == ParameterType::INT64 || type_ == ParameterType::SYM_INT) { } else if (type_ == ParameterType::INT64 || type_ == ParameterType::SYM_INT) {
default_int = atol(str.c_str()); default_int = atol(str.c_str());
} else if (type_ == ParameterType::BOOL) { } else if (type_ == ParameterType::BOOL) {
@ -1351,16 +1347,14 @@ void FunctionParameter::set_default_str(const std::string& str) {
default_intlist = parse_intlist_args(str, size); default_intlist = parse_intlist_args(str, size);
} }
} else if (type_ == ParameterType::FLOAT_LIST) { } else if (type_ == ParameterType::FLOAT_LIST) {
if (str != "None") { TORCH_CHECK(str == "None", "Defaults not supported for float[]");
throw std::runtime_error("Defaults not supported for float[]");
}
} else if (type_ == ParameterType::SCALARTYPE) { } else if (type_ == ParameterType::SCALARTYPE) {
if (str == "None") { if (str == "None") {
default_scalartype = at::ScalarType::Undefined; default_scalartype = at::ScalarType::Undefined;
} else if (str == "torch.int64") { } else if (str == "torch.int64") {
default_scalartype = at::ScalarType::Long; default_scalartype = at::ScalarType::Long;
} else { } else {
throw std::runtime_error("invalid default value for ScalarType: " + str); TORCH_CHECK(false, "invalid default value for ScalarType: " + str);
} }
} else if (type_ == ParameterType::LAYOUT) { } else if (type_ == ParameterType::LAYOUT) {
if (str == "None") { if (str == "None") {
@ -1370,16 +1364,12 @@ void FunctionParameter::set_default_str(const std::string& str) {
} else if (str == "torch.sparse_coo") { } else if (str == "torch.sparse_coo") {
default_layout = at::Layout::Sparse; default_layout = at::Layout::Sparse;
} else { } else {
throw std::runtime_error("invalid default value for layout: " + str); TORCH_CHECK(false, "invalid default value for layout: " + str);
} }
} else if (type_ == ParameterType::DEVICE) { } else if (type_ == ParameterType::DEVICE) {
if (str != "None") { TORCH_CHECK(str == "None", "invalid device: " + str);
throw std::runtime_error("invalid device: " + str);
}
} else if (type_ == ParameterType::STREAM) { } else if (type_ == ParameterType::STREAM) {
if (str != "None") { TORCH_CHECK(str == "None", "invalid stream: " + str);
throw std::runtime_error("invalid stream: " + str);
}
} else if (type_ == ParameterType::STRING) { } else if (type_ == ParameterType::STRING) {
if (str != "None") { if (str != "None") {
default_string = parse_string_literal(str); default_string = parse_string_literal(str);
@ -1408,7 +1398,7 @@ void FunctionParameter::set_default_str(const std::string& str) {
} else if (type_ == ParameterType::QSCHEME) { // NOLINT } else if (type_ == ParameterType::QSCHEME) { // NOLINT
// throw std::runtime_error("ParameterType::QSCHEME"); // throw std::runtime_error("ParameterType::QSCHEME");
} else { } else {
throw std::runtime_error("unknown parameter type"); TORCH_CHECK(false, "unknown parameter type");
} }
default_value = str; default_value = str;
} }
@ -1423,7 +1413,7 @@ FunctionSignature::FunctionSignature(const std::string& fmt, int index)
deprecated(false) { deprecated(false) {
auto open_paren = fmt.find('('); auto open_paren = fmt.find('(');
if (open_paren == std::string::npos) { if (open_paren == std::string::npos) {
throw std::runtime_error("missing opening parenthesis: " + fmt); TORCH_CHECK(false, "missing opening parenthesis: " + fmt);
} }
name = fmt.substr(0, open_paren); name = fmt.substr(0, open_paren);
@ -1445,12 +1435,9 @@ FunctionSignature::FunctionSignature(const std::string& fmt, int index)
break; break;
} }
} }
if (offset == std::string::npos) { TORCH_CHECK(
throw std::runtime_error("missing closing parenthesis: " + fmt); offset != std::string::npos, "missing closing parenthesis: " + fmt);
} TORCH_CHECK(offset != last_offset, "malformed signature: " + fmt);
if (offset == last_offset) {
throw std::runtime_error("malformed signature: " + fmt);
}
auto param_str = fmt.substr(last_offset, offset - last_offset); auto param_str = fmt.substr(last_offset, offset - last_offset);
last_offset = next_offset; last_offset = next_offset;

View File

@ -120,7 +120,7 @@ inline bool THPUtils_unpackBool(PyObject* obj) {
} else if (obj == Py_False) { } else if (obj == Py_False) {
return false; return false;
} else { } else {
throw std::runtime_error("couldn't convert python object to boolean"); TORCH_CHECK(false, "couldn't convert python object to boolean");
} }
} }
@ -199,13 +199,11 @@ inline c10::DeviceIndex THPUtils_unpackDeviceIndex(PyObject* obj) {
if (value == -1 && PyErr_Occurred()) { if (value == -1 && PyErr_Occurred()) {
throw python_error(); throw python_error();
} }
if (overflow != 0) { TORCH_CHECK(overflow == 0, "Overflow when unpacking DeviceIndex");
throw std::runtime_error("Overflow when unpacking DeviceIndex"); TORCH_CHECK(
} value <= std::numeric_limits<c10::DeviceIndex>::max() &&
if (value > std::numeric_limits<c10::DeviceIndex>::max() || value >= std::numeric_limits<c10::DeviceIndex>::min(),
value < std::numeric_limits<c10::DeviceIndex>::min()) { "Overflow when unpacking DeviceIndex");
throw std::runtime_error("Overflow when unpacking DeviceIndex");
}
return (c10::DeviceIndex)value; return (c10::DeviceIndex)value;
} }

View File

@ -101,7 +101,7 @@ inline void store_scalar(void* data, at::ScalarType scalarType, PyObject* obj) {
at::convert<at::Float8_e8m0fnu, double>(THPUtils_unpackDouble(obj)); at::convert<at::Float8_e8m0fnu, double>(THPUtils_unpackDouble(obj));
break; break;
default: default:
throw std::runtime_error("store_scalar: invalid type"); TORCH_CHECK(false, "store_scalar: invalid type");
} }
} }
@ -165,7 +165,7 @@ inline PyObject* load_scalar(const void* data, at::ScalarType scalarType) {
return PyFloat_FromDouble( return PyFloat_FromDouble(
at::convert<double, at::Float8_e8m0fnu>(*(at::Float8_e8m0fnu*)data)); at::convert<double, at::Float8_e8m0fnu>(*(at::Float8_e8m0fnu*)data));
default: default:
throw std::runtime_error("load_scalar: invalid type"); TORCH_CHECK(false, "load_scalar: invalid type");
} }
} }

View File

@ -26,12 +26,10 @@ inline std::string THPUtils_unpackString(PyObject* obj) {
if (PyUnicode_Check(obj)) { if (PyUnicode_Check(obj)) {
Py_ssize_t size = 0; Py_ssize_t size = 0;
const char* data = PyUnicode_AsUTF8AndSize(obj, &size); const char* data = PyUnicode_AsUTF8AndSize(obj, &size);
if (!data) { TORCH_CHECK(data, "error unpacking string as utf-8");
throw std::runtime_error("error unpacking string as utf-8");
}
return std::string(data, (size_t)size); return std::string(data, (size_t)size);
} }
throw std::runtime_error("unpackString: expected bytes or unicode object"); TORCH_CHECK(false, "unpackString: expected bytes or unicode object");
} }
// Unpacks PyBytes (PyString) or PyUnicode as std::string_view // Unpacks PyBytes (PyString) or PyUnicode as std::string_view
@ -50,12 +48,10 @@ inline std::string_view THPUtils_unpackStringView(PyObject* obj) {
if (PyUnicode_Check(obj)) { if (PyUnicode_Check(obj)) {
Py_ssize_t size = 0; Py_ssize_t size = 0;
const char* data = PyUnicode_AsUTF8AndSize(obj, &size); const char* data = PyUnicode_AsUTF8AndSize(obj, &size);
if (!data) { TORCH_CHECK(data, "error unpacking string as utf-8");
throw std::runtime_error("error unpacking string as utf-8");
}
return std::string_view(data, (size_t)size); return std::string_view(data, (size_t)size);
} }
throw std::runtime_error("unpackString: expected bytes or unicode object"); TORCH_CHECK(false, "unpackString: expected bytes or unicode object");
} }
inline PyObject* THPUtils_packString(const char* str) { inline PyObject* THPUtils_packString(const char* str) {

View File

@ -689,7 +689,7 @@ Tensor legacy_sparse_tensor_generic_ctor_new(
return new_with_sizes( return new_with_sizes(
options, scalar_type, deviceOptional, r.symintlist(0)); options, scalar_type, deviceOptional, r.symintlist(0));
} }
throw std::runtime_error("new(): invalid arguments"); TORCH_CHECK(false, "new(): invalid arguments");
} }
// NB: device_idx here is NOT a DeviceIndex, but index into PythonArgs // NB: device_idx here is NOT a DeviceIndex, but index into PythonArgs
@ -808,7 +808,7 @@ static Tensor legacy_tensor_generic_ctor_new(
return legacy_new_from_sequence( return legacy_new_from_sequence(
options, scalar_type, deviceOptional, r.pyobject(0)); options, scalar_type, deviceOptional, r.pyobject(0));
} }
throw std::runtime_error("new(): invalid arguments"); TORCH_CHECK(false, "new(): invalid arguments");
} }
// Handles ONLY torch.Tensor // Handles ONLY torch.Tensor
@ -1072,7 +1072,7 @@ static Tensor sparse_compressed_tensor_ctor_worker(
values.options().layout(layout).pinned_memory(pin_memory)) values.options().layout(layout).pinned_memory(pin_memory))
.set_requires_grad(r.toBool(ARG_REQUIRES_GRAD1)); .set_requires_grad(r.toBool(ARG_REQUIRES_GRAD1));
} }
throw std::runtime_error(name + ": invalid arguments"); TORCH_CHECK(false, name + ": invalid arguments");
} }
Tensor sparse_compressed_tensor_ctor( Tensor sparse_compressed_tensor_ctor(
@ -1274,7 +1274,7 @@ Tensor sparse_coo_tensor_ctor(
inferred_options.dtype(inferred_scalar_type).layout(at::kSparse)) inferred_options.dtype(inferred_scalar_type).layout(at::kSparse))
.set_requires_grad(r.toBool(ARG_REQUIRES_GRAD2)); .set_requires_grad(r.toBool(ARG_REQUIRES_GRAD2));
} }
throw std::runtime_error("sparse_coo_tensor(): invalid arguments"); TORCH_CHECK(false, "sparse_coo_tensor(): invalid arguments");
} }
void _validate_sparse_coo_tensor_args( void _validate_sparse_coo_tensor_args(
@ -1504,7 +1504,7 @@ Tensor tensor_ctor(
new_tensor.set_requires_grad(args_requires_grad); new_tensor.set_requires_grad(args_requires_grad);
return new_tensor; return new_tensor;
} }
throw std::runtime_error("tensor(): invalid arguments"); TORCH_CHECK(false, "tensor(): invalid arguments");
} }
Tensor as_tensor( Tensor as_tensor(
@ -1523,7 +1523,7 @@ Tensor as_tensor(
/*copy_numpy=*/false, /*copy_numpy=*/false,
/*type_inference=*/type_inference); /*type_inference=*/type_inference);
} }
throw std::runtime_error("tensor(): invalid arguments"); TORCH_CHECK(false, "tensor(): invalid arguments");
} }
Tensor new_tensor( Tensor new_tensor(
@ -1561,7 +1561,7 @@ Tensor new_tensor(
new_tensor.set_requires_grad(args_requires_grad); new_tensor.set_requires_grad(args_requires_grad);
return new_tensor; return new_tensor;
} }
throw std::runtime_error("new_tensor(): invalid arguments"); TORCH_CHECK(false, "new_tensor(): invalid arguments");
} }
Tensor tensor_frombuffer( Tensor tensor_frombuffer(

View File

@ -9,32 +9,32 @@
namespace torch::utils { namespace torch::utils {
PyObject* tensor_to_numpy(const at::Tensor&, bool) { PyObject* tensor_to_numpy(const at::Tensor&, bool) {
throw std::runtime_error("PyTorch was compiled without NumPy support"); TORCH_CHECK(false, "PyTorch was compiled without NumPy support");
} }
at::Tensor tensor_from_numpy( at::Tensor tensor_from_numpy(
PyObject* obj, PyObject* obj,
bool warn_if_not_writeable /*=true*/) { bool warn_if_not_writeable /*=true*/) {
throw std::runtime_error("PyTorch was compiled without NumPy support"); TORCH_CHECK(false, "PyTorch was compiled without NumPy support");
} }
bool is_numpy_available() { bool is_numpy_available() {
throw std::runtime_error("PyTorch was compiled without NumPy support"); TORCH_CHECK(false, "PyTorch was compiled without NumPy support");
} }
bool is_numpy_int(PyObject* obj) { bool is_numpy_int(PyObject* obj) {
throw std::runtime_error("PyTorch was compiled without NumPy support"); TORCH_CHECK(false, "PyTorch was compiled without NumPy support");
} }
bool is_numpy_scalar(PyObject* obj) { bool is_numpy_scalar(PyObject* obj) {
throw std::runtime_error("PyTorch was compiled without NumPy support"); TORCH_CHECK(false, "PyTorch was compiled without NumPy support");
} }
at::Tensor tensor_from_cuda_array_interface( at::Tensor tensor_from_cuda_array_interface(
PyObject* obj, PyObject* obj,
std::optional<c10::Device> device_opt) { std::optional<c10::Device> device_opt) {
throw std::runtime_error("PyTorch was compiled without NumPy support"); TORCH_CHECK(false, "PyTorch was compiled without NumPy support");
} }
void warn_numpy_not_writeable() { void warn_numpy_not_writeable() {
throw std::runtime_error("PyTorch was compiled without NumPy support"); TORCH_CHECK(false, "PyTorch was compiled without NumPy support");
} }
// No-op stubs. // No-op stubs.
@ -215,9 +215,7 @@ void warn_numpy_not_writeable() {
at::Tensor tensor_from_numpy( at::Tensor tensor_from_numpy(
PyObject* obj, PyObject* obj,
bool warn_if_not_writeable /*=true*/) { bool warn_if_not_writeable /*=true*/) {
if (!is_numpy_available()) { TORCH_CHECK(is_numpy_available(), "Numpy is not available");
throw std::runtime_error("Numpy is not available");
}
TORCH_CHECK_TYPE( TORCH_CHECK_TYPE(
PyArray_Check(obj), PyArray_Check(obj),
"expected np.ndarray (got ", "expected np.ndarray (got ",
@ -385,9 +383,7 @@ bool is_numpy_scalar(PyObject* obj) {
at::Tensor tensor_from_cuda_array_interface( at::Tensor tensor_from_cuda_array_interface(
PyObject* obj, PyObject* obj,
std::optional<c10::Device> device_opt) { std::optional<c10::Device> device_opt) {
if (!is_numpy_available()) { TORCH_CHECK(is_numpy_available(), "Numpy is not available");
throw std::runtime_error("Numpy is not available");
}
auto cuda_dict = auto cuda_dict =
THPObjectPtr(PyObject_GetAttrString(obj, "__cuda_array_interface__")); THPObjectPtr(PyObject_GetAttrString(obj, "__cuda_array_interface__"));
TORCH_INTERNAL_ASSERT(cuda_dict); TORCH_INTERNAL_ASSERT(cuda_dict);