mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 12:54:11 +08:00
[vulkan] Allow benchmark binary to handle non-single tensor inputs/outputs for Vulkan models (#73109)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/73109 This change updates the Vulkan model runner in `speed_benchmark_torch` to be able to generate inputs for models that have input/output types other than just a single tensor. Input elements are processed depending on their type. Test Plan: Imported from OSS Reviewed By: mikaylagawarecki Differential Revision: D34354839 Pulled By: SS-JIA fbshipit-source-id: 993e55372d2664fa7eddb16146deba264727f399 (cherry picked from commit 4a140202acb336412676ac090a38d7b93ae49898)
This commit is contained in:
committed by
PyTorch MergeBot
parent
bdc8b3f3e8
commit
52175307e2
@ -180,35 +180,48 @@ class vkRunner final : public Runner<T> {
|
||||
virtual c10::IValue run(
|
||||
T& module,
|
||||
const std::vector<c10::IValue>& inputs) override {
|
||||
// Upload the input tensor(s) to GPU memory.
|
||||
inputs_.clear();
|
||||
inputs_.reserve(inputs.size());
|
||||
for (const auto& input : inputs) {
|
||||
if (input.isTensor()) {
|
||||
inputs_.emplace_back(input.toTensor().vulkan());
|
||||
}
|
||||
else if (input.isList()) {
|
||||
const c10::List<c10::IValue> input_as_list = input.toList();
|
||||
c10::List<at::Tensor> input_vk_list;
|
||||
input_vk_list.reserve(input_as_list.size());
|
||||
for (int i=0; i < input_as_list.size(); ++i) {
|
||||
const c10::IValue element = input_as_list.get(i);
|
||||
if (element.isTensor()) {
|
||||
input_vk_list.emplace_back(element.toTensor().vulkan());
|
||||
}
|
||||
else {
|
||||
CAFFE_THROW("Input of type c10::List must only contain Tensors!");
|
||||
}
|
||||
|
||||
if (inputs_.size() == 0) {
|
||||
// Upload the input tensor(s) to GPU memory.
|
||||
inputs_.clear();
|
||||
inputs_.reserve(inputs.size());
|
||||
for (const auto& input : inputs) {
|
||||
if (input.isTensor()) {
|
||||
inputs_.emplace_back(at::rand(input.toTensor().sizes()).vulkan());
|
||||
}
|
||||
else if (input.isTensorList()) {
|
||||
const c10::List<at::Tensor> input_as_list = input.toTensorList();
|
||||
c10::List<at::Tensor> input_vk_list;
|
||||
input_vk_list.reserve(input_as_list.size());
|
||||
for (int i=0; i < input_as_list.size(); ++i) {
|
||||
const at::Tensor element = input_as_list.get(i);
|
||||
input_vk_list.emplace_back(at::rand(element.sizes()).vulkan());
|
||||
}
|
||||
inputs_.emplace_back(c10::IValue(input_vk_list));
|
||||
}
|
||||
else {
|
||||
CAFFE_THROW("Inputs must only contain IValues of type c10::Tensor or c10::TensorList!");
|
||||
}
|
||||
inputs_.emplace_back(c10::IValue(input_vk_list));
|
||||
}
|
||||
else {
|
||||
CAFFE_THROW("Inputs must only contain IValues of type c10::Tensor or c10::List!");
|
||||
}
|
||||
}
|
||||
|
||||
// Run, and download the output tensor to system memory.
|
||||
return module.forward(inputs_).toTensor().cpu();
|
||||
c10::IValue output = module.forward(inputs_);
|
||||
if (output.isTensor()) {
|
||||
return output.toTensor().cpu();
|
||||
}
|
||||
else if (output.isTensorList()) {
|
||||
return output.toTensorList().get(0).cpu();
|
||||
}
|
||||
else if (output.isList()) {
|
||||
return output.toList().get(0).toTensor().cpu();
|
||||
}
|
||||
else if (output.isTuple()) {
|
||||
return output.toTuple()->elements()[0].toTensor().cpu();
|
||||
}
|
||||
else {
|
||||
CAFFE_THROW("Outputs must only be either c10::Tensor or c10::TensorList!");
|
||||
};
|
||||
}
|
||||
|
||||
private:
|
||||
|
Reference in New Issue
Block a user