import torch.nn as nn class LoRALayer(nn.Module): def __init__(self, in_features, out_features, rank, alpha): super().__init__() self.rank = rank self.alpha = alpha self.scaling = alpha/rank self.loraA = nn.Linear(in_features, rank, bias=False) self.loraB = nn.Linear(rank, out_features, bias=False) nn.init.kaiming_uniform_(self.loraA.weight, a=5**0.5) nn.init.zeros_(self.loraB.weight) def forward(self, x): delta = self.loraB(self.loraA(x)) # (x*A)*B --> ((B, S, D) * (B, D, R)) * (B, R, D) --> (B, S, R) * (B, R, D) --> (B, S, D) x = self.scaling * delta return x