update preprocessing
Browse files
README.md
CHANGED
|
@@ -107,6 +107,19 @@ class PowerToDB(torch.nn.Module):
|
|
| 107 |
return log_spec
|
| 108 |
|
| 109 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 110 |
def preprocess(audio, sample_rate_of_audio):
|
| 111 |
"""
|
| 112 |
Preprocess the audio to the format that the model expects
|
|
@@ -115,30 +128,28 @@ def preprocess(audio, sample_rate_of_audio):
|
|
| 115 |
- Normalize the melscale spectrogram with mean: -4.268, std: 4.569 (from AudioSet)
|
| 116 |
|
| 117 |
"""
|
| 118 |
-
|
| 119 |
-
|
| 120 |
-
|
| 121 |
-
|
| 122 |
-
)
|
| 123 |
-
audio = resample(audio)
|
| 124 |
-
spectrogram = torchaudio.transforms.Spectrogram(
|
| 125 |
-
n_fft=1024, hop_length=320, power=2.0
|
| 126 |
-
)(audio)
|
| 127 |
-
melspec = torchaudio.transforms.MelScale(n_mels=128, n_stft=513)(spectrogram)
|
| 128 |
dbscale = powerToDB(melspec)
|
| 129 |
-
normalized_dbscale =
|
|
|
|
|
|
|
| 130 |
return normalized_dbscale
|
| 131 |
|
| 132 |
preprocessed_audio = preprocess(audio, sample_rate)
|
|
|
|
|
|
|
| 133 |
|
| 134 |
-
|
|
|
|
| 135 |
print("Logits shape: ", logits.shape)
|
| 136 |
|
| 137 |
top5 = torch.topk(logits, 5)
|
| 138 |
print("Top 5 logits:", top5.values)
|
| 139 |
print("Top 5 predicted classes:")
|
| 140 |
print([model.config.id2label[i] for i in top5.indices.squeeze().tolist()])
|
| 141 |
-
|
| 142 |
```
|
| 143 |
|
| 144 |
## Model Source
|
|
|
|
| 107 |
return log_spec
|
| 108 |
|
| 109 |
|
| 110 |
+
|
| 111 |
+
# Initialize the transformations
|
| 112 |
+
|
| 113 |
+
spectrogram_converter = torchaudio.transforms.Spectrogram(
|
| 114 |
+
n_fft=1024, hop_length=320, power=2.0
|
| 115 |
+
)
|
| 116 |
+
mel_converter = torchaudio.transforms.MelScale(
|
| 117 |
+
n_mels=128, n_stft=513, sample_rate=32_000
|
| 118 |
+
)
|
| 119 |
+
normalizer = transforms.Normalize((-4.268,), (4.569,))
|
| 120 |
+
powerToDB = PowerToDB(top_db=80)
|
| 121 |
+
|
| 122 |
+
|
| 123 |
def preprocess(audio, sample_rate_of_audio):
|
| 124 |
"""
|
| 125 |
Preprocess the audio to the format that the model expects
|
|
|
|
| 128 |
- Normalize the melscale spectrogram with mean: -4.268, std: 4.569 (from AudioSet)
|
| 129 |
|
| 130 |
"""
|
| 131 |
+
# convert waveform to spectrogram
|
| 132 |
+
spectrogram = spectrogram_converter(audio)
|
| 133 |
+
spectrogram = spectrogram.to(torch.float32)
|
| 134 |
+
melspec = mel_converter(spectrogram)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 135 |
dbscale = powerToDB(melspec)
|
| 136 |
+
normalized_dbscale = normalizer(dbscale)
|
| 137 |
+
# add dimension 3 from left
|
| 138 |
+
normalized_dbscale = normalized_dbscale.unsqueeze(-3)
|
| 139 |
return normalized_dbscale
|
| 140 |
|
| 141 |
preprocessed_audio = preprocess(audio, sample_rate)
|
| 142 |
+
print("Preprocessed_audio shape:", preprocessed_audio.shape)
|
| 143 |
+
|
| 144 |
|
| 145 |
+
|
| 146 |
+
logits = model(preprocessed_audio).logits
|
| 147 |
print("Logits shape: ", logits.shape)
|
| 148 |
|
| 149 |
top5 = torch.topk(logits, 5)
|
| 150 |
print("Top 5 logits:", top5.values)
|
| 151 |
print("Top 5 predicted classes:")
|
| 152 |
print([model.config.id2label[i] for i in top5.indices.squeeze().tolist()])
|
|
|
|
| 153 |
```
|
| 154 |
|
| 155 |
## Model Source
|