mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 12:54:11 +08:00
Generated by running the following from PyTorch root: ``` find . -regex ".*\.\(cpp\|h\|cu\|hpp\|cc\|cxx\)$" | grep -v "build/" | xargs -n 50 -P 4 perl -pi -e 's/c10::optional/std::optional/' ``` `c10::optional` is just an alias for `std::optional`. This removes usages of that alias in preparation for eliminating it entirely. Pull Request resolved: https://github.com/pytorch/pytorch/pull/126135 Approved by: https://github.com/Skylion007, https://github.com/malfet, https://github.com/albanD, https://github.com/aaronenyeshi
201 lines
6.4 KiB
C++
201 lines
6.4 KiB
C++
#include <c10/util/irange.h>
|
|
#include <torch/script.h>
|
|
#include <torch/cuda.h>
|
|
|
|
#include "op.h"
|
|
|
|
#include <memory>
|
|
#include <string>
|
|
#include <vector>
|
|
|
|
#include <iostream>
|
|
|
|
namespace helpers {
|
|
template <typename Predicate>
|
|
void check_all_parameters(
|
|
const torch::jit::Module& module,
|
|
Predicate predicate) {
|
|
for (at::Tensor parameter : module.parameters()) {
|
|
AT_ASSERT(predicate(parameter));
|
|
}
|
|
}
|
|
|
|
template<class Result, class... Args>
|
|
Result get_operator_from_registry_and_execute(const char* op_name, Args&&... args) {
|
|
auto& ops = torch::jit::getAllOperatorsFor(
|
|
torch::jit::Symbol::fromQualString(op_name));
|
|
TORCH_INTERNAL_ASSERT(ops.size() == 1);
|
|
|
|
auto& op = ops.front();
|
|
TORCH_INTERNAL_ASSERT(op->schema().name() == op_name);
|
|
|
|
torch::jit::Stack stack;
|
|
torch::jit::push(stack, std::forward<Args>(args)...);
|
|
op->getOperation()(stack);
|
|
|
|
TORCH_INTERNAL_ASSERT(1 == stack.size());
|
|
return torch::jit::pop(stack).to<Result>();
|
|
}
|
|
} // namespace helpers
|
|
|
|
void get_operator_from_registry_and_execute() {
|
|
std::vector<torch::Tensor> output =
|
|
helpers::get_operator_from_registry_and_execute<std::vector<torch::Tensor>>("custom::op", torch::ones(5), 2.0, 3);
|
|
|
|
const auto manual = custom_op(torch::ones(5), 2.0, 3);
|
|
|
|
TORCH_INTERNAL_ASSERT(output.size() == 3);
|
|
for (const auto i : c10::irange(output.size())) {
|
|
TORCH_INTERNAL_ASSERT(output[i].allclose(torch::ones(5) * 2));
|
|
TORCH_INTERNAL_ASSERT(output[i].allclose(manual[i]));
|
|
}
|
|
}
|
|
|
|
void get_autograd_operator_from_registry_and_execute() {
|
|
torch::Tensor x = torch::randn({5,5}, torch::requires_grad());
|
|
torch::Tensor y = torch::randn({5,5}, torch::requires_grad());
|
|
torch::Tensor z = torch::randn({5,5}, torch::requires_grad());
|
|
|
|
torch::Tensor output =
|
|
helpers::get_operator_from_registry_and_execute<torch::Tensor>("custom::op_with_autograd", x, 2, y, std::optional<torch::Tensor>());
|
|
|
|
TORCH_INTERNAL_ASSERT(output.allclose(x + 2*y + x*y));
|
|
auto go = torch::ones({}, torch::requires_grad());
|
|
output.sum().backward(go, false, true);
|
|
|
|
TORCH_INTERNAL_ASSERT(torch::allclose(x.grad(), y + torch::ones({5,5})));
|
|
TORCH_INTERNAL_ASSERT(torch::allclose(y.grad(), x + torch::ones({5,5})*2));
|
|
|
|
// Test with optional argument.
|
|
at::zero_(x.mutable_grad());
|
|
at::zero_(y.mutable_grad());
|
|
output = helpers::get_operator_from_registry_and_execute<torch::Tensor>(
|
|
"custom::op_with_autograd", x, 2, y, z);
|
|
|
|
TORCH_INTERNAL_ASSERT(output.allclose(x + 2*y + x*y + z));
|
|
go = torch::ones({}, torch::requires_grad());
|
|
output.sum().backward(go, false, true);
|
|
|
|
TORCH_INTERNAL_ASSERT(torch::allclose(x.grad(), y + torch::ones({5,5})));
|
|
TORCH_INTERNAL_ASSERT(torch::allclose(y.grad(), x + torch::ones({5,5})*2));
|
|
TORCH_INTERNAL_ASSERT(torch::allclose(z.grad(), torch::ones({5,5})));
|
|
}
|
|
|
|
void get_autograd_operator_from_registry_and_execute_in_nograd_mode() {
|
|
at::AutoDispatchBelowAutograd guard;
|
|
|
|
torch::Tensor x = torch::randn({5,5}, torch::requires_grad());
|
|
torch::Tensor y = torch::randn({5,5}, torch::requires_grad());
|
|
|
|
torch::Tensor output =
|
|
helpers::get_operator_from_registry_and_execute<torch::Tensor>("custom::op_with_autograd", x, 2, y, std::optional<torch::Tensor>());
|
|
|
|
TORCH_INTERNAL_ASSERT(output.allclose(x + 2*y + x*y));
|
|
}
|
|
|
|
void load_serialized_module_with_custom_op_and_execute(
|
|
const std::string& path_to_exported_script_module) {
|
|
torch::jit::Module module =
|
|
torch::jit::load(path_to_exported_script_module);
|
|
std::vector<torch::jit::IValue> inputs;
|
|
inputs.push_back(torch::ones(5));
|
|
auto output = module.forward(inputs).toTensor();
|
|
|
|
AT_ASSERT(output.allclose(torch::ones(5) + 1));
|
|
}
|
|
|
|
void test_argument_checking_for_serialized_modules(
|
|
const std::string& path_to_exported_script_module) {
|
|
torch::jit::Module module =
|
|
torch::jit::load(path_to_exported_script_module);
|
|
|
|
try {
|
|
module.forward({torch::jit::IValue(1), torch::jit::IValue(2)});
|
|
AT_ASSERT(false);
|
|
} catch (const c10::Error& error) {
|
|
AT_ASSERT(
|
|
std::string(error.what_without_backtrace())
|
|
.find("Expected at most 2 argument(s) for operator 'forward', "
|
|
"but received 3 argument(s)") == 0);
|
|
}
|
|
|
|
try {
|
|
module.forward({torch::jit::IValue(5)});
|
|
AT_ASSERT(false);
|
|
} catch (const c10::Error& error) {
|
|
AT_ASSERT(
|
|
std::string(error.what_without_backtrace())
|
|
.find("forward() Expected a value of type 'Tensor' "
|
|
"for argument 'input' but instead found type 'int'") == 0);
|
|
}
|
|
|
|
try {
|
|
module.forward({});
|
|
AT_ASSERT(false);
|
|
} catch (const c10::Error& error) {
|
|
AT_ASSERT(
|
|
std::string(error.what_without_backtrace())
|
|
.find("forward() is missing value for argument 'input'") == 0);
|
|
}
|
|
}
|
|
|
|
void test_move_to_device(const std::string& path_to_exported_script_module) {
|
|
torch::jit::Module module =
|
|
torch::jit::load(path_to_exported_script_module);
|
|
|
|
helpers::check_all_parameters(module, [](const torch::Tensor& tensor) {
|
|
return tensor.device().is_cpu();
|
|
});
|
|
|
|
module.to(torch::kCUDA);
|
|
|
|
helpers::check_all_parameters(module, [](const torch::Tensor& tensor) {
|
|
return tensor.device().is_cuda();
|
|
});
|
|
|
|
module.to(torch::kCPU);
|
|
|
|
helpers::check_all_parameters(module, [](const torch::Tensor& tensor) {
|
|
return tensor.device().is_cpu();
|
|
});
|
|
}
|
|
|
|
void test_move_to_dtype(const std::string& path_to_exported_script_module) {
|
|
torch::jit::Module module =
|
|
torch::jit::load(path_to_exported_script_module);
|
|
|
|
module.to(torch::kFloat16);
|
|
|
|
helpers::check_all_parameters(module, [](const torch::Tensor& tensor) {
|
|
return tensor.dtype() == torch::kFloat16;
|
|
});
|
|
|
|
module.to(torch::kDouble);
|
|
|
|
helpers::check_all_parameters(module, [](const torch::Tensor& tensor) {
|
|
return tensor.dtype() == torch::kDouble;
|
|
});
|
|
}
|
|
|
|
int main(int argc, const char* argv[]) {
|
|
if (argc != 2) {
|
|
std::cerr << "usage: test_custom_ops <path-to-exported-script-module>\n";
|
|
return -1;
|
|
}
|
|
const std::string path_to_exported_script_module = argv[1];
|
|
|
|
get_operator_from_registry_and_execute();
|
|
get_autograd_operator_from_registry_and_execute();
|
|
get_autograd_operator_from_registry_and_execute_in_nograd_mode();
|
|
load_serialized_module_with_custom_op_and_execute(
|
|
path_to_exported_script_module);
|
|
test_argument_checking_for_serialized_modules(path_to_exported_script_module);
|
|
test_move_to_dtype(path_to_exported_script_module);
|
|
|
|
if (torch::cuda::device_count() > 0) {
|
|
test_move_to_device(path_to_exported_script_module);
|
|
}
|
|
|
|
std::cout << "ok\n";
|
|
}
|