gagannarula commited on
Commit
6d06ff9
·
verified ·
1 Parent(s): be9659f

gagan/modelmerging_and_multiturn (#6)

Browse files

- Model merging + audio cache (eabe9bf6849303d5a373ac512aa5ea120e210ea4)

Files changed (5) hide show
  1. NatureLM/config.py +1 -0
  2. NatureLM/models/NatureLM.py +188 -19
  3. app.py +72 -46
  4. configs/inference.yml +1 -0
  5. data_store.py +16 -11
NatureLM/config.py CHANGED
@@ -136,6 +136,7 @@ class GenerateConfig(BaseModel, extra="forbid", validate_assignment=True):
136
  temperature: float
137
  repetition_penalty: float
138
  length_penalty: float
 
139
 
140
 
141
  class ModelConfig(BaseModel, extra="forbid", validate_assignment=True):
 
136
  temperature: float
137
  repetition_penalty: float
138
  length_penalty: float
139
+ merging_alpha: float = 1.0
140
 
141
 
142
  class ModelConfig(BaseModel, extra="forbid", validate_assignment=True):
NatureLM/models/NatureLM.py CHANGED
@@ -12,8 +12,10 @@
12
  # See the License for the specific language governing permissions and
13
  # limitations under the License.
14
 
 
15
  import logging
16
  import os
 
17
  from pathlib import Path
18
  from typing import Literal, Union
19
 
@@ -35,8 +37,98 @@ from .Qformer import BertConfig, BertLMHeadModel
35
  from .utils import StoppingCriteriaSub
36
 
37
  torch.backends.cuda.matmul.allow_tf32 = True
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
38
 
39
- auth_token = os.getenv('llama')
40
 
41
  class NatureLM(nn.Module, PyTorchModelHubMixin):
42
  def __init__(
@@ -65,9 +157,16 @@ class NatureLM(nn.Module, PyTorchModelHubMixin):
65
  max_txt_len: int = 128,
66
  end_sym: str = "</s>",
67
  device: str = "cuda",
 
68
  ):
69
  super().__init__()
70
 
 
 
 
 
 
 
71
  self.beats_path = beats_path
72
  self.beats_cfg = beats_cfg
73
  self.use_audio_Qformer = use_audio_Qformer
@@ -84,7 +183,9 @@ class NatureLM(nn.Module, PyTorchModelHubMixin):
84
 
85
  logging.info(f"Llama path: {llama_path}")
86
  logging.info("Loading Llama Tokenizer")
87
- self.llama_tokenizer = AutoTokenizer.from_pretrained(llama_path, use_fast=False, use_auth_token=auth_token)
 
 
88
  self.llama_tokenizer.add_special_tokens({"pad_token": "[PAD]"})
89
  self.llama_tokenizer.padding_side = "right"
90
 
@@ -95,7 +196,6 @@ class NatureLM(nn.Module, PyTorchModelHubMixin):
95
  torch_dtype=torch.float32,
96
  attn_implementation="eager",
97
  device_map="cpu",
98
- use_auth_token=auth_token
99
  )
100
  # An issue with tiny-llama is that pad_token_id was set to -1, but
101
  # model.save_pretrained checks generation configs and does not allow -1 as
@@ -106,7 +206,6 @@ class NatureLM(nn.Module, PyTorchModelHubMixin):
106
  llama_path,
107
  torch_dtype=torch.bfloat16,
108
  attn_implementation=flash_attn,
109
- use_auth_token=auth_token
110
  )
111
 
112
  self.llama_model.resize_token_embeddings(len(self.llama_tokenizer))
@@ -135,7 +234,9 @@ class NatureLM(nn.Module, PyTorchModelHubMixin):
135
  self.beats = BEATs(cfg=BEATsConfig(dict(self.beats_cfg)))
136
 
137
  if self.beats_path:
138
- beats_ckpt = universal_torch_load(self.beats_path, cache_mode="none", map_location="cpu")
 
 
139
  self.beats.load_state_dict(beats_ckpt["model"])
140
 
141
  self.ln_audio = nn.LayerNorm(self.beats.cfg.encoder_embed_dim)
@@ -336,11 +437,15 @@ class NatureLM(nn.Module, PyTorchModelHubMixin):
336
  audio_embeds = self.ln_audio(audio_embeds)
337
 
338
  # Generate attention mask
339
- audio_atts = torch.ones(audio_embeds.size()[:-1], dtype=torch.long).to(audio_embeds.device)
 
 
340
 
341
  if self.window_level_Qformer:
342
  B, T, C = audio_embeds.shape # batch, T, Channels
343
- kernel = round(1500 * self.second_per_window / 30.0) # 160 ms patches; calculate kernel size
 
 
344
  stride = round(1500 * self.second_stride / 30.0) # Calculate stride size
345
  kernel = (1, kernel)
346
  stride = (1, stride)
@@ -360,7 +465,9 @@ class NatureLM(nn.Module, PyTorchModelHubMixin):
360
  audio_embeds_overlap, [0, 3, 2, 1]
361
  ) # (B, num_windows, kernel_size, C)
362
  audio_embeds = audio_embeds_overlap.reshape(-1, kernel[1], C)
363
- audio_atts = torch.ones(audio_embeds.size()[:-1], dtype=torch.long).to(audio_embeds.device)
 
 
364
 
365
  # Q-Former mechanism
366
  query_tokens = self.audio_query_tokens.expand(audio_embeds.shape[0], -1, -1)
