Revert "[MPS] Add CompileShader method (#141478)"

This reverts commit 0478fee42db16a0477add1d0a644ce713f31a875.

Reverted https://github.com/pytorch/pytorch/pull/141478 on behalf of https://github.com/malfet due to Broke doctests, by trying to run MPS example on Linux ([comment](https://github.com/pytorch/pytorch/pull/141478#issuecomment-2533351909))
This commit is contained in:
PyTorch MergeBot
2024-12-11 00:37:10 +00:00
parent b94a206414
commit 393cf46f42
7 changed files with 7 additions and 344 deletions

View File

@ -13,7 +13,6 @@ typedef void* MTLComputePipelineState_t;
typedef void* MTLComputeCommandEncoder_t;
#endif
#include <c10/util/OptionalArrayRef.h>
#include <functional>
#include <optional>
#include <type_traits>
@ -81,8 +80,8 @@ class MetalKernelFunction {
uint64_t length,
std::optional<uint64_t> groupSize = std::nullopt);
void dispatch(
c10::ArrayRef<uint64_t> length,
c10::OptionalArrayRef<uint64_t> groupSize = std::nullopt);
std::array<uint64_t, 2> length,
std::optional<std::array<uint64_t, 2>> groupSize = std::nullopt);
private:
MTLComputePipelineState_t cps;

View File

