Files
pytorch/aten
Nikita Shulga 9f9105a67b [MPS] Write/Invoke Metal shaders from C++ (#141547)
By introducing `DynamicMetalShaderLibrary` and `MetalShaderFunction`
Add unittests that also serves as an example of how API works

Using this primitive, one can compile and dispatch any 1D or 2D shader over MPS tensor using the following pattern
```cpp
auto x = torch::empty({8, 16}, at::device(at::kMPS));
DynamicMetalShaderLibrary lib(R"MTL(
  kernel void full(device float* t, constant ulong2& strides, uint2 idx [[thread_position_in_grid]]) {
    t[idx.x*strides.x + idx.y*strides.y] = idx.x + 33.0 * idx.y;
  }
)MTL");
auto func = lib.getKernelFunction("full");
func->runCommandBlock([&] {
   func->startEncoding();
   func->setArg(0, x);
   func->setArg(1, x.strides());
   func->dispatch({8, 16});
});

```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/141547
Approved by: https://github.com/Skylion007
2024-12-02 23:57:59 +00:00
..
2023-05-19 00:49:08 +00:00