pinthoz commited on
Commit
0cd6681
·
verified ·
1 Parent(s): ed09cf3

Fix sidebar and GPT2 tokens

Browse files
attention_app/isa.py CHANGED
@@ -3,15 +3,16 @@ import numpy as np
3
  import nltk
4
  from typing import List, Dict, Tuple, Optional
5
 
 
6
  # Ensure nltk data is downloaded
7
  try:
8
  nltk.data.find('tokenizers/punkt')
9
- except LookupError:
10
  nltk.download('punkt')
11
 
12
  try:
13
  nltk.data.find('tokenizers/punkt_tab')
14
- except LookupError:
15
  nltk.download('punkt_tab')
16
 
17
  def get_sentence_boundaries(text: str, tokens: List[str], tokenizer, inputs) -> Tuple[List[str], List[int]]:
 
3
  import nltk
4
  from typing import List, Dict, Tuple, Optional
5
 
6
+ # Ensure nltk data is downloaded
7
  # Ensure nltk data is downloaded
8
  try:
9
  nltk.data.find('tokenizers/punkt')
10
+ except (LookupError, OSError):
11
  nltk.download('punkt')
12
 
13
  try:
14
  nltk.data.find('tokenizers/punkt_tab')
15
+ except (LookupError, OSError):
16
  nltk.download('punkt_tab')
17
 
18
  def get_sentence_boundaries(text: str, tokens: List[str], tokenizer, inputs) -> Tuple[List[str], List[int]]:
attention_app/server/main.py CHANGED
@@ -149,12 +149,13 @@ def server(input, output, session):
149
  print(f"ERROR in compute_all: {e}")
150
  traceback.print_exc()
151
  cached_result.set(None)
152
- await session.send_custom_message('stop_loading', {})
153
  finally:
154
  running.set(False)
155
 
156
 
157
 
 
 
158
  @output
159
  @render.ui
160
  def preview_text():
@@ -171,10 +172,11 @@ def server(input, output, session):
171
  att_received_norm = (attention_received - attention_received.min()) / (attention_received.max() - attention_received.min() + 1e-10)
172
  token_html = []
173
  for i, (tok, att_recv, recv_norm) in enumerate(zip(tokens, attention_received, att_received_norm)):
 
174
  opacity = 0.2 + (recv_norm * 0.6)
175
  bg_color = f"rgba(59, 130, 246, {opacity})" # Keep blue for attention
176
- tooltip = f"Token: {tok}
Attention Received: {att_recv:.3f}"
177
- token_html.append(f'<span class="token-viz" style="background:{bg_color};" title="{tooltip}">{tok}</span>')
178
  html = '<div class="token-viz-container">' + ''.join(token_html) + '</div>'