@@ -376,13 +483,19 @@ class NatureLM(nn.Module, PyTorchModelHubMixin):
376
  if self.window_level_Qformer:
377
  audio_embeds = audio_embeds.view(B, -1, audio_embeds.size(2)).contiguous()
378
 
379
- audio_atts = torch.ones(audio_embeds.size()[:-1], dtype=torch.long).to(audio_embeds.device)
 
 
380
 
381
  elif self.htsat:
382
  # HTSAT processing
383
  audio_embeds = self.ln_audio(audio_embeds)
384
- audio_embeds = self.audio_llama_proj(audio_embeds).reshape(-1, 30, self.llama_model.config.hidden_size)
385
- audio_atts = torch.ones(audio_embeds.size()[:-1], dtype=torch.long).to(audio_embeds.device)
 
 
 
 
386
 
387
  else:
388
  raise NotImplementedError("no audio qformer or max pooling")
@@ -390,9 +503,32 @@ class NatureLM(nn.Module, PyTorchModelHubMixin):
390
  return audio_embeds, audio_atts
391
 
392
  def encode_audio(self, raw_wav, audio_padding_mask=None):
 
 
 
 
 
 
 
 
 
 
393
  with torch.autocast(self.device.type, dtype=torch.bfloat16):
394
  audio_embeds, audio_pad_mask = self.beats(raw_wav, padding_mask=audio_padding_mask)
395
- return self._encode_auditory_feature(audio_embeds=audio_embeds, audio_pad_mask=audio_pad_mask)
 
 
 
 
 
 
 
 
 
 
 
 
 
396
 
397
  def prompt_wrap(self, audio_embeds, audio_atts, prompt: list[str]):
