[CUDA] Fix missing __syncthreads in MultiMarginLoss backward (#158994)

Turns out issue in #158921 is detectable with a simple unit test and adding the missing sync fixes it

Pull Request resolved: https://github.com/pytorch/pytorch/pull/158994
Approved by: https://github.com/malfet, https://github.com/Skylion007

Co-authored-by: Nikita Shulga <2453524+malfet@users.noreply.github.com>
This commit is contained in:
eqy
2025-07-24 20:47:24 +00:00
committed by PyTorch MergeBot
parent 13398dab79
commit 8573a2beda
2 changed files with 20 additions and 0 deletions

View File

@ -9291,6 +9291,25 @@ class TestNNDeviceType(NNTestCase):
y = torch.ones(10, 0, device=device).type(torch.long)
mod(x, y)
@onlyCUDA
@dtypes(torch.float, torch.double)
def test_MarginLoss_race(self, device, dtype):
loss = torch.nn.MultiMarginLoss().to(device)
batch = 1
classes = 128
x = torch.randn(batch, classes, requires_grad=True, device=device, dtype=dtype)
y = torch.randint(low=0, high=classes, size=(batch,), device=device, dtype=torch.long)
x_cpu = x.detach().clone().cpu()
y_cpu = y.detach().clone().cpu()
out = loss(x, y)
out.backward()
x_cpu = x.detach().clone().cpu()
x_cpu.requires_grad = True
y_cpu = y.detach().clone().cpu()
out_cpu = loss.cpu()(x_cpu, y_cpu)
out_cpu.backward()
self.assertEqual(x_cpu.grad, x.grad.cpu())
@onlyCUDA
def test_MarginLoss_warnings(self, device):
model = torch.nn.Linear(128, 22, device=device)