[VK-API][Op Redesign][3/n] Expose new Context and Resource APIs (#121060)

Summary: For use in the next diff.

Test Plan: sc

Differential Revision: D54397862

Pull Request resolved: https://github.com/pytorch/pytorch/pull/121060
Approved by: https://github.com/SS-JIA
This commit is contained in:
Jorge Pineda
2024-03-04 22:26:07 +00:00
committed by PyTorch MergeBot
parent 70c23a51ac
commit eba28a6f91
2 changed files with 30 additions and 29 deletions

View File

@ -51,8 +51,7 @@ Context::~Context() {
}
}
DescriptorSet Context::submit_compute_prologue(
CommandBuffer& command_buffer,
DescriptorSet Context::get_descriptor_set(
const ShaderInfo& shader_descriptor,
const utils::uvec3& local_workgroup_size) {
VkDescriptorSetLayout shader_layout =
@ -66,21 +65,34 @@ DescriptorSet Context::submit_compute_prologue(
shader_cache().retrieve(shader_descriptor),
local_workgroup_size});
command_buffer.bind_pipeline(pipeline, pipeline_layout, local_workgroup_size);
cmd_.bind_pipeline(pipeline, pipeline_layout, local_workgroup_size);
return descriptor_pool().get_descriptor_set(
shader_layout, shader_descriptor.kernel_layout);
}
void Context::submit_compute_epilogue(
CommandBuffer& command_buffer,
void Context::register_shader_dispatch(
const DescriptorSet& descriptors,
PipelineBarrier& pipeline_barrier,
const ShaderInfo& shader_descriptor,
const utils::uvec3& global_workgroup_size) {
command_buffer.bind_descriptors(descriptors.get_bind_handle());
command_buffer.insert_barrier(pipeline_barrier);
// Adjust the global workgroup size based on the output tile size
const utils::uvec3 effective_global_wg = {
utils::div_up(
global_workgroup_size.data[0u],
shader_descriptor.out_tile_size.data[0u]),
utils::div_up(
global_workgroup_size.data[1u],
shader_descriptor.out_tile_size.data[1u]),
utils::div_up(
global_workgroup_size.data[2u],
shader_descriptor.out_tile_size.data[2u]),
};
command_buffer.dispatch(global_workgroup_size);
cmd_.bind_descriptors(descriptors.get_bind_handle());
cmd_.insert_barrier(pipeline_barrier);
cmd_.dispatch(effective_global_wg);
}
void Context::submit_cmd_to_gpu(VkFence fence_handle, const bool final_use) {
@ -164,12 +176,13 @@ namespace {
void memcpy_to_buffer(const VulkanBuffer& src, VulkanBuffer& dst) {
MemoryMap dst_mapping(dst, MemoryAccessType::WRITE);
MemoryMap src_mapping(src, api::MemoryAccessType::READ);
MemoryMap src_mapping(src, MemoryAccessType::READ);
src_mapping.invalidate();
void* dst_ptr = dst_mapping.template data<void>();
void* src_ptr = src_mapping.template data<void>();
// @lint-ignore CLANGTIDY facebook-security-vulnerable-memcpy
memcpy(dst_ptr, src_ptr, src.mem_size());
}

View File

@ -168,19 +168,14 @@ class Context final {
}
}
private:
DescriptorSet submit_compute_prologue(
CommandBuffer&,
DescriptorSet get_descriptor_set(const ShaderInfo&, const utils::uvec3&);
void register_shader_dispatch(
const DescriptorSet&,
PipelineBarrier&,
const ShaderInfo&,
const utils::uvec3&);
void submit_compute_epilogue(
CommandBuffer&,
const DescriptorSet&,
PipelineBarrier&,
const utils::uvec3&);
public:
template <class S, class D>
bool submit_copy(
PipelineBarrier&,
@ -502,23 +497,16 @@ inline bool Context::submit_compute_job(
// Factor out template parameter independent code to minimize code bloat.
DescriptorSet descriptor_set =
submit_compute_prologue(cmd_, shader, local_work_group_size);
get_descriptor_set(shader, local_work_group_size);
detail::bind(
descriptor_set,
std::index_sequence_for<Arguments...>{},
std::forward<Arguments>(arguments)...);
// Adjust the global workgroup size based on the output tile size
const utils::uvec3 effective_global_wg = {
utils::div_up(global_work_group.data[0u], shader.out_tile_size.data[0u]),
utils::div_up(global_work_group.data[1u], shader.out_tile_size.data[1u]),
utils::div_up(global_work_group.data[2u], shader.out_tile_size.data[2u]),
};
// Factor out template parameter independent code to minimize code bloat.
submit_compute_epilogue(
cmd_, descriptor_set, pipeline_barrier, effective_global_wg);
register_shader_dispatch(
descriptor_set, pipeline_barrier, shader, global_work_group);
#ifdef USE_VULKAN_GPU_DIAGNOSTICS
if (enable_op_profiling_) {