[ATen-vulkan] Implement global shader registry (#121088)

Differential Revision: D54447700

## Context

This changeset updates Vulkan SPIR-V codegen to introduce a global SPIR-V shader registry and register shaders dynamically at static initialization time. This change makes it possible to define and link custom shader libraries to the ATen-Vulkan runtime.

Before:

* `gen_vulkan_spv.py` generated two files, `spv.h` and `spv.cpp` which would contain the definition and initialization of Vulkan shader registry variables.

After:

* Introduce the `ShaderRegistry` class in `api/`, which encapsulates functionality of the `ShaderRegistry` class previously defined in the generated `spv.h` file
* Introduce a global shader registry (defined as a static variable in the `api::shader_registry() function`
* Define a `ShaderRegisterInit` class (taking inspiration from `TorchLibraryInit`) that allows for dynamic shader registration
* `gen_vulkan_spv.py` now only generates `spv.cpp`, which defines a static `ShaderRegisterInit` instance that triggers registration of the compiled shaders to the global shader registry.

Benefits:

* Cleaner code base; we no longer have `ShaderRegistry` defined in a generated file, and don't need a separate implementation file (`impl/Registry.*`) to handle shader lookup. All that logic now lives under `api/ShaderRegistry.*`
* Makes it possible to compile and link separate shader libraries, providing similar flexibility as defining and linking custom ATen operators

Pull Request resolved: https://github.com/pytorch/pytorch/pull/121088
Approved by: https://github.com/manuelcandales, https://github.com/jorgep31415
This commit is contained in:
Stephen Jia
2024-03-05 03:56:57 +00:00
committed by PyTorch MergeBot
parent c3c618c750
commit ffe45a8188
10 changed files with 287 additions and 221 deletions

View File

@ -0,0 +1,61 @@
#include <ATen/native/vulkan/api/ShaderRegistry.h>
namespace at {
namespace native {
namespace vulkan {
namespace api {
bool ShaderRegistry::has_shader(const std::string& shader_name) {
const ShaderListing::const_iterator it = listings_.find(shader_name);
return it != listings_.end();
}
bool ShaderRegistry::has_dispatch(const std::string& op_name) {
const Registry::const_iterator it = registry_.find(op_name);
return it != registry_.end();
}
void ShaderRegistry::register_shader(ShaderInfo&& shader_info) {
if (has_shader(shader_info.kernel_name)) {
VK_THROW(
"Shader with name ", shader_info.kernel_name, "already registered");
}
listings_.emplace(shader_info.kernel_name, shader_info);
}
void ShaderRegistry::register_op_dispatch(
const std::string& op_name,
const DispatchKey key,
const std::string& shader_name) {
if (!has_dispatch(op_name)) {
registry_.emplace(op_name, Dispatcher());
}
const Dispatcher::const_iterator it = registry_[op_name].find(key);
if (it != registry_[op_name].end()) {
registry_[op_name][key] = shader_name;
} else {
registry_[op_name].emplace(key, shader_name);
}
}
const ShaderInfo& ShaderRegistry::get_shader_info(
const std::string& shader_name) {
const ShaderListing::const_iterator it = listings_.find(shader_name);
VK_CHECK_COND(
it != listings_.end(),
"Could not find ShaderInfo with name ",
shader_name);
return it->second;
}
ShaderRegistry& shader_registry() {
static ShaderRegistry registry;
return registry;
}
} // namespace api
} // namespace vulkan
} // namespace native
} // namespace at

View File

@ -0,0 +1,84 @@
#pragma once
// @lint-ignore-every CLANGTIDY facebook-hte-BadMemberName
#ifdef USE_VULKAN_API
#include <ATen/native/vulkan/api/Shader.h>
#include <string>
#include <unordered_map>
#define VK_KERNEL(shader_name) \
::at::native::vulkan::api::shader_registry().get_shader_info(#shader_name)
namespace at {
namespace native {
namespace vulkan {
namespace api {
enum class DispatchKey : int8_t {
CATCHALL,
ADRENO,
MALI,
OVERRIDE,
};
class ShaderRegistry final {
using ShaderListing = std::unordered_map<std::string, ShaderInfo>;
using Dispatcher = std::unordered_map<DispatchKey, std::string>;
using Registry = std::unordered_map<std::string, Dispatcher>;
ShaderListing listings_;
Dispatcher dispatcher_;
Registry registry_;
public:
/*
* Check if the registry has a shader registered under the given name
*/
bool has_shader(const std::string& shader_name);
/*
* Check if the registry has a dispatch registered under the given name
*/
bool has_dispatch(const std::string& op_name);
/*
* Register a ShaderInfo to a given shader name
*/
void register_shader(ShaderInfo&& shader_info);
/*
* Register a dispatch entry to the given op name
*/
void register_op_dispatch(
const std::string& op_name,
const DispatchKey key,
const std::string& shader_name);
/*
* Given a shader name, return the ShaderInfo which contains the SPIRV binary
*/
const ShaderInfo& get_shader_info(const std::string& shader_name);
};
class ShaderRegisterInit final {
using InitFn = void();
public:
ShaderRegisterInit(InitFn* init_fn) {
init_fn();
};
};
// The global shader registry is retrieved using this function, where it is
// declared as a static local variable.
ShaderRegistry& shader_registry();
} // namespace api
} // namespace vulkan
} // namespace native
} // namespace at
#endif /* USE_VULKAN_API */

View File

@ -10,6 +10,7 @@
#include <ATen/native/vulkan/api/Resource.h>
#include <ATen/native/vulkan/api/Runtime.h>
#include <ATen/native/vulkan/api/Shader.h>
#include <ATen/native/vulkan/api/ShaderRegistry.h>
#include <ATen/native/vulkan/api/Tensor.h>
#include <ATen/native/vulkan/api/Utils.h>

View File

@ -3,12 +3,6 @@
#ifdef USE_VULKAN_API
#include <ATen/native/vulkan/api/api.h>
#include <ATen/native/vulkan/impl/Registry.h>
#define VK_KERNEL(shader_name) \
::at::native::vulkan::get_shader_info(#shader_name)
#define VK_LOOKUP_KERNEL(op_name) \
::at::native::vulkan::look_up_shader_info(#op_name)
namespace at {
namespace native {

View File

@ -1,88 +0,0 @@
#ifdef USE_VULKAN_API
#include <ATen/native/vulkan/api/Shader.h>
#include <ATen/native/vulkan/impl/Registry.h>
#include <ATen/native/vulkan/spv.h>
namespace at {
namespace native {
namespace vulkan {
const api::ShaderInfo& get_shader_info(const std::string& shader_name) {
const ShaderListing::const_iterator shader_infos_iterator =
get_shader_infos().find(shader_name);
VK_CHECK_COND(
shader_infos_iterator != get_shader_infos().end(),
"Could not get ShaderInfo named ",
shader_name);
return shader_infos_iterator->second;
}
const api::ShaderInfo& look_up_shader_info(const std::string& op_name) {
const ShaderRegistry::iterator registry_iterator =
get_shader_registry().find(op_name);
VK_CHECK_COND(
registry_iterator != get_shader_registry().end(),
"Could not look up ShaderInfo for ",
op_name,
" in shader registry");
const RegistryKeyMap& registry_key_map = registry_iterator->second;
// Look for "override" and "catchall" keys
for (const std::string key : {"override", "catchall"}) {
const RegistryKeyMap::const_iterator registry_key_iterator =
registry_key_map.find(key);
if (registry_key_iterator != registry_key_map.end()) {
const ShaderListing::const_iterator shader_infos_iterator =
get_shader_infos().find(registry_key_iterator->second);
VK_CHECK_COND(
shader_infos_iterator != get_shader_infos().end(),
"Could not get ShaderInfo named ",
registry_key_iterator->second,
" (listed under ",
op_name,
" -> ",
key,
" in shader registry)");
return shader_infos_iterator->second;
}
}
VK_CHECK_COND(
false,
"Could not look up ShaderInfo for ",
op_name,
" with a valid key in shader registry");
}
void set_registry_override(
const std::string& op_name,
const std::string& shader_name) {
const ShaderRegistry::iterator registry_iterator =
get_shader_registry().find(op_name);
VK_CHECK_COND(
registry_iterator != get_shader_registry().end(),
"Could not look up ShaderInfo for ",
op_name,
" in shader registry");
VK_CHECK_COND(
get_shader_infos().find(shader_name) != get_shader_infos().end(),
"Could not get ShaderInfo named ",
shader_name);
registry_iterator->second["override"] = shader_name;
}
} // namespace vulkan
} // namespace native
} // namespace at
#endif // USE_VULKAN_API

View File

@ -1,33 +0,0 @@
#pragma once
#ifdef USE_VULKAN_API
#include <string>
namespace at {
namespace native {
namespace vulkan {
namespace api {
// Forward declaration of ShaderInfo
struct ShaderInfo;
} // namespace api
/**
* Get the shader with a given name
*/
const api::ShaderInfo& get_shader_info(const std::string& shader_name);
/**
* Look up which shader to use for a given op in the shader registry
*/
const api::ShaderInfo& look_up_shader_info(const std::string& op_name);
void set_registry_override(
const std::string& op_name,
const std::string& shader_name);
} // namespace vulkan
} // namespace native
} // namespace at
#endif // USE_VULKAN_API

View File

@ -10,11 +10,6 @@
#include <ATen/native/vulkan/impl/Common.h>
#include <ATen/native/vulkan/ops/Convert.h>
#define VK_KERNEL(shader_name) \
::at::native::vulkan::get_shader_info(#shader_name)
#define VK_LOOKUP_KERNEL(op_name) \
::at::native::vulkan::look_up_shader_info(#op_name)
namespace at {
namespace native {
namespace vulkan {

View File

@ -319,7 +319,7 @@ static api::ShaderInfo get_shader(
switch (method) {
case Conv2dSlidingWindow:
shader = VK_LOOKUP_KERNEL(conv2d);
shader = VK_KERNEL(conv2d);
break;
case Conv2dDepthwise:
shader = VK_KERNEL(conv2d_dw);
@ -335,7 +335,7 @@ static api::ShaderInfo get_shader(
}
break;
case Conv2dPointwise:
shader = VK_LOOKUP_KERNEL(conv2d_pw);
shader = VK_KERNEL(conv2d_pw_output_tile_2x2);
break;
}
return shader;

View File

@ -144,6 +144,9 @@ def get_glsl_paths():
],
)
def spv_shader_library():
pass
# @lint-ignore BUCKRESTRICTEDSYNTAX
IS_OSS = read_config("pt", "is_oss", "0") == "1" # True for OSS BUCK build, and False for internal BUCK build
@ -700,6 +703,43 @@ def gen_aten_libtorch_files(name, extra_params = [], compatible_with = [], apple
apple_sdks = apple_sdks,
)
def vulkan_spv_shader_library(name, spv_filegroup):
genrule_cmd = [
"$(exe //xplat/caffe2/tools:gen_aten_vulkan_spv_bin)",
"--glsl-paths $(location {})".format(spv_filegroup),
"--output-path $OUT --env FLOAT_IMAGE_FORMAT={}".format(get_glsl_image_format()),
"--glslc-path=$(exe //xplat/caffe2/fb/vulkan/dotslash:glslc)",
"--tmp-dir-path=$TMP",
]
genrule_name = "gen_{}_cpp".format(name)
fb_xplat_genrule(
name = "gen_{}_cpp".format(name),
outs = {
"{}.cpp".format(name): ["spv.cpp"],
},
cmd = " ".join(genrule_cmd),
default_outs = ["."],
labels = ["uses_dotslash"],
)
fb_xplat_cxx_library(
name = name,
srcs = [
":{}[{}.cpp]".format(genrule_name, name),
],
# Static initialization is used to register shaders to the global shader registry,
# therefore link_whole must be True to make sure unused symbols are not discarded.
# @lint-ignore BUCKLINT: Avoid `link_whole=True`
link_whole = True,
# Define a soname that can be used for dynamic loading in Java, Python, etc.
soname = "lib{}.$(ext)".format(name),
visibility = ["PUBLIC"],
exported_deps = [
"//xplat/caffe2:torch_vulkan_api",
],
)
def copy_metal(name, apple_sdks = None):
cmd = []
cmd_exe = []

View File

@ -513,115 +513,127 @@ def getShaderInfo(srcFilePath: str) -> ShaderInfo:
# C++ File Generation #
#########################
cpp_template = """
#include <ATen/native/vulkan/api/ShaderRegistry.h>
#include <stdint.h>
#include <vector>
def gen_cpp_files(
spv_files: Dict[str, str], cpp_header_path: str, cpp_src_file_path: str
) -> None:
h = "#pragma once\n"
h += "#include <ATen/native/vulkan/api/Types.h>\n"
h += "#include <ATen/native/vulkan/api/vk_api.h>\n"
h += "#include <string>\n"
h += "#include <unordered_map>\n"
using namespace at::native::vulkan;
nsbegin = "namespace at {\nnamespace native {\nnamespace vulkan {\n"
nsend = "} // namespace vulkan\n} // namespace native\n} // namespace at\n"
namespace at {{
namespace native {{
namespace vulkan {{
anon_ns_begin = "namespace {\n"
anon_ns_end = "} // namespace\n"
namespace {{
h += nsbegin
{spv_bin_arrays}
# Forward declaration of ShaderInfo
h += "namespace api {\nstruct ShaderInfo;\n} // namespace api\n"
h += "typedef std::unordered_map<std::string, api::ShaderInfo> ShaderListing;\n"
h += "typedef std::unordered_map<std::string, std::string> RegistryKeyMap;\n"
h += "typedef std::unordered_map<std::string, RegistryKeyMap> ShaderRegistry;\n"
h += "extern const ShaderListing shader_infos;\n"
h += "extern ShaderRegistry shader_registry;\n"
h += "inline const ShaderListing& get_shader_infos() {\n return shader_infos;\n}\n"
h += (
"inline ShaderRegistry& get_shader_registry() {\n return shader_registry;\n}\n"
}}
static void register_fn() {{
{register_shader_infos}
{shader_info_registry}
}}
static const api::ShaderRegisterInit register_shaders(&register_fn);
}}
}}
}}
"""
def generateSpvBinStr(spvPath: str, name: str) -> Tuple[int, str]:
with open(spvPath, "rb") as fr:
next_bin = array.array("I", fr.read())
sizeBytes = 4 * len(next_bin)
spv_bin_str = "const uint32_t {}_bin[] = {{\n{}\n}};".format(
name,
textwrap.indent(",\n".join(str(x) for x in next_bin), " "),
)
return sizeBytes, spv_bin_str
def generateShaderInfoStr(shader_info: ShaderInfo, name: str, sizeBytes: int) -> str:
tile_size = (
f"{{{', '.join(str(x) for x in shader_info.tile_size)}}}"
if (len(shader_info.tile_size) > 0)
else "std::vector<uint32_t>()"
)
h += nsend
shader_info_layouts = "{{{}}}".format(",\n ".join(shader_info.layouts))
cpp = "#include <ATen/native/vulkan/api/Shader.h>\n"
cpp += f"#include <ATen/native/vulkan/{CPP_H_NAME}>\n"
cpp += "#include <stdint.h>\n"
cpp += "#include <vector>\n"
cpp += nsbegin
shader_info_args = [
f'"{name}"',
f"{name}_bin",
str(sizeBytes),
shader_info_layouts,
tile_size,
storageTypeToEnum[shader_info.weight_storage_type],
storageTypeToEnum[shader_info.bias_storage_type],
]
shader_info_bin_code = []
shader_info_cpp_code = []
shader_info_registry_code = []
shader_info_str = textwrap.indent(
"api::shader_registry().register_shader(\n api::ShaderInfo(\n{args}));\n".format(
args=textwrap.indent(",\n".join(shader_info_args), " "),
),
" ",
)
return shader_info_str
def generateShaderDispatchStr(shader_info: ShaderInfo, name: str) -> str:
if shader_info.register_for is None:
return ""
(op_name, registry_keys) = shader_info.register_for
for registry_key in registry_keys:
shader_dispatch_str = textwrap.indent(
f'api::shader_registry().register_op_dispatch("{op_name}", api::DispatchKey::{registry_key.upper()}, "{name}");',
" ",
)
return shader_dispatch_str
def genCppFiles(
spv_files: Dict[str, str], cpp_header_path: str, cpp_src_file_path: str
) -> None:
spv_bin_strs = []
register_shader_info_strs = []
shader_registry_strs = []
for spvPath, srcPath in spv_files.items():
name = getName(spvPath).replace("_spv", "")
with open(spvPath, "rb") as fr:
next_bin = array.array("I", fr.read())
sizeBytes = 4 * len(next_bin)
shader_info_bin_code.append(
"const uint32_t {}_bin[] = {{\n{}\n}};".format(
name,
textwrap.indent(",\n".join(str(x) for x in next_bin), " "),
),
)
sizeBytes, spv_bin_str = generateSpvBinStr(spvPath, name)
spv_bin_strs.append(spv_bin_str)
shader_info = getShaderInfo(srcPath)
tile_size = (
f"{{{', '.join(str(x) for x in shader_info.tile_size)}}}"
if (len(shader_info.tile_size) > 0)
else "std::vector<uint32_t>()"
)
shader_info_layouts = "{{{}}}".format(",\n ".join(shader_info.layouts))
shader_info_args = [
f'"vulkan.{name}"',
f"{name}_bin",
str(sizeBytes),
shader_info_layouts,
tile_size,
storageTypeToEnum[shader_info.weight_storage_type],
storageTypeToEnum[shader_info.bias_storage_type],
]
shader_info_cpp_code.append(
textwrap.indent(
'{{"{}",\n api::ShaderInfo(\n{})}}'.format(
name,
textwrap.indent(",\n".join(shader_info_args), " "),
),
" ",
),
register_shader_info_strs.append(
generateShaderInfoStr(shader_info, name, sizeBytes)
)
if shader_info.register_for is not None:
(op_name, registry_keys) = shader_info.register_for
for registry_key in registry_keys:
shader_info_registry_code.append(
textwrap.indent(
f'{{"{op_name}", {{{{"{registry_key}", "{name}"}}}}}}',
" ",
),
)
shader_registry_strs.append(generateShaderDispatchStr(shader_info, name))
cpp += anon_ns_begin
cpp += "\n".join(shader_info_bin_code) + "\n"
cpp += anon_ns_end
spv_bin_arrays = "\n".join(spv_bin_strs)
register_shader_infos = "\n".join(register_shader_info_strs)
shader_info_registry = "\n".join(shader_registry_strs)
cpp += "const ShaderListing shader_infos = {{\n{}}};\n".format(
",\n".join(shader_info_cpp_code),
cpp = cpp_template.format(
spv_bin_arrays=spv_bin_arrays,
register_shader_infos=register_shader_infos,
shader_info_registry=shader_info_registry,
)
cpp += "ShaderRegistry shader_registry = {{\n{}}};\n".format(
",\n".join(shader_info_registry_code),
)
cpp += nsend
with open(cpp_header_path, "w") as fw:
fw.write(h)
with open(cpp_src_file_path, "w") as fw:
fw.write(cpp)
@ -672,7 +684,7 @@ def main(argv: List[str]) -> int:
shader_generator = SPVGenerator(options.glsl_paths, env, options.glslc_path)
output_spv_files = shader_generator.generateSPV(options.tmp_dir_path)
gen_cpp_files(
genCppFiles(
output_spv_files,
f"{options.output_path}/{CPP_H_NAME}",
f"{options.output_path}/{CPP_SRC_NAME}",