mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-21 21:49:24 +08:00
This allows one to do something like that ```python import torch x = torch.ones(10, device="mps") m = torch.mps._compile_shader(""" kernel void foo(device float* x, uint idx [[thread_position_in_grid]]) { x[idx] += idx; } ") m.foo(x) ``` And in general enables writing custom operators using Metal shaders purely in Python Pull Request resolved: https://github.com/pytorch/pytorch/pull/141478 Approved by: https://github.com/manuelcandales
11 lines
173 B
C++
11 lines
173 B
C++
#pragma once
|
|
|
|
#include <torch/csrc/python_headers.h>
|
|
|
|
namespace torch::mps {
|
|
|
|
PyMethodDef* python_functions();
|
|
void initModule(PyObject* module);
|
|
|
|
} // namespace torch::mps
|