Compare commits

...

1 Commits

Author SHA1 Message Date
254a955f20 Autoupdate min_lrs for ReduceLROnPlateau if possible, fixes #104361 (#137637)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/137637
Approved by: https://github.com/albanD
ghstack-source-id: 9a7b26eadb1b3e5d6749cd98434ba7c4c6b5e402
2024-10-09 19:18:56 -07:00
5 changed files with 447 additions and 25 deletions

View File

@ -989,14 +989,15 @@ class TestFlexAttention(InductorTestCase):
self.run_test(composed_score_mod, dtype)
@supported_platform
@expectedFailure # TODO: Remove this after supporting compiled flex attention with training bias
@common_utils.parametrize("dtype", test_dtypes)
def test_captured_buffers(self, dtype: torch.dtype):
head_offset = torch.rand(H, device="cuda", dtype=dtype)
def test_captured_buffers_req_grad(self, dtype: torch.dtype):
head_offset = torch.rand(8, device="cuda", dtype=dtype, requires_grad=True)
def score_mod(score, b, h, m, n):
return score + head_offset[h]
self.run_test(score_mod, dtype)
self.run_test(score_mod, dtype, 4, 8, 128, 128)
@supported_platform
@common_utils.parametrize("dtype", test_dtypes)
@ -2068,6 +2069,242 @@ def forward(self, arg0_1, arg1_1, arg2_1, arg3_1, arg4_1):
f"Ref error: {ref_error}, Flex Error: {flex_error}",
)
@supported_platform
def test_head_bias_req_grad(self):
B, H, S, D = 1, 4, 256, 64
bias = torch.randn(H, device="cuda", dtype=torch.float16, requires_grad=True)
bias_flex = bias.detach().clone().requires_grad_(True)
def head_bias(score, b, h, q_idx, kv_idx):
return score + bias_flex[h]
bias_sdpa_ref = bias.detach().clone().requires_grad_(True)
implicit_bias_sdpa_ref = bias_sdpa_ref
implicit_bias_sdpa_ref = implicit_bias_sdpa_ref.view(H, 1, 1).expand(H, S, S)
bias_sdpa_gold = (
bias.detach().clone().to(dtype=torch.float64).requires_grad_(True)
)
implicit_bias_sdpa_gold = bias_sdpa_gold
implicit_bias_sdpa_gold = implicit_bias_sdpa_gold.view(H, 1, 1).expand(H, S, S)
self._test_learnable_bias_inner(
B,
H,
S,
D,
head_bias,
bias_flex,
implicit_bias_sdpa_ref,
bias_sdpa_ref,
implicit_bias_sdpa_gold,
bias_sdpa_gold,
)
@supported_platform
def test_comparison_vs_sdpa_with_learnable_bias(self):
# 1-dimensional bias:
B, H, S, D = 1, 1, 256, 64
bias = torch.randn(
2 * S, device="cuda", dtype=torch.float16, requires_grad=True
)
bias_flex = bias.detach().clone().requires_grad_(True)
def rel_pos_1d(score, b, h, q_idx, kv_idx):
return score + bias_flex[q_idx + kv_idx]
bias_indices = torch.arange(S)[:, None] + torch.arange(S)
bias_sdpa_ref = bias.detach().clone().requires_grad_(True)
implicit_bias_sdpa_ref = bias_sdpa_ref[bias_indices]
bias_sdpa_gold = (
bias.detach().clone().to(dtype=torch.float64).requires_grad_(True)
)
implicit_bias_sdpa_gold = bias_sdpa_gold[bias_indices]
self._test_learnable_bias_inner(
B,
H,
S,
D,
rel_pos_1d,
bias_flex,
implicit_bias_sdpa_ref,
bias_sdpa_ref,
implicit_bias_sdpa_gold,
bias_sdpa_gold,
)
# 2-dimensional bias:
B, H, S, D = 1, 1, 256, 64
bias = torch.randn(S, S, device="cuda", dtype=torch.float16, requires_grad=True)
bias_flex = bias.detach().clone().requires_grad_(True)
def rel_pos_2d(score, b, h, q_idx, kv_idx):
return score + bias_flex[q_idx, kv_idx]
bias_sdpa_ref = bias.detach().clone().requires_grad_(True)
implicit_bias_sdpa_ref = bias_sdpa_ref
bias_sdpa_gold = (
bias.detach().clone().to(dtype=torch.float64).requires_grad_(True)
)
implicit_bias_sdpa_gold = bias_sdpa_gold
self._test_learnable_bias_inner(
B,
H,
S,
D,
rel_pos_2d,
bias_flex,
implicit_bias_sdpa_ref,
bias_sdpa_ref,
implicit_bias_sdpa_gold,
bias_sdpa_gold,
)
# 2-dimensional bias + index multiple
B, H, S, D = 1, 1, 256, 64
bias = torch.randn(S, S, device="cuda", dtype=torch.float16, requires_grad=True)
bias_flex = bias.detach().clone().requires_grad_(True)
def rel_pos_2d(score, b, h, q_idx, kv_idx):
return score + bias_flex[q_idx][kv_idx]
bias_sdpa_ref = bias.detach().clone().requires_grad_(True)
implicit_bias_sdpa_ref = bias_sdpa_ref
bias_sdpa_gold = (
bias.detach().clone().to(dtype=torch.float64).requires_grad_(True)
)
implicit_bias_sdpa_gold = bias_sdpa_gold
self._test_learnable_bias_inner(
B,
H,
S,
D,
rel_pos_2d,
bias_flex,
implicit_bias_sdpa_ref,
bias_sdpa_ref,
implicit_bias_sdpa_gold,
bias_sdpa_gold,
)
# 2-dimensional bias + transposed:
B, H, S, D = 1, 1, 256, 64
bias = torch.randn(S, S, device="cuda", dtype=torch.float16, requires_grad=True)
bias_flex = bias.detach().clone().requires_grad_(True)
def rel_pos_2d_transposed(score, b, h, q_idx, kv_idx):
return score + bias_flex[kv_idx, q_idx]
bias_sdpa_ref = bias.detach().clone().requires_grad_(True)
implicit_bias_sdpa_ref = bias_sdpa_ref.transpose(-1, -2)
bias_sdpa_gold = (
bias.detach().clone().to(dtype=torch.float64).requires_grad_(True)
)
implicit_bias_sdpa_gold = bias_sdpa_gold.transpose(-1, -2)
self._test_learnable_bias_inner(
B,
H,
S,
D,
rel_pos_2d_transposed,
bias_flex,
implicit_bias_sdpa_ref,
bias_sdpa_ref,
implicit_bias_sdpa_gold,
bias_sdpa_gold,
)
# 3-dimensional bias + transposed
B, H, S, D = 4, 8, 256, 64
bias = torch.randn(
H, S, S, device="cuda", dtype=torch.float16, requires_grad=True
)
bias_flex = bias.detach().clone().requires_grad_(True)
def rel_pos_3d_transposed(score, b, h, q_idx, kv_idx):
return score + bias_flex[h, kv_idx, q_idx]
bias_sdpa_ref = bias.detach().clone().requires_grad_(True)
implicit_bias_sdpa_ref = bias_sdpa_ref.transpose(-1, -2)
bias_sdpa_gold = (
bias.detach().clone().to(dtype=torch.float64).requires_grad_(True)
)
implicit_bias_sdpa_gold = bias_sdpa_gold.transpose(-1, -2)
self._test_learnable_bias_inner(
B,
H,
S,
D,
rel_pos_3d_transposed,
bias_flex,
implicit_bias_sdpa_ref,
bias_sdpa_ref,
implicit_bias_sdpa_gold,
bias_sdpa_gold,
)
def _test_learnable_bias_inner(
self,
B,
H,
S,
D,
score_mod,
bias_flex,
implicit_bias_sdpa_ref,
bias_sdpa_ref,
implicit_bias_sdpa_gold,
bias_sdpa_gold,
):
make_tensor = functools.partial(
torch.ones,
(B, H, S, D),
device="cuda",
dtype=torch.float16,
requires_grad=True,
)
q_ref, k_ref, v_ref = make_tensor(), make_tensor(), make_tensor()
q_gold, k_gold, v_gold = query_key_value_clones(
q_ref, k_ref, v_ref, torch.float64
)
q_flex, k_flex, v_flex = query_key_value_clones(q_ref, k_ref, v_ref)
out_ref = torch.nn.functional.scaled_dot_product_attention(
q_ref, k_ref, v_ref, attn_mask=implicit_bias_sdpa_ref
)
out_ref.sum().backward()
out_gold = torch.nn.functional.scaled_dot_product_attention(
q_gold, k_gold, v_gold, attn_mask=implicit_bias_sdpa_gold
)
out_gold.sum().backward()
out_flex = flex_attention(q_flex, k_flex, v_flex, score_mod=score_mod)
out_flex.sum().backward()
name = score_mod.__name__
for ref, flex, gold in [
(out_ref, out_flex, out_gold),
(q_ref.grad, q_flex.grad, q_gold.grad),
(k_ref.grad, k_flex.grad, k_gold.grad),
(v_ref.grad, v_flex.grad, v_gold.grad),
(bias_sdpa_ref.grad, bias_flex.grad, bias_sdpa_gold.grad),
]:
ref_error = rmse(ref, gold)
flex_error = rmse(flex, gold)
self.assertTrue(
ref_error * 1.2 >= flex_error,
f"{name} -> Ref error: {ref_error}, Flex eager Error: {flex_error}",
)
@supported_platform
def test_causal_block_non_divisible(self):
def mask_mod(b, h, q, kv):

