mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Revert "[MPS] Add CompileShader method (#141478)"
This reverts commit 0478fee42db16a0477add1d0a644ce713f31a875. Reverted https://github.com/pytorch/pytorch/pull/141478 on behalf of https://github.com/malfet due to Broke doctests, by trying to run MPS example on Linux ([comment](https://github.com/pytorch/pytorch/pull/141478#issuecomment-2533351909))
This commit is contained in:
@ -13,7 +13,6 @@ typedef void* MTLComputePipelineState_t;
|
||||
typedef void* MTLComputeCommandEncoder_t;
|
||||
#endif
|
||||
|
||||
#include <c10/util/OptionalArrayRef.h>
|
||||
#include <functional>
|
||||
#include <optional>
|
||||
#include <type_traits>
|
||||
@ -81,8 +80,8 @@ class MetalKernelFunction {
|
||||
uint64_t length,
|
||||
std::optional<uint64_t> groupSize = std::nullopt);
|
||||
void dispatch(
|
||||
c10::ArrayRef<uint64_t> length,
|
||||
c10::OptionalArrayRef<uint64_t> groupSize = std::nullopt);
|
||||
std::array<uint64_t, 2> length,
|
||||
std::optional<std::array<uint64_t, 2>> groupSize = std::nullopt);
|
||||
|
||||
private:
|
||||
MTLComputePipelineState_t cps;
|
||||
|
@ -962,15 +962,11 @@ void MetalKernelFunction::dispatch(uint64_t length, std::optional<uint64_t> grou
|
||||
[encoder dispatchThreads:MTLSizeMake(length, 1, 1) threadsPerThreadgroup:MTLSizeMake(group_size_val, 1, 1)];
|
||||
}
|
||||
|
||||
void MetalKernelFunction::dispatch(c10::ArrayRef<uint64_t> length, c10::OptionalArrayRef<uint64_t> group_size) {
|
||||
TORCH_CHECK(length.size() > 0 && length.size() < 4, "Dispatch dimentions must be less than 3 and non-empty");
|
||||
TORCH_CHECK(!group_size.has_value() || group_size->size() == length.size(),
|
||||
"size and group_size must have same number of dimentions");
|
||||
auto group_size_length = group_size.has_value() ? group_size->size() : 0;
|
||||
[encoder dispatchThreads:MTLSizeMake(length[0], length.size() > 1 ? length[1] : 1, length.size() == 3 ? length[2] : 1)
|
||||
threadsPerThreadgroup:MTLSizeMake(group_size_length > 0 ? group_size->at(0) : getMaxThreadsPerThreadgroup(),
|
||||
group_size_length > 1 ? group_size->at(1) : 1,
|
||||
group_size_length == 3 ? group_size->at(2) : 1)];
|
||||
void MetalKernelFunction::dispatch(std::array<uint64_t, 2> length, std::optional<std::array<uint64_t, 2>> group_size) {
|
||||
auto group_size_val =
|
||||
group_size.value_or(std::array<uint64_t, 2>{std::min(length[0], getMaxThreadsPerThreadgroup()), 1});
|
||||
[encoder dispatchThreads:MTLSizeMake(length[0], length[1], 1)
|
||||
threadsPerThreadgroup:MTLSizeMake(group_size_val[0], group_size_val[1], 1)];
|
||||
}
|
||||
|
||||
void MetalKernelFunction::setArg(unsigned idx, const at::TensorBase& t) {
|
||||
|
@ -12492,89 +12492,6 @@ class TestCommon(TestCase):
|
||||
cpu_tensor = ones("cpu")
|
||||
self.assertEqual(mps_tensor.cpu(), cpu_tensor)
|
||||
|
||||
class TestMetalLibrary(TestCaseMPS):
|
||||
def test_metal_arange(self):
|
||||
x = torch.zeros(12, device="mps", dtype=torch.half)
|
||||
lib = torch.mps._compile_shader("""
|
||||
kernel void arange(device half* x, uint idx [[thread_position_in_grid]]) {
|
||||
x[idx] = idx;
|
||||
}
|
||||
""")
|
||||
lib.arange(x)
|
||||
self.assertEqual(x, torch.arange(x.numel(), device='mps', dtype=x.dtype))
|
||||
|
||||
def test_metal_dispatch_3d(self):
|
||||
x = torch.empty(12, device="mps")
|
||||
y = torch.empty_like(x)
|
||||
z = torch.empty_like(x)
|
||||
lib = torch.mps._compile_shader("""
|
||||
kernel void arange_x(device float* x, uint3 idx [[thread_position_in_grid]]) {
|
||||
x[idx.x + idx.y + idx.z] = idx.x;
|
||||
}
|
||||
|
||||
kernel void arange_y(device float* x, uint3 idx [[thread_position_in_grid]]) {
|
||||
x[idx.x + idx.y + idx.z] = idx.y;
|
||||
}
|
||||
|
||||
kernel void arange_z(device float* x, uint3 idx [[thread_position_in_grid]]) {
|
||||
x[idx.x + idx.y + idx.z] = idx.z;
|
||||
}
|
||||
""")
|
||||
|
||||
# Check that one can enumerate all shaders
|
||||
self.assertEqual(set(dir(lib)), {f"arange_{i}" for i in ["x", "y", "z"]})
|
||||
|
||||
lib.arange_x(x)
|
||||
lib.arange_y(y, threads=(1, y.numel()))
|
||||
lib.arange_z(z, threads=(1, 1, z.numel()))
|
||||
|
||||
self.assertEqual(x, torch.arange(x.numel(), device='mps', dtype=x.dtype))
|
||||
self.assertEqual(x, y)
|
||||
self.assertEqual(x, z)
|
||||
|
||||
def test_metal_arange_with_arg(self):
|
||||
x = torch.zeros(12, device="mps")
|
||||
lib = torch.mps._compile_shader("""
|
||||
kernel void arange(device float* x, constant float& start, constant float& step,
|
||||
uint idx [[thread_position_in_grid]]) {
|
||||
x[idx] = start + idx * step;
|
||||
}
|
||||
""")
|
||||
lib.arange(x, 3.14, .5)
|
||||
self.assertEqual(x, torch.arange(3.14, 8.66, .5, device='mps'))
|
||||
|
||||
def test_metal_arange_with_arg_and_cast(self):
|
||||
x = torch.zeros(12, device="mps", dtype=torch.half)
|
||||
y = torch.zeros(12, device="mps", dtype=torch.half)
|
||||
lib = torch.mps._compile_shader("""
|
||||
kernel void arange_all_half(device half* x, constant half2& start_step,
|
||||
uint idx [[thread_position_in_grid]]) {
|
||||
x[idx] = start_step.x + idx * start_step.y;
|
||||
}
|
||||
|
||||
kernel void arange_half_float(device half* x, constant half& start, constant float& step,
|
||||
uint idx [[thread_position_in_grid]]) {
|
||||
x[idx] = start + idx * step;
|
||||
}
|
||||
""")
|
||||
lib.arange_all_half(x, [3.14, .5], arg_casts="fp16")
|
||||
lib.arange_half_float(y, 3.14, .5, arg_casts={1: "fp16"})
|
||||
self.assertEqual(x, torch.arange(3.14, 8.66, .5, device='mps', dtype=x.dtype))
|
||||
self.assertEqual(x, y)
|
||||
|
||||
def test_metal_error_checking(self):
|
||||
# Syntax error asserts
|
||||
self.assertRaises(RuntimeError, lambda: torch.mps._compile_shader("Syntax error"))
|
||||
cpu_tensor = torch.rand(3)
|
||||
mps_tensor = torch.rand(3, device="mps")
|
||||
lib = torch.mps._compile_shader("kernel void full(device half* x) { x[0] = 1.0; }")
|
||||
# Passing CPU tensor asserts
|
||||
self.assertRaises(RuntimeError, lambda: lib.full(cpu_tensor))
|
||||
# Passing invalid shader name asserts
|
||||
self.assertRaises(RuntimeError, lambda: lib.non_existing(mps_tensor))
|
||||
# Passing no tensors asserts
|
||||
self.assertRaises(RuntimeError, lambda: lib.full(12))
|
||||
|
||||
|
||||
# TODO: Actually instantiate that test for the "mps" device to better reflect what it is doing.
|
||||
# This requires mps to be properly registered in the device generic test framework which is not the
|
||||
|
@ -1811,9 +1811,6 @@ PyObject* initModule() {
|
||||
#ifdef USE_CUDA
|
||||
torch::cuda::initModule(module);
|
||||
#endif
|
||||
#ifdef USE_MPS
|
||||
torch::mps::initModule(module);
|
||||
#endif
|
||||
#ifdef USE_XPU
|
||||
torch::xpu::initModule(module);
|
||||
#endif
|
||||
|
@ -1,25 +1,16 @@
|
||||
#define PYBIND11_DETAILED_ERROR_MESSAGES
|
||||
|
||||
#include <ATen/ATen.h>
|
||||
#include <c10/util/CallOnce.h>
|
||||
#include <pybind11/pytypes.h>
|
||||
#include <torch/csrc/Generator.h>
|
||||
#include <torch/csrc/THP.h>
|
||||
#include <torch/csrc/python_headers.h>
|
||||
#include <torch/csrc/utils/pybind.h>
|
||||
#include <torch/csrc/utils/python_numbers.h>
|
||||
#include <torch/csrc/utils/python_strings.h>
|
||||
#include <memory>
|
||||
|
||||
// pthread.h is included for tracking bad forks
|
||||
#ifndef WIN32
|
||||
#include <pthread.h>
|
||||
#endif
|
||||
|
||||
#ifdef USE_MPS
|
||||
#include <ATen/native/mps/MetalShaderLibrary.h>
|
||||
#endif
|
||||
|
||||
namespace torch::mps {
|
||||
|
||||
namespace {
|
||||
@ -287,224 +278,4 @@ PyMethodDef* python_functions() {
|
||||
return _MPSModule_methods;
|
||||
}
|
||||
|
||||
#ifdef USE_MPS
|
||||
namespace {
|
||||
template <typename T = uint64_t>
|
||||
std::optional<std::vector<T>> optional_vec_from_pyobject(
|
||||
const py::object& py_value) {
|
||||
if (py_value.is_none()) {
|
||||
return std::nullopt;
|
||||
}
|
||||
if (py::isinstance<py::int_>(py_value)) {
|
||||
return std::vector({py_value.cast<T>()});
|
||||
}
|
||||
auto vec = py_value.cast<std::vector<T>>();
|
||||
TORCH_CHECK(vec.size() > 0 && vec.size() < 4);
|
||||
return vec;
|
||||
}
|
||||
|
||||
struct OptionalArgCaster {
|
||||
public:
|
||||
OptionalArgCaster(const py::object& arg) {
|
||||
if (arg.is_none()) {
|
||||
} else if (py::isinstance<py::str>(arg)) {
|
||||
default_cast = arg.cast<std::string>();
|
||||
} else if (py::isinstance<py::dict>(arg)) {
|
||||
cast_map = arg.cast<std::unordered_map<unsigned, std::string>>();
|
||||
} else {
|
||||
TORCH_CHECK(
|
||||
false,
|
||||
"Unexpected caster arg type ",
|
||||
arg.attr("__class__").attr("__name__").cast<const std::string>());
|
||||
}
|
||||
}
|
||||
template <typename T>
|
||||
void setValue(
|
||||
::at::native::mps::MetalKernelFunction& f,
|
||||
unsigned idx,
|
||||
const std::vector<T>& values) {
|
||||
auto cast_str =
|
||||
cast_map.find(idx) != cast_map.end() ? cast_map[idx] : default_cast;
|
||||
if (cast_str.size() == 0) {
|
||||
f.setArg(idx, values);
|
||||
} else if (cast_str == "fp16") {
|
||||
std::vector<c10::Half> cast_values(values.begin(), values.end());
|
||||
f.setArg(idx, cast_values);
|
||||
} else if (cast_str == "bf16") {
|
||||
std::vector<c10::BFloat16> cast_values(values.begin(), values.end());
|
||||
f.setArg(idx, cast_values);
|
||||
} else if (cast_str == "int32") {
|
||||
std::vector<int32_t> cast_values(values.begin(), values.end());
|
||||
f.setArg(idx, cast_values);
|
||||
} else if (cast_str == "int16") {
|
||||
std::vector<int16_t> cast_values(values.begin(), values.end());
|
||||
f.setArg(idx, cast_values);
|
||||
} else if (cast_str == "int8") {
|
||||
std::vector<int8_t> cast_values(values.begin(), values.end());
|
||||
f.setArg(idx, cast_values);
|
||||
} else if (cast_str == "uint8") {
|
||||
std::vector<uint8_t> cast_values(values.begin(), values.end());
|
||||
f.setArg(idx, cast_values);
|
||||
} else {
|
||||
TORCH_CHECK(false, "Unsupported cast instruction ", default_cast);
|
||||
}
|
||||
}
|
||||
|
||||
template <
|
||||
typename T,
|
||||
typename = std::enable_if_t<
|
||||
std::is_same_v<float, T> || std::is_same_v<int64_t, T>>>
|
||||
void setValue(
|
||||
::at::native::mps::MetalKernelFunction& f,
|
||||
unsigned idx,
|
||||
const T& value) {
|
||||
auto cast_str =
|
||||
cast_map.find(idx) != cast_map.end() ? cast_map[idx] : default_cast;
|
||||
if (cast_str.size() == 0) {
|
||||
f.setArg(idx, value);
|
||||
} else if (cast_str == "fp16") {
|
||||
f.setArg(idx, static_cast<c10::Half>(value));
|
||||
} else if (cast_str == "bf16") {
|
||||
f.setArg(idx, static_cast<c10::BFloat16>(value));
|
||||
} else if (cast_str == "int32") {
|
||||
f.setArg(idx, static_cast<int32_t>(value));
|
||||
} else if (cast_str == "int16") {
|
||||
f.setArg(idx, static_cast<int16_t>(value));
|
||||
} else if (cast_str == "int8") {
|
||||
f.setArg(idx, static_cast<int8_t>(value));
|
||||
} else if (cast_str == "uint8") {
|
||||
f.setArg(idx, static_cast<uint8_t>(value));
|
||||
} else {
|
||||
TORCH_CHECK(false, "Unsupported cast instruction ", default_cast);
|
||||
}
|
||||
}
|
||||
|
||||
void setValue(
|
||||
::at::native::mps::MetalKernelFunction& f,
|
||||
unsigned idx,
|
||||
const py::object& arg) {
|
||||
if (py::isinstance<py::tuple>(arg) || py::isinstance<py::list>(arg)) {
|
||||
auto len = arg.attr("__len__")().cast<uint64_t>();
|
||||
TORCH_CHECK(
|
||||
len > 0, "Empty list/tuple can not be an argument to metal kernel")
|
||||
auto element = arg.attr("__getitem__")(0);
|
||||
if (py::isinstance<py::int_>(element)) {
|
||||
auto values = arg.cast<std::vector<int64_t>>();
|
||||
setValue(f, idx, values);
|
||||
} else if (py::isinstance<py::float_>(element)) {
|
||||
auto values = arg.cast<std::vector<float>>();
|
||||
setValue(f, idx, values);
|
||||
} else {
|
||||
TORCH_CHECK(false, "Unexpected argument types");
|
||||
}
|
||||
} else if (py::isinstance<py::float_>(arg)) {
|
||||
auto value = arg.cast<float>();
|
||||
setValue(f, idx, value);
|
||||
} else if (py::isinstance<py::int_>(arg)) {
|
||||
auto value = arg.cast<int64_t>();
|
||||
setValue(f, idx, value);
|
||||
} else {
|
||||
TORCH_CHECK(false, "Unsupported argument type");
|
||||
}
|
||||
}
|
||||
|
||||
private:
|
||||
std::string default_cast;
|
||||
std::unordered_map<unsigned, std::string> cast_map;
|
||||
};
|
||||
|
||||
} // namespace
|
||||
|
||||
void initModule(PyObject* module) {
|
||||
using namespace at::native::mps;
|
||||
auto m = py::handle(module).cast<py::module>();
|
||||
py::class_<
|
||||
DynamicMetalShaderLibrary,
|
||||
std::shared_ptr<DynamicMetalShaderLibrary>>(m, "_mps_ShaderLibrary")
|
||||
.def(
|
||||
"__getattr__",
|
||||
[](DynamicMetalShaderLibrary& self, const std::string& name) {
|
||||
return self.getKernelFunction(name);
|
||||
})
|
||||
.def("__dir__", [](DynamicMetalShaderLibrary& self) {
|
||||
return self.getFunctionNames();
|
||||
});
|
||||
py::class_<MetalKernelFunction, std::shared_ptr<MetalKernelFunction>>(
|
||||
m, "_mps_MetalKernel")
|
||||
.def(
|
||||
"__call__",
|
||||
[](MetalKernelFunction& self,
|
||||
const py::args& args,
|
||||
const py::object& py_threads,
|
||||
const py::object& py_group_size,
|
||||
const py::object& arg_casts) {
|
||||
auto threads = optional_vec_from_pyobject(py_threads);
|
||||
auto group_size = optional_vec_from_pyobject(py_group_size);
|
||||
OptionalArgCaster caster(arg_casts);
|
||||
self.runCommandBlock([&] {
|
||||
self.startEncoding();
|
||||
for (auto idx : c10::irange(args.size())) {
|
||||
if (THPVariable_Check(args[idx].ptr())) {
|
||||
auto t = THPVariable_Unpack(args[idx].ptr());
|
||||
self.setArg(idx, t);
|
||||
if (!threads) {
|
||||
threads = {static_cast<uint64_t>(t.numel())};
|
||||
}
|
||||
continue;
|
||||
}
|
||||
caster.setValue(self, idx, args[idx]);
|
||||
}
|
||||
TORCH_CHECK(
|
||||
threads.has_value() && threads->size() < 4,
|
||||
"Number of threads is undefined or has wrong dimention");
|
||||
TORCH_CHECK(
|
||||
!group_size.has_value() ||
|
||||
threads->size() == group_size->size());
|
||||
if (threads->size() == 1) {
|
||||
if (group_size.has_value()) {
|
||||
self.dispatch(threads->at(0), group_size->at(0));
|
||||
} else {
|
||||
self.dispatch(threads->at(0));
|
||||
}
|
||||
} else if (threads->size() == 2) {
|
||||
if (group_size.has_value()) {
|
||||
self.dispatch(
|
||||
{threads->at(0), threads->at(1)},
|
||||
{group_size->at(0), group_size->at(1)});
|
||||
} else {
|
||||
self.dispatch({threads->at(0), threads->at(1)});
|
||||
}
|
||||
} else {
|
||||
if (group_size.has_value()) {
|
||||
self.dispatch(
|
||||
{threads->at(0), threads->at(1), threads->at(2)},
|
||||
{group_size->at(0),
|
||||
group_size->at(1),
|
||||
group_size->at(2)});
|
||||
} else {
|
||||
self.dispatch(
|
||||
{threads->at(0), threads->at(1), threads->at(2)});
|
||||
}
|
||||
}
|
||||
});
|
||||
},
|
||||
py::kw_only(),
|
||||
py::arg("threads") = py::none(),
|
||||
py::arg("group_size") = py::none(),
|
||||
py::arg("arg_casts") = py::none())
|
||||
.def_property_readonly(
|
||||
"max_threads_per_threadgroup",
|
||||
&MetalKernelFunction::getMaxThreadsPerThreadgroup)
|
||||
.def_property_readonly(
|
||||
"thread_execution_width",
|
||||
&MetalKernelFunction::getThreadExecutionWidth)
|
||||
.def_property_readonly(
|
||||
"static_thread_group_memory_length",
|
||||
&MetalKernelFunction::getStaticThreadGroupMemoryLength);
|
||||
m.def("_mps_compileShader", [](const std::string& source) {
|
||||
return std::make_shared<DynamicMetalShaderLibrary>(source);
|
||||
});
|
||||
}
|
||||
#endif /* USE_MPS */
|
||||
|
||||
} // namespace torch::mps
|
||||
|
@ -5,6 +5,5 @@
|
||||
namespace torch::mps {
|
||||
|
||||
PyMethodDef* python_functions();
|
||||
void initModule(PyObject* module);
|
||||
|
||||
} // namespace torch::mps
|
||||
|
@ -140,22 +140,6 @@ def recommended_max_memory() -> int:
|
||||
return torch._C._mps_recommendedMaxMemory()
|
||||
|
||||
|
||||
def _compile_shader(source: str):
|
||||
r"""Compiles compute shader from source and allows one to invoke kernels
|
||||
defined there from the comfort of Python runtime
|
||||
Example::
|
||||
|
||||
>>> lib = torch.mps._compile_shader(
|
||||
... "kernel void full(device float* out, constant float& val, uint idx [[thread_position_in_grid]]) { out[idx] = val; }"
|
||||
... )
|
||||
>>> x = torch.zeros(16, device="mps")
|
||||
>>> lib.full(x, 3.14)
|
||||
"""
|
||||
if not hasattr(torch._C, "_mps_compileShader"):
|
||||
raise RuntimeError("MPS is not available")
|
||||
return torch._C._mps_compileShader(source)
|
||||
|
||||
|
||||
def is_available() -> bool:
|
||||
return device_count() > 0
|
||||
|
||||
|
Reference in New Issue
Block a user