performative_dashboard / bar_plot.py
ror's picture
ror HF Staff
Probably v1
79e7993
Raw
History Blame Contribute Delete
8.82 kB
import matplotlib.pyplot as plt
import matplotlib.patches as mpatches
import io
import numpy as np
import base64
from plot_utils import get_color_for_config
from data import load_data, ModelBenchmarkData
def reorder_data(per_scenario_data: dict) -> dict:
keys = list(per_scenario_data.keys())
def sorting_fn(key: str) -> float:
cfg = per_scenario_data[key]["config"]
attn_implementation = cfg["attn_implementation"]
attn_impl_prio = {
"flash_attention_2": 0,
"sdpa": 1,
"eager": 2,
"flex_attention": 3,
}[attn_implementation]
sdpa_backend_prio = {
None: -1,
"flash_attention": 0,
"math": 1,
"efficient_attention": 2,
"cudnn_attention": 3,
}[cfg["sdpa_backend"]]
return (
attn_impl_prio,
sdpa_backend_prio,
cfg["kernelize"],
cfg["compile_mode"] is not None,
)
keys.sort(key=sorting_fn)
per_scenario_data = {k: per_scenario_data[k] for k in keys}
return per_scenario_data
def infer_bar_label(config: dict) -> str:
"""Format legend labels to be more readable."""
if config["attn_implementation"] == "eager":
attn_implementation = "Eager"
elif config["attn_implementation"] == "flash_attention_2":
attn_implementation = "Flash attention"
elif config["attn_implementation"] == "flex_attention":
attn_implementation = "Flex attention"
elif config["attn_implementation"] == "sdpa":
attn_implementation = {
"flash_attention": "SDPA (flash attention)",
"efficient_attention": "SDPA (efficient_attention)",
"cudnn_attention": "SDPA (cudnn)",
"math": "SDPA (math)",
}.get(config["sdpa_backend"], "SDPA (unknown backend)")
else:
attn_implementation = "Unknown"
compile = "compiled" if config["compile_mode"] is not None else "no compile"
kernels = "kernelized" if config["kernelize"] else "no kernels"
return f"{attn_implementation}, {compile}, {kernels}"
def infer_bar_hatch(config: dict) -> str:
if config["compile_mode"] is not None:
return "/"
else:
return ""
def make_bar_kwargs(
per_device_data: dict[str, ModelBenchmarkData], key: str
) -> tuple[dict, list]:
# Prepare accumulators
current_x = 0
bar_kwargs = {"x": [], "height": [], "color": [], "label": [], "hatch": []}
errors_bars = []
x_ticks = []
for device_name, device_data in per_device_data.items():
per_scenario_data = device_data.get_bar_plot_data()
per_scenario_data = reorder_data(per_scenario_data)
device_xs = []
for scenario_name, scenario_data in per_scenario_data.items():
bar_kwargs["x"].append(current_x)
bar_kwargs["height"].append(np.median(scenario_data[key]))
bar_kwargs["color"].append(get_color_for_config(scenario_data["config"]))
bar_kwargs["label"].append(infer_bar_label(scenario_data["config"]))
bar_kwargs["hatch"].append(infer_bar_hatch(scenario_data["config"]))
errors_bars.append(np.std(scenario_data[key]))
device_xs.append(current_x)
current_x += 1
x_ticks.append((np.mean(device_xs), device_name))
current_x += 1.5
return bar_kwargs, errors_bars, x_ticks
def create_matplotlib_bar_plot() -> None:
"""Create side-by-side matplotlib bar charts for TTFT and TPOT data."""
# Create figure with dark theme - maximum size for full screen
plt.style.use("dark_background")
fig, axs = plt.subplots(2, 1, figsize=(20, 11), sharex=True) # used to be 30, 16
fig.patch.set_facecolor("#000000")
# Load data and ensure coherence
per_device_data = load_data()
batch_size, sequence_length, num_tokens_to_generate = None, None, None
for device_name, device_data in per_device_data.items():
bs, seqlen, n_tok = device_data.ensure_coherence()
if batch_size is None:
batch_size, sequence_length, num_tokens_to_generate = bs, seqlen, n_tok
elif (bs, seqlen, n_tok) != (
batch_size,
sequence_length,
num_tokens_to_generate,
):
fig.suptitle(
f"Mismatch for batch size, sequence length and number of tokens to generate between configs: {bs} "
f"!= {batch_size}, {seqlen} != {sequence_length}, {n_tok} != {num_tokens_to_generate}",
color="white",
fontsize=18,
)
return None
# TTFT Plot (top)
ttft_bars, ttft_errors, x_ticks = make_bar_kwargs(per_device_data, "ttft")
draw_bar_plot(axs[0], ttft_bars, ttft_errors, "TTFT (seconds)", x_ticks)
# # ITL Plot (bottom)
itl_bars, itl_errors, x_ticks = make_bar_kwargs(per_device_data, "itl")
draw_bar_plot(axs[1], itl_bars, itl_errors, "ITL (seconds)", x_ticks)
# Title and tight layout
title = "\n".join(
[
"Time to first token and inter-token latency (lower is better)",
f"Batch size: {batch_size}, sequence length: {sequence_length}, new tokens: {num_tokens_to_generate}",
]
)
fig.suptitle(title, color="white", fontsize=20, y=1.005, linespacing=1.5)
plt.tight_layout()
# Add common legend with full text
legend_labels, legend_colors, legend_hatches = [], [], []
for label, color, hatch in zip(
ttft_bars["label"], ttft_bars["color"], ttft_bars["hatch"]
):
if label not in legend_labels:
legend_labels.append(label)
legend_colors.append(color)
legend_hatches.append(hatch)
# Make sure all attn implementations are equally represented
# implementations = {}
# for label, color, hatch in zip(legend_labels, legend_colors, legend_hatches):
# impl = label.split(",")[0]
# implementations[impl] = implementations.get(impl, []) + [(label, color, hatch)]
# n_max = max(len(impls) for impls in implementations.values())
# for label_color_pairs in implementations.values():
# for _ in range(len(label_color_pairs), n_max):
# label_color_pairs.append(("", "#000000"))
# legend_labels, legend_colors = zip(*sum(implementations.values(), []))
legend_handles = [
mpatches.Patch(facecolor=color, hatch=hatch, label=label, edgecolor="white")
for color, hatch, label in zip(legend_colors, legend_hatches, legend_labels)
]
# Put a legend to the right of the current axis
fig.legend(
handles=legend_handles,
loc="lower center",
ncol=4,
bbox_to_anchor=(0.515, -0.11),
facecolor="black",
edgecolor="white",
labelcolor="white",
fontsize=14,
)
# Save plot to bytes with high DPI for crisp text
buffer = io.BytesIO()
plt.savefig(buffer, format="png", facecolor="#000000", bbox_inches="tight", dpi=150)
buffer.seek(0)
# Convert to base64 for HTML embedding
img_data = base64.b64encode(buffer.getvalue()).decode()
plt.close(fig)
# Return HTML with embedded image - full page coverage
html = f"""
<div style="width: 90vw; height: 90vh; background: #000; display: flex; justify-content: center; align-items: center; margin: 0; padding: 0; top: 0; left: 0;">
<img src="data:image/png;base64,{img_data}" style="width: 100%; height: 100%; object-fit: contain; max-width: none; max-height: none;" />
</div>
"""
return html
def draw_bar_plot(
ax: plt.Axes,
bar_kwargs: dict,
errors: list,
ylabel: str,
xticks: list[tuple[float, str]],
adapt_ylim: bool = False,
) -> None:
ax.set_facecolor("#000000")
ax.grid(True, alpha=0.3, color="white", axis="y", zorder=0)
# Draw bars
_ = ax.bar(**bar_kwargs, width=1.0, edgecolor="white", linewidth=1, zorder=3)
# Add error bars
ax.errorbar(
bar_kwargs["x"],
bar_kwargs["height"],
yerr=errors,
fmt="none",
ecolor="white",
alpha=0.8,
elinewidth=1.5,
capthick=1.5,
capsize=4,
zorder=4,
)
# Set labels, ticks and grid
ax.set_ylabel(ylabel, color="white", fontsize=16)
ax.set_xticks([])
ax.tick_params(colors="white", labelsize=13)
ax.set_xticks([xt[0] for xt in xticks], [xt[1] for xt in xticks], fontsize=16)
# Truncate axis to better fit the bars
if adapt_ylim:
new_ymin, new_ymax = 1e9, -1e9
for h, e in zip(bar_kwargs["height"], errors):
new_ymin = min(new_ymin, 0.98 * (h - e))
new_ymax = max(new_ymax, 1.02 * (h + e))
ymin, ymax = ax.get_ylim()
ax.set_ylim(max(ymin, new_ymin), min(ymax, new_ymax))