Revert "[Inductor UT] Generalize device-bias code in test_flex_attention.py (#151937)"

This reverts commit 443840080265ce6133121c91d258b619eae151bb.

Reverted https://github.com/pytorch/pytorch/pull/151937 on behalf of https://github.com/malfet due to Broke ASAN tests, probably by enabling too many tests https://hud.pytorch.org/hud/pytorch/pytorch/main/1?per_page=50&name_filter=asan&mergeEphemeralLF=true ([comment](https://github.com/pytorch/pytorch/pull/151937#issuecomment-2835151532))
This commit is contained in:
PyTorch MergeBot
2025-04-28 12:56:49 +00:00
parent 0b6ea0b959
commit 9c864f9b0f

View File

@ -41,6 +41,7 @@ from torch.testing._internal.common_device_type import (
dtypesIfCUDA,
flex_attention_supported_platform as supported_platform,
instantiate_device_type_tests,
largeTensorTest,
skipCPUIf,
skipCUDAIf,
)
@ -65,6 +66,18 @@ T = TypeVar("T")
M = TypeVar("M", bound=Callable)
def large_tensor_test_class(
size: str, device: Optional[Union[torch.device, str]] = None
) -> Callable[[type[T]], type[T]]:
def decorator(cls: type[T]) -> type[T]:
for name, method in list(cls.__dict__.items()):
if callable(method) and name.startswith("test_"):
setattr(cls, name, largeTensorTest(size, device)(method))
return cls
return decorator
@contextmanager
def temp_float32_matmul_precision(precision: str):
"""
@ -376,6 +389,7 @@ def batch_reserve(paged_attention: PagedAttention, target_seq_len: Tensor):
)
@large_tensor_test_class("2GB", device="cuda")
class TestFlexAttention(InductorTestCase):
def setUp(self):
super().setUp()
@ -1912,17 +1926,9 @@ def forward(self, arg0_1, arg1_1, arg2_1, arg3_1, arg4_1):
def causal_mask_slidewindow_mod(b, h, q_idx, kv_idx):
return (q_idx >= kv_idx) & (q_idx <= kv_idx + window_size)
mask1 = create_block_mask(
causal_mask, 1, None, 512, 512, _compile=False, device=device
)
mask1 = create_block_mask(causal_mask, 1, None, 512, 512, _compile=False)
mask2 = create_block_mask(
causal_mask_slidewindow_mod,
1,
None,
512,
512,
_compile=False,
device=device,
causal_mask_slidewindow_mod, 1, None, 512, 512, _compile=False
)
def f(q, k, v):
@ -2579,7 +2585,7 @@ def forward(self, arg0_1, arg1_1, arg2_1, arg3_1, arg4_1):
@supported_platform
@skip_on_cpu
def test_eager_backward_strides(self, device):
def test_eager_backward_strides(self):
class Repro(torch.nn.Module):
def __init__(self):
super().__init__()
@ -2601,8 +2607,8 @@ def forward(self, arg0_1, arg1_1, arg2_1, arg3_1, arg4_1):
x = torch.nn.attention.flex_attention.flex_attention(q, k, v)
return x
model = Repro().to(device)
x = torch.randn((1, 512, 256), device=device, requires_grad=True)
model = Repro().cuda()
x = torch.randn((1, 512, 256), device="cuda", requires_grad=True)
out = torch.compile(model, backend="aot_eager", fullgraph=True)(x)
out.backward(torch.ones_like(out))
@ -2683,7 +2689,7 @@ def forward(self, arg0_1, arg1_1, arg2_1, arg3_1, arg4_1):
requires_grad=True,
)
block_mask = create_block_mask(mask, None, None, 4096, 4096, device=device)
block_mask = create_block_mask(mask, None, None, 4096, 4096)
# Compile 1st version with q/k/v(seqlen=4096) and block_mask(seqlen=4096)
torch.compile(flex_attention, dynamic=True, fullgraph=True)(
make_tensor(), make_tensor(), make_tensor(), block_mask=block_mask
@ -2692,7 +2698,7 @@ def forward(self, arg0_1, arg1_1, arg2_1, arg3_1, arg4_1):
make_tensor2 = functools.partial(
torch.randn,
(4, 4, 2048, 64),
device=device,
device="cuda",
dtype=torch.float32,
requires_grad=True,
)
@ -2709,7 +2715,7 @@ def forward(self, arg0_1, arg1_1, arg2_1, arg3_1, arg4_1):
k.grad = None
v.grad = None
block_mask2 = create_block_mask(mask, None, None, 2048, 2048, device=device)
block_mask2 = create_block_mask(mask, None, None, 2048, 2048)
# Reuse the 1st version with q/k/v(seqlen=2048) and block_mask(seqlen=2048)
out2 = torch.compile(flex_attention, dynamic=True, fullgraph=True)(
q, k, v, block_mask=block_mask2
@ -2815,11 +2821,11 @@ def forward(self, arg0_1, arg1_1, arg2_1, arg3_1, arg4_1):
@supported_platform
@skip_on_cpu
def test_strided_backwards(self, device):
def test_strided_backwards(self):
shape = (1, 2, 4096, 64)
Q = torch.randn(shape, requires_grad=True, device=device)
K = torch.randn(shape, requires_grad=True, device=device)
V = torch.randn(shape, requires_grad=True, device=device)
Q = torch.randn(shape, requires_grad=True, device="cuda")
K = torch.randn(shape, requires_grad=True, device="cuda")
V = torch.randn(shape, requires_grad=True, device="cuda")
func = torch.compile(flex_attention, dynamic=True, fullgraph=True)
K_sliced = K[:, :, :-128]
@ -3070,7 +3076,7 @@ def forward(self, arg0_1, arg1_1, arg2_1, arg3_1, arg4_1):
make_tensor = functools.partial(
torch.randn,
(2, 2, 128, 64),
device=device,
device="cuda",
dtype=torch.float32,
requires_grad=True,
)
@ -3141,7 +3147,7 @@ def forward(self, arg0_1, arg1_1, arg2_1, arg3_1, arg4_1):
return score_mod
m = Attention().to(device).eval().to(dtype)
m = Attention().cuda().eval().to(dtype)
m = torch.compile(m, mode="default", fullgraph=False)
q = torch.randn(B, H, N, D, device=device, dtype=dtype)
@ -3211,15 +3217,10 @@ def forward(self, arg0_1, arg1_1, arg2_1, arg3_1, arg4_1):
return causal_mask & window_mask
sliding_window_causal = torch.nn.attention.flex_attention.create_block_mask(
sliding_window_causal,
B=None,
H=None,
Q_LEN=N_CTX,
KV_LEN=N_CTX,
device=device,
sliding_window_causal, B=None, H=None, Q_LEN=N_CTX, KV_LEN=N_CTX
)
global_causal = torch.nn.attention.flex_attention.create_block_mask(
global_causal, B=None, H=None, Q_LEN=N_CTX, KV_LEN=N_CTX, device=device
global_causal, B=None, H=None, Q_LEN=N_CTX, KV_LEN=N_CTX
)
local_attn = functools.partial(
@ -3267,10 +3268,10 @@ def forward(self, arg0_1, arg1_1, arg2_1, arg3_1, arg4_1):
def test_mixed_device_error_message(self, device):
# Create tensors on different devices
cpu_tensor = torch.randn(2, 2, 128, 16, device="cpu")
gpu_tensor = torch.randn(2, 2, 128, 16, device=device)
cuda_tensor = torch.randn(2, 2, 128, 16, device=device)
# Use different devices for query, key, and value
query, key, value = cpu_tensor, gpu_tensor, cpu_tensor
query, key, value = cpu_tensor, cuda_tensor, cpu_tensor
expected_error_message = (
"Expected query, key, and value to have the same device type, "
@ -3284,8 +3285,8 @@ def forward(self, arg0_1, arg1_1, arg2_1, arg3_1, arg4_1):
@supported_platform
@skip_on_cpu
def test_captured_wrong_device_error_message(self, device):
means = torch.randn(64, 3, device=device)
length_scales = torch.logspace(0.001, 0.1, 8, device="cpu")
means = torch.randn(64, 3).cuda()
length_scales = torch.logspace(0.001, 0.1, 8)
def euclidean_dist_pos_embed(score, b, h, q_idx, k_idx):
q_pos = means[q_idx]
@ -3305,8 +3306,8 @@ def forward(self, arg0_1, arg1_1, arg2_1, arg3_1, arg4_1):
@skip_on_cpu
def test_cant_lower_error_message(self, device):
# We can't lower a 256-element reduction inside a pointwise reduction
means = torch.randn(64, 256, device=device)
length_scales = torch.logspace(0.001, 0.1, 8, device=device)
means = torch.randn(64, 256).cuda()
length_scales = torch.logspace(0.001, 0.1, 8).cuda()
def euclidean_dist_pos_embed(score, b, h, q_idx, k_idx):
q_pos = means[q_idx]
@ -3326,8 +3327,8 @@ def forward(self, arg0_1, arg1_1, arg2_1, arg3_1, arg4_1):
@skip_on_cpu
def test_reduction_unrolled(self, device):
# We can't lower a 256-element reduction inside a pointwise reduction
means = torch.randn(S, 3, device=device)
length_scales = torch.logspace(0.001, 0.1, H, device=device)
means = torch.randn(S, 3).to(device)
length_scales = torch.logspace(0.001, 0.1, H).to(device)
def euclidean_dist_pos_embed(score, b, h, q_idx, k_idx):
q_pos = means[q_idx]
@ -3337,7 +3338,7 @@ def forward(self, arg0_1, arg1_1, arg2_1, arg3_1, arg4_1):
inv_dist = torch.exp(-dist / scale)
return inv_dist * score
self.run_test(euclidean_dist_pos_embed, torch.bfloat16, device=device)
self.run_test(euclidean_dist_pos_embed, torch.bfloat16, device)
@supported_platform
@skip_on_cpu
@ -3348,9 +3349,7 @@ def forward(self, arg0_1, arg1_1, arg2_1, arg3_1, arg4_1):
expected_error_message = (
"ValueError: Q and KV block size must be divisible by BLOCK_M and BLOCK_N."
)
block_mask = create_block_mask(
noop_mask, 1, 8, 128, 128, BLOCK_SIZE=96, device=device
)
block_mask = create_block_mask(noop_mask, 1, 8, 128, 128, BLOCK_SIZE=96)
with self.assertRaisesRegex(RuntimeError, expected_error_message):
torch.compile(flex_attention)(q, k, v, block_mask=block_mask)
@ -3417,7 +3416,6 @@ def forward(self, arg0_1, arg1_1, arg2_1, arg3_1, arg4_1):
H=None,
Q_LEN=max_time,
KV_LEN=max_time,
device=device,
)
x = torch.compile(
@ -3430,7 +3428,7 @@ def forward(self, arg0_1, arg1_1, arg2_1, arg3_1, arg4_1):
)
return x
model = Model(128).to(device)
model = Model(128).cuda()
B, F, T = 16, 256, 12
for _ in range(5):
x = torch.randn(B, T, F, device=device)
@ -3601,7 +3599,7 @@ def forward(self, arg0_1, arg1_1, arg2_1, arg3_1, arg4_1):
)
return y.transpose(1, 2).contiguous().view(B, T, C)
model = SimpleAttention().to(device)
model = SimpleAttention().cuda()
model.compile(mode="default", dynamic=True)
sequence_len = 256
@ -3609,9 +3607,7 @@ def forward(self, arg0_1, arg1_1, arg2_1, arg3_1, arg4_1):
torch._dynamo.reset()
for batch_shape in [4, 16, 32]:
# Create dense mask
rand_mask = torch.randint(
0, 2, (batch_shape, sequence_len), device=device
).bool()
rand_mask = torch.randint(0, 2, (batch_shape, sequence_len)).cuda().bool()
block_mask = torch.compile(create_block_mask, dynamic=True)(
B=batch_shape,
BLOCK_SIZE=128,
@ -3623,7 +3619,7 @@ def forward(self, arg0_1, arg1_1, arg2_1, arg3_1, arg4_1):
)
# Run forward pass
x = torch.randn(batch_shape, sequence_len, 512, device=device)
x = torch.randn(batch_shape, sequence_len, 512).cuda()
model(x, block_mask=block_mask)
self.assertEqual(torch._dynamo.utils.counters["aot_autograd"]["ok"], 2)
@ -3651,7 +3647,7 @@ def forward(self, arg0_1, arg1_1, arg2_1, arg3_1, arg4_1):
block_mask=block_mask,
)
model = SimpleAttention().to(device)
model = SimpleAttention().cuda()
from torch._dynamo.testing import EagerAndRecordGraphs
backend = EagerAndRecordGraphs()
@ -3660,7 +3656,7 @@ def forward(self, arg0_1, arg1_1, arg2_1, arg3_1, arg4_1):
torch._dynamo.reset()
for batch_shape in [4, 16, 32]:
x = torch.randn(batch_shape, sequence_len, 512, device=device)
x = torch.randn(batch_shape, sequence_len, 512).cuda()
model(x)
self.assertEqual(len(backend.graphs), 1)
self.assertExpectedInline(
@ -3687,7 +3683,7 @@ def forward(self, child : torch.Tensor, child_1 : torch.Tensor, child_2 : torch.
def causal_mask(b, h, q_idx, kv_idx):
return q_idx >= kv_idx
block_mask = create_block_mask(causal_mask, 1, 1, 128, 128, device=device)
block_mask = create_block_mask(causal_mask, 1, 1, 128, 128)
func = torch.compile(flex_attention, backend=cnt, fullgraph=True)
out = func(query, key, value, _squared, block_mask=block_mask)
@ -3747,10 +3743,13 @@ class GraphModule(torch.nn.Module):
out.sum().backward()
joint_graph = normalize_gm(aot_graphs[1].print_readable(print_output=False))
expected_joint_graph = """\
self.assertExpectedInline(
joint_graph,
"""\
class GraphModule(torch.nn.Module):
def forward(self, primals_1: "f64[2, 2, 128, 4]", primals_2: "f64[2, 2, 128, 4]", primals_3: "f64[2, 2, 128, 4]", full: "i32[1, 1, 1]", full_default: "i32[1, 1, 1, 1]", convert_element_type: "i32[1, 1, 1]", convert_element_type_1: "i32[1, 1, 1, 1]", getitem_2: "f64[2, 2, 128, 4]", getitem_3: "f32[2, 2, 128]", tangents_1: "f64[2, 2, 128, 4]"):
full_default_4: "f32[2, 2, 128]" = torch.ops.aten.full.default([2, 2, 128], 0, dtype = torch.float32, layout = torch.strided, device = device(type='GPU_TYPE', index=0), pin_memory = False)
full_default_4: "f32[2, 2, 128]" = torch.ops.aten.full.default([2, 2, 128], 0, dtype = torch.float32, layout = torch.strided, device = device(type='cuda', index=0), pin_memory = False)
fw_graph0 = self.fw_graph0
joint_graph0 = self.joint_graph0
mask_graph0 = self.mask_graph0
@ -3775,15 +3774,9 @@ class GraphModule(torch.nn.Module):
class mask_graph0(torch.nn.Module):
def forward(self, arg0_1: "i32[]", arg1_1: "i32[]", arg2_1: "i32[]", arg3_1: "i32[]"):
full: "b8[]" = torch.ops.aten.full.default([], True, dtype = torch.bool, layout = torch.strided, device = device(type='GPU_TYPE', index=0), pin_memory = False)
full: "b8[]" = torch.ops.aten.full.default([], True, dtype = torch.bool, layout = torch.strided, device = device(type='cuda', index=0), pin_memory = False)
return full
""".replace( # noqa: B950
"GPU_TYPE", torch.device(device).type
)
self.assertExpectedInline(
joint_graph,
expected_joint_graph,
""", # noqa: B950
)
@supported_platform
@ -3833,7 +3826,7 @@ class GraphModule(torch.nn.Module):
flex_attention(q, k, v)
# compiled cpu support for small embedding size
q, k, v = [torch.randn(2, 2, 128, 8, device=device) for _ in range(3)]
q, k, v = [torch.randn(2, 2, 128, 8, device="cpu") for _ in range(3)]
flex_attention(q, k, v)
# compiled gpu kernel does not support small embedding size
@ -3905,7 +3898,7 @@ class TestBlockMask(InductorTestCase):
def causal_mask(b, h, q, kv):
return (q + (offset[b] * 128)) >= kv
block_mask = create_block_mask(causal_mask, 4, 2, 2048, 2048, device=device)
block_mask = create_block_mask(causal_mask, 4, 2, 2048, 2048)
self.assertEqual(block_mask.shape, (4, 2, 2048, 2048))
self.assertEqual(block_mask[0].shape, (2, 2048, 2048))
self.assertEqual(block_mask[0, 0].shape, (2048, 2048))
@ -3916,7 +3909,7 @@ class TestBlockMask(InductorTestCase):
self.assertEqual(block_mask.sparsity(), block_mask[1].sparsity())
offset = torch.arange(8, device=device)
block_mask = create_block_mask(causal_mask, 8, 1, 2048, 2048, device=device)
block_mask = create_block_mask(causal_mask, 8, 1, 2048, 2048)
self.assertEqual(block_mask.sparsity(), 29.1015625)
self.assertTrue(block_mask.sparsity() < block_mask[0].sparsity())
self.assertTrue(block_mask[0].sparsity() > block_mask[1].sparsity())
@ -3933,7 +3926,7 @@ class TestBlockMask(InductorTestCase):
Q_BLOCK_SIZE, KV_BLOCK_SIZE = BLOCK_SIZE
block_mask = create_block_mask(
noop_mask, B, H, Q_LEN, KV_LEN, BLOCK_SIZE=BLOCK_SIZE, device=device
noop_mask, B, H, Q_LEN, KV_LEN, BLOCK_SIZE=BLOCK_SIZE
)
self.assertEqual(block_mask.BLOCK_SIZE, (Q_BLOCK_SIZE, KV_BLOCK_SIZE))
@ -3946,7 +3939,7 @@ class TestBlockMask(InductorTestCase):
def causal_mask(b, h, q, kv):
return (q + (offset[b] * 128)) >= kv
block_mask = create_block_mask(causal_mask, 4, 2, 512, 512, device=device)
block_mask = create_block_mask(causal_mask, 4, 2, 512, 512)
assert block_mask.kv_num_blocks.shape == (4, 2, 4)
assert block_mask.kv_indices.shape == (4, 2, 4, 4)
@ -3998,17 +3991,16 @@ class TestBlockMask(InductorTestCase):
@supported_platform
def test_block_mask_device_change(self, device):
device = torch.device(device)
offset = torch.zeros(8, device=device)
def causal_mask(b, h, q, kv):
return (q + (offset[b] * 128)) >= kv
block_mask = create_block_mask(causal_mask, 1, 1, 512, 512, device=device)
assert block_mask.kv_indices.device.type == device.type
assert block_mask.kv_num_blocks.device.type == device.type
assert block_mask.q_indices.device.type == device.type
assert block_mask.q_num_blocks.device.type == device.type
block_mask = create_block_mask(causal_mask, 1, 1, 512, 512)
assert block_mask.kv_indices.is_cuda
assert block_mask.kv_num_blocks.is_cuda
assert block_mask.q_indices.is_cuda
assert block_mask.q_num_blocks.is_cuda
block_mask = block_mask.to("cpu")
assert block_mask.kv_indices.is_cpu
@ -4016,11 +4008,11 @@ class TestBlockMask(InductorTestCase):
assert block_mask.q_indices.is_cpu
assert block_mask.q_num_blocks.is_cpu
block_mask = block_mask.to(device)
assert block_mask.kv_indices.device.type == device.type
assert block_mask.kv_num_blocks.device.type == device.type
assert block_mask.q_indices.device.type == device.type
assert block_mask.q_num_blocks.device.type == device.type
block_mask = block_mask.to("cuda")
assert block_mask.kv_indices.is_cuda
assert block_mask.kv_num_blocks.is_cuda
assert block_mask.q_indices.is_cuda
assert block_mask.q_num_blocks.is_cuda
@supported_platform
def test_compiling_create_block_mask(self, device):
@ -4030,7 +4022,7 @@ class TestBlockMask(InductorTestCase):
return (q >= kv) & (seq[q] == seq[kv])
block_mask = torch.compile(create_block_mask, fullgraph=True)(
mask_mod, 1, 1, 512, 512, device=device
mask_mod, 1, 1, 512, 512
)
self.assertIsInstance(block_mask, BlockMask)
self.assertEqual(block_mask.kv_num_blocks.shape, torch.Size((1, 1, 4)))
@ -4042,27 +4034,21 @@ class TestBlockMask(InductorTestCase):
return q >= kv
torch._dynamo.reset()
block_mask = torch.compile(create_block_mask)(
mask_mod, 2, 4, 1024, 1024, device=device
)
block_mask = torch.compile(create_block_mask)(mask_mod, 2, 4, 1024, 1024)
self.assertIsInstance(block_mask, BlockMask)
self.assertEqual(block_mask.kv_num_blocks.shape, torch.Size((2, 4, 8)))
self.assertEqual(block_mask.kv_indices.shape, torch.Size((2, 4, 8, 8)))
self.assertEqual(torch._dynamo.utils.counters["aot_autograd"]["ok"], 1)
# automatic dynamic shapes triggered and recompilation.
block_mask = torch.compile(create_block_mask)(
mask_mod, 4, 8, 2048, 2048, device=device
)
block_mask = torch.compile(create_block_mask)(mask_mod, 4, 8, 2048, 2048)
self.assertIsInstance(block_mask, BlockMask)
self.assertEqual(block_mask.kv_num_blocks.shape, torch.Size((4, 8, 16)))
self.assertEqual(block_mask.kv_indices.shape, torch.Size((4, 8, 16, 16)))
self.assertEqual(torch._dynamo.utils.counters["aot_autograd"]["ok"], 2)
# no recompilation.
block_mask = torch.compile(create_block_mask)(
mask_mod, 6, 16, 3072, 3072, device=device
)
block_mask = torch.compile(create_block_mask)(mask_mod, 6, 16, 3072, 3072)
self.assertIsInstance(block_mask, BlockMask)
self.assertEqual(block_mask.kv_num_blocks.shape, torch.Size((6, 16, 24)))
self.assertEqual(block_mask.kv_indices.shape, torch.Size((6, 16, 24, 24)))
@ -4073,7 +4059,7 @@ class TestBlockMask(InductorTestCase):
def causal_mask(b, h, q, kv):
return q >= kv
block_mask = create_block_mask(causal_mask, 1, 1, 2048, 2048, device=device)
block_mask = create_block_mask(causal_mask, 1, 1, 2048, 2048)
def replace_non_printable(s):
def replace(c):
@ -4114,9 +4100,7 @@ BlockMask(shape=(1,s1,s2048,s2048),ssparsity=46.88%,s
def causal_offset_mask(b, h, q, kv):
return (q + offset[b] * 128) >= kv
block_mask = create_block_mask(
causal_offset_mask, 8, 1, 2048, 2048, device=device
)
block_mask = create_block_mask(causal_offset_mask, 8, 1, 2048, 2048)
str_block_mask = str(block_mask)
self.assertTrue("sparsity=29.10" in str_block_mask)
@ -4272,7 +4256,7 @@ BlockMask(shape=(1,s1,s2048,s2048),ssparsity=46.88%,s
def causal_mask(b, h, q_idx, kv_idx):
return q_idx >= kv_idx
block_mask = create_block_mask(causal_mask, 1, 1, 2048, 2048, device=device)
block_mask = create_block_mask(causal_mask, 1, 1, 2048, 2048)
# manually set q_num_blocks and q_indices to None
block_mask.q_num_blocks = None
block_mask.q_indices = None
@ -4341,6 +4325,7 @@ BlockMask(shape=(1,s1,s2048,s2048),ssparsity=46.88%,s
lengths[index] += 1
return lengths
device = "cuda"
max_seq_len, doc_count = 128, 4
SEQ_LEN = max_seq_len
@ -4370,7 +4355,7 @@ BlockMask(shape=(1,s1,s2048,s2048),ssparsity=46.88%,s
)
for i in range(5):
lengths = generate_random_lengths(1024 + i, 5)
offsets = length_to_offsets(lengths, device)
offsets = length_to_offsets(lengths, "cuda")
doc_ids = _offsets_to_doc_ids_tensor(offsets)
def doc_mask_mod(b, h, q_idx, kv_idx):
@ -4382,9 +4367,7 @@ BlockMask(shape=(1,s1,s2048,s2048),ssparsity=46.88%,s
q, k, v = (
torch.randn(1, 12, 1024 + i, 64, device=device) for _ in range(3)
)
block_mask = create_block_mask(
doc_mask_mod, None, None, 1024 + i, 1024 + i, device=device
)
block_mask = create_block_mask(doc_mask_mod, None, None, 1024 + i, 1024 + i)
torch.compile(flex_attention)(q, k, v, block_mask=block_mask)
@supported_platform
@ -4449,16 +4432,17 @@ BlockMask(shape=(1,s1,s2048,s2048),ssparsity=46.88%,s
)
return q, k, v
block_mask = create_block_mask(mask_mod, None, None, 1024, 1024, device=device)
block_mask = create_block_mask(mask_mod, None, None, 1024, 1024)
flex_attention_call(*create_inputs(1024), block_mask=block_mask)
with self.assertRaisesRegex(ValueError, "block_mask was created for"):
flex_attention_call(*create_inputs(2048), block_mask=block_mask)
block_mask = create_block_mask(mask_mod, None, None, 1023, 1023, device=device)
block_mask = create_block_mask(mask_mod, None, None, 1023, 1023)
with self.assertRaisesRegex(ValueError, "block_mask was created for"):
flex_attention_call(*create_inputs(1024), block_mask=block_mask)
@large_tensor_test_class("2GB", device="cuda")
class TestPagedAttention(InductorTestCase):
def setUp(self):
super().setUp()
@ -4908,18 +4892,20 @@ supports_learnable_bias = unittest.skipUnless(
@supports_learnable_bias
@large_tensor_test_class("2GB", device="cuda")
class TestLearnableBiases(InductorTestCase):
def setUp(self):
super().setUp()
self.device = "cuda"
self.dtype = torch.float32
self.atol = 3e-2
self.rtol = 3e-2
def _init_tensors(self, params: Params, device: str):
def _init_tensors(self, params: Params):
make_tensor = functools.partial(
torch.randn,
(params.batch_size, params.num_heads, params.seq_length, params.head_dim),
device=device,
device=self.device,
dtype=params.dtype,
requires_grad=True,
)
@ -4980,11 +4966,11 @@ class TestLearnableBiases(InductorTestCase):
"params", get_params(device_configs["cuda"].dtypes), name_fn=lambda x: f"{x}"
)
@common_utils.parametrize("mode", ["default", "max-autotune-no-cudagraphs"])
def test_relative_1d_bias(self, device, params, mode: str):
query, key, value = self._init_tensors(params, device=device)
def test_relative_1d_bias(self, params, mode: str):
query, key, value = self._init_tensors(params)
bias = torch.randn(
2 * params.seq_length,
device=device,
device=self.device,
dtype=params.dtype,
requires_grad=True,
)
@ -5012,12 +4998,12 @@ class TestLearnableBiases(InductorTestCase):
@common_utils.parametrize(
"params", get_params(device_configs["cuda"].dtypes), name_fn=lambda x: f"{x}"
)
def test_absolute_2d_bias(self, device, params):
query, key, value = self._init_tensors(params, device=device)
def test_absolute_2d_bias(self, params):
query, key, value = self._init_tensors(params)
bias = torch.randn(
params.seq_length,
params.seq_length,
device=device,
device=self.device,
dtype=params.dtype,
requires_grad=True,
)
@ -5045,13 +5031,13 @@ class TestLearnableBiases(InductorTestCase):
@common_utils.parametrize(
"params", get_params(device_configs["cuda"].dtypes), name_fn=lambda x: f"{x}"
)
def test_head_specific_bias(self, device, params):
query, key, value = self._init_tensors(params, device=device)
def test_head_specific_bias(self, params):
query, key, value = self._init_tensors(params)
bias = torch.randn(
params.num_heads,
params.seq_length,
params.seq_length,
device=device,
device=self.device,
dtype=params.dtype,
requires_grad=True,
)
@ -5079,14 +5065,14 @@ class TestLearnableBiases(InductorTestCase):
@common_utils.parametrize(
"params", get_params(device_configs["cuda"].dtypes), name_fn=lambda x: f"{x}"
)
def test_batch_head_bias(self, device, params):
query, key, value = self._init_tensors(params, device=device)
def test_batch_head_bias(self, params):
query, key, value = self._init_tensors(params)
bias = torch.randn(
params.batch_size,
params.num_heads,
params.seq_length,
params.seq_length,
device=device,
device=self.device,
dtype=params.dtype,
requires_grad=True,
)
@ -5114,11 +5100,11 @@ class TestLearnableBiases(InductorTestCase):
@common_utils.parametrize(
"params", get_params(device_configs["cuda"].dtypes), name_fn=lambda x: f"{x}"
)
def test_multiplicative_bias(self, device, params):
query, key, value = self._init_tensors(params, device=device)
def test_multiplicative_bias(self, params):
query, key, value = self._init_tensors(params)
bias = torch.randn(
params.seq_length,
device=device,
device=self.device,
dtype=params.dtype,
requires_grad=True,
)
@ -5146,12 +5132,12 @@ class TestLearnableBiases(InductorTestCase):
@common_utils.parametrize(
"params", get_params(device_configs["cuda"].dtypes), name_fn=lambda x: f"{x}"
)
def test_local_window_bias(self, device, params):
query, key, value = self._init_tensors(params, device=device)
def test_local_window_bias(self, params):
query, key, value = self._init_tensors(params)
window_size = 8
bias = torch.randn(
2 * window_size + 1,
device=device,
device=self.device,
dtype=params.dtype,
requires_grad=True,
)
@ -5180,11 +5166,11 @@ class TestLearnableBiases(InductorTestCase):
@common_utils.parametrize(
"params", get_params(device_configs["cuda"].dtypes), name_fn=lambda x: f"{x}"
)
def test_global_tokens_bias(self, device, params):
query, key, value = self._init_tensors(params, device=device)
def test_global_tokens_bias(self, params):
query, key, value = self._init_tensors(params)
bias = torch.randn(
params.seq_length,
device=device,
device=self.device,
dtype=params.dtype,
requires_grad=True,
)
@ -5212,18 +5198,18 @@ class TestLearnableBiases(InductorTestCase):
@common_utils.parametrize(
"params", get_params(device_configs["cuda"].dtypes), name_fn=lambda x: f"{x}"
)
def test_weird_bias(self, device, params):
query, key, value = self._init_tensors(params, device=device)
def test_weird_bias(self, params):
query, key, value = self._init_tensors(params)
bias = torch.randn(
params.batch_size,
params.num_heads,
4,
params.seq_length,
device=device,
device=self.device,
dtype=params.dtype,
requires_grad=True,
)
which_bias = torch.tensor(0, device=device)
which_bias = torch.tensor(0, device=self.device)
def bias_func(score, b, h, q_idx, kv_idx):
return score + bias[b, h, which_bias, q_idx]
@ -5248,11 +5234,11 @@ class TestLearnableBiases(InductorTestCase):
@common_utils.parametrize(
"params", get_params(device_configs["cuda"].dtypes), name_fn=lambda x: f"{x}"
)
def test_indirect_bias(self, device, params):
query, key, value = self._init_tensors(params, device=device)
def test_indirect_bias(self, params):
query, key, value = self._init_tensors(params)
bias = torch.randn(
params.seq_length,
device=device,
device=self.device,
dtype=params.dtype,
requires_grad=True,
)
@ -5261,7 +5247,7 @@ class TestLearnableBiases(InductorTestCase):
0,
params.seq_length,
(params.seq_length,),
device=device,
device=self.device,
)
def bias_func(score, b, h, q_idx, kv_idx):
@ -5288,11 +5274,11 @@ class TestLearnableBiases(InductorTestCase):
"params", get_params(device_configs["cuda"].dtypes), name_fn=lambda x: f"{x}"
)
@common_utils.parametrize("mode", ["default", "max-autotune-no-cudagraphs"])
def test_symmetric_bias(self, device, params, mode: str):
query, key, value = self._init_tensors(params, device=device)
def test_symmetric_bias(self, params, mode: str):
query, key, value = self._init_tensors(params)
bias = torch.randn(
params.seq_length,
device=device,
device=self.device,
dtype=params.dtype,
requires_grad=True,
)
@ -5324,12 +5310,12 @@ class TestLearnableBiases(InductorTestCase):
@common_utils.parametrize(
"params", get_params(device_configs["cuda"].dtypes), name_fn=lambda x: f"{x}"
)
def test_flipped_indexed_bias(self, device, params):
query, key, value = self._init_tensors(params, device=device)
def test_flipped_indexed_bias(self, params):
query, key, value = self._init_tensors(params)
bias = torch.randn(
params.seq_length,
params.seq_length,
device=device,
device=self.device,
dtype=params.dtype,
requires_grad=True,
)
@ -5358,11 +5344,11 @@ class TestLearnableBiases(InductorTestCase):
"params", get_params(device_configs["cuda"].dtypes), name_fn=lambda x: f"{x}"
)
@common_utils.parametrize("mode", ["default", "max-autotune-no-cudagraphs"])
def test_head_specific_gate(self, device, params, mode: str):
query, key, value = self._init_tensors(params, device=device)
def test_head_specific_gate(self, params, mode: str):
query, key, value = self._init_tensors(params)
gate_score = torch.randn(
params.num_heads,
device=device,
device=self.device,
dtype=params.dtype,
requires_grad=True,
)
@ -5390,18 +5376,18 @@ class TestLearnableBiases(InductorTestCase):
@common_utils.parametrize(
"params", get_params(device_configs["cuda"].dtypes), name_fn=lambda x: f"{x}"
)
def test_distinct_biases(self, device, params):
query, key, value = self._init_tensors(params, device=device)
def test_distinct_biases(self, params):
query, key, value = self._init_tensors(params)
# Create two separate bias tensors
bias1 = torch.randn(
params.seq_length,
device=device,
device=self.device,
dtype=params.dtype,
requires_grad=True,
)
bias2 = torch.randn(
params.seq_length,
device=device,
device=self.device,
dtype=params.dtype,
requires_grad=True,
)
@ -5438,8 +5424,8 @@ class TestLearnableBiases(InductorTestCase):
@common_utils.parametrize(
"params", get_params(device_configs["cuda"].dtypes), name_fn=lambda x: f"{x}"
)
def test_relative_1d_bias_only_grad(self, device, params):
query, key, value = self._init_tensors(params, device=device)
def test_relative_1d_bias_only_grad(self, params):
query, key, value = self._init_tensors(params)
query = query.detach().requires_grad_(False)
key = key.detach().requires_grad_(False)
value = value.detach().requires_grad_(False)
@ -5447,7 +5433,7 @@ class TestLearnableBiases(InductorTestCase):
# Only bias requires gradients
bias = torch.randn(
2 * params.seq_length,
device=device,
device=self.device,
dtype=params.dtype,
requires_grad=True, # Only bias needs gradients
)
@ -5471,10 +5457,10 @@ class TestLearnableBiases(InductorTestCase):
out_eager, out_compiled, out_gold, (bias,), names=["out", "bias"]
)
def test_flex_attention_with_dynamic_max_autotune(self, device):
query = torch.randn(2, 16, 512, 64, device=device)
key = torch.randn(2, 16, 512, 64, device=device)
value = torch.randn(2, 16, 512, 64, device=device)
def test_flex_attention_with_dynamic_max_autotune(self):
query = torch.randn(2, 16, 512, 64, device=self.device)
key = torch.randn(2, 16, 512, 64, device=self.device)
value = torch.randn(2, 16, 512, 64, device=self.device)
query.requires_grad = True
key.requires_grad = True
value.requires_grad = True
@ -5488,9 +5474,7 @@ class TestLearnableBiases(InductorTestCase):
return m >= n
mask_shape = (1, 1, M, N)
block_mask = torch.compile(create_block_mask)(
causal, *mask_shape, device=device
)
block_mask = torch.compile(create_block_mask)(causal, *mask_shape, "cuda")
compiled_sdpa = torch.compile(
flex_attention, dynamic=True, mode="max-autotune-no-cudagraphs"
@ -5511,23 +5495,26 @@ class TestLearnableBiases(InductorTestCase):
out.shape, query.shape, f"Expected shape {query.shape}, got {out.shape}"
)
def test_inspect_bug(self, device):
def test_inspect_bug(self):
# https://github.com/pytorch/pytorch/issues/139374
def sliding_window(b, h, q_idx, kv_idx, val):
return (q_idx - kv_idx).abs() < val
sliding_window2 = functools.partial(
sliding_window, val=torch.randn((), device=device)
sliding_window, val=torch.randn((), device=self.device)
)
opt_fn = torch.compile(create_block_mask, fullgraph=True)
create_block_mask(sliding_window2, None, None, 1024, 1024, device=device)
create_block_mask(sliding_window2, None, None, 1024, 1024)
# checks that the compile is working
opt_fn(sliding_window2, None, None, 1024, 1024, device=device)
opt_fn(sliding_window2, None, None, 1024, 1024)
@supported_platform
def test_head_bias_req_grad(self, device):
def test_head_bias_req_grad(self):
device = self.device
B, H, S, D = 1, 4, 256, 64
bias = torch.randn(H, device=device, dtype=torch.float16, requires_grad=True)
bias = torch.randn(
H, device=self.device, dtype=torch.float16, requires_grad=True
)
bias_flex = bias.detach().clone().requires_grad_(True)
@ -5558,7 +5545,8 @@ class TestLearnableBiases(InductorTestCase):
)
@supported_platform
def test_comparison_vs_sdpa_with_learnable_bias(self, device):
def test_comparison_vs_sdpa_with_learnable_bias(self):
device = self.device
# 1-dimensional bias:
B, H, S, D = 1, 1, 256, 64
bias = torch.randn(
@ -5771,7 +5759,8 @@ class TestLearnableBiases(InductorTestCase):
instantiate_device_type_tests(TestFlexAttention, globals(), only_for=test_device)
instantiate_device_type_tests(TestPagedAttention, globals(), only_for=test_device)
instantiate_device_type_tests(TestBlockMask, globals(), only_for=("cuda",))
instantiate_device_type_tests(TestLearnableBiases, globals(), only_for=("cuda",))
common_utils.instantiate_parametrized_tests(TestLearnableBiases)
if __name__ == "__main__":
from torch._inductor.test_case import run_tests