mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 12:54:11 +08:00
[nnc][aot_compiler] Memory formats args to aot_compiler (#72873)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/72873 Test Plan: Imported from OSS Reviewed By: priyaramani Differential Revision: D34250984 Pulled By: IvanKobzarev fbshipit-source-id: e723ee64b024883eef78853e1b185b7040cafb09 (cherry picked from commit e9908df045acf33aa3cd0aec6784f15421236787)
This commit is contained in:
committed by
PyTorch MergeBot
parent
41ad221751
commit
c32b74cecb
@ -30,6 +30,12 @@ C10_DEFINE_string(
|
||||
"If multiple inputs needed, use semicolon to separate "
|
||||
"the dtype of different tensors."
|
||||
"Supported dtypes: float, int64, uint8");
|
||||
C10_DEFINE_string(
|
||||
input_memory_formats,
|
||||
"",
|
||||
"Input memory format."
|
||||
"If multiple inputs needed, use semicolon to separate."
|
||||
"Supported values: contiguous, channels_last");
|
||||
C10_DEFINE_string(method_name, "forward", "The name of the method.");
|
||||
C10_DEFINE_string(
|
||||
output_llvm,
|
||||
@ -61,6 +67,7 @@ c10::Dict<c10::IValue, c10::IValue> createCompileSpec() {
|
||||
c10::StringType::get(), c10::AnyType::get());
|
||||
method_spec.insert("sizes", FLAGS_input_dims);
|
||||
method_spec.insert("types", FLAGS_input_types);
|
||||
method_spec.insert("memory_formats", FLAGS_input_memory_formats);
|
||||
method_spec.insert("asmfile", FLAGS_output_llvm);
|
||||
method_spec.insert("model_name", FLAGS_model_name);
|
||||
method_spec.insert("model_version", FLAGS_model_version);
|
||||
@ -79,6 +86,7 @@ int main(int argc, char** argv) {
|
||||
" --model_version=<model version>"
|
||||
" --input_dims=<input dimensions like '1,3,224,224;2,2'>"
|
||||
" --input_types=<input dtypes like 'float;float'>"
|
||||
" --input_memory_formats=<input memory formats like 'channels_last;contiguous'>"
|
||||
" [--method_name=<method name>]"
|
||||
" [--output_llvm=<llvm assembly output file path>]"
|
||||
" [--output_model=<output model file path>]");
|
||||
@ -93,10 +101,14 @@ int main(int argc, char** argv) {
|
||||
CAFFE_ENFORCE(!FLAGS_model_name.empty(), c10::UsageMessage());
|
||||
CAFFE_ENFORCE(!FLAGS_model_version.empty(), c10::UsageMessage());
|
||||
CAFFE_ENFORCE(!FLAGS_input_dims.empty(), c10::UsageMessage());
|
||||
const auto dims_size = split(';', FLAGS_input_dims).size();
|
||||
CAFFE_ENFORCE(
|
||||
split(';', FLAGS_input_dims).size() ==
|
||||
split(';', FLAGS_input_types).size(),
|
||||
dims_size == split(';', FLAGS_input_types).size(),
|
||||
"Number of input_dims and input_types should be the same");
|
||||
const auto mem_formats_size = split(';', FLAGS_input_memory_formats).size();
|
||||
CAFFE_ENFORCE(
|
||||
mem_formats_size == 0 || mem_formats_size == dims_size,
|
||||
"Number of input_memory_formats should be 0 (default contiguous) or the same as number of input_dims");
|
||||
if (FLAGS_output_llvm.empty()) {
|
||||
FLAGS_output_llvm =
|
||||
FLAGS_model.substr(0, FLAGS_model.find('.')) + ".compiled.ll";
|
||||
|
@ -253,6 +253,24 @@ std::vector<at::ScalarType> parseInputTypes(
|
||||
return scalarTypes;
|
||||
}
|
||||
|
||||
std::vector<at::MemoryFormat> parseInputMemoryFormats(
|
||||
const std::string& input_memory_format_str) {
|
||||
std::vector<std::string> memFormatsStr = split(';', input_memory_format_str);
|
||||
std::vector<at::MemoryFormat> memFormats;
|
||||
for (const auto& memFormatStr : memFormatsStr) {
|
||||
at::MemoryFormat memFormat;
|
||||
if (memFormatStr == "contiguous") {
|
||||
memFormat = at::MemoryFormat::Contiguous;
|
||||
} else if (memFormatStr == "channels_last") {
|
||||
memFormat = at::MemoryFormat::ChannelsLast;
|
||||
} else {
|
||||
CAFFE_THROW("Unsupported memory format: ", memFormatStr);
|
||||
}
|
||||
memFormats.push_back(memFormat);
|
||||
}
|
||||
return memFormats;
|
||||
}
|
||||
|
||||
std::string getNncKernelId(
|
||||
const std::string& model_name,
|
||||
const std::string& model_version,
|
||||
@ -309,12 +327,16 @@ std::shared_ptr<Graph> preprocessGraphPasses(
|
||||
|
||||
std::vector<c10::optional<at::Tensor>> generateExampleInputs(
|
||||
const std::vector<std::vector<int64_t>>& inputShapes,
|
||||
const std::vector<at::ScalarType>& inputTypes) {
|
||||
const std::vector<at::ScalarType>& inputTypes,
|
||||
const std::vector<at::MemoryFormat>& inputMemoryFormats) {
|
||||
std::vector<c10::optional<at::Tensor>> example_inputs;
|
||||
example_inputs.reserve(inputShapes.size());
|
||||
for (int i = 0; i < inputShapes.size(); ++i) {
|
||||
const auto dtype = at::dtype(inputTypes[i]);
|
||||
const auto memory_format = inputMemoryFormats[i];
|
||||
example_inputs.emplace_back(
|
||||
at::rand(inputShapes[i]).to(at::dtype(inputTypes[i])));
|
||||
at::rand(inputShapes[i], at::TensorOptions(dtype))
|
||||
.contiguous(memory_format));
|
||||
}
|
||||
return example_inputs;
|
||||
}
|
||||
@ -343,7 +365,15 @@ c10::IValue preprocess(
|
||||
auto sizes = parseInputShapes(*method_spec.at("sizes").toString());
|
||||
auto types = parseInputTypes(*method_spec.at("types").toString());
|
||||
|
||||
auto example_inputs = generateExampleInputs(sizes, types);
|
||||
std::string memory_formats_str = method_spec.contains("memory_formats")
|
||||
? (*method_spec.at("memory_formats").toString()).string()
|
||||
: "";
|
||||
auto memory_formats = memory_formats_str.empty()
|
||||
? std::vector<at::MemoryFormat>(
|
||||
sizes.size(), at::MemoryFormat::Contiguous)
|
||||
: parseInputMemoryFormats(memory_formats_str);
|
||||
|
||||
auto example_inputs = generateExampleInputs(sizes, types, memory_formats);
|
||||
graph = preprocessGraphPasses(graph, example_inputs);
|
||||
|
||||
auto kernel_func_name =
|
||||
|
Reference in New Issue
Block a user