Compare commits

...

21 Commits

Author SHA1 Message Date
3df009e4d3 add torch.all in test 2022-07-31 22:25:08 +02:00
befaac7532 make fix-copies 2022-07-30 09:20:44 +02:00
115d62b817 Improve test to support bool 2022-07-30 09:13:59 +02:00
ded4ba6a17 Improve test to support bool 2022-07-30 08:49:11 +02:00
efa1814de6 Requires to reshape 2022-07-30 08:25:40 +02:00
82c2bebc4a make style 2022-07-30 07:43:12 +02:00
47b2325c18 Woops 2022-07-30 07:42:23 +02:00
8c796b3fb0 Value don't need to be casted to fp32 2022-07-29 20:54:19 +02:00
0d2c182260 Try this fancy way of doing it 2022-07-29 20:49:18 +02:00
5b84d96f07 More optimization 2022-07-29 20:40:45 +02:00
9cd5368975 Hopefully this works 2022-07-29 20:29:20 +02:00
5aec6894df Woops 2022-07-29 19:07:07 +02:00
711b07f7cd Woops 2022-07-29 19:04:04 +02:00
7f64c1e658 Woops 2022-07-29 18:56:12 +02:00
33713b311e Woops 2022-07-29 18:51:20 +02:00
12ccfa8dc0 Revert back changes 2022-07-29 18:48:26 +02:00
16a8b6bc83 Woops 2022-07-29 18:44:31 +02:00
215a85fa2c Woops 2022-07-29 18:42:25 +02:00
0991db29b2 make style 2022-07-29 18:39:40 +02:00
ca03190593 Try to improve gpt2 performances 2022-07-29 18:38:01 +02:00
c9baf40937 Try to improve gpt2 performances 2022-07-29 18:13:26 +02:00
4 changed files with 141 additions and 71 deletions

View File

