mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Add and test training in lite interpreter. (#32359)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/32359 Test Plan: Imported from OSS Differential Revision: D19450614 Pulled By: iseeyuan fbshipit-source-id: 6bafff39d7880a5b7fb9cd70c33a4e584812be12
This commit is contained in:
committed by
Facebook Github Bot
parent
2ba74b741e
commit
f097ca503d
@ -484,12 +484,13 @@ if (NOT INTERN_BUILD_MOBILE OR NOT BUILD_CAFFE2_MOBILE)
|
|||||||
${TORCH_SRC_DIR}/csrc/jit/mobile/type_parser.cpp
|
${TORCH_SRC_DIR}/csrc/jit/mobile/type_parser.cpp
|
||||||
)
|
)
|
||||||
|
|
||||||
if (NOT INTERN_DISABLE_MOBILE_INTERP)
|
if (NOT INTERN_DISABLE_MOBILE_INTERP AND BUILD_CAFFE2_MOBILE)
|
||||||
set (MOBILE_SRCS
|
set (MOBILE_SRCS
|
||||||
${TORCH_SRC_DIR}/csrc/jit/mobile/function.cpp
|
${TORCH_SRC_DIR}/csrc/jit/mobile/function.cpp
|
||||||
${TORCH_SRC_DIR}/csrc/jit/mobile/import.cpp
|
${TORCH_SRC_DIR}/csrc/jit/mobile/import.cpp
|
||||||
${TORCH_SRC_DIR}/csrc/jit/mobile/module.cpp
|
${TORCH_SRC_DIR}/csrc/jit/mobile/module.cpp
|
||||||
${TORCH_SRC_DIR}/csrc/jit/mobile/register_mobile_ops.cpp
|
${TORCH_SRC_DIR}/csrc/jit/mobile/register_mobile_ops.cpp
|
||||||
|
${TORCH_SRC_DIR}/csrc/jit/mobile/register_mobile_autograd.cpp
|
||||||
${TORCH_SRC_DIR}/csrc/jit/mobile/interpreter.cpp
|
${TORCH_SRC_DIR}/csrc/jit/mobile/interpreter.cpp
|
||||||
)
|
)
|
||||||
list (APPEND TORCH_SRCS ${MOBILE_SRCS})
|
list (APPEND TORCH_SRCS ${MOBILE_SRCS})
|
||||||
|
@ -52,7 +52,6 @@ white_list = [
|
|||||||
('prim::MMTreeReduce', datetime.date(2020, 3, 1)),
|
('prim::MMTreeReduce', datetime.date(2020, 3, 1)),
|
||||||
('prim::Constant', datetime.date(2020, 3, 1)),
|
('prim::Constant', datetime.date(2020, 3, 1)),
|
||||||
('_prim::TupleUnpack', datetime.date(2020, 3, 1)),
|
('_prim::TupleUnpack', datetime.date(2020, 3, 1)),
|
||||||
('_aten::format', datetime.date(2020, 3, 1)),
|
|
||||||
('aten::random_', datetime.date(2020, 3, 1)),
|
('aten::random_', datetime.date(2020, 3, 1)),
|
||||||
('quantized::add_(scalar_)?(relu_)?out', datetime.date(2020, 3, 1)),
|
('quantized::add_(scalar_)?(relu_)?out', datetime.date(2020, 3, 1)),
|
||||||
('quantized::cat_(relu_)?out', datetime.date(2020, 3, 1)),
|
('quantized::cat_(relu_)?out', datetime.date(2020, 3, 1)),
|
||||||
@ -67,6 +66,7 @@ white_list = [
|
|||||||
('aten::ones_like', datetime.date(2020, 3, 15)),
|
('aten::ones_like', datetime.date(2020, 3, 15)),
|
||||||
('aten::randint_like', datetime.date(2020, 3, 15)),
|
('aten::randint_like', datetime.date(2020, 3, 15)),
|
||||||
('aten::zeros_like', datetime.date(2020, 3, 15)),
|
('aten::zeros_like', datetime.date(2020, 3, 15)),
|
||||||
|
('_aten', datetime.date(2020, 4, 1)),
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
|
@ -4,6 +4,8 @@
|
|||||||
#include <torch/csrc/jit/mobile/import.h>
|
#include <torch/csrc/jit/mobile/import.h>
|
||||||
#include <torch/csrc/jit/mobile/module.h>
|
#include <torch/csrc/jit/mobile/module.h>
|
||||||
#include <torch/csrc/jit/serialization/import.h>
|
#include <torch/csrc/jit/serialization/import.h>
|
||||||
|
#include <torch/torch.h>
|
||||||
|
#include <c10/core/TensorOptions.h>
|
||||||
|
|
||||||
// Tests go in torch::jit
|
// Tests go in torch::jit
|
||||||
namespace torch {
|
namespace torch {
|
||||||
@ -207,5 +209,67 @@ void testLiteInterpreterWrongMethodName() {
|
|||||||
ASSERT_THROWS_WITH(bc.run_method("forward", inputs), "is not defined");
|
ASSERT_THROWS_WITH(bc.run_method("forward", inputs), "is not defined");
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
void testLiteInterpreterParams() {
|
||||||
|
script::Module m("m");
|
||||||
|
m.register_parameter("foo", torch::ones({1}, at::requires_grad()), false);
|
||||||
|
m.define(R"(
|
||||||
|
def forward(self, x):
|
||||||
|
b = 1.0
|
||||||
|
return self.foo * x + b
|
||||||
|
)");
|
||||||
|
|
||||||
|
double learning_rate = 0.1, momentum = 0.1;
|
||||||
|
int n_epoc = 10;
|
||||||
|
// init: y = x + 1;
|
||||||
|
// target: y = 2 x + 1
|
||||||
|
std::vector<std::pair<Tensor, Tensor>> trainData{
|
||||||
|
{1 * torch::ones({1}), 3 * torch::ones({1})},
|
||||||
|
};
|
||||||
|
|
||||||
|
// Reference: Full jit
|
||||||
|
std::stringstream ms;
|
||||||
|
m.save(ms);
|
||||||
|
auto mm = load(ms);
|
||||||
|
// mm.train();
|
||||||
|
std::vector<::at::Tensor> parameters;
|
||||||
|
for (auto parameter : mm.parameters()) {
|
||||||
|
parameters.emplace_back(parameter);
|
||||||
|
}
|
||||||
|
::torch::optim::SGD optimizer(
|
||||||
|
parameters,
|
||||||
|
::torch::optim::SGDOptions(learning_rate).momentum(momentum));
|
||||||
|
for (int epoc = 0; epoc < n_epoc; ++epoc) {
|
||||||
|
for (auto &data : trainData) {
|
||||||
|
auto source = data.first, targets = data.second;
|
||||||
|
optimizer.zero_grad();
|
||||||
|
std::vector<IValue> train_inputs{source};
|
||||||
|
auto output = mm.forward(train_inputs).toTensor();
|
||||||
|
auto loss = ::torch::l1_loss(output, targets);
|
||||||
|
loss.backward();
|
||||||
|
optimizer.step();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
std::stringstream ss;
|
||||||
|
m._save_for_mobile(ss);
|
||||||
|
mobile::Module bc = _load_for_mobile(ss);
|
||||||
|
std::vector<::at::Tensor> bc_parameters = bc.parameters();
|
||||||
|
::torch::optim::SGD bc_optimizer(
|
||||||
|
bc_parameters,
|
||||||
|
::torch::optim::SGDOptions(learning_rate).momentum(momentum));
|
||||||
|
for (int epoc = 0; epoc < n_epoc; ++epoc) {
|
||||||
|
for (auto &data : trainData) {
|
||||||
|
auto source = data.first, targets = data.second;
|
||||||
|
bc_optimizer.zero_grad();
|
||||||
|
std::vector<IValue> train_inputs{source};
|
||||||
|
auto output = bc.forward(train_inputs).toTensor();
|
||||||
|
auto loss = ::torch::l1_loss(output, targets);
|
||||||
|
loss.backward();
|
||||||
|
bc_optimizer.step();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
AT_ASSERT(parameters[0].item<float>() == bc_parameters[0].item<float>());
|
||||||
|
}
|
||||||
} // namespace jit
|
} // namespace jit
|
||||||
} // namespace torch
|
} // namespace torch
|
||||||
|
@ -85,7 +85,8 @@ namespace jit {
|
|||||||
_(MobileTypeParser) \
|
_(MobileTypeParser) \
|
||||||
_(LiteInterpreterPrim) \
|
_(LiteInterpreterPrim) \
|
||||||
_(LiteInterpreterLoadOrigJit) \
|
_(LiteInterpreterLoadOrigJit) \
|
||||||
_(LiteInterpreterWrongMethodName)
|
_(LiteInterpreterWrongMethodName) \
|
||||||
|
_(LiteInterpreterParams)
|
||||||
|
|
||||||
#define TH_FORALL_TESTS_CUDA(_) \
|
#define TH_FORALL_TESTS_CUDA(_) \
|
||||||
_(ArgumentSpec) \
|
_(ArgumentSpec) \
|
||||||
|
@ -31,9 +31,13 @@ using namespace torch::autograd::generated;
|
|||||||
namespace torch { namespace autograd {
|
namespace torch { namespace autograd {
|
||||||
|
|
||||||
namespace VariableType {
|
namespace VariableType {
|
||||||
namespace {
|
// Comment the anonymous namespace so that the generated functions
|
||||||
|
// can be accessed from outside of the files (register_mobile_autograd.cpp).
|
||||||
|
// Later when we merge the mobile op registration the anonymous namespace
|
||||||
|
// will be restored.
|
||||||
|
// namespace {
|
||||||
${type_derived_method_definitions}
|
${type_derived_method_definitions}
|
||||||
}
|
// }
|
||||||
}
|
}
|
||||||
|
|
||||||
namespace {
|
namespace {
|
||||||
|
@ -189,6 +189,7 @@ libtorch_sources = [
|
|||||||
"torch/csrc/jit/mobile/import.cpp",
|
"torch/csrc/jit/mobile/import.cpp",
|
||||||
"torch/csrc/jit/mobile/module.cpp",
|
"torch/csrc/jit/mobile/module.cpp",
|
||||||
"torch/csrc/jit/mobile/register_mobile_ops.cpp",
|
"torch/csrc/jit/mobile/register_mobile_ops.cpp",
|
||||||
|
"torch/csrc/jit/mobile/register_mobile_autograd.cpp",
|
||||||
"torch/csrc/jit/mobile/interpreter.cpp",
|
"torch/csrc/jit/mobile/interpreter.cpp",
|
||||||
"torch/csrc/jit/mobile/type_parser.cpp",
|
"torch/csrc/jit/mobile/type_parser.cpp",
|
||||||
"torch/csrc/jit/tensorexpr/codegen.cpp",
|
"torch/csrc/jit/tensorexpr/codegen.cpp",
|
||||||
|
@ -52,6 +52,25 @@ Function* Module::find_method(const std::string& basename) const {
|
|||||||
AT_ERROR("Method '", basename, "' is not defined.");
|
AT_ERROR("Method '", basename, "' is not defined.");
|
||||||
}
|
}
|
||||||
|
|
||||||
|
namespace {
|
||||||
|
void slot_params_recurse(
|
||||||
|
const c10::intrusive_ptr<c10::ivalue::Object>& obj,
|
||||||
|
std::vector<at::Tensor>* params) {
|
||||||
|
for (const auto& slot : obj->slots()) {
|
||||||
|
if (slot.isTensor()) {
|
||||||
|
params->emplace_back(slot.toTensor());
|
||||||
|
} else if (slot.isObject()) {
|
||||||
|
slot_params_recurse(slot.toObject(), params);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} // namespace
|
||||||
|
|
||||||
|
const std::vector<at::Tensor> Module::parameters() const {
|
||||||
|
std::vector<at::Tensor> params;
|
||||||
|
slot_params_recurse(object_, ¶ms);
|
||||||
|
return params;
|
||||||
|
}
|
||||||
} // namespace mobile
|
} // namespace mobile
|
||||||
} // namespace torch
|
|
||||||
} // namespace jit
|
} // namespace jit
|
||||||
|
} // namespace torch
|
||||||
|
@ -28,7 +28,9 @@ class TORCH_API Module {
|
|||||||
}
|
}
|
||||||
Function* find_method(const std::string& basename) const;
|
Function* find_method(const std::string& basename) const;
|
||||||
std::string name() {return object_->name();}
|
std::string name() {return object_->name();}
|
||||||
private:
|
const std::vector<at::IValue>& slots() const {return object_->slots();}
|
||||||
|
const std::vector<at::Tensor> parameters() const;
|
||||||
|
private:
|
||||||
c10::intrusive_ptr<c10::ivalue::Object> object_;
|
c10::intrusive_ptr<c10::ivalue::Object> object_;
|
||||||
std::shared_ptr<CompilationUnit> cu_;
|
std::shared_ptr<CompilationUnit> cu_;
|
||||||
};
|
};
|
||||||
|
28
torch/csrc/jit/mobile/register_mobile_autograd.cpp
Normal file
28
torch/csrc/jit/mobile/register_mobile_autograd.cpp
Normal file
@ -0,0 +1,28 @@
|
|||||||
|
#include <ATen/core/op_registration/op_registration.h>
|
||||||
|
#include <ATen/ATen.h>
|
||||||
|
#include <ATen/core/stack.h>
|
||||||
|
#include <ATen/TypeDefault.h>
|
||||||
|
|
||||||
|
using Stack = std::vector<c10::IValue>;
|
||||||
|
using at::Tensor;
|
||||||
|
using at::Scalar;
|
||||||
|
|
||||||
|
namespace torch {
|
||||||
|
namespace autograd {
|
||||||
|
namespace VariableType {
|
||||||
|
Tensor mul_Tensor(const Tensor &self, const Tensor &other);
|
||||||
|
Tensor add_Scalar(const Tensor &self, Scalar other, Scalar alpha);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
namespace {
|
||||||
|
static auto registry = torch::RegisterOperators().op(
|
||||||
|
"_aten::add.Scalar",
|
||||||
|
torch::RegisterOperators::options().kernel(c10::DispatchKey::VariableTensorId, &torch::autograd::VariableType::add_Scalar)
|
||||||
|
).op(
|
||||||
|
"_aten::mul.Tensor(Tensor self, Tensor other) -> Tensor",
|
||||||
|
torch::RegisterOperators::options().kernel(c10::DispatchKey::VariableTensorId, &torch::autograd::VariableType::mul_Tensor)
|
||||||
|
.aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)
|
||||||
|
);
|
||||||
|
}
|
Reference in New Issue
Block a user