Instructions to use mbhr/nrms with libraries, inference providers, notebooks, and local apps. Follow these links to get started.
- Libraries
- Transformers
How to use mbhr/nrms with Transformers:
# Load model directly from transformers import AutoModel model = AutoModel.from_pretrained("mbhr/nrms", dtype="auto") - Notebooks
- Google Colab
- Kaggle
| #!/usr/bin/env python | |
| # coding: utf-8 | |
| # In[1]: | |
| import os | |
| from datetime import datetime | |
| from pathlib import Path | |
| import polars as pl | |
| import torch | |
| from transformers import AutoModel, AutoTokenizer | |
| from transformers import Trainer, TrainingArguments | |
| from accelerate import Accelerator, DistributedType | |
| from torch.optim import AdamW | |
| from torch.utils.data import DataLoader | |
| from utils._constants import * | |
| from utils._nlp import get_transformers_word_embeddings | |
| from utils._polars import concat_str_columns, slice_join_dataframes | |
| from utils._articles import ( | |
| convert_text2encoding_with_transformers, | |
| create_article_id_to_value_mapping | |
| ) | |
| from utils._python import make_lookup_objects | |
| from utils._behaviors import ( | |
| create_binary_labels_column, | |
| sampling_strategy_wu2019, | |
| truncate_history, | |
| ) | |
| from utils._articles_behaviors import map_list_article_id_to_value | |
| from dataset.pytorch_dataloader import ( | |
| ebnerd_from_path, | |
| NRMSDataset, | |
| NewsrecDataset, | |
| ) | |
| from evaluation import ( | |
| MetricEvaluator, | |
| AucScore, | |
| NdcgScore, | |
| MrrScore, | |
| F1Score, | |
| LogLossScore, | |
| RootMeanSquaredError, | |
| AccuracyScore | |
| ) | |
| from models.nrms import NRMSModel | |
| os.environ["TOKENIZERS_PARALLELISM"] = "false" | |
| # In[2]: | |
| TEST_DATA_PATH = "merged_0412_final.parquet" | |
| # In[3]: | |
| test_df = pl.read_parquet(TEST_DATA_PATH).with_columns(pl.Series("labels", [[]])) | |
| # In[4]: | |
| from transformers import AutoModel, AutoTokenizer | |
| model_name = "Maltehb/danish-bert-botxo" | |
| model = AutoModel.from_pretrained(model_name) | |
| tokenizer = AutoTokenizer.from_pretrained(model_name) | |
| word2vec_embeddimg = get_transformers_word_embeddings(model) | |
| # In[5]: | |
| ARTICLES_DATA_PATH = "/work/Blue/ebnerd/ebnerd_testset/articles.parquet" | |
| ARTICLE_COLUMNS = [DEFAULT_TITLE_COL, DEFAULT_SUBTITLE_COL] | |
| TEXT_MAX_LENGTH = 30 | |
| articles_df = pl.read_parquet(ARTICLES_DATA_PATH) | |
| df_articles, cat_col = concat_str_columns(articles_df, columns=ARTICLE_COLUMNS) | |
| df_articles, token_col_title = convert_text2encoding_with_transformers( | |
| df_articles, tokenizer, cat_col, max_length=TEXT_MAX_LENGTH | |
| ) | |
| article_mapping = create_article_id_to_value_mapping(df=df_articles, value_col=token_col_title) | |
| # In[6]: | |
| from dataclasses import dataclass, field | |
| import numpy as np | |
| class NRMSTestDataset(NewsrecDataset): | |
| def __post_init__(self): | |
| """ | |
| Post-initialization method. Loads the data and sets additional attributes. | |
| """ | |
| self.lookup_article_index = {id: i for i, id in enumerate(self.article_dict, start=1)} | |
| self.lookup_article_matrix = np.array(list(self.article_dict.values())) | |
| UNKNOWN_ARRAY = np.zeros(self.lookup_article_matrix.shape[1], dtype=self.lookup_article_matrix.dtype) | |
| self.lookup_article_matrix = np.vstack([UNKNOWN_ARRAY, self.lookup_article_matrix]) | |
| self.unknown_index = [0] | |
| self.X, self.y = self.load_data() | |
| if self.kwargs is not None: | |
| self.set_kwargs(self.kwargs) | |
| def __getitem__(self, idx) -> dict: | |
| """ | |
| history_input_tensor: (samples, history_size, document_dimension) | |
| candidate_input_title: (samples, npratio, document_dimension) | |
| label: (samples, npratio) | |
| """ | |
| batch_X = self.X[idx] | |
| article_id_fixed = [self.lookup_article_index.get(f, 0) for f in batch_X["article_id_fixed"].to_list()[0]] | |
| history_input_tensor = self.lookup_article_matrix[article_id_fixed] | |
| article_id_inview = [self.lookup_article_index.get(f, 0) for f in batch_X["article_ids_inview"].to_list()[0]] | |
| candidate_input_title = self.lookup_article_matrix[article_id_inview] | |
| return { | |
| "user_id": self.X[idx]["user_id"][0], | |
| "history_input_tensor": history_input_tensor, | |
| "candidate_article_id" : self.X[idx]["article_ids_inview"][0][0], | |
| "candidate_input_title": candidate_input_title, | |
| "labels" : np.int32(0) | |
| } | |
| # In[7]: | |
| test_dataset = NRMSTestDataset( | |
| behaviors=test_df, | |
| history_column=DEFAULT_HISTORY_ARTICLE_ID_COL, | |
| article_dict=article_mapping, | |
| unknown_representation="zeros", | |
| eval_mode=False, | |
| ) | |
| # In[8]: | |
| nrms_model = NRMSModel( | |
| pretrained_weight=torch.tensor(word2vec_embeddimg), | |
| emb_dim=768, | |
| num_heads=16, | |
| hidden_dim=128, | |
| item_dim=64, | |
| ) | |
| state_dict = torch.load("nrms_model.epoch0.step20001.pth") | |
| nrms_model = torch.compile(nrms_model) | |
| nrms_model.load_state_dict(state_dict["model"]) | |
| nrms_model.to("cuda:1") | |
| # In[ ]: | |
| import torch._dynamo | |
| from tqdm import tqdm | |
| import os | |
| from torch.utils.data import DataLoader | |
| BATCH_SIZE = 256 | |
| test_dataloader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=60) | |
| os.environ["TOKENIZERS_PARALLELISM"] = "true" | |
| torch._dynamo.config.suppress_errors = True | |
| nrms_model.eval() | |
| with open("test_set.txt", 'w') as f: | |
| with torch.no_grad(): | |
| for i, batch in enumerate(tqdm(test_dataloader)): | |
| user_id = batch["user_id"].cpu().tolist() | |
| candidate_article_id = batch["candidate_article_id"].cpu().tolist() | |
| history_input_tensor = batch["history_input_tensor"].to("cuda:1") | |
| candidate_input_title = batch["candidate_input_title"].to("cuda:1") | |
| output_logits = nrms_model(history_input_tensor, candidate_input_title, None)[:,0].cpu().tolist() | |
| for j in range(len(user_id)): | |
| line = f"{user_id[j]},{candidate_article_id[j]},{output_logits[j]}\n" | |
| f.write(line) | |
| # In[ ]: | |