398
  """Merge audio embeddings with embeddings of the tokens in the prompt.
@@ -440,7 +576,9 @@ class NatureLM(nn.Module, PyTorchModelHubMixin):
440
  wrapped_atts = []
441
 
442
  for part in prompt_parts:
443
- tokens = self.llama_tokenizer(part, return_tensors="pt", add_special_tokens=False).to(device)
 
 
444
  part_embeds = self.llama_embed_tokens(tokens.input_ids).squeeze(0)
445
  part_atts = tokens.attention_mask.squeeze(0)
446
  wrapped_embeds.append(part_embeds)
@@ -506,7 +644,9 @@ class NatureLM(nn.Module, PyTorchModelHubMixin):
506
 
507
  # BOS token embeddings
508
  bos_token_id = self.llama_tokenizer.bos_token_id
509
- bos = torch.full((batch_size, 1), bos_token_id, dtype=torch.long, device=audio_embeds.device)
 
 
510
  bos_embeds = self.llama_embed_tokens(bos)
511
 
512
  # Prepare lists to collect per-sample embeddings, attention masks, and targets
@@ -521,7 +661,9 @@ class NatureLM(nn.Module, PyTorchModelHubMixin):
521
 
522
  # Extract non-padded text embeddings and attention mask
523
  text_embed = to_regress_embeds[i][to_regress_tokens.attention_mask[i].bool()]
524
- text_att = to_regress_tokens.attention_mask[i][to_regress_tokens.attention_mask[i].bool()]
 
 
525
 
526
  # Extract corresponding targets for the text tokens
527
  target = targets[i][to_regress_tokens.attention_mask[i].bool()]
@@ -581,7 +723,9 @@ class NatureLM(nn.Module, PyTorchModelHubMixin):
581
  shift_logits.view(-1, nvocab), # Flatten to [batch_size * (seq_len-1), vocab_size]
582
  shift_labels.view(-1), # Flatten to [batch_size * (seq_len-1)]
583
  )
584
- loss_per_token = loss_per_token.view(shift_labels.size()) # Reshape back to [batch_size, seq_len-1]
 
 
585
 
586
  # Create mask
587
  mask = shift_labels != -100 # [batch_size, seq_len-1]
@@ -597,7 +741,9 @@ class NatureLM(nn.Module, PyTorchModelHubMixin):
597
  predicted_tokens = shift_logits.argmax(dim=-1) # [batch_size, seq_len-1]
598
 
599
  # Compute per-example correct counts
600
- correct_per_sample = ((predicted_tokens == shift_labels) & mask).sum(dim=1).float() # [batch_size]
 
 
601
  total_tokens_per_sample = mask.sum(dim=1).float() # [batch_size]
602
 
603
  # Total correct and total tokens across the batch
@@ -615,8 +761,31 @@ class NatureLM(nn.Module, PyTorchModelHubMixin):
615
 
616
  return {"loss": loss, "per_example_loss": loss_per_example}
617
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
618
  @torch.inference_mode()
619
- def generate(self, samples, generate_cfg, prompts):
 
 
 
 
620
  batch_size = len(prompts)
621
 
622
  raw_wav = samples["raw_wav"]
@@ -645,7 +814,7 @@ class NatureLM(nn.Module, PyTorchModelHubMixin):
645
  stopping_criteria = StoppingCriteriaList([StoppingCriteriaSub(stops=stop_words_ids)])
646
 
647
  with torch.autocast(self.device.type, dtype=torch.bfloat16):
648
- outputs = self.llama_model.generate(
649
  inputs_embeds=embeds.bfloat16(),
650
  max_new_tokens=generate_cfg.max_new_tokens,
651
  stopping_criteria=stopping_criteria,
 
12
  # See the License for the specific language governing permissions and
13
  # limitations under the License.
14
 
15
+ import hashlib
16
  import logging
17
  import os
18
+ from collections import OrderedDict
19
  from pathlib import Path
20
  from typing import Literal, Union
21
 
 
37
  from .utils import StoppingCriteriaSub
38
 
39
  torch.backends.cuda.matmul.allow_tf32 = True
40
+ auth_token = os.getenv("llama", None)
41
+
42
+
43
+ class AudioEncodingCache:
44
+ """LRU cache for audio encoding with content-based hashing."""
45
+
46
+ def __init__(self, capacity: int = 100):
47
+ self.capacity = capacity
48
+ self.cache = OrderedDict()
49
+ self.hits = 0
50
+ self.misses = 0
51
+
52
+ def _compute_hash(
53
+ self, raw_wav: torch.Tensor, audio_padding_mask: torch.Tensor | None = None
54
+ ) -> str:
55
+ """Compute a hash key from the audio tensor and padding mask."""
56
+ # Use a sample of the tensor for efficiency (first, middle, last portions)
57
+ B, L = raw_wav.shape
58
+ sample_size = min(1000, L) # Sample 1000 points or entire length if smaller
59
+
60
+ # Sample from beginning, middle, and end
61
+ indices = torch.cat(
62
+ [
63
+ torch.arange(min(sample_size // 3, L)),
64
+ torch.arange(L // 2, min(L // 2 + sample_size // 3, L)),
65
+ torch.arange(max(0, L - sample_size // 3), L),
66
+ ]
67
+ )
68
+
69
+ sampled_wav = raw_wav[:, indices].cpu().numpy().tobytes()
70
+
71
+ # Create hash from audio data, shape, and padding mask presence
72
+ hash_obj = hashlib.sha256(sampled_wav)
73
+ hash_obj.update(str(raw_wav.shape).encode())
74
+ hash_obj.update(str(raw_wav.dtype).encode())
75
+
76
+ if audio_padding_mask is not None:
77
+ mask_sample = audio_padding_mask[:, indices].cpu().numpy().tobytes()
78
+ hash_obj.update(mask_sample)
79
+ hash_obj.update(str(audio_padding_mask.shape).encode())
80
+ else:
81
+ hash_obj.update(b"no_mask")
82
+
83
+ return hash_obj.hexdigest()
84
+
85
+ def get(self, raw_wav: torch.Tensor, audio_padding_mask: torch.Tensor = None):
86
+ """Retrieve cached encoding if available."""
87
+ key = self._compute_hash(raw_wav, audio_padding_mask)
88
+
89
+ if key in self.cache:
90
+ self.hits += 1
91
+ # Move to end (most recently used)
92
+ self.cache.move_to_end(key)
93
+ return self.cache[key]
94
+
95
+ self.misses += 1
96
+ return None
97
+
98
+ def put(self, raw_wav: torch.Tensor, audio_padding_mask: torch.Tensor, value: tuple):
99
+ """Store encoding in cache (on CPU to save GPU memory)."""
100
+ key = self._compute_hash(raw_wav, audio_padding_mask)
101
+
102
+ # Move tensors to CPU for storage
103
+ audio_embeds, audio_atts = value
104
+ cached_value = (audio_embeds.cpu(), audio_atts.cpu())
105
+
106
+ # Add to cache
107
+ self.cache[key] = cached_value
108
+ self.cache.move_to_end(key)
109
+
110
+ # Evict oldest if over capacity
111
+ if len(self.cache) > self.capacity:
112
+ self.cache.popitem(last=False)
113
+
114
+ def clear(self):
115
+ """Clear the cache."""
116
+ self.cache.clear()
117
+ self.hits = 0
118
+ self.misses = 0
119
+
120
+ def get_stats(self):
121
+ """Get cache statistics."""
122
+ total = self.hits + self.misses
123
+ hit_rate = self.hits / total if total > 0 else 0
124
+ return {
125
+ "hits": self.hits,
126
+ "misses": self.misses,
127
+ "hit_rate": hit_rate,
128
+ "size": len(self.cache),
129
+ "capacity": self.capacity,
130
+ }
131
 
 
132
 
133
  class NatureLM(nn.Module, PyTorchModelHubMixin):
134
  def __init__(
 
157
  max_txt_len: int = 128,
158
  end_sym: str = "</s>",
159
  device: str = "cuda",
160
+ audio_encoding_cache_size: int = 100,
161
  ):
162
  super().__init__()
163
 
164
+ self.audio_encoding_cache = (
165
+ AudioEncodingCache(capacity=audio_encoding_cache_size)
166
+ if audio_encoding_cache_size > 0
167
+ else None
168
+ )
169
+
170
  self.beats_path = beats_path
171
  self.beats_cfg = beats_cfg
172
  self.use_audio_Qformer = use_audio_Qformer
 
183
 
184
  logging.info(f"Llama path: {llama_path}")
185
  logging.info("Loading Llama Tokenizer")
186
+ self.llama_tokenizer = AutoTokenizer.from_pretrained(
187
+ llama_path, use_fast=False, use_auth_token=auth_token
188
+ )
189
  self.llama_tokenizer.add_special_tokens({"pad_token": "[PAD]"})
190
  self.llama_tokenizer.padding_side = "right"
191
 
 
196
  torch_dtype=torch.float32,
197
  attn_implementation="eager",
198
  device_map="cpu",
 
199
  )
200
  # An issue with tiny-llama is that pad_token_id was set to -1, but
201
  # model.save_pretrained checks generation configs and does not allow -1 as
 
206
  llama_path,
207
  torch_dtype=torch.bfloat16,
208
  attn_implementation=flash_attn,
 
209
  )
210
 
211
  self.llama_model.resize_token_embeddings(len(self.llama_tokenizer))
 
234
  self.beats = BEATs(cfg=BEATsConfig(dict(self.beats_cfg)))
235
 
236
  if self.beats_path:
237
+ beats_ckpt = universal_torch_load(
238
+ self.beats_path, cache_mode="none", map_location="cpu"
239
+ )
240
  self.beats.load_state_dict(beats_ckpt["model"])
241
 
242
  self.ln_audio = nn.LayerNorm(self.beats.cfg.encoder_embed_dim)
 
437
  audio_embeds = self.ln_audio(audio_embeds)
438
 
439
  # Generate attention mask
440
+ audio_atts = torch.ones(audio_embeds.size()[:-1], dtype=torch.long).to(
441
+ audio_embeds.device
442
+ )
443
 
444
  if self.window_level_Qformer:
445
  B, T, C = audio_embeds.shape # batch, T, Channels
446
+ kernel = round(
447
+ 1500 * self.second_per_window / 30.0
448
+ ) # 160 ms patches; calculate kernel size
449
  stride = round(1500 * self.second_stride / 30.0) # Calculate stride size
450
  kernel = (1, kernel)
451
  stride = (1, stride)
 
465
  audio_embeds_overlap, [0, 3, 2, 1]
466
  ) # (B, num_windows, kernel_size, C)
467
  audio_embeds = audio_embeds_overlap.reshape(-1, kernel[1], C)
468
+ audio_atts = torch.ones(audio_embeds.size()[:-1], dtype=torch.long).to(
469
+ audio_embeds.device
470
+ )
471
 
472
  # Q-Former mechanism
473
  query_tokens = self.audio_query_tokens.expand(audio_embeds.shape[0], -1, -1)
 
483
  if self.window_level_Qformer:
484
  audio_embeds = audio_embeds.view(B, -1, audio_embeds.size(2)).contiguous()
485
 
486
+ audio_atts = torch.ones(audio_embeds.size()[:-1], dtype=torch.long).to(
487
+ audio_embeds.device
488
+ )
489
 
490
  elif self.htsat:
491
  # HTSAT processing
492
  audio_embeds = self.ln_audio(audio_embeds)
493
+ audio_embeds = self.audio_llama_proj(audio_embeds).reshape(
494
+ -1, 30, self.llama_model.config.hidden_size
495
+ )
496
+ audio_atts = torch.ones(audio_embeds.size()[:-1], dtype=torch.long).to(
497
+ audio_embeds.device
498
+ )
499
 
500
  else:
501
  raise NotImplementedError("no audio qformer or max pooling")
 
503
  return audio_embeds, audio_atts
504
 
505
  def encode_audio(self, raw_wav, audio_padding_mask=None):
506
+ # Only use cache during inference (not training)
507
+ if self.audio_encoding_cache is not None and not self.training:
508
+ cached_result = self.audio_encoding_cache.get(raw_wav, audio_padding_mask)
509
+ if cached_result is not None:
510
+ print("#### Audio encoding cache hit ####")
511
+ # Move cached tensors back to the model's device
512
+ audio_embeds, audio_atts = cached_result
513
+ return audio_embeds.to(self.device), audio_atts.to(self.device)
514
+
515
+ # Compute encoding if not cached
516
  with torch.autocast(self.device.type, dtype=torch.bfloat16):
517
  audio_embeds, audio_pad_mask = self.beats(raw_wav, padding_mask=audio_padding_mask)
518
+ result = self._encode_auditory_feature(
519
+ audio_embeds=audio_embeds, audio_pad_mask=audio_pad_mask
520
+ )
521
+
522
+ # Store in cache if enabled and in inference mode
523
+ if self.audio_encoding_cache is not None and not self.training:
524
+ self.audio_encoding_cache.put(raw_wav, audio_padding_mask, result)
525
+
526
+ return result
527
+
528
+ def clear_audio_embed_cache(self):
529
+ """Clear the audio encoding cache."""
530
+ if self.audio_encoding_cache is not None:
531
+ self.audio_encoding_cache.clear()
532
 
533
  def prompt_wrap(self, audio_embeds, audio_atts, prompt: list[str]):
534
  """Merge audio embeddings with embeddings of the tokens in the prompt.
 
