yhj137's picture
update
29fa0af
import gradio as gr
import torch
import os
from miditoolkit import MidiFile
from src.model.generate import batch_performance_render, map_midi
from src.model.pianoformer import PianoT5Gemma
# ------------------------------
# Load model
# ------------------------------
def load_model():
print("Loading model...")
model = PianoT5Gemma.from_pretrained(
"yhj137/pianist-transformer-rendering",
token=os.environ.get("hf_token"),
torch_dtype=torch.bfloat16
)
model.eval()
return model
model = load_model()
# ------------------------------
# Define inference function
# ------------------------------
def render_midi(midi_file, temperature, top_p):
try:
input_path = midi_file.name
midi = MidiFile(input_path)
# Run inference
res = batch_performance_render(
model,
[midi],
temperature=temperature,
top_p=top_p,
device="cpu" # change to "cuda" if GPU available
)
# Save raw (unmapped) result
raw_out_path = "raw_render.mid"
res[0].dump(raw_out_path)
# Try to create editable (mapped) version
editable_out_path = "editable_render.mid"
try:
mapped = map_midi(midi, res[0])
mapped.dump(editable_out_path)
return [raw_out_path, editable_out_path]
except Exception as e:
print(f"[Warning] map_midi failed: {e}")
return [raw_out_path, f"[Error] map_midi failed: {e}"]
except Exception as e:
raise gr.Error(f"Inference failed: {e}")
# ------------------------------
# Build Gradio interface
# ------------------------------
demo = gr.Interface(
fn=render_midi,
inputs=[
gr.File(label="Upload a Score MIDI File (.mid or .midi)", file_types=[".mid", ".midi", ".MID", ".MIDI"]),
gr.Slider(0.1, 2.0, value=1.0, step=0.01, label="Temperature"),
gr.Slider(0.1, 1.0, value=0.95, step=0.01, label="Top-p"),
],
outputs=[
gr.File(label="Raw Performance"),
gr.File(label="Editable Version")
],
title="🎹 Pianist Transformer Rendering",
description=(
"Upload a piano score MIDI file and let the Pianist Transformer render it into "
"a more expressive performance MIDI.\n\n"
"Two versions will be saved:\n\n"
"• **Raw Performance** – directly generated by the model\n\n"
"• **Editable Version** – aligned with the score using our Expressive Tempo Mapping algorithm\n\n"
"If mapping fails, only the raw version will be returned with an error message.\n\n"
"⚠️ **This is only a demo running on limited compute resources. Please do not upload long pieces — "
"we recommend clips shorter than 1 minute.**"
),
)
if __name__ == "__main__":
demo.launch()