Compare commits

...

7 Commits

Author SHA1 Message Date
f0621b73f1 [lint] use lintrunner to format aten/src/ATen/native/mps/operations/Pooling.mm by (lintrunner init -> lintrunner -a) 2025-10-22 20:41:29 +08:00
d088d19b41 Update aten/src/ATen/native/mps/operations/Pooling.mm
Co-authored-by: Nikita Shulga <2453524+malfet@users.noreply.github.com>
2025-10-17 12:37:31 +08:00
41f8694999 Refactor Pooling.mm for improved readability 2025-10-16 19:43:13 +08:00
1520a7e7f6 Merge branch 'pytorch:main' into linhaifeng/dev 2025-10-16 19:37:58 +08:00
d45dfe5b25 Merge branch 'pytorch:main' into linhaifeng/dev 2025-10-16 17:01:07 +08:00
a7368610c0 Merge branch 'pytorch:main' into linhaifeng/dev 2025-10-16 15:02:35 +08:00
ea1facbd54 [Fix][MPS]Fix Objective-C memory leaks in MPS operations
- Add autorelease calls to MPSGraph descriptor objects in Pooling.mm
  - Follows the same pattern as PR #154765 for Linear.mm
  - Prevents memory leaks in MPS backend operations
2025-10-16 15:01:33 +08:00

View File

@ -166,20 +166,20 @@ static void pool2d_template(const Tensor& input,
MPSShape* gradOutputShape = is_backward_pass ? getMPSShape(grad_output, memory_format) : nullptr;
auto cachedGraph = LookUpOrCreateCachedGraph<PoolingCachedGraph>(key, [&](auto* mpsGraph, auto* newCachedGraph) {
MPSGraphPooling2DOpDescriptor* desc = [MPSGraphPooling2DOpDescriptor
descriptorWithKernelWidth:kW
kernelHeight:kH
strideInX:dW
strideInY:dH
dilationRateInX:dilationW
dilationRateInY:dilationH
paddingLeft:padW
paddingRight:ceil_mode ? padW * dW : padW
paddingTop:padH
paddingBottom:ceil_mode ? padH * dH : padH
paddingStyle:MPSGraphPaddingStyleExplicit
dataLayout:memory_format == MemoryFormat::ChannelsLast ? MPSGraphTensorNamedDataLayoutNHWC
: MPSGraphTensorNamedDataLayoutNCHW];
auto desc = [[MPSGraphPooling2DOpDescriptor descriptorWithKernelWidth:kW
kernelHeight:kH
strideInX:dW
strideInY:dH
dilationRateInX:dilationW
dilationRateInY:dilationH
paddingLeft:padW
paddingRight:ceil_mode ? padW * dW : padW
paddingTop:padH
paddingBottom:ceil_mode ? padH * dH : padH
paddingStyle:MPSGraphPaddingStyleExplicit
dataLayout:memory_format == MemoryFormat::ChannelsLast
? MPSGraphTensorNamedDataLayoutNHWC
: MPSGraphTensorNamedDataLayoutNCHW] autorelease];
desc.ceilMode = (padW == 0 && padH == 0) ? ceil_mode : false;
if (has_indices) {
desc.returnIndicesMode = MPSGraphPoolingReturnIndicesGlobalFlatten2D;