|
|
import os.path
|
|
|
|
|
|
import torch
|
|
|
import torch.nn as nn
|
|
|
from transformers import RobertaTokenizerFast, RobertaForMaskedLM
|
|
|
import streamlit as st
|
|
|
|
|
|
|
|
|
class SimpleClassifier(nn.Module):
|
|
|
def __init__(self, in_features: int, hidden_features: int,
|
|
|
out_features: int, activation=nn.ReLU()):
|
|
|
super().__init__()
|
|
|
self.bn = nn.BatchNorm1d(in_features)
|
|
|
self.in2hid = nn.Linear(in_features, hidden_features)
|
|
|
self.activation = activation
|
|
|
self.hid2hid = nn.Linear(hidden_features, hidden_features)
|
|
|
self.hid2out = nn.Linear(hidden_features, out_features)
|
|
|
|
|
|
|
|
|
|
|
|
self.bn2 = nn.BatchNorm1d(hidden_features)
|
|
|
|
|
|
def forward(self, X):
|
|
|
X = self.bn(X)
|
|
|
X = self.in2hid(X)
|
|
|
|
|
|
X = self.activation(X)
|
|
|
X = self.hid2hid(torch.concat((X,), 1))
|
|
|
|
|
|
X = self.activation(X)
|
|
|
X = self.hid2out(torch.concat((X,), 1))
|
|
|
|
|
|
X = nn.functional.sigmoid(X)
|
|
|
return X
|
|
|
|
|
|
|
|
|
@st.cache(allow_output_mutation=True)
|
|
|
def load_models():
|
|
|
model = RobertaForMaskedLM.from_pretrained("roberta-base")
|
|
|
model.lm_head = nn.Identity()
|
|
|
tokenizer = RobertaTokenizerFast.from_pretrained("roberta-base")
|
|
|
my_classifier = SimpleClassifier(768, 768, 1)
|
|
|
weights_path = "twitter_model_91_5-.pth"
|
|
|
my_classifier.load_state_dict(torch.load(weights_path, map_location=device))
|
|
|
my_classifier.eval()
|
|
|
return {
|
|
|
"tokenizer": tokenizer,
|
|
|
"model": model,
|
|
|
"classifier": my_classifier
|
|
|
}
|
|
|
|
|
|
|
|
|
def classify_text(text: str) -> float:
|
|
|
models = load_models()
|
|
|
tokenizer, model, classifier = models["tokenizer"], models["model"], models["classifier"]
|
|
|
|
|
|
X = tokenizer(
|
|
|
text,
|
|
|
truncation=True,
|
|
|
max_length=128,
|
|
|
return_tensors='pt'
|
|
|
)["input_ids"]
|
|
|
with torch.no_grad():
|
|
|
X = model.forward(X)[-1][0].sum(axis=0)[None, :]
|
|
|
return classifier(X)
|
|
|
|
|
|
|
|
|
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
|
|
|
|