Files
pytorch/test/test_mps.py
Yuanyuan Chen fdab48a7c1 Enable all PIE rules on ruff (#165814)
This PR enables all PIE rules on ruff, there are already some enabled rules from this family, the new added rules are
```
PIE796  Enum contains duplicate value: {value}
PIE808  Unnecessary start argument in range
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/165814
Approved by: https://github.com/ezyang
2025-10-18 07:36:18 +00:00

12937 lines
553 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

# Owner(s): ["module: mps"]
# ruff: noqa: F841
import io
import sys
import math
import random
import unittest
import warnings
import shutil
import subprocess
import tempfile
import os
import copy
import gc
import threading
import torch
import torch.nn as nn
import torch.nn.functional as F
import itertools
from collections import defaultdict
from torch import inf
from torch.nn import Buffer, Parameter
from torch.testing._internal import opinfo
from torch.testing._internal.common_utils import \
(gradcheck, gradgradcheck, parametrize, run_tests, TestCase, download_file, MACOS_VERSION, IS_CI,
NoTest, skipIfSlowGradcheckEnv, suppress_warnings, serialTest, instantiate_parametrized_tests, xfailIf)
from torch.testing._internal.common_mps import mps_ops_modifier, mps_ops_grad_modifier, mps_ops_error_inputs_modifier
from torch.testing import make_tensor
from torch.testing._internal.common_dtype import get_all_dtypes, integral_types
import torch.backends.mps
from torch.distributions import Uniform, Exponential
from torch.utils._python_dispatch import TorchDispatchMode
from functools import partial
from torch.testing._internal.common_methods_invocations import (
op_db,
UnaryUfuncInfo,
ReductionOpInfo,
SpectralFuncInfo,
BinaryUfuncInfo,
)
from torch.testing._internal.common_device_type import ops, dtypes, instantiate_device_type_tests, OpDTypes
from torch.testing._internal.common_nn import NNTestCase
from torch.testing._internal.common_quantization import _group_quantize_tensor, _dynamically_quantize_per_channel
import numpy as np
import torch
import torch.utils._pytree as pytree
from itertools import product
import operator
test_consistency_op_db = copy.deepcopy(op_db)
test_error_inputs_op_db = copy.deepcopy(op_db)
# Add bicubic2d_aa to test_consistency_op_db
for op in op_db:
if op.name != "_upsample_bilinear2d_aa":
continue
op = copy.deepcopy(op)
op.name = "_upsample_bicubic2d_aa"
op.op = torch.ops.aten._upsample_bicubic2d_aa
test_consistency_op_db.append(op)
break
# Copied from `test_ops.py` for the purposes of duplicating `test_numpy_ref`
_ref_test_ops = tuple(
filter(
lambda op: not isinstance(
op, (UnaryUfuncInfo, ReductionOpInfo, SpectralFuncInfo, BinaryUfuncInfo)
)
and op.ref is not None,
op_db,
)
)
# Same logic as test_cuda.py
if not torch.backends.mps.is_available():
print('MPS not available, skipping tests', file=sys.stderr)
TestCase = NoTest # noqa: F811
NNTestCase = NoTest # noqa: F811
total_memory = int(subprocess.check_output(["sysctl", "-n", "hw.memsize"]))
MPS_UNSUPPORTED_TYPES = [torch.double, torch.cdouble]
MPS_DTYPES = [t for t in get_all_dtypes() if t not in MPS_UNSUPPORTED_TYPES]
# Determine whether to enable MPS memory leak check (uses same code as CUDA).
TEST_MPS_MEM_LEAK_CHECK = os.getenv('PYTORCH_TEST_MPS_MEM_LEAK_CHECK', '0') == '1'
def skipMPSMemoryLeakCheckIf(condition):
def dec(fn):
if getattr(fn, '_do_mps_memory_leak_check', True):
fn._do_mps_memory_leak_check = not condition
return fn
return dec
class MpsMemoryLeakCheck:
def __init__(self, testcase, name=None):
self.name = testcase.id() if name is None else name
self.testcase = testcase
def __enter__(self):
# Performs a gc if required (required if any memory is held)
caching_allocator_mem_allocated = torch.mps.current_allocated_memory()
if caching_allocator_mem_allocated > 0:
gc.collect()
torch.mps.empty_cache()
# Acquires caching allocator and driver statistics before the test is run
self.caching_allocator_before = torch.mps.current_allocated_memory()
self.driver_before = torch.mps.driver_allocated_memory()
def __exit__(self, exc_type, exc_value, traceback):
# Don't check for leaks if an exception was thrown
if exc_type is not None:
return
# Compares caching allocator before/after statistics
# An increase in allocated memory is a discrepancy indicating a possible memory leak
discrepancy_detected = False
caching_allocator_mem_allocated = torch.mps.current_allocated_memory()
if caching_allocator_mem_allocated > self.caching_allocator_before:
discrepancy_detected = True
# Short-circuits if no discrepancy detected
if not discrepancy_detected:
return
# Validates the discrepancy persists after garbage collection and
# is confirmed by the driver API
gc.collect()
torch.mps.empty_cache()
discrepancy_detected = True
# Query memory multiple items to ensure leak was not transient
for n in range(3):
caching_allocator_mem_allocated = torch.mps.current_allocated_memory()
driver_mem_allocated = torch.mps.driver_allocated_memory()
caching_allocator_discrepancy = False
driver_discrepancy = False
if caching_allocator_mem_allocated > self.caching_allocator_before:
caching_allocator_discrepancy = True
if driver_mem_allocated > self.driver_before:
driver_discrepancy = True
if not (caching_allocator_discrepancy or driver_discrepancy):
# Leak was false positive, exit loop
discrepancy_detected = False
break
if caching_allocator_discrepancy and not driver_discrepancy:
# Just raises a warning if the leak is not validated by the driver API
msg = ("MPS caching allocator reports a memory leak not "
f"verified by the driver API in {self.name}! "
f"Caching allocator allocated memory was {self.caching_allocator_before} "
f"and is now reported as {caching_allocator_mem_allocated}. "
f"MPS driver allocated memory was {self.driver_before} and is now {driver_mem_allocated}.")
warnings.warn(msg)
elif caching_allocator_discrepancy and driver_discrepancy:
# A caching allocator discrepancy validated by the driver API is a failure
msg = (f"MPS driver API confirmed a leak in {self.name}! "
f"Caching allocator allocated memory was {self.caching_allocator_before} "
f"and is now reported as {caching_allocator_mem_allocated}. "
f"MPS driver allocated memory was {self.driver_before} and is now {driver_mem_allocated}.")
raise RuntimeError(msg)
class TestAutocastMPS(TestCase):
def test_matmul_autocast(self):
autocast_tensor_A = torch.rand((8, 8), device="mps")
autocast_tensor_B = torch.rand((8, 8), device="mps")
tensor_A = autocast_tensor_A.detach().clone()
tensor_B = autocast_tensor_B.detach().clone()
autocast_output_tensor = torch.empty(8, 8)
output_tensor = autocast_output_tensor.detach().clone()
with torch.autocast(device_type="mps"):
autocast_output_tensor = torch.mm(autocast_tensor_A, autocast_tensor_B)
autocast_output_tensor = torch.mm(autocast_tensor_A, autocast_output_tensor)
output_tensor = torch.mm(tensor_A, tensor_B)
output_tensor = torch.mm(tensor_A, output_tensor)
self.assertEqual(autocast_output_tensor.dtype, torch.float16, "Autocast output tensor was not expected type float16")
self.assertEqual(autocast_output_tensor,
output_tensor.to(torch.float16),
f"Autocast & non-autocast tensors did not match, \
got:\n{autocast_output_tensor} \n{output_tensor.to(torch.float16)}")
@parametrize("dtype", [torch.float16, torch.bfloat16, torch.float32])
def test_scaled_dot_product_attention_autocast(self, dtype):
# Regression test for https://github.com/pytorch/pytorch/issues/141774
query = torch.rand(4, 1, 16, 8, dtype=torch.float32, device="mps")
key = torch.rand(4, 1, 16, 8, dtype=torch.float32, device="mps")
value = torch.rand(4, 1, 16, 8, dtype=dtype, device="mps")
with torch.amp.autocast(device_type="mps"):
y_autocast = F.scaled_dot_product_attention(query, key, value)
y = F.scaled_dot_product_attention(query, key, value.to(torch.float32))
self.assertEqual(y.to(y_autocast.dtype), y_autocast)
def test_conv_transpose3d_autocast_fp32(self):
m = nn.ConvTranspose3d(16, 33, 3, stride=2).to("mps")
x = torch.randn(20, 16, 10, 50, 100, device="mps")
with torch.amp.autocast(device_type="mps"):
y = m(x)
self.assertEqual(y.dtype, torch.float32)
def test_conv3d_autocast(self):
# Regression test for https://github.com/pytorch/pytorch/issues/160415
class Foo(nn.Module):
def __init__(self):
super().__init__()
self.c1 = nn.Conv3d(3, 3, 1)
self.c2 = nn.Conv3d(3, 3, 1)
def forward(self, x):
x = self.c1(x)
x = self.c2(x)
return x
x = torch.randn(2, 3, 4, 4, 4, device="mps")
model = Foo().to("mps")
with torch.amp.autocast(device_type="mps"):
y = model(x)
self.assertEqual(y.dtype, torch.float16)
def test_gradscaler_mps(self):
# big model to force chunking/depth in the gradscaler dispatch
class Model(nn.Module):
def __init__(self):
super().__init__()
self.fc1 = nn.Linear(10, 2048)
self.fc2 = nn.Linear(2048, 2048)
self.fc3 = nn.Linear(2048, 2048)
self.fc4 = nn.Linear(2048, 2048)
self.fc5 = nn.Linear(2048, 5)
self.relu = nn.ReLU()
def forward(self, x):
x = self.relu(self.fc1(x))
x = self.relu(self.fc2(x))
x = self.relu(self.fc3(x))
x = self.relu(self.fc4(x))
return self.fc5(x)
torch.manual_seed(42)
def helper(model_cpu, model_mps, dtype, iterations, batch_size, atol=3e-4, rtol=1e-5):
optimizer_cpu = torch.optim.SGD(model_cpu.parameters(), lr=0.01)
optimizer_mps = torch.optim.SGD(model_mps.parameters(), lr=0.01)
loss_fn = nn.MSELoss()
input_cpu = torch.randn(batch_size, 10)
target_cpu = torch.randn(batch_size, 5)
input_mps = input_cpu.to('mps')
target_mps = target_cpu.to('mps')
scaler_cpu = torch.amp.GradScaler(device="cpu")
scaler_mps = torch.amp.GradScaler(device="mps")
for _ in range(iterations):
optimizer_cpu.zero_grad()
optimizer_mps.zero_grad()
with torch.amp.autocast(device_type="cpu", dtype=dtype):
output_cpu = model_cpu(input_cpu)
loss_cpu = loss_fn(output_cpu, target_cpu)
scaler_cpu.scale(loss_cpu).backward()
scaler_cpu.step(optimizer_cpu)
scaler_cpu.update()
with torch.autocast(device_type="mps", dtype=dtype):
output_mps = model_mps(input_mps)
loss_mps = loss_fn(output_mps, target_mps)
scaler_mps.scale(loss_mps).backward()
scaler_mps.step(optimizer_mps)
scaler_mps.update()
for p_cpu, p_mps in zip(model_cpu.parameters(), model_mps.parameters()):
self.assertEqual(p_mps.cpu(), p_cpu, rtol=rtol, atol=atol)
model_cpu = Model().to('cpu')
model_mps = Model().to('mps')
model_mps.load_state_dict(model_cpu.state_dict())
helper(model_cpu, model_mps, torch.float16, iterations=5, batch_size=4)
helper(model_cpu, model_mps, torch.bfloat16, iterations=5, batch_size=4)
def test_non_fast_path_amp_unscale(self):
torch.manual_seed(42)
class Model(nn.Module):
def __init__(self):
super().__init__()
self.linear1 = nn.Linear(10, 10)
self.linear2 = nn.Linear(10, 10)
def forward(self, x):
x = self.linear1(x)
x = F.relu(x)
x = self.linear2(x)
x = x.mean(dim=1)
return x
cpu_model = Model().to("cpu")
mps_model = copy.deepcopy(cpu_model).to("mps")
cpu_optimizer = torch.optim.SGD(cpu_model.parameters(), lr=0.01)
mps_optimizer = torch.optim.SGD(mps_model.parameters(), lr=0.01)
cpu_scaler = torch.amp.GradScaler(device="cpu")
mps_scaler = torch.amp.GradScaler(device="mps")
def helper(model, optimizer, scaler, device, input, target, apply_grad_transform=False):
optimizer.zero_grad()
with torch.autocast(device_type=device, dtype=torch.bfloat16):
output = model(input)
loss = nn.MSELoss()(output, target)
scaler.scale(loss).backward()
if apply_grad_transform:
for p in model.parameters():
if p.grad is not None and p.grad.dim() >= 2:
p.grad = p.grad.as_strided(p.grad.size(), (1,) * p.grad.dim())
scaler.unscale_(optimizer)
scaler.step(optimizer)
scaler.update()
# CPU forward/backward pass
input_cpu = torch.randn(32, 10, device="cpu")
target_cpu = torch.randn(32, device="cpu")
helper(cpu_model, cpu_optimizer, cpu_scaler, "cpu", input_cpu, target_cpu)
# MPS forward/backward pass
input_mps = input_cpu.to("mps")
target_mps = target_cpu.to("mps")
helper(mps_model, mps_optimizer, mps_scaler, "mps", input_mps, target_mps, apply_grad_transform=True)
updated_linear1_weight_cpu = cpu_model.linear1.weight.detach()
updated_linear2_weight_cpu = cpu_model.linear2.weight.detach()
updated_linear1_weight_mps = mps_model.linear1.weight.detach().cpu()
updated_linear2_weight_mps = mps_model.linear2.weight.detach().cpu()
self.assertEqual(updated_linear1_weight_cpu, updated_linear1_weight_mps, atol=6e-4, rtol=1e-6)
self.assertEqual(updated_linear2_weight_cpu, updated_linear2_weight_mps, atol=6e-4, rtol=1e-6)
# Expand TestCase class with Memory Leak Detection on MPS device
class TestCaseMPS(TestCase):
_do_mps_memory_leak_check = True
def __init__(self, method_name='runTest'):
super().__init__(method_name)
test_method = getattr(self, method_name, None)
if test_method is not None:
# Wraps the tested method if we should do MPS memory check.
if TEST_MPS_MEM_LEAK_CHECK:
if self._do_mps_memory_leak_check:
self.wrap_with_mps_policy(method_name, self.assertLeaksNoMpsTensors)
def assertLeaksNoMpsTensors(self, name=None):
name = self.id() if name is None else name
return MpsMemoryLeakCheck(self, name)
def wrap_with_mps_policy(self, method_name, policy):
test_method = getattr(self, method_name)
setattr(self, method_name, super().wrap_method_with_policy(test_method, policy))
# checks for leaks even if TEST_MPS_MEM_LEAK_CHECK is 0
def wrap_with_mps_memory_check(self, method):
return super().wrap_method_with_policy(method, self.assertLeaksNoMpsTensors)
class TestMemoryLeak(TestCaseMPS):
def test_mps_memory_leak_detection(self):
l = []
@self.wrap_with_mps_memory_check
def no_leak():
pass
# Trigger an intentional memory leak
@self.wrap_with_mps_memory_check
def leak_gpu0():
# increasing to 8MB to force acquiring a new block and overcome blocksize differences across platforms
l.append(torch.randn(1024 * 1024 * 8, device=torch.device("mps")))
no_leak()
# check if a runtime error for memory leak was emitted which would
# confirm whether memory leak detection worked successfully or not.
with self.assertRaisesRegex(RuntimeError, r"MPS driver API confirmed .+"):
leak_gpu0()
def test_copy_cast_no_leak(self):
def step(x):
x = x.to(device='cpu', dtype=torch.float32)
x = x.to(device='mps', dtype=torch.float16)
a = torch.randn(128, 128, device='mps', dtype=torch.float16)
# Warm up / prebuild MPS shaders (otherwise check fails on 13.2)
step(a)
torch.mps.empty_cache()
driver_before = torch.mps.driver_allocated_memory()
step(a)
torch.mps.empty_cache()
driver_after = torch.mps.driver_allocated_memory()
self.assertEqual(driver_before, driver_after, f"Detected {driver_after - driver_before} bytes leak of GPU memory")
class TestPixelShuffle(TestCaseMPS):
def test_pixel_shuffle_unshuffle(self):
def _test_pixel_shuffle_unshuffle_helper(num_input_dims, valid_channels_dim=True,
upscale_factor=None, is_contiguous=True):
def generate_input():
# If valid_channels_dim=False, add 1 to make channels dim indivisible by upscale_factor ** 2.
channels = random.randint(1, 4) * upscale_factor ** 2 + (0 if valid_channels_dim else 1)
height = random.randint(5, 10)
width = random.randint(5, 10)
if num_input_dims == 1:
input = torch.rand(channels, requires_grad=True, device='mps')
assert is_contiguous
elif num_input_dims == 2:
input = torch.rand(width, height, requires_grad=True, device='mps').T
if is_contiguous:
input = input.contiguous()
else:
batch_sizes = [random.randint(1, 3) for _ in range(num_input_dims - 3)]
input = torch.rand(*batch_sizes, channels, width, height, requires_grad=True, device='mps')
input = input.transpose(-1, -2)
if is_contiguous:
input = input.contiguous()
if not is_contiguous and len(input.reshape(-1)) > 0:
assert not input.is_contiguous()
input = input.detach().clone()
input.requires_grad = True
return input
# Function to imperatively ensure pixels are shuffled to the correct locations.
# Used to validate the batch operations in pixel_shuffle.
def _verify_pixel_shuffle(input, output, upscale_factor):
for c in range(output.size(-3)):
for h in range(output.size(-2)):
for w in range(output.size(-1)):
height_idx = h // upscale_factor
weight_idx = w // upscale_factor
channel_idx = (upscale_factor * (h % upscale_factor)) + (w % upscale_factor) + \
(c * upscale_factor ** 2)
self.assertEqual(output[..., c, h, w], input[..., channel_idx, height_idx, weight_idx])
upscale_factor = random.randint(2, 5) if upscale_factor is None else upscale_factor
input = generate_input()
ps = nn.PixelShuffle(upscale_factor)
pus = nn.PixelUnshuffle(downscale_factor=upscale_factor)
if num_input_dims >= 3 and valid_channels_dim and upscale_factor > 0:
output = ps(input)
_verify_pixel_shuffle(input, output, upscale_factor)
output.backward(output.data)
self.assertEqual(input.data, input.grad.data)
# Ensure unshuffle properly inverts shuffle.
unshuffle_output = pus(output)
self.assertEqual(input, unshuffle_output)
else:
self.assertRaises(RuntimeError, lambda: ps(input))
def _test_pixel_unshuffle_error_case_helper(num_input_dims, valid_height_dim=True, valid_width_dim=True,
downscale_factor=None):
downscale_factor = random.randint(2, 5) if downscale_factor is None else downscale_factor
channels = random.randint(1, 4)
# If valid_height_dim=False, add 1 to make height dim indivisible by downscale_factor.
height = random.randint(3, 5) * abs(downscale_factor) + (0 if valid_height_dim else 1)
# If valid_width_dim=False, add 1 to make width dim indivisible by downscale_factor.
width = random.randint(3, 5) * abs(downscale_factor) + (0 if valid_width_dim else 1)
if num_input_dims == 1:
input = torch.rand(channels, requires_grad=True, device='mps')
elif num_input_dims == 2:
input = torch.rand(height, width, requires_grad=True, device='mps')
else:
batch_sizes = [random.randint(1, 3) for _ in range(num_input_dims - 3)]
input = torch.rand(*batch_sizes, channels, height, width, requires_grad=True, device='mps')
pus = nn.PixelUnshuffle(downscale_factor)
self.assertRaises(RuntimeError, lambda: pus(input))
def _test_pixel_shuffle_unshuffle_for_input_dims(num_input_dims):
# For 1D - 2D, this is an error case.
# For 3D - 5D, this is a success case for pixel_shuffle + pixel_unshuffle.
is_contiguous_check = [True, False] if num_input_dims > 1 else [True]
for is_contiguous in is_contiguous_check:
_test_pixel_shuffle_unshuffle_helper(
num_input_dims=num_input_dims, is_contiguous=is_contiguous
)
_test_pixel_shuffle_unshuffle_helper(
num_input_dims=num_input_dims, valid_channels_dim=False, is_contiguous=is_contiguous
)
_test_pixel_shuffle_unshuffle_helper(
num_input_dims=num_input_dims, upscale_factor=0, is_contiguous=is_contiguous
)
_test_pixel_shuffle_unshuffle_helper(
num_input_dims=num_input_dims, upscale_factor=-2, is_contiguous=is_contiguous
)
# Error cases for pixel_unshuffle.
_test_pixel_unshuffle_error_case_helper(num_input_dims=num_input_dims, valid_height_dim=False)
_test_pixel_unshuffle_error_case_helper(num_input_dims=num_input_dims, valid_width_dim=False)
_test_pixel_unshuffle_error_case_helper(num_input_dims=num_input_dims, downscale_factor=0)
_test_pixel_unshuffle_error_case_helper(num_input_dims=num_input_dims, downscale_factor=-2)
def test_pixel_shuffle_unshuffle_1D():
_test_pixel_shuffle_unshuffle_for_input_dims(num_input_dims=1)
def test_pixel_shuffle_unshuffle_2D():
_test_pixel_shuffle_unshuffle_for_input_dims(num_input_dims=2)
def test_pixel_shuffle_unshuffle_3D():
_test_pixel_shuffle_unshuffle_for_input_dims(num_input_dims=3)
def test_pixel_shuffle_unshuffle_4D():
_test_pixel_shuffle_unshuffle_for_input_dims(num_input_dims=4)
def test_pixel_shuffle_unshuffle_5D():
_test_pixel_shuffle_unshuffle_for_input_dims(num_input_dims=5)
test_pixel_shuffle_unshuffle_1D()
test_pixel_shuffle_unshuffle_2D()
test_pixel_shuffle_unshuffle_3D()
test_pixel_shuffle_unshuffle_4D()
test_pixel_shuffle_unshuffle_5D()
class MPSReluTest(TestCaseMPS):
def _npRelu(self, np_features):
return np.maximum(np_features, np.zeros(np_features.shape)).astype(np_features.dtype)
def testNpRelu(self):
torch.testing.assert_close(
np.array([[0., 0.7, 0.0, 0.3, 0.0], [0.1, 0.0, 0.5, 0.0, 0.9]]),
self._npRelu(
np.array([[-0.9, 0.7, -0.5, 0.3, -0.1], [0.1, -0.3, 0.5, -0.7,
0.9]])))
def _testRelu(self, np_features, device):
np_relu = self._npRelu(np_features)
# Convert the numpy array to a PyTorch Tensor,
# and move the Tensor to the CPU/GPU based on the "device" parameter
py_tensor = torch.from_numpy(np_features).to(device)
py_relu = torch.nn.ReLU(inplace=False)(py_tensor)
py_relu_cpu = py_relu.to("cpu")
self.assertEqual(np_relu, py_relu_cpu)
def _testReluInPlace(self, np_features, device):
np_relu = self._npRelu(np_features)
# Convert the numpy array to a PyTorch Tensor,
# and move the Tensor to the CPU/GPU based on the "device" parameter
py_tensor = torch.from_numpy(np_features).to(device)
py_relu = torch.nn.ReLU(inplace=True)(py_tensor)
py_relu_cpu = py_relu.to("cpu")
self.assertEqual(np_relu, py_relu_cpu)
# Inplace Relu modifies the initial input and it should match the output of Relu
self.assertEqual(np_relu, py_tensor.to("cpu"))
def testNumbersCPU(self):
for t in [np.int32]:
# Force execution on CPU even if a GPU kernel is available for the type.
self._testRelu(
np.array([[-9, 7, -5, 3, -1], [1, -3, 5, -7, 9]]).astype(t),
device="cpu")
self._testReluInPlace(
np.array([[-9, 7, -5, 3, -1], [1, -3, 5, -7, 9]]).astype(t),
device="cpu")
def testNumbersGPU(self):
for t in [np.float16, np.float32]:
self._testRelu(
np.array([[-9, 7, -5, 3, -1], [1, -3, 5, -7, 9]]).astype(t),
device="mps")
self._testReluInPlace(
np.array([[-9, 7, -5, 3, -1], [1, -3, 5, -7, 9]]).astype(t),
device="mps")
self._testRelu(np.array([]).astype(t), device="mps")
self._testReluInPlace(np.array([]).astype(t), device="mps")
class MatmulTest(TestCaseMPS):
def _helper(self, shape_tensor_1, shape_tensor_2, expand_tensor_1_shape=None, expand_tensor_2_shape=None):
if expand_tensor_1_shape:
tensor1_mps = torch.randn(shape_tensor_1, device="mps").expand(expand_tensor_1_shape)
else:
tensor1_mps = torch.randn(shape_tensor_1, device="mps")
if expand_tensor_2_shape:
tensor2_mps = torch.randn(shape_tensor_2, device="mps").expand(expand_tensor_2_shape)
else:
tensor2_mps = torch.randn(shape_tensor_2, device="mps")
tensor1_cpu = tensor1_mps.to("cpu")
tensor2_cpu = tensor2_mps.to("cpu")
matmul_cpu = torch.matmul(tensor1_cpu, tensor2_cpu)
matmul_mps = torch.matmul(tensor1_mps, tensor2_mps)
self.assertEqual(matmul_cpu, matmul_mps.to("cpu"))
def test_vector_x_vector(self):
# uses `dot`
self._helper(3, 3)
def test_matrix_x_vector(self):
# uses `addmv`
self._helper((3, 4), 4)
def test_batched_matrix_x_broadcasted_vector(self):
self._helper((10, 3, 4), 4)
def test_batched_matrix_x_batched_matrix(self):
# uses `bmm.out`
self._helper((10, 3, 4), (10, 4, 5))
def test_batched_matrix_x_broadcasted_matrix(self):
self._helper((10, 3, 4), (4, 5))
def test_large_matmul(self):
# Issue: #141909
tensor1_mps = torch.randn(1, 1, 72250, dtype=torch.half)
tensor2_mps = torch.randn(1, 72250, 1, dtype=torch.half)
matmul_mps = torch.matmul(tensor1_mps, tensor2_mps)
tensor1_cpu = tensor1_mps.to("cpu")
tensor2_cpu = tensor2_mps.to("cpu")
matmul_cpu = torch.matmul(tensor1_cpu, tensor2_cpu)
self.assertEqual(matmul_cpu, matmul_mps.to("cpu"))
class MPSLeakyReluTest(TestCaseMPS):
def _npLeakyRelu(self, np_features, negative_slope=0.1):
return np.maximum(np_features, negative_slope * np_features).astype(np_features.dtype)
def testNpLeakyRelu(self):
torch.testing.assert_close(
np.array([[-0.09, 0.7, -0.05, 0.3, -0.01],
[0.1, -0.03, 0.5, -0.07, 0.9]]),
self._npLeakyRelu(
np.array([[-0.9, 0.7, -0.5, 0.3, -0.1], [0.1, -0.3, 0.5, -0.7,
0.9]]),
negative_slope=0.1))
def _testLeakyRelu(self, shape, dtype, negative_slope, contiguous):
cpu_x = torch.randn(shape, device='cpu', dtype=dtype)
mps_x = cpu_x.detach().clone().to('mps')
if not contiguous and not (0 in shape or len(shape) < 2):
# Transposing will make the tensor non-contiguous
cpu_x = cpu_x.transpose(0, 1)
mps_x = mps_x.transpose(0, 1)
assert not mps_x.is_contiguous()
cpu_x.requires_grad_()
mps_x.requires_grad_()
relu_op = torch.nn.LeakyReLU(negative_slope)
cpu_leaky_relu = relu_op(cpu_x)
mps_leaky_relu = relu_op(mps_x)
torch.testing.assert_close(cpu_leaky_relu, mps_leaky_relu.to('cpu'))
# test backward pass
cpu_grad = torch.ones_like(cpu_leaky_relu)
mps_grad = cpu_grad.to('mps')
mps_leaky_relu.backward(gradient=mps_grad)
cpu_leaky_relu.backward(gradient=cpu_grad)
assert cpu_x.grad is not None # Check that the grad is well-populated
self.assertEqual(cpu_x.grad, mps_x.grad)
def testNumbersCPU(self):
for t in [torch.float, torch.half]:
for shape in [[], (0,), (0, 3), (4,), (4, 3), (5, 4, 3)]:
for contiguous in [True, False]:
self._testLeakyRelu(shape,
dtype=t,
negative_slope=0.2,
contiguous=contiguous)
class TestAvgPool(TestCaseMPS):
def _sum_pool2d(self, x, kernel_size):
windows = torch.nn.functional.unfold(x, kernel_size=kernel_size, stride=kernel_size)
return torch.sum(windows, dim=1)
def _sum_pool3d(self, x, kernel_size):
# Because unfold does not support 3D sliding window we will split tensor to multiple tensors and calculate sum
h = kernel_size[0]
splited_x = [t.sum(0) for t in x.split(h) if t.size(0) == h]
# sum_pool2d assumes tensor in (1, 1, n, m) view, so unsqueeze two times
splited_x = [self._sum_pool2d(t.unsqueeze(0).unsqueeze(0), kernel_size[1:]) for t in splited_x]
joined_x = torch.cat(splited_x)
return joined_x.view(1, joined_x.numel())
def _avg_pool2d(self, x, kernel_size):
size = reduce(operator.mul, kernel_size) # noqa: F821
return self._sum_pool2d(x, kernel_size) / size
def _avg_pool3d(self, x, kernel_size):
size = reduce(operator.mul, kernel_size) # noqa: F821
return self._sum_pool3d(x, kernel_size) / size
def test_avg_pool2d_with_zero_divisor(self):
self.assertRaisesRegex(RuntimeError, "divisor must be not zero",
lambda: F.avg_pool2d(torch.zeros(3, 3, 3), (2, 2), divisor_override=0))
def test_doubletensor_avg_pool2d_with_divisor(self):
n, m = 3, 3
input = torch.rand(1, 1, n, m)
for i in range(1, n + 1):
for j in range(1, m + 1):
for divisor in [1, 7, i * j]:
actual = F.avg_pool2d(input[0], (i, j), divisor_override=divisor)
actual = actual.view(1, actual.numel())
expected = self._sum_pool2d(input, (i, j)) / divisor
self.assertEqual(actual, expected, rtol=0, atol=1e-5)
def test_avg_pool2d_ceil_mode(self):
# Regression test for gh-36977
x = 10 * torch.randn((1, 16, 4, 4))
y = torch.nn.functional.avg_pool2d(
x, ceil_mode=True, count_include_pad=True, kernel_size=(1, 2),
padding=(0, 1), stride=2)
self.assertFalse(torch.isnan(y).any())
y = torch.nn.functional.avg_pool2d(
x.to('mps'), ceil_mode=True, count_include_pad=True, kernel_size=(1, 2),
padding=(0, 1), stride=2)
self.assertFalse(torch.isnan(y).any())
# Test some cases for avg_pool2d which used to mismatch CPU results.
# Addresses this issue: https://github.com/pytorch/pytorch/issues/160743
def test_avg_pool2d_ceil_mode_mismatch(self):
sizes = [
(4, 2, 3),
(5, 2, 3),
(50, 2, 3),
(4, 1, 2, 3),
(4, 4, 2, 3),
(2, 2, 4, 6),
(5, 40, 60),
(2, 2, 40, 60),
]
kwargs = dict(kernel_size=[1, 3],
stride=[2, 3],
ceil_mode=True,
divisor_override=7)
for input_size in sizes:
model = torch.nn.AvgPool2d(**kwargs)
x = torch.arange(math.prod(input_size), dtype=torch.float).reshape(input_size)
out_cpu = model(x)
out_mps = model(x.to("mps"))
msg = f'{input_size=}, {kwargs=}'
self.assertEqual(out_mps, out_cpu, msg=msg)
class TestMPS(TestCaseMPS):
def test_exp(self, device="mps", dtype=torch.float):
for v in (2, -2) + ((1j, 1 + 1j) if dtype.is_complex else ()):
b = torch.arange(18, dtype=dtype, device=device) / 3 * math.pi
a = torch.tensor(v, dtype=dtype, device="mps") * b
self.compare_with_numpy(torch.exp, np.exp, a)
@xfailIf(MACOS_VERSION > 15.0)
def test_conv_raises_error(self, device='mps', dtype=torch.float):
conv = nn.Conv1d(1, 65537, 3, padding=1).to('mps')
x = torch.ones([1, 1, 3])
with self.assertRaises(NotImplementedError):
y = conv(x.to("mps"))
@xfailIf(MACOS_VERSION < 15.1)
def test_conv_high_channel_size(self):
out_channels = 65537
weight = torch.randn(out_channels, 1, 1)
x = torch.ones([1, 1, 1])
y_cpu = F.conv1d(x.to("cpu"), weight.to("cpu"))
y_mps = F.conv1d(x.to("mps"), weight.to("mps"))
self.assertEqual(y_cpu, y_mps)
def test_triu_inf(self, device="mps", dtype=torch.float):
for diag in [-1, 0, 1]:
mask = torch.full((3, 6, 6), float("-inf"))
mask_mps = mask.detach().clone().to('mps')
cpu_ref = torch.triu(mask, diagonal=diag)
mps_out = torch.triu(mask_mps, diagonal=diag)
self.assertEqual(cpu_ref, mps_out)
def test_exp1(self, device="mps", dtype=torch.float):
input = torch.tensor([-0.1, 1.0, -0.9, 0.1], device=device, dtype=dtype)
output = torch.exp(input)
output_cpu = torch.exp(input.cpu())
# If exponentWithTensor: MPS call is used on M1 running 14.5 test will fail with
# Mismatched elements: 3 / 4 (75.0%)
# Greatest absolute difference: 1.1920928955078125e-07 at index (3,) (up to 1e-08 allowed)
# Greatest relative difference: 1.0786502002702036e-07 at index (3,) (up to 1e-08 allowed)
self.assertEqual(output, output_cpu, atol=1e-8, rtol=1e-8)
def test_exp_strided_output(self):
x = torch.rand((256, 10), device='mps')
x_cpu = x.to("cpu")
x = x.permute(1, 0)
x_cpu = x_cpu.permute(1, 0)
res = x.exp()
res_cpu = x_cpu.exp()
self.assertEqual(res, res_cpu)
def _testLeakyRelu(self, np_features, negative_slope, device):
cpu_x = torch.from_numpy(np_features).requires_grad_()
mps_x = torch.from_numpy(np_features).to('mps').requires_grad_()
relu_op = torch.nn.LeakyReLU(negative_slope)
cpu_leaky_relu = relu_op(cpu_x)
mps_leaky_relu = relu_op(mps_x)
torch.testing.assert_close(cpu_leaky_relu, mps_leaky_relu.to('cpu'))
# test backward pass
cpu_grad = torch.ones_like(cpu_leaky_relu)
mps_grad = cpu_grad.to('mps')
cpu_leaky_relu.backward(gradient=cpu_grad)
mps_leaky_relu.backward(gradient=mps_grad)
torch.testing.assert_close(cpu_x.grad, mps_x.grad.to('cpu'))
def testNumbersGPU(self):
for t in [np.float32]:
self._testLeakyRelu(
np.array([[-9, 7, -5, 3, -1], [1, -3, 5, -7, 9]]).astype(t),
negative_slope=0.1,
device="mps")
def test_fill(self):
def helper(val, shape, dtype):
tensor = torch.zeros(shape, device='mps', dtype=dtype)
tensor_mps = tensor.fill_(val)
tensor_0 = torch.zeros(shape, device='cpu', dtype=dtype)
tensor_cpu = tensor_0.fill_(val)
self.assertEqual(tensor_mps, tensor_cpu)
helper(0, [1024], torch.float32)
helper(0.2, [2, 3], torch.float32)
helper(0.2 + 0.5j, [2, 3], torch.complex64)
def test_fill_storage_offset(self):
shape = [2, 10]
val = 0.2
tensor = torch.ones(shape, device="mps")
tensor_mps = tensor[:][1].fill_(val)
tensor_0 = torch.ones(shape, device="cpu")
tensor_cpu = tensor_0[:][1].fill_(val)
self.assertEqual(tensor_mps, tensor_cpu)
self.assertEqual(tensor, tensor_0)
shape = [1, 10]
val = 0.0
tensor = torch.ones(shape, device="mps")
val_tensor_mps = torch.tensor(val, device="mps")
tensor_mps = tensor[:, 9].fill_(val_tensor_mps)
# Regression test for https://github.com/pytorch/pytorch/issues/114692
tensor[:, 5].fill_(val_tensor_mps)
tensor_0 = torch.ones(shape, device="cpu")
val_tensor_cpu = torch.tensor(val, device="cpu")
tensor_cpu = tensor_0[:, 9].fill_(val_tensor_cpu)
tensor_0[:, 5].fill_(val_tensor_cpu)
self.assertEqual(tensor_mps.to(device="cpu"), tensor_cpu)
self.assertEqual(tensor.to(device="cpu"), tensor_0)
def test_cdist_large(self, device="mps"):
for cm in ['use_mm_for_euclid_dist_if_necessary', 'use_mm_for_euclid_dist', 'donot_use_mm_for_euclid_dist']:
x = torch.randn(100, 10, device=device)
y = torch.randn(100, 10, device=device)
actual = torch.cdist(x, y, p=2, compute_mode=cm)
expected = self._brute_cdist(x, y, p=2)
self.assertEqual(expected, actual)
def test_cdist_large_batch(self, device="mps"):
for cm in ['use_mm_for_euclid_dist_if_necessary', 'use_mm_for_euclid_dist', 'donot_use_mm_for_euclid_dist']:
x = torch.randn(4, 3, 100, 10, device=device)
y = torch.randn(4, 3, 100, 10, device=device)
actual = torch.cdist(x, y, p=2, compute_mode=cm)
expected = self._brute_cdist(x, y, p=2)
self.assertEqual(expected, actual)
def test_cdist_non_contiguous(self, device="mps"):
for cm in ['use_mm_for_euclid_dist', 'donot_use_mm_for_euclid_dist']:
x = torch.randn(5, 7, device=device).mT
y = torch.randn(5, 3, device=device).mT
actual = torch.cdist(x, y, p=2, compute_mode=cm)
expected = self._brute_cdist(x, y, p=2)
self.assertFalse(x.is_contiguous())
self.assertFalse(y.is_contiguous())
self.assertEqual(expected, actual)
x = torch.randn(7, 5, device=device)
y = torch.randn(5, 3, device=device).t()
actual = torch.cdist(x, y, p=2, compute_mode=cm)
expected = self._brute_cdist(x, y, p=2)
self.assertTrue(x.is_contiguous())
self.assertFalse(y.is_contiguous())
self.assertEqual(expected, actual)
x = torch.randn(5, 7, device=device).t()
y = torch.randn(3, 5, device=device)
actual = torch.cdist(x, y, p=2, compute_mode=cm)
expected = self._brute_cdist(x, y, p=2)
self.assertFalse(x.is_contiguous())
self.assertTrue(y.is_contiguous())
self.assertEqual(expected, actual)
def test_cdist_non_contiguous_batch(self, device="mps"):
for cm in ['use_mm_for_euclid_dist', 'donot_use_mm_for_euclid_dist']:
x = torch.randn(4, 3, 2, 5, 7, device=device).mT
y = torch.randn(4, 3, 2, 5, 3, device=device).mT
actual = torch.cdist(x, y, p=2, compute_mode=cm)
expected = self._brute_cdist(x, y, p=2)
self.assertFalse(x.is_contiguous())
self.assertFalse(y.is_contiguous())
self.assertEqual(expected, actual)
x = torch.randn(7, 2, 7, 5, device=device)
y = torch.randn(7, 2, 5, 3, device=device).mT
actual = torch.cdist(x, y, p=2, compute_mode=cm)
expected = self._brute_cdist(x, y, p=2)
self.assertTrue(x.is_contiguous())
self.assertFalse(y.is_contiguous())
self.assertEqual(expected, actual)
x = torch.randn(4, 5, 7, device=device).mT
y = torch.randn(4, 3, 5, device=device)
actual = torch.cdist(x, y, p=2, compute_mode=cm)
expected = self._brute_cdist(x, y, p=2)
self.assertFalse(x.is_contiguous())
self.assertTrue(y.is_contiguous())
self.assertEqual(expected, actual)
def test_cdist_euclidean_large(self, device="mps"):
def _test_euclidean_large_cdist(sizex, sizey=None):
if sizey is None:
sizey = sizex
x = torch.randn(sizex, device=device, dtype=torch.float)
y = torch.randn(sizey, device=device, dtype=torch.float)
eps = 1e-6
# to avoid extremum
x = x - (((x - y) < eps).float() * 2 * eps)
x.requires_grad = True
y.requires_grad = True
dist = torch.cdist(x, y, p=2)
# Do a backward pass to check that it is valid for large
# matrices
loss = dist.sum()
loss.backward()
_test_euclidean_large_cdist((2000, 5))
def test_cdist_same_inputs(self, device="mps"):
# Test to detect issues in cdist gradient calculation
# When the distances are 0
sizex = (1, 27, 32)
for p in [0, 1, 2, 3, 1.5, 2.5, float('inf')]:
x = torch.randn(sizex, device=device, dtype=torch.float)
dist_grad = torch.randn((1, 27, 27), device=device, dtype=torch.float)
y = x.clone()
eps = 1e-6
x.requires_grad = True
d = torch.cdist(x, y)
d.backward(dist_grad)
# Check that the backward pass does not contain invalid
# values such as nan or inf
assert torch.isfinite(x.grad).all()
def _brute_cdist(self, x, y, p=2):
r1 = x.shape[-2]
r2 = y.shape[-2]
if r1 == 0 or r2 == 0:
return torch.empty(r1, r2, device=x.device)
return torch.norm(x[..., None, :] - y[..., None, :, :], p=p, dim=-1)
def test_cdist_norm(self, device="mps"):
for r1 in [3, 4]:
for m in [2, 3]:
for r2 in [4, 6]:
for p in [0, 1, 1.5, 2.5, float('inf')]:
x = torch.randn(r1, m, device=device)
y = torch.randn(r2, m, device=device)
if p == 2:
for cm in ['use_mm_for_euclid_dist', 'donot_use_mm_for_euclid_dist']:
actual = torch.cdist(x, y, p=2, compute_mode=cm)
expected = self._brute_cdist(x, y, p=2)
self.assertEqual(expected, actual, rtol=0, atol=0.02)
else:
actual = torch.cdist(x, y, p=p)
expected = self._brute_cdist(x, y, p=p)
self.assertEqual(expected, actual)
def test_cdist_norm_batch(self, device="mps"):
for r1 in [3, 4]:
for m in [2, 3]:
for r2 in [4, 6]:
for p in [0, 3, 1.5, 2.5, float('inf')]:
x = torch.randn(2, 3, 6, r1, m, device=device)
y = torch.randn(2, 3, 6, r2, m, device=device)
if p == 2:
for cm in ['use_mm_for_euclid_dist', 'donot_use_mm_for_euclid_dist']:
actual = torch.cdist(x, y, p=2, compute_mode=cm)
expected = self._brute_cdist(x, y, p=2)
self.assertEqual(expected, actual, rtol=0, atol=0.02)
else:
actual = torch.cdist(x, y, p=p)
expected = self._brute_cdist(x, y, p=p)
self.assertEqual(expected, actual)
def test_mm(self):
B = torch.ones(5, 6).to("mps")
C = torch.ones(6, 5).to("mps")
D = torch.mm(B, C).cpu()
torch.testing.assert_close(D, torch.full((5, 5), 6.0))
def test_linalg_cross(self):
def helper(dtype):
device = "mps"
if dtype is torch.int32 or dtype is torch.int64:
x = torch.randint(0, 99999, (100, 3, 100), dtype=dtype, device=device)
y = torch.randint(0, 99999, (100, 3, 100), dtype=dtype, device=device)
else:
x = torch.rand(100, 3, 100, dtype=dtype, device=device)
y = torch.rand(100, 3, 100, dtype=dtype, device=device)
x_cpu = x.to("cpu")
y_cpu = y.to("cpu")
res1 = torch.linalg.cross(x, y, dim=1)
res2 = torch.tensor((), dtype=dtype, device=device)
res1_cpu = torch.linalg.cross(x_cpu, y_cpu, dim=1)
res2_cpu = torch.tensor((), dtype=dtype, device="cpu")
torch.linalg.cross(x, y, dim=1, out=res2)
torch.linalg.cross(x_cpu, y_cpu, dim=1, out=res2_cpu)
self.assertEqual(res1, res2)
self.assertEqual(res1, res1_cpu)
self.assertEqual(res2, res2_cpu)
# test for broadcastable inputs
if dtype is torch.int32 or dtype is torch.int64:
x = torch.randint(0, 99999, (1, 3, 2), dtype=dtype, device=device)
y = torch.randint(0, 99999, (4, 3, 1), dtype=dtype, device=device)
else:
x = torch.rand(1, 3, 2, dtype=dtype, device=device)
y = torch.rand(4, 3, 1, dtype=dtype, device=device)
x_cpu = x.to("cpu")
y_cpu = y.to("cpu")
res1 = torch.linalg.cross(x, y, dim=1)
res2 = torch.tensor((), dtype=dtype, device=device)
res1_cpu = torch.linalg.cross(x_cpu, y_cpu, dim=1)
res2_cpu = torch.tensor((), dtype=dtype, device="cpu")
torch.linalg.cross(x, y, dim=1, out=res2)
torch.linalg.cross(x_cpu, y_cpu, dim=1, out=res2_cpu)
self.assertEqual(res1, res2)
self.assertEqual(res1, res1_cpu)
self.assertEqual(res2, res2_cpu)
[helper(dtype) for dtype in [torch.int32, torch.int64, torch.float32]]
def test_cross(self):
a = torch.randn(4, 3, device="mps")
b = torch.randn(4, 3, device="mps")
a_cpu = a.to("cpu")
b_cpu = b.to("cpu")
res = torch.cross(a, b, dim=1)
res_cpu = torch.cross(a_cpu, b_cpu, dim=1)
self.assertEqual(res, res_cpu)
def test_addmm(self):
A = torch.ones(5, 5).to("mps")
B = torch.ones(5, 6).to("mps")
C = torch.ones(6, 5).to("mps")
D = torch.addmm(A, B, C).to("cpu")
torch.testing.assert_close(D, torch.full((5, 5), 7.0))
def test_bmm(self):
batch1_cpu = torch.randn(10, 3, 4)
batch2_cpu = torch.randn(10, 4, 5)
batch1_mps = batch1_cpu.detach().clone().to("mps")
batch2_mps = batch2_cpu.detach().clone().to("mps")
output_cpu = torch.bmm(batch1_cpu, batch2_cpu)
output_mps = torch.bmm(batch1_mps, batch2_mps)
self.assertEqual(output_cpu, output_mps)
self.assertEqual(output_cpu.size(), output_mps.size())
@xfailIf(MACOS_VERSION < 15.0)
@parametrize("dtype", [torch.float16, torch.bfloat16])
def test_large_bmm(self, dtype):
B, M, N = 11, 20064, 128
batch1 = torch.randn(B, M, N, dtype=dtype, device='mps')
batch2 = torch.randn(B, N, M, dtype=dtype, device='mps')
output_mps = torch.bmm(batch1, batch2)
# For performance reasons, check only one(non-first) batch for correctness
# TODO: Check two when https://github.com/pytorch/pytorch/issues/153560 is fixed
batch_idx = torch.randint(1, B, size=()).item()
output_cpu = torch.mm(batch1[batch_idx].cpu(), batch2[batch_idx].cpu())
# Using the low precision comparison for FP16
tol = 1e-2 if dtype == torch.float16 else None
self.assertEqual(output_cpu, output_mps[batch_idx], atol=tol, rtol=tol)
@parametrize("dtype", [torch.float32, torch.float16, torch.bfloat16])
def test_take_along_dim(self, dtype):
x = torch.tensor([[-5.], [0.], [5.]], dtype=dtype)
inds = torch.tensor([[0], [1], [2]])
ref = torch.take_along_dim(x, inds, 0)
x_mps = x.detach().clone().to('mps')
inds_mps = inds.detach().clone().to('mps')
res = torch.take_along_dim(x_mps, inds_mps, 0)
self.assertEqual(res, ref)
def test_addr(self):
A = torch.ones(5, 10).to("mps")
B = torch.ones(5).to("mps")
C = torch.ones(10).to("mps")
D = torch.addr(A, B, C).to("cpu")
torch.testing.assert_close(D, torch.full((5, 10), 2.0))
def test_trace(self):
M_cpu = torch.randn(3, 3)
M_mps = M_cpu.detach().clone().to("mps")
output_cpu = torch.trace(M_cpu)
output_mps = torch.trace(M_mps)
self.assertEqual(output_cpu, output_mps)
self.assertEqual(output_cpu.size(), output_mps.size())
def test_addbmm(self):
M_cpu = torch.randn(3, 5)
batch1_cpu = torch.randn(10, 3, 4)
batch2_cpu = torch.randn(10, 4, 5)
M_mps = M_cpu.detach().clone().to("mps")
batch1_mps = batch1_cpu.detach().clone().to("mps")
batch2_mps = batch2_cpu.detach().clone().to("mps")
output_cpu = torch.addbmm(M_cpu, batch1_cpu, batch2_cpu)
output_mps = torch.addbmm(M_mps, batch1_mps, batch2_mps)
self.assertEqual(output_cpu, output_mps)
self.assertEqual(output_cpu.size(), output_mps.size())
def test_baddbmm(self):
def helper(input_shape, batch1_shape, batch2_shape):
M_cpu = torch.randn(input_shape)
batch1_cpu = torch.randn(batch1_shape)
batch2_cpu = torch.randn(batch2_shape)
alpha = 1.2
beta = 0.8
M_mps = M_cpu.detach().clone().to("mps")
batch1_mps = batch1_cpu.detach().clone().to("mps")
batch2_mps = batch2_cpu.detach().clone().to("mps")
output_cpu = torch.baddbmm(M_cpu, batch1_cpu, batch2_cpu, beta=beta, alpha=alpha)
output_mps = torch.baddbmm(M_mps, batch1_mps, batch2_mps, beta=beta, alpha=alpha)
self.assertEqual(output_cpu, output_mps)
self.assertEqual(output_cpu.size(), output_mps.size())
helper(input_shape=(3, 5), batch1_shape=(10, 3, 4), batch2_shape=(10, 4, 5))
helper(input_shape=(10, 3, 5), batch1_shape=(10, 3, 4), batch2_shape=(10, 4, 5))
helper(input_shape=(1, 77, 77), batch1_shape=(8, 77, 64), batch2_shape=(8, 64, 77))
def test_local_scalar_dense_mps(self):
x_cpu = torch.randn(1)
y_mps = x_cpu.to("mps")
torch.testing.assert_close(x_cpu.item(), y_mps.item())
def test_linear_1d_weight(self):
device = 'cpu'
projected = torch.rand([8]).to(device)
x = torch.rand([1, 2, 2, 8]).to(device)
x_mps = x.to('mps')
projected_mps = projected.to('mps')
linear = F.linear(x, projected)
linear_mps = F.linear(x_mps, projected_mps)
self.assertEqual(linear, linear_mps)
projected = torch.rand([1, 8]).to(device)
x = torch.rand([1, 2, 2, 8]).to(device)
x_mps = x.to('mps')
projected_mps = projected.to('mps')
linear = F.linear(x, projected)
linear_mps = F.linear(x_mps, projected_mps)
self.assertEqual(linear, linear_mps)
def test_linear_bias(self):
def helper(bias_shape):
device = "cpu"
x = torch.randn(2, 2, 2, 64, device=device)
linear = torch.nn.Linear(64, 4, device=device)
linear.bias = torch.nn.Parameter(torch.randn(bias_shape, dtype=torch.float32, device=device))
y = linear(x)
device = "mps"
x_mps = x.to(device)
linear.to(device)
y_mps = linear(x_mps)
self.assertEqual(y, y_mps)
helper(())
helper((2, 4))
def test_linear_errors(self):
# Mixed CPU<->MPS tensors
size = (3, 3)
# Unsupported dtypes
with self.assertRaisesRegex(RuntimeError, "does not support linear for non-float weights"):
torch.nn.functional.linear(torch.rand(size, device='mps'),
torch.randint(-10, 10, size, dtype=torch.int8, device='mps'))
# Weights on wrong device
with self.assertRaisesRegex(RuntimeError, "argument weight is on cpu but expected on mps"):
torch.nn.functional.linear(torch.rand(size, device='mps'),
torch.rand(size, device='cpu'))
# Input on wrong device
with self.assertRaisesRegex(RuntimeError, "argument input is on cpu but expected on mps"):
torch.nn.functional.linear(torch.rand(size, device='cpu'),
torch.rand(size, device='mps'))
def test_linear_non_contiguous(self):
# Regression test for https://github.com/pytorch/pytorch/issues/161640
# Slice tensors to force non-contiguity
large_weight = torch.randn(12, 8, device='mps')
weight_sliced = large_weight[::2, ::1]
weight_contiguous_equiv = weight_sliced.contiguous()
input_s = torch.randn(2, 8, device='mps')
result_sliced = torch.nn.functional.linear(input_s, weight_sliced)
result_contig = torch.nn.functional.linear(input_s, weight_contiguous_equiv)
self.assertEqual(result_contig, result_sliced)
def _linear_helper(self, in_features, out_features, shape, bias=True, backward_pass=False):
cpu_linear = torch.nn.Linear(in_features=in_features, out_features=out_features, device="cpu", bias=bias)
mps_linear = torch.nn.Linear(in_features=in_features, out_features=out_features, device="mps", bias=bias)
# Use the same weights and bias as the ones from the cpu
mps_linear.weight.data = cpu_linear.weight.data.detach().clone().to("mps")
if bias:
mps_linear.bias.data = cpu_linear.bias.data.detach().clone().to("mps")
linear_mps_input = torch.randn(shape).to('mps')
linear_cpu_input = linear_mps_input.detach().clone().to('cpu')
if backward_pass:
linear_mps_input = linear_mps_input.requires_grad_()
linear_cpu_input = linear_cpu_input.requires_grad_()
linear_cpu_output = cpu_linear(linear_cpu_input)
linear_mps_output = mps_linear(linear_mps_input)
self.assertEqual(linear_cpu_output, linear_mps_output.to('cpu'))
self.assertEqual(linear_cpu_output.size(), linear_mps_output.size())
if backward_pass:
cpu_grad = torch.rand_like(linear_cpu_output, requires_grad=True)
grad = cpu_grad.detach().to('mps').requires_grad_()
linear_cpu_output.backward(gradient=cpu_grad, create_graph=True)
linear_mps_output.backward(gradient=grad, create_graph=True)
self.assertEqual(linear_cpu_input.grad.size(), linear_mps_input.grad.size())
self.assertEqual(linear_cpu_input.grad, linear_mps_input.grad.to("cpu"), atol=8e-04, rtol=10.4e-05)
self.assertEqual(cpu_linear.weight.grad.size(), mps_linear.weight.grad.size())
self.assertEqual(cpu_linear.weight.grad, mps_linear.weight.grad.to("cpu"), atol=8e-04, rtol=10.4e-05)
if bias:
self.assertEqual(cpu_linear.bias.grad.size(), mps_linear.bias.grad.size())
self.assertEqual(cpu_linear.bias.grad, mps_linear.bias.grad.to("cpu"), atol=8e-04, rtol=10.4e-05)
# test gradgrad
x_grad_out = torch.rand_like(linear_cpu_input)
x_grad_out_mps = x_grad_out.to("mps")
w_grad_out = torch.rand_like(cpu_linear.weight)
w_grad_out_mps = w_grad_out.to("mps")
linear_cpu_input.grad.detach().zero_()
linear_mps_input.grad.detach().zero_()
cpu_linear.weight.grad.detach().zero_()
mps_linear.weight.grad.detach().zero_()
if bias:
b_grad_out = torch.rand_like(cpu_linear.bias)
b_grad_out_mps = b_grad_out.to("mps")
cpu_linear.bias.grad.detach().zero_()
mps_linear.bias.grad.detach().zero_()
linear_cpu_input.grad.backward(x_grad_out, retain_graph=True)
linear_mps_input.grad.backward(x_grad_out_mps, retain_graph=True)
cpu_linear.weight.grad.backward(w_grad_out, retain_graph=True)
mps_linear.weight.grad.backward(w_grad_out_mps, retain_graph=True)
if bias:
cpu_linear.bias.grad.backward(b_grad_out, retain_graph=True)
mps_linear.bias.grad.backward(b_grad_out_mps, retain_graph=True)
self.assertEqual(cpu_grad.grad, grad.grad)
self.assertEqual(linear_cpu_input.grad, linear_mps_input.grad)
self.assertEqual(cpu_linear.weight.grad, mps_linear.weight.grad)
if bias:
self.assertEqual(cpu_linear.bias.grad, mps_linear.bias.grad)
def test_linear1D(self):
self._linear_helper(in_features=2, out_features=3, shape=([2]), bias=True, backward_pass=False)
def test_linear1D_backward(self):
self._linear_helper(in_features=2, out_features=3, shape=([2]), bias=True, backward_pass=True)
def test_linear2D(self):
self._linear_helper(in_features=2, out_features=3, shape=((4, 2)), bias=True, backward_pass=False)
def test_linear2D_backward(self):
self._linear_helper(in_features=2, out_features=3, shape=((4, 2)), bias=True, backward_pass=True)
def test_linear2D_no_bias(self):
self._linear_helper(in_features=2, out_features=3, shape=((4, 2)), bias=False, backward_pass=False)
def test_linear2D_no_bias_backward(self):
self._linear_helper(in_features=2, out_features=3, shape=((4, 2)), bias=False, backward_pass=True)
def test_linear3D(self):
self._linear_helper(in_features=2, out_features=3, shape=((4, 5, 2)), bias=True, backward_pass=False)
def test_linear3D_backward(self):
self._linear_helper(in_features=2, out_features=3, shape=((4, 5, 2)), bias=True, backward_pass=True)
def test_linear3D_no_bias(self):
self._linear_helper(in_features=2, out_features=3, shape=((4, 5, 2)), bias=True, backward_pass=False)
def test_linear3D_no_bias_backward(self):
self._linear_helper(in_features=2, out_features=3, shape=((4, 5, 2)), bias=True, backward_pass=True)
def test_linear_large(self):
# Regression test for https://github.com/pytorch/pytorch/issues/122045
x_cpu = torch.randn(9, 1024, 1, device='cpu')
w_cpu = torch.randn(50304, 1, device='cpu')
x_mps = x_cpu.detach().clone().to('mps')
w_mps = w_cpu.detach().clone().to('mps')
out_cpu = F.linear(x_cpu, w_cpu, None)
out_mps = F.linear(x_mps, w_mps, None)
self.assertEqual(out_cpu, out_mps)
def test_uniform(self):
low = torch.zeros(5, 5, requires_grad=True)
high = (torch.ones(5, 5) * 3).requires_grad_()
low_1d = torch.zeros(1, requires_grad=True)
high_1d = (torch.ones(1) * 3).requires_grad_()
self.assertEqual(Uniform(low, high).sample().size(), (5, 5))
self.assertEqual(Uniform(low, high).sample((7,)).size(), (7, 5, 5))
self.assertEqual(Uniform(low_1d, high_1d).sample().size(), (1,))
self.assertEqual(Uniform(low_1d, high_1d).sample((1,)).size(), (1, 1))
self.assertEqual(Uniform(0.0, 1.0).sample((1,)).size(), (1,))
# Check log_prob computation when value outside range
uniform = Uniform(low_1d, high_1d, validate_args=False)
above_high = torch.tensor([4.0])
below_low = torch.tensor([-1.0])
self.assertEqual(uniform.log_prob(above_high).item(), -inf)
self.assertEqual(uniform.log_prob(below_low).item(), -inf)
# check cdf computation when value outside range
self.assertEqual(uniform.cdf(below_low).item(), 0)
self.assertEqual(uniform.cdf(above_high).item(), 1)
state = torch.get_rng_state()
rand = low.new(low.size()).uniform_()
torch.set_rng_state(state)
u = Uniform(low, high).rsample()
u.backward(torch.ones_like(u))
self.assertEqual(low.grad, 1 - rand)
self.assertEqual(high.grad, rand)
low.grad.zero_()
high.grad.zero_()
def test_randperm(self, device="mps"):
rng_device = None
for n in (5, 100, 50000, 100000):
for dtype in (torch.long, torch.half, torch.float):
if n > 2049 and dtype == torch.half: # Large n for torch.half will raise an exception, do not test here.
continue
if n > 256 and dtype == torch.bfloat16:
continue
with torch.random.fork_rng(devices=rng_device):
res1 = torch.randperm(n, dtype=dtype, device=device)
res2 = torch.empty(0, dtype=dtype, device=device)
torch.randperm(n, out=res2, dtype=dtype, device=device)
self.assertEqual(res1.cpu().sort().values.long(), torch.arange(n, device=device))
# Default type is long
for n in (100, 10000):
self.assertEqual(torch.randperm(n, device=device).dtype, torch.long)
# randperm of 0 elements is an empty tensor
res1 = torch.randperm(0)
res2 = torch.tensor(5, dtype=dtype, device=device)
torch.randperm(0, out=res2)
self.assertEqual(res1.numel(), 0)
self.assertEqual(res2.numel(), 0)
# Test non-contiguous tensors
for n in (4, 5, 6, 10, 20):
non_contiguous_tensor = torch.zeros((2, 3), dtype=torch.long, device=device).t()
self.assertFalse(non_contiguous_tensor.is_contiguous())
with torch.random.fork_rng(devices=rng_device):
res = torch.randperm(n, dtype=torch.long, device=device)
torch.randperm(n, out=non_contiguous_tensor)
self.assertEqual(res.cpu().sort().values.long(), torch.arange(n, device=device))
# Test forward maxpool2d
def test_max_pool2d(self):
def helper(shape, ks, padding=0, dilation=1, ceil_mode=False, return_indices=False, test_ties=False):
cpu_x = None
if (test_ties):
cpu_x = torch.ones(shape, device='cpu', dtype=torch.float, requires_grad=True)
else:
cpu_x = torch.randn(shape, device='cpu', dtype=torch.float, requires_grad=True)
x = cpu_x.detach().clone().to('mps').requires_grad_()
pool = torch.nn.MaxPool2d(kernel_size=ks, padding=padding, dilation=dilation,
ceil_mode=ceil_mode, return_indices=return_indices)
if (return_indices is False):
y = pool(x)
ref_y = pool(cpu_x)
cpu_grad = torch.ones_like(ref_y)
grad = cpu_grad.to('mps')
y.backward(gradient=grad)
ref_y.backward(gradient=cpu_grad)
self.assertEqual(y, ref_y)
self.assertEqual(x.grad, cpu_x.grad)
else:
y, idx = pool(x)
ref_y, ref_idx = pool(cpu_x)
cpu_grad = torch.ones_like(ref_y)
grad = cpu_grad.to('mps')
y.backward(gradient=grad)
ref_y.backward(gradient=cpu_grad)
self.assertEqual(y, ref_y)
self.assertEqual(idx, ref_idx)
self.assertEqual(x.grad, cpu_x.grad)
# Test with no batch dimension
helper((8, 4, 4), ks=2)
helper((2, 8, 4, 4), ks=2)
helper((1, 1000, 32, 32), ks=4)
helper((1, 1000, 1, 4), ks=(1, 4)) # test for max_pool1d
# Test padding
helper((1, 1000, 32, 32), ks=4, padding=1)
helper((1, 1000, 1, 4), ks=(1, 4), padding=(0, 1)) # test for max_pool1d
# Test dilation
helper((1, 1000, 32, 32), ks=4, dilation=2)
helper((1, 1000, 1, 4), ks=(1, 4), padding=(0, 2)) # test for max_pool1d
# Test ceil mode
helper((1, 1000, 32, 32), ks=4, ceil_mode=True)
helper((1, 1000, 1, 4), ks=(1, 4), ceil_mode=True) # test for max_pool1d
# Test return indices
for test_ties in [False, True]:
# Test with no batch dimension
helper((8, 4, 4), ks=2, return_indices=True, test_ties=test_ties)
helper((2, 8, 4, 4), ks=2, return_indices=True, test_ties=test_ties)
helper((1, 1000, 32, 32), ks=4, return_indices=True, test_ties=test_ties)
helper((1, 1000, 1, 4), ks=(1, 4), return_indices=True, test_ties=test_ties) # test for max_pool1d
# Test padding
helper((1, 1000, 32, 32), ks=4, padding=1, return_indices=True, test_ties=test_ties)
helper((1, 1000, 1, 4), ks=(1, 4), padding=(0, 1),
return_indices=True, test_ties=test_ties) # test for max_pool1d
# Test dilation
helper((1, 1000, 32, 32), ks=4, dilation=2, return_indices=True, test_ties=test_ties)
helper((1, 1000, 1, 4), ks=(1, 4), padding=(0, 2),
return_indices=True, test_ties=test_ties) # test for max_pool1d
# Test ceil mode
helper((1, 1000, 32, 32), ks=4, ceil_mode=True, return_indices=True, test_ties=test_ties)
helper((1, 1000, 1, 4), ks=(1, 4), ceil_mode=True,
return_indices=True, test_ties=test_ties) # test for max_pool1d
def test_adaptive_avg_pool2d_output_size_one(self):
def helper(size, memory_format):
x = torch.randint(1, 10, size, dtype=torch.float, device='mps', requires_grad=True)
if memory_format == 'non_contiguous':
x = x[::2, ::2, ::2, ::2]
else:
x = x.to(memory_format=memory_format)
net = torch.nn.AdaptiveAvgPool2d((1, 1))
out = net(x)
ref_out = x.contiguous().mean((-1, -2)).view((x.size(0), x.size(1), 1, 1))
out.sum().backward() # make sure it doesn't crash
self.assertEqual(out, ref_out)
if memory_format == torch.channels_last:
self.assertTrue(out.is_contiguous(memory_format=torch.channels_last))
c = out.size(1)
self.assertEqual(out.stride(), [c, 1, c, c])
else:
self.assertTrue(out.is_contiguous())
c = out.size(1)
self.assertEqual(out.stride(), [c, 1, 1, 1])
helper((2, 3, 6, 6), torch.contiguous_format)
def test_masked_scatter(self):
def helper(shape):
x_mps = torch.randn(shape, device="mps")
x_cpu = x_mps.detach().clone().cpu()
mask_mps = torch.rand(shape, device="mps") < 0.6
mask_cpu = mask_mps.detach().clone().cpu()
y_mps = torch.randn(shape, device="mps")
y_cpu = y_mps.detach().clone().cpu()
y_mps.masked_scatter_(mask_mps, x_mps)
y_cpu.masked_scatter_(mask_cpu, x_cpu)
self.assertEqual(y_mps, y_cpu)
helper([2, 5])
helper([10, 10])
helper([5, 10, 3])
helper([10, 5, 10, 3])
helper([10, 5, 10, 3, 20])
def test_masked_fill(self):
device = "mps"
dtype = torch.float32
mask_dtype = torch.bool
num_dest = 10
dst = torch.zeros(num_dest, dtype=dtype, device=device)
mask = torch.randint(2, (num_dest,), dtype=mask_dtype, device=device)
val = random.random()
dst2 = torch.zeros(num_dest, dtype=dtype)
mask_cpu = mask.to("cpu")
dst.masked_fill_(mask, val)
for i in range(num_dest):
if mask_cpu[i]:
dst2[i] = val
self.assertEqual(dst.to("cpu"), dst2, atol=0, rtol=0)
# Regression test for https://github.com/pytorch/pytorch/issues/143477
# Allocating 48x25x1024x1024 tensor crashes on MacOS-13
mask_bool = torch.triu(torch.ones(1024, 1024, device=device), diagonal=1).bool()
attn_scores = torch.rand(48, 25, 1024, 1024, device=device)
attn_scores.masked_fill_(mask_bool, 0)
def test_masked_fill__non_contiguous(self):
shape = (3, 5)
x_mps = torch.randn(shape, device="mps")
x_cpu = x_mps.detach().clone().cpu()
mask_mps = torch.zeros(shape, device="mps", dtype=torch.bool)
mask_cpu = mask_mps.detach().clone().cpu()
x_mps_strided = x_mps.T
x_cpu_strided = x_cpu.T
x_mps_strided.masked_fill_(mask_mps.T, float("-inf"))
x_cpu_strided.masked_fill_(mask_cpu.T, float("-inf"))
self.assertEqual(x_mps_strided, x_cpu_strided)
self.assertFalse((x_mps_strided == float("-inf")).any())
def test_nhwc_operation(self):
def helper(shape, channels_last=False):
import numpy as np
np.random.seed(332)
arr = (256 - 128) * np.random.random_sample(size=shape) + 128
cpu_x = torch.tensor(arr, device='cpu', dtype=torch.float, requires_grad=True)
if (channels_last):
cpu_x = cpu_x.to(memory_format=torch.channels_last)
cpu_x.retain_grad()
x = cpu_x.detach().clone().to('mps').requires_grad_()
# This passes
self.assertEqual(x, cpu_x)
helper((2, 2, 2, 2), True)
# Test forward batch norm
def test_batch_norm(self):
def helper(shape, eps=1, momentum=0.1, wts=False, training=False, channels_last=False,
track_running_stats=True, test_module=False):
import numpy as np
np.random.seed(332)
arr = (256 - 128) * np.random.random_sample(size=shape) + 128
cpu_x = torch.tensor(arr, device='cpu', dtype=torch.float, requires_grad=True)
if (channels_last):
cpu_x = cpu_x.to(memory_format=torch.channels_last)
cpu_x.retain_grad()
x = cpu_x.detach().clone().to('mps').requires_grad_()
mean_shape = [shape[1]]
cpu_running_mean = None
cpu_running_var = None
running_mean = None
running_var = None
if (track_running_stats):
mean_arr = (240 - 140) * np.random.random_sample(size=mean_shape) + 140
cpu_running_mean = torch.tensor(mean_arr, device='cpu', dtype=torch.float)
var_arr = 32 * np.random.random_sample(size=mean_shape)
cpu_running_var = torch.tensor(var_arr, device='cpu', dtype=torch.float)
running_mean = cpu_running_mean.detach().clone().to('mps')
running_var = cpu_running_var.detach().clone().to('mps')
weight = None
cpu_weight = None
bias = None
cpu_bias = None
if (wts):
cpu_weight = torch.randn(mean_shape, device='cpu', dtype=torch.float, requires_grad=True)
weight = cpu_weight.detach().clone().to('mps').requires_grad_()
cpu_bias = torch.randn(mean_shape, device='cpu', dtype=torch.float, requires_grad=True)
bias = cpu_bias.detach().clone().to('mps').requires_grad_()
y = None
ref_y = None
if (not test_module):
y = torch.nn.functional.batch_norm(x, running_mean, running_var,
weight=weight,
bias=bias,
training=training,
momentum=momentum, eps=eps)
ref_y = torch.nn.functional.batch_norm(cpu_x, cpu_running_mean, cpu_running_var,
weight=cpu_weight,
bias=cpu_bias,
training=training,
momentum=momentum, eps=eps)
else:
batchnorm_op = None
mps_batchnorm_op = None
if (len(shape) == 3):
batchnorm_op = torch.nn.BatchNorm1d(shape[1],
eps=eps,
momentum=momentum,
affine=wts,
track_running_stats=track_running_stats,
device='cpu')
mps_batchnorm_op = torch.nn.BatchNorm1d(shape[1],
eps=eps,
momentum=momentum,
affine=wts,
track_running_stats=track_running_stats,
device='mps')
elif (len(shape) == 4):
batchnorm_op = torch.nn.BatchNorm2d(shape[1],
eps=eps,
momentum=momentum,
affine=wts,
track_running_stats=track_running_stats,
device='cpu')
mps_batchnorm_op = torch.nn.BatchNorm2d(shape[1],
eps=eps,
momentum=momentum,
affine=wts,
track_running_stats=track_running_stats,
device='mps')
elif (len(shape) == 5):
batchnorm_op = torch.nn.BatchNorm3d(shape[1],
eps=eps,
momentum=momentum,
affine=wts,
track_running_stats=track_running_stats,
device='cpu')
mps_batchnorm_op = torch.nn.BatchNorm3d(shape[1],
eps=eps,
momentum=momentum,
affine=wts,
track_running_stats=track_running_stats,
device='mps')
if (track_running_stats):
batchnorm_op.running_mean = cpu_running_mean
batchnorm_op.running_var = cpu_running_var
mps_batchnorm_op.running_mean = running_mean
mps_batchnorm_op.running_var = running_var
if (wts):
batchnorm_op.weight = torch.nn.Parameter(cpu_weight)
batchnorm_op.bias = torch.nn.Parameter(cpu_bias)
mps_batchnorm_op.weight = torch.nn.Parameter(weight)
mps_batchnorm_op.bias = torch.nn.Parameter(bias)
ref_y = batchnorm_op(cpu_x)
y = mps_batchnorm_op(x)
self.assertEqual(y, ref_y)
if (not test_module):
self.assertEqual(running_mean, cpu_running_mean)
self.assertEqual(running_var, cpu_running_var)
else:
self.assertEqual(mps_batchnorm_op.running_mean, batchnorm_op.running_mean)
self.assertEqual(mps_batchnorm_op.running_var, batchnorm_op.running_var)
cpu_grad = torch.randn(ref_y.shape)
grad = cpu_grad.to('mps')
ref_y.backward(gradient=cpu_grad)
y.backward(gradient=grad)
self.assertEqual(x.grad, cpu_x.grad)
if (wts):
if (not test_module):
self.assertEqual(weight.grad, cpu_weight.grad)
self.assertEqual(bias.grad, cpu_bias.grad)
else:
self.assertEqual(mps_batchnorm_op.weight.grad, batchnorm_op.weight.grad)
self.assertEqual(mps_batchnorm_op.bias.grad, batchnorm_op.bias.grad)
for shape in [(2, 3, 2, 2), (2, 3, 2, 2, 2), (2, 3, 2)]:
for test_module in [False, True]:
for track_running_stats in [True, False]:
for channels_last in [False]:
if (channels_last and len(shape) != 4):
continue
# Running stats must be tracked in eval mode
if (track_running_stats):
helper(shape, eps=0, momentum=1, channels_last=channels_last,
track_running_stats=track_running_stats, test_module=test_module)
helper(shape, channels_last=channels_last,
track_running_stats=track_running_stats, test_module=test_module)
helper(shape, eps=1e-05, momentum=0.1, wts=False, training=False, channels_last=channels_last,
track_running_stats=track_running_stats, test_module=test_module)
helper(shape, eps=0, momentum=1.0, wts=False, training=False, channels_last=channels_last,
track_running_stats=track_running_stats, test_module=test_module)
helper(shape, eps=1, momentum=1, wts=True, training=False, channels_last=channels_last,
track_running_stats=track_running_stats, test_module=test_module)
helper(shape, eps=3, momentum=0.67, wts=True, training=False, channels_last=channels_last,
track_running_stats=track_running_stats, test_module=test_module)
helper(shape, eps=1e-05, momentum=0.1, wts=False, training=True, channels_last=channels_last,
track_running_stats=track_running_stats, test_module=test_module)
helper(shape, eps=0, momentum=1.0, wts=False, training=True, channels_last=channels_last,
track_running_stats=track_running_stats, test_module=test_module)
helper(shape, eps=1, momentum=1, wts=True, training=True, channels_last=channels_last,
track_running_stats=track_running_stats, test_module=test_module)
helper(shape, eps=3, momentum=0.67, wts=True, training=True, channels_last=channels_last,
track_running_stats=track_running_stats, test_module=test_module)
def test_batch_norm_backward(self):
inputs = torch.rand(1, 8, 4, 4, device="mps", requires_grad=True)
x = torch.nn.BatchNorm2d(8).to("mps")
y = torch.nn.BatchNorm2d(8).to("mps")
y.weight.requires_grad = False
y.bias.requires_grad = False
outputs = y(x(inputs))
# This used to crash, see https://github.com/pytorch/pytorch/issues/98602
outputs.sum().backward()
def test_batch_norm_slices(self):
# Regression test for https://github.com/pytorch/pytorch/issues/133520
bn_cpu = nn.BatchNorm2d(100, affine=False, device='cpu')
bn_mps = nn.BatchNorm2d(100, affine=False, device='mps')
x_cpu = torch.randn(100, 100, 35, 45).to('cpu')
x_mps = x_cpu.to('mps')
res_cpu = bn_cpu(x_cpu[5:])
res_mps = bn_mps(x_mps[5:])
self.assertEqual(res_cpu, res_mps)
def test_batch_norm_backward_weight_bias_gradients(self):
# See issue: https://github.com/pytorch/pytorch/issues/156555
N, C, L = 4, 3, 5
x = torch.randn(N, C, L)
y = torch.randn(N, C, L)
bn_cpu = nn.BatchNorm1d(C, affine=True).cpu().train()
bn_mps = nn.BatchNorm1d(C, affine=True).to('mps').train()
bn_mps.load_state_dict(bn_cpu.state_dict())
out_cpu = bn_cpu(x)
out_mps = bn_mps(x.to('mps'))
loss_cpu = ((out_cpu - y) ** 2).mean()
loss_mps = ((out_mps - y.to('mps')) ** 2).mean()
loss_cpu.backward()
loss_mps.backward()
self.assertEqual(bn_cpu.weight.grad, bn_mps.weight.grad, atol=1e-5, rtol=1e-5)
self.assertEqual(bn_cpu.bias.grad, bn_mps.bias.grad, atol=1e-5, rtol=1e-5)
def test_layer_norm_backward(self):
inputs = torch.rand(4, 4, device="mps", requires_grad=True)
x = torch.nn.LayerNorm(4).to("mps")
y = torch.nn.LayerNorm(4).to("mps")
y.weight.requires_grad = False
y.bias.requires_grad = False
outputs = y(x(inputs))
# This used to crash, see https://github.com/pytorch/pytorch/issues/98602
outputs.sum().backward()
def test_norm(self):
a = torch.arange(9, dtype=torch.float, device="mps") - 4
b = a.reshape((3, 3))
a_cpu = torch.arange(9, dtype=torch.float, device="cpu") - 4
b_cpu = a_cpu.reshape((3, 3))
res = torch.norm(a)
res_cpu = torch.norm(a_cpu)
self.assertEqual(res, res_cpu)
res = torch.norm(b)
res_cpu = torch.norm(b_cpu)
self.assertEqual(res, res_cpu)
res = torch.norm(a, float('inf'))
res_cpu = torch.norm(a_cpu, float('inf'))
self.assertEqual(res, res_cpu)
res = torch.norm(b, float('inf'))
res_cpu = torch.norm(b_cpu, float('inf'))
self.assertEqual(res, res_cpu)
c = torch.tensor([[1, 2, 3], [-1, 1, 4]], dtype=torch.float, device="mps")
c_cpu = torch.tensor([[1, 2, 3], [-1, 1, 4]] , dtype=torch.float, device="cpu")
res = torch.norm(c, dim=0)
res_cpu = torch.norm(c_cpu, dim=0)
self.assertEqual(res, res_cpu)
res = torch.norm(c, dim=1)
res_cpu = torch.norm(c_cpu, dim=1)
self.assertEqual(res, res_cpu)
res = torch.norm(c, p=1, dim=1)
res_cpu = torch.norm(c_cpu, p=1, dim=1)
self.assertEqual(res, res_cpu)
d = torch.arange(8, dtype=torch.float, device="mps").reshape(2, 2, 2)
d_cpu = torch.arange(8, dtype=torch.float, device="cpu").reshape(2, 2, 2)
res = torch.norm(d, dim=(1, 2))
res_cpu = torch.norm(d_cpu, dim=(1, 2))
self.assertEqual(res, res_cpu)
res = torch.norm(d[0, :, :]), torch.norm(d[1, :, :])
res_cpu = torch.norm(d_cpu[0, :, :]), torch.norm(d_cpu[1, :, :])
self.assertEqual(res, res_cpu)
def test_linalg_vector_norm(self):
x_mps = torch.tensor([0, 0, 0, 2, 3], dtype=torch.float, device="mps")
x_cpu = x_mps.detach().clone().cpu()
res_mps = torch.linalg.vector_norm(x_mps, ord=0)
res_cpu = torch.linalg.vector_norm(x_cpu, ord=0)
self.assertEqual(res_mps, res_cpu)
a_mps = torch.arange(27, dtype=torch.float, device="mps") - 4
a_cpu = torch.arange(27, dtype=torch.float, device="cpu") - 4
B_mps = a_mps.reshape(3, 3, 3)
B_cpu = a_cpu.reshape(3, 3, 3)
res_mps = torch.linalg.vector_norm(a_mps, ord=3.5)
res_cpu = torch.linalg.vector_norm(a_cpu, ord=3.5)
self.assertEqual(res_mps, res_cpu)
res_mps = torch.linalg.vector_norm(B_mps, ord=3.5)
res_cpu = torch.linalg.vector_norm(B_cpu, ord=3.5)
self.assertEqual(res_mps, res_cpu)
for dim in range(B_mps.dim()):
res_mps = torch.linalg.vector_norm(B_mps, ord=3.5, dim=dim)
res_cpu = torch.linalg.vector_norm(B_cpu, ord=3.5, dim=dim)
self.assertEqual(res_mps, res_cpu)
def test_linalg_lu_factor_ex(self):
from torch.testing._internal.common_utils import make_fullrank_matrices_with_distinct_singular_values
make_fullrank = make_fullrank_matrices_with_distinct_singular_values
make_arg = partial(make_fullrank, device="cpu", dtype=torch.float32)
def run_lu_factor_ex_test(size, *batch_dims, check_errors, atol=1e-5, rtol=1e-6):
input_cpu = make_arg(*batch_dims, size, size)
input_mps = input_cpu.to('mps')
out_cpu = torch.linalg.lu_factor_ex(input_cpu, check_errors=check_errors)
out_mps = torch.linalg.lu_factor_ex(input_mps, check_errors=check_errors)
self.assertEqual(out_cpu, out_mps, atol=atol, rtol=rtol)
out_cpu = torch.linalg.lu_factor_ex(input_cpu.mT, check_errors=check_errors)
out_mps = torch.linalg.lu_factor_ex(input_mps.mT, check_errors=check_errors)
self.assertEqual(out_cpu, out_mps, atol=atol, rtol=rtol)
# test with different even/odd matrix sizes
matrix_sizes = [1, 2, 3, 4]
# even/odd batch sizes
batch_sizes = [1, 2, 4]
for check_errors in [True, False]:
for size in matrix_sizes:
for batch_size in batch_sizes:
run_lu_factor_ex_test(size, batch_size, check_errors=check_errors)
# test >3D matrices
run_lu_factor_ex_test(32, 10, 10, check_errors=False)
run_lu_factor_ex_test(32, 2, 2, 10, 10, check_errors=True)
# big matrix check with batch size > 1
run_lu_factor_ex_test(256, 2, check_errors=False, atol=3e-5, rtol=5e-6)
def test_linalg_solve(self):
from torch.testing._internal.common_utils import make_fullrank_matrices_with_distinct_singular_values
make_fullrank = make_fullrank_matrices_with_distinct_singular_values
make_arg = partial(make_fullrank, device="cpu", dtype=torch.float32)
def run_linalg_solve_test(size, *batch_dims):
A_cpu = make_arg(*batch_dims, size, size)
A_mps = A_cpu.to('mps')
for left in [True, False]:
if left:
b_cpu = torch.randn(*batch_dims, size, 3, device='cpu', dtype=torch.float32)
else:
b_cpu = torch.randn(*batch_dims, 3, size, device='cpu', dtype=torch.float32)
b_mps = b_cpu.to('mps')
# Solve the system
X_cpu = torch.linalg.solve(A_cpu, b_cpu, left=left)
X_mps = torch.linalg.solve(A_mps, b_mps, left=left)
self.assertEqual(X_cpu, X_mps)
# Test with transposed matrices
X_cpu_t = torch.linalg.solve(A_cpu.mT, b_cpu, left=left)
X_mps_t = torch.linalg.solve(A_mps.mT, b_mps, left=left)
self.assertEqual(X_cpu_t, X_mps_t)
# test with different even/odd matrix sizes
matrix_sizes = [1, 2, 3, 4]
# even/odd batch sizes
batch_sizes = [1, 2, 4]
for size in matrix_sizes:
for batch_size in batch_sizes:
run_linalg_solve_test(size, batch_size)
# test >3D matrices
run_linalg_solve_test(32, 10, 10)
run_linalg_solve_test(32, 2, 2, 2, 2, 10, 10)
def test_linalg_solve_singular(self):
# Regression test for https://github.com/pytorch/pytorch/issues/163962
# Explicit singular matrix
A = torch.tensor([[1.0, 2.0], [2.0, 4.0]], device="mps")
b = torch.rand_like(A)
with self.assertRaisesRegex(RuntimeError, "input matrix is singular"):
torch.linalg.solve(A, b)
def test_linalg_solve_with_broadcasting(self):
from functools import partial
import torch
from torch.testing._internal.common_utils import (
make_fullrank_matrices_with_distinct_singular_values
)
make_fullrank = make_fullrank_matrices_with_distinct_singular_values
make_arg = partial(make_fullrank, device="cpu", dtype=torch.float32)
batch_size = 4
size = 3
A_cpu = make_arg(batch_size, size, size)
A_mps = A_cpu.to('mps')
for left in [True, False]:
b_cpu = torch.randn(batch_size, size, device='cpu', dtype=torch.float32)
b_mps = b_cpu.to('mps')
if left:
b_cpu = b_cpu.unsqueeze(-1)
b_mps = b_mps.unsqueeze(-1)
else:
b_cpu = b_cpu.view(batch_size, 1, size)
b_mps = b_mps.view(batch_size, 1, size)
X_cpu = torch.linalg.solve(A_cpu, b_cpu, left=left)
X_mps = torch.linalg.solve(A_mps, b_mps, left=left)
self.assertEqual(X_cpu, X_mps)
X_cpu_t = torch.linalg.solve(A_cpu.mT, b_cpu, left=left)
X_mps_t = torch.linalg.solve(A_mps.mT, b_mps, left=left)
self.assertEqual(X_cpu_t, X_mps_t)
def test_linalg_det(self):
from torch.testing._internal.common_utils import make_fullrank_matrices_with_distinct_singular_values
make_fullrank = make_fullrank_matrices_with_distinct_singular_values
make_arg = partial(make_fullrank, device="cpu", dtype=torch.float32)
def run_det_test(size, *batch_dims):
input_cpu = make_arg(*batch_dims, size, size)
input_mps = input_cpu.to('mps')
out_cpu = torch.linalg.det(input_cpu)
out_mps = torch.linalg.det(input_mps)
self.assertEqual(out_cpu, out_mps)
# non-contiguous matrices
input_cpu_T = input_cpu.mT
input_mps_T = input_mps.mT
out_cpu_T = torch.linalg.det(input_cpu_T)
out_mps_T = torch.linalg.det(input_mps_T)
self.assertEqual(out_cpu_T, out_mps_T)
# test with different even/odd matrix sizes
matrix_sizes = [2, 3, 4]
# even/odd batch sizes
batch_sizes = [1, 2, 4]
for size in matrix_sizes:
for batch_size in batch_sizes:
run_det_test(size, batch_size)
# test >3D matrices
run_det_test(32, 10, 10)
run_det_test(32, 2, 2, 10, 10)
def test_layer_norm(self):
def helper(input_shape, normalized_shape, eps=1e-05, elementwise_affine=True, dtype=torch.float32, non_contiguous=False):
cpu_x = torch.randn(input_shape, device='cpu', dtype=dtype, requires_grad=True)
x = cpu_x.detach().clone().to('mps').requires_grad_()
if non_contiguous:
x = x.mT
cpu_x = cpu_x.mT
normalized_shape[-1], normalized_shape[-2] = normalized_shape[-2], normalized_shape[-1]
cpu_op = torch.nn.LayerNorm(normalized_shape, eps=eps, elementwise_affine=elementwise_affine, device='cpu', dtype=dtype)
mps_op = torch.nn.LayerNorm(normalized_shape, eps=eps, elementwise_affine=elementwise_affine, device='mps', dtype=dtype)
cpu_wt = torch.randn(normalized_shape, device='cpu', dtype=dtype, requires_grad=True)
wt = cpu_wt.detach().clone().to('mps').requires_grad_()
cpu_bias = torch.randn(normalized_shape, device='cpu', dtype=dtype, requires_grad=True)
bias = cpu_bias.detach().clone().to('mps').requires_grad_()
if (elementwise_affine):
cpu_op.weight = torch.nn.Parameter(cpu_wt)
mps_op.weight = torch.nn.Parameter(wt)
cpu_op.bias = torch.nn.Parameter(cpu_bias)
mps_op.bias = torch.nn.Parameter(bias)
cpu_result = cpu_op(cpu_x)
result = mps_op(x)
cpu_grad = torch.randn(cpu_result.shape)
grad = cpu_grad.to('mps')
cpu_result.backward(cpu_grad)
result.backward(grad)
self.assertEqual(result, cpu_result)
self.assertEqual(x.grad, cpu_x.grad)
if (elementwise_affine):
self.assertEqual(mps_op.weight.grad, cpu_op.weight.grad)
self.assertEqual(mps_op.bias.grad, cpu_op.bias.grad)
for (elementwise_affine, non_contiguous) in itertools.product([True, False], [True, False]):
helper((2, 2, 2, 2), [2, 2], elementwise_affine=elementwise_affine, non_contiguous=non_contiguous)
helper((2, 3, 4, 5), [4, 5], elementwise_affine=elementwise_affine, non_contiguous=non_contiguous)
helper((2, 3, 4, 5, 6), [4, 5, 6], elementwise_affine=elementwise_affine, non_contiguous=non_contiguous)
# Regression test for https://github.com/pytorch/pytorch/issues/96113
torch.nn.LayerNorm((16,), elementwise_affine=True).to("mps")(torch.randn(1, 2, 16).to("mps", dtype=torch.float16))
def test_ifft(self):
# See: https://github.com/pytorch/pytorch/issues/124096
device = torch.device("mps")
N = 64
signal = torch.rand(N, device=device)
fft_result = torch.fft.rfft(signal)
ifft_result = torch.fft.irfft(fft_result, n=signal.shape[0])
# Expecting the inverted to yield the original signal
self.assertEqual(ifft_result, signal)
def test_fftfreq(self):
# Regression test for https://github.com/pytorch/pytorch/issues/135223
freq_cpu = torch.fft.fftfreq(10**4, device='cpu')
freq_mps = torch.fft.fftfreq(10**4, device='mps')
self.assertEqual(freq_cpu, freq_mps)
def test_instance_norm(self):
def helper(shape, eps=1, momentum=0.1, wts=False, channels_last=False, track_running_stats=True, test_module=False):
import numpy as np
np.random.seed(332)
arr = (256 - 128) * np.random.random_sample(size=shape) + 128
cpu_x = torch.tensor(arr, device='cpu', dtype=torch.float, requires_grad=True)
if (channels_last):
cpu_x = cpu_x.to(memory_format=torch.channels_last)
cpu_x.retain_grad()
x = cpu_x.detach().clone().to('mps').requires_grad_()
mean_shape = [shape[1]]
cpu_running_mean = None
cpu_running_var = None
running_mean = None
running_var = None
if (track_running_stats):
mean_arr = (240 - 140) * np.random.random_sample(size=mean_shape) + 140
cpu_running_mean = torch.tensor(mean_arr, device='cpu', dtype=torch.float)
var_arr = 32 * np.random.random_sample(size=mean_shape)
cpu_running_var = torch.tensor(var_arr, device='cpu', dtype=torch.float)
running_mean = cpu_running_mean.detach().clone().to('mps')
running_var = cpu_running_var.detach().clone().to('mps')
weight = None
cpu_weight = None
bias = None
cpu_bias = None
if (wts):
cpu_weight = torch.randn(mean_shape, device='cpu', dtype=torch.float, requires_grad=True)
weight = cpu_weight.detach().clone().to('mps').requires_grad_()
cpu_bias = torch.randn(mean_shape, device='cpu', dtype=torch.float, requires_grad=True)
bias = cpu_bias.detach().clone().to('mps').requires_grad_()
y = None
ref_y = None
if (not test_module):
ref_y = torch.nn.functional.instance_norm(cpu_x, cpu_running_mean, cpu_running_var,
weight=cpu_weight,
bias=cpu_bias,
momentum=momentum, eps=eps)
y = torch.nn.functional.instance_norm(x, running_mean, running_var,
weight=weight,
bias=bias,
momentum=momentum, eps=eps)
else:
instancenorm_op = None
mps_instancenorm_op = None
if (len(shape) == 3):
instancenorm_op = torch.nn.InstanceNorm1d(shape[1],
eps=eps,
momentum=momentum,
affine=wts,
track_running_stats=track_running_stats,
device='cpu')
mps_instancenorm_op = torch.nn.InstanceNorm1d(shape[1],
eps=eps,
momentum=momentum,
affine=wts,
track_running_stats=track_running_stats,
device='mps')
elif (len(shape) == 4):
instancenorm_op = torch.nn.InstanceNorm2d(shape[1],
eps=eps,
momentum=momentum,
affine=wts,
track_running_stats=track_running_stats,
device='cpu')
mps_instancenorm_op = torch.nn.InstanceNorm2d(shape[1],
eps=eps,
momentum=momentum,
affine=wts,
track_running_stats=track_running_stats,
device='mps')
elif (len(shape) == 5):
instancenorm_op = torch.nn.InstanceNorm3d(shape[1],
eps=eps,
momentum=momentum,
affine=wts,
track_running_stats=track_running_stats,
device='cpu')
mps_instancenorm_op = torch.nn.InstanceNorm3d(shape[1],
eps=eps,
momentum=momentum,
affine=wts,
track_running_stats=track_running_stats,
device='mps')
if (track_running_stats):
instancenorm_op.running_mean = cpu_running_mean
instancenorm_op.running_var = cpu_running_var
mps_instancenorm_op.running_mean = running_mean
mps_instancenorm_op.running_var = running_var
if (wts):
instancenorm_op.weight = torch.nn.Parameter(cpu_weight)
instancenorm_op.bias = torch.nn.Parameter(cpu_bias)
mps_instancenorm_op.weight = torch.nn.Parameter(weight)
mps_instancenorm_op.bias = torch.nn.Parameter(bias)
ref_y = instancenorm_op(cpu_x)
y = mps_instancenorm_op(x)
self.assertEqual(y, ref_y)
if (not test_module):
self.assertEqual(running_mean, cpu_running_mean)
self.assertEqual(running_var, cpu_running_var)
else:
self.assertEqual(mps_instancenorm_op.running_mean, instancenorm_op.running_mean)
self.assertEqual(mps_instancenorm_op.running_var, instancenorm_op.running_var)
cpu_grad = torch.randn(ref_y.shape)
grad = cpu_grad.to('mps')
ref_y.backward(gradient=cpu_grad)
y.backward(gradient=grad)
self.assertEqual(x.grad, cpu_x.grad)
if (wts):
if (not test_module):
self.assertEqual(weight.grad, cpu_weight.grad)
self.assertEqual(bias.grad, cpu_bias.grad)
else:
self.assertEqual(mps_instancenorm_op.weight.grad, instancenorm_op.weight.grad)
self.assertEqual(mps_instancenorm_op.bias.grad, instancenorm_op.bias.grad)
for shape in [(2, 3, 2, 2), (2, 3, 2, 2, 2), (2, 3, 2)]:
for test_module in [False, True]:
for track_running_stats in [True, False]:
for channels_last in [False]:
if (channels_last and len(shape) != 4):
continue
# Running stats must be tracked in eval mode
if (track_running_stats):
helper(shape, eps=0, momentum=1, channels_last=channels_last,
track_running_stats=track_running_stats, test_module=test_module)
helper(shape, channels_last=channels_last,
track_running_stats=track_running_stats, test_module=test_module)
helper(shape, eps=1e-05, momentum=0.1, wts=False, channels_last=channels_last,
track_running_stats=track_running_stats, test_module=test_module)
helper(shape, eps=0, momentum=1.0, wts=False, channels_last=channels_last,
track_running_stats=track_running_stats, test_module=test_module)
helper(shape, eps=1, momentum=1, wts=True, channels_last=channels_last,
track_running_stats=track_running_stats, test_module=test_module)
helper(shape, eps=3, momentum=0.67, wts=True, channels_last=channels_last,
track_running_stats=track_running_stats, test_module=test_module)
helper(shape, eps=1e-05, momentum=0.1, wts=False, channels_last=channels_last,
track_running_stats=track_running_stats, test_module=test_module)
helper(shape, eps=0, momentum=1.0, wts=False, channels_last=channels_last,
track_running_stats=track_running_stats, test_module=test_module)
helper(shape, eps=1, momentum=1, wts=True, channels_last=channels_last,
track_running_stats=track_running_stats, test_module=test_module)
helper(shape, eps=3, momentum=0.67, wts=True, channels_last=channels_last,
track_running_stats=track_running_stats, test_module=test_module)
def test_weight_norm(self):
def validate_weight_norm_equality(model, cpu_model, x, cpu_x, dim):
cpu_norm = torch.nn.utils.parametrizations.weight_norm(cpu_model, dim=dim)
norm = torch.nn.utils.parametrizations.weight_norm(model, dim=dim)
cpu_out = cpu_norm(cpu_x)
out = norm(x)
self.assertEqual(cpu_out, out)
cpu_grad = torch.randn(cpu_out.shape)
grad = cpu_grad.to('mps')
cpu_out.backward(gradient=cpu_grad)
out.backward(gradient=grad)
self.assertEqual(cpu_model.parametrizations.weight.original0.grad, model.parametrizations.weight.original0.grad)
self.assertEqual(cpu_model.parametrizations.weight.original1.grad, model.parametrizations.weight.original1.grad)
self.assertEqual(x.grad, cpu_x.grad)
def helper(dim, layer='linear', dtype=torch.float32):
# linear layer
if layer == 'linear':
cpu_x = torch.randn((2, 5), device='cpu', dtype=dtype, requires_grad=True)
x = cpu_x.detach().clone().to('mps').requires_grad_()
cpu_weight = torch.randn(10, 5, device='cpu', dtype=dtype, requires_grad=True)
weight = cpu_weight.detach().clone().to('mps').requires_grad_()
cpu_bias = torch.randn(10, device='cpu', dtype=dtype, requires_grad=True)
bias = cpu_bias.detach().clone().to('mps').requires_grad_()
cpu_linear = torch.nn.Linear(5, 10, device='cpu')
linear = torch.nn.Linear(5, 10, device='mps')
with torch.no_grad():
cpu_linear.weight.copy_(cpu_weight)
cpu_linear.bias.copy_(cpu_bias)
linear.weight.copy_(weight)
linear.bias.copy_(bias)
validate_weight_norm_equality(linear, cpu_linear, x, cpu_x, dim)
# conv layer
if layer == 'conv':
cpu_x = torch.randn((3, 5, 5), device='cpu', dtype=dtype, requires_grad=True)
x = cpu_x.detach().clone().to('mps').requires_grad_()
cpu_conv = torch.nn.Conv2d(3, 3, 3, device='cpu')
conv = torch.nn.Conv2d(3, 3, 3, device='mps')
with torch.no_grad():
conv.weight.copy_(cpu_conv.weight)
conv.bias.copy_(cpu_conv.bias)
validate_weight_norm_equality(conv, cpu_conv, x, cpu_x, dim)
# conv3d layer
if layer == 'conv3d':
cpu_x = torch.randn((3, 5, 5, 4), device='cpu', dtype=dtype, requires_grad=True)
x = cpu_x.detach().clone().to('mps').requires_grad_()
cpu_conv = torch.nn.Conv3d(3, 3, 3, device='cpu')
conv = torch.nn.Conv3d(3, 3, 3, device='mps')
with torch.no_grad():
conv.weight.copy_(cpu_conv.weight)
conv.bias.copy_(cpu_conv.bias)
validate_weight_norm_equality(conv, cpu_conv, x, cpu_x, dim)
helper(0, layer='linear')
helper(1, layer='linear')
helper(-1, layer='linear')
helper(0, layer='conv')
helper(1, layer='conv')
helper(2, layer='conv')
helper(3, layer='conv')
helper(-1, layer='conv')
# Conv3d is only available from MacOS 13 onwards
helper(0, layer='conv3d')
helper(1, layer='conv3d')
helper(2, layer='conv3d')
helper(3, layer='conv3d')
helper(4, layer='conv3d')
helper(-1, layer='conv3d')
# Test conv2d
def test_conv2d_unit(self):
def helper(input_shape, wt_shape,
stride=1, padding=0,
dilation=1, groups=1,
bias_shape=None):
cpu_x = torch.randn(input_shape, device='cpu', dtype=torch.float, requires_grad=True)
x = cpu_x.detach().clone().to('mps').requires_grad_()
cpu_wt = torch.randn(wt_shape, device='cpu', dtype=torch.float, requires_grad=True)
wt = cpu_wt.detach().clone().to('mps').requires_grad_()
cpu_bias = None
bias = None
if (bias_shape is not None):
cpu_bias = torch.randn(bias_shape, device='cpu', dtype=torch.float, requires_grad=True)
bias = cpu_bias.detach().clone().to('mps').requires_grad_()
y = torch.nn.functional.conv2d(x, wt, bias=bias, stride=stride,
padding=padding, dilation=dilation, groups=groups)
ref_y = torch.nn.functional.conv2d(cpu_x, cpu_wt, bias=cpu_bias, stride=stride,
padding=padding, dilation=dilation, groups=groups)
cpu_grad = torch.ones_like(ref_y)
grad = cpu_grad.to('mps')
y.backward(gradient=grad)
ref_y.backward(gradient=cpu_grad)
self.assertEqual(y, ref_y, rtol=2.6e-05, atol=2e-04)
self.assertEqual(x.grad, cpu_x.grad, rtol=2.6e-06, atol=2e-05)
self.assertEqual(wt.grad, cpu_wt.grad, atol=8e-04, rtol=10.4e-05)
if (bias_shape is not None):
self.assertEqual(bias.grad, cpu_bias.grad, atol=8e-04, rtol=10.4e-05)
N = 1
C_in = 3
C_out = 64
H = 64
W = 64
kH = 4
kW = 4
stride = 2
padding = 1
helper((N, C_in, H, W), (C_out, C_in, kH, kW), stride=stride, padding=padding)
N = 4
C_in = 16
H = 32
W = 32
C_out = 8
kH = 3
kW = 3
for groups in [1, 2, 4]:
helper((N, C_in, H, W), (C_out, C_in // groups, kH, kW), groups=groups)
helper((N, C_in, H, W), (C_out, C_in // groups, kH, kW), groups=groups)
helper((N, C_in, H, W), (C_out, C_in // groups, kH, kW), bias_shape=(C_out), groups=groups)
helper((N, C_in, H, W), (C_out, C_in // groups, kH, kW), bias_shape=(C_out), groups=groups)
helper((N, C_in * 2, H * 2, W * 2), (C_out * 2, (C_in * 2) // groups, kH + 2, kW + 2), groups=groups)
helper((N, C_in * 2, H * 2, W * 2), (C_out * 2, (C_in * 2) // groups, kH + 2, kW + 2), groups=groups)
helper((N, C_in * 2, H * 2, W * 2), (C_out * 2, (C_in * 2) // groups,
kH + 2, kW + 2), bias_shape=(C_out * 2), groups=groups)
helper((N, C_in * 2, H * 2, W * 2), (C_out * 2, (C_in * 2) // groups,
kH + 2, kW + 2), bias_shape=(C_out * 2), groups=groups)
# Test conv transpose 2d
def test_conv_transpose2d(self):
def helper(input_shape, wt_shape,
stride=1, padding=0,
output_padding=0,
dilation=1, groups=1,
bias_shape=None):
cpu_x = torch.randn(input_shape, device='cpu', dtype=torch.float, requires_grad=True)
x = cpu_x.detach().clone().to('mps').requires_grad_()
cpu_wt = torch.randn(wt_shape, device='cpu', dtype=torch.float, requires_grad=True)
wt = cpu_wt.detach().clone().to('mps').requires_grad_()
cpu_bias = None
bias = None
if (bias_shape is not None):
cpu_bias = torch.randn(bias_shape, device='cpu', dtype=torch.float, requires_grad=True)
bias = cpu_bias.detach().clone().to('mps').requires_grad_()
y = torch.nn.functional.conv_transpose2d(
x, wt, bias=bias, stride=stride, padding=padding, output_padding=output_padding, groups=groups, dilation=dilation)
ref_y = torch.nn.functional.conv_transpose2d(
cpu_x, cpu_wt, bias=cpu_bias, stride=stride, padding=padding,
output_padding=output_padding, groups=groups, dilation=dilation)
cpu_grad = torch.randn(ref_y.shape)
grad = cpu_grad.to('mps')
y.backward(gradient=grad)
ref_y.backward(gradient=cpu_grad)
self.assertEqual(y, ref_y, rtol=2.6e-05, atol=2e-04)
self.assertEqual(x.grad, cpu_x.grad, rtol=2.6e-06, atol=2e-05)
self.assertEqual(wt.grad, cpu_wt.grad, atol=8e-04, rtol=10.4e-05)
# if (bias_shape is not None):
# print(cpu_bias.grad)
# print(bias.grad.to('cpu'))
# self.assertEqual(bias.grad, cpu_bias.grad)
N = 4
C_in = 2
H = 32
W = 32
C_out = 8
groups = 1
kH = 3
kW = 3
for stride in [1, 2, 3]:
for padding in [0, 1, 2]:
for output_padding in [0, 1, 2]:
for dilation in [1, 2]:
if (output_padding >= stride or output_padding >= dilation):
continue
helper((N, C_out, H, W), (C_out, C_in, kH, kW), stride=stride,
padding=padding, output_padding=output_padding, dilation=dilation)
helper((N, C_out, H, W), (C_out, C_in, kH, kW), stride=stride,
padding=padding, output_padding=output_padding, dilation=dilation)
helper((N, C_out, H, W), (C_out, C_in, kH, kW), bias_shape=(C_in), stride=stride,
padding=padding, output_padding=output_padding, dilation=dilation)
helper((N, C_out, H, W), (C_out, C_in, kH, kW), bias_shape=(C_in), stride=stride,
padding=padding, output_padding=output_padding, dilation=dilation)
# Test sigmoid
def test_sigmoid(self):
def helper(shape):
cpu_x = torch.randn(shape, device='cpu', dtype=torch.float, requires_grad=True)
x = cpu_x.detach().clone().to('mps').requires_grad_()
sigmoid_op = torch.nn.Sigmoid()
y = sigmoid_op(x)
ref_y = sigmoid_op(cpu_x)
cpu_grad = torch.ones_like(ref_y)
grad = cpu_grad.to('mps')
y.backward(gradient=grad)
ref_y.backward(gradient=cpu_grad)
self.assertEqual(y, ref_y)
self.assertEqual(x.grad, cpu_x.grad)
helper((2, 3, 4, 5))
helper((2, 3, 4))
helper((2, 8, 4, 5))
# Test tanh
def test_tanh(self):
def helper(shape):
cpu_x = torch.randn(shape, device='cpu', dtype=torch.float, requires_grad=True)
x = cpu_x.detach().clone().to('mps').requires_grad_()
tanh_op = torch.nn.Tanh()
y = tanh_op(x)
ref_y = tanh_op(cpu_x)
cpu_grad = torch.ones_like(ref_y)
grad = cpu_grad.to('mps')
y.backward(gradient=grad)
ref_y.backward(gradient=cpu_grad)
self.assertEqual(y, ref_y)
self.assertEqual(x.grad, cpu_x.grad)
helper((2, 3, 4, 5))
helper((2, 3, 4))
helper((2, 8, 4, 5))
def test_threshold(self):
def helper(threshold, value, num_elems, inplace=False, requires_grad=True):
m = nn.Threshold(threshold=threshold, value=value, inplace=inplace)
input_cpu = torch.randn(num_elems, requires_grad=requires_grad, dtype=torch.float)
input_mps = input_cpu.detach().clone().to('mps').requires_grad_(requires_grad)
output_cpu = m(input_cpu)
output_mps = m(input_mps)
cpu_grad = torch.ones_like(output_cpu)
mps_grad = cpu_grad.to('mps')
self.assertEqual(output_cpu, output_mps)
if requires_grad:
output_cpu.backward(gradient=cpu_grad)
output_mps.backward(gradient=mps_grad)
self.assertEqual(input_cpu.grad, input_mps.grad)
helper(threshold=0.1, value=20, num_elems=2)
helper(threshold=-0.1, value=10, num_elems=10)
helper(threshold=0.5, value=-15, num_elems=100)
helper(threshold=1, value=10, num_elems=100, inplace=True, requires_grad=False)
# Test pow
def test_pow(self):
def helper(shape):
# aten::pow.Tensor_Tensor
cpu_x = torch.randn(shape, device='cpu', dtype=torch.float, requires_grad=False)
x = cpu_x.detach().clone().to('mps')
cpu_y = torch.randn(shape, device='cpu', dtype=torch.float, requires_grad=False)
y = cpu_y.detach().clone().to('mps')
z = torch.pow(x, y)
ref_z = torch.pow(cpu_x, cpu_y)
self.assertEqual(z, ref_z)
# aten::pow.Tensor_Scalar
cpu_x = torch.randn(shape, device='cpu', dtype=torch.float, requires_grad=False)
x = cpu_x.detach().clone().to('mps')
exp = random.random()
z = torch.pow(x, exp)
ref_z = torch.pow(cpu_x, exp)
self.assertEqual(z, ref_z)
# aten::pow.Scalar
x = random.random()
cpu_y = torch.randn(shape, device='cpu', dtype=torch.float, requires_grad=False)
y = cpu_y.detach().clone().to('mps')
z = torch.pow(x, y)
ref_z = torch.pow(x, cpu_y)
self.assertEqual(z, ref_z)
helper((2, 8, 4, 5))
# Test addcmul
def test_addcmul(self):
def helper(shape, value, xtype=torch.float32, ytype=None, ztype=None):
def rand_helper(dtype):
if dtype.is_floating_point:
return torch.randn(shape, device='cpu', dtype=dtype, requires_grad=False)
return torch.randint(10, shape, dtype=dtype, device='cpu', requires_grad=False)
cpu_x = rand_helper(xtype)
x = cpu_x.detach().clone().to('mps')
cpu_y = rand_helper(ytype if ytype is not None else xtype)
y = cpu_y.detach().clone().to('mps')
cpu_z = rand_helper(ztype if ztype is not None else xtype)
z = cpu_z.detach().clone().to('mps')
y = torch.addcmul(x, y, z, value=value)
ref_y = torch.addcmul(cpu_x, cpu_y, cpu_z, value=value)
self.assertEqual(y, ref_y)
helper((2, 3, 4, 5), 0.1)
helper((2, 8, 4, 5), 0.1)
helper((2, 3, 4, 5), 0.2)
helper((2, 8, 4, 5), 0.2)
# Integral types
helper((2, 2), 1.0, xtype=torch.int32)
helper((2, 2), 2.0, xtype=torch.int16)
# Mixed types
helper((2, 2), 1.0, xtype=torch.float16, ytype=torch.float32)
helper((3, 2), 1.0, ytype=torch.float16)
helper((2, 3), 1.0, ztype=torch.float16)
helper((2, 2), 1.0, xtype=torch.int32, ytype=torch.int16, ztype=torch.uint8)
helper((2, 2), 1.0, ytype=torch.int16, ztype=torch.uint8)
# Test addcdiv
def test_addcdiv(self):
def helper(shape, value):
cpu_x = torch.randn(shape, device='cpu', dtype=torch.float, requires_grad=False)
cpu_y = torch.randn(shape, device='cpu', dtype=torch.float, requires_grad=False)
# clamp to avoid division by 0
cpu_z = torch.randn(shape, device='cpu', dtype=torch.float, requires_grad=False).clamp_min_(0.1)
cpu_out = torch.randn(shape, device='cpu', dtype=torch.float, requires_grad=False)
mps_x = cpu_x.detach().clone().to('mps')
mps_y = cpu_y.detach().clone().to('mps')
mps_z = cpu_z.detach().clone().to('mps')
mps_out = cpu_out.detach().clone().to('mps')
result_div_mps = torch.addcdiv(mps_x, mps_y, mps_z, value=value)
result_div_cpu = torch.addcdiv(cpu_x, cpu_y, cpu_z, value=value)
self.assertEqual(result_div_mps, result_div_cpu)
# test .out variant
self.assertEqual(torch.addcdiv(mps_x, mps_y, mps_z, out=mps_out, value=value), result_div_cpu)
helper((2, 3, 4, 5), 0.1)
helper((2, 8, 4, 5), 0.2)
helper((2, 3, 4, 5), 1.0) # value of 1 should be ignored internally
def test_addcdiv_transpose(self):
# Regression test for issue https://github.com/pytorch/pytorch/issues/118115
# Testing continuity of all input tensors
def helper(shape, value):
shape_t = shape[::-1]
for i in range(2):
for j in range(2):
for k in range(2):
x = torch.rand(shape, device="cpu") if i == 0 else torch.rand(shape_t, device="cpu").t()
y = torch.rand(shape, device="cpu") if j == 0 else torch.rand(shape_t, device="cpu").t()
z = torch.rand(shape, device="cpu") if k == 0 else torch.rand(shape_t, device="cpu").t()
x_mps = x.detach().clone().to(device="mps")
y_mps = y.detach().clone().to(device="mps")
z_mps = z.detach().clone().to(device="mps")
result_cpu = x.addcdiv_(y, z, value=value)
result_mps = x_mps.addcdiv(y_mps, z_mps, value=value)
result_mps_out = result_cpu.detach().clone().to('mps')
torch.addcdiv(x_mps, y_mps, z_mps, out=result_mps_out, value=value)
self.assertEqual(result_cpu, result_mps)
self.assertEqual(result_cpu, result_mps_out)
helper((2, 3), 1.0)
helper((2, 3), 0.2)
helper((100, 300), 1.0)
helper((100, 300), 0.2)
def test_buffer_size_match(self):
# this test shouldn't cause any crash
size = 16
cpu_A = torch.rand(size, device='cpu')
cpu_F = torch.rand(size, size, size, device='cpu')
mps_A = cpu_A.to('mps')
mps_F = cpu_F.to('mps')
self.assertEqual(cpu_A @ cpu_F, mps_A @ mps_F)
def test_transpose_inplace(self):
values = [[1.0, 2.0, 3.0], [4.0, 5.0, 6.0], [7.0, 8.0, 9.0]]
cpu_x = torch.tensor(values, device='cpu')
mps_x = torch.tensor(values, device='mps')
cpu_x.transpose_(0, 1)
mps_x.transpose_(0, 1)
self.assertEqual(cpu_x, mps_x.to('cpu'))
def test_expand_cpu_to_mps_copy(self):
# https://github.com/pytorch/pytorch/issues/78642
x = torch.tensor(1).expand([10]).to("mps")
x_cpu = torch.tensor(1).expand([10])
self.assertEqual(x_cpu, x.cpu())
def test_cpu_to_strided_mps_copy(self):
# https://github.com/pytorch/pytorch/issues/86975
a1 = torch.Tensor([[1, 2], [3, 4], [5, 6]]).to(torch.device("mps"))
b1 = torch.Tensor([-1, -1])
a1[1:, 1] = b1
a2 = torch.Tensor([[1, 2], [3, 4], [5, 6]]).to(torch.device("mps"))
b2 = torch.Tensor([-1, -1]).to(torch.device("mps"))
a2[1:, 1] = b2
self.assertEqual(a1, a2)
def test_view_slice_reshape(self):
x = torch.randn([1, 4, 4], device="mps")
y = x[0, :1, 1:]
x_cpu = x.to("cpu")
y_cpu = x_cpu[0, :1, 1:]
r = y + 1
r_cpu = y_cpu + 1
self.assertEqual(r, r_cpu)
def test_slice_reshape(self):
x = torch.randn([1, 6, 4, 2], dtype=torch.float, device="mps")
x_cpu = x.detach().clone().to("cpu")
x = x[:, 3:].view(2, 3, 4, 1)
x_cpu = x_cpu[:, 3:].view(2, 3, 4, 1)
self.assertEqual(x, x_cpu)
x = x + 2
x_cpu = x_cpu + 2
self.assertEqual(x, x_cpu)
# Regression test for https://github.com/pytorch/pytorch/issues/143140
def slice_and_reshape(t):
return t[:, :, :, :3, :3].reshape(18, 1, 3)
x = torch.rand(1, 1, 1, 4, 5, 6, dtype=torch.cfloat, device="mps")
x_cpu = x.detach().clone().cpu()
self.assertEqual(slice_and_reshape(x_cpu), slice_and_reshape(x).cpu())
def test_reshape_storage_offset(self):
# https://github.com/pytorch/pytorch/issues/95883
B = 4
T = 1
lin_cpu = nn.Linear(10, 256)
lin_mps = nn.Linear(10, 256, device="mps")
# Use the same weights and bias as the ones from the cpu
lin_mps.weight.data = lin_cpu.weight.data.detach().clone().to("mps").requires_grad_()
lin_mps.bias.data = lin_cpu.bias.data.detach().clone().to("mps").requires_grad_()
x_mps = torch.rand([B, T, 10], device="mps", requires_grad=True)
x_cpu = x_mps.detach().clone().cpu().requires_grad_()
x_mps = lin_mps(x_mps)
x_cpu = lin_cpu(x_cpu)
self.assertEqual(x_mps.shape, (B, T, 256))
self.assertEqual(x_cpu.shape, (B, T, 256))
cls_token_mps = torch.rand([1, 256], device="mps", requires_grad=True).repeat(B, 1, 1)
cls_token_cpu = cls_token_mps.detach().clone().cpu()
x_mps = torch.cat([cls_token_mps, x_mps], dim=1)
x_cpu = torch.cat([cls_token_cpu, x_cpu], dim=1)
x_mps = x_mps.transpose(0, 1)
x_cpu = x_cpu.transpose(0, 1)
target_mps = torch.rand_like(x_mps)
target_cpu = target_mps.detach().clone().cpu()
loss_mps = F.mse_loss(x_mps, target_mps)
loss_cpu = F.mse_loss(x_cpu, target_cpu)
self.assertEqual(loss_mps, loss_cpu)
loss_mps.backward()
loss_cpu.backward()
self.assertEqual(x_mps.grad, x_cpu.grad)
def test_stack_storage_offset(self):
# https://github.com/pytorch/pytorch/issues/87856
x_cpu = torch.tensor([[1, 2]])
x_mps = x_cpu.detach().clone().to("mps")
y_cpu = torch.stack((x_cpu[:, :1], x_cpu[:, -1:]), dim=-1)
y_mps = torch.stack((x_mps[:, :1], x_mps[:, -1:]), dim=-1)
self.assertEqual(y_cpu, y_mps)
t_mps = torch.tensor([1, 2, 3, 4], device="mps")
t_cpu = t_mps.detach().cpu().detach()
x_mps = t_mps[2:]
y_mps = t_mps[:2]
x_cpu = t_cpu[2:]
y_cpu = t_cpu[:2]
res_mps = torch.stack((y_mps, x_mps), dim=-1)
res_cpu = torch.stack((y_cpu, x_cpu), dim=-1)
self.assertEqual(res_mps, res_cpu)
def test_unsafe_chunk(self):
# https://github.com/pytorch/pytorch/issues/91065
a = torch.rand(5, dtype=torch.float32, device="cpu")
ret = a.unsafe_chunk(4, 0)
y = ret[0] * ret[2]
a_mps = a.to("mps")
ret_mps = a_mps.unsafe_chunk(4, 0)
y_mps = ret_mps[0] * ret_mps[2]
self.assertEqual(y, y_mps)
def test_slice_casting(self):
# generate random binary numbers
cpu_in = torch.bernoulli(torch.empty(1, 1, 128, 128).uniform_(0, 1)).to(torch.uint8)
mps_in = cpu_in.detach().clone().to("mps")
# check copy_cast(unit8 -> bool) on tensors with storage offset
cpu_out = cpu_in[:, :, 11 : 12, :12].to(torch.bool)
mps_out = mps_in[:, :, 11 : 12, :12].to(torch.bool)
self.assertEqual(cpu_out, mps_out)
def test_slice_reshape_contg_view(self):
import torch
x_mps = torch.randn(1, 4800, 2, device="mps")
x_cpu = x_mps.detach().clone().cpu()
r_mps = x_mps + 2
r_cpu = x_cpu + 2
self.assertEqual(r_mps, r_cpu)
def test_contiguous_slice_2d(self):
def helper(shape):
for i in range(shape[0]):
for j in range(shape[1]):
t_mps = torch.randn(shape, device="mps")
t_cpu = t_mps.detach().clone().cpu()
y_mps = t_mps[i:, :j]
y_cpu = t_cpu[i:, :j]
self.assertEqual(y_mps + 1, y_cpu + 1)
y_mps = t_mps[i:, j]
y_cpu = t_cpu[i:, j]
self.assertEqual(y_mps + 1, y_cpu + 1)
y_mps = t_mps[i, :j]
y_cpu = t_cpu[i, :j]
self.assertEqual(y_mps + 1, y_cpu + 1)
y_mps = t_mps[:i, :j]
y_cpu = t_cpu[:i, :j]
self.assertEqual(y_mps + 1, y_cpu + 1)
y_mps = t_mps[:i, j]
y_cpu = t_cpu[:i, j]
self.assertEqual(y_mps + 1, y_cpu + 1)
y_mps = t_mps[:i, j:]
y_cpu = t_cpu[:i, j:]
self.assertEqual(y_mps + 1, y_cpu + 1)
l = []
for N in range(1, 3):
l.append(N)
for C in range(1, 3):
l.append(C)
helper(l)
for D in range(1, 3):
l.append(D)
helper(l)
for H in range(1, 3):
l.append(H)
helper(l)
for W in range(1, 3):
l.append(W)
helper(l)
l.pop()
l.pop()
l.pop()
l.pop()
l.pop()
helper([9, 15, 4])
helper([9, 3, 2])
helper([3, 4, 18, 22])
helper([3, 4, 18, 22, 150])
def test_contiguous_slice_3d(self):
x = torch.randn(2, 3, 3, device="mps")
x_cpu = x.detach().clone().cpu()
x = x[:1]
x_cpu = x_cpu[:1]
out = x[:, 0:1, 0:1] * x[:, 1:2, 1:2]
out_cpu = x_cpu[:, 0:1, 0:1] * x_cpu[:, 1:2, 1:2]
self.assertEqual(out, out_cpu)
def test_view_slice(self):
# https://github.com/pytorch/pytorch/issues/83995
NUM_SAMPLES = 60
s = (0, 1)
X = torch.rand(8000, 3, dtype=torch.float32, device='cpu')
X_mps = X.detach().clone().to("cpu")
idx = torch.randint(0, X.shape[0], (1,)).repeat(len(s))
pts = torch.randint(0, X.shape[0], (NUM_SAMPLES, X.shape[1]))
idx_mps = idx.to("mps")
pts_mps = pts.to("mps")
pts[:, s] = idx
pts_mps[:, s] = idx_mps
actual_pts = torch.zeros(NUM_SAMPLES, X.shape[1], dtype=torch.float)
actual_pts_mps = torch.zeros(NUM_SAMPLES, X.shape[1], dtype=torch.float, device="mps")
for i in range(NUM_SAMPLES):
for j in range(X.shape[1]):
actual_pts_mps[i, j] = X_mps[pts_mps[i, j], j]
actual_pts[i, j] = X[pts[i, j], j]
self.assertEqual(actual_pts[i, j], actual_pts_mps[i, j])
def test_slice_scatter(self):
shape = (4, 4)
tensor = torch.randint(10, shape, device="mps")
tensor_before = tensor.clone()
torch.empty(shape[0], shape[1] * 2, device="mps")[:, ::2].copy_(tensor)
torch.testing.assert_close(tensor, tensor_before)
def test_slice(self):
values = [[1.0, 2.0, 3.0], [4.0, 5.0, 6.0], [7.0, 8.0, 9.0]]
cpu_x = torch.tensor(values, device='cpu')
mps_x = (torch.tensor(values, device='mps', dtype=torch.float))
cpu_slice1 = cpu_x[:2, :]
mps_slice1 = mps_x[:2, :]
self.assertEqual(cpu_slice1, mps_slice1)
cpu_slice2 = cpu_x[:, :1]
mps_slice2 = mps_x[:, :1]
self.assertEqual(cpu_slice2, mps_slice2)
cpu_slice3 = cpu_x[1:2, :]
mps_slice3 = mps_x[1:2, :]
self.assertEqual(cpu_slice3, mps_slice3.to('cpu'))
cpu_slice4 = cpu_x[1, :]
mps_slice4 = mps_x[1, :].to('cpu')
self.assertEqual(cpu_slice4, mps_slice4)
@parametrize("torch_type", arg_values=[torch.float16, torch.float32, torch.bfloat16])
def test_slice_view_api(self, torch_type: torch.dtype):
def helper(x_tensor, y_func, z_func, r_func=None):
x_mps = x_tensor.detach().clone().to("mps")
y = y_func(x_tensor)
y_mps = y_func(x_mps)
self.assertEqual(y, y_mps)
z = z_func(y)
z_mps = z_func(y_mps)
self.assertEqual(z, z_mps)
self.assertEqual(z.storage_offset(), z_mps.storage_offset())
if r_func:
r = r_func(z)
r_mps = r_func(z_mps)
self.assertEqual(r, r_mps)
# Skip bfloat16 before MacOS15
if not (MACOS_VERSION < 15.0 and torch_type == torch.bfloat16):
# Tests for previously encountered MPS bugs
helper(
torch.randn(4, 4, dtype=torch_type),
lambda x: x[1],
lambda y: y.reshape(2, 2),
lambda z: z + 1
)
helper(
torch.randn(2, 4, dtype=torch_type),
lambda x: x[1],
lambda y: y + torch.ones(4, device=y.device)
)
helper(
torch.randn(4, 6, dtype=torch_type),
lambda x: x[1],
lambda y: y.reshape(3, 2).t(),
lambda z: z + 1
)
helper(
torch.arange(4, dtype=torch_type).resize(1, 2, 2),
lambda x: x.permute(2, 0, 1),
lambda y: y + 1
)
helper(
torch.randn(4, 8, dtype=torch_type),
lambda x: x.transpose(0, 1).reshape(-1),
lambda y: y[:2],
lambda z: z + 1
)
helper(
torch.randn(1, dtype=torch_type),
lambda x: x.expand(2, 3),
lambda y: y + torch.ones(2, 3, device=y.device)
)
def test_slice_reshape_contiguous(self):
x = torch.randn(4, 4)
x_mps = x.detach().clone().to("mps")
y = x[1]
y_mps = x_mps[1]
self.assertEqual(y, y_mps)
z = y.reshape(2, 2)
z_mps = y_mps.reshape(2, 2)
self.assertEqual(z, z_mps)
self.assertEqual(z.storage_offset(), z_mps.storage_offset())
def test_scalar_from_slice_unary(self):
# https://github.com/pytorch/pytorch/issues/82543
tensor_list = torch.tensor([1.0, 1.2], device="mps")
for scalar in tensor_list:
r_mps = torch.ceil(scalar)
r_cpu = torch.ceil(scalar.to("cpu"))
self.assertEqual(r_mps.cpu(), r_cpu)
def test_scalar_from_slice_binary(self):
# https://github.com/pytorch/pytorch/issues/82543
def helper(binary_op):
tensor_list = torch.tensor([1.0, 1.2, 2.5, 1.0], device="mps")
for scalar in tensor_list:
r_mps = binary_op(scalar, 1.0)
r_cpu = binary_op(scalar.cpu(), 1.0)
self.assertEqual(r_mps.cpu(), r_cpu)
helper(torch.sub)
helper(torch.add)
helper(torch.not_equal)
helper(torch.eq)
def test_slice_contiguous_view(self):
# https://github.com/pytorch/pytorch/issues/77750
def helper(operator):
t_mps = torch.tensor([1, 2, 3, 4], device="mps")
t_cpu = torch.tensor([1, 2, 3, 4], device="cpu")
# contiguous view
x_mps = t_mps[2:] # 3, 4
y_mps = t_mps[:2] # 1, 2
x_cpu = t_cpu[2:]
y_cpu = t_cpu[:2]
res_mps = res_cpu = None
if operator == "<=":
res_mps = x_mps <= y_mps
res_cpu = x_cpu <= y_cpu
elif operator == "<":
res_mps = x_mps < y_mps
res_cpu = x_cpu < y_cpu
elif operator == ">=":
res_mps = x_mps >= y_mps
res_cpu = x_cpu >= y_cpu
elif operator == ">":
res_mps = x_mps >= y_mps
res_cpu = x_cpu >= y_cpu
elif operator == "==":
res_mps = x_mps == y_mps
res_cpu = x_cpu == y_cpu
elif operator == "!=":
res_mps = x_mps != y_mps
res_cpu = x_cpu != y_cpu
elif operator == "stack":
res_mps = torch.stack((y_mps, x_mps), dim=-1)
res_cpu = torch.stack((y_cpu, x_cpu), dim=-1)
self.assertEqual(res_mps, res_cpu)
for op in ["<=", "<", ">=", ">", "==", "!=", "stack"]:
helper(op)
def test_slice_of_slice(self):
x = torch.tensor([0.5, 0.5], device="cpu")
x_mps = torch.tensor([0.5, 0.5], device="mps")
tensor = x[1][None]
tensor_mps = x_mps[1][None]
res = tensor.ne(0)
res_mps = tensor_mps.ne(0)
self.assertEqual(res, res_mps)
def test_index_storage_offset(self):
# https://github.com/pytorch/pytorch/issues/78107
a = torch.tensor([8.2670e-01, -1.0293e+00])
b_cpu = a[0]
c_cpu = a[1]
# both 'b' and 'c' are views of 'a'
# 'b' has a storage offset of 0, while 'c' has a storage offset of 1
# when copying from 'cpu' to 'mps', c will have a storage_offset of 1 which needs to be taking into account,
# otherwise it ends with same value as 'b'
b = b_cpu.to('mps')
c = c_cpu.to('mps')
res_mps = b > c
res_cpu = b_cpu > c_cpu
self.assertEqual(res_mps, res_cpu)
res_mps = c > b
res_cpu = c_cpu > b_cpu
self.assertEqual(res_mps, res_cpu)
def test_flatten(self):
values = [[[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]], [[7.0, 8.0, 9.0], [10.0, 11.0, 12.0]]]
cpu_x = torch.tensor(values, device='cpu')
mps_x = torch.tensor(values, device='mps')
cpu_flatten1 = cpu_x.flatten()
mps_flatten1 = mps_x.flatten().to('cpu')
self.assertEqual(cpu_flatten1, mps_flatten1)
cpu_flatten2 = cpu_x.flatten(start_dim=1)
mps_flatten2 = mps_x.flatten(start_dim=1).to('cpu')
self.assertEqual(cpu_flatten2, mps_flatten2)
cpu_flatten3 = cpu_x.flatten(end_dim=1)
mps_flatten3 = mps_x.flatten(end_dim=1).to('cpu')
self.assertEqual(cpu_flatten3, mps_flatten3)
# Test repeat
def test_repeat(self):
def helper(shape, repeats):
cpu_x = torch.randn(shape, device='cpu', dtype=torch.float, requires_grad=True)
x = cpu_x.detach().clone().to('mps').requires_grad_()
y = x.repeat(repeats)
ref_y = cpu_x.repeat(repeats)
cpu_grad = torch.randn(ref_y.shape)
grad = cpu_grad.to('mps')
y.backward(gradient=grad)
ref_y.backward(gradient=cpu_grad)
self.assertEqual(y, ref_y)
self.assertEqual(x.grad, cpu_x.grad)
helper((2, 3, 4, 5), (2, 3, 4, 5))
helper((2, 3, 4), (4, 3, 2, 5, 7, 2))
helper((3, 4, 5), (2, 3, 4, 5))
helper((3, 4, 5), (2, 2, 2))
def test_torch_repeat_interleave(self, device="mps"):
y = torch.tensor([[1, 2], [3, 4]], device=device)
# exercise single argument function signature
temp = y.repeat_interleave(2)
self.assertEqual(torch.Size([8]), temp.size())
for dtype in [torch.int, torch.long]:
lengths = torch.tensor([1, 2], dtype=dtype, device="mps")
output_size = torch.sum(lengths)
a = torch.repeat_interleave(
y,
lengths,
dim=0,
)
self.assertEqual(a.dtype, y.dtype)
self.assertEqual(a.size(), torch.Size([3, 2]))
a_with_output = torch.repeat_interleave(
y,
lengths,
dim=0,
output_size=output_size,
)
self.assertEqual(a_with_output.dtype, y.dtype)
self.assertEqual(a_with_output.size(), torch.Size([3, 2]))
def test_repeat_interleave(self, device="mps"):
x = torch.tensor([0, 1, 2, 3], device=device)
expected = torch.tensor([1, 2, 2, 3, 3, 3], device=device)
self.assertEqual(torch.repeat_interleave(x), expected)
with self.assertRaises(RuntimeError):
torch.repeat_interleave(torch.arange(4, device=device).reshape(2, 2))
with self.assertRaises(RuntimeError):
torch.repeat_interleave(torch.arange(4.0, device=device))
with self.assertRaises(RuntimeError):
torch.repeat_interleave(torch.tensor([1, 2, -1, 3, 4], device=device))
y = torch.tensor([[1, 2], [3, 4]], device=device)
y1_v1 = torch.repeat_interleave(y, 2)
y1_v2 = torch.repeat_interleave(y, torch.tensor(2, device=device))
y1_v3 = torch.repeat_interleave(y, torch.tensor([2], device=device))
y1_expect = torch.tensor([1, 1, 2, 2, 3, 3, 4, 4], device=device)
self.assertEqual(y1_v1, y1_expect)
self.assertEqual(y1_v2, y1_expect)
self.assertEqual(y1_v3, y1_expect)
y2 = torch.repeat_interleave(y, 3, dim=1)
y2_expect = torch.tensor([[1, 1, 1, 2, 2, 2],
[3, 3, 3, 4, 4, 4]], device=device)
self.assertEqual(y2, y2_expect)
y3 = torch.repeat_interleave(y, torch.tensor([1, 2], device=device), dim=0)
y3_expect = torch.tensor([[1, 2],
[3, 4],
[3, 4]], device=device)
self.assertEqual(y3, y3_expect)
with self.assertRaises(RuntimeError):
torch.repeat_interleave(y, torch.tensor([1, 2, 3], device=device), dim=0)
with self.assertRaises(RuntimeError):
torch.repeat_interleave(y, torch.arange(9, device=device).reshape(3, 3), dim=0)
# test zero sized dimension
x = torch.zeros((5, 0), device=device)
y = torch.repeat_interleave(x, repeats=3, dim=1)
self.assertEqual(y, x.new_zeros(5, 0, device=device))
x = torch.tensor([], dtype=torch.int64, device=device)
y = torch.repeat_interleave(x, x)
self.assertEqual(y, x)
def test_repeat_interleave_simple(self):
def helper(shape, dtype=torch.float32, num_repeats=torch.Tensor(), dim=None):
x = torch.randn(shape, dtype=dtype, device="mps")
x_cpu = x.detach().clone().cpu()
num_repeats_cpu = num_repeats.detach().clone().cpu()
repeats = torch.repeat_interleave(x, num_repeats, dim)
repeats_cpu = torch.repeat_interleave(x_cpu, num_repeats_cpu, dim)
self.assertEqual(repeats, repeats_cpu)
helper(shape=3, num_repeats=torch.tensor([100], device="mps"))
helper(shape=(2, 2), num_repeats=torch.tensor([3, 3], device="mps"), dim=0)
helper(shape=(10, 15, 8), num_repeats=torch.arange(10, device="mps"), dim=0)
helper(shape=(10, 15, 8), num_repeats=torch.randint(0, 100, (15, ), device="mps"), dim=1)
helper(shape=(10, 15, 30), num_repeats=torch.randint(0, 100, (30, ), device="mps"), dim=2)
def test_count_nonzero(self):
def helper(dtype):
n = [
[[1, 0, 2], [3, 0, 2], [7, 9, -4]],
[[0, 2, 3], [3, 2, 1], [2, 0, 0]],
]
cpu_x = torch.tensor(n, dtype=dtype)
mps_x = torch.tensor(n, dtype=dtype).to('mps')
# All non-zeros
self.assertEqual(
torch.count_nonzero(cpu_x),
torch.count_nonzero(mps_x)
)
# dim=1
self.assertEqual(
torch.count_nonzero(cpu_x, dim=1),
torch.count_nonzero(mps_x, dim=1)
)
# dim=(0, 1)
self.assertEqual(
torch.count_nonzero(cpu_x, dim=(0, 1)),
torch.count_nonzero(mps_x, dim=(0, 1))
)
helper(torch.int32)
helper(torch.int64)
helper(torch.float16)
helper(torch.float32)
def _test_module_empty_input(self, module, inp, check_size=True):
inp.requires_grad_(True)
out = module(inp)
gO = torch.rand_like(out)
out.backward(gO)
if check_size:
self.assertEqual(out.size(), inp.size())
for p in module.parameters():
if p.requires_grad:
self.assertEqual(p.grad, torch.zeros_like(p.grad))
self.assertEqual(inp.grad, torch.zeros_like(inp))
# Test dtype casting, with and without simultaneous device change
def test_to(self):
values = [[[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]], [[7.0, 8.0, 9.0], [10.0, 11.0, 12.0]]]
cpu_x = torch.tensor(values, device='cpu')
mps_x = torch.tensor(values, device='mps')
self.assertEqual(cpu_x.int(), mps_x.int().cpu())
self.assertEqual(cpu_x.bool(), mps_x.bool().cpu())
self.assertEqual(cpu_x.float(), mps_x.float().cpu())
self.assertEqual(torch.tensor(1.3, device='mps').int().cpu(),
torch.tensor(1, dtype=torch.int32))
self.assertEqual(torch.tensor(0.0, device='mps').bool().cpu(), torch.tensor(False))
self.assertEqual(torch.tensor(0.1, device='mps').bool().cpu(), torch.tensor(True))
self.assertEqual(torch.tensor(0.1, device='mps').bool().int().cpu(),
torch.tensor(1, dtype=torch.int32))
self.assertEqual(torch.tensor(0.1, device='mps').bool().int().float().cpu(),
torch.tensor(1.0))
self.assertEqual(torch.tensor(4.25, device='mps').to('cpu', torch.int),
torch.tensor(4, dtype=torch.int32))
self.assertEqual(torch.tensor(4.25, device='cpu').to('mps', torch.int).cpu(),
torch.tensor(4, dtype=torch.int32))
self.assertEqual(torch.tensor(-8.34, device='cpu').to('mps', torch.int),
torch.tensor(-8.34, device='cpu').to('mps').to(torch.int))
# Cast int8 and uint8 to float and compare results
# See https://github.com/pytorch/pytorch/issues/80009 for more details
cpu_byte = torch.tensor([60, 160, 20, 220], dtype=torch.uint8)
cpu_char = torch.tensor([60, -60, 20, -120], dtype=torch.uint8)
for x_cpu in [cpu_byte, cpu_char]:
x_mps = x_cpu.to('mps')
self.assertEqual(x_mps.to(torch.float32), x_cpu.to(torch.float32))
def test_setitem_scalar(self) -> None:
device = 'mps'
for dtype in [torch.int32, torch.float32, torch.int64]:
for i in range(3, 6):
for j in range(3, 6):
t = torch.zeros(i, j, dtype=dtype, device=device)
self.assertEqual(t.sum(), 0)
t[1, 1] = 1
t[2, 1] = j
t[1, 2] = i
self.assertEqual(t[1, 1], 1)
self.assertEqual(t[1, 2], i)
self.assertEqual(t[2, 1], j)
self.assertEqual(t.sum(), 1 + i + j)
def test_stride_of_strides(self) -> None:
x = torch.rand(32, 1, device='mps')
y = x.as_strided(size=(32, 2), stride=(1, 0))
# Casting stride of strided tensor to CPU use to crash with "buffer is not large enough." assert
# See https://github.com/pytorch/pytorch/issues/79181#issuecomment-1154683435
z = y.as_strided(size=(32, 3), stride=(1, 0)).to("cpu")
self.assertEqual(x.to("cpu").as_strided(size=(32, 3), stride=(1, 0)), z)
def test_type_casting(self):
# https://github.com/pytorch/pytorch/issues/81567
def helper(data, to_dtype):
a_cpu = torch.tensor(data)
a_mps = a_cpu.to(torch.device('mps'))
res_cpu = a_cpu.type(to_dtype)
res_mps = a_mps.type(to_dtype)
self.assertEqual(res_cpu, res_mps)
helper([9.0, 3.0, 5.0, 4.0], torch.LongTensor)
helper([9.0, 3.0, 5.0, 4.0], torch.FloatTensor)
helper([9.0, 3.0, 5.0, 4.0], torch.IntTensor)
helper([9.0, 3.0, 5.0, 4.0], torch.ShortTensor)
helper([9.0, 3.0, 5.0, 4.0], torch.HalfTensor)
helper([9.0, 3.0, 5.0, 4.0], torch.CharTensor)
helper([9.0, 3.0, 5.0, 4.0], torch.ByteTensor)
def test_to_casting(self):
# https://github.com/pytorch/pytorch/issues/81567
def helper(data, to_dtype):
a_cpu = torch.tensor(data)
a_mps = a_cpu.to(torch.device('mps'))
res_cpu = a_cpu.to(to_dtype)
res_mps = a_mps.to(to_dtype)
self.assertEqual(res_cpu, res_mps)
helper([9.0, 3.0, 5.0, 4.0], torch.int64)
helper([9.0, 3.0, 5.0, 4.0], torch.float)
helper([9.0, 3.0, 5.0, 4.0], torch.int32)
helper([9.0, 3.0, 5.0, 4.0], torch.short)
helper([9.0, 3.0, 5.0, 4.0], torch.half)
helper([9.0, 3.0, 5.0, 4.0], torch.int8)
helper([9.0, 3.0, 5.0, 4.0], torch.uint8)
def test_storage_offset_greater_than_src_nbytes(self):
# https://github.com/pytorch/pytorch/issues/80844
n_tensors = 100
n_tensor_elems = 784
elems = torch.arange(n_tensors * n_tensor_elems, dtype=torch.float32)
tensor_list = []
for i in range(n_tensors - 1):
# create a list of contiguous view tensors (view tensor created by the slice op)
t = elems[n_tensor_elems * i : n_tensor_elems * (i + 1)]
tensor_list.append(t)
for i in range(n_tensors - 1):
t = tensor_list[i].view(1, n_tensor_elems)
t_mps = t.to("mps")
self.assertEqual(t, t_mps.cpu(), f"i={i}")
# See https://github.com/pytorch/pytorch/issues/82427
# and https://github.com/pytorch/pytorch/issues/83692
def test_full_bugs(self):
# Test should not crash
x = torch.full((3, 3), True, device='mps')
# torch.full should work for uint8
y_mps = torch.full((2, 2), 247, device='mps', dtype=torch.uint8)
y_cpu = torch.full((2, 2), 247, device='cpu', dtype=torch.uint8)
self.assertEqual(y_mps, y_cpu)
def test_div_bugs(self):
for (dtype, mode) in itertools.product(integral_types(), ['trunc', 'floor']):
x = torch.tensor(list(range(1, 11)), device='mps', dtype=dtype)
y = torch.div(x, 101, rounding_mode=mode)
self.assertEqual(y.sum(), 0)
# See https://github.com/pytorch/pytorch/issues/82663
def test_bool_expand(self):
x = torch.tensor([[1], [0]], dtype=torch.bool, device='mps')
y = torch.tensor([0, 1], dtype=torch.bool, device='mps')
self.assertFalse(torch.equal(x.expand(2, 2), y.expand(2, 2)))
def test_int_expand(self):
x = torch.tensor([[1], [0]], dtype=torch.int8, device='mps')
y = torch.tensor([0, 1], dtype=torch.int8, device='mps')
self.assertFalse(torch.equal(x.expand(2, 2), y.expand(2, 2)))
# Empty unary op should return tensor of the same size
def test_empty_neg(self):
x = torch.tensor([[]], device='mps')
y = -x
self.assertEqual(x, y)
def _test_unique_scalar_empty(self, dtype, device, f):
# test scalar
x = torch.tensor(0, dtype=dtype, device=device)
unique, inverse, counts = f(x, return_inverse=True, return_counts=True)
expected_unique = torch.tensor([0], dtype=dtype, device=device)
expected_inverse = torch.tensor(0, device=device)
expected_counts = torch.tensor([1], device=device)
self.assertEqual(unique, expected_unique)
self.assertEqual(inverse, expected_inverse)
self.assertEqual(counts, expected_counts)
# test zero sized tensor
x = torch.zeros((0, 0, 3), dtype=dtype, device=device)
unique, inverse, counts = f(x, return_inverse=True, return_counts=True)
expected_unique = torch.tensor([], dtype=dtype, device=device)
expected_inverse = torch.empty((0, 0, 3), dtype=torch.long, device=device)
expected_counts = torch.tensor([], dtype=torch.long, device=device)
self.assertEqual(unique, expected_unique)
self.assertEqual(inverse, expected_inverse)
self.assertEqual(counts, expected_counts)
def _test_unique_with_expects(self, device, dtype, f, x, expected_unique, expected_inverse, expected_counts, additional_shape):
def ensure_tuple(x):
if isinstance(x, torch.Tensor):
return (x,)
return x
for return_inverse in [True, False]:
for return_counts in [True, False]:
# test with expected
ret = ensure_tuple(f(x, return_inverse=return_inverse, return_counts=return_counts))
self.assertEqual(len(ret), 1 + int(return_inverse) + int(return_counts))
self.assertEqual(expected_unique, ret[0])
if return_inverse:
self.assertEqual(expected_inverse, ret[1])
if return_counts:
count_index = 1 + int(return_inverse)
self.assertEqual(expected_counts, ret[count_index])
# tests per-element unique on a higher rank tensor.
y = x.view(additional_shape)
y_unique, y_inverse, y_counts = f(y, return_inverse=True, return_counts=True)
self.assertEqual(expected_unique, y_unique)
self.assertEqual(expected_inverse.view(additional_shape), y_inverse)
self.assertEqual(expected_counts, y_counts)
def test_unique_all_dtypes(self, device="mps"):
def helper(dtype):
def ensure_tuple(x):
if isinstance(x, torch.Tensor):
return (x,)
return x
if dtype is torch.bool:
x = torch.tensor([True, False, False, False, True, False, True, False], dtype=torch.bool, device=device)
expected_unique = torch.tensor([False, True], dtype=torch.bool, device=device)
expected_inverse = torch.tensor([1, 0, 0, 0, 1, 0, 1, 0], dtype=torch.long, device=device)
expected_counts = torch.tensor([5, 3], dtype=torch.long, device=device)
else:
x = torch.tensor([1, 2, 3, 2, 8, 5, 2, 3], dtype=dtype, device=device)
expected_unique = torch.tensor([1, 2, 3, 5, 8], dtype=dtype, device=device)
expected_inverse = torch.tensor([0, 1, 2, 1, 4, 3, 1, 2], device=device)
expected_counts = torch.tensor([1, 3, 2, 1, 1], device=device)
# test sorted unique
fs = (
lambda x, **kwargs: torch.unique(x, sorted=True, **kwargs),
lambda x, **kwargs: x.unique(sorted=True, **kwargs),
)
x_sliced = torch.empty(x.size(0) * 2, dtype=dtype, device=device)[::2].copy_(x)
xs = (x, x_sliced)
for f, x in product(fs, xs):
self._test_unique_with_expects(device, dtype, f, x, expected_unique, expected_inverse, expected_counts, (2, 2, 2))
self._test_unique_scalar_empty(dtype, device, f)
# test unsorted unique
fs = (
lambda x, **kwargs: torch.unique(x, sorted=False, **kwargs),
lambda x, **kwargs: x.unique(sorted=False, **kwargs)
)
for f, x in product(fs, xs):
self._test_unique_scalar_empty(dtype, device, f)
for return_inverse, return_counts in product((True, False), repeat=2):
ret = ensure_tuple(f(x, return_inverse=return_inverse, return_counts=return_counts))
self.assertEqual(len(ret), 1 + int(return_inverse) + int(return_counts))
x_list = x.tolist()
x_unique_list = ret[0].tolist()
self.assertEqual(expected_unique.tolist(), sorted(x_unique_list))
if return_inverse:
x_inverse_list = ret[1].tolist()
for i, j in enumerate(x_inverse_list):
self.assertEqual(x_list[i], x_unique_list[j])
if return_counts:
count_index = 1 + int(return_inverse)
x_counts_list = ret[count_index].tolist()
for i, j in zip(x_unique_list, x_counts_list):
count = 0
for k in x_list:
if k == i:
count += 1
self.assertEqual(j, count)
[helper(dtype) for dtype in [torch.float32, torch.int64, torch.int32, torch.int16, torch.uint8]]
def test_unique(self):
def helper(x, return_inverse, return_counts):
cpu_x = x
x = cpu_x.detach().clone().to('mps')
result = torch.unique(x, return_inverse=return_inverse, return_counts=return_counts)
result_cpu = torch.unique(cpu_x, return_inverse=return_inverse, return_counts=return_counts)
self.assertEqual(result, result_cpu)
helper(torch.tensor([1, 2, 4, 2, 1]), False, False)
helper(torch.randint(3, (10, )), False, False)
helper(torch.randint(3, (10, )), True, False)
helper(torch.randint(3, (10, )), False, True)
helper(torch.randint(3, (10, )), True, True)
helper(torch.randint(3, (1, )), True, True)
helper(torch.randint(3, (0, )), True, True)
# Regression test for https://github.com/pytorch/pytorch/issues/104879
x = torch.arange(2, device="mps")
self.assertEqual(x.reshape(1, 1, 2).unique(), x)
def test_unique_consecutive(self):
def helper(x, dim, return_inverse, return_counts):
cpu_x = x
x = cpu_x.detach().clone().to('mps')
result = torch.unique_consecutive(x, dim=dim, return_inverse=return_inverse, return_counts=return_counts)
result_cpu = torch.unique_consecutive(cpu_x, dim=dim, return_inverse=return_inverse, return_counts=return_counts)
self.assertEqual(result, result_cpu)
helper(torch.tensor([1, 2, 4, 2, 1]), 0, False, False)
helper(torch.randint(3, (10, )), 0, False, False)
helper(torch.randint(3, (10, )), 0, True, False)
helper(torch.randint(3, (10, )), 0, False, True)
helper(torch.randint(3, (10, )), 0, True, True)
helper(torch.randint(3, (10, )), 0, True, True)
helper(torch.randint(3, (1, )), 0, True, True)
helper(torch.randint(3, (0, )), 0, True, True)
helper(torch.tensor([[1, 1, 2, 3, 3, 2], [1, 1, 1, 2, 2, 1]]), 0, False, False)
helper(torch.tensor([[1, 1, 2, 3, 3, 2], [1, 1, 1, 2, 2, 1]]), 0, True, True)
helper(torch.randint(2, (20, 2)), 0, True, True)
helper(torch.randint(2, (1, 2)), 0, True, True)
helper(torch.randint(2, (0, 2)), 0, True, True)
helper(torch.tensor([[1, 1, 2, 3, 3, 2], [1, 1, 1, 2, 2, 1]]), 1, False, False)
helper(torch.tensor([[1, 1, 2, 3, 3, 2], [1, 1, 1, 2, 2, 1]]), 1, True, True)
helper(torch.randint(2, (2, 20)), 1, True, True)
helper(torch.randint(2, (2, 1)), 1, True, True)
helper(torch.randint(2, (2, 0)), 1, True, True)
# See https://github.com/pytorch/pytorch/issues/85675
def test_cat_non_contiguous(self):
def rotate_subset(data, dim):
x1 = data[:, :, :2, :]
x2 = data[:, :, 2:, :]
self.assertFalse(x1.is_contiguous())
self.assertFalse(x2.is_contiguous())
return torch.concat((x1, x2), dim=dim)
for dtype in MPS_DTYPES:
if dtype == torch.bool:
continue
data = torch.arange(48).to(dtype=dtype).reshape(1, 2, 4, 6)
data = data.to(memory_format=torch.channels_last)
mps_data = data.to("mps")
self.assertEqual(data, mps_data)
for dim in range(data.dim()):
cpu_result = rotate_subset(data, dim)
mps_result = rotate_subset(mps_data, dim)
self.assertEqual(cpu_result, mps_result.to("cpu"))
# TODO: enable memory format test
# self.assertEqual(cpu_result.is_contiguous(), mps_result.is_contiguous())
# Skip if a test needs more memory than the system has.
def _skip_if_exceeds_total_memory(self, required_memory):
if total_memory < required_memory:
self.skipTest(
f"Needs {required_memory / (1024**3):0.01f} GiB RAM, "
f"but only {total_memory / (1024**3):0.01f} GiB is available.")
@parametrize("dtype", MPS_DTYPES)
def test_cat_large_tensor(self, dtype):
a_shape = (1, 11 + (1 << 31), 1)
b_shape = (1, 100, 1)
# Assume up to 1% extra overhead memory might be required.
required_memory = 1.01 * (math.prod(a_shape) + math.prod(a_shape)) * dtype.itemsize
self._skip_if_exceeds_total_memory(required_memory)
a_cpu = make_tensor((1,), dtype=dtype, device='cpu').expand(a_shape)
b_cpu = make_tensor(b_shape, dtype=dtype, device='cpu')
r_cpu = torch.cat([a_cpu, b_cpu], dim=1)
# Pick a subset of output elements to compare, because comparing all of
# them takes too long.
rand_indices = torch.randint(0, a_cpu.shape[1] + b_cpu.shape[1], (10_000,))
r_cpu_part0 = r_cpu[:, rand_indices, :].clone()
r_cpu_part1 = r_cpu[:, -200:, :].clone()
r_cpu_part2 = r_cpu[:, :200, :].clone()
# Delete the CPU result to free up memory for the MPS run.
del r_cpu
a_mps = (
torch.empty(0, dtype=dtype, device='mps')
.set_(a_cpu.untyped_storage().mps())
.as_strided(size=a_cpu.size(), stride=a_cpu.stride())
)
b_mps = b_cpu.to('mps')
try:
r_mps = torch.cat([a_mps, b_mps], dim=1)
except RuntimeError as e:
if "Invalid buffer size" in str(e):
self.skipTest(f"Exceeds max buffer size for MPS: {str(e)}.")
raise e
self.assertEqual(r_mps[:, rand_indices, :], r_cpu_part0)
self.assertEqual(r_mps[:, -200:, :], r_cpu_part1)
self.assertEqual(r_mps[:, :200, :], r_cpu_part2)
def test_large_tensor_to_string(self):
shape = (2, 1 << 31)
# Assume up to 1% extra overhead memory might be required.
required_memory = 1.01 * 2 * math.prod(shape)
self._skip_if_exceeds_total_memory(required_memory)
self.assertEqual(
str(torch.ones(shape, dtype=torch.int8, device='mps')),
(
"tensor([[1, 1, 1, ..., 1, 1, 1],\n"
" [1, 1, 1, ..., 1, 1, 1]], device='mps:0', dtype=torch.int8)"
),
)
# See https://github.com/pytorch/pytorch/issues/152701
def test_jacfwd_cat(self):
def fn(x, y):
return torch.cat((x, y))
x = torch.rand(2, device="mps")
y = torch.rand(3, device="mps")
rc = torch.func.jacfwd(fn)(x, y)
self.assertEqual(rc.shape, (5, 2))
# See https://github.com/pytorch/pytorch/issues/85967
def test_from_numpy_non_contiguous(self):
a = np.arange(9).reshape(3, 3)[:, :2]
t_cpu = torch.tensor(a, device="cpu")
t_mps = torch.tensor(a, device="mps")
self.assertEqual(t_cpu, t_mps.to("cpu"))
# See https://github.com/pytorch/pytorch/issues/86954
def test_copy_non_contiguous(self):
x = torch.arange(27).reshape(3, 3, 3).permute(2, 0, 1)
self.assertFalse(x.is_contiguous())
y = x.to('mps')
self.assertFalse(y.is_contiguous())
self.assertEqual(x, y.to('cpu'))
x = torch.arange(4**3).reshape(4, 4, 4).permute((2, 0, 1))[1:, ::2]
y = x.to('mps')
self.assertEqual(x, y.to('cpu'))
x = torch.full((4, 4, 4, 4), 13, device="cpu")
y = torch.full((4, 4, 4, 4), 13, device="mps")
z = torch.arange(4**4).reshape(4, 4, 4, 4).permute(3, 2, 0, 1)[1::, ::2]
x.permute(3, 2, 1, 0)[1::, ::2] = z
# As y is on MPS and z on CPU, this dispatches to a copy operator
y.permute(3, 2, 1, 0)[1::, ::2] = z
self.assertEqual(x, y.to('cpu'))
# See https://github.com/pytorch/pytorch/issues/95417
def test_copy_storage_offset(self):
x_cpu = torch.zeros(5, device="cpu", dtype=torch.float32)
x_mps = torch.zeros(5, device="mps", dtype=torch.float32)
update_cpu = torch.tensor([1, 1], device="cpu", dtype=torch.int64)
update_mps = torch.tensor([1, 1], device="mps", dtype=torch.int64)
x_cpu[2:4] = update_cpu
x_mps[2:4] = update_mps # implicit type casting and copy
self.assertEqual(x_cpu, x_mps)
x_cpu[2:4] = update_mps # implicit device moving and copy
self.assertEqual(x_cpu, x_mps)
def test_copy_broadcasting(self):
def helper(src_shape, dst_shape, src_dtype, dst_dtype):
cpu_src = torch.randint(0, 127, src_shape).to(src_dtype)
cpu_dst = torch.randint(0, 127, dst_shape).to(dst_dtype)
cpu_result = cpu_dst.copy_(cpu_src)
mps_src = cpu_src.to("mps")
mps_dst = cpu_dst.to("mps")
mps_result = mps_dst.copy_(mps_src)
self.assertEqual(cpu_result, mps_result)
test_dtypes = [torch.float32, torch.int32, torch.int16, torch.int8]
for (src_dtype, dst_dtype) in itertools.product(test_dtypes, test_dtypes):
helper((2, 1), (2, 3), src_dtype, dst_dtype)
helper((2, 1), (2, 2), src_dtype, dst_dtype)
helper((3, 1, 4, 1), (3, 4, 4, 5), src_dtype, dst_dtype)
helper((3,), (2, 3), src_dtype, dst_dtype)
helper((2,), (2, 2), src_dtype, dst_dtype)
helper((4, 1, 5), (3, 4, 4, 5), src_dtype, dst_dtype)
helper((4, 1, 5), (4, 0, 5), src_dtype, dst_dtype)
helper((1, 5), (4, 0, 5), src_dtype, dst_dtype)
helper((3, 1, 0), (3, 5, 0), src_dtype, dst_dtype)
helper((0, 1, 0), (0, 5, 0), src_dtype, dst_dtype)
# Regression test for https://github.com/pytorch/pytorch/issues/107867
self.assertEqual(torch.tensor([[1]], device='mps').item(), 1.0)
# See https://github.com/pytorch/pytorch/pull/84742
# and https://github.com/pytorch/pytorch/pull/78319
@parametrize("binop", ['add', 'sub', 'mul', 'div'])
def test_binops_dtype_precedence(self, binop):
# Test dtype precedence (casting order) in binary operations by comparing to CPU result
# Example values for all dtypes supported on the MPS backend
sample_vals = {
torch.bool: [False, True],
torch.int16: [-15, 0, 1, 10],
torch.int32: [-376, 0, 1, 13],
torch.int64: [-8, 0, 1, 77],
torch.float16: [-234.5, 0.0, 1.0, 2.0],
torch.float32: [-1.0, 0.0, 0.1, 111.99],
}
# Test all combinations of dtypes, operations, dimensionality
for dtype1, dtype2 in itertools.product(sample_vals, repeat=2):
# bool minus bool is generally unsupported, so skip
if binop == 'sub' and (dtype1 == torch.bool or dtype2 == torch.bool):
continue
full_shape = (10,)
for val1, val2 in itertools.product(sample_vals[dtype1], sample_vals[dtype2]):
# print(f'{dtype1},{dtype2}: ({val1}).{binop}({val2})')
# print(getattr(torch.tensor(val1, dtype=dtype1, device='mps'), binop)
# (torch.tensor(val2, dtype=dtype2, device='mps')))
# print(getattr(torch.tensor(val1, dtype=dtype1, device='cpu'), binop)
# (torch.tensor(val2, dtype=dtype2, device='cpu')))
self.assertEqual(
getattr(torch.tensor(val1, dtype=dtype1, device='mps'), binop)
(torch.tensor(val2, dtype=dtype2, device='mps')),
getattr(torch.tensor(val1, dtype=dtype1, device='cpu'), binop)
(torch.tensor(val2, dtype=dtype2, device='cpu')))
self.assertEqual(
getattr(torch.tensor([val1], dtype=dtype1, device='mps'), binop)
(torch.tensor([val2], dtype=dtype2, device='mps')),
getattr(torch.tensor([val1], dtype=dtype1, device='cpu'), binop)
(torch.tensor([val2], dtype=dtype2, device='cpu')))
self.assertEqual(
getattr(torch.tensor(val1, dtype=dtype1, device='mps'), binop)
(torch.tensor([val2], dtype=dtype2, device='mps')),
getattr(torch.tensor(val1, dtype=dtype1, device='cpu'), binop)
(torch.tensor([val2], dtype=dtype2, device='cpu')))
self.assertEqual(
getattr(torch.tensor([val1], dtype=dtype1, device='mps'), binop)
(torch.tensor(val2, dtype=dtype2, device='mps')),
getattr(torch.tensor([val1], dtype=dtype1, device='cpu'), binop)
(torch.tensor(val2, dtype=dtype2, device='cpu')))
# Test tensors created with torch.full
x1 = torch.full(full_shape, val1, dtype=dtype1, device='mps')
y1 = torch.tensor(val2, dtype=dtype2, device='mps')
x2 = torch.full(full_shape, val1, dtype=dtype1, device='cpu')
y2 = torch.tensor(val2, dtype=dtype2, device='cpu')
self.assertEqual(getattr(x1, binop)(y1), getattr(x2, binop)(y2))
x3 = torch.tensor(val1, dtype=dtype1, device='mps')
y3 = torch.full(full_shape, val2, dtype=dtype2, device='mps')
x4 = torch.tensor(val1, dtype=dtype1, device='cpu')
y4 = torch.full(full_shape, val2, dtype=dtype2, device='cpu')
self.assertEqual(getattr(x3, binop)(y3), getattr(x4, binop)(y4))
self.assertEqual(
getattr(torch.tensor(val1, dtype=dtype1, device='mps'), binop)
(torch.full(full_shape, val2, dtype=dtype2, device='mps')),
getattr(torch.tensor(val1, dtype=dtype1, device='cpu'), binop)
(torch.full(full_shape, val2, dtype=dtype2, device='cpu')))
def test_xor_non_contigous(self):
# See https://github.com/pytorch/pytorch/issues/145203
x_mps = torch.randint(-16000, 16000, (10, 2), dtype=torch.int16, device="mps")
x_cpu = x_mps.detach().cpu()
x_mps[:, 0] ^= 3
x_cpu[:, 0] ^= 3
self.assertEqual(x_mps.cpu(), x_cpu)
def test_nansum(self):
def helper(dtype, noncontiguous, dim):
zero_cpu = torch.zeros((), dtype=dtype)
# Randomly scale the values
scale = random.randint(10, 100)
x_cpu: torch.Tensor = make_tensor(
(5, 5), dtype=dtype, device='cpu',
low=-scale, high=scale, noncontiguous=noncontiguous)
if dtype.is_floating_point:
nan_mask_cpu = x_cpu < (0.2 * scale)
x_no_nan_cpu = torch.where(nan_mask_cpu, zero_cpu, x_cpu)
x_cpu[nan_mask_cpu] = np.nan
else:
x_no_nan_cpu = x_cpu
x_mps = x_cpu.to('mps')
actual_out_mps = torch.empty(0, dtype=dtype, device='mps')
expect_out_cpu = torch.empty(0, dtype=dtype)
dim_kwargs = {"dim": dim} if dim is not None else {}
expect = torch.sum(x_no_nan_cpu, **dim_kwargs)
actual_cpu = torch.nansum(x_cpu, **dim_kwargs)
# Sanity check on CPU
self.assertEqual(expect, actual_cpu)
# Test MPS
actual_mps = torch.nansum(x_mps, **dim_kwargs)
# Test out= variant
torch.nansum(x_mps, out=actual_out_mps, **dim_kwargs)
torch.nansum(x_cpu, out=expect_out_cpu, **dim_kwargs)
self.assertEqual(expect, actual_mps)
self.assertEqual(expect_out_cpu, actual_out_mps)
args = itertools.product(
(torch.float16, torch.float32, torch.int32, torch.int64), # dtype
(True, False), # noncontiguous
(0, 1, None), # dim
)
for dtype, noncontiguous, dim in args:
with self.subTest(dtype=dtype, noncontiguous=noncontiguous, dim=dim):
helper(dtype, noncontiguous, dim)
def test_cumsum_all_dtypes(self):
def helper(dtype):
t = torch.tensor([1, 1, 1, 1], device="mps", dtype=dtype)
t_cpu = torch.tensor([1, 1, 1, 1], device="cpu")
a = t.cumsum(0, dtype=dtype)
a_cpu = t_cpu.cumsum(0, dtype=dtype)
self.assertEqual(a.cpu(), a_cpu)
[helper(dtype) for dtype in [torch.int8, torch.int16, torch.int32, torch.int64, torch.float32]]
def test_cumsum_bool(self):
a = torch.ones(2**16, dtype=torch.bool)
t_cpu = a.cumsum(0)
t_mps = a.to("mps").cumsum(0)
self.assertEqual(t_cpu, t_mps)
def test_cumsum_minus_one_axis(self):
def helper(dtype):
# Test with axis -1
cpu_x = None
if dtype == torch.float32:
cpu_x = torch.randn(10, 3, device='cpu', dtype=torch.float32)
else:
cpu_x = torch.randint(0, 20, (10, 3), device='cpu', dtype=torch.float32)
x = cpu_x.detach().clone().to('mps')
cpu_y = cpu_x.cumsum(-1)
y = x.cumsum(-1)
self.assertEqual(y, cpu_y)
[helper(dtype) for dtype in [torch.float32, torch.int16, torch.int32, torch.uint8]]
def test_cumprod_all_dtypes(self):
def helper(dtype):
t = torch.tensor([1, 1, 1, 1], device="mps", dtype=dtype)
t_cpu = torch.tensor([1, 1, 1, 1], device="cpu")
a = t.cumprod(0, dtype=dtype)
a_cpu = t_cpu.cumprod(0, dtype=dtype)
self.assertEqual(a.cpu(), a_cpu)
[helper(dtype) for dtype in [torch.int8, torch.int16, torch.int32, torch.int64, torch.float32]]
def test_cumprod_minus_one_axis(self):
def helper(dtype):
# Test with axis -1
cpu_x = None
if dtype == torch.float32:
cpu_x = torch.randn(10, 3, device='cpu', dtype=torch.float32)
else:
cpu_x = torch.randint(0, 20, (10, 3), device='cpu', dtype=torch.float32)
x = cpu_x.detach().clone().to('mps')
cpu_y = cpu_x.cumprod(-1)
y = x.cumprod(-1)
self.assertEqual(y, cpu_y)
[helper(dtype) for dtype in [torch.float32, torch.int16, torch.int32, torch.uint8]]
def test_median_int16(self):
def helper(shape, dtype):
cpu_x = torch.randint(-9999, 9999, shape, device='cpu', dtype=dtype)
x = cpu_x.detach().clone().to('mps')
median_result = torch.median(x)
median_result_cpu = torch.median(cpu_x)
self.assertEqual(median_result, median_result_cpu)
helper((2, 8, 4, 5), torch.int16)
def test_activation_checkpoint_does_not_error(self):
from torch.utils.checkpoint import checkpoint
for use_reentrant in (True, False):
a = torch.tensor(1., device="mps", requires_grad=True)
def fn(x):
return x.sin().cos().exp()
out = checkpoint(fn, a, use_reentrant=use_reentrant)
out.backward()
def test_as_strided(self):
values = [[1.0, 2.0, 3.0], [4.0, 5.0, 6.0], [7.0, 8.0, 9.0]]
values_1 = [[1.0, 1.0], [1.0, 1.0]]
cpu_x = torch.tensor(values, device='cpu')
ones1 = torch.tensor(values_1, device='mps')
x = cpu_x.detach().clone().to('mps').requires_grad_()
strided_cpu = torch.as_strided(cpu_x, (2, 2), (1, 2))
strided_mps = torch.as_strided(x, (2, 2), (1, 2))
self.assertEqual(strided_mps, strided_cpu)
strided_cpu_out = strided_cpu + ones1.to('cpu')
strided_mps_out = strided_mps + ones1
self.assertEqual(strided_cpu_out, strided_mps_out)
# test with storage offsets
cpu_x = torch.rand(3, 3, device='cpu')
mps_x = cpu_x.to('mps')
strided_cpu1 = torch.as_strided(cpu_x, (2, 2), (1, 2), 0)
strided_mps1 = torch.as_strided(mps_x, (2, 2), (1, 2), 0)
strided_cpu2 = torch.as_strided(cpu_x, (2, 2), (1, 2), 1)
strided_mps2 = torch.as_strided(mps_x, (2, 2), (1, 2), 1)
strided_cpu_out = strided_cpu1 - strided_cpu2
strided_mps_out = strided_mps1 - strided_mps2
self.assertEqual(strided_cpu_out, strided_mps_out)
def test_unfold(self):
x = torch.arange(1., 8)
x_mps = torch.arange(1., 8, device="mps")
y = x.unfold(0, 2, 1)
y_mps = x_mps.unfold(0, 2, 1)
self.assertEqual(y, y_mps)
def test_unfold_all_devices_and_dtypes(self):
supported_dtypes = [torch.float32, torch.float16, torch.int64, torch.int32, torch.int16, torch.uint8]
for dt in supported_dtypes:
x = torch.empty((0, 1, 3, 0), dtype=dt, device="mps")
self.assertEqual((0, 1, 1, 0, 3), x.unfold(2, 3, 2).shape)
def test_unfold_scalars(self):
x = torch.tensor(0.5, device="mps")
# unfold on a 0-dimensional tensor should always return a 1-d dimensional
# tensor of shape [size] (i.e., the second parameter to unfold)
self.assertEqual(torch.empty(0, device="mps"), x.unfold(0, 0, 1))
self.assertEqual(torch.empty(0, device="mps"), x.unfold(0, 0, 2))
self.assertEqual(torch.tensor([0.5], device="mps"), x.unfold(0, 1, 1))
def test_bincount_simple(self):
input = torch.randint(0, 8, (5,), dtype=torch.int32, device="mps")
input_cpu = input.to("cpu")
weights = torch.linspace(0, 1, steps=5, device="mps", dtype=torch.float32)
weights_cpu = weights.to("cpu")
x = torch.bincount(input)
x_cpu = torch.bincount(input_cpu)
self.assertEqual(x, x_cpu)
y = input.bincount(weights)
y_cpu = input_cpu.bincount(weights_cpu)
self.assertEqual(y, y_cpu)
def test_bincount_reduction(self):
device = "mps"
# negative input throws
with self.assertRaisesRegex(RuntimeError, '1-d non-negative integral'):
torch.bincount(torch.tensor([1, -1], device=device, dtype=torch.int32))
# n-d input, with n > 1 throws
with self.assertRaisesRegex(RuntimeError, '1-d non-negative integral'):
torch.bincount(torch.tensor([[1, 2], [3, 4]], device=device))
# minlength < 0 throws
with self.assertRaisesRegex(RuntimeError, 'minlength should be >= 0'):
torch.bincount(torch.tensor([1, 3], device=device),
torch.tensor([.2, .2], device=device),
minlength=-1)
# n-d weights, with n > 1 throws
with self.assertRaisesRegex(RuntimeError, '1-d'):
torch.bincount(torch.tensor([1, 0], device=device, dtype=torch.int32),
torch.tensor([[1., 0.3], [1., 0.3]], device=device, dtype=torch.float))
# input and weights dim mismatch
with self.assertRaisesRegex(RuntimeError, 'same length'):
torch.bincount(torch.tensor([1, 0], device=device, dtype=torch.int32),
torch.tensor([1., 0.3, 0.5], device=device, dtype=torch.float))
# 1-d input with no elements and default minlength
self.assertEqual(torch.bincount(torch.tensor([], device=device, dtype=torch.long)),
torch.zeros(0, dtype=torch.long, device=device))
# 1-d input with no elements and specified minlength
self.assertEqual(torch.bincount(torch.tensor([], device=device, dtype=torch.long), minlength=10),
torch.zeros(10, dtype=torch.long, device=device))
# test tensor method without weights
long_counts = torch.tensor(
[0, 3, 2, 1, 3], dtype=torch.uint8, device=device).bincount()
self.assertEqual(
torch.tensor([1, 1, 1, 2], dtype=torch.int64, device=device),
long_counts)
# test avoiding overflow for uint8 (#76979)
count_uint8 = torch.tensor([0, 1, 2, 3, 255], dtype=torch.uint8, device=device).bincount()
count_int16 = torch.tensor([0, 1, 2, 3, 255], dtype=torch.int16, device=device).bincount()
self.assertEqual(count_uint8, count_int16)
# test minlength functionality
int_counts = torch.bincount(
torch.tensor([1, 1, 1, 1], device=device, dtype=torch.int32), minlength=5)
self.assertEqual(
torch.tensor([0, 4, 0, 0, 0], dtype=torch.int64, device=device),
int_counts)
# test weights
byte_counts = torch.bincount(
torch.tensor([0, 1, 1, 1, 4], device=device, dtype=torch.int32),
torch.tensor([.1, .2, .3, .4, .5], device=device))
self.assertEqual(
torch.tensor([0.1, 0.9, 0, 0, 0.5], device=device), byte_counts)
byte_counts = torch.bincount(
torch.tensor([0, 1, 1, 1, 4], device=device, dtype=torch.int32),
torch.tensor([1, 2, 3, 4, 5], dtype=torch.int8, device=device))
self.assertEqual(
torch.tensor([1, 9, 0, 0, 5], device=device, dtype=torch.int32), byte_counts)
# test non-contiguous inputs and weights
inputs = torch.tensor([[0, 0], [3, 1], [2, 1], [1, 1], [3, 4]], device=device, dtype=torch.int32)
weights = torch.tensor([[.1, 1], [.2, 2], [.3, 3], [.4, 4], [.5, 5]], device=device)
for i in [0, 1]:
assert not inputs[:, i].is_contiguous(), "Inputs are supposed to be non-contiguous"
assert not weights[:, i].is_contiguous(), "Weights are supposed to be non-contiguous"
# inputs are non-contiguous but weights are contiguous
self.assertEqual(inputs[:, 0].bincount(), torch.tensor([1, 1, 1, 2]))
# inputs and weights are non-contiguous
self.assertEqual(
inputs[:, 1].bincount(weights[:, 1]),
torch.tensor([1, 9, 0, 0, 5], dtype=torch.float32))
# weights are non-contiguous but inputs are contiguous
self.assertEqual(inputs[:, 1].contiguous().bincount(weights[:, 1]),
torch.tensor([1, 9, 0, 0, 5], dtype=torch.float32))
# test bincount on non-contiguous slices
all0s = torch.zeros((32, 2), dtype=torch.int32, device=device)
self.assertEqual(all0s[:, 0].bincount(), torch.tensor([32]))
all1s = torch.ones((32, 2), dtype=torch.int32, device=device)
self.assertEqual(all1s[:, 0].bincount(), torch.tensor([0, 32]))
# test large number of bins - global memory use
big_exp = torch.zeros(100, device=device)
big_exp[-1] = 50.0
big_w = torch.tensor([.5] * 100, device=device)
big_out = torch.tensor([99] * 100, device=device, dtype=torch.int32).bincount(big_w)
self.assertEqual(big_exp, big_out)
# test large input size
big_exp = torch.zeros(2, device=device, dtype=torch.int64)
big_exp[1] = 10
big_out = torch.ones(10, dtype=torch.int8, device=device).bincount()
self.assertEqual(big_exp, big_out)
def test_bincount(self):
device = "mps"
input_size = (5000,)
w = torch.randn(input_size, dtype=torch.float, device=device)
w_cpu = w.cpu()
t = torch.randint(50, input_size, dtype=torch.int8, device=device)
self.assertEqual(t.cpu().bincount(), t.bincount())
self.assertEqual(t.cpu().bincount(w_cpu), t.bincount(w))
t = torch.randint(500, input_size, dtype=torch.int32, device=device)
self.assertEqual(t.cpu().bincount(), t.bincount())
self.assertEqual(t.cpu().bincount(w_cpu), t.bincount(w))
t = torch.randint(2000, input_size, dtype=torch.int32, device=device)
self.assertEqual(t.cpu().bincount(), t.bincount())
self.assertEqual(t.cpu().bincount(w_cpu), t.bincount(w))
t = torch.zeros([10], dtype=torch.int32, device=device)
t[0] = 35488
counted = t.bincount(minlength=65536)
self.assertEqual(torch.sum(counted), 10)
def test_sum_backward(self):
def helper(n, c):
values = [[1.0, 2.0, 3.0], [4.0, 5.0, 6.0], [7.0, 8.0, 9.0]]
cpu_x = torch.tensor(values, device='cpu', requires_grad=True)
x = cpu_x.detach().clone().to('mps').requires_grad_()
all_sum = torch.sum(x)
all_sum_cpu = torch.sum(cpu_x)
all_sum.backward()
all_sum_cpu.backward()
self.assertEqual(all_sum, all_sum_cpu)
self.assertEqual(x.grad, cpu_x.grad)
helper(3, 3)
# L1 loss
def test_l1_loss(self):
def helper(shape, reduction):
# create the criterion
loss = torch.nn.L1Loss(reduction=reduction)
inputCPU = torch.randn(shape, device='cpu', dtype=torch.float, requires_grad=True)
targetCPU = torch.randn(shape, device='cpu', dtype=torch.float, requires_grad=False)
inputMPS = inputCPU.detach().clone().to('mps').requires_grad_()
targetMPS = targetCPU.detach().clone().to('mps')
# forward pass
outputCPU = loss(inputCPU, targetCPU)
outputMPS = loss(inputMPS, targetMPS)
self.assertEqual(outputCPU, outputMPS)
# backward pass
if reduction != 'none':
# chose 2 just to make the grad_output > 1 in backward pass
outputCPU.backward(gradient=torch.full_like(outputCPU, 2))
outputMPS.backward(gradient=torch.full_like(outputMPS, 2))
self.assertEqual(inputCPU.grad, inputMPS.grad)
helper([8, 5, 4], 'none')
helper([7, 5, 2, 4], 'sum')
# verify if changes in shape would cause cached graph lookup problems
helper([7, 5, 2, 4, 6], 'sum')
helper([8, 4, 5, 7, 6], 'mean')
# Mean Squared Error
def test_mse_loss(self):
def helper(shape, reduction):
# create the criterion
loss = torch.nn.MSELoss(reduction=reduction)
inputCPU = torch.randn(shape, device='cpu', dtype=torch.float, requires_grad=True)
targetCPU = torch.randn(shape, device='cpu', dtype=torch.float, requires_grad=False)
inputMPS = inputCPU.detach().clone().to('mps').requires_grad_()
targetMPS = targetCPU.detach().clone().to('mps')
# forward pass
outputCPU = loss(inputCPU, targetCPU)
outputMPS = loss(inputMPS, targetMPS)
self.assertEqual(outputCPU, outputMPS)
# backward pass
if reduction != 'none':
# chose 2 just to make the grad_output > 1 in backward pass
outputCPU.backward(gradient=torch.full_like(outputCPU, 2))
outputMPS.backward(gradient=torch.full_like(outputMPS, 2))
self.assertEqual(inputCPU.grad, inputMPS.grad)
helper([8, 5, 4], 'none')
helper([7, 5, 2, 4], 'sum')
# verify if changes in shape would cause cached graph lookup problems
helper([7, 5, 2, 4, 6], 'sum')
helper([8, 4, 5, 7, 6], 'mean')
helper((3, 3, 0), 'sum')
helper((3, 3, 0), 'mean')
helper((3, 3, 0), 'none')
def test_mse_loss_strided_output(self):
# https://github.com/pytorch/pytorch/issues/124621
lf = nn.MSELoss(reduction='none')
model_cpu = nn.Sequential(
nn.Conv1d(3, 3, 1),
)
model_mps = copy.deepcopy(model_cpu).to("mps")
x = torch.randn(128, 10, 3)
x = x.permute(0, 2, 1)
x_mps = x.detach().clone().to("mps").permute(0, 2, 1)
x_mps = x_mps.permute(0, 2, 1)
y = model_cpu(x)
y_mps = model_mps(x_mps)
y = y.permute(0, 2, 1)[:, :5, :]
y_mps = y_mps.permute(0, 2, 1)[:, :5, :]
y_hat = torch.randn(128, 5, 3)
y_hat_mps = y_hat.detach().clone().to("mps")
loss = lf(y, y_hat)
loss_mps = lf(y_mps, y_hat_mps)
self.assertEqual(loss, loss_mps)
def test_mse_loss_unsupported_types(self):
loss = nn.MSELoss()
for dtype in MPS_DTYPES:
a_mps = torch.tensor([0, 1, 2], dtype=dtype, device='mps')
a_cpu = torch.tensor([0, 1, 2], dtype=dtype, device='cpu')
if dtype.is_floating_point:
self.assertEqual(loss(a_mps, a_mps), loss(a_cpu, a_cpu))
continue
self.assertRaises(RuntimeError, lambda: loss(a_mps, a_mps))
self.assertRaises(RuntimeError, lambda: loss(a_cpu, a_cpu))
# Binary Cross Enropy
def test_bce_loss_simple(self):
def helper(shape, reduction):
# create the criterion
loss = torch.nn.BCELoss(reduction=reduction)
# input and target must be within [0..1]
input_t = np.random.random_sample(size=shape).astype(np.float32)
target_t = np.random.random_sample(size=shape).astype(np.float32)
inputCPU = torch.tensor(input_t, device='cpu', dtype=torch.float, requires_grad=True)
targetCPU = torch.tensor(target_t, device='cpu', dtype=torch.float, requires_grad=False)
inputMPS = inputCPU.detach().clone().to('mps').requires_grad_()
targetMPS = targetCPU.detach().clone().to('mps')
# forward pass
outputCPU = loss(inputCPU, targetCPU)
outputMPS = loss(inputMPS, targetMPS)
self.assertEqual(outputCPU, outputMPS)
# backward pass
if reduction != 'none':
# chose 0.6 just to have the grad_output != 1
outputCPU.backward(gradient=torch.full_like(outputCPU, 0.6))
outputMPS.backward(gradient=torch.full_like(outputMPS, 0.6))
self.assertEqual(inputCPU.grad, inputMPS.grad)
helper([8, 5, 4], 'none')
helper([7, 5, 2, 4], 'sum')
# verify if changes in shape would cause cached graph lookup problems
helper([7, 5, 2, 4, 6], 'sum')
helper([8, 4, 5, 7, 6], 'mean')
helper([1, 1, 32, 32], 'mean')
def test_bce_loss_always_nonnegative(self):
target = torch.ones(5, device='mps')
input = torch.ones(5, device='mps')
self.assertEqual((nn.BCELoss()(input, target) < 0).sum(), 0)
target = torch.zeros(5, device='mps')
input = torch.zeros(5, device='mps')
self.assertEqual((nn.BCELoss()(input, target) < 0).sum(), 0)
def test_bce_loss_size_mismatch(self):
bceloss = nn.BCELoss()
a = torch.rand(25, device='mps')
b = torch.rand(25, 1, device='mps')
with self.assertRaisesRegex(ValueError, r'Using a target size \('):
bceloss(a, b)
def test_bce_with_logits_gives_same_result_as_sigmoid_and_bce_loss_large_tensors_with_grad(self):
x_size = 1024
y_size = 256
target = torch.rand(x_size, y_size, device='mps')
for reduction in ['none', 'mean', 'sum']:
output_sig = torch.rand(x_size, y_size, device='mps') - 0.5
output_logits = output_sig.detach().clone()
output_sig.requires_grad = True
output_logits.requires_grad = True
weight = torch.rand(y_size, device='mps')
loss_sig = nn.BCELoss(weight, reduction=reduction)(
torch.sigmoid(output_sig), target
)
loss_logits = nn.BCEWithLogitsLoss(weight, reduction=reduction)(
output_logits, target
)
self.assertEqual(loss_logits, loss_sig)
if reduction == 'none':
grad = torch.rand(x_size, y_size, device='mps')
loss_sig.backward(grad)
loss_logits.backward(grad)
else:
loss_sig.backward()
loss_logits.backward()
self.assertEqual(output_sig.grad, output_logits.grad)
def test_bce_with_logits_has_correct_grad_at_zero(self):
output = torch.zeros(3, 1, requires_grad=True, device='mps')
target = torch.zeros(3, 1, device='mps')
nn.BCEWithLogitsLoss(reduction='sum')(output, target).backward()
expected_grad = torch.empty(3, 1, device='mps').fill_(0.5)
self.assertEqual(output.grad, expected_grad)
def test_bce_with_logits_broadcasts_weights(self):
target = torch.rand(16, 4, device='mps')
output = torch.rand(16, 4, device='mps') - 0.5
weight = torch.rand(4, device='mps')
out1 = nn.BCEWithLogitsLoss(weight)(output, target)
weight = weight.expand(16, 4).contiguous()
out2 = nn.BCEWithLogitsLoss(weight)(output, target)
self.assertEqual(out1, out2)
weight = torch.rand(16, 1, device='mps')
out1 = nn.BCEWithLogitsLoss(weight)(output, target)
weight = weight.expand(16, 4).contiguous()
out2 = nn.BCEWithLogitsLoss(weight)(output, target)
self.assertEqual(out1, out2)
def test_bce_with_logits_ones_in_pos_weights_are_the_same_as_none(self):
target = torch.rand(64, 4, device='mps')
output = torch.rand(64, 4, device='mps') - 0.5
pos_weight = torch.ones(64, 4, device='mps')
self.assertEqual(nn.BCEWithLogitsLoss()(output, target),
nn.BCEWithLogitsLoss(pos_weight=pos_weight)(output, target))
def test_bce_with_logits_broadcasts_pos_weights(self):
target = torch.rand(64, 4, device='mps')
output = torch.rand(64, 4, device='mps') - 0.5
pos_weight = torch.rand(4, device='mps')
out1 = nn.BCEWithLogitsLoss(pos_weight=pos_weight)(output, target)
pos_weight1 = pos_weight.expand(1, 4)
out2 = nn.BCEWithLogitsLoss(pos_weight=pos_weight1)(output, target)
pos_weight2 = pos_weight.expand(64, 4)
out3 = nn.BCEWithLogitsLoss(pos_weight=pos_weight2)(output, target)
self.assertEqual(out1, out2)
self.assertEqual(out1, out3)
def test_bce_with_logits_with_pos_weight_has_correct_grad_at_zero(self):
output = torch.zeros(3, 1, requires_grad=True, device='mps')
target = torch.zeros(3, 1, device='mps')
pos_weight = torch.ones(3, 1, device='mps')
nn.BCEWithLogitsLoss(pos_weight=pos_weight, reduction='sum')(output, target).backward()
expected_grad = torch.empty(3, 1, device='mps').fill_(0.5)
grad = output.grad
self.assertEqual(grad, expected_grad)
def test_bce_with_logits_stability(self):
output = torch.tensor([0., -120.], device='mps')
target = torch.tensor([0., 1.], device='mps')
pos_weight = torch.tensor([1., 1.], device='mps')
out1 = nn.BCEWithLogitsLoss()(output, target)
self.assertTrue(torch.isfinite(out1).all().item())
out2 = nn.BCEWithLogitsLoss(pos_weight=pos_weight)(output, target)
self.assertTrue(torch.isfinite(out2).all().item())
def test_bce_loss_broadcasts_weights(self):
sigmoid = nn.Sigmoid()
target = torch.rand(16, 4, device='mps')
output = torch.rand(16, 4, device='mps') - 0.5
weight = torch.rand(4, device='mps')
out1 = nn.BCELoss(weight)(sigmoid(output), target)
weight = weight.expand(16, 4).contiguous()
out2 = nn.BCELoss(weight)(sigmoid(output), target)
self.assertEqual(out1, out2)
weight = torch.rand(16, 1, device='mps')
out1 = nn.BCELoss(weight)(sigmoid(output), target)
weight = weight.expand(16, 4).contiguous()
out2 = nn.BCELoss(weight)(sigmoid(output), target)
self.assertEqual(out1, out2)
def test_cross_entropy_loss(self):
# Regression test for https://github.com/pytorch/pytorch/issues/116095
loss = nn.CrossEntropyLoss()
pred = torch.randn(3, 5, requires_grad=True, dtype=torch.float16, device='mps')
target = torch.ones(3, dtype=torch.long, device='mps')
output = loss(pred, target)
output.backward()
def test_log_softmax(self):
values = [[[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]], [[7.0, 8.0, 9.0], [10.0, 11.0, 12.0]]]
cpu_x = torch.tensor(values, device='cpu', requires_grad=True)
mps_x = torch.tensor(values, device='mps', requires_grad=True)
cpu_log_softmax = F.log_softmax(cpu_x, dim=0)
mps_log_softmax = F.log_softmax(mps_x, dim=0)
self.assertEqual(cpu_log_softmax, mps_log_softmax.to('cpu'))
cpu_grad = torch.ones_like(cpu_log_softmax)
mps_grad = torch.ones_like(cpu_log_softmax).to('mps')
cpu_log_softmax.backward(gradient=cpu_grad)
mps_log_softmax.backward(gradient=mps_grad)
self.assertEqual(cpu_x.grad, mps_x.grad.to('cpu'))
def test_log_softmax_large_numbers(self):
values = [
[10.0, 100.0, 1000.0, 10000.0, 100000.0, 1000000.0],
[-10.0, -100.0, -1000.0, -10000.0, -100000.0, -1000000.0]
]
cpu_x = torch.tensor(values, device='cpu', requires_grad=True)
mps_x = torch.tensor(values, device='mps', requires_grad=True)
cpu_log_softmax = F.log_softmax(cpu_x, dim=-1)
mps_log_softmax = F.log_softmax(mps_x, dim=-1)
self.assertEqual(cpu_log_softmax, mps_log_softmax.to('cpu'))
cpu_grad = torch.ones_like(cpu_log_softmax)
mps_grad = torch.ones_like(cpu_log_softmax).to('mps')
cpu_log_softmax.backward(gradient=cpu_grad)
mps_log_softmax.backward(gradient=mps_grad)
self.assertEqual(cpu_x.grad, mps_x.grad.to('cpu'))
def test_eq(self):
values1 = [[[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]], [[7.0, 8.0, 9.0], [10.0, 11.0, 12.0]]]
values2 = [[[1.0, 2.0, 15.0], [4.0, 5.0, 6.0]], [[7.0, 8.0, 9.0], [0.0, 11.0, 12.0]]]
mps_x = torch.tensor(values1, device='mps')
mps_y = torch.tensor(values2, device='mps')
cpu_x = torch.tensor(values1, device='cpu')
cpu_y = torch.tensor(values2, device='cpu')
result_mps = torch.eq(mps_x, mps_y)
result_cpu = torch.eq(cpu_x, cpu_y)
self.assertEqual(result_cpu, result_mps.to('cpu'))
def test_signed_vs_unsigned_comparison(self):
cpu_x = torch.tensor((-1, 2, 3), device='cpu', dtype=torch.uint8)
mps_x = torch.tensor((-1, 2, 3), device='mps', dtype=torch.uint8)
# in the comparison of signed vs. unsigned we should always cast to unsigned
self.assertEqual(cpu_x == -1, mps_x == -1)
self.assertEqual(cpu_x > -1, mps_x > -1)
self.assertEqual(cpu_x < -1, mps_x < -1)
def test_eq_int64(self):
values1 = [[[1, 2, 3], [4, 5, 6]], [[7, 8, 9], [10, 11, 12]]]
values2 = [[[1, 2, 15], [4, 5, 6]], [[7, 8, 9], [0, 11, 12]]]
mps_x = torch.tensor(values1, device='mps')
mps_y = torch.tensor(values2, device='mps')
cpu_x = torch.tensor(values1, device='cpu')
cpu_y = torch.tensor(values2, device='cpu')
result_mps = torch.eq(mps_x, mps_y)
result_cpu = torch.eq(cpu_x, cpu_y)
self.assertEqual(result_cpu, result_mps.to('cpu'))
def test_ne(self):
def helper(shape):
cpu_x = torch.randn(shape, device='cpu', dtype=torch.float)
cpu_y = torch.randn(shape, device='cpu', dtype=torch.float)
mps_x = cpu_x.detach().clone().to('mps')
mps_y = cpu_y.detach().clone().to('mps')
result_mps = torch.ne(mps_x, mps_y)
result_cpu = torch.ne(cpu_x, cpu_y)
self.assertEqual(result_cpu, result_mps.to('cpu'))
helper((2, 3, 4, 5))
def test_ne_scalar(self):
def helper(shape):
cpu_x = torch.randn(shape, device='cpu', dtype=torch.float)
mps_x = cpu_x.detach().clone().to('mps')
result_mps = torch.ne(mps_x, 0.0)
result_cpu = torch.ne(cpu_x, 0.0)
self.assertEqual(result_cpu, result_mps.to('cpu'))
helper((2, 3, 4, 5))
def test_lt(self):
def helper(shape):
cpu_x = torch.randn(shape, device='cpu', dtype=torch.float)
cpu_y = torch.randn(shape, device='cpu', dtype=torch.float)
mps_x = cpu_x.detach().clone().to('mps')
mps_y = cpu_y.detach().clone().to('mps')
result_mps = torch.lt(mps_x, mps_y)
result_cpu = torch.lt(cpu_x, cpu_y)
self.assertEqual(result_cpu, result_mps.to('cpu'))
helper((2, 3, 4, 5))
def test_lt_scalar(self):
def helper(shape):
cpu_x = torch.randn(shape, device='cpu', dtype=torch.float)
mps_x = cpu_x.detach().clone().to('mps')
result_mps = torch.lt(mps_x, 0.0)
result_cpu = torch.lt(cpu_x, 0.0)
self.assertEqual(result_cpu, result_mps.to('cpu'))
helper((2, 3, 4, 5))
def test_le(self):
def helper(shape):
cpu_x = torch.randn(shape, device='cpu', dtype=torch.float)
cpu_y = torch.randn(shape, device='cpu', dtype=torch.float)
mps_x = cpu_x.detach().clone().to('mps')
mps_y = cpu_y.detach().clone().to('mps')
result_mps = torch.le(mps_x, mps_y)
result_cpu = torch.le(cpu_x, cpu_y)
self.assertEqual(result_cpu, result_mps.to('cpu'))
helper((2, 3, 4, 5))
def test_le_scalar(self):
def helper(shape):
cpu_x = torch.randn(shape, device='cpu', dtype=torch.float)
mps_x = cpu_x.detach().clone().to('mps')
result_mps = torch.le(mps_x, 0.0)
result_cpu = torch.le(cpu_x, 0.0)
self.assertEqual(result_cpu, result_mps.to('cpu'))
helper((2, 3, 4, 5))
def test_ge(self):
def helper(shape):
cpu_x = torch.randn(shape, device='cpu', dtype=torch.float)
cpu_y = torch.randn(shape, device='cpu', dtype=torch.float)
mps_x = cpu_x.detach().clone().to('mps')
mps_y = cpu_y.detach().clone().to('mps')
result_mps = torch.ge(mps_x, mps_y)
result_cpu = torch.ge(cpu_x, cpu_y)
self.assertEqual(result_cpu, result_mps.to('cpu'))
helper((2, 3, 4, 5))
def test_ge_scalar(self):
def helper(shape):
cpu_x = torch.randn(shape, device='cpu', dtype=torch.float)
mps_x = cpu_x.detach().clone().to('mps')
result_mps = torch.ge(mps_x, 0.0)
result_cpu = torch.ge(cpu_x, 0.0)
self.assertEqual(result_cpu, result_mps.to('cpu'))
helper((2, 3, 4, 5))
def test_gt(self):
def helper(shape):
cpu_x = torch.randn(shape, device='cpu', dtype=torch.float)
cpu_y = torch.randn(shape, device='cpu', dtype=torch.float)
mps_x = cpu_x.detach().clone().to('mps')
mps_y = cpu_y.detach().clone().to('mps')
result_mps = torch.gt(mps_x, mps_y)
result_cpu = torch.gt(cpu_x, cpu_y)
self.assertEqual(result_cpu, result_mps.to('cpu'))
helper((2, 3, 4, 5))
def test_gt_scalar(self):
def helper(shape):
cpu_x = torch.randn(shape, device='cpu', dtype=torch.float)
mps_x = cpu_x.detach().clone().to('mps')
result_mps = torch.gt(mps_x, 0.0)
result_cpu = torch.gt(cpu_x, 0.0)
self.assertEqual(result_cpu, result_mps.to('cpu'))
helper((2, 3, 4, 5))
def test_argmax(self):
# https://github.com/pytorch/pytorch/issues/98191
cpu_tensor = torch.tensor([[0, 1], [2, 1], [1, 0]])
res_cpu = torch.argmax(cpu_tensor, dim=1)
mps_tensor = cpu_tensor.to(torch.device('mps'))
res_mps = torch.argmax(mps_tensor, dim=1)
self.assertEqual(res_cpu, res_mps)
# https://github.com/pytorch/pytorch/issues/92311
mps_tensor = torch.randn(10, 2, device='mps', dtype=torch.float32)
cpu_tensor = mps_tensor.detach().clone().cpu()
res_mps = torch.argmax(mps_tensor, dim=1)
res_cpu = torch.argmax(cpu_tensor, dim=1)
self.assertEqual(res_cpu, res_mps)
# Test forward argmin argmax
def test_argmin_argmax(self):
def helper(n, c, h, w, reduction_type, dtype=torch.float32):
if reduction_type == "max":
arg_reduction_fn = torch.argmax
else:
arg_reduction_fn = torch.argmin
cpu_x = None
x = None
if (dtype not in [torch.float32, torch.bool]):
cpu_x = torch.randint(50, (n, c, h, w), device='cpu', dtype=dtype, requires_grad=False)
x = cpu_x.detach().clone().to('mps')
elif (dtype == torch.bool):
cpu_x = torch.randint(2, (n, c, h, w), device='cpu', dtype=dtype, requires_grad=False)
x = cpu_x.detach().clone().to('mps')
else:
cpu_x = torch.randn(n, c, h, w, device='cpu', dtype=dtype, requires_grad=True)
x = cpu_x.detach().clone().to('mps').requires_grad_()
y = arg_reduction_fn(x)
ref_y = arg_reduction_fn(cpu_x)
self.assertEqual(y, ref_y)
y_0 = arg_reduction_fn(x, dim=0)
refy_0 = arg_reduction_fn(cpu_x, dim=0)
self.assertEqual(y_0, refy_0)
y_0dim = arg_reduction_fn(x, dim=0, keepdim=True)
refy_0dim = arg_reduction_fn(cpu_x, dim=0, keepdim=True)
self.assertEqual(y_0dim, refy_0dim)
y_1 = arg_reduction_fn(x, dim=1)
refy_1 = arg_reduction_fn(cpu_x, dim=1)
self.assertEqual(y_1, refy_1)
y_1dim = arg_reduction_fn(x, dim=1, keepdim=True)
refy_1dim = arg_reduction_fn(cpu_x, dim=1, keepdim=True)
self.assertEqual(y_1dim, refy_1dim)
y_2 = arg_reduction_fn(x, dim=2)
refy_2 = arg_reduction_fn(cpu_x, dim=2)
self.assertEqual(y_2, refy_2)
y_2dim = arg_reduction_fn(x, dim=2, keepdim=True)
refy_2dim = arg_reduction_fn(cpu_x, dim=2, keepdim=True)
self.assertEqual(y_2dim, refy_2dim)
y_3 = arg_reduction_fn(x, dim=3)
refy_3 = arg_reduction_fn(cpu_x, dim=3)
self.assertEqual(y_3, refy_3)
y_3dim = arg_reduction_fn(x, dim=3, keepdim=True)
refy_3dim = arg_reduction_fn(cpu_x, dim=3, keepdim=True)
self.assertEqual(y_3dim, refy_3dim)
helper(2, 8, 4, 4, "max", torch.float32)
helper(2, 8, 4, 4, "max", torch.int32)
helper(2, 8, 4, 4, "max", torch.float16)
helper(2, 8, 4, 4, "max", torch.int64)
helper(2, 8, 4, 4, "min", torch.float32)
helper(2, 8, 4, 4, "min", torch.int32)
helper(2, 8, 4, 4, "min", torch.float16)
helper(2, 8, 4, 4, "min", torch.int64)
def test_reduction_sum_max_long_val(self):
x_mps = torch.tensor([sys.maxsize, sys.maxsize - 10, sys.maxsize - 5, sys.maxsize - 18], device="mps")
x_cpu = x_mps.detach().clone().cpu()
res_mps = torch.sum(x_mps)
res_cpu = torch.sum(x_cpu)
self.assertEqual(res_mps, res_cpu)
# Test forward max
# Note - don't test grad now
def test_max_el(self):
def helper(n, c, h, w, dtype=torch.float32):
if (dtype not in [torch.float32, torch.bool]):
cpu_x = torch.randint(50, (n, c, h, w), device='cpu', dtype=dtype, requires_grad=False)
x = cpu_x.detach().clone().to('mps')
elif (dtype == torch.bool):
cpu_x = torch.randint(2, (n, c, h, w), device='cpu', dtype=dtype, requires_grad=False)
x = cpu_x.detach().clone().to('mps')
else:
cpu_x = torch.randn(n, c, h, w, device='cpu', dtype=dtype, requires_grad=True)
x = cpu_x.detach().clone().to('mps')
ref_y = torch.max(cpu_x)
y = torch.max(x)
self.assertEqual(y, ref_y)
for dim in [0, 1, 2, 3]:
for keepdim in [True, False]:
y, idx = torch.max(x, dim=dim, keepdim=keepdim)
refy, refidx = torch.max(cpu_x, dim=dim, keepdim=keepdim)
self.assertEqual(y, refy)
self.assertEqual(idx, refidx)
y_0 = torch.ones(c, h, w, device='mps', dtype=dtype)
idx_0 = torch.ones(c, h, w, device='mps', dtype=torch.int64)
torch.max(x, dim=0, out=(y_0, idx_0))
refy_0, refidx_0 = torch.max(cpu_x, dim=0)
self.assertEqual(y_0, refy_0)
self.assertEqual(idx_0, refidx_0)
y_0dim = torch.ones(1, c, h, w, device='mps', dtype=dtype)
idx_0dim = torch.ones(1, c, h, w, device='mps', dtype=torch.int64)
torch.max(x, dim=0, keepdim=True, out=(y_0dim, idx_0dim))
refy_0dim, refidx_0dim = torch.max(cpu_x, dim=0, keepdim=True)
self.assertEqual(y_0dim, refy_0dim)
self.assertEqual(idx_0dim, refidx_0dim)
y_1 = torch.ones(n, h, w, device='mps', dtype=dtype)
idx_1 = torch.ones(n, h, w, device='mps', dtype=torch.int64)
torch.max(x, dim=1, out=(y_1, idx_1))
refy_1, refidx_1 = torch.max(cpu_x, dim=1)
self.assertEqual(y_1, refy_1)
self.assertEqual(idx_1, refidx_1)
y_1dim = torch.ones(n, 1, h, w, device='mps', dtype=dtype)
idx_1dim = torch.ones(n, 1, h, w, device='mps', dtype=torch.int64)
torch.max(x, dim=1, keepdim=True, out=(y_1dim, idx_1dim))
refy_1dim, refidx_1dim = torch.max(cpu_x, keepdim=True, dim=1)
self.assertEqual(y_1dim, refy_1dim)
self.assertEqual(idx_1dim, refidx_1dim)
y_2 = torch.ones(n, c, w, device='mps', dtype=dtype)
idx_2 = torch.ones(n, c, w, device='mps', dtype=torch.int64)
torch.max(x, dim=2, out=(y_2, idx_2))
refy_2, refidx_2 = torch.max(cpu_x, dim=2)
self.assertEqual(y_2, refy_2)
self.assertEqual(idx_2, refidx_2)
y_2dim = torch.ones(n, c, 1, w, device='mps', dtype=dtype)
idx_2dim = torch.ones(n, c, 1, w, device='mps', dtype=torch.int64)
torch.max(x, dim=2, keepdim=True, out=(y_2dim, idx_2dim))
refy_2dim, refidx_2dim = torch.max(cpu_x, dim=2, keepdim=True,)
self.assertEqual(y_2dim, refy_2dim)
self.assertEqual(idx_2dim, refidx_2dim)
y_3 = torch.ones(n, c, h, device='mps', dtype=dtype)
idx_3 = torch.ones(n, c, h, device='mps', dtype=torch.int64)
torch.max(x, dim=3, out=(y_3, idx_3))
refy_3, refidx_3 = torch.max(cpu_x, dim=3)
self.assertEqual(y_3, refy_3)
self.assertEqual(idx_3, refidx_3)
y_3dim = torch.ones(n, c, h, 1, device='mps', dtype=dtype)
idx_3dim = torch.ones(n, c, h, 1, device='mps', dtype=torch.int64)
torch.max(x, dim=3, keepdim=True, out=(y_3dim, idx_3dim))
refy_3dim, refidx_3dim = torch.max(cpu_x, dim=3, keepdim=True,)
self.assertEqual(y_3dim, refy_3dim)
self.assertEqual(idx_3dim, refidx_3dim)
helper(2, 8, 4, 5, torch.float32)
helper(2, 8, 4, 5, torch.int32)
# helper(2, 8, 4, 5, torch.int64)
def test_median(self):
def helper_dtype_int32(n1, n2, n3):
cpu_x = torch.randint(50, (n1, n2, n3), device='cpu', dtype=torch.int32)
mps_x = cpu_x.detach().clone().to('mps')
result_cpu = torch.median(cpu_x)
result_mps = torch.median(mps_x)
self.assertEqual(result_cpu, result_mps)
for dim in [0, 1, 2]:
for keepdim in [True, False]:
y, idx = torch.median(cpu_x, dim=dim, keepdim=keepdim)
refy, refidx = torch.median(mps_x, dim=dim, keepdim=keepdim)
self.assertEqual(y, refy)
self.assertEqual(idx, refidx)
def helper_dtype_float32(n1, n2, n3):
cpu_x = torch.randn(n1, n2, n3, device='cpu', dtype=torch.float32)
mps_x = cpu_x.detach().clone().to('mps')
result_cpu = torch.median(cpu_x)
result_mps = torch.median(mps_x)
self.assertEqual(result_cpu, result_mps)
for dim in [0, 1, 2]:
for keepdim in [True, False]:
y, idx = torch.median(cpu_x, dim=dim, keepdim=keepdim)
refy, refidx = torch.median(mps_x, dim=dim, keepdim=keepdim)
self.assertEqual(y, refy)
self.assertEqual(idx, refidx)
helper_dtype_int32(10, 10, 10) # median at even place
helper_dtype_int32(3, 3, 3) # median at odd place
helper_dtype_int32(1, 1, 1)
helper_dtype_int32(1, 2, 3)
helper_dtype_float32(10, 10, 10)
helper_dtype_float32(3, 3, 3)
helper_dtype_float32(1, 1, 1)
def test_any(self):
def helper(shape):
input_xs = []
prod = 1
for i in range(len(shape)):
prod *= shape[i]
input_xs.append(torch.randn(prod, dtype=torch.float).reshape(shape))
input_xs.append(torch.arange(0, prod, dtype=torch.float).reshape(shape))
input_xs.append(torch.ones(prod, dtype=torch.float).reshape(shape))
input_xs.append(torch.zeros(prod, dtype=torch.float).reshape(shape))
input_xs.append(torch.arange(0, prod, dtype=torch.int).reshape(shape))
input_xs.append(torch.ones(prod, dtype=torch.int).reshape(shape))
input_xs.append(torch.zeros(prod, dtype=torch.int).reshape(shape))
input_xs.append(torch.arange(0, prod, dtype=torch.int).reshape(shape).bool())
input_xs.append(torch.ones(prod, dtype=torch.int).reshape(shape).bool())
input_xs.append(torch.zeros(prod, dtype=torch.int).reshape(shape).bool())
for i, cpu_x in enumerate(input_xs):
x = cpu_x.detach().clone().to('mps')
y = torch.any(x)
ref_y = torch.any(cpu_x)
self.assertEqual(y, ref_y)
y_0 = torch.any(x, dim=0)
refy_0 = torch.any(cpu_x, dim=0)
self.assertEqual(y_0, refy_0)
y_0dim = torch.any(x, dim=0, keepdim=True)
refy_0dim = torch.any(cpu_x, dim=0, keepdim=True)
self.assertEqual(y_0dim, refy_0dim)
y_0dim = torch.any(x, dim=0, keepdim=True)
refy_0dim = torch.any(cpu_x, dim=0, keepdim=True)
self.assertEqual(y_0dim, refy_0dim)
y_1 = torch.any(x, dim=1)
refy_1 = torch.any(cpu_x, dim=1)
self.assertEqual(y_1, refy_1)
y_1dim = torch.any(x, dim=1, keepdim=True)
refy_1dim = torch.any(cpu_x, dim=1, keepdim=True)
self.assertEqual(y_1dim, refy_1dim)
if (len(shape) > 2):
y_2 = torch.any(x, dim=2)
refy_2 = torch.any(cpu_x, dim=2)
self.assertEqual(y_2, refy_2)
y_2dim = torch.any(x, dim=2, keepdim=True)
refy_2dim = torch.any(cpu_x, dim=2, keepdim=True)
self.assertEqual(y_2dim, refy_2dim)
y_3 = torch.any(x, dim=3)
refy_3 = torch.any(cpu_x, dim=3)
self.assertEqual(y_3, refy_3)
y_3dim = torch.any(x, dim=3, keepdim=True)
refy_3dim = torch.any(cpu_x, dim=3, keepdim=True)
self.assertEqual(y_3dim, refy_3dim)
helper((1, 1, 1, 1))
helper((1, 1, 3, 3))
helper((7, 13))
helper((2, 8, 4, 5))
def test_reduction_ops_5D(self):
def helper(fn, dim):
shape = (1, 1, 2, 1, 1)
x_cpu = fn(torch.zeros(shape), dim=dim)
x_mps = fn(torch.zeros(shape, device="mps"), dim=dim)
self.assertEqual(x_cpu, x_mps.cpu())
for fn in [torch.any, torch.all]:
for dim in range(4):
helper(fn, dim)
# 6D tensor reductions
# Regression test for https://github.com/pytorch/pytorch/issues/95538
x = (torch.rand(2, 3, 4, 3, 4, 2, device="mps") - .5).relu()
self.assertEqual(x.all(), x.cpu().all())
for i in range(-5, 6):
self.assertEqual(x.all(dim=i), x.cpu().all(dim=i))
def test_all(self):
def helper(shape):
input_xs = []
prod = 1
for i in range(len(shape)):
prod *= shape[i]
input_xs.append(torch.randn(prod, dtype=torch.float).reshape(shape))
input_xs.append(torch.arange(0, prod, dtype=torch.float).reshape(shape))
input_xs.append(torch.ones(prod, dtype=torch.float).reshape(shape))
input_xs.append(torch.zeros(prod, dtype=torch.float).reshape(shape))
input_xs.append(torch.arange(0, prod, dtype=torch.int).reshape(shape))
input_xs.append(torch.ones(prod, dtype=torch.int).reshape(shape))
input_xs.append(torch.zeros(prod, dtype=torch.int).reshape(shape))
input_xs.append(torch.arange(0, prod, dtype=torch.int).reshape(shape).bool())
input_xs.append(torch.ones(prod, dtype=torch.int).reshape(shape).bool())
input_xs.append(torch.zeros(prod, dtype=torch.int).reshape(shape).bool())
for i, cpu_x in enumerate(input_xs):
x = cpu_x.detach().clone().to('mps')
y = torch.all(x)
ref_y = torch.all(cpu_x)
self.assertEqual(y, ref_y)
y_0 = torch.all(x, dim=0)
refy_0 = torch.all(cpu_x, dim=0)
self.assertEqual(y_0, refy_0)
y_0dim = torch.all(x, dim=0, keepdim=True)
refy_0dim = torch.all(cpu_x, dim=0, keepdim=True)
self.assertEqual(y_0dim, refy_0dim)
y_0dim = torch.all(x, dim=0, keepdim=True)
refy_0dim = torch.all(cpu_x, dim=0, keepdim=True)
self.assertEqual(y_0dim, refy_0dim)
y_1 = torch.all(x, dim=1)
refy_1 = torch.all(cpu_x, dim=1)
self.assertEqual(y_1, refy_1)
y_1dim = torch.all(x, dim=1, keepdim=True)
refy_1dim = torch.all(cpu_x, dim=1, keepdim=True)
self.assertEqual(y_1dim, refy_1dim)
if (len(shape) > 2):
y_2 = torch.all(x, dim=2)
refy_2 = torch.all(cpu_x, dim=2)
self.assertEqual(y_2, refy_2)
y_2dim = torch.all(x, dim=2, keepdim=True)
refy_2dim = torch.all(cpu_x, dim=2, keepdim=True)
self.assertEqual(y_2dim, refy_2dim)
y_3 = torch.all(x, dim=3)
refy_3 = torch.all(cpu_x, dim=3)
self.assertEqual(y_3, refy_3)
y_3dim = torch.all(x, dim=3, keepdim=True)
refy_3dim = torch.all(cpu_x, dim=3, keepdim=True)
self.assertEqual(y_3dim, refy_3dim)
helper((1, 1, 1, 1))
helper((1, 1, 3, 3))
helper((7, 13))
helper((2, 8, 4, 5))
# Empty tensor
x_cpu = torch.tensor([], dtype=torch.bool)
x_mps = x_cpu.to("mps")
self.assertEqual(x_cpu.all(), x_mps.all().cpu())
# Test forward min
def test_min_el(self):
def helper(n, c, h, w):
cpu_x = torch.randn(n, c, h, w, device='cpu', dtype=torch.float, requires_grad=False)
x = cpu_x.detach().clone().to('mps')
y = torch.min(x)
ref_y = torch.min(cpu_x)
self.assertEqual(y, ref_y)
y_0, idx_0 = torch.min(x, dim=0)
refy_0, refidx_0 = torch.min(cpu_x, dim=0)
self.assertEqual(y_0, refy_0)
self.assertEqual(idx_0, refidx_0)
y_0 = torch.ones(c, h, w, device='mps', dtype=torch.float)
idx_0 = torch.ones(c, h, w, device='mps', dtype=torch.int64)
torch.min(x, dim=0, out=(y_0, idx_0))
refy_0, refidx_0 = torch.min(cpu_x, dim=0)
self.assertEqual(y_0, refy_0)
self.assertEqual(idx_0, refidx_0)
y_0dim, idx_0dim = torch.min(x, dim=0, keepdim=True)
refy_0dim, refidx_0dim = torch.min(cpu_x, dim=0, keepdim=True)
self.assertEqual(y_0dim, refy_0dim)
self.assertEqual(idx_0dim, refidx_0dim)
y_0dim = torch.ones(1, c, h, w, device='mps', dtype=torch.float)
idx_0dim = torch.ones(1, c, h, w, device='mps', dtype=torch.int64)
torch.min(x, dim=0, keepdim=True, out=(y_0dim, idx_0dim))
refy_0dim, refidx_0dim = torch.min(cpu_x, dim=0, keepdim=True)
self.assertEqual(y_0dim, refy_0dim)
self.assertEqual(idx_0dim, refidx_0dim)
y_1, idx_1 = torch.min(x, dim=1)
refy_1, refidx_1 = torch.min(cpu_x, dim=1)
self.assertEqual(y_1, refy_1)
self.assertEqual(idx_1, refidx_1)
y_1 = torch.ones(n, h, w, device='mps', dtype=torch.float)
idx_1 = torch.ones(n, h, w, device='mps', dtype=torch.int64)
torch.min(x, dim=1, out=(y_1, idx_1))
refy_1, refidx_1 = torch.min(cpu_x, dim=1)
self.assertEqual(y_1, refy_1)
self.assertEqual(idx_1, refidx_1)
y_1dim, idx_1dim = torch.min(x, dim=1, keepdim=True)
refy_1dim, refidx_1dim = torch.min(cpu_x, dim=1, keepdim=True)
self.assertEqual(y_1dim, refy_1dim)
self.assertEqual(idx_1dim, refidx_1dim)
y_1dim = torch.ones(n, 1, h, w, device='mps', dtype=torch.float)
idx_1dim = torch.ones(n, 1, h, w, device='mps', dtype=torch.int64)
torch.min(x, dim=1, keepdim=True, out=(y_1dim, idx_1dim))
refy_1dim, refidx_1dim = torch.min(cpu_x, keepdim=True, dim=1)
self.assertEqual(y_1dim, refy_1dim)
self.assertEqual(idx_1dim, refidx_1dim)
y_2, idx_2 = torch.min(x, dim=2)
refy_2, refidx_2 = torch.min(cpu_x, dim=2)
self.assertEqual(y_2, refy_2)
self.assertEqual(idx_2, refidx_2)
y_2 = torch.ones(n, c, w, device='mps', dtype=torch.float)
idx_2 = torch.ones(n, c, w, device='mps', dtype=torch.int64)
torch.min(x, dim=2, out=(y_2, idx_2))
refy_2, refidx_2 = torch.min(cpu_x, dim=2)
self.assertEqual(y_2, refy_2)
self.assertEqual(idx_2, refidx_2)
y_2dim, idx_2dim = torch.min(x, dim=2, keepdim=True)
refy_2dim, refidx_2dim = torch.min(cpu_x, dim=2, keepdim=True)
self.assertEqual(y_2dim, refy_2dim)
self.assertEqual(idx_2dim, refidx_2dim)
y_2dim = torch.ones(n, c, 1, w, device='mps', dtype=torch.float)
idx_2dim = torch.ones(n, c, 1, w, device='mps', dtype=torch.int64)
torch.min(x, dim=2, keepdim=True, out=(y_2dim, idx_2dim))
refy_2dim, refidx_2dim = torch.min(cpu_x, dim=2, keepdim=True,)
self.assertEqual(y_2dim, refy_2dim)
self.assertEqual(idx_2dim, refidx_2dim)
y_3, idx_3 = torch.min(x, dim=3)
refy_3, refidx_3 = torch.min(cpu_x, dim=3)
self.assertEqual(y_3, refy_3)
self.assertEqual(idx_3, refidx_3)
y_3 = torch.ones(n, c, h, device='mps', dtype=torch.float)
idx_3 = torch.ones(n, c, h, device='mps', dtype=torch.int64)
torch.min(x, dim=3, out=(y_3, idx_3))
refy_3, refidx_3 = torch.min(cpu_x, dim=3)
self.assertEqual(y_3, refy_3)
self.assertEqual(idx_3, refidx_3)
y_3dim, idx_3dim = torch.min(x, dim=3, keepdim=True)
refy_3dim, refidx_3dim = torch.min(cpu_x, dim=3, keepdim=True)
self.assertEqual(y_3dim, refy_3dim)
self.assertEqual(idx_3dim, refidx_3dim)
y_3dim = torch.ones(n, c, h, 1, device='mps', dtype=torch.float)
idx_3dim = torch.ones(n, c, h, 1, device='mps', dtype=torch.int64)
torch.min(x, dim=3, keepdim=True, out=(y_3dim, idx_3dim))
refy_3dim, refidx_3dim = torch.min(cpu_x, dim=3, keepdim=True,)
self.assertEqual(y_3dim, refy_3dim)
self.assertEqual(idx_3dim, refidx_3dim)
helper(2, 8, 4, 5)
def test_fmin(self):
# Regression test for https://github.com/pytorch/pytorch/issues/143933
scalar = torch.tensor(.5)
x_mps = torch.rand(32, device="mps")
x_cpu = x_mps.detach().cpu()
self.assertEqual(torch.fmin(x_mps, scalar), torch.fmin(x_cpu, scalar))
# Test forward sum
def test_sum(self):
def helper(n, c, h, w, dtype=torch.float32):
cpu_x = None
x = None
if (dtype not in [torch.float32, torch.bool]):
cpu_x = torch.randint(50, (n, c, h, w), device='cpu', dtype=dtype, requires_grad=False)
x = cpu_x.detach().clone().to('mps')
elif (dtype == torch.bool):
cpu_x = torch.randint(2, (n, c, h, w), device='cpu', dtype=dtype, requires_grad=False)
x = cpu_x.detach().clone().to('mps')
else:
cpu_x = torch.randn(n, c, h, w, device='cpu', dtype=dtype, requires_grad=True)
x = cpu_x.detach().clone().to('mps').requires_grad_()
all_sum = torch.sum(x)
all_sum_cpu = torch.sum(cpu_x)
self.assertEqual(all_sum, all_sum_cpu)
nil_dim_sum = torch.sum(x, dim=[])
nil_dim_sum_cpu = torch.sum(cpu_x, dim=[])
self.assertEqual(nil_dim_sum, nil_dim_sum_cpu)
nil_dim_sum_keepdim = torch.sum(x, dim=[], keepdim=True)
nil_dim_sum_cpu_keepdim = torch.sum(cpu_x, dim=[], keepdim=True)
self.assertEqual(nil_dim_sum_keepdim, nil_dim_sum_cpu_keepdim)
zero_dim_sum = torch.sum(x, dim=[0])
zero_dim_sum_cpu = torch.sum(cpu_x, dim=[0])
self.assertEqual(zero_dim_sum, zero_dim_sum_cpu)
zero_dim_sum_keepdim = torch.sum(x, dim=[0], keepdim=True)
zero_dim_sum_cpu_keepdim = torch.sum(cpu_x, dim=[0], keepdim=True)
self.assertEqual(zero_dim_sum_keepdim, zero_dim_sum_cpu_keepdim)
zero_one_dim_sum = torch.sum(x, dim=[0, 1])
zero_one_dim_sum_cpu = torch.sum(cpu_x, dim=[0, 1])
self.assertEqual(zero_one_dim_sum, zero_one_dim_sum_cpu)
zero_one_dim_sum_keepdim = torch.sum(x, dim=[0, 1], keepdim=True)
zero_one_dim_sum_cpu_keepdim = torch.sum(cpu_x, dim=[0, 1], keepdim=True)
self.assertEqual(zero_one_dim_sum_keepdim, zero_one_dim_sum_cpu_keepdim)
two_three_dim_sum = torch.sum(x, dim=[2, 3])
two_three_dim_sum_cpu = torch.sum(cpu_x, dim=[2, 3])
self.assertEqual(two_three_dim_sum, two_three_dim_sum_cpu)
two_three_keepdim_sum = torch.sum(x, dim=[2, 3], keepdim=True)
two_three_dim_keepsum_cpu = torch.sum(cpu_x, dim=[2, 3], keepdim=True)
self.assertEqual(two_three_keepdim_sum, two_three_dim_keepsum_cpu)
helper(2, 8, 4, 5)
helper(2, 8, 4, 5, dtype=torch.int32)
helper(2, 8, 4, 5, dtype=torch.int64)
helper(2, 8, 4, 5, dtype=torch.bool)
# Regression test for https://github.com/pytorch/pytorch/issues/136132
x = torch.ones(2, 4, 1, 30, 1, device='mps').sum(dim=-2)
self.assertEqual(x.numel(), 8)
self.assertEqual(x.max().item(), 30.0)
# Test forward prod
def test_prod(self):
def helper(shape, dtype=torch.float32):
cpu_x = None
x = None
if (dtype not in [torch.float32, torch.bool]):
cpu_x = torch.randint(1, 6, shape, device='cpu', dtype=dtype, requires_grad=False)
x = cpu_x.detach().clone().to('mps')
elif (dtype == torch.bool):
cpu_x = torch.randint(2, shape, device='cpu', dtype=dtype, requires_grad=False)
x = cpu_x.detach().clone().to('mps')
else:
cpu_x = torch.randn(shape, device='cpu', dtype=dtype, requires_grad=True)
x = cpu_x.detach().clone().to('mps').requires_grad_()
all_prod = torch.prod(x)
all_prod_cpu = torch.prod(cpu_x)
self.assertEqual(all_prod, all_prod_cpu)
for dim in range(len(shape)):
dim_prod = torch.prod(x, dim=dim)
dim_prod_cpu = torch.prod(cpu_x, dim=dim)
self.assertEqual(dim_prod, dim_prod_cpu)
dim_prod_keepdim = torch.prod(x, dim=dim, keepdim=True)
dim_prod_cpu_keepdim = torch.prod(cpu_x, dim=dim, keepdim=True)
self.assertEqual(dim_prod_keepdim, dim_prod_cpu_keepdim)
for dtype in [torch.float32, torch.int32, torch.int64, torch.bool]:
helper((2, 3), dtype)
# Test forward mean
def test_mean(self):
def helper(n, c, h, w):
cpu_x = torch.randn(n, c, h, w, device='cpu', dtype=torch.float, requires_grad=True)
x = cpu_x.detach().clone().to('mps').requires_grad_()
all_mean = torch.mean(x)
all_mean_cpu = torch.mean(cpu_x)
self.assertEqual(all_mean, all_mean_cpu)
nil_dim_mean = torch.mean(x, dim=[])
nil_dim_mean_cpu = torch.mean(cpu_x, dim=[])
self.assertEqual(nil_dim_mean, nil_dim_mean_cpu)
nil_dim_mean_keepdim = torch.mean(x, dim=[], keepdim=True)
nil_dim_mean_cpu_keepdim = torch.mean(cpu_x, dim=[], keepdim=True)
self.assertEqual(nil_dim_mean_keepdim, nil_dim_mean_cpu_keepdim)
zero_dim_mean = torch.mean(x, dim=[0])
zero_dim_mean_cpu = torch.mean(cpu_x, dim=[0])
self.assertEqual(zero_dim_mean, zero_dim_mean_cpu)
zero_dim_mean_keepdim = torch.mean(x, dim=[0], keepdim=True)
zero_dim_mean_cpu_keepdim = torch.mean(cpu_x, dim=[0], keepdim=True)
self.assertEqual(zero_dim_mean_keepdim, zero_dim_mean_cpu_keepdim)
zero_one_dim_mean = torch.mean(x, dim=[0, 1])
zero_one_dim_mean_cpu = torch.mean(cpu_x, dim=[0, 1])
self.assertEqual(zero_one_dim_mean, zero_one_dim_mean_cpu)
zero_one_dim_mean_keepdim = torch.mean(x, dim=[0, 1], keepdim=True)
zero_one_dim_mean_cpu_keepdim = torch.mean(cpu_x, dim=[0, 1], keepdim=True)
self.assertEqual(zero_one_dim_mean_keepdim, zero_one_dim_mean_cpu_keepdim)
two_three_dim_mean = torch.mean(x, dim=[2, 3])
two_three_dim_mean_cpu = torch.mean(cpu_x, dim=[2, 3])
self.assertEqual(two_three_dim_mean, two_three_dim_mean_cpu)
two_three_keepdim_mean = torch.mean(x, dim=[2, 3], keepdim=True)
two_three_dim_keepmean_cpu = torch.mean(cpu_x, dim=[2, 3], keepdim=True)
self.assertEqual(two_three_keepdim_mean, two_three_dim_keepmean_cpu)
helper(2, 8, 4, 5)
# Test std
def test_std(self):
def helper(shape):
cpu_x = torch.randn(shape, device='cpu', dtype=torch.float, requires_grad=False)
x = cpu_x.detach().clone().to('mps')
all_std = torch.std(x, unbiased=False)
all_std_cpu = torch.std(cpu_x, unbiased=False)
self.assertEqual(all_std, all_std_cpu)
nil_dim_std = torch.std(x, dim=[], unbiased=False)
nil_dim_std_cpu = torch.std(cpu_x, dim=[], unbiased=False)
self.assertEqual(nil_dim_std, nil_dim_std_cpu)
nil_dim_std_keepdim = torch.std(x, dim=[], keepdim=True, unbiased=False)
nil_dim_std_cpu_keepdim = torch.std(cpu_x, dim=[], keepdim=True, unbiased=False)
self.assertEqual(nil_dim_std_keepdim, nil_dim_std_cpu_keepdim)
zero_dim_std = torch.std(x, dim=[0], unbiased=False)
zero_dim_std_cpu = torch.std(cpu_x, dim=[0], unbiased=False)
self.assertEqual(zero_dim_std, zero_dim_std_cpu)
zero_dim_std_keepdim = torch.std(x, dim=[0], keepdim=True, unbiased=False)
zero_dim_std_cpu_keepdim = torch.std(cpu_x, dim=[0], keepdim=True, unbiased=False)
self.assertEqual(zero_dim_std_keepdim, zero_dim_std_cpu_keepdim)
zero_one_dim_std = torch.std(x, dim=[0, 1], unbiased=False)
zero_one_dim_std_cpu = torch.std(cpu_x, dim=[0, 1], unbiased=False)
self.assertEqual(zero_one_dim_std, zero_one_dim_std_cpu)
zero_one_dim_std_keepdim = torch.std(x, dim=[0, 1], keepdim=True, unbiased=False)
zero_one_dim_std_cpu_keepdim = torch.std(cpu_x, dim=[0, 1], keepdim=True, unbiased=False)
self.assertEqual(zero_one_dim_std_keepdim, zero_one_dim_std_cpu_keepdim)
two_three_dim_std = torch.std(x, dim=[2, 3], unbiased=False)
two_three_dim_std_cpu = torch.std(cpu_x, dim=[2, 3], unbiased=False)
self.assertEqual(two_three_dim_std, two_three_dim_std_cpu)
two_three_keepdim_std = torch.std(x, dim=[2, 3], keepdim=True, unbiased=False)
two_three_dim_keepstd_cpu = torch.std(cpu_x, dim=[2, 3], keepdim=True, unbiased=False)
self.assertEqual(two_three_keepdim_std, two_three_dim_keepstd_cpu)
all_std = torch.std(x, unbiased=True)
all_std_cpu = torch.std(cpu_x, unbiased=True)
self.assertEqual(all_std, all_std_cpu)
nil_dim_std = torch.std(x, dim=[], unbiased=True)
nil_dim_std_cpu = torch.std(cpu_x, dim=[], unbiased=True)
self.assertEqual(nil_dim_std, nil_dim_std_cpu)
nil_dim_std_keepdim = torch.std(x, dim=[], keepdim=True, unbiased=True)
nil_dim_std_cpu_keepdim = torch.std(cpu_x, dim=[], keepdim=True, unbiased=True)
self.assertEqual(nil_dim_std_keepdim, nil_dim_std_cpu_keepdim)
zero_dim_std = torch.std(x, dim=[0], unbiased=True)
zero_dim_std_cpu = torch.std(cpu_x, dim=[0], unbiased=True)
self.assertEqual(zero_dim_std, zero_dim_std_cpu)
zero_dim_std_keepdim = torch.std(x, dim=[0], keepdim=True, unbiased=True)
zero_dim_std_cpu_keepdim = torch.std(cpu_x, dim=[0], keepdim=True, unbiased=True)
self.assertEqual(zero_dim_std_keepdim, zero_dim_std_cpu_keepdim)
zero_one_dim_std = torch.std(x, dim=[0, 1], unbiased=True)
zero_one_dim_std_cpu = torch.std(cpu_x, dim=[0, 1], unbiased=True)
self.assertEqual(zero_one_dim_std, zero_one_dim_std_cpu)
zero_one_dim_std_keepdim = torch.std(x, dim=[0, 1], keepdim=True, unbiased=True)
zero_one_dim_std_cpu_keepdim = torch.std(cpu_x, dim=[0, 1], keepdim=True, unbiased=True)
self.assertEqual(zero_one_dim_std_keepdim, zero_one_dim_std_cpu_keepdim)
two_three_dim_std = torch.std(x, dim=[2, 3], unbiased=True)
two_three_dim_std_cpu = torch.std(cpu_x, dim=[2, 3], unbiased=True)
self.assertEqual(two_three_dim_std, two_three_dim_std_cpu)
two_three_keepdim_std = torch.std(x, dim=[2, 3], keepdim=True, unbiased=True)
two_three_dim_keepstd_cpu = torch.std(cpu_x, dim=[2, 3], keepdim=True, unbiased=True)
self.assertEqual(two_three_keepdim_std, two_three_dim_keepstd_cpu)
helper((4, 5, 6, 7))
# verify if a change in shape of input would cause problems with graph caching
helper((9, 5, 6, 7))
# Test var
def test_var_simple(self):
def helper():
shape = [2, 3, 4, 5]
cpu_x = torch.randn(shape, device='cpu', dtype=torch.float, requires_grad=False)
x = cpu_x.detach().clone().to('mps')
for unbiased in [False, True]:
for keepdim in [False, True]:
zero_dim_var = x.var(-1, keepdim=keepdim, unbiased=unbiased)
zero_dim_var_cpu = cpu_x.var(-1, keepdim=keepdim, unbiased=unbiased)
self.assertEqual(zero_dim_var, zero_dim_var_cpu)
all_var = torch.var(x, unbiased=unbiased)
all_var_cpu = torch.var(cpu_x, unbiased=unbiased)
self.assertEqual(all_var, all_var_cpu)
nil_dim_var = torch.var(x, dim=[], keepdim=keepdim, unbiased=unbiased)
nil_dim_var_cpu = torch.var(cpu_x, dim=[], keepdim=keepdim, unbiased=unbiased)
self.assertEqual(nil_dim_var, nil_dim_var_cpu)
zero_dim_var = torch.var(x, dim=[0], keepdim=keepdim, unbiased=unbiased)
zero_dim_var_cpu = torch.var(cpu_x, dim=[0], keepdim=keepdim, unbiased=unbiased)
self.assertEqual(zero_dim_var, zero_dim_var_cpu)
zero_one_dim_var = torch.var(x, dim=[0, -1], keepdim=keepdim, unbiased=unbiased)
zero_one_dim_var_cpu = torch.var(cpu_x, dim=[0, -1], keepdim=keepdim, unbiased=unbiased)
self.assertEqual(zero_one_dim_var, zero_one_dim_var_cpu)
two_three_dim_var = torch.var(x, dim=[2, 3], keepdim=keepdim, unbiased=unbiased)
two_three_dim_var_cpu = torch.var(cpu_x, dim=[2, 3], keepdim=keepdim, unbiased=unbiased)
self.assertEqual(two_three_dim_var, two_three_dim_var_cpu)
helper()
# Regression test for https://github.com/pytorch/pytorch/issues/160738
self.assertTrue(torch.var(torch.tensor(3.13, device='mps'), dim=0).isnan().item())
# Test forward amax
def test_amax(self):
def helper(shape, dim, keepdim):
cpu_x = torch.randn(shape, device='cpu', dtype=torch.float, requires_grad=True)
x = cpu_x.detach().clone().to('mps').requires_grad_()
result = torch.amax(x, dim=dim, keepdim=keepdim)
result_cpu = torch.amax(cpu_x, dim=dim, keepdim=keepdim)
cpu_grad = torch.randn(result_cpu.shape)
grad = cpu_grad.to('mps')
result_cpu.backward(gradient=cpu_grad)
result.backward(gradient=grad)
self.assertEqual(result, result_cpu)
self.assertEqual(x.grad, cpu_x.grad)
for dim in ([], [0], [0, 1], [2, 3]):
for keepdim in [False, True]:
helper((2, 8, 4, 5), dim, keepdim)
# Test forward amin
def test_amin(self):
def helper(shape, dim, keepdim):
cpu_x = torch.randn(shape, device='cpu', dtype=torch.float, requires_grad=True)
x = cpu_x.detach().clone().to('mps').requires_grad_()
result = torch.amin(x, dim=dim, keepdim=keepdim)
result_cpu = torch.amin(cpu_x, dim=dim, keepdim=keepdim)
cpu_grad = torch.randn(result_cpu.shape)
grad = cpu_grad.to('mps')
result_cpu.backward(gradient=cpu_grad)
result.backward(gradient=grad)
self.assertEqual(result, result_cpu)
self.assertEqual(x.grad, cpu_x.grad)
for dim in ([], [0], [0, 1], [2, 3]):
for keepdim in [False, True]:
helper((2, 8, 4, 5), dim, keepdim)
# Test minimum and maximum
def test_minimum_maximum(self):
def helper(n, c, h, w):
cpu_x = torch.randn(n, c, h, w, device='cpu', dtype=torch.float, requires_grad=False)
cpu_y = torch.randn(n, c, h, w, device='cpu', dtype=torch.float, requires_grad=False)
mps_x = cpu_x.detach().clone().to('mps')
mps_y = cpu_y.detach().clone().to('mps')
minimum_result_cpu = torch.minimum(cpu_x, cpu_y)
minimum_result_mps = torch.minimum(mps_x, mps_y)
self.assertEqual(minimum_result_cpu, minimum_result_mps)
maximum_result_cpu = torch.maximum(cpu_x, cpu_y)
maximum_result_mps = torch.maximum(mps_x, mps_y)
self.assertEqual(maximum_result_cpu, maximum_result_mps)
helper(1, 1, 4, 5)
def test_minimum_maximum_nan_propagation(self):
x = torch.rand(32, device="mps")
y = torch.rand(32, device="mps")
x[3] = torch.nan
y[5] = torch.nan
self.assertTrue(torch.minimum(x, y).isnan().any().item())
self.assertTrue(torch.maximum(x, y).isnan().any().item())
def test_clamp_fp16_fp32(self):
cpu_x = torch.randn(10, device='cpu', dtype=torch.float, requires_grad=False)
x = cpu_x.detach().clone().to('mps')
dtype = torch.float16
clamp_min_vals_mps = torch.ones(10, device="mps").to(torch.float16)
clamp_max_vals_mps = torch.ones(10, device="mps").to(torch.float16) * 10
clamp_result_mps = torch.clamp(x, clamp_min_vals_mps, clamp_max_vals_mps)
clamp_min_vals_cpu = torch.ones(10, device="cpu").to(torch.float16)
clamp_max_vals_cpu = torch.ones(10, device="cpu").to(torch.float16) * 10
clamp_result_cpu = torch.clamp(cpu_x, clamp_min_vals_cpu, clamp_max_vals_cpu)
self.assertEqual(clamp_result_mps, clamp_result_cpu)
def test_clamp_nan(self):
t_mps = torch.tensor([torch.nan, 1, 2], device="mps")
t_cpu = torch.tensor([torch.nan, 1, 2], device="cpu")
clamp_min_max_mps = torch.clamp(t_mps, min=-100, max=100)
clamp_min_max_cpu = torch.clamp(t_cpu, min=-100, max=100)
self.assertEqual(clamp_min_max_mps, clamp_min_max_cpu)
clamp_min_mps = torch.clamp(t_mps, min=-100)
clamp_min_cpu = torch.clamp(t_cpu, min=-100)
self.assertEqual(clamp_min_mps, clamp_min_cpu)
clamp_max_mps = torch.clamp(t_mps, max=100)
clamp_max_cpu = torch.clamp(t_cpu, max=100)
self.assertEqual(clamp_max_mps, clamp_max_cpu)
# Test clamp_min
def test_clamp_min(self):
def helper(n, c, h, w):
cpu_x = torch.randn(n, c, h, w, device='cpu', dtype=torch.float, requires_grad=False)
x = cpu_x.detach().clone().to('mps')
cpu_min_t = torch.randn(n, c, h, w, device='cpu', dtype=torch.float, requires_grad=False)
min_t = cpu_min_t.detach().clone().to('mps')
clamp_min_result = torch.clamp_min(x, min=5.0)
clamp_min_result_cpu = torch.clamp_min(cpu_x, min=5.0)
self.assertEqual(clamp_min_result, clamp_min_result_cpu)
clamp_min_t_result = torch.clamp_min(x, min=min_t)
clamp_min_t_result_cpu = torch.clamp_min(cpu_x, min=cpu_min_t)
self.assertEqual(clamp_min_t_result, clamp_min_t_result_cpu)
helper(2, 8, 4, 5)
# Test clamp_max
def test_clamp_max(self):
def helper(n, c, h, w):
cpu_x = torch.randn(n, c, h, w, device='cpu', dtype=torch.float, requires_grad=False)
x = cpu_x.detach().clone().to('mps')
cpu_max_t = torch.randn(n, c, h, w, device='cpu', dtype=torch.float, requires_grad=False)
max_t = cpu_max_t.detach().clone().to('mps')
clamp_max_result = torch.clamp_max(x, max=100.0)
clamp_max_result_cpu = torch.clamp_max(cpu_x, max=100.0)
self.assertEqual(clamp_max_result, clamp_max_result_cpu)
clamp_max_t_result = torch.clamp_max(x, max=max_t)
clamp_max_t_result_cpu = torch.clamp_max(cpu_x, max=cpu_max_t)
self.assertEqual(clamp_max_t_result, clamp_max_t_result_cpu)
helper(2, 8, 4, 5)
# Test clamp
def test_clamp(self):
def helper(n, c, h, w):
import numpy as np
upper_bound = 1000
half_upper_bound = upper_bound / 2
# x=[0..1000)
x_arr = upper_bound * np.random.random_sample(size=(n, c, h, w)).astype(np.float32)
cpu_x = torch.tensor(x_arr, device='cpu', dtype=torch.float, requires_grad=False)
x = cpu_x.detach().clone().to('mps')
# x=[0..500)
min_arr = half_upper_bound * np.random.random_sample(size=(n, c, h, w)).astype(np.float32)
cpu_min_t = torch.tensor(min_arr, device='cpu', dtype=torch.float, requires_grad=False)
min_t = cpu_min_t.detach().clone().to('mps')
# x=[500..1000), to ensure max's are greater than mins
max_arr = (half_upper_bound * np.random.random_sample(size=(n, c, h, w)).astype(np.float32)) + half_upper_bound
cpu_max_t = torch.tensor(max_arr, device='cpu', dtype=torch.float, requires_grad=False)
max_t = cpu_max_t.detach().clone().to('mps')
# [200..600]: just an arbitrary range between [0..1000]
clamp_result = torch.clamp(x, min=200.0, max=600.0)
clamp_result_cpu = torch.clamp(cpu_x, min=200.0, max=600.0)
self.assertEqual(clamp_result, clamp_result_cpu)
# test optional scalar refs and cached graph keys by passing only max
clamp_opt_result = torch.clamp(x, max=600.0)
clamp_opt_result_cpu = torch.clamp(cpu_x, max=600.0)
self.assertEqual(clamp_opt_result, clamp_opt_result_cpu)
clamp_t_result = torch.clamp(x, min=min_t, max=max_t)
clamp_t_result_cpu = torch.clamp(cpu_x, min=cpu_min_t, max=cpu_max_t)
self.assertEqual(clamp_t_result, clamp_t_result_cpu)
# test optional tensor refs and cached graph keys by passing only max
clamp_topt_result = torch.clamp(x, max=max_t)
clamp_topt_result_cpu = torch.clamp(cpu_x, max=cpu_max_t)
self.assertEqual(clamp_topt_result, clamp_topt_result_cpu)
# test strided x
clamp_result = torch.clamp(x.movedim(0, -1), min=200.0, max=600.0)
clamp_result_cpu = torch.clamp(cpu_x.movedim(0, -1), min=200.0, max=600.0)
self.assertEqual(clamp_result, clamp_result_cpu)
# test strided x, min_t, max_t
clamp_result = torch.clamp(x.movedim(0, -1), min=min_t.movedim(0, -1), max=max_t.movedim(0, -1))
clamp_result_cpu = torch.clamp(cpu_x.movedim(0, -1), min=cpu_min_t.movedim(0, -1), max=cpu_max_t.movedim(0, -1))
self.assertEqual(clamp_result, clamp_result_cpu)
# test strided min_t, max_t
clamp_result = torch.clamp(
x.movedim(0, -1).clone(memory_format=torch.contiguous_format),
min=min_t.movedim(0, -1),
max=max_t.movedim(0, -1)
)
clamp_result_cpu = torch.clamp(
cpu_x.movedim(0, -1).clone(memory_format=torch.contiguous_format),
min=cpu_min_t.movedim(0, -1),
max=cpu_max_t.movedim(0, -1)
)
self.assertEqual(clamp_result, clamp_result_cpu)
# test inplace clamping
x.clamp_(min=200.0, max=600.0)
cpu_x.clamp_(min=200.0, max=600.0)
self.assertEqual(cpu_x, x)
helper(2, 8, 4, 5)
def test_divmode(self):
def helper(shape, rounding_mode):
for dtype in [torch.float32, torch.float16, torch.int32, torch.int64]:
if ((rounding_mode is not None and "floor" in rounding_mode and dtype == torch.int64) or
(rounding_mode is not None and "trunc" in rounding_mode and dtype == torch.float16)) is False:
cpu_x = None
cpu_y = None
if (dtype in [torch.float32, torch.float16]):
cpu_x = torch.randn(shape, device='cpu', dtype=dtype, requires_grad=False)
cpu_y = torch.randn(shape, device='cpu', dtype=dtype, requires_grad=False)
else:
cpu_x = torch.randint(-10, 0, shape, device='cpu', dtype=dtype, requires_grad=False)
cpu_y = torch.randint(-10, 0, shape, device='cpu', dtype=dtype, requires_grad=False)
mps_x = cpu_x.detach().clone().to('mps')
# clamp to avoid division by 0
mps_y = cpu_y.detach().clone().to('mps')
if (rounding_mode == "floor_divide"):
result_div_cpu = torch.floor_divide(cpu_x, cpu_y)
result_div_mps = torch.floor_divide(mps_x, mps_y)
self.assertEqual(result_div_mps, result_div_cpu)
else:
result_div_cpu = torch.div(cpu_x, cpu_y, rounding_mode=rounding_mode)
result_div_mps = torch.div(mps_x, mps_y, rounding_mode=rounding_mode)
self.assertEqual(result_div_mps, result_div_cpu)
helper((2, 8, 4, 5), None)
helper((2, 8, 4, 5), "floor")
helper((2, 8, 4, 5), "trunc")
helper((2, 8, 4, 5), "floor_divide")
def test_rounding(self):
def helper(shape):
cpu_x = torch.randn(shape, device='cpu', dtype=torch.float, requires_grad=False)
mps_x = cpu_x.detach().clone().to('mps')
result_floor_cpu = torch.floor(cpu_x)
result_floor_mps = torch.floor(mps_x)
self.assertEqual(result_floor_mps, result_floor_cpu)
result_ceil_cpu = torch.ceil(cpu_x)
result_ceil_mps = torch.ceil(mps_x)
self.assertEqual(result_ceil_mps, result_ceil_cpu)
result_trunc_cpu = torch.trunc(cpu_x)
result_trunc_mps = torch.trunc(mps_x)
self.assertEqual(result_trunc_mps, result_trunc_cpu)
result_round_cpu = torch.round(cpu_x)
result_round_mps = torch.round(mps_x)
self.assertEqual(result_round_mps, result_round_cpu)
helper((2, 6, 3, 5))
helper((2, 8, 4, 5))
def test_remainder(self):
res_cpu = torch.remainder(
torch.tensor([-3, -2, -1, 1, 2, 3], dtype=torch.int32, device="cpu"), torch.tensor(2, device="cpu", dtype=torch.int32))
res_mps = torch.remainder(
torch.tensor([-3, -2, -1, 1, 2, 3], dtype=torch.int32, device="mps"), torch.tensor(2, device="mps", dtype=torch.int32))
self.assertEqual(res_cpu, res_mps)
res_cpu = torch.remainder(
torch.tensor([1, 2, 3, 4, 5], dtype=torch.int32, device="cpu"), -1.5)
res_mps = torch.remainder(
torch.tensor([1, 2, 3, 4, 5], dtype=torch.int32, device="mps"), -1.5)
self.assertEqual(res_cpu, res_mps)
# Regression test for https://github.com/pytorch/pytorch/issues/154171
# Essentially remained over integral types should rely on integers ops
self.assertEqual(torch.tensor(42309891, device='mps') % torch.tensor(31, device='mps'), torch.tensor(6, device='mps'))
def test_expand(self):
def helper(n, c):
values = [[1.0], [4.0], [7.0]]
cpu_x = torch.tensor(values, device='cpu')
x = cpu_x.detach().clone().to('mps')
strided_cpu = torch.as_strided(cpu_x, (3, 4), (1, 0))
strided_mps = torch.as_strided(x, (3, 4), (1, 0))
self.assertEqual(strided_mps, strided_cpu)
helper(3, 1)
def test_im2col(self):
def helper(x):
return torch.nn.functional.unfold(x, kernel_size=(10, 15), dilation=2, padding=5, stride=3)
x_cpu = torch.rand(1, 1, 200, 100)
x = x_cpu.detach().clone().to('mps')
self.assertEqual(helper(x_cpu), helper(x))
def test_col2im(self):
def helper(shapes, output_size, kernel_size, padding, stride, contiguous, dtype=torch.float32, test_bool=False):
atol = 1e-5 if dtype == torch.float else 1e-2
rtol = 1e-3 if dtype == torch.float else 1e-2
x_cpu = torch.rand(*shapes, dtype=dtype)
if test_bool:
x_cpu = x_cpu > 0.5
x_mps = x_cpu.clone().to('mps')
if not contiguous:
x_cpu = x_cpu.mT
x_mps = x_mps.mT
out_cpu = torch.nn.functional.fold(
x_cpu,
output_size=output_size,
kernel_size=kernel_size,
padding=padding,
stride=stride
)
out_mps = torch.nn.functional.fold(
x_mps,
output_size=output_size,
kernel_size=kernel_size,
padding=padding,
stride=stride
)
self.assertEqual(out_cpu, out_mps, atol=atol, rtol=rtol)
helper((4, 27, 1600), (40, 40), 3, 1, 1, True)
helper((1, 27, 1600), (40, 40), 3, 1, 1, True)
helper((27, 1600), (40, 40), 3, 1, 1, True)
helper((27, 320), (80, 4), 3, 1, 1, True)
helper((27, 320), (4, 80), 3, 1, 1, True)
helper((320, 27), (4, 80), 3, 1, 1, False)
helper((4, 75, 1600), (40, 40), 5, 2, 1, True)
helper((4, 75, 441), (41, 41), 5, 2, 2, True)
helper((4, 12, 100), (20, 20), 2, 0, 2, True)
helper((4, 48, 225), (30, 30), 4, 1, 2, True)
helper((100, 75), (20, 20), 5, 2, 2, False)
helper((4, 15, 1600), (40, 40), (3, 5), (1, 2), (1, 1), True)
helper((4, 45, 187), (35, 33), (3, 5), (0, 1), (2, 3), True)
helper((1600, 15), (40, 40), (3, 5), (1, 2), (1, 1), False)
helper((20, 15), (2, 10), (3, 5), (1, 2), (1, 1), False, torch.bfloat16)
helper((20, 15), (2, 10), (3, 5), (1, 2), (1, 1), False, torch.float16)
helper((20, 15), (2, 10), (3, 5), (1, 2), (1, 1), False, test_bool=True)
def test_select(self):
def helper(n, c):
cpu_x = torch.randn(n, c, device='cpu', dtype=torch.float, requires_grad=True)
x = cpu_x.detach().clone().to('mps').requires_grad_()
strided_cpu = torch.as_strided(cpu_x, (3, 1), (3, 1))
strided_mps = torch.as_strided(x, (3, 1), (3, 1))
self.assertEqual(strided_mps, strided_cpu)
strided_cpu = torch.as_strided(cpu_x, (1, 3), (3, 1))
strided_mps = torch.as_strided(x, (1, 3), (3, 1))
self.assertEqual(strided_mps, strided_cpu)
strided_cpu = torch.as_strided(cpu_x, (3, 1), (3, 1), storage_offset=1)
strided_mps = torch.as_strided(x, (3, 1), (3, 1), storage_offset=1)
self.assertEqual(strided_mps, strided_cpu)
helper(3, 3)
def test_sort(self):
for SIZE in (4, 2049):
device = 'mps'
x = torch.rand(4, SIZE, device=device)
res1val, res1ind = torch.sort(x)
res2val = torch.tensor((), device=device)
res2ind = torch.tensor((), device=device, dtype=torch.long)
torch.sort(x, out=(res2val, res2ind))
self.assertEqual(res1val, res2val, atol=0, rtol=0)
self.assertEqual(res1ind, res2ind, atol=0, rtol=0)
self.assertEqual(torch.argsort(x), res1ind)
self.assertEqual(x.argsort(), res1ind)
self.assertEqual(
torch.sort(torch.tensor((50, 40, 30, 20, 10), device=device))[0],
torch.tensor((10, 20, 30, 40, 50), device=device),
atol=0, rtol=0
)
def test_linalg_cholesky(self):
from torch.testing._internal.common_utils import random_hermitian_pd_matrix
def run_cholesky_test(size, *batch_dims, upper=False, check_errors=False):
if check_errors:
# expect failure for non-positive definite matrix
input_mps = torch.eye(size, dtype=torch.float32, device="mps")
input_mps[0, 0] = -1
error_msg = r'The factorization could not be completed because the input is not positive-definite'
with self.assertRaisesRegex(RuntimeError, error_msg):
torch.linalg.cholesky_ex(input_mps, upper=upper, check_errors=check_errors)
return
# output checks for positive definite matrix
input_cpu = random_hermitian_pd_matrix(size, *batch_dims, dtype=torch.float32, device="cpu")
input_mps = input_cpu.to('mps')
output_cpu = torch.linalg.cholesky_ex(input_cpu, upper=upper)
output_mps = torch.linalg.cholesky_ex(input_mps, upper=upper)
self.assertEqual(output_cpu, output_mps, atol=2e-5, rtol=1e-6)
# test with different even/odd matrix sizes
matrix_sizes = [1, 2, 3, 4, 8, 17, 64, 128, 154]
# even/odd batch sizes
batch_sizes = [1, 2, 4, 8, 16, 17]
for upper in [True, False]:
for size in matrix_sizes:
for batch_size in batch_sizes:
run_cholesky_test(size, batch_size, upper=upper)
# test >3D matrices
run_cholesky_test(128, 10, 10, upper=False)
run_cholesky_test(128, 2, 2, 2, 2, 10, 10, upper=True)
run_cholesky_test(32, 2, upper=False, check_errors=True)
run_cholesky_test(32, 2, upper=True, check_errors=True)
def test_linalg_cholesky_info(self):
# non psd matrix with leading minor of order 2 being not positive definite
A = torch.tensor([
[4.0, 1.0, 0.0],
[1.0, -2.0, 1.0],
[0.0, 1.0, 3.0]
], device="mps")
with self.assertRaisesRegex(RuntimeError, r'leading minor of order 2 is not positive-definite'):
torch.linalg.cholesky_ex(A, check_errors=True)
def test_upsample_nearest2d(self):
def helper(N, C, H, W, memory_format):
inputCPU = torch.arange(N * C * H * W, device='cpu', dtype=torch.float,
requires_grad=True).reshape(N, C, H, W).to(memory_format=memory_format)
inputCPU.retain_grad()
inputMPS = inputCPU.detach().to('mps').requires_grad_()
values = [1, 2, 5, 10, 40]
for i in values:
for j in values:
upsample_nearest2d = nn.UpsamplingNearest2d(scale_factor=(i, j))
outputCPU = upsample_nearest2d(inputCPU)
outputMPS = upsample_nearest2d(inputMPS)
self.assertEqual(outputCPU, outputMPS)
upsample_nearest2d = nn.UpsamplingNearest2d((i * H, j * W))
outputCPU = upsample_nearest2d(inputCPU)
outputMPS = upsample_nearest2d(inputMPS)
self.assertEqual(outputCPU, outputMPS)
outputCPU.backward(gradient=torch.full_like(outputCPU, 0.3))
outputMPS.backward(gradient=torch.full_like(outputMPS, 0.3))
self.assertEqual(inputCPU.grad, inputMPS.grad)
for memory_format in [torch.channels_last, torch.contiguous_format]:
helper(1, 1, 4, 4, memory_format=memory_format)
helper(7, 5, 3, 2, memory_format=memory_format)
def test_upsample_bilinear2d(self):
def helper(N, C, H, W):
inputCPU = torch.arange(N * C * H * W, device='cpu', dtype=torch.float,
requires_grad=True).reshape(N, C, H, W)
inputCPU.retain_grad()
inputMPS = inputCPU.detach().clone().to('mps').requires_grad_()
values = [1, 2, 5, 10, 40]
for i in values:
for j in values:
upsample_bilinear2d = nn.UpsamplingBilinear2d(scale_factor=(i, j))
outputCPU = upsample_bilinear2d(inputCPU)
outputMPS = upsample_bilinear2d(inputMPS)
self.assertEqual(outputCPU, outputMPS)
upsample_bilinear2d = nn.UpsamplingBilinear2d((i * H, j * W))
outputCPU = upsample_bilinear2d(inputCPU)
outputMPS = upsample_bilinear2d(inputMPS)
self.assertEqual(outputCPU, outputMPS)
outputCPU.backward(gradient=torch.full_like(outputCPU, 0.3))
outputMPS.backward(gradient=torch.full_like(outputMPS, 0.3))
self.assertEqual(inputCPU.grad, inputMPS.grad)
helper(1, 1, 4, 4)
helper(7, 5, 3, 2)
def test_interpolate(self):
def helper(shape, output_size, scales, mode, align_corners=False):
inputCPU = torch.randn(shape, device='cpu', dtype=torch.float, requires_grad=True)
inputCPU.retain_grad()
inputMPS = inputCPU.detach().clone().to('mps').requires_grad_()
# align_corners is used for 2D interpolation only
if (align_corners is True and len(shape) > 3 and mode == 'bilinear'):
if scales is not None:
outputCPU = nn.functional.interpolate(inputCPU, scale_factor=scales, mode=mode, align_corners=align_corners)
outputMPS = nn.functional.interpolate(inputMPS, scale_factor=scales, mode=mode, align_corners=align_corners)
else:
outputCPU = nn.functional.interpolate(inputCPU, size=output_size, mode=mode, align_corners=align_corners)
outputMPS = nn.functional.interpolate(inputMPS, size=output_size, mode=mode, align_corners=align_corners)
elif scales is not None:
outputCPU = nn.functional.interpolate(inputCPU, scale_factor=scales, mode=mode)
outputMPS = nn.functional.interpolate(inputMPS, scale_factor=scales, mode=mode)
else:
outputCPU = nn.functional.interpolate(inputCPU, size=output_size, mode=mode)
outputMPS = nn.functional.interpolate(inputMPS, size=output_size, mode=mode)
self.assertEqual(outputCPU, outputMPS)
# backward pass (chose 0.6 just to have the grad_output != 1)
outputCPU.backward(gradient=torch.full_like(outputCPU, 0.6))
outputMPS.backward(gradient=torch.full_like(outputMPS, 0.6))
self.assertEqual(inputCPU.grad, inputMPS.grad)
# 1D interpolation
for mode in ['nearest', 'nearest-exact']:
helper([2, 3, 4], [3], None, mode) # downsample with size
helper([2, 3, 4], [6], None, mode) # upsample with size
helper([2, 3, 4], None, [0.6], mode) # downsample with scale factor
helper([2, 3, 4], None, [1.7], mode) # upsample with scale factor
# 2D interpolation
for mode in ['nearest', 'nearest-exact', 'bilinear']:
helper([2, 3, 4, 5], [3, 4], None, mode) # downsample_nearest with size
helper([2, 3, 4, 5], [6, 7], None, mode) # upsample_nearest with size
helper([2, 3, 4, 5], None, [0.6, 0.7], mode) # downsample_nearest with scale factor
helper([2, 3, 4, 5], None, [1.4, 1.7], mode) # upsample_nearest with scale factor
# align_corners=True
helper([2, 3, 4, 5], [3, 4], None, 'bilinear', True)
helper([2, 3, 4, 5], None, [1.4, 1.7], 'bilinear', True)
# Regression test for https://github.com/pytorch/pytorch/issues/144245
inp = torch.tensor([[[1.]], [[2]], [[4]]], device='mps')
for align_corners in [True, False]:
def interp(x):
return F.interpolate(x, 3, mode='linear', align_corners=align_corners)
self.assertEqual(interp(inp).cpu(), interp(inp.cpu()))
# Test concat forward
def test_cat1(self):
def helper(shape_x, shape_y, shape_z):
cpu_x = torch.randn(shape_x, device='cpu', dtype=torch.float, requires_grad=False)
x = cpu_x.detach().clone().to('mps')
cpu_y = torch.randn(shape_y, device='cpu', dtype=torch.float, requires_grad=False)
y = cpu_y.detach().clone().to('mps')
cpu_z = torch.randn(shape_z, device='cpu', dtype=torch.float, requires_grad=False)
z = cpu_z.detach().clone().to('mps')
cat = torch.cat([x, y, z], dim=1)
cat_cpu = torch.cat([cpu_x, cpu_y, cpu_z], dim=1)
self.assertEqual(cat, cat_cpu)
helper([2, 2, 4, 5], [2, 3, 4, 5], [2, 5, 4, 5])
helper([2, 2, 6, 5], [2, 3, 6, 5], [2, 5, 6, 5])
helper([0, 2, 4, 5], [0, 3, 4, 5], [0, 5, 4, 5])
helper([2, 2, 6, 5], [0], [2, 5, 6, 5])
helper([0], [2, 3, 6, 5], [2, 5, 6, 5])
helper([2, 3, 4, 5], [2, 5, 4, 5], [0])
helper([2, 2, 6, 5], [2, 0, 6, 5], [2, 5, 6, 5])
helper([2, 0, 6, 5], [2, 3, 6, 5], [2, 5, 6, 5])
helper([2, 0, 6, 5], [2, 3, 6, 5], [2, 0, 6, 5])
# Test stack forward
def test_stack(self):
# All shapes must be same
def helper(shape, dtype=torch.float32):
x, cpu_x = None, None
y, cpu_y = None, None
z, cpu_z = None, None
if (dtype not in [torch.float32, torch.bool]):
cpu_x = torch.randint(50, shape, device='cpu', dtype=dtype, requires_grad=False)
x = cpu_x.detach().clone().to('mps')
cpu_y = torch.randint(50, shape, device='cpu', dtype=dtype, requires_grad=False)
y = cpu_y.detach().clone().to('mps')
cpu_z = torch.randint(50, shape, device='cpu', dtype=dtype, requires_grad=False)
z = cpu_z.detach().clone().to('mps')
elif (dtype == torch.bool):
cpu_x = torch.randint(2, shape, device='cpu', dtype=dtype, requires_grad=False)
x = cpu_x.detach().clone().to('mps')
cpu_y = torch.randint(2, shape, device='cpu', dtype=dtype, requires_grad=False)
y = cpu_y.detach().clone().to('mps')
cpu_z = torch.randint(2, shape, device='cpu', dtype=dtype, requires_grad=False)
z = cpu_z.detach().clone().to('mps')
else:
cpu_x = torch.randn(shape, device='cpu', dtype=dtype, requires_grad=True)
x = cpu_x.detach().clone().to('mps').requires_grad_()
cpu_y = torch.randn(shape, device='cpu', dtype=dtype, requires_grad=True)
y = cpu_y.detach().clone().to('mps').requires_grad_()
cpu_z = torch.randn(shape, device='cpu', dtype=dtype, requires_grad=True)
z = cpu_z.detach().clone().to('mps').requires_grad_()
stack = torch.stack([x, y, z], dim=1)
stack_cpu = torch.stack([cpu_x, cpu_y, cpu_z], dim=1)
self.assertEqual(stack, stack_cpu)
helper([2, 8, 4, 5])
helper([2, 8, 4, 5], dtype=torch.float16)
helper([2, 8, 4, 5], dtype=torch.int32)
helper([2, 8, 4, 5], dtype=torch.int64)
helper([2, 8, 4, 5], dtype=torch.bool)
# Empty test - Currently failing! Empty tensor not handled!
# helper([0, 2, 4, 5])
# Test abs
def test_abs(self):
def helper(shape):
cpu_x = torch.randn(shape, device='cpu', dtype=torch.float, requires_grad=False)
x = cpu_x.detach().clone().to('mps')
abs_result = torch.abs(x)
abs_result_cpu = torch.abs(cpu_x)
self.assertEqual(abs_result, abs_result_cpu)
helper((2, 8, 4, 5))
def test_angle(self):
def helper(shape, dtype):
cpu_x = torch.randn(shape, device='cpu', dtype=dtype, requires_grad=False)
cpu_x.flatten()[0] = torch.nan # Test that NaN is propagated correctly
x = cpu_x.detach().clone().to('mps')
angle_result = torch.angle(x)
angle_result_cpu = torch.angle(cpu_x)
self.assertEqual(angle_result, angle_result_cpu)
helper((2, 8, 4, 5), torch.float16)
helper((2, 8, 4, 5), torch.float32)
helper((2, 8, 4, 5), torch.complex64)
def test_log(self):
def helper(shape):
cpu_x = torch.randn(shape, device='cpu', dtype=torch.float, requires_grad=False)
x = cpu_x.detach().clone().to('mps')
log_result = torch.log(x)
log_result_cpu = torch.log(cpu_x)
self.assertEqual(log_result, log_result_cpu)
helper((2, 8, 4, 5))
def test_log_ten(self):
def helper(shape):
cpu_x = torch.randn(shape, device='cpu', dtype=torch.float, requires_grad=False)
x = cpu_x.detach().clone().to('mps')
log_ten_result = torch.log10(x)
log_ten_result_cpu = torch.log10(cpu_x)
self.assertEqual(log_ten_result, log_ten_result_cpu)
helper((2, 8, 4, 5))
def test_log_two(self):
def helper(shape):
cpu_x = torch.randn(shape, device='cpu', dtype=torch.float, requires_grad=False)
x = cpu_x.detach().clone().to('mps')
log_two_result = torch.log2(x)
log_two_result_cpu = torch.log2(cpu_x)
self.assertEqual(log_two_result, log_two_result_cpu)
helper((2, 8, 4, 5))
@parametrize("dtype", {torch.float, torch.half, torch.bfloat16})
def test_log1p(self, dtype):
eps = torch.finfo(dtype).eps
# Small values
cpu_x = torch.arange(-10.0 * eps, 10.0 * eps, 1e-2 * eps, dtype=dtype, requires_grad=False)
x = cpu_x.detach().clone().to('mps')
log_result = torch.log1p(x)
log_result_cpu = torch.log1p(cpu_x)
self.assertEqual(log_result, log_result_cpu, atol=0, rtol=2e-7)
# Fallback to log
cpu_x = torch.arange(-1.0, 2.0, 1e-4, dtype=dtype, requires_grad=False)
x = cpu_x.detach().clone().to('mps')
log_result = torch.log1p(x)
log_result_cpu = torch.log1p(cpu_x)
self.assertEqual(log_result, log_result_cpu, atol=0, rtol=2e-7)
def test_logaddexp(self):
def helper(shape):
cpu_x = torch.randn(shape, device='cpu', dtype=torch.float, requires_grad=False)
x = cpu_x.detach().clone().to('mps')
cpu_y = torch.randn(shape, device='cpu', dtype=torch.float, requires_grad=False)
y = cpu_y.detach().clone().to('mps')
log_result = torch.logaddexp(x, y)
log_result_cpu = torch.logaddexp(cpu_x, cpu_y)
self.assertEqual(log_result, log_result_cpu)
helper((2, 8, 4, 5))
def test_logaddexp2(self):
def helper(shape):
cpu_x = torch.randn(shape, device='cpu', dtype=torch.float, requires_grad=False)
x = cpu_x.detach().clone().to('mps')
cpu_y = torch.randn(shape, device='cpu', dtype=torch.float, requires_grad=False)
y = cpu_y.detach().clone().to('mps')
log_result = torch.logaddexp2(x, y)
log_result_cpu = torch.logaddexp2(cpu_x, cpu_y)
self.assertEqual(log_result, log_result_cpu)
helper((2, 8, 4, 5))
def test_logsumexp(self):
def helper(shape):
cpu_x = torch.randn(shape, device='cpu', dtype=torch.float, requires_grad=False)
x = cpu_x.detach().clone().to('mps')
log_result = torch.logsumexp(x, -1)
log_result_cpu = torch.logsumexp(cpu_x, -1)
self.assertEqual(log_result, log_result_cpu)
helper((2, 8, 4, 5))
# Test concat forward
def test_cat2(self):
def helper1(shape_x, shape_y, shape_z, shape_w):
cpu_x = torch.randn(shape_x, device='cpu', dtype=torch.float, requires_grad=False)
x = cpu_x.detach().clone().to('mps')
cpu_y = torch.randn(shape_y, device='cpu', dtype=torch.float, requires_grad=False)
y = cpu_y.detach().clone().to('mps')
cpu_z = torch.randn(shape_z, device='cpu', dtype=torch.float, requires_grad=False)
z = cpu_z.detach().clone().to('mps')
cpu_w = torch.randn(shape_w, device='cpu', dtype=torch.float, requires_grad=False)
w = cpu_w.detach().clone().to('mps')
cat = torch.cat([x, y, z, w], dim=1)
cat_cpu = torch.cat([cpu_x, cpu_y, cpu_z, cpu_w], dim=1)
self.assertEqual(cat, cat_cpu)
def helper(shape_x, shape_y, shape_z):
cpu_x = torch.randn(shape_x, device='cpu', dtype=torch.float, requires_grad=False)
x = cpu_x.detach().clone().to('mps')
cpu_y = torch.randn(shape_y, device='cpu', dtype=torch.float, requires_grad=False)
y = cpu_y.detach().clone().to('mps')
cpu_z = torch.randn(shape_z, device='cpu', dtype=torch.float, requires_grad=False)
z = cpu_z.detach().clone().to('mps')
cat = torch.cat([x, y, z], dim=1)
cat_cpu = torch.cat([cpu_x, cpu_y, cpu_z], dim=1)
self.assertEqual(cat, cat_cpu)
helper([2, 8, 4, 5], [2, 10, 4, 5], [2, 6, 4, 5])
helper([2, 2, 4, 5], [2, 3, 4, 5], [2, 5, 4, 5])
# Empty test - Currently failing! Empty tensor not handled!
# helper([0, 2, 4, 5], [2, 0, 4, 5], [2, 5, 0, 5])
# Test isnan
def test_isnan(self):
def helper(shape):
cpu_x = torch.randn(shape, device='cpu', dtype=torch.float, requires_grad=False)
nan_index = [random.randrange(0, shape[0])]
# make a selected row inf
cpu_x.index_put_(indices=[torch.tensor(nan_index)], values=torch.tensor(float('nan')))
x = cpu_x.detach().clone().to('mps')
isnan_result = torch.isnan(x)
isnan_result_cpu = torch.isnan(cpu_x)
self.assertEqual(isnan_result, isnan_result_cpu)
helper((8, 2, 4, 5))
# Test reciprocal
def test_reciprocal(self):
def helper(shape):
cpu_x = torch.randn(shape, device='cpu', dtype=torch.float, requires_grad=True)
x = cpu_x.detach().clone().to('mps').requires_grad_()
reciprocal_result = torch.reciprocal(x)
reciprocal_result_cpu = torch.reciprocal(cpu_x)
cpu_grad = torch.ones_like(reciprocal_result_cpu)
grad = cpu_grad.to('mps')
reciprocal_result.backward(gradient=grad)
reciprocal_result_cpu.backward(gradient=cpu_grad)
self.assertEqual(reciprocal_result, reciprocal_result_cpu)
self.assertEqual(x.grad, cpu_x.grad)
helper((2, 8, 4, 5))
# Test sqrt
def test_sqrt(self):
def helper(shape):
cpu_x = torch.randn(shape, device='cpu', dtype=torch.float, requires_grad=True)
x = cpu_x.detach().clone().to('mps').requires_grad_()
sqrt_result = torch.sqrt(x)
sqrt_result_cpu = torch.sqrt(cpu_x)
cpu_grad = torch.ones_like(sqrt_result_cpu)
grad = cpu_grad.to('mps')
sqrt_result.backward(gradient=grad)
sqrt_result_cpu.backward(gradient=cpu_grad)
self.assertEqual(sqrt_result, sqrt_result_cpu)
self.assertEqual(x.grad, cpu_x.grad)
helper((2, 8, 4, 5))
# Test complex half
x = torch.rand(8, device='mps', dtype=torch.chalf)
rc_h = x.sqrt()
rc_f = x.cfloat().sqrt().chalf()
self.assertEqual(rc_h, rc_f)
# Test selu, elu, celu
def test_elu(self):
def helper(shape, alpha=1.0, memory_format=torch.contiguous_format):
cpu_x = torch.randn(shape, device='cpu', dtype=torch.float)
cpu_x = cpu_x.to(memory_format=memory_format).requires_grad_()
x = cpu_x.detach().clone().to('mps').requires_grad_(True)
for activation_func in [torch.nn.ELU(alpha=alpha), torch.nn.CELU(alpha=alpha), torch.nn.SELU()]:
elu_result = activation_func(x)
elu_result_cpu = activation_func(cpu_x)
cpu_grad = torch.randn(elu_result_cpu.shape)
grad = cpu_grad.to('mps')
elu_result.backward(gradient=grad)
elu_result_cpu.backward(gradient=cpu_grad)
self.assertEqual(elu_result, elu_result_cpu)
self.assertEqual(x.grad, cpu_x.grad)
# Test empty shape too
for memory_fromat in [torch.channels_last, torch.contiguous_format]:
for shape in [(2, 8, 4, 5)]:
for alpha in [0.000001, 1.0, 2.3, 0.34, 23]:
helper(shape, alpha, memory_fromat)
def test_elu_strided_output(self):
# https://github.com/pytorch/pytorch/issues/124834
elu_input = torch.randn(1, 1024, 500)
alpha = float(1)
inplace = False
elu_input_noncontiguous = elu_input.transpose(1, 2)
self.assertEqual(
F.elu(elu_input_noncontiguous.to('cpu'), alpha, inplace),
F.elu(elu_input_noncontiguous.to('mps'), alpha, inplace)
)
# Test glu
def test_glu(self):
def helper(shape, dim=0):
cpu_x = torch.randn(shape, device='cpu', dtype=torch.float, requires_grad=True)
x = cpu_x.detach().clone().to('mps').requires_grad_()
for activation_func in [torch.nn.GLU(dim=dim)]:
glu_result = activation_func(x)
glu_result_cpu = activation_func(cpu_x)
cpu_grad = torch.randn(glu_result_cpu.shape)
grad = cpu_grad.to('mps')
glu_result.backward(gradient=grad)
glu_result_cpu.backward(gradient=cpu_grad)
self.assertEqual(glu_result, glu_result_cpu)
self.assertEqual(x.grad, cpu_x.grad)
for shape in [[4], (2, 4), (2, 8, 4, 6)]:
for dim in range(len(shape)):
helper(shape, dim)
# Test softplus
def test_softplus(self):
def helper(shape, beta, threshold, dtype):
cpu_x = torch.randn(shape, device='cpu', dtype=dtype, requires_grad=True)
x = cpu_x.detach().clone().to('mps').requires_grad_()
softplus_result = torch.nn.Softplus(beta=beta, threshold=threshold)(x)
softplus_result_cpu = torch.nn.Softplus(beta=beta, threshold=threshold)(cpu_x)
cpu_grad = torch.randn(softplus_result.shape)
grad = cpu_grad.to('mps')
softplus_result.backward(gradient=grad)
softplus_result_cpu.backward(gradient=cpu_grad)
self.assertEqual(softplus_result, softplus_result_cpu)
self.assertEqual(x.grad, cpu_x.grad)
# Test empty shape too
for shape, beta, threshold, dtype in product(
[(), (2, 3), (10, 10), (2, 3, 4, 5)],
[0.5, 1, 2, 3, 4],
[0.5, 20, 30, 40, 50],
[torch.float16, torch.float32]
):
helper(shape, beta, threshold, dtype)
# Test silu
def test_silu(self):
def helper(shape, contiguous=True):
cpu_x = torch.randn(shape, device='cpu', dtype=torch.float)
x = cpu_x.detach().clone().to('mps')
if not contiguous and (0 not in shape and len(shape) >= 2):
# Transposing will make the tensor non-contiguous
cpu_x = cpu_x.transpose(0, 1)
x = x.transpose(0, 1)
assert not x.is_contiguous()
cpu_x.requires_grad_()
x.requires_grad_()
silu_result = torch.nn.SiLU()(x)
silu_result_cpu = torch.nn.SiLU()(cpu_x)
cpu_grad = torch.randn(silu_result_cpu.shape)
grad = cpu_grad.to('mps')
silu_result.backward(gradient=grad)
silu_result_cpu.backward(gradient=cpu_grad)
self.assertEqual(silu_result, silu_result_cpu)
self.assertEqual(x.grad, cpu_x.grad)
# Test empty shape too
for shape in [[], (2, 3), (2, 8, 4, 5)]:
for contiguous in [True, False]:
helper(shape, contiguous)
def test_cast_mps_to_cpu(self):
def helper(src_dtype, dst_dtype):
input = torch.rand((1, 3, 128, 128), dtype=src_dtype)
input_cast_mps = input.to('mps')
input_cast_cpu = input_cast_mps.to('cpu', dtype=dst_dtype)
# needs to match the initial Tensor
self.assertEqual(input_cast_cpu, input.to(dtype=dst_dtype))
helper(torch.half, torch.float)
helper(torch.float, torch.half)
def test_cast_mps_to_mps(self):
def helper(src_dtype, dst_dtype):
input_cpu = torch.rand((1, 3, 128, 128), dtype=src_dtype)
input_mps = input_cpu.to('mps')
output_mps = input_mps.to(dtype=dst_dtype)
output_cpu = input_cpu.to(dtype=dst_dtype)
self.assertEqual(output_mps.cpu(), output_cpu)
helper(torch.half, torch.float)
helper(torch.float, torch.half)
helper(torch.half, torch.long)
helper(torch.float, torch.int)
def test_avg_pool2d_count_include_pad(self):
cpu_x = torch.randn((1, 3, 9, 9), device='cpu', dtype=torch.float, requires_grad=True)
x = cpu_x.detach().clone().to('mps').requires_grad_()
pool = torch.nn.AvgPool2d(kernel_size=(3, 3), padding=(1, 1), stride=(1, 1), ceil_mode=True, count_include_pad=True)
ref_y = pool(cpu_x)
y = pool(x)
self.assertEqual(y, ref_y)
cpu_grad = torch.randn(ref_y.shape)
grad = cpu_grad.to('mps')
ref_y.backward(gradient=cpu_grad)
y.backward(gradient=grad)
self.assertEqual(x.grad, cpu_x.grad)
# Test adaptive avg pool2d - when the input size is a multiple of output size
# Not testing for channels last right now
def test_adaptive_avg_pool2d_simple(self):
def helper(input_shape, out_shape, channels_last):
cpu_x = torch.randn(input_shape, device='cpu', dtype=torch.float, requires_grad=True)
if (channels_last):
cpu_x = cpu_x.to(memory_format=torch.channels_last)
cpu_x.retain_grad()
x = cpu_x.detach().clone().to('mps').requires_grad_()
avg_result = torch.nn.AdaptiveAvgPool2d(out_shape)(x)
avg_result_cpu = torch.nn.AdaptiveAvgPool2d(out_shape)(cpu_x)
cpu_grad = torch.randn(avg_result_cpu.shape)
grad = cpu_grad.to('mps')
avg_result.backward(gradient=grad)
avg_result_cpu.backward(gradient=cpu_grad)
self.assertEqual(avg_result, avg_result_cpu)
self.assertEqual(x.grad, cpu_x.grad)
helper((2, 2, 4, 4), (2, 2), False)
helper((2, 2, 9, 9), (3, 3), False)
helper((2, 2, 9, 9), (9, 9), False)
helper((2, 2, 16, 16), (2, 2), False)
helper((2, 2, 16, 16), (2, 16), False)
helper((2, 16, 16), (4, 4), False)
# Output shape larger than input shape
helper((2, 2, 4, 4), (8, 8), False)
helper((2, 2, 2, 2), (4, 4), False)
helper((2, 2, 3, 3), (9, 9), False)
helper((2, 2, 2, 2), (16, 16), False)
helper((2, 2, 2, 16), (16, 16), False)
helper((2, 4, 4), (16, 16), False)
try:
helper((2, 2, 3, 3), (7, 7), False)
except Exception as e:
pass
# Test max avg pool2d - when the input size is a multiple of output size
# Not testing for channels last right now
def test_adaptive_max_pool2d_simple(self):
def helper(input_shape, out_shape, return_indices, dtype, channels_last=False):
cpu_x = None
if (dtype in [torch.float16, torch.float32]):
cpu_x = torch.randn(input_shape, device='cpu', dtype=dtype, requires_grad=True)
else:
cpu_x = torch.randint(50, input_shape, device='cpu', dtype=dtype, requires_grad=True)
if (channels_last):
cpu_x = cpu_x.to(memory_format=torch.channels_last)
cpu_x.retain_grad()
x = cpu_x.detach().clone().to('mps').requires_grad_()
max_result, max_indices = None, None
max_result_cpu, max_indices_cpu = None, None
if (return_indices):
max_result, max_indices = torch.nn.AdaptiveMaxPool2d(out_shape, return_indices)(x)
max_result_cpu, max_indices_cpu = torch.nn.AdaptiveMaxPool2d(out_shape, return_indices)(cpu_x)
else:
max_result = torch.nn.AdaptiveMaxPool2d(out_shape, return_indices)(x)
max_result_cpu = torch.nn.AdaptiveMaxPool2d(out_shape, return_indices)(cpu_x)
cpu_grad = torch.randn(max_result_cpu.shape)
grad = cpu_grad.to('mps')
max_result.backward(gradient=grad)
max_result_cpu.backward(gradient=cpu_grad)
self.assertEqual(max_result, max_result_cpu)
if (return_indices):
self.assertEqual(max_indices, max_indices_cpu)
self.assertEqual(x.grad, cpu_x.grad)
for dtype in [torch.float32]:
for return_indices in [False, True]:
helper((2, 2, 4, 4), (2, 2), return_indices, dtype)
helper((2, 2, 9, 9), (3, 3), return_indices, dtype)
helper((2, 2, 9, 9), (9, 9), return_indices, dtype)
helper((2, 2, 16, 16), (2, 2), return_indices, dtype)
helper((2, 2, 16, 16), (2, 16), return_indices, dtype)
helper((2, 16, 16), (4, 4), return_indices, dtype)
def test_gelu_simple(self):
def helper(shape, dtype=torch.float, contiguous=True):
cpu_x = torch.randn(shape, device='cpu', dtype=dtype)
x = cpu_x.detach().clone().to('mps')
if not contiguous and (0 not in shape and len(shape) >= 2):
# Transposing will make the tensor non-contiguous
cpu_x = cpu_x.transpose(0, 1)
x = x.transpose(0, 1)
assert not x.is_contiguous()
cpu_x.requires_grad_()
x.requires_grad_()
gelu_result = torch.nn.GELU()(x)
# GELU is not supported on CPU, so cast it to float
gelu_result_cpu = torch.nn.GELU()(cpu_x.to(torch.float))
cpu_grad = torch.ones_like(gelu_result_cpu)
grad = cpu_grad.to('mps')
gelu_result.backward(gradient=grad)
gelu_result_cpu.backward(gradient=cpu_grad)
atol = 1e-5 if dtype == torch.float else 1e-2
rtol = 1e-3 if dtype == torch.float else 1e-2
self.assertEqual(gelu_result, gelu_result_cpu.to(dtype), atol=atol, rtol=rtol)
assert x.grad is not None # Check that the grad is well-populated
self.assertEqual(x.grad, cpu_x.grad, atol=atol, rtol=rtol)
# Test empty shape too
for dtype in [torch.float, torch.half]:
for shape in [[], (0,), (0, 3), (4,), (4, 3), (5, 4, 3)]:
for contiguous in [True, False]:
helper(shape, dtype, contiguous)
# Test that gelu would raise an assert for integral types
for dtype in [torch.int8, torch.int16, torch.int32, torch.int64]:
self.assertRaises(RuntimeError, lambda: torch.nn.GELU()(torch.randint(100, (2,), dtype=dtype, device="mps")))
def test_mish_simple(self):
def helper(shape, dtype=torch.float, contiguous=True):
cpu_x = torch.randn(shape, device='cpu', dtype=dtype)
x = cpu_x.detach().clone().to('mps')
if not contiguous and (0 not in shape and len(shape) >= 2):
# Transposing will make the tensor non-contiguous
cpu_x = cpu_x.transpose(0, 1)
x = x.transpose(0, 1)
assert not x.is_contiguous()
cpu_x.requires_grad_()
x.requires_grad_()
mish_result = torch.nn.Mish()(x)
mish_result_cpu = torch.nn.Mish()(cpu_x)
cpu_grad = torch.ones_like(mish_result_cpu)
grad = cpu_grad.to('mps')
mish_result.backward(gradient=grad)
mish_result_cpu.backward(gradient=cpu_grad)
atol = 1e-5 if dtype == torch.float else 1e-2
rtol = 1e-3 if dtype == torch.float else 1e-2
self.assertEqual(mish_result, mish_result_cpu.to(dtype), atol=atol, rtol=rtol)
assert x.grad is not None # Check that the grad is well-populated
self.assertEqual(x.grad, cpu_x.grad, atol=atol, rtol=rtol)
# Test empty shape too
for dtype in [torch.float, torch.half]:
for shape in [[], (0,), (0, 3), (4,), (4, 3), (5, 4, 3)]:
for contiguous in [True, False]:
helper(shape, dtype, contiguous)
def test_gelu(self):
def _test_gelu(n, m, dtype, contiguous, atol=None, rtol=None):
numpy_dtype = {
torch.bfloat16: torch.float, torch.float: torch.float, torch.double: torch.double
}[dtype]
devices = ['cpu']
devices += ['mps']
def _gelu_ref(X):
return X * stats.norm.cdf(X) # noqa: F821
for d in devices:
X = torch.rand(n, m, dtype=dtype, requires_grad=True, device=d)[:, ::2]
res = X
ref = (X.to(numpy_dtype).cpu().detach().numpy())
self.assertEqual(res, ref, rtol=rtol, atol=atol, exact_dtype=False)
for n in [1, 5, 10]:
for m in [1, 5, 10]:
_test_gelu(n, m, torch.float32, True)
_test_gelu(n, m, torch.float32, False)
# Test multi threaded
num_threads = torch.get_num_threads()
torch.set_num_threads(4)
try:
_test_gelu(32, 32, torch.float32, False)
finally:
torch.set_num_threads(num_threads)
def test_gelu_tanh(self):
def helper(shape):
cpu_x = torch.randn(shape, device='cpu', dtype=torch.float)
x = cpu_x.detach().clone().to('mps')
gelu_tanh_result = torch.nn.functional.gelu(x, approximate='tanh')
gelu_tanh_result_cpu = torch.nn.functional.gelu(cpu_x, approximate='tanh')
self.assertEqual(gelu_tanh_result, gelu_tanh_result_cpu)
helper((2, 8, 4, 5))
# Test hardtanh
def test_hardtanh(self):
def helper(shape, min_val, max_val, inplace=False):
cpu_x = None
x = None
if (not inplace):
cpu_x = torch.randn(shape, device='cpu', dtype=torch.float, requires_grad=True)
x = cpu_x.detach().clone().to('mps').requires_grad_()
else:
cpu_x = torch.randn(shape, device='cpu', dtype=torch.float, requires_grad=False)
x = cpu_x.detach().clone().to('mps')
hardtanh_result = torch.nn.Hardtanh(min_val=min_val, max_val=max_val, inplace=inplace)(x)
hardtanh_result_cpu = torch.nn.Hardtanh(min_val=min_val, max_val=max_val, inplace=inplace)(cpu_x)
self.assertEqual(hardtanh_result, hardtanh_result_cpu)
if (not inplace):
cpu_grad = torch.randn(hardtanh_result_cpu.shape)
grad = cpu_grad.to('mps')
hardtanh_result.backward(gradient=grad)
hardtanh_result_cpu.backward(gradient=cpu_grad)
self.assertEqual(x.grad, cpu_x.grad)
# Test empty shape too
for shape in [(0, 3), [], (2, 3), (2, 8, 4, 5)]:
for min_val, max_val in zip([-1, -2, 3], [1, -1, 4]):
helper(shape, min_val, max_val)
helper(shape, min_val, max_val, inplace=True)
def test_hardswish(self):
def helper(shape, inplace=False, requires_grad=True):
m = nn.Hardswish(inplace=inplace)
input_cpu = torch.randn(shape, device='cpu', dtype=torch.float, requires_grad=requires_grad)
input_mps = input_cpu.detach().clone().to('mps').requires_grad_(requires_grad)
if inplace and requires_grad: # check that both raise runtime error
self.assertRaises(RuntimeError, lambda: m(input_cpu))
self.assertRaises(RuntimeError, lambda: m(input_mps))
return
output_cpu = m(input_cpu)
output_mps = m(input_mps)
cpu_grad = torch.ones_like(output_cpu)
mps_grad = cpu_grad.to('mps')
self.assertEqual(output_cpu, output_mps)
if requires_grad:
output_cpu.backward(gradient=cpu_grad)
output_mps.backward(gradient=mps_grad)
self.assertEqual(input_cpu.grad, input_mps.grad)
for shape in [(0, 3), [], (2, 3), (2, 8, 4, 5)]:
helper(shape, inplace=False, requires_grad=False)
helper(shape, inplace=True, requires_grad=False)
helper(shape, inplace=False, requires_grad=True)
helper(shape, inplace=True, requires_grad=True)
def test_transpose_2D(self):
values = [[1.0, 2.0, 3.0], [4.0, 5.0, 6.0], [7.0, 8.0, 9.0]]
values1 = [[1.0, 1.0, 1.0], [1.0, 1.0, 1.0], [1.0, 1.0, 1.0]]
cpu_x = torch.tensor(values, device='cpu')
mps_x = torch.tensor(values, device='mps')
mps_x1 = torch.tensor(values1, device='mps')
cpu_transpose = torch.transpose(cpu_x, 0, 1)
mps_transpose = torch.transpose(mps_x, 0, 1)
self.assertEqual(cpu_transpose, mps_transpose.to('cpu'))
def test_transpose_3D(self):
values = [[[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]], [[7.0, 8.0, 9.0], [10.0, 11.0, 12.0]]]
cpu_x = torch.tensor(values, device='cpu')
mps_x = torch.tensor(values, device='mps')
cpu_transpose1 = torch.transpose(cpu_x, 0, 1)
mps_transpose1 = torch.transpose(mps_x, 0, 1).to('cpu')
self.assertEqual(cpu_transpose1, mps_transpose1)
cpu_transpose2 = torch.transpose(cpu_x, 0, 2)
mps_transpose2 = torch.transpose(mps_x, 0, 2).to('cpu')
self.assertEqual(cpu_transpose2, mps_transpose2)
cpu_transpose3 = torch.transpose(cpu_x, 1, 2)
mps_transpose3 = torch.transpose(mps_x, 1, 2).to('cpu')
self.assertEqual(cpu_transpose3, mps_transpose3)
def test_transpose_4D(self):
values = [[[[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]], [[7.0, 8.0, 9.0], [10.0, 11.0, 12.0]]],
[[[13.0, 14.0, 15.0], [16.0, 17.0, 18.0]], [[19.0, 20.0, 21.0], [22.0, 23.0, 24.0]]]]
cpu_x = torch.tensor(values, device='cpu')
mps_x = torch.tensor(values, device='mps')
cpu_transpose1 = torch.transpose(cpu_x, 0, 1)
mps_transpose1 = torch.transpose(mps_x, 0, 1).to('cpu')
self.assertEqual(cpu_transpose1, mps_transpose1)
cpu_transpose2 = torch.transpose(cpu_x, 0, 2)
mps_transpose2 = torch.transpose(mps_x, 0, 2).to('cpu')
self.assertEqual(cpu_transpose2, mps_transpose2)
cpu_transpose3 = torch.transpose(cpu_x, 0, 3)
mps_transpose3 = torch.transpose(mps_x, 0, 3).to('cpu')
self.assertEqual(cpu_transpose3, mps_transpose3)
cpu_transpose4 = torch.transpose(cpu_x, 3, 1)
mps_transpose4 = torch.transpose(mps_x, 3, 1).to('cpu')
self.assertEqual(cpu_transpose4, mps_transpose4)
cpu_transpose5 = torch.transpose(cpu_x, 3, 2)
mps_transpose5 = torch.transpose(mps_x, 3, 2).to('cpu')
self.assertEqual(cpu_transpose5, mps_transpose5)
cpu_transpose6 = torch.transpose(cpu_x, 1, 2)
mps_transpose6 = torch.transpose(mps_x, 1, 2).to('cpu')
self.assertEqual(cpu_transpose6, mps_transpose6)
# Test sign
def test_sign(self):
def helper(shape):
cpu_x = torch.randn(shape, device='cpu', dtype=torch.float, requires_grad=True)
x = cpu_x.detach().clone().to('mps').requires_grad_()
sign_result = torch.sign(x)
sign_result_cpu = torch.sign(cpu_x)
cpu_grad = torch.ones_like(sign_result_cpu)
grad = cpu_grad.to('mps')
sign_result.backward(gradient=grad)
sign_result_cpu.backward(gradient=cpu_grad)
self.assertEqual(sign_result, sign_result_cpu)
helper((2, 8, 4, 5))
def test_signbit(self):
def helper(shape, dtype):
cpu_x = torch.randn(shape, device='cpu').to(dtype)
x = cpu_x.clone().to('mps')
signbit_result = torch.signbit(x)
signbit_result_cpu = torch.signbit(cpu_x)
self.assertEqual(signbit_result, signbit_result_cpu)
helper((2, 8, 4, 5), torch.int)
helper((2, 8, 4, 5), torch.float)
helper((2, 8, 4, 5), torch.int64)
# Test neg
def test_neg(self):
def helper(shape):
cpu_x = torch.randn(shape, device='cpu', dtype=torch.float, requires_grad=True)
x = cpu_x.detach().clone().to('mps').requires_grad_()
neg_result = torch.neg(x)
neg_result_cpu = torch.neg(cpu_x)
cpu_grad = torch.ones_like(neg_result_cpu)
grad = cpu_grad.to('mps')
neg_result.backward(gradient=grad)
neg_result_cpu.backward(gradient=cpu_grad)
self.assertEqual(neg_result, neg_result_cpu)
helper((2, 8, 4, 5))
def test_neg_strided_input(self):
# See https://github.com/pytorch/pytorch/issues/98074#issuecomment-1496088337
x = torch.arange(18.0, device='mps').reshape(2, 3, 3)
y = x.permute(1, 0, 2)[..., 1]
z = y + y.neg()
self.assertEqual(z.abs().max().item(), 0.0)
# Test index add
def test_index_add(self):
def helper(shape, dim, index, source_shape, alpha, x_dtype=torch.float32, idx_dtype=torch.int32):
cpu_x = torch.randn(shape, device='cpu', dtype=x_dtype, requires_grad=False)
x = cpu_x.detach().clone().to('mps')
cpu_idx = torch.tensor(index, device='cpu', dtype=idx_dtype)
idx = cpu_idx.detach().clone().to('mps')
cpu_source = torch.randn(source_shape, device='cpu', dtype=x_dtype, requires_grad=False)
source = cpu_source.detach().clone().to('mps')
idx_result = torch.index_add(x, dim=dim, index=idx, source=source, alpha=alpha)
idx_result_cpu = torch.index_add(cpu_x, dim=dim, index=cpu_idx, source=cpu_source, alpha=alpha)
self.assertEqual(idx_result, idx_result_cpu)
helper((2, 8, 4, 5), 0, [0, 1, 0], (3, 8, 4, 5), 5)
helper((8, 8, 4, 5), 0, [7], (1, 8, 4, 5), 6.0)
helper((2, 8, 4, 5), 1, [0, 3, 7], (2, 3, 4, 5), 5)
helper((2, 8, 4, 5), 2, [3, 0], (2, 8, 2, 5), 3.0)
helper((2, 8, 4, 5), 3, [2, 3, 0], (2, 8, 4, 3), 4)
helper((2, 3, 3), -1, [1, 2], (2, 3, 2), 6.0)
# test result dim=1
helper((2,), 0, [1], (1,), 6.0)
helper(2, 0, 1, 1, 6)
# test float16
helper((2,), 0, [1], (1,), 6.0, x_dtype=torch.float16)
def test_index_64bit(self):
""" Test that index operations work for 4Gb+ tensors """
# Cleanup memory
gc.collect()
torch.mps.empty_cache()
# Check that index operations work for 4+GB tensors
x = torch.rand(16000, 67120, device="mps")
self.assertGreater(x.element_size() * x.numel(), 2**32)
idx = torch.arange(0, 2, device="mps")
x_sampled = x[:, idx]
self.assertEqual(x[:, 0], x_sampled[:, 0])
# Reclaim memory after running the tests
del x
gc.collect()
torch.mps.empty_cache()
def test_mm_large(self):
""" Test that MM works for matrices with index larger than 32K """
x = torch.rand(10, 1, device="mps")
y = torch.rand(1, 32769, device="mps")
# This used to crash with:
# error: subRange.start (24576) is not less than length of dimension[0] (16384)
# See https://github.com/pytorch/pytorch/issues/116769#issuecomment-1888302095
self.assertNotEqual(torch.mm(x, y[:, 16384:32768]).abs().max().item(), 0.0)
def compare_mm(m, n, k, dtype=torch.float):
x = torch.rand(m, n, device="mps", dtype=dtype)
y = torch.rand(n, k, device="mps", dtype=dtype)
z = torch.mm(x, y).cpu()
z_cpu = torch.mm(x.cpu(), y.cpu())
self.assertEqual(z, z_cpu)
# Used to produce incorrect results with MPS on M1 running MacOS 14.3, but correct with Metal
compare_mm(1024, 1, 32769)
# one more time, but with dimensions inverted
# see https://github.com/pytorch/pytorch/issues/116769#issuecomment-1920066984
compare_mm(32769, 1, 1025)
# Test bfloat16 mm
compare_mm(1024, 1, 32769, torch.bfloat16)
@unittest.skipIf(total_memory < 12_000_000_000, "Needs at least 12Gb RAM to run the test")
@unittest.skipIf(IS_CI, "May be fixes https://github.com/pytorch/pytorch/issues/149999")
def test_copy_large(self):
""" Test that copy of 4Gb+ tensors works """
x = torch.ones((2**30 + 11,), dtype=torch.float32)
y = x.to(device="mps")
self.assertTrue(torch.all(y == torch.tensor(1.0, device="mps")))
del y
del x
# Test flip
def test_flip(self):
def helper(shape, dims):
cpu_x = torch.randn(shape, device='cpu', dtype=torch.float, requires_grad=False)
x = cpu_x.detach().clone().to('mps')
flip_result = torch.flip(x, dims=dims)
flip_result_cpu = torch.flip(cpu_x, dims=dims)
self.assertEqual(flip_result, flip_result_cpu)
helper((2, 8, 4, 5), [0])
helper((8, 8, 4, 5), [0, 1])
helper((2, 8, 4, 5), (0, 1, 2, 3))
helper((2, 3, 3), (-1,))
# empty dims
helper((2, 8, 4, 5), [])
# input.numel() == 1
helper((1,), (0,))
# input.numel() == 0
helper((0,), (0,))
# none of dims that needs to be flipped
helper((1, 3), [0])
# Test index select
def test_index_select(self):
def helper(shape, dim, index, idx_dtype=torch.int32):
cpu_x = torch.randn(shape, device='cpu', dtype=torch.float, requires_grad=False)
x = cpu_x.detach().clone().to('mps')
cpu_idx = torch.tensor(index, device='cpu', dtype=idx_dtype)
idx = cpu_idx.detach().clone().to('mps')
idx_result = torch.index_select(x, dim=dim, index=idx)
idx_result_cpu = torch.index_select(cpu_x, dim=dim, index=cpu_idx)
self.assertEqual(idx_result, idx_result_cpu)
helper((2, 8, 4, 5), 0, [1])
helper((8, 8, 4, 5), 0, [0, 3, 2, 7, 6])
helper((2, 8, 4, 5), 1, [0, 3, 2, 7, 6])
helper((2, 8, 4, 5), 2, [3, 0, 1])
helper((2, 8, 4, 5), 3, [2, 3, 0])
helper((2, 3, 3), -1, [1, 2])
helper((), 0, [0])
helper((5), 0, [])
def test_index_copy_non_contiguous(self):
def helper(shape, dim, index):
dest_cpu = torch.randn(shape)
dest = dest_cpu.clone().to('mps')
dest_cpu = dest_cpu.transpose(0, 1)
dest = dest.transpose(0, 1)
dim = 1 if dim == 0 else 0 if dim == 1 else dim
src_shape = list(dest_cpu.shape)
src_shape[dim] = len(index)
src_cpu = torch.randn(src_shape)
src = src_cpu.clone().to('mps')
idx_cpu = torch.tensor(index, dtype=torch.long)
idx_mps = idx_cpu.clone().to('mps')
dest_cpu.index_copy_(dim, idx_cpu, src_cpu)
dest.index_copy_(dim, idx_mps, src)
self.assertEqual(dest, dest_cpu)
test_cases = [
((2, 8, 4, 5), 0, [1]),
((8, 8, 4, 5), 0, [0, 3, 2, 7, 6]),
((2, 8, 4, 5), 1, [0, 3, 2, 7, 6]),
((2, 8, 4, 5), 2, [3, 0, 1]),
((2, 8, 4, 5), 3, [2, 3, 0]),
((2, 3, 3), -1, [1, 2])
]
for args in test_cases:
helper(*args)
def test_index_select_scalar(self):
def helper(value, dim, index, idx_dtype=torch.int32):
cpu_x = torch.tensor(value, device='cpu', dtype=torch.float, requires_grad=False)
x = cpu_x.detach().clone().to('mps')
cpu_idx = torch.tensor(index, device='cpu', dtype=idx_dtype)
idx = cpu_idx.detach().clone().to('mps')
idx_result = torch.index_select(x, dim=dim, index=idx)
idx_result_cpu = torch.index_select(cpu_x, dim=dim, index=cpu_idx)
self.assertEqual(idx_result, idx_result_cpu)
helper(22, 0, [0])
with self.assertRaisesRegex(RuntimeError, "Index to scalar can have only 1 value"):
helper(22, 0, [])
def test_embedding_dense_backward(self):
def helper(n, d, m, idx):
embeddingMPS = nn.Embedding(n, d, max_norm=True, device='mps')
embedding_weight = embeddingMPS.weight.detach().cpu()
W_MPS = torch.randn((m, d), requires_grad=True, device='mps')
idx_MPS = torch.tensor(idx, device='mps')
a_MPS = embeddingMPS.weight.clone() @ W_MPS.t() # weight must be cloned for this to be differentiable
a_MPS.retain_grad()
b_MPS = embeddingMPS(idx_MPS) @ W_MPS.t() # modifies weight in-place
b_MPS.retain_grad()
out_MPS = (a_MPS.unsqueeze(0) + b_MPS)
loss_MPS = out_MPS.sigmoid().prod()
loss_MPS.backward()
embeddingCPU = nn.Embedding(n, d, max_norm=True, _weight=embedding_weight)
W_CPU = W_MPS.to('cpu')
idx_CPU = torch.tensor(idx)
a_CPU = embeddingCPU.weight.clone() @ W_CPU.t() # weight must be cloned for this to be differentiable
a_CPU.retain_grad()
b_CPU = embeddingCPU(idx_CPU) @ W_CPU.t() # modifies weight in-place
b_CPU.retain_grad()
out_CPU = (a_CPU.unsqueeze(0) + b_CPU)
loss_CPU = out_CPU.sigmoid().prod()
loss_CPU.backward()
self.assertEqual(b_CPU.grad, b_MPS.grad)
self.assertEqual(a_CPU.grad, a_MPS.grad)
helper(3, 5, 7, [0, 1, 2])
helper(3, 6, 7, [0, 1, 2]) # verify if changes in shape would cause cached graph lookup problems
helper(3, 5, 7, 2) # test scalar index
# Test pytorch gather
def test_gather(self):
def helper(shape, dim, idx_shape, idx_dtype=torch.int64):
cpu_x = torch.randn(shape, device='cpu', dtype=torch.float, requires_grad=True)
x = cpu_x.detach().clone().to('mps').requires_grad_()
# Indices should be taken from range of axis along which gathering is done
idx_np = np.random.randint(0, shape[dim], idx_shape)
cpu_idx = torch.tensor(idx_np, device='cpu', dtype=idx_dtype)
idx = cpu_idx.detach().clone().to('mps')
gather_result = torch.gather(x, dim=dim, index=idx)
gather_result_cpu = torch.gather(cpu_x, dim=dim, index=cpu_idx)
cpu_grad = torch.randn(idx_shape, device='cpu', dtype=torch.float)
grad = cpu_grad.to('mps')
gather_result.backward(gradient=grad)
gather_result_cpu.backward(gradient=cpu_grad)
self.assertEqual(gather_result, gather_result_cpu)
self.assertEqual(cpu_x.grad, x.grad)
helper((6, 3, 3), 0, (3, 3, 3))
helper((2, 3, 3, 3), 0, (10, 3, 3, 3))
helper((2, 8, 4, 5), 0, (10, 8, 4, 5))
helper((2, 8, 4, 5), 0, (10, 6, 3, 2))
helper((8, 8, 4, 5), 0, (6, 8, 4, 5))
helper((8, 8, 4, 5), 0, (6, 7, 2, 3))
helper((2, 8, 4, 5), 1, (2, 5, 3, 4))
helper((2, 8, 4, 5), 2, (1, 8, 10, 3))
helper((2, 8, 4, 5), 3, (2, 5, 3, 12))
# Test pytorch gather
def test_gather_scalar(self):
idx_dtype = torch.int64
cpu_x = torch.tensor(3, device='cpu', dtype=torch.float, requires_grad=True)
x = cpu_x.detach().clone().to('mps').requires_grad_()
idx_np = [0]
cpu_idx = torch.tensor(idx_np, device='cpu', dtype=idx_dtype)
idx = cpu_idx.detach().clone().to('mps')
gather_result = torch.gather(x, dim=0, index=idx)
gather_result_cpu = torch.gather(cpu_x, dim=0, index=cpu_idx)
cpu_grad = torch.randn([1], device='cpu', dtype=torch.float)
grad = cpu_grad.to('mps')
gather_result.backward(gradient=grad)
gather_result_cpu.backward(gradient=cpu_grad)
self.assertEqual(gather_result, gather_result_cpu)
self.assertEqual(cpu_x.grad, x.grad)
# Test pytorch scatter_add and scatter
def test_scatter_add(self):
def helper(shape, dim, idx_shape, src_shape, idx_dtype=torch.int64, do_add=True):
cpu_x = torch.randn(shape, device='cpu', dtype=torch.float, requires_grad=True)
x = cpu_x.detach().clone().to('mps').requires_grad_()
cpu_src = torch.randn(src_shape, device='cpu', dtype=torch.float, requires_grad=True)
src = cpu_src.detach().clone().to('mps').requires_grad_()
# Indices should be taken from range of axis along which gathering is done
idx_np = None
if (do_add):
idx_np = np.random.randint(0, shape[dim], idx_shape)
else:
idx_np = np.array([[0, 1, 2],
[1, 2, 3],
[2, 3, 4],
[3, 4, 5],
[4, 5, 6]])
cpu_idx = torch.tensor(idx_np, device='cpu', dtype=idx_dtype)
idx = cpu_idx.detach().clone().to('mps')
scatter_result = None
scatter_result_cpu = None
if (do_add):
scatter_result = torch.scatter_add(x, dim=dim, index=idx, src=src)
scatter_result_cpu = torch.scatter_add(cpu_x, dim=dim, index=cpu_idx, src=cpu_src)
else:
scatter_result = torch.scatter(x, dim=dim, index=idx, src=src)
scatter_result_cpu = torch.scatter(cpu_x, dim=dim, index=cpu_idx, src=cpu_src)
cpu_grad = None
grad = None
if (idx_shape == src_shape):
cpu_grad = torch.randn(shape, device='cpu', dtype=torch.float)
grad = cpu_grad.to('mps')
scatter_result.backward(gradient=grad)
scatter_result_cpu.backward(gradient=cpu_grad)
self.assertEqual(scatter_result, scatter_result_cpu)
if (idx_shape == src_shape):
self.assertEqual(cpu_x.grad, x.grad)
self.assertEqual(cpu_src.grad, src.grad)
helper((2, 3), 0, (5, 3), (5, 3))
helper((2, 8, 4, 5), 0, (10, 8, 4, 5), (10, 8, 4, 5))
helper((8, 8, 4, 5), 0, (10, 8, 4, 5), (10, 8, 4, 5))
helper((8, 8, 4, 5), 0, (4, 7, 3, 2), (4, 7, 3, 2))
helper((8, 8, 4, 5), 0, (4, 6, 3, 2), (4, 7, 3, 2))
helper((8, 8, 4, 5), 0, (4, 6, 3, 2), (8, 8, 4, 5))
helper((2, 8, 4, 5), 1, (2, 20, 4, 5), (2, 20, 4, 5))
helper((2, 8, 4, 5), 1, (2, 13, 3, 2), (2, 13, 3, 2))
helper((8, 8, 4, 5), 1, (6, 5, 2, 3), (6, 5, 2, 3))
helper((8, 8, 4, 5), 1, (3, 4, 2, 2), (6, 5, 2, 3))
helper((4, 5, 9, 8), 2, (4, 5, 13, 8), (4, 5, 13, 8))
helper((4, 5, 9, 8), 2, (3, 4, 10, 6), (3, 4, 10, 6))
helper((4, 5, 9, 8), 2, (3, 3, 7, 5), (3, 4, 10, 6))
# Test scatter src
helper((8, 3), 0, (5, 3), (5, 3), do_add=False)
helper((10, 3), 0, (5, 3), (5, 8), do_add=False)
# Test pytorch scatter_add and scatter for scalar input
def test_scatter_add_scalar(self):
def helper(idx_dtype=torch.int64, do_add=True):
cpu_x = torch.tensor(2, device='cpu', dtype=torch.float, requires_grad=True)
x = cpu_x.detach().clone().to('mps').requires_grad_()
cpu_src = torch.tensor(3, device='cpu', dtype=torch.float, requires_grad=True)
src = cpu_src.detach().clone().to('mps').requires_grad_()
# Indices should be taken from range of axis along which gathering is done
idx_np = [0]
cpu_idx = torch.tensor(idx_np, device='cpu', dtype=idx_dtype)
idx = cpu_idx.detach().clone().to('mps')
scatter_result = None
scatter_result_cpu = None
if (do_add):
scatter_result = torch.scatter_add(x, dim=0, index=idx, src=src)
scatter_result_cpu = torch.scatter_add(cpu_x, dim=0, index=cpu_idx, src=cpu_src)
else:
scatter_result = torch.scatter(x, dim=0, index=idx, src=src)
scatter_result_cpu = torch.scatter(cpu_x, dim=0, index=cpu_idx, src=cpu_src)
cpu_grad = None
grad = None
cpu_grad = torch.tensor(1.2, device='cpu', dtype=torch.float)
grad = cpu_grad.to('mps')
scatter_result.backward(gradient=grad)
scatter_result_cpu.backward(gradient=cpu_grad)
self.assertEqual(scatter_result, scatter_result_cpu)
self.assertEqual(cpu_x.grad, x.grad)
self.assertEqual(cpu_src.grad, src.grad)
helper()
helper(do_add=False)
# Test pytorch scatter_reduce
def test_scatter_reduce(self):
def helper(shape, dim, idx_shape, src_shape, idx_dtype=torch.int64, reduce_str="sum"):
cpu_x = torch.randn(shape, device='cpu', dtype=torch.float, requires_grad=True)
x = cpu_x.detach().clone().to('mps').requires_grad_()
cpu_src = torch.randn(src_shape, device='cpu', dtype=torch.float, requires_grad=True)
src = cpu_src.detach().clone().to('mps').requires_grad_()
# Indices should be taken from range of axis along which gathering is done
idx_np = np.random.randint(0, shape[dim], idx_shape)
cpu_idx = torch.tensor(idx_np, device='cpu', dtype=idx_dtype)
idx = cpu_idx.detach().clone().to('mps')
scatter_result = torch.scatter(x, dim=dim, index=idx, src=src, reduce=reduce_str)
scatter_result_cpu = torch.scatter(cpu_x, dim=dim, index=cpu_idx, src=cpu_src, reduce=reduce_str)
self.assertEqual(scatter_result, scatter_result_cpu)
# for reduce in ["sum", "prod", "amax", "amin"]:
for reduce_type in ["add", "multiply"]:
helper((2, 3), 0, (5, 3), (5, 3), reduce_str=reduce_type)
helper((2, 8, 4, 5), 0, (10, 8, 4, 5), (10, 8, 4, 5), reduce_str=reduce_type)
helper((8, 8, 4, 5), 0, (10, 8, 4, 5), (10, 8, 4, 5), reduce_str=reduce_type)
helper((8, 8, 4, 5), 0, (4, 7, 3, 2), (4, 7, 3, 2), reduce_str=reduce_type)
helper((8, 8, 4, 5), 0, (4, 6, 3, 2), (4, 7, 3, 2), reduce_str=reduce_type)
helper((8, 8, 4, 5), 0, (4, 6, 3, 2), (8, 8, 4, 5), reduce_str=reduce_type)
helper((2, 8, 4, 5), 1, (2, 20, 4, 5), (2, 20, 4, 5), reduce_str=reduce_type)
helper((2, 8, 4, 5), 1, (2, 13, 3, 2), (2, 13, 3, 2), reduce_str=reduce_type)
helper((8, 8, 4, 5), 1, (6, 5, 2, 3), (6, 5, 2, 3), reduce_str=reduce_type)
helper((8, 8, 4, 5), 1, (3, 4, 2, 2), (6, 5, 2, 3), reduce_str=reduce_type)
helper((4, 5, 9, 8), 2, (4, 5, 13, 8), (4, 5, 13, 8), reduce_str=reduce_type)
helper((4, 5, 9, 8), 2, (3, 4, 10, 6), (3, 4, 10, 6), reduce_str=reduce_type)
helper((4, 5, 9, 8), 2, (3, 3, 7, 5), (3, 4, 10, 6), reduce_str=reduce_type)
def test_is_nonzero(self):
self.assertFalse(torch.is_nonzero(torch.tensor([0.]).to('mps')))
self.assertTrue(torch.is_nonzero(torch.tensor([1.5]).to('mps')))
self.assertFalse(torch.is_nonzero(torch.tensor([False]).to('mps')))
self.assertTrue(torch.is_nonzero(torch.tensor([3]).to('mps')))
# Test triu
def test_triu(self):
def helper(shape, diag=0):
cpu_x = torch.randn(shape, device='cpu', dtype=torch.float, requires_grad=True)
x = cpu_x.detach().clone().to('mps').requires_grad_()
triu_result = torch.triu(x, diag)
triu_result_cpu = torch.triu(cpu_x, diag)
cpu_grad = torch.randn(triu_result_cpu.shape)
grad = cpu_grad.to('mps')
triu_result.backward(gradient=grad)
triu_result_cpu.backward(gradient=cpu_grad)
self.assertEqual(triu_result, triu_result_cpu)
self.assertEqual(x.grad, cpu_x.grad)
helper((2, 8, 4, 5))
helper((2, 8, 4, 5), diag=1)
helper((2, 8, 4, 5), diag=2)
helper((2, 8, 4, 5), diag=3)
helper((2, 8, 4, 5), diag=-1)
helper((2, 8, 4, 5), diag=-2)
helper((2, 8, 4, 5), diag=-3)
# Test inplace
x_mps = torch.arange(9.0, device='mps').reshape(3, 3).t().triu()
x_cpu = torch.arange(9.0, device='cpu').reshape(3, 3).t().triu()
self.assertEqual(x_cpu, x_mps)
self.assertEqual(x_cpu.stride(), x_mps.stride())
# Test inverse
def test_inverse(self):
def helper(n, atol=1e-5, rtol=1e-6):
cpu_input = torch.randn(n, n, device='cpu')
mps_input = cpu_input.to('mps')
cpu_result = torch.linalg.inv(cpu_input)
mps_result = torch.linalg.inv(mps_input)
self.assertEqual(cpu_result, mps_result, atol=atol, rtol=rtol)
helper(2)
helper(6)
helper(3)
helper(8)
helper(1025, atol=1e-4)
# Test tril
def test_tril(self):
def helper(shape, diag=0):
cpu_x = torch.randn(shape, device='cpu', dtype=torch.float, requires_grad=True)
x = cpu_x.detach().clone().to('mps').requires_grad_()
tril_result = torch.tril(x, diag)
tril_result_cpu = torch.tril(cpu_x, diag)
cpu_grad = torch.randn(tril_result_cpu.shape)
grad = cpu_grad.to('mps')
tril_result.backward(gradient=grad)
tril_result_cpu.backward(gradient=cpu_grad)
self.assertEqual(tril_result, tril_result_cpu)
self.assertEqual(x.grad, cpu_x.grad)
for diag in [0, 1, 2, 3, -1, -2, -3]:
helper((2, 8, 4, 5), diag=diag)
def helper_nans_infs(value, diag_vals=(0, 1, -2)):
"""For nans and infs"""
mps_tensor = torch.full((2, 2, 5, 5), value, device="mps")
cpu_tensor = torch.full((2, 2, 5, 5), value, device="cpu")
for diag in diag_vals:
mps_result = torch.tril(mps_tensor, diagonal=diag)
cpu_result = torch.tril(cpu_tensor, diagonal=diag)
self.assertEqual(mps_result, cpu_result, f"Mismatch for diag={diag}")
helper_nans_infs(float("inf"))
helper_nans_infs(float("-inf"))
helper_nans_infs(float("nan"))
# test eye
def test_eye(self):
def helper(n, m, dtype):
cpu_result = None
result = None
if (n == m):
cpu_result = torch.eye(n, dtype=dtype, device='cpu')
result = torch.eye(n, dtype=dtype, device='mps')
else:
cpu_result = torch.eye(n, m, device='cpu')
result = torch.eye(n, m, device='mps')
self.assertEqual(result, cpu_result)
for dtype in [torch.bool, torch.float16, torch.float32, torch.uint8, torch.int16, torch.int32, torch.int64]:
helper(2, 2, dtype)
helper(2, 3, dtype)
helper(0, 2, dtype)
helper(0, 0, dtype)
helper(3, 8, dtype)
helper(8, 3, dtype)
# Test diag
def test_diag(self):
def helper(shape, diag=0):
cpu_x = torch.randn(shape, device='cpu', dtype=torch.float, requires_grad=True)
x = cpu_x.detach().clone().to('mps').requires_grad_()
diag_result = torch.diag(x, diag)
diag_result_cpu = torch.diag(cpu_x, diag)
# cpu_grad = torch.randn(diag_result_cpu.shape)
# grad = cpu_grad.to('mps')
# diag_result.backward(gradient=grad)
# diag_result_cpu.backward(gradient=cpu_grad)
self.assertEqual(diag_result, diag_result_cpu)
# self.assertEqual(x.grad, cpu_x.grad)
for shape in [(5, 5), (5, 6), (6, 5), (5,), (6,)]:
for diag in [0, 1, 2, 3, 4, -1, -2, -3, -4]:
helper(shape, diag=diag)
# Test linspace
def test_linspace(self):
def helper(start, end, steps, dtype=torch.float32):
cpu_result = torch.tensor(np.linspace(start, end, steps), dtype=dtype)
result = torch.linspace(start, end, steps, dtype=dtype, device='mps')
self.assertEqual(cpu_result, result)
for dtype in [torch.float32, torch.int32, torch.uint8, torch.int64]:
helper(2, 5, 10, dtype)
helper(2, 2, 10, dtype)
helper(5, 2, 10, dtype)
helper(2, 2, 0, dtype)
# Test argange
def test_arange(self):
self.assertEqual(np.arange(10), torch.arange(10, device='mps'))
self.assertEqual(np.arange(7, 1, -1), torch.arange(7, 1, -1, device='mps'))
self.assertEqual(np.arange(1, 2, .3, dtype=np.float32), torch.arange(1, 2, .3, device='mps'))
self.assertEqual(np.arange(6.3, dtype=np.float32), torch.arange(6.3, device='mps'))
def do_arange(start=1.2, end=10.3, dtype=torch.bfloat16, device='cpu'):
return torch.arange(start, end, device=device, dtype=dtype)
self.assertEqual(do_arange(device='mps'), do_arange(device='cpu'))
def test_arange_empty(self):
out_mps = torch.tensor([], device="mps")
out_cpu = torch.tensor([], device="cpu")
y_mps = torch.arange(0, 0, 1, out=out_mps)
y_cpu = torch.arange(0, 0, 1, out=out_cpu)
self.assertEqual(y_mps, y_cpu)
# Test rgange
def test_range(self):
self.assertEqual(np.arange(11, dtype=np.float32), torch.range(0, 10, device='mps'))
self.assertEqual(np.arange(7, 0, -1, dtype=np.float32), torch.range(7, 1, -1, device='mps'))
self.assertEqual(np.array([1.0000, 1.3000, 1.6000, 1.9000], dtype=np.float32), torch.range(1, 2, .3, device='mps'))
self.assertEqual(np.arange(6.3, dtype=np.float32), torch.arange(0, 6.3, device='mps'))
# Test softmax
def test_softmax(self):
def helper(shape, dim, channels_last=False):
cpu_x = torch.randn(shape, device='cpu', dtype=torch.float, requires_grad=True)
if (channels_last):
cpu_x = cpu_x.to(memory_format=torch.channels_last)
cpu_x.retain_grad()
x = cpu_x.detach().clone().to('mps').requires_grad_()
softmax_result = torch.nn.functional.softmax(x, dim=dim)
softmax_result_cpu = torch.nn.functional.softmax(cpu_x, dim=dim)
# Currently NOT testing backward for channels last backward
cpu_grad = None
grad = None
if (not channels_last):
cpu_grad = torch.randn(shape, device='cpu', dtype=torch.float)
grad = cpu_grad.to('mps')
softmax_result.backward(gradient=grad)
softmax_result_cpu.backward(gradient=cpu_grad)
self.assertEqual(softmax_result, softmax_result_cpu)
if (not channels_last):
self.assertEqual(x.grad, cpu_x.grad)
def helper2(dim):
cpu_x = torch.tensor(1.23, device='cpu', dtype=torch.float, requires_grad=True)
x = cpu_x.detach().clone().to('mps').requires_grad_()
softmax_result = torch.nn.functional.softmax(x, dim=dim)
softmax_result_cpu = torch.nn.functional.softmax(cpu_x, dim=dim)
cpu_grad = torch.tensor(2.34, device='cpu', dtype=torch.float)
grad = cpu_grad.to('mps')
softmax_result.backward(gradient=grad)
softmax_result_cpu.backward(gradient=cpu_grad)
self.assertEqual(softmax_result, softmax_result_cpu)
self.assertEqual(x.grad, cpu_x.grad)
helper2(0)
for channels_last in [False]:
for shape in [(2, 4, 8, 5), (3, 4, 6, 7, 2)]:
if (len(shape) != 4 and channels_last):
continue
for dim in [0, 1, 2, 3, -1, -2, -3]:
helper(shape, dim, channels_last)
def test_nan_to_num(self):
inputCPU = torch.tensor([float('nan'), float('inf'), -float('inf'), 3.14])
inputMPS = inputCPU.detach().clone().to('mps').requires_grad_()
outputCPU = torch.nan_to_num(inputCPU, nan=2.0, posinf=1.0, neginf=-1.0)
outputMPS = torch.nan_to_num(inputMPS, nan=2.0, posinf=1.0, neginf=-1.0)
self.assertEqual(outputMPS, outputCPU)
# Test where
def test_where(self):
def helper(shape, x_shape, y_shape, cond_dtype=torch.bool, x_dtype=torch.float):
cpu_cond = torch.randint(2, shape, device='cpu', dtype=cond_dtype, requires_grad=False)
cond = cpu_cond.detach().clone().to('mps')
cpu_x = torch.randn(x_shape, device='cpu', dtype=x_dtype, requires_grad=True)
x = cpu_x.detach().clone().to('mps').requires_grad_()
cpu_y = torch.randn(y_shape, device='cpu', dtype=x_dtype, requires_grad=True)
y = cpu_y.detach().clone().to('mps').requires_grad_()
cpu_out = torch.where(cpu_cond, cpu_x, cpu_y)
out = torch.where(cond, x, y)
cpu_grad = torch.randn(cpu_out.shape)
grad = cpu_grad.to('mps')
cpu_out.backward(gradient=cpu_grad)
out.backward(gradient=grad)
self.assertEqual(out, cpu_out)
self.assertEqual(x.grad, cpu_x.grad)
self.assertEqual(y.grad, cpu_y.grad)
for shape in ([(0, 3), [], (2, 3), (9,)]):
helper(shape, shape, shape)
helper((2, 3, 1), (2, 3, 4), (2, 1, 4))
helper((2, 1, 1), (2, 3, 4), (1, 3, 4))
helper((1, 1, 1), (1, 1, 4), (2, 3, 1))
helper([], (1, 1, 4), (2, 3, 1))
helper([], (2, 3, 4), [])
helper((5, 2, 3), (2, 3), (2, 3))
helper((2, 3), (5, 2, 3), (2, 3))
helper((2, 3), (2, 3), (5, 2, 3))
helper((2, 3), (5, 2, 3), (6, 5, 2, 3))
# Test that output is correctly resizes
# TODO: Remove me when out OpInfo testing is enabled on MPS
output = torch.tensor(0.0, device="mps")
cond = torch.randint(2, (3, 3), dtype=torch.bool, device="mps")
inp = torch.rand(3, 3, device="mps")
other = torch.rand(3, 3, device="mps")
out = torch.where(cond, inp, other, out=output)
self.assertEqual(id(out), id(output))
self.assertEqual(out.shape, (3, 3))
# Test normal
def test_normal(self):
def helper(shape, mean=0.0, std=1.0, dtype=torch.float):
mps_out = torch.normal(mean, std, shape, device='mps', dtype=dtype)
mean_array = np.ones(shape)
mean_array *= mean
cpu_mean_tensor = torch.tensor(mean_array, device='cpu', dtype=dtype, requires_grad=False)
mean_tensor = cpu_mean_tensor.detach().clone().to('mps')
std_array = np.ones(shape)
std_array *= std
cpu_std_tensor = torch.tensor(std_array, device='cpu', dtype=dtype, requires_grad=False)
std_tensor = cpu_std_tensor.detach().clone().to('mps')
# test out
mps_out = torch.zeros(shape, device='mps', dtype=dtype)
torch.normal(mean_tensor, std, out=mps_out)
mps_out = torch.zeros(shape, device='mps', dtype=dtype)
torch.normal(mean, std_tensor, out=mps_out)
mps_out = torch.zeros(shape, device='mps', dtype=dtype)
torch.normal(mean_tensor, std_tensor, out=mps_out)
# test without out
mps_out = torch.normal(mean_tensor, std)
self.assertEqual(mps_out.size(), mean_tensor.size())
mps_out = torch.normal(mean, std_tensor)
self.assertEqual(mps_out.size(), std_tensor.size())
inferred_shape = torch.broadcast_shapes(mean_tensor.size(), std_tensor.size())
mps_out = torch.normal(mean_tensor, std_tensor)
self.assertEqual(mps_out.size(), inferred_shape)
helper((2, 3, 4, 5, 6))
helper((100, 100), 2.5, 1.2)
helper((10, 10), 2.5, 1.2, dtype=torch.bfloat16)
# Test invalid inputs
with self.assertRaises(TypeError):
helper((10, 10), 10, 11, dtype=torch.int32)
def test_bernoulli(self):
shape = (10, 10)
all_ones = torch.ones(shape, device='mps')
all_zeros = torch.zeros(shape, device='mps')
prob_tensor = all_ones * 0.5
# probability of drawing "1" is 0.5
mps_out = torch.bernoulli(prob_tensor)
# We can't check reliably the mean and std.
# Just make sure we don't return constant values
self.assertNotEqual(mps_out.to('cpu').mean(), 0.)
self.assertNotEqual(mps_out.to('cpu').std() ** 2, 0.)
# probability of drawing "1" is 0
mps_out = torch.bernoulli(all_zeros)
self.assertEqual(mps_out, all_zeros)
# probability of drawing "1" is 1
mps_out = torch.bernoulli(all_ones)
self.assertEqual(mps_out, all_ones)
# Check it works for different dtypes
for dtype in [torch.float16, torch.int8, torch.int16, torch.int32, torch.int64]:
mps_out = torch.zeros(shape, device='mps', dtype=dtype).bernoulli(0.5)
# Check that output is not all zeros or ones
uniq = mps_out.unique()
self.assertEqual(uniq, torch.arange(2, device='mps', dtype=dtype))
@parametrize("dtype", [torch.float16, torch.bfloat16, torch.float32])
def test_dropout(self, dtype):
shapes = [
(100_000,),
(100, 1000),
(10, 100, 100),
(10, 10, 10, 10, 10),
]
p_list = [0, 0.34, 0.78, 1]
for shape, p, train in itertools.product(shapes, p_list, [False, True]):
input = torch.randn(shape, device='mps', dtype=dtype, requires_grad=True)
output, mask = torch.native_dropout(input, p, train=train)
p_actual_mps = 1 - (mask.sum() / mask.numel())
if train:
self.assertEqual(p_actual_mps, p, atol=1e-2, rtol=1e-2)
self.assertTrue((output[mask.logical_not()] == 0).all())
self.assertEqual(output[mask], input[mask] / (1 - p))
else:
self.assertEqual(output, input)
self.assertTrue(mask.all())
output_grad = torch.randn_like(output)
output.backward(output_grad)
grad_scale = 0 if p == 1 else 1 / (1 - p)
if train:
self.assertEqual(input.grad, output_grad * mask * grad_scale)
else:
self.assertEqual(input.grad, output_grad)
def test_mps_generator(self):
# explicit manual seeding by creating an MPS Generator
g_mps = torch.Generator(device='mps')
g_mps.manual_seed(999)
mps_x = torch.randn(5, device='mps', generator=g_mps)
g_mps.manual_seed(999)
# generate random numbers with offset `0`
mps_y = torch.randn(5, device='mps', generator=g_mps)
# seed values were the same, so the random tensor contents should match
self.assertEqual(mps_x, mps_y)
# save generator's state (offset = 1) to restore it later
g_state = g_mps.get_state()
# generate random numbers with offset `1`
mps_x = torch.randn(5, device='mps', generator=g_mps)
# in this case, the random results must differ from the last generated random results
self.assertNotEqual(mps_x, mps_y)
# mps_x was produced by g_state, we use it as our reference mps_y.
mps_y = mps_x
# restore the previously saved state, and the results should match again
g_mps.set_state(g_state)
mps_x = torch.randn(5, device='mps', generator=g_mps)
self.assertEqual(mps_x, mps_y)
@serialTest()
def test_default_mps_generator(self):
# manual seeding on the "default" MPS generator using
# the global torch.manual_seed()
torch.manual_seed(230)
mps_x = torch.randn(5, device='mps')
# manual seeding using torch.mps.manual_seed()
# which should set the "default" MPS generator
# like the global torch.manual_seed()
torch.mps.manual_seed(230)
# generate random numbers with offset `0`
mps_y = torch.randn(5, device='mps')
# seed values were the same, so the random tensor contents should match
self.assertEqual(mps_x, mps_y)
# save the default generator's state (offset = 1) to restore it later
g_state = torch.mps.get_rng_state()
# generate random numbers with offset `1`
mps_x = torch.randn(5, device='mps')
# in this case, the random results must differ from the last generated random results
self.assertNotEqual(mps_x, mps_y)
# since we called randn twice after seeding, the offset should be 2
self.assertEqual(torch.mps._get_default_mps_generator().get_offset(), 2)
# mps_x was produced by g_state, we use it as our reference mps_y.
mps_y = mps_x
# restore the previously saved state to the "default" MPS generator, and the results should match again
torch.mps.set_rng_state(g_state)
mps_x = torch.randn(5, device='mps')
self.assertEqual(mps_x, mps_y)
def test_device_synchronize(self):
# just running some ops each followed by a synchronize to wait for
# MPS stream to finish running each of them
net1 = torch.nn.ConvTranspose2d(128, 64, kernel_size=3, stride=2, padding=1, output_padding=1)\
.to(device='mps', dtype=torch.float)
x = torch.rand(1, 128, 6, 6, device='mps', dtype=torch.float, requires_grad=True)
torch.mps.synchronize()
x = net1(x)
torch.mps.synchronize()
x.backward(torch.randn_like(x))
torch.mps.synchronize()
@serialTest()
def test_mps_allocator_module(self):
# first garbage collect and empty the cached blocks
gc.collect()
torch.mps.empty_cache()
# measure memory allocations from MPSAllocator
current_alloc_before = torch.mps.current_allocated_memory()
# after garbage collection and emptying the cache the
# current_allocated_memory must be zero
self.assertEqual(current_alloc_before, 0)
# measure total memory allocations from Metal driver
driver_alloc_before = torch.mps.driver_allocated_memory()
# allocate a new 8 MB tensor to force allocation of a new Metal Heap
x = torch.ones(1024 * 1024 * 8, device="mps")
# get memory allocations after allocating tensor x
current_alloc_after = torch.mps.current_allocated_memory()
driver_alloc_after = torch.mps.driver_allocated_memory()
# current and driver memory allocations must have
# grown at this point
self.assertGreater(current_alloc_after, current_alloc_before)
self.assertGreater(driver_alloc_after, driver_alloc_before)
def test_mps_allocator_stats(self):
max_memory = torch.mps.recommended_max_memory()
print(f"Recommended Max Memory : {max_memory / 1024 ** 3} GB")
self.assertGreater(max_memory, 0)
# to verify this test, run XCode Instruments "Metal System Trace" or "Logging" tool,
# press record, then run this python test, and press stop. Next expand
# the os_signposts->PyTorchMPS and check if events or intervals are logged
# like this example:
# "aten::mps_convolution_backward_input:f32[1,128,6,6]:f32[128,64,3,3]:1,128,6,6 (id=G2, run=2)"
def test_mps_profiler_module(self):
with torch.mps.profiler.profile(mode="event", wait_until_completed=False) as p:
# just running some ops to capture the OS Signposts traces for profiling
net1 = torch.nn.ConvTranspose2d(128, 64, kernel_size=3, stride=2, padding=1, output_padding=1)\
.to(device='mps', dtype=torch.float)
x = torch.rand(1, 128, 6, 6, device='mps', dtype=torch.float, requires_grad=True)
x = net1(x)
torch.mps.profiler.start(mode="interval", wait_until_completed=True)
# just running some ops to capture the OS Signposts traces for profiling
x = torch.rand(1, 128, 6, 6, device='mps', dtype=torch.float, requires_grad=True)
x = net1(x)
torch.mps.profiler.stop()
def test_mps_event_module(self):
startEvent = torch.mps.Event(enable_timing=True)
startEvent.record()
net1 = torch.nn.ConvTranspose2d(128, 64, kernel_size=3, stride=2, padding=1, output_padding=1)\
.to(device='mps', dtype=torch.float)
x = torch.rand(1, 128, 6, 6, device='mps', dtype=torch.float, requires_grad=True)
x = net1(x)
endEvent = torch.mps.Event(enable_timing=True)
endEvent.record()
elapsedTime = startEvent.elapsed_time(endEvent)
self.assertGreater(elapsedTime, 0.0)
def test_generic_event(self):
startEvent = torch.Event('mps', enable_timing=True)
startEvent.record()
net1 = torch.nn.ConvTranspose2d(128, 64, kernel_size=3, stride=2, padding=1, output_padding=1)\
.to(device='mps', dtype=torch.float)
x = torch.rand(1, 128, 6, 6, device='mps', dtype=torch.float, requires_grad=True)
x = net1(x)
endEvent = torch.Event('mps', enable_timing=True)
endEvent.record()
elapsedTime = startEvent.elapsed_time(endEvent)
self.assertGreater(elapsedTime, 0.0)
def test_generic_device_synchronize(self):
event = torch.Event('mps')
a = torch.randn(1000)
b = torch.randn(1000)
c = a + b
a_acc = a.to("mps", non_blocking=True)
b_acc = b.to("mps", non_blocking=True)
event.record()
event.synchronize()
c_acc = a_acc + b_acc
event.record()
torch.accelerator.synchronize()
self.assertTrue(event.query())
self.assertEqual(c_acc.cpu(), c)
def test_jit_save_load(self):
m = torch.nn.Module()
m.x = torch.rand(3, 3, device='mps')
buffer = io.BytesIO()
torch.jit.save(torch.jit.script(m), buffer)
buffer.seek(0)
n = torch.jit.load(buffer)
self.assertEqual(n.x, m.x)
# Test random_, random_.to and random_.from
def test_random(self):
def helper(shape, low, high, dtype=torch.int32):
mps_out = torch.randint(low, high, shape, dtype=dtype, device='mps')
# We can't check reliably the mean and std.
# Just make sure we don't return constant values
self.assertNotEqual(mps_out.float().mean().item(), 0.)
self.assertNotEqual(mps_out.float().std().item(), 0.)
helper([100, 100], 0, 10)
helper([100, 100], 23, 89)
helper([100, 100], 23, 89, dtype=torch.float32)
helper([100, 100], 23, 89, dtype=torch.int64)
helper([100, 100], 0, 2, dtype=torch.bool)
# Test random_
for dtype in [torch.bool, torch.int8, torch.uint8, torch.int32, torch.float16, torch.float32]:
x = torch.empty(10, 10, dtype=dtype, device='mps')
x.random_()
self.assertNotEqual(x.max().item(), 0)
def test_random_5d(self):
# See https://github.com/pytorch/pytorch/issues/147624 / FB16550905
shape = (2, 3, 4, 5, 6)
x = torch.rand(shape, device="mps")
self.assertNotEqual(x[0], x[1])
# Check that normal distributino is not affected by the same
y = torch.normal(torch.zeros(shape, device="mps"), torch.ones(shape, device="mps"))
self.assertNotEqual(y[0], y[1])
# Test exponential
@unittest.skip("This does not test anything")
def test_exponential(self):
def helper(shape, lambda_, dtype=torch.float32):
mps_out = torch.zeros(shape, device='mps', dtype=dtype)
mps_out.exponential_(lambda_)
print(mps_out.to('cpu').float().mean(), 1 / lambda_)
print(mps_out.to('cpu').float().std() ** 2, 1 / (lambda_**2))
for dtype in [torch.float32, torch.float16]:
helper([100, 100], 2, dtype)
helper([100, 100], 1, dtype)
helper([100, 100], 3, dtype)
helper([100, 100], 0.5, dtype)
def test_exponential_1(self):
rate = torch.randn(5, 5).abs().requires_grad_()
rate_1d = torch.randn(1).abs().requires_grad_()
self.assertEqual(Exponential(rate).sample().size(), (5, 5))
self.assertEqual(Exponential(rate).sample((7,)).size(), (7, 5, 5))
self.assertEqual(Exponential(rate_1d).sample((1,)).size(), (1, 1))
self.assertEqual(Exponential(rate_1d).sample().size(), (1,))
self.assertEqual(Exponential(0.2).sample((1,)).size(), (1,))
self.assertEqual(Exponential(50.0).sample((1,)).size(), (1,))
@parametrize("dtype", [torch.float16, torch.bfloat16, torch.float32])
def test_exponential_nonzero(self, dtype):
for _ in range(100):
a = torch.empty(32_000, device="mps", dtype=dtype).exponential_()
self.assertTrue((a != 0).all())
# Test add
def test_add_sub(self):
def helper(shape, alpha, op_name, inplace):
if op_name == "add":
op = torch.Tensor.add_ if inplace else torch.add
elif op_name == "sub":
op = torch.Tensor.sub_ if inplace else torch.sub
for dtype in [torch.float16, torch.float32]:
cpu_x = torch.randn(shape, device='cpu', dtype=dtype, requires_grad=False)
mps_x = cpu_x.detach().clone().to('mps')
cpu_y = torch.randn(shape, device='cpu', dtype=dtype, requires_grad=False)
mps_y = cpu_y.detach().clone().to('mps')
cpu_out = op(cpu_x, cpu_y, alpha=alpha)
mps_out = op(mps_x, mps_y, alpha=alpha)
# fp16 isn't accurate when alpha is passed
# TODO: remove or fix 'tol' when we fix problems with fp16
tol = 2e-3 if dtype is torch.float16 else None
self.assertEqual(mps_out, cpu_out, rtol=tol, atol=tol)
if not (cpu_y.shape != () and inplace): # in-place output cannot be broadcasted.
# create a scalar tensor
cpu_s = torch.tensor(2.3, device='cpu', dtype=dtype, requires_grad=False)
mps_s = cpu_s.detach().clone().to('mps')
# primary tensor is scalar
self.assertEqual(op(cpu_s, cpu_y), op(mps_s, mps_y))
# create a scalar tensor
cpu_s = torch.tensor(2.3, device='cpu', dtype=dtype, requires_grad=False)
mps_s = cpu_s.detach().clone().to('mps')
# secondary tensor is scalar
self.assertEqual(op(cpu_x, cpu_s), op(mps_x, mps_s), rtol=tol, atol=tol)
for op_name, inplace in product(["add", "sub"], [True, False]):
helper((), 0.0, op_name, inplace)
helper((2, 8, 4, 5), 0.0, op_name, inplace)
helper((2, 8, 4, 5), 0.1, op_name, inplace)
helper((2, 8, 4, 5), 1.0, op_name, inplace)
helper((2, 8, 3, 5), 0.1, op_name, inplace)
helper((2, 8, 3, 5), 0.2, op_name, inplace)
# Test float32 int alpha
# See https://github.com/pytorch/pytorch/issues/143932
x = torch.rand(32, device='mps', dtype=torch.float32)
y = torch.arange(32, device='mps', dtype=torch.int32)
self.assertEqual(torch.add(x, y, alpha=2).cpu(), torch.add(x.cpu(), y.cpu(), alpha=2))
self.assertEqual(torch.add(x, 3, alpha=2).cpu(), torch.add(x.cpu(), 3, alpha=2))
# Regression test for https://github.com/pytorch/pytorch/issues/160208
self.assertEqual(torch.add(y, x, alpha=2).cpu(), torch.add(y.cpu(), x.cpu(), alpha=2))
# Test add
def test_add_scalars(self):
def helper(alpha):
for dtype in [torch.float16, torch.float32]:
cpu_x = torch.tensor(2.3, device='cpu', dtype=dtype, requires_grad=False)
x = cpu_x.detach().clone().to('mps')
cpu_y = torch.tensor(3.4, device='cpu', dtype=dtype, requires_grad=False)
y = cpu_y.detach().clone().to('mps')
cpu_out = torch.add(cpu_x, cpu_y, alpha=alpha)
out = torch.add(x, y, alpha=alpha)
# fp16 isn't accurate when alpha is passed
tol = 1e-3 if dtype is torch.float16 else None
self.assertEqual(out, cpu_out, rtol=tol, atol=tol)
helper(1.0)
helper(0.0)
helper(0.1)
helper(0.2)
# Test int32 tensor + int64 scalar add
# see https://github.com/pytorch/pytorch/issues/79835#issuecomment-1164984534
x = torch.ones(4, dtype=torch.int32, device='mps')
self.assertEqual(x + 1, torch.full((4,), 2, dtype=torch.int32, device='mps'))
self.assertTrue(torch.equal(x + 1.5, torch.full((4,), 2.5, device='mps')))
def test_types_binary_op(self):
# Float * Bool
cpu_x = torch.arange(5, dtype=torch.float32, device="cpu") * torch.tensor([True, False, True, False, True], device="cpu")
mps_x = torch.arange(5, dtype=torch.float32, device="mps") * torch.tensor([True, False, True, False, True], device="mps")
self.assertEqual(cpu_x, mps_x)
# Float * Int64
cpu_y = torch.arange(5, dtype=torch.float32, device="cpu") * torch.tensor([1, 0, 1, 0, 1], device="cpu")
mps_y = torch.arange(5, dtype=torch.float32, device="mps") * torch.tensor([1, 0, 1, 0, 1], device="mps")
self.assertEqual(cpu_y, mps_y)
def test_unary_ops(self):
def helper(shape, op):
for dtypef in [torch.float32]:
cpu_x = torch.randn(shape, device='cpu', dtype=dtypef, requires_grad=False)
mps_x = cpu_x.detach().clone().to('mps')
self.assertEqual(op(cpu_x), op(mps_x))
for dtypei in [torch.int32, torch.int16]:
cpu_x = torch.randint(0, 1000, shape, device='cpu', dtype=dtypei, requires_grad=False)
mps_x = cpu_x.to('mps')
self.assertEqual(op(cpu_x), op(mps_x), rtol=1e-4, atol=1e-4)
# test slice
for dtypef in [torch.float32]:
cpu_x = torch.randn(shape, device='cpu', dtype=dtypef, requires_grad=False)
mps_x = cpu_x.detach().clone().to('mps')
cpu_slice = cpu_x[:, ::2, :, :]
mps_slice = mps_x[:, ::2, :, :]
self.assertEqual(op(cpu_slice), op(mps_slice))
# test view
for dtypef in [torch.float32]:
cpu_x = torch.randn(shape, device='cpu', dtype=dtypef, requires_grad=False)
mps_x = cpu_x.detach().clone().to('mps')
# create view of tensor by reducing the 3rd and 4th dimension
combined_dim = shape[-1] * shape[-2]
reshaped_dims = list(shape[:-2]) + [combined_dim]
cpu_view = cpu_x.view(*reshaped_dims)
mps_view = mps_x.view(*reshaped_dims)
self.assertEqual(op(cpu_view), op(mps_view))
helper((2, 8, 4, 5), torch.exp)
helper((2, 8, 3, 5), torch.exp2)
helper((2, 8, 3, 5), torch.expm1)
helper((2, 8, 3, 5), torch.log)
helper((2, 8, 3, 5), torch.cos)
helper((2, 8, 3, 5), torch.erfinv)
def test_non_dense_in_storage_unary_ops(self):
def helper(op):
for dtypef in [torch.float32]:
cpu_x = torch.randn(100, device='cpu', dtype=dtypef, requires_grad=False)
mps_x = cpu_x.detach().clone().to('mps')
self.assertEqual(op(cpu_x[::2]), op(mps_x[::2]))
for dtypei in [torch.int32, torch.int16, torch.int8]:
cpu_x = torch.randint(127, device='cpu', size=(100,), dtype=dtypei, requires_grad=False)
mps_x = cpu_x.to('mps')
self.assertEqual(op(cpu_x[::2]), op(mps_x[::2]), rtol=1e-4, atol=1e-4)
helper(torch.exp)
helper(torch.exp2)
helper(torch.expm1)
helper(torch.log)
helper(torch.cos)
def test_unary_ops_storage_offset_strided(self):
def helper(shape, op, inplace, dtype=torch.float32):
# test in-place with storage_offset
cpu_x = torch.randn(shape, device='cpu', dtype=dtype)
mps_x = cpu_x.detach().clone().to('mps')
y = op(mps_x[1])
cpu_y = op(cpu_x[1])
self.assertEqual(y, cpu_y)
# See https://github.com/pytorch/pytorch/issues/100764
if not inplace:
cpu_x = torch.randn(shape, device='cpu', dtype=dtype)
mps_x = cpu_x.detach().clone().to('mps')
cpu_y = torch.empty(shape, device='cpu', dtype=dtype).t()
mps_y = cpu_y.detach().clone().to('mps')
op(cpu_x, out=cpu_y)
op(mps_x, out=mps_y)
self.assertEqual(mps_y, cpu_y)
# test for non contiguous but dense input/output with similar strides
cpu_x = torch.randn(shape, device='cpu', dtype=dtype).mT
mps_x = cpu_x.to('mps')
cpu_y = torch.empty_like(cpu_x)
mps_y = cpu_y.to('mps')
op(cpu_x, out=cpu_y)
op(mps_x, out=mps_y)
self.assertEqual(mps_y, cpu_y)
# test for sliced inputs and outputs with similar strides
mps_x, mps_y = torch.randn((2, shape[0] * 2, shape[1] * 2), device='mps', dtype=dtype).unbind(0)
op(mps_x[::2, ::2], out=mps_y[::2, ::2])
self.assertEqual(mps_y[::2, ::2], op(mps_x[::2, ::2].contiguous()))
helper((5, 5), torch.exp, False)
helper((5, 5), torch.cos, False)
helper((5, 5), torch.neg, False)
helper((5, 5), torch.tanh, False)
helper((5, 5), torch.tanh_, True)
helper((5, 5), lambda x, **kwargs: torch.round(x, decimals=2, **kwargs), False)
def test_atan2(self):
def helper(shape):
input_cpu = torch.randn(shape)
input_mps = input_cpu.detach().clone().to("mps")
other_cpu = torch.randn(shape)
other_mps = other_cpu.detach().clone().to("mps")
atan2_cpu = torch.atan2(input_cpu, other_cpu)
atan2_mps = torch.atan2(input_mps, other_mps)
self.assertEqual(atan2_cpu, atan2_mps.to("cpu"))
helper(4)
helper(10000)
helper((10000, 40))
@unittest.skip("This does not test anything")
def test_multinomial(self):
# Test with num_dist = 1
def helper(probs, compare_mean, compare_var, num_samples=5, replacement=True):
cpu_prob_tensor = torch.tensor(probs, device='cpu', dtype=torch.float, requires_grad=False)
prob_tensor = cpu_prob_tensor.detach().clone().to('mps')
mps_out = torch.multinomial(prob_tensor, num_samples, replacement=replacement)
if (not replacement):
print(mps_out.to('cpu'))
else:
# Compare "real" with theoretical values
print(mps_out.to('cpu').float().mean(), compare_mean)
print(mps_out.to('cpu').float().std() ** 2, compare_var)
# TODO: Add tests for data types
helper(np.array([[0., 0., 0., 0.5, 0.5]]), (3 + 4) / 2, (12.5 - 3.5 ** 2), 100000)
helper(np.array([[.2, .2, .2, .2, .2]]), (0 + 1 + 2 + 3 + 4) / 5, (6 - 2 * 2), 10000)
helper(np.array([[1, 1, 1, 1, 1]]), (0 + 1 + 2 + 3 + 4) / 5, (6 - 2 * 2), 10000)
helper(np.array([1, 1, 1, 1, 1]), (0 + 1 + 2 + 3 + 4) / 5, (6 - 2 * 2), 10000)
helper(np.array([[1, 1, 1, 1, 1, 1, 1]]), 0, 0, 7, False)
def test_non_contiguous_sampling_variation(self):
torch.manual_seed(42)
# transpose so it's made non-contiguous
probs = torch.tensor([[.25, .1], [.25, .1], [.25, .1], [.25, .7]]).T.to("mps")
samples = {torch.multinomial(probs, 1).flatten()[0].item() for _ in range(200)}
# we should get different samples rather than the same value repeated,
# indicating the sampling is working properly on non-contiguous tensors
self.assertNotEqual(len(samples), 1)
def test_cumsum_dim_check(self):
x = torch.rand((3, 3), device="mps")
self.assertEqual(x.cumsum(1), x.cumsum(-1))
self.assertEqual(x.cumsum(0), x.cumsum(-2))
self.assertRaises(IndexError, lambda: x.cumsum(2))
self.assertRaises(IndexError, lambda: x.cumsum(-3))
def test_cumprod_dim_check(self):
x = torch.rand((3, 3), device="mps")
self.assertEqual(x.cumprod(1), x.cumprod(-1))
self.assertEqual(x.cumprod(0), x.cumprod(-2))
self.assertRaises(IndexError, lambda: x.cumprod(2))
self.assertRaises(IndexError, lambda: x.cumprod(-3))
def test_do_sync_thrice_its_all_right(self):
# Regression test for https://github.com/pytorch/pytorch/commit/9bc9d4cdb4355a385a7d7959f07d04d1648d6904
# That caused sync calls to deadlock
x = torch.nextafter(torch.ones(1024, device='mps'), torch.zeros(1024, device='mps'))
for _ in range(3):
torch.mps.synchronize()
self.assertLess(x.sum().item(), x.numel())
@parametrize("dtype", [torch.int32, torch.int64, torch.int16, torch.int8, torch.uint8])
def test_inplace_bitwise_not(self, dtype):
# Start with bitwise not here (reported by @qqaatw)
x_mps, x_cpu = [torch.arange(64, device=device, dtype=dtype) for device in ["cpu", "mps"]]
for x in [x_mps, x_cpu]:
x[::2].bitwise_not_()
self.assertEqual(x_mps.cpu(), x_cpu)
def test_empty_posneginf(self):
# just to check that it doesnt crash
input_tensor = torch.empty(0, device="mps")
out_pos = torch.isposinf(input_tensor)
out_neg = torch.isposinf(input_tensor)
self.assertEqual(out_pos.numel(), 0)
self.assertEqual(out_neg.numel(), 0)
def test_empty_dot(self):
# just to check that it doesnt crash
a = torch.rand((0), device="mps")
b = torch.rand((0), device="mps")
self.assertEqual(a.dot(b), a.cpu().dot(b.cpu()))
class TestLargeTensors(TestCaseMPS):
@serialTest()
def test_64bit_binops(self):
if torch.mps.recommended_max_memory() < 16_000_000_000:
raise unittest.SkipTest("Needs at least 16Gb of RAM")
a = torch.rand(1, 1024, 1024, dtype=torch.float16, device='mps')
b = torch.rand(5000, 1, 1, dtype=torch.float16, device='mps')
rc = (a + b).sin()
slice_idx = -2
rc_slice = rc[slice_idx:]
rc_slice_cpu = (a.cpu() + b.cpu()[slice_idx:]).sin()
self.assertEqual(rc_slice, rc_slice_cpu)
@serialTest()
def test_64bit_index_select(self):
if torch.mps.recommended_max_memory() < 16_000_000_000:
raise unittest.SkipTest("Needs at least 16Gb of RAM")
B, N = 11, 20000
x = torch.empty(B, N, N, dtype=torch.float16, device='mps')
for i in range(B):
x[i] = 1.0 * i
batch_idx = torch.tensor([9], device='mps')
y = x[batch_idx]
self.assertEqual(y[0, 1, 2].item(), 9.0)
# Reclaim memory after running the tests
del y
del x
gc.collect()
torch.mps.empty_cache()
@serialTest()
def test_rand_2b_raises(self):
int32_max = torch.iinfo(torch.int32).max
with self.assertRaises(RuntimeError):
# This used to crash with NDArray dimension length > INT_MAX
x = torch.randint(0, 10, (int32_max + 1,), dtype=torch.int8, device='mps')
x = torch.randint(0, 10, (int32_max,), dtype=torch.int8, device='mps')
self.assertEqual(x.numel(), int32_max)
del x
class TestLogical(TestCaseMPS):
def _wrap_tensor(self, x, device="cpu", dtype=None, requires_grad=False):
return torch.tensor(x, device=device, dtype=dtype, requires_grad=requires_grad)
def test_logical_not(self):
def helper(x):
cpu_x = x
x = cpu_x.detach().clone().to('mps')
result = torch.logical_not(x)
result_cpu = torch.logical_not(cpu_x)
self.assertEqual(result, result_cpu)
helper(self._wrap_tensor([1, 1, 0, 0]))
helper(self._wrap_tensor([1, 1, 0, 0], dtype=torch.float, requires_grad=True))
helper(self._wrap_tensor([True, True, False, False]))
helper(self._wrap_tensor(1))
helper(self._wrap_tensor(0))
helper(self._wrap_tensor(True))
helper(self._wrap_tensor(False))
def test_logical_and(self):
def helper(x, other):
cpu_x = x
x = cpu_x.detach().clone().to('mps')
cpu_other = other
other = cpu_other.detach().clone().to('mps')
result = torch.logical_and(x, other)
result_cpu = torch.logical_and(cpu_x, cpu_other)
self.assertEqual(result, result_cpu)
helper(self._wrap_tensor([1, 1, 0, 0]), self._wrap_tensor([1, 0, 0, 1]))
helper(
self._wrap_tensor([1, 1, 0, 0], dtype=torch.float, requires_grad=True),
self._wrap_tensor([1, 0, 0, 1], dtype=torch.float)
)
helper(self._wrap_tensor([True, True, False, False]), self._wrap_tensor([True, False, False, True]))
helper(self._wrap_tensor((1, 0, 1, 0)), self._wrap_tensor(1))
helper(self._wrap_tensor((1, 0, 1, 0)), self._wrap_tensor(0))
helper(self._wrap_tensor((1, 0, 1, 0)), self._wrap_tensor(True))
helper(self._wrap_tensor((1, 0, 1, 0)), self._wrap_tensor(False))
def test_logical_or(self):
def helper(x, other):
cpu_x = x
x = cpu_x.detach().clone().to('mps')
cpu_other = other
other = cpu_other.detach().clone().to('mps')
result = torch.logical_or(x, other)
result_cpu = torch.logical_or(cpu_x, cpu_other)
self.assertEqual(result, result_cpu)
helper(self._wrap_tensor([1, 1, 0, 0]), self._wrap_tensor([1, 0, 0, 1]))
helper(
self._wrap_tensor([1, 1, 0, 0], dtype=torch.float, requires_grad=True),
self._wrap_tensor([1, 0, 0, 1], dtype=torch.float)
)
helper(self._wrap_tensor([True, True, False, False]), self._wrap_tensor([True, False, False, True]))
helper(self._wrap_tensor((1, 0, 1, 0)), self._wrap_tensor(1))
helper(self._wrap_tensor((1, 0, 1, 0)), self._wrap_tensor(0))
helper(self._wrap_tensor((1, 0, 1, 0)), self._wrap_tensor(True))
helper(self._wrap_tensor((1, 0, 1, 0)), self._wrap_tensor(False))
def test_logical_xor(self):
def helper(x, other):
cpu_x = x
x = cpu_x.detach().clone().to('mps')
cpu_other = other
other = cpu_other.detach().clone().to('mps')
result = torch.logical_xor(x, other)
result_cpu = torch.logical_xor(cpu_x, cpu_other)
self.assertEqual(result, result_cpu)
helper(self._wrap_tensor([1, 1, 0, 0]), self._wrap_tensor([1, 0, 0, 1]))
helper(
self._wrap_tensor([1, 1, 0, 0], dtype=torch.float, requires_grad=True),
self._wrap_tensor([1, 0, 0, 1], dtype=torch.float)
)
helper(self._wrap_tensor([True, True, False, False]), self._wrap_tensor([True, False, False, True]))
helper(self._wrap_tensor((1, 0, 1, 0)), self._wrap_tensor(1))
helper(self._wrap_tensor((1, 0, 1, 0)), self._wrap_tensor(0))
helper(self._wrap_tensor((1, 0, 1, 0)), self._wrap_tensor(True))
helper(self._wrap_tensor((1, 0, 1, 0)), self._wrap_tensor(False))
@parametrize("dtype", [torch.float32, torch.float16, torch.int32, torch.int16, torch.uint8, torch.int8, torch.bool])
def test_min_max(self, dtype):
for _ in range(10):
if dtype == torch.float32 or dtype == torch.float16:
x = torch.randn((30, 15), device='mps', dtype=dtype)
else:
x = torch.randint(0, 100, (30, 15), device="mps", dtype=dtype)
x_cpu = x.to("cpu")
y = x.max()
y_cpu = x_cpu.max()
self.assertEqual(y, y_cpu)
z = x.min()
z_cpu = x_cpu.min()
self.assertEqual(z, z_cpu)
@parametrize("dtype", [torch.float32, torch.float16, torch.bfloat16])
def test_min_max_nan_propagation(self, dtype):
cpu_x = torch.tensor([1.0, float("nan"), 3.0], device="cpu", dtype=dtype)
mps_x = cpu_x.detach().clone().to('mps')
cpu_max = torch.max(cpu_x)
mps_max = torch.max(mps_x).to('cpu')
cpu_amax = torch.amax(cpu_x)
mps_amax = torch.amax(mps_x).to('cpu')
cpu_min = torch.min(cpu_x)
mps_min = torch.min(mps_x).to('cpu')
cpu_amin = torch.amin(cpu_x)
mps_amin = torch.amin(mps_x).to('cpu')
self.assertEqual(cpu_max, mps_max)
self.assertEqual(cpu_amax, mps_amax)
self.assertEqual(cpu_min, mps_min)
self.assertEqual(cpu_amin, mps_amin)
def test_isin(self):
def helper(dtype):
shapes = [([2, 5], [3, 5, 2]), ([10, 3, 5], [20, 1, 3]),
([5], [10]), ([0], [5]), ([5], [0])]
for shape_tuple in shapes:
for inverted in [True, False]:
if dtype.is_floating_point:
# Half is not supported for CPU isin. Compute reference in FP32
A = torch.randn(size=shape_tuple[0], device='cpu', dtype=torch.float32)
B = torch.randn(size=shape_tuple[1], device='cpu', dtype=torch.float32)
else:
A = torch.randint(0, 100, size=shape_tuple[0], device='cpu', dtype=dtype)
B = torch.randint(0, 100, size=shape_tuple[1], device='cpu', dtype=dtype)
A_mps = A.detach().clone().to('mps')
B_mps = B.detach().clone().to('mps')
cpu_ref = torch.isin(A, B, invert=inverted)
if dtype in [torch.float16, torch.bfloat16]:
cpu_ref.type(dtype)
mps_out = torch.isin(A_mps, B_mps, invert=inverted)
self.assertEqual(mps_out, cpu_ref)
dtypes = [torch.float32, torch.float16, torch.bfloat16, torch.int32, torch.int16, torch.uint8, torch.int8]
[helper(dtype) for dtype in dtypes]
# Mixed dtypes (see https://github.com/pytorch/pytorch/issues/151443 )
x = torch.arange(4.0, device="mps")
y = torch.tensor([1, 3], device="mps", dtype=torch.float16)
self.assertEqual(torch.isin(x, y), torch.tensor([False, True, False, True], device="mps"))
# Tensor.Scalar variant (aliases to eq), not covered by OpInfo
self.assertEqual(torch.isin(x, 2.0), torch.tensor([False, False, True, False], device="mps"))
self.assertEqual(torch.isin(x, 1.0, invert=True), torch.tensor([True, False, True, True], device="mps"))
self.assertEqual(torch.isin(x, 8.0), torch.tensor([False, False, False, False], device="mps"))
# Scalar.Tensor variant(alaises to Scalar.Scalar), not covered by OpInfo
self.assertEqual(torch.isin(2.0, x), torch.tensor(True, device="mps"))
def test_isin_asserts(self):
C = torch.randn(size=[1, 4], device='mps', dtype=torch.float32)
D = torch.randn(size=[1, 4], device='cpu', dtype=torch.float32)
with self.assertRaisesRegex(RuntimeError, 'Expected elements.is_mps()*'):
out = torch.isin(C, D)
@parametrize("dtype", [torch.int32, torch.int64, torch.int16, torch.int8, torch.uint8, torch.bool])
def test_shifts(self, dtype):
x = make_tensor(256, device="mps", dtype=dtype)
if dtype is not torch.bool:
x[3] = torch.iinfo(dtype).max
x[5] = torch.iinfo(dtype).min
x_cpu = x.cpu()
self.assertEqual((x >> 3).cpu(), x_cpu >> 3)
self.assertEqual((x << 1).cpu(), x_cpu << 1)
# Regression test for https://github.com/pytorch/pytorch/issues/147889
x = x.clamp(0, 8)
x_cpu = x.cpu()
self.assertEqual((4095 >> x).cpu(), 4095 >> x_cpu)
self.assertEqual((257 << x).cpu(), 257 << x_cpu)
class TestSmoothL1Loss(TestCaseMPS):
@parametrize("reduction", ["none", "mean", "sum"])
@parametrize("requires_grad", [False, True])
def test_smooth_l1_loss(self, reduction, requires_grad):
def helper(sizes):
# CPU
input_cpu = torch.randn(*sizes, requires_grad=requires_grad)
target_cpu = torch.randn(*sizes)
# MPS
input_mps = input_cpu.detach().clone().to('mps').requires_grad_()
target_mps = target_cpu.detach().clone().to('mps')
smooth_l1_loss_cpu = F.smooth_l1_loss(input_cpu, target_cpu, beta=1.0, reduction=reduction)
smooth_l1_loss_mps = F.smooth_l1_loss(input_mps, target_mps, beta=1.0, reduction=reduction)
self.assertEqual(smooth_l1_loss_cpu, smooth_l1_loss_mps)
if requires_grad:
if reduction == "none":
grad_cpu = torch.zeros_like(smooth_l1_loss_cpu)
grad_mps = grad_cpu.to('mps')
smooth_l1_loss_cpu.backward(grad_cpu)
smooth_l1_loss_mps.backward(grad_mps)
else:
smooth_l1_loss_cpu.backward()
smooth_l1_loss_mps.backward()
self.assertEqual(input_cpu.grad, input_mps.grad.to("cpu"))
helper((2, 3, 4))
helper((8, 5))
helper((3, ))
helper((3, 3, 0))
class TestNLLLoss(TestCaseMPS):
def test_nll_loss_mismatched_batch(self, device='mps'):
x = torch.randn((10, 3), requires_grad=True, device=device)
# t should have size (10,)
t = torch.zeros((3,), dtype=torch.int64, device=device)
with self.assertRaisesRegex(ValueError, 'Expected.*batch_size'):
F.nll_loss(x, t)
def test_nll_loss_out_of_bounds_ignore_index(self):
def test_nll_loss_out_of_bounds_ignore_index_helper(device):
output = []
x = torch.tensor([[0.3, 0.5, 0.2], [0.1, 0.7, 0.2], [0.4, 0.5, 0.1], [
0.3, 0.5, 0.2], [0.1, 0.7, 0.2], [0.4, 0.5, 0.1]], device=device)
t1 = torch.tensor([0, 1, 255, 0, 1, 2], dtype=torch.int64, device=device)
t2 = torch.tensor([0, 1, 1, 0, -100, 2], dtype=torch.int64, device=device)
for reduction in ['mean', 'none']:
# out of bound ignore_index
output.append(F.nll_loss(x, t1, ignore_index=255, reduction=reduction))
# default ignore_index
output.append(F.nll_loss(x, t2, reduction=reduction))
return output
output_cpu = test_nll_loss_out_of_bounds_ignore_index_helper(device='cpu')
output_mps = test_nll_loss_out_of_bounds_ignore_index_helper(device='mps')
for cpu, mps in zip(output_cpu, output_mps):
self.assertEqual(cpu, mps)
def test_nll_loss_invalid_target_dim(self):
def _test_nll_loss_invalid_target_dim(device):
output = []
x = torch.tensor([[0.3, 0.5, 0.2], [0.1, 0.7, 0.2], [0.4, 0.5, 0.1], [
0.3, 0.5, 0.2], [0.1, 0.7, 0.2], [0.4, 0.5, 0.1]], device=device)
t = torch.zeros((6, 2), dtype=torch.int64, device=device)
with self.assertRaisesRegex(RuntimeError, "1D target tensor expected"):
F.nll_loss(x, t)
_test_nll_loss_invalid_target_dim(device='cpu')
_test_nll_loss_invalid_target_dim(device='mps')
def test_nll_loss_invalid_weights(self):
def _test_nll_loss_invalid_weights(device):
x = torch.tensor([[0.3, 0.5, 0.2], [0.1, 0.7, 0.2], [0.4, 0.5, 0.1], [
0.3, 0.5, 0.2], [0.1, 0.7, 0.2], [0.4, 0.5, 0.1]], device=device)
t = torch.tensor([0, 1, 2, 1, 1, 2], dtype=torch.int64, device=device)
invalid_weights = [
torch.zeros(4, device=device),
torch.zeros((1, 3), device=device),
]
msg = "weight tensor should be defined either for all 3 classes or no classes"
for weight in invalid_weights:
with self.assertRaisesRegex(RuntimeError, msg):
F.nll_loss(x, t, weight=weight)
_test_nll_loss_invalid_weights(device='cpu')
_test_nll_loss_invalid_weights(device='mps')
def _nll_loss_helper(self, input_size, reduction, expected):
# CPU
input = torch.rand(input_size, requires_grad=True, device='cpu')
num_channels = input_size[1]
target_size = (input_size[0], ) + tuple(input_size[2:])
target = torch.randint(num_channels, target_size, device='cpu')
weights = torch.randn(num_channels)
# MPS
input_mps = input.detach().clone().to('mps').requires_grad_()
target_mps = target.detach().clone().to('mps')
weights_mps = weights.to("mps")
output_cpu = F.nll_loss(input, target, weight=weights, reduction=reduction)
output_mps = F.nll_loss(input_mps, target_mps, weight=weights_mps, reduction=reduction)
self.assertEqual(output_cpu, output_mps.to('cpu'))
output_cpu.sum().backward()
output_mps.sum().backward()
self.assertEqual(input.grad, input_mps.grad.to('cpu'))
def _nll_loss_1d_helper(self, input_size, reduction):
# CPU
input = torch.rand(input_size, requires_grad=True, device='cpu')
num_channels = input_size[0]
target = torch.randint(num_channels, [], device='cpu')
# MPS
input_mps = input.detach().clone().to('mps').requires_grad_()
target_mps = target.detach().clone().to('mps')
output_cpu = F.nll_loss(input, target, reduction=reduction)
output_mps = F.nll_loss(input_mps, target_mps, reduction=reduction)
self.assertEqual(output_cpu, output_mps.to('cpu'))
output_cpu.sum().backward()
output_mps.sum().backward()
self.assertEqual(input.grad, input_mps.grad.to('cpu'))
def test_nll_loss_1d(self, device='cpu'):
self._nll_loss_1d_helper([10], "none")
self._nll_loss_1d_helper([10], "mean")
self._nll_loss_1d_helper([10], "sum")
def test_nll_loss_empty_tensor_reduction_none(self, device='cpu'):
self._nll_loss_helper([1, 3], "none", torch.empty([0], device=device))
self._nll_loss_helper([3, 5, 7], "none", torch.empty([5, 7], device=device))
self._nll_loss_helper([2, 3, 1, 7], "none", torch.empty([2, 1, 7], device=device))
self._nll_loss_helper([2, 3, 5, 1], "none", torch.empty([2, 5, 1], device=device))
self._nll_loss_helper([2, 3, 5, 7, 1], "none", torch.empty([2, 5, 7, 1], device=device))
def test_nll_loss_empty_tensor_reduction_mean(self, device='cpu'):
nan = torch.tensor(float('nan'), device=device)
self._nll_loss_helper([1, 3], "mean", nan)
self._nll_loss_helper([1, 3, 5, 7], "mean", nan)
self._nll_loss_helper([2, 3, 1, 7], "mean", nan)
self._nll_loss_helper([2, 3, 5, 1], "mean", nan)
self._nll_loss_helper([2, 3, 5, 7, 1], "mean", nan)
def test_nll_loss_empty_tensor_reduction_sum(self, device='cpu'):
zero = torch.tensor(0, device=device)
self._nll_loss_helper([1, 3], "sum", zero)
self._nll_loss_helper([1, 3, 5, 7], "sum", zero)
self._nll_loss_helper([2, 3, 1, 7], "sum", zero)
self._nll_loss_helper([2, 3, 5, 1], "sum", zero)
self._nll_loss_helper([2, 3, 5, 7, 1], "sum", zero)
def test_nll_loss_byte_target_matches_long(self, device='cpu'):
N, C = 10, 4
input = torch.randn(N, C, device=device, requires_grad=True)
target = torch.empty(N, dtype=torch.long, device=device).random_(0, C)
def compute_result_and_gradient(reduction, target_dtype):
result, grad = {}, {}
for dev in ['cpu', 'mps']:
input_dev = input.to(dev)
input_ = input_dev.detach()
input_.requires_grad_()
target_dev = target.to(dev)
prob = F.log_softmax(input_, dim=-1)
loss = nn.NLLLoss(reduction=reduction)
result[dev] = loss(prob, target_dev.to(target_dtype))
result[dev].sum().backward()
grad[dev] = input_.grad
return result, grad
for reduction in ["none", "mean", "sum"]:
result_long, grad_long = compute_result_and_gradient(reduction, torch.long)
result_byte, grad_byte = compute_result_and_gradient(reduction, torch.uint8)
self.assertEqual(result_long['mps'].to('cpu'), result_long['cpu'])
self.assertEqual(grad_long['mps'].to('cpu'), grad_long['cpu'])
def test_nll_loss_backward(self):
# Copy-n-pasted from similar test_torchinductor.py test
# Used to crash with `error: 'mps.divide' op requires the same element type for all operands and results`
labels = (
torch.zeros([5], dtype=torch.int64, device="mps"),
torch.tensor([-100, -100, 3, -100, -100], dtype=torch.int64, device="mps"),
)
for label in labels:
inp = torch.rand(5, 5, device="mps", dtype=torch.half)
grad_out = torch.empty((), device=inp.device, dtype=inp.dtype)
total_weight = torch.tensor(1.0, device=inp.device)
torch.ops.aten.nll_loss_backward(grad_out, inp, label, None, 1, -100, total_weight)
class TestTopK(TestCase):
def _test_topk(self, shape, largest):
cpu_x = torch.randn(shape, device='cpu', dtype=torch.float, requires_grad=False)
x = cpu_x.detach().clone().to('mps')
if isinstance(shape, tuple):
for curr_dim, dim_size in enumerate(shape):
for k in range(1, dim_size + 1):
topk_values, topk_indices = torch.topk(x, k, dim=curr_dim, largest=largest)
topk_values_cpu, topk_indices_cpu = torch.topk(cpu_x, k, dim=curr_dim, largest=largest)
self.assertEqual(topk_values, topk_values_cpu)
self.assertEqual(topk_indices, topk_indices_cpu)
else:
for k in range(1, shape):
topk_values, topk_indices = torch.topk(x, k, dim=0, largest=largest)
topk_values_cpu, topk_indices_cpu = torch.topk(cpu_x, k, dim=0, largest=largest)
self.assertEqual(topk_values, topk_values_cpu)
self.assertEqual(topk_indices, topk_indices_cpu)
def test_topk(self):
largest_vals = [True, False]
shapes = [
# Zero Element Tensors
0,
(1, 0),
(0, 1),
(1, 0, 1),
# Multiple Element Tensors
1,
2,
(5, 1),
(1, 5),
(5, 9, 7, 4),
]
for shape in shapes:
for largest_val in largest_vals:
with self.subTest(shape=shape, largest_val=largest_val):
self._test_topk(shape, largest_val)
def test_topk_gt_4d(self):
a = torch.ones(5, 4, 3, 2, 1, dtype=torch.float).to('mps')
try:
t_mps = torch.ops.aten.topk(a, k=5, dim=0)
except Exception as e:
e_string = str(e)
self.assertEqual(e_string, "On-going issue on MPSGraph topk when ndims() - axis > 4, see issue #154890")
class TestNNMPS(NNTestCase):
def _create_basic_net(self):
class Layer(nn.Module):
def __init__(self) -> None:
super().__init__()
self.layer_dummy_param = Parameter(torch.empty(3, 5))
self.layer_dummy_buf = Buffer(torch.zeros(1, 3, 3, 7))
class Net(nn.Module):
def __init__(self) -> None:
super().__init__()
self.l1 = Layer()
self.dummy_param = Parameter(torch.empty(3, 5))
self.dummy_buf = Buffer(torch.zeros(7, 3, 3, 1))
l = Layer()
n = Net()
s = nn.Sequential(n, n)
return l, n, s
def test_requires_grad_(self):
m = self._create_basic_net()[-1]
assert len(list(m.buffers())) > 0, 'invalid test'
assert all(not b.requires_grad for b in m.buffers()) > 0, 'invalid test'
assert len(list(m.parameters())) > 0, 'invalid test'
assert all(p.requires_grad for p in m.parameters()) > 0, 'invalid test'
for requires_grad in (False, True):
self.assertIs(m.requires_grad_(requires_grad), m)
for p in m.parameters():
self.assertEqual(p.requires_grad, requires_grad)
for b in m.buffers():
self.assertFalse(b.requires_grad)
def test_module_backcompat(self):
from torch.serialization import SourceChangeWarning
path = download_file('https://download.pytorch.org/test_data/linear.pt')
with warnings.catch_warnings():
warnings.simplefilter('ignore', SourceChangeWarning)
# weights_only=False as this is a legacy use case that loads a module
m = torch.load(path, weights_only=False)
input = torch.randn(2, 3, dtype=torch.float)
self.assertEqual(m(input).size(), (2, 5))
def test_conv_backcompat(self):
from torch.serialization import SourceChangeWarning
# This file was generated by running on PyTorch 1.0.1 on Python 2:
#
# import torch
# from torch import nn
# m = nn.Conv2d(1, 1, 1)
# torch.save(m, 'legacy_conv2d.pt')
#
# NB: This Pickle also contains some Unicode data!
path = download_file('https://download.pytorch.org/test_data/legacy_conv2d.pt')
with warnings.catch_warnings():
warnings.simplefilter('ignore', SourceChangeWarning)
# weights_only=False as this is a legacy use case that loads a module
m = torch.load(path, encoding='utf-8', weights_only=False)
input = torch.randn((1, 1, 1, 1), dtype=torch.float)
self.assertEqual(m(input).size(), (1, 1, 1, 1))
def test_conv_expand(self):
device = 'mps'
input_ = torch.rand(2, 3, 16, 16, device=device)
kernel = torch.rand(1, 1, 3, 11, device=device)
tmp_kernel = kernel.expand(-1, 3, -1, -1)
output = F.conv2d(input_, tmp_kernel, groups=1, padding=0, stride=1)
# The test should not crash
def test_permute(self):
M_cpu = torch.randn(5, 5)
M_mps = M_cpu.to('mps')
output_cpu = M_cpu.permute(1, 0)
output_mps = M_mps.permute(1, 0)
self.assertEqual(output_cpu, output_mps)
self.assertEqual(output_cpu.size(), output_mps.size())
# Printing of non_contiguous should not crash
def test_print_non_contiguous(self):
# print(obj) is equivalent to calling `x=str(obj); print(x)`
# Use assertTrue in case to make sure non-empty string is returned
self.assertTrue(str(torch.ones(100, 100, device='mps').nonzero()))
self.assertTrue(str(torch.ones(100, 100, device='mps').nonzero().contiguous()))
def test_zero_grad(self):
i = torch.randn(2, 5, requires_grad=True)
module = nn.Linear(5, 5)
for p in module.parameters():
p.requires_grad = False
module.zero_grad()
module.weight.requires_grad = True
module.zero_grad()
self.assertIsNone(module.weight.grad) # uninitialized grad
module(i).sum().backward()
self.assertIsNotNone(module.weight.grad)
self.assertGreater(module.weight.grad.data.abs().sum(), 0)
module.zero_grad()
self.assertIsNone(module.weight.grad)
module.bias.requires_grad = True
module.zero_grad()
self.assertIsNone(module.weight.grad)
self.assertIsNone(module.bias.grad)
module(i).sum().backward()
self.assertIsNotNone(module.weight.grad)
self.assertIsNotNone(module.bias.grad)
self.assertGreater(module.weight.grad.data.abs().sum(), 0)
self.assertGreater(module.bias.grad.data.abs().sum(), 0)
# Force set to zeros.
module.zero_grad(set_to_none=False)
self.assertEqual(module.weight.grad.data, module.weight.data.clone().zero_())
self.assertEqual(module.bias.grad.data, module.bias.data.clone().zero_())
module.zero_grad()
self.assertIsNone(module.weight.grad)
self.assertIsNone(module.bias.grad)
def test_no_grad(self):
for dtype in [torch.bfloat16, torch.float, torch.double]:
module = nn.Conv2d(2, 5, kernel_size=3, padding=1).to(dtype)
input = torch.randn(1, 2, 10, 10).to(dtype)
x = input
y = input.clone()
output = module(x)
self.assertTrue(output.requires_grad)
output.backward(torch.ones(1, 5, 10, 10))
with torch.no_grad():
output2 = module(y)
self.assertFalse(output2.requires_grad)
self.assertRaises(RuntimeError, lambda: output2.backward(torch.ones(1, 5, 10, 10)))
def test_invalid_conv1d(self):
for dtype in [torch.bfloat16, torch.float, torch.double]:
module = nn.Conv1d(in_channels=3, out_channels=33, kernel_size=10, stride=1, bias=True).to(dtype)
input = torch.randn(1, 3, 4).to(dtype)
with self.assertRaisesRegex(RuntimeError,
r'Calculated padded input size per channel: \(4\). ' +
r'Kernel size: \(10\). Kernel size can\'t be greater than actual input size'):
module(input)
# Negative stride check
module = nn.Conv1d(in_channels=3, out_channels=6, kernel_size=3, stride=-1, bias=True).to(dtype)
input = torch.randn(1, 3, 4).to(dtype)
with self.assertRaisesRegex(RuntimeError, 'non-positive stride is not supported'):
module(input)
def test_conv2d_discontiguous_weight(self):
# Test for https://github.com/pytorch/pytorch/issues/55781
x = torch.ones(64, 16, 16, 16)
weight = torch.arange(0, 1.0, 1 / 2.0 ** 10).reshape(32, 16, 1, 2)[:, :, :, ::2]
self.assertFalse(weight.is_contiguous())
y = torch.nn.functional.conv2d(x, weight, None)
if torch.backends.mkldnn.is_available():
# Disable MKLDNN explicitly, so that either NNPACK or THCNN will be used
with torch.backends.mkldnn.flags(enabled=False):
y_ = torch.nn.functional.conv2d(x, weight, None)
self.assertEqual(y, y_)
self.assertEqual(y.sum(), 4186112.)
def test_invalid_conv2d(self):
for dtype in [torch.bfloat16, torch.float, torch.double]:
module = torch.nn.Conv2d(1, 1, kernel_size=3, dilation=2, stride=2).to(dtype)
input = torch.empty(1, 1, 4, 4).to(dtype)
self.assertRaises(RuntimeError, lambda: module(input))
module = nn.Conv2d(in_channels=3, out_channels=33, kernel_size=10, stride=1, bias=True)
input = torch.randn(1, 3, 1, 1)
with self.assertRaisesRegex(RuntimeError,
r'Calculated padded input size per channel: \(1 x 1\). ' +
r'Kernel size: \(10 x 10\). Kernel size can\'t be greater than actual input size'):
module(input)
# Negative stride check
module = nn.Conv2d(in_channels=3, out_channels=6, kernel_size=4, stride=-1, bias=True).to(dtype)
input = torch.randn(1, 3, 4, 4).to(dtype)
with self.assertRaisesRegex(RuntimeError, 'non-positive stride is not supported'):
module(input)
# Zero stride check
module = nn.Conv2d(in_channels=3, out_channels=6, kernel_size=4, stride=0, bias=True).to(dtype)
input = torch.randn(1, 3, 4, 4).to(dtype)
with self.assertRaisesRegex(RuntimeError, 'non-positive stride is not supported'):
module(input)
# Input and weights on different devices
self.assertRaisesRegex(RuntimeError,
'must be on the same device',
lambda: torch.conv2d(torch.rand(1, 3, 32, 32), torch.rand(1, 3, 3, 3, device='mps')))
self.assertRaisesRegex(RuntimeError,
'Input type \\(MPSFloatType\\) and weight type \\(torch\\.FloatTensor\\) should be the same',
lambda: torch.conv2d(torch.rand(1, 3, 32, 32, device='mps'), torch.rand(1, 3, 3, 3)))
def test_conv2d_valid_padding(self, device='mps'):
# Test F.conv2d padding='valid' is the same as no padding
x = torch.rand(1, 1, 1, 10, device=device).to(torch.float)
y = torch.rand(1, 1, 1, 4, device=device).to(torch.float)
expect = F.conv2d(x, y)
actual = F.conv2d(x, y, padding='valid')
self.assertEqual(expect.to('cpu'), actual.to('cpu'))
def test_conv2d_backward_collision(self):
# Test for https://github.com/pytorch/pytorch/issues/112998
x = torch.rand(1, 1, 10, 10, device="mps", requires_grad=True)
m1 = nn.Conv2d(1, 1, 3, stride=2, padding=1).to("mps")
m2 = nn.Conv2d(1, 1, 4, stride=2, padding=1).to("mps")
y1, y2 = m1(x), m2(x)
self.assertEqual(y1.shape, y2.shape)
y1.sum().backward()
# This used to crash with MPSNDArrayConvolutionA14.mm:4352: failed assertion
y2.sum().backward()
def test_conv3d_backward_collision(self):
# Conv3D is only available from MacOS 13.2 onwards
x = torch.rand(1, 1, 10, 10, 20, device="mps", requires_grad=True)
m1 = nn.Conv3d(1, 1, 3, stride=2, padding=1).to("mps")
m2 = nn.Conv3d(1, 1, 4, stride=2, padding=1).to("mps")
y1, y2 = m1(x), m2(x)
self.assertEqual(y1.shape, y2.shape)
y1.sum().backward()
# This used to crash with MPSNDArrayConvolutionA14.mm:4352: failed assertion
y2.sum().backward()
# Regression test for https://github.com/pytorch/pytorch/issues/141471
def test_conv3d_channels_last_3d(self):
m_cpu = nn.Conv3d(16, 33, (3, 5, 2), stride=(2, 1, 1), padding=(4, 2, 0), device="cpu")
m_mps = copy.deepcopy(m_cpu).to("mps")
x_cpu = torch.randn(20, 16, 10, 50, 100, device="cpu").to(memory_format=torch.channels_last_3d)
x_mps = x_cpu.detach().clone().to("mps")
res_cpu = m_cpu(x_cpu)
res_mps = m_mps(x_mps)
self.assertEqual(res_cpu, res_mps)
def test_gemm_permute_transpose(self):
batch_size = 32
n = 20
hidden = 768
num_attention_heads = 12
attention_head_size = hidden // num_attention_heads
def transpose_for_scores(x: torch.Tensor) -> torch.Tensor:
new_x_shape = x.size()[:-1] + (num_attention_heads, attention_head_size)
x = x.view(new_x_shape)
return x.permute(0, 2, 1, 3)
def attention2(key, *, workaround=False, device):
key = transpose_for_scores(key)
res = key.transpose(-1, -2)
return res
A = torch.randn(batch_size, n, hidden)
A_mps = A.detach().clone().to("mps")
r1 = attention2(A, device="cpu")
r2 = attention2(A_mps, device="mps")
r2_cpu = r2.to("cpu")
self.assertEqual(r1, r2_cpu)
def test_group_norm_backward(self, device='mps'):
# See https://github.com/pytorch/pytorch/issues/88331 for more detail
shape = [1, 4, 16, 16]
x = torch.full(shape, 7.0, device=device)
target = torch.ones((1, 3, 128, 128), device=device)
conv_in = nn.Conv2d(4, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), device=device)
conv_out = nn.Conv2d(128, 3, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), device=device)
norm = nn.GroupNorm(32, 128, eps=1e-6, affine=True, device=device)
with torch.enable_grad():
x = x.detach().requires_grad_()
out = 5.5 * x
out = conv_in(out)
out = out + norm(out)
out = out + norm(out)
out = out + norm(out)
out = F.interpolate(out, scale_factor=8.0, mode="nearest")
out = norm(out)
out = conv_out(out)
loss = (out - target).norm(dim=-1).sum()
grad = -torch.autograd.grad(loss, x)[0]
self.assertFalse(grad.detach().isnan().any().item(), 'NaN gradients returned by autograd')
# def test_conv2d_same_padding(self, device='mps'):
# x = torch.rand(1, 1, 10, 11, device=device)
# y = torch.rand(1, 1, 4, 5, device=device)
# expect = F.conv2d(x, y, padding=(2, 2))[..., 1:, :]
# actual = F.conv2d(x, y, padding='same')
# self.assertEqual(expect.to('cpu'), actual.to('cpu'))
# # With dilation
# y = torch.rand(1, 1, 3, 4, device=device)
# expect = F.conv2d(x, y, padding=(2, 3), dilation=2)
# actual = F.conv2d(x, y, padding='same', dilation=2)
# self.assertEqual(expect, actual)
# # Dilation with asymmetric padding
# y = torch.rand(1, 1, 4, 4, device=device)
# expect = F.conv2d(x, y, padding=5, dilation=3)[..., 1:, 1:]
# actual = F.conv2d(x, y, padding='same', dilation=3)
# self.assertEqual(expect, actual)
class TestPad(TestCaseMPS):
def test_constant_pad(self):
m = torch.nn.ConstantPad2d((-2, -2, -2, -2), 3.5)
input_cpu = torch.randn(1, 16, 16, 16)
input_mps = input_cpu.detach().clone().to("mps")
r_cpu = m(input_cpu)
r_mps = m(input_mps)
self.assertEqual(r_cpu, r_mps.to("cpu"))
# Arbitrary input dimensions
pad = (1, 1, 0, 0, 0, 0)
value = 3.5
input_cpu = torch.randn((1, 1, 3, 3, 3, 3, 3, 3, 3, 3))
input_mps = input_cpu.detach().clone().to("mps")
r_cpu = F.pad(input_cpu, pad=pad, value=value)
r_mps = F.pad(input_mps, pad=pad, value=value)
self.assertEqual(r_cpu, r_mps.to("cpu"))
def test_circular_pad(self):
# https://github.com/pytorch/pytorch/issues/80856
k_cpu = torch.ones(3, 3, 9, 9)
k_mps = k_cpu.detach().clone().to("mps")
x_cpu = torch.rand(1, 3, 32, 32)
x_mps = x_cpu.detach().clone().to("mps")
x_pad_cpu = F.pad(x_cpu, (2, 2, 2, 2), mode='circular')
x_pad_mps = F.pad(x_mps, (2, 2, 2, 2), mode='circular')
y_cpu = F.conv2d(x_pad_cpu, k_cpu)
y_mps = F.conv2d(x_pad_mps, k_mps)
self.assertEqual(y_cpu, y_mps.cpu())
def test_constant_pad_4d_warning(self):
inputCPU = torch.rand((1, 2, 2, 2, 1, 1))
inputMPS = inputCPU.detach().clone().to('mps')
outputCPU = F.pad(inputCPU, [0, 0, 0, 0, 0, 0, 1, 0])
outputMPS = F.pad(inputMPS, [0, 0, 0, 0, 0, 0, 1, 0])
self.assertEqual(outputCPU, outputMPS)
def test_pad(self):
def helper(shape, padding, op, value=0):
inputCPU = torch.randn(shape, device='cpu', dtype=torch.float, requires_grad=True)
inputCPU.retain_grad()
inputMPS = inputCPU.detach().clone().to('mps').requires_grad_()
if (op in [nn.ConstantPad1d, nn.ConstantPad2d, nn.ConstantPad3d]):
padCriteria = op(padding, value)
else:
padCriteria = op(padding)
outputCPU = padCriteria(inputCPU)
outputMPS = padCriteria(inputMPS)
self.assertEqual(outputCPU, outputMPS)
# backward pass (chose 0.6 just to have the grad_output != 1)
outputCPU.backward(gradient=torch.full_like(outputCPU, 0.6))
outputMPS.backward(gradient=torch.full_like(outputMPS, 0.6))
self.assertEqual(inputCPU.grad, inputMPS.grad)
# 1D Padding
helper((2, 4, 3), 2, nn.ReflectionPad1d)
# verify if a change in shape of input would cause problems with graph caching
helper((2, 4, 4), (1, 3), nn.ReflectionPad1d)
# Replication 1D
helper((2, 1, 6), 3, nn.ReplicationPad1d)
# Constant Pad 1D
helper((2, 3, 4), 2, nn.ConstantPad1d)
# Constant Pad 1D with single dimension input
helper((16), (1, 2), nn.ConstantPad1d)
# 2D Padding
helper((1, 2, 3, 4), (1, 1, 2, 0), nn.ReflectionPad2d)
# verify if a change in shape of input would cause problems with graph caching
helper((2, 4, 3, 4), (1, 1, 2, 0), nn.ReflectionPad2d)
# this should make the padding (2, 2, 2, 2)
helper((2, 1, 6, 8), 2, nn.ReplicationPad2d)
# verify if a change in shape of padding would cause problems with graph caching
helper((2, 1, 6, 8), (2, 4, 3, 5), nn.ReplicationPad2d)
# Constant Pad 2D
helper((2, 1, 6, 8), (2, 4, 3, 5), nn.ConstantPad2d)
# input size < pad size
helper((1, 2, 3), (0, 0, 0, 1), nn.ConstantPad2d)
# pad dims < input dims
helper((50, 9, 300), (0, 0, 0, 31), nn.ConstantPad2d)
# pad dims == input dims
helper((1, 3), (0, 2, 0, 1), nn.ConstantPad2d)
# input.numel() == 0 but output.numel() > 0
helper((0, 3, 3), (1, 1, 1, 1, 1, 1), nn.ConstantPad2d)
# pad dims < input dims - 2
helper((1, 2, 3, 4), (1, 2), nn.ConstantPad2d)
# 3D Padding
helper((2, 4, 6, 8, 4), (1, 3, 3, 5, 3, 4), nn.ReflectionPad3d)
# verify if a change in shape of padding would cause problems with graph caching
helper((2, 4, 6, 8, 4), (1, 3, 3, 5, 3, 4), nn.ReplicationPad3d)
# case where input_d == pad_front/back for ReplicationPad3d
helper((3, 4, 5, 6, 7), (1, 2, 3, 4, 5, 6), nn.ReplicationPad3d)
# Constant Pad 3D
helper((2, 4, 6, 8, 4), (1, 3, 3, 5, 3, 4), nn.ConstantPad3d)
# input size < pad size
helper((2, 4, 6), (1, 3, 3, 5, 3, 4), nn.ConstantPad3d)
# check the workaround for the right padding bug in Monterey
helper((1, 2, 2, 2, 2), (0, 1), nn.ConstantPad3d)
def test_constant_pad_nd_preserves_memory_format(self):
nchw_tensor = torch.rand((1, 2, 5, 3))
nchw_padded = torch.constant_pad_nd(nchw_tensor, [1, 2], 0.5)
self.assertTrue(nchw_padded.is_contiguous(memory_format=torch.contiguous_format))
nhwc_tensor = nchw_tensor.contiguous(memory_format=torch.channels_last)
nhwc_padded = torch.constant_pad_nd(nhwc_tensor, [1, 2], 0.5)
self.assertTrue(nhwc_padded.is_contiguous(memory_format=torch.channels_last))
def test_constant_pad_nd_with_empty_pad(self):
# Empty constant pad is no-op
# See https://github.com/pytorch/pytorch/issues/161066
input_mps = torch.randn((2, 3, 4), device="mps")
output_mps = torch.constant_pad_nd(input_mps, [])
self.assertEqual(output_mps, input_mps)
class TestLinalgMPS(TestCaseMPS):
def _test_addmm_addmv(self, f, t, m, v, *, alpha=None, beta=None, transpose_out=False):
dtype = t.dtype
numpy_dtype = dtype
alpha = 1.2 if alpha is None else alpha
beta = 0.8 if beta is None else beta
res1 = f(t, m, v, alpha=alpha, beta=beta)
res2 = torch.full_like(res1, math.nan)
if transpose_out:
res2 = res2.t().clone(memory_format=torch.contiguous_format).t()
f(t, m, v, alpha=alpha, beta=beta, out=res2)
res3 = alpha * (m.to(numpy_dtype).cpu().numpy() @ v.to(numpy_dtype).cpu().numpy())
if beta != 0:
res3 += (torch.mul(t, beta)).to(numpy_dtype).cpu().numpy()
res3 = torch.from_numpy(res3).to(dtype)
self.assertEqual(res1, res2)
self.assertEqual(res1, res3)
def test_addmm(self, device="mps", dtype=torch.float32):
M = torch.randn(10, 25, device=device).to(dtype)
m1 = torch.randn(10, 50, device=device).to(dtype)
m2 = torch.randn(50, 25, device=device).to(dtype)
self._test_addmm_addmv(torch.addmm, M, m1, m2)
# Test beta=0, M=nan
M = torch.full((10, 25), math.nan, device=device).to(dtype)
m1 = torch.randn(10, 50, device=device).to(dtype)
m2 = torch.randn(50, 25, device=device).to(dtype)
self._test_addmm_addmv(torch.addmm, M, m1, m2, beta=0)
# Test transpose
for t1, t2, t3, t4 in itertools.product([True, False], repeat=4):
def maybe_transpose(cond, m):
if not cond:
return m
return m.t().clone(memory_format=torch.contiguous_format).t()
M = maybe_transpose(t1, torch.randn(10, 25, device=device).to(dtype))
m1 = maybe_transpose(t2, torch.randn(10, 50, device=device).to(dtype))
m2 = maybe_transpose(t3, torch.randn(50, 25, device=device).to(dtype))
self._test_addmm_addmv(torch.addmm, M, m1, m2, transpose_out=t4)
def _test_addr(self, f, t, m, v, alpha=None, beta=None):
dtype = t.dtype
numpy_dtype = dtype
alpha = 1.2 if alpha is None else alpha
beta = 0.8 if beta is None else beta
res1 = f(t, m, v, alpha=alpha, beta=beta)
res2 = alpha * np.outer(m.to(numpy_dtype).cpu().numpy(), v.to(numpy_dtype).cpu().numpy())
if beta != 0:
res2 += (torch.mul(t, beta)).to(numpy_dtype).cpu().numpy()
res2 = torch.from_numpy(res2).to(dtype)
self.assertEqual(res1, res2)
def test_addr(self, device="mps", dtype=torch.float32):
M = torch.randn(10, 25, device=device).to(dtype)
m1 = torch.randn(10, device=device).to(dtype)
m2 = torch.randn(25, device=device).to(dtype)
self._test_addr(torch.addr, M, m1, m2)
# Test beta=0, M=nan
M = torch.full((10, 25), math.nan, device=device).to(dtype)
m1 = torch.randn(10, device=device).to(dtype)
m2 = torch.randn(25, device=device).to(dtype)
self._test_addr(torch.addr, M, m1, m2, beta=0)
def test_matrix_rank(self, device="mps", dtype=torch.float32):
matrix_rank = torch.linalg.matrix_rank
def run_test(shape0, shape1, batch):
a = torch.randn(*batch, shape0, shape1, dtype=dtype, device=device)
rank_a = matrix_rank(a)
self.assertEqual(rank_a, matrix_rank(a.mH))
aaH = torch.matmul(a, a.mH)
rank_aaH = matrix_rank(aaH)
rank_aaH_hermitian = matrix_rank(aaH, hermitian=True)
self.assertEqual(rank_aaH, rank_aaH_hermitian)
aHa = torch.matmul(a.mH, a)
self.assertEqual(matrix_rank(aHa), matrix_rank(aHa, hermitian=True))
# check against NumPy
self.assertEqual(rank_a, np.linalg.matrix_rank(a.cpu().numpy()))
self.assertEqual(matrix_rank(a, 0.01), np.linalg.matrix_rank(a.cpu().numpy(), 0.01))
self.assertEqual(rank_aaH, np.linalg.matrix_rank(aaH.cpu().numpy()))
self.assertEqual(matrix_rank(aaH, 0.01), np.linalg.matrix_rank(aaH.cpu().numpy(), 0.01))
# hermitian flag for NumPy was added in 1.14.0
if np.lib.NumpyVersion(np.__version__) >= '1.14.0':
self.assertEqual(rank_aaH_hermitian,
np.linalg.matrix_rank(aaH.cpu().numpy(), hermitian=True))
self.assertEqual(matrix_rank(aaH, 0.01, True),
np.linalg.matrix_rank(aaH.cpu().numpy(), 0.01, True))
# check out= variant
out = torch.empty(a.shape[:-2], dtype=torch.int64, device=device)
ans = matrix_rank(a, out=out)
self.assertEqual(ans, out)
self.assertEqual(ans, rank_a)
shapes = (3, 13)
batches = ((), (0, ), (4, ), (3, 5, ))
for (shape0, shape1), batch in zip(itertools.product(shapes, reversed(shapes)), batches):
# escape only when NotImplementedError of downstream function is raised
# TODO: remove this once the required function is implemented
try:
run_test(shape0, shape1, batch)
except NotImplementedError as e:
with self.assertRaisesRegex(
NotImplementedError,
"The operator 'aten::_linalg_svd.U' is not currently implemented for the MPS device."):
raise e
def test_pinv(self, device="mps", dtype=torch.float32, precision=1e-4):
from torch.testing._internal.common_utils import random_hermitian_pd_matrix
def run_test_main(A, hermitian):
# Testing against definition for pseudo-inverses
A_pinv = torch.linalg.pinv(A, hermitian=hermitian)
np_A = A.cpu().numpy()
np_A_pinv = A_pinv.cpu().numpy()
if A.numel() > 0:
self.assertEqual(A, np_A @ np_A_pinv @ np_A, atol=precision, rtol=precision)
self.assertEqual(A_pinv, np_A_pinv @ np_A @ np_A_pinv, atol=precision, rtol=precision)
self.assertEqual(np_A @ np_A_pinv, (np_A @ np_A_pinv).conj().swapaxes(-2, -1), atol=precision, rtol=precision)
self.assertEqual(np_A_pinv @ np_A, (np_A_pinv @ np_A).conj().swapaxes(-2, -1), atol=precision, rtol=precision)
else:
self.assertEqual(A.shape, A_pinv.shape[:-2] + (A_pinv.shape[-1], A_pinv.shape[-2]))
# Check out= variant
out = torch.empty_like(A_pinv)
ans = torch.linalg.pinv(A, hermitian=hermitian, out=out)
self.assertEqual(ans, out)
self.assertEqual(ans, A_pinv)
def run_test_numpy(A, hermitian):
# Check against NumPy output
# Test float rcond, and specific value for each matrix
rconds = [float(torch.rand(1)), ]
# Test different types of rcond tensor
for rcond_type in MPS_DTYPES:
# TODO: Figure out why it's not supported for complex
# Skip test for bfloat16 as numpy does not support the type
if rcond_type.is_complex or rcond_type == torch.bfloat16:
continue
rconds.append(torch.rand(A.shape[:-2], dtype=torch.float32, device=device).to(rcond_type))
# Test broadcasting of rcond
if A.ndim > 2:
rconds.append(torch.rand(A.shape[-3], device=device))
for rcond in rconds:
actual = torch.linalg.pinv(A, rcond=rcond, hermitian=hermitian)
torch_rtol = torch.linalg.pinv(A, rtol=rcond, hermitian=hermitian)
self.assertEqual(actual, torch_rtol, atol=precision, rtol=precision)
numpy_rcond = rcond if isinstance(rcond, float) else rcond.cpu().numpy()
expected = np.linalg.pinv(A.cpu().numpy(), rcond=numpy_rcond, hermitian=hermitian)
self.assertEqual(actual, expected, atol=precision, rtol=precision)
for sizes in [(5, 5), (3, 5, 5), (3, 2, 5, 5), # square matrices
(3, 2), (5, 3, 2), (2, 5, 3, 2), # fat matrices
(2, 3), (5, 2, 3), (2, 5, 2, 3), # thin matrices
(0, 0), (0, 2), (2, 0), (3, 0, 0), (0, 3, 0), (0, 0, 3)]: # zero numel matrices
A = torch.randn(*sizes, dtype=dtype, device=device)
hermitian = False
run_test_main(A, hermitian)
run_test_numpy(A, hermitian)
# Check hermitian = True
for sizes in [(5, 5), (3, 5, 5), (3, 2, 5, 5), # square matrices
(0, 0), (3, 0, 0), ]: # zero numel square matrices
A = random_hermitian_pd_matrix(sizes[-1], *sizes[:-2], dtype=dtype, device=device)
hermitian = True
# escape only when NotImplementedError of downstream function is raised
# TODO: remove this once the required function is implemented
try:
run_test_main(A, hermitian)
except NotImplementedError as e:
with self.assertRaisesRegex(
NotImplementedError,
"The operator 'aten::_linalg_eigh.eigenvalues' is not currently implemented for the MPS device."):
raise e
try:
run_test_numpy(A, hermitian)
except NotImplementedError as e:
with self.assertRaisesRegex(
NotImplementedError,
"The operator 'aten::_linalg_eigh.eigenvalues' is not currently implemented for the MPS device."):
raise e
@parametrize("m", [1, 32, 64])
@parametrize("n", [48, 64])
@parametrize("q_group", [32, 64, 128, 256])
@parametrize("num_groups", [1, 2])
def test__int4_mm(self, m, n, q_group, num_groups):
k = q_group * num_groups
inner_k_tiles = 2
torch.manual_seed(1)
a_f32 = torch.rand((m, k), device="mps")
b_f32 = torch.rand((k, n), device="mps")
def convert_weight_to_int4pack(b):
b_int32, b_scales_and_zeros = _group_quantize_tensor(
b, n_bit=4, q_group_size=q_group
)
b_scales_and_zeros = b_scales_and_zeros.to("mps")
b_int4pack = torch._convert_weight_to_int4pack(
b_int32, inner_k_tiles
)
return b_int4pack, b_scales_and_zeros
def weight_int4pack_mm(a, b_int4pack, b_scales_and_zeros):
return torch._weight_int4pack_mm(
a, b_int4pack, q_group, b_scales_and_zeros
)
b_int4pack, b_scales_and_zeros_f32 = convert_weight_to_int4pack(b_f32)
for dtype in [torch.float16, torch.float32, torch.bfloat16]:
a = a_f32.to(dtype=dtype)
b = b_f32.to(dtype=dtype)
b_scales_and_zeros = b_scales_and_zeros_f32.to(dtype=dtype)
ref = torch.mm(a, b)
res = weight_int4pack_mm(a, b_int4pack, b_scales_and_zeros)
mean_err = ((res - ref).abs() / ref).mean()
self.assertLess(mean_err, 0.05)
@parametrize("m", [1, 32, 64])
@parametrize("k", [32, 64])
@parametrize("n", [32, 64])
def test__int8_mm(self, m, k, n):
torch.manual_seed(1)
a_f32 = torch.rand((m, k), device="mps")
b_f32 = torch.rand((n, k), device="mps")
def convert_weight_to_int8pack(b):
b_int8pack, b_scales, _ = _dynamically_quantize_per_channel(
b, -128, 127, torch.int8
)
return b_int8pack, b_scales
def weight_int8pack_mm(a, b_int8pack, b_scales):
return torch._weight_int8pack_mm(a, b_int8pack, b_scales)
b_int8pack, b_scales_f32 = convert_weight_to_int8pack(b_f32)
for dtype in [torch.float16, torch.float32, torch.bfloat16]:
a = a_f32.to(dtype=dtype)
b = b_f32.to(dtype=dtype)
b_scales = b_scales_f32.to(dtype=dtype)
res = weight_int8pack_mm(a, b_int8pack, b_scales)
ref = torch.mm(a, b.transpose(0, 1))
mean_err = ((res - ref).abs() / ref).mean()
self.assertLess(mean_err, 0.05)
class TestSDPA(TestCaseMPS):
def _compare_tensors(self, y, ref):
denom = torch.maximum(ref.abs(), torch.tensor([1e-6], device=ref.device, dtype=ref.dtype))
err = ((y - ref).abs() / denom).mean().item()
self.assertLess(err, 0.01)
def _test_sdpa_no_mask(
self,
is_causal: bool,
dtype: torch.dtype,
L: int = 1,
S: int = 72,
NH: int = 32,
HS: int = 128,
requires_grad: bool = False
):
torch.manual_seed(1729)
with torch.nn.attention.sdpa_kernel([torch.nn.attention.SDPBackend.MATH]):
q = torch.randn([1, NH, L, HS], dtype=dtype, device="mps", requires_grad=requires_grad)
k = torch.randn([1, NH, S, HS], dtype=q.dtype, device="mps")
v = torch.randn([1, NH, S, HS], dtype=q.dtype, device="mps")
q_cpu = q.cpu().detach().cpu().requires_grad_(requires_grad)
k_cpu = k.cpu()
v_cpu = v.cpu()
y = F.scaled_dot_product_attention(q, k, v, dropout_p=0.0, is_causal=is_causal)
y_ref = F.scaled_dot_product_attention(q_cpu, k_cpu, v_cpu, dropout_p=0.0, is_causal=is_causal)
self._compare_tensors(y.cpu(), y_ref)
if requires_grad and torch.is_grad_enabled():
y.sum().backward()
y_ref.sum().backward()
self._compare_tensors(q.grad.cpu(), q_cpu.grad)
def test_sdpa_no_mask_no_causal_fp32(self):
self._test_sdpa_no_mask(False, torch.float32)
def test_sdpa_no_mask_no_causal_fp16(self):
self._test_sdpa_no_mask(False, torch.float16)
def test_sdpa_no_mask_causal_fp32(self):
self._test_sdpa_no_mask(True, torch.float32)
def test_sdpa_no_mask_causal_fp16(self):
self._test_sdpa_no_mask(True, torch.float16)
def test_sdpa_no_mask_causal_fp16_L7(self):
self._test_sdpa_no_mask(True, torch.float16, 7)
def test_sdpa_no_mask_causal_fp16_L7_S17(self):
self._test_sdpa_no_mask(True, torch.float16, 7, 17)
def test_sdpa_no_mask_causal_fp16_L7_S17_NH23_HS121(self):
self._test_sdpa_no_mask(True, torch.float16, 7, 17, 23, 121)
def test_sdpa_no_mask_no_causal_fp32_grad(self):
self._test_sdpa_no_mask(False, torch.float32, requires_grad=True)
with torch.no_grad():
self._test_sdpa_no_mask(False, torch.float32, requires_grad=True)
def _test_sdpa_mask(self, dtype: torch.dtype, L: int = 1, S: int = 72, NH: int = 32, HS: int = 128):
torch.manual_seed(1729)
causal_mask = torch.tril(torch.ones(S, S, dtype=torch.bool, device='mps'))
with torch.nn.attention.sdpa_kernel([torch.nn.attention.SDPBackend.MATH]):
i = 42
q = torch.randn([1, NH, L, HS], dtype=dtype, device="mps")
k = torch.randn([1, NH, S, HS], dtype=q.dtype, device="mps")
v = torch.randn([1, NH, S, HS], dtype=q.dtype, device="mps")
input_pos = torch.tensor([i], dtype=torch.int32, device='mps')
mask = causal_mask[None, None, input_pos]
y = F.scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=0.0, is_causal=False)
y_ref = F.scaled_dot_product_attention(q.cpu(), k.cpu(), v.cpu(), attn_mask=mask.cpu(), dropout_p=0.0, is_causal=False)
self._compare_tensors(y.cpu(), y_ref)
def test_sdpa_mask_fp32(self):
self._test_sdpa_mask(torch.float32)
# Test twice to catch https://github.com/pytorch/pytorch/issues/148194
self._test_sdpa_mask(torch.float32)
def test_sdpa_mask_fp16(self):
self._test_sdpa_mask(torch.float16)
def test_sdpa_mask_fp16_L6(self):
self._test_sdpa_mask(torch.float16, 6)
def test_sdpa_mask_fp16_L6_S17_NH23_HS121(self):
self._test_sdpa_mask(torch.float16, 7, 17, 23, 121)
# Regression test from: https://github.com/pytorch/pytorch/issues/156707
@parametrize("dtype", [torch.float16, torch.float32])
def test_sdpa_full_mask(self, dtype):
q = torch.randn(1, 1, 2, 4, dtype=dtype)
k = torch.randn(1, 1, 2, 4, dtype=dtype)
v = torch.randn(1, 1, 2, 4, dtype=dtype)
mask = torch.tensor([[[[False, False], [True, True]]]], dtype=torch.bool)
out_cpu = F.scaled_dot_product_attention(q, k, v, attn_mask=mask)
out_mps = F.scaled_dot_product_attention(q.to('mps'), k.to('mps'), v.to('mps'), attn_mask=mask.to('mps'))
self._compare_tensors(out_mps.cpu(), out_cpu)
@parametrize("dtype", [torch.float16, torch.float32])
def test_sdpa_3d_input(self, dtype):
head_num, seq_len, embed_dim = 16, 16, 80
q = torch.randn(head_num, seq_len, embed_dim, dtype=dtype)
k = torch.randn(head_num, seq_len, embed_dim, dtype=dtype)
v = torch.randn(head_num, seq_len, embed_dim, dtype=dtype)
attention_mask = torch.ones(1, seq_len, seq_len, dtype=dtype)
with torch.nn.attention.sdpa_kernel([torch.nn.attention.SDPBackend.MATH]):
y = F.scaled_dot_product_attention(
q.to("mps"),
k.to("mps"),
v.to("mps"),
attention_mask.to("mps"),
dropout_p=0.0
)
y_ref = F.scaled_dot_product_attention(
q.to("cpu"),
k.to("cpu"),
v.to("cpu"),
attention_mask.to("cpu"),
dropout_p=0.0
)
self._compare_tensors(y.cpu(), y_ref)
@parametrize("dtype", [torch.float16, torch.float32])
def test_sdpa_no_mask_5d(
self,
dtype: torch.dtype,
B: int = 2,
extra: int = 3,
NH: int = 4,
L: int = 10,
HS: int = 16,
requires_grad: bool = False
):
torch.manual_seed(1729)
q = torch.randn(B, extra, NH, L, HS, dtype=dtype, device="mps", requires_grad=requires_grad)
k = torch.randn(B, extra, NH, L, HS, dtype=dtype, device="mps")
v = torch.randn(B, extra, NH, L, HS, dtype=dtype, device="mps")
with torch.nn.attention.sdpa_kernel([torch.nn.attention.SDPBackend.MATH]):
y = F.scaled_dot_product_attention(q, k, v, dropout_p=0.0, is_causal=False)
y_ref = F.scaled_dot_product_attention(q.cpu(), k.cpu(), v.cpu(), dropout_p=0.0, is_causal=False)
self._compare_tensors(y.cpu(), y_ref)
if requires_grad and torch.is_grad_enabled():
y.sum().backward()
y_ref.sum().backward()
self._compare_tensors(q.grad.cpu(), q.cpu().grad)
@parametrize('dtype', [torch.float16, torch.float32])
def test_sdpa_mask_5d(
self,
dtype: torch.dtype,
B: int = 2,
extra: int = 3,
NH: int = 4,
L: int = 10,
HS: int = 16
):
torch.manual_seed(1729)
q = torch.randn(B, extra, NH, L, HS, dtype=dtype, device="mps")
k = torch.randn(B, extra, NH, L, HS, dtype=dtype, device="mps")
v = torch.randn(B, extra, NH, L, HS, dtype=dtype, device="mps")
mask = torch.tril(torch.ones(L, L, dtype=torch.bool, device="mps")).unsqueeze(0).unsqueeze(0)
with torch.nn.attention.sdpa_kernel([torch.nn.attention.SDPBackend.MATH]):
y = F.scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=0.0, is_causal=False)
y_ref = F.scaled_dot_product_attention(q.cpu(), k.cpu(), v.cpu(), attn_mask=mask.cpu(), dropout_p=0.0, is_causal=False)
self._compare_tensors(y.cpu(), y_ref)
@parametrize("dtype", [torch.float16, torch.float32])
@parametrize("is_causal", [True, False])
def test_sdpa_enable_gqa(self, dtype, is_causal):
q_heads = 32
key_heads = 16
L = 7
S = 17
HS = 23
q = torch.randn([2, q_heads, L, HS], dtype=dtype, device="mps")
k = torch.randn([2, key_heads, S, HS], dtype=dtype, device="mps")
v = torch.randn([2, key_heads, S, HS], dtype=dtype, device="mps")
y_ref = F.scaled_dot_product_attention(
q.cpu(), k.cpu(), v.cpu(), dropout_p=0.0, is_causal=is_causal, enable_gqa=True,
)
with torch.nn.attention.sdpa_kernel([torch.nn.attention.SDPBackend.MATH]):
y = F.scaled_dot_product_attention(
q, k, v, dropout_p=0.0, is_causal=is_causal, enable_gqa=True,
)
self._compare_tensors(y.cpu(), y_ref)
@serialTest()
def test_sdpa_fp32_no_memory_leak(self):
def get_mps_memory_usage():
return (torch.mps.current_allocated_memory() / (1024 * 1024),
torch.mps.driver_allocated_memory() / (1024 * 1024))
batch_size, seq_len, num_heads, head_dim = 4, 128, 8, 64
query = torch.randn(batch_size, num_heads, seq_len, head_dim, device="mps", dtype=torch.float32)
key = torch.randn(batch_size, num_heads, seq_len, head_dim, device="mps", dtype=torch.float32)
value = torch.randn(batch_size, num_heads, seq_len, head_dim, device="mps", dtype=torch.float32)
memory_footprints = []
for i in range(100):
output = F.scaled_dot_product_attention(query, key, value)
current_mem, driver_mem = get_mps_memory_usage()
memory_footprints.append((current_mem, driver_mem))
# 5 MB different maximum allowed value(could be decreased even more)
torch.testing.assert_close(memory_footprints[-1], memory_footprints[0], atol=5, rtol=1)
def generate_qkv(self, batch: int, NH: int, q_len: int, s_len: int, head_dim: int, layout: str, dtype: torch.dtype):
if layout == "contiguous":
q = torch.randn(batch, NH, q_len, head_dim, dtype=dtype, device="mps")
k = torch.randn(batch, NH, s_len, head_dim, dtype=dtype, device="mps")
elif layout == "mT":
# Transpose head dimension and length
q = torch.randn(batch, NH, head_dim, q_len, dtype=dtype, device="mps").mT
k = torch.randn(batch, NH, head_dim, s_len, dtype=dtype, device="mps").mT
elif layout == "transpose_seq_head":
# Transpose length and number of heads
q = torch.randn(batch, q_len, NH, head_dim, dtype=dtype, device="mps").transpose(1, 2)
k = torch.randn(batch, s_len, NH, head_dim, dtype=dtype, device="mps").transpose(1, 2)
elif layout == "permute":
# Permute head dimension and length
q = torch.randn(batch, head_dim, NH, q_len, dtype=dtype, device="mps").permute(0, 2, 3, 1)
k = torch.randn(batch, head_dim, NH, s_len, dtype=dtype, device="mps").permute(0, 2, 3, 1)
else:
raise ValueError(f"Unknown layout: {layout}")
v = torch.randn(batch, NH, s_len, head_dim, dtype=dtype, device="mps")
return q, k, v
def run_fast_attention_test(
self,
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
with_mask: bool,
dropout_p: float = 0.0,
is_causal: bool = False,
):
q_len = q.shape[2]
s_len = k.shape[2]
if with_mask:
attn_mask = torch.zeros(q.shape[0], q.shape[1], q_len, s_len,
dtype=torch.bool, device=q.device)
attn_mask[..., s_len // 2:] = True
with torch.nn.attention.sdpa_kernel([torch.nn.attention.SDPBackend.MATH]):
y = F.scaled_dot_product_attention(
q, k, v,
attn_mask=attn_mask,
dropout_p=dropout_p,
is_causal=is_causal,
)
y_ref = F.scaled_dot_product_attention(
q.cpu(),
k.cpu(),
v.cpu(),
attn_mask=attn_mask.cpu(),
dropout_p=dropout_p,
is_causal=is_causal,
)
else:
with torch.nn.attention.sdpa_kernel([torch.nn.attention.SDPBackend.MATH]):
y = F.scaled_dot_product_attention(
q, k, v,
dropout_p=dropout_p,
is_causal=is_causal,
)
y_ref = F.scaled_dot_product_attention(
q.cpu(),
k.cpu(),
v.cpu(),
dropout_p=dropout_p,
is_causal=is_causal,
)
self._compare_tensors(y.cpu(), y_ref)
@parametrize("dtype", [torch.float16, torch.float32])
@parametrize("layout", ["contiguous", "mT", "transpose_seq_head", "permute"])
@parametrize("head_dim", [64, 96, 128]) # 64, 96, 128 are for the fast kernel
@parametrize("with_mask", [True, False])
def test_fast_vector_attention(self, dtype: torch.dtype, layout: str, head_dim: int, with_mask: bool):
torch.manual_seed(1729)
batch = 1
NH = 2
q_len = 4 # <8 so that vector fast is eligible
s_len = 16 # smaller than 1024 so that we use the onepass variant
q, k, v = self.generate_qkv(batch, NH, q_len, s_len, head_dim, layout, dtype)
self.run_fast_attention_test(q, k, v, with_mask)
@parametrize("dtype", [torch.float32]) # float16 underflows sometimes, which leads to flaky tests
@parametrize("layout", ["contiguous", "mT", "transpose_seq_head", "permute"])
@parametrize("with_mask", [True, False])
def test_fast_vector_attention_2pass(self, dtype: torch.dtype, layout: str, with_mask: bool):
torch.manual_seed(1729)
batch = 1
NH = 32
q_len = 8
s_len = 1024 # large enough to trigger the twopass path
head_dim = 64 # supported head dimension for vector attention
q, k, v = self.generate_qkv(batch, NH, q_len, s_len, head_dim, layout, dtype)
self.run_fast_attention_test(q, k, v, with_mask)
@unittest.skip("Full attention fast kernel not implemented yet")
@parametrize("dtype", [torch.float16, torch.float32])
@parametrize("layout", ["contiguous", "mT"])
@parametrize("head_dim", [64, 80, 128]) # 64, 80, 128 are for the fast kernel
@parametrize("with_mask", [True, False])
def test_fast_full_attention(self, dtype: torch.dtype, layout: str, head_dim: int, with_mask: bool):
torch.manual_seed(1729)
batch = 1
NH = 2
q_len = 32 # threshold to trigger full fast attention path
s_len = 16
q, k, v = self.generate_qkv(batch, NH, q_len, s_len, head_dim, layout, dtype)
self.run_fast_attention_test(q, k, v, with_mask)
class TestSDPAMetaDispatchMode(TorchDispatchMode):
"""
TorchDispatchMode which intercepts the
_scaled_dot_product_attention_math_for_mps aten operator to check that the
meta kernel is correct.
"""
def __init__(self, test):
self.test = test
super().__init__()
def __torch_dispatch__(self, func, types, args, kwargs=None):
kwargs = kwargs or {}
res = func(*args, **kwargs)
if func != torch.ops.aten._scaled_dot_product_attention_math_for_mps.default:
return res
meta_args, meta_kwargs = pytree.tree_map_only(torch.Tensor, lambda t: t.to(device="meta"), (args, kwargs))
meta_res = func(*meta_args, **meta_kwargs)
def format_res(res):
return [
(t.shape, t.stride(), t.dtype) if isinstance(t, torch.Tensor) else t
for t in pytree.tree_flatten(res)[0]
]
# Format the output so that we only look at the tensor metadata
self.test.assertEqual(format_res(res), format_res(meta_res))
return res
def create_sdpa_meta_test():
"""
Creates a new class which takes every test in TestSDPA and adds the
TestSDPAMetaDispatchMode context in order to test the
scaled_dot_product_attention_for_mps meta kernel. This allows us to test all
the branches for the sdpa op. If there are changes to the sdpa kernel
without changing the meta kernel, a torch.compile guard will catch the issue
but not necessarily export.
"""
orig_test_cls = TestSDPA
new_test_cls = type(f"{orig_test_cls.__name__}Meta", orig_test_cls.__bases__, {})
new_test_cls.__qualname__ = new_test_cls.__name__
for name in dir(orig_test_cls):
if name.startswith("test_"):
fn = getattr(orig_test_cls, name)
if not callable(fn):
setattr(new_test_cls, name, getattr(orig_test_cls, name))
continue
new_name = f"{name}_meta"
def new_fn(self, *args, **kwargs):
with TestSDPAMetaDispatchMode(self):
fn(self, *args, **kwargs)
new_fn.__name__ = new_name
setattr(new_test_cls, new_name, new_fn)
elif not hasattr(new_test_cls, name):
setattr(new_test_cls, name, getattr(orig_test_cls, name))
return new_test_cls
TestSDPAMeta = create_sdpa_meta_test()
instantiate_parametrized_tests(TestSDPAMeta)
class TestGatherScatter(TestCaseMPS):
def test_slicing_with_step(self):
# Slicing with step
# https://github.com/pytorch/pytorch/issues/78886
x_mps = torch.zeros(10, dtype=torch.float32, device="mps")
x_mps[::2] = 1.0
x_cpu = torch.zeros(10, dtype=torch.float32, device="cpu")
x_cpu[::2] = 1.0
self.assertEqual(x_cpu, x_mps)
def test_cast_gather_scatter(self):
for _ in range(50):
input = np.random.randint(0, 255, size=(5, 5, 4), dtype=np.uint8)
with torch.no_grad():
s = torch.tensor(input, dtype=torch.uint8, device="mps").unsqueeze(0)
s_cpu = torch.tensor(input, dtype=torch.uint8, device="cpu").unsqueeze(0)
s = s.long()
s_cpu = s_cpu.long()
self.assertEqual(s.cpu(), s_cpu)
s = s.float()
s_cpu = s_cpu.float()
self.assertEqual(s.cpu(), s_cpu)
s /= 255
s_cpu /= 255
self.assertEqual(s.cpu(), s_cpu)
def test_slicing_replace_column(self):
# https://github.com/pytorch/pytorch/issues/78074
def _helper(tensor_data):
x_cpu = torch.tensor(tensor_data)
x_mps = x_cpu.to('mps')
x_cpu[:, 0] = 7
x_mps[:, 0] = 7
self.assertEqual(x_cpu, x_mps)
_helper([[1, 2, 3], [4, 5, 6]])
_helper([[1, 2, 3], [4, 5, 6], [7, 8, 9]])
_helper([[1, 2, 3], [4, 5, 6], [7, 8, 9], [10, 11, 12]])
def test_inplace_scatter(self):
# https://github.com/pytorch/pytorch/issues/79672
a_mps = torch.ones((2, 2),).to(torch.device("mps"))
b_mps = torch.ones((2, 2),).to(torch.device("mps"))
a_cpu = torch.ones((2, 2),).to(torch.device("cpu"))
b_cpu = torch.ones((2, 2),).to(torch.device("cpu"))
a_mps[:, 0] += b_mps[:, 0]
a_cpu[:, 0] += b_cpu[:, 0]
self.assertEqual(a_cpu, a_mps)
a_mps[:, 0] = a_mps[:, 0] + b_mps[:, 0]
a_cpu[:, 0] = a_cpu[:, 0] + b_cpu[:, 0]
self.assertEqual(a_cpu, a_mps)
# These tests were taken from test/test_view_ops.py
# They are subset of those tests as currently only this subset is working.
# This whole `class` will be removed when we add generic device testing. There
# are no additional tests added apart from what is part of test_view_ops.py
class TestViewOpsMPS(TestCaseMPS):
exact_dtype = True
def test_permute_slicing(self):
# test the fix for crash reported in
# https://github.com/pytorch/pytorch/issues/94190
cpu_x = (torch.randn([3, 2, 2]).float())
mps_x = cpu_x.detach().clone().to('mps')
cpu_out = cpu_x.permute((2, 0, 1)) * 2.0
mps_out = mps_x.permute((2, 0, 1)) * 2.0
# this print caused a crash prior to fix PR#94259
print(torch.zeros_like(mps_out))
# test the fix for fill_scalar_mps() mentioned in issue #94190
self.assertEqual(torch.zeros_like(cpu_out), torch.zeros_like(mps_out))
self.assertEqual(cpu_x[:, 1, :].fill_(1), mps_x[:, 1, :].fill_(1))
def is_view_of(self, base, other):
if (not other._is_view() or
other is base or
other._base is not base or
base.device != other.device):
return False
# Note: only validates storage on native device types
# because some accelerators, like XLA, do not expose storage
if base.device.type == 'mps':
if base.untyped_storage().data_ptr() != other.untyped_storage().data_ptr():
return False
return True
# Returns true if v1 and v2 are views of the same base
def is_view_of_same_base(self, v1, v2):
if (not v1._is_view() or v1 is v2):
return False
return self.is_view_of(v1._base, v2)
# Performs transpose if contiguous=True, else returns the input tensor as is
def _do_transpose(self, x, contiguous=False, dim0=0, dim1=1):
if contiguous:
return x
else:
return x.transpose(dim0, dim1)
def test_diagonal_view(self, device="mps"):
t = torch.ones((5, 5), device=device)
v = torch.diagonal(t)
self.assertTrue(self.is_view_of(t, v))
v[0] = 0
self.assertEqual(t[0, 0], v[0])
t = torch.ones((3, 3, 3), device="mps")
v = torch.diagonal(t, offset=1, dim1=1, dim2=2)
self.assertTrue(self.is_view_of(t, v))
v[0, 0] = 0
self.assertEqual(t[0, 0, 1], v[0, 0])
def test_select_view(self, device="mps") -> None:
t = torch.ones((5, 5), device=device)
v = t.select(0, 2)
self.assertTrue(self.is_view_of(t, v))
v[0] = 0
self.assertEqual(t[2, 0], v[0])
def test_unbind_view(self, device="mps") -> None:
t = torch.zeros((5, 5), device=device)
tup = torch.unbind(t)
for idx, v in enumerate(tup):
self.assertTrue(self.is_view_of(t, v))
v[0] = idx + 1
self.assertEqual(t[idx, 0], v[0])
def test_expand_view(self, device="mps") -> None:
t = torch.ones((5, 1), device=device)
v = t.expand(5, 5)
self.assertTrue(self.is_view_of(t, v))
v[2, 2] = 0
self.assertEqual(t[2, 0], v[2, 2])
def test_expand_as_view(self, device="mps"):
t = torch.ones((5, 1), device=device)
e = torch.empty((5, 5), device=device)
v = t.expand_as(e)
self.assertTrue(self.is_view_of(t, v))
v[2, 2] = 0
self.assertEqual(t[2, 0], v[2, 2])
def test_narrow_view(self, device="mps"):
t = torch.ones((5, 5), device=device)
v = torch.narrow(t, 1, 2, 2)
self.assertTrue(self.is_view_of(t, v))
v[0, 0] = 0
self.assertEqual(t[0, 2], v[0, 0])
def test_permute_view(self, device="mps") -> None:
t = torch.ones((5, 5), device=device)
v = t.permute(1, 0)
self.assertTrue(self.is_view_of(t, v))
v[0, 1] = 0
self.assertEqual(t[1, 0], v[0, 1])
def test_transpose_view(self, device="mps"):
for fn in (torch.swapdims, torch.swapaxes, torch.transpose):
t = torch.ones((5, 5), device=device)
v = fn(t, 0, 1)
self.assertTrue(self.is_view_of(t, v))
v[0, 1] = 0
self.assertEqual(t[1, 0], v[0, 1])
def test_transpose_inplace_view(self, device="mps"):
t = torch.ones(5, 5, device=device)
v = t.view_as(t)
v = v.swapdims_(0, 1)
self.assertTrue(self.is_view_of(t, v))
v[0, 1] = 0
self.assertEqual(t[1, 0], v[0, 1])
t = torch.ones(5, 5, device=device)
v = t.view_as(t)
v = v.swapaxes_(0, 1)
self.assertTrue(self.is_view_of(t, v))
v[0, 1] = 0
self.assertEqual(t[1, 0], v[0, 1])
t = torch.ones(5, 5, device=device)
v = t.view_as(t)
v = v.transpose_(0, 1)
self.assertTrue(self.is_view_of(t, v))
v[0, 1] = 0
self.assertEqual(t[1, 0], v[0, 1])
def test_t_view(self, device="mps"):
t = torch.ones((5, 5), device=device)
v = t.t()
self.assertTrue(self.is_view_of(t, v))
v[0, 1] = 0
self.assertEqual(t[1, 0], v[0, 1])
def test_inplace_view_add(self):
# https://github.com/pytorch/pytorch/issues/96153
t_mps = torch.ones((2, 6,), device='mps')[1].reshape(2, 3)
t_cpu = torch.ones((2, 6,), device='cpu')[1].reshape(2, 3)
t_mps = t_mps + 1
t_cpu = t_cpu + 1
self.assertEqual(t_mps, t_cpu)
def test_t_inplace_view(self, device="mps"):
t = torch.ones(5, 5, device=device)
v = t.view_as(t)
v = v.t_()
self.assertTrue(self.is_view_of(t, v))
v[0, 1] = 0
self.assertEqual(t[1, 0], v[0, 1])
def test_T_view(self, device="mps"):
for op in ("T", "H", "mT", "mH"):
t = torch.ones((5, 5), device=device)
v = getattr(t, op)
self.assertTrue(self.is_view_of(t, v))
v[0, 1] = 0
self.assertEqual(t[1, 0], v[0, 1])
def test_unfold_view(self, device="mps"):
t = torch.ones(10, device=device)
v = t.unfold(0, 3, 2)
self.assertTrue(self.is_view_of(t, v))
v[1, 0] = 0
self.assertEqual(t[2], v[1, 0])
def test_squeeze_view(self, device="mps"):
t = torch.ones(5, 1, 5, device=device)
v = torch.squeeze(t)
self.assertTrue(self.is_view_of(t, v))
v[0, 1] = 0
self.assertIs(t, v._base)
def test_squeeze_inplace_view(self, device="mps"):
t = torch.ones(5, 5, device=device)
v = t.view_as(t)
v = v.squeeze_()
self.assertTrue(self.is_view_of(t, v))
v[0, 1] = 0
self.assertIs(t, v._base)
def test_unsqueeze_view(self, device="mps"):
t = torch.ones(5, 5, device=device)
v = torch.unsqueeze(t, 1)
self.assertTrue(self.is_view_of(t, v))
v[0, 0, 1] = 0
self.assertEqual(t[0, 1], v[0, 0, 1])
def test_unsqueeze_inplace_view(self, device="mps"):
t = torch.ones(5, 5, device=device)
v = t.view_as(t)
v = v.unsqueeze_(1)
self.assertTrue(self.is_view_of(t, v))
v[0, 0, 1] = 0
self.assertEqual(t[0, 1], v[0, 0, 1])
def test_as_strided_view(self, device="mps"):
t = torch.ones(5, 5, device=device)
v = torch.as_strided(t, (25,), (1,))
self.assertTrue(self.is_view_of(t, v))
v[6] = 0
self.assertEqual(t[1, 1], v[6])
def test_as_strided_inplace_view(self, device="mps"):
t = torch.ones(5, 5, device=device)
v = t.view_as(t)
v = v.as_strided_((25,), (1,))
self.assertTrue(self.is_view_of(t, v))
v[6] = 0
self.assertEqual(t[1, 1], v[6])
def test_view_view(self, device="mps"):
t = torch.ones(5, 5, device=device)
v = t.view(25)
self.assertTrue(self.is_view_of(t, v))
v[6] = 0
self.assertEqual(t[1, 1], v[6])
def test_view_as_view(self, device="mps"):
t = torch.ones(5, 5, device=device)
e = torch.empty((25,))
v = t.view_as(e)
self.assertTrue(self.is_view_of(t, v))
v[6] = 0
self.assertEqual(t[1, 1], v[6])
def test_contiguous_self(self, device="mps"):
t = torch.ones(5, 5, device=device)
s = t.contiguous()
self.assertIs(s, t)
def test_contiguous_nonview(self, device="mps"):
t = torch.ones(5, 5, device=device)
nv = t.t().contiguous()
self.assertFalse(self.is_view_of(t, nv))
nv[0, 0] = 0
self.assertNotEqual(t[0, 0], nv[0, 0])
def test_reshape_view(self, device="mps"):
t = torch.ones(5, 5, device=device)
v = torch.reshape(t, (25,))
self.assertTrue(self.is_view_of(t, v))
v[6] = 0
self.assertEqual(t[1, 1], v[6])
def test_reshape_as_view(self, device="mps"):
t = torch.ones(5, 5, device=device)
e = torch.empty((25,), device=device)
v = t.reshape_as(e)
self.assertTrue(self.is_view_of(t, v))
v[6] = 0
self.assertEqual(t[1, 1], v[6])
def test_reshape_nonview(self, device="mps"):
t = torch.ones(5, 5, device=device)
nv = torch.reshape(t.t(), (25,))
self.assertFalse(self.is_view_of(t, nv))
nv[6] = 0
self.assertNotEqual(t[1, 1], nv[6])
def test_flatten_view(self, device="mps"):
def test_writes_propagate(t, v):
idx_t = (0,) * t.ndim
idx_v = (0,) * v.ndim
v[idx_v] = 0
self.assertEqual(t[idx_t], v[idx_v])
t = torch.ones(1, 2, 3, 4, device=device)
v = t.flatten()
self.assertTrue(self.is_view_of(t, v))
test_writes_propagate(t, v)
# zero-dimensional tensor
t = torch.tensor(1, device=device)
v = t.flatten()
test_writes_propagate(t, v)
self.assertTrue(self.is_view_of(t, v))
t = torch.ones(1, 2, 3, 4, device=device).transpose(2, 3)
v = t.flatten(0, 1)
test_writes_propagate(t, v)
self.assertTrue(self.is_view_of_same_base(t, v))
# stride[i] = stride[i + 1] * size[i + 1] is satisfied for 3 groups:
t = torch.ones(720, device=device) \
.as_strided((2, 3, 2, 3, 5, 4), (6, 2, 15, 5, 1, 0))
# [--1--|---2---|-3-] [--1--|----2---|-3-]
v1 = t.flatten(0, 1)
v2 = v1.flatten(1, 3)
v3 = v2.flatten(2, 2)
test_writes_propagate(t, v1)
self.assertTrue(self.is_view_of_same_base(t, v1))
test_writes_propagate(t, v2)
self.assertTrue(self.is_view_of_same_base(t, v2))
test_writes_propagate(t, v3)
self.assertTrue(self.is_view_of_same_base(t, v3))
def test_flatten_nonview(self, device="mps"):
def assert_is_nonview(t, nv):
idx_t = (0,) * t.ndim
idx_nv = (0,) * nv.ndim
self.assertFalse(nv._is_view())
nv[idx_nv] = 0
self.assertNotEqual(t[idx_t], nv[idx_nv])
t = torch.ones(2, 3, 2, 3, device=device).transpose(2, 3)
nv = t.flatten(1, 3)
assert_is_nonview(t, nv)
t = torch.ones(2, 2, device=device).T
nv = t.flatten()
assert_is_nonview(t, nv)
# flatten returns the original object if start_dim=end_dim
t = t = torch.ones(2, 2, device=device)
nv = t.flatten(1, 1)
self.assertIs(t, nv)
def test_basic_indexing_slice_view(self, device="mps"):
t = torch.ones(5, 5, device=device)
v = t[:2, :3]
self.assertTrue(self.is_view_of(t, v))
v[0, 0] = 0
self.assertEqual(t[0, 0], v[0, 0])
def test_basic_indexing_ellipses_view(self, device="mps"):
t = torch.ones(5, 5, device=device)
v = t[..., :2]
self.assertTrue(self.is_view_of(t, v))
v[0, 0] = 0
self.assertEqual(t[0, 0], v[0, 0])
def test_basic_indexing_newaxis_view(self, device="mps"):
t = torch.ones(5, 5, device=device)
v = t[None, :2, 3]
self.assertTrue(self.is_view_of(t, v))
v[0, 0] = 0
self.assertEqual(t[0, 3], v[0, 0])
def test_chunk_view(self, device="mps"):
t = torch.zeros(3, 3, device=device)
l = torch.chunk(t, 3)
for idx, v in enumerate(l):
self.assertTrue(self.is_view_of(t, v))
v[0, 0] = idx + 1
self.assertEqual(t[idx, 0], v[0, 0])
def test_split_view(self, device="mps"):
t = torch.zeros(3, 3, device=device)
l = torch.split(t, [1, 1, 1])
for idx, v in enumerate(l):
self.assertTrue(self.is_view_of(t, v))
v[0, 0] = idx + 1
self.assertEqual(t[idx, 0], v[0, 0])
def test_movedim_view(self, device="mps"):
def run_test(device, op):
t = torch.zeros(3, 3, device=device)
out = op(t)
self.assertTrue(self.is_view_of(t, out))
# Randomly change values in output
# and verify that original is changed
# as well.
for _ in range(3):
idx_1, idx_2 = random.randint(0, 2), random.randint(0, 2)
out[idx_1, idx_2] = random.random()
self.assertEqual(t[idx_2, idx_1], out[idx_1, idx_2])
for fn in [torch.movedim, torch.moveaxis]:
op = partial(fn, source=(0, 1), destination=(1, 0))
run_test(device, op)
op = partial(fn, source=0, destination=1)
run_test(device, op)
# Testing that the generated view_copy kernel and its derivative are implemented correctly
def test_view_copy(self, device="mps"):
a = torch.randn(4, device=device, requires_grad=True)
a_ref = a.detach().clone().requires_grad_()
a_view = a_ref.view(2, 2)
a_view_copy = torch.view_copy(a, (2, 2))
# view_copy ops don't preserve view relationship
self.assertTrue(self.is_view_of(a_ref, a_view))
self.assertFalse(self.is_view_of(a, a_view_copy))
a_view_copy.sum().backward()
a_view.sum().backward()
# forward and backward give the same shape + result
self.assertEqual(a_view_copy, a_view)
self.assertEqual(a.grad, a_ref.grad)
def test_view_copy_out(self, device="mps"):
a = torch.randn(2, 2, device=device)
out = torch.empty(2, device=device)
torch.diagonal_copy(a, out=out)
expected = torch.diagonal_copy(a)
self.assertEqual(expected, out)
a = torch.randn(4, device=device)
out1 = torch.empty(2, device=device)
out2 = torch.empty(2, device=device)
torch.split_copy(a, 2, out=(out1, out2))
expected1, expected2 = torch.split_copy(a, 2)
self.assertEqual(expected1, out1)
self.assertEqual(expected2, out2)
def test_detached_view_copy(self, device="mps"):
# https://github.com/pytorch/pytorch/issues/86052
x = torch.arange(2)
# .detach() makes y not a view, but contig tensor
# with non-zero offset
y = x[1].detach()
z = y.to(device)
self.assertEqual(y, z.cpu())
def test_empty_reshape(self, device="mps"):
x = torch.randn(0, 6, device=device)
self.assertEqual((1, 0, 6, 1, 1), x.reshape(1, 0, 6, 1, 1).shape)
# should be viewable -- i.e. data_ptr is the same.
self.assertEqual(x.data_ptr(), x.reshape(1, 0, 6, 1, 1).data_ptr())
# match NumPy semantics -- don't infer the size of dimension with a degree of freedom
self.assertRaises(RuntimeError, lambda: x.reshape(0, -1))
def test_expand(self, device="mps"):
tensor = torch.rand(1, 8, 1, device=device)
tensor2 = torch.rand(5, device=device)
template = torch.rand(4, 8, 5, device=device)
target = template.size()
self.assertEqual(tensor.expand_as(template).size(), target)
self.assertEqual(tensor.expand(4, 8, 5).size(), target)
self.assertEqual(tensor.expand(target).size(), target)
self.assertEqual(tensor2.expand_as(template).size(), target)
self.assertEqual(tensor2.expand(4, 8, 5).size(), target)
self.assertEqual(tensor2.expand(target).size(), target)
# test double expand
self.assertEqual(tensor2.expand(1, 5).expand(2, 2, 5), tensor2.repeat(2, 2, 1))
# test non-contiguous
noncontig = torch.randn(5, 2, 1, 3, device=device)[:, 0]
self.assertFalse(noncontig.is_contiguous())
self.assertEqual(noncontig.expand(2, 5, 4, 3), noncontig.contiguous().repeat(2, 1, 4, 1))
# make sure it's compatible with unsqueeze
expanded = tensor2.expand(1, 1, 5)
unsqueezed = tensor2.unsqueeze(0).unsqueeze(1)
self.assertEqual(expanded, unsqueezed)
self.assertEqual(expanded.stride(), unsqueezed.stride())
# test -1 as target size
self.assertEqual(tensor.expand(4, -1, 5), tensor.expand(4, 8, 5))
self.assertRaises(RuntimeError, lambda: tensor2.expand(-1, -1))
# test expanding empty to empty
self.assertEqual(torch.zeros(0, device=device).expand((0,)), torch.zeros(0, device=device))
def test_view_empty(self, device="mps"):
x = torch.randn(0, 6, device=device)
self.assertEqual((1, 0, 6, 1, 1), x.view(1, 0, 6, 1, 1).shape)
def test_reshape(self, device="mps"):
x = torch.randn(3, 3, device=device)
self.assertEqual(x.data_ptr(), x.reshape(-1).data_ptr())
self.assertEqual(x.data_ptr(), x.reshape(1, 9, 1).data_ptr())
self.assertEqual(torch.reshape(x, (9,)), x.reshape(9))
self.assertRaises(RuntimeError, lambda: x.reshape(-1, -1))
y = torch.randn(4, 4, 4, device=device)[:, 0, :]
# .data_ptr() on meta tensors is always 0 so they are equal regardless of the reshape
if device != "meta":
self.assertNotEqual(y.data_ptr(), y.reshape(-1).data_ptr())
self.assertEqual(y.contiguous().view(-1), y.reshape(-1))
self.assertEqual(y.reshape(2, 2, 4).data_ptr(), y.data_ptr())
s = torch.randn((), device=device)
self.assertEqual(s.data_ptr(), s.reshape(()).data_ptr())
self.assertEqual(s.reshape(-1).shape, (1,))
self.assertRaises(RuntimeError, lambda: s.reshape(2))
empty = torch.tensor([], device=device)
self.assertEqual(empty, empty.reshape(-1))
self.assertEqual(empty, empty.reshape([0]))
# TODO: fix these once we have multi-dimensional empty tensors
self.assertEqual(empty.reshape([0, 1]).shape, (0, 1))
self.assertEqual(empty.reshape([1, -1]).shape, (1, 0))
self.assertRaises(RuntimeError, lambda: empty.reshape(1))
x = torch.randn(3, 3, device=device)
self.assertEqual(x.data_ptr(), x.reshape_as(torch.rand(9)).data_ptr())
self.assertEqual(x.data_ptr(), x.reshape_as(torch.rand(1, 9, 1)).data_ptr())
self.assertRaises(RuntimeError, lambda: x.reshape_as(torch.rand(10, device=device)))
def test_narrow(self, device="mps"):
x = torch.tensor([[0, 1, 2], [3, 4, 5], [6, 7, 8]])
self.assertEqual(x.narrow(0, 0, 1), torch.tensor([[0, 1, 2]]))
self.assertEqual(x.narrow(0, 0, 2), torch.tensor([[0, 1, 2], [3, 4, 5]]))
self.assertEqual(x.narrow(0, 1, 1), torch.tensor([[3, 4, 5]]))
self.assertEqual(x.narrow(0, -1, 1), torch.tensor([[6, 7, 8]]))
self.assertEqual(x.narrow(0, -2, 2), torch.tensor([[3, 4, 5], [6, 7, 8]]))
self.assertEqual(x.narrow(0, -3, 3), torch.tensor([[0, 1, 2], [3, 4, 5], [6, 7, 8]]))
self.assertEqual(x.narrow(-1, -1, 1), torch.tensor([[2], [5], [8]]))
self.assertEqual(x.narrow(-2, -1, 1), torch.tensor([[6, 7, 8]]))
def test_narrow_tensor(self, device="mps"):
x = torch.tensor([[0, 1, 2], [3, 4, 5], [6, 7, 8]])
self.assertEqual(x.narrow(0, torch.tensor(0), 1), torch.tensor([[0, 1, 2]]))
with self.assertRaises(Exception):
x.narrow(0, torch.tensor(0.), 1)
with self.assertRaises(Exception):
x.narrow(0, torch.tensor([0]), 1)
with self.assertRaises(Exception):
x.narrow(0, torch.tensor([0, 1]), 1)
def test_t(self, device="mps"):
# Test 0D tensors
x = torch.randn(())
self.assertEqual(x, x.t())
x = x.to_sparse()
self.assertEqual(x, x.t())
# Test 1D tensors
x = torch.arange(4)
self.assertEqual(x, x.t())
x = x.to_sparse()
self.assertEqual(x, x.t())
# Test 2D tensors
x = torch.rand((2, 2))
self.assertEqual(x.t(), x.transpose(0, 1))
x = x.to_sparse()
self.assertEqual(x.t(), x.transpose(0, 1))
# Test 3D tensor
x = torch.rand((2, 2, 2))
with self.assertRaisesRegex(RuntimeError, 'expects a tensor with <= 2 dimensions, but self is 3D'):
x.t()
x = x.to_sparse()
with self.assertRaisesRegex(RuntimeError, 'expects a tensor with <= 2 sparse and 0 dense dimensions'):
x.t()
def test_split(self, device="mps"):
tensor = torch.rand(7, 4)
split_size = 3
dim = 0
target_sizes = ([3, 4], [3, 4], [1, 4])
splits = tensor.split(split_size, dim)
start = 0
for target_size, split in zip(target_sizes, splits):
self.assertEqual(split.size(), target_size)
self.assertEqual(tensor.narrow(dim, start, target_size[dim]), split, atol=0, rtol=0)
start = start + target_size[dim]
# Variable sections split
tensor = torch.randn(20, 10)
dim = 0
split_sizes = [5, 5, 10]
target_sizes = ([[5, 10], [5, 10], [10, 10]])
splits = tensor.split(split_sizes, dim)
start = 0
for target_size, split in zip(target_sizes, splits):
self.assertEqual(split.size(), target_size)
self.assertEqual(tensor.narrow(dim, start, target_size[dim]), split, atol=0, rtol=0)
start = start + target_size[dim]
split_sizes = [2, 2, 6]
target_sizes = ([20, 2], [20, 2], [20, 6])
dim = 1
splits = tensor.split(split_sizes, dim)
start = 0
for target_size, split in zip(target_sizes, splits):
self.assertEqual(split.size(), target_size)
self.assertEqual(tensor.narrow(dim, start, target_size[dim]), split, atol=0, rtol=0)
start = start + target_size[dim]
def test_chunk(self, device="mps"):
tensor = torch.rand(4, 7)
num_chunks = 3
dim = 1
target_sizes = ([4, 3], [4, 3], [4, 1])
splits = tensor.chunk(num_chunks, dim)
start = 0
for target_size, split in zip(target_sizes, splits):
self.assertEqual(split.size(), target_size)
self.assertEqual(tensor.narrow(dim, start, target_size[dim]), split,
atol=0, rtol=0)
start = start + target_size[dim]
# Invalid chunk sizes
error_regex = 'chunk expects.*greater than 0'
with self.assertRaisesRegex(RuntimeError, error_regex):
tensor.chunk(0)
with self.assertRaisesRegex(RuntimeError, error_regex):
tensor.chunk(-2)
def test_unsqueeze(self, device="mps") -> None:
x = torch.randn(2, 3, 4)
y = x.unsqueeze(1)
self.assertEqual(y, x.view(2, 1, 3, 4))
y = x.clone().unsqueeze_(2)
self.assertEqual(y, x.view(2, 3, 1, 4))
x = x[:, 1]
self.assertFalse(x.is_contiguous())
y = x.unsqueeze(1)
self.assertEqual(y, x.contiguous().view(2, 1, 4))
y = x.clone().unsqueeze_(2)
self.assertEqual(y, x.contiguous().view(2, 4, 1))
# unit test for special case transposed copy (see ATen/native/Copy.cpp for details)
def test_big_transpose(self, device="mps"):
t = torch.rand(456, 789, device=device)
t1 = t.t().contiguous()
t2 = torch.from_numpy(t.cpu().numpy().transpose())
self.assertEqual(t1, t2)
def test_T(self, device="mps"):
a = torch.randn(2, 3, 4, device=device)
t1 = a.T
t2 = a.permute(2, 1, 0)
self.assertEqual(t2, t1)
b = torch.randn(10, device=device)
self.assertEqual(b, b.T)
def test_transposes(self, device="mps", dtype=torch.float32):
for op in ("T", "H", "mT", "mH", "adjoint"):
shapes = ((2, 3), (2, 3, 4)) if op[0] == "m" or op == "adjoint" else ((2, 3),)
for shape in shapes:
a = make_tensor(shape, device=device, dtype=dtype)
t1 = getattr(a, op)
if op == "adjoint":
t1 = t1()
t2 = a
if a.ndim != 0:
t2 = t2.transpose(-2, -1)
if op[-1] == "H" or op == "adjoint":
t2 = t2.conj()
self.assertEqual(t2, t1)
def test_transposes_errors(self, device="mps", dtype=torch.float32):
for op in ("H", "mT", "mH", "adjoint"):
shapes = ((2,), (2, 3, 4)) if op == "H" else ((2,),)
for shape in shapes:
a = make_tensor(shape, device=device, dtype=dtype)
with self.assertRaisesRegex(RuntimeError, "only supported on matrices"):
t1 = getattr(a, op)
if op == "adjoint":
t1 = t1()
def test_python_types(self, device="mps"):
a1 = torch.randn((1, 2), device=device, dtype=torch.float32)
a2 = torch.randn((1, 2), device=device, dtype=torch.float32)
self.assertEqual(a1.dtype, a2.dtype)
b1 = torch.arange(10, 20, dtype=torch.int64, device=device)
b2 = torch.arange(10, 20, dtype=int, device=device)
self.assertEqual(b1.dtype, b2.dtype)
c1 = torch.tensor([True, False], dtype=torch.bool, device=device)
c2 = torch.tensor([True, False], dtype=bool, device=device)
self.assertEqual(c1.dtype, c2.dtype)
# TODO: is resize best put in test_view_ops?
def test_resize_as_preserves_strides(self, device="mps"):
x = torch.empty(2, 3).t()
old_strides = x.stride()
x.resize_as_(x)
self.assertEqual(x.stride(), old_strides)
def test_memory_format_resize_as(self, device="mps"):
def test_helper(shape, memory_format, device="mps"):
xc = torch.randn(shape, device=device).contiguous(memory_format=memory_format)
flat = torch.randn(xc.numel(), device=device)
flat.resize_as_(xc, memory_format=torch.preserve_format)
self.assertTrue(flat.is_contiguous(memory_format=memory_format))
test_helper((10, 3, 32, 32), torch.channels_last, device="mps")
test_helper((3, 10, 3, 32, 32), torch.channels_last_3d, device="mps")
def test_memory_format_resize_(self, device="mps"):
def test_helper(shape, numel, memory_format, device="mps"):
flat = torch.randn(numel, device=device)
flat.resize_(shape, memory_format=memory_format)
self.assertTrue(flat.is_contiguous(memory_format=memory_format))
test_helper((10, 3, 32, 32), 10 * 3 * 32 * 32, torch.channels_last, device="mps")
test_helper((3, 10, 3, 32, 32), 3 * 10 * 3 * 32 * 32, torch.channels_last_3d, device="mps")
# TODO: OpInfo this
def _test_atleast(self, device, torch_fn):
# 0-dim
s = torch.tensor(0.5, dtype=torch.double, requires_grad=True)
gradcheck(lambda x: torch_fn(x), s)
gradgradcheck(lambda x: torch_fn(x), s)
# 1-dim
a = torch.rand(4, dtype=torch.double, requires_grad=True)
gradcheck(lambda x: torch_fn(x), a)
gradgradcheck(lambda x: torch_fn(x), a)
# 2,3,4-dim
b = torch.rand(4, 3, dtype=torch.double, requires_grad=True)
c = torch.rand(4, 3, 2, dtype=torch.double, requires_grad=True)
d = torch.rand(4, 3, 2, 1, dtype=torch.double, requires_grad=True)
input_tuple = (s, a, b, c, d)
gradcheck(lambda s, w, x, y, z: torch_fn(s, w, x, y, z), input_tuple)
gradgradcheck(lambda s, w, x, y, z: torch_fn(s, w, x, y, z), input_tuple)
def test_atleast_gradient(self, device="mps"):
self._test_atleast(device, torch.atleast_1d)
self._test_atleast(device, torch.atleast_2d)
self._test_atleast(device, torch.atleast_3d)
def test_view(self, device="mps"):
tensor = torch.rand(15, device=device)
template = torch.rand(3, 5, device=device)
empty = torch.empty(0, device=device)
target = template.size()
self.assertEqual(tensor.view_as(template).size(), target)
self.assertEqual(tensor.view(3, 5).size(), target)
self.assertEqual(tensor.view(torch.Size([3, 5])).size(), target)
self.assertEqual(tensor.view(-1, 5).size(), target)
self.assertEqual(tensor.view(3, -1).size(), target)
tensor_view = tensor.view(5, 3)
tensor_view.fill_(random.uniform(0, 1))
self.assertEqual(empty.view_as(empty), empty)
self.assertEqual(empty.view(0), empty)
self.assertEqual(empty.view(0, 3, 0, 1).size(), torch.Size([0, 3, 0, 1]))
self.assertEqual(empty.view(0, 3, 0, 1).view(0), empty)
# test size inference with empty tensors
self.assertEqual(empty.view(-1).size(), torch.Size([0]))
self.assertEqual(empty.view(10, 3, -1).size(), torch.Size([10, 3, 0]))
with self.assertRaisesRegex(RuntimeError, r"because the unspecified dimension size -1 can be any value"):
empty.view(-1, 0)
with self.assertRaisesRegex(RuntimeError, r"because the unspecified dimension size -1 can be any value"):
empty.view(3, 0, -1, 0)
self.assertRaises(RuntimeError, lambda: tensor.view(15, 0))
self.assertRaises(RuntimeError, lambda: tensor.view(7, -1))
self.assertRaises(RuntimeError, lambda: tensor.view(15, -1, -1))
def test_contiguous(self, device="mps"):
x = torch.randn(1, 16, 5, 5, device=device)
self.assertTrue(x.is_contiguous())
stride = list(x.stride())
stride[0] = 20
# change the stride in dimension 0. the tensor is still contiguous because size[0] is 1
x.set_(x.storage(), 0, x.size(), stride)
self.assertTrue(x.is_contiguous())
def test_resize_mps_dtypes(self, device="mps"):
shape = (2, 2)
for dt in MPS_DTYPES:
x = torch.tensor([[1, 2], [3, 4], [5, 6]], dtype=dt, device=device)
x.resize_(shape)
self.assertEqual(shape, x.shape)
def test_resize_as_mps_dtypes(self, device="mps"):
for dt in MPS_DTYPES:
x = torch.tensor([[1, 2], [3, 4], [5, 6]], dtype=dt, device=device)
y = torch.tensor([[1, 2, 3], [4, 5, 6]], dtype=dt, device=device)
x.resize_as_(y)
self.assertEqual(y.shape, x.shape)
def test_resize_overflow(self, device="mps"):
x = torch.empty((), dtype=torch.float64)
with self.assertRaisesRegex(RuntimeError, 'Storage size calculation overflowed'):
x.resize_([2, 4, 2**29, 2**29])
with self.assertRaisesRegex(RuntimeError, 'overflow'):
x.resize_([8, 8, 2**29, 2**29])
def test_view_all_dtypes_and_devices(self, device="mps"):
for dt in (torch.float, torch.bool):
x = torch.tensor([[1, 2], [3, 4], [5, 6]], dtype=dt, device=device)
self.assertEqual(x.view(6).shape, [6])
class TestConvolutionMPS(TestCaseMPS):
def test_conv1d_all_strides_paddings(self):
# https://github.com/pytorch/pytorch/issues/82921
def helper(stride, padding):
y_cpu = torch.randn(1, 57, 40)
conv_cpu = nn.Conv1d(57, 20, stride=stride, padding=padding, kernel_size=3, bias=False)
conv_gpu = copy.deepcopy(conv_cpu).to(device='mps')
x_cpu = conv_cpu(y_cpu)
y_gpu = y_cpu.to(device='mps')
x_gpu = conv_gpu(y_gpu)
self.assertEqual(x_cpu, x_gpu.cpu())
for stride in range(1, 4):
for padding in range(1, 4):
helper(stride, padding)
def test_conv1d_channels_last(self):
# https://github.com/pytorch/pytorch/issues/81557
model_cpu = torch.nn.Conv1d(1, 128, 3)
a_cpu = torch.arange((128 * 176), dtype=torch.float32)
a_cpu = a_cpu.view(128, 176, 1).permute(0, 2, 1)
out_cpu = model_cpu(a_cpu)
a_mps = a_cpu.detach().clone().to("mps")
model_mps = model_cpu.to("mps")
out_mps = model_mps(a_mps)
self.assertEqual(out_cpu, out_mps.cpu(), rtol=2.6e-05, atol=2e-04)
def test_conv_transpose_1d_all_strides(self):
# https://github.com/pytorch/pytorch/issues/82711
def helper(stride):
y_cpu = torch.ones(1, 1, 2)
deconv_cpu = nn.ConvTranspose1d(in_channels=1, out_channels=1, kernel_size=1, stride=stride, bias=False, padding=1)
deconv_cpu.weight.data = torch.ones(1, 1, 2)
deconv_gpu = copy.deepcopy(deconv_cpu).to(device='mps')
x_cpu = deconv_cpu(y_cpu)
y_gpu = y_cpu.to(device='mps')
x_gpu = deconv_gpu(y_gpu)
self.assertEqual(x_cpu, x_gpu.cpu())
[helper(stride) for stride in [1, 2, 3]]
def test_conv_transpose_1d_nn_functional(self):
# https://github.com/pytorch/pytorch/issues/82563
tin = torch.rand((1, 512, 1245), dtype=torch.float32)
tparams = torch.rand((512, 256, 16), dtype=torch.float32)
tbias = torch.rand((256), dtype=torch.float32)
device = 'cpu'
tcpu = torch.nn.functional.conv_transpose1d(tin.to(device), tparams.to(device), tbias.to(device), stride=8, padding=4)
device = 'mps'
tgpu = torch.nn.functional.conv_transpose1d(tin.to(device), tparams.to(device), tbias.to(device), stride=8, padding=4)
self.assertEqual(tcpu, tgpu.cpu(), rtol=2.6e-05, atol=2e-04)
def test_conv_backward_1d_channels_last(self):
def helper(shape, in_channels=1, out_channels=1, kernel_size=3, groups=1):
# https://github.com/pytorch/pytorch/issues/84511
conv_cpu = torch.nn.Conv1d(
in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, groups=groups).requires_grad_()
conv_mps = torch.nn.Conv1d(
in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, groups=groups).to("mps")
conv_mps.weight.data = conv_cpu.weight.data.detach().clone().to("mps").requires_grad_(True)
conv_mps.bias.data = conv_cpu.bias.data.detach().clone().to("mps").requires_grad_(True)
data = torch.rand(shape, dtype=torch.float32)
x_cpu = data.permute(0, 2, 1).contiguous().requires_grad_(True)
x_mps = data.permute(0, 2, 1).detach().clone().to("mps").contiguous().requires_grad_(True)
res_cpu = conv_cpu(x_cpu)
res_mps = conv_mps(x_mps)
self.assertEqual(res_cpu, res_mps)
res_cpu = res_cpu.sum().backward()
res_mps = res_mps.sum().backward()
self.assertEqual(conv_cpu.weight.grad, conv_mps.weight.grad, rtol=2.6e-05, atol=2e-04)
self.assertEqual(x_cpu.grad, x_mps.grad)
helper(shape=(1, 176, 1))
helper(shape=(2, 12, 1))
helper(shape=(3, 176, 1))
helper(shape=(4, 376, 1))
helper(shape=(1024, 376, 9), in_channels=9, out_channels=1, groups=1)
helper(shape=(1024, 376, 9), in_channels=9, out_channels=9, groups=3)
# Regression test for https://github.com/pytorch/pytorch/issues/140902
# And https://github.com/pytorch/pytorch/issues/142344 (adding grad for input)
ic, oc, ks, f = 2, 5, 3, 7
conv = torch.nn.Conv1d(ic, oc, kernel_size=ks, padding=1).to("mps")
inp = torch.rand(1, ic, f, device="mps", requires_grad=True)
out = conv(inp)
grad_in = torch.rand(1, oc, f, device="mps")
grad_in_cl = torch.empty(1, f, oc, device="mps").transpose(1, 2)
grad_in_cl[:] = grad_in
# It does not matter whether grad_in contiguous, or channels last, results should equal to each other
grad_rc = torch.autograd.grad((out,), (inp, conv.weight, conv.bias), (grad_in,), retain_graph=True)
grad_rc_cl = torch.autograd.grad((out,), (inp, conv.weight, conv.bias), (grad_in_cl,), retain_graph=True)
self.assertEqual(grad_rc[0], grad_rc_cl[0])
self.assertEqual(grad_rc[1], grad_rc_cl[1])
self.assertEqual(grad_rc[2], grad_rc_cl[2])
def test_conv1d_contiguous(self):
model_cpu = torch.nn.Conv1d(1, 128, 3)
a_cpu = torch.ones(128, 1, 176)
out_cpu = model_cpu(a_cpu)
a_mps = a_cpu.detach().clone().to("mps")
model_mps = model_cpu.to("mps")
out_mps = model_mps(a_mps)
self.assertEqual(out_cpu.shape, out_mps.shape)
self.assertEqual(out_cpu, out_mps.cpu())
def test_conv2d_all_strides_paddings(self):
# https://github.com/pytorch/pytorch/issues/83180
def helper(N, C, H, W, groups, input_mem_format, weight_mem_format, permute_data):
x_cpu = torch.randn(N, C, H, W).to(memory_format=input_mem_format).requires_grad_()
x_mps = x_cpu.detach().clone().to(device='mps').requires_grad_()
if permute_data:
x_cpu.permute(0, 2, 3, 1)
x_mps.permute(0, 2, 3, 1)
for strideX in range(1, 4):
for strideY in range(1, 4):
conv_cpu = torch.nn.Conv2d(
in_channels=N, out_channels=C, kernel_size=H, groups=groups, stride=(strideX, strideY)).requires_grad_()
conv_cpu.weight.data = conv_cpu.weight.to(memory_format=weight_mem_format).requires_grad_()
conv_mps = torch.nn.Conv2d(
in_channels=N, out_channels=C, kernel_size=H, groups=groups, stride=(strideX, strideY), device="mps")
conv_mps.weight.data = conv_cpu.weight.data.detach().clone().to("mps").requires_grad_()
conv_mps.bias.data = conv_cpu.bias.data.detach().clone().to("mps").requires_grad_()
res_cpu = conv_cpu(x_cpu)
res_mps = conv_mps(x_mps)
self.assertEqual(res_cpu, res_mps.cpu(), rtol=1e-03, atol=1e-05)
res_cpu = res_cpu.sum().backward()
res_mps = res_mps.sum().backward()
self.assertEqual(res_cpu, res_mps, rtol=2.6e-05, atol=2e-04)
self.assertEqual(conv_cpu.weight.grad, conv_mps.weight.grad, rtol=2.6e-05, atol=2e-04)
self.assertEqual(conv_cpu.bias.grad, conv_mps.bias.grad)
self.assertEqual(x_cpu.grad, x_mps.grad)
for mem_format_input in [torch.contiguous_format, torch.channels_last]:
for mem_format_weight in [torch.contiguous_format, torch.channels_last]:
for permute_data in [True, False]:
helper(2, 2, 3, 6, 1, mem_format_input, mem_format_weight, permute_data)
helper(10, 10, 4, 6, 2, mem_format_input, mem_format_weight, permute_data)
helper(32, 32, 4, 6, 2, mem_format_input, mem_format_weight, permute_data)
def test_conv_transpose_2d_strided(self):
def helper(m_cpu, memory_format):
m_mps = copy.deepcopy(m_cpu).requires_grad_()
m_mps.weight.data = m_cpu.weight.data.detach().clone().to("mps").requires_grad_()
m_mps.bias.data = m_cpu.bias.data.detach().clone().to("mps").requires_grad_()
input_cpu = torch.randn(20, 16, 50, 100).to(memory_format=memory_format).requires_grad_()
input_mps = input_cpu.detach().clone().to("mps")
output_cpu = m_cpu(input_cpu)
output_mps = m_mps(input_mps)
self.assertEqual(output_cpu, output_mps)
for mem_format_input in [torch.contiguous_format, torch.channels_last]:
# With square kernels and equal stride
helper(nn.ConvTranspose2d(16, 33, 3, stride=2).requires_grad_(), mem_format_input)
# non-square kernels and unequal stride and with padding
helper(nn.ConvTranspose2d(16, 33, (3, 5), stride=(2, 1), padding=(4, 2)).requires_grad_(), mem_format_input)
def test_conv_transpose_2d_specified_output(self):
input_cpu = torch.randn(1, 16, 12, 12)
input_mps = input_cpu.detach().clone().to("mps")
downsample_cpu = nn.Conv2d(16, 16, 3, stride=2, padding=1)
downsample_mps = nn.Conv2d(16, 16, 3, stride=2, padding=1, device="mps")
downsample_mps.weight.data = downsample_cpu.weight.data.detach().clone().to("mps").requires_grad_()
downsample_mps.bias.data = downsample_cpu.bias.data.detach().clone().to("mps").requires_grad_()
upsample_cpu = nn.ConvTranspose2d(16, 16, 3, stride=2, padding=1)
upsample_mps = nn.ConvTranspose2d(16, 16, 3, stride=2, padding=1, device="mps")
upsample_mps.weight.data = upsample_cpu.weight.data.detach().clone().to("mps").requires_grad_()
upsample_mps.bias.data = upsample_cpu.bias.data.detach().clone().to("mps").requires_grad_()
h_cpu = downsample_cpu(input_cpu)
h_mps = downsample_mps(input_mps)
self.assertEqual(h_cpu, h_mps)
size_cpu = h_cpu.size()
size_mps = h_mps.size()
self.assertEqual(size_cpu, size_mps)
output_cpu = upsample_cpu(h_cpu, output_size=input_cpu.size())
output_mps = upsample_mps(h_mps, output_size=input_mps.size())
self.assertEqual(output_cpu, output_mps)
self.assertEqual(output_cpu.size(), output_mps.size())
def test_conv2d_single_stride(self):
y_cpu = torch.randn(2, 2, 3, 6)
y_gpu = y_cpu.to(device='mps')
for stride in range(1, 4):
conv_cpu = torch.nn.Conv2d(in_channels=2, out_channels=2, kernel_size=3, stride=stride)
conv_gpu = copy.deepcopy(conv_cpu).to(device='mps')
x_cpu = conv_cpu(y_cpu)
x_gpu = conv_gpu(y_gpu)
self.assertEqual(x_cpu, x_gpu.cpu(), rtol=1e-03, atol=1e-05)
def test_conv3d_single_stride(self):
# Conv3d is only available from MacOS 13.2 onwards
y_cpu = torch.randn(2, 2, 3, 6)
y_gpu = y_cpu.to(device='mps')
for stride in range(1, 4):
conv_cpu = torch.nn.Conv3d(in_channels=2, out_channels=2, kernel_size=2, stride=stride)
conv_gpu = copy.deepcopy(conv_cpu).to(device='mps')
x_cpu = conv_cpu(y_cpu)
x_gpu = conv_gpu(y_gpu)
self.assertEqual(x_cpu, x_gpu.cpu(), rtol=1e-03, atol=1e-05)
def test_grid_sample(self):
def test(N, C, H, W, mode, padding_mode, align_corners, input_requires_grad):
def test_shape(N, C, IH, IW, H, W, mode, padding_mode, align_corners):
for grid_dim_contig_order in [(0, 1, 2, 3), (0, 3, 1, 2), (3, 0, 1, 2), (0, 2, 1, 3)]:
# grid_dim_contig_order specifies the dimension order that can
# make grid to be contiguous.
# i.e., grid.permute(grid_dim_contig_order) is contiguous.
# e.g., with grid_dim_contig_order=[0, 3, 1, 2], grid should be
# initialized with contiguous tensor of shape [N, 2, H, W]
# and permuted to [N, H, W, 2] afterwards.
grid_shape = [N, H, W, 2]
grid_init_shape = [grid_shape[d] for d in grid_dim_contig_order]
grid_fwd_permute = [None, None, None, None]
for i, d in enumerate(grid_dim_contig_order):
grid_fwd_permute[d] = i
def get_grid(device='cpu', data=None):
if data is not None:
assert list(data.shape) == grid_shape
data = data.permute(grid_dim_contig_order).to(device)
else:
data = torch.randn(grid_init_shape, device=device)
grid = data.permute(grid_fwd_permute)
assert grid.permute(grid_dim_contig_order).is_contiguous()
return grid
input_cpu = torch.randn(C, N, IH, IW).transpose(0, 1).requires_grad_(input_requires_grad)
grid_cpu = get_grid().requires_grad_()
out_cpu = F.grid_sample(input_cpu, grid_cpu, mode=mode, padding_mode=padding_mode,
align_corners=align_corners)
self.assertEqual(out_cpu.size(), torch.Size([N, C, H, W]))
gradients = torch.randn_like(out_cpu)
out_cpu.backward(gradients)
# Compare against unvectorized CPU fallback
# NOTE [ grid_sample CPU fallback ]
# grid_sample uses AVX for 2d images, but that requires 32-bit indexing for
# 32-bit floats. So we also have a fallback that is used only for float tensors
# requiring 64-bit indexing. That requires too much memory to run on CI, so we
# also export the fallback and test it here to ensure feature parity with
# the vectorized version.
input_fallback = input_cpu.float().detach_().requires_grad_()
grid_fallback = grid_cpu.float().detach_().requires_grad_()
out_fallback = torch._grid_sampler_2d_cpu_fallback(
input_fallback, grid_fallback,
F.GRID_SAMPLE_INTERPOLATION_MODES[mode],
F.GRID_SAMPLE_PADDING_MODES[padding_mode],
align_corners)
self.assertEqual(out_fallback, out_cpu.float(), atol=1e-5, rtol=5e-5)
out_fallback.backward(gradients.float())
if input_requires_grad:
self.assertEqual(input_fallback.grad, input_cpu.grad.float(), atol=1e-4, rtol=5e-5)
self.assertEqual(grid_fallback.grad, grid_cpu.grad.float(), atol=1e-4, rtol=5e-5)
input_mps = input_cpu.detach().transpose(0, 1).to("mps").transpose(0, 1).requires_grad_(input_requires_grad)
grid_mps = get_grid('mps', grid_cpu.detach()).requires_grad_()
out_mps = F.grid_sample(input_mps, grid_mps, mode=mode, padding_mode=padding_mode, align_corners=align_corners)
self.assertEqual(out_cpu, out_mps)
out_mps.backward(gradients.to("mps"))
if input_requires_grad:
self.assertEqual(input_cpu.grad, input_mps.grad)
self.assertEqual(grid_cpu.grad, grid_mps.grad, atol=5e-5, rtol=0)
# check that zero-dimensional input strides don't error out
base_input = torch.randn(N, C, 1, IW)
input_cpu = base_input.expand_as(input_mps).requires_grad_(input_requires_grad)
out_cpu = F.grid_sample(input_cpu, grid_cpu, mode=mode, padding_mode=padding_mode,
align_corners=align_corners)
input_mps = base_input.to("mps").expand_as(input_mps).requires_grad_(input_requires_grad)
out_mps = F.grid_sample(input_mps, grid_mps, mode=mode, padding_mode=padding_mode, align_corners=align_corners)
self.assertEqual(out_cpu, out_mps)
# test same size output
test_shape(N, C, H, W, H, W, mode, padding_mode, align_corners)
# test larger output
N = random.randint(2, 8)
C = random.randint(2, 8)
IH = random.randint(2, 8)
IW = random.randint(2, 8)
H = random.randint(IH + 1, 12)
W = random.randint(IW + 1, 12)
test_shape(N, C, IH, IW, H, W, mode, padding_mode, align_corners)
# test smaller output
N = random.randint(2, 8)
C = random.randint(2, 8)
IH = random.randint(2, 8)
IW = random.randint(2, 8)
H = random.randint(2, IH)
W = random.randint(2, IW)
test_shape(N, C, IH, IW, H, W, mode, padding_mode, align_corners)
# test 1x1 inpput
N = random.randint(2, 8)
C = random.randint(2, 8)
IH = 1
IW = 1
H = random.randint(2, 5)
W = random.randint(2, 5)
test_shape(N, C, IH, IW, H, W, mode, padding_mode, align_corners)
# testing empty grid
N = random.randint(2, 8)
C = random.randint(2, 8)
IH = random.randint(2, 8)
IW = random.randint(2, 8)
W = random.randint(3, IW + 2)
test_shape(N, C, IH, IW, 0, W, mode, padding_mode, align_corners)
# testing empty channel
N = random.randint(2, 8)
IH = random.randint(2, 8)
IW = random.randint(2, 8)
H = random.randint(3, IH + 2)
W = random.randint(3, IW + 2)
test_shape(N, 0, IH, IW, H, W, mode, padding_mode, align_corners)
# testing empty batch
C = random.randint(2, 8)
IH = random.randint(2, 8)
IW = random.randint(2, 8)
H = random.randint(3, IH + 2)
W = random.randint(3, IW + 2)
test_shape(0, C, IH, IW, H, W, mode, padding_mode, align_corners)
for mode in ('bilinear', 'nearest'):
for padding_mode in ('zeros', 'reflection'):
for align_corners in (True, False):
# test known input
input = torch.arange(1., 11, device="mps").view(1, 1, 2, 5)
grid = torch.tensor(
[[[-0.9, -4.1], [0, 0.2000], [1, -1], [-0.333, 1e-6], [0.5, 1.0]],
[[-1.0, -0.5], [0, 0.3333], [1, -1], [-0.200, 1e-6], [1.5, 0.5]]], device="mps").view(1, 2, 5, 2)
if mode == 'bilinear':
if padding_mode == 'zeros':
if align_corners:
groundtruth = torch.tensor(
[[0.0000, 6.0000000000, 5.0000, 4.8340, 9.0000],
[2.2500, 6.3332500450, 5.0000, 5.1000, 0.0000]], device="mps").view(1, 1, 2, 5)
else:
groundtruth = torch.tensor(
[[0.0000, 6.5000000000, 1.2500, 4.6675000191, 4.6250],
[0.5000, 7.1665000916, 1.2500, 5.0000000000, 0.0000]], device="mps").view(1, 1, 2, 5)
elif padding_mode == 'border':
if align_corners:
groundtruth = torch.tensor(
[[1.2000, 6.0000000000, 5.0000, 4.8340, 9.0000],
[2.2500, 6.3332500450, 5.0000, 5.1000, 8.7500]], device="mps").view(1, 1, 2, 5)
else:
groundtruth = torch.tensor(
[[1.0000, 6.5000000000, 5.0000, 4.6675000191, 9.2500],
[1.0000, 7.1665000916, 5.0000, 5.0000000000, 10.0000]], device="mps").view(1, 1, 2, 5)
elif padding_mode == 'reflection':
if align_corners:
groundtruth = torch.tensor(
[[3.4500, 6.0000000000, 5.0000, 4.8340, 9.0000],
[2.2500, 6.3332500450, 5.0000, 5.1000, 7.7500]], device="mps").view(1, 1, 2, 5)
else:
groundtruth = torch.tensor(
[[3.0000004768, 6.5000000000, 5.0000, 4.6675000191, 9.2500],
[1.0000000000, 7.1665000916, 5.0000, 5.0000000000, 9.2500]], device="mps").view(1, 1, 2, 5)
else:
raise AssertionError(f"missing groundtruth test for padding mode '{padding_mode}'")
elif mode == 'nearest':
if padding_mode == 'zeros':
if align_corners:
groundtruth = torch.tensor(
[[0., 8., 5., 7., 9.],
[1., 8., 5., 8., 0.]], device="mps").view(1, 1, 2, 5)
else:
groundtruth = torch.tensor(
[[0., 8., 5., 7., 0.],
[1., 8., 5., 8., 0.]], device="mps").view(1, 1, 2, 5)
elif padding_mode == 'border':
if align_corners:
groundtruth = torch.tensor(
[[1., 8., 5., 7., 9.],
[1., 8., 5., 8., 10.]], device="mps").view(1, 1, 2, 5)
else:
groundtruth = torch.tensor(
[[1., 8., 5., 7., 9.],
[1., 8., 5., 8., 10.]], device="mps").view(1, 1, 2, 5)
elif padding_mode == 'reflection':
if align_corners:
groundtruth = torch.tensor(
[[1., 8., 5., 7., 9.],
[1., 8., 5., 8., 9.]], device="mps").view(1, 1, 2, 5)
else:
groundtruth = torch.tensor(
[[1., 8., 5., 7., 9.],
[1., 8., 5., 8., 9.]], device="mps").view(1, 1, 2, 5)
else:
raise AssertionError(f"missing groundtruth test for padding mode '{padding_mode}'")
elif mode == 'bicubic':
if padding_mode == 'zeros':
if align_corners:
groundtruth = torch.tensor(
[[-0.10424726, 7.1400003, 5.0000, 5.7842274, 9.0000],
[2.4492188, 7.4814040, 5.0000, 6.0277520, 0.0000]], device="mps").view(1, 1, 2, 5)
else:
groundtruth = torch.tensor(
[[0.00000, 7.6287503, 1.0625, 5.5977230, 5.3270264],
[0.40625, 8.0288770, 1.0625, 5.9375067, -0.3515625]], device="mps").view(1, 1, 2, 5)
elif padding_mode == 'border':
if align_corners:
groundtruth = torch.tensor(
[[1.1520010, 6.0599990, 5.0000, 4.870930, 9.0000000],
[2.1328125, 6.4258375, 5.0000, 5.076003, 8.8671875]], device="mps").view(1, 1, 2, 5)
else:
groundtruth = torch.tensor(
[[0.894531, 6.6050020, 4.625, 4.7138715, 9.800781],
[0.906250, 7.2822485, 4.625, 5.0000052, 10.00000]], device="mps").view(1, 1, 2, 5)
elif padding_mode == 'reflection':
if align_corners:
groundtruth = torch.tensor(
[[3.1822524, 6.239998, 5.0000, 4.8709273, 9.00000],
[1.7812500, 6.703594, 5.0000, 5.0760007, 8.21875]], device="mps").view(1, 1, 2, 5)
else:
groundtruth = torch.tensor(
[[2.7993753, 6.6050020, 4.25, 4.7138715, 10.269531],
[0.8125000, 7.2822485, 4.25, 5.0000052, 9.332031]], device="mps").view(1, 1, 2, 5)
else:
raise AssertionError(f"missing groundtruth test for padding mode '{padding_mode}'")
else:
raise AssertionError(f"missing groundtruth test for interpolation mode '{mode}'")
output = F.grid_sample(input, grid, mode=mode, padding_mode=padding_mode,
align_corners=align_corners)
self.assertEqual(output, groundtruth, atol=1e-5, rtol=0,
msg=f"groundtruth comparison failed for mode={mode}, "
f"padding_mode={padding_mode}")
class TestAdvancedIndexing(TestCaseMPS):
supported_dtypes = [torch.float32, torch.float16, torch.int64, torch.int32, torch.int16, torch.uint8]
supported_np_dtypes = [np.float32, np.float16, np.int64, np.int32, np.int16, np.uint8]
def test_nonzero_no_warning(self):
device = "mps"
t = torch.randn((2, 2), device=device)
with warnings.catch_warnings(record=True) as w:
warnings.simplefilter("always")
torch.nonzero(t)
t.nonzero()
self.assertEqual(len(w), 0)
def test_nonzero(self):
def helper(dtype):
device = "mps"
shapes = [
torch.Size((12,)),
torch.Size((12, 1)),
torch.Size((1, 12)),
torch.Size((6, 2)),
torch.Size((3, 2, 2)),
torch.Size((5, 5, 5)),
]
def gen_nontrivial_input(shape, dtype, device):
if dtype != torch.bfloat16:
return torch.randint(2, shape, device=device, dtype=dtype)
else:
# windows does not work for bfloat16 randing
return torch.randint(2, shape, device=device, dtype=torch.float).to(dtype)
for shape in shapes:
tensor = gen_nontrivial_input(shape, dtype, device)
dst1 = torch.nonzero(tensor, as_tuple=False)
dst2 = tensor.nonzero(as_tuple=False)
dst3 = torch.empty([], dtype=torch.long, device=device)
dst3 = dst3.resize_(0)
torch.nonzero(tensor, out=dst3)
np_array = tensor.cpu().numpy() if dtype != torch.bfloat16 else tensor.float().cpu().numpy()
np_result = torch.from_numpy(np.stack(np_array.nonzero())).t()
self.assertEqual(dst1.cpu(), np_result, atol=0, rtol=0)
self.assertEqual(dst2.cpu(), np_result, atol=0, rtol=0)
self.assertEqual(dst3.cpu(), np_result, atol=0, rtol=0)
tup1 = torch.nonzero(tensor, as_tuple=True)
tup2 = tensor.nonzero(as_tuple=True)
tup1 = torch.stack(tup1).t().cpu()
tup2 = torch.stack(tup2).t().cpu()
self.assertEqual(tup1, np_result, atol=0, rtol=0)
self.assertEqual(tup2, np_result, atol=0, rtol=0)
[helper(dtype) for dtype in self.supported_dtypes]
def test_nonzero_astuple_out(self):
device = "mps"
t = torch.randn((3, 3, 3), device=device)
out = torch.empty([], dtype=torch.long, device=device)
out = out.resize_(0)
with self.assertRaises(RuntimeError):
torch.nonzero(t, as_tuple=True, out=out)
self.assertEqual(torch.nonzero(t, as_tuple=False, out=out), torch.nonzero(t, out=out))
# Verifies that JIT script cannot handle the as_tuple kwarg
# See Issue https://github.com/pytorch/pytorch/issues/45499.
def _foo(t):
tuple_result = torch.nonzero(t, as_tuple=True)
nontuple_result = torch.nonzero(t, as_tuple=False)
out = torch.empty_like(nontuple_result)
torch.nonzero(t, as_tuple=False, out=out)
return tuple_result, nontuple_result, out
with self.assertRaises(RuntimeError):
scripted_foo = torch.jit.script(_foo)
# Verifies that JIT tracing works fine
traced_foo = torch.jit.trace(_foo, t)
traced_tuple, traced_nontuple, traced_out = traced_foo(t)
expected_tuple = torch.nonzero(t, as_tuple=True)
expected_nontuple = torch.nonzero(t)
self.assertEqual(traced_tuple, expected_tuple)
self.assertEqual(traced_nontuple, expected_nontuple)
self.assertEqual(traced_out, expected_nontuple)
def test_nonzero_discontiguous(self):
device = "mps"
shape = (4, 4)
tensor = torch.randint(2, shape, device=device)
tensor_nc = torch.empty(shape[0], shape[1] * 2, device=device)[:, ::2].copy_(tensor)
dst1 = tensor.nonzero(as_tuple=False)
dst2 = tensor_nc.nonzero(as_tuple=False)
self.assertEqual(dst1, dst2, atol=0, rtol=0)
dst3 = torch.empty_like(dst1)
data_ptr = dst3.data_ptr()
# expect dst3 storage to be reused
torch.nonzero(tensor, out=dst3)
self.assertEqual(data_ptr, dst3.data_ptr())
self.assertEqual(dst1, dst3, atol=0, rtol=0)
# discontiguous out
dst4 = torch.empty(dst1.size(0), dst1.size(1) * 2, dtype=torch.long, device=device)[:, ::2]
data_ptr = dst4.data_ptr()
strides = dst4.stride()
torch.nonzero(tensor, out=dst4)
self.assertEqual(data_ptr, dst4.data_ptr())
self.assertEqual(dst1, dst4, atol=0, rtol=0)
self.assertEqual(strides, dst4.stride())
def test_nonzero_non_diff(self):
device = "mps"
x = torch.randn(10, requires_grad=True, device=device)
nz = x.nonzero()
self.assertFalse(nz.requires_grad)
def test_nonzero_multi_threading(self):
# Test that MPS doesn't crash if nonzero called concurrently
# See https://github.com/pytorch/pytorch/issues/100285
x = torch.rand(3, 3, device="mps")
t1 = threading.Thread(target=torch.nonzero, args=(x,))
t2 = threading.Thread(target=torch.nonzero, args=(x,))
t1.start()
t2.start()
def test_sliced_view_cast(self):
# This used to crash on MacOS Sequoia
# See https://github.com/pytorch/pytorch/issues/137800
x = torch.rand(16, 16, device='mps', dtype=torch.float16)
y = x[:, 0:2].view(torch.float32) + 1
def test_masked_select(self):
x = torch.randn(3, 4)
x_mps = x.to("mps")
mask = x.ge(0.5)
mask_mps = x_mps.ge(0.5)
res = torch.masked_select(x, mask)
res_mps = torch.masked_select(x_mps, mask_mps)
self.assertEqual(res, res_mps)
# examples from https://www.tutorialspoint.com/numpy/numpy_advanced_indexing.htm
def test_indexing_get(self):
def helper(dtype):
x_cpu = torch.tensor([[1, 2], [3, 4], [5, 6]], dtype=dtype)
x_mps = x_cpu.detach().clone().to("mps")
y_cpu = x_cpu[[0, 1, 2], [0, 1, 0]]
y_mps = x_mps[[0, 1, 2], [0, 1, 0]]
self.assertEqual(y_cpu, y_mps, str(dtype))
[helper(dtype) for dtype in self.supported_dtypes]
def test_indexing_select_corners(self):
def helper(dtype):
x_cpu = torch.tensor([[0, 1, 2], [3, 4, 5], [6, 7, 8], [9, 10, 11]], dtype=dtype)
x_mps = x_cpu.detach().clone().to("mps")
rows_cpu = torch.tensor([[0, 0], [3, 3]])
rows_mps = rows_cpu.detach().clone().to("mps")
cols_cpu = torch.tensor([[0, 2], [0, 2]])
cols_mps = cols_cpu.detach().clone().to("mps")
res_cpu = x_cpu[rows_cpu, cols_cpu]
res_mps = x_mps[rows_mps, cols_mps]
self.assertEqual(res_cpu, res_mps, str(dtype))
[helper(dtype) for dtype in self.supported_dtypes]
# FIXME: uint8 fails for this testcase, needs further debugging
def test_slicing_using_advanced_index_for_column(self):
def helper(dtype):
x_cpu = torch.tensor([[0, 1, 2], [3, 4, 5], [6, 7, 8], [9, 10, 11]], dtype=dtype)
x_mps = x_cpu.detach().clone().to("mps")
z_cpu = x_cpu[1:4, 1:3]
z_mps = x_mps[1:4, 1:3]
self.assertEqual(z_cpu, z_mps, str(dtype))
# using advanced index for column
y_cpu = x_cpu[1:4, [1, 2]]
y_mps = x_mps[1:4, [1, 2]]
self.assertEqual(y_cpu, y_mps, str(dtype))
# FIXME: use supported_dtypes once uint8 is fixed
[helper(dtype) for dtype in [torch.float32, torch.float16, torch.int64, torch.int32, torch.int16]]
def test_boolean_array_indexing(self):
def helper(dtype):
x_cpu = torch.tensor([[0, 1, 2], [3, 4, 5], [6, 7, 8], [9, 10, 11]], dtype=dtype)
x_mps = x_cpu.detach().clone().to("mps")
res_cpu = x_cpu[x_cpu > 5]
res_mps = x_mps[x_mps > 5]
self.assertEqual(res_cpu, res_mps, str(dtype))
for dtype in self.supported_dtypes:
helper(dtype)
def test_advanced_indexing_3D_get(self):
def helper(x_cpu):
x_mps = x_cpu.detach().clone().to("mps")
self.assertEqual(x_cpu[[1, 2], 3, :], x_mps[[1, 2], 3, :])
self.assertEqual(x_cpu[[0, 2], :, :], x_mps[[0, 2], :, :])
self.assertEqual(x_cpu[:, [1, 0], [1]], x_mps[:, [1, 0], [1]])
x_cpu = torch.tensor([[[0.1, 0.2, 0.3, 0.4],
[0.5, 0.6, 0.7, 0.8],
[0.9, 1.0, 1.1, 1.2],
[1.3, 1.4, 1.5, 1.6]],
[[2.0, 2.1, 2.2, 2.3],
[2.4, 2.5, 2.6, 2.7],
[2.8, 2.9, 3.0, 3.1],
[3.2, 3.3, 3.4, 3.5]],
[[4.0, 4.1, 4.2, 4.3],
[4.4, 4.5, 4.6, 4.7],
[4.8, 4.9, 5.0, 5.1],
[5.1, 5.2, 5.3, 5.4]]], device="cpu", dtype=torch.float32)
helper(x_cpu)
for idx in range(len(self.supported_np_dtypes)):
# torch.randn / torch.rand don't work with all dtypes
# Generate input data for all dtypes on Numpy them move to torch
input_t = np.random.random_sample(size=[3, 4, 4]).astype(self.supported_np_dtypes[idx])
inputCPU = torch.tensor(input_t, device='cpu', dtype=self.supported_dtypes[idx])
helper(inputCPU)
def test_advanced_indexing_3D_put(self):
def helper(x_cpu):
dtype = x_cpu.dtype
x_mps = x_cpu.detach().clone().to("mps")
out_tensor_cpu = torch.tensor([88, 99], dtype=dtype, device="cpu")
out_tensor_cpu_view = out_tensor_cpu[1:]
out_tensor_mps = torch.tensor([88, 99], dtype=dtype, device="mps")
out_tensor_mps_view = out_tensor_mps[1:]
x_cpu[[1, 2], 3, :] = out_tensor_cpu_view
x_mps[[1, 2], 3, :] = out_tensor_mps_view
self.assertEqual(x_cpu, x_mps)
x_cpu[[0, 2], :, :] = out_tensor_cpu_view
x_mps[[0, 2], :, :] = out_tensor_mps_view
self.assertEqual(x_cpu, x_mps)
x_cpu[:, [1, 0], [1]] = out_tensor_cpu_view
x_mps[:, [1, 0], [1]] = out_tensor_mps_view
self.assertEqual(x_cpu, x_mps)
x_cpu = torch.tensor([[[0.1, 0.2, 0.3, 0.4],
[0.5, 0.6, 0.7, 0.8],
[0.9, 1.0, 1.1, 1.2],
[1.3, 1.4, 1.5, 1.6]],
[[2.0, 2.1, 2.2, 2.3],
[2.4, 2.5, 2.6, 2.7],
[2.8, 2.9, 3.0, 3.1],
[3.2, 3.3, 3.4, 3.5]],
[[4.0, 4.1, 4.2, 4.3],
[4.4, 4.5, 4.6, 4.7],
[4.8, 4.9, 5.0, 5.1],
[5.1, 5.2, 5.3, 5.4]]], device="cpu", dtype=torch.float32)
helper(x_cpu)
for idx in range(len(self.supported_np_dtypes)):
# torch.randn / torch.rand don't work with all dtypes
# Generate input data for all dtypes on Numpy them move to torch
input_t = np.random.random_sample(size=[3, 4, 4]).astype(self.supported_np_dtypes[idx])
inputCPU = torch.tensor(input_t, device='cpu', dtype=self.supported_dtypes[idx])
helper(inputCPU)
def test_index_put_with_view_indices(self):
def helper(dtype):
target_cpu = torch.zeros([5, 3], device="cpu", dtype=dtype)
target_mps = torch.zeros([5, 3], device="mps", dtype=dtype)
indices_cpu = torch.tensor([[0, 1], [0, 1]], dtype=torch.int64, device="cpu")
indices_mps = torch.tensor([[0, 1], [0, 1]], dtype=torch.int64, device="mps")
value_cpu = torch.ones(indices_cpu.shape[0], device="cpu", dtype=dtype)
value_mps = torch.ones(indices_mps.shape[0], device="mps", dtype=dtype)
target_cpu.index_put_(tuple(indices_cpu.t()), value_cpu, accumulate=True)
target_mps.index_put_(tuple(indices_mps.t()), value_mps, accumulate=True)
self.assertEqual(target_cpu, target_mps)
[helper(dtype) for dtype in [torch.int32, torch.float]]
# tests from 'test_indexing.py'
def test_advancedindex_big(self, device="mps"):
reference = torch.arange(0, 123344, dtype=torch.int, device=device)
self.assertEqual(reference[[0, 123, 44488, 68807, 123343], ],
torch.tensor([0, 123, 44488, 68807, 123343], dtype=torch.int))
def test_set_item_to_scalar_tensor(self, device="mps"):
m = random.randint(1, 10)
n = random.randint(1, 10)
z = torch.randn([m, n], device=device)
a = 1.0
w = torch.tensor(a, requires_grad=True, device=device)
z[:, 0] = w
z.sum().backward()
self.assertEqual(w.grad, m * a)
def test_single_int(self, device="mps"):
v = torch.randn(5, 7, 3, device=device)
self.assertEqual(v[4].shape, (7, 3))
def test_multiple_int(self, device="mps"):
v = torch.randn(5, 7, 3, device=device)
self.assertEqual(v[4].shape, (7, 3))
self.assertEqual(v[4, :, 1].shape, (7,))
def test_none(self, device="mps"):
v = torch.randn(5, 7, 3, device=device)
self.assertEqual(v[None].shape, (1, 5, 7, 3))
self.assertEqual(v[:, None].shape, (5, 1, 7, 3))
self.assertEqual(v[:, None, None].shape, (5, 1, 1, 7, 3))
self.assertEqual(v[..., None].shape, (5, 7, 3, 1))
def test_step(self, device="mps"):
v = torch.arange(10, device=device)
self.assertEqual(v[::1], v)
self.assertEqual(v[::2].tolist(), [0, 2, 4, 6, 8])
self.assertEqual(v[::3].tolist(), [0, 3, 6, 9])
self.assertEqual(v[::11].tolist(), [0])
self.assertEqual(v[1:6:2].tolist(), [1, 3, 5])
def test_step_assignment(self, device="mps"):
v = torch.zeros(4, 4, device=device)
v[0, 1::2] = torch.tensor([3., 4.], device=device)
self.assertEqual(v[0].tolist(), [0, 3, 0, 4])
self.assertEqual(v[1:].sum(), 0)
def test_bool_indices(self, device="mps"):
v = torch.randn(5, 7, 3, device=device)
boolIndices = torch.tensor([True, False, True, True, False], dtype=torch.bool, device=device)
self.assertEqual(v[boolIndices].shape, (3, 7, 3))
self.assertEqual(v[boolIndices], torch.stack([v[0], v[2], v[3]]))
v = torch.tensor([True, False, True], dtype=torch.bool, device=device)
boolIndices = torch.tensor([True, False, False], dtype=torch.bool, device=device)
uint8Indices = torch.tensor([1, 0, 0], dtype=torch.uint8, device=device)
with warnings.catch_warnings(record=True) as w:
self.assertEqual(v[boolIndices].shape, v[uint8Indices].shape)
self.assertEqual(v[boolIndices], v[uint8Indices])
self.assertEqual(v[boolIndices], torch.tensor([True], dtype=torch.bool, device=device))
self.assertEqual(len(w), 2)
def test_bool_indices_accumulate(self, device="mps"):
mask = torch.zeros(size=(10, ), dtype=torch.uint8, device=device)
mask = mask > 0
y = torch.ones(size=(10, 10), device=device)
y.index_put_((mask, ), y[mask], accumulate=True)
self.assertEqual(y, torch.ones(size=(10, 10), device=device))
def test_multiple_bool_indices(self, device="mps"):
v = torch.randn(5, 7, 3, device=device)
# note: these broadcast together and are transposed to the first dim
mask1 = torch.tensor([1, 0, 1, 1, 0], dtype=torch.bool, device=device)
mask2 = torch.tensor([1, 1, 1], dtype=torch.bool, device=device)
self.assertEqual(v[mask1, :, mask2].shape, (3, 7))
def test_byte_mask(self, device="mps"):
v = torch.randn(5, 7, 3, device=device)
mask = torch.ByteTensor([1, 0, 1, 1, 0]).to(device)
with warnings.catch_warnings(record=True) as w:
self.assertEqual(v[mask].shape, (3, 7, 3))
self.assertEqual(v[mask], torch.stack([v[0], v[2], v[3]]))
self.assertEqual(len(w), 2)
v = torch.tensor([1.], device=device)
self.assertEqual(v[v == 0], torch.tensor([], device=device))
def test_byte_mask_accumulate(self, device="mps"):
mask = torch.zeros(size=(10, ), dtype=torch.uint8, device=device)
y = torch.ones(size=(10, 10), device=device)
with warnings.catch_warnings(record=True) as w:
warnings.simplefilter("always")
y.index_put_((mask, ), y[mask], accumulate=True)
self.assertEqual(y, torch.ones(size=(10, 10), device=device))
self.assertEqual(len(w), 2)
def test_index_put_accumulate_expanded_values(self, device="mps"):
t = torch.zeros((5, 2))
t_dev = t.to(device)
indices = [
torch.tensor([0, 1, 2, 3]),
torch.tensor([1, ]),
]
indices_dev = [i.to(device) for i in indices]
values0d = torch.tensor(1.0)
values1d = torch.tensor([1.0, ])
out_mps = t_dev.index_put_(indices_dev, values0d.to(device), accumulate=True)
out_cpu = t.index_put_(indices, values0d, accumulate=True)
self.assertEqual(out_mps.cpu(), out_cpu)
out_mps = t_dev.index_put_(indices_dev, values1d.to(device), accumulate=True)
out_cpu = t.index_put_(indices, values1d, accumulate=True)
self.assertEqual(out_mps.cpu(), out_cpu)
t = torch.zeros(4, 3, 2)
t_dev = t.to(device)
indices = [
torch.tensor([0, ]),
torch.arange(3)[:, None],
torch.arange(2)[None, :],
]
indices_dev = [i.to(device) for i in indices]
values1d = torch.tensor([-1.0, -2.0])
values2d = torch.tensor([[-1.0, -2.0], ])
out_mps = t_dev.index_put_(indices_dev, values1d.to(device), accumulate=True)
out_cpu = t.index_put_(indices, values1d, accumulate=True)
self.assertEqual(out_mps.cpu(), out_cpu)
out_mps = t_dev.index_put_(indices_dev, values2d.to(device), accumulate=True)
out_cpu = t.index_put_(indices, values2d, accumulate=True)
self.assertEqual(out_mps.cpu(), out_cpu)
def test_index_put_accumulate_non_contiguous(self, device="mps"):
t = torch.zeros((5, 2, 2))
t_dev = t.to(device)
t1 = t_dev[:, 0, :]
t2 = t[:, 0, :]
self.assertFalse(t1.is_contiguous())
self.assertFalse(t2.is_contiguous())
indices = [torch.tensor([0, 1]), ]
indices_dev = [i.to(device) for i in indices]
value = torch.randn(2, 2)
out_mps = t1.index_put_(indices_dev, value.to(device), accumulate=True)
out_cpu = t2.index_put_(indices, value, accumulate=True)
self.assertFalse(t1.is_contiguous())
self.assertFalse(t2.is_contiguous())
self.assertEqual(out_mps.cpu(), out_cpu)
def test_index_put_accumulate_with_optional_tensors(self, device="mps"):
# TODO: replace with a better solution.
# Currently, here using torchscript to put None into indices.
# on C++ it gives indices as a list of 2 optional tensors: first is null and
# the second is a valid tensor.
@torch.jit.script
def func(x, i, v):
idx = [None, i]
x.index_put_(idx, v, accumulate=True)
return x
n = 4
t = torch.arange(n * 2, dtype=torch.float32).reshape(n, 2)
t_dev = t.to(device)
indices = torch.tensor([1, 0])
indices_dev = indices.to(device)
value0d = torch.tensor(10.0)
value1d = torch.tensor([1.0, 2.0])
out_mps = func(t_dev, indices_dev, value0d.to("mps"))
out_cpu = func(t, indices, value0d)
self.assertEqual(out_mps.cpu(), out_cpu)
out_mps = func(t_dev, indices_dev, value1d.to("mps"))
out_cpu = func(t, indices, value1d)
self.assertEqual(out_mps.cpu(), out_cpu)
def test_index_put_accumulate_duplicate_indices(self, device="mps"):
for i in range(1, 128):
# generate indices by random walk, this will create indices with
# lots of duplicates interleaved with each other
delta = torch.empty(i, dtype=torch.float32, device=device).uniform_(-1, 1)
indices = delta.cumsum(0).long().to("mps")
# abs for int64 is not supported on mps, fallback on 'cpu' to calculate it
input = torch.randn(indices.cpu().abs().max().to("mps") + 1, device=device)
values = torch.randn(indices.size(0), device=device)
output = input.index_put((indices,), values, accumulate=True)
input_list = input.tolist()
indices_list = indices.tolist()
values_list = values.tolist()
for i, v in zip(indices_list, values_list):
input_list[i] += v
self.assertEqual(output, input_list)
def test_index_put_deterministic(self, device="mps"):
def helper(dtype, accumulate, deterministic, num_tests=128):
acc_expected = torch.tensor([233, 187, 360], device=device, dtype=dtype)
non_acc_expected = torch.tensor([38, 37, 39], device=device, dtype=dtype)
t_idx = torch.tensor(
[0, 0, 0, 0, 2, 2, 1, 0, 2, 1, 0, 1, 2, 1, 0, 2, 2, 2, 2, 2,
0, 0, 2, 1, 2, 1, 0, 0, 2, 0, 2, 1, 1, 2, 2, 0, 2, 1, 0, 2]
)
for _ in range(num_tests):
try:
torch.use_deterministic_algorithms(deterministic)
t = torch.zeros(3, dtype=dtype, device=device)
t.index_put_((t_idx,), torch.arange(len(t_idx), device=device, dtype=dtype), accumulate=accumulate)
if accumulate:
self.assertEqual(t, acc_expected)
else:
self.assertEqual(t, non_acc_expected)
finally:
torch.use_deterministic_algorithms(False)
for accumulate, deterministic in product((False, True), (False, True)):
dtype = torch.float if accumulate else torch.long
if not accumulate and not deterministic:
with self.assertRaisesRegex(AssertionError, "Tensor-likes are not equal!"):
helper(dtype, accumulate, deterministic)
else:
helper(dtype, accumulate, deterministic)
def test_multiple_byte_mask(self, device="mps"):
v = torch.randn(5, 7, 3, device=device)
# note: these broadcast together and are transposed to the first dim
mask1 = torch.ByteTensor([1, 0, 1, 1, 0]).to(device)
mask2 = torch.ByteTensor([1, 1, 1]).to(device)
with warnings.catch_warnings(record=True) as w:
warnings.simplefilter("always")
self.assertEqual(v[mask1, :, mask2].shape, (3, 7))
self.assertEqual(len(w), 2)
def test_byte_mask2d(self, device="mps"):
v = torch.randn(5, 7, 3, device=device)
c = torch.randn(5, 7, device=device)
num_ones = (c > 0).sum()
r = v[c > 0]
self.assertEqual(r.shape, (num_ones, 3))
def test_jit_indexing(self, device="mps"):
def fn1(x):
x[x < 50] = 1.0
return x
def fn2(x):
x[0:50] = 1.0
return x
scripted_fn1 = torch.jit.script(fn1)
scripted_fn2 = torch.jit.script(fn2)
data = torch.arange(100, device=device, dtype=torch.float)
out = scripted_fn1(data.detach().clone())
ref = torch.tensor(np.concatenate((np.ones(50), np.arange(50, 100))), device=device, dtype=torch.float)
self.assertEqual(out, ref)
out = scripted_fn2(data.detach().clone())
self.assertEqual(out, ref)
def test_int_indices(self, device="mps"):
v = torch.randn(5, 7, 3, device=device)
self.assertEqual(v[[0, 4, 2]].shape, (3, 7, 3))
self.assertEqual(v[:, [0, 4, 2]].shape, (5, 3, 3))
self.assertEqual(v[:, [[0, 1], [4, 3]]].shape, (5, 2, 2, 3))
def test_index_put_src_datatype(self):
def helper(device, dtype):
src = torch.ones(3, 2, 4, device=device, dtype=dtype)
vals = torch.ones(3, 2, 4, device=device, dtype=dtype)
indices = (torch.tensor([0, 2, 1]),)
res = src.index_put_(indices, vals, accumulate=True)
self.assertEqual(res.shape, src.shape)
[helper(device="mps", dtype=dtype) for dtype in [torch.float, torch.int32]]
def test_index_src_datatype(self):
def helper(device, dtype):
orig_dtype = dtype
if dtype is torch.bool:
dtype = torch.uint8
src = torch.ones(3, 2, 4, device=device, dtype=dtype)
if orig_dtype is torch.bool:
src = src == 1
# test index
res = src[[0, 2, 1], :, :]
self.assertEqual(res.shape, src.shape)
# test index_put, no accum
src[[0, 2, 1], :, :] = res
self.assertEqual(res.shape, src.shape)
[helper(device="mps", dtype=dtype) for dtype in [torch.float, torch.float16, torch.long, torch.bool]]
def test_int_indices2d(self, device="mps"):
# From the NumPy indexing example
x = torch.arange(0, 12, device=device).view(4, 3)
rows = torch.tensor([[0, 0], [3, 3]], device=device)
columns = torch.tensor([[0, 2], [0, 2]], device=device)
self.assertEqual(x[rows, columns].tolist(), [[0, 2], [9, 11]])
def test_int_indices_broadcast(self, device="mps"):
# From the NumPy indexing example
x = torch.arange(0, 12, device=device).view(4, 3)
rows = torch.tensor([0, 3], device=device)
columns = torch.tensor([0, 2], device=device)
result = x[rows[:, None], columns]
self.assertEqual(result.tolist(), [[0, 2], [9, 11]])
def test_empty_index(self, device="mps"):
x = torch.arange(0, 12, device=device).view(4, 3)
idx = torch.tensor([], dtype=torch.long, device=device)
self.assertEqual(x[idx].numel(), 0)
# empty assignment should have no effect but not throw an exception
y = x.clone()
y[idx] = -1
self.assertEqual(x, y)
mask = torch.zeros(4, 3, device=device).bool()
y[mask] = -1
self.assertEqual(x, y)
def test_empty_ndim_index(self, device="mps"):
x = torch.randn(5, device=device)
self.assertEqual(torch.empty(0, 2, device=device), x[torch.empty(0, 2, dtype=torch.int64, device=device)])
x = torch.randn(2, 3, 4, 5, device=device)
self.assertEqual(torch.empty(2, 0, 6, 4, 5, device=device),
x[:, torch.empty(0, 6, dtype=torch.int64, device=device)])
x = torch.empty(10, 0, device=device)
self.assertEqual(x[[1, 2]].shape, (2, 0))
self.assertEqual(x[[], []].shape, (0,))
with self.assertRaisesRegex(IndexError, 'for dimension with size 0'):
x[:, [0, 1]]
def test_empty_ndim_index_bool(self, device="mps"):
x = torch.randn(5, device=device)
self.assertRaises(IndexError, lambda: x[torch.empty(0, 2, dtype=torch.uint8, device=device)])
def test_empty_slice(self, device="mps"):
x = torch.randn(2, 3, 4, 5, device=device)
y = x[:, :, :, 1]
z = y[:, 1:1, :]
self.assertEqual((2, 0, 4), z.shape)
# this isn't technically necessary, but matches NumPy stride calculations.
self.assertEqual((60, 20, 5), z.stride())
self.assertTrue(z.is_contiguous())
def test_empty_reduce(self, device="mps"):
x = torch.rand(0, 3, device=device)
self.assertTrue(x.mean().isnan())
self.assertTrue(x.nanmean().isnan())
self.assertTrue(x.median().isnan())
self.assertTrue(x.nanmedian().isnan())
self.assertEqual(x.count_nonzero(), 0)
self.assertEqual(x.sum(), 0)
self.assertEqual(x.nansum(), 0)
self.assertRaises(RuntimeError, lambda: x.amax())
self.assertRaises(IndexError, lambda: x.amax(dim=0))
self.assertRaises(RuntimeError, lambda: x.amin())
self.assertRaises(IndexError, lambda: x.amin(dim=0))
def test_index_getitem_copy_bools_slices(self, device="mps"):
true = torch.tensor(1, dtype=torch.uint8, device=device)
false = torch.tensor(0, dtype=torch.uint8, device=device)
tensors = [torch.randn(2, 3, device=device), torch.tensor(3., device=device)]
for a in tensors:
self.assertNotEqual(a.data_ptr(), a[True].data_ptr())
self.assertEqual(torch.empty(0, *a.shape), a[False])
self.assertNotEqual(a.data_ptr(), a[true].data_ptr())
self.assertEqual(torch.empty(0, *a.shape), a[false])
self.assertEqual(a.data_ptr(), a[None].data_ptr())
self.assertEqual(a.data_ptr(), a[...].data_ptr())
def test_index_setitem_bools_slices(self, device="mps"):
true = torch.tensor(1, dtype=torch.uint8, device=device)
false = torch.tensor(0, dtype=torch.uint8, device=device)
tensors = [torch.randn(2, 3, device=device), torch.tensor(3, device=device)]
for a in tensors:
# prefix with a 1,1, to ensure we are compatible with numpy which cuts off prefix 1s
# (some of these ops already prefix a 1 to the size)
neg_ones = torch.ones_like(a) * -1
neg_ones_expanded = neg_ones.unsqueeze(0).unsqueeze(0)
a[True] = neg_ones_expanded
self.assertEqual(a, neg_ones)
a[False] = 5
self.assertEqual(a, neg_ones)
a[true] = neg_ones_expanded * 2
self.assertEqual(a, neg_ones * 2)
a[false] = 5
self.assertEqual(a, neg_ones * 2)
a[None] = neg_ones_expanded * 3
self.assertEqual(a, neg_ones * 3)
a[...] = neg_ones_expanded * 4
self.assertEqual(a, neg_ones * 4)
if a.dim() == 0:
with self.assertRaises(IndexError):
a[:] = neg_ones_expanded * 5
def test_index_scalar_with_bool_mask(self, device="mps"):
a = torch.tensor(1, device=device)
uintMask = torch.tensor(True, dtype=torch.uint8, device=device)
boolMask = torch.tensor(True, dtype=torch.bool, device=device)
self.assertEqual(a[uintMask], a[boolMask])
self.assertEqual(a[uintMask].dtype, a[boolMask].dtype)
a = torch.tensor(True, dtype=torch.bool, device=device)
self.assertEqual(a[uintMask], a[boolMask])
self.assertEqual(a[uintMask].dtype, a[boolMask].dtype)
def test_setitem_expansion_error(self, device="mps"):
true = torch.tensor(True, device=device)
a = torch.randn(2, 3, device=device)
# check prefix with non-1s doesn't work
a_expanded = a.expand(torch.Size([5, 1]) + a.size())
# NumPy: ValueError
with self.assertRaises(RuntimeError):
a[True] = a_expanded
with self.assertRaises(RuntimeError):
a[true] = a_expanded
def test_getitem_scalars(self, device="mps"):
zero = torch.tensor(0, dtype=torch.int64, device=device)
one = torch.tensor(1, dtype=torch.int64, device=device)
# non-scalar indexed with scalars
a = torch.randn(2, 3, device=device)
self.assertEqual(a[0], a[zero])
self.assertEqual(a[0][1], a[zero][one])
self.assertEqual(a[0, 1], a[zero, one])
self.assertEqual(a[0, one], a[zero, 1])
# indexing by a scalar should slice (not copy)
self.assertEqual(a[0, 1].data_ptr(), a[zero, one].data_ptr())
self.assertEqual(a[1].data_ptr(), a[one.int()].data_ptr())
self.assertEqual(a[1].data_ptr(), a[one.short()].data_ptr())
# scalar indexed with scalar
r = torch.randn((), device=device)
with self.assertRaises(IndexError):
r[:]
with self.assertRaises(IndexError):
r[zero]
self.assertEqual(r, r[...])
def test_setitem_scalars(self, device="mps"):
zero = torch.tensor(0, dtype=torch.int64)
# non-scalar indexed with scalars
a = torch.randn(2, 3, device=device)
a_set_with_number = a.clone()
a_set_with_scalar = a.clone()
b = torch.randn(3, device=device)
a_set_with_number[0] = b
a_set_with_scalar[zero] = b
self.assertEqual(a_set_with_number, a_set_with_scalar)
a[1, zero] = 7.7
self.assertEqual(7.7, a[1, 0])
# scalar indexed with scalars
r = torch.randn((), device=device)
with self.assertRaises(IndexError):
r[:] = 8.8
with self.assertRaises(IndexError):
r[zero] = 8.8
r[...] = 9.9
self.assertEqual(9.9, r)
def test_basic_advanced_combined(self, device="mps"):
# From the NumPy indexing example
x = torch.arange(0, 12, device=device).view(4, 3)
self.assertEqual(x[1:2, 1:3], x[1:2, [1, 2]])
self.assertEqual(x[1:2, 1:3].tolist(), [[4, 5]])
# Check that it is a copy
unmodified = x.clone()
x[1:2, [1, 2]].zero_()
self.assertEqual(x, unmodified)
# But assignment should modify the original
unmodified = x.clone()
x[1:2, [1, 2]] = 0
self.assertNotEqual(x, unmodified)
def test_int_assignment(self, device="mps"):
x = torch.arange(0, 4, device=device).view(2, 2)
x[1] = 5
self.assertEqual(x.tolist(), [[0, 1], [5, 5]])
x = torch.arange(0, 4, device=device).view(2, 2)
x[1] = torch.arange(5, 7, device=device)
self.assertEqual(x.tolist(), [[0, 1], [5, 6]])
def test_byte_tensor_assignment(self, device="mps"):
x = torch.arange(0., 16, device=device).view(4, 4)
b = torch.ByteTensor([True, False, True, False]).to(device)
value = torch.tensor([3., 4., 5., 6.], device=device)
with warnings.catch_warnings(record=True) as w:
x[b] = value
self.assertEqual(len(w), 1)
self.assertEqual(x[0], value)
self.assertEqual(x[1], torch.arange(4., 8, device=device))
self.assertEqual(x[2], value)
self.assertEqual(x[3], torch.arange(12., 16, device=device))
def test_variable_slicing(self, device="mps"):
x = torch.arange(0, 16, device=device).view(4, 4)
indices = torch.IntTensor([0, 1]).to(device)
i, j = indices
self.assertEqual(x[i:j], x[0:1])
def test_ellipsis_tensor(self, device="mps"):
x = torch.arange(0, 9, device=device).view(3, 3)
idx = torch.tensor([0, 2], device=device)
self.assertEqual(x[..., idx].tolist(), [[0, 2],
[3, 5],
[6, 8]])
self.assertEqual(x[idx, ...].tolist(), [[0, 1, 2],
[6, 7, 8]])
def test_invalid_index(self, device="mps"):
x = torch.arange(0, 16, device=device).view(4, 4)
self.assertRaisesRegex(TypeError, 'slice indices', lambda: x["0":"1"])
def test_out_of_bound_index(self, device="mps"):
x = torch.arange(0, 100, device=device).view(2, 5, 10)
self.assertRaisesRegex(IndexError, 'index 5 is out of bounds for dimension 1 with size 5', lambda: x[0, 5])
self.assertRaisesRegex(IndexError, 'index 4 is out of bounds for dimension 0 with size 2', lambda: x[4, 5])
self.assertRaisesRegex(IndexError, 'index 15 is out of bounds for dimension 2 with size 10',
lambda: x[0, 1, 15])
self.assertRaisesRegex(IndexError, 'index 12 is out of bounds for dimension 2 with size 10',
lambda: x[:, :, 12])
def test_zero_dim_index(self, device="mps"):
x = torch.tensor(10, device=device)
self.assertEqual(x, x.item())
def runner():
print(x[0])
return x[0]
self.assertRaisesRegex(IndexError, 'invalid index', runner)
def test_cpu_indices(self, device="mps"):
idx = torch.tensor([0, 1])
b = torch.zeros(2, device=device)
x = torch.ones(10, device=device)
x[idx] = b # index_put_
ref = torch.ones(10, device=device)
ref[:2] = 0
self.assertEqual(x, ref, atol=0, rtol=0)
out = x[idx] # index
self.assertEqual(out, torch.zeros(2, device=device), atol=0, rtol=0)
def test_nextafter(self, device="mps"):
for dtype in [torch.float16, torch.float32]:
x = torch.tensor([1, -1, 0, 0, 2, -2], device=device, dtype=dtype)
y = torch.tensor([2, -2, -1, 1, -3, 3], device=device, dtype=dtype)
na = torch.nextafter(x, y)
na_cpu = torch.nextafter(x.cpu(), y.cpu())
na_ge_x_mps = na.cpu() > x.cpu()
# greater is broken on MPS, see https://github.com/pytorch/pytorch/issues/125051
na_ge_x_cpu = na_cpu > x.cpu()
self.assertEqual(na_ge_x_mps, na_ge_x_cpu)
class TestRNNMPS(TestCaseMPS):
def _lstm_helper(self, num_layers, dtype, device, bidirectional=False, bias=True, batch_first=False,
seq_len=3, batch_size=5, hidden_size=7, input_size=11, backward=False):
rnn = nn.LSTM(
input_size=input_size,
hidden_size=hidden_size,
num_layers=num_layers,
bias=bias,
bidirectional=bidirectional,
batch_first=batch_first,
device="cpu"
)
bidirectional_mul = 2 if bidirectional else 1
if batch_first:
input = torch.randn(batch_size, seq_len, input_size, device="cpu", dtype=dtype, requires_grad=backward)
hx = torch.randn(num_layers * bidirectional_mul, batch_size, hidden_size, device="cpu", dtype=dtype,
requires_grad=backward)
cx = torch.randn(num_layers * bidirectional_mul, batch_size, hidden_size, device="cpu", dtype=dtype,
requires_grad=backward)
else:
input = torch.randn(seq_len, batch_size, input_size, device="cpu", dtype=dtype, requires_grad=backward)
hx = torch.randn(num_layers * bidirectional_mul, batch_size, hidden_size, device="cpu", dtype=dtype,
requires_grad=backward)
cx = torch.randn(num_layers * bidirectional_mul, batch_size, hidden_size, device="cpu", dtype=dtype,
requires_grad=backward)
cpu_output, (cpu_hn, cpu_cn) = rnn(input, (hx, cx))
rnn = rnn.to(device)
input = input.to(device)
hx = hx.to(device)
cx = cx.to(device)
output, (hn, cn) = rnn(input, (hx, cx))
self.assertEqual(cpu_output, output)
self.assertEqual(cpu_hn, hn)
self.assertEqual(cpu_cn, cn)
def get_backward_results(rnn, device, inp, hx, cx, output_grad_presented=True, states_grad_presented=True):
rnn = rnn.to(device)
inp, hx, cx = inp.to(device), hx.to(device), cx.to(device)
output, (hx_out, cx_out) = rnn(inp, (hx, cx))
assert output_grad_presented or states_grad_presented, "At least some outputs must be used"
f = 0
if output_grad_presented:
f = f + 3 * output.sum()
if states_grad_presented:
f = f + (hx_out * cx_out).sum()
param_names, params = zip(*rnn.named_parameters())
param_grads = zip(param_names, torch.autograd.grad(f, params, retain_graph=True))
input_grad, hx_grad, cx_grad = torch.autograd.grad(f, [inp, hx, cx])
return output, param_grads, input_grad, hx_grad, cx_grad
if backward:
grad_cases = [
dict(output_grad_presented=True, states_grad_presented=True),
dict(output_grad_presented=False, states_grad_presented=True),
dict(output_grad_presented=True, states_grad_presented=False),
]
for grad_case in grad_cases:
cpu_output, cpu_weights_grad, cpu_input_grad, cpu_hx_grad, cpu_cx_grad =\
get_backward_results(rnn, "cpu", input, hx, cx, **grad_case)
mps_output, mps_weights_grad, mps_input_grad, mps_hx_grad, mps_cx_grad =\
get_backward_results(rnn, device, input, hx, cx, **grad_case)
self.assertEqual(cpu_hx_grad, mps_hx_grad)
self.assertEqual(cpu_cx_grad, mps_cx_grad)
self.assertEqual(cpu_output, mps_output)
self.assertEqual(cpu_input_grad, mps_input_grad)
for (cpu_name, cpu_weight_grad), (mps_name, mps_weight_grad) in zip(cpu_weights_grad, mps_weights_grad):
self.assertEqual(cpu_weight_grad, mps_weight_grad,
f"mismatch in cpu:{cpu_name} vs mps:{mps_name}, layers: {num_layers}")
LSTM_TEST_CASES = [
{}, # default
dict(batch_first=True),
dict(bias=False),
dict(bidirectional=True),
dict(batch_first=True, bias=False),
dict(bidirectional=True, bias=False),
dict(bidirectional=True, batch_first=True),
dict(bidirectional=True, batch_first=True, bias=False)
]
def test_lstm_forward(self, device="mps", dtype=torch.float32):
for num_layers in [1, 2, 5]:
for test_options in self.LSTM_TEST_CASES:
self._lstm_helper(num_layers=num_layers, dtype=dtype, device=device, **test_options)
def test_lstm_backward(self, device="mps", dtype=torch.float32):
for num_layers in [1, 2, 5]:
for test_options in self.LSTM_TEST_CASES:
self._lstm_helper(num_layers=num_layers, dtype=dtype, device=device, backward=True, **test_options)
def test_RNN_cell_no_broadcasting(self):
def test(cell_module, input, hx, input_size, hidden_size):
cell = cell_module(input_size, hidden_size, device='mps')
self.assertRaises(RuntimeError, lambda: cell(input, hx))
def test_all(hidden_size, bad_hx, good_hx, input_size, input):
test(nn.RNNCell, input, bad_hx, input_size, hidden_size)
test(nn.GRUCell, input, bad_hx, input_size, hidden_size)
test(nn.LSTMCell, input, (bad_hx, good_hx), input_size, hidden_size)
test(nn.LSTMCell, input, (good_hx, bad_hx), input_size, hidden_size)
hidden_size = 20
input_size = 10
input = torch.randn(3, input_size, device='mps')
bad_hx = torch.randn(1, hidden_size, device='mps')
good_hx = torch.randn(3, hidden_size, device='mps')
# Test hidden/input batch size broadcasting
test_all(hidden_size, bad_hx, good_hx, input_size, input)
# Test hx's hidden_size vs module's hidden_size broadcasting
bad_hx = torch.randn(3, 1)
test_all(hidden_size, bad_hx, good_hx, input_size, input)
# Test input's input_size vs module's input_size broadcasting
bad_input = torch.randn(3, 1)
test_all(hidden_size, good_hx, good_hx, input_size, bad_input)
def test_LSTM_cell(self):
# this is just a smoke test; these modules are implemented through
# autograd so no Jacobian test is needed
for bias in (True, False):
input = torch.randn(3, 10, device='mps')
hx = torch.randn(3, 20, device='mps')
cx = torch.randn(3, 20, device='mps')
lstm = nn.LSTMCell(10, 20, bias=bias, device='mps')
for _ in range(6):
hx, cx = lstm(input, (hx, cx))
(hx + cx).sum().backward()
def test_LSTM_cell_forward_input_size(self):
input = torch.randn(3, 11, device='mps')
hx = torch.randn(3, 20, device='mps')
cx = torch.randn(3, 20, device='mps')
lstm = nn.LSTMCell(10, 20, device='mps')
self.assertRaises(Exception, lambda: lstm(input, (hx, cx)))
def test_LSTM_cell_forward_hidden_size(self):
input = torch.randn(3, 10, device='mps')
hx = torch.randn(3, 21, device='mps')
cx = torch.randn(3, 20, device='mps')
lstm = nn.LSTMCell(10, 20, device='mps')
self.assertRaises(Exception, lambda: lstm(input, (hx, cx)))
self.assertRaises(Exception, lambda: lstm(input, (cx, hx)))
class TestFallbackWarning(TestCase):
# TODO: Remove once test_testing.py is running on MPS devices
def test_no_warning_on_import(self):
out = subprocess.check_output(
[sys.executable, "-W", "always", "-c", "import torch"],
stderr=subprocess.STDOUT,
# On Windows, opening the subprocess with the default CWD makes `import torch`
# fail, so just set CWD to this script's directory
cwd=os.path.dirname(os.path.realpath(__file__)),).decode("utf-8")
self.assertEqual(out, "")
def _get_not_implemented_op(self):
# This can be changed once we actually implement 'lcm'
# Should return fn, args, kwargs, string_version
return (torch.lcm,
[torch.tensor([1], device='mps'), torch.tensor([2], device='mps')], {},
"torch.lcm(torch.tensor([1], device='mps'), torch.tensor([2], device='mps'))")
def test_error_on_not_implemented(self):
fn, args, kwargs, _ = self._get_not_implemented_op()
with self.assertRaisesRegex(NotImplementedError, "not currently implemented for the MPS device"):
fn(*args, **kwargs)
def test_warn_on_not_implemented_with_fallback(self):
_, _, _, op = self._get_not_implemented_op()
script = f"""
import os
# MUST happen before pytorch's import
os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1"
import warnings
with warnings.catch_warnings(record=True) as w:
import torch
if len(w) > 0:
print(w)
exit(1)
# This should run just fine and raise warning about perf
with warnings.catch_warnings(record=True) as w:
{op}
if len(w) != 1:
print(w)
exit(2)
"""
try:
subprocess.check_output(
[sys.executable, '-W', 'always', '-c', script],
stderr=subprocess.STDOUT,
# On Windows, opening the subprocess with the default CWD makes `import torch`
# fail, so just set CWD to this script's directory
cwd=os.path.dirname(os.path.realpath(__file__)),)
except subprocess.CalledProcessError as e:
if e.returncode == 1:
self.assertTrue(False, "There was a warning when importing torch when PYTORCH_ENABLE_MPS_FALLBACK is set." +
e.output.decode("utf-8"))
elif e.returncode == 2:
self.assertTrue(False, "There wasn't exactly one warning when running not implemented op with "
f"PYTORCH_ENABLE_MPS_FALLBACK set. {e.output}")
else:
self.assertTrue(False, "Running a not implemented op failed even though PYTORCH_ENABLE_MPS_FALLBACK is set. " +
e.output.decode("utf-8"))
class TestNoRegression(TestCase):
def test_assert_close(self):
a = torch.ones(1, device="mps")
b = torch.zeros(1, device="mps")
inf = a / b
nan = b / b
with self.assertRaisesRegex(AssertionError, "Tensor-likes are not close!"):
torch.testing.assert_close(a, inf)
# TODO: The NaN test is failing when all the tests in test_mps are run
# together but passes when run separately. There seems to be memory
# corruption which needs to be fixed for this test to be enabled.
# with self.assertRaisesRegex(AssertionError, "Tensor-likes are not close!"):
# torch.testing.assert_close(a, nan)
def test_double_error(self):
with self.assertRaisesRegex(TypeError, "the MPS framework doesn't support float64"):
a = torch.ones(2, dtype=torch.float64, device="mps")
a = torch.ones(2, device="mps")
with self.assertRaisesRegex(TypeError, "the MPS framework doesn't support float64"):
a = a.double()
def test_legacy_constructor(self):
a = torch.ones(2, device="mps")
b = a.new(1)
def test_serialization_map_location(self):
# Ensures that cpu Tensor can be loaded on mps
with tempfile.NamedTemporaryFile() as f:
x = torch.rand(2)
torch.save(x, f)
f.seek(0)
x2 = torch.load(f, map_location="mps")
self.assertEqual(x, x2)
self.assertEqual(x2.device.type, "mps")
# Ensures that mps Tensors can be loaded on mps
with tempfile.NamedTemporaryFile() as f:
x = torch.rand(2, device="mps")
torch.save(x, f)
f.seek(0)
x2 = torch.load(f)
self.assertEqual(x, x2)
self.assertEqual(x2.device.type, "mps")
# Ensures that mps Tensors can be loaded on cpu
with tempfile.NamedTemporaryFile() as f:
x = torch.rand(2, device="mps")
torch.save(x, f)
f.seek(0)
x2 = torch.load(f, map_location="cpu")
self.assertEqual(x, x2)
self.assertEqual(x2.device.type, "cpu")
# Ensures that `mps:0` Tensors can be loaded on mps
with tempfile.NamedTemporaryFile() as f:
x = torch.rand(2, device="mps:0")
torch.save(x, f)
f.seek(0)
x2 = torch.load(f, map_location="mps:0")
self.assertEqual(x, x2)
self.assertEqual(x2.device.type, "mps")
MPS_GRAD_DTYPES = [torch.float32, torch.float16]
def transform_opinfo_sample_to_cpu(sample):
"""Transforms opinfo.core.SampleInput from MPS to CPU"""
def transform_sample(x):
if not isinstance(x, torch.Tensor):
return x
requires_grad = x.requires_grad
conjugated = x.is_conj()
rc = x.detach()
rc = rc.cpu() if not conjugated else x.conj().cpu().conj()
return rc.requires_grad_(x.requires_grad)
cpu_sample = sample.transform(transform_sample)
# Transform kwargs `device="mps:0"` to `device="cpu"`
if cpu_sample.kwargs.get("device", "") == "mps:0":
cpu_sample.kwargs["device"] = "cpu"
return cpu_sample
class TestConsistency(TestCaseMPS):
# TODO: This is only used while some ops are being added.
# This list should contain all ops and dtypes eventually
# This can be generated automatically in the `new_mps_allowlist.txt` file
# by doing `EXPECTTEST_ACCEPT=1 python test_mps.py TestConsistencyCPU`
# You most likely do NOT want to modify this manually
BF16_LOW_PRECISION_LIST = {
'nn.functional.linear',
'nn.functional.gaussian_nll_loss',
}
FP16_LOW_PRECISION_LIST = {
'add', 'sub', 'div', 'addcdiv',
'__rdiv__', '__rmul__',
'nn.functional.huber_loss',
'true_divide', 'kron',
'gradient', 'var', 'std', 'std_mean', 'ldexp',
'linalg.vector_norm', 'lerp',
'addr', 'var_mean',
'var_mean_unbiased',
'acosh', 'asinh', 'asin',
'masked.std',
'nn.functional.avg_pool2d', # NS: Only for backward pass
'nn.functional.normalize',
'nn.functional.triplet_margin_loss',
'nn.functional.triplet_margin_with_distance_loss',
'nn.functional.batch_norm',
# NOTE: nn.functional.group_norm is here because 1 ULP difference in the mean
# output from the forward pass (tolerable) blew up into 8 ULP difference from
# the backward pass, and MPS uses fp16 accumulation anyway.
'nn.functional.group_norm',
'nn.functional.instance_norm',
'round', 'xlogy', 'addcmul',
'nn.functional.cross_entropy',
'nn.functional.binary_cross_entropy',
'nn.functional.nll_loss',
'nn.functional.max_pool2d',
'nn.functional.gelu',
'nn.functional.glu',
'_native_batch_norm_legit',
'_batch_norm_with_update',
'native_batch_norm',
'softmax',
'_softmax_backward_data',
'log_softmax',
'masked.softmax',
'masked.log_softmax',
'masked.softmin',
'nn.functional.kl_div',
'nn.functional.softmin',
'cross', 'linalg.cross',
'prod', 'masked.prod',
'nextafter',
'native_layer_norm',
'nn.functional.layer_norm',
'nn.functional.interpolate',
'nn.functional.upsample_nearest',
'norm', 'masked.normalize',
'arange', 'linspace',
'special.xlog1py',
# CPU accumulates sequantially, but GPU does in parallel
'_unsafe_masked_index_put_accumulate',
}
FP32_LOW_PRECISION_LIST = {
# conv2d, conv_transpose2d and conv_transpose3d results have a very small
# difference compared to CPU/CUDA, so we use lower precision on FP32
'nn.functional.conv2d',
'nn.functional.conv_transpose2d',
'nn.functional.conv_transpose3d',
'matmul', '__rmatmul__',
'linalg.multi_dot',
'addbmm',
}
def _compute_tolerances(self, op, dtype):
if (op.name in self.FP32_LOW_PRECISION_LIST) and dtype in [torch.float32, torch.complex64]:
return (1e-4, 3e-5)
if op.name in self.FP16_LOW_PRECISION_LIST and dtype in [torch.float16, torch.bfloat16]:
return (2e-2, 1e-2) if dtype == torch.float16 else (5e-2, 5e-2)
if op.name in self.BF16_LOW_PRECISION_LIST and dtype == torch.bfloat16:
return (5e-2, 5e-2)
if op.name in ['nn.functional.conv_transpose1d',
'nn.functional.conv_transpose2d',
'nn.functional.conv_transpose3d',
'__rmatmul__', 'addbmm', 'addmv',
'baddbmm', 'cov', 'matmul', 'mv'] and dtype in [torch.float16, torch.bfloat16]:
return (5e-2, 5e-2) if dtype == torch.float16 else (5e-2, 1e-1)
if op.name == "masked.mean":
return (7e-4, 2e-3)
if op.name == "native_layer_norm":
return (1e-4, 1.3e-5)
if op.name in ['fft.rfftn', 'fft.hfftn', 'fft.hfft2', 'fft.fft', 'fft.fftn', 'fft.rfft']:
# TODO: Investigate why this is needed
# See https://github.com/pytorch/pytorch/issues/120237
return (3e-5, 3e-5)
# TODO: Rounding is broken for linspace, see https://github.com/pytorch/pytorch/issues/137635
if op.name == 'linspace' and dtype in [torch.int8, torch.uint8, torch.int32, torch.int16, torch.int64]:
return (1.0, 0.0)
return (None, None)
# Used for accept mode only
NEW_ALLOW_LIST = defaultdict(list)
NEW_ALLOW_LIST_GRAD = defaultdict(list)
@ops(mps_ops_modifier(test_consistency_op_db), allowed_dtypes=MPS_DTYPES)
def test_output_match(self, device, dtype, op):
self.assertEqual(device, "mps:0")
include_conjugated_inputs = dtype.is_complex and op.test_conjugated_samples
def get_samples():
return op.sample_inputs(
device,
dtype,
requires_grad=(dtype.is_floating_point or dtype.is_complex),
include_conjugated_inputs=include_conjugated_inputs,
set_seed=True,
)
for mps_sample in get_samples():
#
# Forward check
#
cpu_sample = transform_opinfo_sample_to_cpu(mps_sample)
cpu_args = [cpu_sample.input] + list(cpu_sample.args)
cpu_kwargs = cpu_sample.kwargs
mps_args = [mps_sample.input] + list(mps_sample.args)
mps_kwargs = mps_sample.kwargs
# for tensor_split(), the second tensor arg ("tensor_indices_or_sections") must be on CPU only
if op.name == "tensor_split" and isinstance(mps_args[1], torch.Tensor):
mps_args[1] = cpu_args[1]
# Order of ops in index_put is not guaranteed, which can lead to large errors if inputs are
# not normalized
if op.name == "_unsafe_masked_index_put_accumulate" and dtype in [torch.bfloat16, torch.float16]:
mps_args[3] = F.normalize(mps_args[3])
cpu_args[3] = F.normalize(cpu_args[3])
with warnings.catch_warnings():
warnings.filterwarnings("ignore", category=UserWarning)
cpu_out = op(*cpu_args, **cpu_kwargs)
mps_out = op(*mps_args, **mps_kwargs)
atol, rtol = self._compute_tolerances(op, dtype)
if (op.name == "nn.functional.interpolate" and dtype == torch.uint8 and
cpu_kwargs.get("mode") == "bilinear" and
cpu_kwargs.get("recompute_scale_factor") is True and
cpu_kwargs.get("scale_factor") == 1.7):
# For 1/3, 2/3 scale factors results will not match CPU ones
# As MPS compute scales in floats, but CPU always used doubles, which results
# in slight numerical differences
atol, rtol = 1, 0
if op.name in ["_upsample_bilinear2d_aa", "_upsample_bicubic2d_aa"] and cpu_kwargs.get("scale_factors") == [1.7, 0.9]:
# Similar to the above, float vs double precision aresults in slight error
atol, rtol = 2e-5, 2e-6
if op.name in ["grid_sampler_3d", "asinh"]:
atol, rtol = 1e-4, 1e-4
if op.name == "kthvalue":
self.assertEqual(cpu_out[0], mps_out[0], atol=atol, rtol=rtol)
# kthvalue is non-deterministic if input has repeated values
dim = cpu_args[2] if len(cpu_args) > 2 else -1
keep_dim = cpu_args[3] if len(cpu_args) > 3 else False
values = torch.gather(mps_sample.input, dim, mps_out[1] if keep_dim else mps_out[1].unsqueeze(dim))
self.assertEqual(values if keep_dim else values.squeeze(dim), mps_out[0])
continue
self.assertEqual(cpu_out, mps_out, atol=atol, rtol=rtol)
@ops(mps_ops_grad_modifier(copy.deepcopy(test_consistency_op_db)), allowed_dtypes=MPS_GRAD_DTYPES)
def test_output_grad_match(self, device, dtype, op):
self.assertEqual(device, "mps:0")
def get_samples():
return op.sample_inputs(
device,
dtype,
requires_grad=(dtype.is_floating_point or dtype.is_complex),
# TODO: Enable per-sample seed setting and tweak tolerances / fix xfails
set_seed=False,
)
for mps_sample in get_samples():
#
# Forward check
#
cpu_sample = transform_opinfo_sample_to_cpu(mps_sample)
cpu_args = [cpu_sample.input] + list(cpu_sample.args)
cpu_kwargs = cpu_sample.kwargs
mps_args = [mps_sample.input] + list(mps_sample.args)
mps_kwargs = mps_sample.kwargs
# for tensor_split(), the second tensor arg ("tensor_indices_or_sections") must be on CPU only
if op.name == "tensor_split" and isinstance(mps_args[1], torch.Tensor):
mps_args[1] = cpu_args[1]
# Order of ops in index_put is not guaranteed, which can lead to large errors if inputs are
# not normalized
if op.name == "_unsafe_masked_index_put_accumulate" and dtype in [torch.bfloat16, torch.float16]:
mps_args[3] = F.normalize(mps_args[3])
cpu_args[3] = F.normalize(cpu_args[3])
with warnings.catch_warnings():
warnings.filterwarnings("ignore", category=UserWarning)
cpu_out = op(*cpu_args, **cpu_kwargs)
mps_out = op(*mps_args, **mps_kwargs)
if op.name == "unique" and cpu_kwargs["sorted"] is False:
continue
atol, rtol = self._compute_tolerances(op, dtype)
if op.name in ["renorm", "norm", "linalg.norm"] and dtype == torch.float16:
atol = 7e-4
rtol = 1.5e-3
self.assertEqual(cpu_out, mps_out, atol=atol, rtol=rtol)
#
# Backward check
#
cpu_out = (cpu_out,) if isinstance(cpu_out, torch.Tensor) else tuple(cpu_out)
mps_out = (mps_out,) if isinstance(mps_out, torch.Tensor) else tuple(mps_out)
def req_grad(t):
return isinstance(t, torch.Tensor) and t.requires_grad
diff_cpu_out = tuple(t for t in cpu_out if req_grad(t))
diff_mps_out = tuple(t for t in mps_out if req_grad(t))
diff_cpu_arg = tuple(t for t in pytree.tree_leaves((cpu_args, cpu_kwargs)) if req_grad(t))
diff_mps_arg = tuple(t for t in pytree.tree_leaves((mps_args, mps_kwargs)) if req_grad(t))
self.assertEqual(len(diff_cpu_out), len(diff_mps_out))
self.assertEqual(len(diff_cpu_arg), len(diff_mps_arg))
if len(diff_cpu_out) == 0:
continue
# rand_like does not work with certain dtypes, so cast to double and cast back
cpu_grad_outputs = tuple(torch.rand_like(t, dtype=torch.double).to(dtype=t.dtype) for t in diff_cpu_out)
mps_grad_outputs = tuple(t.to("mps") for t in cpu_grad_outputs)
# Compare computed gradients with cpu given random grad_output vector
# Sometimes when the derivative is 0, we just don't bother creating the graph
# allow_unused is needed in those cases.
cpu_grad_inputs = torch.autograd.grad(diff_cpu_out, diff_cpu_arg, grad_outputs=cpu_grad_outputs, allow_unused=True)
mps_grad_inputs = torch.autograd.grad(diff_mps_out, diff_mps_arg, grad_outputs=mps_grad_outputs, allow_unused=True)
if (
op.name == "nn.functional.pad"
and op.variant_test_name in ["replicate", "reflect"]
and dtype == torch.float16
):
atol = 1e-5
rtol = 1.5e-3
if op.name == "nn.functional.unfold" and dtype == torch.float16:
atol, rtol = 1e-3, 1e-3
# Order of ops in unsafe_masked_index backward is not guaranteed
# which leads to larger errors
if op.name == "_unsafe_masked_index" and dtype == torch.float16:
atol, rtol = 3e-3, 3e-3
if op.name == "logcumsumexp":
atol, rtol = 4e-3, 1e-3
if op.name == "nn.functional.max_pool3d" and dtype == torch.float16:
# In a few cases where stride is smaller than kernel size,
# several output grad elements of similar magnitudes get summed
# together, introducing significant error for float16.
atol, rtol = 5e-3, 5e-3
if op.name == "nn.functional.embedding_bag" and dtype == torch.float16:
atol, rtol = 5e-3, 5e-3
self.assertEqual(cpu_grad_inputs, mps_grad_inputs, atol=atol, rtol=rtol)
# The CPU impl of grid_sampler_3d gives a large amount of error for half
# precision types. So instead of testing MPS-vs-CPU outputs, test
# full-vs-half precision dtypes for MPS.
@dtypes(torch.float16, torch.bfloat16)
def test_grid_sampler_3d_half_precision(self, device, dtype):
op = next((op for op in test_consistency_op_db if op.name == "grid_sampler_3d"), None)
include_conjugated_inputs = dtype.is_complex and op.test_conjugated_samples
def get_samples():
return op.sample_inputs(
device,
dtype,
requires_grad=(dtype.is_floating_point or dtype.is_complex),
include_conjugated_inputs=include_conjugated_inputs,
set_seed=True,
)
for half_sample in get_samples():
half_input = half_sample.input
half_grid, mode, padding_mode, align_corners = half_sample.args
full_input = half_input.to(torch.float).detach()
full_grid = half_grid.to(torch.float).detach()
with warnings.catch_warnings():
warnings.filterwarnings("ignore", category=UserWarning)
half_out = op(half_input, half_grid, mode, padding_mode, align_corners)
full_out = op(full_input, full_grid, mode, padding_mode, align_corners)
atol, rtol = 1e-4, 1e-4
self.assertEqual(half_out, full_out.to(dtype), atol=atol, rtol=rtol)
def test_grid_sampler_3d_nan(self, device):
input = torch.ones(1, 1, 3, 3, 3)
grid_nan = torch.tensor([[[[[torch.nan, 1., 1.], [1., 1., 1.]]]]])
out_cpu = torch.grid_sampler_3d(input, grid_nan, 0, 0, True)
out_mps = torch.grid_sampler_3d(input.to(device), grid_nan.to(device), 0, 0, True)
self.assertEqual(out_mps, out_cpu)
def test_fmax_mixed_dtypes(self, device):
# Regression tesing for https://github.com/pytorch/pytorch/issues/149951
# fmax and fmin are implemented as binary metal shaders and they were implemented
# with the assumption that both args have the same dtype
x = torch.rand((3, 3), device=device, dtype=torch.float32)
x_int = torch.randint(-10, 10, (3, 3), device=device, dtype=torch.int8)
y = torch.rand((3, 3), device=device, dtype=torch.float16)
for op in [torch.fmax, torch.fmin]:
self.assertEqual(op(x, y), op(x.to("mps"), y.to("mps")).cpu())
self.assertEqual(op(x_int, y), op(x_int.to("mps"), y.to("mps")).cpu())
# Stride
self.assertEqual(op(x.t(), y), op(x.to("mps").t(), y.to("mps")).cpu())
# Broadcast
self.assertEqual(op(x, y[0]), op(x.to("mps"), y.to("mps")[0]).cpu())
class TestErrorInputs(TestCase):
_ignore_not_implemented_error = True
@ops(
mps_ops_error_inputs_modifier(
[op for op in test_error_inputs_op_db if op.error_inputs_func is not None]
),
dtypes=OpDTypes.none
)
def test_error_inputs(self, device, op):
self.assertEqual(device, "mps:0")
# TODO: Enable per-sample seed setting and tweak tolerances / fix xfails
mps_samples = op.error_inputs(device, set_seed=False)
for mps_sample in mps_samples:
mps_sample_input = mps_sample.sample_input
error_type = mps_sample.error_type
error_regex = mps_sample.error_regex
mps_args = [mps_sample_input.input] + list(mps_sample_input.args)
mps_kwargs = mps_sample_input.kwargs
# for tensor_split(), the second tensor arg ("tensor_indices_or_sections") must be on CPU only
if (op.name == "tensor_split" and isinstance(mps_args[1], torch.Tensor)):
mps_args[1] = mps_args[1].cpu()
with self.assertRaisesRegex(error_type, error_regex):
op(*mps_args, **mps_kwargs)
class TestComplex(TestCase):
def test_tensor_scalar_binops(self):
# Regression test for https://github.com/pytorch/pytorch/issues/119088
def to_cpu(x):
return x.cpu() if isinstance(x, torch.Tensor) else x
# Allocate tensors on mps
with torch.device("mps"):
inputs = [torch.rand(2, dtype=dtype) for dtype in [torch.float, torch.half, torch.cfloat]]
self.assertTrue(all(x.device.type == "mps" for x in inputs))
# Add scalars
inputs.extend([7, 3.14, 2 + 3j, torch.tensor(4 + 5j, dtype=torch.chalf)])
# Iterate over all permutations of types(int, float, complex, half) and ops (excluding div)
for x, y in itertools.product(inputs, inputs):
for op_name in ["__add__", "__sub__", "__mul__"]:
x_cpu, y_cpu = map(to_cpu, (x, y))
res = getattr(x, op_name)(y)
res_cpu = getattr(x_cpu, op_name)(y_cpu)
self.assertEqual(to_cpu(res), res_cpu, f"{op_name}({x}, {y}) produces different results {res} vs {res_cpu}")
# Copied from `TestCommon` in `test_ops.py`, just enough to duplicate the `test_numpy_ref` for MPS
@skipIfSlowGradcheckEnv
class TestCommon(TestCase):
exact_dtype = True
# Verifies, on teardown, that no OpInfo is still using dynamic dtypes in CI
@classmethod
def tearDownClass(cls):
super().tearDownClass()
if IS_CI:
err_msg = (
"The operator(s) below is(are) using dynamic_dtypes in the OpInfo entries."
"This is OK for testing, but be sure to set the dtypes manually before landing your PR!"
)
# Assure no opinfo entry has dynamic_dtypes
filtered_ops = list(filter(opinfo.utils.is_dynamic_dtype_set, op_db))
for op in filtered_ops:
fmt_str = opinfo.utils.str_format_dynamic_dtype(op)
err_msg += "\n" + fmt_str
assert len(filtered_ops) == 0, err_msg
# This is the MPS equivalent of `test_numpy_ref` from `test_ops.py`. It lives over here while
# MPS still requires some fairly heavy special casing in the test framework.
# When MPS becomes more consistent, this can probably be merged with that test using
# `@dtypesIfMPS(torch.float32)`, but for now, the assertions themselves need to be loosened
@suppress_warnings
# MPS only supports float32
@ops(_ref_test_ops, allowed_dtypes=(torch.float32,))
def test_numpy_ref_mps(self, device, dtype, op):
# Unlike `test_numpy_ref`, this test compares in `float32` since at the time of this test's creation MPS
# does not support float64 Tensors.
# TODO: Enable per-sample seed setting and tweak tolerances / fix xfails
inputs = op.reference_inputs(device, dtype, set_seed=False)
for sample_input in inputs:
self.compare_with_reference(op, op.ref, sample_input)
@dtypes(*get_all_dtypes())
def test_tensor_creation(self, device, dtype):
def ones(device):
return torch.ones((2, 2), dtype=dtype, device=device)
if dtype not in MPS_DTYPES:
with self.assertRaises(TypeError):
ones(device)
else:
mps_tensor = ones(device)
cpu_tensor = ones("cpu")
self.assertEqual(mps_tensor.cpu(), cpu_tensor)
class TestMetalLibrary(TestCaseMPS):
def test_metal_arange(self):
x = torch.zeros(12, device="mps", dtype=torch.half)
lib = torch.mps.compile_shader("""
kernel void arange(device half* x, uint idx [[thread_position_in_grid]]) {
x[idx] = idx;
}
""")
lib.arange(x)
self.assertEqual(x, torch.arange(x.numel(), device='mps', dtype=x.dtype))
def test_metal_dispatch_3d(self):
x = torch.empty(12, device="mps")
y = torch.empty_like(x)
z = torch.empty_like(x)
lib = torch.mps.compile_shader("""
kernel void arange_x(device float* x, uint3 idx [[thread_position_in_grid]]) {
x[idx.x + idx.y + idx.z] = idx.x;
}
kernel void arange_y(device float* x, uint3 idx [[thread_position_in_grid]]) {
x[idx.x + idx.y + idx.z] = idx.y;
}
kernel void arange_z(device float* x, uint3 idx [[thread_position_in_grid]]) {
x[idx.x + idx.y + idx.z] = idx.z;
}
""")
# Check that one can enumerate all shaders
self.assertEqual(set(dir(lib)), {f"arange_{i}" for i in ["x", "y", "z"]})
lib.arange_x(x)
lib.arange_y(y, threads=(1, y.numel()))
lib.arange_z(z, threads=(1, 1, z.numel()))
self.assertEqual(x, torch.arange(x.numel(), device='mps', dtype=x.dtype))
self.assertEqual(x, y)
self.assertEqual(x, z)
def test_metal_arange_with_arg(self, start=3.14, step=.5):
x = torch.zeros(12, device="mps")
lib = torch.mps.compile_shader("""
kernel void arange(device float* x, constant float& start, constant float& step,
uint idx [[thread_position_in_grid]]) {
x[idx] = start + idx * step;
}
""")
lib.arange(x, start, step)
self.assertEqual(x, torch.arange(start, 8.66, .5, device='mps'))
def test_metal_arange_with_arg_and_scalar_tensor(self):
self.test_metal_arange_with_arg(step=torch.tensor(.5))
def test_metal_arange_with_arg_and_scalar_tensor_float64(self):
self.test_metal_arange_with_arg(step=torch.tensor(.5, dtype=torch.float64))
def test_metal_arange_with_arg_and_cast(self):
x = torch.zeros(12, device="mps", dtype=torch.half)
y = torch.zeros(12, device="mps", dtype=torch.half)
lib = torch.mps.compile_shader("""
kernel void arange_all_half(device half* x, constant half2& start_step,
uint idx [[thread_position_in_grid]]) {
x[idx] = start_step.x + idx * start_step.y;
}
kernel void arange_half_float(device half* x, constant half& start, constant float& step,
uint idx [[thread_position_in_grid]]) {
x[idx] = start + idx * step;
}
""")
lib.arange_all_half(x, [3.14, .5], arg_casts="fp16")
lib.arange_half_float(y, 3.14, .5, arg_casts={1: "fp16"})
self.assertEqual(x, torch.arange(3.14, 8.66, .5, device='mps', dtype=x.dtype))
self.assertEqual(x, y)
def test_metal_error_checking(self):
# Syntax error asserts
self.assertRaises(SyntaxError, lambda: torch.mps.compile_shader("Syntax error"))
cpu_tensor = torch.rand(3)
mps_tensor = torch.rand(3, device="mps")
lib = torch.mps.compile_shader("kernel void full(device half* x) { x[0] = 1.0; }")
# Passing CPU tensor asserts
self.assertRaises(RuntimeError, lambda: lib.full(cpu_tensor))
# Passing invalid shader name asserts
self.assertRaises(RuntimeError, lambda: lib.non_existing(mps_tensor))
# Passing no tensors asserts
self.assertRaises(RuntimeError, lambda: lib.full(12))
# Exceeing thread group size asserts
max_thread_group_size = lib.full.max_threads_per_threadgroup
self.assertRaises(ValueError, lambda: lib.full(mps_tensor, group_size=max_thread_group_size + 5))
self.assertRaises(ValueError, lambda: lib.full(mps_tensor, threads=(3, max_thread_group_size),
group_size=(3, max_thread_group_size)))
def test_metal_include(self):
# Checks that includes embedding works
lib = torch.mps.compile_shader("#include <c10/metal/special_math.h>")
self.assertIsNotNone(lib)
@parametrize("dtype", [torch.float32, torch.float16, torch.bfloat16, torch.int32, torch.int64])
def test_reduction_utils(self, dtype):
from torch._inductor.codegen.mps import DTYPE_TO_METAL
lib = torch.mps.compile_shader(f"""
#include <c10/metal/reduction_utils.h>
kernel void do_sum(device {DTYPE_TO_METAL[dtype]}* out,
constant {DTYPE_TO_METAL[dtype]}* inp,
uint idx [[thread_position_in_grid]]) {{
out[idx] = c10::metal::simd_sum(inp[idx]);
}}
kernel void do_max(device {DTYPE_TO_METAL[dtype]}* out0,
device int* out1,
constant {DTYPE_TO_METAL[dtype]}* inp,
uint idx [[thread_position_in_grid]]) {{
auto rc = c10::metal::simd_argmax(inp[idx]);
out0[idx] = rc.first;
out1[idx] = rc.second;
}}
""")
x = torch.testing.make_tensor(28, device="mps", dtype=dtype)
y = torch.empty_like(x)
z0 = torch.empty_like(x)
z1 = torch.empty_like(x, dtype=torch.int32)
lib.do_sum(y, x)
lib.do_max(z0, z1, x)
x_sum = x.sum()
x_max, x_max_idx = x.max(dim=0)
max_err = (y - x_sum).abs().max().item()
self.assertLess(max_err, 1e-2 if dtype == torch.float16 else 1e-5,
f"results are {y}, but all elements should have been {x_sum.item()}")
self.assertTrue((z0 == x_max).all().item(),
f"results are {z0}, but all elements should have been {x_max.item()}")
self.assertTrue((z1 == x_max_idx).all().item(),
f"results are {z1}, but all elements should have been {x_max_idx.item()}")
# Test nan propagation
if not dtype.is_floating_point:
return
idx = 25
x[idx] = torch.nan
lib.do_max(z0, z1, x)
self.assertTrue(z0.isnan().all().item(), f"results are {z0}, but all elements shold have been nan")
self.assertTrue((z1 == idx).all().item(), f"results are {z1}, but all elements shold have been {idx}")
@parametrize("dtype", [torch.float32, torch.float16, torch.int32, torch.bfloat16])
def test_atomic_add(self, dtype):
from torch._inductor.codegen.mps import DTYPE_TO_METAL
mdtype = DTYPE_TO_METAL[dtype]
lib = torch.mps.compile_shader(f"""
#include <c10/metal/atomic.h>
using namespace c10::metal;
kernel void atomic_add(device AtomicType<{mdtype}>::type* out,
constant {mdtype}* inc,
uint idx [[thread_position_in_grid]]) {{
AtomicType<{mdtype}>::atomic_add(out, idx & 1 ? 3 : 4, inc[idx]);
}}
""")
x = torch.arange(16, device="mps", dtype=dtype)
y = torch.arange(16, device="mps", dtype=dtype)
lib.atomic_add(x, y)
self.assertEqual(x[3], 67)
self.assertEqual(x[4], 60)
def test_argument_buffers(self):
lib = torch.mps.compile_shader("""
constant constexpr auto nbuffers = 64;
struct Inputs {
metal::array<device float *, nbuffers> args;
};
kernel void sum_all(device float* output, constant Inputs& inputs, uint idx [[thread_position_in_grid]]) {
auto rc = inputs.args[0][idx];
for(auto i = 1; i < nbuffers; ++i) {
rc += inputs.args[i][idx];
}
output[idx] = rc;
}
""")
inputs = torch.rand(64, 32, device="mps").unbind(0)
output = torch.empty_like(inputs[0])
lib.sum_all(output, inputs)
correct = torch.zeros_like(inputs[0])
for inp in inputs:
correct += inp
self.assertEqual(correct, output)
@unittest.skipIf(not torch.mps.profiler.is_metal_capture_enabled(), "Set MTL_CAPTURE_ENABLED and try again")
def test_metal_capture(self):
lib = torch.mps.compile_shader("kernel void full(device float* x, uint idx [[thread_position_in_grid]]) { x[idx] = 1.0; }")
mps_tensor = torch.rand(32, device="mps")
capture_name = f"lib_full{''.join(random.choice('0123456789') for i in range(5))}"
capture_dirname = f"0000-{capture_name}.gputrace"
if os.path.exists(capture_dirname):
shutil.rmtree(capture_dirname)
with torch.mps.profiler.metal_capture(capture_name):
self.assertTrue(torch.mps.profiler.is_capturing_metal())
lib.full(mps_tensor)
self.assertEqual(mps_tensor.sum().item(), 32.0)
self.assertTrue(os.path.exists(capture_dirname), f"Capture file {capture_dirname} has not been generated")
capture_listdir = os.listdir(capture_dirname)
shutil.rmtree(capture_dirname)
self.assertGreater(len(capture_listdir), 3,
f"Capture file {capture_dirname} contains only metadata, i.e. {capture_listdir}")
# TODO: Actually instantiate that test for the "mps" device to better reflect what it is doing.
# This requires mps to be properly registered in the device generic test framework which is not the
# case right now. We can probably use `allow_mps` introduced in https://github.com/pytorch/pytorch/pull/87342
# to achieve this.
instantiate_device_type_tests(TestConsistency, globals(), allow_mps=True, only_for="mps")
instantiate_device_type_tests(TestErrorInputs, globals(), allow_mps=True, only_for="mps")
instantiate_device_type_tests(TestCommon, globals(), allow_mps=True, only_for="mps")
instantiate_device_type_tests(TestLinalgMPS, globals(), allow_mps=True, only_for="mps")
instantiate_parametrized_tests(TestAutocastMPS)
instantiate_parametrized_tests(TestLogical)
instantiate_parametrized_tests(TestMPS)
instantiate_parametrized_tests(TestSDPA)
instantiate_parametrized_tests(TestSmoothL1Loss)
instantiate_parametrized_tests(TestMetalLibrary)
if __name__ == "__main__":
run_tests()