Compare commits

...

4 Commits

Author SHA1 Message Date
4402750a68 Add sequoia ops 2025-05-07 13:40:12 -07:00
21fb7457a9 Add macos15 check 2025-05-07 13:33:50 -07:00
a29349b633 Remove MPSGraph usage from cat_mps_out 2025-05-07 13:30:53 -07:00
51dd0770e4 Synchronize mps backend in the timer 2025-05-05 09:36:25 -07:00
4 changed files with 56 additions and 2 deletions

View File

@ -100,6 +100,7 @@ MPSGraphTensor* castFromIHFTypes(MPSGraph* mpsGraph,
const TensorBase& input,
bool includesInt64 = false);
MPSNDArray* getStridedMPSNDArray(const TensorBase& src, MPSNDArray* srcNDArray);
MPSNDArray* getMPSNDArray(const TensorBase& t, const IntArrayRef& sizes = {}, const IntArrayRef& strides = {});
MPSNDArray* getMPSNDArray(const TensorBase& t, MPSShape* sizes = nil, MPSShape* strides = nil);
// The MPSShape could vary based on memory format

View File

@ -467,7 +467,7 @@ MPSNDArray* getMPSNDArray(const TensorBase& t, MPSShape* sizes, MPSShape* stride
offset:t.storage_offset() * t.element_size()
descriptor:srcTensorDesc] autorelease];
if (strides != nil) {
srcNDArray = [srcNDArray arrayViewWithShape:sizes strides:strides];
srcNDArray = getStridedMPSNDArray(t, srcNDArray);
}
return srcNDArray;
}
@ -476,7 +476,7 @@ MPSNDArray* getMPSNDArray(const TensorBase& t, const IntArrayRef& sizes, const I
return getMPSNDArray(t, getMPSShape(sizes.empty() ? t.sizes() : sizes), strides.empty() ? nil : getMPSShape(strides));
}
static MPSNDArray* getStridedMPSNDArray(const TensorBase& src, MPSNDArray* srcNDArray) {
MPSNDArray* getStridedMPSNDArray(const TensorBase& src, MPSNDArray* srcNDArray) {
auto strides = src.strides();
auto sizes = src.sizes();
auto nStrides = strides.size();

View File

@ -5,6 +5,7 @@
#include <ATen/native/TensorShape.h>
#include <ATen/native/TypeProperties.h>
#include <ATen/native/mps/MPSGraphVenturaOps.h>
#include <ATen/native/mps/MPSGraphSequoiaOps.h>
#include <ATen/native/mps/OperationUtils.h>
#ifndef AT_PER_OPERATOR_HEADERS
@ -137,6 +138,44 @@ TORCH_IMPL_FUNC(topk_out_mps)
}
}
#include <chrono>
static void cat_out_mps_new(std::vector<Tensor>& inputs, int64_t dimension, const Tensor& out) {
using namespace mps;
id<MTLBuffer> outBuffer = getMTLBufferStorage(out);
auto device = MPSDevice::getInstance()->device();
MPSStream* mpsStream = getCurrentMPSStream();
auto computeEncoder = mpsStream->commandEncoder();
auto mpsdimension = out.dim() - dimension - 1;
dispatch_sync_with_rethrow(mpsStream->queue(), ^() {
@autoreleasepool {
mpsStream->endKernelCoalescing();
id<MTLCommandBuffer> commandBuffer = mpsStream->commandBuffer();
auto outNDArray = getMPSNDArray(out, out.sizes(), out.strides());
int offset_index = 0;
for(int i = 0; i < inputs.size(); ++i) {
auto& tensor = inputs[i];
auto ndarray = getMPSNDArray(tensor, tensor.sizes(), tensor.strides());
auto length = tensor.size(dimension);
auto desc = outNDArray.descriptor;
[desc sliceDimension:mpsdimension withSubrange:{.start=static_cast<NSUInteger>(offset_index), .length=static_cast<NSUInteger>(length)}];
offset_index+=length;
auto slice_out = [outNDArray arrayViewWithDescriptor:desc];
auto identity = [[MPSNDArrayIdentity alloc] initWithDevice:device];
[identity encodeToCommandEncoder:computeEncoder
commandBuffer:commandBuffer
sourceArrays:@[ndarray]
destinationArray:slice_out];
}
}
});
}
TORCH_IMPL_FUNC(cat_out_mps)
(const ITensorListRef& inputs,
int64_t dimension,
@ -151,6 +190,7 @@ TORCH_IMPL_FUNC(cat_out_mps)
if (out.numel() == 0) {
return;
}
auto materialized_inputs = inputs.materialize();
auto out_dtype = at::native::result_type(inputs);
@ -246,6 +286,13 @@ TORCH_IMPL_FUNC(cat_out_mps)
return;
}
bool is_macos_15_or_newer = is_macos_13_or_newer(MacOSVersion::MACOS_VER_15_0_PLUS);
if (is_macos_15_or_newer) {
cat_out_mps_new(input_tensors, dimension, out);
return;
}
struct CachedGraph : public MPSCachedGraph {
CachedGraph(MPSGraph* graph) : MPSCachedGraph(graph) {}
std::vector<MPSGraphTensor*> inputTensors_;
@ -318,7 +365,9 @@ TORCH_IMPL_FUNC(cat_out_mps)
for (auto& inputPlaceholder : inputPlaceholders) {
feeds[inputPlaceholder.getMPSGraphTensor()] = inputPlaceholder.getMPSGraphTensorData();
}
runMPSGraph(getCurrentMPSStream(), cachedGraph->graph(), feeds, outputPlaceholder);
}
}

View File

@ -21,6 +21,10 @@ elif torch.xpu.is_available():
def timer() -> float:
torch.xpu.synchronize()
return timeit.default_timer()
elif torch.mps.is_available():
def timer() -> float:
torch.mps.synchronize()
return timeit.default_timer()
elif torch._C._get_privateuse1_backend_name() != "privateuseone":
privateuse1_device_handler = getattr(torch, torch._C._get_privateuse1_backend_name(), None) \
if torch._C._get_privateuse1_backend_name() != "cpu" else None