mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 12:54:11 +08:00
Add dynamic shape support to AOT driver & compiler (#72995)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/72995 Add ability to specify input dimensions that need to be dynamic. Example: if dim 115 can be dynamic in input sizes "1,115;1", then specify dynamic_dims as "115" Also recompile and update CI models and some asm code as the old ones don't compile with compiler changes in context.cpp Test Plan: - Compiles and runs BI Bytedoc model with and without dynamic inputs. Reviewed By: ZolotukhinM Differential Revision: D34233121 fbshipit-source-id: 35095e549ebd6d3bec98b9abb3f0764366a0ff6f (cherry picked from commit 33166a9f9ac9194b5df0a35280b57708df255ebd)
This commit is contained in:
committed by
PyTorch MergeBot
parent
5a7778c9a6
commit
ac97e953b4
@ -36,6 +36,10 @@ C10_DEFINE_string(
|
||||
"Input memory format."
|
||||
"If multiple inputs needed, use semicolon to separate."
|
||||
"Supported values: contiguous, channels_last");
|
||||
C10_DEFINE_string(
|
||||
dynamic_dims,
|
||||
"",
|
||||
"Comma separated dimensions of input tensors that can be dynamic");
|
||||
C10_DEFINE_string(method_name, "forward", "The name of the method.");
|
||||
C10_DEFINE_string(
|
||||
output_llvm,
|
||||
@ -68,6 +72,7 @@ c10::Dict<c10::IValue, c10::IValue> createCompileSpec() {
|
||||
method_spec.insert("sizes", FLAGS_input_dims);
|
||||
method_spec.insert("types", FLAGS_input_types);
|
||||
method_spec.insert("memory_formats", FLAGS_input_memory_formats);
|
||||
method_spec.insert("dynamic_sizes", FLAGS_dynamic_dims);
|
||||
method_spec.insert("asmfile", FLAGS_output_llvm);
|
||||
method_spec.insert("model_name", FLAGS_model_name);
|
||||
method_spec.insert("model_version", FLAGS_model_version);
|
||||
|
@ -23,7 +23,9 @@ c10::Dict<c10::IValue, c10::IValue> create_compile_spec(
|
||||
const std::string& method_name,
|
||||
const std::string& model_name,
|
||||
const std::string& input_shapes,
|
||||
const std::string& input_types) {
|
||||
const std::string& input_types,
|
||||
const std::string& memory_formats,
|
||||
const std::string& dynamic_sizes) {
|
||||
c10::Dict<c10::IValue, c10::IValue> method_spec(
|
||||
c10::StringType::get(), c10::AnyType::get());
|
||||
|
||||
@ -33,6 +35,8 @@ c10::Dict<c10::IValue, c10::IValue> create_compile_spec(
|
||||
method_spec.insert("model_version", "v1");
|
||||
method_spec.insert("asmfile", "fake_nnc_model.s");
|
||||
method_spec.insert("arch", "x86-64");
|
||||
method_spec.insert("memory_formats", memory_formats);
|
||||
method_spec.insert("dynamic_sizes", dynamic_sizes);
|
||||
|
||||
c10::Dict<c10::IValue, c10::IValue> compile_spec(
|
||||
c10::StringType::get(), c10::AnyType::get());
|
||||
@ -63,7 +67,7 @@ REGISTER_NNC_KERNEL(
|
||||
|
||||
TEST(NNCBackendTest, AOTCompileThenExecute) {
|
||||
torch::jit::Module m("m");
|
||||
auto param = torch::ones({});
|
||||
auto param = torch::ones({1});
|
||||
m.register_parameter("param", param, false);
|
||||
m.define(R"(
|
||||
def forward(self, input):
|
||||
@ -77,7 +81,7 @@ TEST(NNCBackendTest, AOTCompileThenExecute) {
|
||||
|
||||
// Compile the model with NNC.
|
||||
auto compile_spec = create_compile_spec(
|
||||
"forward", "_add_kernel_nnc_fake_model", "4,4", "float");
|
||||
"forward", "_add_kernel_nnc_fake_model", "4,4", "float", "", "");
|
||||
auto any_dict_ty =
|
||||
c10::DictType::create(c10::StringType::get(), c10::AnyType::get());
|
||||
auto frozen_m = torch::jit::freeze_module(m.clone());
|
||||
|
@ -43,9 +43,18 @@ std::vector<int64_t> getConstSizes(const BufPtr b) {
|
||||
|
||||
// Construct input-specs vector from the inputs of the original graph
|
||||
std::vector<mobile::nnc::InputSpec> toInputSpecs(
|
||||
const std::shared_ptr<Graph>& g) {
|
||||
const std::shared_ptr<tensorexpr::TensorExprKernel>& kernel) {
|
||||
const std::shared_ptr<Graph>& g = kernel->graph();
|
||||
std::vector<mobile::nnc::InputSpec> specs;
|
||||
for (auto v : g->inputs()) {
|
||||
|
||||
// Graph inputs include scalar values for symbolic shapes, for which we
|
||||
// don't need input specs. These scalar values come last among the graph
|
||||
// inputs
|
||||
auto num_inputs =
|
||||
g->inputs().size() - kernel->getSymbolicShapeInputs().size();
|
||||
|
||||
for (int i = 0; i < num_inputs; i++) {
|
||||
auto v = g->inputs()[i];
|
||||
const auto& t = v->type();
|
||||
mobile::nnc::InputSpec spec;
|
||||
TORCH_CHECK(t->kind() == TypeKind::TensorType, "Unsupported input type");
|
||||
@ -120,7 +129,7 @@ std::unique_ptr<Function> compileMethod(
|
||||
const std::vector<at::ScalarType>& types) {
|
||||
auto func = std::make_unique<Function>();
|
||||
func->set_name(method_name);
|
||||
func->set_input_specs(toInputSpecs(kernel->graph()));
|
||||
func->set_input_specs(toInputSpecs(kernel));
|
||||
|
||||
auto params = c10::impl::GenericList(c10::AnyType::get());
|
||||
auto const_descriptors = kernel->getConstantDescriptors();
|
||||
@ -177,18 +186,33 @@ std::pair<std::unique_ptr<Function>, const std::string> aotCompile(
|
||||
std::shared_ptr<Graph>& g,
|
||||
const std::vector<std::vector<int64_t>>& sizes,
|
||||
const std::vector<at::ScalarType>& types,
|
||||
const std::string& kernel_func_name) {
|
||||
const std::string& kernel_func_name,
|
||||
const std::vector<int64_t>& symbolic_ind) {
|
||||
GRAPH_DEBUG("Input sizes ", sizes);
|
||||
GRAPH_DEBUG("Input types ", types);
|
||||
GRAPH_DEBUG("Method name ", method_name);
|
||||
GRAPH_DEBUG("Kernel func name ", kernel_func_name);
|
||||
GRAPH_DEBUG("Symbolic indices ", symbolic_ind);
|
||||
|
||||
std::shared_ptr<tensorexpr::TensorExprKernel> kernel =
|
||||
std::make_shared<tensorexpr::TensorExprKernel>(
|
||||
TensorExprKernel(g, kernel_func_name));
|
||||
std::shared_ptr<tensorexpr::TensorExprKernel> kernel;
|
||||
std::vector<torch::jit::StrideInput> stride_desc = {
|
||||
torch::jit::StrideInput::TENSOR_CONT};
|
||||
std::unordered_map<
|
||||
const torch::jit::Value*,
|
||||
std::vector<torch::jit::StrideInput>>
|
||||
symbolic_strides;
|
||||
if (!symbolic_ind.empty()) {
|
||||
for (auto i : g->inputs()) {
|
||||
symbolic_strides[i] = stride_desc;
|
||||
}
|
||||
for (auto o : g->outputs()) {
|
||||
symbolic_strides[o] = stride_desc;
|
||||
}
|
||||
}
|
||||
kernel = std::make_shared<tensorexpr::TensorExprKernel>(TensorExprKernel(
|
||||
g, kernel_func_name, {}, symbolic_ind, false, symbolic_strides));
|
||||
|
||||
const std::string compiled_assembly = kernel->getCodeText();
|
||||
|
||||
auto func = compileMethod(kernel, method_name, sizes, types);
|
||||
return std::make_pair(std::move(func), compiled_assembly);
|
||||
}
|
||||
@ -271,6 +295,17 @@ std::vector<at::MemoryFormat> parseInputMemoryFormats(
|
||||
return memFormats;
|
||||
}
|
||||
|
||||
std::vector<int64_t> parseInputDynamicShapes(
|
||||
const std::string& dynamic_dims_s) {
|
||||
std::vector<std::string> dynamic_dims_list = split(',', dynamic_dims_s);
|
||||
std::vector<int64_t> dynamic_dims;
|
||||
dynamic_dims.reserve(dynamic_dims_list.size());
|
||||
for (const auto& dim : dynamic_dims_list) {
|
||||
dynamic_dims.push_back(c10::stoi(dim));
|
||||
}
|
||||
return dynamic_dims;
|
||||
}
|
||||
|
||||
std::string getNncKernelId(
|
||||
const std::string& model_name,
|
||||
const std::string& model_version,
|
||||
@ -288,9 +323,12 @@ std::string getNncKernelFuncName(
|
||||
return "nnc_" + model_name + "_" + model_version + "_" + method_name;
|
||||
}
|
||||
|
||||
std::shared_ptr<Graph> preprocessGraphPasses(
|
||||
// Preprocess the graph and returns the processed graph and
|
||||
// symbolic values if dynamic input shapes are specified
|
||||
std::pair<std::shared_ptr<Graph>, std::vector<int64_t>> preprocessGraphPasses(
|
||||
std::shared_ptr<Graph>& graph,
|
||||
const std::vector<c10::optional<at::Tensor>>& example_inputs) {
|
||||
const std::vector<c10::optional<at::Tensor>>& example_inputs,
|
||||
const std::vector<int64_t>& dynamic_sizes) {
|
||||
GRAPH_DEBUG("Before preprocessing graph passes: ", *graph);
|
||||
torch::jit::RemoveTensorMutation(graph);
|
||||
torch::jit::EliminateDeadCode(graph->block());
|
||||
@ -321,8 +359,12 @@ std::shared_ptr<Graph> preprocessGraphPasses(
|
||||
RemoveTensorMutation(graph);
|
||||
EliminateDeadCode(graph);
|
||||
LowerAllTuples(graph);
|
||||
|
||||
auto sym_val =
|
||||
torch::jit::tensorexpr::makeShapesSymbolic(graph, dynamic_sizes);
|
||||
|
||||
GRAPH_DEBUG("After preprocessing graph passes: ", *graph);
|
||||
return graph;
|
||||
return std::make_pair(graph, sym_val);
|
||||
}
|
||||
|
||||
std::vector<c10::optional<at::Tensor>> generateExampleInputs(
|
||||
@ -335,8 +377,7 @@ std::vector<c10::optional<at::Tensor>> generateExampleInputs(
|
||||
const auto dtype = at::dtype(inputTypes[i]);
|
||||
const auto memory_format = inputMemoryFormats[i];
|
||||
example_inputs.emplace_back(
|
||||
at::rand(inputShapes[i], at::TensorOptions(dtype))
|
||||
.contiguous(memory_format));
|
||||
at::rand(inputShapes[i]).to(dtype).contiguous(memory_format));
|
||||
}
|
||||
return example_inputs;
|
||||
}
|
||||
@ -364,6 +405,8 @@ c10::IValue preprocess(
|
||||
|
||||
auto sizes = parseInputShapes(*method_spec.at("sizes").toString());
|
||||
auto types = parseInputTypes(*method_spec.at("types").toString());
|
||||
auto dynamic_sizes =
|
||||
parseInputDynamicShapes(*method_spec.at("dynamic_sizes").toString());
|
||||
|
||||
std::string memory_formats_str = method_spec.contains("memory_formats")
|
||||
? (*method_spec.at("memory_formats").toString()).string()
|
||||
@ -374,12 +417,20 @@ c10::IValue preprocess(
|
||||
: parseInputMemoryFormats(memory_formats_str);
|
||||
|
||||
auto example_inputs = generateExampleInputs(sizes, types, memory_formats);
|
||||
graph = preprocessGraphPasses(graph, example_inputs);
|
||||
auto preprocessed =
|
||||
preprocessGraphPasses(graph, example_inputs, dynamic_sizes);
|
||||
|
||||
auto kernel_func_name =
|
||||
getNncKernelFuncName(model_name, model_version, method_name);
|
||||
auto processed_graph = preprocessed.first;
|
||||
auto sym_values = preprocessed.second;
|
||||
auto compiled = torch::jit::mobile::nnc::aotCompile(
|
||||
method_name, graph, sizes, types, kernel_func_name);
|
||||
method_name,
|
||||
processed_graph,
|
||||
sizes,
|
||||
types,
|
||||
kernel_func_name,
|
||||
sym_values);
|
||||
writeOutputLlvmAssembly(compiled.second, asmfile_name);
|
||||
auto func = std::move(compiled.first);
|
||||
func->set_nnc_kernel_id(
|
||||
|
@ -41,7 +41,17 @@ c10::IValue InputSpec::serialize() const {
|
||||
}
|
||||
|
||||
bool InputSpec::validate(const at::Tensor& input) const {
|
||||
return input.sizes() == sizes_ && input.scalar_type() == dtype_;
|
||||
if (sizes_.size() != input.sizes().size() || input.scalar_type() != dtype_) {
|
||||
return false;
|
||||
}
|
||||
auto spec_sizes = sizes_;
|
||||
for (int i = 0; i < spec_sizes.size(); i++) {
|
||||
// InputSpec size 0 means that the dimension is dynamic
|
||||
if (spec_sizes[i] != 0 && spec_sizes[i] != input.sizes()[i]) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
OutputSpec::OutputSpec(const c10::IValue& value) {
|
||||
@ -136,6 +146,14 @@ Function::Function(const c10::IValue& value) {
|
||||
|
||||
// memory_plan_
|
||||
memory_plan_ = MemoryPlan(dict.at("memory_plan"));
|
||||
|
||||
// symbolic shape positions
|
||||
for (const auto& sym_shape_pos :
|
||||
dict.at("sym_shape_pos").toTupleRef().elements()) {
|
||||
auto sym_shape_elements = sym_shape_pos.toTupleRef().elements();
|
||||
sym_shape_positions_.emplace_back(
|
||||
sym_shape_elements[0].toInt(), sym_shape_elements[1].toInt());
|
||||
}
|
||||
}
|
||||
|
||||
c10::IValue Function::serialize() const {
|
||||
@ -185,18 +203,20 @@ void Function::init_execution_state() const {
|
||||
ExecutionState state;
|
||||
memory_plan_.allocate(&state);
|
||||
|
||||
// The arguments vector consists of 4 sections: inputs, outputs, parameters
|
||||
// and buffers.
|
||||
// The arguments vector consists of 5 sections: inputs, symbolic shapes,
|
||||
// outputs, parameters and buffers.
|
||||
auto input_args = input_specs_.size();
|
||||
auto sym_shape_args = sym_shape_positions_.size();
|
||||
auto output_args = output_specs_.size();
|
||||
auto param_args = parameters_.size();
|
||||
auto buffer_args = state.preallocations_.size();
|
||||
|
||||
auto& arguments = state.arguments_;
|
||||
arguments.reserve(input_args + output_args + param_args + buffer_args);
|
||||
arguments.reserve(
|
||||
input_args + sym_shape_args + output_args + param_args + buffer_args);
|
||||
|
||||
// Keep empty slots to fill in inputs/outputs pointers at execution time.
|
||||
arguments.resize(input_args + output_args);
|
||||
arguments.resize(input_args + sym_shape_args + output_args);
|
||||
|
||||
// Fill in parameters as untyped raw pointers.
|
||||
// The underlying storage of the parameters should be owned by `parameters_`,
|
||||
@ -233,7 +253,7 @@ c10::impl::GenericList Function::run(
|
||||
|
||||
// Fill in input tensors.
|
||||
TORCH_CHECK(
|
||||
input_specs_.size() == (inputs.size() + sym_shape_positions_.size()),
|
||||
input_specs_.size() == inputs.size(),
|
||||
"Input size doesn't match the spec, expect: ",
|
||||
input_specs_.size(),
|
||||
" actual: ",
|
||||
|
Reference in New Issue
Block a user