diff --git a/binaries/speed_benchmark_torch.cc b/binaries/speed_benchmark_torch.cc index e317d74ca249..bb160d6adb6f 100644 --- a/binaries/speed_benchmark_torch.cc +++ b/binaries/speed_benchmark_torch.cc @@ -184,7 +184,27 @@ class vkRunner final : public Runner { inputs_.clear(); inputs_.reserve(inputs.size()); for (const auto& input : inputs) { - inputs_.emplace_back(input.toTensor().vulkan()); + if (input.isTensor()) { + inputs_.emplace_back(input.toTensor().vulkan()); + } + else if (input.isList()) { + const c10::List input_as_list = input.toList(); + c10::List input_vk_list; + input_vk_list.reserve(input_as_list.size()); + for (int i=0; i < input_as_list.size(); ++i) { + const c10::IValue element = input_as_list.get(i); + if (element.isTensor()) { + input_vk_list.emplace_back(element.toTensor().vulkan()); + } + else { + CAFFE_THROW("Input of type c10::List must only contain Tensors!"); + } + } + inputs_.emplace_back(c10::IValue(input_vk_list)); + } + else { + CAFFE_THROW("Inputs must only contain IValues of type c10::Tensor or c10::List!"); + } } // Run, and download the output tensor to system memory.