mirror of
https://github.com/huggingface/transformers.git
synced 2025-10-20 17:13:56 +08:00
Compare commits
21 Commits
v4.42.4
...
thomas/acc
Author | SHA1 | Date | |
---|---|---|---|
3df009e4d3 | |||
befaac7532 | |||
115d62b817 | |||
ded4ba6a17 | |||
efa1814de6 | |||
82c2bebc4a | |||
47b2325c18 | |||
8c796b3fb0 | |||
0d2c182260 | |||
5b84d96f07 | |||
9cd5368975 | |||
5aec6894df | |||
711b07f7cd | |||
7f64c1e658 | |||
33713b311e | |||
12ccfa8dc0 | |||
16a8b6bc83 | |||
215a85fa2c | |||
0991db29b2 | |||
ca03190593 | |||
c9baf40937 |
@ -123,7 +123,7 @@ class DecisionTransformerGPT2Attention(nn.Module):
|
||||
max_positions = config.max_position_embeddings
|
||||
self.register_buffer(
|
||||
"bias",
|
||||
torch.tril(torch.ones((max_positions, max_positions), dtype=torch.uint8)).view(
|
||||
torch.tril(torch.ones((max_positions, max_positions), dtype=torch.bool)).view(
|
||||
1, 1, max_positions, max_positions
|
||||
),
|
||||
)
|
||||
@ -177,23 +177,24 @@ class DecisionTransformerGPT2Attention(nn.Module):
|
||||
def _attn(self, query, key, value, attention_mask=None, head_mask=None):
|
||||
attn_weights = torch.matmul(query, key.transpose(-1, -2))
|
||||
|
||||
normalizer = 1
|
||||
if self.scale_attn_weights:
|
||||
attn_weights = attn_weights / torch.tensor(
|
||||
value.size(-1) ** 0.5, dtype=attn_weights.dtype, device=attn_weights.device
|
||||
)
|
||||
normalizer *= torch.tensor(value.size(-1) ** 0.5, dtype=query.dtype, device=query.device)
|
||||
|
||||
# Layer-wise attention scaling
|
||||
if self.scale_attn_by_inverse_layer_idx:
|
||||
attn_weights = attn_weights / float(self.layer_idx + 1)
|
||||
normalizer *= float(self.layer_idx + 1)
|
||||
|
||||
attn_weights = attn_weights / normalizer
|
||||
|
||||
if not self.is_cross_attention:
|
||||
# if only "normal" attention layer implements causal mask
|
||||
query_length, key_length = query.size(-2), key.size(-2)
|
||||
causal_mask = self.bias[:, :, key_length - query_length : key_length, :key_length].to(torch.bool)
|
||||
causal_mask = self.bias[:, :, key_length - query_length : key_length, :key_length]
|
||||
mask_value = torch.finfo(attn_weights.dtype).min
|
||||
# Need to be a tensor, otherwise we get error: `RuntimeError: expected scalar type float but found double`.
|
||||
# Need to be on the same device, otherwise `RuntimeError: ..., x and y to be on the same device`
|
||||
mask_value = torch.tensor(mask_value, dtype=attn_weights.dtype).to(attn_weights.device)
|
||||
mask_value = torch.tensor(mask_value, dtype=attn_weights.dtype, device=attn_weights.device)
|
||||
attn_weights = torch.where(causal_mask, attn_weights, mask_value)
|
||||
|
||||
if attention_mask is not None:
|
||||
@ -214,13 +215,19 @@ class DecisionTransformerGPT2Attention(nn.Module):
|
||||
|
||||
return attn_output, attn_weights
|
||||
|
||||
def _upcast_and_reordered_attn(self, query, key, value, attention_mask=None, head_mask=None):
|
||||
def _upcast_and_reordered_attn(self, query, key, value, num_heads: int, attention_mask=None, head_mask=None):
|
||||
# Use `torch.baddbmm` (a bit more efficient w/ alpha param for scaling -- from Megatron-LM)
|
||||
bsz, num_heads, q_seq_len, dk = query.size()
|
||||
_, _, k_seq_len, _ = key.size()
|
||||
bsz_times_num_heads, query_length, _ = query.size()
|
||||
_, _, key_length = key.size()
|
||||
batch_size = bsz_times_num_heads // num_heads
|
||||
|
||||
# Preallocate attn_weights for `baddbmm`
|
||||
attn_weights = torch.empty(bsz * num_heads, q_seq_len, k_seq_len, dtype=torch.float32, device=query.device)
|
||||
attn_weights = torch.empty(
|
||||
bsz_times_num_heads, query_length, key_length, dtype=torch.float32, device=query.device
|
||||
)
|
||||
if attention_mask is not None:
|
||||
# Apply the attention mask
|
||||
attn_weights.view(batch_size, num_heads, query_length, key_length)[:] = attention_mask
|
||||
|
||||
# Compute Scale Factor
|
||||
scale_factor = 1.0
|
||||
@ -233,27 +240,21 @@ class DecisionTransformerGPT2Attention(nn.Module):
|
||||
# Upcast (turn off autocast) and reorder (Scale K by 1 / root(dk))
|
||||
if is_amp_available:
|
||||
with autocast(enabled=False):
|
||||
q, k = query.reshape(-1, q_seq_len, dk), key.transpose(-1, -2).reshape(-1, dk, k_seq_len)
|
||||
attn_weights = torch.baddbmm(attn_weights, q.float(), k.float(), beta=0, alpha=scale_factor)
|
||||
attn_weights = attn_weights.reshape(bsz, num_heads, q_seq_len, k_seq_len)
|
||||
attn_weights = torch.baddbmm(
|
||||
attn_weights, query, key, beta=0 if attention_mask is None else 1, alpha=scale_factor
|
||||
)
|
||||
else:
|
||||
q, k = query.reshape(-1, q_seq_len, dk), key.transpose(-1, -2).reshape(-1, dk, k_seq_len)
|
||||
attn_weights = torch.baddbmm(attn_weights, q.float(), k.float(), beta=0, alpha=scale_factor)
|
||||
attn_weights = attn_weights.reshape(bsz, num_heads, q_seq_len, k_seq_len)
|
||||
attn_weights = torch.baddbmm(
|
||||
attn_weights, query, key, beta=0 if attention_mask is None else 1, alpha=scale_factor
|
||||
)
|
||||
|
||||
if not self.is_cross_attention:
|
||||
# if only "normal" attention layer implements causal mask
|
||||
query_length, key_length = query.size(-2), key.size(-2)
|
||||
causal_mask = self.bias[:, :, key_length - query_length : key_length, :key_length].bool()
|
||||
mask_value = torch.finfo(attn_weights.dtype).min
|
||||
# Need to be a tensor, otherwise we get error: `RuntimeError: expected scalar type float but found double`.
|
||||
# Need to be on the same device, otherwise `RuntimeError: ..., x and y to be on the same device`
|
||||
mask_value = torch.tensor(mask_value, dtype=attn_weights.dtype).to(attn_weights.device)
|
||||
attn_weights = torch.where(causal_mask, attn_weights, mask_value)
|
||||
|
||||
if attention_mask is not None:
|
||||
# Apply the attention mask
|
||||
attn_weights = attn_weights + attention_mask
|
||||
causal_mask = self.bias[:, :, key_length - query_length : key_length, :key_length]
|
||||
# In-place update of `attn_weights`
|
||||
attn_weights.view(batch_size, num_heads, query_length, key_length).masked_fill_(
|
||||
causal_mask, torch.finfo(attn_weights.dtype).min
|
||||
)
|
||||
|
||||
attn_weights = nn.functional.softmax(attn_weights, dim=-1)
|
||||
|
||||
@ -290,7 +291,7 @@ class DecisionTransformerGPT2Attention(nn.Module):
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: Optional[Tuple[torch.FloatTensor]],
|
||||
layer_past: Optional[Tuple[torch.Tensor]] = None,
|
||||
layer_past: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
||||
attention_mask: Optional[torch.FloatTensor] = None,
|
||||
head_mask: Optional[torch.FloatTensor] = None,
|
||||
encoder_hidden_states: Optional[torch.Tensor] = None,
|
||||
@ -315,6 +316,12 @@ class DecisionTransformerGPT2Attention(nn.Module):
|
||||
key = self._split_heads(key, self.num_heads, self.head_dim)
|
||||
value = self._split_heads(value, self.num_heads, self.head_dim)
|
||||
|
||||
batch_size, num_heads, query_length, head_dim = query.size()
|
||||
if self.reorder_and_upcast_attn:
|
||||
query = query.to(torch.float32).reshape(-1, query_length, head_dim)
|
||||
key = key.transpose(-1, -2).to(torch.float32).reshape(-1, head_dim, query_length)
|
||||
value = value.reshape(-1, query_length, head_dim)
|
||||
|
||||
if layer_past is not None:
|
||||
past_key, past_value = layer_past
|
||||
key = torch.cat((past_key, key), dim=-2)
|
||||
@ -326,7 +333,10 @@ class DecisionTransformerGPT2Attention(nn.Module):
|
||||
present = None
|
||||
|
||||
if self.reorder_and_upcast_attn:
|
||||
attn_output, attn_weights = self._upcast_and_reordered_attn(query, key, value, attention_mask, head_mask)
|
||||
attn_output, attn_weights = self._upcast_and_reordered_attn(
|
||||
query, key, value, num_heads=num_heads, attention_mask=attention_mask, head_mask=head_mask
|
||||
)
|
||||
attn_output = attn_output.view(batch_size, num_heads, query_length, head_dim)
|
||||
else:
|
||||
attn_output, attn_weights = self._attn(query, key, value, attention_mask, head_mask)
|
||||
|
||||
|
@ -134,7 +134,7 @@ class GPT2Attention(nn.Module):
|
||||
max_positions = config.max_position_embeddings
|
||||
self.register_buffer(
|
||||
"bias",
|
||||
torch.tril(torch.ones((max_positions, max_positions), dtype=torch.uint8)).view(
|
||||
torch.tril(torch.ones((max_positions, max_positions), dtype=torch.bool)).view(
|
||||
1, 1, max_positions, max_positions
|
||||
),
|
||||
)
|
||||
@ -188,23 +188,24 @@ class GPT2Attention(nn.Module):
|
||||
def _attn(self, query, key, value, attention_mask=None, head_mask=None):
|
||||
attn_weights = torch.matmul(query, key.transpose(-1, -2))
|
||||
|
||||
normalizer = 1
|
||||
if self.scale_attn_weights:
|
||||
attn_weights = attn_weights / torch.tensor(
|
||||
value.size(-1) ** 0.5, dtype=attn_weights.dtype, device=attn_weights.device
|
||||
)
|
||||
normalizer *= torch.tensor(value.size(-1) ** 0.5, dtype=query.dtype, device=query.device)
|
||||
|
||||
# Layer-wise attention scaling
|
||||
if self.scale_attn_by_inverse_layer_idx:
|
||||
attn_weights = attn_weights / float(self.layer_idx + 1)
|
||||
normalizer *= float(self.layer_idx + 1)
|
||||
|
||||
attn_weights = attn_weights / normalizer
|
||||
|
||||
if not self.is_cross_attention:
|
||||
# if only "normal" attention layer implements causal mask
|
||||
query_length, key_length = query.size(-2), key.size(-2)
|
||||
causal_mask = self.bias[:, :, key_length - query_length : key_length, :key_length].to(torch.bool)
|
||||
causal_mask = self.bias[:, :, key_length - query_length : key_length, :key_length]
|
||||
mask_value = torch.finfo(attn_weights.dtype).min
|
||||
# Need to be a tensor, otherwise we get error: `RuntimeError: expected scalar type float but found double`.
|
||||
# Need to be on the same device, otherwise `RuntimeError: ..., x and y to be on the same device`
|
||||
mask_value = torch.tensor(mask_value, dtype=attn_weights.dtype).to(attn_weights.device)
|
||||
mask_value = torch.tensor(mask_value, dtype=attn_weights.dtype, device=attn_weights.device)
|
||||
attn_weights = torch.where(causal_mask, attn_weights, mask_value)
|
||||
|
||||
if attention_mask is not None:
|
||||
@ -225,13 +226,19 @@ class GPT2Attention(nn.Module):
|
||||
|
||||
return attn_output, attn_weights
|
||||
|
||||
def _upcast_and_reordered_attn(self, query, key, value, attention_mask=None, head_mask=None):
|
||||
def _upcast_and_reordered_attn(self, query, key, value, num_heads: int, attention_mask=None, head_mask=None):
|
||||
# Use `torch.baddbmm` (a bit more efficient w/ alpha param for scaling -- from Megatron-LM)
|
||||
bsz, num_heads, q_seq_len, dk = query.size()
|
||||
_, _, k_seq_len, _ = key.size()
|
||||
bsz_times_num_heads, query_length, _ = query.size()
|
||||
_, _, key_length = key.size()
|
||||
batch_size = bsz_times_num_heads // num_heads
|
||||
|
||||
# Preallocate attn_weights for `baddbmm`
|
||||
attn_weights = torch.empty(bsz * num_heads, q_seq_len, k_seq_len, dtype=torch.float32, device=query.device)
|
||||
attn_weights = torch.empty(
|
||||
bsz_times_num_heads, query_length, key_length, dtype=torch.float32, device=query.device
|
||||
)
|
||||
if attention_mask is not None:
|
||||
# Apply the attention mask
|
||||
attn_weights.view(batch_size, num_heads, query_length, key_length)[:] = attention_mask
|
||||
|
||||
# Compute Scale Factor
|
||||
scale_factor = 1.0
|
||||
@ -244,27 +251,21 @@ class GPT2Attention(nn.Module):
|
||||
# Upcast (turn off autocast) and reorder (Scale K by 1 / root(dk))
|
||||
if is_amp_available:
|
||||
with autocast(enabled=False):
|
||||
q, k = query.reshape(-1, q_seq_len, dk), key.transpose(-1, -2).reshape(-1, dk, k_seq_len)
|
||||
attn_weights = torch.baddbmm(attn_weights, q.float(), k.float(), beta=0, alpha=scale_factor)
|
||||
attn_weights = attn_weights.reshape(bsz, num_heads, q_seq_len, k_seq_len)
|
||||
attn_weights = torch.baddbmm(
|
||||
attn_weights, query, key, beta=0 if attention_mask is None else 1, alpha=scale_factor
|
||||
)
|
||||
else:
|
||||
q, k = query.reshape(-1, q_seq_len, dk), key.transpose(-1, -2).reshape(-1, dk, k_seq_len)
|
||||
attn_weights = torch.baddbmm(attn_weights, q.float(), k.float(), beta=0, alpha=scale_factor)
|
||||
attn_weights = attn_weights.reshape(bsz, num_heads, q_seq_len, k_seq_len)
|
||||
attn_weights = torch.baddbmm(
|
||||
attn_weights, query, key, beta=0 if attention_mask is None else 1, alpha=scale_factor
|
||||
)
|
||||
|
||||
if not self.is_cross_attention:
|
||||
# if only "normal" attention layer implements causal mask
|
||||
query_length, key_length = query.size(-2), key.size(-2)
|
||||
causal_mask = self.bias[:, :, key_length - query_length : key_length, :key_length].bool()
|
||||
mask_value = torch.finfo(attn_weights.dtype).min
|
||||
# Need to be a tensor, otherwise we get error: `RuntimeError: expected scalar type float but found double`.
|
||||
# Need to be on the same device, otherwise `RuntimeError: ..., x and y to be on the same device`
|
||||
mask_value = torch.tensor(mask_value, dtype=attn_weights.dtype).to(attn_weights.device)
|
||||
attn_weights = torch.where(causal_mask, attn_weights, mask_value)
|
||||
|
||||
if attention_mask is not None:
|
||||
# Apply the attention mask
|
||||
attn_weights = attn_weights + attention_mask
|
||||
causal_mask = self.bias[:, :, key_length - query_length : key_length, :key_length]
|
||||
# In-place update of `attn_weights`
|
||||
attn_weights.view(batch_size, num_heads, query_length, key_length).masked_fill_(
|
||||
causal_mask, torch.finfo(attn_weights.dtype).min
|
||||
)
|
||||
|
||||
attn_weights = nn.functional.softmax(attn_weights, dim=-1)
|
||||
|
||||
@ -294,14 +295,14 @@ class GPT2Attention(nn.Module):
|
||||
"""
|
||||
Merges attn_head_size dim and num_attn_heads dim into hidden_size
|
||||
"""
|
||||
tensor = tensor.permute(0, 2, 1, 3).contiguous()
|
||||
tensor = tensor.permute(0, 2, 1, 3)
|
||||
new_shape = tensor.size()[:-2] + (num_heads * attn_head_size,)
|
||||
return tensor.view(new_shape)
|
||||
return tensor.reshape(new_shape)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: Optional[Tuple[torch.FloatTensor]],
|
||||
layer_past: Optional[Tuple[torch.Tensor]] = None,
|
||||
layer_past: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
||||
attention_mask: Optional[torch.FloatTensor] = None,
|
||||
head_mask: Optional[torch.FloatTensor] = None,
|
||||
encoder_hidden_states: Optional[torch.Tensor] = None,
|
||||
@ -326,6 +327,12 @@ class GPT2Attention(nn.Module):
|
||||
key = self._split_heads(key, self.num_heads, self.head_dim)
|
||||
value = self._split_heads(value, self.num_heads, self.head_dim)
|
||||
|
||||
batch_size, num_heads, query_length, head_dim = query.size()
|
||||
if self.reorder_and_upcast_attn:
|
||||
query = query.to(torch.float32).reshape(-1, query_length, head_dim)
|
||||
key = key.transpose(-1, -2).to(torch.float32).reshape(-1, head_dim, query_length)
|
||||
value = value.reshape(-1, query_length, head_dim)
|
||||
|
||||
if layer_past is not None:
|
||||
past_key, past_value = layer_past
|
||||
key = torch.cat((past_key, key), dim=-2)
|
||||
@ -337,7 +344,10 @@ class GPT2Attention(nn.Module):
|
||||
present = None
|
||||
|
||||
if self.reorder_and_upcast_attn:
|
||||
attn_output, attn_weights = self._upcast_and_reordered_attn(query, key, value, attention_mask, head_mask)
|
||||
attn_output, attn_weights = self._upcast_and_reordered_attn(
|
||||
query, key, value, num_heads=num_heads, attention_mask=attention_mask, head_mask=head_mask
|
||||
)
|
||||
attn_output = attn_output.view(batch_size, num_heads, query_length, head_dim)
|
||||
else:
|
||||
attn_output, attn_weights = self._attn(query, key, value, attention_mask, head_mask)
|
||||
|
||||
@ -1102,16 +1112,49 @@ class GPT2LMHeadModel(GPT2PreTrainedModel):
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _reorder_cache(past: Tuple[Tuple[torch.Tensor]], beam_idx: torch.Tensor) -> Tuple[Tuple[torch.Tensor]]:
|
||||
def _reorder_cache(
|
||||
past: Tuple[Tuple[torch.Tensor, torch.Tensor], ...], beam_idx: torch.Tensor
|
||||
) -> Tuple[Tuple[torch.Tensor, torch.Tensor], ...]:
|
||||
"""
|
||||
This function is used to re-order the `past_key_values` cache if [`~PreTrainedModel.beam_search`] or
|
||||
[`~PreTrainedModel.beam_sample`] is called. This is required to match `past_key_values` with the correct
|
||||
beam_idx at every generation step.
|
||||
"""
|
||||
return tuple(
|
||||
tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past)
|
||||
for layer_past in past
|
||||
)
|
||||
# Depending on `config.reorder_and_upcast_attn` values, stored past are different:
|
||||
# - True: key [batch_size * num_heads, head_dim, seq_length], value [batch_size * num_heads, seq_length, head_dim]
|
||||
# - False: key/value [batch_size, num_heads, seq_length, head_dim]
|
||||
past_num_dimensions = len(past[0][0].shape)
|
||||
device_to_beam_idx = {
|
||||
past_state.device: beam_idx.to(past_state.device) for layer_past in past for past_state in layer_past
|
||||
}
|
||||
|
||||
if past_num_dimensions == 3:
|
||||
batch_size_times_num_heads, head_dim, seq_length = past[0][0].shape
|
||||
batch_size = len(beam_idx)
|
||||
num_heads = batch_size_times_num_heads // batch_size
|
||||
return tuple(
|
||||
(
|
||||
layer_past[0]
|
||||
.view(batch_size, num_heads, head_dim, seq_length)
|
||||
.index_select(0, device_to_beam_idx[layer_past[0].device])
|
||||
.view(batch_size_times_num_heads, head_dim, seq_length),
|
||||
layer_past[1]
|
||||
.view(batch_size, num_heads, seq_length, head_dim)
|
||||
.index_select(0, device_to_beam_idx[layer_past[0].device])
|
||||
.view(batch_size_times_num_heads, seq_length, head_dim),
|
||||
)
|
||||
for layer_past in past
|
||||
)
|
||||
elif past_num_dimensions == 4:
|
||||
return tuple(
|
||||
tuple(past_state.index_select(0, device_to_beam_idx[past_state.device]) for past_state in layer_past)
|
||||
for layer_past in past
|
||||
)
|
||||
else:
|
||||
raise ValueError(
|
||||
"Past keys and values have to be either 3 or 4 dimensional depending on"
|
||||
f" `config.reorder_and_upcast_attn`, for {past_num_dimensions} dimensions"
|
||||
)
|
||||
|
||||
|
||||
@add_start_docstrings(
|
||||
|
@ -358,8 +358,17 @@ class LayoutLMv2ModelTest(ModelTesterMixin, unittest.TestCase):
|
||||
if key == "layoutlmv2.visual_segment_embedding":
|
||||
# we skip the visual segment embedding as it has a custom initialization scheme
|
||||
continue
|
||||
max_diff = (model_slow_init.state_dict()[key] - model_fast_init.state_dict()[key]).sum().item()
|
||||
self.assertLessEqual(max_diff, 1e-3, msg=f"{key} not identical")
|
||||
|
||||
slow_dtype = model_slow_init.state_dict()[key].dtype
|
||||
fast_dtype = model_fast_init.state_dict()[key].dtype
|
||||
self.assertEqual(slow_dtype, fast_dtype)
|
||||
|
||||
if fast_dtype == torch.bool:
|
||||
# torch.BoolTensor should be deterministic
|
||||
self.assertEqual(torch.all(model_slow_init.state_dict()[key] == model_fast_init.state_dict()[key]))
|
||||
else:
|
||||
max_diff = (model_slow_init.state_dict()[key] - model_fast_init.state_dict()[key]).sum().item()
|
||||
self.assertLessEqual(max_diff, 1e-3, msg=f"{key} not identical")
|
||||
|
||||
def test_attention_outputs(self):
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
|
@ -342,8 +342,16 @@ class ModelTesterMixin:
|
||||
model_slow_init = model_class_copy.from_pretrained(tmpdirname, _fast_init=False)
|
||||
|
||||
for key in model_fast_init.state_dict().keys():
|
||||
max_diff = (model_slow_init.state_dict()[key] - model_fast_init.state_dict()[key]).sum().item()
|
||||
self.assertLessEqual(max_diff, 1e-3, msg=f"{key} not identical")
|
||||
slow_dtype = model_slow_init.state_dict()[key].dtype
|
||||
fast_dtype = model_fast_init.state_dict()[key].dtype
|
||||
self.assertEqual(slow_dtype, fast_dtype)
|
||||
|
||||
if fast_dtype == torch.bool:
|
||||
# torch.BoolTensor should be deterministic
|
||||
self.assertEqual(torch.all(model_slow_init.state_dict()[key] == model_fast_init.state_dict()[key]))
|
||||
else:
|
||||
max_diff = (model_slow_init.state_dict()[key] - model_fast_init.state_dict()[key]).sum().item()
|
||||
self.assertLessEqual(max_diff, 1e-3, msg=f"{key} not identical")
|
||||
|
||||
def test_save_load_fast_init_to_base(self):
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
|
Reference in New Issue
Block a user