576
  wrapped_atts = []
577
 
578
  for part in prompt_parts:
579
+ tokens = self.llama_tokenizer(
580
+ part, return_tensors="pt", add_special_tokens=False
581
+ ).to(device)
582
  part_embeds = self.llama_embed_tokens(tokens.input_ids).squeeze(0)
583
  part_atts = tokens.attention_mask.squeeze(0)
584
  wrapped_embeds.append(part_embeds)
 
644
 
645
  # BOS token embeddings
646
  bos_token_id = self.llama_tokenizer.bos_token_id
647
+ bos = torch.full(
648
+ (batch_size, 1), bos_token_id, dtype=torch.long, device=audio_embeds.device
649
+ )
650
  bos_embeds = self.llama_embed_tokens(bos)
651
 
652
  # Prepare lists to collect per-sample embeddings, attention masks, and targets
 
661
 
662
  # Extract non-padded text embeddings and attention mask
663
  text_embed = to_regress_embeds[i][to_regress_tokens.attention_mask[i].bool()]
664
+ text_att = to_regress_tokens.attention_mask[i][
665
+ to_regress_tokens.attention_mask[i].bool()
666
+ ]
667
 
668
  # Extract corresponding targets for the text tokens
669
  target = targets[i][to_regress_tokens.attention_mask[i].bool()]
 
