|
|
import gradio as gr |
|
|
import torch |
|
|
import os |
|
|
from miditoolkit import MidiFile |
|
|
from src.model.generate import batch_performance_render, map_midi |
|
|
from src.model.pianoformer import PianoT5Gemma |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def load_model(): |
|
|
print("Loading model...") |
|
|
model = PianoT5Gemma.from_pretrained( |
|
|
"yhj137/pianist-transformer-rendering", |
|
|
token=os.environ.get("hf_token"), |
|
|
torch_dtype=torch.bfloat16 |
|
|
) |
|
|
model.eval() |
|
|
return model |
|
|
|
|
|
model = load_model() |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def render_midi(midi_file, temperature, top_p): |
|
|
try: |
|
|
input_path = midi_file.name |
|
|
midi = MidiFile(input_path) |
|
|
|
|
|
|
|
|
res = batch_performance_render( |
|
|
model, |
|
|
[midi], |
|
|
temperature=temperature, |
|
|
top_p=top_p, |
|
|
device="cpu" |
|
|
) |
|
|
|
|
|
|
|
|
raw_out_path = "raw_render.mid" |
|
|
res[0].dump(raw_out_path) |
|
|
|
|
|
|
|
|
editable_out_path = "editable_render.mid" |
|
|
try: |
|
|
mapped = map_midi(midi, res[0]) |
|
|
mapped.dump(editable_out_path) |
|
|
return [raw_out_path, editable_out_path] |
|
|
except Exception as e: |
|
|
print(f"[Warning] map_midi failed: {e}") |
|
|
return [raw_out_path, f"[Error] map_midi failed: {e}"] |
|
|
|
|
|
except Exception as e: |
|
|
raise gr.Error(f"Inference failed: {e}") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
demo = gr.Interface( |
|
|
fn=render_midi, |
|
|
inputs=[ |
|
|
gr.File(label="Upload a Score MIDI File (.mid or .midi)", file_types=[".mid", ".midi", ".MID", ".MIDI"]), |
|
|
gr.Slider(0.1, 2.0, value=1.0, step=0.01, label="Temperature"), |
|
|
gr.Slider(0.1, 1.0, value=0.95, step=0.01, label="Top-p"), |
|
|
], |
|
|
outputs=[ |
|
|
gr.File(label="Raw Performance"), |
|
|
gr.File(label="Editable Version") |
|
|
], |
|
|
title="🎹 Pianist Transformer Rendering", |
|
|
description=( |
|
|
"Upload a piano score MIDI file and let the Pianist Transformer render it into " |
|
|
"a more expressive performance MIDI.\n\n" |
|
|
"Two versions will be saved:\n\n" |
|
|
"• **Raw Performance** – directly generated by the model\n\n" |
|
|
"• **Editable Version** – aligned with the score using our Expressive Tempo Mapping algorithm\n\n" |
|
|
"If mapping fails, only the raw version will be returned with an error message.\n\n" |
|
|
"⚠️ **This is only a demo running on limited compute resources. Please do not upload long pieces — " |
|
|
"we recommend clips shorter than 1 minute.**" |
|
|
), |
|
|
) |
|
|
|
|
|
if __name__ == "__main__": |
|
|
demo.launch() |