Spaces:
Running
Running
| # Imports | |
| import gradio as gr | |
| import os | |
| import matplotlib.pyplot as plt | |
| import torch | |
| import torchaudio | |
| from torch import nn | |
| import pytorch_lightning as pl | |
| from ema_pytorch import EMA | |
| import yaml | |
| from audio_diffusion_pytorch import DiffusionModel, UNetV0, VDiffusion, VSampler | |
| # Load configs | |
| def load_configs(config_path): | |
| with open(config_path, 'r') as file: | |
| config = yaml.safe_load(file) | |
| pl_configs = config['model'] | |
| model_configs = config['model']['model'] | |
| return pl_configs, model_configs | |
| # plot mel spectrogram | |
| def plot_mel_spectrogram(sample, sr): | |
| transform = torchaudio.transforms.MelSpectrogram( | |
| sample_rate=sr, | |
| n_fft=1024, | |
| hop_length=512, | |
| n_mels=80, | |
| center=True, | |
| norm="slaney", | |
| ) | |
| spectrogram = transform(torch.mean(sample, dim=0)) # downmix and cal spectrogram | |
| spectrogram = torchaudio.functional.amplitude_to_DB(spectrogram, 1.0, 1e-10, 80.0) | |
| # Plot the Mel spectrogram | |
| fig = plt.figure(figsize=(7, 4)) | |
| plt.imshow(spectrogram, aspect='auto', origin='lower') | |
| plt.colorbar(format='%+2.0f dB') | |
| plt.xlabel('Frame') | |
| plt.ylabel('Mel Bin') | |
| plt.title('Mel Spectrogram') | |
| plt.tight_layout() | |
| return fig | |
| # Define PyTorch Lightning model | |
| class Model(pl.LightningModule): | |
| def __init__( | |
| self, | |
| lr: float, | |
| lr_beta1: float, | |
| lr_beta2: float, | |
| lr_eps: float, | |
| lr_weight_decay: float, | |
| ema_beta: float, | |
| ema_power: float, | |
| model: nn.Module, | |
| ): | |
| super().__init__() | |
| self.lr = lr | |
| self.lr_beta1 = lr_beta1 | |
| self.lr_beta2 = lr_beta2 | |
| self.lr_eps = lr_eps | |
| self.lr_weight_decay = lr_weight_decay | |
| self.model = model | |
| self.model_ema = EMA(self.model, beta=ema_beta, power=ema_power) | |
| # Instantiate model (must match model that was trained) | |
| def load_model(model_configs, pl_configs) -> nn.Module: | |
| # Diffusion model | |
| model = DiffusionModel( | |
| net_t=UNetV0, # The model type used for diffusion (U-Net V0 in this case) | |
| in_channels=model_configs['in_channels'], # U-Net: number of input/output (audio) channels | |
| channels=model_configs['channels'], # U-Net: channels at each layer | |
| factors=model_configs['factors'], # U-Net: downsampling and upsampling factors at each layer | |
| items=model_configs['items'], # U-Net: number of repeating items at each layer | |
| attentions=model_configs['attentions'], # U-Net: attention enabled/disabled at each layer | |
| attention_heads=model_configs['attention_heads'], # U-Net: number of attention heads per attention item | |
| attention_features=model_configs['attention_features'], # U-Net: number of attention features per attention item | |
| diffusion_t=VDiffusion, # The diffusion method used | |
| sampler_t=VSampler # The diffusion sampler used | |
| ) | |
| # pl model | |
| model = Model( | |
| lr=pl_configs['lr'], | |
| lr_beta1=pl_configs['lr_beta1'], | |
| lr_beta2=pl_configs['lr_beta2'], | |
| lr_eps=pl_configs['lr_eps'], | |
| lr_weight_decay=pl_configs['lr_weight_decay'], | |
| ema_beta=pl_configs['ema_beta'], | |
| ema_power=pl_configs['ema_power'], | |
| model=model | |
| ) | |
| return model | |
| # Assign to GPU | |
| def assign_to_gpu(model): | |
| if torch.cuda.is_available(): | |
| model = model.to('cuda') | |
| print(f"Device: {model.device}") | |
| return model | |
| # Load model checkpoint | |
| def load_checkpoint(model, ckpt_path) -> None: | |
| checkpoint = torch.load(ckpt_path, map_location='cpu')['state_dict'] | |
| model.load_state_dict(checkpoint) # should output "<All keys matched successfully>" | |
| # Generate Samples | |
| def generate_samples(model_name, num_samples, num_steps, init_audio=None, noise_level=0.7, duration=32768): | |
| # load_checkpoint | |
| ckpt_path = models[model_name] | |
| load_checkpoint(model, ckpt_path) | |
| if num_samples > 1: | |
| duration = int(duration / 2) | |
| # Generate samples | |
| with torch.no_grad(): | |
| if init_audio: | |
| # load audio sample | |
| audio_sample = torch.tensor(init_audio[1].T, dtype=torch.float32).unsqueeze(0).to(model.device) | |
| audio_sample = audio_sample / torch.max(torch.abs(audio_sample)) # normalize init_audio | |
| # Trim audio | |
| og_shape = audio_sample.shape | |
| if duration < og_shape[2]: | |
| audio_sample = audio_sample[:,:,:duration] | |
| elif duration > og_shape[2]: | |
| # Pad tensor with zeros to match sample length | |
| audio_sample = torch.concat((audio_sample, torch.zeros(og_shape[0], og_shape[1], duration - og_shape[2]).to(model.device)), dim=2) | |
| else: | |
| audio_sample = torch.zeros((1, 2, int(duration)), device=model.device) | |
| noise_level = 1.0 | |
| all_samples = torch.zeros(2, 0) | |
| for i in range(num_samples): | |
| noise = torch.randn_like(audio_sample, device=model.device) * noise_level # [batch_size, in_channels, length] | |
| audio = (audio_sample * abs(1-noise_level)) + noise # add noise | |
| # generate samples | |
| generated_sample = model.model_ema.ema_model.sample(audio, num_steps=num_steps).squeeze(0).cpu() # Suggested num_steps 10-100 | |
| # concatenate all samples: | |
| all_samples = torch.concat((all_samples, generated_sample), dim=1) | |
| torch.cuda.empty_cache() | |
| fig = plot_mel_spectrogram(all_samples, sr) | |
| plt.title(f"{model_name} Mel Spectrogram") | |
| return (sr, all_samples.cpu().detach().numpy().T), fig # (sample rate, audio), plot | |
| # Define Constants & initialize model | |
| # load model & configs | |
| sr = 44100 # sampling rate | |
| config_path = "saved_models/config.yaml" # config path | |
| pl_configs, model_configs = load_configs(config_path) | |
| model = load_model(model_configs, pl_configs) | |
| model = assign_to_gpu(model) | |
| models = { | |
| "Kicks": "saved_models/kicks/kicks_v7.ckpt", | |
| "Snares": "saved_models/snares/snares_v0.ckpt", | |
| "Hi-hats": "saved_models/hihats/hihats_v2.ckpt", | |
| "Percussion": "saved_models/percussion/percussion_v0.ckpt" | |
| } | |
| intro = """ | |
| <h1 style="font-weight: 1400; text-align: center; margin-bottom: 6px;"> | |
| Tiny Audio Diffusion | |
| </h1> | |
| <h3 style="font-weight: 600; text-align: center;"> | |
| Christopher Landschoot - Audio waveform diffusion built to run on consumer-grade hardware (<2GB VRAM) | |
| </h3> | |
| <h4 style="text-align: center; margin-bottom: 6px;"> | |
| <a href="https://github.com/crlandsc/tiny-audio-diffusion" style="text-decoration: underline;" target="_blank">GitHub Repo</a> | |
| | <a href="https://youtu.be/m6Eh2srtTro" style="text-decoration: underline;" target="_blank">Repo Tutorial Video</a> | |
| | <a href="https://medium.com/towards-data-science/tiny-audio-diffusion-ddc19e90af9b" style="text-decoration: underline;" target="_blank">Towards Data Science Article</a> | |
| </h4> | |
| """ | |
| with gr.Blocks() as demo: | |
| # Layout | |
| gr.HTML(intro) | |
| with gr.Row(equal_height=False): | |
| with gr.Column(): | |
| # Inputs | |
| model_name = gr.Dropdown(choices=list(models.keys()), value=list(models.keys())[3], label="Model") | |
| num_samples = gr.Slider(1, 25, step=1, label="Number of Samples to Generate", value=3) | |
| num_steps = gr.Slider(1, 100, step=1, label="Number of Diffusion Steps", value=15) | |
| # Conditioning Audio Input | |
| with gr.Accordion("Input Audio (optional)", open=False): | |
| init_audio_description = gr.HTML('Upload an audio file to perform conditional "style transfer" diffusion.<br>Leaving input audio blank results in unconditional generation.') | |
| init_audio = gr.Audio(label="Input Audio Sample") | |
| init_audio_noise = gr.Slider(0, 1, step=0.01, label="Noise to add to input audio", value=0.70)#, visible=True) | |
| # Examples | |
| gr.Examples( | |
| examples=[ | |
| os.path.join(os.path.dirname(__file__), "samples", "guitar.wav"), | |
| os.path.join(os.path.dirname(__file__), "samples", "snare.wav"), | |
| os.path.join(os.path.dirname(__file__), "samples", "kick.wav"), | |
| os.path.join(os.path.dirname(__file__), "samples", "hihat.wav") | |
| ], | |
| inputs=init_audio, | |
| label="Example Audio Inputs" | |
| ) | |
| # Buttons | |
| with gr.Row(): | |
| with gr.Column(): | |
| clear_button = gr.Button(value="Reset All") | |
| with gr.Column(): | |
| generate_btn = gr.Button("Generate Samples!") | |
| with gr.Column(): | |
| # Outputs | |
| output_audio = gr.Audio(label="Generated Audio Sample") | |
| output_plot = gr.Plot(label="Generated Audio Spectrogram") | |
| # Functionality | |
| # Generate samples | |
| generate_btn.click(fn=generate_samples, inputs=[model_name, num_samples, num_steps, init_audio, init_audio_noise], outputs=[output_audio, output_plot]) | |
| # clear_button button to reset everything | |
| clear_button.click(fn=lambda: [3, 15, None, 0.70, None, None], outputs=[num_samples, num_steps, init_audio, init_audio_noise, output_audio, output_plot]) | |
| if __name__ == "__main__": | |
| demo.launch() |