diff --git a/test/inductor/test_flex_attention.py b/test/inductor/test_flex_attention.py index c79f98ce6428..2642c256f312 100644 --- a/test/inductor/test_flex_attention.py +++ b/test/inductor/test_flex_attention.py @@ -41,6 +41,7 @@ from torch.testing._internal.common_device_type import ( dtypesIfCUDA, flex_attention_supported_platform as supported_platform, instantiate_device_type_tests, + largeTensorTest, skipCPUIf, skipCUDAIf, ) @@ -65,6 +66,18 @@ T = TypeVar("T") M = TypeVar("M", bound=Callable) +def large_tensor_test_class( + size: str, device: Optional[Union[torch.device, str]] = None +) -> Callable[[type[T]], type[T]]: + def decorator(cls: type[T]) -> type[T]: + for name, method in list(cls.__dict__.items()): + if callable(method) and name.startswith("test_"): + setattr(cls, name, largeTensorTest(size, device)(method)) + return cls + + return decorator + + @contextmanager def temp_float32_matmul_precision(precision: str): """ @@ -376,6 +389,7 @@ def batch_reserve(paged_attention: PagedAttention, target_seq_len: Tensor): ) +@large_tensor_test_class("2GB", device="cuda") class TestFlexAttention(InductorTestCase): def setUp(self): super().setUp() @@ -1912,17 +1926,9 @@ def forward(self, arg0_1, arg1_1, arg2_1, arg3_1, arg4_1): def causal_mask_slidewindow_mod(b, h, q_idx, kv_idx): return (q_idx >= kv_idx) & (q_idx <= kv_idx + window_size) - mask1 = create_block_mask( - causal_mask, 1, None, 512, 512, _compile=False, device=device - ) + mask1 = create_block_mask(causal_mask, 1, None, 512, 512, _compile=False) mask2 = create_block_mask( - causal_mask_slidewindow_mod, - 1, - None, - 512, - 512, - _compile=False, - device=device, + causal_mask_slidewindow_mod, 1, None, 512, 512, _compile=False ) def f(q, k, v): @@ -2579,7 +2585,7 @@ def forward(self, arg0_1, arg1_1, arg2_1, arg3_1, arg4_1): @supported_platform @skip_on_cpu - def test_eager_backward_strides(self, device): + def test_eager_backward_strides(self): class Repro(torch.nn.Module): def __init__(self): super().__init__() @@ -2601,8 +2607,8 @@ def forward(self, arg0_1, arg1_1, arg2_1, arg3_1, arg4_1): x = torch.nn.attention.flex_attention.flex_attention(q, k, v) return x - model = Repro().to(device) - x = torch.randn((1, 512, 256), device=device, requires_grad=True) + model = Repro().cuda() + x = torch.randn((1, 512, 256), device="cuda", requires_grad=True) out = torch.compile(model, backend="aot_eager", fullgraph=True)(x) out.backward(torch.ones_like(out)) @@ -2683,7 +2689,7 @@ def forward(self, arg0_1, arg1_1, arg2_1, arg3_1, arg4_1): requires_grad=True, ) - block_mask = create_block_mask(mask, None, None, 4096, 4096, device=device) + block_mask = create_block_mask(mask, None, None, 4096, 4096) # Compile 1st version with q/k/v(seqlen=4096) and block_mask(seqlen=4096) torch.compile(flex_attention, dynamic=True, fullgraph=True)( make_tensor(), make_tensor(), make_tensor(), block_mask=block_mask @@ -2692,7 +2698,7 @@ def forward(self, arg0_1, arg1_1, arg2_1, arg3_1, arg4_1): make_tensor2 = functools.partial( torch.randn, (4, 4, 2048, 64), - device=device, + device="cuda", dtype=torch.float32, requires_grad=True, ) @@ -2709,7 +2715,7 @@ def forward(self, arg0_1, arg1_1, arg2_1, arg3_1, arg4_1): k.grad = None v.grad = None - block_mask2 = create_block_mask(mask, None, None, 2048, 2048, device=device) + block_mask2 = create_block_mask(mask, None, None, 2048, 2048) # Reuse the 1st version with q/k/v(seqlen=2048) and block_mask(seqlen=2048) out2 = torch.compile(flex_attention, dynamic=True, fullgraph=True)( q, k, v, block_mask=block_mask2 @@ -2815,11 +2821,11 @@ def forward(self, arg0_1, arg1_1, arg2_1, arg3_1, arg4_1): @supported_platform @skip_on_cpu - def test_strided_backwards(self, device): + def test_strided_backwards(self): shape = (1, 2, 4096, 64) - Q = torch.randn(shape, requires_grad=True, device=device) - K = torch.randn(shape, requires_grad=True, device=device) - V = torch.randn(shape, requires_grad=True, device=device) + Q = torch.randn(shape, requires_grad=True, device="cuda") + K = torch.randn(shape, requires_grad=True, device="cuda") + V = torch.randn(shape, requires_grad=True, device="cuda") func = torch.compile(flex_attention, dynamic=True, fullgraph=True) K_sliced = K[:, :, :-128] @@ -3070,7 +3076,7 @@ def forward(self, arg0_1, arg1_1, arg2_1, arg3_1, arg4_1): make_tensor = functools.partial( torch.randn, (2, 2, 128, 64), - device=device, + device="cuda", dtype=torch.float32, requires_grad=True, ) @@ -3141,7 +3147,7 @@ def forward(self, arg0_1, arg1_1, arg2_1, arg3_1, arg4_1): return score_mod - m = Attention().to(device).eval().to(dtype) + m = Attention().cuda().eval().to(dtype) m = torch.compile(m, mode="default", fullgraph=False) q = torch.randn(B, H, N, D, device=device, dtype=dtype) @@ -3211,15 +3217,10 @@ def forward(self, arg0_1, arg1_1, arg2_1, arg3_1, arg4_1): return causal_mask & window_mask sliding_window_causal = torch.nn.attention.flex_attention.create_block_mask( - sliding_window_causal, - B=None, - H=None, - Q_LEN=N_CTX, - KV_LEN=N_CTX, - device=device, + sliding_window_causal, B=None, H=None, Q_LEN=N_CTX, KV_LEN=N_CTX ) global_causal = torch.nn.attention.flex_attention.create_block_mask( - global_causal, B=None, H=None, Q_LEN=N_CTX, KV_LEN=N_CTX, device=device + global_causal, B=None, H=None, Q_LEN=N_CTX, KV_LEN=N_CTX ) local_attn = functools.partial( @@ -3267,10 +3268,10 @@ def forward(self, arg0_1, arg1_1, arg2_1, arg3_1, arg4_1): def test_mixed_device_error_message(self, device): # Create tensors on different devices cpu_tensor = torch.randn(2, 2, 128, 16, device="cpu") - gpu_tensor = torch.randn(2, 2, 128, 16, device=device) + cuda_tensor = torch.randn(2, 2, 128, 16, device=device) # Use different devices for query, key, and value - query, key, value = cpu_tensor, gpu_tensor, cpu_tensor + query, key, value = cpu_tensor, cuda_tensor, cpu_tensor expected_error_message = ( "Expected query, key, and value to have the same device type, " @@ -3284,8 +3285,8 @@ def forward(self, arg0_1, arg1_1, arg2_1, arg3_1, arg4_1): @supported_platform @skip_on_cpu def test_captured_wrong_device_error_message(self, device): - means = torch.randn(64, 3, device=device) - length_scales = torch.logspace(0.001, 0.1, 8, device="cpu") + means = torch.randn(64, 3).cuda() + length_scales = torch.logspace(0.001, 0.1, 8) def euclidean_dist_pos_embed(score, b, h, q_idx, k_idx): q_pos = means[q_idx] @@ -3305,8 +3306,8 @@ def forward(self, arg0_1, arg1_1, arg2_1, arg3_1, arg4_1): @skip_on_cpu def test_cant_lower_error_message(self, device): # We can't lower a 256-element reduction inside a pointwise reduction - means = torch.randn(64, 256, device=device) - length_scales = torch.logspace(0.001, 0.1, 8, device=device) + means = torch.randn(64, 256).cuda() + length_scales = torch.logspace(0.001, 0.1, 8).cuda() def euclidean_dist_pos_embed(score, b, h, q_idx, k_idx): q_pos = means[q_idx] @@ -3326,8 +3327,8 @@ def forward(self, arg0_1, arg1_1, arg2_1, arg3_1, arg4_1): @skip_on_cpu def test_reduction_unrolled(self, device): # We can't lower a 256-element reduction inside a pointwise reduction - means = torch.randn(S, 3, device=device) - length_scales = torch.logspace(0.001, 0.1, H, device=device) + means = torch.randn(S, 3).to(device) + length_scales = torch.logspace(0.001, 0.1, H).to(device) def euclidean_dist_pos_embed(score, b, h, q_idx, k_idx): q_pos = means[q_idx] @@ -3337,7 +3338,7 @@ def forward(self, arg0_1, arg1_1, arg2_1, arg3_1, arg4_1): inv_dist = torch.exp(-dist / scale) return inv_dist * score - self.run_test(euclidean_dist_pos_embed, torch.bfloat16, device=device) + self.run_test(euclidean_dist_pos_embed, torch.bfloat16, device) @supported_platform @skip_on_cpu @@ -3348,9 +3349,7 @@ def forward(self, arg0_1, arg1_1, arg2_1, arg3_1, arg4_1): expected_error_message = ( "ValueError: Q and KV block size must be divisible by BLOCK_M and BLOCK_N." ) - block_mask = create_block_mask( - noop_mask, 1, 8, 128, 128, BLOCK_SIZE=96, device=device - ) + block_mask = create_block_mask(noop_mask, 1, 8, 128, 128, BLOCK_SIZE=96) with self.assertRaisesRegex(RuntimeError, expected_error_message): torch.compile(flex_attention)(q, k, v, block_mask=block_mask) @@ -3417,7 +3416,6 @@ def forward(self, arg0_1, arg1_1, arg2_1, arg3_1, arg4_1): H=None, Q_LEN=max_time, KV_LEN=max_time, - device=device, ) x = torch.compile( @@ -3430,7 +3428,7 @@ def forward(self, arg0_1, arg1_1, arg2_1, arg3_1, arg4_1): ) return x - model = Model(128).to(device) + model = Model(128).cuda() B, F, T = 16, 256, 12 for _ in range(5): x = torch.randn(B, T, F, device=device) @@ -3601,7 +3599,7 @@ def forward(self, arg0_1, arg1_1, arg2_1, arg3_1, arg4_1): ) return y.transpose(1, 2).contiguous().view(B, T, C) - model = SimpleAttention().to(device) + model = SimpleAttention().cuda() model.compile(mode="default", dynamic=True) sequence_len = 256 @@ -3609,9 +3607,7 @@ def forward(self, arg0_1, arg1_1, arg2_1, arg3_1, arg4_1): torch._dynamo.reset() for batch_shape in [4, 16, 32]: # Create dense mask - rand_mask = torch.randint( - 0, 2, (batch_shape, sequence_len), device=device - ).bool() + rand_mask = torch.randint(0, 2, (batch_shape, sequence_len)).cuda().bool() block_mask = torch.compile(create_block_mask, dynamic=True)( B=batch_shape, BLOCK_SIZE=128, @@ -3623,7 +3619,7 @@ def forward(self, arg0_1, arg1_1, arg2_1, arg3_1, arg4_1): ) # Run forward pass - x = torch.randn(batch_shape, sequence_len, 512, device=device) + x = torch.randn(batch_shape, sequence_len, 512).cuda() model(x, block_mask=block_mask) self.assertEqual(torch._dynamo.utils.counters["aot_autograd"]["ok"], 2) @@ -3651,7 +3647,7 @@ def forward(self, arg0_1, arg1_1, arg2_1, arg3_1, arg4_1): block_mask=block_mask, ) - model = SimpleAttention().to(device) + model = SimpleAttention().cuda() from torch._dynamo.testing import EagerAndRecordGraphs backend = EagerAndRecordGraphs() @@ -3660,7 +3656,7 @@ def forward(self, arg0_1, arg1_1, arg2_1, arg3_1, arg4_1): torch._dynamo.reset() for batch_shape in [4, 16, 32]: - x = torch.randn(batch_shape, sequence_len, 512, device=device) + x = torch.randn(batch_shape, sequence_len, 512).cuda() model(x) self.assertEqual(len(backend.graphs), 1) self.assertExpectedInline( @@ -3687,7 +3683,7 @@ def forward(self, child : torch.Tensor, child_1 : torch.Tensor, child_2 : torch. def causal_mask(b, h, q_idx, kv_idx): return q_idx >= kv_idx - block_mask = create_block_mask(causal_mask, 1, 1, 128, 128, device=device) + block_mask = create_block_mask(causal_mask, 1, 1, 128, 128) func = torch.compile(flex_attention, backend=cnt, fullgraph=True) out = func(query, key, value, _squared, block_mask=block_mask) @@ -3747,10 +3743,13 @@ class GraphModule(torch.nn.Module): out.sum().backward() joint_graph = normalize_gm(aot_graphs[1].print_readable(print_output=False)) - expected_joint_graph = """\ + + self.assertExpectedInline( + joint_graph, + """\ class GraphModule(torch.nn.Module): def forward(self, primals_1: "f64[2, 2, 128, 4]", primals_2: "f64[2, 2, 128, 4]", primals_3: "f64[2, 2, 128, 4]", full: "i32[1, 1, 1]", full_default: "i32[1, 1, 1, 1]", convert_element_type: "i32[1, 1, 1]", convert_element_type_1: "i32[1, 1, 1, 1]", getitem_2: "f64[2, 2, 128, 4]", getitem_3: "f32[2, 2, 128]", tangents_1: "f64[2, 2, 128, 4]"): - full_default_4: "f32[2, 2, 128]" = torch.ops.aten.full.default([2, 2, 128], 0, dtype = torch.float32, layout = torch.strided, device = device(type='GPU_TYPE', index=0), pin_memory = False) + full_default_4: "f32[2, 2, 128]" = torch.ops.aten.full.default([2, 2, 128], 0, dtype = torch.float32, layout = torch.strided, device = device(type='cuda', index=0), pin_memory = False) fw_graph0 = self.fw_graph0 joint_graph0 = self.joint_graph0 mask_graph0 = self.mask_graph0 @@ -3775,15 +3774,9 @@ class GraphModule(torch.nn.Module): class mask_graph0(torch.nn.Module): def forward(self, arg0_1: "i32[]", arg1_1: "i32[]", arg2_1: "i32[]", arg3_1: "i32[]"): - full: "b8[]" = torch.ops.aten.full.default([], True, dtype = torch.bool, layout = torch.strided, device = device(type='GPU_TYPE', index=0), pin_memory = False) + full: "b8[]" = torch.ops.aten.full.default([], True, dtype = torch.bool, layout = torch.strided, device = device(type='cuda', index=0), pin_memory = False) return full -""".replace( # noqa: B950 - "GPU_TYPE", torch.device(device).type - ) - - self.assertExpectedInline( - joint_graph, - expected_joint_graph, +""", # noqa: B950 ) @supported_platform @@ -3833,7 +3826,7 @@ class GraphModule(torch.nn.Module): flex_attention(q, k, v) # compiled cpu support for small embedding size - q, k, v = [torch.randn(2, 2, 128, 8, device=device) for _ in range(3)] + q, k, v = [torch.randn(2, 2, 128, 8, device="cpu") for _ in range(3)] flex_attention(q, k, v) # compiled gpu kernel does not support small embedding size @@ -3905,7 +3898,7 @@ class TestBlockMask(InductorTestCase): def causal_mask(b, h, q, kv): return (q + (offset[b] * 128)) >= kv - block_mask = create_block_mask(causal_mask, 4, 2, 2048, 2048, device=device) + block_mask = create_block_mask(causal_mask, 4, 2, 2048, 2048) self.assertEqual(block_mask.shape, (4, 2, 2048, 2048)) self.assertEqual(block_mask[0].shape, (2, 2048, 2048)) self.assertEqual(block_mask[0, 0].shape, (2048, 2048)) @@ -3916,7 +3909,7 @@ class TestBlockMask(InductorTestCase): self.assertEqual(block_mask.sparsity(), block_mask[1].sparsity()) offset = torch.arange(8, device=device) - block_mask = create_block_mask(causal_mask, 8, 1, 2048, 2048, device=device) + block_mask = create_block_mask(causal_mask, 8, 1, 2048, 2048) self.assertEqual(block_mask.sparsity(), 29.1015625) self.assertTrue(block_mask.sparsity() < block_mask[0].sparsity()) self.assertTrue(block_mask[0].sparsity() > block_mask[1].sparsity()) @@ -3933,7 +3926,7 @@ class TestBlockMask(InductorTestCase): Q_BLOCK_SIZE, KV_BLOCK_SIZE = BLOCK_SIZE block_mask = create_block_mask( - noop_mask, B, H, Q_LEN, KV_LEN, BLOCK_SIZE=BLOCK_SIZE, device=device + noop_mask, B, H, Q_LEN, KV_LEN, BLOCK_SIZE=BLOCK_SIZE ) self.assertEqual(block_mask.BLOCK_SIZE, (Q_BLOCK_SIZE, KV_BLOCK_SIZE)) @@ -3946,7 +3939,7 @@ class TestBlockMask(InductorTestCase): def causal_mask(b, h, q, kv): return (q + (offset[b] * 128)) >= kv - block_mask = create_block_mask(causal_mask, 4, 2, 512, 512, device=device) + block_mask = create_block_mask(causal_mask, 4, 2, 512, 512) assert block_mask.kv_num_blocks.shape == (4, 2, 4) assert block_mask.kv_indices.shape == (4, 2, 4, 4) @@ -3998,17 +3991,16 @@ class TestBlockMask(InductorTestCase): @supported_platform def test_block_mask_device_change(self, device): - device = torch.device(device) offset = torch.zeros(8, device=device) def causal_mask(b, h, q, kv): return (q + (offset[b] * 128)) >= kv - block_mask = create_block_mask(causal_mask, 1, 1, 512, 512, device=device) - assert block_mask.kv_indices.device.type == device.type - assert block_mask.kv_num_blocks.device.type == device.type - assert block_mask.q_indices.device.type == device.type - assert block_mask.q_num_blocks.device.type == device.type + block_mask = create_block_mask(causal_mask, 1, 1, 512, 512) + assert block_mask.kv_indices.is_cuda + assert block_mask.kv_num_blocks.is_cuda + assert block_mask.q_indices.is_cuda + assert block_mask.q_num_blocks.is_cuda block_mask = block_mask.to("cpu") assert block_mask.kv_indices.is_cpu @@ -4016,11 +4008,11 @@ class TestBlockMask(InductorTestCase): assert block_mask.q_indices.is_cpu assert block_mask.q_num_blocks.is_cpu - block_mask = block_mask.to(device) - assert block_mask.kv_indices.device.type == device.type - assert block_mask.kv_num_blocks.device.type == device.type - assert block_mask.q_indices.device.type == device.type - assert block_mask.q_num_blocks.device.type == device.type + block_mask = block_mask.to("cuda") + assert block_mask.kv_indices.is_cuda + assert block_mask.kv_num_blocks.is_cuda + assert block_mask.q_indices.is_cuda + assert block_mask.q_num_blocks.is_cuda @supported_platform def test_compiling_create_block_mask(self, device): @@ -4030,7 +4022,7 @@ class TestBlockMask(InductorTestCase): return (q >= kv) & (seq[q] == seq[kv]) block_mask = torch.compile(create_block_mask, fullgraph=True)( - mask_mod, 1, 1, 512, 512, device=device + mask_mod, 1, 1, 512, 512 ) self.assertIsInstance(block_mask, BlockMask) self.assertEqual(block_mask.kv_num_blocks.shape, torch.Size((1, 1, 4))) @@ -4042,27 +4034,21 @@ class TestBlockMask(InductorTestCase): return q >= kv torch._dynamo.reset() - block_mask = torch.compile(create_block_mask)( - mask_mod, 2, 4, 1024, 1024, device=device - ) + block_mask = torch.compile(create_block_mask)(mask_mod, 2, 4, 1024, 1024) self.assertIsInstance(block_mask, BlockMask) self.assertEqual(block_mask.kv_num_blocks.shape, torch.Size((2, 4, 8))) self.assertEqual(block_mask.kv_indices.shape, torch.Size((2, 4, 8, 8))) self.assertEqual(torch._dynamo.utils.counters["aot_autograd"]["ok"], 1) # automatic dynamic shapes triggered and recompilation. - block_mask = torch.compile(create_block_mask)( - mask_mod, 4, 8, 2048, 2048, device=device - ) + block_mask = torch.compile(create_block_mask)(mask_mod, 4, 8, 2048, 2048) self.assertIsInstance(block_mask, BlockMask) self.assertEqual(block_mask.kv_num_blocks.shape, torch.Size((4, 8, 16))) self.assertEqual(block_mask.kv_indices.shape, torch.Size((4, 8, 16, 16))) self.assertEqual(torch._dynamo.utils.counters["aot_autograd"]["ok"], 2) # no recompilation. - block_mask = torch.compile(create_block_mask)( - mask_mod, 6, 16, 3072, 3072, device=device - ) + block_mask = torch.compile(create_block_mask)(mask_mod, 6, 16, 3072, 3072) self.assertIsInstance(block_mask, BlockMask) self.assertEqual(block_mask.kv_num_blocks.shape, torch.Size((6, 16, 24))) self.assertEqual(block_mask.kv_indices.shape, torch.Size((6, 16, 24, 24))) @@ -4073,7 +4059,7 @@ class TestBlockMask(InductorTestCase): def causal_mask(b, h, q, kv): return q >= kv - block_mask = create_block_mask(causal_mask, 1, 1, 2048, 2048, device=device) + block_mask = create_block_mask(causal_mask, 1, 1, 2048, 2048) def replace_non_printable(s): def replace(c): @@ -4114,9 +4100,7 @@ BlockMask(shape=(1,s1,s2048,s2048),ssparsity=46.88%,s def causal_offset_mask(b, h, q, kv): return (q + offset[b] * 128) >= kv - block_mask = create_block_mask( - causal_offset_mask, 8, 1, 2048, 2048, device=device - ) + block_mask = create_block_mask(causal_offset_mask, 8, 1, 2048, 2048) str_block_mask = str(block_mask) self.assertTrue("sparsity=29.10" in str_block_mask) @@ -4272,7 +4256,7 @@ BlockMask(shape=(1,s1,s2048,s2048),ssparsity=46.88%,s def causal_mask(b, h, q_idx, kv_idx): return q_idx >= kv_idx - block_mask = create_block_mask(causal_mask, 1, 1, 2048, 2048, device=device) + block_mask = create_block_mask(causal_mask, 1, 1, 2048, 2048) # manually set q_num_blocks and q_indices to None block_mask.q_num_blocks = None block_mask.q_indices = None @@ -4341,6 +4325,7 @@ BlockMask(shape=(1,s1,s2048,s2048),ssparsity=46.88%,s lengths[index] += 1 return lengths + device = "cuda" max_seq_len, doc_count = 128, 4 SEQ_LEN = max_seq_len @@ -4370,7 +4355,7 @@ BlockMask(shape=(1,s1,s2048,s2048),ssparsity=46.88%,s ) for i in range(5): lengths = generate_random_lengths(1024 + i, 5) - offsets = length_to_offsets(lengths, device) + offsets = length_to_offsets(lengths, "cuda") doc_ids = _offsets_to_doc_ids_tensor(offsets) def doc_mask_mod(b, h, q_idx, kv_idx): @@ -4382,9 +4367,7 @@ BlockMask(shape=(1,s1,s2048,s2048),ssparsity=46.88%,s q, k, v = ( torch.randn(1, 12, 1024 + i, 64, device=device) for _ in range(3) ) - block_mask = create_block_mask( - doc_mask_mod, None, None, 1024 + i, 1024 + i, device=device - ) + block_mask = create_block_mask(doc_mask_mod, None, None, 1024 + i, 1024 + i) torch.compile(flex_attention)(q, k, v, block_mask=block_mask) @supported_platform @@ -4449,16 +4432,17 @@ BlockMask(shape=(1,s1,s2048,s2048),ssparsity=46.88%,s ) return q, k, v - block_mask = create_block_mask(mask_mod, None, None, 1024, 1024, device=device) + block_mask = create_block_mask(mask_mod, None, None, 1024, 1024) flex_attention_call(*create_inputs(1024), block_mask=block_mask) with self.assertRaisesRegex(ValueError, "block_mask was created for"): flex_attention_call(*create_inputs(2048), block_mask=block_mask) - block_mask = create_block_mask(mask_mod, None, None, 1023, 1023, device=device) + block_mask = create_block_mask(mask_mod, None, None, 1023, 1023) with self.assertRaisesRegex(ValueError, "block_mask was created for"): flex_attention_call(*create_inputs(1024), block_mask=block_mask) +@large_tensor_test_class("2GB", device="cuda") class TestPagedAttention(InductorTestCase): def setUp(self): super().setUp() @@ -4908,18 +4892,20 @@ supports_learnable_bias = unittest.skipUnless( @supports_learnable_bias +@large_tensor_test_class("2GB", device="cuda") class TestLearnableBiases(InductorTestCase): def setUp(self): super().setUp() + self.device = "cuda" self.dtype = torch.float32 self.atol = 3e-2 self.rtol = 3e-2 - def _init_tensors(self, params: Params, device: str): + def _init_tensors(self, params: Params): make_tensor = functools.partial( torch.randn, (params.batch_size, params.num_heads, params.seq_length, params.head_dim), - device=device, + device=self.device, dtype=params.dtype, requires_grad=True, ) @@ -4980,11 +4966,11 @@ class TestLearnableBiases(InductorTestCase): "params", get_params(device_configs["cuda"].dtypes), name_fn=lambda x: f"{x}" ) @common_utils.parametrize("mode", ["default", "max-autotune-no-cudagraphs"]) - def test_relative_1d_bias(self, device, params, mode: str): - query, key, value = self._init_tensors(params, device=device) + def test_relative_1d_bias(self, params, mode: str): + query, key, value = self._init_tensors(params) bias = torch.randn( 2 * params.seq_length, - device=device, + device=self.device, dtype=params.dtype, requires_grad=True, ) @@ -5012,12 +4998,12 @@ class TestLearnableBiases(InductorTestCase): @common_utils.parametrize( "params", get_params(device_configs["cuda"].dtypes), name_fn=lambda x: f"{x}" ) - def test_absolute_2d_bias(self, device, params): - query, key, value = self._init_tensors(params, device=device) + def test_absolute_2d_bias(self, params): + query, key, value = self._init_tensors(params) bias = torch.randn( params.seq_length, params.seq_length, - device=device, + device=self.device, dtype=params.dtype, requires_grad=True, ) @@ -5045,13 +5031,13 @@ class TestLearnableBiases(InductorTestCase): @common_utils.parametrize( "params", get_params(device_configs["cuda"].dtypes), name_fn=lambda x: f"{x}" ) - def test_head_specific_bias(self, device, params): - query, key, value = self._init_tensors(params, device=device) + def test_head_specific_bias(self, params): + query, key, value = self._init_tensors(params) bias = torch.randn( params.num_heads, params.seq_length, params.seq_length, - device=device, + device=self.device, dtype=params.dtype, requires_grad=True, ) @@ -5079,14 +5065,14 @@ class TestLearnableBiases(InductorTestCase): @common_utils.parametrize( "params", get_params(device_configs["cuda"].dtypes), name_fn=lambda x: f"{x}" ) - def test_batch_head_bias(self, device, params): - query, key, value = self._init_tensors(params, device=device) + def test_batch_head_bias(self, params): + query, key, value = self._init_tensors(params) bias = torch.randn( params.batch_size, params.num_heads, params.seq_length, params.seq_length, - device=device, + device=self.device, dtype=params.dtype, requires_grad=True, ) @@ -5114,11 +5100,11 @@ class TestLearnableBiases(InductorTestCase): @common_utils.parametrize( "params", get_params(device_configs["cuda"].dtypes), name_fn=lambda x: f"{x}" ) - def test_multiplicative_bias(self, device, params): - query, key, value = self._init_tensors(params, device=device) + def test_multiplicative_bias(self, params): + query, key, value = self._init_tensors(params) bias = torch.randn( params.seq_length, - device=device, + device=self.device, dtype=params.dtype, requires_grad=True, ) @@ -5146,12 +5132,12 @@ class TestLearnableBiases(InductorTestCase): @common_utils.parametrize( "params", get_params(device_configs["cuda"].dtypes), name_fn=lambda x: f"{x}" ) - def test_local_window_bias(self, device, params): - query, key, value = self._init_tensors(params, device=device) + def test_local_window_bias(self, params): + query, key, value = self._init_tensors(params) window_size = 8 bias = torch.randn( 2 * window_size + 1, - device=device, + device=self.device, dtype=params.dtype, requires_grad=True, ) @@ -5180,11 +5166,11 @@ class TestLearnableBiases(InductorTestCase): @common_utils.parametrize( "params", get_params(device_configs["cuda"].dtypes), name_fn=lambda x: f"{x}" ) - def test_global_tokens_bias(self, device, params): - query, key, value = self._init_tensors(params, device=device) + def test_global_tokens_bias(self, params): + query, key, value = self._init_tensors(params) bias = torch.randn( params.seq_length, - device=device, + device=self.device, dtype=params.dtype, requires_grad=True, ) @@ -5212,18 +5198,18 @@ class TestLearnableBiases(InductorTestCase): @common_utils.parametrize( "params", get_params(device_configs["cuda"].dtypes), name_fn=lambda x: f"{x}" ) - def test_weird_bias(self, device, params): - query, key, value = self._init_tensors(params, device=device) + def test_weird_bias(self, params): + query, key, value = self._init_tensors(params) bias = torch.randn( params.batch_size, params.num_heads, 4, params.seq_length, - device=device, + device=self.device, dtype=params.dtype, requires_grad=True, ) - which_bias = torch.tensor(0, device=device) + which_bias = torch.tensor(0, device=self.device) def bias_func(score, b, h, q_idx, kv_idx): return score + bias[b, h, which_bias, q_idx] @@ -5248,11 +5234,11 @@ class TestLearnableBiases(InductorTestCase): @common_utils.parametrize( "params", get_params(device_configs["cuda"].dtypes), name_fn=lambda x: f"{x}" ) - def test_indirect_bias(self, device, params): - query, key, value = self._init_tensors(params, device=device) + def test_indirect_bias(self, params): + query, key, value = self._init_tensors(params) bias = torch.randn( params.seq_length, - device=device, + device=self.device, dtype=params.dtype, requires_grad=True, ) @@ -5261,7 +5247,7 @@ class TestLearnableBiases(InductorTestCase): 0, params.seq_length, (params.seq_length,), - device=device, + device=self.device, ) def bias_func(score, b, h, q_idx, kv_idx): @@ -5288,11 +5274,11 @@ class TestLearnableBiases(InductorTestCase): "params", get_params(device_configs["cuda"].dtypes), name_fn=lambda x: f"{x}" ) @common_utils.parametrize("mode", ["default", "max-autotune-no-cudagraphs"]) - def test_symmetric_bias(self, device, params, mode: str): - query, key, value = self._init_tensors(params, device=device) + def test_symmetric_bias(self, params, mode: str): + query, key, value = self._init_tensors(params) bias = torch.randn( params.seq_length, - device=device, + device=self.device, dtype=params.dtype, requires_grad=True, ) @@ -5324,12 +5310,12 @@ class TestLearnableBiases(InductorTestCase): @common_utils.parametrize( "params", get_params(device_configs["cuda"].dtypes), name_fn=lambda x: f"{x}" ) - def test_flipped_indexed_bias(self, device, params): - query, key, value = self._init_tensors(params, device=device) + def test_flipped_indexed_bias(self, params): + query, key, value = self._init_tensors(params) bias = torch.randn( params.seq_length, params.seq_length, - device=device, + device=self.device, dtype=params.dtype, requires_grad=True, ) @@ -5358,11 +5344,11 @@ class TestLearnableBiases(InductorTestCase): "params", get_params(device_configs["cuda"].dtypes), name_fn=lambda x: f"{x}" ) @common_utils.parametrize("mode", ["default", "max-autotune-no-cudagraphs"]) - def test_head_specific_gate(self, device, params, mode: str): - query, key, value = self._init_tensors(params, device=device) + def test_head_specific_gate(self, params, mode: str): + query, key, value = self._init_tensors(params) gate_score = torch.randn( params.num_heads, - device=device, + device=self.device, dtype=params.dtype, requires_grad=True, ) @@ -5390,18 +5376,18 @@ class TestLearnableBiases(InductorTestCase): @common_utils.parametrize( "params", get_params(device_configs["cuda"].dtypes), name_fn=lambda x: f"{x}" ) - def test_distinct_biases(self, device, params): - query, key, value = self._init_tensors(params, device=device) + def test_distinct_biases(self, params): + query, key, value = self._init_tensors(params) # Create two separate bias tensors bias1 = torch.randn( params.seq_length, - device=device, + device=self.device, dtype=params.dtype, requires_grad=True, ) bias2 = torch.randn( params.seq_length, - device=device, + device=self.device, dtype=params.dtype, requires_grad=True, ) @@ -5438,8 +5424,8 @@ class TestLearnableBiases(InductorTestCase): @common_utils.parametrize( "params", get_params(device_configs["cuda"].dtypes), name_fn=lambda x: f"{x}" ) - def test_relative_1d_bias_only_grad(self, device, params): - query, key, value = self._init_tensors(params, device=device) + def test_relative_1d_bias_only_grad(self, params): + query, key, value = self._init_tensors(params) query = query.detach().requires_grad_(False) key = key.detach().requires_grad_(False) value = value.detach().requires_grad_(False) @@ -5447,7 +5433,7 @@ class TestLearnableBiases(InductorTestCase): # Only bias requires gradients bias = torch.randn( 2 * params.seq_length, - device=device, + device=self.device, dtype=params.dtype, requires_grad=True, # Only bias needs gradients ) @@ -5471,10 +5457,10 @@ class TestLearnableBiases(InductorTestCase): out_eager, out_compiled, out_gold, (bias,), names=["out", "bias"] ) - def test_flex_attention_with_dynamic_max_autotune(self, device): - query = torch.randn(2, 16, 512, 64, device=device) - key = torch.randn(2, 16, 512, 64, device=device) - value = torch.randn(2, 16, 512, 64, device=device) + def test_flex_attention_with_dynamic_max_autotune(self): + query = torch.randn(2, 16, 512, 64, device=self.device) + key = torch.randn(2, 16, 512, 64, device=self.device) + value = torch.randn(2, 16, 512, 64, device=self.device) query.requires_grad = True key.requires_grad = True value.requires_grad = True @@ -5488,9 +5474,7 @@ class TestLearnableBiases(InductorTestCase): return m >= n mask_shape = (1, 1, M, N) - block_mask = torch.compile(create_block_mask)( - causal, *mask_shape, device=device - ) + block_mask = torch.compile(create_block_mask)(causal, *mask_shape, "cuda") compiled_sdpa = torch.compile( flex_attention, dynamic=True, mode="max-autotune-no-cudagraphs" @@ -5511,23 +5495,26 @@ class TestLearnableBiases(InductorTestCase): out.shape, query.shape, f"Expected shape {query.shape}, got {out.shape}" ) - def test_inspect_bug(self, device): + def test_inspect_bug(self): # https://github.com/pytorch/pytorch/issues/139374 def sliding_window(b, h, q_idx, kv_idx, val): return (q_idx - kv_idx).abs() < val sliding_window2 = functools.partial( - sliding_window, val=torch.randn((), device=device) + sliding_window, val=torch.randn((), device=self.device) ) opt_fn = torch.compile(create_block_mask, fullgraph=True) - create_block_mask(sliding_window2, None, None, 1024, 1024, device=device) + create_block_mask(sliding_window2, None, None, 1024, 1024) # checks that the compile is working - opt_fn(sliding_window2, None, None, 1024, 1024, device=device) + opt_fn(sliding_window2, None, None, 1024, 1024) @supported_platform - def test_head_bias_req_grad(self, device): + def test_head_bias_req_grad(self): + device = self.device B, H, S, D = 1, 4, 256, 64 - bias = torch.randn(H, device=device, dtype=torch.float16, requires_grad=True) + bias = torch.randn( + H, device=self.device, dtype=torch.float16, requires_grad=True + ) bias_flex = bias.detach().clone().requires_grad_(True) @@ -5558,7 +5545,8 @@ class TestLearnableBiases(InductorTestCase): ) @supported_platform - def test_comparison_vs_sdpa_with_learnable_bias(self, device): + def test_comparison_vs_sdpa_with_learnable_bias(self): + device = self.device # 1-dimensional bias: B, H, S, D = 1, 1, 256, 64 bias = torch.randn( @@ -5771,7 +5759,8 @@ class TestLearnableBiases(InductorTestCase): instantiate_device_type_tests(TestFlexAttention, globals(), only_for=test_device) instantiate_device_type_tests(TestPagedAttention, globals(), only_for=test_device) instantiate_device_type_tests(TestBlockMask, globals(), only_for=("cuda",)) -instantiate_device_type_tests(TestLearnableBiases, globals(), only_for=("cuda",)) + +common_utils.instantiate_parametrized_tests(TestLearnableBiases) if __name__ == "__main__": from torch._inductor.test_case import run_tests