mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[MPS] Remove all pre-MacOS14 logic (#159912)
Delete older enums, checks for MacOS-13.3+ for int64 support, etc Fixes https://github.com/pytorch/pytorch/issues/159275 Pull Request resolved: https://github.com/pytorch/pytorch/pull/159912 Approved by: https://github.com/manuelcandales
This commit is contained in:
committed by
PyTorch MergeBot
parent
c71950907d
commit
d10e9e4781
@ -43,7 +43,6 @@ TensorBase empty_mps(
|
||||
int64_t nelements = c10::multiply_integers(size);
|
||||
auto dtype = dtype_or_default(dtype_opt);
|
||||
TORCH_CHECK_TYPE(dtype != ScalarType::Double, MPS_ERROR_DOUBLE_NOT_SUPPORTED);
|
||||
TORCH_CHECK_TYPE(dtype != ScalarType::BFloat16 || is_macos_13_or_newer(mps::MacOSVersion::MACOS_VER_14_0_PLUS), "MPS BFloat16 is only supported on MacOS 14 or newer");
|
||||
|
||||
|
||||
auto dtype_meta = scalarTypeToTypeMeta(dtype);
|
||||
|
@ -18,11 +18,7 @@ namespace at::mps {
|
||||
|
||||
// Helper enum to check if a MPSGraph op is supported in a given macOS version
|
||||
enum class MacOSVersion : uint32_t {
|
||||
MACOS_VER_13_1_PLUS = 0,
|
||||
MACOS_VER_13_2_PLUS,
|
||||
MACOS_VER_13_3_PLUS,
|
||||
MACOS_VER_14_0_PLUS,
|
||||
MACOS_VER_14_4_PLUS,
|
||||
MACOS_VER_14_4_PLUS = 0,
|
||||
MACOS_VER_15_0_PLUS,
|
||||
MACOS_VER_15_1_PLUS,
|
||||
MACOS_VER_15_2_PLUS,
|
||||
|
@ -32,11 +32,11 @@ MPSDevice::~MPSDevice() {
|
||||
|
||||
MPSDevice::MPSDevice() : _mtl_device(nil) {
|
||||
// Check that MacOS 13.0+ version of MPS framework is available
|
||||
// Create the MPSGraph and check method introduced in 13.0
|
||||
// Create the MPSGraph and check method introduced in 14.0
|
||||
// which is used by MPS backend.
|
||||
id mpsCD = NSClassFromString(@"MPSGraph");
|
||||
|
||||
if ([mpsCD instancesRespondToSelector:@selector(cumulativeSumWithTensor:axis:name:)] == NO) {
|
||||
if ([mpsCD instancesRespondToSelector:@selector(HermiteanToRealFFTWithTensor:axes:descriptor:name:)] == NO) {
|
||||
return;
|
||||
}
|
||||
|
||||
@ -66,24 +66,12 @@ bool MPSDevice::isMacOS13Plus(MacOSVersion version) const {
|
||||
isOperatingSystemAtLeastVersion:{.majorVersion = major, .minorVersion = minor, .patchVersion = 0}];
|
||||
}
|
||||
};
|
||||
static bool _macos_13_1_plus = is_os_version_at_least(13, 1);
|
||||
static bool _macos_13_2_plus = is_os_version_at_least(13, 2);
|
||||
static bool _macos_13_3_plus = is_os_version_at_least(13, 3);
|
||||
static bool _macos_14_0_plus = is_os_version_at_least(14, 0);
|
||||
static bool _macos_14_4_plus = is_os_version_at_least(14, 4);
|
||||
static bool _macos_15_0_plus = is_os_version_at_least(15, 0);
|
||||
static bool _macos_15_1_plus = is_os_version_at_least(15, 1);
|
||||
static bool _macos_15_2_plus = is_os_version_at_least(15, 2);
|
||||
|
||||
switch (version) {
|
||||
case MacOSVersion::MACOS_VER_13_1_PLUS:
|
||||
return _macos_13_1_plus;
|
||||
case MacOSVersion::MACOS_VER_13_2_PLUS:
|
||||
return _macos_13_2_plus;
|
||||
case MacOSVersion::MACOS_VER_13_3_PLUS:
|
||||
return _macos_13_3_plus;
|
||||
case MacOSVersion::MACOS_VER_14_0_PLUS:
|
||||
return _macos_14_0_plus;
|
||||
case MacOSVersion::MACOS_VER_14_4_PLUS:
|
||||
return _macos_14_4_plus;
|
||||
case MacOSVersion::MACOS_VER_15_0_PLUS:
|
||||
|
@ -34,7 +34,7 @@ bool MPSHooks::isOnMacOSorNewer(unsigned major, unsigned minor) const {
|
||||
case 14:
|
||||
switch (minor) {
|
||||
case 0:
|
||||
return is_macos_13_or_newer(MacOSVersion::MACOS_VER_14_0_PLUS);
|
||||
return true;
|
||||
case 4:
|
||||
return is_macos_13_or_newer(MacOSVersion::MACOS_VER_14_4_PLUS);
|
||||
default:
|
||||
@ -42,19 +42,7 @@ bool MPSHooks::isOnMacOSorNewer(unsigned major, unsigned minor) const {
|
||||
return is_macos_13_or_newer(MacOSVersion::MACOS_VER_14_4_PLUS);
|
||||
}
|
||||
case 13:
|
||||
switch (minor) {
|
||||
case 0:
|
||||
return true;
|
||||
case 1:
|
||||
return is_macos_13_or_newer(MacOSVersion::MACOS_VER_13_1_PLUS);
|
||||
case 2:
|
||||
return is_macos_13_or_newer(MacOSVersion::MACOS_VER_13_2_PLUS);
|
||||
case 3:
|
||||
return is_macos_13_or_newer(MacOSVersion::MACOS_VER_13_3_PLUS);
|
||||
default:
|
||||
TORCH_WARN("Can't check whether running on 13.", minor, "+ returning one for 13.3+");
|
||||
return is_macos_13_or_newer(MacOSVersion::MACOS_VER_13_3_PLUS);
|
||||
}
|
||||
return true;
|
||||
default:
|
||||
TORCH_WARN("Checking for unexpected MacOS ", major, ".", minor, " returning false");
|
||||
return false;
|
||||
|
@ -88,14 +88,8 @@ std::string getArrayRefString(const IntArrayRef s);
|
||||
// use has_storage() on the returned tensor to determine if src actually is a view
|
||||
Tensor gatherViewTensor(const Tensor& src, Tensor& dst);
|
||||
Tensor& scatterViewTensor(const Tensor& src, Tensor& output);
|
||||
MPSGraphTensor* castToIHFTypes(MPSGraph* mpsGraph,
|
||||
MPSGraphTensor* inputTensor,
|
||||
const TensorBase& input,
|
||||
bool includesInt64 = false);
|
||||
MPSGraphTensor* castFromIHFTypes(MPSGraph* mpsGraph,
|
||||
MPSGraphTensor* inputTensor,
|
||||
const TensorBase& input,
|
||||
bool includesInt64 = false);
|
||||
MPSGraphTensor* castToIHFTypes(MPSGraph* mpsGraph, MPSGraphTensor* inputTensor, const TensorBase& input);
|
||||
MPSGraphTensor* castFromIHFTypes(MPSGraph* mpsGraph, MPSGraphTensor* inputTensor, const TensorBase& input);
|
||||
|
||||
MPSNDArray* getStridedMPSNDArray(const TensorBase& src, MPSNDArray* srcNDArray);
|
||||
MPSNDArray* getMPSNDArray(const TensorBase& t, const IntArrayRef& sizes = {}, const IntArrayRef& strides = {});
|
||||
@ -435,14 +429,6 @@ inline T* LookUpOrCreateCachedGraph(const std::string& key, std::function<void(M
|
||||
// Common math operations
|
||||
MPSGraphTensor* log1p(MPSGraph* mpsGraph, MPSGraphTensor* inputTensor);
|
||||
|
||||
#define MPS_CHECK_INT64_OP_SUPPORTED(input_tensor, mac_os_13_3_plus, op_name) \
|
||||
if (!mac_os_13_3_plus && input_tensor.scalar_type() == kLong) { \
|
||||
TORCH_WARN_ONCE( \
|
||||
"MPS: no support for int64 for ", \
|
||||
op_name, \
|
||||
", downcasting to a smaller data type (int32/float32). Native support for int64 has been added in macOS 13.3."); \
|
||||
}
|
||||
|
||||
/**
|
||||
* Returns distance from lowest to highest element offset in given tensor.
|
||||
*/
|
||||
@ -618,10 +604,6 @@ inline void runMPSGraph(MPSStream* stream, MPSGraph* graph, NSDictionary* feeds,
|
||||
runMPSGraph(stream, graph, feeds, dictionaryFromPlaceholders(result));
|
||||
}
|
||||
|
||||
inline bool supportsComplex() {
|
||||
return is_macos_13_or_newer(MacOSVersion::MACOS_VER_14_0_PLUS);
|
||||
}
|
||||
|
||||
// MPS yet to support double types, but starting from MacOS 14, supports bfloat16
|
||||
inline bool supportedFloatingType(ScalarType dtype) {
|
||||
return dtype == kFloat || dtype == kHalf || dtype == kBFloat16;
|
||||
@ -633,7 +615,7 @@ inline bool supportedFloatingType(const TensorBase& t) {
|
||||
|
||||
inline bool supportedFloatingOrComplexType(ScalarType dtype) {
|
||||
if (dtype == kComplexFloat || dtype == kComplexHalf) {
|
||||
return supportsComplex();
|
||||
return true;
|
||||
}
|
||||
return supportedFloatingType(dtype);
|
||||
}
|
||||
@ -641,11 +623,6 @@ inline bool supportedFloatingOrComplexType(const TensorBase& t) {
|
||||
return supportedFloatingOrComplexType(t.scalar_type());
|
||||
}
|
||||
|
||||
inline void checkSupportsBFloat16() {
|
||||
TORCH_CHECK_TYPE(is_macos_13_or_newer(MacOSVersion::MACOS_VER_14_0_PLUS),
|
||||
"MPS bfloat16 type is supported on MacOS 14.0 or newer.");
|
||||
}
|
||||
|
||||
inline bool needsGather(const TensorBase& t) {
|
||||
static const bool is_macOS_15_0_or_newer = is_macos_13_or_newer(MacOSVersion::MACOS_VER_15_0_PLUS);
|
||||
return !is_macOS_15_0_or_newer && (!t.is_contiguous() || t.storage_offset());
|
||||
|
@ -89,10 +89,6 @@ void runMPSGraph(MPSStream* mpsStream, MPSGraph* mpsGraph, NSDictionary* feeds,
|
||||
mpsStream->executeMPSGraph(mpsGraph, feeds, results, SyncType::COMMIT_ADAPTIVE);
|
||||
}
|
||||
|
||||
static inline void checkSupportsComplex() {
|
||||
TORCH_CHECK_TYPE(supportsComplex(), "MPS complex types are only supported on MacOS 14.0 or newer.");
|
||||
}
|
||||
|
||||
MPSDataType getMPSDataType(ScalarType scalar_type) {
|
||||
switch (scalar_type) {
|
||||
case ScalarType::Float:
|
||||
@ -100,7 +96,6 @@ MPSDataType getMPSDataType(ScalarType scalar_type) {
|
||||
case ScalarType::Half:
|
||||
return MPSDataTypeFloat16;
|
||||
case ScalarType::BFloat16:
|
||||
checkSupportsBFloat16();
|
||||
return MPSDataTypeBFloat16;
|
||||
case ScalarType::Int:
|
||||
return MPSDataTypeInt32;
|
||||
@ -119,10 +114,8 @@ MPSDataType getMPSDataType(ScalarType scalar_type) {
|
||||
"Cannot convert a float64 Tensor to MPS as the MPS framework doesn't support float64. "
|
||||
"Please use float32 instead.")
|
||||
case ScalarType::ComplexHalf:
|
||||
checkSupportsComplex();
|
||||
return MPSDataTypeComplexFloat16;
|
||||
case ScalarType::ComplexFloat:
|
||||
checkSupportsComplex();
|
||||
return MPSDataTypeComplexFloat32;
|
||||
// Unsigned types
|
||||
case ScalarType::UInt64:
|
||||
@ -140,16 +133,10 @@ MPSDataType getMPSDataType(ScalarType scalar_type) {
|
||||
// #issue 104398441 sortWithTensor and argsortWithTensor has support of
|
||||
// Int32, Half and Float32 types. These utilities are to help cast to these
|
||||
// types.
|
||||
MPSGraphTensor* castToIHFTypes(MPSGraph* mpsGraph,
|
||||
MPSGraphTensor* inputTensor,
|
||||
const TensorBase& input,
|
||||
bool includesInt64) {
|
||||
MPSGraphTensor* castToIHFTypes(MPSGraph* mpsGraph, MPSGraphTensor* inputTensor, const TensorBase& input) {
|
||||
MPSDataType dataType = getMPSDataType(input.scalar_type());
|
||||
bool condition =
|
||||
(dataType != MPSDataTypeInt32) && (dataType != MPSDataTypeFloat32) && (dataType != MPSDataTypeFloat16);
|
||||
if (includesInt64) {
|
||||
condition = condition && (dataType != MPSDataTypeInt64);
|
||||
}
|
||||
bool condition = (dataType != MPSDataTypeInt32) && (dataType != MPSDataTypeFloat32) &&
|
||||
(dataType != MPSDataTypeFloat16) && (dataType != MPSDataTypeInt64);
|
||||
if (condition) {
|
||||
dataType = (dataType & MPSDataTypeFloatBit) ? MPSDataTypeFloat32 : MPSDataTypeInt32;
|
||||
return [mpsGraph castTensor:inputTensor toType:dataType name:@"castInputTensor"];
|
||||
@ -160,16 +147,10 @@ MPSGraphTensor* castToIHFTypes(MPSGraph* mpsGraph,
|
||||
// #issue 104398441 sortWithTensor and argsortWithTensor has support of
|
||||
// Int32, Half and Float32 types. These utilities are to help cast from these
|
||||
// types.
|
||||
MPSGraphTensor* castFromIHFTypes(MPSGraph* mpsGraph,
|
||||
MPSGraphTensor* inputTensor,
|
||||
const TensorBase& input,
|
||||
bool includesInt64) {
|
||||
MPSGraphTensor* castFromIHFTypes(MPSGraph* mpsGraph, MPSGraphTensor* inputTensor, const TensorBase& input) {
|
||||
MPSDataType dataType = getMPSDataType(input.scalar_type());
|
||||
bool condition =
|
||||
(dataType != MPSDataTypeInt32) && (dataType != MPSDataTypeFloat32) && (dataType != MPSDataTypeFloat16);
|
||||
if (includesInt64) {
|
||||
condition = condition && (dataType != MPSDataTypeInt64);
|
||||
}
|
||||
bool condition = (dataType != MPSDataTypeInt32) && (dataType != MPSDataTypeFloat32) &&
|
||||
(dataType != MPSDataTypeFloat16) && (dataType != MPSDataTypeInt64);
|
||||
if (condition) {
|
||||
inputTensor = [mpsGraph castTensor:inputTensor toType:dataType name:@"castInputTensor"];
|
||||
}
|
||||
@ -186,7 +167,6 @@ MPSDataType getMPSScalarType(ScalarType scalar_type) {
|
||||
case ScalarType::Half:
|
||||
return MPSDataTypeFloat16;
|
||||
case ScalarType::BFloat16:
|
||||
checkSupportsBFloat16();
|
||||
return MPSDataTypeBFloat16;
|
||||
case ScalarType::Int:
|
||||
return MPSDataTypeInt32;
|
||||
@ -201,13 +181,11 @@ MPSDataType getMPSScalarType(ScalarType scalar_type) {
|
||||
case ScalarType::Bool:
|
||||
return MPSDataTypeBool;
|
||||
case ScalarType::ComplexHalf:
|
||||
checkSupportsComplex();
|
||||
return MPSDataTypeComplexFloat16;
|
||||
// This is an intentional fallthrough supporting ComplexDouble for Scalar
|
||||
// types as they are casted to Complex64 currently.
|
||||
case ScalarType::ComplexDouble:
|
||||
case ScalarType::ComplexFloat:
|
||||
checkSupportsComplex();
|
||||
return MPSDataTypeComplexFloat32;
|
||||
// Unsigned types
|
||||
case ScalarType::UInt64:
|
||||
@ -267,7 +245,6 @@ std::string scalarToMetalTypeString(const c10::ScalarType& scalar_type) {
|
||||
case ScalarType::Half:
|
||||
return "half";
|
||||
case ScalarType::BFloat16:
|
||||
checkSupportsBFloat16();
|
||||
return "bfloat";
|
||||
case ScalarType::Int:
|
||||
return "int";
|
||||
@ -879,9 +856,7 @@ id<MTLLibrary> MetalShaderLibrary::compileLibrary(const std::string& src) {
|
||||
MTLCompileOptions* options = compile_options;
|
||||
if (!options) {
|
||||
options = [[MTLCompileOptions new] autorelease];
|
||||
// Need 3.0 for atomic oprations, 3.1 introduces bfloat support
|
||||
[options setLanguageVersion:is_macos_13_or_newer(MacOSVersion::MACOS_VER_14_0_PLUS) ? MTLLanguageVersion3_1
|
||||
: MTLLanguageVersion3_0];
|
||||
[options setLanguageVersion:MTLLanguageVersion3_1];
|
||||
if (is_macos_13_or_newer(MacOSVersion::MACOS_VER_15_0_PLUS)) {
|
||||
options.mathMode = fast_math ? MTLMathModeFast : MTLMathModeSafe;
|
||||
options.mathFloatingPointFunctions =
|
||||
|
@ -48,28 +48,11 @@ typedef MPSGraphTensor* (^BinaryOpBlock)(BinaryOpCachedGraph*, MPSGraphTensor*,
|
||||
#define BinaryOpFn(graph, primary, secondary) \
|
||||
MPSGraphTensor*(mps::BinaryOpCachedGraph * graph, MPSGraphTensor * primary, MPSGraphTensor * secondary)
|
||||
|
||||
static inline Tensor legacy_complex_as_view(const Tensor& t) {
|
||||
// Convert non-complex types (and cdouble CPU scalars) to cfloat
|
||||
if (!isComplexType(t.scalar_type()) || t.scalar_type() == kComplexDouble) {
|
||||
return at::view_as_real(t.to(kMPS, kComplexFloat));
|
||||
}
|
||||
return at::view_as_real(t.dim() != 0 ? t : t.to(kMPS));
|
||||
}
|
||||
|
||||
static void binaryOpTensor(const Tensor& self,
|
||||
const Tensor& other,
|
||||
const Tensor& output_,
|
||||
std::string op_name,
|
||||
BinaryOpBlock binaryBlock) {
|
||||
TORCH_CHECK(!(op_name == "power" && !is_macos_13_or_newer(MacOSVersion::MACOS_VER_13_2_PLUS) &&
|
||||
(self.scalar_type() == ScalarType::Long ||
|
||||
(other.scalar_type() == ScalarType::Long &&
|
||||
(self.scalar_type() != ScalarType::Half && self.scalar_type() != ScalarType::Float)))),
|
||||
"MPS: ",
|
||||
op_name,
|
||||
" op with int64 input is supported natively starting from macOS 13.2");
|
||||
TORCH_CHECK_TYPE(!isComplexType(self.scalar_type()) || mps::supportsComplex(),
|
||||
"Complex types are supported starting from MacOS 14.0+");
|
||||
MPSStream* mpsStream = getCurrentMPSStream();
|
||||
|
||||
const bool is_self_scalar = self.dim() == 0;
|
||||
|
@ -51,9 +51,6 @@ inline void dot_check(const Tensor& self, const Tensor& other) {
|
||||
} // namespace mps
|
||||
|
||||
Tensor dot_mps(const Tensor& self, const Tensor& other) {
|
||||
TORCH_CHECK(is_macos_13_or_newer(MacOSVersion::MACOS_VER_14_0_PLUS) || self.scalar_type() != ScalarType::Long,
|
||||
"MPS: dot op doesn't support int64 input on MacOS13")
|
||||
|
||||
using namespace mps;
|
||||
using CachedGraph = MPSBinaryCachedGraph;
|
||||
|
||||
|
@ -124,7 +124,6 @@ static Tensor _mps_convolution_impl(const Tensor& input_t_,
|
||||
IntArrayRef dilation,
|
||||
int64_t groups,
|
||||
std::optional<IntArrayRef> input_shape) {
|
||||
const bool is_macOS_13_2_or_newer = is_macos_13_or_newer(MacOSVersion::MACOS_VER_13_2_PLUS);
|
||||
const bool is_macOS_15_0_or_newer = is_macos_13_or_newer(MacOSVersion::MACOS_VER_15_0_PLUS);
|
||||
Tensor input_t = input_t_;
|
||||
bool is3DConv = input_t.dim() == 5;
|
||||
@ -132,9 +131,6 @@ static Tensor _mps_convolution_impl(const Tensor& input_t_,
|
||||
input_t = input_t.contiguous();
|
||||
}
|
||||
|
||||
TORCH_CHECK(((input_t.dim() < 5) || is_macOS_13_2_or_newer),
|
||||
"Conv3D is only supported on MPS for MacOS_13_2 or newer");
|
||||
|
||||
TORCH_CHECK(isFloatingType(input_t.scalar_type()), "Convolution is supported only for Floating types");
|
||||
|
||||
using namespace at::native::mps;
|
||||
|
@ -60,7 +60,6 @@ static void copy_cast_mps(at::Tensor& dst,
|
||||
outputTensor = [mpsGraph castTensor:outputTensor toType:dstDType name:@"cast"];
|
||||
}
|
||||
if (needs_conj) {
|
||||
TORCH_CHECK(supportsComplex(), "MPS complex tensors conjugation needs MacOS14+");
|
||||
outputTensor = [mpsGraph conjugateWithTensor:outputTensor name:nil];
|
||||
}
|
||||
|
||||
@ -275,24 +274,7 @@ static at::Tensor& copy_kernel_mps(at::Tensor& dst_, const at::Tensor& src_, boo
|
||||
// for GPU to GPU copies we only encode to stream's command buffer (no flushing)
|
||||
stream->copy(sourceBuffer, destBuffer, src.nbytes(), src_byte_offset, dst_byte_offset, profile_id);
|
||||
} else {
|
||||
// Simulate cast to Complex on older MacOS by initializing real and imag parts
|
||||
if (dst_.is_complex() && !supportsComplex()) {
|
||||
if (!src.is_complex()) {
|
||||
at::real(dst_).copy_(src);
|
||||
at::imag(dst_).fill_(0);
|
||||
} else if (src.is_conj() || dst_.is_conj()) {
|
||||
// One cannot take view of conjugated tensor, but for some reason real and imag views are fine
|
||||
// Use this to implement a conjugation
|
||||
at::real(dst_).copy_(at::real(src));
|
||||
if (src.is_conj() != dst_.is_conj()) {
|
||||
at::imag(dst_).copy_(at::neg(at::imag(src)));
|
||||
} else {
|
||||
at::imag(dst_).copy_(at::imag(src));
|
||||
}
|
||||
} else {
|
||||
at::view_as_real(dst_).copy_(at::view_as_real(src));
|
||||
}
|
||||
} else if (dst_byte_offset) {
|
||||
if (dst_byte_offset) {
|
||||
auto maybeCastedSource =
|
||||
at::empty(dst_.sizes(), dst_.scalar_type(), std::nullopt, kMPS, std::nullopt, std::nullopt);
|
||||
auto maybeCastedSourceBuffer = getMTLBufferStorage(maybeCastedSource);
|
||||
|
@ -87,7 +87,6 @@ Tensor& random_mps_impl(Tensor& self,
|
||||
case kFloat:
|
||||
return MPSDataTypeFloat32;
|
||||
case kBFloat16: {
|
||||
checkSupportsBFloat16();
|
||||
return MPSDataTypeBFloat16;
|
||||
}
|
||||
default:
|
||||
|
@ -88,7 +88,6 @@ using namespace mps;
|
||||
|
||||
// TODO: Investigate numerical discrepancies see https://github.com/pytorch/pytorch/issues/120237
|
||||
Tensor& _fft_r2c_mps_out(const Tensor& self, IntArrayRef dim, int64_t normalization, bool onesided, Tensor& out) {
|
||||
TORCH_CHECK(supportsComplex(), "FFT operations are only supported on MacOS 14+");
|
||||
auto key = __func__ + getTensorsStringKey({self, out}) + ":" + getArrayRefString(dim) + ":" +
|
||||
std::to_string(normalization) + ":" + std::to_string(onesided);
|
||||
@autoreleasepool {
|
||||
@ -129,7 +128,6 @@ Tensor& _fft_c2r_mps_out(const Tensor& self,
|
||||
int64_t normalization,
|
||||
int64_t last_dim_size,
|
||||
Tensor& out) {
|
||||
TORCH_CHECK(supportsComplex(), "FFT operations are only supported on MacOS 14+");
|
||||
auto key = __func__ + getTensorsStringKey({self}) + ":" + getArrayRefString(dim) + ":" +
|
||||
std::to_string(normalization) + ":" + std::to_string(last_dim_size);
|
||||
@autoreleasepool {
|
||||
@ -155,7 +153,6 @@ Tensor& _fft_c2r_mps_out(const Tensor& self,
|
||||
}
|
||||
|
||||
Tensor& _fft_c2c_mps_out(const Tensor& self, IntArrayRef dim, int64_t normalization, bool forward, Tensor& out) {
|
||||
TORCH_CHECK(supportsComplex(), "FFT operations are only supported on MacOS 14+");
|
||||
auto key = __func__ + getTensorsStringKey({self}) + ":" + getArrayRefString(dim) + ":" +
|
||||
std::to_string(normalization) + ":" + std::to_string(forward);
|
||||
@autoreleasepool {
|
||||
|
@ -127,15 +127,6 @@ Tensor grid_sampler_2d_mps(const Tensor& input,
|
||||
int64_t interpolation_mode,
|
||||
int64_t padding_mode,
|
||||
bool align_corners) {
|
||||
if (!is_macos_13_or_newer(MacOSVersion::MACOS_VER_13_2_PLUS)) {
|
||||
TORCH_WARN_ONCE("MPS: grid_sampler_2d op is supported natively starting from macOS 13.2. ",
|
||||
"Falling back on CPU. This may have performance implications.");
|
||||
|
||||
return at::grid_sampler_2d(input.to("cpu"), grid.to("cpu"), interpolation_mode, padding_mode, align_corners)
|
||||
.clone()
|
||||
.to("mps");
|
||||
}
|
||||
|
||||
auto in_size = input.sizes();
|
||||
auto grid_size = grid.sizes();
|
||||
auto output = at::empty({in_size[0], in_size[1], grid_size[1], grid_size[2]}, input.options());
|
||||
|
@ -353,14 +353,7 @@ static Tensor& nonzero_out_native_mps(const Tensor& self, Tensor& out_) {
|
||||
}
|
||||
|
||||
Tensor& nonzero_out_mps(const Tensor& self, Tensor& out_) {
|
||||
if (!is_macos_13_or_newer(MacOSVersion::MACOS_VER_14_0_PLUS)) {
|
||||
TORCH_WARN_ONCE("MPS: nonzero op is supported natively starting from macOS 14.0. ",
|
||||
"Falling back on CPU. This may have performance implications.");
|
||||
Tensor out_fallback = nonzero_fallback(self);
|
||||
at::native::resize_output(out_, out_fallback.sizes());
|
||||
out_.copy_(out_fallback);
|
||||
return out_;
|
||||
} else if (self.is_complex()) {
|
||||
if (self.is_complex()) {
|
||||
TORCH_WARN_ONCE("MPS: nonzero op is not supported for complex datatypes. ",
|
||||
"Falling back on CPU. This may have performance implications.");
|
||||
Tensor out_fallback = nonzero_fallback(self);
|
||||
@ -445,11 +438,7 @@ Tensor& nonzero_out_mps(const Tensor& self, Tensor& out_) {
|
||||
}
|
||||
|
||||
Tensor nonzero_mps(const Tensor& self) {
|
||||
if (!is_macos_13_or_newer(MacOSVersion::MACOS_VER_14_0_PLUS)) {
|
||||
TORCH_WARN_ONCE("MPS: nonzero op is supported natively starting from macOS 14.0. ",
|
||||
"Falling back on CPU. This may have performance implications.");
|
||||
return nonzero_fallback(self);
|
||||
} else if (self.is_complex()) {
|
||||
if (self.is_complex()) {
|
||||
TORCH_WARN_ONCE("MPS: nonzero op is not supported for complex datatypes ",
|
||||
"Falling back on CPU. This may have performance implications.");
|
||||
return nonzero_fallback(self);
|
||||
|
@ -152,8 +152,6 @@ static void reduction_out_mps(const Tensor& input_t,
|
||||
const Tensor& output_t,
|
||||
MPSReductionType reduction_type,
|
||||
const std::string& func_name) {
|
||||
bool macOS13_3_plus = is_macos_13_or_newer(MacOSVersion::MACOS_VER_13_3_PLUS);
|
||||
MPS_CHECK_INT64_OP_SUPPORTED(input_t, macOS13_3_plus, func_name);
|
||||
// NS: TODO: get rid of all those shenanigans and just call reduction_op with view tensor
|
||||
bool canSqueezeLastDim = true;
|
||||
IntArrayRef input_shape = input_t.sizes();
|
||||
@ -236,12 +234,10 @@ static void reduction_out_mps(const Tensor& input_t,
|
||||
MPSGraphTensor* castInputTensor = inputTensor;
|
||||
MPSDataType inputCastType = MPSDataTypeInvalid;
|
||||
if (dtype.has_value() &&
|
||||
(dtype.value() == kFloat || dtype.value() == kHalf || dtype.value() == kInt ||
|
||||
(dtype.value() == kLong && macOS13_3_plus))) {
|
||||
(dtype.value() == kFloat || dtype.value() == kHalf || dtype.value() == kInt || dtype.value() == kLong)) {
|
||||
inputCastType = getMPSDataType(dtype.value());
|
||||
} else if (inputScalarType != kInt && inputScalarType != kHalf && inputScalarType != kFloat &&
|
||||
inputScalarType != kComplexFloat && inputScalarType != kComplexHalf &&
|
||||
(inputScalarType != kLong || !macOS13_3_plus)) {
|
||||
inputScalarType != kComplexFloat && inputScalarType != kComplexHalf && inputScalarType != kLong) {
|
||||
inputCastType = getMPSDataType(kFloat);
|
||||
}
|
||||
|
||||
@ -615,9 +611,6 @@ static Tensor std_var_common_impl_mps(const Tensor& input_t,
|
||||
}
|
||||
|
||||
static Tensor median_common_mps(const Tensor& input_t, bool nanmedian) {
|
||||
bool macOS13_3_plus = is_macos_13_or_newer(MacOSVersion::MACOS_VER_13_3_PLUS);
|
||||
MPS_CHECK_INT64_OP_SUPPORTED(input_t, macOS13_3_plus, nanmedian ? "nanmedian" : "median");
|
||||
|
||||
IntArrayRef input_shape = input_t.sizes();
|
||||
int64_t num_in_elements = c10::multiply_integers(input_shape);
|
||||
|
||||
@ -634,8 +627,7 @@ static Tensor median_common_mps(const Tensor& input_t, bool nanmedian) {
|
||||
auto medianCachedGraph =
|
||||
LookUpOrCreateCachedGraph<MedianCachedGraph>(medianKey, [&](auto mpsGraph, auto newCachedGraph) {
|
||||
MPSGraphTensor* inputTensor = mpsGraphRankedPlaceHolder(mpsGraph, input_t);
|
||||
MPSGraphTensor* castInputTensor =
|
||||
castToIHFTypes(mpsGraph, inputTensor, input_t, /*includesInt64=*/macOS13_3_plus);
|
||||
MPSGraphTensor* castInputTensor = castToIHFTypes(mpsGraph, inputTensor, input_t);
|
||||
|
||||
MPSGraphTensor* reshapedTensor = [mpsGraph reshapeTensor:castInputTensor withShape:@[ @-1 ] name:nil];
|
||||
|
||||
@ -693,9 +685,6 @@ static Tensor median_common_mps(const Tensor& input_t, bool nanmedian) {
|
||||
}
|
||||
|
||||
static Tensor min_max_mps_impl(const Tensor& input_t, MPSReductionType reduction_type, const std::string& func_name) {
|
||||
bool macOS13_3_plus = is_macos_13_or_newer(MacOSVersion::MACOS_VER_13_3_PLUS);
|
||||
MPS_CHECK_INT64_OP_SUPPORTED(input_t, macOS13_3_plus, "min_max");
|
||||
|
||||
using CachedGraph = MPSUnaryCachedGraph;
|
||||
|
||||
IntArrayRef input_shape = input_t.sizes();
|
||||
@ -713,8 +702,7 @@ static Tensor min_max_mps_impl(const Tensor& input_t, MPSReductionType reduction
|
||||
MPSGraphTensor* inputTensor = mpsGraphRankedPlaceHolder(mpsGraph, input_t);
|
||||
|
||||
MPSGraphTensor* castOutputTensor = nil;
|
||||
MPSGraphTensor* castInputTensor =
|
||||
castToIHFTypes(mpsGraph, inputTensor, input_t, /*includesInt64=*/macOS13_3_plus);
|
||||
MPSGraphTensor* castInputTensor = castToIHFTypes(mpsGraph, inputTensor, input_t);
|
||||
|
||||
NSArray<NSNumber*>* axes = getTensorAxes(input_t);
|
||||
if (reduction_type == MPSReductionType::MAX) {
|
||||
@ -749,9 +737,6 @@ static void min_max_out_mps(const Tensor& input_t,
|
||||
const Tensor& indices_t,
|
||||
MPSReductionType reduction_type,
|
||||
const std::string& func_name) {
|
||||
bool macOS13_3_plus = is_macos_13_or_newer(MacOSVersion::MACOS_VER_13_3_PLUS);
|
||||
MPS_CHECK_INT64_OP_SUPPORTED(input_t, macOS13_3_plus, "min_max_out");
|
||||
|
||||
if (output_t.numel() == 0) {
|
||||
return;
|
||||
}
|
||||
@ -789,8 +774,7 @@ static void min_max_out_mps(const Tensor& input_t,
|
||||
auto cachedGraph = LookUpOrCreateCachedGraph<CachedGraph>(key, [&](auto mpsGraph, auto newCachedGraph) {
|
||||
MPSGraphTensor* inputTensor = mpsGraphRankedPlaceHolder(mpsGraph, input_t);
|
||||
MPSGraphTensor* outputTensor = nil;
|
||||
MPSGraphTensor* castInputTensor =
|
||||
castToIHFTypes(mpsGraph, inputTensor, input_t, /*includesInt64=*/macOS13_3_plus);
|
||||
MPSGraphTensor* castInputTensor = castToIHFTypes(mpsGraph, inputTensor, input_t);
|
||||
|
||||
if (reduction_type == MPSReductionType::MAX) {
|
||||
outputTensor = [mpsGraph reductionMaximumPropagateNaNWithTensor:castInputTensor axis:(NSInteger)dim_ name:nil];
|
||||
@ -896,9 +880,6 @@ static void argmax_argmin_out_mps(const Tensor& input_t,
|
||||
const std::string& func_name) {
|
||||
using CachedGraph = MPSUnaryCachedGraph;
|
||||
|
||||
bool macOS13_3_plus = is_macos_13_or_newer(MacOSVersion::MACOS_VER_13_3_PLUS);
|
||||
MPS_CHECK_INT64_OP_SUPPORTED(input_t, macOS13_3_plus, "argmax_argmin_out");
|
||||
|
||||
int64_t dim_ = -1;
|
||||
|
||||
if (dim.has_value()) {
|
||||
@ -953,7 +934,7 @@ static void argmax_argmin_out_mps(const Tensor& input_t,
|
||||
|
||||
MPSGraphTensor* castInputTensor = inputTensor;
|
||||
if (inputScalarType != kInt && inputScalarType != kHalf && inputScalarType != kFloat &&
|
||||
(inputScalarType != kLong || !macOS13_3_plus)) {
|
||||
inputScalarType != kLong) {
|
||||
castInputTensor = castMPSTensor(mpsGraph, inputTensor, kFloat);
|
||||
}
|
||||
if (reduction_type == MPSReductionType::MAX) {
|
||||
@ -1282,9 +1263,6 @@ static void all_any_common_impl_mps(const Tensor& input_t,
|
||||
return;
|
||||
}
|
||||
|
||||
bool macOS13_3_plus = is_macos_13_or_newer(MacOSVersion::MACOS_VER_13_3_PLUS);
|
||||
MPS_CHECK_INT64_OP_SUPPORTED(input_t, macOS13_3_plus, op_name);
|
||||
|
||||
int64_t dim_ = maybe_wrap_dim(dim, input_t.dim());
|
||||
native::zero_numel_check_dims(input_t, dim_, op_name.c_str());
|
||||
|
||||
@ -1303,7 +1281,7 @@ static void all_any_common_impl_mps(const Tensor& input_t,
|
||||
auto cachedGraph = LookUpOrCreateCachedGraph<CachedGraph>(key, [&](auto mpsGraph, auto newCachedGraph) {
|
||||
auto inputTensor = mpsGraphRankedPlaceHolder(mpsGraph, input_t);
|
||||
|
||||
auto castInputTensor = castToIHFTypes(mpsGraph, inputTensor, input_t, /*includesInt64=*/macOS13_3_plus);
|
||||
auto castInputTensor = castToIHFTypes(mpsGraph, inputTensor, input_t);
|
||||
// reductionOrWithTensor:axis: will throw an internal assert if number of dimentions is more than 4
|
||||
// See https://github.com/pytorch/pytorch/issues/95538
|
||||
MPSGraphTensor* outputTensor = nil;
|
||||
@ -1369,14 +1347,11 @@ TORCH_IMPL_FUNC(any_all_out_mps)(const Tensor& input_t, const Tensor& output_t)
|
||||
return;
|
||||
}
|
||||
|
||||
bool macOS13_3_plus = is_macos_13_or_newer(MacOSVersion::MACOS_VER_13_3_PLUS);
|
||||
MPS_CHECK_INT64_OP_SUPPORTED(input_t, macOS13_3_plus, "any_all_out");
|
||||
|
||||
@autoreleasepool {
|
||||
std::string key = std::string("any_all_out_mps:") + getTensorsStringKey(input_t);
|
||||
auto cachedGraph = LookUpOrCreateCachedGraph<CachedGraph>(key, [&](auto mpsGraph, auto newCachedGraph) {
|
||||
auto inputTensor = mpsGraphRankedPlaceHolder(mpsGraph, input_t);
|
||||
auto castInputTensor = castToIHFTypes(mpsGraph, inputTensor, input_t, /*includesInt64=*/macOS13_3_plus);
|
||||
auto castInputTensor = castToIHFTypes(mpsGraph, inputTensor, input_t);
|
||||
// reductionOrWithTensor:axes: will throw an internal assert if number of dimentions is more than 4
|
||||
// See https://github.com/pytorch/pytorch/issues/95538
|
||||
if (input_t.dim() > 4) {
|
||||
@ -1420,14 +1395,11 @@ TORCH_IMPL_FUNC(all_all_out_mps)(const Tensor& input_t, const Tensor& output_t)
|
||||
return;
|
||||
}
|
||||
|
||||
bool macOS13_3_plus = is_macos_13_or_newer(MacOSVersion::MACOS_VER_13_3_PLUS);
|
||||
MPS_CHECK_INT64_OP_SUPPORTED(input_t, macOS13_3_plus, "all_all_out");
|
||||
|
||||
@autoreleasepool {
|
||||
std::string key = std::string("all_all_out_mps:") + getTensorsStringKey(input_t);
|
||||
auto cachedGraph = LookUpOrCreateCachedGraph<CachedGraph>(key, [&](auto mpsGraph, auto newCachedGraph) {
|
||||
auto inputTensor = mpsGraphRankedPlaceHolder(mpsGraph, input_t);
|
||||
auto castInputTensor = castToIHFTypes(mpsGraph, inputTensor, input_t, /*includesInt64=*/macOS13_3_plus);
|
||||
auto castInputTensor = castToIHFTypes(mpsGraph, inputTensor, input_t);
|
||||
// reductionAndWithTensor:axes: will throw an internal assert if number of dimentions is more than 4
|
||||
// See https://github.com/pytorch/pytorch/issues/95538
|
||||
if (input_t.ndimension() > 4) {
|
||||
@ -1512,9 +1484,6 @@ static void median_out_mps_common(const Tensor& input_t,
|
||||
Tensor& indices,
|
||||
const std::string& func_name,
|
||||
bool nanmedian) {
|
||||
bool macOS13_3_plus = is_macos_13_or_newer(MacOSVersion::MACOS_VER_13_3_PLUS);
|
||||
MPS_CHECK_INT64_OP_SUPPORTED(input_t, macOS13_3_plus, "median_out");
|
||||
|
||||
int64_t dim_ = maybe_wrap_dim(dim, input_t.dim());
|
||||
native::zero_numel_check_dims(input_t, dim_, "max()");
|
||||
|
||||
@ -1585,8 +1554,7 @@ static void median_out_mps_common(const Tensor& input_t,
|
||||
getTensorsStringKey(indices);
|
||||
auto cachedGraph = LookUpOrCreateCachedGraph<CachedGraph>(key, [&](auto mpsGraph, auto newCachedGraph) {
|
||||
MPSGraphTensor* inputTensor = mpsGraphRankedPlaceHolder(mpsGraph, input_t);
|
||||
MPSGraphTensor* castInputTensor =
|
||||
castToIHFTypes(mpsGraph, inputTensor, input_t, /*includesInt64=*/macOS13_3_plus);
|
||||
MPSGraphTensor* castInputTensor = castToIHFTypes(mpsGraph, inputTensor, input_t);
|
||||
|
||||
MPSGraphTensor* effectiveLengthTensor = nil;
|
||||
if (nanmedian) {
|
||||
|
@ -129,16 +129,8 @@ void computeRepeatIndices(const index_t* repeat_ptr,
|
||||
});
|
||||
}
|
||||
|
||||
Tensor repeat_interleave_mps(const Tensor& repeat_, std::optional<int64_t> output_size) {
|
||||
Tensor repeat_interleave_mps(const Tensor& repeat, std::optional<int64_t> output_size) {
|
||||
Tensor output;
|
||||
Tensor repeat = repeat_;
|
||||
if (repeat.scalar_type() == kLong && !is_macos_13_or_newer(MacOSVersion::MACOS_VER_13_3_PLUS)) {
|
||||
// #103810551: `repeat_interleave_common` uses cumsum to calculate the final shape of output,
|
||||
// which currently doesn't support int64_t as input. Casting internally the indices to int32_t.
|
||||
TORCH_WARN_ONCE(
|
||||
"MPS: no support for int64 repeats mask, casting it to int32. Support has been added in macOS 13.3");
|
||||
repeat = repeat.to(kInt);
|
||||
}
|
||||
AT_DISPATCH_INDEX_TYPES(repeat.scalar_type(), "repeat_interleave_mps", [&]() {
|
||||
output = repeat_interleave_common<index_t, computeRepeatIndices<index_t>>(repeat, output_size);
|
||||
});
|
||||
|
@ -23,125 +23,6 @@ static auto& lib = MetalShaderLibrary::getBundledLibrary();
|
||||
#include <ATen/native/mps/ScanKernel_metallib.h>
|
||||
#endif
|
||||
|
||||
// Generic scan implementation that handles both simple scans and scans with indices
|
||||
static void scan_mps_impl(const Tensor& self,
|
||||
const std::vector<Tensor>& outputs,
|
||||
int64_t dim,
|
||||
const std::string& op_name) {
|
||||
if (outputs[0].numel() == 0) {
|
||||
return;
|
||||
}
|
||||
|
||||
const int64_t ndim = self.dim();
|
||||
const int64_t wrapped_dim = maybe_wrap_dim(dim, ndim);
|
||||
|
||||
// Calculate dimensions for scan operation
|
||||
int64_t row_size = self.size(wrapped_dim);
|
||||
auto sizes = self.sizes();
|
||||
|
||||
bool is_innermost = (wrapped_dim == ndim - 1);
|
||||
|
||||
// Check if all tensors are contiguous
|
||||
bool is_contiguous = self.is_contiguous();
|
||||
for (const auto& output : outputs) {
|
||||
is_contiguous = is_contiguous && output.is_contiguous();
|
||||
}
|
||||
|
||||
uint32_t num_rows, num_orows, num_irows, num_threads;
|
||||
|
||||
if (is_innermost) {
|
||||
// Treat all outer dimensions as a single dimension
|
||||
num_rows = self.numel() / row_size;
|
||||
num_threads = num_rows;
|
||||
} else {
|
||||
// Treat all outer dimensions (i.e. dim_ < dim) as one
|
||||
num_orows = c10::multiply_integers(sizes.begin(), sizes.begin() + wrapped_dim);
|
||||
// Treat all inner dimensions (i.e. dim > dimension) as one
|
||||
num_irows = c10::multiply_integers(sizes.begin() + wrapped_dim + 1, sizes.end());
|
||||
num_threads = num_orows * num_irows;
|
||||
}
|
||||
|
||||
MPSStream* mpsStream = getCurrentMPSStream();
|
||||
dispatch_sync_with_rethrow(mpsStream->queue(), ^() {
|
||||
@autoreleasepool {
|
||||
id<MTLComputeCommandEncoder> computeEncoder = mpsStream->commandEncoder();
|
||||
|
||||
// Choose kernel based on contiguity and dimension
|
||||
std::string kernel_name;
|
||||
if (is_contiguous) {
|
||||
kernel_name =
|
||||
op_name + "_contiguous_" + (is_innermost ? "innermost_" : "outer_") + scalarToMetalTypeString(self);
|
||||
} else {
|
||||
kernel_name = op_name + "_strided_" + scalarToMetalTypeString(self);
|
||||
}
|
||||
|
||||
id<MTLComputePipelineState> scanPSO = lib.getPipelineStateForFunc(kernel_name);
|
||||
|
||||
// this function call is a no-op if MPS Profiler is not enabled
|
||||
getMPSProfiler().beginProfileKernel(scanPSO, op_name, [&]() {
|
||||
std::vector<Tensor> all_tensors = {self};
|
||||
all_tensors.insert(all_tensors.end(), outputs.begin(), outputs.end());
|
||||
return all_tensors;
|
||||
}());
|
||||
|
||||
[computeEncoder setComputePipelineState:scanPSO];
|
||||
|
||||
// Set input tensor
|
||||
mtl_setBuffer(computeEncoder, self, 0);
|
||||
|
||||
// Set output tensors
|
||||
for (size_t i = 0; i < outputs.size(); ++i) {
|
||||
mtl_setBuffer(computeEncoder, outputs[i], i + 1);
|
||||
}
|
||||
|
||||
if (is_contiguous) {
|
||||
// Contiguous kernels
|
||||
if (is_innermost) {
|
||||
if (outputs.size() == 1) {
|
||||
// Simple scan
|
||||
mtl_setArgs<2>(computeEncoder, num_rows, static_cast<uint32_t>(row_size));
|
||||
} else {
|
||||
// Scan with indices
|
||||
mtl_setArgs<3>(computeEncoder, num_rows, static_cast<uint32_t>(row_size));
|
||||
}
|
||||
} else {
|
||||
if (outputs.size() == 1) {
|
||||
// Simple scan
|
||||
mtl_setArgs<2>(computeEncoder, num_orows, num_irows, static_cast<uint32_t>(row_size));
|
||||
} else {
|
||||
// Scan with indices
|
||||
mtl_setArgs<3>(computeEncoder, num_orows, num_irows, static_cast<uint32_t>(row_size));
|
||||
}
|
||||
}
|
||||
} else {
|
||||
// Strided kernels - pass full tensor information
|
||||
if (outputs.size() == 1) {
|
||||
// Simple scan
|
||||
mtl_setArgs<2>(computeEncoder,
|
||||
self.sizes(),
|
||||
self.strides(),
|
||||
outputs[0].strides(),
|
||||
static_cast<uint32_t>(self.ndimension()),
|
||||
static_cast<uint32_t>(wrapped_dim));
|
||||
} else {
|
||||
// Scan with indices
|
||||
mtl_setArgs<3>(computeEncoder,
|
||||
self.sizes(),
|
||||
self.strides(),
|
||||
outputs[0].strides(),
|
||||
outputs[1].strides(),
|
||||
static_cast<uint32_t>(self.ndimension()),
|
||||
static_cast<uint32_t>(wrapped_dim));
|
||||
}
|
||||
}
|
||||
|
||||
mtl_dispatch1DJob(computeEncoder, scanPSO, num_threads);
|
||||
|
||||
getMPSProfiler().endProfileKernel(scanPSO);
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
// Utility function to get 2D grid dimensions for dispatch
|
||||
static std::pair<uint32_t, uint32_t> get_2d_grid_dims(const IntArrayRef& shape, const int64_t dim) {
|
||||
size_t grid_x = 1;
|
||||
@ -375,19 +256,11 @@ static void scan_with_indices_mps_impl(const Tensor& self,
|
||||
} // namespace mps
|
||||
|
||||
void cummax_helper_mps(const Tensor& self, Tensor& values, Tensor& indices, int64_t dim) {
|
||||
if (is_macos_13_or_newer(MacOSVersion::MACOS_VER_14_0_PLUS)) {
|
||||
mps::scan_with_indices_mps_impl(self, values, indices, dim, "cummax");
|
||||
} else {
|
||||
mps::scan_mps_impl(self, {values, indices}, dim, "cummax");
|
||||
}
|
||||
mps::scan_with_indices_mps_impl(self, values, indices, dim, "cummax");
|
||||
}
|
||||
|
||||
void cummin_helper_mps(const Tensor& self, Tensor& values, Tensor& indices, int64_t dim) {
|
||||
if (is_macos_13_or_newer(MacOSVersion::MACOS_VER_14_0_PLUS)) {
|
||||
mps::scan_with_indices_mps_impl(self, values, indices, dim, "cummin");
|
||||
} else {
|
||||
mps::scan_mps_impl(self, {values, indices}, dim, "cummin");
|
||||
}
|
||||
mps::scan_with_indices_mps_impl(self, values, indices, dim, "cummin");
|
||||
}
|
||||
|
||||
Tensor& _logcumsumexp_out_mps(const Tensor& self, int64_t dim, Tensor& result) {
|
||||
@ -402,11 +275,7 @@ Tensor& _logcumsumexp_out_mps(const Tensor& self, int64_t dim, Tensor& result) {
|
||||
return result;
|
||||
}
|
||||
|
||||
if (is_macos_13_or_newer(MacOSVersion::MACOS_VER_14_0_PLUS)) {
|
||||
mps::scan_simple_mps_impl(self, result, wrap_dim, "logcumsumexp");
|
||||
} else {
|
||||
mps::scan_mps_impl(self, {result}, wrap_dim, "logcumsumexp");
|
||||
}
|
||||
mps::scan_simple_mps_impl(self, result, wrap_dim, "logcumsumexp");
|
||||
return result;
|
||||
}
|
||||
|
||||
|
@ -26,9 +26,6 @@ TORCH_IMPL_FUNC(sort_stable_out_mps)
|
||||
const Tensor& indices) {
|
||||
using namespace mps;
|
||||
|
||||
bool macOS13_3_plus = is_macos_13_or_newer(MacOSVersion::MACOS_VER_13_3_PLUS);
|
||||
MPS_CHECK_INT64_OP_SUPPORTED(self, macOS13_3_plus, "sort_stable_out");
|
||||
|
||||
if (self.numel() == 0) {
|
||||
return;
|
||||
}
|
||||
@ -55,8 +52,7 @@ TORCH_IMPL_FUNC(sort_stable_out_mps)
|
||||
auto cachedGraph = LookUpOrCreateCachedGraph<CachedGraph>(key, [&](auto mpsGraph, auto newCachedGraph) {
|
||||
newCachedGraph->selfTensor = mpsGraphRankedPlaceHolder(mpsGraph, getMPSDataType(self), input_shape);
|
||||
|
||||
MPSGraphTensor* castInputTensor =
|
||||
castToIHFTypes(mpsGraph, newCachedGraph->selfTensor, self, /*includesInt64=*/macOS13_3_plus);
|
||||
MPSGraphTensor* castInputTensor = castToIHFTypes(mpsGraph, newCachedGraph->selfTensor, self);
|
||||
MPSGraphTensor* sortedTensor = [mpsGraph sortWithTensor:castInputTensor
|
||||
axis:(NSInteger)dim
|
||||
descending:(BOOL)descending
|
||||
|
@ -297,9 +297,6 @@ static void isin_Tensor_Tensor_out_mps(const Tensor& elements,
|
||||
|
||||
const auto common_type = at::result_type(elements, test_elements);
|
||||
TORCH_CHECK(elements.is_mps() && test_elements.is_mps());
|
||||
TORCH_CHECK(is_macos_13_or_newer(MacOSVersion::MACOS_VER_14_0_PLUS) || supportedFloatingType(common_type),
|
||||
"isin_Tensor_Tensor_out only works on floating types on MPS for pre MacOS_14_0. Received dtype: ",
|
||||
common_type);
|
||||
|
||||
@autoreleasepool {
|
||||
std::string key = op_name + getTensorsStringKey({elements, test_elements}) + std::to_string(invert);
|
||||
|
@ -208,28 +208,12 @@ Tensor& logical_not_out_mps(const Tensor& self, Tensor& output) {
|
||||
}
|
||||
|
||||
Tensor& angle_out_mps(const Tensor& self, Tensor& output) {
|
||||
if (mps::supportsComplex()) {
|
||||
mps::unary_op(self, output, "angle_out_mps", ^MPSGraphTensor*(MPSGraph* mpsGraph, MPSGraphTensor* inputTensor) {
|
||||
auto realPart = [mpsGraph realPartOfTensor:inputTensor name:nil];
|
||||
auto imagPart = [mpsGraph imaginaryPartOfTensor:inputTensor name:nil];
|
||||
return [mpsGraph atan2WithPrimaryTensor:imagPart secondaryTensor:realPart name:nil];
|
||||
});
|
||||
return output;
|
||||
} else {
|
||||
TORCH_CHECK(!self.is_complex(), "MPS does not support angle with complex input on macOS13")
|
||||
mps::unary_op(self, output, "angle_out_mps", ^MPSGraphTensor*(MPSGraph* mpsGraph, MPSGraphTensor* inputTensor) {
|
||||
// On macOS 13 with non-complex input, realPartOfTensor and imaginaryPartOfTensor are
|
||||
// not available, and NaN is not propagated correctly:
|
||||
auto imagPart = [mpsGraph constantWithScalar:0.0 shape:inputTensor.shape dataType:inputTensor.dataType];
|
||||
auto result = [mpsGraph atan2WithPrimaryTensor:imagPart secondaryTensor:inputTensor name:nil];
|
||||
auto nanMask = [mpsGraph isNaNWithTensor:inputTensor name:nil];
|
||||
return [mpsGraph selectWithPredicateTensor:nanMask
|
||||
truePredicateTensor:inputTensor
|
||||
falsePredicateTensor:result
|
||||
name:nil];
|
||||
});
|
||||
return output;
|
||||
}
|
||||
mps::unary_op(self, output, "angle_out_mps", ^MPSGraphTensor*(MPSGraph* mpsGraph, MPSGraphTensor* inputTensor) {
|
||||
auto realPart = [mpsGraph realPartOfTensor:inputTensor name:nil];
|
||||
auto imagPart = [mpsGraph imaginaryPartOfTensor:inputTensor name:nil];
|
||||
return [mpsGraph atan2WithPrimaryTensor:imagPart secondaryTensor:realPart name:nil];
|
||||
});
|
||||
return output;
|
||||
}
|
||||
|
||||
Tensor angle_mps(const Tensor& self) {
|
||||
@ -362,7 +346,6 @@ static void cumulative_op_impl(const Tensor& self,
|
||||
const Tensor& result,
|
||||
MPSCumulativeOpType cumulativeOpType,
|
||||
const std::string& op_name) {
|
||||
bool macOS13_3_plus = is_macos_13_or_newer(MacOSVersion::MACOS_VER_13_3_PLUS);
|
||||
auto nDims = self.dim();
|
||||
auto wrapped_dim = maybe_wrap_dim(dim, nDims);
|
||||
TORCH_CHECK(wrapped_dim >= 0 && wrapped_dim < std::max(1LL, self.ndimension()),
|
||||
@ -381,11 +364,6 @@ static void cumulative_op_impl(const Tensor& self,
|
||||
bool castInputData = (isIntegralType(input.scalar_type(), true) && input.scalar_type() != ScalarType::Int &&
|
||||
input.scalar_type() != ScalarType::Long);
|
||||
|
||||
TORCH_CHECK(macOS13_3_plus || input.scalar_type() != ScalarType::Long,
|
||||
"MPS does not support ",
|
||||
op_name,
|
||||
" op with int64 input. Support has been added in macOS 13.3");
|
||||
|
||||
mps::unary_op(
|
||||
input, result, op_name + std::to_string(dim), ^MPSGraphTensor*(MPSGraph* mpsGraph, MPSGraphTensor* inputTensor) {
|
||||
if (castInputData) {
|
||||
@ -440,17 +418,9 @@ TORCH_IMPL_FUNC(sgn_out_mps)(const Tensor& self, const Tensor& output) {
|
||||
|
||||
Tensor& conj_physical_out_mps(const Tensor& self, Tensor& result) {
|
||||
TORCH_CHECK(self.is_complex());
|
||||
if (!mps::supportsComplex()) {
|
||||
if (!result.is_same_size(self)) {
|
||||
result.resize_(self.sizes());
|
||||
}
|
||||
at::real(result).copy_(at::real(self));
|
||||
at::imag(result).copy_(at::neg(at::imag(self)));
|
||||
} else {
|
||||
mps::unary_op(self, result, "conj", ^MPSGraphTensor*(MPSGraph* mpsGraph, MPSGraphTensor* inputTensor) {
|
||||
return [mpsGraph conjugateWithTensor:inputTensor name:nil];
|
||||
});
|
||||
}
|
||||
mps::unary_op(self, result, "conj", ^MPSGraphTensor*(MPSGraph* mpsGraph, MPSGraphTensor* inputTensor) {
|
||||
return [mpsGraph conjugateWithTensor:inputTensor name:nil];
|
||||
});
|
||||
return result;
|
||||
}
|
||||
|
||||
|
Reference in New Issue
Block a user