Files
pytorch/torch/csrc/autograd/functions/basic_ops.cpp
Adam Paszke af21c6b018 Add Node type to JIT IR
Rewrite Type as a class hierarchy

PR comments + rebase fixes
2017-09-05 17:48:55 -04:00

65 lines
2.0 KiB
C++

#include "basic_ops.h"
#include "torch/csrc/autograd/variable.h"
#include "torch/csrc/autograd/functions/utils.h"
#include "torch/csrc/utils/auto_gpu.h"
namespace torch { namespace autograd {
auto Error::apply(const variable_list& grad_outputs) -> variable_list {
throw std::runtime_error(msg);
};
auto DelayedError::apply(const variable_list& inputs) -> variable_list {
tensor_list outputs;
outputs.reserve(inputs.size());
for (auto& var : inputs) {
outputs.emplace_back(var ? var->data : at::Tensor());
}
return wrap_outputs(inputs, std::move(outputs), [&](FunctionFlags f) {
return std::make_shared<Error>(msg, std::move(f));
});
};
auto Add::apply(const variable_list& inputs) -> variable_list {
check_input_variables("Add", inputs, 2);
auto& input1 = inputs[0]->data;
auto& input2 = inputs[1]->data;
AutoGPU guard(input1);
at::Tensor output;
if (input1.type().isSparse()) {
output = input2 + input1;
} else {
output = input1 + input2;
}
return wrap_outputs(inputs, as_tensor_list(std::move(output)), [&](FunctionFlags f) {
return std::make_shared<AddBackward>(std::move(f));
});
};
auto AddBackward::apply(const variable_list& grad_outputs) -> variable_list {
check_input_variables("AddBackward", grad_outputs, 1);
return {grad_outputs[0], grad_outputs[0]};
};
auto Mul::apply(const variable_list& inputs) -> variable_list {
check_input_variables("Mul", inputs, 2);
auto& input1 = inputs[0]->data;
auto& input2 = inputs[1]->data;
AutoGPU guard(input1.type().isCuda() ? input1.get_device() : -1);
auto output = input1 * input2;
return wrap_outputs(inputs, as_tensor_list(std::move(output)), [&](FunctionFlags f) {
return std::make_shared<MulBackward>(std::move(f), inputs[0]->save(this), inputs[1]->save(this));
});
};
auto MulBackward::apply(const variable_list& grad_outputs) -> variable_list {
check_input_variables("MulBackward", grad_outputs, 1);
throw std::runtime_error("MulBackward::apply not implemented");
};
}} // namespace torch::autograd