mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
* Improve Variable interface * Address comments from @apaszke and @colesbury * string ::operator= is not noexcept * Remove ir.h from tracer_state.h to improve build times * Make Variable a struct and pack SavedVariable fields * Implement as_variable_ref * grad_fn_ptr() -> grad_fn_unsafe() * Reduce hackiness of set_type hack * Include variable.h and edge.h in tracer_state.h because it uses them * class Variable -> struct Variable because Windows cant even * Make Variable::output_nr uint32_t instead of int * Add comment about tracing state * Replaced more static_cast<Variable&> and improve docs * Remove SavedVariable destructor and construct members in init list * Clarify docs for Variable * Variable::set_version -> set_version_counter
57 lines
1.6 KiB
C++
57 lines
1.6 KiB
C++
#include "torch/csrc/autograd/functions/utils.h"
|
|
|
|
#include "torch/csrc/autograd/edge.h"
|
|
#include "torch/csrc/autograd/function.h"
|
|
#include "torch/csrc/autograd/variable.h"
|
|
|
|
#include <sstream>
|
|
#include <vector>
|
|
|
|
namespace torch { namespace autograd {
|
|
|
|
variable_list wrap_outputs(const variable_list& inputs, tensor_list&& outputs,
|
|
function_constructor ctr) {
|
|
variable_list result;
|
|
result.reserve(outputs.size());
|
|
if (!any_variable_requires_grad(inputs)) {
|
|
for (auto& output : outputs) {
|
|
if (output.defined()) {
|
|
result.push_back(make_variable(output, /*requires_grad=*/false));
|
|
} else {
|
|
result.emplace_back();
|
|
}
|
|
}
|
|
} else {
|
|
auto grad_fn = ctr(get_next_functions(inputs));
|
|
for (auto& output : outputs) {
|
|
if (output.defined()) {
|
|
result.push_back(make_variable(output, Edge(grad_fn, grad_fn->num_inputs++)));
|
|
} else {
|
|
++grad_fn->num_inputs;
|
|
result.emplace_back();
|
|
}
|
|
}
|
|
}
|
|
return result;
|
|
}
|
|
|
|
void check_input_variables(const char* name, const variable_list& inputs, int args, int required_args) {
|
|
if (required_args == -1) {
|
|
required_args = args;
|
|
}
|
|
if (inputs.size() != (size_t)args) {
|
|
std::stringstream ss;
|
|
ss << name << ": expected " << args << " arguments (got " << inputs.size();
|
|
ss << ")";
|
|
throw std::runtime_error(ss.str());
|
|
}
|
|
for (int i = 0; i < required_args; ++i) {
|
|
if (!inputs[i].defined()) {
|
|
std::stringstream ss;
|
|
ss << name << ": expected Variable at argument " << i << " (got None)";
|
|
throw std::runtime_error(ss.str());
|
|
}
|
|
}
|
|
}
|
|
}}
|