mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
By introducing `DynamicMetalShaderLibrary` and `MetalShaderFunction` Add unittests that also serves as an example of how API works Using this primitive, one can compile and dispatch any 1D or 2D shader over MPS tensor using the following pattern ```cpp auto x = torch::empty({8, 16}, at::device(at::kMPS)); DynamicMetalShaderLibrary lib(R"MTL( kernel void full(device float* t, constant ulong2& strides, uint2 idx [[thread_position_in_grid]]) { t[idx.x*strides.x + idx.y*strides.y] = idx.x + 33.0 * idx.y; } )MTL"); auto func = lib.getKernelFunction("full"); func->runCommandBlock([&] { func->startEncoding(); func->setArg(0, x); func->setArg(1, x.strides()); func->dispatch({8, 16}); }); ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/141547 Approved by: https://github.com/Skylion007