[Vulkan] Generate ShaderInfos Directly via Codegen in gen_vulkan_spv (#91911)

@bypass-github-export-checks

Before this change, we have the data members which make up a ```ShaderInfo``` sitting in ```spv.h/.cpp``` in an unorganized manner. This diff makes the change such that the ```ShaderInfo```s are initialized directly in spv.h/.cpp

Now spv.h looks like
```
#pragma once
#include <stdint.h>
#include <vector>
#include <string>
#include <ATen/native/vulkan/api/Types.h>
#include <ATen/native/vulkan/api/vk_api.h>
namespace at {
namespace native {
namespace vulkan {
namespace api {
struct ShaderInfo;
} // namespace api
extern const api::ShaderInfo adaptive_avg_pool2d_spv;
...
extern const api::ShaderInfo conv2d_pw_2x2_spv;
} // namespace vulkan
} // namespace native
} // namespace at
```
(Full File: P557399150)
and spv.cpp looks like
```
#include <ATen/native/vulkan/spv.h>
#include <ATen/native/vulkan/api/Shader.h>
namespace at {
namespace native {
namespace vulkan {
namespace {
const uint32_t adaptive_avg_pool2d_spv_bin[] = {
  119734787,
  ...
};
...
const uint32_t conv2d_pw_2x2_spv_bin[] = {
  119734787,
  ...
};
} // namespace
const api::ShaderInfo adaptive_avg_pool2d_spv(
  "vulkan.adaptive_avg_pool2d",
  adaptive_avg_pool2d_spv_bin,
  3204,
  {VK_DESCRIPTOR_TYPE_STORAGE_IMAGE, VK_DESCRIPTOR_TYPE_COMBINED_IMAGE_SAMPLER, VK_DESCRIPTOR_TYPE_UNIFORM_BUFFER},
  std::vector<uint32_t>(),
  api::StorageType::UNKNOWN,
  api::StorageType::UNKNOWN
);
...
const api::ShaderInfo conv2d_pw_2x2_spv(
  "vulkan.conv2d_pw_2x2",
  conv2d_pw_2x2_spv_bin,
  7736,
  {VK_DESCRIPTOR_TYPE_STORAGE_IMAGE, VK_DESCRIPTOR_TYPE_COMBINED_IMAGE_SAMPLER, VK_DESCRIPTOR_TYPE_COMBINED_IMAGE_SAMPLER, VK_DESCRIPTOR_TYPE_COMBINED_IMAGE_SAMPLER, VK_DESCRIPTOR_TYPE_UNIFORM_BUFFER},
  {2, 2, 1},
  api::StorageType::TEXTURE_2D,
  api::StorageType::TEXTURE_2D
);
} // namespace vulkan
} // namespace native
} // namespace at

```
(Full File: P584237146)

Differential Revision: [D41354313](https://our.internmc.facebook.com/intern/diff/D41354313/)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/91911
Approved by: https://github.com/mcr229
This commit is contained in:
salilsdesai
2023-01-09 18:07:59 -08:00
committed by PyTorch MergeBot
parent 776fef9ecc
commit 28eb3c8faf
3 changed files with 57 additions and 63 deletions

View File

@ -16,17 +16,7 @@
}
#else
#include <ATen/native/vulkan/spv.h>
#define VK_KERNEL(name) \
::at::native::vulkan::api::ShaderInfo { \
CONCAT_LITERALS(vulkan., name), name##_spv, name##_spv_len, \
name##_spv_layout \
}
#define VK_SHADER(name) \
::at::native::vulkan::api::ShaderInfo { \
CONCAT_LITERALS(vulkan., name), name##_spv, name##_spv_len, \
name##_spv_layout, name##_spv_tile_size, name##_spv_bias_storage_type, \
name##_spv_weight_storage_type, \
}
#define VK_KERNEL(name) ::at::native::vulkan::name##_spv
#endif /* USE_VULKAN_SHADERC_RUNTIME */
namespace at {

View File

@ -296,13 +296,13 @@ static api::ShaderInfo get_shader(
switch (method) {
case Conv2dSlidingWindow:
shader = VK_SHADER(quantized_conv2d);
shader = VK_KERNEL(quantized_conv2d);
break;
case Conv2dDepthwise:
shader = VK_SHADER(quantized_conv2d_dw);
shader = VK_KERNEL(quantized_conv2d_dw);
break;
case Conv2dPointwise:
shader = VK_SHADER(quantized_conv2d_pw_2x2);
shader = VK_KERNEL(quantized_conv2d_pw_2x2);
break;
// todo fail for quantized transposed conv
}
@ -310,29 +310,29 @@ static api::ShaderInfo get_shader(
}
if (transposed) {
shader = VK_SHADER(conv_transpose2d);
shader = VK_KERNEL(conv_transpose2d);
return shader;
}
switch (method) {
case Conv2dSlidingWindow:
shader = VK_SHADER(conv2d);
shader = VK_KERNEL(conv2d);
break;
case Conv2dDepthwise:
shader = VK_SHADER(conv2d_dw);
shader = VK_KERNEL(conv2d_dw);
if (kernel_size.size() == 4 && kernel_size[2] == 3 &&
kernel_size[3] == 3) {
// 1x1 refers to the output tile size
shader = VK_SHADER(conv2d_dw_3x3);
shader = VK_KERNEL(conv2d_dw_3x3);
}
if (kernel_size.size() == 4 && kernel_size[2] == 5 &&
kernel_size[3] == 5) {
// 1x1 refers to the output tile size
shader = VK_SHADER(conv2d_dw_5x5);
shader = VK_KERNEL(conv2d_dw_5x5);
}
break;
case Conv2dPointwise:
shader = VK_SHADER(conv2d_pw_2x2);
shader = VK_KERNEL(conv2d_pw_2x2);
break;
}
return shader;

View File

@ -70,6 +70,7 @@ storageTypeToEnum = {
"TEXTURE_2D" : "api::StorageType::TEXTURE_2D",
"TEXTURE_3D" : "api::StorageType::TEXTURE_3D",
"BUFFER" : "api::StorageType::BUFFER",
"": "api::StorageType::UNKNOWN",
}
def determineDescriptorType(lineStr):
@ -165,64 +166,67 @@ def genCppH(hFilePath, cppFilePath, srcDirPath, glslcPath, tmpDirPath, env):
h += "#include <vector>\n"
h += "#include <string>\n"
h += "#include <ATen/native/vulkan/api/Types.h>\n"
h += "#include <ATen/native/vulkan/api/vk_api.h>"
h += "#include <ATen/native/vulkan/api/vk_api.h>\n"
nsbegin = "\nnamespace at {\nnamespace native {\nnamespace vulkan {\n"
nsend = "\n}\n}\n} //namespace at::native::vulkan\n"
nsbegin = "namespace at {\nnamespace native {\nnamespace vulkan {\n"
nsend = "} // namespace vulkan\n} // namespace native\n} // namespace at\n"
h += nsbegin
cpp = "#include <ATen/native/vulkan/{}>".format(H_NAME)
# Forward declaration of ShaderInfo
h += "namespace api {\nstruct ShaderInfo;\n} // namespace api\n"
cpp = "#include <ATen/native/vulkan/{}>\n".format(H_NAME)
cpp += "#include <ATen/native/vulkan/api/Shader.h>\n"
cpp += nsbegin
shader_info_bin_code = []
shader_info_cpp_code = []
shader_info_h_code = []
for spvPath, srcPath in spvPaths.items():
name = getName(spvPath)
name_len = name + "_len"
h += "extern const uint32_t {}[];\n".format(name)
h += "extern const uint32_t {};\n".format(name_len)
shader_info = getShaderInfo(srcPath)
name_layout = name + "_layout"
h += "extern const std::vector<VkDescriptorType> {};\n".format(name_layout)
cpp += "const uint32_t " + name + "[] = {\n"
sizeBytes = 0
print("spvPath:{}".format(spvPath))
with open(spvPath, 'rb') as f:
for word in array.array('I', f.read()):
cpp += "{},\n".format(word)
sizeBytes += 4
cpp += "};\n"
cpp += "const uint32_t {} = {};\n".format(name_len, sizeBytes)
next_bin = array.array('I', f.read())
sizeBytes = 4 * len(next_bin)
shader_info_bin_code.append(
"const uint32_t {}_bin[] = {{\n {}\n}};".format(
name,
",\n ".join(str(x) for x in next_bin),
)
)
# Add layout
cpp += "const std::vector<VkDescriptorType> {} = {{\n".format(name_layout)
for descriptor in shader_info.layouts:
cpp += " {},\n".format(descriptor)
cpp += "};\n"
shader_info = getShaderInfo(srcPath)
# Add tile size
if (len(shader_info.tile_size) > 0):
name_tile_size = name + "_tile_size"
h += "extern const std::vector<uint32_t> {};\n".format(name_tile_size)
cpp += "const std::vector<uint32_t> {} = {{\n".format(name_tile_size)
for s in shader_info.tile_size:
cpp += " {},\n".format(s)
cpp += "};\n"
tile_size = (
"{{{}}}".format(", ".join(str(x) for x in shader_info.tile_size))
if (len(shader_info.tile_size) > 0)
else "std::vector<uint32_t>()"
)
# Add weight type
if (shader_info.weight_storage_type != ""):
name_weight_storage_type = name + "_weight_storage_type"
h += "extern const api::StorageType {};\n".format(name_weight_storage_type)
cpp += "const api::StorageType {} = \n".format(name_weight_storage_type)
cpp += " {};\n".format(storageTypeToEnum[shader_info.weight_storage_type])
shader_info_args = [
"\"vulkan.{}\"".format(name.replace("_spv", "")),
"{}_bin".format(name),
str(sizeBytes),
"{{{}}}".format(", ".join(shader_info.layouts)),
tile_size,
storageTypeToEnum[shader_info.weight_storage_type],
storageTypeToEnum[shader_info.bias_storage_type],
]
# Add bias type
if (shader_info.bias_storage_type != ""):
name_bias_storage_type = name + "_bias_storage_type"
h += "extern const api::StorageType {};\n".format(name_bias_storage_type)
cpp += "const api::StorageType {} = \n".format(name_bias_storage_type)
cpp += " {};\n".format(storageTypeToEnum[shader_info.bias_storage_type])
shader_info_h_code.append("extern const api::ShaderInfo {};".format(name))
shader_info_cpp_code.append(
"const api::ShaderInfo {}(\n {}\n);".format(
name,
",\n ".join(shader_info_args),
),
)
cpp += "namespace {{\n{}\n}} // namespace\n".format("\n".join(shader_info_bin_code))
cpp += "{}\n".format("\n".join(shader_info_cpp_code))
h += "{}\n".format("\n".join(shader_info_h_code))
cpp += nsend
h += nsend