179
  legend_html = '''
180
  <div style="display:flex;gap:12px;margin-top:8px;font-size:9px;color:#6b7280;">
@@ -220,7 +222,7 @@ def server(input, output, session):
220
  try: tree_root_idx = int(input.tree_root_token())
221
  except: tree_root_idx = 0
222
 
223
- clean_tokens = [t.replace("##", "") if t.startswith("##") else t for t in tokens]
224
 
225
  return ui.div(
226
  {"class": "dashboard-stack gpt2-layout"},
@@ -813,8 +815,7 @@ def server(input, output, session):
813
  return ui.div(
814
  {"class": "card"},
815
  ui.h4("Global Attention Metrics"),
816
- get_metrics_display(res),
817
- ui.tags.script("$('#loading_spinner').hide(); $('#generate_all').prop('disabled', false).css('opacity', '1'); $('#dashboard-container').removeClass('content-hidden').addClass('content-visible');")
818
  )
819
 
820
  def dashboard_layout_helper(is_gpt2, num_layers, num_heads, clean_tokens):
@@ -1034,7 +1035,7 @@ def server(input, output, session):
1034
  res = cached_result.get()
1035
  if not res: return []
1036
  tokens = res[0]
1037
- return [t.replace("##", "") if t.startswith("##") else t for t in tokens]
1038
 
1039
  @reactive.effect
1040
  def update_selectors():
@@ -1053,7 +1054,7 @@ def server(input, output, session):
1053
  def dashboard_content():
1054
  config = current_layout_config.get()
1055
  if not config:
1056
- return ui.HTML("<script>$('#loading_spinner').hide(); $('#generate_all').prop('disabled', false).css('opacity', '1');</script>")
1057
 
1058
  is_gpt2, num_layers, num_heads = config
1059
 
@@ -1087,8 +1088,11 @@ def server(input, output, session):
1087
  y_flat = y.flatten().tolist()
1088
  scores = np.nan_to_num(matrix.flatten(), nan=0.0).tolist()
1089
 
 
 
 
1090
  hover_texts = [
1091
- f"Target ← {sentences[int(r)][:60]}...<br>Source → {sentences[int(c)][:60]}...<br>ISA = {s:.4f}"
1092
  for r, c, s in zip(y_flat, x_flat, scores)
1093
  ]
1094
 
@@ -1130,7 +1134,7 @@ def server(input, output, session):
1130
  customdata=customdata
1131
  ))
1132
 
1133
- labels = [s[:30] + "..." if len(s) > 30 else s for s in sentences]
1134
 
1135
  fig.update_layout(
1136
  xaxis=dict(
@@ -1171,10 +1175,24 @@ def server(input, output, session):
1171
  # Generate HTML with unique ID
1172
  plot_html = fig.to_html(include_plotlyjs='cdn', full_html=False, div_id="isa_scatter_plot", config={'displayModeBar': False})
1173
 
1174
- # Custom JS to handle clicks and send to Shiny
 
 
1175
  js = """
1176
  <script>
1177
  (function() {
 
 
 
 
 
 
 
 
 
 
 
 
1178
  console.log("DEBUG: Initializing ISA Plot Script");
1179
  function initPlot() {
1180
  var plot = document.getElementById('isa_scatter_plot');
@@ -1251,8 +1269,9 @@ def server(input, output, session):
1251
  attentions, tokens, target_idx, source_idx, boundaries
1252
  )
1253
 
1254
- toks_target = tokens_combined[:src_start]
1255
- toks_source = tokens_combined[src_start:]
 
1256
 
1257
  # Custom colorscale for heatmap (Light Blue -> Deep Blue/Purple)
1258
  heatmap_colorscale = [
@@ -1429,7 +1448,9 @@ def server(input, output, session):
1429
  [1.0, '#1e3a8a']
1430
  ]
1431
 
1432
- fig = px.imshow(att, x=tokens, y=tokens, color_continuous_scale=att_colorscale, aspect="auto")
 
 
1433
  fig.update_traces(customdata=custom, hovertemplate=hover)
1434
  fig.update_layout(
1435
  xaxis_title="Key (attending to)",
@@ -1507,6 +1528,8 @@ def server(input, output, session):
1507
  block_width = 0.95 / n_tokens # Maximum spacing
1508
 
1509
  for i, tok in enumerate(tokens):
 
 
1510
  color = color_palette[i % len(color_palette)]
1511
  x_pos = i / n_tokens + block_width / 2
1512
  show_focus = focus_idx is not None
@@ -1521,15 +1544,15 @@ def server(input, output, session):
1521
  font_size = 13 if is_selected else 10
1522
 
1523
  text_color = color if (show_focus and is_selected) else "#111827"
1524
- fig.add_trace(go.Scatter(x=[x_pos], y=[1.05], mode='text', text=tok, textfont=dict(size=font_size, color=text_color, family='monospace', weight='bold'), showlegend=False, hoverinfo='skip'))
1525
- fig.add_trace(go.Scatter(x=[x_pos], y=[-0.05], mode='text', text=tok, textfont=dict(size=font_size, color=text_color, family='monospace', weight='bold'), showlegend=False, hoverinfo='skip'))
1526
 
1527
  threshold = 0.04
1528
  for i in range(n_tokens):
1529
  for j in range(n_tokens):
1530
  weight = att[i, j]
1531
  if weight > threshold:
1532
- is_line_focused = (focus_idx is None) or (i == focus_idx)
1533
  x_source = i / n_tokens + block_width / 2
1534
  x_target = j / n_tokens + block_width / 2
1535
  x_vals = [x_source, (x_source + x_target) / 2, x_target]
@@ -1542,12 +1565,19 @@ def server(input, output, session):
1542
  line_color = '#2a2a2a'
1543
  line_opacity = 0.003
1544
  line_width = 0.1
1545
- fig.add_trace(go.Scatter(x=x_vals, y=y_vals, mode='lines', line=dict(color=line_color, width=line_width), opacity=line_opacity, showlegend=False, hoverinfo='text' if is_line_focused else 'skip', hovertext=f"<b>{tokens[i]} to {tokens[j]}</b><br>Attention: {weight:.4f}"))
 
 
 
 
 
1546
 
1547
  title_text = ""
1548
  if focus_idx is not None:
1549
  focus_color = color_palette[focus_idx % len(color_palette)]
1550
- title_text += f" · <b style='color:{focus_color}'>Focused: '{tokens[focus_idx]}'</b>"
 
 
1551
 
1552
  fig.update_layout(
1553
  title=title_text,
@@ -1754,7 +1784,7 @@ def server(input, output, session):
1754
  return None
1755
 
1756
  def build_node(current_idx, current_depth, current_value):
1757
- token = tokens[current_idx]
1758
  node = {
1759
  "name": f"{current_idx}: {token}",
1760
  "att": current_value,
 
149
  print(f"ERROR in compute_all: {e}")
150
  traceback.print_exc()
151
  cached_result.set(None)
 
152
  finally:
153
  running.set(False)
154
 
155
 
156
 
157
+
158
+
159
  @output
160
  @render.ui
161
  def preview_text():
 
172
  att_received_norm = (attention_received - attention_received.min()) / (attention_received.max() - attention_received.min() + 1e-10)
173
  token_html = []
174
  for i, (tok, att_recv, recv_norm) in enumerate(zip(tokens, attention_received, att_received_norm)):
175
+ clean_tok = tok.replace("##", "").replace("Ġ", "")
176
  opacity = 0.2 + (recv_norm * 0.6)
177
  bg_color = f"rgba(59, 130, 246, {opacity})" # Keep blue for attention
178
+ tooltip = f"Token: {clean_tok}&#10;Attention Received: {att_recv:.3f}"
179
+ token_html.append(f'<span class="token-viz" style="background:{bg_color};" title="{tooltip}">{clean_tok}</span>')
180
  html = '<div class="token-viz-container">' + ''.join(token_html) + '</div>'
181
  legend_html = '''
