pinthoz commited on
Commit
83f9f95
·
verified ·
1 Parent(s): 93eb2f0

Upload 50 files

Browse files
.gitattributes CHANGED
@@ -1,2 +1,3 @@
1
  static/favicon.ico filter=lfs diff=lfs merge=lfs -text
2
  static/images/architecture.png filter=lfs diff=lfs merge=lfs -text
 
 
1
  static/favicon.ico filter=lfs diff=lfs merge=lfs -text
2
  static/images/architecture.png filter=lfs diff=lfs merge=lfs -text
3
+ attention_app/server/__pycache__/main.cpython-311.pyc filter=lfs diff=lfs merge=lfs -text
attention_app/bias/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (945 Bytes). View file
 
attention_app/bias/__pycache__/attention_bias.cpython-311.pyc ADDED
Binary file (12.3 kB). View file
 
attention_app/bias/__pycache__/embedding_analyzer.cpython-311.pyc ADDED
Binary file (11.7 kB). View file
 
attention_app/bias/__pycache__/feature_extraction.cpython-311.pyc ADDED
Binary file (11.3 kB). View file
 
attention_app/bias/__pycache__/token_detector.cpython-311.pyc ADDED
Binary file (14.9 kB). View file
 
attention_app/bias/__pycache__/visualizations.cpython-311.pyc ADDED
Binary file (14.7 kB). View file
 
attention_app/models.py ADDED
@@ -0,0 +1,108 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ import warnings
3
+ import torch
4
+ from transformers import BertTokenizer, BertModel, BertForMaskedLM
5
+ from transformers.utils import logging as transformers_logging
6
+ from transformers import GPT2Tokenizer, GPT2Model, GPT2LMHeadModel
7
+ import gc
8
+
9
+ # Suppress warnings
10
+ warnings.filterwarnings(
11
+ "ignore",
12
+ message="`resume_download` is deprecated",
13
+ category=FutureWarning,
14
+ module="huggingface_hub.file_download",
15
+ )
16
+ transformers_logging.set_verbosity_error()
17
+ logging.getLogger("transformers").setLevel(logging.ERROR)
18
+
19
+
20
+ class ModelManager:
21
+ """Manages loading and caching of BERT models."""
22
+
23
+ _instances = {}
24
+
25
+ @classmethod
26
+ def get_model(cls, model_name: str):
27
+ """
28
+ Returns (tokenizer, encoder_model, mlm_model) for the specified model_name.
29
+ Loads from cache if available, otherwise loads from HuggingFace.
30
+ """
31
+ # Check if model is already loaded
32
+ if model_name in cls._instances:
33
+ return cls._instances[model_name]
34
+
35
+ # Clear existing cache to free memory
36
+ if cls._instances:
37
+ print(f"Unloading previous models to free memory...")
38
+ cls._instances.clear()
39
+ gc.collect()
40
+ if torch.cuda.is_available():
41
+ torch.cuda.empty_cache()
42
+
43
+ print(f"Loading model: {model_name}...")
44
+
45
+ try:
46
+
47
+ is_gpt2 = "gpt2" in model_name
48
+
49
+ if is_gpt2:
50
+ tokenizer = GPT2Tokenizer.from_pretrained(model_name)
51
+ if tokenizer.pad_token is None:
52
+ tokenizer.pad_token = tokenizer.eos_token
53
+
54
+ encoder = GPT2Model.from_pretrained(
55
+ model_name,
56
+ output_attentions=True,
57
+ output_hidden_states=True,
58
+ )
59
+ encoder.eval()
60
+
61
+ try:
62
+ mlm = GPT2LMHeadModel.from_pretrained(
63
+ model_name,
64
+ output_attentions=False,
65
+ output_hidden_states=False,
66
+ )
67
+ mlm.eval()
68
+ except Exception as e:
69
+ print(f"Warning: Could not load LM head for {model_name}: {e}")
70
+ mlm = None
71
+ else:
72
+ tokenizer = BertTokenizer.from_pretrained(model_name)
73
+
74
+ encoder = BertModel.from_pretrained(
75
+ model_name,
76
+ output_attentions=True,
77
+ output_hidden_states=True,
78
+ )
79
+ encoder.eval()
80
+
81
+ try:
82
+ mlm = BertForMaskedLM.from_pretrained(
83
+ model_name,
84
+ output_attentions=False,
85
+ output_hidden_states=False,
86
+ )
87
+ mlm.eval()
88
+ except Exception as e:
89
+ print(f"Warning: Could not load MLM head for {model_name}: {e}")
90
+ mlm = None
91
+
92
+ # Move to GPU if available
93
+ device = "cuda" if torch.cuda.is_available() else "cpu"
94
+ encoder.to(device)
95
+ if mlm:
96
+ mlm.to(device)
97
+
98
+ cls._instances[model_name] = (tokenizer, encoder, mlm)
99
+ return tokenizer, encoder, mlm
100
+
101
+ except Exception as e:
102
+ raise RuntimeError(f"Failed to load model {model_name}: {str(e)}")
103
+
104
+ @staticmethod
105
+ def get_device():
106
+ return "cuda" if torch.cuda.is_available() else "cpu"
107
+
108
+ __all__ = ["ModelManager"]
attention_app/server/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (493 Bytes). View file
 
attention_app/server/__pycache__/bias_handlers.cpython-311.pyc ADDED
Binary file (24.3 kB). View file
 
attention_app/server/__pycache__/logic.cpython-311.pyc ADDED
Binary file (5.46 kB). View file
 
attention_app/server/__pycache__/main.cpython-311.pyc ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:9fbca4c333ba02d6b4139d59210196392c9f834097a8aaca78c574bd7fe9e564
3
+ size 101097
attention_app/server/__pycache__/renderers.cpython-311.pyc ADDED
Binary file (34.9 kB). View file
 
attention_app/ui/__init__.py ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ """UI package for Attention Atlas."""
2
+
3
+ from .layouts import app_ui
4
+
5
+ __all__ = ["app_ui"]
attention_app/ui/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (314 Bytes). View file
 
attention_app/ui/__pycache__/bias_ui.cpython-311.pyc ADDED
Binary file (6.23 kB). View file
 
attention_app/ui/__pycache__/components.cpython-311.pyc ADDED
Binary file (2.3 kB). View file
 
attention_app/ui/__pycache__/layouts.cpython-311.pyc ADDED
Binary file (5.15 kB). View file
 
attention_app/ui/__pycache__/modals.cpython-311.pyc ADDED
Binary file (5.31 kB). View file
 
attention_app/ui/__pycache__/scripts.cpython-311.pyc ADDED
Binary file (55.5 kB). View file
 
attention_app/ui/__pycache__/styles.cpython-311.pyc ADDED
Binary file (55.1 kB). View file