View File

@ -2405,6 +2405,60 @@ class TestLRScheduler(TestCase):
scheduler2.load_state_dict(state_dict_loaded)
self.assertEqual(scheduler2.state_dict(), state_dict)
@parametrize("min_lr", ["scalar", "list"])
def test_add_param_group_does_not_break_reduce_lr_on_plateau(self, min_lr):
epochs = 20
for param_group in self.opt.param_groups:
param_group["lr"] = 0.5
targets = [[0.5] * 6 + [0.05] * (5 + 6) + [0.005] * 4]
metrics = [1] * 7 + [0.6] + [0.5] * 12
scheduler = ReduceLROnPlateau(
self.opt,
mode="min",
threshold_mode="rel",
threshold=0.1,
patience=5,
cooldown=5,
min_lr=0 if min_lr == "scalar" else [1e-5, 1e-4],
)
for epoch in range(epochs):
# Point is to test the use case in #104361
if epoch == 8:
param = torch.nn.Parameter(torch.rand(2, 3))
self.opt.add_param_group({"params": [param], "lr": 0.05})
if min_lr == "list":
scheduler.min_lrs.append(1e-6)
self.opt.step()
scheduler.step(metrics[epoch])
for param_group, target in zip(self.opt.param_groups, targets):
self.assertEqual(
target[epoch],
param_group["lr"],
msg="LR is wrong in epoch {}: expected {}, got {}".format(
epoch, target[epoch], param_group["lr"]
),
atol=1e-5,
rtol=0,
)
def test_add_param_group_errors_reduce_lr_on_plateau(self):
scheduler = ReduceLROnPlateau(
self.opt,
mode="min",
threshold_mode="rel",
threshold=1e-5,
patience=0,
cooldown=0,
min_lr=[1e-5, 1e-4],
)
param = torch.nn.Parameter(torch.rand(2, 3))
self.opt.add_param_group({"params": [param], "lr": 0.05})
self.opt.step()
scheduler.step(1)
with self.assertRaisesRegex(RuntimeError, "The number of param groups in the"):
self.opt.step()
scheduler.step(1.3)
@parametrize(
"LRClass",
[

View File

@ -1,4 +1,4 @@
from typing import Any, Dict, Optional, Tuple
from typing import Any, Dict, List, Optional, Tuple
import torch
import torch.utils._pytree as pytree
@ -49,6 +49,78 @@ __all__ = ["trace_wrapped"]
# compiled autograd do we inline into the function.
if not torch._running_with_deploy():
"""torch.library.custom_op does not work with torch.deploy/multipy"""
@torch.library.custom_op("FlexAttentionLib::zeros_and_scatter", mutates_args=()) # type: ignore[misc]
def zeros_and_scatter(
shape: List[int],
indices: List[Tensor],
vals: Tensor,
) -> Tensor:
"""Custom Op so that we can register a custom lowering for the new_output + scatter in the backwards pass"""
grad = torch.zeros(shape, device=vals.device, dtype=vals.dtype)
return torch.ops.aten.index_put(grad, indices, vals, accumulate=True)
@zeros_and_scatter.register_fake # type: ignore[misc]
def _(
shape: List[int],
indices: List[Tensor],
vals: Tensor,
) -> Tensor:
return vals.new_empty(shape)
@zeros_and_scatter.register_vmap # type: ignore[misc]
def _(info, indims, shape, indices, value): # type: ignore[no-untyped-def]
"""The batching rule is special in that it returns a tensor that is not batched"""
indices_indims = indims[1]
expanded_indices = []
for idx, idx_indim in zip(indices, indices_indims):
# The index is not a being batched, we should unsqueeze and expand to val
if idx_indim is None:
expanded_indices.append(idx.expand(value.shape))
else:
# the index is being part of the vmap batch, it should be the same size as val
assert idx.shape == value.shape
expanded_indices.append(idx)
out = torch.ops.FlexAttentionLib.zeros_and_scatter(
shape,
expanded_indices,
value,
)
return out, None
class ModIndex(torch.autograd.Function):
generate_vmap_rule = True
@staticmethod
def forward(x: Tensor, indices: List[Tensor]) -> Tensor:
return torch.ops.aten.index(x, indices)
@staticmethod
def setup_context(ctx: Any, inputs: Tuple[Any, ...], output: Any) -> None:
x, indices = inputs
ctx.save_for_backward(*indices)
ctx.input_shape = x.shape
@staticmethod
def backward(ctx, gradOut): # type: ignore[no-untyped-def]
indices = ctx.saved_tensors
return (
torch.ops.FlexAttentionLib.zeros_and_scatter(
ctx.input_shape,
indices,
gradOut,
),
None,
)
mod_index = ModIndex.apply
class TransformGetItemToIndex(TorchFunctionMode):
# This is needed since we want to support calling
# A[q_idx], where q_idx is a scalar tensor in score_mod.
@ -66,7 +138,7 @@ class TransformGetItemToIndex(TorchFunctionMode):
if func == torch.Tensor.__getitem__:
index_args = pytree.tree_leaves(args[1])
if all(isinstance(x, torch.Tensor) for x in index_args):
return torch.ops.aten.index(args[0], index_args)
return mod_index(args[0], index_args)
return func(*args, **(kwargs or {}))

View File

@ -125,7 +125,9 @@ class FlexAttentionBackwardHOP(HigherOrderOperator):
kernel_options: Dict[str, Any],
score_mod_other_buffers: Tuple = (),
mask_mod_other_buffers: Tuple = (),
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
) -> Tuple[
torch.Tensor, torch.Tensor, torch.Tensor, Tuple[Optional[torch.Tensor], ...]
]:
if not all(
isinstance(buf, torch.Tensor)
for buf in score_mod_other_buffers + mask_mod_other_buffers
@ -578,16 +580,15 @@ class FlexAttentionAutogradOp(torch.autograd.Function):
block_mask: Tuple[Any, ...],
scale: float,
kernel_options: Dict[str, Any],
score_mod_other_buffers: Tuple[Any, ...],
mask_mod_other_buffers: Tuple[Any, ...],
*score_mod_other_buffers: Tuple[Any, ...],
) -> Tuple[torch.Tensor, torch.Tensor]:
any_buffer_requires_grad = any(
buffer.requires_grad
for buffer in score_mod_other_buffers + mask_mod_other_buffers
buffer.requires_grad for buffer in mask_mod_other_buffers
)
assert (
not any_buffer_requires_grad
), "Captured buffers that require grad are not yet supported."
), "Captured buffers from mask mod that require grad are not yet supported."
ctx._fw_graph = fw_graph
ctx._joint_graph = joint_graph
ctx._mask_graph = block_mask[-1]
@ -654,9 +655,15 @@ class FlexAttentionAutogradOp(torch.autograd.Function):
mask_mod_other_buffers = tuple(
other_buffers[ctx._score_mod_other_buffers_len :]
)
# We have asserted that other_buffers do not require grad in the forward
none_grads = [None] * 7
grad_query, grad_key, grad_value = flex_attention_backward(
# We have asserted that mask_mod_other_buffers do not require grad,
# but score_mod_other_buffers can require grad.
none_grads = [None] * 6
(
grad_query,
grad_key,
grad_value,
grad_score_mod_captured,
) = flex_attention_backward(
query,
key,
value,
@ -684,7 +691,7 @@ class FlexAttentionAutogradOp(torch.autograd.Function):
score_mod_other_buffers,
mask_mod_other_buffers,
)
return grad_query, grad_key, grad_value, *none_grads
return grad_query, grad_key, grad_value, *none_grads, *grad_score_mod_captured
@flex_attention.py_impl(DispatchKey.Autograd)
@ -725,8 +732,8 @@ def flex_attention_autograd(
block_mask,
scale,
kernel_options,
score_mod_other_buffers,
mask_mod_other_buffers,
*score_mod_other_buffers,
)
return out, logsumexp
@ -750,13 +757,19 @@ def sdpa_dense_backward(
kernel_options: Dict[str, Any],
score_mod_other_buffers: Tuple,
mask_mod_other_buffers: Tuple,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
) -> Tuple[
torch.Tensor, torch.Tensor, torch.Tensor, Tuple[Optional[torch.Tensor], ...]
]:
from torch._dynamo._trace_wrapped_higher_order_op import TransformGetItemToIndex
# Get outputs before calling repeat interleave
actual_grad_query = torch.empty_like(query)
actual_grad_key = torch.empty_like(key)
actual_grad_value = torch.empty_like(value)
actual_grad_score_mod_captured = [
torch.empty_like(buffer) if buffer.requires_grad else None
for buffer in score_mod_other_buffers
]
Bq, Bkv = query.size(0), key.size(0)
if not ((Bq == Bkv) or (Bq > 1 and Bkv == 1)):
@ -817,7 +830,7 @@ def sdpa_dense_backward(
out_dims=out_dims,
)
with TransformGetItemToIndex():
grad_scores, *_ = joint_score_mod(
grad_scores, _, _, _, _, *grad_score_mod_captured = joint_score_mod(
scores, b, h, m, n, grad_score_mod, *score_mod_other_buffers
)
grad_scores = grad_scores * scale
@ -858,8 +871,19 @@ def sdpa_dense_backward(
actual_grad_query.copy_(grad_query)
actual_grad_key.copy_(grad_key)
actual_grad_value.copy_(grad_value)
score_mod_other_buffer_grads = [
actual_grad.copy_(grad) if actual_grad is not None else actual_grad
for actual_grad, grad in zip(
actual_grad_score_mod_captured, grad_score_mod_captured
)
]
return actual_grad_query, actual_grad_key, actual_grad_value
return (
actual_grad_query,
actual_grad_key,
actual_grad_value,
tuple(score_mod_other_buffer_grads),
)
def trace_flex_attention_backward(
@ -878,7 +902,9 @@ def trace_flex_attention_backward(
kernel_options: Dict[str, Any],
score_mod_other_buffers: Tuple = (),
mask_mod_other_buffers: Tuple = (),
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
) -> Tuple[
torch.Tensor, torch.Tensor, torch.Tensor, Tuple[Optional[torch.Tensor], ...]
]:
"""We already have the forward graph and joint graph from the forward pass, so we create a proxy attach both graphs"""
from torch._dynamo._trace_wrapped_higher_order_op import TransformGetItemToIndex
@ -964,7 +990,9 @@ def flex_attention_backward_proxy_torch_dispatch_mode(
kernel_options: Dict[str, Any],
score_mod_other_buffers: Tuple = (),
mask_mod_other_buffers: Tuple = (),
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
) -> Tuple[
torch.Tensor, torch.Tensor, torch.Tensor, Tuple[Optional[torch.Tensor], ...]
]:
assert mode is not None, "Mode should always be enabled for python fallback key"
return trace_flex_attention_backward(
mode,
@ -1002,7 +1030,9 @@ def flex_attention_backward_functionalize(
kernel_options: Dict[str, Any],
score_mod_other_buffers: Tuple = (),
mask_mod_other_buffers: Tuple = (),
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
) -> Tuple[
torch.Tensor, torch.Tensor, torch.Tensor, Tuple[Optional[torch.Tensor], ...]
]:
"""Defines the functionalization rules for the flex_attention operator.
Write now we are unwrapping each tensor and then redispatching to the next,
@ -1040,7 +1070,12 @@ def flex_attention_backward_functionalize(
functional_fw_graph = ctx.functionalize(fw_graph)
functional_joint_graph = ctx.functionalize(joint_graph)
grad_query, grad_key, grad_value = flex_attention_backward(
(
grad_query,
grad_key,
grad_value,
grad_score_mod_captured,
) = flex_attention_backward(
query_unwrapped,
key_unwrapped,
value_unwrapped,
@ -1057,7 +1092,7 @@ def flex_attention_backward_functionalize(
mask_mod_other_buffers_unwrapped,
)
return ctx.wrap_tensors((grad_query, grad_key, grad_value)) # type: ignore[return-value,arg-type]
return ctx.wrap_tensors((grad_query, grad_key, grad_value, grad_score_mod_captured)) # type: ignore[return-value,arg-type]
@flex_attention_backward.py_impl(FakeTensorMode)
@ -1077,12 +1112,20 @@ def flex_attention_backward_fake_tensor_mode(
kernel_options: Dict[str, Any],
score_mod_other_buffers: Tuple = (),
mask_mod_other_buffers: Tuple = (),
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
) -> Tuple[
torch.Tensor, torch.Tensor, torch.Tensor, Tuple[Optional[torch.Tensor], ...]
]:
with mode:
grad_query = torch.empty_like(query)
grad_key = torch.empty_like(key)
grad_value = torch.empty_like(value)
return grad_query, grad_key, grad_value
grad_score_mod_captured = tuple(
[
torch.empty_like(buffer) if buffer.requires_grad else None
for buffer in score_mod_other_buffers
]
)
return grad_query, grad_key, grad_value, grad_score_mod_captured
flex_attention_backward.py_impl(DispatchKey.Autograd)(

View File

@ -1318,8 +1318,10 @@ class ReduceLROnPlateau(LRScheduler):
raise ValueError(
f"expected {len(optimizer.param_groups)} min_lrs, got {len(min_lr)}"
)
self.default_min_lr = None
self.min_lrs = list(min_lr)
else:
self.default_min_lr = min_lr
self.min_lrs = [min_lr] * len(optimizer.param_groups)
self.patience = patience
@ -1375,6 +1377,20 @@ class ReduceLROnPlateau(LRScheduler):
self._last_lr = [group["lr"] for group in self.optimizer.param_groups]
def _reduce_lr(self, epoch):
if len(self.optimizer.param_groups) != len(self.min_lrs):
if self.default_min_lr is None:
raise RuntimeError(
"The number of param groups in the `optimizer` "
f"({len(self.optimizer.param_groups)}) differs "
f"from when `ReduceLROnPlateau` was initialized "
f"({len(self.min_lrs)}), usually due to a new "
"param group being added to the optimizer. Please "
"modify the `min_lrs` field to match the length "
"of the `optimizer` param groups."
)
else:
self.min_lrs = [self.default_min_lr] * len(self.optimizer.param_groups)
for i, param_group in enumerate(self.optimizer.param_groups):
old_lr = float(param_group["lr"])
new_lr = max(old_lr * self.factor, self.min_lrs[i])