From 8f32adc90a7fee83583c9ba89dbdfabb317e0452 Mon Sep 17 00:00:00 2001 From: Nikita Shulga Date: Sun, 28 Sep 2025 19:00:49 -0700 Subject: [PATCH] [MPSHooks] Release pending command encoder (#164093) Before returning a comand buffer, as subsequent calle are very likely to allocate their own encoder, which results in the following runtime error ``` tryCoalescingPreviousComputeCommandEncoderWithConfig:nextEncoderClass:]:1090: failed assertion `A command encoder is already encoding to this command buffer' ``` Added regression test to `test_mps_extension` Please note, that `torch::mps::get_command_buffer()` should be called with dispatch_queue held, both before and after this change, but many implementations skip that Fixes https://github.com/pytorch/pytorch/issues/163721 Pull Request resolved: https://github.com/pytorch/pytorch/pull/164093 Approved by: https://github.com/atalman, https://github.com/Skylion007 --- aten/src/ATen/mps/MPSHooks.mm | 5 ++++- test/cpp_extensions/mps_extension.mm | 29 ++++++++++++++++++++++++++++ test/test_cpp_extensions_jit.py | 6 ++++++ 3 files changed, 39 insertions(+), 1 deletion(-) diff --git a/aten/src/ATen/mps/MPSHooks.mm b/aten/src/ATen/mps/MPSHooks.mm index a2ec221c1bfe..34fbd31af91d 100644 --- a/aten/src/ATen/mps/MPSHooks.mm +++ b/aten/src/ATen/mps/MPSHooks.mm @@ -70,7 +70,10 @@ void MPSHooks::commitStream() const { } void* MPSHooks::getCommandBuffer() const { - return at::mps::getDefaultMPSStream()->commandBuffer(); + auto stream = at::mps::getDefaultMPSStream(); + // Release pending computeCommandEncoder, as extensions is likely to allocate new one + stream->endKernelCoalescing(); + return stream->commandBuffer(); } void* MPSHooks::getDispatchQueue() const { diff --git a/test/cpp_extensions/mps_extension.mm b/test/cpp_extensions/mps_extension.mm index 882e5c5603e2..30b70a76563d 100644 --- a/test/cpp_extensions/mps_extension.mm +++ b/test/cpp_extensions/mps_extension.mm @@ -13,6 +13,11 @@ kernel void add_arrays(device const float* inA, { result[index] = inA[index] + inB[index]; } + +kernel void add_one(device float* data, + uint index [[thread_position_in_grid]]) { + data[index] += 1.0; +} )MPS_ADD_ARRAYS"); at::Tensor get_cpu_add_output(at::Tensor & cpu_input1, at::Tensor & cpu_input2) { @@ -50,7 +55,31 @@ at::Tensor get_mps_add_output(at::Tensor & mps_input1, at::Tensor & mps_input2) return mps_output; } +void mps_add_one_new_encoder(const at::Tensor& input) { + using namespace at::native::mps; + TORCH_CHECK(input.is_mps()); + TORCH_CHECK(input.numel() > 0); + + @autoreleasepool { + auto kernelPSO = lib.getPipelineStateForFunc("add_one"); + auto serialQueue = torch::mps::get_dispatch_queue(); + + dispatch_sync(serialQueue, ^(){ + auto commandBuffer = torch::mps::get_command_buffer(); + // Start a compute pass. + auto computeEncoder = [commandBuffer computeCommandEncoder]; + TORCH_CHECK(computeEncoder, "Failed to create compute command encoder"); + [computeEncoder setComputePipelineState: kernelPSO]; + mtl_setArgs(computeEncoder, input); + mtl_dispatch1DJob(computeEncoder, kernelPSO, input.numel()); + [computeEncoder endEncoding]; + torch::mps::commit(); + }); + } +} + PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { m.def("get_cpu_add_output", &get_cpu_add_output); m.def("get_mps_add_output", &get_mps_add_output); + m.def("mps_add_one_new_context", &mps_add_one_new_encoder); } diff --git a/test/test_cpp_extensions_jit.py b/test/test_cpp_extensions_jit.py index fd80c7fa565a..e93167296a00 100644 --- a/test/test_cpp_extensions_jit.py +++ b/test/test_cpp_extensions_jit.py @@ -220,6 +220,12 @@ class TestCppExtensionJIT(common.TestCase): self.assertEqual(cpu_output, mps_output.to("cpu")) + # Regression test for https://github.com/pytorch/pytorch/issues/163721 + lib = torch.mps.compile_shader("void kernel noop(device float *x) {}") + lib.noop(mps_output) + module.mps_add_one_new_context(mps_output) + self.assertEqual(cpu_output + 1.0, mps_output.to("cpu")) + def _run_jit_cuda_archflags(self, flags, expected): # Compile an extension with given `flags` def _check_cuobjdump_output(expected_values, is_ptx=False):