Ashish Reddy
commited on
Commit
·
a090db7
1
Parent(s):
a494631
committing
Browse files- .gitignore +2 -0
- DISTILBERT_MERGED.pth +3 -0
- app.py +50 -0
- baseline.py +94 -0
- loraLayer.py +19 -0
- loraLinear.py +45 -0
- loraTune.py +86 -0
- mergeWeights.py +50 -0
- requirements.txt +4 -0
.gitignore
ADDED
|
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
|
|
|
| 1 |
+
*.DS_Store
|
| 2 |
+
__pycache__/
|
DISTILBERT_MERGED.pth
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:f124b3db5a1adae5c3d1472b849806e3aa23b352d3d3c9a53bdf404f8d0b2ca0
|
| 3 |
+
size 267861563
|
app.py
ADDED
|
@@ -0,0 +1,50 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch, torch.nn.functional as F
|
| 2 |
+
import gradio as gr
|
| 3 |
+
from transformers import AutoTokenizer, AutoModelForSequenceClassification
|
| 4 |
+
|
| 5 |
+
MODEL_CKPT = "distilbert-base-uncased"
|
| 6 |
+
DEVICE = "cpu" # HF Spaces default
|
| 7 |
+
|
| 8 |
+
print("--- Loading tokenizer & base model ---")
|
| 9 |
+
tokenizer = AutoTokenizer.from_pretrained(MODEL_CKPT)
|
| 10 |
+
model = AutoModelForSequenceClassification.from_pretrained(MODEL_CKPT)
|
| 11 |
+
|
| 12 |
+
print("--- Loading merged fine-tuned weights ---")
|
| 13 |
+
model.load_state_dict(torch.load("DISTILBERT_MERGED.pth", map_location=DEVICE))
|
| 14 |
+
model.to(DEVICE).eval()
|
| 15 |
+
|
| 16 |
+
# nice label names for IMDB
|
| 17 |
+
model.config.id2label = {0: "NEGATIVE", 1: "POSITIVE"}
|
| 18 |
+
model.config.label2id = {v: k for k, v in model.config.id2label.items()}
|
| 19 |
+
|
| 20 |
+
def predict(text):
|
| 21 |
+
tokens = tokenizer(
|
| 22 |
+
text,
|
| 23 |
+
return_tensors="pt",
|
| 24 |
+
padding="max_length",
|
| 25 |
+
truncation=True,
|
| 26 |
+
max_length=256
|
| 27 |
+
).to(DEVICE)
|
| 28 |
+
|
| 29 |
+
with torch.no_grad():
|
| 30 |
+
probs = F.softmax(model(**tokens).logits, dim=-1)[0]
|
| 31 |
+
return {model.config.id2label[i]: float(p) for i, p in enumerate(probs)}
|
| 32 |
+
|
| 33 |
+
demo = gr.Interface(
|
| 34 |
+
fn=predict,
|
| 35 |
+
inputs=gr.Textbox(lines=3, label="Movie Review"),
|
| 36 |
+
outputs=gr.Label(num_top_classes=2, label="Sentiment"),
|
| 37 |
+
title="Sentiment Analysis (LoRA-merged DistilBERT)",
|
| 38 |
+
description=(
|
| 39 |
+
"DistilBERT fine-tuned on IMDB with a custom LoRA adapter. "
|
| 40 |
+
"Adapters have been merged so the model runs with no extra parameters."
|
| 41 |
+
),
|
| 42 |
+
examples=[
|
| 43 |
+
["An absolute masterpiece with brilliant acting!"],
|
| 44 |
+
["Total waste of two hours."],
|
| 45 |
+
["Predictable plot but gorgeous visuals."]
|
| 46 |
+
]
|
| 47 |
+
)
|
| 48 |
+
|
| 49 |
+
if __name__ == "__main__":
|
| 50 |
+
demo.launch()
|
baseline.py
ADDED
|
@@ -0,0 +1,94 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
from transformers import (
|
| 3 |
+
AutoTokenizer,
|
| 4 |
+
AutoModelForSequenceClassification
|
| 5 |
+
)
|
| 6 |
+
from datasets import load_dataset
|
| 7 |
+
from torch.utils.data import DataLoader
|
| 8 |
+
|
| 9 |
+
"""
|
| 10 |
+
---- Device ----
|
| 11 |
+
"""
|
| 12 |
+
|
| 13 |
+
if torch.cuda.is_available():
|
| 14 |
+
device = torch.device('cuda')
|
| 15 |
+
print("Using CUDA (GPU)")
|
| 16 |
+
elif torch.backends.mps.is_available() and torch.backends.mps.is_built():
|
| 17 |
+
device = torch.device('mps')
|
| 18 |
+
print("Using MPS (Apple Silicon GPU)")
|
| 19 |
+
else:
|
| 20 |
+
device = torch.device('cpu')
|
| 21 |
+
print("Using device's CPU")
|
| 22 |
+
|
| 23 |
+
"""
|
| 24 |
+
--- Model ---
|
| 25 |
+
"""
|
| 26 |
+
|
| 27 |
+
model_ckpt = "distilbert-base-uncased"
|
| 28 |
+
|
| 29 |
+
print(f"--- Loading pre-trained model and tokenizer: {model_ckpt.upper()} ---")
|
| 30 |
+
|
| 31 |
+
tok = AutoTokenizer.from_pretrained(model_ckpt)
|
| 32 |
+
model = AutoModelForSequenceClassification.from_pretrained(model_ckpt)
|
| 33 |
+
model.to(device)
|
| 34 |
+
print(f"Model moved to {device}")
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
"""
|
| 38 |
+
--- Data Prep ---
|
| 39 |
+
"""
|
| 40 |
+
|
| 41 |
+
print("\n--- Loading and preparing IMDB dataset ---")
|
| 42 |
+
imdb_dataset = load_dataset("imdb")
|
| 43 |
+
"""
|
| 44 |
+
DatasetDict({
|
| 45 |
+
train: Dataset({
|
| 46 |
+
features: ['text', 'label'],
|
| 47 |
+
num_rows: 25000
|
| 48 |
+
})
|
| 49 |
+
test: Dataset({
|
| 50 |
+
features: ['text', 'label'],
|
| 51 |
+
num_rows: 25000
|
| 52 |
+
})
|
| 53 |
+
unsupervised: Dataset({
|
| 54 |
+
features: ['text', 'label'],
|
| 55 |
+
num_rows: 50000
|
| 56 |
+
})
|
| 57 |
+
})
|
| 58 |
+
"""
|
| 59 |
+
|
| 60 |
+
def tokenize_fn(examples):
|
| 61 |
+
return tok(examples["text"], padding="max_length", truncation=True)
|
| 62 |
+
|
| 63 |
+
tokenized_datasets = imdb_dataset.map(tokenize_fn, batched=True)
|
| 64 |
+
|
| 65 |
+
tokenized_datasets = tokenized_datasets.remove_columns(["text"])
|
| 66 |
+
tokenized_datasets = tokenized_datasets.rename_column("label", "labels")
|
| 67 |
+
tokenized_datasets.set_format("torch")
|
| 68 |
+
|
| 69 |
+
|
| 70 |
+
if __name__ == '__main__':
|
| 71 |
+
|
| 72 |
+
small_eval_dataset = tokenized_datasets["test"].shuffle(seed=42).select(range(1000)) # Select random 1000 test datasets
|
| 73 |
+
eval_dataloader = DataLoader(small_eval_dataset, batch_size=8) # Convert them into 8 batches --> 125 ['labels', 'token_ids', 'attention_mask'] examples in each batch
|
| 74 |
+
|
| 75 |
+
print("\n--- Evaluating baseline model performance ---")
|
| 76 |
+
model.eval()
|
| 77 |
+
num_correct = 0
|
| 78 |
+
num_samples = 0
|
| 79 |
+
|
| 80 |
+
with torch.no_grad(): # Disable gradient calculation for inference (No backprop)
|
| 81 |
+
for batch in eval_dataloader:
|
| 82 |
+
batch = {k: v.to(device) for k, v in batch.items()}
|
| 83 |
+
|
| 84 |
+
outputs = model(**batch) # Forward pass
|
| 85 |
+
logits = outputs.logits # Logits
|
| 86 |
+
|
| 87 |
+
predictions = torch.argmax(logits, dim=-1) # Highest logit score
|
| 88 |
+
|
| 89 |
+
# Compare predictions to true labels
|
| 90 |
+
num_correct += (predictions == batch["labels"]).sum().item()
|
| 91 |
+
num_samples += batch["labels"].size(0)
|
| 92 |
+
|
| 93 |
+
accuracy = num_correct / num_samples
|
| 94 |
+
print(f"Baseline Accuracy on 1000 samples: {accuracy:.4f}") # Around 0.4880 --> 48% accurate (For 1000 testing examples) [As it plays the game of guessing, it always is around the 50% mark as the model isn't still trained and you can expect the output to be always positive or always negative]
|
loraLayer.py
ADDED
|
@@ -0,0 +1,19 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch.nn as nn
|
| 2 |
+
|
| 3 |
+
class LoRALayer(nn.Module):
|
| 4 |
+
def __init__(self, in_features, out_features, rank, alpha):
|
| 5 |
+
super().__init__()
|
| 6 |
+
self.rank = rank
|
| 7 |
+
self.alpha = alpha
|
| 8 |
+
self.scaling = alpha/rank
|
| 9 |
+
|
| 10 |
+
self.loraA = nn.Linear(in_features, rank, bias=False)
|
| 11 |
+
self.loraB = nn.Linear(rank, out_features, bias=False)
|
| 12 |
+
|
| 13 |
+
nn.init.kaiming_uniform_(self.loraA.weight, a=5**0.5)
|
| 14 |
+
nn.init.zeros_(self.loraB.weight)
|
| 15 |
+
|
| 16 |
+
def forward(self, x):
|
| 17 |
+
delta = self.loraB(self.loraA(x)) # (x*A)*B --> ((B, S, D) * (B, D, R)) * (B, R, D) --> (B, S, R) * (B, R, D) --> (B, S, D)
|
| 18 |
+
x = self.scaling * delta
|
| 19 |
+
return x
|
loraLinear.py
ADDED
|
@@ -0,0 +1,45 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch.nn as nn
|
| 2 |
+
from transformers import (
|
| 3 |
+
AutoModelForSequenceClassification
|
| 4 |
+
)
|
| 5 |
+
|
| 6 |
+
RANK = 4
|
| 7 |
+
ALPHA = 4
|
| 8 |
+
model_ckpt = "distilbert-base-uncased"
|
| 9 |
+
|
| 10 |
+
from loraLayer import LoRALayer
|
| 11 |
+
|
| 12 |
+
class LoRALinear(nn.Module):
|
| 13 |
+
def __init__(self, original_layer, rank, alpha):
|
| 14 |
+
super().__init__()
|
| 15 |
+
self.in_features = original_layer.in_features
|
| 16 |
+
self.out_features = original_layer.out_features
|
| 17 |
+
self.original_layer = original_layer
|
| 18 |
+
self.lora = LoRALayer(self.in_features, self.out_features, rank, alpha)
|
| 19 |
+
|
| 20 |
+
def forward(self, x):
|
| 21 |
+
original_output = self.original_layer(x) # Wo*x
|
| 22 |
+
lora_output = self.lora(x) # (xA)B * scaling
|
| 23 |
+
return original_output + lora_output # Wo*x + (xA)B * scaling
|
| 24 |
+
|
| 25 |
+
model = AutoModelForSequenceClassification.from_pretrained(model_ckpt)
|
| 26 |
+
|
| 27 |
+
for param in model.parameters():
|
| 28 |
+
param.requires_grad = False # Freeze all original parameters
|
| 29 |
+
|
| 30 |
+
print("--- Injecting LoRA adapters into q_lin and v_lin layers of DISTILBERT---")
|
| 31 |
+
for layer in model.distilbert.transformer.layer:
|
| 32 |
+
layer.attention.q_lin = LoRALinear(layer.attention.q_lin, RANK, ALPHA)
|
| 33 |
+
layer.attention.v_lin = LoRALinear(layer.attention.v_lin, RANK, ALPHA)
|
| 34 |
+
print("INFO: LoRA Adapters INJECTED")
|
| 35 |
+
|
| 36 |
+
print("\nTrainable parameters:")
|
| 37 |
+
for name, param in model.named_parameters():
|
| 38 |
+
if param.requires_grad:
|
| 39 |
+
print(name)
|
| 40 |
+
|
| 41 |
+
total_params = sum(p.numel() for p in model.parameters())
|
| 42 |
+
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
|
| 43 |
+
print(f"\nTotal parameters: {total_params}")
|
| 44 |
+
print(f"Trainable LoRA parameters: {trainable_params}")
|
| 45 |
+
print(f"Percentage of trainable parameters: {100 * trainable_params / total_params:.4f}%")
|
loraTune.py
ADDED
|
@@ -0,0 +1,86 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch, torch.optim as optim
|
| 2 |
+
from torch.utils.data import DataLoader
|
| 3 |
+
from tqdm import tqdm
|
| 4 |
+
|
| 5 |
+
model_ckpt = "distilbert-base-uncased"
|
| 6 |
+
batch_size = 16
|
| 7 |
+
n_epochs = 3
|
| 8 |
+
learning_rate = 1e-4
|
| 9 |
+
RANK = 4
|
| 10 |
+
ALPHA = 4
|
| 11 |
+
|
| 12 |
+
"""
|
| 13 |
+
---- Device ----
|
| 14 |
+
"""
|
| 15 |
+
|
| 16 |
+
if torch.cuda.is_available():
|
| 17 |
+
device = torch.device('cuda')
|
| 18 |
+
print("Using CUDA (GPU)")
|
| 19 |
+
elif torch.backends.mps.is_available() and torch.backends.mps.is_built():
|
| 20 |
+
device = torch.device('mps')
|
| 21 |
+
print("Using MPS (Apple Silicon GPU)")
|
| 22 |
+
else:
|
| 23 |
+
device = torch.device('cpu')
|
| 24 |
+
print("Using device's CPU")
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
from baseline import tokenized_datasets
|
| 28 |
+
|
| 29 |
+
"""
|
| 30 |
+
tokenized_datasets:
|
| 31 |
+
|
| 32 |
+
DatasetDict({
|
| 33 |
+
train: Dataset({
|
| 34 |
+
features: ['labels', 'input_ids', 'attention_mask'],
|
| 35 |
+
num_rows: 25000
|
| 36 |
+
})
|
| 37 |
+
test: Dataset({
|
| 38 |
+
features: ['labels', 'input_ids', 'attention_mask'],
|
| 39 |
+
num_rows: 25000
|
| 40 |
+
})
|
| 41 |
+
unsupervised: Dataset({
|
| 42 |
+
features: ['labels', 'input_ids', 'attention_mask'],
|
| 43 |
+
num_rows: 50000
|
| 44 |
+
})
|
| 45 |
+
})
|
| 46 |
+
"""
|
| 47 |
+
|
| 48 |
+
train_dataloader = DataLoader(tokenized_datasets["train"], shuffle=True, batch_size=batch_size)
|
| 49 |
+
eval_dataloader = DataLoader(tokenized_datasets["test"], batch_size=batch_size)
|
| 50 |
+
|
| 51 |
+
from loraLinear import model
|
| 52 |
+
|
| 53 |
+
model.to(device)
|
| 54 |
+
print(f"INFO: Moved model to {device}")
|
| 55 |
+
|
| 56 |
+
trainable_params = [p for p in model.parameters() if p.requires_grad] # len: 24
|
| 57 |
+
optimizer = optim.AdamW(trainable_params, lr=learning_rate)
|
| 58 |
+
|
| 59 |
+
for epoch in range(n_epochs):
|
| 60 |
+
model.train()
|
| 61 |
+
print(f"\n--- Starting Epoch {epoch+1}/{n_epochs} ---")
|
| 62 |
+
for batch in tqdm(train_dataloader, desc=f"Training Epoch {epoch+1}"):
|
| 63 |
+
batch = {k: v.to(device) for k, v in batch.items()}
|
| 64 |
+
optimizer.zero_grad()
|
| 65 |
+
outputs = model(**batch)
|
| 66 |
+
loss = outputs.loss
|
| 67 |
+
loss.backward()
|
| 68 |
+
optimizer.step()
|
| 69 |
+
|
| 70 |
+
model.eval()
|
| 71 |
+
num_correct = 0
|
| 72 |
+
num_samples = 0
|
| 73 |
+
with torch.no_grad():
|
| 74 |
+
for batch in tqdm(eval_dataloader, desc=f"Evaluating Epoch {epoch+1}"):
|
| 75 |
+
batch = {k: v.to(device) for k, v in batch.items()}
|
| 76 |
+
outputs = model(**batch)
|
| 77 |
+
predictions = torch.argmax(outputs.logits, dim=-1)
|
| 78 |
+
num_correct += (predictions == batch["labels"]).sum().item()
|
| 79 |
+
num_samples += batch["labels"].size(0)
|
| 80 |
+
|
| 81 |
+
accuracy = num_correct / num_samples
|
| 82 |
+
print(f"--- Epoch {epoch+1} Validation Accuracy: {accuracy:.4f} ---")
|
| 83 |
+
|
| 84 |
+
print("\nFine-tuning complete.")
|
| 85 |
+
torch.save(model.state_dict(), "DISTILBERT_WITH_LORA.pth")
|
| 86 |
+
print("Trained LoRA model saved.")
|
mergeWeights.py
ADDED
|
@@ -0,0 +1,50 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
from transformers import AutoModelForSequenceClassification
|
| 3 |
+
from loraLinear import LoRALinear
|
| 4 |
+
|
| 5 |
+
MODEL_CKPT = "distilbert-base-uncased"
|
| 6 |
+
RANK = 4
|
| 7 |
+
ALPHA = 4
|
| 8 |
+
DEVICE = "cpu" # fine for Spaces; merge is fast
|
| 9 |
+
|
| 10 |
+
# Re-create the LoRA-injected architecture
|
| 11 |
+
lora_model = AutoModelForSequenceClassification.from_pretrained(MODEL_CKPT)
|
| 12 |
+
for blk in lora_model.distilbert.transformer.layer:
|
| 13 |
+
blk.attention.q_lin = LoRALinear(blk.attention.q_lin, RANK, ALPHA)
|
| 14 |
+
blk.attention.v_lin = LoRALinear(blk.attention.v_lin, RANK, ALPHA)
|
| 15 |
+
|
| 16 |
+
lora_model.load_state_dict(torch.load("DISTILBERT_WITH_LORA.pth", map_location=DEVICE))
|
| 17 |
+
lora_model.eval()
|
| 18 |
+
|
| 19 |
+
# Collapse each adapter: W ← W + (B @ A)·scale
|
| 20 |
+
for blk in lora_model.distilbert.transformer.layer:
|
| 21 |
+
for name in ("q_lin", "v_lin"):
|
| 22 |
+
wrap = getattr(blk.attention, name)
|
| 23 |
+
with torch.no_grad():
|
| 24 |
+
base_W = wrap.original_layer.weight # (out, in)
|
| 25 |
+
A = wrap.lora.loraA.weight # (rank, in)
|
| 26 |
+
B = wrap.lora.loraB.weight # (out, rank)
|
| 27 |
+
base_W += (B @ A) * wrap.lora.scaling # in-place update
|
| 28 |
+
|
| 29 |
+
# Copy the merged weights into a *plain* DistilBERT (no wrappers)
|
| 30 |
+
plain_model = AutoModelForSequenceClassification.from_pretrained(MODEL_CKPT)
|
| 31 |
+
with torch.no_grad():
|
| 32 |
+
for i in range(6):
|
| 33 |
+
plain_blk = plain_model.distilbert.transformer.layer[i]
|
| 34 |
+
lora_blk = lora_model.distilbert.transformer.layer[i]
|
| 35 |
+
|
| 36 |
+
for lin in ("q_lin", "v_lin"):
|
| 37 |
+
pl = getattr(plain_blk.attention, lin)
|
| 38 |
+
lr = getattr(lora_blk.attention, lin).original_layer
|
| 39 |
+
pl.weight.copy_(lr.weight)
|
| 40 |
+
pl.bias.copy_(lr.bias)
|
| 41 |
+
|
| 42 |
+
# classification head
|
| 43 |
+
plain_model.pre_classifier.weight.copy_(lora_model.pre_classifier.weight)
|
| 44 |
+
plain_model.pre_classifier.bias.copy_(lora_model.pre_classifier.bias)
|
| 45 |
+
plain_model.classifier.weight.copy_(lora_model.classifier.weight)
|
| 46 |
+
plain_model.classifier.bias.copy_(lora_model.classifier.bias)
|
| 47 |
+
|
| 48 |
+
# Save
|
| 49 |
+
torch.save(plain_model.state_dict(), "DISTILBERT_MERGED.pth")
|
| 50 |
+
print("✅ Merged weights saved to DISTILBERT_MERGED.pth")
|
requirements.txt
ADDED
|
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
torch
|
| 2 |
+
transformers
|
| 3 |
+
datasets
|
| 4 |
+
gradio
|