@ -123,7 +123,7 @@ class DecisionTransformerGPT2Attention(nn.Module):
max_positions = config.max_position_embeddings
self.register_buffer(
"bias",
torch.tril(torch.ones((max_positions, max_positions), dtype=torch.uint8)).view(
torch.tril(torch.ones((max_positions, max_positions), dtype=torch.bool)).view(
1, 1, max_positions, max_positions
),
)
@ -177,23 +177,24 @@ class DecisionTransformerGPT2Attention(nn.Module):
def _attn(self, query, key, value, attention_mask=None, head_mask=None):
attn_weights = torch.matmul(query, key.transpose(-1, -2))
normalizer = 1
if self.scale_attn_weights:
attn_weights = attn_weights / torch.tensor(
value.size(-1) ** 0.5, dtype=attn_weights.dtype, device=attn_weights.device
)
normalizer *= torch.tensor(value.size(-1) ** 0.5, dtype=query.dtype, device=query.device)
# Layer-wise attention scaling
if self.scale_attn_by_inverse_layer_idx:
attn_weights = attn_weights / float(self.layer_idx + 1)
normalizer *= float(self.layer_idx + 1)
attn_weights = attn_weights / normalizer
if not self.is_cross_attention:
# if only "normal" attention layer implements causal mask
query_length, key_length = query.size(-2), key.size(-2)
causal_mask = self.bias[:, :, key_length - query_length : key_length, :key_length].to(torch.bool)
causal_mask = self.bias[:, :, key_length - query_length : key_length, :key_length]
mask_value = torch.finfo(attn_weights.dtype).min
# Need to be a tensor, otherwise we get error: `RuntimeError: expected scalar type float but found double`.
# Need to be on the same device, otherwise `RuntimeError: ..., x and y to be on the same device`
mask_value = torch.tensor(mask_value, dtype=attn_weights.dtype).to(attn_weights.device)
mask_value = torch.tensor(mask_value, dtype=attn_weights.dtype, device=attn_weights.device)
attn_weights = torch.where(causal_mask, attn_weights, mask_value)
if attention_mask is not None:
@ -214,13 +215,19 @@ class DecisionTransformerGPT2Attention(nn.Module):
return attn_output, attn_weights
def _upcast_and_reordered_attn(self, query, key, value, attention_mask=None, head_mask=None):
def _upcast_and_reordered_attn(self, query, key, value, num_heads: int, attention_mask=None, head_mask=None):
# Use `torch.baddbmm` (a bit more efficient w/ alpha param for scaling -- from Megatron-LM)
bsz, num_heads, q_seq_len, dk = query.size()
_, _, k_seq_len, _ = key.size()
bsz_times_num_heads, query_length, _ = query.size()
_, _, key_length = key.size()
batch_size = bsz_times_num_heads // num_heads
# Preallocate attn_weights for `baddbmm`
attn_weights = torch.empty(bsz * num_heads, q_seq_len, k_seq_len, dtype=torch.float32, device=query.device)
attn_weights = torch.empty(
bsz_times_num_heads, query_length, key_length, dtype=torch.float32, device=query.device
)
if attention_mask is not None:
# Apply the attention mask
attn_weights.view(batch_size, num_heads, query_length, key_length)[:] = attention_mask
# Compute Scale Factor
scale_factor = 1.0
@ -233,27 +240,21 @@ class DecisionTransformerGPT2Attention(nn.Module):
# Upcast (turn off autocast) and reorder (Scale K by 1 / root(dk))
if is_amp_available:
with autocast(enabled=False):
q, k = query.reshape(-1, q_seq_len, dk), key.transpose(-1, -2).reshape(-1, dk, k_seq_len)
attn_weights = torch.baddbmm(attn_weights, q.float(), k.float(), beta=0, alpha=scale_factor)
attn_weights = attn_weights.reshape(bsz, num_heads, q_seq_len, k_seq_len)
attn_weights = torch.baddbmm(
attn_weights, query, key, beta=0 if attention_mask is None else 1, alpha=scale_factor
)
else:
q, k = query.reshape(-1, q_seq_len, dk), key.transpose(-1, -2).reshape(-1, dk, k_seq_len)
attn_weights = torch.baddbmm(attn_weights, q.float(), k.float(), beta=0, alpha=scale_factor)
attn_weights = attn_weights.reshape(bsz, num_heads, q_seq_len, k_seq_len)
attn_weights = torch.baddbmm(
attn_weights, query, key, beta=0 if attention_mask is None else 1, alpha=scale_factor
)
if not self.is_cross_attention:
# if only "normal" attention layer implements causal mask
query_length, key_length = query.size(-2), key.size(-2)
causal_mask = self.bias[:, :, key_length - query_length : key_length, :key_length].bool()
mask_value = torch.finfo(attn_weights.dtype).min
# Need to be a tensor, otherwise we get error: `RuntimeError: expected scalar type float but found double`.
# Need to be on the same device, otherwise `RuntimeError: ..., x and y to be on the same device`
mask_value = torch.tensor(mask_value, dtype=attn_weights.dtype).to(attn_weights.device)
attn_weights = torch.where(causal_mask, attn_weights, mask_value)
if attention_mask is not None:
# Apply the attention mask
attn_weights = attn_weights + attention_mask
causal_mask = self.bias[:, :, key_length - query_length : key_length, :key_length]
# In-place update of `attn_weights`
attn_weights.view(batch_size, num_heads, query_length, key_length).masked_fill_(
causal_mask, torch.finfo(attn_weights.dtype).min
)
attn_weights = nn.functional.softmax(attn_weights, dim=-1)
@ -290,7 +291,7 @@ class DecisionTransformerGPT2Attention(nn.Module):
def forward(
self,
hidden_states: Optional[Tuple[torch.FloatTensor]],
layer_past: Optional[Tuple[torch.Tensor]] = None,
layer_past: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
attention_mask: Optional[torch.FloatTensor] = None,
head_mask: Optional[torch.FloatTensor] = None,
encoder_hidden_states: Optional[torch.Tensor] = None,
@ -315,6 +316,12 @@ class DecisionTransformerGPT2Attention(nn.Module):
key = self._split_heads(key, self.num_heads, self.head_dim)
value = self._split_heads(value, self.num_heads, self.head_dim)
batch_size, num_heads, query_length, head_dim = query.size()
if self.reorder_and_upcast_attn:
query = query.to(torch.float32).reshape(-1, query_length, head_dim)
key = key.transpose(-1, -2).to(torch.float32).reshape(-1, head_dim, query_length)
value = value.reshape(-1, query_length, head_dim)
if layer_past is not None:
past_key, past_value = layer_past
key = torch.cat((past_key, key), dim=-2)
@ -326,7 +333,10 @@ class DecisionTransformerGPT2Attention(nn.Module):
present = None
if self.reorder_and_upcast_attn:
attn_output, attn_weights = self._upcast_and_reordered_attn(query, key, value, attention_mask, head_mask)
attn_output, attn_weights = self._upcast_and_reordered_attn(
query, key, value, num_heads=num_heads, attention_mask=attention_mask, head_mask=head_mask
)
attn_output = attn_output.view(batch_size, num_heads, query_length, head_dim)
else:
attn_output, attn_weights = self._attn(query, key, value, attention_mask, head_mask)

View File

@ -134,7 +134,7 @@ class GPT2Attention(nn.Module):
max_positions = config.max_position_embeddings
self.register_buffer(
"bias",
torch.tril(torch.ones((max_positions, max_positions), dtype=torch.uint8)).view(
torch.tril(torch.ones((max_positions, max_positions), dtype=torch.bool)).view(
1, 1, max_positions, max_positions
),
)
@ -188,23 +188,24 @@ class GPT2Attention(nn.Module):
def _attn(self, query, key, value, attention_mask=None, head_mask=None):
attn_weights = torch.matmul(query, key.transpose(-1, -2))
normalizer = 1
if self.scale_attn_weights:
attn_weights = attn_weights / torch.tensor(
value.size(-1) ** 0.5, dtype=attn_weights.dtype, device=attn_weights.device
)
normalizer *= torch.tensor(value.size(-1) ** 0.5, dtype=query.dtype, device=query.device)
# Layer-wise attention scaling
if self.scale_attn_by_inverse_layer_idx:
attn_weights = attn_weights / float(self.layer_idx + 1)
normalizer *= float(self.layer_idx + 1)
attn_weights = attn_weights / normalizer
if not self.is_cross_attention:
# if only "normal" attention layer implements causal mask
query_length, key_length = query.size(-2), key.size(-2)
causal_mask = self.bias[:, :, key_length - query_length : key_length, :key_length].to(torch.bool)
causal_mask = self.bias[:, :, key_length - query_length : key_length, :key_length]
mask_value = torch.finfo(attn_weights.dtype).min
# Need to be a tensor, otherwise we get error: `RuntimeError: expected scalar type float but found double`.
# Need to be on the same device, otherwise `RuntimeError: ..., x and y to be on the same device`
mask_value = torch.tensor(mask_value, dtype=attn_weights.dtype).to(attn_weights.device)
mask_value = torch.tensor(mask_value, dtype=attn_weights.dtype, device=attn_weights.device)
attn_weights = torch.where(causal_mask, attn_weights, mask_value)
if attention_mask is not None:
@ -225,13 +226,19 @@ class GPT2Attention(nn.Module):
return attn_output, attn_weights
def _upcast_and_reordered_attn(self, query, key, value, attention_mask=None, head_mask=None):
def _upcast_and_reordered_attn(self, query, key, value, num_heads: int, attention_mask=None, head_mask=None):
# Use `torch.baddbmm` (a bit more efficient w/ alpha param for scaling -- from Megatron-LM)
bsz, num_heads, q_seq_len, dk = query.size()
_, _, k_seq_len, _ = key.size()
bsz_times_num_heads, query_length, _ = query.size()
_, _, key_length = key.size()
batch_size = bsz_times_num_heads // num_heads
# Preallocate attn_weights for `baddbmm`
attn_weights = torch.empty(bsz * num_heads, q_seq_len, k_seq_len, dtype=torch.float32, device=query.device)
attn_weights = torch.empty(
bsz_times_num_heads, query_length, key_length, dtype=torch.float32, device=query.device
)
if attention_mask is not None:
# Apply the attention mask
attn_weights.view(batch_size, num_heads, query_length, key_length)[:] = attention_mask
# Compute Scale Factor
scale_factor = 1.0
@ -244,27 +251,21 @@ class GPT2Attention(nn.Module):
# Upcast (turn off autocast) and reorder (Scale K by 1 / root(dk))
if is_amp_available:
with autocast(enabled=False):
q, k = query.reshape(-1, q_seq_len, dk), key.transpose(-1, -2).reshape(-1, dk, k_seq_len)
attn_weights = torch.baddbmm(attn_weights, q.float(), k.float(), beta=0, alpha=scale_factor)
attn_weights = attn_weights.reshape(bsz, num_heads, q_seq_len, k_seq_len)
attn_weights = torch.baddbmm(
attn_weights, query, key, beta=0 if attention_mask is None else 1, alpha=scale_factor
)
else:
q, k = query.reshape(-1, q_seq_len, dk), key.transpose(-1, -2).reshape(-1, dk, k_seq_len)
attn_weights = torch.baddbmm(attn_weights, q.float(), k.float(), beta=0, alpha=scale_factor)
attn_weights = attn_weights.reshape(bsz, num_heads, q_seq_len, k_seq_len)
attn_weights = torch.baddbmm(
attn_weights, query, key, beta=0 if attention_mask is None else 1, alpha=scale_factor
)
if not self.is_cross_attention:
# if only "normal" attention layer implements causal mask
query_length, key_length = query.size(-2), key.size(-2)
causal_mask = self.bias[:, :, key_length - query_length : key_length, :key_length].bool()
mask_value = torch.finfo(attn_weights.dtype).min
# Need to be a tensor, otherwise we get error: `RuntimeError: expected scalar type float but found double`.
# Need to be on the same device, otherwise `RuntimeError: ..., x and y to be on the same device`
mask_value = torch.tensor(mask_value, dtype=attn_weights.dtype).to(attn_weights.device)
attn_weights = torch.where(causal_mask, attn_weights, mask_value)
if attention_mask is not None:
# Apply the attention mask
attn_weights = attn_weights + attention_mask
causal_mask = self.bias[:, :, key_length - query_length : key_length, :key_length]
# In-place update of `attn_weights`
attn_weights.view(batch_size, num_heads, query_length, key_length).masked_fill_(
causal_mask, torch.finfo(attn_weights.dtype).min
)
attn_weights = nn.functional.softmax(attn_weights, dim=-1)
@ -294,14 +295,14 @@ class GPT2Attention(nn.Module):
"""
Merges attn_head_size dim and num_attn_heads dim into hidden_size
"""
tensor = tensor.permute(0, 2, 1, 3).contiguous()
tensor = tensor.permute(0, 2, 1, 3)
new_shape = tensor.size()[:-2] + (num_heads * attn_head_size,)
return tensor.view(new_shape)
return tensor.reshape(new_shape)
def forward(
self,
hidden_states: Optional[Tuple[torch.FloatTensor]],
layer_past: Optional[Tuple[torch.Tensor]] = None,
layer_past: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
attention_mask: Optional[torch.FloatTensor] = None,
head_mask: Optional[torch.FloatTensor] = None,
encoder_hidden_states: Optional[torch.Tensor] = None,
@ -326,6 +327,12 @@ class GPT2Attention(nn.Module):
key = self._split_heads(key, self.num_heads, self.head_dim)
value = self._split_heads(value, self.num_heads, self.head_dim)
batch_size, num_heads, query_length, head_dim = query.size()
if self.reorder_and_upcast_attn:
query = query.to(torch.float32).reshape(-1, query_length, head_dim)
key = key.transpose(-1, -2).to(torch.float32).reshape(-1, head_dim, query_length)
value = value.reshape(-1, query_length, head_dim)
if layer_past is not None:
past_key, past_value = layer_past
key = torch.cat((past_key, key), dim=-2)
@ -337,7 +344,10 @@ class GPT2Attention(nn.Module):
present = None
if self.reorder_and_upcast_attn:
attn_output, attn_weights = self._upcast_and_reordered_attn(query, key, value, attention_mask, head_mask)
attn_output, attn_weights = self._upcast_and_reordered_attn(
query, key, value, num_heads=num_heads, attention_mask=attention_mask, head_mask=head_mask
)
attn_output = attn_output.view(batch_size, num_heads, query_length, head_dim)
else:
attn_output, attn_weights = self._attn(query, key, value, attention_mask, head_mask)
@ -1102,16 +1112,49 @@ class GPT2LMHeadModel(GPT2PreTrainedModel):
)
@staticmethod
def _reorder_cache(past: Tuple[Tuple[torch.Tensor]], beam_idx: torch.Tensor) -> Tuple[Tuple[torch.Tensor]]:
def _reorder_cache(
past: Tuple[Tuple[torch.Tensor, torch.Tensor], ...], beam_idx: torch.Tensor
) -> Tuple[Tuple[torch.Tensor, torch.Tensor], ...]:
"""
This function is used to re-order the `past_key_values` cache if [`~PreTrainedModel.beam_search`] or
[`~PreTrainedModel.beam_sample`] is called. This is required to match `past_key_values` with the correct
beam_idx at every generation step.
"""
return tuple(
tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past)
for layer_past in past
)
# Depending on `config.reorder_and_upcast_attn` values, stored past are different:
# - True: key [batch_size * num_heads, head_dim, seq_length], value [batch_size * num_heads, seq_length, head_dim]
# - False: key/value [batch_size, num_heads, seq_length, head_dim]
past_num_dimensions = len(past[0][0].shape)
device_to_beam_idx = {
past_state.device: beam_idx.to(past_state.device) for layer_past in past for past_state in layer_past
}
if past_num_dimensions == 3:
batch_size_times_num_heads, head_dim, seq_length = past[0][0].shape
batch_size = len(beam_idx)
num_heads = batch_size_times_num_heads // batch_size
return tuple(
(
layer_past[0]
.view(batch_size, num_heads, head_dim, seq_length)
.index_select(0, device_to_beam_idx[layer_past[0].device])
.view(batch_size_times_num_heads, head_dim, seq_length),
layer_past[1]
.view(batch_size, num_heads, seq_length, head_dim)
.index_select(0, device_to_beam_idx[layer_past[0].device])
.view(batch_size_times_num_heads, seq_length, head_dim),
)
for layer_past in past
)
elif past_num_dimensions == 4:
return tuple(
tuple(past_state.index_select(0, device_to_beam_idx[past_state.device]) for past_state in layer_past)
for layer_past in past
)
else:
raise ValueError(
"Past keys and values have to be either 3 or 4 dimensional depending on"
f" `config.reorder_and_upcast_attn`, for {past_num_dimensions} dimensions"
)
@add_start_docstrings(

View File

@ -358,8 +358,17 @@ class LayoutLMv2ModelTest(ModelTesterMixin, unittest.TestCase):
if key == "layoutlmv2.visual_segment_embedding":
# we skip the visual segment embedding as it has a custom initialization scheme
continue
max_diff = (model_slow_init.state_dict()[key] - model_fast_init.state_dict()[key]).sum().item()
self.assertLessEqual(max_diff, 1e-3, msg=f"{key} not identical")
slow_dtype = model_slow_init.state_dict()[key].dtype
fast_dtype = model_fast_init.state_dict()[key].dtype
self.assertEqual(slow_dtype, fast_dtype)
if fast_dtype == torch.bool:
# torch.BoolTensor should be deterministic
self.assertEqual(torch.all(model_slow_init.state_dict()[key] == model_fast_init.state_dict()[key]))
else:
max_diff = (model_slow_init.state_dict()[key] - model_fast_init.state_dict()[key]).sum().item()
self.assertLessEqual(max_diff, 1e-3, msg=f"{key} not identical")
def test_attention_outputs(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()

View File

@ -342,8 +342,16 @@ class ModelTesterMixin:
model_slow_init = model_class_copy.from_pretrained(tmpdirname, _fast_init=False)
for key in model_fast_init.state_dict().keys():
max_diff = (model_slow_init.state_dict()[key] - model_fast_init.state_dict()[key]).sum().item()
self.assertLessEqual(max_diff, 1e-3, msg=f"{key} not identical")
slow_dtype = model_slow_init.state_dict()[key].dtype
fast_dtype = model_fast_init.state_dict()[key].dtype
self.assertEqual(slow_dtype, fast_dtype)
if fast_dtype == torch.bool:
# torch.BoolTensor should be deterministic
self.assertEqual(torch.all(model_slow_init.state_dict()[key] == model_fast_init.state_dict()[key]))
else:
max_diff = (model_slow_init.state_dict()[key] - model_fast_init.state_dict()[key]).sum().item()
self.assertLessEqual(max_diff, 1e-3, msg=f"{key} not identical")
def test_save_load_fast_init_to_base(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()