Add and test training in lite interpreter. (#32359)

Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/32359

Test Plan: Imported from OSS

Differential Revision: D19450614

Pulled By: iseeyuan

fbshipit-source-id: 6bafff39d7880a5b7fb9cd70c33a4e584812be12
This commit is contained in:
Martin Yuan
2020-03-03 23:31:03 -08:00
committed by Facebook Github Bot
parent 2ba74b741e
commit f097ca503d
9 changed files with 127 additions and 7 deletions

View File

@ -484,12 +484,13 @@ if (NOT INTERN_BUILD_MOBILE OR NOT BUILD_CAFFE2_MOBILE)
${TORCH_SRC_DIR}/csrc/jit/mobile/type_parser.cpp ${TORCH_SRC_DIR}/csrc/jit/mobile/type_parser.cpp
) )
if (NOT INTERN_DISABLE_MOBILE_INTERP) if (NOT INTERN_DISABLE_MOBILE_INTERP AND BUILD_CAFFE2_MOBILE)
set (MOBILE_SRCS set (MOBILE_SRCS
${TORCH_SRC_DIR}/csrc/jit/mobile/function.cpp ${TORCH_SRC_DIR}/csrc/jit/mobile/function.cpp
${TORCH_SRC_DIR}/csrc/jit/mobile/import.cpp ${TORCH_SRC_DIR}/csrc/jit/mobile/import.cpp
${TORCH_SRC_DIR}/csrc/jit/mobile/module.cpp ${TORCH_SRC_DIR}/csrc/jit/mobile/module.cpp
${TORCH_SRC_DIR}/csrc/jit/mobile/register_mobile_ops.cpp ${TORCH_SRC_DIR}/csrc/jit/mobile/register_mobile_ops.cpp
${TORCH_SRC_DIR}/csrc/jit/mobile/register_mobile_autograd.cpp
${TORCH_SRC_DIR}/csrc/jit/mobile/interpreter.cpp ${TORCH_SRC_DIR}/csrc/jit/mobile/interpreter.cpp
) )
list (APPEND TORCH_SRCS ${MOBILE_SRCS}) list (APPEND TORCH_SRCS ${MOBILE_SRCS})

View File

@ -52,7 +52,6 @@ white_list = [
('prim::MMTreeReduce', datetime.date(2020, 3, 1)), ('prim::MMTreeReduce', datetime.date(2020, 3, 1)),
('prim::Constant', datetime.date(2020, 3, 1)), ('prim::Constant', datetime.date(2020, 3, 1)),
('_prim::TupleUnpack', datetime.date(2020, 3, 1)), ('_prim::TupleUnpack', datetime.date(2020, 3, 1)),
('_aten::format', datetime.date(2020, 3, 1)),
('aten::random_', datetime.date(2020, 3, 1)), ('aten::random_', datetime.date(2020, 3, 1)),
('quantized::add_(scalar_)?(relu_)?out', datetime.date(2020, 3, 1)), ('quantized::add_(scalar_)?(relu_)?out', datetime.date(2020, 3, 1)),
('quantized::cat_(relu_)?out', datetime.date(2020, 3, 1)), ('quantized::cat_(relu_)?out', datetime.date(2020, 3, 1)),
@ -67,6 +66,7 @@ white_list = [
('aten::ones_like', datetime.date(2020, 3, 15)), ('aten::ones_like', datetime.date(2020, 3, 15)),
('aten::randint_like', datetime.date(2020, 3, 15)), ('aten::randint_like', datetime.date(2020, 3, 15)),
('aten::zeros_like', datetime.date(2020, 3, 15)), ('aten::zeros_like', datetime.date(2020, 3, 15)),
('_aten', datetime.date(2020, 4, 1)),
] ]

View File

@ -4,6 +4,8 @@
#include <torch/csrc/jit/mobile/import.h> #include <torch/csrc/jit/mobile/import.h>
#include <torch/csrc/jit/mobile/module.h> #include <torch/csrc/jit/mobile/module.h>
#include <torch/csrc/jit/serialization/import.h> #include <torch/csrc/jit/serialization/import.h>
#include <torch/torch.h>
#include <c10/core/TensorOptions.h>
// Tests go in torch::jit // Tests go in torch::jit
namespace torch { namespace torch {
@ -207,5 +209,67 @@ void testLiteInterpreterWrongMethodName() {
ASSERT_THROWS_WITH(bc.run_method("forward", inputs), "is not defined"); ASSERT_THROWS_WITH(bc.run_method("forward", inputs), "is not defined");
} }
void testLiteInterpreterParams() {
script::Module m("m");
m.register_parameter("foo", torch::ones({1}, at::requires_grad()), false);
m.define(R"(
def forward(self, x):
b = 1.0
return self.foo * x + b
)");
double learning_rate = 0.1, momentum = 0.1;
int n_epoc = 10;
// init: y = x + 1;
// target: y = 2 x + 1
std::vector<std::pair<Tensor, Tensor>> trainData{
{1 * torch::ones({1}), 3 * torch::ones({1})},
};
// Reference: Full jit
std::stringstream ms;
m.save(ms);
auto mm = load(ms);
// mm.train();
std::vector<::at::Tensor> parameters;
for (auto parameter : mm.parameters()) {
parameters.emplace_back(parameter);
}
::torch::optim::SGD optimizer(
parameters,
::torch::optim::SGDOptions(learning_rate).momentum(momentum));
for (int epoc = 0; epoc < n_epoc; ++epoc) {
for (auto &data : trainData) {
auto source = data.first, targets = data.second;
optimizer.zero_grad();
std::vector<IValue> train_inputs{source};
auto output = mm.forward(train_inputs).toTensor();
auto loss = ::torch::l1_loss(output, targets);
loss.backward();
optimizer.step();
}
}
std::stringstream ss;
m._save_for_mobile(ss);
mobile::Module bc = _load_for_mobile(ss);
std::vector<::at::Tensor> bc_parameters = bc.parameters();
::torch::optim::SGD bc_optimizer(
bc_parameters,
::torch::optim::SGDOptions(learning_rate).momentum(momentum));
for (int epoc = 0; epoc < n_epoc; ++epoc) {
for (auto &data : trainData) {
auto source = data.first, targets = data.second;
bc_optimizer.zero_grad();
std::vector<IValue> train_inputs{source};
auto output = bc.forward(train_inputs).toTensor();
auto loss = ::torch::l1_loss(output, targets);
loss.backward();
bc_optimizer.step();
}
}
AT_ASSERT(parameters[0].item<float>() == bc_parameters[0].item<float>());
}
} // namespace jit } // namespace jit
} // namespace torch } // namespace torch

