mirror of
https://github.com/uxlfoundation/oneDNN.git
synced 2025-10-20 18:43:49 +08:00
363 lines
8.9 KiB
Python
Executable File
363 lines
8.9 KiB
Python
Executable File
#!/usr/bin/env python
|
|
################################################################################
|
|
# Copyright 2018-2025 Intel Corporation
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
################################################################################
|
|
|
|
from __future__ import print_function
|
|
|
|
import os
|
|
import re
|
|
import sys
|
|
import datetime
|
|
import xml.etree.ElementTree as ET
|
|
|
|
|
|
def template(body, banner):
|
|
return """\
|
|
%s
|
|
// DO NOT EDIT, AUTO-GENERATED
|
|
// Use this script to update the file: scripts/%s
|
|
|
|
// clang-format off
|
|
|
|
%s""" % (
|
|
banner,
|
|
os.path.basename(__file__),
|
|
body
|
|
)
|
|
|
|
def header(body):
|
|
return (
|
|
"""\
|
|
#ifndef ONEAPI_DNNL_DNNL_DEBUG_H
|
|
#define ONEAPI_DNNL_DNNL_DEBUG_H
|
|
|
|
/// @file
|
|
/// Debug capabilities
|
|
|
|
#include "oneapi/dnnl/dnnl_config.h"
|
|
#include "oneapi/dnnl/dnnl_types.h"
|
|
|
|
#ifdef __cplusplus
|
|
extern "C" {
|
|
#endif
|
|
|
|
%s
|
|
const char DNNL_API *dnnl_runtime2str(unsigned v);
|
|
const char DNNL_API *dnnl_fmt_kind2str(dnnl_format_kind_t v);
|
|
|
|
#ifdef __cplusplus
|
|
}
|
|
#endif
|
|
|
|
#endif
|
|
"""
|
|
% body
|
|
)
|
|
|
|
|
|
def source(body):
|
|
return (
|
|
"""\
|
|
#include <assert.h>
|
|
|
|
#include "oneapi/dnnl/dnnl_debug.h"
|
|
#include "oneapi/dnnl/dnnl_types.h"
|
|
|
|
#include "common/c_types_map.hpp"
|
|
|
|
%s
|
|
"""
|
|
% body
|
|
)
|
|
|
|
|
|
def header_benchdnn(body):
|
|
return (
|
|
"""\
|
|
#ifndef DNNL_DEBUG_HPP
|
|
#define DNNL_DEBUG_HPP
|
|
|
|
#include "oneapi/dnnl/dnnl.h"
|
|
|
|
%s
|
|
/* status */
|
|
const char *status2str(dnnl_status_t status);
|
|
|
|
/* data type */
|
|
const char *dt2str(dnnl_data_type_t dt);
|
|
|
|
/* format */
|
|
const char *fmt_tag2str(dnnl_format_tag_t tag);
|
|
|
|
/* encoding */
|
|
const char *sparse_encoding2str(dnnl_sparse_encoding_t encoding);
|
|
|
|
/* engine kind */
|
|
const char *engine_kind2str(dnnl_engine_kind_t kind);
|
|
|
|
/* scratchpad mode */
|
|
const char *scratchpad_mode2str(dnnl_scratchpad_mode_t mode);
|
|
|
|
/* fpmath mode */
|
|
const char *fpmath_mode2str(dnnl_fpmath_mode_t mode);
|
|
|
|
/* accumulation mode */
|
|
const char *accumulation_mode2str(dnnl_accumulation_mode_t mode);
|
|
|
|
/* rounding mode */
|
|
const char *rounding_mode2str(dnnl_rounding_mode_t mode);
|
|
|
|
#endif
|
|
"""
|
|
% body
|
|
)
|
|
|
|
|
|
def source_benchdnn(body):
|
|
return (
|
|
"""\
|
|
#include <assert.h>
|
|
#include <stdio.h>
|
|
#include <string.h>
|
|
|
|
#include "oneapi/dnnl/dnnl_debug.h"
|
|
|
|
#include "dnnl_debug.hpp"
|
|
|
|
#include "src/common/z_magic.hpp"
|
|
|
|
%s
|
|
|
|
const char *status2str(dnnl_status_t status) {
|
|
return dnnl_status2str(status);
|
|
}
|
|
|
|
const char *dt2str(dnnl_data_type_t dt) {
|
|
return dnnl_dt2str(dt);
|
|
}
|
|
|
|
const char *fmt_tag2str(dnnl_format_tag_t tag) {
|
|
return dnnl_fmt_tag2str(tag);
|
|
}
|
|
|
|
const char *sparse_encoding2str(dnnl_sparse_encoding_t encoding) {
|
|
return dnnl_sparse_encoding2str(encoding);
|
|
}
|
|
|
|
const char *engine_kind2str(dnnl_engine_kind_t kind) {
|
|
return dnnl_engine_kind2str(kind);
|
|
}
|
|
|
|
const char *scratchpad_mode2str(dnnl_scratchpad_mode_t mode) {
|
|
return dnnl_scratchpad_mode2str(mode);
|
|
}
|
|
|
|
const char *fpmath_mode2str(dnnl_fpmath_mode_t mode) {
|
|
return dnnl_fpmath_mode2str(mode);
|
|
}
|
|
|
|
const char *accumulation_mode2str(dnnl_accumulation_mode_t mode) {
|
|
return dnnl_accumulation_mode2str(mode);
|
|
}
|
|
|
|
const char *rounding_mode2str(dnnl_rounding_mode_t mode) {
|
|
return dnnl_rounding_mode2str(mode);
|
|
}
|
|
"""
|
|
% body.rstrip()
|
|
)
|
|
|
|
|
|
def maybe_skip(enum):
|
|
return enum in (
|
|
"dnnl_memory_extra_flags_t",
|
|
"dnnl_normalization_flags_t",
|
|
"dnnl_query_t",
|
|
"dnnl_rnn_cell_flags_t",
|
|
"dnnl_stream_flags_t",
|
|
"dnnl_format_kind_t",
|
|
)
|
|
|
|
|
|
def enum_abbrev(enum):
|
|
def_enum = re.sub(r"^dnnl_", "", enum)
|
|
def_enum = re.sub(r"_t$", "", def_enum)
|
|
return {
|
|
"dnnl_data_type_t": "dt",
|
|
"dnnl_format_tag_t": "fmt_tag",
|
|
"dnnl_primitive_kind_t": "prim_kind",
|
|
"dnnl_engine_kind_t": "engine_kind",
|
|
}.get(enum, def_enum)
|
|
|
|
|
|
def sanitize_value(v):
|
|
if "undef" in v:
|
|
return "undef"
|
|
if "any" in v:
|
|
return "any"
|
|
v = v.split("dnnl_fpmath_mode_")[-1]
|
|
v = v.split("dnnl_accumulation_mode_")[-1]
|
|
v = v.split("dnnl_rounding_mode_")[-1]
|
|
v = v.split("dnnl_scratchpad_mode_")[-1]
|
|
v = v.split("dnnl_")[-1]
|
|
return v
|
|
|
|
|
|
def func_to_str_decl(enum, is_header=False):
|
|
abbrev = enum_abbrev(enum)
|
|
return "const char %s*dnnl_%s2str(%s v)" % (
|
|
"DNNL_API " if is_header else "",
|
|
abbrev,
|
|
enum,
|
|
)
|
|
|
|
|
|
def func_to_str(enum, values):
|
|
indent = " "
|
|
abbrev = enum_abbrev(enum)
|
|
func = ""
|
|
func += func_to_str_decl(enum) + " {\n"
|
|
for v in values:
|
|
func += '%sif (v == %s) return "%s";\n' % (indent, v, sanitize_value(v))
|
|
if (enum == "dnnl_primitive_kind_t"):
|
|
func += '%sif (v == dnnl::impl::primitive_kind::sdpa) return "sdpa";\n' % indent
|
|
func += '%sassert(!"unknown %s");\n' % (indent, abbrev)
|
|
func += '%sreturn "unknown %s";\n}\n' % (indent, abbrev)
|
|
return func
|
|
|
|
|
|
def str_to_func_decl(enum, is_header=False, is_dnnl=True):
|
|
attr = "DNNL_API " if is_header and is_dnnl else ""
|
|
prefix = "dnnl_" if is_dnnl else ""
|
|
abbrev = enum_abbrev(enum)
|
|
return "%s %s%sstr2%s(const char *str)" % (enum, attr, prefix, abbrev)
|
|
|
|
|
|
def str_to_func(enum, values, is_dnnl=True):
|
|
indent = " "
|
|
abbrev = enum_abbrev(enum)
|
|
func = ""
|
|
func += str_to_func_decl(enum, is_dnnl=is_dnnl) + " {\n"
|
|
func += """#define CASE(_case) do { \\
|
|
if (!strcmp(STRINGIFY(_case), str) \\
|
|
|| !strcmp("dnnl_" STRINGIFY(_case), str)) \\
|
|
return CONCAT2(dnnl_, _case); \\
|
|
} while (0)
|
|
"""
|
|
special_values = []
|
|
for v in values:
|
|
if "last" in v:
|
|
continue
|
|
if "undef" in v:
|
|
v_undef = v
|
|
special_values.append(v)
|
|
continue
|
|
if "any" in v:
|
|
special_values.append(v)
|
|
continue
|
|
func += "%sCASE(%s);\n" % (indent, sanitize_value(v))
|
|
func += "#undef CASE\n"
|
|
for v in special_values:
|
|
v_short = re.search(r"(any|undef)", v).group()
|
|
func += """%sif (!strcmp("%s", str) || !strcmp("%s", str))
|
|
return %s;
|
|
""" % (
|
|
indent,
|
|
v_short,
|
|
v,
|
|
v,
|
|
)
|
|
if enum != "dnnl_format_tag_t":
|
|
func += (
|
|
'%sprintf("Error: %s ' % (indent, abbrev)
|
|
+ '`%s` is not supported.\\n", str);\n'
|
|
)
|
|
func += '%sassert(!"unknown %s");\n' % (indent, abbrev)
|
|
func += "%sreturn %s;\n}\n" % (
|
|
indent,
|
|
v_undef if enum != "dnnl_format_tag_t" else "dnnl_format_tag_last",
|
|
)
|
|
return func
|
|
|
|
|
|
def generate(ifile, banners):
|
|
h_body, s_body = "", ""
|
|
h_benchdnn_body, s_benchdnn_body = "", ""
|
|
root = ET.parse(ifile).getroot()
|
|
for v_enum in root.findall("Enumeration"):
|
|
enum = v_enum.attrib["name"]
|
|
if maybe_skip(enum):
|
|
continue
|
|
values = [v_value.attrib["name"] for v_value in v_enum.findall("EnumValue")]
|
|
|
|
h_body += func_to_str_decl(enum, is_header=True) + ";\n"
|
|
s_body += func_to_str(enum, values) + "\n"
|
|
|
|
if enum in ["dnnl_format_tag_t", "dnnl_data_type_t", "dnnl_sparse_encoding_t"]:
|
|
h_benchdnn_body += (
|
|
str_to_func_decl(enum, is_header=True, is_dnnl=False) + ";\n"
|
|
)
|
|
s_benchdnn_body += str_to_func(enum, values, is_dnnl=False) + "\n"
|
|
|
|
bodies = [
|
|
header(h_body),
|
|
source(s_body),
|
|
header_benchdnn(h_benchdnn_body),
|
|
source_benchdnn(s_benchdnn_body),
|
|
]
|
|
return [template(b, y) for b, y in zip(bodies, banners)]
|
|
|
|
|
|
def usage():
|
|
print(
|
|
"""\
|
|
%s types.xml
|
|
|
|
Generates oneDNN debug header and source files with enum to string mapping.
|
|
Input types.xml file can be obtained with CastXML[1]:
|
|
$ castxml --castxml-cc-gnu-c clang --castxml-output=1 \\
|
|
-Iinclude -Ibuild/include include/oneapi/dnnl/dnnl_types.h -o types.xml
|
|
|
|
[1] https://github.com/CastXML/CastXML"""
|
|
% sys.argv[0]
|
|
)
|
|
sys.exit(1)
|
|
|
|
for arg in sys.argv:
|
|
if "-help" in arg:
|
|
usage()
|
|
|
|
script_root = os.path.dirname(os.path.realpath(__file__))
|
|
|
|
ifile = sys.argv[1] if len(sys.argv) > 1 else usage()
|
|
|
|
file_paths = (
|
|
"%s/../include/oneapi/dnnl/dnnl_debug.h" % script_root,
|
|
"%s/../src/common/dnnl_debug_autogenerated.cpp" % script_root,
|
|
"%s/../tests/benchdnn/dnnl_debug.hpp" % script_root,
|
|
"%s/../tests/benchdnn/dnnl_debug_autogenerated.cpp" % script_root,
|
|
)
|
|
|
|
banners = []
|
|
for file_path in file_paths:
|
|
with open(file_path, "r") as f:
|
|
m = re.match(r'^/\*+\n(\*.*\n)+\*+/\n', f.read())
|
|
banners.append('' if m == None else m.group(0))
|
|
|
|
for file_path, file_body in zip(file_paths, generate(ifile, banners)):
|
|
with open(file_path, "w") as f:
|
|
f.write(file_body)
|