182
  <div style="display:flex;gap:12px;margin-top:8px;font-size:9px;color:#6b7280;">
 
222
  try: tree_root_idx = int(input.tree_root_token())
223
  except: tree_root_idx = 0
224
 
225
+ clean_tokens = [t.replace("##", "") if t.startswith("##") else t.replace("Ġ", "") for t in tokens]
226
 
227
  return ui.div(
228
  {"class": "dashboard-stack gpt2-layout"},
 
815
  return ui.div(
816
  {"class": "card"},
817
  ui.h4("Global Attention Metrics"),
818
+ get_metrics_display(res)
 
819
  )
820
 
821
  def dashboard_layout_helper(is_gpt2, num_layers, num_heads, clean_tokens):
 
1035
  res = cached_result.get()
1036
  if not res: return []
1037
  tokens = res[0]
1038
+ return [t.replace("##", "") if t.startswith("##") else t.replace("Ġ", "") for t in tokens]
1039
 
1040
  @reactive.effect
1041
  def update_selectors():
 
1054
  def dashboard_content():
1055
  config = current_layout_config.get()
1056
  if not config:
1057
+ return ui.HTML("<script>$('#generate_all').html('Generate All').prop('disabled', false).css('opacity', '1');</script>")
1058
 
1059
  is_gpt2, num_layers, num_heads = config
1060
 
 
1088
  y_flat = y.flatten().tolist()
1089
  scores = np.nan_to_num(matrix.flatten(), nan=0.0).tolist()
1090
 
1091
+ # Clean tokens for display in hover_texts
1092
+ cleaned_sentences = [s.replace("Ġ", "").replace("##", "") for s in sentences]
1093
+
1094
  hover_texts = [
1095
+ f"Target ← {cleaned_sentences[int(r)][:60]}...<br>Source → {cleaned_sentences[int(c)][:60]}...<br>ISA = {s:.4f}"
1096
  for r, c, s in zip(y_flat, x_flat, scores)
1097
  ]
1098
 
 
1134
  customdata=customdata
1135
  ))
1136
 
1137
+ labels = [s[:30].replace("Ġ", "").replace("##", "") + "..." if len(s) > 30 else s.replace("Ġ", "").replace("##", "") for s in sentences]
1138
 
