anoushka2000 commited on
Commit
de15e64
·
verified ·
1 Parent(s): 6a18edf

Upload folder using huggingface_hub

Browse files
Files changed (1) hide show
  1. modeling_mist_finetuned.py +17 -7
modeling_mist_finetuned.py CHANGED
@@ -20,6 +20,8 @@ import torch
20
  import torch.nn as nn
21
  import torch.nn.functional as F
22
 
 
 
23
  MODEL_TYPE_ALIASES = {}
24
 
25
  def build_encoder(enc_dict: Dict[str, Any]):
@@ -121,10 +123,12 @@ class MISTFinetuned(PreTrainedModel):
121
  if getattr(self, "tokenizer", None) is not None:
122
  return self.tokenizer
123
  try:
124
- return AutoTokenizer.from_pretrained(self.name_or_path, use_fast=True)
 
 
125
  except Exception:
126
  return AutoTokenizer.from_pretrained(
127
- self.config._name_or_path, use_fast=True
128
  )
129
 
130
  def embed(self, smi: List[str], tokenizer=None):
@@ -142,10 +146,14 @@ class MISTFinetuned(PreTrainedModel):
142
  def predict(self, smi: List[str], return_dict: bool = True, tokenizer=None):
143
  tok = self._resolve_tokenizer(tokenizer)
144
  batch = tok(smi)
145
- batch = DataCollatorWithPadding(tok)(batch)
146
- inputs = {k: v.to(self.device) for k, v in batch.items()}
 
 
 
 
147
  with torch.inference_mode():
148
- out = self(**inputs).cpu()
149
  if self.channels is None or not return_dict:
150
  return out
151
  return annotate_prediction(out, maybe_get_annotated_channels(self.channels))
@@ -287,10 +295,12 @@ class MISTMultiTask(PreTrainedModel):
287
  if getattr(self, "tokenizer", None) is not None:
288
  return self.tokenizer
289
  try:
290
- return AutoTokenizer.from_pretrained(self.name_or_path, use_fast=True)
 
 
291
  except Exception:
292
  return AutoTokenizer.from_pretrained(
293
- self.config._name_or_path, use_fast=True
294
  )
295
 
296
  def predict(self, smi: List[str], tokenizer=None):
 
20
  import torch.nn as nn
21
  import torch.nn.functional as F
22
 
23
+
24
+ AutoTokenizer.register("SmirkTokenizer", fast_tokenizer_class=SmirkTokenizerFast)
25
  MODEL_TYPE_ALIASES = {}
26
 
27
  def build_encoder(enc_dict: Dict[str, Any]):
 
123
  if getattr(self, "tokenizer", None) is not None:
124
  return self.tokenizer
125
  try:
126
+ return AutoTokenizer.from_pretrained(
127
+ self.name_or_path, use_fast=True, trust_remote_code=True
128
+ )
129
  except Exception:
130
  return AutoTokenizer.from_pretrained(
131
+ self.config._name_or_path, use_fast=True, trust_remote_code=True
132
  )
133
 
134
  def embed(self, smi: List[str], tokenizer=None):
 
146
  def predict(self, smi: List[str], return_dict: bool = True, tokenizer=None):
147
  tok = self._resolve_tokenizer(tokenizer)
148
  batch = tok(smi)
149
+ collate_fn = DataCollatorWithPadding(tok)
150
+ batch = collate_fn(batch)
151
+ batch = {
152
+ "input_ids": batch["input_ids"].to(self.encoder.device),
153
+ "attention_mask": batch["attention_mask"].to(self.encoder.device),
154
+ }
155
  with torch.inference_mode():
156
+ out = self(**batch).cpu()
157
  if self.channels is None or not return_dict:
158
  return out
159
  return annotate_prediction(out, maybe_get_annotated_channels(self.channels))
 
295
  if getattr(self, "tokenizer", None) is not None:
296
  return self.tokenizer
297
  try:
298
+ return AutoTokenizer.from_pretrained(
299
+ self.name_or_path, use_fast=True, trust_remote_code=True
300
+ )
301
  except Exception:
302
  return AutoTokenizer.from_pretrained(
303
+ self.config._name_or_path, use_fast=True, trust_remote_code=True
304
  )
305
 
306
  def predict(self, smi: List[str], tokenizer=None):