Spaces:
Runtime error
Runtime error
| import torch | |
| import torch.nn as nn | |
| from diffusers.models.embeddings import Timesteps, TimestepEmbedding | |
| import torch.optim as optim | |
| from torch.nn import functional as F | |
| # Define the Expert Network | |
| class Expert(nn.Module): | |
| def __init__(self, input_dim, hidden_dim, output_dim, use_softmax=False): | |
| super(Expert, self).__init__() | |
| self.use_softmax = use_softmax | |
| self.net = nn.Sequential( | |
| nn.Linear(input_dim, hidden_dim), | |
| nn.ReLU(), | |
| nn.Linear(hidden_dim, output_dim), | |
| ) | |
| def forward(self, x): | |
| return ( | |
| self.net(x) if not self.use_softmax else torch.softmax(self.net(x), dim=1) | |
| ) | |
| class DynamicGatingNetwork(nn.Module): | |
| def __init__(self, hidden_dim=64, embed_dim=64, dtype=torch.bfloat16): | |
| super().__init__() | |
| # 处理时间步 | |
| self.time_proj = Timesteps( | |
| hidden_dim, flip_sin_to_cos=True, downscale_freq_shift=0 | |
| ) | |
| self.timestep_embedding = TimestepEmbedding(hidden_dim, embed_dim) | |
| self.timestep_embedding = self.timestep_embedding.to(dtype=torch.bfloat16) | |
| # 处理 noise_latent | |
| self.noise_proj = nn.Linear(hidden_dim, hidden_dim) | |
| self.dtype = dtype | |
| # 权重计算 | |
| self.gate = nn.Sequential( | |
| nn.Linear(hidden_dim, hidden_dim), | |
| nn.ReLU(), | |
| nn.Linear(hidden_dim, 20), # 生成两个权重 | |
| ) | |
| def forward(self, condition_latents, noise_latent, timestep): | |
| """ | |
| global_latents: (bs, 1024, 64) | |
| noise_latent: (bs, 1024, 64) | |
| timestep: (bs,) | |
| """ | |
| bs, seq_len, hidden_dim = condition_latents.shape | |
| # 处理 timestep | |
| time_emb = self.time_proj(timestep) # (bs, hidden_dim) | |
| time_emb = time_emb.to(self.dtype) | |
| time_emb = self.timestep_embedding(time_emb) # (bs, embed_dim) | |
| time_emb = time_emb.unsqueeze(1).expand( | |
| -1, seq_len, -1 | |
| ) # (bs, 1024, embed_dim) | |
| # 处理 noise_latent | |
| noise_emb = self.noise_proj(noise_latent) # (bs, 1024, 64) | |
| # 拼接所有输入 | |
| # fused_input = torch.cat([condition_latents, noise_emb, time_emb], dim=2) # (bs, 1024, 64+64+128) | |
| fused_input = condition_latents + noise_emb + time_emb | |
| # 计算权重 | |
| weight = self.gate(fused_input) # (bs, 1024, 2) | |
| weight = F.softmax(weight, dim=2) # 归一化 | |
| return weight | |
| class MoGLE(nn.Module): | |
| def __init__( | |
| self, | |
| num_experts=20, | |
| input_dim=64, | |
| hidden_dim=32, | |
| output_dim=64, | |
| has_expert=True, | |
| has_gating=True, | |
| weight_is_scale=False, | |
| ): | |
| super().__init__() | |
| expert_model = None | |
| if has_expert: | |
| expert_model = Expert | |
| else: | |
| expert_model = nn.Identity | |
| self.global_expert = expert_model( | |
| input_dim=input_dim, hidden_dim=hidden_dim, output_dim=output_dim | |
| ) | |
| self.local_experts = nn.ModuleList( | |
| [ | |
| expert_model( | |
| input_dim=input_dim, hidden_dim=hidden_dim, output_dim=output_dim | |
| ) | |
| for _ in range(num_experts - 1) | |
| ] | |
| ) | |
| # self.gating = Gating(input_dim=input_dim, num_experts=num_experts) | |
| if has_gating: | |
| self.gating = DynamicGatingNetwork() | |
| else: | |
| self.gating = nn.Identity() | |
| self.weight_is_scale = weight_is_scale | |
| def forward(self, x: torch.Tensor, noise_latent, timestep): | |
| global_mask = x[:, 0] # bs 1024 64 | |
| local_mask = x[:, 1:] # bs 19 1024 64 | |
| if not isinstance(self.gating, nn.Identity): | |
| weights = self.gating.forward( | |
| global_mask, noise_latent=noise_latent, timestep=timestep | |
| ) # bs 1024 20 | |
| _, num_local, h, w = local_mask.shape | |
| global_output = self.global_expert(global_mask).unsqueeze(1) | |
| local_outputs = torch.stack( | |
| [self.local_experts[i](local_mask[:, i]) for i in range(num_local)], dim=1 | |
| ) # (bs, 19, 1024, 64) | |
| global_local_outputs = torch.cat( | |
| [global_output, local_outputs], dim=1 | |
| ) # bs 20 1024 64 | |
| if isinstance(self.gating, nn.Identity): | |
| global_local_outputs = global_local_outputs.sum(dim=1) | |
| return global_local_outputs | |
| if self.weight_is_scale: | |
| weights = torch.mean(weights, dim=1, keepdim=True) # bs 1 20 | |
| # print("gating scale") | |
| weights_expanded = weights.unsqueeze(-1) | |
| output = (global_local_outputs.permute(0, 2, 1, 3) * weights_expanded).sum( | |
| dim=2 | |
| ) | |
| return output # bs 1024 64 | |