yhj137 commited on
Commit
f5399d9
·
1 Parent(s): 73e4a98
.gitignore ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ __pycache__
2
+ .DS_Store
app.py CHANGED
@@ -1,10 +1,70 @@
1
  import gradio as gr
 
2
  import os
3
- def greet(name):
4
- return "Hello " + name + "!!"
 
5
 
6
- from transformers import AutoModel, AutoTokenizer
 
 
 
 
 
 
 
 
 
 
 
7
 
8
- model = AutoModel.from_pretrained("yhj137/pianist-transformer-rendering", token=os.environ["hf_token"])
9
- demo = gr.Interface(fn=greet, inputs="text", outputs="text")
10
- demo.launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import gradio as gr
2
+ import torch
3
  import os
4
+ from miditoolkit import MidiFile
5
+ from src.model.generate import batch_performance_render, map_midi
6
+ from src.model.pianoformer import PianoT5Gemma
7
 
8
+ # ------------------------------
9
+ # Load model
10
+ # ------------------------------
11
+ def load_model():
12
+ print("Loading model...")
13
+ model = PianoT5Gemma.from_pretrained(
14
+ "yhj137/pianist-transformer-rendering",
15
+ token=os.environ.get("hf_token"),
16
+ torch_dtype=torch.bfloat16
17
+ )
18
+ model.eval()
19
+ return model
20
 
