Files
pytorch/torch/csrc/mps/Module.h
Nikita Shulga 95b17f6346 [MPS] Add CompileShader method (#141478)
This allows one to do something like that
```python
import torch
x = torch.ones(10, device="mps")
m = torch.mps._compile_shader("""
   kernel void foo(device float* x, uint idx [[thread_position_in_grid]]) {
     x[idx] += idx;
   }
")
m.foo(x)
```

And in general enables writing custom operators using Metal shaders purely in Python
Pull Request resolved: https://github.com/pytorch/pytorch/pull/141478
Approved by: https://github.com/manuelcandales
2024-12-11 02:00:51 +00:00

11 lines
173 B
C++

#pragma once
#include <torch/csrc/python_headers.h>
namespace torch::mps {
PyMethodDef* python_functions();
void initModule(PyObject* module);
} // namespace torch::mps