723
  shift_logits.view(-1, nvocab), # Flatten to [batch_size * (seq_len-1), vocab_size]
724
  shift_labels.view(-1), # Flatten to [batch_size * (seq_len-1)]
725
  )
726
+ loss_per_token = loss_per_token.view(
727
+ shift_labels.size()
728
+ ) # Reshape back to [batch_size, seq_len-1]
729
 
730
  # Create mask
731
  mask = shift_labels != -100 # [batch_size, seq_len-1]
 
741
  predicted_tokens = shift_logits.argmax(dim=-1) # [batch_size, seq_len-1]
742
 
743
  # Compute per-example correct counts
744
+ correct_per_sample = (
745
+ ((predicted_tokens == shift_labels) & mask).sum(dim=1).float()
746
+ ) # [batch_size]
747
  total_tokens_per_sample = mask.sum(dim=1).float() # [batch_size]
748
 
749
  # Total correct and total tokens across the batch
 
761
 
762
  return {"loss": loss, "per_example_loss": loss_per_example}
763
 
764
+ def model_merging_scaling(self, merging_alpha, adapter_name="default"):
765
+ """
766
+ Performs model merging with the base model by adjusting the scaling of the LoRA adapters as described in
767
+ "Model Merging Improves Zero-Shot Generalization in Bioacoustic Foundation Models"
768
+ (https://arxiv.org/abs/2511.05171).
769
+
770
+ The best value for alpha is task- and dataset-specific, but the paper found alpha values between
771
+ 0.4 and 0.6 to perform generally well.
772
+
773
+ Args:
774
+ merging_alpha: The merging_alpha used for interpolation.
775
+ adapter_name (str): The name of the adapter to rescale when merging.
776
+ """
777
+
778
+ for module in self.llama_model.modules():
779
+ # Check if the module is a LoRA layer and has the specified adapter
780
+ if hasattr(module, "r") and isinstance(module.r, dict) and adapter_name in module.r:
781
+ module.scaling[adapter_name] = merging_alpha * module.scaling[adapter_name]
782
+
783
  @torch.inference_mode()
784
+ def generate(self, samples, generate_cfg, prompts) -> list[str]:
785
+ merging_alpha = getattr(generate_cfg, "merging_alpha", 1.0)
786
+ if merging_alpha != 1.0:
787
+ self.model_merging_scaling(merging_alpha)
788
+
789
  batch_size = len(prompts)
790
 
791
  raw_wav = samples["raw_wav"]
 
814
  stopping_criteria = StoppingCriteriaList([StoppingCriteriaSub(stops=stop_words_ids)])
815
 
816
  with torch.autocast(self.device.type, dtype=torch.bfloat16):