21
+ model = load_model()
22
+
23
+ # ------------------------------
24
+ # Define inference function
25
+ # ------------------------------
26
+ def render_midi(midi_file, temperature, top_p, top_k):
27
+ # Save uploaded file temporarily
28
+ input_path = midi_file.name
29
+ midi = MidiFile(input_path)
30
+
31
+ # Run inference
32
+ res = batch_performance_render(
33
+ model,
34
+ [midi],
35
+ temperature=temperature,
36
+ top_p=top_p,
37
+ top_k=top_k,
38
+ device="cpu" # change to "cuda" if you use GPU space
39
+ )
40
+
41
+ # Map result and save
42
+ mapped = map_midi(midi, res[0])
43
+ out_path = "output.mid"
44
+ mapped.dump(out_path)
45
+
46
+ return out_path
47
+
48
+ # ------------------------------
49
+ # Build Gradio interface
50
+ # ------------------------------
51
+ demo = gr.Interface(
52
+ fn=render_midi,
53
+ inputs=[
54
+ gr.File(label="Upload a Score MIDI File (.mid)", file_types=[".mid"]),
55
+ gr.Slider(0.1, 2.0, value=1.0, step=0.01, label="Temperature"),
56
+ gr.Slider(0.1, 1.0, value=0.95, step=0.01, label="Top-p"),
57
+ gr.Slider(1, 100, value=50, step=1, label="Top-k"),
58
+ ],
59
+ outputs=gr.File(label="Rendered Performance MIDI"),
60
+ title="🎹 Pianist Transformer Rendering",
61
+ description=(
62
+ "Upload a symbolic (score) MIDI file and let the Pianist Transformer render it into "
63
+ "a more expressive performance MIDI. Adjust decoding parameters below to control "
64
+ "the expressiveness and randomness of the output."
65
+ ),
66
+ examples=None,
67
+ )
68
+
69
+ if __name__ == "__main__":
70
+ demo.launch()
src/inference/inference.py ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from src.model.generate import batch_performance_render, map_midi
2
+ from src.model.pianoformer import PianoT5Gemma, PianoT5GemmaConfig
3
+ import torch
4
+ from datasets import load_dataset
5
+ import os
6
+ from miditoolkit import MidiFile
7
+ from src.utils.midi import midi_to_ids, ids_to_midi
8
+ import random
9
+
10
+ if __name__ == "__main__":
11
+ model = PianoT5Gemma.from_pretrained(
12
+ "models/sft/",
13
+ torch_dtype=torch.bfloat16
14
+ )#.cuda()
15
+
16
+ midis = []
17
+ for i in range(1):
18
+ midis.append(MidiFile(f"data/midis/testset/score/{i}.mid"))
19
+
20
+ res = batch_performance_render(
21
+ model,
22
+ midis,
23
+ temperature=1.0,
24
+ top_p=0.95,
25
+ device="cpu"
26
+ )
27
+
28
+ if not os.path.exists("data/midis/testset/inference"):
29
+ os.makedirs("data/midis/testset/inference")
30
+
31
+ for i, mid in enumerate(res):
32
+ mid = map_midi(midis[i], mid)
33
+ mid.dump(f"data/midis/testset/inference/{i}.mid")
src/model/generate.py ADDED
@@ -0,0 +1,377 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from src.utils.midi import ids_to_midi, midi_to_ids
2
+ from src.model.pianoformer import PianoT5Gemma
3
+ from miditoolkit import MidiFile
4
+ from torch.nn.utils.rnn import pad_sequence
5
+ from transformers import LogitsProcessorList, LogitsProcessor
6
+ from tqdm import tqdm
7
+ import torch
8
+ from src.utils.midi import normalize_midi
9
+ from miditoolkit import MidiFile, Note, TempoChange, Instrument, ControlChange
10
+ import bisect
11
+
12
+ class BatchSparseForcedTokenProcessor(LogitsProcessor):
13
+ def __init__(self, input_ids, config, target_len, origin_len, already, weight, progress_callback):
14
+ self.batch_map = [{j: input_ids[i][j] for j in range(0, len(input_ids[i]), 8)} for i in range(len(input_ids))]
15
+ self.valid_id_range = config.valid_id_range
16
+ self.target_len = target_len
17
+ self.origin_len = origin_len
18
+ self.already = already
19
+ self.weight = weight
20
+ self.progress_callback = progress_callback
21
+
22
+ def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
23
+ if self.progress_callback:
24
+ self.progress_callback(
25
+ (input_ids.shape[1] - self.origin_len) / (self.target_len - self.origin_len) * self.weight + self.already
26
+ )
27
+ step = input_ids.shape[1] - 1
28
+ batch_size = scores.shape[0]
29
+ for i in range(batch_size):
30
+ sample_map = self.batch_map[i]
31
+ if step in sample_map:
32
+ forced_token_id = sample_map[step]
33
+ scores[i] = float('-inf')
34
+ scores[i, forced_token_id] = 0.0
35
+ else:
36
+ step = step % 8
37
+ scores[i, :self.valid_id_range[step][0]] = float('-inf')
38
+ scores[i, self.valid_id_range[step][1]:] = float('-inf')
39
+ #if step % 8 > 3:
40
+ # scores = scores / 0.95
41
+ return scores
42
+
43
+ @torch.no_grad()
44
+ def batch_performance_render(
45
+ model,
46
+ score_midi_objs,
47
+ max_context_length=4096,
48
+ overlap_ratio=0.5,
49
+ temperature=1.0,
50
+ top_p=0.95,
51
+ device="cpu",
52
+ progress_callback=None
53
+ ):
54
+ def slide_window(total_len, window_len):
55
+ if total_len <= window_len:
56
+ return [(0, total_len)]
57
+ window_len = window_len // 8 * 8
58
+ out = []
59
+ start = 0
60
+ while start + window_len <= total_len:
61
+ out.append((start, start + window_len))
62
+ start += int(window_len * (1 - overlap_ratio)) // 8 * 8
63
+ if out[-1][1] != total_len:
64
+ out.append((start, total_len))
65
+ return out
66
+ if max_context_length > 4096:
67
+ raise ValueError("You should set max_context_length <= 4096!")
68
+ batch_ids = [torch.tensor(midi_to_ids(model.config, score_midi_obj), dtype=torch.long).to(device) for score_midi_obj in score_midi_objs]
69
+ len_list = [len(batch_ids[i]) for i in range(len(batch_ids))]
70
+
71
+ input_ids = pad_sequence(batch_ids, batch_first=True, padding_value=model.config.pad_token_id)
72
+ windows = slide_window(input_ids.shape[1], max_context_length)
73
+ #print(windows)
74
+ output_list = []
75
+ res_tensor = None
76
+ for i in tqdm(range(len(windows))):
77
+ start, end = windows[i]
78
+ logits_processor = LogitsProcessorList([
79
+ BatchSparseForcedTokenProcessor(
80
+ input_ids[:,start:end],
81
+ model.config,
82
+ end,
83
+ start,
84
+ i / len(windows),
85
+ 1 / len(windows),
86
+ progress_callback,
87
+ )
88
+ ])
89
+ if i == 0:
90
+ output = model.generate(
91
+ input_ids[:,start:end],
92
+ do_sample=True,
93
+ max_new_tokens=end-start,
94
+ logits_processor=logits_processor,
95
+ temperature=temperature,
96
+ top_p=top_p,
97
+ )
98
+ res_tensor = output[:,1:]
99
+ else:
100
+ last_start, last_end = windows[i-1]
101
+ length = int(((last_end-last_start) - (start-last_start)) * 0.2)
102
+ decoder_input_ids = output_list[i-1][:, start-last_start:last_end-last_start - length]
103
+ start_tensor = torch.tensor([[model.config.bos_token_id] for _ in range(input_ids.shape[0])], dtype=torch.long).to(device)
104
+ decoder_input_ids = torch.cat([start_tensor, decoder_input_ids], dim=1)
105
+ #print(decoder_input_ids.shape)
106
+ output = model.generate(
107
+ input_ids[:,start:end],
108
+ decoder_input_ids=decoder_input_ids,
109
+ do_sample=True,
110
+ max_new_tokens=end-last_end+length,
111
+ logits_processor=logits_processor,
112
+ temperature=temperature,
113
+ top_p=top_p,
114
+ )
115
+ res_tensor = torch.cat([res_tensor[:,:-length], output[:,-(end-last_end+length):]], dim=1)
116
+ output_list.append(output)
117
+ res_tensor = res_tensor.cpu().numpy().tolist()
118
+ #print(res_tensor)
119
+ res = []
120
+ for i in range(len(res_tensor)):
121
+ #print(res_tensor[i][:len_list[i]])
122
+ res.append(ids_to_midi(model.config, res_tensor[i][:len_list[i]]))
123
+ return res
124
+
125
+
126
+ def map_midi(score_midi_obj, performance_midi_obj):
127
+ def compute_duration(start_time, target_duration, tempo_list):
128
+ if target_duration <= 0:
129
+ return 0
130
+ if not tempo_list:
131
+ # 如果没有提供tempo信息,则假定为默认的120 BPM
132
+ tempo_list = [TempoChange(120, 0)]
133
+
134
+ # --- 步骤1: 定位start_time所在的BPM区间 ---
135
+ # 提取所有tempo变化的时间点
136
+ tempo_times = [t.time for t in tempo_list]
137
+ # 使用二分查找找到start_time应该插入的位置
138
+ # bisect_right返回的是插入点索引,因此当前生效的tempo在索引-1的位置
139
+ start_tempo_idx = bisect.bisect_right(tempo_times, start_time) - 1
140
+ # 如果start_time在第一个tempo变化之前,索引会是-1,修正为0
141
+ if start_tempo_idx < 0:
142
+ start_tempo_idx = 0
143
+
144
+ # --- 步骤2: 初始化循环变量 ---
145
+ total_ticks_duration = 0.0
146
+ time_remaining_ms = float(target_duration)
147
+ current_tick = start_time
148
+ current_tempo_idx = start_tempo_idx
149
+
150
+ # --- 步骤3: 循环处理每个BPM区间,直到消耗完target_duration ---
151
+ # 使用一个极小值(epsilon)来处理浮点数精度问题
152
+ while time_remaining_ms > 1e-9:
153
+ current_tempo_event = tempo_list[current_tempo_idx]
154
+ current_bpm = current_tempo_event.tempo
155
+
156
+ # 计算在当前BPM下,每个tick持续多少毫秒
157
+ # 1分钟 = 60,000毫秒
158
+ # 每分钟节拍数 = bpm
159
+ # 每拍tick数 = TICK_PER_BEAT
160
+ # ms_per_tick = (毫秒/分钟) / (节拍/分钟) / (tick/节拍) = (60000 / bpm) / TICK_PER_BEAT
161
+ ms_per_tick = (60 * 1000.0 / current_bpm) / 500
162
+
163
+ # 确定当前BPM区间的结束点
164
+ # 如果是最后一个tempo,则它会一直持续下去
165
+ end_of_segment_tick = float('inf')
166
+ if current_tempo_idx + 1 < len(tempo_list):
167
+ end_of_segment_tick = tempo_list[current_tempo_idx + 1].time
168
+
169
+ # 计算从当前位置到本BPM区间结束,有多少tick
170
+ ticks_in_segment = end_of_segment_tick - current_tick
171
+ # 这些tick总共持续多少毫秒
172
+ ms_in_segment = ticks_in_segment * ms_per_tick
173
+
174
+ # --- 步骤4: 判断与更新 ---
175
+ if time_remaining_ms <= ms_in_segment:
176
+ # 如果剩余需要的时间,在本BPM区间内就能满足
177
+ # 计算还需要多少tick来凑够剩余的毫秒数
178
+ ticks_needed = time_remaining_ms / ms_per_tick
179
+ total_ticks_duration += ticks_needed
180
+ # 时间已全部消耗完毕,跳出循环
181
+ time_remaining_ms = 0
182
+ else:
183
+ # 如果本BPM区间的时间不够用
184
+ # 消耗掉整个区间的tick和毫秒数
185
+ total_ticks_duration += ticks_in_segment
186
+ time_remaining_ms -= ms_in_segment
187
+
188
+ # 更新“指针”,移动到下一个BPM区间的起点
189
+ current_tick = end_of_segment_tick
190
+ current_tempo_idx += 1
191
+
192
+ # 返回四舍五入后的总tick数
193
+ return round(total_ticks_duration)
194
+
195
+ def ms_to_tick(target_ms, tempo_list):
196
+ # --- 边缘情况处理 ---
197
+ if target_ms <= 0:
198
+ return 0
199
+ if not tempo_list:
200
+ # 如果没有提供tempo信息,则假定为默认的120 BPM
201
+ tempo_list = [TempoChange(120, 0)]
202
+
203
+ # --- 步骤1: 初始化累加器 ---
204
+ accumulated_ms = 0.0
205
+
206
+ # --- 步骤2: 遍历所有“有终点”的BPM区间 ---
207
+ # 我们遍历到倒数第二个元素,因为每个循环处理的是 tempo[i] 到 tempo[i+1] 的区间
208
+ for i in range(len(tempo_list) - 1):
209
+ current_tempo_event = tempo_list[i]
210
+ next_tempo_event = tempo_list[i+1]
211
+
212
+ current_bpm = current_tempo_event.tempo
213
+
214
+ # 计算当前区间的tick数和对应的毫秒数
215
+ ticks_in_segment = next_tempo_event.time - current_tempo_event.time
216
+
217
+ # 如果区间长度为0,直接跳过,避免除零错误
218
+ if ticks_in_segment == 0:
219
+ continue
220
+
221
+ ms_per_tick = (60 * 1000.0 / current_bpm) / 500
222
+ ms_in_segment = ticks_in_segment * ms_per_tick
223
+
224
+ # --- 步骤3: 判断目标是否在本区间内 ---
225
+ if target_ms <= accumulated_ms + ms_in_segment:
226
+ # 目标在本区间内!
227
+ ms_into_segment = target_ms - accumulated_ms
228
+ ticks_needed = ms_into_segment / ms_per_tick
229
+
230
+ # 最终tick = 本区间起始tick + 在本区间内转换出的tick
231
+ final_tick = current_tempo_event.time + ticks_needed
232
+ return round(final_tick)
233
+
234
+ # 如果目标不在本区间,则累加本区间的总毫秒数,继续下一个循环
235
+ accumulated_ms += ms_in_segment
236
+
237
+ # --- 步骤4: 如果循环结束仍未返回,说明目标在最后一个BPM区间内 ---
238
+ last_tempo_event = tempo_list[-1]
239
+ last_bpm = last_tempo_event.tempo
240
+
241
+ ms_per_tick = (60 * 1000.0 / last_bpm) / 500
242
+
243
+ # 计算进入最后一个区间后,还需要多少毫秒
244
+ ms_into_segment = target_ms - accumulated_ms
245
+ ticks_needed = ms_into_segment / ms_per_tick
246
+
247
+ # 最终tick = 最后一个区间的起始tick + 剩余毫秒转换的tick
248
+ final_tick = last_tempo_event.time + ticks_needed
249
+ return round(final_tick)
250
+
251
+ norm_score = normalize_midi(score_midi_obj)
252
+ norm_performance = normalize_midi(performance_midi_obj)
253
+
254
+ score_notes = norm_score.instruments[0].notes
255
+ performance_notes = norm_performance.instruments[0].notes
256
+ performance_ccs = norm_performance.instruments[0].control_changes
257
+
258
+ start_list = []
259
+ last = -1
260
+ score_start = score_notes[0].start
261
+ performance_start = performance_notes[0].start
262
+ for i in range(len(score_notes)):
263
+ performance_notes[i].end -= performance_start
264
+ performance_notes[i].start -= performance_start
265
+ score_notes[i].end -= score_start
266
+ score_notes[i].start -= score_start
267
+ if score_notes[i].start != last:
268
+ start_list.append((score_notes[i].start, performance_notes[i].start, i))
269
+ last = score_notes[i].start
270
+
271
+ for i in range(len(performance_ccs)):
272
+ performance_ccs[i].time -= performance_start
273
+
274
+ score_interval_list = []
275
+ performance_interval_list = []
276
+
277
+ for i in range(len(start_list)-1):
278
+ score_interval_list.append(start_list[i+1][0] - start_list[i][0])
279
+ performance_interval_list.append(start_list[i+1][1] - start_list[i][1])
280
+ #print(score_interval_list)
281
+ #print(performance_interval_list)
282
+
283
+ tempo_list = []
284
+ start_note_offset = []
285
+ for i in range(len(score_interval_list)):
286
+ if performance_interval_list[i] != 0:
287
+ bpm = 120.0 / performance_interval_list[i] * score_interval_list[i]
288
+ else:
289
+ bpm = 300
290
+
291
+ if bpm > 300:
292
+ start_note_offset.append(300 / 120.0 * performance_interval_list[i] - score_interval_list[i])
293
+ elif bpm < 10:
294
+ start_note_offset.append(10 / 120.0 * performance_interval_list[i] - score_interval_list[i])
295
+ else:
296
+ start_note_offset.append(0)
297
+ tempo_list.append(max(min(bpm, 300), 10))
298
+ #tempo_list.append(120.0 / performance_interval_list[i] * score_interval_list[i])
299
+ #print(tempo_list)
300
+
301
+ for i in range(1, len(start_note_offset)):
302
+ start_note_offset[i] += start_note_offset[i-1]
303
+ #print(start_note_offset)
304
+
305
+ #print(len(tempo_list))
306
+ #print(len(start_list))
307
+ note_tempo_list = []
308
+ note_performance_align = []
309
+ note_start_offset = [0]
310
+ cnt = 0
311
+ for i in range(len(score_notes)):
312
+ if cnt < len(start_list) - 2 and i >= start_list[cnt + 1][2]:
313
+ cnt += 1
314
+ note_tempo_list.append(tempo_list[cnt])
315
+ note_performance_align.append(start_list[cnt][1])
316
+ note_start_offset.append(start_note_offset[cnt])
317
+ #print(note_start_offset)
318
+
319
+ for i in range(len(score_notes)):
320
+ score_notes[i].start += note_start_offset[i]
321
+ note_interval_list = [0]
322
+ for i in range(len(score_notes)-1):
323
+ note_interval_list.append(score_notes[i+1].start - score_notes[i].start)
324
+
325
+ #print(note_tempo_list)
326
+ #print(note_performance_align)
327
+
328
+ #for i in range(len(performance_notes)):
329
+ #print(performance_notes[i].start)
330
+
331
+ micro_shift_list = [0]
332
+ cnt = 1
333
+ last_time = 0
334
+ for i in range(1, len(score_notes)):
335
+ last_time += note_interval_list[i] / note_tempo_list[i-1] * 120
336
+ micro_shift_list.append((performance_notes[i].start - last_time) / 120 * note_tempo_list[i-1])
337
+ #last_time = note_performance_align[i]
338
+ #print(last_time)
339
+
340
+ #print(micro_shift_list)
341
+ #plt.plot(tempo_list)
342
+
343
+ res = MidiFile(ticks_per_beat=500)
344
+ res_notes = []
345
+ start_time_list = []
346
+ tempo_list_filter = []
347
+ cc_list = []
348
+ last = -1
349
+ for i in range(len(score_notes)):
350
+ start_time_list.append(round(score_notes[i].start + micro_shift_list[i]))
351
+ #res_notes.append(Note(performance_notes[i].velocity, score_notes[i].pitch, round(score_notes[i].start + micro_shift_list[i]), round(score_notes[i].start + micro_shift_list[i]) + 100))
352
+ #res.tempo_changes.append(TempoChange(round(note_tempo_list[i]), round(score_notes[i].start + micro_shift_list[i])))
353
+ #print(last , round(note_tempo_list[i]))
354
+ if last != round(note_tempo_list[i]):
355
+ tempo_list_filter.append(TempoChange(round(note_tempo_list[i]), round(score_notes[i].start + micro_shift_list[i])))
356
+ last = round(note_tempo_list[i])
357
+ for i in range(len(score_notes)):
358
+ res_notes.append(
359
+ Note(
360
+ performance_notes[i].velocity,
361
+ score_notes[i].pitch,
362
+ start_time_list[i],
363
+ start_time_list[i]+compute_duration(start_time_list[i], performance_notes[i].duration, tempo_list_filter)
364
+ )
365
+ )
366
+
367
+ for cc in performance_ccs:
368
+ cc_list.append(ControlChange(64, cc.value, ms_to_tick(cc.time, tempo_list_filter)))
369
+
370
+ #print(tempo_list_filter)
371
+ res.tempo_changes = tempo_list_filter
372
+ res.instruments.append(Instrument(program=0, is_drum=False, name="Piano", notes=res_notes, control_changes=cc_list))
373
+ return res
374
+
375
+
376
+ if __name__ == "__main__":
377
+ pass
src/model/pianoformer.py ADDED
@@ -0,0 +1,459 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Callable, Optional, Union
2
+ from transformers import T5GemmaModel, T5GemmaConfig, T5GemmaModuleConfig, T5GemmaPreTrainedModel, T5GemmaForConditionalGeneration, AutoTokenizer
3
+ import torch
4
+ from transformers.models.t5gemma.modeling_t5gemma import (
5
+ T5GemmaLMHead,
6
+ GenerationMixin,
7
+ logger,
8
+ T5GemmaSelfAttention,
9
+ T5GemmaEncoderLayer,
10
+ T5GemmaRMSNorm,
11
+ T5GemmaRotaryEmbedding,
12
+ make_default_2d_attention_mask,
13
+ create_causal_mask,
14
+ bidirectional_mask_function,
15
+ create_sliding_window_causal_mask,
16
+ sliding_window_bidirectional_mask_function,
17
+ T5GemmaDecoder
18
+ )
19
+
20
+ from transformers.modeling_outputs import (
21
+ BaseModelOutput,
22
+ BaseModelOutputWithPastAndCrossAttentions,
23
+ Seq2SeqLMOutput,
24
+ Seq2SeqModelOutput,
25
+ SequenceClassifierOutput,
26
+ TokenClassifierOutput,
27
+ )
28
+ from transformers.cache_utils import Cache, DynamicCache, EncoderDecoderCache
29
+ from transformers.processing_utils import Unpack
30
+ from transformers.utils import TransformersKwargs, auto_docstring, can_return_tuple, is_torchdynamo_compiling, logging
31
+ import torch.nn as nn
32
+
33
+ class PianoT5GemmaConfig(T5GemmaConfig):
34
+
35
+ def __init__(
36
+ self,
37
+ hidden_size=768,
38
+ intermediate_size=3072,
39
+ num_attention_heads=8,
40
+ num_key_value_heads=4,
41
+ head_dim=128,
42
+ encoder_layers_num=8,
43
+ decoder_layers_num=4,
44
+ **kwargs
45
+ ):
46
+ total_vocab_size = 5389
47
+
48
+ self.mask_token_id = 1
49
+ self.bos_token_id = 2
50
+ self.play_token_id = 4
51
+ self.pitch_start = 5
52
+ self.velocity_start = 5 + 128
53
+ self.timing_start = 5 + 128 + 128
54
+ self.pedal_start = 5 + 128 + 128 + 5000
55
+ self.hidden_size = hidden_size
56
+
57
+ self.valid_id_range = [
58
+ (5, 133),
59
+ (261, 5252),
60
+ (133, 261),
61
+ (261, 5261),
62
+ (5261, 5389),
63
+ (5261, 5389),
64
+ (5261, 5389),
65
+ (5261, 5389),
66
+ ]
67
+
68
+ encoder_config = T5GemmaModuleConfig(
69
+ vocab_size=total_vocab_size,
70
+ hidden_size=hidden_size,
71
+ intermediate_size=intermediate_size,
72
+ num_hidden_layers=encoder_layers_num,
73
+ num_attention_heads=num_attention_heads,
74
+ num_key_value_heads=num_key_value_heads,
75
+ head_dim=head_dim,
76
+ pad_token_id=0,
77
+ bos_token_id=2,
78
+ eos_token_id=3,
79
+ )
80
+ decoder_config = T5GemmaModuleConfig(
81
+ vocab_size=total_vocab_size,
82
+ hidden_size=hidden_size,
83
+ intermediate_size=intermediate_size,
84
+ num_hidden_layers=decoder_layers_num,
85
+ num_attention_heads=num_attention_heads,
86
+ num_key_value_heads=num_key_value_heads,
87
+ head_dim=head_dim,
88
+ pad_token_id=0,
89
+ bos_token_id=2,
90
+ eos_token_id=3,
91
+ )
92
+
93
+ super().__init__(
94
+ encoder=encoder_config,
95
+ decoder=decoder_config,
96
+ vocab_size=total_vocab_size,
97
+ **kwargs,
98
+ )
99
+
100
+ class PianoEncoderEmbeddings(nn.Module):
101
+ def __init__(self, config):
102
+ super().__init__()
103
+ self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id)
104
+ if config.hidden_size % 8 != 0:
105
+ raise ValueError("Invalid hidden size: must be a multiple of 8.")
106
+ self.projection_layers = nn.ModuleList([nn.Linear(config.hidden_size, config.hidden_size // 8) for i in range(8)])
107
+ self.hidden_size = config.hidden_size
108
+
109
+ def forward(
110
+ self,
111
+ input_ids: Optional[torch.LongTensor] = None,
112
+ ) -> torch.Tensor:
113
+ input_shape = input_ids.size()
114
+
115
+ batch_size = input_shape[0]
116
+ seq_length = input_shape[1]
117
+
118
+ inputs_embeds = self.word_embeddings(input_ids)
119
+ grouped_embeds = inputs_embeds.view(batch_size, seq_length // 8, 8, -1)
120
+ projection_list = []
121
+ for i in range(8):
122
+ projection_list.append(self.projection_layers[i](grouped_embeds[:,:,i,:]))
123
+ projection_cat = torch.cat(projection_list, dim=-1)
124
+ inputs_embeds = projection_cat.view(batch_size, -1, self.hidden_size)
125
+ embeddings = inputs_embeds
126
+ return embeddings
127
+
128
+
129
+ class PianoT5GemmaEncoder(T5GemmaPreTrainedModel):
130
+ _can_record_outputs = {
131
+ "attentions": T5GemmaSelfAttention,
132
+ "hidden_states": T5GemmaEncoderLayer,
133
+ }
134
+
135
+ def __init__(self, config):
136
+ super().__init__(config)
137
+ self.padding_idx = config.pad_token_id
138
+ self.vocab_size = config.vocab_size
139
+
140
+ self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
141
+ self.embeddings = PianoEncoderEmbeddings(config)
142
+ self.norm = T5GemmaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
143
+ self.rotary_emb = T5GemmaRotaryEmbedding(config=config)
144
+ self.gradient_checkpointing = False
145
+
146
+ self.layers = nn.ModuleList(
147
+ [T5GemmaEncoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
148
+ )
149
+ self.dropout = nn.Dropout(config.dropout_rate)
150
+
151
+ # Initialize weights and apply final processing
152
+ self.post_init()
153
+
154
+ def forward(
155
+ self,
156
+ input_ids: Optional[torch.LongTensor] = None,
157
+ attention_mask: Optional[torch.Tensor] = None,
158
+ position_ids: Optional[torch.LongTensor] = None,
159
+ inputs_embeds: Optional[torch.FloatTensor] = None,
160
+ **kwargs: Unpack[TransformersKwargs],
161
+ ) -> BaseModelOutput:
162
+ if (input_ids is None) ^ (inputs_embeds is not None):
163
+ raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
164
+
165
+ if inputs_embeds is None:
166
+ inputs_embeds = self.embeddings(input_ids)
167
+ input_ids = None
168
+
169
+ cache_position = torch.arange(0, inputs_embeds.shape[1], device=inputs_embeds.device)
170
+
171
+ if position_ids is None:
172
+ position_ids = cache_position.unsqueeze(0)
173
+
174
+ if attention_mask is not None:
175
+ B, L = attention_mask.shape
176
+ block_mask = attention_mask.view(B, L // 8, 8)
177
+ mask2 = block_mask.any(dim=-1).long()
178
+ attention_mask = mask2.view(B, -1)
179
+
180
+ if attention_mask is None:
181
+ attention_mask = make_default_2d_attention_mask(input_ids, inputs_embeds, self.config.pad_token_id)
182
+
183
+ if not isinstance(self_attn_mask_mapping := attention_mask, dict):
184
+ mask_kwargs = {
185
+ "config": self.config,
186
+ "input_embeds": inputs_embeds,
187
+ "attention_mask": attention_mask,
188
+ "cache_position": cache_position,
189
+ "past_key_values": None,
190
+ "position_ids": position_ids,
191
+ }
192
+ self_attn_mask_mapping = {
193
+ "full_attention": create_causal_mask(
194
+ **mask_kwargs,
195
+ or_mask_function=bidirectional_mask_function(attention_mask),
196
+ ),
197
+ "sliding_attention": create_sliding_window_causal_mask(
198
+ **mask_kwargs,
199
+ or_mask_function=sliding_window_bidirectional_mask_function(self.config.sliding_window),
200
+ and_mask_function=bidirectional_mask_function(attention_mask),
201
+ ),
202
+ }
203
+
204
+ hidden_states = inputs_embeds
205
+ position_embeddings = self.rotary_emb(hidden_states, position_ids)
206
+
207
+ normalizer = torch.tensor(self.config.hidden_size**0.5, dtype=hidden_states.dtype)
208
+ hidden_states = hidden_states * normalizer
209
+ hidden_states = self.dropout(hidden_states)
210
+
211
+ for layer_module in self.layers[: self.config.num_hidden_layers]:
212
+ hidden_states = layer_module(
213
+ hidden_states,
214
+ position_embeddings,
215
+ self_attn_mask_mapping[layer_module.attention_type],
216
+ position_ids,
217
+ **kwargs,
218
+ )
219
+ hidden_states = self.norm(hidden_states)
220
+ hidden_states = self.dropout(hidden_states)
221
+ return BaseModelOutput(
222
+ last_hidden_state=hidden_states,
223
+ )
224
+
225
+ class PianoT5GemmaModel(T5GemmaPreTrainedModel):
226
+ def __init__(self, config: T5GemmaConfig):
227
+ super().__init__(config)
228
+
229
+ if not config.is_encoder_decoder:
230
+ raise ValueError("T5GemmaModel only support encoder-decoder modeling. Use `T5GemmaEncoderModel` instead.")
231
+
232
+ self.encoder = PianoT5GemmaEncoder(config.encoder)
233
+ self.decoder = T5GemmaDecoder(config.decoder)
234
+
235
+ self.post_init()
236
+
237
+ def get_encoder(self):
238
+ return self.encoder
239
+
240
+ def get_decoder(self):
241
+ return self.decoder
242
+
243
+ def get_input_embeddings(self):
244
+ return self.encoder.get_input_embeddings()
245
+
246
+ def set_input_embeddings(self, new_embeddings):
247
+ return self.encoder.set_input_embeddings(new_embeddings)
248
+
249
+ def forward(
250
+ self,
251
+ input_ids: Optional[torch.LongTensor] = None,
252
+ attention_mask: Optional[torch.FloatTensor] = None,
253
+ position_ids: Optional[torch.LongTensor] = None,
254
+ decoder_input_ids: Optional[torch.LongTensor] = None,
255
+ decoder_attention_mask: Optional[torch.BoolTensor] = None,
256
+ decoder_position_ids: Optional[torch.LongTensor] = None,
257
+ encoder_outputs: Optional[BaseModelOutput] = None,
258
+ past_key_values: Optional[EncoderDecoderCache] = None,
259
+ inputs_embeds: Optional[torch.Tensor] = None,
260
+ decoder_inputs_embeds: Optional[torch.Tensor] = None,
261
+ use_cache: Optional[bool] = None,
262
+ cache_position: Optional[torch.LongTensor] = None,
263
+ **kwargs: Unpack[TransformersKwargs],
264
+ ) -> Seq2SeqModelOutput:
265
+ r"""
266
+ decoder_position_ids (`torch.LongTensor` of shape `(batch_size, decoder_sequence_length)`, *optional*):
267
+ Indices of positions of each decoder input sequence tokens in the position embeddings. Selected in the range `[0,
268
+ config.decoder.n_positions - 1]`. [What are position IDs?](../glossary#position-ids)
269
+ """
270
+ if encoder_outputs is None:
271
+ encoder_outputs = self.encoder(
272
+ input_ids=input_ids,
273
+ attention_mask=attention_mask,
274
+ position_ids=position_ids,
275
+ inputs_embeds=inputs_embeds,
276
+ **kwargs,
277
+ )
278
+
279
+ encoder_hidden_states = encoder_outputs.last_hidden_state
280
+
281
+ decoder_outputs = self.decoder(
282
+ input_ids=decoder_input_ids,
283
+ attention_mask=decoder_attention_mask,
284
+ position_ids=decoder_position_ids,
285
+ inputs_embeds=decoder_inputs_embeds,
286
+ past_key_values=past_key_values,
287
+ encoder_hidden_states=encoder_hidden_states,
288
+ encoder_attention_mask=attention_mask,
289
+ use_cache=use_cache,
290
+ cache_position=cache_position,
291
+ **kwargs,
292
+ )
293
+
294
+ return Seq2SeqModelOutput(
295
+ last_hidden_state=decoder_outputs.last_hidden_state,
296
+ past_key_values=decoder_outputs.past_key_values,
297
+ decoder_hidden_states=decoder_outputs.hidden_states
298
+ if kwargs.get("output_hidden_states", False)
299
+ else (decoder_outputs.last_hidden_state,),
300
+ decoder_attentions=decoder_outputs.attentions,
301
+ cross_attentions=decoder_outputs.cross_attentions,
302
+ encoder_last_hidden_state=encoder_outputs.last_hidden_state,
303
+ encoder_hidden_states=encoder_outputs.hidden_states,
304
+ encoder_attentions=encoder_outputs.attentions,
305
+ )
306
+
307
+
308
+ class PianoT5Gemma(T5GemmaPreTrainedModel, GenerationMixin):
309
+ _tied_weights_keys = ["model.decoder.embed_tokens.weight", "lm_head.out_proj.weight"]
310
+ _tp_plan = {"lm_head.out_proj": "colwise_rep"}
311
+ _pp_plan = {"lm_head.out_proj": (["hidden_states"], ["logits"])}
312
+
313
+ def __init__(self, config: PianoT5GemmaConfig):
314
+ config.is_encoder_decoder = True
315
+ super().__init__(config)
316
+ self.embeddings = PianoEncoderEmbeddings(config)
317
+ self.model = PianoT5GemmaModel(config)
318
+ self.vocab_size = config.decoder.vocab_size
319
+ self.lm_head = T5GemmaLMHead(config.decoder.hidden_size, self.vocab_size)
320
+ self.loss_type = "ForMaskedLM"
321
+
322
+ self.post_init()
323
+
324
+ def set_output_embeddings(self, new_embeddings):
325
+ self.lm_head.out_proj = new_embeddings
326
+
327
+ def get_output_embeddings(self):
328
+ return self.lm_head.out_proj
329
+
330
+ def _tie_weights(self):
331
+ # Decoder input and output embeddings are tied.
332
+ if self.config.tie_word_embeddings:
333
+ self._tie_or_clone_weights(self.lm_head.out_proj, self.get_decoder().get_input_embeddings())
334
+
335
+ def get_encoder(self):
336
+ return self.model.encoder
337
+
338
+ def get_decoder(self):
339
+ return self.model.decoder
340
+
341
+ def forward(
342
+ self,
343
+ input_ids: Optional[torch.LongTensor] = None,
344
+ attention_mask: Optional[torch.FloatTensor] = None,
345
+ position_ids: Optional[torch.LongTensor] = None,
346
+ decoder_input_ids: Optional[torch.LongTensor] = None,
347
+ decoder_attention_mask: Optional[torch.BoolTensor] = None,
348
+ decoder_position_ids: Optional[torch.LongTensor] = None,
349
+ encoder_outputs: Optional[BaseModelOutput] = None,
350
+ past_key_values: Optional[EncoderDecoderCache] = None,
351
+ inputs_embeds: Optional[torch.FloatTensor] = None,
352
+ decoder_inputs_embeds: Optional[torch.FloatTensor] = None,
353
+ labels: Optional[torch.LongTensor] = None,
354
+ use_cache: Optional[bool] = None,
355
+ cache_position: Optional[torch.LongTensor] = None,
356
+ logits_to_keep: Union[int, torch.Tensor] = 0,
357
+ **kwargs: Unpack[TransformersKwargs],
358
+ ) -> Union[tuple[torch.FloatTensor], Seq2SeqLMOutput]:
359
+ r"""
360
+ decoder_position_ids (`torch.LongTensor` of shape `(batch_size, decoder_sequence_length)`, *optional*):
361
+ Indices of positions of each decoder input sequence tokens in the position embeddings. Selected in the range `[0,
362
+ config.decoder.n_positions - 1]`. [What are position IDs?](../glossary#position-ids)
363
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
364
+ Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
365
+ config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
366
+ (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
367
+ """
368
+ if self.training and self.config._attn_implementation != "eager":
369
+ msg = (
370
+ "It is strongly recommended to train T5Gemma models with the `eager` attention implementation "
371
+ f"instead of `{self.config._attn_implementation}`. Use `eager` with `AutoModelForCausalLM.from_pretrained('<path-to-checkpoint>', attn_implementation='eager')`."
372
+ )
373
+ if is_torchdynamo_compiling():
374
+ raise ValueError(msg)
375
+ else:
376
+ logger.warning_once(msg)
377
+
378
+ if labels is not None and decoder_input_ids is None and decoder_inputs_embeds is None:
379
+ # get decoder inputs from shifting lm labels to the right
380
+ decoder_input_ids = self._shift_right(labels)
381
+
382
+ #if input_ids is not None:
383
+ # inputs_embeds = self.embeddings(input_ids)
384
+
385
+ #if attention_mask is not None:
386
+ # B, L = attention_mask.shape
387
+ # block_mask = attention_mask.view(B, L // 8, 8)
388
+ # mask2 = block_mask.any(dim=-1).long()
389
+ # attention_mask = mask2.view(B, -1)
390
+
391
+ #print(attention_mask)
392
+
393
+ decoder_outputs: Seq2SeqModelOutput = self.model(
394
+ input_ids=input_ids,
395
+ attention_mask=attention_mask,
396
+ position_ids=position_ids,
397
+ decoder_input_ids=decoder_input_ids,
398
+ decoder_attention_mask=decoder_attention_mask,
399
+ decoder_position_ids=decoder_position_ids,
400
+ encoder_outputs=encoder_outputs,
401
+ past_key_values=past_key_values,
402
+ inputs_embeds=inputs_embeds,
403
+ decoder_inputs_embeds=decoder_inputs_embeds,
404
+ use_cache=use_cache,
405
+ cache_position=cache_position,
406
+ **kwargs,
407
+ )
408
+
409
+ hidden_states = decoder_outputs.last_hidden_state
410
+ # Only compute necessary logits, and do not upcast them to float if we are not computing the loss
411
+ slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
412
+ logits = self.lm_head(hidden_states[:, slice_indices, :])
413
+ decoder_config = self.get_decoder().config
414
+ if decoder_config.final_logit_softcapping is not None:
415
+ logits = logits / decoder_config.final_logit_softcapping
416
+ logits = torch.tanh(logits)
417
+ logits = logits * decoder_config.final_logit_softcapping
418
+
419
+ loss = None
420
+ if labels is not None:
421
+ # Input has right-shifted so we directly perform masked lm loss
422
+ loss = self.loss_function(logits, labels, self.vocab_size, **kwargs)
423
+
424
+ return Seq2SeqLMOutput(
425
+ loss=loss,
426
+ logits=logits,
427
+ past_key_values=decoder_outputs.past_key_values,
428
+ decoder_hidden_states=decoder_outputs.decoder_hidden_states,
429
+ decoder_attentions=decoder_outputs.decoder_attentions,
430
+ cross_attentions=decoder_outputs.cross_attentions,
431
+ encoder_last_hidden_state=decoder_outputs.encoder_last_hidden_state,
432
+ encoder_hidden_states=decoder_outputs.encoder_hidden_states,
433
+ encoder_attentions=decoder_outputs.encoder_attentions,
434
+ )
435
+
436
+ def prepare_decoder_input_ids_from_labels(self, labels: torch.Tensor):
437
+ return self._shift_right(labels)
438
+
439
+
440
+
441
+ if __name__ == "__main__":
442
+ config = PianoT5GemmaConfig()
443
+ test = PianoEncoderEmbeddings(config)
444
+ model = PianoT5Gemma(config).cuda()
445
+ #encoder_config = T5GemmaModuleConfig(num_hidden_layers=1)
446
+ #decoder_config = T5GemmaModuleConfig(num_hidden_layers=1)
447
+ #config = T5GemmaConfig(encoder_config, decoder_config, attn_implementation='eager')
448
+
449
+ #model = T5GemmaForConditionalGeneration(config).cuda()
450
+
451
+ toy_ids = torch.tensor([[1,2,3,4,1,2,3,4,1,2,3,4,1,2,3,4]], dtype=torch.long).cuda()
452
+ #tokenizer = AutoTokenizer.from_pretrained("google/t5gemma-2b-2b-ul2")
453
+ #input_text = "Write me a poem about Machine Learning. Answer:"
454
+ #input_ids = tokenizer(input_text, return_tensors="pt").to("cuda")
455
+ print(model.generate(toy_ids, decoder_input_ids=toy_ids, max_new_tokens=32))
456
+
457
+
458
+ #print(model(input_ids=toy_ids, decoder_input_ids=toy_ids).logits.shape)
459
+
src/utils/func.py ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ import inspect
2
+
3
+ def filter_valid_args(arg_dict, class_type):
4
+ valid_keys = inspect.signature(class_type).parameters.keys()
5
+ return {k: v for k, v in arg_dict.items() if k in valid_keys}
src/utils/midi.py ADDED
@@ -0,0 +1,602 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from miditoolkit import MidiFile, Note, Instrument, TempoChange, ControlChange
2
+ import bisect
3
+ import numpy as np
4
+ import os
5
+ from copy import copy
6
+ import random
7
+ from collections import defaultdict
8
+
9
+ """
10
+ def normalize_midi(midi_obj, target_ticks_per_beat = 500, target_tempo = 120):
11
+ ticks_per_beat = midi_obj.ticks_per_beat
12
+ merged_events = []
13
+ for i in range(len(midi_obj.instruments)):
14
+ filter_control_changes = []
15
+ for cc in midi_obj.instruments[i].control_changes:
16
+ if cc.number == 64:
17
+ filter_control_changes.append(cc)
18
+ merged_events.extend(midi_obj.instruments[i].notes + filter_control_changes)
19
+ merged_events.sort(key=lambda x: (x.start, x.pitch) if isinstance(x, Note) else (x.time, x.number))
20
+
21
+ time_interval = []
22
+ last_time = 0
23
+ for note in merged_events:
24
+ if isinstance(note, Note):
25
+ time_interval.append(note.start - last_time)
26
+ last_time = note.start
27
+ else:
28
+ time_interval.append(note.time - last_time)
29
+ last_time = note.time
30
+
31
+ output_notes = []
32
+ output_cc = []
33
+ ind = -1
34
+ now_tempo = 120
35
+ now_time = 0
36
+ for i, note in enumerate(merged_events):
37
+ if isinstance(note, Note):
38
+ time = note.start
39
+ else:
40
+ time = note.time
41
+ while ind + 1 < len(midi_obj.tempo_changes) and time >= midi_obj.tempo_changes[ind+1].time:
42
+ now_tempo = midi_obj.tempo_changes[ind+1].tempo
43
+ ind += 1
44
+ ratio = target_ticks_per_beat * target_tempo / now_tempo / ticks_per_beat
45
+ start_time = time_interval[i] * ratio + now_time
46
+ if isinstance(note, Note):
47
+ end_time = (note.end - note.start) * ratio + start_time
48
+ output_notes.append(Note(note.velocity, note.pitch, round(start_time), round(end_time)))
49
+ else:
50
+ output_cc.append(ControlChange(64, note.value, round(start_time)))
51
+ now_time = round(start_time)
52
+
53
+ output_midi_obj = MidiFile(ticks_per_beat=target_ticks_per_beat)
54
+ output_midi_obj.instruments.append(Instrument(program=0, is_drum=False, name="Piano", notes=output_notes, control_changes=output_cc))
55
+ output_midi_obj.tempo_changes.append(TempoChange(target_tempo, 0))
56
+ for note in output_notes:
57
+ output_midi_obj.max_tick = max(output_midi_obj.max_tick, note.end)
58
+ for cc in output_cc:
59
+ output_midi_obj.max_tick = max(output_midi_obj.max_tick, cc.time)
60
+ return output_midi_obj
61
+ """
62
+
63
+ """
64
+ def normalize_midi(midi_obj, target_ticks_per_beat=500, target_tempo=120):
65
+ # 创建一个新的、干净的MidiFile对象用于输出
66
+ output_midi_obj = MidiFile(ticks_per_beat=target_ticks_per_beat)
67
+ output_midi_obj.tempo_changes.append(TempoChange(target_tempo, 0))
68
+
69
+ # 获取原始MIDI的tick到秒的精确映射
70
+ # 这是最关键的一步,partitura和miditoolkit都有类似功能
71
+ # miditoolkit的get_tick_to_time_mapping()可以处理所有tempo变化
72
+ tick_to_time_map = midi_obj.get_tick_to_time_mapping()
73
+
74
+ # 计算从秒转换回目标tick的比例因子
75
+ # 目标MIDI中,每秒对应的tick数 = target_ticks_per_beat * (target_tempo / 60)
76
+ seconds_to_target_ticks_factor = target_ticks_per_beat * (target_tempo / 60.0)
77
+
78
+ merged_notes = []
79
+ merged_cc = []
80
+
81
+ # 遍历所有乐器轨道
82
+ for instrument in midi_obj.instruments:
83
+ # 只处理非鼓组的乐器
84
+ if not instrument.is_drum:
85
+ # --- 处理音符 (Notes) ---
86
+ for note in instrument.notes:
87
+ # 1. 将原始tick转换为绝对秒数
88
+ start_time_sec = tick_to_time_map[note.start]
89
+ end_time_sec = tick_to_time_map[note.end]
90
+
91
+ # 2. 将绝对秒数转换为目标tick
92
+ new_start_tick = round(start_time_sec * seconds_to_target_ticks_factor)
93
+ new_end_tick = round(end_time_sec * seconds_to_target_ticks_factor)
94
+
95
+ # 避免duration为0的音符
96
+ if new_start_tick == new_end_tick:
97
+ new_end_tick += 1
98
+
99
+ merged_notes.append(Note(velocity=note.velocity,
100
+ pitch=note.pitch,
101
+ start=new_start_tick,
102
+ end=new_end_tick))
103
+
104
+ # --- 处理延音踏板 (CC #64) ---
105
+ for cc in instrument.control_changes:
106
+ if cc.number == 64:
107
+ # 1. 将原始tick转换为绝对秒数
108
+ time_sec = tick_to_time_map[cc.time]
109
+
110
+ # 2. 将绝对秒数转换为目标tick
111
+ new_time_tick = round(time_sec * seconds_to_target_ticks_factor)
112
+
113
+ merged_cc.append(ControlChange(number=64,
114
+ value=cc.value,
115
+ time=new_time_tick))
116
+
117
+ # --- 排序并创建新乐器 ---
118
+ # 按开始时间排序,对于同时开始的事件,CC优先于Note
119
+ merged_notes.sort(key=lambda x: (x.start, x.pitch))
120
+ merged_cc.sort(key=lambda x: (x.time, x.number))
121
+
122
+ output_instrument = Instrument(program=0, is_drum=False, name="Piano")
123
+ output_instrument.notes = merged_notes
124
+ output_instrument.control_changes = merged_cc
125
+ output_midi_obj.instruments.append(output_instrument)
126
+
127
+ # --- 正确计算 max_tick ---
128
+ max_tick = 0
129
+ if output_instrument.notes:
130
+ max_tick = max(max_tick, max(n.end for n in output_instrument.notes))
131
+ if output_instrument.control_changes:
132
+ max_tick = max(max_tick, max(c.time for c in output_instrument.control_changes))
133
+
134
+ output_midi_obj.max_tick = max_tick
135
+
136
+ return output_midi_obj
137
+ """
138
+
139
+ def normalize_midi(midi_obj, target_ticks_per_beat=500, target_tempo=120):
140
+ """
141
+ 将一个MidiFile对象标准化:
142
+ 1. 合并所有轨道的钢琴音符和延音踏板事件。
143
+ 2. 将所有时间信息(包括tempo变化)统一转换为一个固定的ticks_per_beat和tempo。
144
+ 3. 清理重叠音符以避免解析错误。
145
+ 4. 正确计算并设置max_tick。
146
+
147
+ Args:
148
+ midi_obj (MidiFile): 原始的MidiFile对象。
149
+ target_ticks_per_beat (int): 目标ticks_per_beat.
150
+ target_tempo (float): 目标tempo (BPM).
151
+
152
+ Returns:
153
+ MidiFile: 标准化后的新MidiFile对象。
154
+ """
155
+
156
+ # 创建一个新的、干净的MidiFile对象用于输出
157
+ output_midi_obj = MidiFile(ticks_per_beat=target_ticks_per_beat)
158
+ output_midi_obj.tempo_changes.append(TempoChange(target_tempo, 0))
159
+
160
+ tick_to_time_map = midi_obj.get_tick_to_time_mapping()
161
+ seconds_to_target_ticks_factor = target_ticks_per_beat * (target_tempo / 60.0)
162
+
163
+ # --- 1. 收集并转换所有音符 ---
164
+ all_converted_notes = []
165
+ for instrument in midi_obj.instruments:
166
+ if not instrument.is_drum:
167
+ for note in instrument.notes:
168
+ start_time_sec = tick_to_time_map[note.start]
169
+ end_time_sec = tick_to_time_map[note.end]
170
+
171
+ new_start_tick = round(start_time_sec * seconds_to_target_ticks_factor)
172
+ new_end_tick = round(end_time_sec * seconds_to_target_ticks_factor)
173
+
174
+ if new_start_tick >= new_end_tick:
175
+ # 确保音符至少有1 tick的长度
176
+ new_end_tick = new_start_tick + 1
177
+
178
+ all_converted_notes.append(Note(velocity=note.velocity,
179
+ pitch=note.pitch,
180
+ start=new_start_tick,
181
+ end=new_end_tick))
182
+
183
+ # --- 2. 清理重叠音符 (关键新增部分) ---
184
+ # 首先按音高分组,然后按开始时间排序
185
+ notes_by_pitch = defaultdict(list)
186
+ for note in all_converted_notes:
187
+ notes_by_pitch[note.pitch].append(note)
188
+
189
+ merged_notes = []
190
+ for pitch in sorted(notes_by_pitch.keys()):
191
+ # 对每个音高的音符列表按开始时间排序
192
+ sorted_notes = sorted(notes_by_pitch[pitch], key=lambda n: n.start)
193
+
194
+ # 迭代并修复重叠
195
+ if len(sorted_notes) > 1:
196
+ for i in range(len(sorted_notes) - 1):
197
+ current_note = sorted_notes[i]
198
+ next_note = sorted_notes[i+1]
199
+
200
+ # 如果当前音符的结束时间晚于或等于下一个音符的开始时间
201
+ if current_note.end >= next_note.start:
202
+ # 修正当前音符的结束时间,让它在下一个音符开始前结束
203
+ # 我们可以让它在下一个音符开始时就结束
204
+ current_note.end = next_note.start
205
+ # 如果修复后导致时长为0,则丢弃该音符(或者设置为1 tick,这里选择前者更干净)
206
+ if current_note.start >= current_note.end:
207
+ # 标记为待删除,而不是直接删除,以避免迭代问题
208
+ current_note.pitch = -1 # 用一个无效音高作为标记
209
+
210
+ # 将处理过的(且未被标记删除的)音符添加到最终列表
211
+ merged_notes.extend([n for n in sorted_notes if n.pitch != -1])
212
+
213
+ # --- 3. 收集并转换CC事件 ---
214
+ merged_cc = []
215
+ for instrument in midi_obj.instruments:
216
+ if not instrument.is_drum:
217
+ for cc in instrument.control_changes:
218
+ if cc.number == 64:
219
+ time_sec = tick_to_time_map[cc.time]
220
+ new_time_tick = round(time_sec * seconds_to_target_ticks_factor)
221
+ merged_cc.append(ControlChange(number=64,
222
+ value=cc.value,
223
+ time=new_time_tick))
224
+
225
+ # --- 4. 排序并创建新乐器 ---
226
+ merged_notes.sort(key=lambda x: (x.start, x.pitch))
227
+ merged_cc.sort(key=lambda x: (x.time, x.number))
228
+
229
+ output_instrument = Instrument(program=0, is_drum=False, name="Piano")
230
+ output_instrument.notes = merged_notes
231
+ output_instrument.control_changes = merged_cc
232
+ output_midi_obj.instruments.append(output_instrument)
233
+
234
+ # --- 5. 正确计算 max_tick ---
235
+ max_tick = 0
236
+ if output_instrument.notes:
237
+ max_tick = max(max_tick, max(n.end for n in output_instrument.notes if n.end is not None))
238
+ if output_instrument.control_changes:
239
+ max_tick = max(max_tick, max(c.time for c in output_instrument.control_changes if c.time is not None))
240
+
241
+ # 添加一个小的buffer,确保最后一个事件不会被截断
242
+ output_midi_obj.max_tick = max_tick + target_ticks_per_beat
243
+
244
+ return output_midi_obj
245
+
246
+ def midi_to_ids(config, midi_obj, normalize=True):
247
+ def get_pedal(time_list, ccs, time):
248
+ i = bisect.bisect_right(time_list, time)
249
+ if i == 0:
250
+ return 0
251
+ else:
252
+ return ccs[i-1].value
253
+ if normalize:
254
+ norm_midi_obj = normalize_midi(midi_obj)
255
+ else:
256
+ norm_midi_obj = midi_obj
257
+ time_list = [cc.time for cc in norm_midi_obj.instruments[0].control_changes]
258
+ #print(time_list)
259
+ intervals = []
260
+ last_time = 0
261
+ for note in norm_midi_obj.instruments[0].notes:
262
+ intervals.append(note.start - last_time)
263
+ last_time = note.start
264
+ intervals.append(4990)
265
+
266
+ ids = []
267
+ last_time = 0
268
+ for i, note in enumerate(norm_midi_obj.instruments[0].notes):
269
+ interval = config.timing_start + intervals[i]
270
+ #print(interval - interval_start)
271
+
272
+ pitch = config.pitch_start + note.pitch
273
+ velocity = config.velocity_start + note.velocity
274
+ duration = config.timing_start + note.duration
275
+ last_time = last_time + intervals[i]
276
+
277
+ pedal1 = config.pedal_start + get_pedal(time_list, norm_midi_obj.instruments[0].control_changes, last_time)
278
+ pedal2 = config.pedal_start + get_pedal(time_list, norm_midi_obj.instruments[0].control_changes, last_time + intervals[i+1] * 1 / 4)
279
+ pedal3 = config.pedal_start + get_pedal(time_list, norm_midi_obj.instruments[0].control_changes, last_time + intervals[i+1] * 2 / 4)
280
+ pedal4 = config.pedal_start + get_pedal(time_list, norm_midi_obj.instruments[0].control_changes, last_time + intervals[i+1] * 3 / 4)
281
+
282
+ pitch = min(config.valid_id_range[0][1] - 1, max(config.valid_id_range[0][0], pitch))
283
+ interval = min(config.valid_id_range[1][1] - 1, max(config.valid_id_range[1][0], interval))
284
+ velocity = min(config.valid_id_range[2][1] - 1, max(config.valid_id_range[2][0], velocity))
285
+ duration = min(config.valid_id_range[3][1] - 1, max(config.valid_id_range[3][0], duration))
286
+ pedal1 = min(config.valid_id_range[4][1] - 1, max(config.valid_id_range[4][0], pedal1))
287
+ pedal2 = min(config.valid_id_range[5][1] - 1, max(config.valid_id_range[5][0], pedal2))
288
+ pedal3 = min(config.valid_id_range[6][1] - 1, max(config.valid_id_range[6][0], pedal3))
289
+ pedal4 = min(config.valid_id_range[7][1] - 1, max(config.valid_id_range[7][0], pedal4))
290
+
291
+ ids.extend([pitch, interval, velocity, duration, pedal1, pedal2, pedal3, pedal4])
292
+ return ids
293
+
294
+ def ids_to_midi(config, ids, target_ticks_per_beat = 500, target_tempo = 120):
295
+ note_list = []
296
+ cc_list = []
297
+ intervals = []
298
+ for i in range(0, len(ids), 8):
299
+ intervals.append(ids[i+1] - config.timing_start)
300
+ intervals.append(4990)
301
+
302
+ last_time = 0
303
+ for i in range(0, len(ids), 8):
304
+ interval = intervals[i // 8]
305
+ pitch = ids[i] - config.pitch_start
306
+ velocity = ids[i+2] - config.velocity_start
307
+ duration = ids[i+3] - config.timing_start
308
+ pedal1 = ids[i+4] - config.pedal_start
309
+ pedal2 = ids[i+5] - config.pedal_start
310
+ pedal3 = ids[i+6] - config.pedal_start
311
+ pedal4 = ids[i+7] - config.pedal_start
312
+ note_list.append(Note(velocity, pitch, last_time + interval, last_time + interval + duration))
313
+ last_time += interval
314
+ #cc_list.append(ControlChange(64, pedal1, last_time))
315
+ #cc_list.append(ControlChange(64, pedal2, round(last_time + min(intervals[i // 8 + 1] * 1 / 10, 5))))
316
+ #cc_list.append(ControlChange(64, pedal3, round(last_time + max(intervals[i // 8 + 1] * 8 / 10, intervals[i // 8 + 1] * 8 / 10 - 10))))
317
+ #cc_list.append(ControlChange(64, pedal4, round(last_time + max(intervals[i // 8 + 1] * 9 / 10, intervals[i // 8 + 1] * 9 / 10 - 5))))
318
+ cc_list.append(ControlChange(64, pedal1, last_time))
319
+ cc_list.append(ControlChange(64, pedal2, round(last_time + intervals[i // 8 + 1] * 1 / 4)))
320
+ cc_list.append(ControlChange(64, pedal3, round(last_time + intervals[i // 8 + 1] * 2 / 4)))
321
+ cc_list.append(ControlChange(64, pedal4, round(last_time + intervals[i // 8 + 1] * 3 / 4)))
322
+
323
+ max_tick = 0
324
+ for note in note_list:
325
+ max_tick = max(max_tick, note.end)
326
+ for cc in cc_list:
327
+ max_tick = max(max_tick, cc.time)
328
+ max_tick = max_tick + 1
329
+
330
+ output = MidiFile(ticks_per_beat=target_ticks_per_beat)
331
+ output.instruments.append(Instrument(program=0, is_drum=False, name="Piano", notes=note_list, control_changes=cc_list))
332
+ output.tempo_changes.append(TempoChange(target_tempo, 0))
333
+ output.max_tick = max_tick
334
+
335
+ return output
336
+
337
+ def read_corresp(corresp_path):
338
+ out = []
339
+ performacne_id_list = []
340
+ with open(corresp_path, "r") as f:
341
+ align_txt = f.readlines()
342
+
343
+ score_ids_map = {}
344
+ performance_ids_map = {}
345
+ score_temp_list = []
346
+ performance_temp_list = set()
347
+ for line in align_txt[1:]:
348
+ informs = line.split("\t")
349
+ if informs[0] != '*':
350
+ score_temp_list.append((float(informs[1]), int(informs[3]), int(informs[0])))
351
+ if informs[5] != '*':
352
+ performance_temp_list.add((float(informs[6]), int(informs[8]), int(informs[5])))
353
+ performance_temp_list = list(performance_temp_list)
354
+ score_temp_list.sort()
355
+ performance_temp_list.sort()
356
+ for i, inform in enumerate(score_temp_list):
357
+ score_ids_map[inform[2]] = i
358
+ for i, inform in enumerate(performance_temp_list):
359
+ performance_ids_map[inform[2]] = i
360
+
361
+ for line in align_txt[1:]:
362
+ informs = line.split("\t")
363
+ if informs[0] == '*':
364
+ break
365
+ if informs[5] != '*':
366
+ out.append((score_ids_map[int(informs[0])], performance_ids_map[int(informs[5])]))
367
+ else:
368
+ out.append((score_ids_map[int(informs[0])], -1))
369
+
370
+ for line in align_txt[1:]:
371
+ informs = line.split("\t")
372
+ if informs[5] != '*':
373
+ performacne_id_list.append(performance_ids_map[int(informs[5])])
374
+ if out[0][1] == -1:
375
+ out[0] = (out[0][0], min(performacne_id_list))
376
+ if out[-1][1] == -1:
377
+ out[-1] = (out[-1][0], max(performacne_id_list))
378
+ out.sort()
379
+ return out
380
+
381
+ def interpolate(a, b):
382
+ a = np.array(a) + np.linspace(0, 1e-5, len(a))
383
+ b = np.array(b)
384
+ known_inds = np.where(~np.isnan(b))[0]
385
+ x_known = a[known_inds]
386
+ y_known = b[known_inds]
387
+ res = np.interp(a, x_known, y_known)
388
+ res[known_inds] = b[known_inds]
389
+ return [round(i) for i in res.tolist()]
390
+
391
+ def segment_sequences(x, label, unknown_ids, total_notes, max_consecutive_missing, min_segment_notes):
392
+
393
+ if not unknown_ids:
394
+ if total_notes >= min_segment_notes:
395
+ return [x], [label]
396
+ else:
397
+ return [], []
398
+
399
+ x_segments = []
400
+ label_segments = []
401
+
402
+ unknown_set = set(unknown_ids)
403
+
404
+ last_cut_note_idx = 0
405
+ consecutive_missing_count = 0
406
+
407
+ for i in range(total_notes):
408
+ if i in unknown_set:
409
+ consecutive_missing_count += 1
410
+ else:
411
+ consecutive_missing_count = 0
412
+
413
+ if consecutive_missing_count >= max_consecutive_missing:
414
+ segment_end_note_idx = i - consecutive_missing_count + 1
415
+
416
+ if segment_end_note_idx - last_cut_note_idx >= min_segment_notes:
417
+ start_token = last_cut_note_idx * 8
418
+ end_token = segment_end_note_idx * 8
419
+
420
+ x_segments.append(x[start_token:end_token])
421
+ label_segments.append(label[start_token:end_token])
422
+
423
+ last_cut_note_idx = i + 1
424
+ consecutive_missing_count = 0
425
+
426
+ if total_notes - last_cut_note_idx >= min_segment_notes:
427
+ start_token = last_cut_note_idx * 8
428
+ x_segments.append(x[start_token:])
429
+ label_segments.append(label[start_token:])
430
+
431
+ return x_segments, label_segments
432
+
433
+ def align_score_and_performance(config, score_midi_obj, performance_midi_obj):
434
+ norm_score_midi_obj = normalize_midi(score_midi_obj)
435
+ norm_performance_midi_obj = normalize_midi(performance_midi_obj)
436
+
437
+ norm_score_midi_obj.dump("temp/score.mid")
438
+ norm_performance_midi_obj.dump("temp/performance.mid")
439
+
440
+ os.chdir("./tools/AlignmentTool")
441
+ os.system(f"timeout 120s ./MIDIToMIDIAlign.sh ../../temp/performance ../../temp/score")
442
+ os.chdir("./../../")
443
+
444
+ corresp_list = read_corresp("temp/score_corresp.txt")
445
+ aligned_midi_obj = MidiFile(ticks_per_beat=500)
446
+ score_notes = norm_score_midi_obj.instruments[0].notes
447
+ performance_notes = norm_performance_midi_obj.instruments[0].notes
448
+ score_start_list = []
449
+ output_notes = []
450
+ output_ccs = []
451
+ vel_list = []
452
+ start_list = []
453
+ duration_list = []
454
+ unknown_ids = []
455
+ for i, ids in enumerate(corresp_list):
456
+ if ids[1] != -1:
457
+ vel_list.append(performance_notes[ids[1]].velocity)
458
+ start_list.append(performance_notes[ids[1]].start)
459
+ duration_list.append(performance_notes[ids[1]].end - performance_notes[ids[1]].start)
460
+ else:
461
+ vel_list.append(np.nan)
462
+ duration_list.append(np.nan)
463
+ unknown_ids.append(i)
464
+ score_start_list.append(score_notes[ids[0]].start)
465
+ start_list.sort()
466
+ temp = []
467
+ cnt = 0
468
+ for i in range(len(corresp_list)):
469
+ if i not in unknown_ids:
470
+ temp.append(start_list[cnt])
471
+ cnt += 1
472
+ else:
473
+ temp.append(np.nan)
474
+ start_list = interpolate(score_start_list, temp)
475
+ vel_list = interpolate(start_list, vel_list)
476
+ duration_list = interpolate(start_list, duration_list)
477
+
478
+ end_list = []
479
+ for i, ids in enumerate(corresp_list):
480
+ end = start_list[i]+duration_list[i]
481
+ end_list.append(end)
482
+ output_notes.append(Note(vel_list[i], score_notes[ids[0]].pitch, start_list[i], end))
483
+ max_tick = max(end_list) + 4999
484
+ for cc in norm_performance_midi_obj.instruments[0].control_changes:
485
+ if cc.time <= max_tick:
486
+ output_ccs.append(cc)
487
+ else:
488
+ break
489
+
490
+ aligned_midi_obj.instruments.append(Instrument(program=0, is_drum=False, name="Piano", notes=output_notes, control_changes=output_ccs))
491
+ x = midi_to_ids(config, norm_score_midi_obj)
492
+ label = midi_to_ids(config, aligned_midi_obj, normalize=False)
493
+ assert(len(x) == len(label))
494
+ for i in range(len(x)):
495
+ if i % 8 == 0:
496
+ assert(x[i] == label[i])
497
+
498
+ total_notes = len(score_notes)
499
+ xs, labels = segment_sequences(
500
+ x,
501
+ label,
502
+ unknown_ids,
503
+ total_notes,
504
+ 5,
505
+ 64,
506
+ )
507
+ return xs, labels
508
+
509
+ def enhanced_ids(config, ids):
510
+ res = copy(ids)
511
+ retry = 10
512
+ for i in range(len(res)):
513
+ j = i % 8
514
+ if j == 3:
515
+ value = res[i] - config.valid_id_range[j][0]
516
+ if value == 10:
517
+ noise = 0
518
+ for _ in range(retry):
519
+ n = round(np.random.randn() * 5)
520
+ if n >= -9 and n <= 5:
521
+ noise = n
522
+ break
523
+ else:
524
+ noise = 0
525
+ for _ in range(retry):
526
+ n = round(np.random.randn() * 5)
527
+ if n >= -4 and n <= 5:
528
+ noise = n
529
+ break
530
+ value = min(max(value + noise, 0), 4999)
531
+ res[i] = config.valid_id_range[j][0] + value
532
+ elif j == 2:
533
+ value = res[i] - config.valid_id_range[j][0]
534
+ if value == 5:
535
+ noise = 0
536
+ for _ in range(retry):
537
+ n = round(np.random.randn() * 2.5)
538
+ if n >= -4 and n <= 2:
539
+ noise = n
540
+ break
541
+ elif value == 120:
542
+ noise = 0
543
+ for _ in range(retry):
544
+ n = round(np.random.randn() * 2.5)
545
+ if n >= -2 and n <= 7:
546
+ noise = n
547
+ break
548
+ else:
549
+ noise = 0
550
+ for _ in range(retry):
551
+ n = round(np.random.randn() * 2.5)
552
+ if n >= -2 and n <= 2:
553
+ noise = n
554
+ break
555
+ value = min(max(value + noise, 0), 127)
556
+ res[i] = config.valid_id_range[j][0] + value
557
+ elif j == 1:
558
+ value = res[i] - config.valid_id_range[j][0]
559
+ noise = 0
560
+ for _ in range(retry):
561
+ n = round(np.random.randn() * 5)
562
+ if n >= -4 and n <= 5:
563
+ noise = n
564
+ break
565
+ value = min(max(value + noise, 0), 4990)
566
+ res[i] = config.valid_id_range[j][0] + value
567
+ return res
568
+
569
+ def enhanced_ids_uniform(config, ids):
570
+ res = copy(ids)
571
+ for i in range(len(res)):
572
+ j = i % 8
573
+ if j == 3:
574
+ value = res[i] - config.valid_id_range[j][0]
575
+ if value == 10:
576
+ noise = random.randint(-9, 5)
577
+ else:
578
+ noise = random.randint(-4, 5)
579
+ value = min(max(value + noise, 0), 4999)
580
+ res[i] = config.valid_id_range[j][0] + value
581
+ elif j == 2:
582
+ value = res[i] - config.valid_id_range[j][0]
583
+ if value == 5:
584
+ noise = random.randint(-4, 2)
585
+ elif value == 120:
586
+ noise = random.randint(-2, 7)
587
+ else:
588
+ noise = random.randint(-2, 2)
589
+ value = min(max(value + noise, 0), 127)
590
+ res[i] = config.valid_id_range[j][0] + value
591
+ elif j == 1:
592
+ value = res[i] - config.valid_id_range[j][0]
593
+ noise = random.randint(-4, 5)
594
+ value = min(max(value + noise, 0), 4990)
595
+ res[i] = config.valid_id_range[j][0] + value
596
+ return res
597
+
598
+ #if __name__ == "__main__":
599
+ # midi_obj = MidiFile("data/midi/test/2.mid")
600
+ # ids = midi_to_ids(midi_obj)
601
+ # midi = ids_to_midi(ids)
602
+ # midi.dump("data/rebuild/2.mid")