1139
  fig.update_layout(
1140
  xaxis=dict(
 
1175
  # Generate HTML with unique ID
1176
  plot_html = fig.to_html(include_plotlyjs='cdn', full_html=False, div_id="isa_scatter_plot", config={'displayModeBar': False})
1177
 
1178
+ # Custom JS to handle clicks, send to Shiny, AND stop loading state
1179
+ # This is placed here because the ISA plot is the heaviest component.
1180
+ # When this renders, we know the data is ready.
1181
  js = """
1182
  <script>
1183
  (function() {
1184
+ // Stop loading state (Button Reset)
1185
+ var btn = $('#generate_all');
1186
+ if (btn.data('original-content')) {
1187
+ btn.html(btn.data('original-content'));
1188
+ } else {
1189
+ btn.html('Generate All');
1190
+ }
1191
+ btn.prop('disabled', false).css('opacity', '1');
1192
+
1193
+ // Show Dashboard
1194
+ $('#dashboard-container').removeClass('content-hidden').addClass('content-visible');
1195
+
1196
  console.log("DEBUG: Initializing ISA Plot Script");
1197
  function initPlot() {
1198
  var plot = document.getElementById('isa_scatter_plot');
 
1269
  attentions, tokens, target_idx, source_idx, boundaries
1270
  )
1271
 
1272
+ # Clean tokens for display in the heatmap
1273
+ toks_target = [t.replace("Ġ", "").replace("##", "") for t in tokens_combined[:src_start]]
1274
+ toks_source = [t.replace("Ġ", "").replace("##", "") for t in tokens_combined[src_start:]]
1275
 
1276
  # Custom colorscale for heatmap (Light Blue -> Deep Blue/Purple)
1277
  heatmap_colorscale = [
 
1448
  [1.0, '#1e3a8a']
1449
  ]
1450
 
1451
+ # Clean tokens for display in the imshow plot
1452
+ cleaned_tokens = [t.replace("##", "").replace("Ġ", "") for t in tokens]
1453
+ fig = px.imshow(att, x=cleaned_tokens, y=cleaned_tokens, color_continuous_scale=att_colorscale, aspect="auto")
1454
  fig.update_traces(customdata=custom, hovertemplate=hover)
1455
  fig.update_layout(
1456
  xaxis_title="Key (attending to)",
 
1528
  block_width = 0.95 / n_tokens # Maximum spacing
1529
 
1530
  for i, tok in enumerate(tokens):
1531
+ # Clean token for display
1532
+ cleaned_tok = tok.replace("##", "").replace("Ġ", "")
1533
  color = color_palette[i % len(color_palette)]
1534
  x_pos = i / n_tokens + block_width / 2
1535
  show_focus = focus_idx is not None
 
1544
  font_size = 13 if is_selected else 10
1545
 
1546
  text_color = color if (show_focus and is_selected) else "#111827"
1547
+ fig.add_trace(go.Scatter(x=[x_pos], y=[1.05], mode='text', text=cleaned_tok, textfont=dict(size=font_size, color=text_color, family='monospace', weight='bold'), showlegend=False, hoverinfo='skip'))
1548
+ fig.add_trace(go.Scatter(x=[x_pos], y=[-0.05], mode='text', text=cleaned_tok, textfont=dict(size=font_size, color=text_color, family='monospace', weight='bold'), showlegend=False, hoverinfo='skip'))
1549
 
1550
  threshold = 0.04
1551
  for i in range(n_tokens):
1552
  for j in range(n_tokens):
1553
  weight = att[i, j]
1554
  if weight > threshold:
1555
+ is_line_focused = (focus_idx is not None and i == focus_idx) or (focus_idx is None)
1556
  x_source = i / n_tokens + block_width / 2
1557
  x_target = j / n_tokens + block_width / 2
1558
  x_vals = [x_source, (x_source + x_target) / 2, x_target]
 
1565
  line_color = '#2a2a2a'
1566
  line_opacity = 0.003
1567
  line_width = 0.1
1568
+
1569
+ # Clean tokens for hovertext
1570
+ cleaned_token_i = tokens[i].replace("##", "").replace("Ġ", "")
1571
+ cleaned_token_j = tokens[j].replace("##", "").replace("Ġ", "")
1572
+
1573
+ fig.add_trace(go.Scatter(x=x_vals, y=y_vals, mode='lines', line=dict(color=line_color, width=line_width), opacity=line_opacity, showlegend=False, hoverinfo='text' if is_line_focused else 'skip', hovertext=f"<b>{cleaned_token_i} to {cleaned_token_j}</b><br>Attention: {weight:.4f}"))
1574
 
1575
  title_text = ""
1576
  if focus_idx is not None:
1577
  focus_color = color_palette[focus_idx % len(color_palette)]
1578
+ # Clean token for title
1579
+ cleaned_focus_token = tokens[focus_idx].replace("##", "").replace("Ġ", "")
1580
+ title_text += f" · <b style='color:{focus_color}'>Focused: '{cleaned_focus_token}'</b>"
1581
 
1582
  fig.update_layout(
1583
  title=title_text,
 
1784
  return None
1785
 
1786
  def build_node(current_idx, current_depth, current_value):
1787
+ token = tokens[current_idx].replace("##", "").replace("Ġ", "")
1788
  node = {
1789
  "name": f"{current_idx}: {token}",
1790
  "att": current_value,
attention_app/server/renderers.py CHANGED
@@ -86,9 +86,10 @@ def get_embedding_table(res):
86
  vec = embeddings[i]
87
  strip = array_to_base64_img(vec[:64], cmap="Blues", height=0.18)
88
  tip = "Embedding (first 32 dims): " + ", ".join(f"{v:.3f}" for v in vec[:32])
 
89
  rows.append(
90
  f"<tr>"
91
- f"<td class='token-name'>{tok}</td>"
92
  f"<td><img class='heatmap' src='data:image/png;base64,{strip}' title='{tip}'></td>"
93
  f"</tr>"
94
  )
@@ -109,11 +110,12 @@ def get_segment_embedding_view(res):
109
 
110
  rows = ""
111
  for i, (tok, seg) in enumerate(zip(tokens, ids)):
 
112
  row_class = f"seg-row-{seg}" if seg in [0, 1] else ""
113
  seg_label = "A" if seg == 0 else "B" if seg == 1 else str(seg)
114
  rows += f"""
115
  <tr class='{row_class}'>
116
- <td class='token-cell'>{tok}</td>
117
  <td class='segment-cell'>{seg_label}</td>
118
  </tr>
119
  """
@@ -142,11 +144,12 @@ def get_posenc_table(res):
142
  rows = []
143
  for i, tok in enumerate(tokens):
144
  pe = pos_enc[i]
 
145
  strip = array_to_base64_img(pe[:64], cmap="Blues", height=0.18)
146
  tip = f"Position {i} encoding: " + ", ".join(f"{v:.3f}" for v in pe[:32])
147
  rows.append(
148
  f"<tr>"
149
- f"<td class='token-name'>{tok}</td>"
150
  f"<td><img class='heatmap' src='data:image/png;base64,{strip}' title='{tip}'></td>"
151
  f"</tr>"
152
  )
@@ -194,11 +197,12 @@ def get_sum_layernorm_view(res, encoder_model):
194
  norm_np = normalized[0].cpu().numpy()
195
  rows = []
196
  for i, tok in enumerate(tokens):
 
197
  sum_strip = array_to_base64_img(summed_np[i][:96], "Blues", 0.15)
198
  norm_strip = array_to_base64_img(norm_np[i][:96], "Blues", 0.15)
199
  rows.append(
200
  "<tr>"
201
- f"<td class='token-name'>{tok}</td>"
202
  f"<td><img class='heatmap' src='data:image/png;base64,{sum_strip}' title='Sum of embeddings'></td>"
203
  f"<td><img class='heatmap' src='data:image/png;base64,{norm_strip}' title='LayerNorm output'></td>"
204
  "</tr>"
@@ -221,7 +225,7 @@ def get_qkv_table(res, layer_idx):
221
  cards = []
222
  for i, tok in enumerate(tokens):
223
  # Clean token for display
224
- display_tok = tok.replace("##", "") if tok.startswith("##") else tok
225
 
226
  q_strip = array_to_base64_img(Q[i][:48], "Greens", 0.12)
227
  k_strip = array_to_base64_img(K[i][:48], "Oranges", 0.12)
@@ -293,9 +297,9 @@ def get_scaled_attention_view(res, layer_idx, head_idx, focus_idx):
293
  <div class='scaled-rank'>#{rank}</div>
294
  <div class='scaled-details'>
295
  <div class='scaled-connection'>
296
- <span class='token-name' style='color:#ff5ca9;'>{tokens[focus_idx]}</span>
297
  <span style='color:#94a3b8;margin:0 4px;'>→</span>
298
- <span class='token-name' style='color:#3b82f6;'>{tokens[j]}</span>
299
  </div>
300
  <div class='scaled-values'>
301
  <span class='scaled-step'>Q·K = <b>{dot:.2f}</b></span>
@@ -325,15 +329,16 @@ def get_add_norm_view(res, layer_idx):
325
  hs_out = hidden_states[layer_idx + 1][0].cpu().numpy()
326
  rows = []
327
  for i, tok in enumerate(tokens):
 
328
  diff = np.linalg.norm(hs_out[i] - hs_in[i])
329
  norm = np.linalg.norm(hs_in[i]) + 1e-6
330
  ratio = diff / norm
331
  width = max(4, min(100, int(ratio * 80)))
332
  rows.append(
333
- f"<tr><td class='token-name'>{tok}</td>"
334
  f"<td><div style='background:#e5e7eb;border-radius:999px;height:10px;' title='Change: {ratio:.1%}'>"
335
  f"<div style='width:{width}%;height:10px;border-radius:999px;"
336
- f"background:linear-gradient(90deg,#22c55e,#22d3ee);'></div></div></td></tr>"
337
  )
338
  return ui.HTML(
339
  "<div class='card-scroll'>"
@@ -363,11 +368,12 @@ def get_ffn_view(res, layer_idx):
363
  proj_np = proj.cpu().numpy()
364
  rows = []
365
  for i, tok in enumerate(tokens):
 
366
  inter_strip = array_to_base64_img(inter_np[i][:96], "Blues", 0.15)
367
  proj_strip = array_to_base64_img(proj_np[i][:96], "Blues", 0.15)
368
  rows.append(
369
  "<tr>"
370
- f"<td class='token-name'>{tok}</td>"
371
  f"<td><img class='heatmap' src='data:image/png;base64,{inter_strip}' title='Intermediate 3072 dims'></td>"
372
  f"<td><img class='heatmap' src='data:image/png;base64,{proj_strip}' title='Projection back to 768 dims'></td>"
373
  "</tr>"
@@ -388,15 +394,16 @@ def get_add_norm_post_ffn_view(res, layer_idx):
388
  hs_out = hidden_states[layer_idx + 2][0].cpu().numpy()
389
  rows = []
390
  for i, tok in enumerate(tokens):
 
391
  diff = np.linalg.norm(hs_out[i] - hs_mid[i])
392
  norm = np.linalg.norm(hs_mid[i]) + 1e-6
393
  ratio = diff / norm
394
  width = max(4, min(100, int(ratio * 80)))
395
  rows.append(
396
- f"<tr><td class='token-name'>{tok}</td>"
397
  f"<td><div style='background:#e5e7eb;border-radius:999px;height:10px;' title='Change: {ratio:.1%}'>"
398
  f"<div style='width:{width}%;height:10px;border-radius:999px;"
399
- f"background:linear-gradient(90deg,#14b8a6,#0ea5e9);'></div></div></td></tr>"
400
  )
401
  return ui.HTML(
402
  "<div class='card-scroll'>"
@@ -414,6 +421,7 @@ def get_layer_output_view(res, layer_idx):
414
 
415
  rows = []
416
  for i, tok in enumerate(tokens):
 
417
  vec_strip = array_to_base64_img(hs[i][:64], "Blues", 0.15)
418
  vec_tip = "Hidden state (first 32 dims): " + ", ".join(f"{v:.3f}" for v in hs[i][:32])
419
  mean_val = float(hs[i].mean())
@@ -422,7 +430,7 @@ def get_layer_output_view(res, layer_idx):
422
 
423
  rows.append(f"""
424
  <tr>
425
- <td class='token-name'>{tok}</td>
426
  <td><img class='heatmap' src='data:image/png;base64,{vec_strip}' title='{vec_tip}'></td>
427
  <td style='font-size:9px;color:#374151;white-space:nowrap;'>
428
  μ={mean_val:.3f}, σ={std_val:.3f}, max={max_val:.3f}
@@ -470,6 +478,9 @@ def get_output_probabilities(res, use_mlm, text):
470
  top_k = 5
471
 
472
  for i, tok in enumerate(mlm_tokens):
 
 
 
473
  token_probs = probs[i]
474
  top_vals, top_idx = torch.topk(token_probs, top_k)
475
 
 
86
  vec = embeddings[i]
87
  strip = array_to_base64_img(vec[:64], cmap="Blues", height=0.18)
88
  tip = "Embedding (first 32 dims): " + ", ".join(f"{v:.3f}" for v in vec[:32])
89
+ clean_tok = tok.replace("##", "").replace("Ġ", "")
90
  rows.append(
91
  f"<tr>"
92
+ f"<td class='token-name'>{clean_tok}</td>"
93
  f"<td><img class='heatmap' src='data:image/png;base64,{strip}' title='{tip}'></td>"
94
  f"</tr>"
95
  )
 
110
 
111
  rows = ""
112
  for i, (tok, seg) in enumerate(zip(tokens, ids)):
113
+ clean_tok = tok.replace("##", "").replace("Ġ", "")
114
  row_class = f"seg-row-{seg}" if seg in [0, 1] else ""
115
  seg_label = "A" if seg == 0 else "B" if seg == 1 else str(seg)
116
  rows += f"""
117
  <tr class='{row_class}'>
118
+ <td class='token-cell'>{clean_tok}</td>
119
  <td class='segment-cell'>{seg_label}</td>
120
  </tr>
121
  """
 
144
  rows = []
145
  for i, tok in enumerate(tokens):
146
  pe = pos_enc[i]
147
+ clean_tok = tok.replace("##", "").replace("Ġ", "")
148
  strip = array_to_base64_img(pe[:64], cmap="Blues", height=0.18)
149
  tip = f"Position {i} encoding: " + ", ".join(f"{v:.3f}" for v in pe[:32])
150
  rows.append(
151
  f"<tr>"
152
+ f"<td class='token-name'>{clean_tok}</td>"
153
  f"<td><img class='heatmap' src='data:image/png;base64,{strip}' title='{tip}'></td>"
154
  f"</tr>"
155
  )
 
197
  norm_np = normalized[0].cpu().numpy()
198
  rows = []
199
  for i, tok in enumerate(tokens):
200
+ clean_tok = tok.replace("##", "").replace("Ġ", "")
201
  sum_strip = array_to_base64_img(summed_np[i][:96], "Blues", 0.15)
202
  norm_strip = array_to_base64_img(norm_np[i][:96], "Blues", 0.15)
203
  rows.append(
204
  "<tr>"
205
+ f"<td class='token-name'>{clean_tok}</td>"
206
  f"<td><img class='heatmap' src='data:image/png;base64,{sum_strip}' title='Sum of embeddings'></td>"
207
  f"<td><img class='heatmap' src='data:image/png;base64,{norm_strip}' title='LayerNorm output'></td>"
208
  "</tr>"
 
225
  cards = []
226
  for i, tok in enumerate(tokens):
227
  # Clean token for display
228
+ display_tok = tok.replace("##", "").replace("Ġ", "")
229
 
230
  q_strip = array_to_base64_img(Q[i][:48], "Greens", 0.12)
231
  k_strip = array_to_base64_img(K[i][:48], "Oranges", 0.12)
 
297
  <div class='scaled-rank'>#{rank}</div>
298
  <div class='scaled-details'>
299
  <div class='scaled-connection'>
300
+ <span class='token-name' style='color:#ff5ca9;'>{tokens[focus_idx].replace("##", "").replace("Ġ", "")}</span>
301
  <span style='color:#94a3b8;margin:0 4px;'>→</span>
302
+ <span class='token-name' style='color:#3b82f6;'>{tokens[j].replace("##", "").replace("Ġ", "")}</span>
303
  </div>
304
  <div class='scaled-values'>
305
  <span class='scaled-step'>Q·K = <b>{dot:.2f}</b></span>
 
329
  hs_out = hidden_states[layer_idx + 1][0].cpu().numpy()
330
  rows = []
331
  for i, tok in enumerate(tokens):
332
+ clean_tok = tok.replace("##", "").replace("Ġ", "")
333
  diff = np.linalg.norm(hs_out[i] - hs_in[i])
334
  norm = np.linalg.norm(hs_in[i]) + 1e-6
335
  ratio = diff / norm
336
  width = max(4, min(100, int(ratio * 80)))
337
  rows.append(
338
+ f"<tr><td class='token-name'>{clean_tok}</td>"
339
  f"<td><div style='background:#e5e7eb;border-radius:999px;height:10px;' title='Change: {ratio:.1%}'>"
340
  f"<div style='width:{width}%;height:10px;border-radius:999px;"
341
+ f"background:linear-gradient(90deg,#ff5ca9,#3b82f6);'></div></div></td></tr>"
342
  )
343
  return ui.HTML(
344
  "<div class='card-scroll'>"
 
368
  proj_np = proj.cpu().numpy()
369
  rows = []
370
  for i, tok in enumerate(tokens):
371
+ clean_tok = tok.replace("##", "").replace("Ġ", "")
372
  inter_strip = array_to_base64_img(inter_np[i][:96], "Blues", 0.15)
373
  proj_strip = array_to_base64_img(proj_np[i][:96], "Blues", 0.15)
374
  rows.append(
375
  "<tr>"
376
+ f"<td class='token-name'>{clean_tok}</td>"
377
  f"<td><img class='heatmap' src='data:image/png;base64,{inter_strip}' title='Intermediate 3072 dims'></td>"
378
  f"<td><img class='heatmap' src='data:image/png;base64,{proj_strip}' title='Projection back to 768 dims'></td>"
379
  "</tr>"
 
394
  hs_out = hidden_states[layer_idx + 2][0].cpu().numpy()
395
  rows = []
396
  for i, tok in enumerate(tokens):
397
+ clean_tok = tok.replace("##", "").replace("Ġ", "")
398
  diff = np.linalg.norm(hs_out[i] - hs_mid[i])
399
  norm = np.linalg.norm(hs_mid[i]) + 1e-6
400
  ratio = diff / norm
401
  width = max(4, min(100, int(ratio * 80)))
402
  rows.append(
403
+ f"<tr><td class='token-name'>{clean_tok}</td>"
404
  f"<td><div style='background:#e5e7eb;border-radius:999px;height:10px;' title='Change: {ratio:.1%}'>"
405
  f"<div style='width:{width}%;height:10px;border-radius:999px;"
406
+ f"background:linear-gradient(90deg,#ff5ca9,#3b82f6);'></div></div></td></tr>"
407
  )
408
  return ui.HTML(
409
  "<div class='card-scroll'>"
 
421
 
422
  rows = []
423
  for i, tok in enumerate(tokens):
424
+ clean_tok = tok.replace("##", "").replace("Ġ", "")
425
  vec_strip = array_to_base64_img(hs[i][:64], "Blues", 0.15)
426
  vec_tip = "Hidden state (first 32 dims): " + ", ".join(f"{v:.3f}" for v in hs[i][:32])
427
  mean_val = float(hs[i].mean())
 
430
 
431
  rows.append(f"""
432
  <tr>
433
+ <td class='token-name'>{clean_tok}</td>
434
  <td><img class='heatmap' src='data:image/png;base64,{vec_strip}' title='{vec_tip}'></td>
435
  <td style='font-size:9px;color:#374151;white-space:nowrap;'>
436
  μ={mean_val:.3f}, σ={std_val:.3f}, max={max_val:.3f}
 
478
  top_k = 5
479
 
480
  for i, tok in enumerate(mlm_tokens):
481
+ # Clean token header
482
+ tok = tok.replace("##", "").replace("Ġ", "")
483
+ if not tok: tok = "&nbsp;"
484
  token_probs = probs[i]
485
  top_vals, top_idx = torch.topk(token_probs, top_k)
486
 
attention_app/ui/layouts.py CHANGED
@@ -56,11 +56,6 @@ attention_analysis_page = ui.page_fluid(
56
  ui.input_text_area("text_input", None, "All women are naturally nurturing and emotional. Men are logical and suited for leadership positions.", rows=6),
57
  ui.div(
58
  ui.input_action_button("generate_all", "Generate All", class_="btn-primary"),
59
- ui.div(
60
- {"id": "loading_spinner", "class": "loading-container", "style": "display:none;"},
61
- ui.div({"class": "spinner"}),
62
- ui.span("Processing...")
63
- ),
64
  ),
65
  ),
66
 
 
56
  ui.input_text_area("text_input", None, "All women are naturally nurturing and emotional. Men are logical and suited for leadership positions.", rows=6),
57
  ui.div(
58
  ui.input_action_button("generate_all", "Generate All", class_="btn-primary"),
 
 
 
 
 
59
  ),
60
  ),
61
 
attention_app/ui/scripts.py CHANGED
@@ -146,14 +146,23 @@ JS_INTERACTIVE = """
146
 
147
  // Custom message handlers
148
  Shiny.addCustomMessageHandler('start_loading', function(msg) {
149
- $('#loading_spinner').css('display', 'flex');
150
- $('#generate_all').prop('disabled', true).css('opacity', '0.7');
 
 
 
 
151
  $('#dashboard-container').addClass('content-hidden').removeClass('content-visible');
152
  });
153
 
154
  Shiny.addCustomMessageHandler('stop_loading', function(msg) {
155
- $('#loading_spinner').css('display', 'none');
156
- $('#generate_all').prop('disabled', false).css('opacity', '1');
 
 
 
 
 
157
  });
158
 
159
  // Bias Loading Handlers
 
146
 
147
  // Custom message handlers
148
  Shiny.addCustomMessageHandler('start_loading', function(msg) {
149
+ var btn = $('#generate_all');
150
+ if (!btn.data('original-content')) {
151
+ btn.data('original-content', btn.html());
152
+ }
153
+ btn.html('<div class="spinner" style="width:16px;height:16px;border-width:2px;display:inline-block;vertical-align:middle;margin-right:8px;"></div>Processing<span class="loading-dots"></span>');
154
+ btn.prop('disabled', true).css('opacity', '0.8');
155
  $('#dashboard-container').addClass('content-hidden').removeClass('content-visible');
156
  });
157
 
158
  Shiny.addCustomMessageHandler('stop_loading', function(msg) {
159
+ var btn = $('#generate_all');
160
+ if (btn.data('original-content')) {
161
+ btn.html(btn.data('original-content'));
162
+ } else {
163
+ btn.html('Generate All');
164
+ }
165
+ btn.prop('disabled', false).css('opacity', '1');
166
  });
167
 
168
  // Bias Loading Handlers
attention_app/ui/styles.py CHANGED
@@ -534,6 +534,25 @@ CSS = """
534
 
535
  @keyframes spin { to { transform: rotate(360deg); } }
536
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
537
  /* Tables */
538
  .token-table {
539
  width: 100%;
 
534
 
535
  @keyframes spin { to { transform: rotate(360deg); } }
536
 
537
+ /* Spinner inside primary button (needs to be white) */
538
+ .btn-primary .spinner {
539
+ border-color: rgba(255, 255, 255, 0.3);
540
+ border-top-color: white;
541
+ }
542
+
543
+ /* Loading Dots Animation */
544
+ .loading-dots:after {
545
+ content: '.';
546
+ animation: dots 1.5s steps(5, end) infinite;
547
+ }
548
+
549
+ @keyframes dots {
550
+ 0%, 20% { content: '.'; }
551
+ 40% { content: '..'; }
552
+ 60% { content: '...'; }
553
+ 80%, 100% { content: ''; }
554
+ }
555
+
556
  /* Tables */
557
  .token-table {
558
  width: 100%;