817
+ outputs = self.llama_model.generate( # TODO: Wrap the llama_model with outlines https://outlines-dev.github.io/outlines/reference/models/transformers/
818
  inputs_embeds=embeds.bfloat16(),
819
  max_new_tokens=generate_cfg.max_new_tokens,
820
  stopping_criteria=stopping_criteria,
app.py CHANGED
@@ -1,4 +1,5 @@
1
  import spaces
 
2
  import warnings
3
  import traceback
4
  import numpy as np
@@ -17,19 +18,29 @@ from NatureLM.infer import Pipeline
17
 
18
  from data_store import upload_data
19
 
 
20
  warnings.filterwarnings("ignore")
21
  SAMPLE_RATE = 16000 # Default sample rate for NatureLM-audio
22
  DEVICE: str = "cuda" if torch.cuda.is_available() else "cpu"
 
 
 
 
23
 
24
  # Load model at startup if CUDA is available
25
  print(f"Device: {DEVICE}")
26
- if DEVICE == "cuda":
27
- print("CUDA available, loading model at startup...")
28
- model = NatureLM.from_pretrained("EarthSpeciesProject/NatureLM-audio")
29
- model = model.eval().to(DEVICE)
30
- model = Pipeline(model)
31
- else:
32
- print("CUDA not available, model will not be loaded at startup")
 
 
 
 
 
33
 
34
  def get_spectrogram(audio: torch.Tensor) -> plt.Figure:
35
  """Generate a spectrogram from the audio tensor."""
@@ -98,13 +109,11 @@ def prompt_lm(
98
  hop_length_seconds: float = 10.0,
99
  ) -> list[str]:
100
  """Generate response using the model
101
-
102
  Args:
103
  audios (list[str]): List of audio file paths
104
  queries (list[str] | str): Query or list of queries to process
105
  window_length_seconds (float): Length of the window for processing audio
106
  hop_length_seconds (float): Hop length for processing audio
107
-
108
  Returns:
109
  list[str]: List of generated responses for each audio-query pair
110
  """
@@ -157,28 +166,61 @@ def add_user_query(chatbot_history: list[dict], chat_input: str) -> list[dict]:
157
  return chatbot_history
158
 
159
 
160
- def send_data_to_hub(chatbot_history: list[dict], audio: str):
161
  """Upload data to hub"""
162
  if not chatbot_history or len(chatbot_history) < 2:
163
  return
164
  user_text = chatbot_history[-2]["content"]
165
  model_response = chatbot_history[-1]["content"]
166
- upload_data(audio, user_text, model_response)
167
 
168
 
169
  def get_response(chatbot_history: list[dict], audio_input: str) -> list[dict]:
170
- """Generate response from the model based on user input and audio file"""
171
  try:
172
- # Get the last user message from chat history
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
173
  last_user_message = ""
174
  for message in reversed(chatbot_history):
175
  if message["role"] == "user":
176
  last_user_message = message["content"]
177
  break
178
- print("\nUser message:", last_user_message)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
179
  response = prompt_lm(
180
  audios=[audio_input],
181
- queries=[last_user_message.strip()],
182
  window_length_seconds=100_000,
183
  hop_length_seconds=100_000,
184
  )
@@ -192,7 +234,7 @@ def get_response(chatbot_history: list[dict], audio_input: str) -> list[dict]:
192
  print(f"Error generating response: {e}")
193
  traceback.print_exc()
194
  response = "Error generating response. Please try again."
195
-
196
  # Add model response to chat history
197
  chatbot_history.append({"role": "assistant", "content": response})
198
 
@@ -201,17 +243,7 @@ def get_response(chatbot_history: list[dict], audio_input: str) -> list[dict]:
201
 
202
  def main(
203
  assets_dir: Path,
204
- cfg_path: str | Path,
205
- options: list[str] = [],
206
  ):
207
- # Load configuration
208
- try:
209
- cfg = Config.from_sources(yaml_file=cfg_path, cli_args=options)
210
- print("Configuration loaded successfully")
211
- except Exception as e:
212
- print(f"Warning: Could not load config: {e}")
213
- print("Running in demo mode")
214
-
215
  # Check if assets directory exists, if not create a placeholder
216
  if not assets_dir.exists():
217
  print(f"Warning: Assets directory {assets_dir} does not exist")
@@ -248,13 +280,11 @@ def main(
248
  "Caption the audio (Humpback Whale)": [str(whale_audio), "Caption the audio."],
249
  }
250
 
251
- gr.set_static_paths(paths=[Path.cwd().absolute()/"assets"])
252
 
253
  with gr.Blocks(
254
  title="NatureLM-audio",
255
- theme=gr.themes.Base(
256
- primary_hue="blue", font=[gr.themes.GoogleFont("Noto Sans")]
257
- )
258
  ) as app:
259
  with gr.Row():
260
  gr.HTML("""
@@ -272,7 +302,8 @@ def main(
272
 
273
  with gr.Tabs():
274
  with gr.Tab("Analyze Audio"):
275
- uploaded_audio = gr.State()
 
276
  # Status indicator
277
  # status_text = gr.Textbox(
278
  # value=model_manager.get_status(),
@@ -297,8 +328,6 @@ def main(
297
  """,
298
  padding=False,
299
  )
300
-
301
-
302
 
303
  with gr.Column(visible=True) as upload_section:
304
  audio_input = gr.Audio(
@@ -307,6 +336,14 @@ def main(
307
  interactive=True,
308
  sources=["upload"],
309
  )
 
 
 
 
 
 
 
 
310
  with gr.Accordion(
311
  label="Toggle Spectrogram", open=False, visible=False
312
  ) as spectrogram:
@@ -445,13 +482,11 @@ def main(
445
  [clear_button],
446
  ).then(
447
  send_data_to_hub,
448
- [chatbot, audio_input],
449
  None,
450
  )
451
 
452
- clear_button.click(
453
- lambda: gr.ClearButton(visible=False), None, [clear_button]
454
- )
455
 
456
  with gr.Tab("Sample Library"):
457
  with gr.Row():
@@ -494,7 +529,6 @@ def main(
494
  type="filepath",
495
  show_download_button=True,
496
  )
497
-
498
 
499
  with gr.Tab("💡 Help"):
500
  gr.HTML("""
@@ -519,7 +553,6 @@ def main(
519
  </ol>
520
  <p></p>
521
  </div>
522
-
523
  <div class="guide-section">
524
  <h3>Tips</h3>
525
  <b>Prompting Best Practices</b>
@@ -561,7 +594,6 @@ def main(
561
  background: white;
562
  flex: 1;
563
  }
564
-
565
  #chat-input .submit-button {
566
  padding: 10px;
567
  margin: 2px 6px;
@@ -590,7 +622,6 @@ def main(
590
  color: #374151;
591
  margin-bottom: 4px;
592
  }
593
-
594
  .banner .banner-text {
595
  style="font-size: 14px;
596
  color: #6b7280;
@@ -609,7 +640,6 @@ def main(
609
  display: inline-block;
610
  transition: background 0.2s ease;
611
  }
612
-
613
  .link-btn:hover {
614
  background: #2563eb;
615
  }
@@ -635,12 +665,10 @@ def main(
635
  #chat-input {
636
  background: #1e1e1e;
637
  }
638
-
639
  #chat-input textarea {
640
  background: #1e1e1e;
641
  color: white;
642
  }
643
-
644
  .banner {
645
  background: #1e1e1e;
646
  color: white;
@@ -657,8 +685,6 @@ def main(
657
  # Create and launch the app
658
  app = main(
659
  assets_dir=Path("assets"),
660
- cfg_path=Path("configs/inference.yml"),
661
- options=[],
662
  )
663
 
664
  if __name__ == "__main__":
 
1
  import spaces
2
+ import uuid
3
  import warnings
4
  import traceback
5
  import numpy as np
 
18
 
19
  from data_store import upload_data
20
 
21
+
22
  warnings.filterwarnings("ignore")
23
  SAMPLE_RATE = 16000 # Default sample rate for NatureLM-audio
24
  DEVICE: str = "cuda" if torch.cuda.is_available() else "cpu"
25
+ MIN_AUDIO_DURATION: float = 0.5 # seconds
26
+ MAX_HISTORY_TURNS = (
27
+ 3 # Maximum number of conversation turns to include in context (user + assistant pairs)
28
+ )
29
 
30
  # Load model at startup if CUDA is available
31
  print(f"Device: {DEVICE}")
32
+ model = NatureLM.from_pretrained("EarthSpeciesProject/NatureLM-audio")
33
+ model = model.eval().to(DEVICE)
34
+ model = Pipeline(model)
35
+
36
+
37
+ def check_audio_duration_greater(audio_path: str) -> bool:
38
+ """Check the duration of the audio file."""
39
+ info = torchaudio.info(audio_path)
40
+ duration = info.num_frames / info.sample_rate
41
+ if not duration >= MIN_AUDIO_DURATION:
42
+ raise gr.Error(f"Audio duration must be at least {MIN_AUDIO_DURATION} seconds.")
43
+
44
 
45
  def get_spectrogram(audio: torch.Tensor) -> plt.Figure:
46
  """Generate a spectrogram from the audio tensor."""
 
109
  hop_length_seconds: float = 10.0,
110
  ) -> list[str]:
111
  """Generate response using the model
 
112
  Args:
113
  audios (list[str]): List of audio file paths
114
  queries (list[str] | str): Query or list of queries to process
115
  window_length_seconds (float): Length of the window for processing audio
116
  hop_length_seconds (float): Hop length for processing audio
 
117
  Returns:
118
  list[str]: List of generated responses for each audio-query pair
119
  """
 
166
  return chatbot_history
167
 
168
 
169
+ def send_data_to_hub(chatbot_history: list[dict], audio: str, session_id: str):
170
  """Upload data to hub"""
171
  if not chatbot_history or len(chatbot_history) < 2:
172
  return
173
  user_text = chatbot_history[-2]["content"]
174
  model_response = chatbot_history[-1]["content"]
175
+ upload_data(audio, user_text, model_response, session_id)
176
 
177
 
178
  def get_response(chatbot_history: list[dict], audio_input: str) -> list[dict]:
179
+ """Generate response from the model based on user input and audio file with conversation history"""
180
  try:
181
+ # Warn if conversation is getting long
182
+ num_turns = len(chatbot_history)
183
+ if num_turns > MAX_HISTORY_TURNS * 2: # Each turn = user + assistant message
184
+ gr.Warning(
185
+ "⚠️ Long conversations may affect response quality. Consider starting a new conversation with the Clear button."
186
+ )
187
+
188
+ # Build conversation context from history
189
+ conversation_context = []
190
+ for message in chatbot_history:
191
+ if message["role"] == "user":
192
+ conversation_context.append(f"User: {message['content']}")
193
+ elif message["role"] == "assistant":
194
+ conversation_context.append(f"Assistant: {message['content']}")
195
+
196
+ # Get the last user message
197
  last_user_message = ""
198
  for message in reversed(chatbot_history):
199
  if message["role"] == "user":
200
  last_user_message = message["content"]
201
  break
202
+
203
+ # Format the full prompt with conversation history
204
+ if len(conversation_context) > 2: # More than just the current query
205
+ # Include previous turns (limit to last MAX_HISTORY_TURNS exchanges)
206
+ recent_context = conversation_context[
207
+ -(MAX_HISTORY_TURNS + 1) : -1
208
+ ] # Exclude current message
209
+
210
+ full_prompt = (
211
+ "Previous conversation:\n"
212
+ + "\n".join(recent_context)
213
+ + "\n\nCurrent question: "
214
+ + last_user_message
215
+ )
216
+ else:
217
+ full_prompt = last_user_message
218
+
219
+ print("\nFull prompt with history:", full_prompt)
220
+
221
  response = prompt_lm(
222
  audios=[audio_input],
223
+ queries=[full_prompt.strip()],
224
  window_length_seconds=100_000,
225
  hop_length_seconds=100_000,
226
  )
 
234
  print(f"Error generating response: {e}")
235
  traceback.print_exc()
236
  response = "Error generating response. Please try again."
237
+
238
  # Add model response to chat history
239
  chatbot_history.append({"role": "assistant", "content": response})
240
 
 
243
 
244
  def main(
245
  assets_dir: Path,
 
 
246
  ):
 
 
 
 
 
 
 
 
247
  # Check if assets directory exists, if not create a placeholder
248
  if not assets_dir.exists():
249
  print(f"Warning: Assets directory {assets_dir} does not exist")
 
280
  "Caption the audio (Humpback Whale)": [str(whale_audio), "Caption the audio."],
281
  }
282
 
283
+ gr.set_static_paths(paths=[Path.cwd().absolute() / "assets"])
284
 
285
  with gr.Blocks(
286
  title="NatureLM-audio",
287
+ theme=gr.themes.Base(primary_hue="blue", font=[gr.themes.GoogleFont("Noto Sans")]),
 
 
288
  ) as app:
289
  with gr.Row():
290
  gr.HTML("""
 
302
 
303
  with gr.Tabs():
304
  with gr.Tab("Analyze Audio"):
305
+ session_id = gr.State(str(uuid.uuid4()))
306
+ # uploaded_audio = gr.State()
307
  # Status indicator
308
  # status_text = gr.Textbox(
309
  # value=model_manager.get_status(),
 
328
  """,
329
  padding=False,
330
  )
 
 
331
 
332
  with gr.Column(visible=True) as upload_section:
333
  audio_input = gr.Audio(
 
336
  interactive=True,
337
  sources=["upload"],
338
  )
339
+ # check that audio duration is greater than MIN_AUDIO_DURATION
340
+ # raise
341
+ audio_input.change(
342
+ fn=check_audio_duration_greater,
343
+ inputs=[audio_input],
344
+ outputs=[],
345
+ )
346
+
347
  with gr.Accordion(
348
  label="Toggle Spectrogram", open=False, visible=False
349
  ) as spectrogram:
 
482
  [clear_button],
483
  ).then(
484
  send_data_to_hub,
485
+ [chatbot, audio_input, session_id],
486
  None,
487
  )
488
 
489
+ clear_button.click(lambda: gr.ClearButton(visible=False), None, [clear_button])
 
 
490
 
491
  with gr.Tab("Sample Library"):
492
  with gr.Row():
 
529
  type="filepath",
530
  show_download_button=True,
531
  )
 
532
 
533
  with gr.Tab("💡 Help"):
534
  gr.HTML("""
 
553
  </ol>
554
  <p></p>
555
  </div>
 
556
  <div class="guide-section">
557
  <h3>Tips</h3>
558
  <b>Prompting Best Practices</b>
 
594
  background: white;
595
  flex: 1;
596
  }
 
597
  #chat-input .submit-button {
598
  padding: 10px;
599
  margin: 2px 6px;
 
622
  color: #374151;
623
  margin-bottom: 4px;
624
  }
 
625
  .banner .banner-text {
626
  style="font-size: 14px;
627
  color: #6b7280;
 
640
  display: inline-block;
641
  transition: background 0.2s ease;
642
  }
 
643
  .link-btn:hover {
644
  background: #2563eb;
645
  }
 
665
  #chat-input {
666
  background: #1e1e1e;
667
  }
 
