diff --git a/binaries/speed_benchmark_torch.cc b/binaries/speed_benchmark_torch.cc index e4eee10636e3..ea523898b51e 100644 --- a/binaries/speed_benchmark_torch.cc +++ b/binaries/speed_benchmark_torch.cc @@ -180,35 +180,48 @@ class vkRunner final : public Runner { virtual c10::IValue run( T& module, const std::vector& inputs) override { - // Upload the input tensor(s) to GPU memory. - inputs_.clear(); - inputs_.reserve(inputs.size()); - for (const auto& input : inputs) { - if (input.isTensor()) { - inputs_.emplace_back(input.toTensor().vulkan()); - } - else if (input.isList()) { - const c10::List input_as_list = input.toList(); - c10::List input_vk_list; - input_vk_list.reserve(input_as_list.size()); - for (int i=0; i < input_as_list.size(); ++i) { - const c10::IValue element = input_as_list.get(i); - if (element.isTensor()) { - input_vk_list.emplace_back(element.toTensor().vulkan()); - } - else { - CAFFE_THROW("Input of type c10::List must only contain Tensors!"); - } + + if (inputs_.size() == 0) { + // Upload the input tensor(s) to GPU memory. + inputs_.clear(); + inputs_.reserve(inputs.size()); + for (const auto& input : inputs) { + if (input.isTensor()) { + inputs_.emplace_back(at::rand(input.toTensor().sizes()).vulkan()); + } + else if (input.isTensorList()) { + const c10::List input_as_list = input.toTensorList(); + c10::List input_vk_list; + input_vk_list.reserve(input_as_list.size()); + for (int i=0; i < input_as_list.size(); ++i) { + const at::Tensor element = input_as_list.get(i); + input_vk_list.emplace_back(at::rand(element.sizes()).vulkan()); + } + inputs_.emplace_back(c10::IValue(input_vk_list)); + } + else { + CAFFE_THROW("Inputs must only contain IValues of type c10::Tensor or c10::TensorList!"); } - inputs_.emplace_back(c10::IValue(input_vk_list)); - } - else { - CAFFE_THROW("Inputs must only contain IValues of type c10::Tensor or c10::List!"); } } // Run, and download the output tensor to system memory. - return module.forward(inputs_).toTensor().cpu(); + c10::IValue output = module.forward(inputs_); + if (output.isTensor()) { + return output.toTensor().cpu(); + } + else if (output.isTensorList()) { + return output.toTensorList().get(0).cpu(); + } + else if (output.isList()) { + return output.toList().get(0).toTensor().cpu(); + } + else if (output.isTuple()) { + return output.toTuple()->elements()[0].toTensor().cpu(); + } + else { + CAFFE_THROW("Outputs must only be either c10::Tensor or c10::TensorList!"); + }; } private: