mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[Vulkan] Generate ShaderInfos Directly via Codegen in gen_vulkan_spv (#91911)
@bypass-github-export-checks Before this change, we have the data members which make up a ```ShaderInfo``` sitting in ```spv.h/.cpp``` in an unorganized manner. This diff makes the change such that the ```ShaderInfo```s are initialized directly in spv.h/.cpp Now spv.h looks like ``` #pragma once #include <stdint.h> #include <vector> #include <string> #include <ATen/native/vulkan/api/Types.h> #include <ATen/native/vulkan/api/vk_api.h> namespace at { namespace native { namespace vulkan { namespace api { struct ShaderInfo; } // namespace api extern const api::ShaderInfo adaptive_avg_pool2d_spv; ... extern const api::ShaderInfo conv2d_pw_2x2_spv; } // namespace vulkan } // namespace native } // namespace at ``` (Full File: P557399150) and spv.cpp looks like ``` #include <ATen/native/vulkan/spv.h> #include <ATen/native/vulkan/api/Shader.h> namespace at { namespace native { namespace vulkan { namespace { const uint32_t adaptive_avg_pool2d_spv_bin[] = { 119734787, ... }; ... const uint32_t conv2d_pw_2x2_spv_bin[] = { 119734787, ... }; } // namespace const api::ShaderInfo adaptive_avg_pool2d_spv( "vulkan.adaptive_avg_pool2d", adaptive_avg_pool2d_spv_bin, 3204, {VK_DESCRIPTOR_TYPE_STORAGE_IMAGE, VK_DESCRIPTOR_TYPE_COMBINED_IMAGE_SAMPLER, VK_DESCRIPTOR_TYPE_UNIFORM_BUFFER}, std::vector<uint32_t>(), api::StorageType::UNKNOWN, api::StorageType::UNKNOWN ); ... const api::ShaderInfo conv2d_pw_2x2_spv( "vulkan.conv2d_pw_2x2", conv2d_pw_2x2_spv_bin, 7736, {VK_DESCRIPTOR_TYPE_STORAGE_IMAGE, VK_DESCRIPTOR_TYPE_COMBINED_IMAGE_SAMPLER, VK_DESCRIPTOR_TYPE_COMBINED_IMAGE_SAMPLER, VK_DESCRIPTOR_TYPE_COMBINED_IMAGE_SAMPLER, VK_DESCRIPTOR_TYPE_UNIFORM_BUFFER}, {2, 2, 1}, api::StorageType::TEXTURE_2D, api::StorageType::TEXTURE_2D ); } // namespace vulkan } // namespace native } // namespace at ``` (Full File: P584237146) Differential Revision: [D41354313](https://our.internmc.facebook.com/intern/diff/D41354313/) Pull Request resolved: https://github.com/pytorch/pytorch/pull/91911 Approved by: https://github.com/mcr229
This commit is contained in:
committed by
PyTorch MergeBot
parent
776fef9ecc
commit
28eb3c8faf
@ -16,17 +16,7 @@
|
||||
}
|
||||
#else
|
||||
#include <ATen/native/vulkan/spv.h>
|
||||
#define VK_KERNEL(name) \
|
||||
::at::native::vulkan::api::ShaderInfo { \
|
||||
CONCAT_LITERALS(vulkan., name), name##_spv, name##_spv_len, \
|
||||
name##_spv_layout \
|
||||
}
|
||||
#define VK_SHADER(name) \
|
||||
::at::native::vulkan::api::ShaderInfo { \
|
||||
CONCAT_LITERALS(vulkan., name), name##_spv, name##_spv_len, \
|
||||
name##_spv_layout, name##_spv_tile_size, name##_spv_bias_storage_type, \
|
||||
name##_spv_weight_storage_type, \
|
||||
}
|
||||
#define VK_KERNEL(name) ::at::native::vulkan::name##_spv
|
||||
#endif /* USE_VULKAN_SHADERC_RUNTIME */
|
||||
|
||||
namespace at {
|
||||
|
@ -296,13 +296,13 @@ static api::ShaderInfo get_shader(
|
||||
|
||||
switch (method) {
|
||||
case Conv2dSlidingWindow:
|
||||
shader = VK_SHADER(quantized_conv2d);
|
||||
shader = VK_KERNEL(quantized_conv2d);
|
||||
break;
|
||||
case Conv2dDepthwise:
|
||||
shader = VK_SHADER(quantized_conv2d_dw);
|
||||
shader = VK_KERNEL(quantized_conv2d_dw);
|
||||
break;
|
||||
case Conv2dPointwise:
|
||||
shader = VK_SHADER(quantized_conv2d_pw_2x2);
|
||||
shader = VK_KERNEL(quantized_conv2d_pw_2x2);
|
||||
break;
|
||||
// todo fail for quantized transposed conv
|
||||
}
|
||||
@ -310,29 +310,29 @@ static api::ShaderInfo get_shader(
|
||||
}
|
||||
|
||||
if (transposed) {
|
||||
shader = VK_SHADER(conv_transpose2d);
|
||||
shader = VK_KERNEL(conv_transpose2d);
|
||||
return shader;
|
||||
}
|
||||
|
||||
switch (method) {
|
||||
case Conv2dSlidingWindow:
|
||||
shader = VK_SHADER(conv2d);
|
||||
shader = VK_KERNEL(conv2d);
|
||||
break;
|
||||
case Conv2dDepthwise:
|
||||
shader = VK_SHADER(conv2d_dw);
|
||||
shader = VK_KERNEL(conv2d_dw);
|
||||
if (kernel_size.size() == 4 && kernel_size[2] == 3 &&
|
||||
kernel_size[3] == 3) {
|
||||
// 1x1 refers to the output tile size
|
||||
shader = VK_SHADER(conv2d_dw_3x3);
|
||||
shader = VK_KERNEL(conv2d_dw_3x3);
|
||||
}
|
||||
if (kernel_size.size() == 4 && kernel_size[2] == 5 &&
|
||||
kernel_size[3] == 5) {
|
||||
// 1x1 refers to the output tile size
|
||||
shader = VK_SHADER(conv2d_dw_5x5);
|
||||
shader = VK_KERNEL(conv2d_dw_5x5);
|
||||
}
|
||||
break;
|
||||
case Conv2dPointwise:
|
||||
shader = VK_SHADER(conv2d_pw_2x2);
|
||||
shader = VK_KERNEL(conv2d_pw_2x2);
|
||||
break;
|
||||
}
|
||||
return shader;
|
||||
|
@ -70,6 +70,7 @@ storageTypeToEnum = {
|
||||
"TEXTURE_2D" : "api::StorageType::TEXTURE_2D",
|
||||
"TEXTURE_3D" : "api::StorageType::TEXTURE_3D",
|
||||
"BUFFER" : "api::StorageType::BUFFER",
|
||||
"": "api::StorageType::UNKNOWN",
|
||||
}
|
||||
|
||||
def determineDescriptorType(lineStr):
|
||||
@ -165,64 +166,67 @@ def genCppH(hFilePath, cppFilePath, srcDirPath, glslcPath, tmpDirPath, env):
|
||||
h += "#include <vector>\n"
|
||||
h += "#include <string>\n"
|
||||
h += "#include <ATen/native/vulkan/api/Types.h>\n"
|
||||
h += "#include <ATen/native/vulkan/api/vk_api.h>"
|
||||
h += "#include <ATen/native/vulkan/api/vk_api.h>\n"
|
||||
|
||||
nsbegin = "\nnamespace at {\nnamespace native {\nnamespace vulkan {\n"
|
||||
nsend = "\n}\n}\n} //namespace at::native::vulkan\n"
|
||||
nsbegin = "namespace at {\nnamespace native {\nnamespace vulkan {\n"
|
||||
nsend = "} // namespace vulkan\n} // namespace native\n} // namespace at\n"
|
||||
|
||||
h += nsbegin
|
||||
|
||||
cpp = "#include <ATen/native/vulkan/{}>".format(H_NAME)
|
||||
# Forward declaration of ShaderInfo
|
||||
h += "namespace api {\nstruct ShaderInfo;\n} // namespace api\n"
|
||||
|
||||
cpp = "#include <ATen/native/vulkan/{}>\n".format(H_NAME)
|
||||
cpp += "#include <ATen/native/vulkan/api/Shader.h>\n"
|
||||
cpp += nsbegin
|
||||
|
||||
shader_info_bin_code = []
|
||||
shader_info_cpp_code = []
|
||||
shader_info_h_code = []
|
||||
|
||||
for spvPath, srcPath in spvPaths.items():
|
||||
name = getName(spvPath)
|
||||
name_len = name + "_len"
|
||||
h += "extern const uint32_t {}[];\n".format(name)
|
||||
h += "extern const uint32_t {};\n".format(name_len)
|
||||
|
||||
shader_info = getShaderInfo(srcPath)
|
||||
name_layout = name + "_layout"
|
||||
h += "extern const std::vector<VkDescriptorType> {};\n".format(name_layout)
|
||||
|
||||
cpp += "const uint32_t " + name + "[] = {\n"
|
||||
sizeBytes = 0
|
||||
print("spvPath:{}".format(spvPath))
|
||||
with open(spvPath, 'rb') as f:
|
||||
for word in array.array('I', f.read()):
|
||||
cpp += "{},\n".format(word)
|
||||
sizeBytes += 4
|
||||
cpp += "};\n"
|
||||
cpp += "const uint32_t {} = {};\n".format(name_len, sizeBytes)
|
||||
next_bin = array.array('I', f.read())
|
||||
sizeBytes = 4 * len(next_bin)
|
||||
shader_info_bin_code.append(
|
||||
"const uint32_t {}_bin[] = {{\n {}\n}};".format(
|
||||
name,
|
||||
",\n ".join(str(x) for x in next_bin),
|
||||
)
|
||||
)
|
||||
|
||||
# Add layout
|
||||
cpp += "const std::vector<VkDescriptorType> {} = {{\n".format(name_layout)
|
||||
for descriptor in shader_info.layouts:
|
||||
cpp += " {},\n".format(descriptor)
|
||||
cpp += "};\n"
|
||||
shader_info = getShaderInfo(srcPath)
|
||||
|
||||
# Add tile size
|
||||
if (len(shader_info.tile_size) > 0):
|
||||
name_tile_size = name + "_tile_size"
|
||||
h += "extern const std::vector<uint32_t> {};\n".format(name_tile_size)
|
||||
cpp += "const std::vector<uint32_t> {} = {{\n".format(name_tile_size)
|
||||
for s in shader_info.tile_size:
|
||||
cpp += " {},\n".format(s)
|
||||
cpp += "};\n"
|
||||
tile_size = (
|
||||
"{{{}}}".format(", ".join(str(x) for x in shader_info.tile_size))
|
||||
if (len(shader_info.tile_size) > 0)
|
||||
else "std::vector<uint32_t>()"
|
||||
)
|
||||
|
||||
# Add weight type
|
||||
if (shader_info.weight_storage_type != ""):
|
||||
name_weight_storage_type = name + "_weight_storage_type"
|
||||
h += "extern const api::StorageType {};\n".format(name_weight_storage_type)
|
||||
cpp += "const api::StorageType {} = \n".format(name_weight_storage_type)
|
||||
cpp += " {};\n".format(storageTypeToEnum[shader_info.weight_storage_type])
|
||||
shader_info_args = [
|
||||
"\"vulkan.{}\"".format(name.replace("_spv", "")),
|
||||
"{}_bin".format(name),
|
||||
str(sizeBytes),
|
||||
"{{{}}}".format(", ".join(shader_info.layouts)),
|
||||
tile_size,
|
||||
storageTypeToEnum[shader_info.weight_storage_type],
|
||||
storageTypeToEnum[shader_info.bias_storage_type],
|
||||
]
|
||||
|
||||
# Add bias type
|
||||
if (shader_info.bias_storage_type != ""):
|
||||
name_bias_storage_type = name + "_bias_storage_type"
|
||||
h += "extern const api::StorageType {};\n".format(name_bias_storage_type)
|
||||
cpp += "const api::StorageType {} = \n".format(name_bias_storage_type)
|
||||
cpp += " {};\n".format(storageTypeToEnum[shader_info.bias_storage_type])
|
||||
shader_info_h_code.append("extern const api::ShaderInfo {};".format(name))
|
||||
shader_info_cpp_code.append(
|
||||
"const api::ShaderInfo {}(\n {}\n);".format(
|
||||
name,
|
||||
",\n ".join(shader_info_args),
|
||||
),
|
||||
)
|
||||
|
||||
cpp += "namespace {{\n{}\n}} // namespace\n".format("\n".join(shader_info_bin_code))
|
||||
cpp += "{}\n".format("\n".join(shader_info_cpp_code))
|
||||
h += "{}\n".format("\n".join(shader_info_h_code))
|
||||
|
||||
cpp += nsend
|
||||
h += nsend
|
||||
|
Reference in New Issue
Block a user