668
  #chat-input textarea {
669
  background: #1e1e1e;
670
  color: white;
671
  }
 
672
  .banner {
673
  background: #1e1e1e;
674
  color: white;
 
685
  # Create and launch the app
686
  app = main(
687
  assets_dir=Path("assets"),
 
 
688
  )
689
 
690
  if __name__ == "__main__":
configs/inference.yml CHANGED
@@ -59,3 +59,4 @@ generate:
59
  temperature: 0.1
60
  repetition_penalty: 1.0
61
  length_penalty: 1.0
 
 
59
  temperature: 0.1
60
  repetition_penalty: 1.0
61
  length_penalty: 1.0
62
+ merging_alpha: 0.5
data_store.py CHANGED
@@ -7,33 +7,38 @@ from huggingface_hub import HfApi, HfFileSystem
7
  DATASET_REPO = "EarthSpeciesProject/naturelm-audio-space-logs"
8
  SPLIT = "test"
9
  TESTING = os.getenv("TESTING", "0") == "1"
10
- api = HfApi(token=os.getenv("HF_TOKEN",None))
11
  # Upload audio
12
  # check if file exists
13
- hf_fs = HfFileSystem(token=os.getenv("HF_TOKEN",None))
14
 
15
 
16
- def upload_data(audio: str | Path, user_text: str, model_response: str):
17
  data_id = str(uuid.uuid4())
 
18
  if TESTING:
19
  data_id = "test-" + data_id
 
 
20
  # Audio path in repo
21
  suffix = Path(audio).suffix
22
- audio_p = f"{SPLIT}/audio/" + data_id + suffix
23
 
24
- api.upload_file(
25
- path_or_fileobj=str(audio),
26
- path_in_repo=audio_p,
27
- repo_id=DATASET_REPO,
28
- repo_type="dataset",
29
- )
 
30
 
31
  text = {
32
  "user_message": user_text,
33
  "model_response": model_response,
34
- "file_name": "audio/" + data_id + suffix, # has to be relative to metadata.jsonl
35
  "original_fn": os.path.basename(audio),
36
  "id": data_id,
 
37
  }
38
 
39
  # Append to a jsonl file in the repo
 
7
  DATASET_REPO = "EarthSpeciesProject/naturelm-audio-space-logs"
8
  SPLIT = "test"
9
  TESTING = os.getenv("TESTING", "0") == "1"
10
+ api = HfApi(token=os.getenv("HF_TOKEN", None))
11
  # Upload audio
12
  # check if file exists
13
+ hf_fs = HfFileSystem(token=os.getenv("HF_TOKEN", None))
14
 
15
 
16
+ def upload_data(audio: str | Path, user_text: str, model_response: str, session_id: str = ""):
17
  data_id = str(uuid.uuid4())
18
+
19
  if TESTING:
20
  data_id = "test-" + data_id
21
+ session_id = "test-" + session_id
22
+
23
  # Audio path in repo
24
  suffix = Path(audio).suffix
25
+ audio_p = f"{SPLIT}/audio/" + session_id + suffix
26
 
27
+ if not hf_fs.exists(f"datasets/{DATASET_REPO}/{audio_p}"):
28
+ api.upload_file(
29
+ path_or_fileobj=str(audio),
30
+ path_in_repo=audio_p,
31
+ repo_id=DATASET_REPO,
32
+ repo_type="dataset",
33
+ )
34
 
35
  text = {
36
  "user_message": user_text,
37
  "model_response": model_response,
38
+ "file_name": "audio/" + session_id + suffix, # has to be relative to metadata.jsonl
39
  "original_fn": os.path.basename(audio),
40
  "id": data_id,
41
+ "session_id": session_id,
42
  }
43
 
44
  # Append to a jsonl file in the repo