Spaces:
Runtime error
Runtime error
feat: force final ln in encoder
Browse files
src/dalle_mini/model/configuration.py
CHANGED
|
@@ -60,12 +60,14 @@ class DalleBartConfig(PretrainedFromWandbMixin, PretrainedConfig):
|
|
| 60 |
# transformer variants
|
| 61 |
ln_type="layernorm", # layer normalization type, "rmsnorm", "layernorm"
|
| 62 |
ln_positions="normformer", # layer normalization positions, "normformer", "swinv2", "cogview", "postln", "deepnet" (same as postln)
|
| 63 |
-
head_scale=
|
| 64 |
use_cosine_attention=False, # used in Swin v2
|
| 65 |
tau_init=0.05, # used only in cosine attention (Swin v2)
|
| 66 |
use_deepnet_scaling=False, # used in Deepnet
|
| 67 |
-
use_glu=
|
| 68 |
-
|
|
|
|
|
|
|
| 69 |
**kwargs,
|
| 70 |
):
|
| 71 |
# text normalizer
|
|
@@ -91,7 +93,8 @@ class DalleBartConfig(PretrainedFromWandbMixin, PretrainedConfig):
|
|
| 91 |
self.tau_init = tau_init
|
| 92 |
self.use_deepnet_scaling = use_deepnet_scaling
|
| 93 |
self.use_glu = use_glu
|
| 94 |
-
self.
|
|
|
|
| 95 |
|
| 96 |
# common parameters
|
| 97 |
self.encoder_vocab_size = encoder_vocab_size
|
|
|
|
| 60 |
# transformer variants
|
| 61 |
ln_type="layernorm", # layer normalization type, "rmsnorm", "layernorm"
|
| 62 |
ln_positions="normformer", # layer normalization positions, "normformer", "swinv2", "cogview", "postln", "deepnet" (same as postln)
|
| 63 |
+
head_scale=False, # used in NormFormer
|
| 64 |
use_cosine_attention=False, # used in Swin v2
|
| 65 |
tau_init=0.05, # used only in cosine attention (Swin v2)
|
| 66 |
use_deepnet_scaling=False, # used in Deepnet
|
| 67 |
+
use_glu=False, # "GLU Variants Improve Transformer"
|
| 68 |
+
# parameters that should not be necessary but could affect results
|
| 69 |
+
force_ln_scale=True, # force scale in layernorm even when followed by dense layers
|
| 70 |
+
force_final_ln_encoder=False, # force layer normalization in encoder final layer even when followed by dense layers
|
| 71 |
**kwargs,
|
| 72 |
):
|
| 73 |
# text normalizer
|
|
|
|
| 93 |
self.tau_init = tau_init
|
| 94 |
self.use_deepnet_scaling = use_deepnet_scaling
|
| 95 |
self.use_glu = use_glu
|
| 96 |
+
self.force_ln_scale = force_ln_scale
|
| 97 |
+
self.force_final_ln_encoder = force_final_ln_encoder
|
| 98 |
|
| 99 |
# common parameters
|
| 100 |
self.encoder_vocab_size = encoder_vocab_size
|
src/dalle_mini/model/modeling.py
CHANGED
|
@@ -378,7 +378,7 @@ class GLU(nn.Module):
|
|
| 378 |
self.config.ln_type,
|
| 379 |
dtype=self.dtype,
|
| 380 |
epsilon=1e-05,
|
| 381 |
-
use_scale=self.config.
|
| 382 |
)(x)
|
| 383 |
w = nn.Dense(
|
| 384 |
self.ffn_dim,
|
|
@@ -403,7 +403,7 @@ class GLU(nn.Module):
|
|
| 403 |
self.config.ln_type,
|
| 404 |
dtype=self.dtype,
|
| 405 |
epsilon=1e-05,
|
| 406 |
-
use_scale=self.config.
|
| 407 |
)(x)
|
| 408 |
x = nn.Dropout(rate=self.config.activation_dropout)(
|
| 409 |
x, deterministic=deterministic
|
|
@@ -443,7 +443,7 @@ class FFN(nn.Module):
|
|
| 443 |
self.config.ln_type,
|
| 444 |
dtype=self.dtype,
|
| 445 |
epsilon=1e-05,
|
| 446 |
-
use_scale=self.config.
|
| 447 |
)(x)
|
| 448 |
x = nn.Dense(
|
| 449 |
self.ffn_dim,
|
|
@@ -459,7 +459,7 @@ class FFN(nn.Module):
|
|
| 459 |
self.config.ln_type,
|
| 460 |
dtype=self.dtype,
|
| 461 |
epsilon=1e-05,
|
| 462 |
-
use_scale=self.config.
|
| 463 |
)(x)
|
| 464 |
x = nn.Dropout(rate=self.config.activation_dropout)(
|
| 465 |
x, deterministic=deterministic
|
|
@@ -512,7 +512,7 @@ class FlaxBartEncoderLayer(nn.Module):
|
|
| 512 |
self.config.ln_type,
|
| 513 |
dtype=self.dtype,
|
| 514 |
epsilon=1e-05,
|
| 515 |
-
use_scale=self.config.
|
| 516 |
)(hidden_states)
|
| 517 |
hidden_states, attn_weights = FlaxBartAttention(
|
| 518 |
config=self.config,
|
|
@@ -561,7 +561,7 @@ class FlaxBartEncoderLayer(nn.Module):
|
|
| 561 |
use_scale = (
|
| 562 |
self.use_scale
|
| 563 |
or self.config.ln_positions == "postln"
|
| 564 |
-
or self.config.
|
| 565 |
)
|
| 566 |
hidden_states = norm(
|
| 567 |
self.config.ln_type,
|
|
@@ -617,7 +617,7 @@ class FlaxBartDecoderLayer(nn.Module):
|
|
| 617 |
self.config.ln_type,
|
| 618 |
dtype=self.dtype,
|
| 619 |
epsilon=1e-05,
|
| 620 |
-
use_scale=self.config.
|
| 621 |
)(hidden_states)
|
| 622 |
hidden_states, attn_weights = FlaxBartAttention(
|
| 623 |
config=self.config,
|
|
@@ -656,7 +656,7 @@ class FlaxBartDecoderLayer(nn.Module):
|
|
| 656 |
self.config.ln_type,
|
| 657 |
dtype=self.dtype,
|
| 658 |
epsilon=1e-05,
|
| 659 |
-
use_scale=self.config.
|
| 660 |
)(hidden_states)
|
| 661 |
hidden_states, cross_attn_weights = FlaxBartAttention(
|
| 662 |
config=self.config,
|
|
@@ -709,7 +709,7 @@ class FlaxBartDecoderLayer(nn.Module):
|
|
| 709 |
use_scale = (
|
| 710 |
self.use_scale
|
| 711 |
or self.config.ln_positions == "postln"
|
| 712 |
-
or self.config.
|
| 713 |
)
|
| 714 |
hidden_states = norm(
|
| 715 |
self.config.ln_type,
|
|
@@ -761,8 +761,9 @@ class FlaxBartEncoderLayerCollection(nn.Module):
|
|
| 761 |
# or every 6 layers for Swin v2
|
| 762 |
# not needed for other models which use layernorm before x-attention
|
| 763 |
# ignored args for deepnet which always add a norm with scale
|
| 764 |
-
add_norm = self.config.
|
| 765 |
-
|
|
|
|
| 766 |
)
|
| 767 |
# we don't need to scale the norm for the last layer
|
| 768 |
use_scale = i != n_layers - 1
|
|
|
|
| 378 |
self.config.ln_type,
|
| 379 |
dtype=self.dtype,
|
| 380 |
epsilon=1e-05,
|
| 381 |
+
use_scale=self.config.force_ln_scale,
|
| 382 |
)(x)
|
| 383 |
w = nn.Dense(
|
| 384 |
self.ffn_dim,
|
|
|
|
| 403 |
self.config.ln_type,
|
| 404 |
dtype=self.dtype,
|
| 405 |
epsilon=1e-05,
|
| 406 |
+
use_scale=self.config.force_ln_scale,
|
| 407 |
)(x)
|
| 408 |
x = nn.Dropout(rate=self.config.activation_dropout)(
|
| 409 |
x, deterministic=deterministic
|
|
|
|
| 443 |
self.config.ln_type,
|
| 444 |
dtype=self.dtype,
|
| 445 |
epsilon=1e-05,
|
| 446 |
+
use_scale=self.config.force_ln_scale,
|
| 447 |
)(x)
|
| 448 |
x = nn.Dense(
|
| 449 |
self.ffn_dim,
|
|
|
|
| 459 |
self.config.ln_type,
|
| 460 |
dtype=self.dtype,
|
| 461 |
epsilon=1e-05,
|
| 462 |
+
use_scale=self.config.force_ln_scale,
|
| 463 |
)(x)
|
| 464 |
x = nn.Dropout(rate=self.config.activation_dropout)(
|
| 465 |
x, deterministic=deterministic
|
|
|
|
| 512 |
self.config.ln_type,
|
| 513 |
dtype=self.dtype,
|
| 514 |
epsilon=1e-05,
|
| 515 |
+
use_scale=self.config.force_ln_scale,
|
| 516 |
)(hidden_states)
|
| 517 |
hidden_states, attn_weights = FlaxBartAttention(
|
| 518 |
config=self.config,
|
|
|
|
| 561 |
use_scale = (
|
| 562 |
self.use_scale
|
| 563 |
or self.config.ln_positions == "postln"
|
| 564 |
+
or self.config.force_ln_scale
|
| 565 |
)
|
| 566 |
hidden_states = norm(
|
| 567 |
self.config.ln_type,
|
|
|
|
| 617 |
self.config.ln_type,
|
| 618 |
dtype=self.dtype,
|
| 619 |
epsilon=1e-05,
|
| 620 |
+
use_scale=self.config.force_ln_scale,
|
| 621 |
)(hidden_states)
|
| 622 |
hidden_states, attn_weights = FlaxBartAttention(
|
| 623 |
config=self.config,
|
|
|
|
| 656 |
self.config.ln_type,
|
| 657 |
dtype=self.dtype,
|
| 658 |
epsilon=1e-05,
|
| 659 |
+
use_scale=self.config.force_ln_scale,
|
| 660 |
)(hidden_states)
|
| 661 |
hidden_states, cross_attn_weights = FlaxBartAttention(
|
| 662 |
config=self.config,
|
|
|
|
| 709 |
use_scale = (
|
| 710 |
self.use_scale
|
| 711 |
or self.config.ln_positions == "postln"
|
| 712 |
+
or self.config.force_ln_scale
|
| 713 |
)
|
| 714 |
hidden_states = norm(
|
| 715 |
self.config.ln_type,
|
|
|
|
| 761 |
# or every 6 layers for Swin v2
|
| 762 |
# not needed for other models which use layernorm before x-attention
|
| 763 |
# ignored args for deepnet which always add a norm with scale
|
| 764 |
+
add_norm = self.config.force_final_ln_encoder or (
|
| 765 |
+
self.config.ln_positions == "swinv2"
|
| 766 |
+
and ((i == n_layers - 1) or ((i + 1) % 6 == 0))
|
| 767 |
)
|
| 768 |
# we don't need to scale the norm for the last layer
|
| 769 |
use_scale = i != n_layers - 1
|