Update README.md
Browse files
README.md
CHANGED
|
@@ -51,32 +51,26 @@ import pandas as pd, numpy as np, warnings, torch, re
|
|
| 51 |
from transformers import AutoModelForSequenceClassification, AutoTokenizer
|
| 52 |
from bs4 import BeautifulSoup
|
| 53 |
warnings.filterwarnings("ignore", category=UserWarning, module='bs4')
|
| 54 |
-
|
| 55 |
# Helper Functions
|
| 56 |
def clean_and_parse_tweet(tweet):
|
| 57 |
tweet = re.sub(r"https?://\S+|www\.\S+", " URL ", tweet)
|
| 58 |
parsed = BeautifulSoup(tweet, "html.parser").get_text() if "filename" not in str(BeautifulSoup(tweet, "html.parser")) else None
|
| 59 |
return re.sub(r" +", " ", re.sub(r'^[.:]+', '', re.sub(r"\\n+|\n+", " ", parsed or tweet)).strip()) if parsed else None
|
| 60 |
-
|
| 61 |
def predict_tweet(tweet, model, tokenizer, device, threshold=0.5):
|
| 62 |
inputs = tokenizer(tweet, return_tensors="pt", padding=True, truncation=True, max_length=128).to(device)
|
| 63 |
probs = torch.sigmoid(model(**inputs).logits).detach().cpu().numpy()[0]
|
| 64 |
return probs, [id2label[i] for i, p in enumerate(probs) if id2label[i] in {'Product', 'Place', 'Price', 'Promotion'} and p >= threshold]
|
| 65 |
-
|
| 66 |
# Setup
|
| 67 |
device = "mps" if torch.backends.mps.is_built() and torch.backends.mps.is_available() else "cuda" if torch.cuda.is_available() else "cpu"
|
| 68 |
synxp = "dmr76/mmx_classifier_microblog_ENv02"
|
| 69 |
model = AutoModelForSequenceClassification.from_pretrained(synxp).to(device)
|
| 70 |
tokenizer = AutoTokenizer.from_pretrained(synxp)
|
| 71 |
id2label = model.config.id2label
|
| 72 |
-
|
| 73 |
# ---->>> Define your Tweet <<<----
|
| 74 |
tweet = "Best cushioning ever!!! ๐ค๐ค๐ค my zoom vomeros are the bomb๐๐ฝโโ๏ธ๐จ!!! \n @nike #run #training https://randomurl.ai"
|
| 75 |
-
|
| 76 |
# Clean and Predict
|
| 77 |
cleaned_tweet = clean_and_parse_tweet(tweet)
|
| 78 |
probs, labels = predict_tweet(cleaned_tweet, model, tokenizer, device)
|
| 79 |
-
|
| 80 |
# Print Labels and Probabilities
|
| 81 |
print("Please don't forget to cite the paper: https://ssrn.com/abstract=4542949 in you use this code")
|
| 82 |
print(labels, probs)
|
|
|
|
| 51 |
from transformers import AutoModelForSequenceClassification, AutoTokenizer
|
| 52 |
from bs4 import BeautifulSoup
|
| 53 |
warnings.filterwarnings("ignore", category=UserWarning, module='bs4')
|
|
|
|
| 54 |
# Helper Functions
|
| 55 |
def clean_and_parse_tweet(tweet):
|
| 56 |
tweet = re.sub(r"https?://\S+|www\.\S+", " URL ", tweet)
|
| 57 |
parsed = BeautifulSoup(tweet, "html.parser").get_text() if "filename" not in str(BeautifulSoup(tweet, "html.parser")) else None
|
| 58 |
return re.sub(r" +", " ", re.sub(r'^[.:]+', '', re.sub(r"\\n+|\n+", " ", parsed or tweet)).strip()) if parsed else None
|
|
|
|
| 59 |
def predict_tweet(tweet, model, tokenizer, device, threshold=0.5):
|
| 60 |
inputs = tokenizer(tweet, return_tensors="pt", padding=True, truncation=True, max_length=128).to(device)
|
| 61 |
probs = torch.sigmoid(model(**inputs).logits).detach().cpu().numpy()[0]
|
| 62 |
return probs, [id2label[i] for i, p in enumerate(probs) if id2label[i] in {'Product', 'Place', 'Price', 'Promotion'} and p >= threshold]
|
|
|
|
| 63 |
# Setup
|
| 64 |
device = "mps" if torch.backends.mps.is_built() and torch.backends.mps.is_available() else "cuda" if torch.cuda.is_available() else "cpu"
|
| 65 |
synxp = "dmr76/mmx_classifier_microblog_ENv02"
|
| 66 |
model = AutoModelForSequenceClassification.from_pretrained(synxp).to(device)
|
| 67 |
tokenizer = AutoTokenizer.from_pretrained(synxp)
|
| 68 |
id2label = model.config.id2label
|
|
|
|
| 69 |
# ---->>> Define your Tweet <<<----
|
| 70 |
tweet = "Best cushioning ever!!! ๐ค๐ค๐ค my zoom vomeros are the bomb๐๐ฝโโ๏ธ๐จ!!! \n @nike #run #training https://randomurl.ai"
|
|
|
|
| 71 |
# Clean and Predict
|
| 72 |
cleaned_tweet = clean_and_parse_tweet(tweet)
|
| 73 |
probs, labels = predict_tweet(cleaned_tweet, model, tokenizer, device)
|
|
|
|
| 74 |
# Print Labels and Probabilities
|
| 75 |
print("Please don't forget to cite the paper: https://ssrn.com/abstract=4542949 in you use this code")
|
| 76 |
print(labels, probs)
|