update
Browse files- .gitignore +2 -0
- app.py +66 -6
- src/inference/inference.py +33 -0
- src/model/generate.py +377 -0
- src/model/pianoformer.py +459 -0
- src/utils/func.py +5 -0
- src/utils/midi.py +602 -0
.gitignore
ADDED
|
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
|
|
|
| 1 |
+
__pycache__
|
| 2 |
+
.DS_Store
|
app.py
CHANGED
|
@@ -1,10 +1,70 @@
|
|
| 1 |
import gradio as gr
|
|
|
|
| 2 |
import os
|
| 3 |
-
|
| 4 |
-
|
|
|
|
| 5 |
|
| 6 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 7 |
|
| 8 |
-
model =
|
| 9 |
-
|
| 10 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
import gradio as gr
|
| 2 |
+
import torch
|
| 3 |
import os
|
| 4 |
+
from miditoolkit import MidiFile
|
| 5 |
+
from src.model.generate import batch_performance_render, map_midi
|
| 6 |
+
from src.model.pianoformer import PianoT5Gemma
|
| 7 |
|
| 8 |
+
# ------------------------------
|
| 9 |
+
# Load model
|
| 10 |
+
# ------------------------------
|
| 11 |
+
def load_model():
|
| 12 |
+
print("Loading model...")
|
| 13 |
+
model = PianoT5Gemma.from_pretrained(
|
| 14 |
+
"yhj137/pianist-transformer-rendering",
|
| 15 |
+
token=os.environ.get("hf_token"),
|
| 16 |
+
torch_dtype=torch.bfloat16
|
| 17 |
+
)
|
| 18 |
+
model.eval()
|
| 19 |
+
return model
|
| 20 |
|
| 21 |
+
model = load_model()
|
| 22 |
+
|
| 23 |
+
# ------------------------------
|
| 24 |
+
# Define inference function
|
| 25 |
+
# ------------------------------
|
| 26 |
+
def render_midi(midi_file, temperature, top_p, top_k):
|
| 27 |
+
# Save uploaded file temporarily
|
| 28 |
+
input_path = midi_file.name
|
| 29 |
+
midi = MidiFile(input_path)
|
| 30 |
+
|
| 31 |
+
# Run inference
|
| 32 |
+
res = batch_performance_render(
|
| 33 |
+
model,
|
| 34 |
+
[midi],
|
| 35 |
+
temperature=temperature,
|
| 36 |
+
top_p=top_p,
|
| 37 |
+
top_k=top_k,
|
| 38 |
+
device="cpu" # change to "cuda" if you use GPU space
|
| 39 |
+
)
|
| 40 |
+
|
| 41 |
+
# Map result and save
|
| 42 |
+
mapped = map_midi(midi, res[0])
|
| 43 |
+
out_path = "output.mid"
|
| 44 |
+
mapped.dump(out_path)
|
| 45 |
+
|
| 46 |
+
return out_path
|
| 47 |
+
|
| 48 |
+
# ------------------------------
|
| 49 |
+
# Build Gradio interface
|
| 50 |
+
# ------------------------------
|
| 51 |
+
demo = gr.Interface(
|
| 52 |
+
fn=render_midi,
|
| 53 |
+
inputs=[
|
| 54 |
+
gr.File(label="Upload a Score MIDI File (.mid)", file_types=[".mid"]),
|
| 55 |
+
gr.Slider(0.1, 2.0, value=1.0, step=0.01, label="Temperature"),
|
| 56 |
+
gr.Slider(0.1, 1.0, value=0.95, step=0.01, label="Top-p"),
|
| 57 |
+
gr.Slider(1, 100, value=50, step=1, label="Top-k"),
|
| 58 |
+
],
|
| 59 |
+
outputs=gr.File(label="Rendered Performance MIDI"),
|
| 60 |
+
title="🎹 Pianist Transformer Rendering",
|
| 61 |
+
description=(
|
| 62 |
+
"Upload a symbolic (score) MIDI file and let the Pianist Transformer render it into "
|
| 63 |
+
"a more expressive performance MIDI. Adjust decoding parameters below to control "
|
| 64 |
+
"the expressiveness and randomness of the output."
|
| 65 |
+
),
|
| 66 |
+
examples=None,
|
| 67 |
+
)
|
| 68 |
+
|
| 69 |
+
if __name__ == "__main__":
|
| 70 |
+
demo.launch()
|
src/inference/inference.py
ADDED
|
@@ -0,0 +1,33 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from src.model.generate import batch_performance_render, map_midi
|
| 2 |
+
from src.model.pianoformer import PianoT5Gemma, PianoT5GemmaConfig
|
| 3 |
+
import torch
|
| 4 |
+
from datasets import load_dataset
|
| 5 |
+
import os
|
| 6 |
+
from miditoolkit import MidiFile
|
| 7 |
+
from src.utils.midi import midi_to_ids, ids_to_midi
|
| 8 |
+
import random
|
| 9 |
+
|
| 10 |
+
if __name__ == "__main__":
|
| 11 |
+
model = PianoT5Gemma.from_pretrained(
|
| 12 |
+
"models/sft/",
|
| 13 |
+
torch_dtype=torch.bfloat16
|
| 14 |
+
)#.cuda()
|
| 15 |
+
|
| 16 |
+
midis = []
|
| 17 |
+
for i in range(1):
|
| 18 |
+
midis.append(MidiFile(f"data/midis/testset/score/{i}.mid"))
|
| 19 |
+
|
| 20 |
+
res = batch_performance_render(
|
| 21 |
+
model,
|
| 22 |
+
midis,
|
| 23 |
+
temperature=1.0,
|
| 24 |
+
top_p=0.95,
|
| 25 |
+
device="cpu"
|
| 26 |
+
)
|
| 27 |
+
|
| 28 |
+
if not os.path.exists("data/midis/testset/inference"):
|
| 29 |
+
os.makedirs("data/midis/testset/inference")
|
| 30 |
+
|
| 31 |
+
for i, mid in enumerate(res):
|
| 32 |
+
mid = map_midi(midis[i], mid)
|
| 33 |
+
mid.dump(f"data/midis/testset/inference/{i}.mid")
|
src/model/generate.py
ADDED
|
@@ -0,0 +1,377 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from src.utils.midi import ids_to_midi, midi_to_ids
|
| 2 |
+
from src.model.pianoformer import PianoT5Gemma
|
| 3 |
+
from miditoolkit import MidiFile
|
| 4 |
+
from torch.nn.utils.rnn import pad_sequence
|
| 5 |
+
from transformers import LogitsProcessorList, LogitsProcessor
|
| 6 |
+
from tqdm import tqdm
|
| 7 |
+
import torch
|
| 8 |
+
from src.utils.midi import normalize_midi
|
| 9 |
+
from miditoolkit import MidiFile, Note, TempoChange, Instrument, ControlChange
|
| 10 |
+
import bisect
|
| 11 |
+
|
| 12 |
+
class BatchSparseForcedTokenProcessor(LogitsProcessor):
|
| 13 |
+
def __init__(self, input_ids, config, target_len, origin_len, already, weight, progress_callback):
|
| 14 |
+
self.batch_map = [{j: input_ids[i][j] for j in range(0, len(input_ids[i]), 8)} for i in range(len(input_ids))]
|
| 15 |
+
self.valid_id_range = config.valid_id_range
|
| 16 |
+
self.target_len = target_len
|
| 17 |
+
self.origin_len = origin_len
|
| 18 |
+
self.already = already
|
| 19 |
+
self.weight = weight
|
| 20 |
+
self.progress_callback = progress_callback
|
| 21 |
+
|
| 22 |
+
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
|
| 23 |
+
if self.progress_callback:
|
| 24 |
+
self.progress_callback(
|
| 25 |
+
(input_ids.shape[1] - self.origin_len) / (self.target_len - self.origin_len) * self.weight + self.already
|
| 26 |
+
)
|
| 27 |
+
step = input_ids.shape[1] - 1
|
| 28 |
+
batch_size = scores.shape[0]
|
| 29 |
+
for i in range(batch_size):
|
| 30 |
+
sample_map = self.batch_map[i]
|
| 31 |
+
if step in sample_map:
|
| 32 |
+
forced_token_id = sample_map[step]
|
| 33 |
+
scores[i] = float('-inf')
|
| 34 |
+
scores[i, forced_token_id] = 0.0
|
| 35 |
+
else:
|
| 36 |
+
step = step % 8
|
| 37 |
+
scores[i, :self.valid_id_range[step][0]] = float('-inf')
|
| 38 |
+
scores[i, self.valid_id_range[step][1]:] = float('-inf')
|
| 39 |
+
#if step % 8 > 3:
|
| 40 |
+
# scores = scores / 0.95
|
| 41 |
+
return scores
|
| 42 |
+
|
| 43 |
+
@torch.no_grad()
|
| 44 |
+
def batch_performance_render(
|
| 45 |
+
model,
|
| 46 |
+
score_midi_objs,
|
| 47 |
+
max_context_length=4096,
|
| 48 |
+
overlap_ratio=0.5,
|
| 49 |
+
temperature=1.0,
|
| 50 |
+
top_p=0.95,
|
| 51 |
+
device="cpu",
|
| 52 |
+
progress_callback=None
|
| 53 |
+
):
|
| 54 |
+
def slide_window(total_len, window_len):
|
| 55 |
+
if total_len <= window_len:
|
| 56 |
+
return [(0, total_len)]
|
| 57 |
+
window_len = window_len // 8 * 8
|
| 58 |
+
out = []
|
| 59 |
+
start = 0
|
| 60 |
+
while start + window_len <= total_len:
|
| 61 |
+
out.append((start, start + window_len))
|
| 62 |
+
start += int(window_len * (1 - overlap_ratio)) // 8 * 8
|
| 63 |
+
if out[-1][1] != total_len:
|
| 64 |
+
out.append((start, total_len))
|
| 65 |
+
return out
|
| 66 |
+
if max_context_length > 4096:
|
| 67 |
+
raise ValueError("You should set max_context_length <= 4096!")
|
| 68 |
+
batch_ids = [torch.tensor(midi_to_ids(model.config, score_midi_obj), dtype=torch.long).to(device) for score_midi_obj in score_midi_objs]
|
| 69 |
+
len_list = [len(batch_ids[i]) for i in range(len(batch_ids))]
|
| 70 |
+
|
| 71 |
+
input_ids = pad_sequence(batch_ids, batch_first=True, padding_value=model.config.pad_token_id)
|
| 72 |
+
windows = slide_window(input_ids.shape[1], max_context_length)
|
| 73 |
+
#print(windows)
|
| 74 |
+
output_list = []
|
| 75 |
+
res_tensor = None
|
| 76 |
+
for i in tqdm(range(len(windows))):
|
| 77 |
+
start, end = windows[i]
|
| 78 |
+
logits_processor = LogitsProcessorList([
|
| 79 |
+
BatchSparseForcedTokenProcessor(
|
| 80 |
+
input_ids[:,start:end],
|
| 81 |
+
model.config,
|
| 82 |
+
end,
|
| 83 |
+
start,
|
| 84 |
+
i / len(windows),
|
| 85 |
+
1 / len(windows),
|
| 86 |
+
progress_callback,
|
| 87 |
+
)
|
| 88 |
+
])
|
| 89 |
+
if i == 0:
|
| 90 |
+
output = model.generate(
|
| 91 |
+
input_ids[:,start:end],
|
| 92 |
+
do_sample=True,
|
| 93 |
+
max_new_tokens=end-start,
|
| 94 |
+
logits_processor=logits_processor,
|
| 95 |
+
temperature=temperature,
|
| 96 |
+
top_p=top_p,
|
| 97 |
+
)
|
| 98 |
+
res_tensor = output[:,1:]
|
| 99 |
+
else:
|
| 100 |
+
last_start, last_end = windows[i-1]
|
| 101 |
+
length = int(((last_end-last_start) - (start-last_start)) * 0.2)
|
| 102 |
+
decoder_input_ids = output_list[i-1][:, start-last_start:last_end-last_start - length]
|
| 103 |
+
start_tensor = torch.tensor([[model.config.bos_token_id] for _ in range(input_ids.shape[0])], dtype=torch.long).to(device)
|
| 104 |
+
decoder_input_ids = torch.cat([start_tensor, decoder_input_ids], dim=1)
|
| 105 |
+
#print(decoder_input_ids.shape)
|
| 106 |
+
output = model.generate(
|
| 107 |
+
input_ids[:,start:end],
|
| 108 |
+
decoder_input_ids=decoder_input_ids,
|
| 109 |
+
do_sample=True,
|
| 110 |
+
max_new_tokens=end-last_end+length,
|
| 111 |
+
logits_processor=logits_processor,
|
| 112 |
+
temperature=temperature,
|
| 113 |
+
top_p=top_p,
|
| 114 |
+
)
|
| 115 |
+
res_tensor = torch.cat([res_tensor[:,:-length], output[:,-(end-last_end+length):]], dim=1)
|
| 116 |
+
output_list.append(output)
|
| 117 |
+
res_tensor = res_tensor.cpu().numpy().tolist()
|
| 118 |
+
#print(res_tensor)
|
| 119 |
+
res = []
|
| 120 |
+
for i in range(len(res_tensor)):
|
| 121 |
+
#print(res_tensor[i][:len_list[i]])
|
| 122 |
+
res.append(ids_to_midi(model.config, res_tensor[i][:len_list[i]]))
|
| 123 |
+
return res
|
| 124 |
+
|
| 125 |
+
|
| 126 |
+
def map_midi(score_midi_obj, performance_midi_obj):
|
| 127 |
+
def compute_duration(start_time, target_duration, tempo_list):
|
| 128 |
+
if target_duration <= 0:
|
| 129 |
+
return 0
|
| 130 |
+
if not tempo_list:
|
| 131 |
+
# 如果没有提供tempo信息,则假定为默认的120 BPM
|
| 132 |
+
tempo_list = [TempoChange(120, 0)]
|
| 133 |
+
|
| 134 |
+
# --- 步骤1: 定位start_time所在的BPM区间 ---
|
| 135 |
+
# 提取所有tempo变化的时间点
|
| 136 |
+
tempo_times = [t.time for t in tempo_list]
|
| 137 |
+
# 使用二分查找找到start_time应该插入的位置
|
| 138 |
+
# bisect_right返回的是插入点索引,因此当前生效的tempo在索引-1的位置
|
| 139 |
+
start_tempo_idx = bisect.bisect_right(tempo_times, start_time) - 1
|
| 140 |
+
# 如果start_time在第一个tempo变化之前,索引会是-1,修正为0
|
| 141 |
+
if start_tempo_idx < 0:
|
| 142 |
+
start_tempo_idx = 0
|
| 143 |
+
|
| 144 |
+
# --- 步骤2: 初始化循环变量 ---
|
| 145 |
+
total_ticks_duration = 0.0
|
| 146 |
+
time_remaining_ms = float(target_duration)
|
| 147 |
+
current_tick = start_time
|
| 148 |
+
current_tempo_idx = start_tempo_idx
|
| 149 |
+
|
| 150 |
+
# --- 步骤3: 循环处理每个BPM区间,直到消耗完target_duration ---
|
| 151 |
+
# 使用一个极小值(epsilon)来处理浮点数精度问题
|
| 152 |
+
while time_remaining_ms > 1e-9:
|
| 153 |
+
current_tempo_event = tempo_list[current_tempo_idx]
|
| 154 |
+
current_bpm = current_tempo_event.tempo
|
| 155 |
+
|
| 156 |
+
# 计算在当前BPM下,每个tick持续多少毫秒
|
| 157 |
+
# 1分钟 = 60,000毫秒
|
| 158 |
+
# 每分钟节拍数 = bpm
|
| 159 |
+
# 每拍tick数 = TICK_PER_BEAT
|
| 160 |
+
# ms_per_tick = (毫秒/分钟) / (节拍/分钟) / (tick/节拍) = (60000 / bpm) / TICK_PER_BEAT
|
| 161 |
+
ms_per_tick = (60 * 1000.0 / current_bpm) / 500
|
| 162 |
+
|
| 163 |
+
# 确定当前BPM区间的结束点
|
| 164 |
+
# 如果是最后一个tempo,则它会一直持续下去
|
| 165 |
+
end_of_segment_tick = float('inf')
|
| 166 |
+
if current_tempo_idx + 1 < len(tempo_list):
|
| 167 |
+
end_of_segment_tick = tempo_list[current_tempo_idx + 1].time
|
| 168 |
+
|
| 169 |
+
# 计算从当前位置到本BPM区间结束,有多少tick
|
| 170 |
+
ticks_in_segment = end_of_segment_tick - current_tick
|
| 171 |
+
# 这些tick总共持续多少毫秒
|
| 172 |
+
ms_in_segment = ticks_in_segment * ms_per_tick
|
| 173 |
+
|
| 174 |
+
# --- 步骤4: 判断与更新 ---
|
| 175 |
+
if time_remaining_ms <= ms_in_segment:
|
| 176 |
+
# 如果剩余需要的时间,在本BPM区间内就能满足
|
| 177 |
+
# 计算还需要多少tick来凑够剩余的毫秒数
|
| 178 |
+
ticks_needed = time_remaining_ms / ms_per_tick
|
| 179 |
+
total_ticks_duration += ticks_needed
|
| 180 |
+
# 时间已全部消耗完毕,跳出循环
|
| 181 |
+
time_remaining_ms = 0
|
| 182 |
+
else:
|
| 183 |
+
# 如果本BPM区间的时间不够用
|
| 184 |
+
# 消耗掉整个区间的tick和毫秒数
|
| 185 |
+
total_ticks_duration += ticks_in_segment
|
| 186 |
+
time_remaining_ms -= ms_in_segment
|
| 187 |
+
|
| 188 |
+
# 更新“指针”,移动到下一个BPM区间的起点
|
| 189 |
+
current_tick = end_of_segment_tick
|
| 190 |
+
current_tempo_idx += 1
|
| 191 |
+
|
| 192 |
+
# 返回四舍五入后的总tick数
|
| 193 |
+
return round(total_ticks_duration)
|
| 194 |
+
|
| 195 |
+
def ms_to_tick(target_ms, tempo_list):
|
| 196 |
+
# --- 边缘情况处理 ---
|
| 197 |
+
if target_ms <= 0:
|
| 198 |
+
return 0
|
| 199 |
+
if not tempo_list:
|
| 200 |
+
# 如果没有提供tempo信息,则假定为默认的120 BPM
|
| 201 |
+
tempo_list = [TempoChange(120, 0)]
|
| 202 |
+
|
| 203 |
+
# --- 步骤1: 初始化累加器 ---
|
| 204 |
+
accumulated_ms = 0.0
|
| 205 |
+
|
| 206 |
+
# --- 步骤2: 遍历所有“有终点”的BPM区间 ---
|
| 207 |
+
# 我们遍历到倒数第二个元素,因为每个循环处理的是 tempo[i] 到 tempo[i+1] 的区间
|
| 208 |
+
for i in range(len(tempo_list) - 1):
|
| 209 |
+
current_tempo_event = tempo_list[i]
|
| 210 |
+
next_tempo_event = tempo_list[i+1]
|
| 211 |
+
|
| 212 |
+
current_bpm = current_tempo_event.tempo
|
| 213 |
+
|
| 214 |
+
# 计算当前区间的tick数和对应的毫秒数
|
| 215 |
+
ticks_in_segment = next_tempo_event.time - current_tempo_event.time
|
| 216 |
+
|
| 217 |
+
# 如果区间长度为0,直接跳过,避免除零错误
|
| 218 |
+
if ticks_in_segment == 0:
|
| 219 |
+
continue
|
| 220 |
+
|
| 221 |
+
ms_per_tick = (60 * 1000.0 / current_bpm) / 500
|
| 222 |
+
ms_in_segment = ticks_in_segment * ms_per_tick
|
| 223 |
+
|
| 224 |
+
# --- 步骤3: 判断目标是否在本区间内 ---
|
| 225 |
+
if target_ms <= accumulated_ms + ms_in_segment:
|
| 226 |
+
# 目标在本区间内!
|
| 227 |
+
ms_into_segment = target_ms - accumulated_ms
|
| 228 |
+
ticks_needed = ms_into_segment / ms_per_tick
|
| 229 |
+
|
| 230 |
+
# 最终tick = 本区间起始tick + 在本区间内转换出的tick
|
| 231 |
+
final_tick = current_tempo_event.time + ticks_needed
|
| 232 |
+
return round(final_tick)
|
| 233 |
+
|
| 234 |
+
# 如果目标不在本区间,则累加本区间的总毫秒数,继续下一个循环
|
| 235 |
+
accumulated_ms += ms_in_segment
|
| 236 |
+
|
| 237 |
+
# --- 步骤4: 如果循环结束仍未返回,说明目标在最后一个BPM区间内 ---
|
| 238 |
+
last_tempo_event = tempo_list[-1]
|
| 239 |
+
last_bpm = last_tempo_event.tempo
|
| 240 |
+
|
| 241 |
+
ms_per_tick = (60 * 1000.0 / last_bpm) / 500
|
| 242 |
+
|
| 243 |
+
# 计算进入最后一个区间后,还需要多少毫秒
|
| 244 |
+
ms_into_segment = target_ms - accumulated_ms
|
| 245 |
+
ticks_needed = ms_into_segment / ms_per_tick
|
| 246 |
+
|
| 247 |
+
# 最终tick = 最后一个区间的起始tick + 剩余毫秒转换的tick
|
| 248 |
+
final_tick = last_tempo_event.time + ticks_needed
|
| 249 |
+
return round(final_tick)
|
| 250 |
+
|
| 251 |
+
norm_score = normalize_midi(score_midi_obj)
|
| 252 |
+
norm_performance = normalize_midi(performance_midi_obj)
|
| 253 |
+
|
| 254 |
+
score_notes = norm_score.instruments[0].notes
|
| 255 |
+
performance_notes = norm_performance.instruments[0].notes
|
| 256 |
+
performance_ccs = norm_performance.instruments[0].control_changes
|
| 257 |
+
|
| 258 |
+
start_list = []
|
| 259 |
+
last = -1
|
| 260 |
+
score_start = score_notes[0].start
|
| 261 |
+
performance_start = performance_notes[0].start
|
| 262 |
+
for i in range(len(score_notes)):
|
| 263 |
+
performance_notes[i].end -= performance_start
|
| 264 |
+
performance_notes[i].start -= performance_start
|
| 265 |
+
score_notes[i].end -= score_start
|
| 266 |
+
score_notes[i].start -= score_start
|
| 267 |
+
if score_notes[i].start != last:
|
| 268 |
+
start_list.append((score_notes[i].start, performance_notes[i].start, i))
|
| 269 |
+
last = score_notes[i].start
|
| 270 |
+
|
| 271 |
+
for i in range(len(performance_ccs)):
|
| 272 |
+
performance_ccs[i].time -= performance_start
|
| 273 |
+
|
| 274 |
+
score_interval_list = []
|
| 275 |
+
performance_interval_list = []
|
| 276 |
+
|
| 277 |
+
for i in range(len(start_list)-1):
|
| 278 |
+
score_interval_list.append(start_list[i+1][0] - start_list[i][0])
|
| 279 |
+
performance_interval_list.append(start_list[i+1][1] - start_list[i][1])
|
| 280 |
+
#print(score_interval_list)
|
| 281 |
+
#print(performance_interval_list)
|
| 282 |
+
|
| 283 |
+
tempo_list = []
|
| 284 |
+
start_note_offset = []
|
| 285 |
+
for i in range(len(score_interval_list)):
|
| 286 |
+
if performance_interval_list[i] != 0:
|
| 287 |
+
bpm = 120.0 / performance_interval_list[i] * score_interval_list[i]
|
| 288 |
+
else:
|
| 289 |
+
bpm = 300
|
| 290 |
+
|
| 291 |
+
if bpm > 300:
|
| 292 |
+
start_note_offset.append(300 / 120.0 * performance_interval_list[i] - score_interval_list[i])
|
| 293 |
+
elif bpm < 10:
|
| 294 |
+
start_note_offset.append(10 / 120.0 * performance_interval_list[i] - score_interval_list[i])
|
| 295 |
+
else:
|
| 296 |
+
start_note_offset.append(0)
|
| 297 |
+
tempo_list.append(max(min(bpm, 300), 10))
|
| 298 |
+
#tempo_list.append(120.0 / performance_interval_list[i] * score_interval_list[i])
|
| 299 |
+
#print(tempo_list)
|
| 300 |
+
|
| 301 |
+
for i in range(1, len(start_note_offset)):
|
| 302 |
+
start_note_offset[i] += start_note_offset[i-1]
|
| 303 |
+
#print(start_note_offset)
|
| 304 |
+
|
| 305 |
+
#print(len(tempo_list))
|
| 306 |
+
#print(len(start_list))
|
| 307 |
+
note_tempo_list = []
|
| 308 |
+
note_performance_align = []
|
| 309 |
+
note_start_offset = [0]
|
| 310 |
+
cnt = 0
|
| 311 |
+
for i in range(len(score_notes)):
|
| 312 |
+
if cnt < len(start_list) - 2 and i >= start_list[cnt + 1][2]:
|
| 313 |
+
cnt += 1
|
| 314 |
+
note_tempo_list.append(tempo_list[cnt])
|
| 315 |
+
note_performance_align.append(start_list[cnt][1])
|
| 316 |
+
note_start_offset.append(start_note_offset[cnt])
|
| 317 |
+
#print(note_start_offset)
|
| 318 |
+
|
| 319 |
+
for i in range(len(score_notes)):
|
| 320 |
+
score_notes[i].start += note_start_offset[i]
|
| 321 |
+
note_interval_list = [0]
|
| 322 |
+
for i in range(len(score_notes)-1):
|
| 323 |
+
note_interval_list.append(score_notes[i+1].start - score_notes[i].start)
|
| 324 |
+
|
| 325 |
+
#print(note_tempo_list)
|
| 326 |
+
#print(note_performance_align)
|
| 327 |
+
|
| 328 |
+
#for i in range(len(performance_notes)):
|
| 329 |
+
#print(performance_notes[i].start)
|
| 330 |
+
|
| 331 |
+
micro_shift_list = [0]
|
| 332 |
+
cnt = 1
|
| 333 |
+
last_time = 0
|
| 334 |
+
for i in range(1, len(score_notes)):
|
| 335 |
+
last_time += note_interval_list[i] / note_tempo_list[i-1] * 120
|
| 336 |
+
micro_shift_list.append((performance_notes[i].start - last_time) / 120 * note_tempo_list[i-1])
|
| 337 |
+
#last_time = note_performance_align[i]
|
| 338 |
+
#print(last_time)
|
| 339 |
+
|
| 340 |
+
#print(micro_shift_list)
|
| 341 |
+
#plt.plot(tempo_list)
|
| 342 |
+
|
| 343 |
+
res = MidiFile(ticks_per_beat=500)
|
| 344 |
+
res_notes = []
|
| 345 |
+
start_time_list = []
|
| 346 |
+
tempo_list_filter = []
|
| 347 |
+
cc_list = []
|
| 348 |
+
last = -1
|
| 349 |
+
for i in range(len(score_notes)):
|
| 350 |
+
start_time_list.append(round(score_notes[i].start + micro_shift_list[i]))
|
| 351 |
+
#res_notes.append(Note(performance_notes[i].velocity, score_notes[i].pitch, round(score_notes[i].start + micro_shift_list[i]), round(score_notes[i].start + micro_shift_list[i]) + 100))
|
| 352 |
+
#res.tempo_changes.append(TempoChange(round(note_tempo_list[i]), round(score_notes[i].start + micro_shift_list[i])))
|
| 353 |
+
#print(last , round(note_tempo_list[i]))
|
| 354 |
+
if last != round(note_tempo_list[i]):
|
| 355 |
+
tempo_list_filter.append(TempoChange(round(note_tempo_list[i]), round(score_notes[i].start + micro_shift_list[i])))
|
| 356 |
+
last = round(note_tempo_list[i])
|
| 357 |
+
for i in range(len(score_notes)):
|
| 358 |
+
res_notes.append(
|
| 359 |
+
Note(
|
| 360 |
+
performance_notes[i].velocity,
|
| 361 |
+
score_notes[i].pitch,
|
| 362 |
+
start_time_list[i],
|
| 363 |
+
start_time_list[i]+compute_duration(start_time_list[i], performance_notes[i].duration, tempo_list_filter)
|
| 364 |
+
)
|
| 365 |
+
)
|
| 366 |
+
|
| 367 |
+
for cc in performance_ccs:
|
| 368 |
+
cc_list.append(ControlChange(64, cc.value, ms_to_tick(cc.time, tempo_list_filter)))
|
| 369 |
+
|
| 370 |
+
#print(tempo_list_filter)
|
| 371 |
+
res.tempo_changes = tempo_list_filter
|
| 372 |
+
res.instruments.append(Instrument(program=0, is_drum=False, name="Piano", notes=res_notes, control_changes=cc_list))
|
| 373 |
+
return res
|
| 374 |
+
|
| 375 |
+
|
| 376 |
+
if __name__ == "__main__":
|
| 377 |
+
pass
|
src/model/pianoformer.py
ADDED
|
@@ -0,0 +1,459 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import Callable, Optional, Union
|
| 2 |
+
from transformers import T5GemmaModel, T5GemmaConfig, T5GemmaModuleConfig, T5GemmaPreTrainedModel, T5GemmaForConditionalGeneration, AutoTokenizer
|
| 3 |
+
import torch
|
| 4 |
+
from transformers.models.t5gemma.modeling_t5gemma import (
|
| 5 |
+
T5GemmaLMHead,
|
| 6 |
+
GenerationMixin,
|
| 7 |
+
logger,
|
| 8 |
+
T5GemmaSelfAttention,
|
| 9 |
+
T5GemmaEncoderLayer,
|
| 10 |
+
T5GemmaRMSNorm,
|
| 11 |
+
T5GemmaRotaryEmbedding,
|
| 12 |
+
make_default_2d_attention_mask,
|
| 13 |
+
create_causal_mask,
|
| 14 |
+
bidirectional_mask_function,
|
| 15 |
+
create_sliding_window_causal_mask,
|
| 16 |
+
sliding_window_bidirectional_mask_function,
|
| 17 |
+
T5GemmaDecoder
|
| 18 |
+
)
|
| 19 |
+
|
| 20 |
+
from transformers.modeling_outputs import (
|
| 21 |
+
BaseModelOutput,
|
| 22 |
+
BaseModelOutputWithPastAndCrossAttentions,
|
| 23 |
+
Seq2SeqLMOutput,
|
| 24 |
+
Seq2SeqModelOutput,
|
| 25 |
+
SequenceClassifierOutput,
|
| 26 |
+
TokenClassifierOutput,
|
| 27 |
+
)
|
| 28 |
+
from transformers.cache_utils import Cache, DynamicCache, EncoderDecoderCache
|
| 29 |
+
from transformers.processing_utils import Unpack
|
| 30 |
+
from transformers.utils import TransformersKwargs, auto_docstring, can_return_tuple, is_torchdynamo_compiling, logging
|
| 31 |
+
import torch.nn as nn
|
| 32 |
+
|
| 33 |
+
class PianoT5GemmaConfig(T5GemmaConfig):
|
| 34 |
+
|
| 35 |
+
def __init__(
|
| 36 |
+
self,
|
| 37 |
+
hidden_size=768,
|
| 38 |
+
intermediate_size=3072,
|
| 39 |
+
num_attention_heads=8,
|
| 40 |
+
num_key_value_heads=4,
|
| 41 |
+
head_dim=128,
|
| 42 |
+
encoder_layers_num=8,
|
| 43 |
+
decoder_layers_num=4,
|
| 44 |
+
**kwargs
|
| 45 |
+
):
|
| 46 |
+
total_vocab_size = 5389
|
| 47 |
+
|
| 48 |
+
self.mask_token_id = 1
|
| 49 |
+
self.bos_token_id = 2
|
| 50 |
+
self.play_token_id = 4
|
| 51 |
+
self.pitch_start = 5
|
| 52 |
+
self.velocity_start = 5 + 128
|
| 53 |
+
self.timing_start = 5 + 128 + 128
|
| 54 |
+
self.pedal_start = 5 + 128 + 128 + 5000
|
| 55 |
+
self.hidden_size = hidden_size
|
| 56 |
+
|
| 57 |
+
self.valid_id_range = [
|
| 58 |
+
(5, 133),
|
| 59 |
+
(261, 5252),
|
| 60 |
+
(133, 261),
|
| 61 |
+
(261, 5261),
|
| 62 |
+
(5261, 5389),
|
| 63 |
+
(5261, 5389),
|
| 64 |
+
(5261, 5389),
|
| 65 |
+
(5261, 5389),
|
| 66 |
+
]
|
| 67 |
+
|
| 68 |
+
encoder_config = T5GemmaModuleConfig(
|
| 69 |
+
vocab_size=total_vocab_size,
|
| 70 |
+
hidden_size=hidden_size,
|
| 71 |
+
intermediate_size=intermediate_size,
|
| 72 |
+
num_hidden_layers=encoder_layers_num,
|
| 73 |
+
num_attention_heads=num_attention_heads,
|
| 74 |
+
num_key_value_heads=num_key_value_heads,
|
| 75 |
+
head_dim=head_dim,
|
| 76 |
+
pad_token_id=0,
|
| 77 |
+
bos_token_id=2,
|
| 78 |
+
eos_token_id=3,
|
| 79 |
+
)
|
| 80 |
+
decoder_config = T5GemmaModuleConfig(
|
| 81 |
+
vocab_size=total_vocab_size,
|
| 82 |
+
hidden_size=hidden_size,
|
| 83 |
+
intermediate_size=intermediate_size,
|
| 84 |
+
num_hidden_layers=decoder_layers_num,
|
| 85 |
+
num_attention_heads=num_attention_heads,
|
| 86 |
+
num_key_value_heads=num_key_value_heads,
|
| 87 |
+
head_dim=head_dim,
|
| 88 |
+
pad_token_id=0,
|
| 89 |
+
bos_token_id=2,
|
| 90 |
+
eos_token_id=3,
|
| 91 |
+
)
|
| 92 |
+
|
| 93 |
+
super().__init__(
|
| 94 |
+
encoder=encoder_config,
|
| 95 |
+
decoder=decoder_config,
|
| 96 |
+
vocab_size=total_vocab_size,
|
| 97 |
+
**kwargs,
|
| 98 |
+
)
|
| 99 |
+
|
| 100 |
+
class PianoEncoderEmbeddings(nn.Module):
|
| 101 |
+
def __init__(self, config):
|
| 102 |
+
super().__init__()
|
| 103 |
+
self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id)
|
| 104 |
+
if config.hidden_size % 8 != 0:
|
| 105 |
+
raise ValueError("Invalid hidden size: must be a multiple of 8.")
|
| 106 |
+
self.projection_layers = nn.ModuleList([nn.Linear(config.hidden_size, config.hidden_size // 8) for i in range(8)])
|
| 107 |
+
self.hidden_size = config.hidden_size
|
| 108 |
+
|
| 109 |
+
def forward(
|
| 110 |
+
self,
|
| 111 |
+
input_ids: Optional[torch.LongTensor] = None,
|
| 112 |
+
) -> torch.Tensor:
|
| 113 |
+
input_shape = input_ids.size()
|
| 114 |
+
|
| 115 |
+
batch_size = input_shape[0]
|
| 116 |
+
seq_length = input_shape[1]
|
| 117 |
+
|
| 118 |
+
inputs_embeds = self.word_embeddings(input_ids)
|
| 119 |
+
grouped_embeds = inputs_embeds.view(batch_size, seq_length // 8, 8, -1)
|
| 120 |
+
projection_list = []
|
| 121 |
+
for i in range(8):
|
| 122 |
+
projection_list.append(self.projection_layers[i](grouped_embeds[:,:,i,:]))
|
| 123 |
+
projection_cat = torch.cat(projection_list, dim=-1)
|
| 124 |
+
inputs_embeds = projection_cat.view(batch_size, -1, self.hidden_size)
|
| 125 |
+
embeddings = inputs_embeds
|
| 126 |
+
return embeddings
|
| 127 |
+
|
| 128 |
+
|
| 129 |
+
class PianoT5GemmaEncoder(T5GemmaPreTrainedModel):
|
| 130 |
+
_can_record_outputs = {
|
| 131 |
+
"attentions": T5GemmaSelfAttention,
|
| 132 |
+
"hidden_states": T5GemmaEncoderLayer,
|
| 133 |
+
}
|
| 134 |
+
|
| 135 |
+
def __init__(self, config):
|
| 136 |
+
super().__init__(config)
|
| 137 |
+
self.padding_idx = config.pad_token_id
|
| 138 |
+
self.vocab_size = config.vocab_size
|
| 139 |
+
|
| 140 |
+
self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
|
| 141 |
+
self.embeddings = PianoEncoderEmbeddings(config)
|
| 142 |
+
self.norm = T5GemmaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
| 143 |
+
self.rotary_emb = T5GemmaRotaryEmbedding(config=config)
|
| 144 |
+
self.gradient_checkpointing = False
|
| 145 |
+
|
| 146 |
+
self.layers = nn.ModuleList(
|
| 147 |
+
[T5GemmaEncoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
|
| 148 |
+
)
|
| 149 |
+
self.dropout = nn.Dropout(config.dropout_rate)
|
| 150 |
+
|
| 151 |
+
# Initialize weights and apply final processing
|
| 152 |
+
self.post_init()
|
| 153 |
+
|
| 154 |
+
def forward(
|
| 155 |
+
self,
|
| 156 |
+
input_ids: Optional[torch.LongTensor] = None,
|
| 157 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 158 |
+
position_ids: Optional[torch.LongTensor] = None,
|
| 159 |
+
inputs_embeds: Optional[torch.FloatTensor] = None,
|
| 160 |
+
**kwargs: Unpack[TransformersKwargs],
|
| 161 |
+
) -> BaseModelOutput:
|
| 162 |
+
if (input_ids is None) ^ (inputs_embeds is not None):
|
| 163 |
+
raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
|
| 164 |
+
|
| 165 |
+
if inputs_embeds is None:
|
| 166 |
+
inputs_embeds = self.embeddings(input_ids)
|
| 167 |
+
input_ids = None
|
| 168 |
+
|
| 169 |
+
cache_position = torch.arange(0, inputs_embeds.shape[1], device=inputs_embeds.device)
|
| 170 |
+
|
| 171 |
+
if position_ids is None:
|
| 172 |
+
position_ids = cache_position.unsqueeze(0)
|
| 173 |
+
|
| 174 |
+
if attention_mask is not None:
|
| 175 |
+
B, L = attention_mask.shape
|
| 176 |
+
block_mask = attention_mask.view(B, L // 8, 8)
|
| 177 |
+
mask2 = block_mask.any(dim=-1).long()
|
| 178 |
+
attention_mask = mask2.view(B, -1)
|
| 179 |
+
|
| 180 |
+
if attention_mask is None:
|
| 181 |
+
attention_mask = make_default_2d_attention_mask(input_ids, inputs_embeds, self.config.pad_token_id)
|
| 182 |
+
|
| 183 |
+
if not isinstance(self_attn_mask_mapping := attention_mask, dict):
|
| 184 |
+
mask_kwargs = {
|
| 185 |
+
"config": self.config,
|
| 186 |
+
"input_embeds": inputs_embeds,
|
| 187 |
+
"attention_mask": attention_mask,
|
| 188 |
+
"cache_position": cache_position,
|
| 189 |
+
"past_key_values": None,
|
| 190 |
+
"position_ids": position_ids,
|
| 191 |
+
}
|
| 192 |
+
self_attn_mask_mapping = {
|
| 193 |
+
"full_attention": create_causal_mask(
|
| 194 |
+
**mask_kwargs,
|
| 195 |
+
or_mask_function=bidirectional_mask_function(attention_mask),
|
| 196 |
+
),
|
| 197 |
+
"sliding_attention": create_sliding_window_causal_mask(
|
| 198 |
+
**mask_kwargs,
|
| 199 |
+
or_mask_function=sliding_window_bidirectional_mask_function(self.config.sliding_window),
|
| 200 |
+
and_mask_function=bidirectional_mask_function(attention_mask),
|
| 201 |
+
),
|
| 202 |
+
}
|
| 203 |
+
|
| 204 |
+
hidden_states = inputs_embeds
|
| 205 |
+
position_embeddings = self.rotary_emb(hidden_states, position_ids)
|
| 206 |
+
|
| 207 |
+
normalizer = torch.tensor(self.config.hidden_size**0.5, dtype=hidden_states.dtype)
|
| 208 |
+
hidden_states = hidden_states * normalizer
|
| 209 |
+
hidden_states = self.dropout(hidden_states)
|
| 210 |
+
|
| 211 |
+
for layer_module in self.layers[: self.config.num_hidden_layers]:
|
| 212 |
+
hidden_states = layer_module(
|
| 213 |
+
hidden_states,
|
| 214 |
+
position_embeddings,
|
| 215 |
+
self_attn_mask_mapping[layer_module.attention_type],
|
| 216 |
+
position_ids,
|
| 217 |
+
**kwargs,
|
| 218 |
+
)
|
| 219 |
+
hidden_states = self.norm(hidden_states)
|
| 220 |
+
hidden_states = self.dropout(hidden_states)
|
| 221 |
+
return BaseModelOutput(
|
| 222 |
+
last_hidden_state=hidden_states,
|
| 223 |
+
)
|
| 224 |
+
|
| 225 |
+
class PianoT5GemmaModel(T5GemmaPreTrainedModel):
|
| 226 |
+
def __init__(self, config: T5GemmaConfig):
|
| 227 |
+
super().__init__(config)
|
| 228 |
+
|
| 229 |
+
if not config.is_encoder_decoder:
|
| 230 |
+
raise ValueError("T5GemmaModel only support encoder-decoder modeling. Use `T5GemmaEncoderModel` instead.")
|
| 231 |
+
|
| 232 |
+
self.encoder = PianoT5GemmaEncoder(config.encoder)
|
| 233 |
+
self.decoder = T5GemmaDecoder(config.decoder)
|
| 234 |
+
|
| 235 |
+
self.post_init()
|
| 236 |
+
|
| 237 |
+
def get_encoder(self):
|
| 238 |
+
return self.encoder
|
| 239 |
+
|
| 240 |
+
def get_decoder(self):
|
| 241 |
+
return self.decoder
|
| 242 |
+
|
| 243 |
+
def get_input_embeddings(self):
|
| 244 |
+
return self.encoder.get_input_embeddings()
|
| 245 |
+
|
| 246 |
+
def set_input_embeddings(self, new_embeddings):
|
| 247 |
+
return self.encoder.set_input_embeddings(new_embeddings)
|
| 248 |
+
|
| 249 |
+
def forward(
|
| 250 |
+
self,
|
| 251 |
+
input_ids: Optional[torch.LongTensor] = None,
|
| 252 |
+
attention_mask: Optional[torch.FloatTensor] = None,
|
| 253 |
+
position_ids: Optional[torch.LongTensor] = None,
|
| 254 |
+
decoder_input_ids: Optional[torch.LongTensor] = None,
|
| 255 |
+
decoder_attention_mask: Optional[torch.BoolTensor] = None,
|
| 256 |
+
decoder_position_ids: Optional[torch.LongTensor] = None,
|
| 257 |
+
encoder_outputs: Optional[BaseModelOutput] = None,
|
| 258 |
+
past_key_values: Optional[EncoderDecoderCache] = None,
|
| 259 |
+
inputs_embeds: Optional[torch.Tensor] = None,
|
| 260 |
+
decoder_inputs_embeds: Optional[torch.Tensor] = None,
|
| 261 |
+
use_cache: Optional[bool] = None,
|
| 262 |
+
cache_position: Optional[torch.LongTensor] = None,
|
| 263 |
+
**kwargs: Unpack[TransformersKwargs],
|
| 264 |
+
) -> Seq2SeqModelOutput:
|
| 265 |
+
r"""
|
| 266 |
+
decoder_position_ids (`torch.LongTensor` of shape `(batch_size, decoder_sequence_length)`, *optional*):
|
| 267 |
+
Indices of positions of each decoder input sequence tokens in the position embeddings. Selected in the range `[0,
|
| 268 |
+
config.decoder.n_positions - 1]`. [What are position IDs?](../glossary#position-ids)
|
| 269 |
+
"""
|
| 270 |
+
if encoder_outputs is None:
|
| 271 |
+
encoder_outputs = self.encoder(
|
| 272 |
+
input_ids=input_ids,
|
| 273 |
+
attention_mask=attention_mask,
|
| 274 |
+
position_ids=position_ids,
|
| 275 |
+
inputs_embeds=inputs_embeds,
|
| 276 |
+
**kwargs,
|
| 277 |
+
)
|
| 278 |
+
|
| 279 |
+
encoder_hidden_states = encoder_outputs.last_hidden_state
|
| 280 |
+
|
| 281 |
+
decoder_outputs = self.decoder(
|
| 282 |
+
input_ids=decoder_input_ids,
|
| 283 |
+
attention_mask=decoder_attention_mask,
|
| 284 |
+
position_ids=decoder_position_ids,
|
| 285 |
+
inputs_embeds=decoder_inputs_embeds,
|
| 286 |
+
past_key_values=past_key_values,
|
| 287 |
+
encoder_hidden_states=encoder_hidden_states,
|
| 288 |
+
encoder_attention_mask=attention_mask,
|
| 289 |
+
use_cache=use_cache,
|
| 290 |
+
cache_position=cache_position,
|
| 291 |
+
**kwargs,
|
| 292 |
+
)
|
| 293 |
+
|
| 294 |
+
return Seq2SeqModelOutput(
|
| 295 |
+
last_hidden_state=decoder_outputs.last_hidden_state,
|
| 296 |
+
past_key_values=decoder_outputs.past_key_values,
|
| 297 |
+
decoder_hidden_states=decoder_outputs.hidden_states
|
| 298 |
+
if kwargs.get("output_hidden_states", False)
|
| 299 |
+
else (decoder_outputs.last_hidden_state,),
|
| 300 |
+
decoder_attentions=decoder_outputs.attentions,
|
| 301 |
+
cross_attentions=decoder_outputs.cross_attentions,
|
| 302 |
+
encoder_last_hidden_state=encoder_outputs.last_hidden_state,
|
| 303 |
+
encoder_hidden_states=encoder_outputs.hidden_states,
|
| 304 |
+
encoder_attentions=encoder_outputs.attentions,
|
| 305 |
+
)
|
| 306 |
+
|
| 307 |
+
|
| 308 |
+
class PianoT5Gemma(T5GemmaPreTrainedModel, GenerationMixin):
|
| 309 |
+
_tied_weights_keys = ["model.decoder.embed_tokens.weight", "lm_head.out_proj.weight"]
|
| 310 |
+
_tp_plan = {"lm_head.out_proj": "colwise_rep"}
|
| 311 |
+
_pp_plan = {"lm_head.out_proj": (["hidden_states"], ["logits"])}
|
| 312 |
+
|
| 313 |
+
def __init__(self, config: PianoT5GemmaConfig):
|
| 314 |
+
config.is_encoder_decoder = True
|
| 315 |
+
super().__init__(config)
|
| 316 |
+
self.embeddings = PianoEncoderEmbeddings(config)
|
| 317 |
+
self.model = PianoT5GemmaModel(config)
|
| 318 |
+
self.vocab_size = config.decoder.vocab_size
|
| 319 |
+
self.lm_head = T5GemmaLMHead(config.decoder.hidden_size, self.vocab_size)
|
| 320 |
+
self.loss_type = "ForMaskedLM"
|
| 321 |
+
|
| 322 |
+
self.post_init()
|
| 323 |
+
|
| 324 |
+
def set_output_embeddings(self, new_embeddings):
|
| 325 |
+
self.lm_head.out_proj = new_embeddings
|
| 326 |
+
|
| 327 |
+
def get_output_embeddings(self):
|
| 328 |
+
return self.lm_head.out_proj
|
| 329 |
+
|
| 330 |
+
def _tie_weights(self):
|
| 331 |
+
# Decoder input and output embeddings are tied.
|
| 332 |
+
if self.config.tie_word_embeddings:
|
| 333 |
+
self._tie_or_clone_weights(self.lm_head.out_proj, self.get_decoder().get_input_embeddings())
|
| 334 |
+
|
| 335 |
+
def get_encoder(self):
|
| 336 |
+
return self.model.encoder
|
| 337 |
+
|
| 338 |
+
def get_decoder(self):
|
| 339 |
+
return self.model.decoder
|
| 340 |
+
|
| 341 |
+
def forward(
|
| 342 |
+
self,
|
| 343 |
+
input_ids: Optional[torch.LongTensor] = None,
|
| 344 |
+
attention_mask: Optional[torch.FloatTensor] = None,
|
| 345 |
+
position_ids: Optional[torch.LongTensor] = None,
|
| 346 |
+
decoder_input_ids: Optional[torch.LongTensor] = None,
|
| 347 |
+
decoder_attention_mask: Optional[torch.BoolTensor] = None,
|
| 348 |
+
decoder_position_ids: Optional[torch.LongTensor] = None,
|
| 349 |
+
encoder_outputs: Optional[BaseModelOutput] = None,
|
| 350 |
+
past_key_values: Optional[EncoderDecoderCache] = None,
|
| 351 |
+
inputs_embeds: Optional[torch.FloatTensor] = None,
|
| 352 |
+
decoder_inputs_embeds: Optional[torch.FloatTensor] = None,
|
| 353 |
+
labels: Optional[torch.LongTensor] = None,
|
| 354 |
+
use_cache: Optional[bool] = None,
|
| 355 |
+
cache_position: Optional[torch.LongTensor] = None,
|
| 356 |
+
logits_to_keep: Union[int, torch.Tensor] = 0,
|
| 357 |
+
**kwargs: Unpack[TransformersKwargs],
|
| 358 |
+
) -> Union[tuple[torch.FloatTensor], Seq2SeqLMOutput]:
|
| 359 |
+
r"""
|
| 360 |
+
decoder_position_ids (`torch.LongTensor` of shape `(batch_size, decoder_sequence_length)`, *optional*):
|
| 361 |
+
Indices of positions of each decoder input sequence tokens in the position embeddings. Selected in the range `[0,
|
| 362 |
+
config.decoder.n_positions - 1]`. [What are position IDs?](../glossary#position-ids)
|
| 363 |
+
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
| 364 |
+
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
|
| 365 |
+
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
|
| 366 |
+
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
|
| 367 |
+
"""
|
| 368 |
+
if self.training and self.config._attn_implementation != "eager":
|
| 369 |
+
msg = (
|
| 370 |
+
"It is strongly recommended to train T5Gemma models with the `eager` attention implementation "
|
| 371 |
+
f"instead of `{self.config._attn_implementation}`. Use `eager` with `AutoModelForCausalLM.from_pretrained('<path-to-checkpoint>', attn_implementation='eager')`."
|
| 372 |
+
)
|
| 373 |
+
if is_torchdynamo_compiling():
|
| 374 |
+
raise ValueError(msg)
|
| 375 |
+
else:
|
| 376 |
+
logger.warning_once(msg)
|
| 377 |
+
|
| 378 |
+
if labels is not None and decoder_input_ids is None and decoder_inputs_embeds is None:
|
| 379 |
+
# get decoder inputs from shifting lm labels to the right
|
| 380 |
+
decoder_input_ids = self._shift_right(labels)
|
| 381 |
+
|
| 382 |
+
#if input_ids is not None:
|
| 383 |
+
# inputs_embeds = self.embeddings(input_ids)
|
| 384 |
+
|
| 385 |
+
#if attention_mask is not None:
|
| 386 |
+
# B, L = attention_mask.shape
|
| 387 |
+
# block_mask = attention_mask.view(B, L // 8, 8)
|
| 388 |
+
# mask2 = block_mask.any(dim=-1).long()
|
| 389 |
+
# attention_mask = mask2.view(B, -1)
|
| 390 |
+
|
| 391 |
+
#print(attention_mask)
|
| 392 |
+
|
| 393 |
+
decoder_outputs: Seq2SeqModelOutput = self.model(
|
| 394 |
+
input_ids=input_ids,
|
| 395 |
+
attention_mask=attention_mask,
|
| 396 |
+
position_ids=position_ids,
|
| 397 |
+
decoder_input_ids=decoder_input_ids,
|
| 398 |
+
decoder_attention_mask=decoder_attention_mask,
|
| 399 |
+
decoder_position_ids=decoder_position_ids,
|
| 400 |
+
encoder_outputs=encoder_outputs,
|
| 401 |
+
past_key_values=past_key_values,
|
| 402 |
+
inputs_embeds=inputs_embeds,
|
| 403 |
+
decoder_inputs_embeds=decoder_inputs_embeds,
|
| 404 |
+
use_cache=use_cache,
|
| 405 |
+
cache_position=cache_position,
|
| 406 |
+
**kwargs,
|
| 407 |
+
)
|
| 408 |
+
|
| 409 |
+
hidden_states = decoder_outputs.last_hidden_state
|
| 410 |
+
# Only compute necessary logits, and do not upcast them to float if we are not computing the loss
|
| 411 |
+
slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
|
| 412 |
+
logits = self.lm_head(hidden_states[:, slice_indices, :])
|
| 413 |
+
decoder_config = self.get_decoder().config
|
| 414 |
+
if decoder_config.final_logit_softcapping is not None:
|
| 415 |
+
logits = logits / decoder_config.final_logit_softcapping
|
| 416 |
+
logits = torch.tanh(logits)
|
| 417 |
+
logits = logits * decoder_config.final_logit_softcapping
|
| 418 |
+
|
| 419 |
+
loss = None
|
| 420 |
+
if labels is not None:
|
| 421 |
+
# Input has right-shifted so we directly perform masked lm loss
|
| 422 |
+
loss = self.loss_function(logits, labels, self.vocab_size, **kwargs)
|
| 423 |
+
|
| 424 |
+
return Seq2SeqLMOutput(
|
| 425 |
+
loss=loss,
|
| 426 |
+
logits=logits,
|
| 427 |
+
past_key_values=decoder_outputs.past_key_values,
|
| 428 |
+
decoder_hidden_states=decoder_outputs.decoder_hidden_states,
|
| 429 |
+
decoder_attentions=decoder_outputs.decoder_attentions,
|
| 430 |
+
cross_attentions=decoder_outputs.cross_attentions,
|
| 431 |
+
encoder_last_hidden_state=decoder_outputs.encoder_last_hidden_state,
|
| 432 |
+
encoder_hidden_states=decoder_outputs.encoder_hidden_states,
|
| 433 |
+
encoder_attentions=decoder_outputs.encoder_attentions,
|
| 434 |
+
)
|
| 435 |
+
|
| 436 |
+
def prepare_decoder_input_ids_from_labels(self, labels: torch.Tensor):
|
| 437 |
+
return self._shift_right(labels)
|
| 438 |
+
|
| 439 |
+
|
| 440 |
+
|
| 441 |
+
if __name__ == "__main__":
|
| 442 |
+
config = PianoT5GemmaConfig()
|
| 443 |
+
test = PianoEncoderEmbeddings(config)
|
| 444 |
+
model = PianoT5Gemma(config).cuda()
|
| 445 |
+
#encoder_config = T5GemmaModuleConfig(num_hidden_layers=1)
|
| 446 |
+
#decoder_config = T5GemmaModuleConfig(num_hidden_layers=1)
|
| 447 |
+
#config = T5GemmaConfig(encoder_config, decoder_config, attn_implementation='eager')
|
| 448 |
+
|
| 449 |
+
#model = T5GemmaForConditionalGeneration(config).cuda()
|
| 450 |
+
|
| 451 |
+
toy_ids = torch.tensor([[1,2,3,4,1,2,3,4,1,2,3,4,1,2,3,4]], dtype=torch.long).cuda()
|
| 452 |
+
#tokenizer = AutoTokenizer.from_pretrained("google/t5gemma-2b-2b-ul2")
|
| 453 |
+
#input_text = "Write me a poem about Machine Learning. Answer:"
|
| 454 |
+
#input_ids = tokenizer(input_text, return_tensors="pt").to("cuda")
|
| 455 |
+
print(model.generate(toy_ids, decoder_input_ids=toy_ids, max_new_tokens=32))
|
| 456 |
+
|
| 457 |
+
|
| 458 |
+
#print(model(input_ids=toy_ids, decoder_input_ids=toy_ids).logits.shape)
|
| 459 |
+
|
src/utils/func.py
ADDED
|
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import inspect
|
| 2 |
+
|
| 3 |
+
def filter_valid_args(arg_dict, class_type):
|
| 4 |
+
valid_keys = inspect.signature(class_type).parameters.keys()
|
| 5 |
+
return {k: v for k, v in arg_dict.items() if k in valid_keys}
|
src/utils/midi.py
ADDED
|
@@ -0,0 +1,602 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from miditoolkit import MidiFile, Note, Instrument, TempoChange, ControlChange
|
| 2 |
+
import bisect
|
| 3 |
+
import numpy as np
|
| 4 |
+
import os
|
| 5 |
+
from copy import copy
|
| 6 |
+
import random
|
| 7 |
+
from collections import defaultdict
|
| 8 |
+
|
| 9 |
+
"""
|
| 10 |
+
def normalize_midi(midi_obj, target_ticks_per_beat = 500, target_tempo = 120):
|
| 11 |
+
ticks_per_beat = midi_obj.ticks_per_beat
|
| 12 |
+
merged_events = []
|
| 13 |
+
for i in range(len(midi_obj.instruments)):
|
| 14 |
+
filter_control_changes = []
|
| 15 |
+
for cc in midi_obj.instruments[i].control_changes:
|
| 16 |
+
if cc.number == 64:
|
| 17 |
+
filter_control_changes.append(cc)
|
| 18 |
+
merged_events.extend(midi_obj.instruments[i].notes + filter_control_changes)
|
| 19 |
+
merged_events.sort(key=lambda x: (x.start, x.pitch) if isinstance(x, Note) else (x.time, x.number))
|
| 20 |
+
|
| 21 |
+
time_interval = []
|
| 22 |
+
last_time = 0
|
| 23 |
+
for note in merged_events:
|
| 24 |
+
if isinstance(note, Note):
|
| 25 |
+
time_interval.append(note.start - last_time)
|
| 26 |
+
last_time = note.start
|
| 27 |
+
else:
|
| 28 |
+
time_interval.append(note.time - last_time)
|
| 29 |
+
last_time = note.time
|
| 30 |
+
|
| 31 |
+
output_notes = []
|
| 32 |
+
output_cc = []
|
| 33 |
+
ind = -1
|
| 34 |
+
now_tempo = 120
|
| 35 |
+
now_time = 0
|
| 36 |
+
for i, note in enumerate(merged_events):
|
| 37 |
+
if isinstance(note, Note):
|
| 38 |
+
time = note.start
|
| 39 |
+
else:
|
| 40 |
+
time = note.time
|
| 41 |
+
while ind + 1 < len(midi_obj.tempo_changes) and time >= midi_obj.tempo_changes[ind+1].time:
|
| 42 |
+
now_tempo = midi_obj.tempo_changes[ind+1].tempo
|
| 43 |
+
ind += 1
|
| 44 |
+
ratio = target_ticks_per_beat * target_tempo / now_tempo / ticks_per_beat
|
| 45 |
+
start_time = time_interval[i] * ratio + now_time
|
| 46 |
+
if isinstance(note, Note):
|
| 47 |
+
end_time = (note.end - note.start) * ratio + start_time
|
| 48 |
+
output_notes.append(Note(note.velocity, note.pitch, round(start_time), round(end_time)))
|
| 49 |
+
else:
|
| 50 |
+
output_cc.append(ControlChange(64, note.value, round(start_time)))
|
| 51 |
+
now_time = round(start_time)
|
| 52 |
+
|
| 53 |
+
output_midi_obj = MidiFile(ticks_per_beat=target_ticks_per_beat)
|
| 54 |
+
output_midi_obj.instruments.append(Instrument(program=0, is_drum=False, name="Piano", notes=output_notes, control_changes=output_cc))
|
| 55 |
+
output_midi_obj.tempo_changes.append(TempoChange(target_tempo, 0))
|
| 56 |
+
for note in output_notes:
|
| 57 |
+
output_midi_obj.max_tick = max(output_midi_obj.max_tick, note.end)
|
| 58 |
+
for cc in output_cc:
|
| 59 |
+
output_midi_obj.max_tick = max(output_midi_obj.max_tick, cc.time)
|
| 60 |
+
return output_midi_obj
|
| 61 |
+
"""
|
| 62 |
+
|
| 63 |
+
"""
|
| 64 |
+
def normalize_midi(midi_obj, target_ticks_per_beat=500, target_tempo=120):
|
| 65 |
+
# 创建一个新的、干净的MidiFile对象用于输出
|
| 66 |
+
output_midi_obj = MidiFile(ticks_per_beat=target_ticks_per_beat)
|
| 67 |
+
output_midi_obj.tempo_changes.append(TempoChange(target_tempo, 0))
|
| 68 |
+
|
| 69 |
+
# 获取原始MIDI的tick到秒的精确映射
|
| 70 |
+
# 这是最关键的一步,partitura和miditoolkit都有类似功能
|
| 71 |
+
# miditoolkit的get_tick_to_time_mapping()可以处理所有tempo变化
|
| 72 |
+
tick_to_time_map = midi_obj.get_tick_to_time_mapping()
|
| 73 |
+
|
| 74 |
+
# 计算从秒转换回目标tick的比例因子
|
| 75 |
+
# 目标MIDI中,每秒对应的tick数 = target_ticks_per_beat * (target_tempo / 60)
|
| 76 |
+
seconds_to_target_ticks_factor = target_ticks_per_beat * (target_tempo / 60.0)
|
| 77 |
+
|
| 78 |
+
merged_notes = []
|
| 79 |
+
merged_cc = []
|
| 80 |
+
|
| 81 |
+
# 遍历所有乐器轨道
|
| 82 |
+
for instrument in midi_obj.instruments:
|
| 83 |
+
# 只处理非鼓组的乐器
|
| 84 |
+
if not instrument.is_drum:
|
| 85 |
+
# --- 处理音符 (Notes) ---
|
| 86 |
+
for note in instrument.notes:
|
| 87 |
+
# 1. 将原始tick转换为绝对秒数
|
| 88 |
+
start_time_sec = tick_to_time_map[note.start]
|
| 89 |
+
end_time_sec = tick_to_time_map[note.end]
|
| 90 |
+
|
| 91 |
+
# 2. 将绝对秒数转换为目标tick
|
| 92 |
+
new_start_tick = round(start_time_sec * seconds_to_target_ticks_factor)
|
| 93 |
+
new_end_tick = round(end_time_sec * seconds_to_target_ticks_factor)
|
| 94 |
+
|
| 95 |
+
# 避免duration为0的音符
|
| 96 |
+
if new_start_tick == new_end_tick:
|
| 97 |
+
new_end_tick += 1
|
| 98 |
+
|
| 99 |
+
merged_notes.append(Note(velocity=note.velocity,
|
| 100 |
+
pitch=note.pitch,
|
| 101 |
+
start=new_start_tick,
|
| 102 |
+
end=new_end_tick))
|
| 103 |
+
|
| 104 |
+
# --- 处理延音踏板 (CC #64) ---
|
| 105 |
+
for cc in instrument.control_changes:
|
| 106 |
+
if cc.number == 64:
|
| 107 |
+
# 1. 将原始tick转换为绝对秒数
|
| 108 |
+
time_sec = tick_to_time_map[cc.time]
|
| 109 |
+
|
| 110 |
+
# 2. 将绝对秒数转换为目标tick
|
| 111 |
+
new_time_tick = round(time_sec * seconds_to_target_ticks_factor)
|
| 112 |
+
|
| 113 |
+
merged_cc.append(ControlChange(number=64,
|
| 114 |
+
value=cc.value,
|
| 115 |
+
time=new_time_tick))
|
| 116 |
+
|
| 117 |
+
# --- 排序并创建新乐器 ---
|
| 118 |
+
# 按开始时间排序,对于同时开始的事件,CC优先于Note
|
| 119 |
+
merged_notes.sort(key=lambda x: (x.start, x.pitch))
|
| 120 |
+
merged_cc.sort(key=lambda x: (x.time, x.number))
|
| 121 |
+
|
| 122 |
+
output_instrument = Instrument(program=0, is_drum=False, name="Piano")
|
| 123 |
+
output_instrument.notes = merged_notes
|
| 124 |
+
output_instrument.control_changes = merged_cc
|
| 125 |
+
output_midi_obj.instruments.append(output_instrument)
|
| 126 |
+
|
| 127 |
+
# --- 正确计算 max_tick ---
|
| 128 |
+
max_tick = 0
|
| 129 |
+
if output_instrument.notes:
|
| 130 |
+
max_tick = max(max_tick, max(n.end for n in output_instrument.notes))
|
| 131 |
+
if output_instrument.control_changes:
|
| 132 |
+
max_tick = max(max_tick, max(c.time for c in output_instrument.control_changes))
|
| 133 |
+
|
| 134 |
+
output_midi_obj.max_tick = max_tick
|
| 135 |
+
|
| 136 |
+
return output_midi_obj
|
| 137 |
+
"""
|
| 138 |
+
|
| 139 |
+
def normalize_midi(midi_obj, target_ticks_per_beat=500, target_tempo=120):
|
| 140 |
+
"""
|
| 141 |
+
将一个MidiFile对象标准化:
|
| 142 |
+
1. 合并所有轨道的钢琴音符和延音踏板事件。
|
| 143 |
+
2. 将所有时间信息(包括tempo变化)统一转换为一个固定的ticks_per_beat和tempo。
|
| 144 |
+
3. 清理重叠音符以避免解析错误。
|
| 145 |
+
4. 正确计算并设置max_tick。
|
| 146 |
+
|
| 147 |
+
Args:
|
| 148 |
+
midi_obj (MidiFile): 原始的MidiFile对象。
|
| 149 |
+
target_ticks_per_beat (int): 目标ticks_per_beat.
|
| 150 |
+
target_tempo (float): 目标tempo (BPM).
|
| 151 |
+
|
| 152 |
+
Returns:
|
| 153 |
+
MidiFile: 标准化后的新MidiFile对象。
|
| 154 |
+
"""
|
| 155 |
+
|
| 156 |
+
# 创建一个新的、干净的MidiFile对象用于输出
|
| 157 |
+
output_midi_obj = MidiFile(ticks_per_beat=target_ticks_per_beat)
|
| 158 |
+
output_midi_obj.tempo_changes.append(TempoChange(target_tempo, 0))
|
| 159 |
+
|
| 160 |
+
tick_to_time_map = midi_obj.get_tick_to_time_mapping()
|
| 161 |
+
seconds_to_target_ticks_factor = target_ticks_per_beat * (target_tempo / 60.0)
|
| 162 |
+
|
| 163 |
+
# --- 1. 收集并转换所有音符 ---
|
| 164 |
+
all_converted_notes = []
|
| 165 |
+
for instrument in midi_obj.instruments:
|
| 166 |
+
if not instrument.is_drum:
|
| 167 |
+
for note in instrument.notes:
|
| 168 |
+
start_time_sec = tick_to_time_map[note.start]
|
| 169 |
+
end_time_sec = tick_to_time_map[note.end]
|
| 170 |
+
|
| 171 |
+
new_start_tick = round(start_time_sec * seconds_to_target_ticks_factor)
|
| 172 |
+
new_end_tick = round(end_time_sec * seconds_to_target_ticks_factor)
|
| 173 |
+
|
| 174 |
+
if new_start_tick >= new_end_tick:
|
| 175 |
+
# 确保音符至少有1 tick的长度
|
| 176 |
+
new_end_tick = new_start_tick + 1
|
| 177 |
+
|
| 178 |
+
all_converted_notes.append(Note(velocity=note.velocity,
|
| 179 |
+
pitch=note.pitch,
|
| 180 |
+
start=new_start_tick,
|
| 181 |
+
end=new_end_tick))
|
| 182 |
+
|
| 183 |
+
# --- 2. 清理重叠音符 (关键新增部分) ---
|
| 184 |
+
# 首先按音高分组,然后按开始时间排序
|
| 185 |
+
notes_by_pitch = defaultdict(list)
|
| 186 |
+
for note in all_converted_notes:
|
| 187 |
+
notes_by_pitch[note.pitch].append(note)
|
| 188 |
+
|
| 189 |
+
merged_notes = []
|
| 190 |
+
for pitch in sorted(notes_by_pitch.keys()):
|
| 191 |
+
# 对每个音高的音符列表按开始时间排序
|
| 192 |
+
sorted_notes = sorted(notes_by_pitch[pitch], key=lambda n: n.start)
|
| 193 |
+
|
| 194 |
+
# 迭代并修复重叠
|
| 195 |
+
if len(sorted_notes) > 1:
|
| 196 |
+
for i in range(len(sorted_notes) - 1):
|
| 197 |
+
current_note = sorted_notes[i]
|
| 198 |
+
next_note = sorted_notes[i+1]
|
| 199 |
+
|
| 200 |
+
# 如果当前音符的结束时间晚于或等于下一个音符的开始时间
|
| 201 |
+
if current_note.end >= next_note.start:
|
| 202 |
+
# 修正当前音符的结束时间,让它在下一个音符开始前结束
|
| 203 |
+
# 我们可以让它在下一个音符开始时就结束
|
| 204 |
+
current_note.end = next_note.start
|
| 205 |
+
# 如果修复后导致时长为0,则丢弃该音符(或者设置为1 tick,这里选择前者更干净)
|
| 206 |
+
if current_note.start >= current_note.end:
|
| 207 |
+
# 标记为待删除,而不是直接删除,以避免迭代问题
|
| 208 |
+
current_note.pitch = -1 # 用一个无效音高作为标记
|
| 209 |
+
|
| 210 |
+
# 将处理过的(且未被标记删除的)音符添加到最终列表
|
| 211 |
+
merged_notes.extend([n for n in sorted_notes if n.pitch != -1])
|
| 212 |
+
|
| 213 |
+
# --- 3. 收集并转换CC事件 ---
|
| 214 |
+
merged_cc = []
|
| 215 |
+
for instrument in midi_obj.instruments:
|
| 216 |
+
if not instrument.is_drum:
|
| 217 |
+
for cc in instrument.control_changes:
|
| 218 |
+
if cc.number == 64:
|
| 219 |
+
time_sec = tick_to_time_map[cc.time]
|
| 220 |
+
new_time_tick = round(time_sec * seconds_to_target_ticks_factor)
|
| 221 |
+
merged_cc.append(ControlChange(number=64,
|
| 222 |
+
value=cc.value,
|
| 223 |
+
time=new_time_tick))
|
| 224 |
+
|
| 225 |
+
# --- 4. 排序并创建新乐器 ---
|
| 226 |
+
merged_notes.sort(key=lambda x: (x.start, x.pitch))
|
| 227 |
+
merged_cc.sort(key=lambda x: (x.time, x.number))
|
| 228 |
+
|
| 229 |
+
output_instrument = Instrument(program=0, is_drum=False, name="Piano")
|
| 230 |
+
output_instrument.notes = merged_notes
|
| 231 |
+
output_instrument.control_changes = merged_cc
|
| 232 |
+
output_midi_obj.instruments.append(output_instrument)
|
| 233 |
+
|
| 234 |
+
# --- 5. 正确计算 max_tick ---
|
| 235 |
+
max_tick = 0
|
| 236 |
+
if output_instrument.notes:
|
| 237 |
+
max_tick = max(max_tick, max(n.end for n in output_instrument.notes if n.end is not None))
|
| 238 |
+
if output_instrument.control_changes:
|
| 239 |
+
max_tick = max(max_tick, max(c.time for c in output_instrument.control_changes if c.time is not None))
|
| 240 |
+
|
| 241 |
+
# 添加一个小的buffer,确保最后一个事件不会被截断
|
| 242 |
+
output_midi_obj.max_tick = max_tick + target_ticks_per_beat
|
| 243 |
+
|
| 244 |
+
return output_midi_obj
|
| 245 |
+
|
| 246 |
+
def midi_to_ids(config, midi_obj, normalize=True):
|
| 247 |
+
def get_pedal(time_list, ccs, time):
|
| 248 |
+
i = bisect.bisect_right(time_list, time)
|
| 249 |
+
if i == 0:
|
| 250 |
+
return 0
|
| 251 |
+
else:
|
| 252 |
+
return ccs[i-1].value
|
| 253 |
+
if normalize:
|
| 254 |
+
norm_midi_obj = normalize_midi(midi_obj)
|
| 255 |
+
else:
|
| 256 |
+
norm_midi_obj = midi_obj
|
| 257 |
+
time_list = [cc.time for cc in norm_midi_obj.instruments[0].control_changes]
|
| 258 |
+
#print(time_list)
|
| 259 |
+
intervals = []
|
| 260 |
+
last_time = 0
|
| 261 |
+
for note in norm_midi_obj.instruments[0].notes:
|
| 262 |
+
intervals.append(note.start - last_time)
|
| 263 |
+
last_time = note.start
|
| 264 |
+
intervals.append(4990)
|
| 265 |
+
|
| 266 |
+
ids = []
|
| 267 |
+
last_time = 0
|
| 268 |
+
for i, note in enumerate(norm_midi_obj.instruments[0].notes):
|
| 269 |
+
interval = config.timing_start + intervals[i]
|
| 270 |
+
#print(interval - interval_start)
|
| 271 |
+
|
| 272 |
+
pitch = config.pitch_start + note.pitch
|
| 273 |
+
velocity = config.velocity_start + note.velocity
|
| 274 |
+
duration = config.timing_start + note.duration
|
| 275 |
+
last_time = last_time + intervals[i]
|
| 276 |
+
|
| 277 |
+
pedal1 = config.pedal_start + get_pedal(time_list, norm_midi_obj.instruments[0].control_changes, last_time)
|
| 278 |
+
pedal2 = config.pedal_start + get_pedal(time_list, norm_midi_obj.instruments[0].control_changes, last_time + intervals[i+1] * 1 / 4)
|
| 279 |
+
pedal3 = config.pedal_start + get_pedal(time_list, norm_midi_obj.instruments[0].control_changes, last_time + intervals[i+1] * 2 / 4)
|
| 280 |
+
pedal4 = config.pedal_start + get_pedal(time_list, norm_midi_obj.instruments[0].control_changes, last_time + intervals[i+1] * 3 / 4)
|
| 281 |
+
|
| 282 |
+
pitch = min(config.valid_id_range[0][1] - 1, max(config.valid_id_range[0][0], pitch))
|
| 283 |
+
interval = min(config.valid_id_range[1][1] - 1, max(config.valid_id_range[1][0], interval))
|
| 284 |
+
velocity = min(config.valid_id_range[2][1] - 1, max(config.valid_id_range[2][0], velocity))
|
| 285 |
+
duration = min(config.valid_id_range[3][1] - 1, max(config.valid_id_range[3][0], duration))
|
| 286 |
+
pedal1 = min(config.valid_id_range[4][1] - 1, max(config.valid_id_range[4][0], pedal1))
|
| 287 |
+
pedal2 = min(config.valid_id_range[5][1] - 1, max(config.valid_id_range[5][0], pedal2))
|
| 288 |
+
pedal3 = min(config.valid_id_range[6][1] - 1, max(config.valid_id_range[6][0], pedal3))
|
| 289 |
+
pedal4 = min(config.valid_id_range[7][1] - 1, max(config.valid_id_range[7][0], pedal4))
|
| 290 |
+
|
| 291 |
+
ids.extend([pitch, interval, velocity, duration, pedal1, pedal2, pedal3, pedal4])
|
| 292 |
+
return ids
|
| 293 |
+
|
| 294 |
+
def ids_to_midi(config, ids, target_ticks_per_beat = 500, target_tempo = 120):
|
| 295 |
+
note_list = []
|
| 296 |
+
cc_list = []
|
| 297 |
+
intervals = []
|
| 298 |
+
for i in range(0, len(ids), 8):
|
| 299 |
+
intervals.append(ids[i+1] - config.timing_start)
|
| 300 |
+
intervals.append(4990)
|
| 301 |
+
|
| 302 |
+
last_time = 0
|
| 303 |
+
for i in range(0, len(ids), 8):
|
| 304 |
+
interval = intervals[i // 8]
|
| 305 |
+
pitch = ids[i] - config.pitch_start
|
| 306 |
+
velocity = ids[i+2] - config.velocity_start
|
| 307 |
+
duration = ids[i+3] - config.timing_start
|
| 308 |
+
pedal1 = ids[i+4] - config.pedal_start
|
| 309 |
+
pedal2 = ids[i+5] - config.pedal_start
|
| 310 |
+
pedal3 = ids[i+6] - config.pedal_start
|
| 311 |
+
pedal4 = ids[i+7] - config.pedal_start
|
| 312 |
+
note_list.append(Note(velocity, pitch, last_time + interval, last_time + interval + duration))
|
| 313 |
+
last_time += interval
|
| 314 |
+
#cc_list.append(ControlChange(64, pedal1, last_time))
|
| 315 |
+
#cc_list.append(ControlChange(64, pedal2, round(last_time + min(intervals[i // 8 + 1] * 1 / 10, 5))))
|
| 316 |
+
#cc_list.append(ControlChange(64, pedal3, round(last_time + max(intervals[i // 8 + 1] * 8 / 10, intervals[i // 8 + 1] * 8 / 10 - 10))))
|
| 317 |
+
#cc_list.append(ControlChange(64, pedal4, round(last_time + max(intervals[i // 8 + 1] * 9 / 10, intervals[i // 8 + 1] * 9 / 10 - 5))))
|
| 318 |
+
cc_list.append(ControlChange(64, pedal1, last_time))
|
| 319 |
+
cc_list.append(ControlChange(64, pedal2, round(last_time + intervals[i // 8 + 1] * 1 / 4)))
|
| 320 |
+
cc_list.append(ControlChange(64, pedal3, round(last_time + intervals[i // 8 + 1] * 2 / 4)))
|
| 321 |
+
cc_list.append(ControlChange(64, pedal4, round(last_time + intervals[i // 8 + 1] * 3 / 4)))
|
| 322 |
+
|
| 323 |
+
max_tick = 0
|
| 324 |
+
for note in note_list:
|
| 325 |
+
max_tick = max(max_tick, note.end)
|
| 326 |
+
for cc in cc_list:
|
| 327 |
+
max_tick = max(max_tick, cc.time)
|
| 328 |
+
max_tick = max_tick + 1
|
| 329 |
+
|
| 330 |
+
output = MidiFile(ticks_per_beat=target_ticks_per_beat)
|
| 331 |
+
output.instruments.append(Instrument(program=0, is_drum=False, name="Piano", notes=note_list, control_changes=cc_list))
|
| 332 |
+
output.tempo_changes.append(TempoChange(target_tempo, 0))
|
| 333 |
+
output.max_tick = max_tick
|
| 334 |
+
|
| 335 |
+
return output
|
| 336 |
+
|
| 337 |
+
def read_corresp(corresp_path):
|
| 338 |
+
out = []
|
| 339 |
+
performacne_id_list = []
|
| 340 |
+
with open(corresp_path, "r") as f:
|
| 341 |
+
align_txt = f.readlines()
|
| 342 |
+
|
| 343 |
+
score_ids_map = {}
|
| 344 |
+
performance_ids_map = {}
|
| 345 |
+
score_temp_list = []
|
| 346 |
+
performance_temp_list = set()
|
| 347 |
+
for line in align_txt[1:]:
|
| 348 |
+
informs = line.split("\t")
|
| 349 |
+
if informs[0] != '*':
|
| 350 |
+
score_temp_list.append((float(informs[1]), int(informs[3]), int(informs[0])))
|
| 351 |
+
if informs[5] != '*':
|
| 352 |
+
performance_temp_list.add((float(informs[6]), int(informs[8]), int(informs[5])))
|
| 353 |
+
performance_temp_list = list(performance_temp_list)
|
| 354 |
+
score_temp_list.sort()
|
| 355 |
+
performance_temp_list.sort()
|
| 356 |
+
for i, inform in enumerate(score_temp_list):
|
| 357 |
+
score_ids_map[inform[2]] = i
|
| 358 |
+
for i, inform in enumerate(performance_temp_list):
|
| 359 |
+
performance_ids_map[inform[2]] = i
|
| 360 |
+
|
| 361 |
+
for line in align_txt[1:]:
|
| 362 |
+
informs = line.split("\t")
|
| 363 |
+
if informs[0] == '*':
|
| 364 |
+
break
|
| 365 |
+
if informs[5] != '*':
|
| 366 |
+
out.append((score_ids_map[int(informs[0])], performance_ids_map[int(informs[5])]))
|
| 367 |
+
else:
|
| 368 |
+
out.append((score_ids_map[int(informs[0])], -1))
|
| 369 |
+
|
| 370 |
+
for line in align_txt[1:]:
|
| 371 |
+
informs = line.split("\t")
|
| 372 |
+
if informs[5] != '*':
|
| 373 |
+
performacne_id_list.append(performance_ids_map[int(informs[5])])
|
| 374 |
+
if out[0][1] == -1:
|
| 375 |
+
out[0] = (out[0][0], min(performacne_id_list))
|
| 376 |
+
if out[-1][1] == -1:
|
| 377 |
+
out[-1] = (out[-1][0], max(performacne_id_list))
|
| 378 |
+
out.sort()
|
| 379 |
+
return out
|
| 380 |
+
|
| 381 |
+
def interpolate(a, b):
|
| 382 |
+
a = np.array(a) + np.linspace(0, 1e-5, len(a))
|
| 383 |
+
b = np.array(b)
|
| 384 |
+
known_inds = np.where(~np.isnan(b))[0]
|
| 385 |
+
x_known = a[known_inds]
|
| 386 |
+
y_known = b[known_inds]
|
| 387 |
+
res = np.interp(a, x_known, y_known)
|
| 388 |
+
res[known_inds] = b[known_inds]
|
| 389 |
+
return [round(i) for i in res.tolist()]
|
| 390 |
+
|
| 391 |
+
def segment_sequences(x, label, unknown_ids, total_notes, max_consecutive_missing, min_segment_notes):
|
| 392 |
+
|
| 393 |
+
if not unknown_ids:
|
| 394 |
+
if total_notes >= min_segment_notes:
|
| 395 |
+
return [x], [label]
|
| 396 |
+
else:
|
| 397 |
+
return [], []
|
| 398 |
+
|
| 399 |
+
x_segments = []
|
| 400 |
+
label_segments = []
|
| 401 |
+
|
| 402 |
+
unknown_set = set(unknown_ids)
|
| 403 |
+
|
| 404 |
+
last_cut_note_idx = 0
|
| 405 |
+
consecutive_missing_count = 0
|
| 406 |
+
|
| 407 |
+
for i in range(total_notes):
|
| 408 |
+
if i in unknown_set:
|
| 409 |
+
consecutive_missing_count += 1
|
| 410 |
+
else:
|
| 411 |
+
consecutive_missing_count = 0
|
| 412 |
+
|
| 413 |
+
if consecutive_missing_count >= max_consecutive_missing:
|
| 414 |
+
segment_end_note_idx = i - consecutive_missing_count + 1
|
| 415 |
+
|
| 416 |
+
if segment_end_note_idx - last_cut_note_idx >= min_segment_notes:
|
| 417 |
+
start_token = last_cut_note_idx * 8
|
| 418 |
+
end_token = segment_end_note_idx * 8
|
| 419 |
+
|
| 420 |
+
x_segments.append(x[start_token:end_token])
|
| 421 |
+
label_segments.append(label[start_token:end_token])
|
| 422 |
+
|
| 423 |
+
last_cut_note_idx = i + 1
|
| 424 |
+
consecutive_missing_count = 0
|
| 425 |
+
|
| 426 |
+
if total_notes - last_cut_note_idx >= min_segment_notes:
|
| 427 |
+
start_token = last_cut_note_idx * 8
|
| 428 |
+
x_segments.append(x[start_token:])
|
| 429 |
+
label_segments.append(label[start_token:])
|
| 430 |
+
|
| 431 |
+
return x_segments, label_segments
|
| 432 |
+
|
| 433 |
+
def align_score_and_performance(config, score_midi_obj, performance_midi_obj):
|
| 434 |
+
norm_score_midi_obj = normalize_midi(score_midi_obj)
|
| 435 |
+
norm_performance_midi_obj = normalize_midi(performance_midi_obj)
|
| 436 |
+
|
| 437 |
+
norm_score_midi_obj.dump("temp/score.mid")
|
| 438 |
+
norm_performance_midi_obj.dump("temp/performance.mid")
|
| 439 |
+
|
| 440 |
+
os.chdir("./tools/AlignmentTool")
|
| 441 |
+
os.system(f"timeout 120s ./MIDIToMIDIAlign.sh ../../temp/performance ../../temp/score")
|
| 442 |
+
os.chdir("./../../")
|
| 443 |
+
|
| 444 |
+
corresp_list = read_corresp("temp/score_corresp.txt")
|
| 445 |
+
aligned_midi_obj = MidiFile(ticks_per_beat=500)
|
| 446 |
+
score_notes = norm_score_midi_obj.instruments[0].notes
|
| 447 |
+
performance_notes = norm_performance_midi_obj.instruments[0].notes
|
| 448 |
+
score_start_list = []
|
| 449 |
+
output_notes = []
|
| 450 |
+
output_ccs = []
|
| 451 |
+
vel_list = []
|
| 452 |
+
start_list = []
|
| 453 |
+
duration_list = []
|
| 454 |
+
unknown_ids = []
|
| 455 |
+
for i, ids in enumerate(corresp_list):
|
| 456 |
+
if ids[1] != -1:
|
| 457 |
+
vel_list.append(performance_notes[ids[1]].velocity)
|
| 458 |
+
start_list.append(performance_notes[ids[1]].start)
|
| 459 |
+
duration_list.append(performance_notes[ids[1]].end - performance_notes[ids[1]].start)
|
| 460 |
+
else:
|
| 461 |
+
vel_list.append(np.nan)
|
| 462 |
+
duration_list.append(np.nan)
|
| 463 |
+
unknown_ids.append(i)
|
| 464 |
+
score_start_list.append(score_notes[ids[0]].start)
|
| 465 |
+
start_list.sort()
|
| 466 |
+
temp = []
|
| 467 |
+
cnt = 0
|
| 468 |
+
for i in range(len(corresp_list)):
|
| 469 |
+
if i not in unknown_ids:
|
| 470 |
+
temp.append(start_list[cnt])
|
| 471 |
+
cnt += 1
|
| 472 |
+
else:
|
| 473 |
+
temp.append(np.nan)
|
| 474 |
+
start_list = interpolate(score_start_list, temp)
|
| 475 |
+
vel_list = interpolate(start_list, vel_list)
|
| 476 |
+
duration_list = interpolate(start_list, duration_list)
|
| 477 |
+
|
| 478 |
+
end_list = []
|
| 479 |
+
for i, ids in enumerate(corresp_list):
|
| 480 |
+
end = start_list[i]+duration_list[i]
|
| 481 |
+
end_list.append(end)
|
| 482 |
+
output_notes.append(Note(vel_list[i], score_notes[ids[0]].pitch, start_list[i], end))
|
| 483 |
+
max_tick = max(end_list) + 4999
|
| 484 |
+
for cc in norm_performance_midi_obj.instruments[0].control_changes:
|
| 485 |
+
if cc.time <= max_tick:
|
| 486 |
+
output_ccs.append(cc)
|
| 487 |
+
else:
|
| 488 |
+
break
|
| 489 |
+
|
| 490 |
+
aligned_midi_obj.instruments.append(Instrument(program=0, is_drum=False, name="Piano", notes=output_notes, control_changes=output_ccs))
|
| 491 |
+
x = midi_to_ids(config, norm_score_midi_obj)
|
| 492 |
+
label = midi_to_ids(config, aligned_midi_obj, normalize=False)
|
| 493 |
+
assert(len(x) == len(label))
|
| 494 |
+
for i in range(len(x)):
|
| 495 |
+
if i % 8 == 0:
|
| 496 |
+
assert(x[i] == label[i])
|
| 497 |
+
|
| 498 |
+
total_notes = len(score_notes)
|
| 499 |
+
xs, labels = segment_sequences(
|
| 500 |
+
x,
|
| 501 |
+
label,
|
| 502 |
+
unknown_ids,
|
| 503 |
+
total_notes,
|
| 504 |
+
5,
|
| 505 |
+
64,
|
| 506 |
+
)
|
| 507 |
+
return xs, labels
|
| 508 |
+
|
| 509 |
+
def enhanced_ids(config, ids):
|
| 510 |
+
res = copy(ids)
|
| 511 |
+
retry = 10
|
| 512 |
+
for i in range(len(res)):
|
| 513 |
+
j = i % 8
|
| 514 |
+
if j == 3:
|
| 515 |
+
value = res[i] - config.valid_id_range[j][0]
|
| 516 |
+
if value == 10:
|
| 517 |
+
noise = 0
|
| 518 |
+
for _ in range(retry):
|
| 519 |
+
n = round(np.random.randn() * 5)
|
| 520 |
+
if n >= -9 and n <= 5:
|
| 521 |
+
noise = n
|
| 522 |
+
break
|
| 523 |
+
else:
|
| 524 |
+
noise = 0
|
| 525 |
+
for _ in range(retry):
|
| 526 |
+
n = round(np.random.randn() * 5)
|
| 527 |
+
if n >= -4 and n <= 5:
|
| 528 |
+
noise = n
|
| 529 |
+
break
|
| 530 |
+
value = min(max(value + noise, 0), 4999)
|
| 531 |
+
res[i] = config.valid_id_range[j][0] + value
|
| 532 |
+
elif j == 2:
|
| 533 |
+
value = res[i] - config.valid_id_range[j][0]
|
| 534 |
+
if value == 5:
|
| 535 |
+
noise = 0
|
| 536 |
+
for _ in range(retry):
|
| 537 |
+
n = round(np.random.randn() * 2.5)
|
| 538 |
+
if n >= -4 and n <= 2:
|
| 539 |
+
noise = n
|
| 540 |
+
break
|
| 541 |
+
elif value == 120:
|
| 542 |
+
noise = 0
|
| 543 |
+
for _ in range(retry):
|
| 544 |
+
n = round(np.random.randn() * 2.5)
|
| 545 |
+
if n >= -2 and n <= 7:
|
| 546 |
+
noise = n
|
| 547 |
+
break
|
| 548 |
+
else:
|
| 549 |
+
noise = 0
|
| 550 |
+
for _ in range(retry):
|
| 551 |
+
n = round(np.random.randn() * 2.5)
|
| 552 |
+
if n >= -2 and n <= 2:
|
| 553 |
+
noise = n
|
| 554 |
+
break
|
| 555 |
+
value = min(max(value + noise, 0), 127)
|
| 556 |
+
res[i] = config.valid_id_range[j][0] + value
|
| 557 |
+
elif j == 1:
|
| 558 |
+
value = res[i] - config.valid_id_range[j][0]
|
| 559 |
+
noise = 0
|
| 560 |
+
for _ in range(retry):
|
| 561 |
+
n = round(np.random.randn() * 5)
|
| 562 |
+
if n >= -4 and n <= 5:
|
| 563 |
+
noise = n
|
| 564 |
+
break
|
| 565 |
+
value = min(max(value + noise, 0), 4990)
|
| 566 |
+
res[i] = config.valid_id_range[j][0] + value
|
| 567 |
+
return res
|
| 568 |
+
|
| 569 |
+
def enhanced_ids_uniform(config, ids):
|
| 570 |
+
res = copy(ids)
|
| 571 |
+
for i in range(len(res)):
|
| 572 |
+
j = i % 8
|
| 573 |
+
if j == 3:
|
| 574 |
+
value = res[i] - config.valid_id_range[j][0]
|
| 575 |
+
if value == 10:
|
| 576 |
+
noise = random.randint(-9, 5)
|
| 577 |
+
else:
|
| 578 |
+
noise = random.randint(-4, 5)
|
| 579 |
+
value = min(max(value + noise, 0), 4999)
|
| 580 |
+
res[i] = config.valid_id_range[j][0] + value
|
| 581 |
+
elif j == 2:
|
| 582 |
+
value = res[i] - config.valid_id_range[j][0]
|
| 583 |
+
if value == 5:
|
| 584 |
+
noise = random.randint(-4, 2)
|
| 585 |
+
elif value == 120:
|
| 586 |
+
noise = random.randint(-2, 7)
|
| 587 |
+
else:
|
| 588 |
+
noise = random.randint(-2, 2)
|
| 589 |
+
value = min(max(value + noise, 0), 127)
|
| 590 |
+
res[i] = config.valid_id_range[j][0] + value
|
| 591 |
+
elif j == 1:
|
| 592 |
+
value = res[i] - config.valid_id_range[j][0]
|
| 593 |
+
noise = random.randint(-4, 5)
|
| 594 |
+
value = min(max(value + noise, 0), 4990)
|
| 595 |
+
res[i] = config.valid_id_range[j][0] + value
|
| 596 |
+
return res
|
| 597 |
+
|
| 598 |
+
#if __name__ == "__main__":
|
| 599 |
+
# midi_obj = MidiFile("data/midi/test/2.mid")
|
| 600 |
+
# ids = midi_to_ids(midi_obj)
|
| 601 |
+
# midi = ids_to_midi(ids)
|
| 602 |
+
# midi.dump("data/rebuild/2.mid")
|