View File

@ -85,7 +85,8 @@ namespace jit {
_(MobileTypeParser) \ _(MobileTypeParser) \
_(LiteInterpreterPrim) \ _(LiteInterpreterPrim) \
_(LiteInterpreterLoadOrigJit) \ _(LiteInterpreterLoadOrigJit) \
_(LiteInterpreterWrongMethodName) _(LiteInterpreterWrongMethodName) \
_(LiteInterpreterParams)
#define TH_FORALL_TESTS_CUDA(_) \ #define TH_FORALL_TESTS_CUDA(_) \
_(ArgumentSpec) \ _(ArgumentSpec) \

View File

@ -31,9 +31,13 @@ using namespace torch::autograd::generated;
namespace torch { namespace autograd { namespace torch { namespace autograd {
namespace VariableType { namespace VariableType {
namespace { // Comment the anonymous namespace so that the generated functions
// can be accessed from outside of the files (register_mobile_autograd.cpp).
// Later when we merge the mobile op registration the anonymous namespace
// will be restored.
// namespace {
${type_derived_method_definitions} ${type_derived_method_definitions}
} // }
} }
namespace { namespace {

View File

@ -189,6 +189,7 @@ libtorch_sources = [
"torch/csrc/jit/mobile/import.cpp", "torch/csrc/jit/mobile/import.cpp",
"torch/csrc/jit/mobile/module.cpp", "torch/csrc/jit/mobile/module.cpp",
"torch/csrc/jit/mobile/register_mobile_ops.cpp", "torch/csrc/jit/mobile/register_mobile_ops.cpp",
"torch/csrc/jit/mobile/register_mobile_autograd.cpp",
"torch/csrc/jit/mobile/interpreter.cpp", "torch/csrc/jit/mobile/interpreter.cpp",
"torch/csrc/jit/mobile/type_parser.cpp", "torch/csrc/jit/mobile/type_parser.cpp",
"torch/csrc/jit/tensorexpr/codegen.cpp", "torch/csrc/jit/tensorexpr/codegen.cpp",

View File

@ -52,6 +52,25 @@ Function* Module::find_method(const std::string& basename) const {
AT_ERROR("Method '", basename, "' is not defined."); AT_ERROR("Method '", basename, "' is not defined.");
} }
namespace {
void slot_params_recurse(
const c10::intrusive_ptr<c10::ivalue::Object>& obj,
std::vector<at::Tensor>* params) {
for (const auto& slot : obj->slots()) {
if (slot.isTensor()) {
params->emplace_back(slot.toTensor());
} else if (slot.isObject()) {
slot_params_recurse(slot.toObject(), params);
}
}
}
} // namespace
const std::vector<at::Tensor> Module::parameters() const {
std::vector<at::Tensor> params;
slot_params_recurse(object_, &params);
return params;
}
} // namespace mobile } // namespace mobile
} // namespace torch
} // namespace jit } // namespace jit
} // namespace torch

View File

@ -28,7 +28,9 @@ class TORCH_API Module {
} }
Function* find_method(const std::string& basename) const; Function* find_method(const std::string& basename) const;
std::string name() {return object_->name();} std::string name() {return object_->name();}
private: const std::vector<at::IValue>& slots() const {return object_->slots();}
const std::vector<at::Tensor> parameters() const;
private:
c10::intrusive_ptr<c10::ivalue::Object> object_; c10::intrusive_ptr<c10::ivalue::Object> object_;
std::shared_ptr<CompilationUnit> cu_; std::shared_ptr<CompilationUnit> cu_;
}; };

View File

@ -0,0 +1,28 @@
#include <ATen/core/op_registration/op_registration.h>
#include <ATen/ATen.h>
#include <ATen/core/stack.h>
#include <ATen/TypeDefault.h>
using Stack = std::vector<c10::IValue>;
using at::Tensor;
using at::Scalar;
namespace torch {
namespace autograd {
namespace VariableType {
Tensor mul_Tensor(const Tensor &self, const Tensor &other);
Tensor add_Scalar(const Tensor &self, Scalar other, Scalar alpha);
}
}
}
namespace {
static auto registry = torch::RegisterOperators().op(
"_aten::add.Scalar",
torch::RegisterOperators::options().kernel(c10::DispatchKey::VariableTensorId, &torch::autograd::VariableType::add_Scalar)
).op(
"_aten::mul.Tensor(Tensor self, Tensor other) -> Tensor",
torch::RegisterOperators::options().kernel(c10::DispatchKey::VariableTensorId, &torch::autograd::VariableType::mul_Tensor)
.aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)
);
}