Ashish Reddy
committing
a090db7
import torch
from transformers import AutoModelForSequenceClassification
from loraLinear import LoRALinear
MODEL_CKPT = "distilbert-base-uncased"
RANK = 4
ALPHA = 4
DEVICE = "cpu" # fine for Spaces; merge is fast
# Re-create the LoRA-injected architecture
lora_model = AutoModelForSequenceClassification.from_pretrained(MODEL_CKPT)
for blk in lora_model.distilbert.transformer.layer:
blk.attention.q_lin = LoRALinear(blk.attention.q_lin, RANK, ALPHA)
blk.attention.v_lin = LoRALinear(blk.attention.v_lin, RANK, ALPHA)
lora_model.load_state_dict(torch.load("DISTILBERT_WITH_LORA.pth", map_location=DEVICE))
lora_model.eval()
# Collapse each adapter: W ← W + (B @ A)·scale
for blk in lora_model.distilbert.transformer.layer:
for name in ("q_lin", "v_lin"):
wrap = getattr(blk.attention, name)
with torch.no_grad():
base_W = wrap.original_layer.weight # (out, in)
A = wrap.lora.loraA.weight # (rank, in)
B = wrap.lora.loraB.weight # (out, rank)
base_W += (B @ A) * wrap.lora.scaling # in-place update
# Copy the merged weights into a *plain* DistilBERT (no wrappers)
plain_model = AutoModelForSequenceClassification.from_pretrained(MODEL_CKPT)
with torch.no_grad():
for i in range(6):
plain_blk = plain_model.distilbert.transformer.layer[i]
lora_blk = lora_model.distilbert.transformer.layer[i]
for lin in ("q_lin", "v_lin"):
pl = getattr(plain_blk.attention, lin)
lr = getattr(lora_blk.attention, lin).original_layer
pl.weight.copy_(lr.weight)
pl.bias.copy_(lr.bias)
# classification head
plain_model.pre_classifier.weight.copy_(lora_model.pre_classifier.weight)
plain_model.pre_classifier.bias.copy_(lora_model.pre_classifier.bias)
plain_model.classifier.weight.copy_(lora_model.classifier.weight)
plain_model.classifier.bias.copy_(lora_model.classifier.bias)
# Save
torch.save(plain_model.state_dict(), "DISTILBERT_MERGED.pth")
print("✅ Merged weights saved to DISTILBERT_MERGED.pth")