[MPS] Remove all pre-MacOS14 logic (#159912)

Delete older enums, checks for MacOS-13.3+ for int64 support, etc

Fixes https://github.com/pytorch/pytorch/issues/159275
Pull Request resolved: https://github.com/pytorch/pytorch/pull/159912
Approved by: https://github.com/manuelcandales
This commit is contained in:
Nikita Shulga
2025-08-05 22:27:30 -07:00
committed by PyTorch MergeBot
parent c71950907d
commit d10e9e4781
20 changed files with 42 additions and 393 deletions

View File

@ -43,7 +43,6 @@ TensorBase empty_mps(
int64_t nelements = c10::multiply_integers(size);
auto dtype = dtype_or_default(dtype_opt);
TORCH_CHECK_TYPE(dtype != ScalarType::Double, MPS_ERROR_DOUBLE_NOT_SUPPORTED);
TORCH_CHECK_TYPE(dtype != ScalarType::BFloat16 || is_macos_13_or_newer(mps::MacOSVersion::MACOS_VER_14_0_PLUS), "MPS BFloat16 is only supported on MacOS 14 or newer");
auto dtype_meta = scalarTypeToTypeMeta(dtype);

View File

@ -18,11 +18,7 @@ namespace at::mps {
// Helper enum to check if a MPSGraph op is supported in a given macOS version
enum class MacOSVersion : uint32_t {
MACOS_VER_13_1_PLUS = 0,
MACOS_VER_13_2_PLUS,
MACOS_VER_13_3_PLUS,
MACOS_VER_14_0_PLUS,
MACOS_VER_14_4_PLUS,
MACOS_VER_14_4_PLUS = 0,
MACOS_VER_15_0_PLUS,
MACOS_VER_15_1_PLUS,
MACOS_VER_15_2_PLUS,

View File

@ -32,11 +32,11 @@ MPSDevice::~MPSDevice() {
MPSDevice::MPSDevice() : _mtl_device(nil) {
// Check that MacOS 13.0+ version of MPS framework is available
// Create the MPSGraph and check method introduced in 13.0
// Create the MPSGraph and check method introduced in 14.0
// which is used by MPS backend.
id mpsCD = NSClassFromString(@"MPSGraph");
if ([mpsCD instancesRespondToSelector:@selector(cumulativeSumWithTensor:axis:name:)] == NO) {
if ([mpsCD instancesRespondToSelector:@selector(HermiteanToRealFFTWithTensor:axes:descriptor:name:)] == NO) {
return;
}
@ -66,24 +66,12 @@ bool MPSDevice::isMacOS13Plus(MacOSVersion version) const {
isOperatingSystemAtLeastVersion:{.majorVersion = major, .minorVersion = minor, .patchVersion = 0}];
}
};
static bool _macos_13_1_plus = is_os_version_at_least(13, 1);
static bool _macos_13_2_plus = is_os_version_at_least(13, 2);
static bool _macos_13_3_plus = is_os_version_at_least(13, 3);
static bool _macos_14_0_plus = is_os_version_at_least(14, 0);
static bool _macos_14_4_plus = is_os_version_at_least(14, 4);
static bool _macos_15_0_plus = is_os_version_at_least(15, 0);
static bool _macos_15_1_plus = is_os_version_at_least(15, 1);
static bool _macos_15_2_plus = is_os_version_at_least(15, 2);
switch (version) {
case MacOSVersion::MACOS_VER_13_1_PLUS:
return _macos_13_1_plus;
case MacOSVersion::MACOS_VER_13_2_PLUS:
return _macos_13_2_plus;
case MacOSVersion::MACOS_VER_13_3_PLUS:
return _macos_13_3_plus;
case MacOSVersion::MACOS_VER_14_0_PLUS:
return _macos_14_0_plus;
case MacOSVersion::MACOS_VER_14_4_PLUS:
return _macos_14_4_plus;
case MacOSVersion::MACOS_VER_15_0_PLUS:

View File

@ -34,7 +34,7 @@ bool MPSHooks::isOnMacOSorNewer(unsigned major, unsigned minor) const {
case 14:
switch (minor) {
case 0:
return is_macos_13_or_newer(MacOSVersion::MACOS_VER_14_0_PLUS);
return true;
case 4:
return is_macos_13_or_newer(MacOSVersion::MACOS_VER_14_4_PLUS);
default:
@ -42,19 +42,7 @@ bool MPSHooks::isOnMacOSorNewer(unsigned major, unsigned minor) const {
return is_macos_13_or_newer(MacOSVersion::MACOS_VER_14_4_PLUS);
}
case 13:
switch (minor) {
case 0:
return true;
case 1:
return is_macos_13_or_newer(MacOSVersion::MACOS_VER_13_1_PLUS);
case 2:
return is_macos_13_or_newer(MacOSVersion::MACOS_VER_13_2_PLUS);
case 3:
return is_macos_13_or_newer(MacOSVersion::MACOS_VER_13_3_PLUS);
default:
TORCH_WARN("Can't check whether running on 13.", minor, "+ returning one for 13.3+");
return is_macos_13_or_newer(MacOSVersion::MACOS_VER_13_3_PLUS);
}
return true;
default:
TORCH_WARN("Checking for unexpected MacOS ", major, ".", minor, " returning false");
return false;

View File

@ -88,14 +88,8 @@ std::string getArrayRefString(const IntArrayRef s);
// use has_storage() on the returned tensor to determine if src actually is a view
Tensor gatherViewTensor(const Tensor& src, Tensor& dst);
Tensor& scatterViewTensor(const Tensor& src, Tensor& output);
MPSGraphTensor* castToIHFTypes(MPSGraph* mpsGraph,
MPSGraphTensor* inputTensor,
const TensorBase& input,
bool includesInt64 = false);
MPSGraphTensor* castFromIHFTypes(MPSGraph* mpsGraph,
MPSGraphTensor* inputTensor,
const TensorBase& input,
bool includesInt64 = false);
MPSGraphTensor* castToIHFTypes(MPSGraph* mpsGraph, MPSGraphTensor* inputTensor, const TensorBase& input);
MPSGraphTensor* castFromIHFTypes(MPSGraph* mpsGraph, MPSGraphTensor* inputTensor, const TensorBase& input);
MPSNDArray* getStridedMPSNDArray(const TensorBase& src, MPSNDArray* srcNDArray);
MPSNDArray* getMPSNDArray(const TensorBase& t, const IntArrayRef& sizes = {}, const IntArrayRef& strides = {});
@ -435,14 +429,6 @@ inline T* LookUpOrCreateCachedGraph(const std::string& key, std::function<void(M
// Common math operations
MPSGraphTensor* log1p(MPSGraph* mpsGraph, MPSGraphTensor* inputTensor);
#define MPS_CHECK_INT64_OP_SUPPORTED(input_tensor, mac_os_13_3_plus, op_name) \
if (!mac_os_13_3_plus && input_tensor.scalar_type() == kLong) { \
TORCH_WARN_ONCE( \
"MPS: no support for int64 for ", \
op_name, \
", downcasting to a smaller data type (int32/float32). Native support for int64 has been added in macOS 13.3."); \
}
/**
* Returns distance from lowest to highest element offset in given tensor.
*/
@ -618,10 +604,6 @@ inline void runMPSGraph(MPSStream* stream, MPSGraph* graph, NSDictionary* feeds,
runMPSGraph(stream, graph, feeds, dictionaryFromPlaceholders(result));
}
inline bool supportsComplex() {
return is_macos_13_or_newer(MacOSVersion::MACOS_VER_14_0_PLUS);
}
// MPS yet to support double types, but starting from MacOS 14, supports bfloat16
inline bool supportedFloatingType(ScalarType dtype) {
return dtype == kFloat || dtype == kHalf || dtype == kBFloat16;
@ -633,7 +615,7 @@ inline bool supportedFloatingType(const TensorBase& t) {
inline bool supportedFloatingOrComplexType(ScalarType dtype) {
if (dtype == kComplexFloat || dtype == kComplexHalf) {
return supportsComplex();
return true;
}
return supportedFloatingType(dtype);
}
@ -641,11 +623,6 @@ inline bool supportedFloatingOrComplexType(const TensorBase& t) {
return supportedFloatingOrComplexType(t.scalar_type());
}
inline void checkSupportsBFloat16() {
TORCH_CHECK_TYPE(is_macos_13_or_newer(MacOSVersion::MACOS_VER_14_0_PLUS),
"MPS bfloat16 type is supported on MacOS 14.0 or newer.");
}
inline bool needsGather(const TensorBase& t) {
static const bool is_macOS_15_0_or_newer = is_macos_13_or_newer(MacOSVersion::MACOS_VER_15_0_PLUS);
return !is_macOS_15_0_or_newer && (!t.is_contiguous() || t.storage_offset());

View File

@ -89,10 +89,6 @@ void runMPSGraph(MPSStream* mpsStream, MPSGraph* mpsGraph, NSDictionary* feeds,
mpsStream->executeMPSGraph(mpsGraph, feeds, results, SyncType::COMMIT_ADAPTIVE);
}
static inline void checkSupportsComplex() {
TORCH_CHECK_TYPE(supportsComplex(), "MPS complex types are only supported on MacOS 14.0 or newer.");
}
MPSDataType getMPSDataType(ScalarType scalar_type) {
switch (scalar_type) {
case ScalarType::Float:
@ -100,7 +96,6 @@ MPSDataType getMPSDataType(ScalarType scalar_type) {
case ScalarType::Half:
return MPSDataTypeFloat16;
case ScalarType::BFloat16:
checkSupportsBFloat16();
return MPSDataTypeBFloat16;
case ScalarType::Int:
return MPSDataTypeInt32;
@ -119,10 +114,8 @@ MPSDataType getMPSDataType(ScalarType scalar_type) {
"Cannot convert a float64 Tensor to MPS as the MPS framework doesn't support float64. "
"Please use float32 instead.")
case ScalarType::ComplexHalf:
checkSupportsComplex();
return MPSDataTypeComplexFloat16;
case ScalarType::ComplexFloat:
checkSupportsComplex();
return MPSDataTypeComplexFloat32;
// Unsigned types
case ScalarType::UInt64:
@ -140,16 +133,10 @@ MPSDataType getMPSDataType(ScalarType scalar_type) {
// #issue 104398441 sortWithTensor and argsortWithTensor has support of
// Int32, Half and Float32 types. These utilities are to help cast to these
// types.
MPSGraphTensor* castToIHFTypes(MPSGraph* mpsGraph,
MPSGraphTensor* inputTensor,
const TensorBase& input,
bool includesInt64) {
MPSGraphTensor* castToIHFTypes(MPSGraph* mpsGraph, MPSGraphTensor* inputTensor, const TensorBase& input) {
MPSDataType dataType = getMPSDataType(input.scalar_type());
bool condition =
(dataType != MPSDataTypeInt32) && (dataType != MPSDataTypeFloat32) && (dataType != MPSDataTypeFloat16);
if (includesInt64) {
condition = condition && (dataType != MPSDataTypeInt64);
}
bool condition = (dataType != MPSDataTypeInt32) && (dataType != MPSDataTypeFloat32) &&
(dataType != MPSDataTypeFloat16) && (dataType != MPSDataTypeInt64);
if (condition) {
dataType = (dataType & MPSDataTypeFloatBit) ? MPSDataTypeFloat32 : MPSDataTypeInt32;
return [mpsGraph castTensor:inputTensor toType:dataType name:@"castInputTensor"];
@ -160,16 +147,10 @@ MPSGraphTensor* castToIHFTypes(MPSGraph* mpsGraph,
// #issue 104398441 sortWithTensor and argsortWithTensor has support of
// Int32, Half and Float32 types. These utilities are to help cast from these
// types.
MPSGraphTensor* castFromIHFTypes(MPSGraph* mpsGraph,
MPSGraphTensor* inputTensor,
const TensorBase& input,
bool includesInt64) {
MPSGraphTensor* castFromIHFTypes(MPSGraph* mpsGraph, MPSGraphTensor* inputTensor, const TensorBase& input) {
MPSDataType dataType = getMPSDataType(input.scalar_type());
bool condition =
(dataType != MPSDataTypeInt32) && (dataType != MPSDataTypeFloat32) && (dataType != MPSDataTypeFloat16);
if (includesInt64) {
condition = condition && (dataType != MPSDataTypeInt64);
}
bool condition = (dataType != MPSDataTypeInt32) && (dataType != MPSDataTypeFloat32) &&
(dataType != MPSDataTypeFloat16) && (dataType != MPSDataTypeInt64);
if (condition) {
inputTensor = [mpsGraph castTensor:inputTensor toType:dataType name:@"castInputTensor"];
}
@ -186,7 +167,6 @@ MPSDataType getMPSScalarType(ScalarType scalar_type) {
case ScalarType::Half:
return MPSDataTypeFloat16;
case ScalarType::BFloat16:
checkSupportsBFloat16();
return MPSDataTypeBFloat16;
case ScalarType::Int:
return MPSDataTypeInt32;
@ -201,13 +181,11 @@ MPSDataType getMPSScalarType(ScalarType scalar_type) {
case ScalarType::Bool:
return MPSDataTypeBool;
case ScalarType::ComplexHalf:
checkSupportsComplex();
return MPSDataTypeComplexFloat16;
// This is an intentional fallthrough supporting ComplexDouble for Scalar
// types as they are casted to Complex64 currently.
case ScalarType::ComplexDouble:
case ScalarType::ComplexFloat:
checkSupportsComplex();
return MPSDataTypeComplexFloat32;
// Unsigned types
case ScalarType::UInt64:
@ -267,7 +245,6 @@ std::string scalarToMetalTypeString(const c10::ScalarType& scalar_type) {
case ScalarType::Half:
return "half";
case ScalarType::BFloat16:
checkSupportsBFloat16();
return "bfloat";
case ScalarType::Int:
return "int";
@ -879,9 +856,7 @@ id<MTLLibrary> MetalShaderLibrary::compileLibrary(const std::string& src) {
MTLCompileOptions* options = compile_options;
if (!options) {
options = [[MTLCompileOptions new] autorelease];
// Need 3.0 for atomic oprations, 3.1 introduces bfloat support
[options setLanguageVersion:is_macos_13_or_newer(MacOSVersion::MACOS_VER_14_0_PLUS) ? MTLLanguageVersion3_1
: MTLLanguageVersion3_0];
[options setLanguageVersion:MTLLanguageVersion3_1];
if (is_macos_13_or_newer(MacOSVersion::MACOS_VER_15_0_PLUS)) {
options.mathMode = fast_math ? MTLMathModeFast : MTLMathModeSafe;
options.mathFloatingPointFunctions =

View File

@ -48,28 +48,11 @@ typedef MPSGraphTensor* (^BinaryOpBlock)(BinaryOpCachedGraph*, MPSGraphTensor*,
#define BinaryOpFn(graph, primary, secondary) \
MPSGraphTensor*(mps::BinaryOpCachedGraph * graph, MPSGraphTensor * primary, MPSGraphTensor * secondary)
static inline Tensor legacy_complex_as_view(const Tensor& t) {
// Convert non-complex types (and cdouble CPU scalars) to cfloat
if (!isComplexType(t.scalar_type()) || t.scalar_type() == kComplexDouble) {
return at::view_as_real(t.to(kMPS, kComplexFloat));
}
return at::view_as_real(t.dim() != 0 ? t : t.to(kMPS));
}
static void binaryOpTensor(const Tensor& self,
const Tensor& other,
const Tensor& output_,
std::string op_name,
BinaryOpBlock binaryBlock) {
TORCH_CHECK(!(op_name == "power" && !is_macos_13_or_newer(MacOSVersion::MACOS_VER_13_2_PLUS) &&
(self.scalar_type() == ScalarType::Long ||
(other.scalar_type() == ScalarType::Long &&
(self.scalar_type() != ScalarType::Half && self.scalar_type() != ScalarType::Float)))),
"MPS: ",
op_name,
" op with int64 input is supported natively starting from macOS 13.2");
TORCH_CHECK_TYPE(!isComplexType(self.scalar_type()) || mps::supportsComplex(),
"Complex types are supported starting from MacOS 14.0+");
MPSStream* mpsStream = getCurrentMPSStream();
const bool is_self_scalar = self.dim() == 0;

View File

@ -51,9 +51,6 @@ inline void dot_check(const Tensor& self, const Tensor& other) {
} // namespace mps
Tensor dot_mps(const Tensor& self, const Tensor& other) {
TORCH_CHECK(is_macos_13_or_newer(MacOSVersion::MACOS_VER_14_0_PLUS) || self.scalar_type() != ScalarType::Long,
"MPS: dot op doesn't support int64 input on MacOS13")
using namespace mps;
using CachedGraph = MPSBinaryCachedGraph;

View File

@ -124,7 +124,6 @@ static Tensor _mps_convolution_impl(const Tensor& input_t_,
IntArrayRef dilation,
int64_t groups,
std::optional<IntArrayRef> input_shape) {
const bool is_macOS_13_2_or_newer = is_macos_13_or_newer(MacOSVersion::MACOS_VER_13_2_PLUS);
const bool is_macOS_15_0_or_newer = is_macos_13_or_newer(MacOSVersion::MACOS_VER_15_0_PLUS);
Tensor input_t = input_t_;
bool is3DConv = input_t.dim() == 5;
@ -132,9 +131,6 @@ static Tensor _mps_convolution_impl(const Tensor& input_t_,
input_t = input_t.contiguous();
}
TORCH_CHECK(((input_t.dim() < 5) || is_macOS_13_2_or_newer),
"Conv3D is only supported on MPS for MacOS_13_2 or newer");
TORCH_CHECK(isFloatingType(input_t.scalar_type()), "Convolution is supported only for Floating types");
using namespace at::native::mps;

View File

@ -60,7 +60,6 @@ static void copy_cast_mps(at::Tensor& dst,
outputTensor = [mpsGraph castTensor:outputTensor toType:dstDType name:@"cast"];
}
if (needs_conj) {
TORCH_CHECK(supportsComplex(), "MPS complex tensors conjugation needs MacOS14+");
outputTensor = [mpsGraph conjugateWithTensor:outputTensor name:nil];
}
@ -275,24 +274,7 @@ static at::Tensor& copy_kernel_mps(at::Tensor& dst_, const at::Tensor& src_, boo
// for GPU to GPU copies we only encode to stream's command buffer (no flushing)
stream->copy(sourceBuffer, destBuffer, src.nbytes(), src_byte_offset, dst_byte_offset, profile_id);
} else {
// Simulate cast to Complex on older MacOS by initializing real and imag parts
if (dst_.is_complex() && !supportsComplex()) {
if (!src.is_complex()) {
at::real(dst_).copy_(src);
at::imag(dst_).fill_(0);
} else if (src.is_conj() || dst_.is_conj()) {
// One cannot take view of conjugated tensor, but for some reason real and imag views are fine
// Use this to implement a conjugation
at::real(dst_).copy_(at::real(src));
if (src.is_conj() != dst_.is_conj()) {
at::imag(dst_).copy_(at::neg(at::imag(src)));
} else {
at::imag(dst_).copy_(at::imag(src));
}
} else {
at::view_as_real(dst_).copy_(at::view_as_real(src));
}
} else if (dst_byte_offset) {
if (dst_byte_offset) {
auto maybeCastedSource =
at::empty(dst_.sizes(), dst_.scalar_type(), std::nullopt, kMPS, std::nullopt, std::nullopt);
auto maybeCastedSourceBuffer = getMTLBufferStorage(maybeCastedSource);

View File

@ -87,7 +87,6 @@ Tensor& random_mps_impl(Tensor& self,
case kFloat:
return MPSDataTypeFloat32;
case kBFloat16: {
checkSupportsBFloat16();
return MPSDataTypeBFloat16;
}
default:

View File

@ -88,7 +88,6 @@ using namespace mps;
// TODO: Investigate numerical discrepancies see https://github.com/pytorch/pytorch/issues/120237
Tensor& _fft_r2c_mps_out(const Tensor& self, IntArrayRef dim, int64_t normalization, bool onesided, Tensor& out) {
TORCH_CHECK(supportsComplex(), "FFT operations are only supported on MacOS 14+");
auto key = __func__ + getTensorsStringKey({self, out}) + ":" + getArrayRefString(dim) + ":" +
std::to_string(normalization) + ":" + std::to_string(onesided);
@autoreleasepool {
@ -129,7 +128,6 @@ Tensor& _fft_c2r_mps_out(const Tensor& self,
int64_t normalization,
int64_t last_dim_size,
Tensor& out) {
TORCH_CHECK(supportsComplex(), "FFT operations are only supported on MacOS 14+");
auto key = __func__ + getTensorsStringKey({self}) + ":" + getArrayRefString(dim) + ":" +
std::to_string(normalization) + ":" + std::to_string(last_dim_size);
@autoreleasepool {
@ -155,7 +153,6 @@ Tensor& _fft_c2r_mps_out(const Tensor& self,
}
Tensor& _fft_c2c_mps_out(const Tensor& self, IntArrayRef dim, int64_t normalization, bool forward, Tensor& out) {
TORCH_CHECK(supportsComplex(), "FFT operations are only supported on MacOS 14+");
auto key = __func__ + getTensorsStringKey({self}) + ":" + getArrayRefString(dim) + ":" +
std::to_string(normalization) + ":" + std::to_string(forward);
@autoreleasepool {

View File

@ -127,15 +127,6 @@ Tensor grid_sampler_2d_mps(const Tensor& input,
int64_t interpolation_mode,
int64_t padding_mode,
bool align_corners) {
if (!is_macos_13_or_newer(MacOSVersion::MACOS_VER_13_2_PLUS)) {
TORCH_WARN_ONCE("MPS: grid_sampler_2d op is supported natively starting from macOS 13.2. ",
"Falling back on CPU. This may have performance implications.");
return at::grid_sampler_2d(input.to("cpu"), grid.to("cpu"), interpolation_mode, padding_mode, align_corners)
.clone()
.to("mps");
}
auto in_size = input.sizes();
auto grid_size = grid.sizes();
auto output = at::empty({in_size[0], in_size[1], grid_size[1], grid_size[2]}, input.options());

View File

@ -353,14 +353,7 @@ static Tensor& nonzero_out_native_mps(const Tensor& self, Tensor& out_) {
}
Tensor& nonzero_out_mps(const Tensor& self, Tensor& out_) {
if (!is_macos_13_or_newer(MacOSVersion::MACOS_VER_14_0_PLUS)) {
TORCH_WARN_ONCE("MPS: nonzero op is supported natively starting from macOS 14.0. ",
"Falling back on CPU. This may have performance implications.");
Tensor out_fallback = nonzero_fallback(self);
at::native::resize_output(out_, out_fallback.sizes());
out_.copy_(out_fallback);
return out_;
} else if (self.is_complex()) {
if (self.is_complex()) {
TORCH_WARN_ONCE("MPS: nonzero op is not supported for complex datatypes. ",
"Falling back on CPU. This may have performance implications.");
Tensor out_fallback = nonzero_fallback(self);
@ -445,11 +438,7 @@ Tensor& nonzero_out_mps(const Tensor& self, Tensor& out_) {
}
Tensor nonzero_mps(const Tensor& self) {
if (!is_macos_13_or_newer(MacOSVersion::MACOS_VER_14_0_PLUS)) {
TORCH_WARN_ONCE("MPS: nonzero op is supported natively starting from macOS 14.0. ",
"Falling back on CPU. This may have performance implications.");
return nonzero_fallback(self);
} else if (self.is_complex()) {
if (self.is_complex()) {
TORCH_WARN_ONCE("MPS: nonzero op is not supported for complex datatypes ",
"Falling back on CPU. This may have performance implications.");
return nonzero_fallback(self);

View File

@ -152,8 +152,6 @@ static void reduction_out_mps(const Tensor& input_t,
const Tensor& output_t,
MPSReductionType reduction_type,
const std::string& func_name) {
bool macOS13_3_plus = is_macos_13_or_newer(MacOSVersion::MACOS_VER_13_3_PLUS);
MPS_CHECK_INT64_OP_SUPPORTED(input_t, macOS13_3_plus, func_name);
// NS: TODO: get rid of all those shenanigans and just call reduction_op with view tensor
bool canSqueezeLastDim = true;
IntArrayRef input_shape = input_t.sizes();
@ -236,12 +234,10 @@ static void reduction_out_mps(const Tensor& input_t,
MPSGraphTensor* castInputTensor = inputTensor;
MPSDataType inputCastType = MPSDataTypeInvalid;
if (dtype.has_value() &&
(dtype.value() == kFloat || dtype.value() == kHalf || dtype.value() == kInt ||
(dtype.value() == kLong && macOS13_3_plus))) {
(dtype.value() == kFloat || dtype.value() == kHalf || dtype.value() == kInt || dtype.value() == kLong)) {
inputCastType = getMPSDataType(dtype.value());
} else if (inputScalarType != kInt && inputScalarType != kHalf && inputScalarType != kFloat &&
inputScalarType != kComplexFloat && inputScalarType != kComplexHalf &&
(inputScalarType != kLong || !macOS13_3_plus)) {
inputScalarType != kComplexFloat && inputScalarType != kComplexHalf && inputScalarType != kLong) {
inputCastType = getMPSDataType(kFloat);
}
@ -615,9 +611,6 @@ static Tensor std_var_common_impl_mps(const Tensor& input_t,
}
static Tensor median_common_mps(const Tensor& input_t, bool nanmedian) {
bool macOS13_3_plus = is_macos_13_or_newer(MacOSVersion::MACOS_VER_13_3_PLUS);
MPS_CHECK_INT64_OP_SUPPORTED(input_t, macOS13_3_plus, nanmedian ? "nanmedian" : "median");
IntArrayRef input_shape = input_t.sizes();
int64_t num_in_elements = c10::multiply_integers(input_shape);
@ -634,8 +627,7 @@ static Tensor median_common_mps(const Tensor& input_t, bool nanmedian) {
auto medianCachedGraph =
LookUpOrCreateCachedGraph<MedianCachedGraph>(medianKey, [&](auto mpsGraph, auto newCachedGraph) {
MPSGraphTensor* inputTensor = mpsGraphRankedPlaceHolder(mpsGraph, input_t);
MPSGraphTensor* castInputTensor =
castToIHFTypes(mpsGraph, inputTensor, input_t, /*includesInt64=*/macOS13_3_plus);
MPSGraphTensor* castInputTensor = castToIHFTypes(mpsGraph, inputTensor, input_t);
MPSGraphTensor* reshapedTensor = [mpsGraph reshapeTensor:castInputTensor withShape:@[ @-1 ] name:nil];
@ -693,9 +685,6 @@ static Tensor median_common_mps(const Tensor& input_t, bool nanmedian) {
}
static Tensor min_max_mps_impl(const Tensor& input_t, MPSReductionType reduction_type, const std::string& func_name) {
bool macOS13_3_plus = is_macos_13_or_newer(MacOSVersion::MACOS_VER_13_3_PLUS);
MPS_CHECK_INT64_OP_SUPPORTED(input_t, macOS13_3_plus, "min_max");
using CachedGraph = MPSUnaryCachedGraph;
IntArrayRef input_shape = input_t.sizes();
@ -713,8 +702,7 @@ static Tensor min_max_mps_impl(const Tensor& input_t, MPSReductionType reduction
MPSGraphTensor* inputTensor = mpsGraphRankedPlaceHolder(mpsGraph, input_t);
MPSGraphTensor* castOutputTensor = nil;
MPSGraphTensor* castInputTensor =
castToIHFTypes(mpsGraph, inputTensor, input_t, /*includesInt64=*/macOS13_3_plus);
MPSGraphTensor* castInputTensor = castToIHFTypes(mpsGraph, inputTensor, input_t);
NSArray<NSNumber*>* axes = getTensorAxes(input_t);
if (reduction_type == MPSReductionType::MAX) {
@ -749,9 +737,6 @@ static void min_max_out_mps(const Tensor& input_t,
const Tensor& indices_t,
MPSReductionType reduction_type,
const std::string& func_name) {
bool macOS13_3_plus = is_macos_13_or_newer(MacOSVersion::MACOS_VER_13_3_PLUS);
MPS_CHECK_INT64_OP_SUPPORTED(input_t, macOS13_3_plus, "min_max_out");
if (output_t.numel() == 0) {
return;
}
@ -789,8 +774,7 @@ static void min_max_out_mps(const Tensor& input_t,
auto cachedGraph = LookUpOrCreateCachedGraph<CachedGraph>(key, [&](auto mpsGraph, auto newCachedGraph) {
MPSGraphTensor* inputTensor = mpsGraphRankedPlaceHolder(mpsGraph, input_t);
MPSGraphTensor* outputTensor = nil;
MPSGraphTensor* castInputTensor =
castToIHFTypes(mpsGraph, inputTensor, input_t, /*includesInt64=*/macOS13_3_plus);
MPSGraphTensor* castInputTensor = castToIHFTypes(mpsGraph, inputTensor, input_t);
if (reduction_type == MPSReductionType::MAX) {
outputTensor = [mpsGraph reductionMaximumPropagateNaNWithTensor:castInputTensor axis:(NSInteger)dim_ name:nil];
@ -896,9 +880,6 @@ static void argmax_argmin_out_mps(const Tensor& input_t,
const std::string& func_name) {
using CachedGraph = MPSUnaryCachedGraph;
bool macOS13_3_plus = is_macos_13_or_newer(MacOSVersion::MACOS_VER_13_3_PLUS);
MPS_CHECK_INT64_OP_SUPPORTED(input_t, macOS13_3_plus, "argmax_argmin_out");
int64_t dim_ = -1;
if (dim.has_value()) {
@ -953,7 +934,7 @@ static void argmax_argmin_out_mps(const Tensor& input_t,
MPSGraphTensor* castInputTensor = inputTensor;
if (inputScalarType != kInt && inputScalarType != kHalf && inputScalarType != kFloat &&
(inputScalarType != kLong || !macOS13_3_plus)) {
inputScalarType != kLong) {
castInputTensor = castMPSTensor(mpsGraph, inputTensor, kFloat);
}
if (reduction_type == MPSReductionType::MAX) {
@ -1282,9 +1263,6 @@ static void all_any_common_impl_mps(const Tensor& input_t,
return;
}
bool macOS13_3_plus = is_macos_13_or_newer(MacOSVersion::MACOS_VER_13_3_PLUS);
MPS_CHECK_INT64_OP_SUPPORTED(input_t, macOS13_3_plus, op_name);
int64_t dim_ = maybe_wrap_dim(dim, input_t.dim());
native::zero_numel_check_dims(input_t, dim_, op_name.c_str());
@ -1303,7 +1281,7 @@ static void all_any_common_impl_mps(const Tensor& input_t,
auto cachedGraph = LookUpOrCreateCachedGraph<CachedGraph>(key, [&](auto mpsGraph, auto newCachedGraph) {
auto inputTensor = mpsGraphRankedPlaceHolder(mpsGraph, input_t);
auto castInputTensor = castToIHFTypes(mpsGraph, inputTensor, input_t, /*includesInt64=*/macOS13_3_plus);
auto castInputTensor = castToIHFTypes(mpsGraph, inputTensor, input_t);
// reductionOrWithTensor:axis: will throw an internal assert if number of dimentions is more than 4
// See https://github.com/pytorch/pytorch/issues/95538
MPSGraphTensor* outputTensor = nil;
@ -1369,14 +1347,11 @@ TORCH_IMPL_FUNC(any_all_out_mps)(const Tensor& input_t, const Tensor& output_t)
return;
}
bool macOS13_3_plus = is_macos_13_or_newer(MacOSVersion::MACOS_VER_13_3_PLUS);
MPS_CHECK_INT64_OP_SUPPORTED(input_t, macOS13_3_plus, "any_all_out");
@autoreleasepool {
std::string key = std::string("any_all_out_mps:") + getTensorsStringKey(input_t);
auto cachedGraph = LookUpOrCreateCachedGraph<CachedGraph>(key, [&](auto mpsGraph, auto newCachedGraph) {
auto inputTensor = mpsGraphRankedPlaceHolder(mpsGraph, input_t);
auto castInputTensor = castToIHFTypes(mpsGraph, inputTensor, input_t, /*includesInt64=*/macOS13_3_plus);
auto castInputTensor = castToIHFTypes(mpsGraph, inputTensor, input_t);
// reductionOrWithTensor:axes: will throw an internal assert if number of dimentions is more than 4
// See https://github.com/pytorch/pytorch/issues/95538
if (input_t.dim() > 4) {
@ -1420,14 +1395,11 @@ TORCH_IMPL_FUNC(all_all_out_mps)(const Tensor& input_t, const Tensor& output_t)
return;
}
bool macOS13_3_plus = is_macos_13_or_newer(MacOSVersion::MACOS_VER_13_3_PLUS);
MPS_CHECK_INT64_OP_SUPPORTED(input_t, macOS13_3_plus, "all_all_out");
@autoreleasepool {
std::string key = std::string("all_all_out_mps:") + getTensorsStringKey(input_t);
auto cachedGraph = LookUpOrCreateCachedGraph<CachedGraph>(key, [&](auto mpsGraph, auto newCachedGraph) {
auto inputTensor = mpsGraphRankedPlaceHolder(mpsGraph, input_t);
auto castInputTensor = castToIHFTypes(mpsGraph, inputTensor, input_t, /*includesInt64=*/macOS13_3_plus);
auto castInputTensor = castToIHFTypes(mpsGraph, inputTensor, input_t);
// reductionAndWithTensor:axes: will throw an internal assert if number of dimentions is more than 4
// See https://github.com/pytorch/pytorch/issues/95538
if (input_t.ndimension() > 4) {
@ -1512,9 +1484,6 @@ static void median_out_mps_common(const Tensor& input_t,
Tensor& indices,
const std::string& func_name,
bool nanmedian) {
bool macOS13_3_plus = is_macos_13_or_newer(MacOSVersion::MACOS_VER_13_3_PLUS);
MPS_CHECK_INT64_OP_SUPPORTED(input_t, macOS13_3_plus, "median_out");
int64_t dim_ = maybe_wrap_dim(dim, input_t.dim());
native::zero_numel_check_dims(input_t, dim_, "max()");
@ -1585,8 +1554,7 @@ static void median_out_mps_common(const Tensor& input_t,
getTensorsStringKey(indices);
auto cachedGraph = LookUpOrCreateCachedGraph<CachedGraph>(key, [&](auto mpsGraph, auto newCachedGraph) {
MPSGraphTensor* inputTensor = mpsGraphRankedPlaceHolder(mpsGraph, input_t);
MPSGraphTensor* castInputTensor =
castToIHFTypes(mpsGraph, inputTensor, input_t, /*includesInt64=*/macOS13_3_plus);
MPSGraphTensor* castInputTensor = castToIHFTypes(mpsGraph, inputTensor, input_t);
MPSGraphTensor* effectiveLengthTensor = nil;
if (nanmedian) {

View File

@ -129,16 +129,8 @@ void computeRepeatIndices(const index_t* repeat_ptr,
});
}
Tensor repeat_interleave_mps(const Tensor& repeat_, std::optional<int64_t> output_size) {
Tensor repeat_interleave_mps(const Tensor& repeat, std::optional<int64_t> output_size) {
Tensor output;
Tensor repeat = repeat_;
if (repeat.scalar_type() == kLong && !is_macos_13_or_newer(MacOSVersion::MACOS_VER_13_3_PLUS)) {
// #103810551: `repeat_interleave_common` uses cumsum to calculate the final shape of output,
// which currently doesn't support int64_t as input. Casting internally the indices to int32_t.
TORCH_WARN_ONCE(
"MPS: no support for int64 repeats mask, casting it to int32. Support has been added in macOS 13.3");
repeat = repeat.to(kInt);
}
AT_DISPATCH_INDEX_TYPES(repeat.scalar_type(), "repeat_interleave_mps", [&]() {
output = repeat_interleave_common<index_t, computeRepeatIndices<index_t>>(repeat, output_size);
});

View File

@ -23,125 +23,6 @@ static auto& lib = MetalShaderLibrary::getBundledLibrary();
#include <ATen/native/mps/ScanKernel_metallib.h>
#endif
// Generic scan implementation that handles both simple scans and scans with indices
static void scan_mps_impl(const Tensor& self,
const std::vector<Tensor>& outputs,
int64_t dim,
const std::string& op_name) {
if (outputs[0].numel() == 0) {
return;
}
const int64_t ndim = self.dim();
const int64_t wrapped_dim = maybe_wrap_dim(dim, ndim);
// Calculate dimensions for scan operation
int64_t row_size = self.size(wrapped_dim);
auto sizes = self.sizes();
bool is_innermost = (wrapped_dim == ndim - 1);
// Check if all tensors are contiguous
bool is_contiguous = self.is_contiguous();
for (const auto& output : outputs) {
is_contiguous = is_contiguous && output.is_contiguous();
}
uint32_t num_rows, num_orows, num_irows, num_threads;
if (is_innermost) {
// Treat all outer dimensions as a single dimension
num_rows = self.numel() / row_size;
num_threads = num_rows;
} else {
// Treat all outer dimensions (i.e. dim_ < dim) as one
num_orows = c10::multiply_integers(sizes.begin(), sizes.begin() + wrapped_dim);
// Treat all inner dimensions (i.e. dim > dimension) as one
num_irows = c10::multiply_integers(sizes.begin() + wrapped_dim + 1, sizes.end());
num_threads = num_orows * num_irows;
}
MPSStream* mpsStream = getCurrentMPSStream();
dispatch_sync_with_rethrow(mpsStream->queue(), ^() {
@autoreleasepool {
id<MTLComputeCommandEncoder> computeEncoder = mpsStream->commandEncoder();
// Choose kernel based on contiguity and dimension
std::string kernel_name;
if (is_contiguous) {
kernel_name =
op_name + "_contiguous_" + (is_innermost ? "innermost_" : "outer_") + scalarToMetalTypeString(self);
} else {
kernel_name = op_name + "_strided_" + scalarToMetalTypeString(self);
}
id<MTLComputePipelineState> scanPSO = lib.getPipelineStateForFunc(kernel_name);
// this function call is a no-op if MPS Profiler is not enabled
getMPSProfiler().beginProfileKernel(scanPSO, op_name, [&]() {
std::vector<Tensor> all_tensors = {self};
all_tensors.insert(all_tensors.end(), outputs.begin(), outputs.end());
return all_tensors;
}());
[computeEncoder setComputePipelineState:scanPSO];
// Set input tensor
mtl_setBuffer(computeEncoder, self, 0);
// Set output tensors
for (size_t i = 0; i < outputs.size(); ++i) {
mtl_setBuffer(computeEncoder, outputs[i], i + 1);
}
if (is_contiguous) {
// Contiguous kernels
if (is_innermost) {
if (outputs.size() == 1) {
// Simple scan
mtl_setArgs<2>(computeEncoder, num_rows, static_cast<uint32_t>(row_size));
} else {
// Scan with indices
mtl_setArgs<3>(computeEncoder, num_rows, static_cast<uint32_t>(row_size));
}
} else {
if (outputs.size() == 1) {
// Simple scan
mtl_setArgs<2>(computeEncoder, num_orows, num_irows, static_cast<uint32_t>(row_size));
} else {
// Scan with indices
mtl_setArgs<3>(computeEncoder, num_orows, num_irows, static_cast<uint32_t>(row_size));
}
}
} else {
// Strided kernels - pass full tensor information
if (outputs.size() == 1) {
// Simple scan
mtl_setArgs<2>(computeEncoder,
self.sizes(),
self.strides(),
outputs[0].strides(),
static_cast<uint32_t>(self.ndimension()),
static_cast<uint32_t>(wrapped_dim));
} else {
// Scan with indices
mtl_setArgs<3>(computeEncoder,
self.sizes(),
self.strides(),
outputs[0].strides(),
outputs[1].strides(),
static_cast<uint32_t>(self.ndimension()),
static_cast<uint32_t>(wrapped_dim));
}
}
mtl_dispatch1DJob(computeEncoder, scanPSO, num_threads);
getMPSProfiler().endProfileKernel(scanPSO);
}
});
}
// Utility function to get 2D grid dimensions for dispatch
static std::pair<uint32_t, uint32_t> get_2d_grid_dims(const IntArrayRef& shape, const int64_t dim) {
size_t grid_x = 1;
@ -375,19 +256,11 @@ static void scan_with_indices_mps_impl(const Tensor& self,
} // namespace mps
void cummax_helper_mps(const Tensor& self, Tensor& values, Tensor& indices, int64_t dim) {
if (is_macos_13_or_newer(MacOSVersion::MACOS_VER_14_0_PLUS)) {
mps::scan_with_indices_mps_impl(self, values, indices, dim, "cummax");
} else {
mps::scan_mps_impl(self, {values, indices}, dim, "cummax");
}
mps::scan_with_indices_mps_impl(self, values, indices, dim, "cummax");
}
void cummin_helper_mps(const Tensor& self, Tensor& values, Tensor& indices, int64_t dim) {
if (is_macos_13_or_newer(MacOSVersion::MACOS_VER_14_0_PLUS)) {
mps::scan_with_indices_mps_impl(self, values, indices, dim, "cummin");
} else {
mps::scan_mps_impl(self, {values, indices}, dim, "cummin");
}
mps::scan_with_indices_mps_impl(self, values, indices, dim, "cummin");
}
Tensor& _logcumsumexp_out_mps(const Tensor& self, int64_t dim, Tensor& result) {
@ -402,11 +275,7 @@ Tensor& _logcumsumexp_out_mps(const Tensor& self, int64_t dim, Tensor& result) {
return result;
}
if (is_macos_13_or_newer(MacOSVersion::MACOS_VER_14_0_PLUS)) {
mps::scan_simple_mps_impl(self, result, wrap_dim, "logcumsumexp");
} else {
mps::scan_mps_impl(self, {result}, wrap_dim, "logcumsumexp");
}
mps::scan_simple_mps_impl(self, result, wrap_dim, "logcumsumexp");
return result;
}

View File

@ -26,9 +26,6 @@ TORCH_IMPL_FUNC(sort_stable_out_mps)
const Tensor& indices) {
using namespace mps;
bool macOS13_3_plus = is_macos_13_or_newer(MacOSVersion::MACOS_VER_13_3_PLUS);
MPS_CHECK_INT64_OP_SUPPORTED(self, macOS13_3_plus, "sort_stable_out");
if (self.numel() == 0) {
return;
}
@ -55,8 +52,7 @@ TORCH_IMPL_FUNC(sort_stable_out_mps)
auto cachedGraph = LookUpOrCreateCachedGraph<CachedGraph>(key, [&](auto mpsGraph, auto newCachedGraph) {
newCachedGraph->selfTensor = mpsGraphRankedPlaceHolder(mpsGraph, getMPSDataType(self), input_shape);
MPSGraphTensor* castInputTensor =
castToIHFTypes(mpsGraph, newCachedGraph->selfTensor, self, /*includesInt64=*/macOS13_3_plus);
MPSGraphTensor* castInputTensor = castToIHFTypes(mpsGraph, newCachedGraph->selfTensor, self);
MPSGraphTensor* sortedTensor = [mpsGraph sortWithTensor:castInputTensor
axis:(NSInteger)dim
descending:(BOOL)descending

View File

@ -297,9 +297,6 @@ static void isin_Tensor_Tensor_out_mps(const Tensor& elements,
const auto common_type = at::result_type(elements, test_elements);
TORCH_CHECK(elements.is_mps() && test_elements.is_mps());
TORCH_CHECK(is_macos_13_or_newer(MacOSVersion::MACOS_VER_14_0_PLUS) || supportedFloatingType(common_type),
"isin_Tensor_Tensor_out only works on floating types on MPS for pre MacOS_14_0. Received dtype: ",
common_type);
@autoreleasepool {
std::string key = op_name + getTensorsStringKey({elements, test_elements}) + std::to_string(invert);

View File

@ -208,28 +208,12 @@ Tensor& logical_not_out_mps(const Tensor& self, Tensor& output) {
}
Tensor& angle_out_mps(const Tensor& self, Tensor& output) {
if (mps::supportsComplex()) {
mps::unary_op(self, output, "angle_out_mps", ^MPSGraphTensor*(MPSGraph* mpsGraph, MPSGraphTensor* inputTensor) {
auto realPart = [mpsGraph realPartOfTensor:inputTensor name:nil];
auto imagPart = [mpsGraph imaginaryPartOfTensor:inputTensor name:nil];
return [mpsGraph atan2WithPrimaryTensor:imagPart secondaryTensor:realPart name:nil];
});
return output;
} else {
TORCH_CHECK(!self.is_complex(), "MPS does not support angle with complex input on macOS13")
mps::unary_op(self, output, "angle_out_mps", ^MPSGraphTensor*(MPSGraph* mpsGraph, MPSGraphTensor* inputTensor) {
// On macOS 13 with non-complex input, realPartOfTensor and imaginaryPartOfTensor are
// not available, and NaN is not propagated correctly:
auto imagPart = [mpsGraph constantWithScalar:0.0 shape:inputTensor.shape dataType:inputTensor.dataType];
auto result = [mpsGraph atan2WithPrimaryTensor:imagPart secondaryTensor:inputTensor name:nil];
auto nanMask = [mpsGraph isNaNWithTensor:inputTensor name:nil];
return [mpsGraph selectWithPredicateTensor:nanMask
truePredicateTensor:inputTensor
falsePredicateTensor:result
name:nil];
});
return output;
}
mps::unary_op(self, output, "angle_out_mps", ^MPSGraphTensor*(MPSGraph* mpsGraph, MPSGraphTensor* inputTensor) {
auto realPart = [mpsGraph realPartOfTensor:inputTensor name:nil];
auto imagPart = [mpsGraph imaginaryPartOfTensor:inputTensor name:nil];
return [mpsGraph atan2WithPrimaryTensor:imagPart secondaryTensor:realPart name:nil];
});
return output;
}
Tensor angle_mps(const Tensor& self) {
@ -362,7 +346,6 @@ static void cumulative_op_impl(const Tensor& self,
const Tensor& result,
MPSCumulativeOpType cumulativeOpType,
const std::string& op_name) {
bool macOS13_3_plus = is_macos_13_or_newer(MacOSVersion::MACOS_VER_13_3_PLUS);
auto nDims = self.dim();
auto wrapped_dim = maybe_wrap_dim(dim, nDims);
TORCH_CHECK(wrapped_dim >= 0 && wrapped_dim < std::max(1LL, self.ndimension()),
@ -381,11 +364,6 @@ static void cumulative_op_impl(const Tensor& self,
bool castInputData = (isIntegralType(input.scalar_type(), true) && input.scalar_type() != ScalarType::Int &&
input.scalar_type() != ScalarType::Long);
TORCH_CHECK(macOS13_3_plus || input.scalar_type() != ScalarType::Long,
"MPS does not support ",
op_name,
" op with int64 input. Support has been added in macOS 13.3");
mps::unary_op(
input, result, op_name + std::to_string(dim), ^MPSGraphTensor*(MPSGraph* mpsGraph, MPSGraphTensor* inputTensor) {
if (castInputData) {
@ -440,17 +418,9 @@ TORCH_IMPL_FUNC(sgn_out_mps)(const Tensor& self, const Tensor& output) {
Tensor& conj_physical_out_mps(const Tensor& self, Tensor& result) {
TORCH_CHECK(self.is_complex());
if (!mps::supportsComplex()) {
if (!result.is_same_size(self)) {
result.resize_(self.sizes());
}
at::real(result).copy_(at::real(self));
at::imag(result).copy_(at::neg(at::imag(self)));
} else {
mps::unary_op(self, result, "conj", ^MPSGraphTensor*(MPSGraph* mpsGraph, MPSGraphTensor* inputTensor) {
return [mpsGraph conjugateWithTensor:inputTensor name:nil];
});
}
mps::unary_op(self, result, "conj", ^MPSGraphTensor*(MPSGraph* mpsGraph, MPSGraphTensor* inputTensor) {
return [mpsGraph conjugateWithTensor:inputTensor name:nil];
});
return result;
}