burtenshaw HF Staff commited on
Commit
4c6e4e8
·
verified ·
1 Parent(s): 2d26ed8

Upload folder using huggingface_hub

Browse files
Files changed (1) hide show
  1. modeling_nanogpt.py +58 -0
modeling_nanogpt.py CHANGED
@@ -1,5 +1,7 @@
1
  import math
 
2
  from dataclasses import dataclass
 
3
 
4
  import torch
5
  import torch.nn as nn
@@ -170,5 +172,61 @@ class NanoGPTModel(PreTrainedModel):
170
  loss = F.cross_entropy(logits.view(-1, logits.size(-1)), labels.view(-1), ignore_index=-1, reduction='mean')
171
  return {"loss": loss, "logits": logits}
172
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
173
 
174
 
 
1
  import math
2
+ import os
3
  from dataclasses import dataclass
4
+ from pathlib import Path
5
 
6
  import torch
7
  import torch.nn as nn
 
172
  loss = F.cross_entropy(logits.view(-1, logits.size(-1)), labels.view(-1), ignore_index=-1, reduction='mean')
173
  return {"loss": loss, "logits": logits}
174
 
175
+ @classmethod
176
+ def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
177
+ config = kwargs.pop("config", None)
178
+ subfolder = kwargs.pop("subfolder", None)
179
+ device_map = kwargs.get("device_map")
180
+ if device_map is not None:
181
+ # Delegate complex dispatch (like accelerate) to the base implementation.
182
+ if subfolder is not None:
183
+ kwargs["subfolder"] = subfolder
184
+ if config is not None:
185
+ kwargs["config"] = config
186
+ return super().from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
187
+
188
+ base_path = Path(pretrained_model_name_or_path)
189
+ if subfolder:
190
+ base_path = base_path / subfolder
191
+
192
+ weight_path = None
193
+ if base_path.is_dir():
194
+ candidate_files = [
195
+ base_path / "pytorch_model.bin",
196
+ base_path / "model.bin",
197
+ ]
198
+ candidate_files.extend(sorted(base_path.glob("model_*.pt"), reverse=True))
199
+ candidate_files.extend(sorted(base_path.glob("*.bin"), reverse=True))
200
+ for cand in candidate_files:
201
+ if cand.is_file():
202
+ weight_path = cand
203
+ break
204
+
205
+ if weight_path is None:
206
+ # Fall back to the default behaviour (e.g. remote repo or standard filenames)
207
+ if subfolder is not None:
208
+ kwargs["subfolder"] = subfolder
209
+ if config is not None:
210
+ kwargs["config"] = config
211
+ return super().from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
212
+
213
+ if config is None:
214
+ config = NanoGPTConfig.from_pretrained(pretrained_model_name_or_path, subfolder=subfolder)
215
+
216
+ torch_dtype = kwargs.pop("torch_dtype", None)
217
+ strict = kwargs.pop("strict", True)
218
+
219
+ state_dict = torch.load(str(weight_path), map_location="cpu")
220
+ if isinstance(state_dict, dict) and "state_dict" in state_dict:
221
+ state_dict = state_dict["state_dict"]
222
+ state_dict = {k.lstrip("_orig_mod."): v for k, v in state_dict.items()}
223
+
224
+ model = cls(config, *model_args)
225
+ model.load_state_dict(state_dict, strict=strict)
226
+ if torch_dtype is not None:
227
+ model = model.to(dtype=torch_dtype)
228
+ model.eval()
229
+ return model
230
+
231
 
232