mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[iOS][OSS][BE] Add Simulator tests for full JIT (#64851)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/64851 ghstack-source-id: 137970229 Test Plan: CircleCI Reviewed By: hanton, cccclai Differential Revision: D30877963 fbshipit-source-id: 7bb8ade1959b85c3902ba9dc0660cdac8f558d64
This commit is contained in:
committed by
Facebook GitHub Bot
parent
fd09e564d6
commit
f159f12fee
25
.circleci/config.yml
generated
25
.circleci/config.yml
generated
@ -1797,10 +1797,6 @@ jobs:
|
||||
no_output_timeout: "30m"
|
||||
command: |
|
||||
set -e
|
||||
if [ ${BUILD_LITE_INTERPRETER} == 0 ]; then
|
||||
echo "Run Build Test is not for full jit, skipping."
|
||||
exit 0
|
||||
fi
|
||||
PROJ_ROOT=/Users/distiller/project
|
||||
PROFILE=PyTorch_CI_2021
|
||||
# run the ruby build script
|
||||
@ -1826,21 +1822,28 @@ jobs:
|
||||
if [ ${IOS_PLATFORM} != "SIMULATOR" ]; then
|
||||
echo "not SIMULATOR build, skip it."
|
||||
exit 0
|
||||
elif [ ${BUILD_LITE_INTERPRETER} == 0 ]; then
|
||||
echo "Run Simulator Tests is not for full jit, skipping."
|
||||
exit 0
|
||||
fi
|
||||
WORKSPACE=/Users/distiller/workspace
|
||||
PROJ_ROOT=/Users/distiller/project
|
||||
source ~/anaconda/bin/activate
|
||||
pip install torch torchvision --progress-bar off
|
||||
#run unit test
|
||||
# use the pytorch nightly build to generate models
|
||||
conda install pytorch torchvision -c pytorch-nightly --yes
|
||||
# generate models for differnet backends
|
||||
cd ${PROJ_ROOT}/ios/TestApp/benchmark
|
||||
mkdir -p ../models
|
||||
python trace_model.py
|
||||
ruby setup.rb
|
||||
if [ ${BUILD_LITE_INTERPRETER} == 1 ]; then
|
||||
ruby setup.rb --lite 1
|
||||
else
|
||||
ruby setup.rb
|
||||
fi
|
||||
cd ${PROJ_ROOT}/ios/TestApp
|
||||
instruments -s -devices
|
||||
fastlane scan
|
||||
if [ ${BUILD_LITE_INTERPRETER} == 1 ]; then
|
||||
fastlane scan --only_testing TestAppTests/TestAppTests/testLiteInterpreter
|
||||
else
|
||||
fastlane scan --only_testing TestAppTests/TestAppTests/testFullJIT
|
||||
fi
|
||||
pytorch_linux_bazel_build:
|
||||
<<: *pytorch_params
|
||||
machine:
|
||||
|
@ -534,10 +534,6 @@
|
||||
no_output_timeout: "30m"
|
||||
command: |
|
||||
set -e
|
||||
if [ ${BUILD_LITE_INTERPRETER} == 0 ]; then
|
||||
echo "Run Build Test is not for full jit, skipping."
|
||||
exit 0
|
||||
fi
|
||||
PROJ_ROOT=/Users/distiller/project
|
||||
PROFILE=PyTorch_CI_2021
|
||||
# run the ruby build script
|
||||
@ -563,21 +559,28 @@
|
||||
if [ ${IOS_PLATFORM} != "SIMULATOR" ]; then
|
||||
echo "not SIMULATOR build, skip it."
|
||||
exit 0
|
||||
elif [ ${BUILD_LITE_INTERPRETER} == 0 ]; then
|
||||
echo "Run Simulator Tests is not for full jit, skipping."
|
||||
exit 0
|
||||
fi
|
||||
WORKSPACE=/Users/distiller/workspace
|
||||
PROJ_ROOT=/Users/distiller/project
|
||||
source ~/anaconda/bin/activate
|
||||
pip install torch torchvision --progress-bar off
|
||||
#run unit test
|
||||
# use the pytorch nightly build to generate models
|
||||
conda install pytorch torchvision -c pytorch-nightly --yes
|
||||
# generate models for differnet backends
|
||||
cd ${PROJ_ROOT}/ios/TestApp/benchmark
|
||||
mkdir -p ../models
|
||||
python trace_model.py
|
||||
ruby setup.rb
|
||||
if [ ${BUILD_LITE_INTERPRETER} == 1 ]; then
|
||||
ruby setup.rb --lite 1
|
||||
else
|
||||
ruby setup.rb
|
||||
fi
|
||||
cd ${PROJ_ROOT}/ios/TestApp
|
||||
instruments -s -devices
|
||||
fastlane scan
|
||||
if [ ${BUILD_LITE_INTERPRETER} == 1 ]; then
|
||||
fastlane scan --only_testing TestAppTests/TestAppTests/testLiteInterpreter
|
||||
else
|
||||
fastlane scan --only_testing TestAppTests/TestAppTests/testFullJIT
|
||||
fi
|
||||
pytorch_linux_bazel_build:
|
||||
<<: *pytorch_params
|
||||
machine:
|
||||
|
@ -3,7 +3,7 @@
|
||||
archiveVersion = 1;
|
||||
classes = {
|
||||
};
|
||||
objectVersion = 52;
|
||||
objectVersion = 50;
|
||||
objects = {
|
||||
|
||||
/* Begin PBXBuildFile section */
|
||||
@ -13,7 +13,6 @@
|
||||
A06D4CBD232F0DB200763E16 /* Assets.xcassets in Resources */ = {isa = PBXBuildFile; fileRef = A06D4CBC232F0DB200763E16 /* Assets.xcassets */; };
|
||||
A06D4CC0232F0DB200763E16 /* LaunchScreen.storyboard in Resources */ = {isa = PBXBuildFile; fileRef = A06D4CBE232F0DB200763E16 /* LaunchScreen.storyboard */; };
|
||||
A06D4CC3232F0DB200763E16 /* main.m in Sources */ = {isa = PBXBuildFile; fileRef = A06D4CC2232F0DB200763E16 /* main.m */; };
|
||||
A0EA3B0A237FCB72007CEA34 /* TestAppTests.mm in Sources */ = {isa = PBXBuildFile; fileRef = A0EA3B09237FCB72007CEA34 /* TestAppTests.mm */; platformFilter = ios; };
|
||||
/* End PBXBuildFile section */
|
||||
|
||||
/* Begin PBXContainerItemProxy section */
|
||||
@ -39,7 +38,6 @@
|
||||
A06D4CC2232F0DB200763E16 /* main.m */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.c.objc; path = main.m; sourceTree = "<group>"; };
|
||||
A0EA3AFF237FCB08007CEA34 /* TestAppTests.xctest */ = {isa = PBXFileReference; explicitFileType = wrapper.cfbundle; includeInIndex = 0; path = TestAppTests.xctest; sourceTree = BUILT_PRODUCTS_DIR; };
|
||||
A0EA3B03237FCB08007CEA34 /* Info.plist */ = {isa = PBXFileReference; lastKnownFileType = text.plist.xml; path = Info.plist; sourceTree = "<group>"; };
|
||||
A0EA3B09237FCB72007CEA34 /* TestAppTests.mm */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.cpp.objcpp; path = TestAppTests.mm; sourceTree = "<group>"; };
|
||||
/* End PBXFileReference section */
|
||||
|
||||
/* Begin PBXFrameworksBuildPhase section */
|
||||
@ -97,7 +95,6 @@
|
||||
A0EA3B00237FCB08007CEA34 /* TestAppTests */ = {
|
||||
isa = PBXGroup;
|
||||
children = (
|
||||
A0EA3B09237FCB72007CEA34 /* TestAppTests.mm */,
|
||||
A0EA3B03237FCB08007CEA34 /* Info.plist */,
|
||||
);
|
||||
path = TestAppTests;
|
||||
@ -211,7 +208,6 @@
|
||||
isa = PBXSourcesBuildPhase;
|
||||
buildActionMask = 2147483647;
|
||||
files = (
|
||||
A0EA3B0A237FCB72007CEA34 /* TestAppTests.mm in Sources */,
|
||||
);
|
||||
runOnlyForDeploymentPostprocessing = 0;
|
||||
};
|
||||
|
@ -1,45 +0,0 @@
|
||||
#import <XCTest/XCTest.h>
|
||||
|
||||
#include <torch/script.h>
|
||||
#include <torch/csrc/jit/mobile/function.h>
|
||||
#include <torch/csrc/jit/mobile/import.h>
|
||||
#include <torch/csrc/jit/mobile/interpreter.h>
|
||||
#include <torch/csrc/jit/mobile/module.h>
|
||||
#include <torch/csrc/jit/mobile/observer.h>
|
||||
#include "ATen/ATen.h"
|
||||
#include "caffe2/core/timer.h"
|
||||
#include "caffe2/utils/string_utils.h"
|
||||
#include "torch/csrc/autograd/grad_mode.h"
|
||||
|
||||
@interface TestAppTests : XCTestCase
|
||||
|
||||
@end
|
||||
|
||||
@implementation TestAppTests {
|
||||
torch::jit::mobile::Module _module;
|
||||
}
|
||||
|
||||
+ (void)setUp {
|
||||
[super setUp];
|
||||
}
|
||||
|
||||
- (void)setUp {
|
||||
[super setUp];
|
||||
NSString* modelPath = [[NSBundle bundleForClass:[self class]] pathForResource:@"model"
|
||||
ofType:@"ptl"];
|
||||
XCTAssertTrue([NSFileManager.defaultManager fileExistsAtPath:modelPath],
|
||||
@"model.ptl doesn't exist!");
|
||||
_module = torch::jit::_load_for_mobile(modelPath.UTF8String);
|
||||
}
|
||||
|
||||
- (void)testForward {
|
||||
// _module.eval();
|
||||
c10::InferenceMode mode;
|
||||
std::vector<c10::IValue> inputs;
|
||||
inputs.push_back(torch::ones({1, 3, 224, 224}, at::ScalarType::Float));
|
||||
auto outputTensor = _module.forward(inputs).toTensor();
|
||||
float* outputBuffer = outputTensor.data_ptr<float>();
|
||||
XCTAssertTrue(outputBuffer != nullptr, @"");
|
||||
}
|
||||
|
||||
@end
|
22
ios/TestApp/TestAppTests/TestFullJIT.mm
Normal file
22
ios/TestApp/TestAppTests/TestFullJIT.mm
Normal file
@ -0,0 +1,22 @@
|
||||
#import <XCTest/XCTest.h>
|
||||
|
||||
#include <torch/script.h>
|
||||
|
||||
@interface TestAppTests : XCTestCase
|
||||
|
||||
@end
|
||||
|
||||
@implementation TestAppTests {
|
||||
}
|
||||
|
||||
- (void)testFullJIT {
|
||||
NSString* modelPath = [[NSBundle bundleForClass:[self class]] pathForResource:@"model"
|
||||
ofType:@"pt"];
|
||||
auto module = torch::jit::load(modelPath.UTF8String);
|
||||
c10::InferenceMode mode;
|
||||
auto input = torch::ones({1, 3, 224, 224}, at::kFloat);
|
||||
auto outputTensor = module.forward({input}).toTensor();
|
||||
XCTAssertTrue(outputTensor.numel() == 1000);
|
||||
}
|
||||
|
||||
@end
|
24
ios/TestApp/TestAppTests/TestLiteInterpreter.mm
Normal file
24
ios/TestApp/TestAppTests/TestLiteInterpreter.mm
Normal file
@ -0,0 +1,24 @@
|
||||
#import <XCTest/XCTest.h>
|
||||
|
||||
#include <torch/csrc/jit/mobile/import.h>
|
||||
#include <torch/csrc/jit/mobile/module.h>
|
||||
#include <torch/script.h>
|
||||
|
||||
@interface TestAppTests : XCTestCase
|
||||
|
||||
@end
|
||||
|
||||
@implementation TestAppTests {
|
||||
}
|
||||
|
||||
- (void)testLiteInterpreter {
|
||||
NSString* modelPath = [[NSBundle bundleForClass:[self class]] pathForResource:@"model_lite"
|
||||
ofType:@"ptl"];
|
||||
auto module = torch::jit::_load_for_mobile(modelPath.UTF8String);
|
||||
c10::InferenceMode mode;
|
||||
auto input = torch::ones({1, 3, 224, 224}, at::kFloat);
|
||||
auto outputTensor = module.forward({input}).toTensor();
|
||||
XCTAssertTrue(outputTensor.numel() == 1000);
|
||||
}
|
||||
|
||||
@end
|
@ -8,7 +8,12 @@ option_parser = OptionParser.new do |opts|
|
||||
opts.on('-t', '--team_id ', 'development team ID') { |value|
|
||||
options[:team_id] = value
|
||||
}
|
||||
opts.on('-l', '--lite ', 'use lite interpreter') { |value|
|
||||
options[:lite] = value
|
||||
}
|
||||
end.parse!
|
||||
puts options.inspect
|
||||
|
||||
puts "Current directory: #{Dir.pwd}"
|
||||
install_path = File.expand_path("../../../build_ios/install")
|
||||
if not Dir.exist? (install_path)
|
||||
@ -37,26 +42,55 @@ targets.each do |target|
|
||||
end
|
||||
end
|
||||
end
|
||||
puts "Installing the testing model..."
|
||||
model_path = File.expand_path("./model.ptl")
|
||||
if not File.exist?(model_path)
|
||||
raise "model.pt can't be found!"
|
||||
end
|
||||
|
||||
group = project.main_group.find_subpath(File.join('TestApp'),true)
|
||||
group.set_source_tree('SOURCE_ROOT')
|
||||
group.files.each do |file|
|
||||
if (file.name.to_s.end_with?(".pt"))
|
||||
if (file.name.to_s.end_with?(".pt") ||
|
||||
file.name.to_s.end_with?(".ptl"))
|
||||
group.remove_reference(file)
|
||||
targets.each do |target|
|
||||
target.resources_build_phase.remove_file_reference(file)
|
||||
end
|
||||
end
|
||||
end
|
||||
model_file_ref = group.new_reference(model_path)
|
||||
targets.each do |target|
|
||||
target.resources_build_phase.add_file_reference(model_file_ref, true)
|
||||
|
||||
file_refs = []
|
||||
# collect models
|
||||
puts "Installing models..."
|
||||
models_dir = File.expand_path("../models")
|
||||
Dir.foreach(models_dir) do |model|
|
||||
if(model.end_with?(".pt") || model.end_with?(".ptl"))
|
||||
model_path = models_dir + "/" + model
|
||||
file_refs.push(group.new_reference(model_path))
|
||||
end
|
||||
end
|
||||
|
||||
targets.each do |target|
|
||||
file_refs.each do |ref|
|
||||
target.resources_build_phase.add_file_reference(ref, true)
|
||||
end
|
||||
end
|
||||
|
||||
# add test files
|
||||
puts "Adding test files..."
|
||||
testTarget = targets[1]
|
||||
testFilePath = File.expand_path('../TestAppTests/')
|
||||
group = project.main_group.find_subpath(File.join('TestAppTests'),true)
|
||||
group.files.each do |file|
|
||||
if (file.path.end_with?(".mm"))
|
||||
file.remove_from_project
|
||||
end
|
||||
end
|
||||
|
||||
if(options[:lite])
|
||||
file = group.new_file("TestLiteInterpreter.mm")
|
||||
testTarget.add_file_references([file])
|
||||
else
|
||||
file = group.new_file("TestFullJIT.mm")
|
||||
testTarget.add_file_references([file])
|
||||
end
|
||||
|
||||
puts "Linking static libraries..."
|
||||
libs = ['libc10.a', 'libclog.a', 'libpthreadpool.a', 'libXNNPACK.a', 'libeigen_blas.a', 'libcpuinfo.a', 'libpytorch_qnnpack.a', 'libtorch_cpu.a', 'libtorch.a']
|
||||
targets.each do |target|
|
||||
|
@ -5,6 +5,7 @@ from torch.utils.mobile_optimizer import optimize_for_mobile
|
||||
model = torchvision.models.mobilenet_v2(pretrained=True)
|
||||
model.eval()
|
||||
example = torch.rand(1, 3, 224, 224)
|
||||
traced_script_module = torch.jit.script(model, example)
|
||||
traced_script_module = torch.jit.trace(model, example)
|
||||
optimized_scripted_module = optimize_for_mobile(traced_script_module)
|
||||
exported_optimized_scripted_module = optimized_scripted_module._save_for_lite_interpreter("model.ptl")
|
||||
torch.jit.save(optimized_scripted_module, '../models/model.pt')
|
||||
exported_optimized_scripted_module = optimized_scripted_module._save_for_lite_interpreter("../models/model_lite.ptl")
|
||||
|
Reference in New Issue
Block a user