[nnc][aot_compiler] Memory formats args to aot_compiler (#72873)

Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/72873

Test Plan: Imported from OSS

Reviewed By: priyaramani

Differential Revision: D34250984

Pulled By: IvanKobzarev

fbshipit-source-id: e723ee64b024883eef78853e1b185b7040cafb09
(cherry picked from commit e9908df045acf33aa3cd0aec6784f15421236787)
This commit is contained in:
Ivan Kobzarev
2022-02-16 10:32:06 -08:00
committed by PyTorch MergeBot
parent 41ad221751
commit c32b74cecb
2 changed files with 47 additions and 5 deletions

View File

@ -30,6 +30,12 @@ C10_DEFINE_string(
"If multiple inputs needed, use semicolon to separate "
"the dtype of different tensors."
"Supported dtypes: float, int64, uint8");
C10_DEFINE_string(
input_memory_formats,
"",
"Input memory format."
"If multiple inputs needed, use semicolon to separate."
"Supported values: contiguous, channels_last");
C10_DEFINE_string(method_name, "forward", "The name of the method.");
C10_DEFINE_string(
output_llvm,
@ -61,6 +67,7 @@ c10::Dict<c10::IValue, c10::IValue> createCompileSpec() {
c10::StringType::get(), c10::AnyType::get());
method_spec.insert("sizes", FLAGS_input_dims);
method_spec.insert("types", FLAGS_input_types);
method_spec.insert("memory_formats", FLAGS_input_memory_formats);
method_spec.insert("asmfile", FLAGS_output_llvm);
method_spec.insert("model_name", FLAGS_model_name);
method_spec.insert("model_version", FLAGS_model_version);
@ -79,6 +86,7 @@ int main(int argc, char** argv) {
" --model_version=<model version>"
" --input_dims=<input dimensions like '1,3,224,224;2,2'>"
" --input_types=<input dtypes like 'float;float'>"
" --input_memory_formats=<input memory formats like 'channels_last;contiguous'>"
" [--method_name=<method name>]"
" [--output_llvm=<llvm assembly output file path>]"
" [--output_model=<output model file path>]");
@ -93,10 +101,14 @@ int main(int argc, char** argv) {
CAFFE_ENFORCE(!FLAGS_model_name.empty(), c10::UsageMessage());
CAFFE_ENFORCE(!FLAGS_model_version.empty(), c10::UsageMessage());
CAFFE_ENFORCE(!FLAGS_input_dims.empty(), c10::UsageMessage());
const auto dims_size = split(';', FLAGS_input_dims).size();
CAFFE_ENFORCE(
split(';', FLAGS_input_dims).size() ==
split(';', FLAGS_input_types).size(),
dims_size == split(';', FLAGS_input_types).size(),
"Number of input_dims and input_types should be the same");
const auto mem_formats_size = split(';', FLAGS_input_memory_formats).size();
CAFFE_ENFORCE(
mem_formats_size == 0 || mem_formats_size == dims_size,
"Number of input_memory_formats should be 0 (default contiguous) or the same as number of input_dims");
if (FLAGS_output_llvm.empty()) {
FLAGS_output_llvm =
FLAGS_model.substr(0, FLAGS_model.find('.')) + ".compiled.ll";

View File

@ -253,6 +253,24 @@ std::vector<at::ScalarType> parseInputTypes(
return scalarTypes;
}
std::vector<at::MemoryFormat> parseInputMemoryFormats(
const std::string& input_memory_format_str) {
std::vector<std::string> memFormatsStr = split(';', input_memory_format_str);
std::vector<at::MemoryFormat> memFormats;
for (const auto& memFormatStr : memFormatsStr) {
at::MemoryFormat memFormat;
if (memFormatStr == "contiguous") {
memFormat = at::MemoryFormat::Contiguous;
} else if (memFormatStr == "channels_last") {
memFormat = at::MemoryFormat::ChannelsLast;
} else {
CAFFE_THROW("Unsupported memory format: ", memFormatStr);
}
memFormats.push_back(memFormat);
}
return memFormats;
}
std::string getNncKernelId(
const std::string& model_name,
const std::string& model_version,
@ -309,12 +327,16 @@ std::shared_ptr<Graph> preprocessGraphPasses(
std::vector<c10::optional<at::Tensor>> generateExampleInputs(
const std::vector<std::vector<int64_t>>& inputShapes,
const std::vector<at::ScalarType>& inputTypes) {
const std::vector<at::ScalarType>& inputTypes,
const std::vector<at::MemoryFormat>& inputMemoryFormats) {
std::vector<c10::optional<at::Tensor>> example_inputs;
example_inputs.reserve(inputShapes.size());
for (int i = 0; i < inputShapes.size(); ++i) {
const auto dtype = at::dtype(inputTypes[i]);
const auto memory_format = inputMemoryFormats[i];
example_inputs.emplace_back(
at::rand(inputShapes[i]).to(at::dtype(inputTypes[i])));
at::rand(inputShapes[i], at::TensorOptions(dtype))
.contiguous(memory_format));
}
return example_inputs;
}
@ -343,7 +365,15 @@ c10::IValue preprocess(
auto sizes = parseInputShapes(*method_spec.at("sizes").toString());
auto types = parseInputTypes(*method_spec.at("types").toString());
auto example_inputs = generateExampleInputs(sizes, types);
std::string memory_formats_str = method_spec.contains("memory_formats")
? (*method_spec.at("memory_formats").toString()).string()
: "";
auto memory_formats = memory_formats_str.empty()
? std::vector<at::MemoryFormat>(
sizes.size(), at::MemoryFormat::Contiguous)
: parseInputMemoryFormats(memory_formats_str);
auto example_inputs = generateExampleInputs(sizes, types, memory_formats);
graph = preprocessGraphPasses(graph, example_inputs);
auto kernel_func_name =