@ -962,15 +962,11 @@ void MetalKernelFunction::dispatch(uint64_t length, std::optional<uint64_t> grou
[encoder dispatchThreads:MTLSizeMake(length, 1, 1) threadsPerThreadgroup:MTLSizeMake(group_size_val, 1, 1)];
}
void MetalKernelFunction::dispatch(c10::ArrayRef<uint64_t> length, c10::OptionalArrayRef<uint64_t> group_size) {
TORCH_CHECK(length.size() > 0 && length.size() < 4, "Dispatch dimentions must be less than 3 and non-empty");
TORCH_CHECK(!group_size.has_value() || group_size->size() == length.size(),
"size and group_size must have same number of dimentions");
auto group_size_length = group_size.has_value() ? group_size->size() : 0;
[encoder dispatchThreads:MTLSizeMake(length[0], length.size() > 1 ? length[1] : 1, length.size() == 3 ? length[2] : 1)
threadsPerThreadgroup:MTLSizeMake(group_size_length > 0 ? group_size->at(0) : getMaxThreadsPerThreadgroup(),
group_size_length > 1 ? group_size->at(1) : 1,
group_size_length == 3 ? group_size->at(2) : 1)];
void MetalKernelFunction::dispatch(std::array<uint64_t, 2> length, std::optional<std::array<uint64_t, 2>> group_size) {
auto group_size_val =
group_size.value_or(std::array<uint64_t, 2>{std::min(length[0], getMaxThreadsPerThreadgroup()), 1});
[encoder dispatchThreads:MTLSizeMake(length[0], length[1], 1)
threadsPerThreadgroup:MTLSizeMake(group_size_val[0], group_size_val[1], 1)];
}
void MetalKernelFunction::setArg(unsigned idx, const at::TensorBase& t) {

View File

@ -12492,89 +12492,6 @@ class TestCommon(TestCase):
cpu_tensor = ones("cpu")
self.assertEqual(mps_tensor.cpu(), cpu_tensor)
class TestMetalLibrary(TestCaseMPS):
def test_metal_arange(self):
x = torch.zeros(12, device="mps", dtype=torch.half)
lib = torch.mps._compile_shader("""
kernel void arange(device half* x, uint idx [[thread_position_in_grid]]) {
x[idx] = idx;
}
""")
lib.arange(x)
self.assertEqual(x, torch.arange(x.numel(), device='mps', dtype=x.dtype))
def test_metal_dispatch_3d(self):
x = torch.empty(12, device="mps")
y = torch.empty_like(x)
z = torch.empty_like(x)
lib = torch.mps._compile_shader("""
kernel void arange_x(device float* x, uint3 idx [[thread_position_in_grid]]) {
x[idx.x + idx.y + idx.z] = idx.x;
}
kernel void arange_y(device float* x, uint3 idx [[thread_position_in_grid]]) {
x[idx.x + idx.y + idx.z] = idx.y;
}
kernel void arange_z(device float* x, uint3 idx [[thread_position_in_grid]]) {
x[idx.x + idx.y + idx.z] = idx.z;
}
""")
# Check that one can enumerate all shaders
self.assertEqual(set(dir(lib)), {f"arange_{i}" for i in ["x", "y", "z"]})
lib.arange_x(x)
lib.arange_y(y, threads=(1, y.numel()))
lib.arange_z(z, threads=(1, 1, z.numel()))
self.assertEqual(x, torch.arange(x.numel(), device='mps', dtype=x.dtype))
self.assertEqual(x, y)
self.assertEqual(x, z)
def test_metal_arange_with_arg(self):
x = torch.zeros(12, device="mps")
lib = torch.mps._compile_shader("""
kernel void arange(device float* x, constant float& start, constant float& step,
uint idx [[thread_position_in_grid]]) {
x[idx] = start + idx * step;
}
""")
lib.arange(x, 3.14, .5)
self.assertEqual(x, torch.arange(3.14, 8.66, .5, device='mps'))
def test_metal_arange_with_arg_and_cast(self):
x = torch.zeros(12, device="mps", dtype=torch.half)
y = torch.zeros(12, device="mps", dtype=torch.half)
lib = torch.mps._compile_shader("""
kernel void arange_all_half(device half* x, constant half2& start_step,
uint idx [[thread_position_in_grid]]) {
x[idx] = start_step.x + idx * start_step.y;
}
kernel void arange_half_float(device half* x, constant half& start, constant float& step,
uint idx [[thread_position_in_grid]]) {
x[idx] = start + idx * step;
}
""")
lib.arange_all_half(x, [3.14, .5], arg_casts="fp16")
lib.arange_half_float(y, 3.14, .5, arg_casts={1: "fp16"})
self.assertEqual(x, torch.arange(3.14, 8.66, .5, device='mps', dtype=x.dtype))
self.assertEqual(x, y)
def test_metal_error_checking(self):
# Syntax error asserts
self.assertRaises(RuntimeError, lambda: torch.mps._compile_shader("Syntax error"))
cpu_tensor = torch.rand(3)
mps_tensor = torch.rand(3, device="mps")
lib = torch.mps._compile_shader("kernel void full(device half* x) { x[0] = 1.0; }")
# Passing CPU tensor asserts
self.assertRaises(RuntimeError, lambda: lib.full(cpu_tensor))
# Passing invalid shader name asserts
self.assertRaises(RuntimeError, lambda: lib.non_existing(mps_tensor))
# Passing no tensors asserts
self.assertRaises(RuntimeError, lambda: lib.full(12))
# TODO: Actually instantiate that test for the "mps" device to better reflect what it is doing.
# This requires mps to be properly registered in the device generic test framework which is not the

View File

@ -1811,9 +1811,6 @@ PyObject* initModule() {
#ifdef USE_CUDA
torch::cuda::initModule(module);
#endif
#ifdef USE_MPS
torch::mps::initModule(module);
#endif
#ifdef USE_XPU
torch::xpu::initModule(module);
#endif

View File

@ -1,25 +1,16 @@
#define PYBIND11_DETAILED_ERROR_MESSAGES
#include <ATen/ATen.h>
#include <c10/util/CallOnce.h>
#include <pybind11/pytypes.h>
#include <torch/csrc/Generator.h>
#include <torch/csrc/THP.h>
#include <torch/csrc/python_headers.h>
#include <torch/csrc/utils/pybind.h>
#include <torch/csrc/utils/python_numbers.h>
#include <torch/csrc/utils/python_strings.h>
#include <memory>
// pthread.h is included for tracking bad forks
#ifndef WIN32
#include <pthread.h>
#endif
#ifdef USE_MPS
#include <ATen/native/mps/MetalShaderLibrary.h>
#endif
namespace torch::mps {
namespace {
@ -287,224 +278,4 @@ PyMethodDef* python_functions() {
return _MPSModule_methods;
}
#ifdef USE_MPS
namespace {
template <typename T = uint64_t>
std::optional<std::vector<T>> optional_vec_from_pyobject(
const py::object& py_value) {
if (py_value.is_none()) {
return std::nullopt;
}
if (py::isinstance<py::int_>(py_value)) {
return std::vector({py_value.cast<T>()});
}
auto vec = py_value.cast<std::vector<T>>();
TORCH_CHECK(vec.size() > 0 && vec.size() < 4);
return vec;
}
struct OptionalArgCaster {
public:
OptionalArgCaster(const py::object& arg) {
if (arg.is_none()) {
} else if (py::isinstance<py::str>(arg)) {
default_cast = arg.cast<std::string>();
} else if (py::isinstance<py::dict>(arg)) {
cast_map = arg.cast<std::unordered_map<unsigned, std::string>>();
} else {
TORCH_CHECK(
false,
"Unexpected caster arg type ",
arg.attr("__class__").attr("__name__").cast<const std::string>());
}
}
template <typename T>
void setValue(
::at::native::mps::MetalKernelFunction& f,
unsigned idx,
const std::vector<T>& values) {
auto cast_str =
cast_map.find(idx) != cast_map.end() ? cast_map[idx] : default_cast;
if (cast_str.size() == 0) {
f.setArg(idx, values);
} else if (cast_str == "fp16") {
std::vector<c10::Half> cast_values(values.begin(), values.end());
f.setArg(idx, cast_values);
} else if (cast_str == "bf16") {
std::vector<c10::BFloat16> cast_values(values.begin(), values.end());
f.setArg(idx, cast_values);
} else if (cast_str == "int32") {
std::vector<int32_t> cast_values(values.begin(), values.end());
f.setArg(idx, cast_values);
} else if (cast_str == "int16") {
std::vector<int16_t> cast_values(values.begin(), values.end());
f.setArg(idx, cast_values);
} else if (cast_str == "int8") {
std::vector<int8_t> cast_values(values.begin(), values.end());
f.setArg(idx, cast_values);
} else if (cast_str == "uint8") {
std::vector<uint8_t> cast_values(values.begin(), values.end());
f.setArg(idx, cast_values);
} else {
TORCH_CHECK(false, "Unsupported cast instruction ", default_cast);
}
}
template <
typename T,
typename = std::enable_if_t<
std::is_same_v<float, T> || std::is_same_v<int64_t, T>>>
void setValue(
::at::native::mps::MetalKernelFunction& f,
unsigned idx,
const T& value) {
auto cast_str =
cast_map.find(idx) != cast_map.end() ? cast_map[idx] : default_cast;
if (cast_str.size() == 0) {
f.setArg(idx, value);
} else if (cast_str == "fp16") {
f.setArg(idx, static_cast<c10::Half>(value));
} else if (cast_str == "bf16") {
f.setArg(idx, static_cast<c10::BFloat16>(value));
} else if (cast_str == "int32") {
f.setArg(idx, static_cast<int32_t>(value));
} else if (cast_str == "int16") {
f.setArg(idx, static_cast<int16_t>(value));
} else if (cast_str == "int8") {
f.setArg(idx, static_cast<int8_t>(value));
} else if (cast_str == "uint8") {
f.setArg(idx, static_cast<uint8_t>(value));
} else {
TORCH_CHECK(false, "Unsupported cast instruction ", default_cast);
}
}
void setValue(
::at::native::mps::MetalKernelFunction& f,
unsigned idx,
const py::object& arg) {
if (py::isinstance<py::tuple>(arg) || py::isinstance<py::list>(arg)) {
auto len = arg.attr("__len__")().cast<uint64_t>();
TORCH_CHECK(
len > 0, "Empty list/tuple can not be an argument to metal kernel")
auto element = arg.attr("__getitem__")(0);
if (py::isinstance<py::int_>(element)) {
auto values = arg.cast<std::vector<int64_t>>();
setValue(f, idx, values);
} else if (py::isinstance<py::float_>(element)) {
auto values = arg.cast<std::vector<float>>();
setValue(f, idx, values);
} else {
TORCH_CHECK(false, "Unexpected argument types");
}
} else if (py::isinstance<py::float_>(arg)) {
auto value = arg.cast<float>();
setValue(f, idx, value);
} else if (py::isinstance<py::int_>(arg)) {
auto value = arg.cast<int64_t>();
setValue(f, idx, value);
} else {
TORCH_CHECK(false, "Unsupported argument type");
}
}
private:
std::string default_cast;
std::unordered_map<unsigned, std::string> cast_map;
};
} // namespace
void initModule(PyObject* module) {
using namespace at::native::mps;
auto m = py::handle(module).cast<py::module>();
py::class_<
DynamicMetalShaderLibrary,
std::shared_ptr<DynamicMetalShaderLibrary>>(m, "_mps_ShaderLibrary")
.def(
"__getattr__",
[](DynamicMetalShaderLibrary& self, const std::string& name) {
return self.getKernelFunction(name);
})
.def("__dir__", [](DynamicMetalShaderLibrary& self) {
return self.getFunctionNames();
});
py::class_<MetalKernelFunction, std::shared_ptr<MetalKernelFunction>>(
m, "_mps_MetalKernel")
.def(
"__call__",
[](MetalKernelFunction& self,
const py::args& args,
const py::object& py_threads,
const py::object& py_group_size,
const py::object& arg_casts) {
auto threads = optional_vec_from_pyobject(py_threads);
auto group_size = optional_vec_from_pyobject(py_group_size);
OptionalArgCaster caster(arg_casts);
self.runCommandBlock([&] {
self.startEncoding();
for (auto idx : c10::irange(args.size())) {
if (THPVariable_Check(args[idx].ptr())) {
auto t = THPVariable_Unpack(args[idx].ptr());
self.setArg(idx, t);
if (!threads) {
threads = {static_cast<uint64_t>(t.numel())};
}
continue;
}
caster.setValue(self, idx, args[idx]);
}
TORCH_CHECK(
threads.has_value() && threads->size() < 4,
"Number of threads is undefined or has wrong dimention");
TORCH_CHECK(
!group_size.has_value() ||
threads->size() == group_size->size());
if (threads->size() == 1) {
if (group_size.has_value()) {
self.dispatch(threads->at(0), group_size->at(0));
} else {
self.dispatch(threads->at(0));
}
} else if (threads->size() == 2) {
if (group_size.has_value()) {
self.dispatch(
{threads->at(0), threads->at(1)},
{group_size->at(0), group_size->at(1)});
} else {
self.dispatch({threads->at(0), threads->at(1)});
}
} else {
if (group_size.has_value()) {
self.dispatch(
{threads->at(0), threads->at(1), threads->at(2)},
{group_size->at(0),
group_size->at(1),
group_size->at(2)});
} else {
self.dispatch(
{threads->at(0), threads->at(1), threads->at(2)});
}
}
});
},
py::kw_only(),
py::arg("threads") = py::none(),
py::arg("group_size") = py::none(),
py::arg("arg_casts") = py::none())
.def_property_readonly(
"max_threads_per_threadgroup",
&MetalKernelFunction::getMaxThreadsPerThreadgroup)
.def_property_readonly(
"thread_execution_width",
&MetalKernelFunction::getThreadExecutionWidth)
.def_property_readonly(
"static_thread_group_memory_length",
&MetalKernelFunction::getStaticThreadGroupMemoryLength);
m.def("_mps_compileShader", [](const std::string& source) {
return std::make_shared<DynamicMetalShaderLibrary>(source);
});
}
#endif /* USE_MPS */
} // namespace torch::mps

View File

@ -5,6 +5,5 @@
namespace torch::mps {
PyMethodDef* python_functions();
void initModule(PyObject* module);
} // namespace torch::mps

View File

@ -140,22 +140,6 @@ def recommended_max_memory() -> int:
return torch._C._mps_recommendedMaxMemory()
def _compile_shader(source: str):
r"""Compiles compute shader from source and allows one to invoke kernels
defined there from the comfort of Python runtime
Example::
>>> lib = torch.mps._compile_shader(
... "kernel void full(device float* out, constant float& val, uint idx [[thread_position_in_grid]]) { out[idx] = val; }"
... )
>>> x = torch.zeros(16, device="mps")
>>> lib.full(x, 3.14)
"""
if not hasattr(torch._C, "_mps_compileShader"):
raise RuntimeError("MPS is not available")
return torch._C._mps_compileShader(source)
def is_available() -> bool:
return device_count() > 0