mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 12:54:11 +08:00
Revert "[Inductor UT] Generalize device-bias code in test_flex_attention.py
(#151937)"
This reverts commit 443840080265ce6133121c91d258b619eae151bb. Reverted https://github.com/pytorch/pytorch/pull/151937 on behalf of https://github.com/malfet due to Broke ASAN tests, probably by enabling too many tests https://hud.pytorch.org/hud/pytorch/pytorch/main/1?per_page=50&name_filter=asan&mergeEphemeralLF=true ([comment](https://github.com/pytorch/pytorch/pull/151937#issuecomment-2835151532))
This commit is contained in:
@ -41,6 +41,7 @@ from torch.testing._internal.common_device_type import (
|
||||
dtypesIfCUDA,
|
||||
flex_attention_supported_platform as supported_platform,
|
||||
instantiate_device_type_tests,
|
||||
largeTensorTest,
|
||||
skipCPUIf,
|
||||
skipCUDAIf,
|
||||
)
|
||||
@ -65,6 +66,18 @@ T = TypeVar("T")
|
||||
M = TypeVar("M", bound=Callable)
|
||||
|
||||
|
||||
def large_tensor_test_class(
|
||||
size: str, device: Optional[Union[torch.device, str]] = None
|
||||
) -> Callable[[type[T]], type[T]]:
|
||||
def decorator(cls: type[T]) -> type[T]:
|
||||
for name, method in list(cls.__dict__.items()):
|
||||
if callable(method) and name.startswith("test_"):
|
||||
setattr(cls, name, largeTensorTest(size, device)(method))
|
||||
return cls
|
||||
|
||||
return decorator
|
||||
|
||||
|
||||
@contextmanager
|
||||
def temp_float32_matmul_precision(precision: str):
|
||||
"""
|
||||
@ -376,6 +389,7 @@ def batch_reserve(paged_attention: PagedAttention, target_seq_len: Tensor):
|
||||
)
|
||||
|
||||
|
||||
@large_tensor_test_class("2GB", device="cuda")
|
||||
class TestFlexAttention(InductorTestCase):
|
||||
def setUp(self):
|
||||
super().setUp()
|
||||
@ -1912,17 +1926,9 @@ def forward(self, arg0_1, arg1_1, arg2_1, arg3_1, arg4_1):
|
||||
def causal_mask_slidewindow_mod(b, h, q_idx, kv_idx):
|
||||
return (q_idx >= kv_idx) & (q_idx <= kv_idx + window_size)
|
||||
|
||||
mask1 = create_block_mask(
|
||||
causal_mask, 1, None, 512, 512, _compile=False, device=device
|
||||
)
|
||||
mask1 = create_block_mask(causal_mask, 1, None, 512, 512, _compile=False)
|
||||
mask2 = create_block_mask(
|
||||
causal_mask_slidewindow_mod,
|
||||
1,
|
||||
None,
|
||||
512,
|
||||
512,
|
||||
_compile=False,
|
||||
device=device,
|
||||
causal_mask_slidewindow_mod, 1, None, 512, 512, _compile=False
|
||||
)
|
||||
|
||||
def f(q, k, v):
|
||||
@ -2579,7 +2585,7 @@ def forward(self, arg0_1, arg1_1, arg2_1, arg3_1, arg4_1):
|
||||
|
||||
@supported_platform
|
||||
@skip_on_cpu
|
||||
def test_eager_backward_strides(self, device):
|
||||
def test_eager_backward_strides(self):
|
||||
class Repro(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
@ -2601,8 +2607,8 @@ def forward(self, arg0_1, arg1_1, arg2_1, arg3_1, arg4_1):
|
||||
x = torch.nn.attention.flex_attention.flex_attention(q, k, v)
|
||||
return x
|
||||
|
||||
model = Repro().to(device)
|
||||
x = torch.randn((1, 512, 256), device=device, requires_grad=True)
|
||||
model = Repro().cuda()
|
||||
x = torch.randn((1, 512, 256), device="cuda", requires_grad=True)
|
||||
out = torch.compile(model, backend="aot_eager", fullgraph=True)(x)
|
||||
out.backward(torch.ones_like(out))
|
||||
|
||||
@ -2683,7 +2689,7 @@ def forward(self, arg0_1, arg1_1, arg2_1, arg3_1, arg4_1):
|
||||
requires_grad=True,
|
||||
)
|
||||
|
||||
block_mask = create_block_mask(mask, None, None, 4096, 4096, device=device)
|
||||
block_mask = create_block_mask(mask, None, None, 4096, 4096)
|
||||
# Compile 1st version with q/k/v(seqlen=4096) and block_mask(seqlen=4096)
|
||||
torch.compile(flex_attention, dynamic=True, fullgraph=True)(
|
||||
make_tensor(), make_tensor(), make_tensor(), block_mask=block_mask
|
||||
@ -2692,7 +2698,7 @@ def forward(self, arg0_1, arg1_1, arg2_1, arg3_1, arg4_1):
|
||||
make_tensor2 = functools.partial(
|
||||
torch.randn,
|
||||
(4, 4, 2048, 64),
|
||||
device=device,
|
||||
device="cuda",
|
||||
dtype=torch.float32,
|
||||
requires_grad=True,
|
||||
)
|
||||
@ -2709,7 +2715,7 @@ def forward(self, arg0_1, arg1_1, arg2_1, arg3_1, arg4_1):
|
||||
k.grad = None
|
||||
v.grad = None
|
||||
|
||||
block_mask2 = create_block_mask(mask, None, None, 2048, 2048, device=device)
|
||||
block_mask2 = create_block_mask(mask, None, None, 2048, 2048)
|
||||
# Reuse the 1st version with q/k/v(seqlen=2048) and block_mask(seqlen=2048)
|
||||
out2 = torch.compile(flex_attention, dynamic=True, fullgraph=True)(
|
||||
q, k, v, block_mask=block_mask2
|
||||
@ -2815,11 +2821,11 @@ def forward(self, arg0_1, arg1_1, arg2_1, arg3_1, arg4_1):
|
||||
|
||||
@supported_platform
|
||||
@skip_on_cpu
|
||||
def test_strided_backwards(self, device):
|
||||
def test_strided_backwards(self):
|
||||
shape = (1, 2, 4096, 64)
|
||||
Q = torch.randn(shape, requires_grad=True, device=device)
|
||||
K = torch.randn(shape, requires_grad=True, device=device)
|
||||
V = torch.randn(shape, requires_grad=True, device=device)
|
||||
Q = torch.randn(shape, requires_grad=True, device="cuda")
|
||||
K = torch.randn(shape, requires_grad=True, device="cuda")
|
||||
V = torch.randn(shape, requires_grad=True, device="cuda")
|
||||
func = torch.compile(flex_attention, dynamic=True, fullgraph=True)
|
||||
|
||||
K_sliced = K[:, :, :-128]
|
||||
@ -3070,7 +3076,7 @@ def forward(self, arg0_1, arg1_1, arg2_1, arg3_1, arg4_1):
|
||||
make_tensor = functools.partial(
|
||||
torch.randn,
|
||||
(2, 2, 128, 64),
|
||||
device=device,
|
||||
device="cuda",
|
||||
dtype=torch.float32,
|
||||
requires_grad=True,
|
||||
)
|
||||
@ -3141,7 +3147,7 @@ def forward(self, arg0_1, arg1_1, arg2_1, arg3_1, arg4_1):
|
||||
|
||||
return score_mod
|
||||
|
||||
m = Attention().to(device).eval().to(dtype)
|
||||
m = Attention().cuda().eval().to(dtype)
|
||||
m = torch.compile(m, mode="default", fullgraph=False)
|
||||
|
||||
q = torch.randn(B, H, N, D, device=device, dtype=dtype)
|
||||
@ -3211,15 +3217,10 @@ def forward(self, arg0_1, arg1_1, arg2_1, arg3_1, arg4_1):
|
||||
return causal_mask & window_mask
|
||||
|
||||
sliding_window_causal = torch.nn.attention.flex_attention.create_block_mask(
|
||||
sliding_window_causal,
|
||||
B=None,
|
||||
H=None,
|
||||
Q_LEN=N_CTX,
|
||||
KV_LEN=N_CTX,
|
||||
device=device,
|
||||
sliding_window_causal, B=None, H=None, Q_LEN=N_CTX, KV_LEN=N_CTX
|
||||
)
|
||||
global_causal = torch.nn.attention.flex_attention.create_block_mask(
|
||||
global_causal, B=None, H=None, Q_LEN=N_CTX, KV_LEN=N_CTX, device=device
|
||||
global_causal, B=None, H=None, Q_LEN=N_CTX, KV_LEN=N_CTX
|
||||
)
|
||||
|
||||
local_attn = functools.partial(
|
||||
@ -3267,10 +3268,10 @@ def forward(self, arg0_1, arg1_1, arg2_1, arg3_1, arg4_1):
|
||||
def test_mixed_device_error_message(self, device):
|
||||
# Create tensors on different devices
|
||||
cpu_tensor = torch.randn(2, 2, 128, 16, device="cpu")
|
||||
gpu_tensor = torch.randn(2, 2, 128, 16, device=device)
|
||||
cuda_tensor = torch.randn(2, 2, 128, 16, device=device)
|
||||
|
||||
# Use different devices for query, key, and value
|
||||
query, key, value = cpu_tensor, gpu_tensor, cpu_tensor
|
||||
query, key, value = cpu_tensor, cuda_tensor, cpu_tensor
|
||||
|
||||
expected_error_message = (
|
||||
"Expected query, key, and value to have the same device type, "
|
||||
@ -3284,8 +3285,8 @@ def forward(self, arg0_1, arg1_1, arg2_1, arg3_1, arg4_1):
|
||||
@supported_platform
|
||||
@skip_on_cpu
|
||||
def test_captured_wrong_device_error_message(self, device):
|
||||
means = torch.randn(64, 3, device=device)
|
||||
length_scales = torch.logspace(0.001, 0.1, 8, device="cpu")
|
||||
means = torch.randn(64, 3).cuda()
|
||||
length_scales = torch.logspace(0.001, 0.1, 8)
|
||||
|
||||
def euclidean_dist_pos_embed(score, b, h, q_idx, k_idx):
|
||||
q_pos = means[q_idx]
|
||||
@ -3305,8 +3306,8 @@ def forward(self, arg0_1, arg1_1, arg2_1, arg3_1, arg4_1):
|
||||
@skip_on_cpu
|
||||
def test_cant_lower_error_message(self, device):
|
||||
# We can't lower a 256-element reduction inside a pointwise reduction
|
||||
means = torch.randn(64, 256, device=device)
|
||||
length_scales = torch.logspace(0.001, 0.1, 8, device=device)
|
||||
means = torch.randn(64, 256).cuda()
|
||||
length_scales = torch.logspace(0.001, 0.1, 8).cuda()
|
||||
|
||||
def euclidean_dist_pos_embed(score, b, h, q_idx, k_idx):
|
||||
q_pos = means[q_idx]
|
||||
@ -3326,8 +3327,8 @@ def forward(self, arg0_1, arg1_1, arg2_1, arg3_1, arg4_1):
|
||||
@skip_on_cpu
|
||||
def test_reduction_unrolled(self, device):
|
||||
# We can't lower a 256-element reduction inside a pointwise reduction
|
||||
means = torch.randn(S, 3, device=device)
|
||||
length_scales = torch.logspace(0.001, 0.1, H, device=device)
|
||||
means = torch.randn(S, 3).to(device)
|
||||
length_scales = torch.logspace(0.001, 0.1, H).to(device)
|
||||
|
||||
def euclidean_dist_pos_embed(score, b, h, q_idx, k_idx):
|
||||
q_pos = means[q_idx]
|
||||
@ -3337,7 +3338,7 @@ def forward(self, arg0_1, arg1_1, arg2_1, arg3_1, arg4_1):
|
||||
inv_dist = torch.exp(-dist / scale)
|
||||
return inv_dist * score
|
||||
|
||||
self.run_test(euclidean_dist_pos_embed, torch.bfloat16, device=device)
|
||||
self.run_test(euclidean_dist_pos_embed, torch.bfloat16, device)
|
||||
|
||||
@supported_platform
|
||||
@skip_on_cpu
|
||||
@ -3348,9 +3349,7 @@ def forward(self, arg0_1, arg1_1, arg2_1, arg3_1, arg4_1):
|
||||
expected_error_message = (
|
||||
"ValueError: Q and KV block size must be divisible by BLOCK_M and BLOCK_N."
|
||||
)
|
||||
block_mask = create_block_mask(
|
||||
noop_mask, 1, 8, 128, 128, BLOCK_SIZE=96, device=device
|
||||
)
|
||||
block_mask = create_block_mask(noop_mask, 1, 8, 128, 128, BLOCK_SIZE=96)
|
||||
|
||||
with self.assertRaisesRegex(RuntimeError, expected_error_message):
|
||||
torch.compile(flex_attention)(q, k, v, block_mask=block_mask)
|
||||
@ -3417,7 +3416,6 @@ def forward(self, arg0_1, arg1_1, arg2_1, arg3_1, arg4_1):
|
||||
H=None,
|
||||
Q_LEN=max_time,
|
||||
KV_LEN=max_time,
|
||||
device=device,
|
||||
)
|
||||
|
||||
x = torch.compile(
|
||||
@ -3430,7 +3428,7 @@ def forward(self, arg0_1, arg1_1, arg2_1, arg3_1, arg4_1):
|
||||
)
|
||||
return x
|
||||
|
||||
model = Model(128).to(device)
|
||||
model = Model(128).cuda()
|
||||
B, F, T = 16, 256, 12
|
||||
for _ in range(5):
|
||||
x = torch.randn(B, T, F, device=device)
|
||||
@ -3601,7 +3599,7 @@ def forward(self, arg0_1, arg1_1, arg2_1, arg3_1, arg4_1):
|
||||
)
|
||||
return y.transpose(1, 2).contiguous().view(B, T, C)
|
||||
|
||||
model = SimpleAttention().to(device)
|
||||
model = SimpleAttention().cuda()
|
||||
model.compile(mode="default", dynamic=True)
|
||||
sequence_len = 256
|
||||
|
||||
@ -3609,9 +3607,7 @@ def forward(self, arg0_1, arg1_1, arg2_1, arg3_1, arg4_1):
|
||||
torch._dynamo.reset()
|
||||
for batch_shape in [4, 16, 32]:
|
||||
# Create dense mask
|
||||
rand_mask = torch.randint(
|
||||
0, 2, (batch_shape, sequence_len), device=device
|
||||
).bool()
|
||||
rand_mask = torch.randint(0, 2, (batch_shape, sequence_len)).cuda().bool()
|
||||
block_mask = torch.compile(create_block_mask, dynamic=True)(
|
||||
B=batch_shape,
|
||||
BLOCK_SIZE=128,
|
||||
@ -3623,7 +3619,7 @@ def forward(self, arg0_1, arg1_1, arg2_1, arg3_1, arg4_1):
|
||||
)
|
||||
|
||||
# Run forward pass
|
||||
x = torch.randn(batch_shape, sequence_len, 512, device=device)
|
||||
x = torch.randn(batch_shape, sequence_len, 512).cuda()
|
||||
model(x, block_mask=block_mask)
|
||||
|
||||
self.assertEqual(torch._dynamo.utils.counters["aot_autograd"]["ok"], 2)
|
||||
@ -3651,7 +3647,7 @@ def forward(self, arg0_1, arg1_1, arg2_1, arg3_1, arg4_1):
|
||||
block_mask=block_mask,
|
||||
)
|
||||
|
||||
model = SimpleAttention().to(device)
|
||||
model = SimpleAttention().cuda()
|
||||
from torch._dynamo.testing import EagerAndRecordGraphs
|
||||
|
||||
backend = EagerAndRecordGraphs()
|
||||
@ -3660,7 +3656,7 @@ def forward(self, arg0_1, arg1_1, arg2_1, arg3_1, arg4_1):
|
||||
|
||||
torch._dynamo.reset()
|
||||
for batch_shape in [4, 16, 32]:
|
||||
x = torch.randn(batch_shape, sequence_len, 512, device=device)
|
||||
x = torch.randn(batch_shape, sequence_len, 512).cuda()
|
||||
model(x)
|
||||
self.assertEqual(len(backend.graphs), 1)
|
||||
self.assertExpectedInline(
|
||||
@ -3687,7 +3683,7 @@ def forward(self, child : torch.Tensor, child_1 : torch.Tensor, child_2 : torch.
|
||||
def causal_mask(b, h, q_idx, kv_idx):
|
||||
return q_idx >= kv_idx
|
||||
|
||||
block_mask = create_block_mask(causal_mask, 1, 1, 128, 128, device=device)
|
||||
block_mask = create_block_mask(causal_mask, 1, 1, 128, 128)
|
||||
|
||||
func = torch.compile(flex_attention, backend=cnt, fullgraph=True)
|
||||
out = func(query, key, value, _squared, block_mask=block_mask)
|
||||
@ -3747,10 +3743,13 @@ class GraphModule(torch.nn.Module):
|
||||
out.sum().backward()
|
||||
|
||||
joint_graph = normalize_gm(aot_graphs[1].print_readable(print_output=False))
|
||||
expected_joint_graph = """\
|
||||
|
||||
self.assertExpectedInline(
|
||||
joint_graph,
|
||||
"""\
|
||||
class GraphModule(torch.nn.Module):
|
||||
def forward(self, primals_1: "f64[2, 2, 128, 4]", primals_2: "f64[2, 2, 128, 4]", primals_3: "f64[2, 2, 128, 4]", full: "i32[1, 1, 1]", full_default: "i32[1, 1, 1, 1]", convert_element_type: "i32[1, 1, 1]", convert_element_type_1: "i32[1, 1, 1, 1]", getitem_2: "f64[2, 2, 128, 4]", getitem_3: "f32[2, 2, 128]", tangents_1: "f64[2, 2, 128, 4]"):
|
||||
full_default_4: "f32[2, 2, 128]" = torch.ops.aten.full.default([2, 2, 128], 0, dtype = torch.float32, layout = torch.strided, device = device(type='GPU_TYPE', index=0), pin_memory = False)
|
||||
full_default_4: "f32[2, 2, 128]" = torch.ops.aten.full.default([2, 2, 128], 0, dtype = torch.float32, layout = torch.strided, device = device(type='cuda', index=0), pin_memory = False)
|
||||
fw_graph0 = self.fw_graph0
|
||||
joint_graph0 = self.joint_graph0
|
||||
mask_graph0 = self.mask_graph0
|
||||
@ -3775,15 +3774,9 @@ class GraphModule(torch.nn.Module):
|
||||
|
||||
class mask_graph0(torch.nn.Module):
|
||||
def forward(self, arg0_1: "i32[]", arg1_1: "i32[]", arg2_1: "i32[]", arg3_1: "i32[]"):
|
||||
full: "b8[]" = torch.ops.aten.full.default([], True, dtype = torch.bool, layout = torch.strided, device = device(type='GPU_TYPE', index=0), pin_memory = False)
|
||||
full: "b8[]" = torch.ops.aten.full.default([], True, dtype = torch.bool, layout = torch.strided, device = device(type='cuda', index=0), pin_memory = False)
|
||||
return full
|
||||
""".replace( # noqa: B950
|
||||
"GPU_TYPE", torch.device(device).type
|
||||
)
|
||||
|
||||
self.assertExpectedInline(
|
||||
joint_graph,
|
||||
expected_joint_graph,
|
||||
""", # noqa: B950
|
||||
)
|
||||
|
||||
@supported_platform
|
||||
@ -3833,7 +3826,7 @@ class GraphModule(torch.nn.Module):
|
||||
flex_attention(q, k, v)
|
||||
|
||||
# compiled cpu support for small embedding size
|
||||
q, k, v = [torch.randn(2, 2, 128, 8, device=device) for _ in range(3)]
|
||||
q, k, v = [torch.randn(2, 2, 128, 8, device="cpu") for _ in range(3)]
|
||||
flex_attention(q, k, v)
|
||||
|
||||
# compiled gpu kernel does not support small embedding size
|
||||
@ -3905,7 +3898,7 @@ class TestBlockMask(InductorTestCase):
|
||||
def causal_mask(b, h, q, kv):
|
||||
return (q + (offset[b] * 128)) >= kv
|
||||
|
||||
block_mask = create_block_mask(causal_mask, 4, 2, 2048, 2048, device=device)
|
||||
block_mask = create_block_mask(causal_mask, 4, 2, 2048, 2048)
|
||||
self.assertEqual(block_mask.shape, (4, 2, 2048, 2048))
|
||||
self.assertEqual(block_mask[0].shape, (2, 2048, 2048))
|
||||
self.assertEqual(block_mask[0, 0].shape, (2048, 2048))
|
||||
@ -3916,7 +3909,7 @@ class TestBlockMask(InductorTestCase):
|
||||
self.assertEqual(block_mask.sparsity(), block_mask[1].sparsity())
|
||||
|
||||
offset = torch.arange(8, device=device)
|
||||
block_mask = create_block_mask(causal_mask, 8, 1, 2048, 2048, device=device)
|
||||
block_mask = create_block_mask(causal_mask, 8, 1, 2048, 2048)
|
||||
self.assertEqual(block_mask.sparsity(), 29.1015625)
|
||||
self.assertTrue(block_mask.sparsity() < block_mask[0].sparsity())
|
||||
self.assertTrue(block_mask[0].sparsity() > block_mask[1].sparsity())
|
||||
@ -3933,7 +3926,7 @@ class TestBlockMask(InductorTestCase):
|
||||
Q_BLOCK_SIZE, KV_BLOCK_SIZE = BLOCK_SIZE
|
||||
|
||||
block_mask = create_block_mask(
|
||||
noop_mask, B, H, Q_LEN, KV_LEN, BLOCK_SIZE=BLOCK_SIZE, device=device
|
||||
noop_mask, B, H, Q_LEN, KV_LEN, BLOCK_SIZE=BLOCK_SIZE
|
||||
)
|
||||
|
||||
self.assertEqual(block_mask.BLOCK_SIZE, (Q_BLOCK_SIZE, KV_BLOCK_SIZE))
|
||||
@ -3946,7 +3939,7 @@ class TestBlockMask(InductorTestCase):
|
||||
def causal_mask(b, h, q, kv):
|
||||
return (q + (offset[b] * 128)) >= kv
|
||||
|
||||
block_mask = create_block_mask(causal_mask, 4, 2, 512, 512, device=device)
|
||||
block_mask = create_block_mask(causal_mask, 4, 2, 512, 512)
|
||||
assert block_mask.kv_num_blocks.shape == (4, 2, 4)
|
||||
assert block_mask.kv_indices.shape == (4, 2, 4, 4)
|
||||
|
||||
@ -3998,17 +3991,16 @@ class TestBlockMask(InductorTestCase):
|
||||
|
||||
@supported_platform
|
||||
def test_block_mask_device_change(self, device):
|
||||
device = torch.device(device)
|
||||
offset = torch.zeros(8, device=device)
|
||||
|
||||
def causal_mask(b, h, q, kv):
|
||||
return (q + (offset[b] * 128)) >= kv
|
||||
|
||||
block_mask = create_block_mask(causal_mask, 1, 1, 512, 512, device=device)
|
||||
assert block_mask.kv_indices.device.type == device.type
|
||||
assert block_mask.kv_num_blocks.device.type == device.type
|
||||
assert block_mask.q_indices.device.type == device.type
|
||||
assert block_mask.q_num_blocks.device.type == device.type
|
||||
block_mask = create_block_mask(causal_mask, 1, 1, 512, 512)
|
||||
assert block_mask.kv_indices.is_cuda
|
||||
assert block_mask.kv_num_blocks.is_cuda
|
||||
assert block_mask.q_indices.is_cuda
|
||||
assert block_mask.q_num_blocks.is_cuda
|
||||
|
||||
block_mask = block_mask.to("cpu")
|
||||
assert block_mask.kv_indices.is_cpu
|
||||
@ -4016,11 +4008,11 @@ class TestBlockMask(InductorTestCase):
|
||||
assert block_mask.q_indices.is_cpu
|
||||
assert block_mask.q_num_blocks.is_cpu
|
||||
|
||||
block_mask = block_mask.to(device)
|
||||
assert block_mask.kv_indices.device.type == device.type
|
||||
assert block_mask.kv_num_blocks.device.type == device.type
|
||||
assert block_mask.q_indices.device.type == device.type
|
||||
assert block_mask.q_num_blocks.device.type == device.type
|
||||
block_mask = block_mask.to("cuda")
|
||||
assert block_mask.kv_indices.is_cuda
|
||||
assert block_mask.kv_num_blocks.is_cuda
|
||||
assert block_mask.q_indices.is_cuda
|
||||
assert block_mask.q_num_blocks.is_cuda
|
||||
|
||||
@supported_platform
|
||||
def test_compiling_create_block_mask(self, device):
|
||||
@ -4030,7 +4022,7 @@ class TestBlockMask(InductorTestCase):
|
||||
return (q >= kv) & (seq[q] == seq[kv])
|
||||
|
||||
block_mask = torch.compile(create_block_mask, fullgraph=True)(
|
||||
mask_mod, 1, 1, 512, 512, device=device
|
||||
mask_mod, 1, 1, 512, 512
|
||||
)
|
||||
self.assertIsInstance(block_mask, BlockMask)
|
||||
self.assertEqual(block_mask.kv_num_blocks.shape, torch.Size((1, 1, 4)))
|
||||
@ -4042,27 +4034,21 @@ class TestBlockMask(InductorTestCase):
|
||||
return q >= kv
|
||||
|
||||
torch._dynamo.reset()
|
||||
block_mask = torch.compile(create_block_mask)(
|
||||
mask_mod, 2, 4, 1024, 1024, device=device
|
||||
)
|
||||
block_mask = torch.compile(create_block_mask)(mask_mod, 2, 4, 1024, 1024)
|
||||
self.assertIsInstance(block_mask, BlockMask)
|
||||
self.assertEqual(block_mask.kv_num_blocks.shape, torch.Size((2, 4, 8)))
|
||||
self.assertEqual(block_mask.kv_indices.shape, torch.Size((2, 4, 8, 8)))
|
||||
self.assertEqual(torch._dynamo.utils.counters["aot_autograd"]["ok"], 1)
|
||||
|
||||
# automatic dynamic shapes triggered and recompilation.
|
||||
block_mask = torch.compile(create_block_mask)(
|
||||
mask_mod, 4, 8, 2048, 2048, device=device
|
||||
)
|
||||
block_mask = torch.compile(create_block_mask)(mask_mod, 4, 8, 2048, 2048)
|
||||
self.assertIsInstance(block_mask, BlockMask)
|
||||
self.assertEqual(block_mask.kv_num_blocks.shape, torch.Size((4, 8, 16)))
|
||||
self.assertEqual(block_mask.kv_indices.shape, torch.Size((4, 8, 16, 16)))
|
||||
self.assertEqual(torch._dynamo.utils.counters["aot_autograd"]["ok"], 2)
|
||||
|
||||
# no recompilation.
|
||||
block_mask = torch.compile(create_block_mask)(
|
||||
mask_mod, 6, 16, 3072, 3072, device=device
|
||||
)
|
||||
block_mask = torch.compile(create_block_mask)(mask_mod, 6, 16, 3072, 3072)
|
||||
self.assertIsInstance(block_mask, BlockMask)
|
||||
self.assertEqual(block_mask.kv_num_blocks.shape, torch.Size((6, 16, 24)))
|
||||
self.assertEqual(block_mask.kv_indices.shape, torch.Size((6, 16, 24, 24)))
|
||||
@ -4073,7 +4059,7 @@ class TestBlockMask(InductorTestCase):
|
||||
def causal_mask(b, h, q, kv):
|
||||
return q >= kv
|
||||
|
||||
block_mask = create_block_mask(causal_mask, 1, 1, 2048, 2048, device=device)
|
||||
block_mask = create_block_mask(causal_mask, 1, 1, 2048, 2048)
|
||||
|
||||
def replace_non_printable(s):
|
||||
def replace(c):
|
||||
@ -4114,9 +4100,7 @@ BlockMask(shape=(1,s1,s2048,s2048),ssparsity=46.88%,s
|
||||
def causal_offset_mask(b, h, q, kv):
|
||||
return (q + offset[b] * 128) >= kv
|
||||
|
||||
block_mask = create_block_mask(
|
||||
causal_offset_mask, 8, 1, 2048, 2048, device=device
|
||||
)
|
||||
block_mask = create_block_mask(causal_offset_mask, 8, 1, 2048, 2048)
|
||||
str_block_mask = str(block_mask)
|
||||
self.assertTrue("sparsity=29.10" in str_block_mask)
|
||||
|
||||
@ -4272,7 +4256,7 @@ BlockMask(shape=(1,s1,s2048,s2048),ssparsity=46.88%,s
|
||||
def causal_mask(b, h, q_idx, kv_idx):
|
||||
return q_idx >= kv_idx
|
||||
|
||||
block_mask = create_block_mask(causal_mask, 1, 1, 2048, 2048, device=device)
|
||||
block_mask = create_block_mask(causal_mask, 1, 1, 2048, 2048)
|
||||
# manually set q_num_blocks and q_indices to None
|
||||
block_mask.q_num_blocks = None
|
||||
block_mask.q_indices = None
|
||||
@ -4341,6 +4325,7 @@ BlockMask(shape=(1,s1,s2048,s2048),ssparsity=46.88%,s
|
||||
lengths[index] += 1
|
||||
return lengths
|
||||
|
||||
device = "cuda"
|
||||
max_seq_len, doc_count = 128, 4
|
||||
SEQ_LEN = max_seq_len
|
||||
|
||||
@ -4370,7 +4355,7 @@ BlockMask(shape=(1,s1,s2048,s2048),ssparsity=46.88%,s
|
||||
)
|
||||
for i in range(5):
|
||||
lengths = generate_random_lengths(1024 + i, 5)
|
||||
offsets = length_to_offsets(lengths, device)
|
||||
offsets = length_to_offsets(lengths, "cuda")
|
||||
doc_ids = _offsets_to_doc_ids_tensor(offsets)
|
||||
|
||||
def doc_mask_mod(b, h, q_idx, kv_idx):
|
||||
@ -4382,9 +4367,7 @@ BlockMask(shape=(1,s1,s2048,s2048),ssparsity=46.88%,s
|
||||
q, k, v = (
|
||||
torch.randn(1, 12, 1024 + i, 64, device=device) for _ in range(3)
|
||||
)
|
||||
block_mask = create_block_mask(
|
||||
doc_mask_mod, None, None, 1024 + i, 1024 + i, device=device
|
||||
)
|
||||
block_mask = create_block_mask(doc_mask_mod, None, None, 1024 + i, 1024 + i)
|
||||
torch.compile(flex_attention)(q, k, v, block_mask=block_mask)
|
||||
|
||||
@supported_platform
|
||||
@ -4449,16 +4432,17 @@ BlockMask(shape=(1,s1,s2048,s2048),ssparsity=46.88%,s
|
||||
)
|
||||
return q, k, v
|
||||
|
||||
block_mask = create_block_mask(mask_mod, None, None, 1024, 1024, device=device)
|
||||
block_mask = create_block_mask(mask_mod, None, None, 1024, 1024)
|
||||
flex_attention_call(*create_inputs(1024), block_mask=block_mask)
|
||||
with self.assertRaisesRegex(ValueError, "block_mask was created for"):
|
||||
flex_attention_call(*create_inputs(2048), block_mask=block_mask)
|
||||
|
||||
block_mask = create_block_mask(mask_mod, None, None, 1023, 1023, device=device)
|
||||
block_mask = create_block_mask(mask_mod, None, None, 1023, 1023)
|
||||
with self.assertRaisesRegex(ValueError, "block_mask was created for"):
|
||||
flex_attention_call(*create_inputs(1024), block_mask=block_mask)
|
||||
|
||||
|
||||
@large_tensor_test_class("2GB", device="cuda")
|
||||
class TestPagedAttention(InductorTestCase):
|
||||
def setUp(self):
|
||||
super().setUp()
|
||||
@ -4908,18 +4892,20 @@ supports_learnable_bias = unittest.skipUnless(
|
||||
|
||||
|
||||
@supports_learnable_bias
|
||||
@large_tensor_test_class("2GB", device="cuda")
|
||||
class TestLearnableBiases(InductorTestCase):
|
||||
def setUp(self):
|
||||
super().setUp()
|
||||
self.device = "cuda"
|
||||
self.dtype = torch.float32
|
||||
self.atol = 3e-2
|
||||
self.rtol = 3e-2
|
||||
|
||||
def _init_tensors(self, params: Params, device: str):
|
||||
def _init_tensors(self, params: Params):
|
||||
make_tensor = functools.partial(
|
||||
torch.randn,
|
||||
(params.batch_size, params.num_heads, params.seq_length, params.head_dim),
|
||||
device=device,
|
||||
device=self.device,
|
||||
dtype=params.dtype,
|
||||
requires_grad=True,
|
||||
)
|
||||
@ -4980,11 +4966,11 @@ class TestLearnableBiases(InductorTestCase):
|
||||
"params", get_params(device_configs["cuda"].dtypes), name_fn=lambda x: f"{x}"
|
||||
)
|
||||
@common_utils.parametrize("mode", ["default", "max-autotune-no-cudagraphs"])
|
||||
def test_relative_1d_bias(self, device, params, mode: str):
|
||||
query, key, value = self._init_tensors(params, device=device)
|
||||
def test_relative_1d_bias(self, params, mode: str):
|
||||
query, key, value = self._init_tensors(params)
|
||||
bias = torch.randn(
|
||||
2 * params.seq_length,
|
||||
device=device,
|
||||
device=self.device,
|
||||
dtype=params.dtype,
|
||||
requires_grad=True,
|
||||
)
|
||||
@ -5012,12 +4998,12 @@ class TestLearnableBiases(InductorTestCase):
|
||||
@common_utils.parametrize(
|
||||
"params", get_params(device_configs["cuda"].dtypes), name_fn=lambda x: f"{x}"
|
||||
)
|
||||
def test_absolute_2d_bias(self, device, params):
|
||||
query, key, value = self._init_tensors(params, device=device)
|
||||
def test_absolute_2d_bias(self, params):
|
||||
query, key, value = self._init_tensors(params)
|
||||
bias = torch.randn(
|
||||
params.seq_length,
|
||||
params.seq_length,
|
||||
device=device,
|
||||
device=self.device,
|
||||
dtype=params.dtype,
|
||||
requires_grad=True,
|
||||
)
|
||||
@ -5045,13 +5031,13 @@ class TestLearnableBiases(InductorTestCase):
|
||||
@common_utils.parametrize(
|
||||
"params", get_params(device_configs["cuda"].dtypes), name_fn=lambda x: f"{x}"
|
||||
)
|
||||
def test_head_specific_bias(self, device, params):
|
||||
query, key, value = self._init_tensors(params, device=device)
|
||||
def test_head_specific_bias(self, params):
|
||||
query, key, value = self._init_tensors(params)
|
||||
bias = torch.randn(
|
||||
params.num_heads,
|
||||
params.seq_length,
|
||||
params.seq_length,
|
||||
device=device,
|
||||
device=self.device,
|
||||
dtype=params.dtype,
|
||||
requires_grad=True,
|
||||
)
|
||||
@ -5079,14 +5065,14 @@ class TestLearnableBiases(InductorTestCase):
|
||||
@common_utils.parametrize(
|
||||
"params", get_params(device_configs["cuda"].dtypes), name_fn=lambda x: f"{x}"
|
||||
)
|
||||
def test_batch_head_bias(self, device, params):
|
||||
query, key, value = self._init_tensors(params, device=device)
|
||||
def test_batch_head_bias(self, params):
|
||||
query, key, value = self._init_tensors(params)
|
||||
bias = torch.randn(
|
||||
params.batch_size,
|
||||
params.num_heads,
|
||||
params.seq_length,
|
||||
params.seq_length,
|
||||
device=device,
|
||||
device=self.device,
|
||||
dtype=params.dtype,
|
||||
requires_grad=True,
|
||||
)
|
||||
@ -5114,11 +5100,11 @@ class TestLearnableBiases(InductorTestCase):
|
||||
@common_utils.parametrize(
|
||||
"params", get_params(device_configs["cuda"].dtypes), name_fn=lambda x: f"{x}"
|
||||
)
|
||||
def test_multiplicative_bias(self, device, params):
|
||||
query, key, value = self._init_tensors(params, device=device)
|
||||
def test_multiplicative_bias(self, params):
|
||||
query, key, value = self._init_tensors(params)
|
||||
bias = torch.randn(
|
||||
params.seq_length,
|
||||
device=device,
|
||||
device=self.device,
|
||||
dtype=params.dtype,
|
||||
requires_grad=True,
|
||||
)
|
||||
@ -5146,12 +5132,12 @@ class TestLearnableBiases(InductorTestCase):
|
||||
@common_utils.parametrize(
|
||||
"params", get_params(device_configs["cuda"].dtypes), name_fn=lambda x: f"{x}"
|
||||
)
|
||||
def test_local_window_bias(self, device, params):
|
||||
query, key, value = self._init_tensors(params, device=device)
|
||||
def test_local_window_bias(self, params):
|
||||
query, key, value = self._init_tensors(params)
|
||||
window_size = 8
|
||||
bias = torch.randn(
|
||||
2 * window_size + 1,
|
||||
device=device,
|
||||
device=self.device,
|
||||
dtype=params.dtype,
|
||||
requires_grad=True,
|
||||
)
|
||||
@ -5180,11 +5166,11 @@ class TestLearnableBiases(InductorTestCase):
|
||||
@common_utils.parametrize(
|
||||
"params", get_params(device_configs["cuda"].dtypes), name_fn=lambda x: f"{x}"
|
||||
)
|
||||
def test_global_tokens_bias(self, device, params):
|
||||
query, key, value = self._init_tensors(params, device=device)
|
||||
def test_global_tokens_bias(self, params):
|
||||
query, key, value = self._init_tensors(params)
|
||||
bias = torch.randn(
|
||||
params.seq_length,
|
||||
device=device,
|
||||
device=self.device,
|
||||
dtype=params.dtype,
|
||||
requires_grad=True,
|
||||
)
|
||||
@ -5212,18 +5198,18 @@ class TestLearnableBiases(InductorTestCase):
|
||||
@common_utils.parametrize(
|
||||
"params", get_params(device_configs["cuda"].dtypes), name_fn=lambda x: f"{x}"
|
||||
)
|
||||
def test_weird_bias(self, device, params):
|
||||
query, key, value = self._init_tensors(params, device=device)
|
||||
def test_weird_bias(self, params):
|
||||
query, key, value = self._init_tensors(params)
|
||||
bias = torch.randn(
|
||||
params.batch_size,
|
||||
params.num_heads,
|
||||
4,
|
||||
params.seq_length,
|
||||
device=device,
|
||||
device=self.device,
|
||||
dtype=params.dtype,
|
||||
requires_grad=True,
|
||||
)
|
||||
which_bias = torch.tensor(0, device=device)
|
||||
which_bias = torch.tensor(0, device=self.device)
|
||||
|
||||
def bias_func(score, b, h, q_idx, kv_idx):
|
||||
return score + bias[b, h, which_bias, q_idx]
|
||||
@ -5248,11 +5234,11 @@ class TestLearnableBiases(InductorTestCase):
|
||||
@common_utils.parametrize(
|
||||
"params", get_params(device_configs["cuda"].dtypes), name_fn=lambda x: f"{x}"
|
||||
)
|
||||
def test_indirect_bias(self, device, params):
|
||||
query, key, value = self._init_tensors(params, device=device)
|
||||
def test_indirect_bias(self, params):
|
||||
query, key, value = self._init_tensors(params)
|
||||
bias = torch.randn(
|
||||
params.seq_length,
|
||||
device=device,
|
||||
device=self.device,
|
||||
dtype=params.dtype,
|
||||
requires_grad=True,
|
||||
)
|
||||
@ -5261,7 +5247,7 @@ class TestLearnableBiases(InductorTestCase):
|
||||
0,
|
||||
params.seq_length,
|
||||
(params.seq_length,),
|
||||
device=device,
|
||||
device=self.device,
|
||||
)
|
||||
|
||||
def bias_func(score, b, h, q_idx, kv_idx):
|
||||
@ -5288,11 +5274,11 @@ class TestLearnableBiases(InductorTestCase):
|
||||
"params", get_params(device_configs["cuda"].dtypes), name_fn=lambda x: f"{x}"
|
||||
)
|
||||
@common_utils.parametrize("mode", ["default", "max-autotune-no-cudagraphs"])
|
||||
def test_symmetric_bias(self, device, params, mode: str):
|
||||
query, key, value = self._init_tensors(params, device=device)
|
||||
def test_symmetric_bias(self, params, mode: str):
|
||||
query, key, value = self._init_tensors(params)
|
||||
bias = torch.randn(
|
||||
params.seq_length,
|
||||
device=device,
|
||||
device=self.device,
|
||||
dtype=params.dtype,
|
||||
requires_grad=True,
|
||||
)
|
||||
@ -5324,12 +5310,12 @@ class TestLearnableBiases(InductorTestCase):
|
||||
@common_utils.parametrize(
|
||||
"params", get_params(device_configs["cuda"].dtypes), name_fn=lambda x: f"{x}"
|
||||
)
|
||||
def test_flipped_indexed_bias(self, device, params):
|
||||
query, key, value = self._init_tensors(params, device=device)
|
||||
def test_flipped_indexed_bias(self, params):
|
||||
query, key, value = self._init_tensors(params)
|
||||
bias = torch.randn(
|
||||
params.seq_length,
|
||||
params.seq_length,
|
||||
device=device,
|
||||
device=self.device,
|
||||
dtype=params.dtype,
|
||||
requires_grad=True,
|
||||
)
|
||||
@ -5358,11 +5344,11 @@ class TestLearnableBiases(InductorTestCase):
|
||||
"params", get_params(device_configs["cuda"].dtypes), name_fn=lambda x: f"{x}"
|
||||
)
|
||||
@common_utils.parametrize("mode", ["default", "max-autotune-no-cudagraphs"])
|
||||
def test_head_specific_gate(self, device, params, mode: str):
|
||||
query, key, value = self._init_tensors(params, device=device)
|
||||
def test_head_specific_gate(self, params, mode: str):
|
||||
query, key, value = self._init_tensors(params)
|
||||
gate_score = torch.randn(
|
||||
params.num_heads,
|
||||
device=device,
|
||||
device=self.device,
|
||||
dtype=params.dtype,
|
||||
requires_grad=True,
|
||||
)
|
||||
@ -5390,18 +5376,18 @@ class TestLearnableBiases(InductorTestCase):
|
||||
@common_utils.parametrize(
|
||||
"params", get_params(device_configs["cuda"].dtypes), name_fn=lambda x: f"{x}"
|
||||
)
|
||||
def test_distinct_biases(self, device, params):
|
||||
query, key, value = self._init_tensors(params, device=device)
|
||||
def test_distinct_biases(self, params):
|
||||
query, key, value = self._init_tensors(params)
|
||||
# Create two separate bias tensors
|
||||
bias1 = torch.randn(
|
||||
params.seq_length,
|
||||
device=device,
|
||||
device=self.device,
|
||||
dtype=params.dtype,
|
||||
requires_grad=True,
|
||||
)
|
||||
bias2 = torch.randn(
|
||||
params.seq_length,
|
||||
device=device,
|
||||
device=self.device,
|
||||
dtype=params.dtype,
|
||||
requires_grad=True,
|
||||
)
|
||||
@ -5438,8 +5424,8 @@ class TestLearnableBiases(InductorTestCase):
|
||||
@common_utils.parametrize(
|
||||
"params", get_params(device_configs["cuda"].dtypes), name_fn=lambda x: f"{x}"
|
||||
)
|
||||
def test_relative_1d_bias_only_grad(self, device, params):
|
||||
query, key, value = self._init_tensors(params, device=device)
|
||||
def test_relative_1d_bias_only_grad(self, params):
|
||||
query, key, value = self._init_tensors(params)
|
||||
query = query.detach().requires_grad_(False)
|
||||
key = key.detach().requires_grad_(False)
|
||||
value = value.detach().requires_grad_(False)
|
||||
@ -5447,7 +5433,7 @@ class TestLearnableBiases(InductorTestCase):
|
||||
# Only bias requires gradients
|
||||
bias = torch.randn(
|
||||
2 * params.seq_length,
|
||||
device=device,
|
||||
device=self.device,
|
||||
dtype=params.dtype,
|
||||
requires_grad=True, # Only bias needs gradients
|
||||
)
|
||||
@ -5471,10 +5457,10 @@ class TestLearnableBiases(InductorTestCase):
|
||||
out_eager, out_compiled, out_gold, (bias,), names=["out", "bias"]
|
||||
)
|
||||
|
||||
def test_flex_attention_with_dynamic_max_autotune(self, device):
|
||||
query = torch.randn(2, 16, 512, 64, device=device)
|
||||
key = torch.randn(2, 16, 512, 64, device=device)
|
||||
value = torch.randn(2, 16, 512, 64, device=device)
|
||||
def test_flex_attention_with_dynamic_max_autotune(self):
|
||||
query = torch.randn(2, 16, 512, 64, device=self.device)
|
||||
key = torch.randn(2, 16, 512, 64, device=self.device)
|
||||
value = torch.randn(2, 16, 512, 64, device=self.device)
|
||||
query.requires_grad = True
|
||||
key.requires_grad = True
|
||||
value.requires_grad = True
|
||||
@ -5488,9 +5474,7 @@ class TestLearnableBiases(InductorTestCase):
|
||||
return m >= n
|
||||
|
||||
mask_shape = (1, 1, M, N)
|
||||
block_mask = torch.compile(create_block_mask)(
|
||||
causal, *mask_shape, device=device
|
||||
)
|
||||
block_mask = torch.compile(create_block_mask)(causal, *mask_shape, "cuda")
|
||||
|
||||
compiled_sdpa = torch.compile(
|
||||
flex_attention, dynamic=True, mode="max-autotune-no-cudagraphs"
|
||||
@ -5511,23 +5495,26 @@ class TestLearnableBiases(InductorTestCase):
|
||||
out.shape, query.shape, f"Expected shape {query.shape}, got {out.shape}"
|
||||
)
|
||||
|
||||
def test_inspect_bug(self, device):
|
||||
def test_inspect_bug(self):
|
||||
# https://github.com/pytorch/pytorch/issues/139374
|
||||
def sliding_window(b, h, q_idx, kv_idx, val):
|
||||
return (q_idx - kv_idx).abs() < val
|
||||
|
||||
sliding_window2 = functools.partial(
|
||||
sliding_window, val=torch.randn((), device=device)
|
||||
sliding_window, val=torch.randn((), device=self.device)
|
||||
)
|
||||
opt_fn = torch.compile(create_block_mask, fullgraph=True)
|
||||
create_block_mask(sliding_window2, None, None, 1024, 1024, device=device)
|
||||
create_block_mask(sliding_window2, None, None, 1024, 1024)
|
||||
# checks that the compile is working
|
||||
opt_fn(sliding_window2, None, None, 1024, 1024, device=device)
|
||||
opt_fn(sliding_window2, None, None, 1024, 1024)
|
||||
|
||||
@supported_platform
|
||||
def test_head_bias_req_grad(self, device):
|
||||
def test_head_bias_req_grad(self):
|
||||
device = self.device
|
||||
B, H, S, D = 1, 4, 256, 64
|
||||
bias = torch.randn(H, device=device, dtype=torch.float16, requires_grad=True)
|
||||
bias = torch.randn(
|
||||
H, device=self.device, dtype=torch.float16, requires_grad=True
|
||||
)
|
||||
|
||||
bias_flex = bias.detach().clone().requires_grad_(True)
|
||||
|
||||
@ -5558,7 +5545,8 @@ class TestLearnableBiases(InductorTestCase):
|
||||
)
|
||||
|
||||
@supported_platform
|
||||
def test_comparison_vs_sdpa_with_learnable_bias(self, device):
|
||||
def test_comparison_vs_sdpa_with_learnable_bias(self):
|
||||
device = self.device
|
||||
# 1-dimensional bias:
|
||||
B, H, S, D = 1, 1, 256, 64
|
||||
bias = torch.randn(
|
||||
@ -5771,7 +5759,8 @@ class TestLearnableBiases(InductorTestCase):
|
||||
instantiate_device_type_tests(TestFlexAttention, globals(), only_for=test_device)
|
||||
instantiate_device_type_tests(TestPagedAttention, globals(), only_for=test_device)
|
||||
instantiate_device_type_tests(TestBlockMask, globals(), only_for=("cuda",))
|
||||
instantiate_device_type_tests(TestLearnableBiases, globals(), only_for=("cuda",))
|
||||
|
||||
common_utils.instantiate_parametrized_tests(TestLearnableBiases)
|
||||
|
||||
if __name__ == "__main__":
|
||||
from torch._inductor.test_case import run_tests
|
||||
|
Reference in New Issue
Block a user