mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[ATen-vulkan] Implement global shader registry (#121088)
Differential Revision: D54447700 ## Context This changeset updates Vulkan SPIR-V codegen to introduce a global SPIR-V shader registry and register shaders dynamically at static initialization time. This change makes it possible to define and link custom shader libraries to the ATen-Vulkan runtime. Before: * `gen_vulkan_spv.py` generated two files, `spv.h` and `spv.cpp` which would contain the definition and initialization of Vulkan shader registry variables. After: * Introduce the `ShaderRegistry` class in `api/`, which encapsulates functionality of the `ShaderRegistry` class previously defined in the generated `spv.h` file * Introduce a global shader registry (defined as a static variable in the `api::shader_registry() function` * Define a `ShaderRegisterInit` class (taking inspiration from `TorchLibraryInit`) that allows for dynamic shader registration * `gen_vulkan_spv.py` now only generates `spv.cpp`, which defines a static `ShaderRegisterInit` instance that triggers registration of the compiled shaders to the global shader registry. Benefits: * Cleaner code base; we no longer have `ShaderRegistry` defined in a generated file, and don't need a separate implementation file (`impl/Registry.*`) to handle shader lookup. All that logic now lives under `api/ShaderRegistry.*` * Makes it possible to compile and link separate shader libraries, providing similar flexibility as defining and linking custom ATen operators Pull Request resolved: https://github.com/pytorch/pytorch/pull/121088 Approved by: https://github.com/manuelcandales, https://github.com/jorgep31415
This commit is contained in:
committed by
PyTorch MergeBot
parent
c3c618c750
commit
ffe45a8188
61
aten/src/ATen/native/vulkan/api/ShaderRegistry.cpp
Normal file
61
aten/src/ATen/native/vulkan/api/ShaderRegistry.cpp
Normal file
@ -0,0 +1,61 @@
|
||||
#include <ATen/native/vulkan/api/ShaderRegistry.h>
|
||||
|
||||
namespace at {
|
||||
namespace native {
|
||||
namespace vulkan {
|
||||
namespace api {
|
||||
|
||||
bool ShaderRegistry::has_shader(const std::string& shader_name) {
|
||||
const ShaderListing::const_iterator it = listings_.find(shader_name);
|
||||
return it != listings_.end();
|
||||
}
|
||||
|
||||
bool ShaderRegistry::has_dispatch(const std::string& op_name) {
|
||||
const Registry::const_iterator it = registry_.find(op_name);
|
||||
return it != registry_.end();
|
||||
}
|
||||
|
||||
void ShaderRegistry::register_shader(ShaderInfo&& shader_info) {
|
||||
if (has_shader(shader_info.kernel_name)) {
|
||||
VK_THROW(
|
||||
"Shader with name ", shader_info.kernel_name, "already registered");
|
||||
}
|
||||
listings_.emplace(shader_info.kernel_name, shader_info);
|
||||
}
|
||||
|
||||
void ShaderRegistry::register_op_dispatch(
|
||||
const std::string& op_name,
|
||||
const DispatchKey key,
|
||||
const std::string& shader_name) {
|
||||
if (!has_dispatch(op_name)) {
|
||||
registry_.emplace(op_name, Dispatcher());
|
||||
}
|
||||
const Dispatcher::const_iterator it = registry_[op_name].find(key);
|
||||
if (it != registry_[op_name].end()) {
|
||||
registry_[op_name][key] = shader_name;
|
||||
} else {
|
||||
registry_[op_name].emplace(key, shader_name);
|
||||
}
|
||||
}
|
||||
|
||||
const ShaderInfo& ShaderRegistry::get_shader_info(
|
||||
const std::string& shader_name) {
|
||||
const ShaderListing::const_iterator it = listings_.find(shader_name);
|
||||
|
||||
VK_CHECK_COND(
|
||||
it != listings_.end(),
|
||||
"Could not find ShaderInfo with name ",
|
||||
shader_name);
|
||||
|
||||
return it->second;
|
||||
}
|
||||
|
||||
ShaderRegistry& shader_registry() {
|
||||
static ShaderRegistry registry;
|
||||
return registry;
|
||||
}
|
||||
|
||||
} // namespace api
|
||||
} // namespace vulkan
|
||||
} // namespace native
|
||||
} // namespace at
|
84
aten/src/ATen/native/vulkan/api/ShaderRegistry.h
Normal file
84
aten/src/ATen/native/vulkan/api/ShaderRegistry.h
Normal file
@ -0,0 +1,84 @@
|
||||
#pragma once
|
||||
|
||||
// @lint-ignore-every CLANGTIDY facebook-hte-BadMemberName
|
||||
|
||||
#ifdef USE_VULKAN_API
|
||||
|
||||
#include <ATen/native/vulkan/api/Shader.h>
|
||||
|
||||
#include <string>
|
||||
#include <unordered_map>
|
||||
|
||||
#define VK_KERNEL(shader_name) \
|
||||
::at::native::vulkan::api::shader_registry().get_shader_info(#shader_name)
|
||||
|
||||
namespace at {
|
||||
namespace native {
|
||||
namespace vulkan {
|
||||
namespace api {
|
||||
|
||||
enum class DispatchKey : int8_t {
|
||||
CATCHALL,
|
||||
ADRENO,
|
||||
MALI,
|
||||
OVERRIDE,
|
||||
};
|
||||
|
||||
class ShaderRegistry final {
|
||||
using ShaderListing = std::unordered_map<std::string, ShaderInfo>;
|
||||
using Dispatcher = std::unordered_map<DispatchKey, std::string>;
|
||||
using Registry = std::unordered_map<std::string, Dispatcher>;
|
||||
|
||||
ShaderListing listings_;
|
||||
Dispatcher dispatcher_;
|
||||
Registry registry_;
|
||||
|
||||
public:
|
||||
/*
|
||||
* Check if the registry has a shader registered under the given name
|
||||
*/
|
||||
bool has_shader(const std::string& shader_name);
|
||||
|
||||
/*
|
||||
* Check if the registry has a dispatch registered under the given name
|
||||
*/
|
||||
bool has_dispatch(const std::string& op_name);
|
||||
|
||||
/*
|
||||
* Register a ShaderInfo to a given shader name
|
||||
*/
|
||||
void register_shader(ShaderInfo&& shader_info);
|
||||
|
||||
/*
|
||||
* Register a dispatch entry to the given op name
|
||||
*/
|
||||
void register_op_dispatch(
|
||||
const std::string& op_name,
|
||||
const DispatchKey key,
|
||||
const std::string& shader_name);
|
||||
|
||||
/*
|
||||
* Given a shader name, return the ShaderInfo which contains the SPIRV binary
|
||||
*/
|
||||
const ShaderInfo& get_shader_info(const std::string& shader_name);
|
||||
};
|
||||
|
||||
class ShaderRegisterInit final {
|
||||
using InitFn = void();
|
||||
|
||||
public:
|
||||
ShaderRegisterInit(InitFn* init_fn) {
|
||||
init_fn();
|
||||
};
|
||||
};
|
||||
|
||||
// The global shader registry is retrieved using this function, where it is
|
||||
// declared as a static local variable.
|
||||
ShaderRegistry& shader_registry();
|
||||
|
||||
} // namespace api
|
||||
} // namespace vulkan
|
||||
} // namespace native
|
||||
} // namespace at
|
||||
|
||||
#endif /* USE_VULKAN_API */
|
@ -10,6 +10,7 @@
|
||||
#include <ATen/native/vulkan/api/Resource.h>
|
||||
#include <ATen/native/vulkan/api/Runtime.h>
|
||||
#include <ATen/native/vulkan/api/Shader.h>
|
||||
#include <ATen/native/vulkan/api/ShaderRegistry.h>
|
||||
#include <ATen/native/vulkan/api/Tensor.h>
|
||||
#include <ATen/native/vulkan/api/Utils.h>
|
||||
|
||||
|
@ -3,12 +3,6 @@
|
||||
#ifdef USE_VULKAN_API
|
||||
|
||||
#include <ATen/native/vulkan/api/api.h>
|
||||
#include <ATen/native/vulkan/impl/Registry.h>
|
||||
|
||||
#define VK_KERNEL(shader_name) \
|
||||
::at::native::vulkan::get_shader_info(#shader_name)
|
||||
#define VK_LOOKUP_KERNEL(op_name) \
|
||||
::at::native::vulkan::look_up_shader_info(#op_name)
|
||||
|
||||
namespace at {
|
||||
namespace native {
|
||||
|
@ -1,88 +0,0 @@
|
||||
#ifdef USE_VULKAN_API
|
||||
|
||||
#include <ATen/native/vulkan/api/Shader.h>
|
||||
#include <ATen/native/vulkan/impl/Registry.h>
|
||||
#include <ATen/native/vulkan/spv.h>
|
||||
|
||||
namespace at {
|
||||
namespace native {
|
||||
namespace vulkan {
|
||||
|
||||
const api::ShaderInfo& get_shader_info(const std::string& shader_name) {
|
||||
const ShaderListing::const_iterator shader_infos_iterator =
|
||||
get_shader_infos().find(shader_name);
|
||||
|
||||
VK_CHECK_COND(
|
||||
shader_infos_iterator != get_shader_infos().end(),
|
||||
"Could not get ShaderInfo named ",
|
||||
shader_name);
|
||||
|
||||
return shader_infos_iterator->second;
|
||||
}
|
||||
|
||||
const api::ShaderInfo& look_up_shader_info(const std::string& op_name) {
|
||||
const ShaderRegistry::iterator registry_iterator =
|
||||
get_shader_registry().find(op_name);
|
||||
|
||||
VK_CHECK_COND(
|
||||
registry_iterator != get_shader_registry().end(),
|
||||
"Could not look up ShaderInfo for ",
|
||||
op_name,
|
||||
" in shader registry");
|
||||
|
||||
const RegistryKeyMap& registry_key_map = registry_iterator->second;
|
||||
|
||||
// Look for "override" and "catchall" keys
|
||||
for (const std::string key : {"override", "catchall"}) {
|
||||
const RegistryKeyMap::const_iterator registry_key_iterator =
|
||||
registry_key_map.find(key);
|
||||
if (registry_key_iterator != registry_key_map.end()) {
|
||||
const ShaderListing::const_iterator shader_infos_iterator =
|
||||
get_shader_infos().find(registry_key_iterator->second);
|
||||
|
||||
VK_CHECK_COND(
|
||||
shader_infos_iterator != get_shader_infos().end(),
|
||||
"Could not get ShaderInfo named ",
|
||||
registry_key_iterator->second,
|
||||
" (listed under ",
|
||||
op_name,
|
||||
" -> ",
|
||||
key,
|
||||
" in shader registry)");
|
||||
|
||||
return shader_infos_iterator->second;
|
||||
}
|
||||
}
|
||||
|
||||
VK_CHECK_COND(
|
||||
false,
|
||||
"Could not look up ShaderInfo for ",
|
||||
op_name,
|
||||
" with a valid key in shader registry");
|
||||
}
|
||||
|
||||
void set_registry_override(
|
||||
const std::string& op_name,
|
||||
const std::string& shader_name) {
|
||||
const ShaderRegistry::iterator registry_iterator =
|
||||
get_shader_registry().find(op_name);
|
||||
|
||||
VK_CHECK_COND(
|
||||
registry_iterator != get_shader_registry().end(),
|
||||
"Could not look up ShaderInfo for ",
|
||||
op_name,
|
||||
" in shader registry");
|
||||
|
||||
VK_CHECK_COND(
|
||||
get_shader_infos().find(shader_name) != get_shader_infos().end(),
|
||||
"Could not get ShaderInfo named ",
|
||||
shader_name);
|
||||
|
||||
registry_iterator->second["override"] = shader_name;
|
||||
}
|
||||
|
||||
} // namespace vulkan
|
||||
} // namespace native
|
||||
} // namespace at
|
||||
|
||||
#endif // USE_VULKAN_API
|
@ -1,33 +0,0 @@
|
||||
#pragma once
|
||||
|
||||
#ifdef USE_VULKAN_API
|
||||
|
||||
#include <string>
|
||||
|
||||
namespace at {
|
||||
namespace native {
|
||||
namespace vulkan {
|
||||
namespace api {
|
||||
// Forward declaration of ShaderInfo
|
||||
struct ShaderInfo;
|
||||
} // namespace api
|
||||
|
||||
/**
|
||||
* Get the shader with a given name
|
||||
*/
|
||||
const api::ShaderInfo& get_shader_info(const std::string& shader_name);
|
||||
|
||||
/**
|
||||
* Look up which shader to use for a given op in the shader registry
|
||||
*/
|
||||
const api::ShaderInfo& look_up_shader_info(const std::string& op_name);
|
||||
|
||||
void set_registry_override(
|
||||
const std::string& op_name,
|
||||
const std::string& shader_name);
|
||||
|
||||
} // namespace vulkan
|
||||
} // namespace native
|
||||
} // namespace at
|
||||
|
||||
#endif // USE_VULKAN_API
|
@ -10,11 +10,6 @@
|
||||
#include <ATen/native/vulkan/impl/Common.h>
|
||||
#include <ATen/native/vulkan/ops/Convert.h>
|
||||
|
||||
#define VK_KERNEL(shader_name) \
|
||||
::at::native::vulkan::get_shader_info(#shader_name)
|
||||
#define VK_LOOKUP_KERNEL(op_name) \
|
||||
::at::native::vulkan::look_up_shader_info(#op_name)
|
||||
|
||||
namespace at {
|
||||
namespace native {
|
||||
namespace vulkan {
|
||||
|
@ -319,7 +319,7 @@ static api::ShaderInfo get_shader(
|
||||
|
||||
switch (method) {
|
||||
case Conv2dSlidingWindow:
|
||||
shader = VK_LOOKUP_KERNEL(conv2d);
|
||||
shader = VK_KERNEL(conv2d);
|
||||
break;
|
||||
case Conv2dDepthwise:
|
||||
shader = VK_KERNEL(conv2d_dw);
|
||||
@ -335,7 +335,7 @@ static api::ShaderInfo get_shader(
|
||||
}
|
||||
break;
|
||||
case Conv2dPointwise:
|
||||
shader = VK_LOOKUP_KERNEL(conv2d_pw);
|
||||
shader = VK_KERNEL(conv2d_pw_output_tile_2x2);
|
||||
break;
|
||||
}
|
||||
return shader;
|
||||
|
@ -144,6 +144,9 @@ def get_glsl_paths():
|
||||
],
|
||||
)
|
||||
|
||||
def spv_shader_library():
|
||||
pass
|
||||
|
||||
# @lint-ignore BUCKRESTRICTEDSYNTAX
|
||||
IS_OSS = read_config("pt", "is_oss", "0") == "1" # True for OSS BUCK build, and False for internal BUCK build
|
||||
|
||||
@ -700,6 +703,43 @@ def gen_aten_libtorch_files(name, extra_params = [], compatible_with = [], apple
|
||||
apple_sdks = apple_sdks,
|
||||
)
|
||||
|
||||
def vulkan_spv_shader_library(name, spv_filegroup):
|
||||
genrule_cmd = [
|
||||
"$(exe //xplat/caffe2/tools:gen_aten_vulkan_spv_bin)",
|
||||
"--glsl-paths $(location {})".format(spv_filegroup),
|
||||
"--output-path $OUT --env FLOAT_IMAGE_FORMAT={}".format(get_glsl_image_format()),
|
||||
"--glslc-path=$(exe //xplat/caffe2/fb/vulkan/dotslash:glslc)",
|
||||
"--tmp-dir-path=$TMP",
|
||||
]
|
||||
|
||||
genrule_name = "gen_{}_cpp".format(name)
|
||||
fb_xplat_genrule(
|
||||
name = "gen_{}_cpp".format(name),
|
||||
outs = {
|
||||
"{}.cpp".format(name): ["spv.cpp"],
|
||||
},
|
||||
cmd = " ".join(genrule_cmd),
|
||||
default_outs = ["."],
|
||||
labels = ["uses_dotslash"],
|
||||
)
|
||||
|
||||
fb_xplat_cxx_library(
|
||||
name = name,
|
||||
srcs = [
|
||||
":{}[{}.cpp]".format(genrule_name, name),
|
||||
],
|
||||
# Static initialization is used to register shaders to the global shader registry,
|
||||
# therefore link_whole must be True to make sure unused symbols are not discarded.
|
||||
# @lint-ignore BUCKLINT: Avoid `link_whole=True`
|
||||
link_whole = True,
|
||||
# Define a soname that can be used for dynamic loading in Java, Python, etc.
|
||||
soname = "lib{}.$(ext)".format(name),
|
||||
visibility = ["PUBLIC"],
|
||||
exported_deps = [
|
||||
"//xplat/caffe2:torch_vulkan_api",
|
||||
],
|
||||
)
|
||||
|
||||
def copy_metal(name, apple_sdks = None):
|
||||
cmd = []
|
||||
cmd_exe = []
|
||||
|
@ -513,115 +513,127 @@ def getShaderInfo(srcFilePath: str) -> ShaderInfo:
|
||||
# C++ File Generation #
|
||||
#########################
|
||||
|
||||
cpp_template = """
|
||||
#include <ATen/native/vulkan/api/ShaderRegistry.h>
|
||||
#include <stdint.h>
|
||||
#include <vector>
|
||||
|
||||
def gen_cpp_files(
|
||||
spv_files: Dict[str, str], cpp_header_path: str, cpp_src_file_path: str
|
||||
) -> None:
|
||||
h = "#pragma once\n"
|
||||
h += "#include <ATen/native/vulkan/api/Types.h>\n"
|
||||
h += "#include <ATen/native/vulkan/api/vk_api.h>\n"
|
||||
h += "#include <string>\n"
|
||||
h += "#include <unordered_map>\n"
|
||||
using namespace at::native::vulkan;
|
||||
|
||||
nsbegin = "namespace at {\nnamespace native {\nnamespace vulkan {\n"
|
||||
nsend = "} // namespace vulkan\n} // namespace native\n} // namespace at\n"
|
||||
namespace at {{
|
||||
namespace native {{
|
||||
namespace vulkan {{
|
||||
|
||||
anon_ns_begin = "namespace {\n"
|
||||
anon_ns_end = "} // namespace\n"
|
||||
namespace {{
|
||||
|
||||
h += nsbegin
|
||||
{spv_bin_arrays}
|
||||
|
||||
# Forward declaration of ShaderInfo
|
||||
h += "namespace api {\nstruct ShaderInfo;\n} // namespace api\n"
|
||||
h += "typedef std::unordered_map<std::string, api::ShaderInfo> ShaderListing;\n"
|
||||
h += "typedef std::unordered_map<std::string, std::string> RegistryKeyMap;\n"
|
||||
h += "typedef std::unordered_map<std::string, RegistryKeyMap> ShaderRegistry;\n"
|
||||
h += "extern const ShaderListing shader_infos;\n"
|
||||
h += "extern ShaderRegistry shader_registry;\n"
|
||||
h += "inline const ShaderListing& get_shader_infos() {\n return shader_infos;\n}\n"
|
||||
h += (
|
||||
"inline ShaderRegistry& get_shader_registry() {\n return shader_registry;\n}\n"
|
||||
}}
|
||||
|
||||
static void register_fn() {{
|
||||
|
||||
{register_shader_infos}
|
||||
|
||||
{shader_info_registry}
|
||||
|
||||
}}
|
||||
|
||||
static const api::ShaderRegisterInit register_shaders(®ister_fn);
|
||||
|
||||
}}
|
||||
}}
|
||||
}}
|
||||
|
||||
"""
|
||||
|
||||
|
||||
def generateSpvBinStr(spvPath: str, name: str) -> Tuple[int, str]:
|
||||
with open(spvPath, "rb") as fr:
|
||||
next_bin = array.array("I", fr.read())
|
||||
sizeBytes = 4 * len(next_bin)
|
||||
spv_bin_str = "const uint32_t {}_bin[] = {{\n{}\n}};".format(
|
||||
name,
|
||||
textwrap.indent(",\n".join(str(x) for x in next_bin), " "),
|
||||
)
|
||||
|
||||
return sizeBytes, spv_bin_str
|
||||
|
||||
|
||||
def generateShaderInfoStr(shader_info: ShaderInfo, name: str, sizeBytes: int) -> str:
|
||||
tile_size = (
|
||||
f"{{{', '.join(str(x) for x in shader_info.tile_size)}}}"
|
||||
if (len(shader_info.tile_size) > 0)
|
||||
else "std::vector<uint32_t>()"
|
||||
)
|
||||
|
||||
h += nsend
|
||||
shader_info_layouts = "{{{}}}".format(",\n ".join(shader_info.layouts))
|
||||
|
||||
cpp = "#include <ATen/native/vulkan/api/Shader.h>\n"
|
||||
cpp += f"#include <ATen/native/vulkan/{CPP_H_NAME}>\n"
|
||||
cpp += "#include <stdint.h>\n"
|
||||
cpp += "#include <vector>\n"
|
||||
cpp += nsbegin
|
||||
shader_info_args = [
|
||||
f'"{name}"',
|
||||
f"{name}_bin",
|
||||
str(sizeBytes),
|
||||
shader_info_layouts,
|
||||
tile_size,
|
||||
storageTypeToEnum[shader_info.weight_storage_type],
|
||||
storageTypeToEnum[shader_info.bias_storage_type],
|
||||
]
|
||||
|
||||
shader_info_bin_code = []
|
||||
shader_info_cpp_code = []
|
||||
shader_info_registry_code = []
|
||||
shader_info_str = textwrap.indent(
|
||||
"api::shader_registry().register_shader(\n api::ShaderInfo(\n{args}));\n".format(
|
||||
args=textwrap.indent(",\n".join(shader_info_args), " "),
|
||||
),
|
||||
" ",
|
||||
)
|
||||
|
||||
return shader_info_str
|
||||
|
||||
|
||||
def generateShaderDispatchStr(shader_info: ShaderInfo, name: str) -> str:
|
||||
if shader_info.register_for is None:
|
||||
return ""
|
||||
|
||||
(op_name, registry_keys) = shader_info.register_for
|
||||
for registry_key in registry_keys:
|
||||
shader_dispatch_str = textwrap.indent(
|
||||
f'api::shader_registry().register_op_dispatch("{op_name}", api::DispatchKey::{registry_key.upper()}, "{name}");',
|
||||
" ",
|
||||
)
|
||||
|
||||
return shader_dispatch_str
|
||||
|
||||
|
||||
def genCppFiles(
|
||||
spv_files: Dict[str, str], cpp_header_path: str, cpp_src_file_path: str
|
||||
) -> None:
|
||||
spv_bin_strs = []
|
||||
register_shader_info_strs = []
|
||||
shader_registry_strs = []
|
||||
|
||||
for spvPath, srcPath in spv_files.items():
|
||||
name = getName(spvPath).replace("_spv", "")
|
||||
|
||||
with open(spvPath, "rb") as fr:
|
||||
next_bin = array.array("I", fr.read())
|
||||
sizeBytes = 4 * len(next_bin)
|
||||
shader_info_bin_code.append(
|
||||
"const uint32_t {}_bin[] = {{\n{}\n}};".format(
|
||||
name,
|
||||
textwrap.indent(",\n".join(str(x) for x in next_bin), " "),
|
||||
),
|
||||
)
|
||||
sizeBytes, spv_bin_str = generateSpvBinStr(spvPath, name)
|
||||
spv_bin_strs.append(spv_bin_str)
|
||||
|
||||
shader_info = getShaderInfo(srcPath)
|
||||
|
||||
tile_size = (
|
||||
f"{{{', '.join(str(x) for x in shader_info.tile_size)}}}"
|
||||
if (len(shader_info.tile_size) > 0)
|
||||
else "std::vector<uint32_t>()"
|
||||
)
|
||||
|
||||
shader_info_layouts = "{{{}}}".format(",\n ".join(shader_info.layouts))
|
||||
|
||||
shader_info_args = [
|
||||
f'"vulkan.{name}"',
|
||||
f"{name}_bin",
|
||||
str(sizeBytes),
|
||||
shader_info_layouts,
|
||||
tile_size,
|
||||
storageTypeToEnum[shader_info.weight_storage_type],
|
||||
storageTypeToEnum[shader_info.bias_storage_type],
|
||||
]
|
||||
|
||||
shader_info_cpp_code.append(
|
||||
textwrap.indent(
|
||||
'{{"{}",\n api::ShaderInfo(\n{})}}'.format(
|
||||
name,
|
||||
textwrap.indent(",\n".join(shader_info_args), " "),
|
||||
),
|
||||
" ",
|
||||
),
|
||||
register_shader_info_strs.append(
|
||||
generateShaderInfoStr(shader_info, name, sizeBytes)
|
||||
)
|
||||
|
||||
if shader_info.register_for is not None:
|
||||
(op_name, registry_keys) = shader_info.register_for
|
||||
for registry_key in registry_keys:
|
||||
shader_info_registry_code.append(
|
||||
textwrap.indent(
|
||||
f'{{"{op_name}", {{{{"{registry_key}", "{name}"}}}}}}',
|
||||
" ",
|
||||
),
|
||||
)
|
||||
shader_registry_strs.append(generateShaderDispatchStr(shader_info, name))
|
||||
|
||||
cpp += anon_ns_begin
|
||||
cpp += "\n".join(shader_info_bin_code) + "\n"
|
||||
cpp += anon_ns_end
|
||||
spv_bin_arrays = "\n".join(spv_bin_strs)
|
||||
register_shader_infos = "\n".join(register_shader_info_strs)
|
||||
shader_info_registry = "\n".join(shader_registry_strs)
|
||||
|
||||
cpp += "const ShaderListing shader_infos = {{\n{}}};\n".format(
|
||||
",\n".join(shader_info_cpp_code),
|
||||
cpp = cpp_template.format(
|
||||
spv_bin_arrays=spv_bin_arrays,
|
||||
register_shader_infos=register_shader_infos,
|
||||
shader_info_registry=shader_info_registry,
|
||||
)
|
||||
cpp += "ShaderRegistry shader_registry = {{\n{}}};\n".format(
|
||||
",\n".join(shader_info_registry_code),
|
||||
)
|
||||
cpp += nsend
|
||||
|
||||
with open(cpp_header_path, "w") as fw:
|
||||
fw.write(h)
|
||||
with open(cpp_src_file_path, "w") as fw:
|
||||
fw.write(cpp)
|
||||
|
||||
@ -672,7 +684,7 @@ def main(argv: List[str]) -> int:
|
||||
shader_generator = SPVGenerator(options.glsl_paths, env, options.glslc_path)
|
||||
output_spv_files = shader_generator.generateSPV(options.tmp_dir_path)
|
||||
|
||||
gen_cpp_files(
|
||||
genCppFiles(
|
||||
output_spv_files,
|
||||
f"{options.output_path}/{CPP_H_NAME}",
|
||||
f"{options.output_path}/{CPP_SRC_NAME}",
|
||||
|
Reference in New Issue
Block a user