ChrisMcCormick commited on
Commit
7e1eb73
·
verified ·
1 Parent(s): de0ded9

Adding source

Browse files

Testing out if this works!

__init__.py ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+
3
+ """
4
+ Subspace Decoder Package
5
+
6
+ A Transformer decoder implementation with Multi-Head Latent Attention (MLA)
7
+ and decomposed MLP layers for efficient parameter usage.
8
+ """
9
+
10
+ # Import all the main classes from models
11
+ from .models import (
12
+ SharedSpaceDecoderConfig,
13
+ SharedSpaceDecoderPreTrainedModel,
14
+ SharedSpaceDecoderModel,
15
+ SharedSpaceDecoderForCausalLM,
16
+ )
17
+
18
+ __version__ = "0.1.0"
19
+
20
+ __all__ = [
21
+ "SharedSpaceDecoderConfig",
22
+ "SharedSpaceDecoderPreTrainedModel",
23
+ "SharedSpaceDecoderModel",
24
+ "SharedSpaceDecoderForCausalLM",
25
+ ]
layers/feedforward.py ADDED
@@ -0,0 +1,196 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """# ▂▂▂▂▂▂▂▂▂▂▂▂
2
+
3
+ # `feedforward.py`
4
+
5
+ Regarding dropout:
6
+
7
+ - I don't see it applied to the MoE in DeepSeek-V3, [here](https://huggingface.co/deepseek-ai/DeepSeek-R1/blob/main/modeling_deepseek.py).
8
+
9
+ - I don't see it applied in [modeling_llama.py](https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py#L140)
10
+
11
+ Norms:
12
+
13
+ * nn.RMSNorm [here](https://docs.pytorch.org/docs/stable/generated/torch.nn.RMSNorm.html)
14
+
15
+ ## FFN
16
+ """
17
+
18
+ import torch
19
+ import torch.nn as nn
20
+ import torch.nn.functional as F
21
+ from models.shared_space_config import SharedSpaceDecoderConfig
22
+
23
+
24
+ def create_norm_layer(hidden_size: int, config: SharedSpaceDecoderConfig) -> nn.Module:
25
+ """
26
+ Create a normalization layer based on the config norm_type.
27
+
28
+ Args:
29
+ hidden_size: The dimension to normalize over
30
+ config: Configuration containing norm_type and epsilon values
31
+
32
+ Returns:
33
+ Either a LayerNorm or RMSNorm layer
34
+ """
35
+ if config.norm_type == "layernorm":
36
+ return nn.LayerNorm(hidden_size, eps=config.layer_norm_eps)
37
+ elif config.norm_type == "rmsnorm":
38
+ return DeepseekV3RMSNorm(hidden_size, eps=config.rms_norm_eps)
39
+ else:
40
+ # This should be caught by config validation, but being defensive
41
+ raise ValueError(f"Unknown norm_type: {config.norm_type}")
42
+
43
+
44
+ # TODO - Find a shared place to put this.
45
+ class DeepseekV3RMSNorm(nn.Module):
46
+ def __init__(self, hidden_size, eps=1e-6):
47
+ """
48
+ DeepseekV3RMSNorm is equivalent to T5LayerNorm
49
+ """
50
+ super().__init__()
51
+ self.weight = nn.Parameter(torch.ones(hidden_size))
52
+ self.variance_epsilon = eps
53
+
54
+ def forward(self, hidden_states):
55
+ input_dtype = hidden_states.dtype
56
+ hidden_states = hidden_states.to(torch.float32)
57
+ variance = hidden_states.pow(2).mean(-1, keepdim=True)
58
+ hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
59
+ return self.weight * hidden_states.to(input_dtype)
60
+
61
+ class SubspaceFeedForward(nn.Module):
62
+ """
63
+ Feed-forward block for SharedSpaceDecoder.
64
+
65
+ Implements SwiGLU:
66
+ FFN(x) = W_out( Swish(W_in(x)) ⊙ W_gate(x) ) + residual
67
+
68
+ Supports both dense and decomposed MLP variants.
69
+
70
+ Dense:
71
+ - W_in: Linear(hidden_dim → intermediate_dim)
72
+ - W_gate: Linear(hidden_dim → intermediate_dim)
73
+ - W_out: Linear(intermediate_dim → hidden_dim)
74
+
75
+ Decomposed:
76
+ - W_in_shared: Linear(hidden_dim → rank, bias=False)
77
+ - W_in_shared_norm: RMSNorm
78
+ - W_in: Linear(rank → intermediate_dim)
79
+ - W_gate_shared: Linear(hidden_dim → rank, bias=False)
80
+ - W_gate_shared_norm: RMSNorm
81
+ - W_gate: Linear(rank → intermediate_dim)
82
+ - W_out: Linear(intermediate_dim → rank, bias=False)
83
+ - W_out_shared: Linear(rank → hidden_dim)
84
+
85
+ Residual, dropout, and post-norm are handled inside the block.
86
+ """
87
+
88
+ def __init__(self, config, layer_idx):
89
+ super().__init__()
90
+
91
+
92
+ #dropout_prob = config.hidden_dropout_prob # TODO - Style -- don't define variables if only used once.
93
+
94
+ # Determine whether this is a dense or decomposed layer.
95
+ # It's dense if either:
96
+ # - ffn_decompose is disabled (no dense layers at all)
97
+ # - ffn_decompose is enabled, but this is one of the early dense layers.
98
+ self.is_dense = (not config.ffn_decompose) or (layer_idx < config.num_dense_layers)
99
+
100
+ hidden_dim = config.hidden_size
101
+ intermediate_dim = config.intermediate_size # TODO - Find something shorter, and use the same name.
102
+
103
+ # If it's one of the dense layers,
104
+ if self.is_dense:
105
+ # === Dense FFN Projections ===
106
+ self.W_in = nn.Linear(hidden_dim, intermediate_dim)
107
+ self.W_gate = nn.Linear(hidden_dim, intermediate_dim)
108
+ self.W_out = nn.Linear(intermediate_dim, hidden_dim)
109
+
110
+ # Define weights for the decomposed version.
111
+ else:
112
+ rank = config.ffn_rank
113
+
114
+ print("hidden_dim:", hidden_dim)
115
+ print("rank:", rank)
116
+
117
+ # === Input Projections ===
118
+ self.W_in_shared = nn.Linear(hidden_dim, rank, bias=False)
119
+ self.W_in_shared_norm = create_norm_layer(rank, config)
120
+ self.W_in = nn.Linear(rank, intermediate_dim, bias=True)
121
+
122
+ # === Gate Projections ===
123
+ self.W_gate_shared = nn.Linear(hidden_dim, rank, bias=False)
124
+ self.W_gate_shared_norm = create_norm_layer(rank, config)
125
+ self.W_gate = nn.Linear(rank, intermediate_dim, bias=True)
126
+
127
+ # === Output Projection ===
128
+ self.W_out = nn.Linear(intermediate_dim, rank, bias=False)
129
+ # TODO - Could experiment with this.
130
+ #self.W_out_shared_layernorm = DeepseekV3RMSNorm(rank, eps=config.eps)
131
+ self.W_out_shared = nn.Linear(rank, hidden_dim, bias=True)
132
+
133
+ # See notes no dropout
134
+ #self.dropout = nn.Dropout(config.hidden_dropout_prob)
135
+
136
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
137
+ # === Tensor Dimension Symbols ===
138
+ # B: batch_size — number of samples in the batch
139
+ # T: seq_len — number of tokens per sample
140
+ # D: hidden_dim — model embedding size
141
+ # R: ffn_rank — latent shared subspace dimension
142
+ # D_ff: intermediate_size — FFN hidden dimension
143
+
144
+ # =========================
145
+ # Gated Feedforward
146
+ # =========================
147
+
148
+ if self.is_dense:
149
+ # =============
150
+ # Dense
151
+ # =============
152
+
153
+ # Input: x [B, T, D]
154
+ # Output: x_proj [B, T, D_ff]
155
+ x_proj = self.W_in(x)
156
+
157
+ # Output: gate [B, T, D_ff]
158
+ gate = self.W_gate(x)
159
+
160
+ # SwiGLU nonlinearity
161
+ x = F.silu(x_proj) * gate # [B, T, D_ff]
162
+
163
+ # See notes on dropout
164
+ #x = self.dropout(x)
165
+
166
+ # Output: x [B, T, D]
167
+ x = self.W_out(x)
168
+
169
+ else:
170
+ # ==================
171
+ # Decomposed
172
+ # ==================
173
+
174
+ # Input: x [B, T, D]
175
+ # Output: x_proj [B, T, D_ff]
176
+ x_proj = self.W_in(self.W_in_shared_norm(self.W_in_shared(x)))
177
+
178
+ # Input: x [B, T, D]
179
+ # Output: gate [B, T, D_ff]
180
+ gate = self.W_gate(self.W_gate_shared_norm(self.W_gate_shared(x)))
181
+
182
+ # SwiGLU nonlinearity
183
+ x = F.silu(x_proj) * gate # [B, T, D_ff]
184
+
185
+ # See notes on dropout
186
+ #x = self.dropout(x)
187
+
188
+ # Output: x [B, T, D]
189
+ x = self.W_out_shared(self.W_out(x))
190
+
191
+
192
+ return x
193
+
194
+
195
+
196
+
layers/mla.py ADDED
@@ -0,0 +1,619 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """# ▂▂▂▂▂▂▂▂▂▂▂▂
2
+
3
+ # `mla.py`
4
+
5
+ Based on: https://huggingface.co/deepseek-ai/DeepSeek-R1/blob/main/modeling_deepseek.py
6
+
7
+ ## RotaryEmbedding
8
+ """
9
+
10
+ import torch
11
+ import torch.nn as nn
12
+ import torch.nn.functional as F
13
+ from typing import Optional
14
+
15
+ from models.shared_space_config import SharedSpaceDecoderConfig
16
+
17
+
18
+ def create_norm_layer(hidden_size: int, config: SharedSpaceDecoderConfig) -> nn.Module:
19
+ """
20
+ Create a normalization layer based on the config norm_type.
21
+
22
+ If `hidden_size` is `None`, this returns an identity layer.
23
+
24
+ Args:
25
+ hidden_size: The dimension to normalize over
26
+ config: Configuration containing norm_type and epsilon values
27
+
28
+ Returns:
29
+ Either a LayerNorm or RMSNorm layer
30
+ """
31
+ if hidden_size is None:
32
+ return nn.Identity()
33
+ elif config.norm_type == "layernorm":
34
+ return nn.LayerNorm(hidden_size, eps=config.layer_norm_eps)
35
+ elif config.norm_type == "rmsnorm":
36
+ return DeepseekV3RMSNorm(hidden_size, eps=config.rms_norm_eps)
37
+ else:
38
+ # This should be caught by config validation, but being defensive
39
+ raise ValueError(f"Unknown norm_type: {config.norm_type}")
40
+
41
+
42
+ # TODO - Find a shared place to put this.
43
+ class DeepseekV3RMSNorm(nn.Module):
44
+ def __init__(self, hidden_size, eps=1e-6):
45
+ """
46
+ DeepseekV3RMSNorm is equivalent to T5LayerNorm
47
+ """
48
+ super().__init__()
49
+ self.weight = nn.Parameter(torch.ones(hidden_size))
50
+ self.variance_epsilon = eps
51
+
52
+ def forward(self, hidden_states):
53
+ input_dtype = hidden_states.dtype
54
+ hidden_states = hidden_states.to(torch.float32)
55
+ variance = hidden_states.pow(2).mean(-1, keepdim=True)
56
+ hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
57
+ return self.weight * hidden_states.to(input_dtype)
58
+
59
+
60
+ # Helper function needed because it's called twice during RoPE,
61
+ # but I dumped it in the comments there.
62
+ # TODO - Nah, screw it, just write it twice! At least then you get
63
+ # to use the word 'query' instead of 'x'.
64
+ def rotate_half(x):
65
+ """Rotates half the hidden dims of the input."""
66
+ x1 = x[..., : x.shape[-1] // 2]
67
+ x2 = x[..., x.shape[-1] // 2 :]
68
+ return torch.cat((-x2, x1), dim=-1)
69
+
70
+ class RotaryEmbedding(nn.Module):
71
+ """Precompute RoPE embeddings and store them as buffers."""
72
+
73
+ def __init__(self, config: SharedSpaceDecoderConfig) -> None:
74
+ super().__init__()
75
+
76
+ dim = config.rope_dims
77
+ seq_len = config.max_position_embeddings
78
+
79
+ # ------------------------------
80
+ # Compute inverse frequencies
81
+ # ------------------------------
82
+ # Shape: [dim // 2]
83
+ # inv_freq[i] = 1 / (theta^(i / dim))
84
+ inv_freq = 1.0 / (
85
+ config.rope_theta
86
+ ** (torch.arange(0, dim, 2, dtype=torch.float32) / dim)
87
+ )
88
+
89
+ # ------------------------------
90
+ # Apply RoPE scaling if configured
91
+ # ------------------------------
92
+ if config.rope_scaling is not None:
93
+ scaling_type = config.rope_scaling.get("type", "linear")
94
+ scaling_factor = config.rope_scaling.get("factor", 1.0)
95
+
96
+ if scaling_type == "linear":
97
+ # Linear scaling: divide frequencies by scaling factor
98
+ inv_freq = inv_freq / scaling_factor
99
+ elif scaling_type == "dynamic":
100
+ # Dynamic scaling: adjust based on sequence length
101
+ # This is a simplified implementation
102
+ inv_freq = inv_freq / scaling_factor
103
+ else:
104
+ print(f"Warning: Unknown RoPE scaling type '{scaling_type}', using linear scaling")
105
+ inv_freq = inv_freq / scaling_factor
106
+
107
+ # ------------------------------
108
+ # Compute position indices
109
+ # ------------------------------
110
+ # Shape: [seq_len]
111
+ t = torch.arange(seq_len, dtype=torch.float32)
112
+
113
+ # ------------------------------
114
+ # Outer product: [seq_len, dim // 2]
115
+ # Each row i contains: t[i] * inv_freq
116
+ # ------------------------------
117
+ freqs = torch.outer(t, inv_freq)
118
+
119
+ # ------------------------------
120
+ # Duplicate for interleaved sin/cos: [seq_len, dim]
121
+ # This matches the common format: [sin_0, cos_0, sin_1, cos_1, ...]
122
+ # ------------------------------
123
+ emb = torch.cat((freqs, freqs), dim=-1)
124
+
125
+ # ------------------------------
126
+ # Register cos/sin as buffers
127
+ # - Stored in float32
128
+ # - Will be moved to correct device/dtype via model.to(...)
129
+ # - Not saved with state_dict (persistent=False)
130
+ # ------------------------------
131
+ self.register_buffer("cos", emb.cos(), persistent=False)
132
+ self.register_buffer("sin", emb.sin(), persistent=False)
133
+
134
+ def forward(self, position_ids: torch.LongTensor) -> tuple[torch.Tensor, torch.Tensor]:
135
+ """ """
136
+ return None # This function is not necessary.
137
+
138
+ """## MLA"""
139
+
140
+ class MultiheadLatentAttention(nn.Module):
141
+ """
142
+ A variant of MLA with:
143
+ - Simplified RoPE handling:
144
+ - A portion of the head dimensions are used for position information.
145
+ - Same number of queries as keys. (no MQA)
146
+ - Optional output subspace
147
+ """
148
+
149
+ def __init__(self, config: SharedSpaceDecoderConfig, layer_idx: int):
150
+ super().__init__()
151
+
152
+ self.config = config
153
+
154
+ # Used to determine if this layer is dense or uses latents.
155
+ self.layer_idx = layer_idx
156
+ self.attention_dropout_prob = config.attention_dropout_prob
157
+
158
+ self.num_heads = config.num_attention_heads
159
+
160
+ self.rope_theta = config.rope_theta
161
+ self.rope_dims = config.rope_dims
162
+ self.nope_dims = config.nope_dims
163
+
164
+ self.q_shared_dim = config.q_shared_dim
165
+ self.kv_shared_dim = config.kv_shared_dim
166
+ self.o_shared_dim = config.o_shared_dim
167
+
168
+ self.qk_private_dim = config.qk_private_dim
169
+ self.vo_private_dim = config.vo_private_dim
170
+
171
+ self.hidden_size = config.hidden_size
172
+
173
+ # =========================
174
+ # Input Projections
175
+ # =========================
176
+
177
+ # If this is one of the dense layers,
178
+ if self.layer_idx < config.num_dense_layers:
179
+
180
+ # =========================
181
+ # Dense Attention
182
+ # =========================
183
+
184
+ # No latent projections.
185
+ self.latent_spaces = False
186
+
187
+ # Define the standard QKV projection
188
+ self.qkv_proj = nn.Linear(
189
+ config.hidden_size,
190
+ self.num_heads * (self.qk_private_dim * 2 + self.vo_private_dim),
191
+ bias=config.attention_bias,
192
+ )
193
+
194
+ # Dense output projection
195
+ self.o_proj = nn.Linear(
196
+ self.num_heads * self.vo_private_dim,
197
+ config.hidden_size,
198
+ bias=config.attention_bias,
199
+ )
200
+
201
+ # If we're past the dense layers,
202
+ else:
203
+
204
+ # =========================
205
+ # Latent Attention
206
+ # =========================
207
+
208
+ # Use latent projections.
209
+ self.latent_spaces = True
210
+
211
+ # Input latent projections
212
+
213
+ print("config.q_shared_dim", config.q_shared_dim)
214
+
215
+ # If we're using a shared query subspace,
216
+ if config.q_shared_dim is not None:
217
+ # Set a flag that we'll check in `forward`.
218
+ self.query_shared = True
219
+
220
+ self.q_shared_proj = nn.Linear(
221
+ config.hidden_size,
222
+ self.q_shared_dim,
223
+ bias=config.attention_bias,
224
+ )
225
+
226
+ self.q_shared_norm = create_norm_layer(self.q_shared_dim, config)
227
+
228
+ else:
229
+ print("Using identity for shared projection.")
230
+ # Set a flag that we'll check in `forward`.
231
+ self.query_shared = False
232
+
233
+ self.q_shared_dim = config.hidden_size
234
+
235
+ #print("Updated self.q_shared_dim to", self.q_shared_dim)
236
+
237
+ # Use identity.
238
+ self.q_shared_proj = nn.Identity()
239
+ self.q_shared_norm = nn.Identity()
240
+
241
+ # If we're using a shared key/value subspace,
242
+ if config.kv_shared_dim is not None:
243
+ # Set a flag that we'll check in `forward`.
244
+ self.keyvalue_shared = True
245
+
246
+ self.kv_shared_proj = nn.Linear(
247
+ config.hidden_size,
248
+ self.kv_shared_dim,
249
+ bias=config.attention_bias,
250
+ )
251
+
252
+ self.kv_shared_norm = create_norm_layer(self.kv_shared_dim, config)
253
+
254
+ else:
255
+ # Set a flag that we'll check in `forward`.
256
+ self.keyvalue_shared = False
257
+
258
+ self.kv_shared_dim = config.hidden_size
259
+
260
+ # Use identity.
261
+ self.kv_shared_proj = nn.Identity()
262
+ self.kv_shared_norm = nn.Identity()
263
+
264
+ #print("config.q_shared_dim", config.q_shared_dim)
265
+ #print("self.qk_private_dim", self.qk_private_dim)
266
+
267
+ # Query heads
268
+ self.q_private_proj = nn.Linear(
269
+ self.q_shared_dim,
270
+ self.num_heads * self.qk_private_dim,
271
+ bias=False # TODO
272
+ )
273
+
274
+ # Key and Value heads, concatenated
275
+ self.kv_private_proj = nn.Linear(
276
+ self.kv_shared_dim,
277
+ self.num_heads * (self.qk_private_dim + self.vo_private_dim),
278
+ bias=False,
279
+ )
280
+
281
+ # Use output subspace if o_shared_dim is specified
282
+ self.output_subspace = config.o_shared_dim is not None
283
+
284
+ # If we're using an output subspace,
285
+ if self.output_subspace:
286
+
287
+ # ==========================
288
+ # Output Subspace
289
+ # ==========================
290
+
291
+ self.o_shared_dim = config.o_shared_dim
292
+
293
+ # Per-head output projections
294
+ # (Similar to original W^O, but projects the scored value vectors
295
+ # into a latent space instead of back to the model)
296
+ self.o_private_proj = nn.Linear(
297
+ self.num_heads * self.vo_private_dim,
298
+ self.o_shared_dim,
299
+ bias=False
300
+ )
301
+
302
+ # Norm layer between o_private_proj and o_shared_proj
303
+ # Note: In previous ViT experiments, this norm step hurt performance, but was beneficial
304
+ # in the DeepSeekV3 experiments.
305
+ # However, we're making it configurable so it can be tested in different contexts.
306
+ self.o_private_norm = create_norm_layer(self.o_shared_dim, config)
307
+
308
+ # Shared output projection
309
+ # The head outputs from `o_private_proj` are first summed together (across
310
+ # heads) in the latent space.
311
+ # Then we project their combined outputs (a single vector per token)
312
+ # back to model space via `o_shared_proj`.
313
+ self.o_shared_proj = nn.Linear(
314
+ self.o_shared_dim,
315
+ self.hidden_size,
316
+ bias=config.attention_bias
317
+ )
318
+ else:
319
+ # Dense output projection
320
+ self.o_proj = nn.Linear(
321
+ self.num_heads * self.vo_private_dim,
322
+ config.hidden_size,
323
+ bias=config.attention_bias,
324
+ )
325
+
326
+ # Softmax scaling factor.
327
+ self.softmax_scale = self.qk_private_dim ** (-0.5)
328
+
329
+
330
+ def forward(
331
+ self,
332
+ hidden_states: torch.Tensor,
333
+ position_embeddings: tuple[torch.Tensor, torch.Tensor],
334
+ attention_mask: Optional[torch.Tensor],
335
+ #past_key_value: Optional[Cache] = None, # TODO - Can I remove this?
336
+ #cache_position: Optional[torch.LongTensor] = None, # TODO - Can I remove this?
337
+ **kwargs,
338
+ ) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor]]]:
339
+ # === Tensor Dimension Symbols ===
340
+ # B: batch_size — number of samples in the batch
341
+ # T: seq_len — number of tokens per sample
342
+ # H: n_heads — number of attention heads
343
+ # D: hidden_dim — model embedding size
344
+ # Dv: vo_private_dim - per-head value/output projection dimension
345
+ # Dr: rope_dims - The first Dr dimensions receive rope.
346
+ # Cq: q_shared_dim - query shared subspace size
347
+ # Ckv: kv_shared_dim - key-value shared subspace size
348
+ # Co: o_shared_dim - output shared subspace size
349
+
350
+ # Input token embeddings
351
+ # hidden_states: [B, T, D]
352
+ B, T = hidden_states.shape[:2]
353
+ H = self.num_heads
354
+ Dq = self.qk_private_dim # per-head dim for Q and K
355
+ Dv = self.vo_private_dim # per-head dim for V/O
356
+
357
+ Dc_q, Dc_kv = self.q_shared_dim, self.kv_shared_dim
358
+
359
+ # ==============================
360
+ # QKV Head Projections
361
+ # ==============================
362
+ # Project tokens into per-head query, key, and value vectors
363
+
364
+ # If this layer uses latent projections,
365
+ if self.latent_spaces:
366
+
367
+ # ================================
368
+ # Shared Space Projections
369
+ # ================================
370
+
371
+ # Project token embeddings into shared latents
372
+ # Input:
373
+ # hidden_states [B, T, D]
374
+ # q_shared_proj [D, Cq]
375
+ # kv_shared_proj [D, Ckv]
376
+ # Output:
377
+ # q_shared [B, T, Cq]
378
+ # kv_shared [B, T, Ckv]
379
+
380
+ # If we're using a shared query subspace,
381
+ if self.q_shared_dim is not None:
382
+ q_shared = self.q_shared_proj(hidden_states)
383
+
384
+ # Normalize latent vectors, shapes unchanged.
385
+ q_shared = self.q_shared_norm(q_shared)
386
+ # Otherwise,
387
+ else:
388
+ # Use the hidden states
389
+ q_shared = hidden_states
390
+
391
+ # If we're using a shared key/value subspace,
392
+ if self.kv_shared_dim is not None:
393
+
394
+ # Project token embeddings into shared subspace.
395
+ kv_shared = self.kv_shared_proj(hidden_states)
396
+
397
+ # Normalize latent vectors, shapes unchanged.
398
+ kv_shared = self.kv_shared_norm(kv_shared)
399
+ # Otherwise,
400
+ else:
401
+ # Use the hidden states
402
+ kv_shared = hidden_states
403
+
404
+ # ======================================
405
+ # Per-Head (Private) Projections
406
+ # ======================================
407
+
408
+ # Project query latents onto query heads.
409
+ # Input:
410
+ # q_shared [B, T, Cq]
411
+ # q_private_proj [Cq, H*Dh]
412
+ # Output:
413
+ # queries [B, T, H*Dh]
414
+ queries = self.q_private_proj(q_shared)
415
+
416
+ # Project key/value latents onto key and value heads.
417
+ # The key and value heads are all concatenated, each head occupies
418
+ # Dh columns of the kv_private_proj. This yields the key and value
419
+ # vectors concatenated in the same way.
420
+ #
421
+ # Input:
422
+ # kv_shared [B, T, Ckv]
423
+ # kv_private_proj [Ckv, 2*H*Dh]
424
+ # Output:
425
+ # keysvalues [B, T, 2*H*Dh]
426
+ keysvalues = self.kv_private_proj(kv_shared)
427
+
428
+ # Split into key and value tensors
429
+ # Each: [B, T, H * Dh]
430
+ keys, values = keysvalues.chunk(2, dim=-1)
431
+
432
+ # If this is a dense attention layer (no latent projections),
433
+ else:
434
+
435
+ # ====================
436
+ # Standard MHA
437
+ # ====================
438
+
439
+ # Standard QKV projection
440
+ # Input:
441
+ # hidden_states [B, T, D]
442
+ # qkv_proj [D, 3*H*Dh]
443
+ # Output:
444
+ # querieskeysvalues [B, T, 3*H*Dh]
445
+ querieskeysvalues = self.qkv_proj(hidden_states)
446
+
447
+ # Separate query, key, and value vectors
448
+ # Each: [B, T, H * Dh]
449
+ queries, keys, values = querieskeysvalues.chunk(3, dim=-1)
450
+
451
+ # Split up queries so that there's just one per row.
452
+ # Same for keys and values.
453
+ #
454
+ # Inputs:
455
+ # Each [B, T, H*Dh]
456
+ # Output:
457
+ # Each [B, H, T, Dh]
458
+ queries = queries.view(B, T, H, Dq).transpose(1, 2)
459
+ keys = keys.view(B, T, H, Dq).transpose(1, 2)
460
+ values = values.view(B, T, H, Dv).transpose(1, 2)
461
+
462
+ # ==================
463
+ # RoPE
464
+ # ==================
465
+ # Apply rotary position embeddings to the first `self.rope_dims` of
466
+ # each head.
467
+ # The slice operations are free, but the concatenation is
468
+ # not, because the outputs of the rotation operation are new data
469
+ # occupying different memory. Still considered the best option,
470
+ # though.
471
+
472
+ # 1. Unpack the precomputed cosine and sine embeddings
473
+ # Position embeddings is a tuple of
474
+ # (cos [seq_len, rope_dims],
475
+ # sin [seq_len, rope_dims])
476
+ cos, sin = position_embeddings
477
+
478
+ # 2. Split the query and key heads into the part to rotate and the part
479
+ # to pass through (early columns get position info, later ones don't)
480
+ #
481
+ # (Using queries as example)
482
+ # Inputs:
483
+ # queries [B, H, T, Dh] Dh = rope_dims + not_rope_dims
484
+ # Outputs:
485
+ # q_rope [B, H, T, Dr]
486
+ # q_pass [B, H, T, Dh-Dr]
487
+ q_rope, q_pass = queries[..., :self.rope_dims], queries[..., self.rope_dims:]
488
+ k_rope, k_pass = keys[..., :self.rope_dims], keys[..., self.rope_dims:]
489
+
490
+ # 3. Apply the rotary embedding to the designated slice
491
+ #
492
+ # To broadcast cos and sin across the batch and head dimensions, we unsqueeze them.
493
+ # Shape change: [T, Dr] -> [1, 1, T, Dr]
494
+ cos = cos.unsqueeze(0).unsqueeze(0)
495
+ sin = sin.unsqueeze(0).unsqueeze(0)
496
+
497
+ #print("q_rope.shape[-1] // 2:", (q_rope.shape[-1] // 2))
498
+ #print("x1 = x[..., :x.shape[-1] // 2 ].shape:", q_rope[..., :q_rope.shape[-1] // 2 ].shape)
499
+ #print("sin/cos.shape:", cos.shape)
500
+ #print("q_rope.shape:", q_rope.shape)
501
+ #print("(q_rope * cos).shape:", (q_rope * cos).shape)
502
+ #print("rotate_half(q_rope).shape:", rotate_half(q_rope).shape)
503
+ #print("(rotate_half(q_rope) * sin).shape:", (rotate_half(q_rope) * sin).shape)
504
+ """
505
+ In this example batch_size = 2, hum_heads = 8, seq_len = 65, rope_dims = 16
506
+
507
+ q_rope.shape[-1] // 2: 8
508
+ x1 = x[..., :x.shape[-1] // 2 ].shape: torch.Size([2, 8, 65, 8])
509
+
510
+ sin/cos.shape: torch.Size([1, 1, 65, 16]) # After double unsqueeze.
511
+ vq_rope.shape: torch.Size([2, 8, 65, 16])
512
+
513
+ (q_rope * cos).shape: torch.Size([2, 8, 65, 16])
514
+
515
+ rotate_half(q_rope).shape: torch.Size([2, 8, 65, 16])
516
+ (rotate_half(q_rope) * sin).shape: torch.Size([2, 8, 65, 16])
517
+ """
518
+
519
+
520
+ # Let's walk through the queries as the example.
521
+ # What does rotate half do?
522
+ # dim -1 is the row vectors, the queries
523
+ #
524
+ # Step 1: Split the vector in half.
525
+ # "q_rope.shape[-1] // 2" <- How much to select. Half the length of the q_rope vector
526
+ # x1 = x[..., :x.shape[-1] // 2 ] # Select the first half of the vector.
527
+ # x2 = x[..., x.shape[-1] // 2:] # Select the second half.
528
+ #
529
+ # Step 2:
530
+ # - Apply negative to the values in the second half.
531
+ # - Reverse the order of the halves.
532
+ # return torch.cat((-x2, x1), dim=-1)
533
+ #
534
+ # ---- (q_rope * cos) ----
535
+ # Element-wise multiply the values in each `cos` vector with the
536
+ # corresponding (i.e., same sequence position) `q_rope` vector.
537
+ #
538
+ # Inputs:
539
+ # q_rope [B, H, T, Dr]
540
+ # cos [1, 1, T, Dr]
541
+ #
542
+ # Outputs:
543
+ # x [B, H, T, Dr]
544
+ #
545
+ # ---- (rotate_half(q_rope)) ----
546
+ # TODO
547
+ #
548
+ # Inputs:
549
+ # q_rope [B, T, Dr]
550
+ #
551
+ # Outputs:
552
+ # rot_q_rope [B, T, Dr]
553
+ #
554
+ # ---- rotated * sin ----
555
+ # TODO
556
+ q_rotated = (q_rope * cos) + (rotate_half(q_rope) * sin)
557
+ k_rotated = (k_rope * cos) + (rotate_half(k_rope) * sin)
558
+
559
+ # 4. Concatenate the rotated and pass-through parts back together
560
+ # Input (each): [B, H, T, Dr] and [B, H, T, Dq-Dr]
561
+ # Output (each): [B, H, T, Dq]
562
+ queries = torch.cat((q_rotated, q_pass), dim=-1)
563
+ keys = torch.cat((k_rotated, k_pass), dim=-1)
564
+
565
+ # ===================
566
+ # Attention
567
+ # ===================
568
+ # The tensors (queries, keys, values) now have shape [B, H, T, Dq]
569
+ # and are ready for the attention score calculation.
570
+
571
+ # Only apply dropout during training.
572
+ # self.training is a pytorch flag.
573
+ if self.training:
574
+ dropout_p = self.attention_dropout_prob
575
+ else:
576
+ dropout_p = 0.0
577
+
578
+ # Call SDPA / Flash Attention
579
+ # https://docs.pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html
580
+ attn_output = F.scaled_dot_product_attention(
581
+ queries,
582
+ keys,
583
+ values,
584
+ attn_mask=None, # attention_mask,
585
+ dropout_p=dropout_p,
586
+ scale=self.softmax_scale,
587
+ is_causal=True, # This is a decoder - apply causal masking
588
+ )
589
+
590
+ # Reshape output back to [B, T, H * Dv] from [B, H, T, Dv]
591
+ attn_output = attn_output.transpose(1, 2).contiguous().view(B, T, H * Dv)
592
+
593
+ # =========================
594
+ # Output Projection
595
+ # =========================
596
+
597
+ # If we are using an output latent projection,
598
+ if self.latent_spaces and self.output_subspace:
599
+
600
+ # Project the attention output into the output latent space.
601
+ # This is analogous to the W^O matrix in standard attention but
602
+ # projects to an intermediate latent dimension.
603
+ attn_output = self.o_private_proj(attn_output)
604
+
605
+ # Apply normalization to the output latents
606
+ attn_output = self.o_private_norm(attn_output)
607
+
608
+ # Re-project the output latent representation back to model space.
609
+ attn_output = self.o_shared_proj(attn_output)
610
+
611
+ # If this is a dense layer,
612
+ else:
613
+ # Project the values back into model space.
614
+ attn_output = self.o_proj(attn_output)
615
+
616
+ # -----------------------------------------
617
+
618
+ return attn_output
619
+
layers/task_heads.py ADDED
@@ -0,0 +1,196 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import torch
3
+ import torch.nn as nn
4
+ import torch.nn.functional as F
5
+ from typing import Optional, Union
6
+
7
+ from transformers.modeling_outputs import CausalLMOutputWithPast
8
+
9
+ from ..models.configuration_shared_subspace_decoder import SharedSpaceDecoderConfig
10
+ from ..models.modeling_shared_subspace_decoder import (
11
+ SharedSpaceDecoderPreTrainedModel,
12
+ SharedSpaceDecoderModel,
13
+ DeepseekV3RMSNorm
14
+ )
15
+
16
+ def create_norm_layer(hidden_size: int, config: SharedSpaceDecoderConfig) -> nn.Module:
17
+ """
18
+ Create a normalization layer based on the config norm_type.
19
+
20
+ Args:
21
+ hidden_size: The dimension to normalize over
22
+ config: Configuration containing norm_type and epsilon values
23
+
24
+ Returns:
25
+ Either a LayerNorm or RMSNorm layer
26
+ """
27
+ if config.norm_type == "layernorm":
28
+ return nn.LayerNorm(hidden_size, eps=config.layer_norm_eps)
29
+ elif config.norm_type == "rmsnorm":
30
+ from ..models.modeling_shared_subspace_decoder import DeepseekV3RMSNorm
31
+ return DeepseekV3RMSNorm(hidden_size, eps=config.rms_norm_eps)
32
+ else:
33
+ # This should be caught by config validation, but being defensive
34
+ raise ValueError(f"Unknown norm_type: {config.norm_type}")
35
+
36
+
37
+ class SharedSpaceDecoderForCausalLM(SharedSpaceDecoderPreTrainedModel):
38
+ """
39
+ Subspace Decoder model with a causal language modeling head.
40
+
41
+ This model extends the SharedSpaceDecoderModel with:
42
+ - A language modeling head that projects hidden states to vocabulary logits
43
+ - Support for computing cross-entropy loss for language modeling
44
+ - Proper HuggingFace compatibility for causal language modeling tasks
45
+ - Decoder-specific initialization strategies
46
+
47
+ The model can be used for:
48
+ - Text generation
49
+ - Language modeling pretraining
50
+ - Fine-tuning on downstream tasks
51
+ """
52
+
53
+ def __init__(self, config: SharedSpaceDecoderConfig) -> None:
54
+ super().__init__(config)
55
+
56
+ # Initialize the base decoder model
57
+ self.model = SharedSpaceDecoderModel(config)
58
+
59
+ # Final layer norm before the language modeling head
60
+ self.norm = create_norm_layer(config.hidden_size, config)
61
+
62
+ # Language modeling head
63
+ # Projects from hidden_size to vocab_size to get logits for each token
64
+ self.lm_head = nn.Linear(
65
+ config.hidden_size,
66
+ config.vocab_size,
67
+ bias=False # Following common practice in modern LMs
68
+ )
69
+
70
+ # Initialize weights with decoder-specific strategy
71
+ # Note: tie_weights() will be called automatically by post_init() if config.tie_word_embeddings=True
72
+ self.post_init()
73
+
74
+ def _init_weights(self, module: nn.Module) -> None:
75
+ """
76
+ Decoder-specific weight initialization with special handling for language modeling head.
77
+
78
+ Key differences from encoder initialization:
79
+ - Language modeling head gets specialized initialization for stability
80
+ - Configurable normalization layers (LayerNorm or RMSNorm) are properly handled
81
+ - Weight tying considerations for embedding/lm_head relationship
82
+ """
83
+
84
+ # Use the base class initialization for most modules
85
+ super()._init_weights(module)
86
+
87
+ # Special handling for language modeling head
88
+ if module is self.lm_head:
89
+ # Use smaller initialization for the language modeling head
90
+ # This helps with training stability in autoregressive generation
91
+ # Common practice is to use std=initializer_range or smaller
92
+ module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
93
+
94
+ # If weight tying is not used, we might want even smaller init
95
+ if self.model.vocab_proj is not None:
96
+ # For vocab subspace models where weights aren't tied,
97
+ # use a smaller scale to prevent initial logits from being too large
98
+ module.weight.data.normal_(mean=0.0, std=self.config.initializer_range * 0.5)
99
+
100
+ def get_input_embeddings(self):
101
+ """Return the input embedding layer for compatibility with HuggingFace."""
102
+ return self.model.vocab_embed
103
+
104
+ def set_input_embeddings(self, value):
105
+ """Set the input embedding layer for compatibility with HuggingFace."""
106
+ self.model.vocab_embed = value
107
+
108
+ def get_output_embeddings(self):
109
+ """Return the output embedding layer (lm_head) for compatibility."""
110
+ return self.lm_head
111
+
112
+ def set_output_embeddings(self, new_embeddings):
113
+ """Set the output embedding layer for compatibility."""
114
+ self.lm_head = new_embeddings
115
+
116
+ def tie_weights(self):
117
+ """
118
+ Tie the input and output embedding weights.
119
+
120
+ This method sets the language modeling head's weight to be the same as
121
+ the input embedding weight. This reduces the number of parameters and
122
+ is a common practice in modern language models.
123
+
124
+ Note: For vocab subspace models, we need to handle the case where
125
+ input embeddings go through a projection layer.
126
+ """
127
+ # Only tie when embeddings live in model space (no vocab_proj)
128
+ if getattr(self.model, "vocab_proj", None) is None:
129
+ # Use HF utility for correct tying/cloning semantics
130
+ self._tie_or_clone_weights(self.lm_head, self.model.vocab_embed)
131
+ # else: leave untied for subspace case
132
+
133
+
134
+ def forward(
135
+ self,
136
+ input_ids: torch.LongTensor,
137
+ attention_mask: Optional[torch.Tensor] = None,
138
+ labels: Optional[torch.LongTensor] = None,
139
+ **kwargs,
140
+ ) -> Union[CausalLMOutputWithPast, tuple]:
141
+ """
142
+ Forward pass for causal language modeling.
143
+
144
+ Args:
145
+ input_ids: Token ids of shape [batch_size, seq_len]
146
+ attention_mask: Attention mask of shape [batch_size, seq_len]
147
+ (1 for real tokens, 0 for padding)
148
+ labels: Ground truth token ids for computing loss. Same shape as input_ids.
149
+ If provided, loss will be computed. Typically input_ids shifted by 1.
150
+
151
+ Returns:
152
+ CausalLMOutputWithPast containing:
153
+ - logits: Prediction logits of shape [batch_size, seq_len, vocab_size]
154
+ - loss: Cross-entropy loss if labels provided, else None
155
+ - hidden_states: Final layer hidden states [batch_size, seq_len, hidden_size]
156
+ """
157
+
158
+ # Run the base decoder model
159
+ # This applies all the transformer layers with causal attention
160
+ hidden_states = self.model(
161
+ input_ids=input_ids,
162
+ attention_mask=attention_mask,
163
+ **kwargs
164
+ )
165
+
166
+ # Apply final layer normalization
167
+ # This normalizes the final hidden states before the language modeling head
168
+ hidden_states = self.norm(hidden_states)
169
+
170
+ # Project to vocabulary logits
171
+ # Shape: [batch_size, seq_len, vocab_size]
172
+ logits = self.lm_head(hidden_states)
173
+
174
+ # Compute loss if labels are provided
175
+ # Previously, we had custom loss computation here, but now we use the
176
+ # standard HuggingFace loss function.
177
+ loss = None
178
+ if labels is not None:
179
+ # Flatten the tokens
180
+ loss = self.loss_function(
181
+ logits,
182
+ labels,
183
+ vocab_size=self.config.vocab_size,
184
+ **kwargs,
185
+ )
186
+
187
+ # Return in HuggingFace format
188
+ return CausalLMOutputWithPast(
189
+ loss=loss,
190
+ logits=logits,
191
+ past_key_values=None, # Not implementing KV cache yet
192
+ #hidden_states=hidden_states,
193
+ hidden_states=hidden_states if kwargs.get("output_hidden_states", False) else None,
194
+ attentions=None,
195
+ )
196
+
models/__init__.py ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+
3
+ """
4
+ Shared Subspace Decoder Models
5
+
6
+ This module contains the implementation of the Shared Subspace Decoder architecture,
7
+ including Multi-Head Latent Attention (MLA) and decomposed MLP layers.
8
+ """
9
+
10
+ from transformers import AutoConfig, AutoModel, AutoModelForCausalLM
11
+
12
+ from .configuration_shared_subspace_decoder import SharedSpaceDecoderConfig
13
+ from .modeling_shared_subspace_decoder import (
14
+ SharedSpaceDecoderPreTrainedModel,
15
+ SharedSpaceDecoderModel,
16
+ )
17
+
18
+ # Import from task_heads in layers directory
19
+ from ..layers.task_heads import SharedSpaceDecoderForCausalLM
20
+
21
+ # Register the configuration class with AutoConfig
22
+ AutoConfig.register("shared_subspace_decoder", SharedSpaceDecoderConfig)
23
+
24
+ # Register the model classes with AutoModel
25
+ AutoModel.register(SharedSpaceDecoderConfig, SharedSpaceDecoderModel)
26
+ AutoModelForCausalLM.register(SharedSpaceDecoderConfig, SharedSpaceDecoderForCausalLM)
27
+
28
+ __all__ = [
29
+ "SharedSpaceDecoderConfig",
30
+ "SharedSpaceDecoderPreTrainedModel",
31
+ "SharedSpaceDecoderModel",
32
+ "SharedSpaceDecoderForCausalLM",
33
+ ]
models/configuration_shared_subspace_decoder.py ADDED
@@ -0,0 +1,329 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """# `shared_space_config.py`
2
+
3
+ #### `*Config`
4
+ """
5
+
6
+ from typing import Optional
7
+
8
+ import torch
9
+ from torch import nn
10
+
11
+ from transformers.configuration_utils import PretrainedConfig
12
+ from transformers.modeling_utils import PreTrainedModel
13
+
14
+ """`def make_shorthand`"""
15
+
16
+ def make_shorthand(model_cfg):
17
+ """
18
+ Takes an instance subencoder `*Config` and constructs a shorthand
19
+ name for the model based on settings.
20
+ """
21
+
22
+ dense_str = str(model_cfg.num_dense_layers) + "mha + "
23
+
24
+ if model_cfg.o_shared_dim is not None:
25
+ o_str = "." + str(model_cfg.o_shared_dim)
26
+ else:
27
+ o_str = ""
28
+
29
+ # If no output subspace is used, the dimension will show as -1.
30
+ attn_str = (
31
+ dense_str
32
+ + "mla."
33
+ + str(model_cfg.q_shared_dim)
34
+ + "."
35
+ + str(model_cfg.kv_shared_dim)
36
+ + o_str
37
+ )
38
+
39
+ # MLP Configuration
40
+ if model_cfg.ffn_decompose:
41
+ dense_str = (
42
+ str(model_cfg.num_dense_layers)
43
+ + "mlp."
44
+ + str(model_cfg.intermediate_size)
45
+ + " + "
46
+ )
47
+
48
+ mlp_str = (
49
+ dense_str
50
+ + str(model_cfg.num_hidden_layers - model_cfg.num_dense_layers)
51
+ + "dcmp."
52
+ + "x"
53
+ + str(model_cfg.intermediate_size)
54
+ + "."
55
+ + str(model_cfg.ffn_rank)
56
+ )
57
+ else:
58
+ mlp_str = "mlp." + str(model_cfg.intermediate_size)
59
+
60
+ # Assemble string
61
+ shorthand = (
62
+ f"{attn_str} - {mlp_str} - "
63
+ f"h{model_cfg.hidden_size} - l{model_cfg.num_hidden_layers}"
64
+ )
65
+
66
+ """
67
+ The run name includes training settings
68
+
69
+ run_name = (
70
+ f"{config['stats']['total_elements']} - "
71
+ f"{attn_str} - {mlp_str} - "
72
+ f"h{model_cfg.hidden_size} - l{model_cfg.num_hidden_layers} - "
73
+ f"bs{ptrain_cfg['train_batch_size']} - lr{lr_str} - "
74
+ f"seq{ptrain_cfg['max_seq_length']}"
75
+ )
76
+ """
77
+
78
+ return shorthand
79
+
80
+
81
+ class SharedSpaceDecoderConfig(PretrainedConfig):
82
+ r"""
83
+ Configuration class for SharedSpaceDecoderConfig.
84
+
85
+ Extends the HuggingFace `PretrainedConfig` to support architectural
86
+ variations including:
87
+ - Multi-Head Latent Attention (MLA)
88
+ - Decomposed MLPs (low-rank FFNs)
89
+ - Flexible attention backends (eager, flash, sdpa)
90
+ - Explicit shared subspaces for Q, K, V, and O projections
91
+
92
+ This config does not infer any defaults based on `hidden_size`. All
93
+ dimensions and ranks must be explicitly specified. If required values are
94
+ missing, a `ValueError` is raised during initialization.
95
+
96
+ ----------------------
97
+ Core Model Parameters:
98
+ ----------------------
99
+ - vocab_size (`int`) — Vocabulary size.
100
+ - hidden_size (`int`) — Model hidden dimension.
101
+ - num_hidden_layers (`int`) — Number of transformer blocks.
102
+ - intermediate_size (`int`) — Feed-forward hidden dimension.
103
+ - hidden_act (`str`) — Activation function.
104
+ - hidden_dropout_prob (`float`) — Dropout after projections and FFNs.
105
+ - attention_dropout_prob (`float`) — Dropout applied to attention scores.
106
+ - max_position_embeddings (`int`) — Max sequence length.
107
+ - initializer_range (`float`) — Stddev of weight init.
108
+
109
+ - layer_norm_eps (`float`) — Epsilon for LayerNorm.
110
+ - rms_norm_ps (`float`) — Epsilon for RMSNorm
111
+
112
+ - classifier_dropout (`float` or None) — Dropout for final classifier.
113
+
114
+ - vocab_subspace
115
+ - vocab_rank
116
+
117
+ ----------------------------------
118
+ Multi-Head Latent Attention (MLA):
119
+ ----------------------------------
120
+ - num_attention_heads (`int`) — Number of attention heads.
121
+
122
+ - q_shared_dim (`int`) — Rank of the shared query subspace.
123
+ - kv_shared_dim (`int`) — Rank of the shared key/value subspace.
124
+
125
+ - output_subspace (`bool`) — Whether to use a shared latent subspace for output projections.
126
+ - o_shared_dim (`int`) — Rank of the shared output subspace (required if `output_subspace=True`).
127
+ - qk_private_dim (`int`) — Query/key private dimension per head.
128
+ - vo_private_dim (`int`) — Value/output private dimension per head.
129
+
130
+ - rope_dims (`int`) — Number of head dimensions carrying RoPE.
131
+ - nope_dims (`int`) — Non-positional encoding dimensions.
132
+ - rope_theta (`float`) — Base frequency used for RoPE.
133
+ - rope_scaling (`dict` or None) — HF-style scaling dict for RoPE.
134
+ - attention_bias (`bool`) — Whether to include bias terms in Q/K/V projections.
135
+ - num_dense_layers (`int`) — Number of leading layers that do not use
136
+ subspaces for attention or FFNs.
137
+ - attention_backend (`str`) — Must be one of `"eager"`, `"flash_attention_2"`, or `"sdpa"`.
138
+
139
+ ----------------------
140
+ Decomposed MLP (Low-Rank FFN):
141
+ ----------------------
142
+ - ffn_decompose (`bool`) — Whether to enable low-rank FFNs.
143
+ - ffn_rank (`int`) — Rank of the shared FFN latent space (required if `ffn_decompose=True`).
144
+
145
+ ----------------------
146
+ Validation Behavior:
147
+ ----------------------
148
+ Raises `ValueError` at init time if:
149
+ - FFN decomposition is enabled without specifying `ffn_rank`.
150
+ - An unknown `attention_backend` is provided.
151
+ """
152
+
153
+ model_type = "shared_subspace_decoder"
154
+
155
+ def __init__(
156
+ self,
157
+
158
+ # === Core Model ===
159
+ vocab_size: int = 30522,
160
+ hidden_size: int = 512,
161
+ num_hidden_layers: int = 12,
162
+
163
+ intermediate_size: int = 3072,
164
+
165
+ hidden_dropout_prob=0.1,
166
+ attention_dropout_prob=0.1,
167
+ max_position_embeddings: int = 2048,
168
+ initializer_range=0.02,
169
+ layer_norm_eps=1e-12,
170
+ rms_norm_eps=1e-6, # Their default, but confirm in config.
171
+ norm_type="layernorm", # Choice between "layernorm" and "rmsnorm"
172
+ classifier_dropout=None,
173
+
174
+ vocab_subspace=False,
175
+ vocab_rank=None,
176
+ tie_word_embeddings=True,
177
+
178
+ # === Multi-Head Latent Attention ===
179
+ num_attention_heads: int = 16,
180
+ rope_dims: int = 16,
181
+
182
+ q_shared_dim: int = None,
183
+ kv_shared_dim: int = None,
184
+
185
+ o_shared_dim=None, # If None, no output subspace is used
186
+
187
+ # Private head dimensions
188
+ qk_private_dim: int = None, # Query/key private dimension per head
189
+ vo_private_dim: int = None, # Value/output private dimension per head
190
+ nope_dims: int = None, # Non-positional encoding dimensions
191
+
192
+ attention_backend="eager",
193
+ rope_theta=10000.0,
194
+ rope_scaling=None,
195
+ attention_bias=False,
196
+
197
+ # === MLA Composition ===
198
+ num_dense_layers=12, # dense MHA layers before MLA starts
199
+
200
+ # === Decomposed MLP ===
201
+ ffn_decompose=False,
202
+ ffn_rank=None,
203
+ **kwargs
204
+ ) -> None:
205
+ super().__init__(**kwargs)
206
+
207
+
208
+
209
+ # === Core Model ===
210
+ self.vocab_size = vocab_size
211
+ self.hidden_size = hidden_size
212
+ self.num_hidden_layers = num_hidden_layers
213
+ self.intermediate_size = intermediate_size
214
+ self.hidden_dropout_prob = hidden_dropout_prob
215
+ self.attention_dropout_prob = attention_dropout_prob
216
+ self.max_position_embeddings = max_position_embeddings
217
+ self.initializer_range = initializer_range
218
+ self.layer_norm_eps = layer_norm_eps
219
+ self.rms_norm_eps = rms_norm_eps
220
+ self.norm_type = norm_type
221
+ self.classifier_dropout = classifier_dropout
222
+
223
+ self.vocab_subspace = vocab_subspace
224
+ self.vocab_rank = vocab_rank
225
+ self.tie_word_embeddings = tie_word_embeddings
226
+
227
+ # === MLA ===
228
+ self.num_attention_heads = num_attention_heads
229
+ self.rope_dims = rope_dims
230
+
231
+ self.q_shared_dim = q_shared_dim
232
+ self.kv_shared_dim = kv_shared_dim
233
+ self.o_shared_dim = o_shared_dim
234
+
235
+ # Private head dimensions
236
+ self.qk_private_dim = qk_private_dim
237
+ self.vo_private_dim = vo_private_dim
238
+ self.nope_dims = nope_dims
239
+ self.rope_theta = rope_theta
240
+ self.rope_scaling = rope_scaling
241
+ self.attention_bias = attention_bias
242
+ self.num_dense_layers = num_dense_layers
243
+
244
+ # === Decomposed FFN ===
245
+ self.ffn_decompose = ffn_decompose
246
+ self.ffn_rank = ffn_rank
247
+
248
+ # === Attention backend ===
249
+ self.attention_backend = attention_backend
250
+
251
+ # === Validation ===
252
+ # TODO - Somewhere during training these get instantiated with bad
253
+ # values...
254
+ #self._validate()
255
+
256
+ #print(f" > SubEnc *Config.init: {make_shorthand(self)}\n")
257
+
258
+
259
+ def _validate(self):
260
+ # === Model ===
261
+ if self.num_dense_layers > self.num_hidden_layers:
262
+ raise ValueError("`num_dense_layers` must be <= `num_hidden_layers`")
263
+ if self.vocab_subspace and self.vocab_rank is None:
264
+ raise ValueError("`vocab_rank` must be set when `vocab_subspace=True`")
265
+
266
+ # === MLA Validation ===
267
+ # At least one of q_shared_dim or kv_shared_dim must be set if we have subspace layers
268
+ if self.num_dense_layers < self.num_hidden_layers and self.q_shared_dim is None and self.kv_shared_dim is None:
269
+ raise ValueError("At least one of q_shared_dim or kv_shared_dim must be set when there are subspace layers")
270
+
271
+ # Validate that private dimensions are set
272
+ if self.qk_private_dim is None or self.vo_private_dim is None:
273
+ raise ValueError("Must set qk_private_dim and vo_private_dim")
274
+ if self.nope_dims is None:
275
+ raise ValueError("Must set nope_dims")
276
+
277
+ # === Decomposed FFN ===
278
+ if self.ffn_decompose and self.ffn_rank is None:
279
+ raise ValueError("`ffn_rank` must be set when `ffn_decompose=True`")
280
+ if self.ffn_decompose and self.num_dense_layers >= self.num_hidden_layers:
281
+ raise ValueError("`ffn_decompose` was set but `num_dense` is >= number of layers")
282
+
283
+ # === Attention Backend ===
284
+ valid_backends = ["eager", "flash_attention_2", "sdpa"]
285
+ if self.attention_backend not in valid_backends:
286
+ raise ValueError(f"Unknown attention backend: {self.attention_backend}, options are {valid_backends}")
287
+
288
+ # === Norm Type ===
289
+ valid_norm_types = ["layernorm", "rmsnorm"]
290
+ if self.norm_type not in valid_norm_types:
291
+ raise ValueError(f"Unknown norm type: {self.norm_type}, options are {valid_norm_types}")
292
+
293
+ #### `get_config`
294
+
295
+ import json
296
+
297
+ def get_config(filename):
298
+
299
+ # Load the config file.
300
+ with open(filename) as f:
301
+ full_cfg = json.load(f)
302
+
303
+ # Strict key check on the model configuration.
304
+
305
+ # Get the list of keys allowed / required by `*Config`
306
+ valid_keys = SharedSpaceDecoderConfig.__init__.__code__.co_varnames
307
+ # Remove `self` and `kwargs`
308
+ valid_keys = set(valid_keys) - {"self", "kwargs"}
309
+
310
+ # Compare the set of keys in the json file vs `*Config`
311
+ extra_keys = set(full_cfg["model"]) - valid_keys
312
+ missing_keys = valid_keys - set(full_cfg["model"])
313
+
314
+ # If there any in the `json` that aren't in `*Config`,
315
+ if extra_keys:
316
+ # List them for the user.
317
+ raise ValueError(f"Unknown keys in config: {sorted(extra_keys)}")
318
+
319
+ # If the json config is missing required keys,
320
+ if missing_keys:
321
+ # List them for the user.
322
+ raise ValueError(f"config json is missing: {sorted(missing_keys)}")
323
+
324
+ # Will raise TypeError, by design, if required args are missing
325
+ # The asterisks unpack the dictionary into a list of keywords as though
326
+ # all of the settings were writting out individually.
327
+ model_cfg = SharedSpaceDecoderConfig(**full_cfg["model"])
328
+
329
+ return full_cfg, model_cfg
models/modeling_shared_subspace_decoder.py ADDED
@@ -0,0 +1,390 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+
3
+ """
4
+ modeling_shared_subspace_decoder.py
5
+
6
+ SharedSpaceDecoder model implementation for HuggingFace Transformers.
7
+ """
8
+
9
+ from typing import Optional
10
+
11
+ import torch
12
+ from torch import nn
13
+
14
+ from transformers.configuration_utils import PretrainedConfig
15
+ from transformers.modeling_utils import PreTrainedModel
16
+ from transformers.modeling_attn_mask_utils import _prepare_4d_attention_mask_for_sdpa
17
+
18
+ from ..layers.mla import MultiheadLatentAttention, RotaryEmbedding
19
+ from ..layers.feedforward import SubspaceFeedForward
20
+ from .configuration_shared_subspace_decoder import SharedSpaceDecoderConfig
21
+
22
+ """`RMSNorm`
23
+
24
+ From:
25
+ https://huggingface.co/deepseek-ai/DeepSeek-R1/blob/main/modeling_deepseek.py
26
+
27
+ TODO - May not need?
28
+ """
29
+
30
+ class DeepseekV3RMSNorm(nn.Module):
31
+ def __init__(self, hidden_size, eps=1e-6):
32
+ """
33
+ DeepseekV3RMSNorm is equivalent to T5LayerNorm
34
+ """
35
+ super().__init__()
36
+ self.weight = nn.Parameter(torch.ones(hidden_size))
37
+ self.variance_epsilon = eps
38
+
39
+ def forward(self, hidden_states):
40
+ input_dtype = hidden_states.dtype
41
+ hidden_states = hidden_states.to(torch.float32)
42
+ variance = hidden_states.pow(2).mean(-1, keepdim=True)
43
+ hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
44
+ return self.weight * hidden_states.to(input_dtype)
45
+
46
+ def create_norm_layer(hidden_size: int, config: SharedSpaceDecoderConfig) -> nn.Module:
47
+ """
48
+ Create a normalization layer based on the config norm_type.
49
+
50
+ Args:
51
+ hidden_size: The dimension to normalize over
52
+ config: Configuration containing norm_type and epsilon values
53
+
54
+ Returns:
55
+ Either a LayerNorm or RMSNorm layer
56
+ """
57
+ if config.norm_type == "layernorm":
58
+ return nn.LayerNorm(hidden_size, eps=config.layer_norm_eps)
59
+ elif config.norm_type == "rmsnorm":
60
+ return DeepseekV3RMSNorm(hidden_size, eps=config.rms_norm_eps)
61
+ else:
62
+ # This should be caught by config validation, but being defensive
63
+ raise ValueError(f"Unknown norm_type: {config.norm_type}")
64
+
65
+ """#### *PreTrainedModel"""
66
+
67
+ class SharedSpaceDecoderPreTrainedModel(PreTrainedModel):
68
+ """
69
+ The **PreTrainedModel object:
70
+ - Is instantiated when TODO
71
+ - Initializes:
72
+ - TODO
73
+ - Provides access to TODO
74
+ - Executes TODO
75
+ """
76
+
77
+ config_class = SharedSpaceDecoderConfig
78
+ base_model_prefix = "model"
79
+
80
+ def _init_weights(self, module: nn.Module) -> None:
81
+ """Weight initialization hook used by :class:`PreTrainedModel`.
82
+
83
+ ``PreTrainedModel.post_init`` will recursively apply this function to
84
+ every submodule right after construction. HuggingFace models override
85
+ it so that creating a model from scratch yields the same initialization
86
+ as ``from_pretrained`` when no checkpoint is supplied.
87
+
88
+ This decoder-specific initialization strategy includes:
89
+ - Proper handling of configurable normalization layers (LayerNorm or RMSNorm)
90
+ - Special initialization for language modeling heads
91
+ - Considerations for causal attention and autoregressive modeling
92
+ - Support for both dense and decomposed vocabulary embeddings
93
+ """
94
+
95
+ if isinstance(module, nn.Linear):
96
+ # Standard linear layer initialization
97
+ module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
98
+ if module.bias is not None:
99
+ module.bias.data.zero_()
100
+
101
+ elif isinstance(module, nn.Embedding):
102
+ # Initialize embeddings with normal distribution
103
+ module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
104
+ if module.padding_idx is not None:
105
+ module.weight.data[module.padding_idx].zero_()
106
+
107
+ elif isinstance(module, DeepseekV3RMSNorm):
108
+ # RMSNorm initialization: weight to 1.0, no bias term
109
+ module.weight.data.fill_(1.0)
110
+
111
+ elif isinstance(module, nn.LayerNorm):
112
+ # LayerNorm initialization: bias to 0, weight to 1.0
113
+ module.bias.data.zero_()
114
+ module.weight.data.fill_(1.0)
115
+
116
+ """# ▂▂▂▂▂▂▂▂▂▂▂▂
117
+
118
+ # Classes
119
+ """
120
+
121
+ """#### `*Layer`"""
122
+
123
+ class SharedSpaceDecoderLayer(nn.Module):
124
+ """
125
+ The **Layer object:
126
+ - Is instantiated by :class:`SharedSpaceDecoderModel` for each
127
+ Transformer block in the decoder.
128
+ - Initializes:
129
+ - ``self_attn`` – multi-head latent attention implementing either
130
+ dense or latent projections depending on the configuration.
131
+ - ``ffn`` – a :class:`SubspaceFeedForward` block.
132
+ - RMSNorm layers for pre-attention and pre-FFN normalization.
133
+ - Provides access to the attention and feed-forward submodules via the
134
+ attributes ``self_attn`` and ``ffn``.
135
+ - Executes a single decoder block in :meth:`forward`.
136
+ """
137
+
138
+ def __init__(self, config: SharedSpaceDecoderConfig, layer_idx: int) -> None:
139
+
140
+ super().__init__()
141
+
142
+ # Norm applied prior to attention.
143
+ self.attn_input_norm = create_norm_layer(config.hidden_size, config)
144
+
145
+ # Attention block
146
+ self.self_attn = MultiheadLatentAttention(config, layer_idx)
147
+
148
+ # Norm applied prior to FFN
149
+ self.ffn_input_norm = create_norm_layer(config.hidden_size, config)
150
+
151
+ # Feed-forward network used after attention
152
+ self.ffn = SubspaceFeedForward(config, layer_idx)
153
+
154
+ def forward(
155
+ self,
156
+ hidden_states: torch.Tensor,
157
+ position_embeddings: tuple[torch.Tensor, torch.Tensor], # RoPE embeddings
158
+ attention_mask: Optional[torch.Tensor],
159
+ ) -> torch.Tensor:
160
+
161
+ # ========================
162
+ # Self Attention
163
+ # ========================
164
+ residual_strm = hidden_states
165
+
166
+ # Normalize the hidden states to create the input to attention.
167
+ attn_input = self.attn_input_norm(hidden_states)
168
+
169
+ # Evaluate
170
+ attn_output = self.self_attn(
171
+ attn_input,
172
+ position_embeddings,
173
+ attention_mask,
174
+ )
175
+
176
+ # Add the attention output (the residual) back to the non-normalized
177
+ # hidden_states.
178
+ hidden_states = residual_strm + attn_output
179
+
180
+ # ===========================
181
+ # Feed-Forward Network
182
+ # ===========================
183
+ residual_strm = hidden_states
184
+
185
+ # Normalize the updated hidden states prior to the FFN
186
+ ffn_input = self.ffn_input_norm(hidden_states)
187
+
188
+ # Evaluate
189
+ ffn_output = self.ffn(ffn_input)
190
+
191
+ # Add the output the un-normalized hidden states.
192
+ hidden_states = residual_strm + ffn_output
193
+
194
+ return hidden_states
195
+
196
+ """#### *Model"""
197
+
198
+ class SharedSpaceDecoderModel(SharedSpaceDecoderPreTrainedModel):
199
+ """
200
+ The **Model object:
201
+ - Initializes:
202
+ - The vocabulary embeddings (and optional decomposition)
203
+ - Position embeddings (calculated in RotaryEmbedding)
204
+ - All of the **Layer objects.
205
+ - Provides interface to vocab embeddings.
206
+ - Executes the whole decoder model in `forward` with causal attention.
207
+
208
+ This is the base decoder without the language modeling head.
209
+ Use SubspaceDecoderForCausalLM for language modeling tasks.
210
+ """
211
+
212
+ def __init__(self, config: SharedSpaceDecoderConfig) -> None:
213
+ super().__init__(config)
214
+
215
+ # ============================
216
+ # Vocabulary Embeddings
217
+ # ============================
218
+ # Decomposing the vocabulary (if enabled) defines a shared projection
219
+ # which constrains the model to store semantic information (and
220
+ # whatever other static token knowledge) into a limited set of
221
+ # feature directions.
222
+
223
+ # If we're decomposing the token embeddings,
224
+ # TODO - Rename to vocab_subspace.
225
+ if config.vocab_subspace:
226
+
227
+ # Create the embedding table. Vocabulary embeddings are learned
228
+ # in a lower dimensional latent space.
229
+ self.vocab_embed = nn.Embedding(
230
+ config.vocab_size, # Number of tokens
231
+ config.vocab_rank # Subspace dimension
232
+ )
233
+
234
+ # Create a
235
+ # Selected token latents will be projected up to model size.
236
+ # vocab_proj has shape [vocab_rank x model_size]
237
+ self.vocab_proj = nn.Linear(
238
+ config.vocab_rank, # Size of latents
239
+ config.hidden_size, # Model size
240
+ bias=False
241
+ )
242
+
243
+ # Otherwise, for a dense vocabulary,
244
+ else:
245
+ # Create the dense embedding table in model space.
246
+ self.vocab_embed = nn.Embedding(
247
+ config.vocab_size, # Number of tokens
248
+ config.hidden_size # Model size
249
+ )
250
+
251
+ self.vocab_proj = None
252
+
253
+ # =====================
254
+ # RoPE Embeddings
255
+ # =====================
256
+
257
+ # Pre-computes the table of RoPE embeddings, leaving them in
258
+ # GPU memory.
259
+ self.rope = RotaryEmbedding(config)
260
+
261
+ # ===================
262
+ # Create Layers
263
+ # ===================
264
+
265
+ layers = []
266
+
267
+ # For each layer,
268
+ for i in range(config.num_hidden_layers):
269
+ # Create a **Layer, providing the config and indicating its number.
270
+ layers.append(
271
+ SharedSpaceDecoderLayer(
272
+ config,
273
+ layer_idx = i
274
+ )
275
+ )
276
+
277
+ # Wrap in torch ModuleList
278
+ self.layers = nn.ModuleList(layers)
279
+
280
+ # Whatever huggingface does behind the scenes...
281
+ self.post_init()
282
+
283
+ # Agents: Do not define boilerplate helpers, e.g., get/set_input_embeddings
284
+
285
+
286
+ def embed(self, input_ids: torch.LongTensor) -> torch.Tensor:
287
+ """
288
+ Return token embeddings for input ids.
289
+ This will perform the up projection to model space if the vocabulary is
290
+ decomposed.
291
+
292
+ input_ids have shape [batch_size, seq_len]
293
+ """
294
+
295
+ # If the vocabulary is decomposed,
296
+ if self.vocab_proj is not None:
297
+
298
+ # Retrieve the latents
299
+ # input_ids: [batch_size, seq_len]
300
+ # x: [batch_size, seq_len, latent_dim]
301
+ x = self.vocab_embed(input_ids)
302
+
303
+ # Project the latents back to model space and return.
304
+ return(self.vocab_proj(x))
305
+
306
+ # If the vocabulary is dense,
307
+ else:
308
+ # Just return the embeddings.
309
+ return self.vocab_embed(input_ids)
310
+
311
+ def forward(
312
+ self,
313
+ input_ids: torch.LongTensor,
314
+ attention_mask: Optional[torch.Tensor] = None,
315
+ **kwargs,
316
+ ) -> torch.Tensor:
317
+ """
318
+ Run the full decoder stack with causal attention.
319
+
320
+ Inputs:
321
+ input_ids [batch_size, seq_len]
322
+ attention_mask [batch_size, seq_len] - 1 for real tokens, 0 for padding
323
+
324
+ Returns:
325
+ Final decoder layer output [batch_size, seq_len, model_size]
326
+ """
327
+
328
+ # Retrieve the token embeddings for this sequence.
329
+ # These are model_size, regardless of whether the vocab is decompd.
330
+ hidden_states = self.embed(input_ids)
331
+
332
+ # Retrieve the rotary position embeddings for all of the positions in
333
+ # our current input sequence.
334
+
335
+ seq_len = hidden_states.size(1)
336
+
337
+ # Retrieves just the ones necessary for the sequence length of the
338
+ # input. These are vectors, two per token. Their length is the
339
+ # number of head dimensions we're applying RoPE to.
340
+ # Input
341
+ # cos: [max_seq_len, rope_dims]
342
+ # sin: [max_seq_len, rope_dims]
343
+ # Outputs:
344
+ # R_cos [seq_len, rope_dims]
345
+ # R_sin [seq_len, rope_dims]
346
+ R_cos = self.rope.cos[:seq_len]
347
+ R_sin = self.rope.sin[:seq_len]
348
+
349
+
350
+ # ===============================
351
+ # Attention Mask Conversion
352
+ # ===============================
353
+
354
+ """
355
+ use_sdpa_attention_masks = (
356
+ self.attn_implementation == "sdpa"
357
+ and self.position_embedding_type == "absolute"
358
+ and head_mask is None
359
+ and not output_attentions
360
+ )
361
+ """
362
+
363
+ # Expand the attention mask
364
+ #if use_sdpa_attention_masks and attention_mask.dim() == 2:
365
+ if True:
366
+ # Expand the attention mask for SDPA.
367
+ # [bsz, seq_len] -> [bsz, 1, seq_len, seq_len]
368
+ extended_attention_mask = _prepare_4d_attention_mask_for_sdpa(
369
+ attention_mask,
370
+ hidden_states.dtype,
371
+ tgt_len = seq_len
372
+ )
373
+ attention_mask = extended_attention_mask
374
+
375
+
376
+ # Run the model!
377
+
378
+ # For each decoder layer,
379
+ for layer_i, layer in enumerate(self.layers):
380
+
381
+ # Evaluate the layer
382
+ hidden_states = layer(
383
+ hidden_states, # Token embeddings
384
+ (R_cos, R_sin), # Rope embeddings, passed as a tuple.
385
+ attention_mask, # Attn mask
386
+ )
387
+
388
+ # Return the final output of the decoder stack.
389
+ return hidden_states
390
+