|
|
import torch |
|
|
from transformers import AutoModelForSequenceClassification |
|
|
from loraLinear import LoRALinear |
|
|
|
|
|
MODEL_CKPT = "distilbert-base-uncased" |
|
|
RANK = 4 |
|
|
ALPHA = 4 |
|
|
DEVICE = "cpu" |
|
|
|
|
|
|
|
|
lora_model = AutoModelForSequenceClassification.from_pretrained(MODEL_CKPT) |
|
|
for blk in lora_model.distilbert.transformer.layer: |
|
|
blk.attention.q_lin = LoRALinear(blk.attention.q_lin, RANK, ALPHA) |
|
|
blk.attention.v_lin = LoRALinear(blk.attention.v_lin, RANK, ALPHA) |
|
|
|
|
|
lora_model.load_state_dict(torch.load("DISTILBERT_WITH_LORA.pth", map_location=DEVICE)) |
|
|
lora_model.eval() |
|
|
|
|
|
|
|
|
for blk in lora_model.distilbert.transformer.layer: |
|
|
for name in ("q_lin", "v_lin"): |
|
|
wrap = getattr(blk.attention, name) |
|
|
with torch.no_grad(): |
|
|
base_W = wrap.original_layer.weight |
|
|
A = wrap.lora.loraA.weight |
|
|
B = wrap.lora.loraB.weight |
|
|
base_W += (B @ A) * wrap.lora.scaling |
|
|
|
|
|
|
|
|
plain_model = AutoModelForSequenceClassification.from_pretrained(MODEL_CKPT) |
|
|
with torch.no_grad(): |
|
|
for i in range(6): |
|
|
plain_blk = plain_model.distilbert.transformer.layer[i] |
|
|
lora_blk = lora_model.distilbert.transformer.layer[i] |
|
|
|
|
|
for lin in ("q_lin", "v_lin"): |
|
|
pl = getattr(plain_blk.attention, lin) |
|
|
lr = getattr(lora_blk.attention, lin).original_layer |
|
|
pl.weight.copy_(lr.weight) |
|
|
pl.bias.copy_(lr.bias) |
|
|
|
|
|
|
|
|
plain_model.pre_classifier.weight.copy_(lora_model.pre_classifier.weight) |
|
|
plain_model.pre_classifier.bias.copy_(lora_model.pre_classifier.bias) |
|
|
plain_model.classifier.weight.copy_(lora_model.classifier.weight) |
|
|
plain_model.classifier.bias.copy_(lora_model.classifier.bias) |
|
|
|
|
|
|
|
|
torch.save(plain_model.state_dict(), "DISTILBERT_MERGED.pth") |
|
|
print("✅ Merged weights saved to DISTILBERT_MERGED.pth") |