Update modeling_qwen.py, fix logn bug
Browse files- modeling_qwen.py +6 -5
modeling_qwen.py
CHANGED
|
@@ -177,7 +177,8 @@ class QWenAttention(nn.Module):
|
|
| 177 |
config.hidden_size, self.projection_size, bias=not config.no_bias
|
| 178 |
)
|
| 179 |
|
| 180 |
-
|
|
|
|
| 181 |
self.core_attention_flash = FlashSelfAttention(
|
| 182 |
causal=True, attention_dropout=config.attn_pdrop
|
| 183 |
)
|
|
@@ -371,12 +372,12 @@ class QWenAttention(nn.Module):
|
|
| 371 |
if self.use_logn_attn and not self.training:
|
| 372 |
if self.logn_tensor.device != query.device:
|
| 373 |
self.logn_tensor = self.logn_tensor.to(query.device).type_as(query)
|
| 374 |
-
seq_start = key.size(
|
| 375 |
-
seq_end = key.size(
|
| 376 |
logn_tensor = self.logn_tensor[:, seq_start:seq_end, :, :]
|
| 377 |
query = query * logn_tensor.expand_as(query)
|
| 378 |
|
| 379 |
-
if self.use_flash_attn and flash_attn_unpadded_func is not None:
|
| 380 |
q, k, v = query, key, value
|
| 381 |
context_layer = self.core_attention_flash(q, k, v)
|
| 382 |
|
|
@@ -397,7 +398,7 @@ class QWenAttention(nn.Module):
|
|
| 397 |
attn_output = self.c_proj(context_layer)
|
| 398 |
outputs = (attn_output, present)
|
| 399 |
if output_attentions:
|
| 400 |
-
if self.use_flash_attn and flash_attn_unpadded_func is not None:
|
| 401 |
raise ValueError("Cannot output attentions while using flash-attn")
|
| 402 |
else:
|
| 403 |
outputs += (attn_weight,)
|
|
|
|
| 177 |
config.hidden_size, self.projection_size, bias=not config.no_bias
|
| 178 |
)
|
| 179 |
|
| 180 |
+
self.is_fp32 = not(config.bf16 or config.fp16)
|
| 181 |
+
if self.use_flash_attn and flash_attn_unpadded_func is not None and not self.is_fp32:
|
| 182 |
self.core_attention_flash = FlashSelfAttention(
|
| 183 |
causal=True, attention_dropout=config.attn_pdrop
|
| 184 |
)
|
|
|
|
| 372 |
if self.use_logn_attn and not self.training:
|
| 373 |
if self.logn_tensor.device != query.device:
|
| 374 |
self.logn_tensor = self.logn_tensor.to(query.device).type_as(query)
|
| 375 |
+
seq_start = key.size(1) - query.size(1)
|
| 376 |
+
seq_end = key.size(1)
|
| 377 |
logn_tensor = self.logn_tensor[:, seq_start:seq_end, :, :]
|
| 378 |
query = query * logn_tensor.expand_as(query)
|
| 379 |
|
| 380 |
+
if self.use_flash_attn and flash_attn_unpadded_func is not None and not self.is_fp32:
|
| 381 |
q, k, v = query, key, value
|
| 382 |
context_layer = self.core_attention_flash(q, k, v)
|
| 383 |
|
|
|
|
| 398 |
attn_output = self.c_proj(context_layer)
|
| 399 |
outputs = (attn_output, present)
|
| 400 |
if output_attentions:
|
| 401 |
+
if self.use_flash_attn and flash_attn_unpadded_func is not None and not self.is_fp32:
|
| 402 |
raise ValueError("Cannot output attentions while using flash-attn")
|
| 403 |
else:
|
| 404 |
outputs += (attn_weight,)
|