Spaces:
Running
Running
Fix sidebar and GPT2 tokens
Browse files- attention_app/isa.py +3 -2
- attention_app/server/main.py +50 -20
- attention_app/server/renderers.py +24 -13
- attention_app/ui/layouts.py +0 -5
- attention_app/ui/scripts.py +13 -4
- attention_app/ui/styles.py +19 -0
attention_app/isa.py
CHANGED
|
@@ -3,15 +3,16 @@ import numpy as np
|
|
| 3 |
import nltk
|
| 4 |
from typing import List, Dict, Tuple, Optional
|
| 5 |
|
|
|
|
| 6 |
# Ensure nltk data is downloaded
|
| 7 |
try:
|
| 8 |
nltk.data.find('tokenizers/punkt')
|
| 9 |
-
except LookupError:
|
| 10 |
nltk.download('punkt')
|
| 11 |
|
| 12 |
try:
|
| 13 |
nltk.data.find('tokenizers/punkt_tab')
|
| 14 |
-
except LookupError:
|
| 15 |
nltk.download('punkt_tab')
|
| 16 |
|
| 17 |
def get_sentence_boundaries(text: str, tokens: List[str], tokenizer, inputs) -> Tuple[List[str], List[int]]:
|
|
|
|
| 3 |
import nltk
|
| 4 |
from typing import List, Dict, Tuple, Optional
|
| 5 |
|
| 6 |
+
# Ensure nltk data is downloaded
|
| 7 |
# Ensure nltk data is downloaded
|
| 8 |
try:
|
| 9 |
nltk.data.find('tokenizers/punkt')
|
| 10 |
+
except (LookupError, OSError):
|
| 11 |
nltk.download('punkt')
|
| 12 |
|
| 13 |
try:
|
| 14 |
nltk.data.find('tokenizers/punkt_tab')
|
| 15 |
+
except (LookupError, OSError):
|
| 16 |
nltk.download('punkt_tab')
|
| 17 |
|
| 18 |
def get_sentence_boundaries(text: str, tokens: List[str], tokenizer, inputs) -> Tuple[List[str], List[int]]:
|
attention_app/server/main.py
CHANGED
|
@@ -149,12 +149,13 @@ def server(input, output, session):
|
|
| 149 |
print(f"ERROR in compute_all: {e}")
|
| 150 |
traceback.print_exc()
|
| 151 |
cached_result.set(None)
|
| 152 |
-
await session.send_custom_message('stop_loading', {})
|
| 153 |
finally:
|
| 154 |
running.set(False)
|
| 155 |
|
| 156 |
|
| 157 |
|
|
|
|
|
|
|
| 158 |
@output
|
| 159 |
@render.ui
|
| 160 |
def preview_text():
|
|
@@ -171,10 +172,11 @@ def server(input, output, session):
|
|
| 171 |
att_received_norm = (attention_received - attention_received.min()) / (attention_received.max() - attention_received.min() + 1e-10)
|
| 172 |
token_html = []
|
| 173 |
for i, (tok, att_recv, recv_norm) in enumerate(zip(tokens, attention_received, att_received_norm)):
|
|
|
|
| 174 |
opacity = 0.2 + (recv_norm * 0.6)
|
| 175 |
bg_color = f"rgba(59, 130, 246, {opacity})" # Keep blue for attention
|
| 176 |
-
tooltip = f"Token: {
|
| 177 |
-
token_html.append(f'<span class="token-viz" style="background:{bg_color};" title="{tooltip}">{
|
| 178 |
html = '<div class="token-viz-container">' + ''.join(token_html) + '</div>'
|
| 179 |
legend_html = '''
|
| 180 |
<div style="display:flex;gap:12px;margin-top:8px;font-size:9px;color:#6b7280;">
|
|
@@ -220,7 +222,7 @@ def server(input, output, session):
|
|
| 220 |
try: tree_root_idx = int(input.tree_root_token())
|
| 221 |
except: tree_root_idx = 0
|
| 222 |
|
| 223 |
-
clean_tokens = [t.replace("##", "") if t.startswith("##") else t for t in tokens]
|
| 224 |
|
| 225 |
return ui.div(
|
| 226 |
{"class": "dashboard-stack gpt2-layout"},
|
|
@@ -813,8 +815,7 @@ def server(input, output, session):
|
|
| 813 |
return ui.div(
|
| 814 |
{"class": "card"},
|
| 815 |
ui.h4("Global Attention Metrics"),
|
| 816 |
-
get_metrics_display(res)
|
| 817 |
-
ui.tags.script("$('#loading_spinner').hide(); $('#generate_all').prop('disabled', false).css('opacity', '1'); $('#dashboard-container').removeClass('content-hidden').addClass('content-visible');")
|
| 818 |
)
|
| 819 |
|
| 820 |
def dashboard_layout_helper(is_gpt2, num_layers, num_heads, clean_tokens):
|
|
@@ -1034,7 +1035,7 @@ def server(input, output, session):
|
|
| 1034 |
res = cached_result.get()
|
| 1035 |
if not res: return []
|
| 1036 |
tokens = res[0]
|
| 1037 |
-
return [t.replace("##", "") if t.startswith("##") else t for t in tokens]
|
| 1038 |
|
| 1039 |
@reactive.effect
|
| 1040 |
def update_selectors():
|
|
@@ -1053,7 +1054,7 @@ def server(input, output, session):
|
|
| 1053 |
def dashboard_content():
|
| 1054 |
config = current_layout_config.get()
|
| 1055 |
if not config:
|
| 1056 |
-
return ui.HTML("<script>$('#
|
| 1057 |
|
| 1058 |
is_gpt2, num_layers, num_heads = config
|
| 1059 |
|
|
@@ -1087,8 +1088,11 @@ def server(input, output, session):
|
|
| 1087 |
y_flat = y.flatten().tolist()
|
| 1088 |
scores = np.nan_to_num(matrix.flatten(), nan=0.0).tolist()
|
| 1089 |
|
|
|
|
|
|
|
|
|
|
| 1090 |
hover_texts = [
|
| 1091 |
-
f"Target ← {
|
| 1092 |
for r, c, s in zip(y_flat, x_flat, scores)
|
| 1093 |
]
|
| 1094 |
|
|
@@ -1130,7 +1134,7 @@ def server(input, output, session):
|
|
| 1130 |
customdata=customdata
|
| 1131 |
))
|
| 1132 |
|
| 1133 |
-
labels = [s[:30] + "..." if len(s) > 30 else s for s in sentences]
|
| 1134 |
|
| 1135 |
fig.update_layout(
|
| 1136 |
xaxis=dict(
|
|
@@ -1171,10 +1175,24 @@ def server(input, output, session):
|
|
| 1171 |
# Generate HTML with unique ID
|
| 1172 |
plot_html = fig.to_html(include_plotlyjs='cdn', full_html=False, div_id="isa_scatter_plot", config={'displayModeBar': False})
|
| 1173 |
|
| 1174 |
-
# Custom JS to handle clicks
|
|
|
|
|
|
|
| 1175 |
js = """
|
| 1176 |
<script>
|
| 1177 |
(function() {
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1178 |
console.log("DEBUG: Initializing ISA Plot Script");
|
| 1179 |
function initPlot() {
|
| 1180 |
var plot = document.getElementById('isa_scatter_plot');
|
|
@@ -1251,8 +1269,9 @@ def server(input, output, session):
|
|
| 1251 |
attentions, tokens, target_idx, source_idx, boundaries
|
| 1252 |
)
|
| 1253 |
|
| 1254 |
-
|
| 1255 |
-
|
|
|
|
| 1256 |
|
| 1257 |
# Custom colorscale for heatmap (Light Blue -> Deep Blue/Purple)
|
| 1258 |
heatmap_colorscale = [
|
|
@@ -1429,7 +1448,9 @@ def server(input, output, session):
|
|
| 1429 |
[1.0, '#1e3a8a']
|
| 1430 |
]
|
| 1431 |
|
| 1432 |
-
|
|
|
|
|
|
|
| 1433 |
fig.update_traces(customdata=custom, hovertemplate=hover)
|
| 1434 |
fig.update_layout(
|
| 1435 |
xaxis_title="Key (attending to)",
|
|
@@ -1507,6 +1528,8 @@ def server(input, output, session):
|
|
| 1507 |
block_width = 0.95 / n_tokens # Maximum spacing
|
| 1508 |
|
| 1509 |
for i, tok in enumerate(tokens):
|
|
|
|
|
|
|
| 1510 |
color = color_palette[i % len(color_palette)]
|
| 1511 |
x_pos = i / n_tokens + block_width / 2
|
| 1512 |
show_focus = focus_idx is not None
|
|
@@ -1521,15 +1544,15 @@ def server(input, output, session):
|
|
| 1521 |
font_size = 13 if is_selected else 10
|
| 1522 |
|
| 1523 |
text_color = color if (show_focus and is_selected) else "#111827"
|
| 1524 |
-
fig.add_trace(go.Scatter(x=[x_pos], y=[1.05], mode='text', text=
|
| 1525 |
-
fig.add_trace(go.Scatter(x=[x_pos], y=[-0.05], mode='text', text=
|
| 1526 |
|
| 1527 |
threshold = 0.04
|
| 1528 |
for i in range(n_tokens):
|
| 1529 |
for j in range(n_tokens):
|
| 1530 |
weight = att[i, j]
|
| 1531 |
if weight > threshold:
|
| 1532 |
-
is_line_focused = (focus_idx is None
|
| 1533 |
x_source = i / n_tokens + block_width / 2
|
| 1534 |
x_target = j / n_tokens + block_width / 2
|
| 1535 |
x_vals = [x_source, (x_source + x_target) / 2, x_target]
|
|
@@ -1542,12 +1565,19 @@ def server(input, output, session):
|
|
| 1542 |
line_color = '#2a2a2a'
|
| 1543 |
line_opacity = 0.003
|
| 1544 |
line_width = 0.1
|
| 1545 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1546 |
|
| 1547 |
title_text = ""
|
| 1548 |
if focus_idx is not None:
|
| 1549 |
focus_color = color_palette[focus_idx % len(color_palette)]
|
| 1550 |
-
|
|
|
|
|
|
|
| 1551 |
|
| 1552 |
fig.update_layout(
|
| 1553 |
title=title_text,
|
|
@@ -1754,7 +1784,7 @@ def server(input, output, session):
|
|
| 1754 |
return None
|
| 1755 |
|
| 1756 |
def build_node(current_idx, current_depth, current_value):
|
| 1757 |
-
token = tokens[current_idx]
|
| 1758 |
node = {
|
| 1759 |
"name": f"{current_idx}: {token}",
|
| 1760 |
"att": current_value,
|
|
|
|
| 149 |
print(f"ERROR in compute_all: {e}")
|
| 150 |
traceback.print_exc()
|
| 151 |
cached_result.set(None)
|
|
|
|
| 152 |
finally:
|
| 153 |
running.set(False)
|
| 154 |
|
| 155 |
|
| 156 |
|
| 157 |
+
|
| 158 |
+
|
| 159 |
@output
|
| 160 |
@render.ui
|
| 161 |
def preview_text():
|
|
|
|
| 172 |
att_received_norm = (attention_received - attention_received.min()) / (attention_received.max() - attention_received.min() + 1e-10)
|
| 173 |
token_html = []
|
| 174 |
for i, (tok, att_recv, recv_norm) in enumerate(zip(tokens, attention_received, att_received_norm)):
|
| 175 |
+
clean_tok = tok.replace("##", "").replace("Ġ", "")
|
| 176 |
opacity = 0.2 + (recv_norm * 0.6)
|
| 177 |
bg_color = f"rgba(59, 130, 246, {opacity})" # Keep blue for attention
|
| 178 |
+
tooltip = f"Token: {clean_tok} Attention Received: {att_recv:.3f}"
|
| 179 |
+
token_html.append(f'<span class="token-viz" style="background:{bg_color};" title="{tooltip}">{clean_tok}</span>')
|
| 180 |
html = '<div class="token-viz-container">' + ''.join(token_html) + '</div>'
|
| 181 |
legend_html = '''
|
| 182 |
<div style="display:flex;gap:12px;margin-top:8px;font-size:9px;color:#6b7280;">
|
|
|
|
| 222 |
try: tree_root_idx = int(input.tree_root_token())
|
| 223 |
except: tree_root_idx = 0
|
| 224 |
|
| 225 |
+
clean_tokens = [t.replace("##", "") if t.startswith("##") else t.replace("Ġ", "") for t in tokens]
|
| 226 |
|
| 227 |
return ui.div(
|
| 228 |
{"class": "dashboard-stack gpt2-layout"},
|
|
|
|
| 815 |
return ui.div(
|
| 816 |
{"class": "card"},
|
| 817 |
ui.h4("Global Attention Metrics"),
|
| 818 |
+
get_metrics_display(res)
|
|
|
|
| 819 |
)
|
| 820 |
|
| 821 |
def dashboard_layout_helper(is_gpt2, num_layers, num_heads, clean_tokens):
|
|
|
|
| 1035 |
res = cached_result.get()
|
| 1036 |
if not res: return []
|
| 1037 |
tokens = res[0]
|
| 1038 |
+
return [t.replace("##", "") if t.startswith("##") else t.replace("Ġ", "") for t in tokens]
|
| 1039 |
|
| 1040 |
@reactive.effect
|
| 1041 |
def update_selectors():
|
|
|
|
| 1054 |
def dashboard_content():
|
| 1055 |
config = current_layout_config.get()
|
| 1056 |
if not config:
|
| 1057 |
+
return ui.HTML("<script>$('#generate_all').html('Generate All').prop('disabled', false).css('opacity', '1');</script>")
|
| 1058 |
|
| 1059 |
is_gpt2, num_layers, num_heads = config
|
| 1060 |
|
|
|
|
| 1088 |
y_flat = y.flatten().tolist()
|
| 1089 |
scores = np.nan_to_num(matrix.flatten(), nan=0.0).tolist()
|
| 1090 |
|
| 1091 |
+
# Clean tokens for display in hover_texts
|
| 1092 |
+
cleaned_sentences = [s.replace("Ġ", "").replace("##", "") for s in sentences]
|
| 1093 |
+
|
| 1094 |
hover_texts = [
|
| 1095 |
+
f"Target ← {cleaned_sentences[int(r)][:60]}...<br>Source → {cleaned_sentences[int(c)][:60]}...<br>ISA = {s:.4f}"
|
| 1096 |
for r, c, s in zip(y_flat, x_flat, scores)
|
| 1097 |
]
|
| 1098 |
|
|
|
|
| 1134 |
customdata=customdata
|
| 1135 |
))
|
| 1136 |
|
| 1137 |
+
labels = [s[:30].replace("Ġ", "").replace("##", "") + "..." if len(s) > 30 else s.replace("Ġ", "").replace("##", "") for s in sentences]
|
| 1138 |
|
| 1139 |
fig.update_layout(
|
| 1140 |
xaxis=dict(
|
|
|
|
| 1175 |
# Generate HTML with unique ID
|
| 1176 |
plot_html = fig.to_html(include_plotlyjs='cdn', full_html=False, div_id="isa_scatter_plot", config={'displayModeBar': False})
|
| 1177 |
|
| 1178 |
+
# Custom JS to handle clicks, send to Shiny, AND stop loading state
|
| 1179 |
+
# This is placed here because the ISA plot is the heaviest component.
|
| 1180 |
+
# When this renders, we know the data is ready.
|
| 1181 |
js = """
|
| 1182 |
<script>
|
| 1183 |
(function() {
|
| 1184 |
+
// Stop loading state (Button Reset)
|
| 1185 |
+
var btn = $('#generate_all');
|
| 1186 |
+
if (btn.data('original-content')) {
|
| 1187 |
+
btn.html(btn.data('original-content'));
|
| 1188 |
+
} else {
|
| 1189 |
+
btn.html('Generate All');
|
| 1190 |
+
}
|
| 1191 |
+
btn.prop('disabled', false).css('opacity', '1');
|
| 1192 |
+
|
| 1193 |
+
// Show Dashboard
|
| 1194 |
+
$('#dashboard-container').removeClass('content-hidden').addClass('content-visible');
|
| 1195 |
+
|
| 1196 |
console.log("DEBUG: Initializing ISA Plot Script");
|
| 1197 |
function initPlot() {
|
| 1198 |
var plot = document.getElementById('isa_scatter_plot');
|
|
|
|
| 1269 |
attentions, tokens, target_idx, source_idx, boundaries
|
| 1270 |
)
|
| 1271 |
|
| 1272 |
+
# Clean tokens for display in the heatmap
|
| 1273 |
+
toks_target = [t.replace("Ġ", "").replace("##", "") for t in tokens_combined[:src_start]]
|
| 1274 |
+
toks_source = [t.replace("Ġ", "").replace("##", "") for t in tokens_combined[src_start:]]
|
| 1275 |
|
| 1276 |
# Custom colorscale for heatmap (Light Blue -> Deep Blue/Purple)
|
| 1277 |
heatmap_colorscale = [
|
|
|
|
| 1448 |
[1.0, '#1e3a8a']
|
| 1449 |
]
|
| 1450 |
|
| 1451 |
+
# Clean tokens for display in the imshow plot
|
| 1452 |
+
cleaned_tokens = [t.replace("##", "").replace("Ġ", "") for t in tokens]
|
| 1453 |
+
fig = px.imshow(att, x=cleaned_tokens, y=cleaned_tokens, color_continuous_scale=att_colorscale, aspect="auto")
|
| 1454 |
fig.update_traces(customdata=custom, hovertemplate=hover)
|
| 1455 |
fig.update_layout(
|
| 1456 |
xaxis_title="Key (attending to)",
|
|
|
|
| 1528 |
block_width = 0.95 / n_tokens # Maximum spacing
|
| 1529 |
|
| 1530 |
for i, tok in enumerate(tokens):
|
| 1531 |
+
# Clean token for display
|
| 1532 |
+
cleaned_tok = tok.replace("##", "").replace("Ġ", "")
|
| 1533 |
color = color_palette[i % len(color_palette)]
|
| 1534 |
x_pos = i / n_tokens + block_width / 2
|
| 1535 |
show_focus = focus_idx is not None
|
|
|
|
| 1544 |
font_size = 13 if is_selected else 10
|
| 1545 |
|
| 1546 |
text_color = color if (show_focus and is_selected) else "#111827"
|
| 1547 |
+
fig.add_trace(go.Scatter(x=[x_pos], y=[1.05], mode='text', text=cleaned_tok, textfont=dict(size=font_size, color=text_color, family='monospace', weight='bold'), showlegend=False, hoverinfo='skip'))
|
| 1548 |
+
fig.add_trace(go.Scatter(x=[x_pos], y=[-0.05], mode='text', text=cleaned_tok, textfont=dict(size=font_size, color=text_color, family='monospace', weight='bold'), showlegend=False, hoverinfo='skip'))
|
| 1549 |
|
| 1550 |
threshold = 0.04
|
| 1551 |
for i in range(n_tokens):
|
| 1552 |
for j in range(n_tokens):
|
| 1553 |
weight = att[i, j]
|
| 1554 |
if weight > threshold:
|
| 1555 |
+
is_line_focused = (focus_idx is not None and i == focus_idx) or (focus_idx is None)
|
| 1556 |
x_source = i / n_tokens + block_width / 2
|
| 1557 |
x_target = j / n_tokens + block_width / 2
|
| 1558 |
x_vals = [x_source, (x_source + x_target) / 2, x_target]
|
|
|
|
| 1565 |
line_color = '#2a2a2a'
|
| 1566 |
line_opacity = 0.003
|
| 1567 |
line_width = 0.1
|
| 1568 |
+
|
| 1569 |
+
# Clean tokens for hovertext
|
| 1570 |
+
cleaned_token_i = tokens[i].replace("##", "").replace("Ġ", "")
|
| 1571 |
+
cleaned_token_j = tokens[j].replace("##", "").replace("Ġ", "")
|
| 1572 |
+
|
| 1573 |
+
fig.add_trace(go.Scatter(x=x_vals, y=y_vals, mode='lines', line=dict(color=line_color, width=line_width), opacity=line_opacity, showlegend=False, hoverinfo='text' if is_line_focused else 'skip', hovertext=f"<b>{cleaned_token_i} to {cleaned_token_j}</b><br>Attention: {weight:.4f}"))
|
| 1574 |
|
| 1575 |
title_text = ""
|
| 1576 |
if focus_idx is not None:
|
| 1577 |
focus_color = color_palette[focus_idx % len(color_palette)]
|
| 1578 |
+
# Clean token for title
|
| 1579 |
+
cleaned_focus_token = tokens[focus_idx].replace("##", "").replace("Ġ", "")
|
| 1580 |
+
title_text += f" · <b style='color:{focus_color}'>Focused: '{cleaned_focus_token}'</b>"
|
| 1581 |
|
| 1582 |
fig.update_layout(
|
| 1583 |
title=title_text,
|
|
|
|
| 1784 |
return None
|
| 1785 |
|
| 1786 |
def build_node(current_idx, current_depth, current_value):
|
| 1787 |
+
token = tokens[current_idx].replace("##", "").replace("Ġ", "")
|
| 1788 |
node = {
|
| 1789 |
"name": f"{current_idx}: {token}",
|
| 1790 |
"att": current_value,
|
attention_app/server/renderers.py
CHANGED
|
@@ -86,9 +86,10 @@ def get_embedding_table(res):
|
|
| 86 |
vec = embeddings[i]
|
| 87 |
strip = array_to_base64_img(vec[:64], cmap="Blues", height=0.18)
|
| 88 |
tip = "Embedding (first 32 dims): " + ", ".join(f"{v:.3f}" for v in vec[:32])
|
|
|
|
| 89 |
rows.append(
|
| 90 |
f"<tr>"
|
| 91 |
-
f"<td class='token-name'>{
|
| 92 |
f"<td><img class='heatmap' src='data:image/png;base64,{strip}' title='{tip}'></td>"
|
| 93 |
f"</tr>"
|
| 94 |
)
|
|
@@ -109,11 +110,12 @@ def get_segment_embedding_view(res):
|
|
| 109 |
|
| 110 |
rows = ""
|
| 111 |
for i, (tok, seg) in enumerate(zip(tokens, ids)):
|
|
|
|
| 112 |
row_class = f"seg-row-{seg}" if seg in [0, 1] else ""
|
| 113 |
seg_label = "A" if seg == 0 else "B" if seg == 1 else str(seg)
|
| 114 |
rows += f"""
|
| 115 |
<tr class='{row_class}'>
|
| 116 |
-
<td class='token-cell'>{
|
| 117 |
<td class='segment-cell'>{seg_label}</td>
|
| 118 |
</tr>
|
| 119 |
"""
|
|
@@ -142,11 +144,12 @@ def get_posenc_table(res):
|
|
| 142 |
rows = []
|
| 143 |
for i, tok in enumerate(tokens):
|
| 144 |
pe = pos_enc[i]
|
|
|
|
| 145 |
strip = array_to_base64_img(pe[:64], cmap="Blues", height=0.18)
|
| 146 |
tip = f"Position {i} encoding: " + ", ".join(f"{v:.3f}" for v in pe[:32])
|
| 147 |
rows.append(
|
| 148 |
f"<tr>"
|
| 149 |
-
f"<td class='token-name'>{
|
| 150 |
f"<td><img class='heatmap' src='data:image/png;base64,{strip}' title='{tip}'></td>"
|
| 151 |
f"</tr>"
|
| 152 |
)
|
|
@@ -194,11 +197,12 @@ def get_sum_layernorm_view(res, encoder_model):
|
|
| 194 |
norm_np = normalized[0].cpu().numpy()
|
| 195 |
rows = []
|
| 196 |
for i, tok in enumerate(tokens):
|
|
|
|
| 197 |
sum_strip = array_to_base64_img(summed_np[i][:96], "Blues", 0.15)
|
| 198 |
norm_strip = array_to_base64_img(norm_np[i][:96], "Blues", 0.15)
|
| 199 |
rows.append(
|
| 200 |
"<tr>"
|
| 201 |
-
f"<td class='token-name'>{
|
| 202 |
f"<td><img class='heatmap' src='data:image/png;base64,{sum_strip}' title='Sum of embeddings'></td>"
|
| 203 |
f"<td><img class='heatmap' src='data:image/png;base64,{norm_strip}' title='LayerNorm output'></td>"
|
| 204 |
"</tr>"
|
|
@@ -221,7 +225,7 @@ def get_qkv_table(res, layer_idx):
|
|
| 221 |
cards = []
|
| 222 |
for i, tok in enumerate(tokens):
|
| 223 |
# Clean token for display
|
| 224 |
-
display_tok = tok.replace("##", "")
|
| 225 |
|
| 226 |
q_strip = array_to_base64_img(Q[i][:48], "Greens", 0.12)
|
| 227 |
k_strip = array_to_base64_img(K[i][:48], "Oranges", 0.12)
|
|
@@ -293,9 +297,9 @@ def get_scaled_attention_view(res, layer_idx, head_idx, focus_idx):
|
|
| 293 |
<div class='scaled-rank'>#{rank}</div>
|
| 294 |
<div class='scaled-details'>
|
| 295 |
<div class='scaled-connection'>
|
| 296 |
-
<span class='token-name' style='color:#ff5ca9;'>{tokens[focus_idx]}</span>
|
| 297 |
<span style='color:#94a3b8;margin:0 4px;'>→</span>
|
| 298 |
-
<span class='token-name' style='color:#3b82f6;'>{tokens[j]}</span>
|
| 299 |
</div>
|
| 300 |
<div class='scaled-values'>
|
| 301 |
<span class='scaled-step'>Q·K = <b>{dot:.2f}</b></span>
|
|
@@ -325,15 +329,16 @@ def get_add_norm_view(res, layer_idx):
|
|
| 325 |
hs_out = hidden_states[layer_idx + 1][0].cpu().numpy()
|
| 326 |
rows = []
|
| 327 |
for i, tok in enumerate(tokens):
|
|
|
|
| 328 |
diff = np.linalg.norm(hs_out[i] - hs_in[i])
|
| 329 |
norm = np.linalg.norm(hs_in[i]) + 1e-6
|
| 330 |
ratio = diff / norm
|
| 331 |
width = max(4, min(100, int(ratio * 80)))
|
| 332 |
rows.append(
|
| 333 |
-
f"<tr><td class='token-name'>{
|
| 334 |
f"<td><div style='background:#e5e7eb;border-radius:999px;height:10px;' title='Change: {ratio:.1%}'>"
|
| 335 |
f"<div style='width:{width}%;height:10px;border-radius:999px;"
|
| 336 |
-
f"background:linear-gradient(90deg,#
|
| 337 |
)
|
| 338 |
return ui.HTML(
|
| 339 |
"<div class='card-scroll'>"
|
|
@@ -363,11 +368,12 @@ def get_ffn_view(res, layer_idx):
|
|
| 363 |
proj_np = proj.cpu().numpy()
|
| 364 |
rows = []
|
| 365 |
for i, tok in enumerate(tokens):
|
|
|
|
| 366 |
inter_strip = array_to_base64_img(inter_np[i][:96], "Blues", 0.15)
|
| 367 |
proj_strip = array_to_base64_img(proj_np[i][:96], "Blues", 0.15)
|
| 368 |
rows.append(
|
| 369 |
"<tr>"
|
| 370 |
-
f"<td class='token-name'>{
|
| 371 |
f"<td><img class='heatmap' src='data:image/png;base64,{inter_strip}' title='Intermediate 3072 dims'></td>"
|
| 372 |
f"<td><img class='heatmap' src='data:image/png;base64,{proj_strip}' title='Projection back to 768 dims'></td>"
|
| 373 |
"</tr>"
|
|
@@ -388,15 +394,16 @@ def get_add_norm_post_ffn_view(res, layer_idx):
|
|
| 388 |
hs_out = hidden_states[layer_idx + 2][0].cpu().numpy()
|
| 389 |
rows = []
|
| 390 |
for i, tok in enumerate(tokens):
|
|
|
|
| 391 |
diff = np.linalg.norm(hs_out[i] - hs_mid[i])
|
| 392 |
norm = np.linalg.norm(hs_mid[i]) + 1e-6
|
| 393 |
ratio = diff / norm
|
| 394 |
width = max(4, min(100, int(ratio * 80)))
|
| 395 |
rows.append(
|
| 396 |
-
f"<tr><td class='token-name'>{
|
| 397 |
f"<td><div style='background:#e5e7eb;border-radius:999px;height:10px;' title='Change: {ratio:.1%}'>"
|
| 398 |
f"<div style='width:{width}%;height:10px;border-radius:999px;"
|
| 399 |
-
f"background:linear-gradient(90deg,#
|
| 400 |
)
|
| 401 |
return ui.HTML(
|
| 402 |
"<div class='card-scroll'>"
|
|
@@ -414,6 +421,7 @@ def get_layer_output_view(res, layer_idx):
|
|
| 414 |
|
| 415 |
rows = []
|
| 416 |
for i, tok in enumerate(tokens):
|
|
|
|
| 417 |
vec_strip = array_to_base64_img(hs[i][:64], "Blues", 0.15)
|
| 418 |
vec_tip = "Hidden state (first 32 dims): " + ", ".join(f"{v:.3f}" for v in hs[i][:32])
|
| 419 |
mean_val = float(hs[i].mean())
|
|
@@ -422,7 +430,7 @@ def get_layer_output_view(res, layer_idx):
|
|
| 422 |
|
| 423 |
rows.append(f"""
|
| 424 |
<tr>
|
| 425 |
-
<td class='token-name'>{
|
| 426 |
<td><img class='heatmap' src='data:image/png;base64,{vec_strip}' title='{vec_tip}'></td>
|
| 427 |
<td style='font-size:9px;color:#374151;white-space:nowrap;'>
|
| 428 |
μ={mean_val:.3f}, σ={std_val:.3f}, max={max_val:.3f}
|
|
@@ -470,6 +478,9 @@ def get_output_probabilities(res, use_mlm, text):
|
|
| 470 |
top_k = 5
|
| 471 |
|
| 472 |
for i, tok in enumerate(mlm_tokens):
|
|
|
|
|
|
|
|
|
|
| 473 |
token_probs = probs[i]
|
| 474 |
top_vals, top_idx = torch.topk(token_probs, top_k)
|
| 475 |
|
|
|
|
| 86 |
vec = embeddings[i]
|
| 87 |
strip = array_to_base64_img(vec[:64], cmap="Blues", height=0.18)
|
| 88 |
tip = "Embedding (first 32 dims): " + ", ".join(f"{v:.3f}" for v in vec[:32])
|
| 89 |
+
clean_tok = tok.replace("##", "").replace("Ġ", "")
|
| 90 |
rows.append(
|
| 91 |
f"<tr>"
|
| 92 |
+
f"<td class='token-name'>{clean_tok}</td>"
|
| 93 |
f"<td><img class='heatmap' src='data:image/png;base64,{strip}' title='{tip}'></td>"
|
| 94 |
f"</tr>"
|
| 95 |
)
|
|
|
|
| 110 |
|
| 111 |
rows = ""
|
| 112 |
for i, (tok, seg) in enumerate(zip(tokens, ids)):
|
| 113 |
+
clean_tok = tok.replace("##", "").replace("Ġ", "")
|
| 114 |
row_class = f"seg-row-{seg}" if seg in [0, 1] else ""
|
| 115 |
seg_label = "A" if seg == 0 else "B" if seg == 1 else str(seg)
|
| 116 |
rows += f"""
|
| 117 |
<tr class='{row_class}'>
|
| 118 |
+
<td class='token-cell'>{clean_tok}</td>
|
| 119 |
<td class='segment-cell'>{seg_label}</td>
|
| 120 |
</tr>
|
| 121 |
"""
|
|
|
|
| 144 |
rows = []
|
| 145 |
for i, tok in enumerate(tokens):
|
| 146 |
pe = pos_enc[i]
|
| 147 |
+
clean_tok = tok.replace("##", "").replace("Ġ", "")
|
| 148 |
strip = array_to_base64_img(pe[:64], cmap="Blues", height=0.18)
|
| 149 |
tip = f"Position {i} encoding: " + ", ".join(f"{v:.3f}" for v in pe[:32])
|
| 150 |
rows.append(
|
| 151 |
f"<tr>"
|
| 152 |
+
f"<td class='token-name'>{clean_tok}</td>"
|
| 153 |
f"<td><img class='heatmap' src='data:image/png;base64,{strip}' title='{tip}'></td>"
|
| 154 |
f"</tr>"
|
| 155 |
)
|
|
|
|
| 197 |
norm_np = normalized[0].cpu().numpy()
|
| 198 |
rows = []
|
| 199 |
for i, tok in enumerate(tokens):
|
| 200 |
+
clean_tok = tok.replace("##", "").replace("Ġ", "")
|
| 201 |
sum_strip = array_to_base64_img(summed_np[i][:96], "Blues", 0.15)
|
| 202 |
norm_strip = array_to_base64_img(norm_np[i][:96], "Blues", 0.15)
|
| 203 |
rows.append(
|
| 204 |
"<tr>"
|
| 205 |
+
f"<td class='token-name'>{clean_tok}</td>"
|
| 206 |
f"<td><img class='heatmap' src='data:image/png;base64,{sum_strip}' title='Sum of embeddings'></td>"
|
| 207 |
f"<td><img class='heatmap' src='data:image/png;base64,{norm_strip}' title='LayerNorm output'></td>"
|
| 208 |
"</tr>"
|
|
|
|
| 225 |
cards = []
|
| 226 |
for i, tok in enumerate(tokens):
|
| 227 |
# Clean token for display
|
| 228 |
+
display_tok = tok.replace("##", "").replace("Ġ", "")
|
| 229 |
|
| 230 |
q_strip = array_to_base64_img(Q[i][:48], "Greens", 0.12)
|
| 231 |
k_strip = array_to_base64_img(K[i][:48], "Oranges", 0.12)
|
|
|
|
| 297 |
<div class='scaled-rank'>#{rank}</div>
|
| 298 |
<div class='scaled-details'>
|
| 299 |
<div class='scaled-connection'>
|
| 300 |
+
<span class='token-name' style='color:#ff5ca9;'>{tokens[focus_idx].replace("##", "").replace("Ġ", "")}</span>
|
| 301 |
<span style='color:#94a3b8;margin:0 4px;'>→</span>
|
| 302 |
+
<span class='token-name' style='color:#3b82f6;'>{tokens[j].replace("##", "").replace("Ġ", "")}</span>
|
| 303 |
</div>
|
| 304 |
<div class='scaled-values'>
|
| 305 |
<span class='scaled-step'>Q·K = <b>{dot:.2f}</b></span>
|
|
|
|
| 329 |
hs_out = hidden_states[layer_idx + 1][0].cpu().numpy()
|
| 330 |
rows = []
|
| 331 |
for i, tok in enumerate(tokens):
|
| 332 |
+
clean_tok = tok.replace("##", "").replace("Ġ", "")
|
| 333 |
diff = np.linalg.norm(hs_out[i] - hs_in[i])
|
| 334 |
norm = np.linalg.norm(hs_in[i]) + 1e-6
|
| 335 |
ratio = diff / norm
|
| 336 |
width = max(4, min(100, int(ratio * 80)))
|
| 337 |
rows.append(
|
| 338 |
+
f"<tr><td class='token-name'>{clean_tok}</td>"
|
| 339 |
f"<td><div style='background:#e5e7eb;border-radius:999px;height:10px;' title='Change: {ratio:.1%}'>"
|
| 340 |
f"<div style='width:{width}%;height:10px;border-radius:999px;"
|
| 341 |
+
f"background:linear-gradient(90deg,#ff5ca9,#3b82f6);'></div></div></td></tr>"
|
| 342 |
)
|
| 343 |
return ui.HTML(
|
| 344 |
"<div class='card-scroll'>"
|
|
|
|
| 368 |
proj_np = proj.cpu().numpy()
|
| 369 |
rows = []
|
| 370 |
for i, tok in enumerate(tokens):
|
| 371 |
+
clean_tok = tok.replace("##", "").replace("Ġ", "")
|
| 372 |
inter_strip = array_to_base64_img(inter_np[i][:96], "Blues", 0.15)
|
| 373 |
proj_strip = array_to_base64_img(proj_np[i][:96], "Blues", 0.15)
|
| 374 |
rows.append(
|
| 375 |
"<tr>"
|
| 376 |
+
f"<td class='token-name'>{clean_tok}</td>"
|
| 377 |
f"<td><img class='heatmap' src='data:image/png;base64,{inter_strip}' title='Intermediate 3072 dims'></td>"
|
| 378 |
f"<td><img class='heatmap' src='data:image/png;base64,{proj_strip}' title='Projection back to 768 dims'></td>"
|
| 379 |
"</tr>"
|
|
|
|
| 394 |
hs_out = hidden_states[layer_idx + 2][0].cpu().numpy()
|
| 395 |
rows = []
|
| 396 |
for i, tok in enumerate(tokens):
|
| 397 |
+
clean_tok = tok.replace("##", "").replace("Ġ", "")
|
| 398 |
diff = np.linalg.norm(hs_out[i] - hs_mid[i])
|
| 399 |
norm = np.linalg.norm(hs_mid[i]) + 1e-6
|
| 400 |
ratio = diff / norm
|
| 401 |
width = max(4, min(100, int(ratio * 80)))
|
| 402 |
rows.append(
|
| 403 |
+
f"<tr><td class='token-name'>{clean_tok}</td>"
|
| 404 |
f"<td><div style='background:#e5e7eb;border-radius:999px;height:10px;' title='Change: {ratio:.1%}'>"
|
| 405 |
f"<div style='width:{width}%;height:10px;border-radius:999px;"
|
| 406 |
+
f"background:linear-gradient(90deg,#ff5ca9,#3b82f6);'></div></div></td></tr>"
|
| 407 |
)
|
| 408 |
return ui.HTML(
|
| 409 |
"<div class='card-scroll'>"
|
|
|
|
| 421 |
|
| 422 |
rows = []
|
| 423 |
for i, tok in enumerate(tokens):
|
| 424 |
+
clean_tok = tok.replace("##", "").replace("Ġ", "")
|
| 425 |
vec_strip = array_to_base64_img(hs[i][:64], "Blues", 0.15)
|
| 426 |
vec_tip = "Hidden state (first 32 dims): " + ", ".join(f"{v:.3f}" for v in hs[i][:32])
|
| 427 |
mean_val = float(hs[i].mean())
|
|
|
|
| 430 |
|
| 431 |
rows.append(f"""
|
| 432 |
<tr>
|
| 433 |
+
<td class='token-name'>{clean_tok}</td>
|
| 434 |
<td><img class='heatmap' src='data:image/png;base64,{vec_strip}' title='{vec_tip}'></td>
|
| 435 |
<td style='font-size:9px;color:#374151;white-space:nowrap;'>
|
| 436 |
μ={mean_val:.3f}, σ={std_val:.3f}, max={max_val:.3f}
|
|
|
|
| 478 |
top_k = 5
|
| 479 |
|
| 480 |
for i, tok in enumerate(mlm_tokens):
|
| 481 |
+
# Clean token header
|
| 482 |
+
tok = tok.replace("##", "").replace("Ġ", "")
|
| 483 |
+
if not tok: tok = " "
|
| 484 |
token_probs = probs[i]
|
| 485 |
top_vals, top_idx = torch.topk(token_probs, top_k)
|
| 486 |
|
attention_app/ui/layouts.py
CHANGED
|
@@ -56,11 +56,6 @@ attention_analysis_page = ui.page_fluid(
|
|
| 56 |
ui.input_text_area("text_input", None, "All women are naturally nurturing and emotional. Men are logical and suited for leadership positions.", rows=6),
|
| 57 |
ui.div(
|
| 58 |
ui.input_action_button("generate_all", "Generate All", class_="btn-primary"),
|
| 59 |
-
ui.div(
|
| 60 |
-
{"id": "loading_spinner", "class": "loading-container", "style": "display:none;"},
|
| 61 |
-
ui.div({"class": "spinner"}),
|
| 62 |
-
ui.span("Processing...")
|
| 63 |
-
),
|
| 64 |
),
|
| 65 |
),
|
| 66 |
|
|
|
|
| 56 |
ui.input_text_area("text_input", None, "All women are naturally nurturing and emotional. Men are logical and suited for leadership positions.", rows=6),
|
| 57 |
ui.div(
|
| 58 |
ui.input_action_button("generate_all", "Generate All", class_="btn-primary"),
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 59 |
),
|
| 60 |
),
|
| 61 |
|
attention_app/ui/scripts.py
CHANGED
|
@@ -146,14 +146,23 @@ JS_INTERACTIVE = """
|
|
| 146 |
|
| 147 |
// Custom message handlers
|
| 148 |
Shiny.addCustomMessageHandler('start_loading', function(msg) {
|
| 149 |
-
$('#
|
| 150 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 151 |
$('#dashboard-container').addClass('content-hidden').removeClass('content-visible');
|
| 152 |
});
|
| 153 |
|
| 154 |
Shiny.addCustomMessageHandler('stop_loading', function(msg) {
|
| 155 |
-
$('#
|
| 156 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 157 |
});
|
| 158 |
|
| 159 |
// Bias Loading Handlers
|
|
|
|
| 146 |
|
| 147 |
// Custom message handlers
|
| 148 |
Shiny.addCustomMessageHandler('start_loading', function(msg) {
|
| 149 |
+
var btn = $('#generate_all');
|
| 150 |
+
if (!btn.data('original-content')) {
|
| 151 |
+
btn.data('original-content', btn.html());
|
| 152 |
+
}
|
| 153 |
+
btn.html('<div class="spinner" style="width:16px;height:16px;border-width:2px;display:inline-block;vertical-align:middle;margin-right:8px;"></div>Processing<span class="loading-dots"></span>');
|
| 154 |
+
btn.prop('disabled', true).css('opacity', '0.8');
|
| 155 |
$('#dashboard-container').addClass('content-hidden').removeClass('content-visible');
|
| 156 |
});
|
| 157 |
|
| 158 |
Shiny.addCustomMessageHandler('stop_loading', function(msg) {
|
| 159 |
+
var btn = $('#generate_all');
|
| 160 |
+
if (btn.data('original-content')) {
|
| 161 |
+
btn.html(btn.data('original-content'));
|
| 162 |
+
} else {
|
| 163 |
+
btn.html('Generate All');
|
| 164 |
+
}
|
| 165 |
+
btn.prop('disabled', false).css('opacity', '1');
|
| 166 |
});
|
| 167 |
|
| 168 |
// Bias Loading Handlers
|
attention_app/ui/styles.py
CHANGED
|
@@ -534,6 +534,25 @@ CSS = """
|
|
| 534 |
|
| 535 |
@keyframes spin { to { transform: rotate(360deg); } }
|
| 536 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 537 |
/* Tables */
|
| 538 |
.token-table {
|
| 539 |
width: 100%;
|
|
|
|
| 534 |
|
| 535 |
@keyframes spin { to { transform: rotate(360deg); } }
|
| 536 |
|
| 537 |
+
/* Spinner inside primary button (needs to be white) */
|
| 538 |
+
.btn-primary .spinner {
|
| 539 |
+
border-color: rgba(255, 255, 255, 0.3);
|
| 540 |
+
border-top-color: white;
|
| 541 |
+
}
|
| 542 |
+
|
| 543 |
+
/* Loading Dots Animation */
|
| 544 |
+
.loading-dots:after {
|
| 545 |
+
content: '.';
|
| 546 |
+
animation: dots 1.5s steps(5, end) infinite;
|
| 547 |
+
}
|
| 548 |
+
|
| 549 |
+
@keyframes dots {
|
| 550 |
+
0%, 20% { content: '.'; }
|
| 551 |
+
40% { content: '..'; }
|
| 552 |
+
60% { content: '...'; }
|
| 553 |
+
80%, 100% { content: ''; }
|
| 554 |
+
}
|
| 555 |
+
|
| 556 |
/* Tables */
|
| 557 |
.token-table {
|
| 558 |
width: 100%;
|