import gradio as gr import torch import os from miditoolkit import MidiFile from src.model.generate import batch_performance_render, map_midi from src.model.pianoformer import PianoT5Gemma # ------------------------------ # Load model # ------------------------------ def load_model(): print("Loading model...") model = PianoT5Gemma.from_pretrained( "yhj137/pianist-transformer-rendering", token=os.environ.get("hf_token"), torch_dtype=torch.bfloat16 ) model.eval() return model model = load_model() # ------------------------------ # Define inference function # ------------------------------ def render_midi(midi_file, temperature, top_p): try: input_path = midi_file.name midi = MidiFile(input_path) # Run inference res = batch_performance_render( model, [midi], temperature=temperature, top_p=top_p, device="cpu" # change to "cuda" if GPU available ) # Save raw (unmapped) result raw_out_path = "raw_render.mid" res[0].dump(raw_out_path) # Try to create editable (mapped) version editable_out_path = "editable_render.mid" try: mapped = map_midi(midi, res[0]) mapped.dump(editable_out_path) return [raw_out_path, editable_out_path] except Exception as e: print(f"[Warning] map_midi failed: {e}") return [raw_out_path, f"[Error] map_midi failed: {e}"] except Exception as e: raise gr.Error(f"Inference failed: {e}") # ------------------------------ # Build Gradio interface # ------------------------------ demo = gr.Interface( fn=render_midi, inputs=[ gr.File(label="Upload a Score MIDI File (.mid or .midi)", file_types=[".mid", ".midi", ".MID", ".MIDI"]), gr.Slider(0.1, 2.0, value=1.0, step=0.01, label="Temperature"), gr.Slider(0.1, 1.0, value=0.95, step=0.01, label="Top-p"), ], outputs=[ gr.File(label="Raw Performance"), gr.File(label="Editable Version") ], title="🎹 Pianist Transformer Rendering", description=( "Upload a piano score MIDI file and let the Pianist Transformer render it into " "a more expressive performance MIDI.\n\n" "Two versions will be saved:\n\n" "• **Raw Performance** – directly generated by the model\n\n" "• **Editable Version** – aligned with the score using our Expressive Tempo Mapping algorithm\n\n" "If mapping fails, only the raw version will be returned with an error message.\n\n" "⚠️ **This is only a demo running on limited compute resources. Please do not upload long pieces — " "we recommend clips shorter than 1 minute.**" ), ) if __name__ == "__main__": demo.launch()