Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| import torch.nn as nn | |
| import audresample | |
| import matplotlib.pyplot as plt | |
| from matplotlib import colors as mcolors | |
| import torch | |
| import librosa | |
| import numpy as np | |
| import types | |
| from transformers import AutoModelForAudioClassification | |
| from transformers.models.wav2vec2.modeling_wav2vec2 import (Wav2Vec2Model, | |
| Wav2Vec2PreTrainedModel) | |
| plt.style.use('seaborn-v0_8-whitegrid') | |
| class ADV(nn.Module): | |
| def __init__(self, config): | |
| super().__init__() | |
| self.dense = nn.Linear(config.hidden_size, config.hidden_size) | |
| self.out_proj = nn.Linear(config.hidden_size, config.num_labels) | |
| def forward(self, x): | |
| x = self.dense(x) | |
| x = torch.tanh(x) | |
| return self.out_proj(x) | |
| class Dawn(Wav2Vec2PreTrainedModel): | |
| r"""https://arxiv.org/abs/2203.07378""" | |
| def __init__(self, config): | |
| super().__init__(config) | |
| self.wav2vec2 = Wav2Vec2Model(config) | |
| self.classifier = ADV(config) | |
| def forward(self, x): | |
| x -= x.mean(1, keepdim=True) | |
| variance = (x * x).mean(1, keepdim=True) + 1e-7 | |
| x = self.wav2vec2(x / variance.sqrt()) | |
| return self.classifier(x.last_hidden_state.mean(1)) | |
| def _forward(self, x): | |
| '''x: (batch, audio-samples-16KHz)''' | |
| x = (x + self.config.mean) / self.config.std # sgn | |
| x = self.ssl_model(x, attention_mask=None).last_hidden_state | |
| # pool | |
| h = self.pool_model.sap_linear(x).tanh() | |
| w = torch.matmul(h, self.pool_model.attention).softmax(1) | |
| mu = (x * w).sum(1) | |
| x = torch.cat( | |
| [ | |
| mu, | |
| ((x * x * w).sum(1) - mu * mu).clamp(min=1e-7).sqrt() | |
| ], 1) | |
| return self.ser_model(x) | |
| # WavLM | |
| device = 'cpu' | |
| base = AutoModelForAudioClassification.from_pretrained( | |
| '3loi/SER-Odyssey-Baseline-WavLM-Multi-Attributes', | |
| trust_remote_code=True).to(device).eval() | |
| base.forward = types.MethodType(_forward, base) | |
| # Wav2Vec2 | |
| dawn = Dawn.from_pretrained( | |
| 'audeering/wav2vec2-large-robust-12-ft-emotion-msp-dim' | |
| ).to(device).eval() | |
| # Wav2Small | |
| import torch | |
| import numpy as np | |
| import torch.nn.functional as F | |
| import librosa | |
| from transformers.models.wav2vec2.modeling_wav2vec2 import Wav2Vec2PreTrainedModel, Wav2Vec2Model | |
| from torch import nn | |
| from transformers import PretrainedConfig | |
| def _prenorm(x, attention_mask=None): | |
| '''mean/var''' | |
| if attention_mask is not None: | |
| N = attention_mask.sum(1, keepdim=True) # here attn msk is unprocessed just the original input | |
| x -= x.sum(1, keepdim=True) / N | |
| var = (x * x).sum(1, keepdim=True) / N | |
| else: | |
| x -= x.mean(1, keepdim=True) # mean is an onnx operator reducemean saves some ops compared to casting integer N to float and the div | |
| var = (x * x).mean(1, keepdim=True) | |
| return x / torch.sqrt(var + 1e-7) | |
| class Spectrogram(nn.Module): | |
| def __init__(self, | |
| n_fft=64, # num cols of DFT | |
| n_time=64, # num rows of DFT matrix | |
| hop_length=32, | |
| freeze_parameters=True): | |
| super().__init__() | |
| fft_window = librosa.filters.get_window('hann', n_time, fftbins=True) | |
| fft_window = librosa.util.pad_center(fft_window, size=n_time) | |
| out_channels = n_fft // 2 + 1 | |
| (x, y) = np.meshgrid(np.arange(n_time), np.arange(n_fft)) | |
| omega = np.exp(-2 * np.pi * 1j / n_time) | |
| dft_matrix = np.power(omega, x * y) # (n_fft, n_time) | |
| dft_matrix = dft_matrix * fft_window[None, :] | |
| dft_matrix = dft_matrix[0 : out_channels, :] | |
| dft_matrix = dft_matrix[:, None, :] | |
| # ---- Assymetric DFT Non Square | |
| self.conv_real = nn.Conv1d(1, out_channels, n_fft, stride=hop_length, padding=0, bias=False) | |
| self.conv_imag = nn.Conv1d(1, out_channels, n_fft, stride=hop_length, padding=0, bias=False) | |
| self.conv_real.weight.data = torch.tensor(np.real(dft_matrix), dtype=self.conv_real.weight.dtype).to(self.conv_real.weight.device) | |
| self.conv_imag.weight.data = torch.tensor(np.imag(dft_matrix), dtype=self.conv_imag.weight.dtype).to(self.conv_imag.weight.device) | |
| if freeze_parameters: | |
| for param in self.parameters(): | |
| param.requires_grad = False | |
| def forward(self, input): | |
| x = input[:, None, :] | |
| real = self.conv_real(x) | |
| imag = self.conv_imag(x) | |
| return real ** 2 + imag ** 2 # bs, mel, time-frames | |
| class LogmelFilterBank(nn.Module): | |
| def __init__(self, | |
| sr=16000, | |
| n_fft=64, | |
| n_mels=26, # maxpool | |
| fmin=0.0, | |
| freeze_parameters=True): | |
| super().__init__() | |
| fmax = sr//2 | |
| W2 = librosa.filters.mel(sr=sr, | |
| n_fft=n_fft, | |
| n_mels=n_mels, | |
| fmin=fmin, | |
| fmax=fmax).T | |
| self.register_buffer('melW', torch.Tensor(W2)) | |
| self.register_buffer('amin', torch.Tensor([1e-10])) | |
| def forward(self, x): | |
| x = torch.matmul(x[:, None, :, :].transpose(2, 3), self.melW) # changes melf not num frames | |
| x = torch.where(x > self.amin, x, self.amin) # not in place | |
| x = 10 * torch.log10(x) | |
| return x | |
| def length_after_conv_layer(_length, k=None, pad=None, stride=None): | |
| return torch.floor( (_length + 2*pad - k) / stride + 1 ) | |
| class Conv(nn.Module): | |
| def __init__(self, c_in, c_out, k=3, stride=1, padding=1): | |
| super().__init__() | |
| self.conv = nn.Conv2d(c_in, c_out, k, stride=stride, padding=padding, bias=False) | |
| self.norm = nn.BatchNorm2d(c_out) | |
| def forward(self, x): | |
| x = self.conv(x) | |
| x = self.norm(x) | |
| return torch.relu_(x) | |
| class Vgg7(nn.Module): | |
| def __init__(self): | |
| super().__init__() | |
| self.l1 = Conv( 1, 13) | |
| self.l2 = Conv(13, 13) | |
| self.l3 = Conv(13, 13) | |
| self.maxpool_A = nn.MaxPool2d(3, | |
| stride=2, | |
| padding=1) | |
| self.l4 = Conv(13, 13) | |
| self.l5 = Conv(13, 13) | |
| self.l6 = Conv(13, 13) | |
| self.l7 = Conv(13, 13) | |
| self.lin = nn.Conv2d(13, 13, 1, padding=0, stride=1) | |
| self.sof = nn.Conv2d(13, 13, 1, padding=0, stride=1) # pool time - reshape mel into channels after pooling | |
| self.spectrogram_extractor = Spectrogram() | |
| self.logmel_extractor = LogmelFilterBank() | |
| def final_length(self, L): | |
| conv_kernel = [64, 3] # [nfft, maxpool] | |
| conv_stride = [32, 2] # [hop_len, maxpool_stride] # consider only layers of stride > 1 | |
| conv_pad = [0, 1] # [pad_stft, pad_maxpool] | |
| for k, stride, pad in zip(conv_kernel, conv_stride, conv_pad): | |
| L = length_after_conv_layer(L, k=k, stride=stride, pad=pad) | |
| return L | |
| def final_attention_mask(self, feature_vector_length, attention_mask=None): | |
| non_padded_lengths = attention_mask.sum(1) | |
| out_lengths = self.final_length(non_padded_lengths) # how can non_padded_lengths get exact 0 here DOES IT MEAN ATTNMASK WAS NOT FILLED? | |
| out_lengths = out_lengths.to(torch.long) | |
| bs, _ = attention_mask.shape | |
| attention_mask = torch.ones((bs, feature_vector_length), | |
| dtype=attention_mask.dtype, | |
| device=attention_mask.device) | |
| for b, _len in enumerate(out_lengths): | |
| attention_mask[b, _len:] = 0 | |
| return attention_mask | |
| def forward(self, x, attention_mask=None): | |
| x = _prenorm(x, | |
| attention_mask=attention_mask) | |
| x = self.spectrogram_extractor(x) | |
| x = self.logmel_extractor(x) | |
| x = self.l1(x) | |
| x = self.l2(x) | |
| x = self.l3(x) | |
| x = self.maxpool_A(x) # reshape here? so these conv will have large kernel | |
| x = self.l4(x) | |
| x = self.l5(x) | |
| x = self.l6(x) | |
| x = self.l7(x) | |
| if attention_mask is not None: | |
| bs, _, t, _ = x.shape | |
| a = self.final_attention_mask(feature_vector_length=t, | |
| attention_mask=attention_mask)[:, None, :, None] | |
| #print(a.shape, x.shape, '\n\n\n\n') | |
| x = torch.masked_fill(x, a < 1, 0) | |
| # mask also affects lin !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! | |
| x = self.lin(x) * ( self.sof(x) -10000. * torch.logical_not(a) ).softmax(2) | |
| else: | |
| x = self.lin(x) * self.sof(x).softmax(2) | |
| x = x.sum(2) # bs, ch, time-frames, HALF_MEL -> bs, ch, HALF_MEL | |
| # -- | |
| xT = x.transpose(1,2) | |
| x = torch.cat([x, | |
| torch.bmm(x, xT), # corr (chxmel) x (melxCH) | |
| # torch.bmm(x, x), # corr ch * ch | |
| # torch.bmm(xT, xT) # corr mel * mel | |
| ], 2) | |
| # -- | |
| return x.reshape(-1, 338) | |
| class Wav2SmallConfig(PretrainedConfig): | |
| model_type = "wav2vec2" | |
| def __init__(self, | |
| **kwargs): | |
| super().__init__(**kwargs) | |
| self.half_mel = 13 | |
| self.n_fft = 64 | |
| self.n_time = 64 | |
| self.hidden = 2 * self.half_mel * self.half_mel | |
| self.hop = self.n_time // 2 | |
| class Wav2Small(Wav2Vec2PreTrainedModel): | |
| def __init__(self, | |
| config): | |
| super().__init__(config) | |
| self.vgg7 = Vgg7() | |
| self.adv = nn.Linear(config.hidden, 3) # 0=arousal, 1=dominance, 2=valence | |
| def forward(self, x, attention_mask=None): | |
| x = self.vgg7(x, attention_mask=attention_mask) | |
| return self.adv(x) | |
| def _ccc(x, y): | |
| '''if len(x) = len(y) = 1 we have 0/0 as a&b can both be negative we should add 1e-7 to denominator protecting sign of denominator | |
| to find sign of denominator and add 1e-7 if sgn>=0 or -1e-7 if sgn<0''' | |
| mean_y = y.mean() | |
| mean_x = x.mean() | |
| a = x - mean_x | |
| b = y - mean_y | |
| L = (mean_x - mean_y).abs() * .1 * x.shape[0] | |
| #print(L / ((mean_x - mean_y) **2 * x.shape[0])) | |
| numerator = torch.dot(a, b) # L term if both a,b scalars dissallows 0 numerator [OFFICIAL CCC HAS L ONLY IN D] | |
| denominator = torch.dot(a, a) + torch.dot(b, b) + L # if both a,b are equalscalars then the dots are all zero and ccc=1 | |
| denominator = torch.where(denominator.sign() < 0, | |
| denominator - 1e-7, | |
| denominator + 1e-7) | |
| ccc = numerator / denominator | |
| return -ccc #+ F.l1_loss(a, b) | |
| wav2small = Wav2Small.from_pretrained('audeering/wav2small').to(device).eval() | |
| # Error figure for the first plot | |
| fig_error, ax = plt.subplots(figsize=(8, 6)) | |
| error_message = "Error: No .wav or Mic. audio provided." | |
| ax.text(0.5, 0.5, error_message, | |
| ha='center', | |
| va='center', | |
| fontsize=24, | |
| color='gray', | |
| fontweight='bold', | |
| transform=ax.transAxes) | |
| ax.set_xticks([]) | |
| ax.set_yticks([]) | |
| ax.set_xticklabels([]) | |
| ax.set_yticklabels([]) | |
| ax.set_frame_on(True) | |
| ax.spines['top'].set_visible(False) | |
| ax.spines['right'].set_visible(False) | |
| ax.spines['bottom'].set_visible(False) | |
| ax.spines['left'].set_visible(False) | |
| def process_audio(audio_filepath): | |
| if audio_filepath is None: | |
| return fig_error, fig_error | |
| waveform, sample_rate = librosa.load(audio_filepath, sr=None) | |
| # Resample audio to 16kHz if the sample rate is different | |
| if sample_rate != 16000: | |
| resampled_waveform_np = audresample.resample(waveform, sample_rate, 16000) | |
| else: | |
| resampled_waveform_np = waveform[None, :] | |
| x = torch.from_numpy(resampled_waveform_np[:, :64000]).to(torch.float) # only 4s for speed | |
| with torch.no_grad(): | |
| logits_dawn = dawn(x).cpu().numpy()[0, :] | |
| logits_wavlm = base(x).cpu().numpy()[0, :] | |
| # 17K params | |
| logits_wav2small = wav2small(x).cpu().numpy()[0, :] | |
| # --- Plot 1: Wav2Vec2 vs Wav2Small Teacher Outputs --- | |
| fig, ax = plt.subplots(figsize=(10, 6)) | |
| left_bars_data = logits_dawn.clip(0, 1) | |
| right_bars_data = logits_wav2small.clip(0, 1) | |
| bar_labels = ['\nArousal', '\nDominance', '\nValence'] | |
| y_pos = np.arange(len(bar_labels)) | |
| # Define colormaps for each category to ensure distinct colors | |
| category_colormaps = [plt.cm.Blues, plt.cm.Greys, plt.cm.Oranges] | |
| left_filled_colors = [] | |
| right_filled_colors = [] | |
| background_colors = [] | |
| # Assign specific shades for filled bars and background bars | |
| for i, cmap in enumerate(category_colormaps): | |
| left_filled_colors.append(cmap(0.74)) | |
| right_filled_colors.append(cmap(0.64)) | |
| background_colors.append(cmap(0.1)) | |
| # Plot transparent background bars | |
| for i in range(len(bar_labels)): | |
| ax.barh(y_pos[i], -1, color=background_colors[i], alpha=0.3, height=0.6) | |
| ax.barh(y_pos[i], 1, color=background_colors[i], alpha=0.3, height=0.6) | |
| # Plot the filled bars for actual data | |
| for i in range(len(bar_labels)): | |
| ax.barh(y_pos[i], -left_bars_data[i], color=left_filled_colors[i], alpha=1, height=0.6) | |
| ax.barh(y_pos[i], right_bars_data[i], color=right_filled_colors[i], alpha=1, height=0.6) | |
| # Add a central vertical axis divider | |
| ax.axvline(0, color='black', linewidth=0.8, linestyle='--') | |
| # Set x-axis limits and y-axis ticks/labels | |
| ax.set_xlim(-1, 1) | |
| ax.set_yticks(y_pos) | |
| ax.set_yticklabels(bar_labels, fontsize=12) | |
| # Custom formatter for x-axis to show absolute percentage values | |
| def abs_tick_formatter(x, pos): | |
| return f'{int(abs(x) * 100)}%' | |
| ax.xaxis.set_major_formatter(plt.FuncFormatter(abs_tick_formatter)) | |
| # Set plot title and x-axis label | |
| ax.set_title('', fontsize=16, pad=20) | |
| ax.set_xlabel('Wav2vec2 (Dawn) Wav2Small (17K param.)', fontsize=12) | |
| # Remove top, right, and left spines for a cleaner look | |
| ax.spines['top'].set_visible(False) | |
| ax.spines['right'].set_visible(False) | |
| ax.spines['left'].set_visible(False) | |
| # Add annotations (percentage values) to the filled bars | |
| for i in range(len(bar_labels)): | |
| ax.text(-left_bars_data[i] - 0.05, y_pos[i], f'{int(left_bars_data[i] * 100)}%', | |
| va='center', ha='right', color=left_filled_colors[i], fontweight='bold') | |
| ax.text(right_bars_data[i] + 0.05, y_pos[i], f'{int(right_bars_data[i] * 100)}%', | |
| va='center', ha='left', color=right_filled_colors[i], fontweight='bold') | |
| # -- PLOT 2 : WavLM / Wav2Small Teacher | |
| fig_2, ax_2 = plt.subplots(figsize=(10, 6)) | |
| left_bars_data = logits_wavlm.clip(0, 1) | |
| right_bars_data = (.5 * logits_dawn + .5 * logits_wavlm).clip(0, 1) | |
| bar_labels = ['\nArousal', '\nDominance', '\nValence'] | |
| y_pos = np.arange(len(bar_labels)) | |
| # Define colormaps for each category to ensure distinct colors | |
| category_colormaps = [plt.cm.Blues, plt.cm.Greys, plt.cm.Oranges] | |
| left_filled_colors = [] | |
| right_filled_colors = [] | |
| background_colors = [] | |
| # Assign specific shades for filled bars and background bars | |
| for i, cmap in enumerate(category_colormaps): | |
| left_filled_colors.append(cmap(0.74)) | |
| right_filled_colors.append(cmap(0.64)) | |
| background_colors.append(cmap(0.1)) | |
| # Plot transparent background bars | |
| for i in range(len(bar_labels)): | |
| ax_2.barh(y_pos[i], -1, color=background_colors[i], alpha=0.3, height=0.6) | |
| ax_2.barh(y_pos[i], 1, color=background_colors[i], alpha=0.3, height=0.6) | |
| # Plot the filled bars for actual data | |
| for i in range(len(bar_labels)): | |
| ax_2.barh(y_pos[i], -left_bars_data[i], color=left_filled_colors[i], alpha=1, height=0.6) | |
| ax_2.barh(y_pos[i], right_bars_data[i], color=right_filled_colors[i], alpha=1, height=0.6) | |
| # Add a central vertical axis divider | |
| ax_2.axvline(0, color='black', linewidth=0.8, linestyle='--') | |
| # Set x-axis limits and y-axis ticks/labels | |
| ax_2.set_xlim(-1, 1) | |
| ax_2.set_yticks(y_pos) | |
| ax_2.set_yticklabels(bar_labels, fontsize=12) | |
| # Custom formatter for x-axis to show absolute percentage values | |
| def abs_tick_formatter(x, pos): | |
| return f'{int(abs(x) * 100)}%' | |
| ax_2.xaxis.set_major_formatter(plt.FuncFormatter(abs_tick_formatter)) | |
| ax_2.set_title('', fontsize=16, pad=20) | |
| ax_2.set_xlabel('WavLM (Baseline) Wav2Small Teacher (0.4B param.)', fontsize=12) | |
| ax_2.spines['top'].set_visible(False) | |
| ax_2.spines['right'].set_visible(False) | |
| ax_2.spines['left'].set_visible(False) | |
| # Add annotations (percentage values) to the filled bars | |
| for i in range(len(bar_labels)): | |
| ax_2.text(-left_bars_data[i] - 0.05, y_pos[i], f'{int(left_bars_data[i] * 100)}%', | |
| va='center', ha='right', color=left_filled_colors[i], fontweight='bold') | |
| ax_2.text(right_bars_data[i] + 0.05, y_pos[i], f'{int(right_bars_data[i] * 100)}%', | |
| va='center', ha='left', color=right_filled_colors[i], fontweight='bold') | |
| return fig, fig_2 | |
| iface = gr.Interface( | |
| fn=process_audio, | |
| inputs=gr.Audio( | |
| sources=["microphone", "upload"], | |
| type="filepath", # Input type is file path | |
| label='' | |
| ), | |
| outputs=[ | |
| gr.Plot(label="Wav2Vec2 vs Wav2Small (17K params) Plot"), # First plot output | |
| gr.Plot(label="WavLM vs Wav2Small Teacher Plot"), # Second plot output | |
| ], | |
| title='', | |
| description='', | |
| flagging_mode="never", # Disables flagging feature | |
| examples=[ | |
| "female-46-neutral.wav", | |
| "female-20-happy.wav", | |
| "male-60-angry.wav", | |
| "male-27-sad.wav", | |
| ], | |
| css="footer {visibility: hidden}" # Hides the Gradio footer | |
| ) | |
| # Gradio Blocks for tabbed interface | |
| with gr.Blocks() as demo: | |
| # First tab for the existing Arousal/Dominance/Valence plots | |
| with gr.Tab(label="Arousal / Dominance / Valence"): | |
| iface.render() | |
| # Second tab for CCC (Concordance Correlation Coefficient) information | |
| with gr.Tab(label="CCC"): | |
| gr.Markdown('''<table style="width:500px"><tr><th colspan=5 >CCC MSP Podcast v1.7</th></tr> | |
| <tr> <td> </td><td>Arousal</td> <td>Dominance</td> <td>Valence</td> <td> Associated Paper </td> </tr> | |
| <tr> <td> <a href="https://huggingface.co/audeering/wav2vec2-large-robust-12-ft-emotion-msp-dim">Wav2Vec2</a></td><td>0.744</td><td>0.655</td><td> 0.638 </td><td> <a href="https://arxiv.org/abs/2203.07378">arXiv</a> </td> </tr> | |
| <tr> <td> <a href="https://huggingface.co/dkounadis/wav2small">Wav2Small Teacher</a></td><td> 0.762 </td> <td> 0.684 </td><td> 0.676 </td><td> <a href="https://arxiv.org/abs/2408.13920">arXiv</a> </td> </tr> | |
| </table> | |
| ''') | |
| # Launch the Gradio application | |
| if __name__ == "__main__": | |
| demo.launch(share=False) | |