Compare commits

...

1 Commits

View File

@ -9667,10 +9667,12 @@ class TestSDPA(TestCaseMPS):
memory_footprints = []
for _ in range(100):
output = F.scaled_dot_product_attention(query, key, value)
# syncronize to wait for the GPU computation to return
torch.mps.synchronize()
current_mem, driver_mem = get_mps_memory_usage()
memory_footprints.append((current_mem, driver_mem))
# 5 MB different maximum allowed value(could be decreased even more)
torch.testing.assert_close(memory_footprints[-1], memory_footprints[0], atol=5, rtol=1)
# 1 kB different maximum allowed value
torch.testing.assert_close(memory_footprints[-1], memory_footprints[0], atol=1e-3, rtol=1e-3)
def generate_qkv(self, batch: int, NH: int, q_len: int, s_len: int, head_dim: int, layout: str, dtype: torch.dtype):
if layout == "contiguous":