mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[Code Clean] Replace std::runtime_error with TORCH_CHECK (#163264)
Related ISSUE: https://github.com/pytorch/pytorch/issues/148114 Pull Request resolved: https://github.com/pytorch/pytorch/pull/163264 Approved by: https://github.com/albanD, https://github.com/cyyever
This commit is contained in:
@ -245,13 +245,12 @@ static void general_trace_function(
|
||||
tracer::addInputs(
|
||||
node, args[i].name().c_str(), iter->toBoolList().vec());
|
||||
} else {
|
||||
throw std::runtime_error(
|
||||
"unsupported input list type: " + elem_type->str());
|
||||
TORCH_CHECK(false, "unsupported input list type: ", elem_type->str());
|
||||
}
|
||||
} else if (iter->isObject()) {
|
||||
tracer::addInputs(node, args[i].name().c_str(), iter->toObject());
|
||||
} else {
|
||||
throw std::runtime_error("unsupported input type: " + type->str());
|
||||
TORCH_CHECK(false, "unsupported input type: ", type->str());
|
||||
}
|
||||
}
|
||||
graph->insertNode(node);
|
||||
@ -277,16 +276,19 @@ static void general_trace_function(
|
||||
AT_ASSERT(iter->isTensorList());
|
||||
tracer::addOutput(node, iter->toTensorList());
|
||||
} else {
|
||||
throw std::runtime_error(
|
||||
"unsupported output list type: " + elem_type->str());
|
||||
TORCH_CHECK(
|
||||
false, "unsupported output list type: ", elem_type->str());
|
||||
}
|
||||
} else if (type->kind() == TypeKind::ClassType) {
|
||||
AT_ASSERT(iter->isObject());
|
||||
tracer::addOutput(node, iter->toObject());
|
||||
} else {
|
||||
throw std::runtime_error(
|
||||
"unsupported output type: " + type->str() +
|
||||
", from operator: " + toString(op.operator_name()));
|
||||
TORCH_CHECK(
|
||||
false,
|
||||
"unsupported output type: ",
|
||||
type->str(),
|
||||
", from operator: ",
|
||||
toString(op.operator_name()));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -11,10 +11,8 @@ void check_single_result(
|
||||
const at::TensorBase& value,
|
||||
const at::TensorBase& result,
|
||||
const std::string& hook_name) {
|
||||
if (!value.defined()) {
|
||||
throw std::runtime_error(
|
||||
"can't replace a empty gradient with a non-empty value");
|
||||
}
|
||||
TORCH_CHECK(
|
||||
value.defined(), "can't replace a empty gradient with a non-empty value");
|
||||
torch::autograd::check_variable_result(value, result, hook_name);
|
||||
}
|
||||
} // namespace
|
||||
|
@ -482,30 +482,31 @@ void check_variable_result(
|
||||
const at::TensorBase& original,
|
||||
const at::TensorBase& result,
|
||||
const std::string& hook_name) {
|
||||
if (!original.options().type_equal(result.options())) {
|
||||
std::stringstream ss;
|
||||
ss << "hook '" << hook_name << "' has changed the type of value (";
|
||||
ss << "was " << original.toString() << " got ";
|
||||
ss << result.toString() << ")";
|
||||
throw std::runtime_error(ss.str());
|
||||
}
|
||||
TORCH_CHECK(
|
||||
original.options().type_equal(result.options()),
|
||||
"hook '",
|
||||
hook_name,
|
||||
"' has changed the type of value (was ",
|
||||
original.toString(),
|
||||
" got ",
|
||||
result.toString(),
|
||||
")");
|
||||
|
||||
if (original.is_cuda() != result.is_cuda()) {
|
||||
std::stringstream ss;
|
||||
ss << "hook '" << hook_name << "' has changed the type of value";
|
||||
if (original.is_cuda()) {
|
||||
ss << " (was CUDA tensor got CPU tensor)";
|
||||
} else {
|
||||
ss << " (was CPU tensor got CUDA tensor)";
|
||||
}
|
||||
throw std::runtime_error(ss.str());
|
||||
}
|
||||
TORCH_CHECK(
|
||||
original.is_cuda() == result.is_cuda(),
|
||||
"hook '",
|
||||
hook_name,
|
||||
"' has changed the type of value (was ",
|
||||
original.is_cuda() ? "CUDA tensor" : "CPU tensor",
|
||||
" got ",
|
||||
result.is_cuda() ? "CUDA tensor" : "CPU tensor",
|
||||
")");
|
||||
|
||||
if (original.sym_sizes().vec() != result.sym_sizes().vec()) {
|
||||
std::stringstream ss;
|
||||
ss << "hook '" << hook_name << "' has changed the size of value";
|
||||
throw std::runtime_error(ss.str());
|
||||
}
|
||||
TORCH_CHECK(
|
||||
original.sym_sizes().vec() == result.sym_sizes().vec(),
|
||||
"hook '",
|
||||
hook_name,
|
||||
"' has changed the size of value");
|
||||
}
|
||||
|
||||
AutogradContext::AutogradContext(PackedArgs& packed_args) {
|
||||
|
@ -228,30 +228,32 @@ inline variable_list CppNode_apply_functional(
|
||||
}
|
||||
}
|
||||
|
||||
if (num_outputs != num_forward_inputs) {
|
||||
std::string msg("function ");
|
||||
msg += name + " returned an incorrect number of gradients (expected ";
|
||||
msg += std::to_string(num_forward_inputs) + ", got ";
|
||||
msg += std::to_string(num_outputs) + ")";
|
||||
throw std::runtime_error(msg);
|
||||
}
|
||||
TORCH_CHECK(
|
||||
num_outputs == num_forward_inputs,
|
||||
"function ",
|
||||
name,
|
||||
" returned an incorrect number of gradients (expected ",
|
||||
num_forward_inputs,
|
||||
", got ",
|
||||
num_outputs,
|
||||
")");
|
||||
|
||||
variable_list results;
|
||||
results.reserve(num_outputs);
|
||||
for (const auto i : c10::irange(num_outputs)) {
|
||||
if (!is_variable_input_[i]) {
|
||||
if (outputs[i].defined()) {
|
||||
std::string msg("function ");
|
||||
msg += name +
|
||||
" returned a gradient different that is defined at position ";
|
||||
msg += std::to_string(i + 1) +
|
||||
", std the corresponding forward input was not a Variable";
|
||||
throw std::runtime_error(msg);
|
||||
}
|
||||
TORCH_CHECK(
|
||||
outputs[i].defined() == false,
|
||||
"function ",
|
||||
name,
|
||||
" returned a gradient different that is defined at position ",
|
||||
i + 1,
|
||||
", std the corresponding forward input was not a Variable");
|
||||
continue;
|
||||
}
|
||||
results.emplace_back(outputs[i]);
|
||||
}
|
||||
|
||||
return results;
|
||||
}
|
||||
|
||||
|
@ -707,9 +707,8 @@ void GraphTask::mark_as_completed_and_run_post_processing() {
|
||||
}
|
||||
|
||||
void GraphTask::exec_post_processing() {
|
||||
if (!not_ready_.empty()) {
|
||||
throw std::runtime_error("could not compute gradients for some functions");
|
||||
}
|
||||
TORCH_CHECK(
|
||||
not_ready_.empty(), "could not compute gradients for some functions");
|
||||
|
||||
// set the thread_local current_graph_task_ as more callbacks can be installed
|
||||
// by existing final callbacks.
|
||||
@ -1149,12 +1148,13 @@ void Engine::evaluate_function(
|
||||
for (const auto i : c10::irange(num_outputs)) {
|
||||
auto& output = outputs[i];
|
||||
at::OptionalDeviceGuard guard(device_of(output));
|
||||
if (output.defined() && isnan(output)._is_any_true().item<bool>()) {
|
||||
std::stringstream ss;
|
||||
ss << "Function '" << fn.name() << "' returned nan values in its " << i
|
||||
<< "th output.";
|
||||
throw std::runtime_error(ss.str());
|
||||
}
|
||||
TORCH_CHECK(
|
||||
!output.defined() || !isnan(output)._is_any_true().item<bool>(),
|
||||
"Function '",
|
||||
fn.name(),
|
||||
"' returned nan values in its ",
|
||||
i,
|
||||
"th output.");
|
||||
}
|
||||
}
|
||||
|
||||
@ -1175,7 +1175,7 @@ void Engine::evaluate_function(
|
||||
|
||||
if (it == dependencies.end()) {
|
||||
auto name = next.function->name();
|
||||
throw std::runtime_error(std::string("dependency not found for ") + name);
|
||||
TORCH_CHECK(false, "dependency not found for ", name);
|
||||
} else if (--it->second == 0) {
|
||||
dependencies.erase(it);
|
||||
is_ready = true;
|
||||
|
@ -17,7 +17,7 @@ variable_list Error::apply(variable_list&& inputs) {
|
||||
}
|
||||
|
||||
variable_list Error::apply(variable_list&& inputs) const {
|
||||
throw std::runtime_error(msg);
|
||||
TORCH_CHECK(false, msg);
|
||||
}
|
||||
|
||||
void Error::compiled_args(CompiledNodeArgs& args) const {
|
||||
|
@ -47,7 +47,7 @@ struct UndefinedGradCtor {
|
||||
|
||||
struct NoCtor {
|
||||
Node* operator()(PyObject* args) {
|
||||
throw std::runtime_error("Cannot construct");
|
||||
TORCH_CHECK(false, "Cannot construct");
|
||||
}
|
||||
};
|
||||
|
||||
|
@ -184,9 +184,7 @@ inline variable_list CopySlices::apply_impl(
|
||||
// see Note [Thread Safety on Autograd Node]
|
||||
std::lock_guard<std::mutex> lock(mutex_);
|
||||
|
||||
if (!fn) {
|
||||
throw std::runtime_error(ERR_BACKWARD_TWICE);
|
||||
}
|
||||
TORCH_CHECK(fn, ERR_BACKWARD_TWICE);
|
||||
|
||||
auto result =
|
||||
grad.new_empty_strided_symint(base.sym_sizes(), base.sym_strides());
|
||||
@ -252,9 +250,7 @@ variable_list CopySlices::apply_with_saved(
|
||||
|
||||
auto results = variable_list(num_outputs());
|
||||
if (grads[0].defined()) {
|
||||
if (!fn) {
|
||||
throw std::runtime_error(ERR_BACKWARD_TWICE);
|
||||
}
|
||||
TORCH_CHECK(fn, ERR_BACKWARD_TWICE);
|
||||
update_exec_info();
|
||||
|
||||
std::vector<bool> needs_input_grad;
|
||||
|
@ -53,18 +53,22 @@ void check_input_variables(
|
||||
if (required_args == -1) {
|
||||
required_args = args;
|
||||
}
|
||||
if (inputs.size() != static_cast<size_t>(args)) {
|
||||
std::stringstream ss;
|
||||
ss << name << ": expected " << args << " arguments (got " << inputs.size();
|
||||
ss << ")";
|
||||
throw std::runtime_error(ss.str());
|
||||
}
|
||||
TORCH_CHECK(
|
||||
inputs.size() == static_cast<size_t>(args),
|
||||
name,
|
||||
": expected ",
|
||||
args,
|
||||
" arguments (got ",
|
||||
inputs.size(),
|
||||
")");
|
||||
|
||||
for (const auto i : c10::irange(required_args)) {
|
||||
if (!inputs[i].defined() && !allow_undefined) {
|
||||
std::stringstream ss;
|
||||
ss << name << ": expected Tensor at argument " << i << " (got None)";
|
||||
throw std::runtime_error(ss.str());
|
||||
}
|
||||
TORCH_CHECK(
|
||||
inputs[i].defined() || allow_undefined,
|
||||
name,
|
||||
": expected Tensor at argument ",
|
||||
i,
|
||||
" (got None)");
|
||||
}
|
||||
}
|
||||
} // namespace torch::autograd
|
||||
|
@ -37,7 +37,8 @@ extern "C" {
|
||||
// https://github.com/pytorch/pytorch/issues/51026
|
||||
__attribute__((weak)) int acc_get_device_type();
|
||||
__attribute__((weak)) int acc_get_device_type() {
|
||||
throw std::runtime_error(
|
||||
TORCH_CHECK(
|
||||
false,
|
||||
"Dummy implementation of acc_get_device_type is not supposed to be called!");
|
||||
}
|
||||
} // extern "C"
|
||||
|
@ -97,7 +97,7 @@ struct TORCH_API LegacyEvent {
|
||||
case EventKind::MemoryAlloc:
|
||||
return "memory_alloc";
|
||||
}
|
||||
throw std::runtime_error("unknown event kind");
|
||||
TORCH_CHECK(false, "unknown event kind");
|
||||
}
|
||||
|
||||
EventKind kind() const {
|
||||
|
@ -30,7 +30,7 @@ void PyAnomalyMetadata::store_stack() {
|
||||
void PyAnomalyMetadata::print_stack(const std::string& current_node_name) {
|
||||
pybind11::gil_scoped_acquire gil;
|
||||
if (!PyDict_Check(dict())) {
|
||||
throw std::runtime_error("Anomaly metadata is not a python dictionary.");
|
||||
TORCH_CHECK(false, "Anomaly metadata is not a python dictionary.");
|
||||
}
|
||||
PyObject* trace_stack = nullptr;
|
||||
if (PyDict_GetItemStringRef(dict(), ANOMALY_TRACE_KEY, &trace_stack) < 0) {
|
||||
|
@ -261,8 +261,7 @@ PyTypeObject* _initFunctionPyTypeObject(
|
||||
type.tp_traverse = THPCppFunction_traverse;
|
||||
type.tp_clear = THPCppFunction_clear;
|
||||
if (PyType_Ready(&type) < 0) {
|
||||
auto msg = std::string("Unable to instantiate PyTypeObject for ") + name;
|
||||
throw std::runtime_error(msg);
|
||||
TORCH_CHECK(false, "Unable to instantiate PyTypeObject for ", name);
|
||||
}
|
||||
return &type;
|
||||
}
|
||||
|
@ -501,7 +501,7 @@ static void child_atfork() {
|
||||
bool THPEngine_initModule(PyObject* module) {
|
||||
#ifndef _WIN32
|
||||
if (pthread_atfork(nullptr, nullptr, child_atfork) != 0) {
|
||||
throw std::runtime_error("unable to set pthread_atfork handler");
|
||||
TORCH_CHECK(false, "unable to set pthread_atfork handler");
|
||||
}
|
||||
#endif
|
||||
if (PyType_Ready(&THPEngineType) < 0)
|
||||
|
@ -188,13 +188,15 @@ auto PyNode::apply(variable_list&& inputs) -> variable_list {
|
||||
}
|
||||
|
||||
// Now the number of gradients should match
|
||||
if (num_outputs != num_forward_inputs) {
|
||||
std::string msg("function ");
|
||||
msg += name() + " returned an incorrect number of gradients (expected ";
|
||||
msg += std::to_string(num_forward_inputs) + ", got ";
|
||||
msg += std::to_string(num_outputs) + ")";
|
||||
throw std::runtime_error(msg);
|
||||
}
|
||||
TORCH_CHECK(
|
||||
num_outputs == num_forward_inputs,
|
||||
"function ",
|
||||
name(),
|
||||
" returned an incorrect number of gradients (expected ",
|
||||
num_forward_inputs,
|
||||
", got ",
|
||||
num_outputs,
|
||||
")");
|
||||
|
||||
// Massage the Python results tuple back into a C++ variable_list
|
||||
return to_variable_list(r.get(), is_variable_input);
|
||||
@ -435,24 +437,24 @@ variable_list PyNode::to_variable_list(
|
||||
PyObject* output = PyTuple_GET_ITEM(outputs, i);
|
||||
bool was_variable = is_variable_input[i];
|
||||
if (!was_variable) {
|
||||
if (output != Py_None) {
|
||||
std::string msg("function ");
|
||||
msg += name() + " returned a gradient different than None at position ";
|
||||
msg += std::to_string(i + 1) +
|
||||
", but the corresponding forward input was not a Variable";
|
||||
throw std::runtime_error(msg);
|
||||
}
|
||||
TORCH_CHECK(
|
||||
output == Py_None,
|
||||
"function ",
|
||||
name(),
|
||||
" returned a gradient different than None at position ",
|
||||
i + 1,
|
||||
", but the corresponding forward input was not a Variable");
|
||||
continue;
|
||||
}
|
||||
if (output == Py_None) {
|
||||
results.emplace_back();
|
||||
} else {
|
||||
if (!THPVariable_Check(output)) {
|
||||
std::string msg("expected Variable or None (got ");
|
||||
msg += THPUtils_typename(output);
|
||||
msg += ")";
|
||||
throw std::runtime_error(msg);
|
||||
}
|
||||
TORCH_CHECK(
|
||||
THPVariable_Check(output),
|
||||
"expected Variable or None (got ",
|
||||
THPUtils_typename(output),
|
||||
")");
|
||||
|
||||
results.emplace_back(THPVariable_Unpack(output));
|
||||
}
|
||||
}
|
||||
|
@ -289,9 +289,7 @@ static variable_list unwrap_variables(PyObject* py_variables) {
|
||||
results[i] = THPVariable_Unpack(item);
|
||||
} else {
|
||||
// this should never happen, but just in case...
|
||||
std::stringstream ss;
|
||||
ss << "expected variable but got " << Py_TYPE(item)->tp_name;
|
||||
throw std::runtime_error(ss.str());
|
||||
TORCH_CHECK(false, "expected variable but got ", Py_TYPE(item)->tp_name);
|
||||
}
|
||||
}
|
||||
return results;
|
||||
@ -308,14 +306,16 @@ static void check_result(PyObject* prev, PyObject* result, PyObject* hook) {
|
||||
|
||||
auto prev_size = PyTuple_GET_SIZE(prev);
|
||||
auto result_size = PyTuple_GET_SIZE(result);
|
||||
if (prev_size != result_size) {
|
||||
std::stringstream ss;
|
||||
auto name = hook_name(hook);
|
||||
ss << "hook '" << name << "' has returned an incorrect number ";
|
||||
ss << "of values (got " << result_size << ", but expected ";
|
||||
ss << prev_size << ")";
|
||||
throw std::runtime_error(ss.str());
|
||||
}
|
||||
|
||||
TORCH_CHECK(
|
||||
prev_size == result_size,
|
||||
"hook '",
|
||||
hook_name(hook),
|
||||
"' has returned an incorrect number of values (got ",
|
||||
result_size,
|
||||
", but expected ",
|
||||
prev_size,
|
||||
")");
|
||||
|
||||
for (const auto i : c10::irange(prev_size)) {
|
||||
check_single_result(
|
||||
@ -330,10 +330,9 @@ static void check_single_result(
|
||||
if (_result == Py_None)
|
||||
return;
|
||||
|
||||
if (_original == Py_None) {
|
||||
throw std::runtime_error(
|
||||
"can't replace a None gradient with a non-None value");
|
||||
}
|
||||
TORCH_CHECK(
|
||||
_original != Py_None,
|
||||
"can't replace a None gradient with a non-None value");
|
||||
|
||||
if (!PyObject_IsInstance(_result, THPVariableClass)) {
|
||||
PyErr_Format(
|
||||
|
@ -11,8 +11,8 @@ struct TORCH_API SavedVariableHooks {
|
||||
virtual ~SavedVariableHooks() = default;
|
||||
virtual std::optional<std::pair<c10::SafePyObject, c10::SafePyObject>>
|
||||
retrieve_unpack_hook_data() const {
|
||||
throw std::runtime_error(
|
||||
"Compiled Autograd only supports python saved tensor hooks ");
|
||||
TORCH_CHECK(
|
||||
false, "Compiled Autograd only supports python saved tensor hooks ");
|
||||
}
|
||||
};
|
||||
|
||||
|
@ -17,8 +17,8 @@ inline std::tuple<
|
||||
std::optional<at::MemoryFormat>>
|
||||
parse_to_conversion(PythonArgs& r, bool allow_copy) {
|
||||
if (r.idx == 0) {
|
||||
if (!allow_copy && !r.isNone(3))
|
||||
throw std::runtime_error(".to() does not accept copy argument");
|
||||
TORCH_CHECK(
|
||||
allow_copy || r.isNone(3), ".to() does not accept copy argument");
|
||||
return std::make_tuple(
|
||||
r.deviceOptional(0),
|
||||
r.scalartypeOptional(1),
|
||||
@ -26,8 +26,8 @@ parse_to_conversion(PythonArgs& r, bool allow_copy) {
|
||||
r.toBool(3),
|
||||
r.memoryformatOptional(4));
|
||||
} else if (r.idx == 1) {
|
||||
if (!allow_copy && !r.isNone(2))
|
||||
throw std::runtime_error(".to() does not accept copy argument");
|
||||
TORCH_CHECK(
|
||||
allow_copy || r.isNone(2), ".to() does not accept copy argument");
|
||||
return std::make_tuple(
|
||||
std::nullopt,
|
||||
r.scalartype(0),
|
||||
@ -36,8 +36,8 @@ parse_to_conversion(PythonArgs& r, bool allow_copy) {
|
||||
r.memoryformatOptional(3));
|
||||
} else {
|
||||
auto tensor = r.tensor(0);
|
||||
if (!allow_copy && !r.isNone(2))
|
||||
throw std::runtime_error(".to() does not accept copy argument");
|
||||
TORCH_CHECK(
|
||||
allow_copy || r.isNone(2), ".to() does not accept copy argument");
|
||||
return std::make_tuple(
|
||||
tensor.device(),
|
||||
tensor.scalar_type(),
|
||||
|
@ -597,10 +597,9 @@ void VariableHooks::_backward(
|
||||
void VariableHooks::requires_grad_(
|
||||
const at::TensorBase& self,
|
||||
bool _requires_grad) const {
|
||||
if (!self.is_leaf() && !_requires_grad) {
|
||||
throw std::runtime_error(
|
||||
autograd::utils::requires_grad_leaf_error(_requires_grad));
|
||||
}
|
||||
TORCH_CHECK(
|
||||
self.is_leaf() || _requires_grad,
|
||||
autograd::utils::requires_grad_leaf_error(_requires_grad));
|
||||
self.set_requires_grad(_requires_grad);
|
||||
}
|
||||
|
||||
@ -624,7 +623,7 @@ const at::TensorBase& VariableHooks::base(const at::TensorBase& self) const {
|
||||
"Can't get base of non-backward view Tensor");
|
||||
return diff_view_meta->get_backward_view().base_;
|
||||
} else {
|
||||
throw std::runtime_error("Can't get base of non-view Tensor");
|
||||
TORCH_CHECK(false, "Can't get base of non-view Tensor");
|
||||
}
|
||||
}
|
||||
|
||||
|
Reference in New Issue
Block a user