File size: 11,168 Bytes
faadf80
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0df4742
6f7dcac
0df4742
 
 
 
 
 
 
 
 
 
6f7dcac
 
0df4742
6f7dcac
0df4742
 
 
 
 
6f7dcac
0df4742
6f7dcac
0df4742
 
6f7dcac
0df4742
 
 
 
 
 
6f7dcac
0df4742
 
 
 
 
 
 
 
 
 
 
 
 
6f7dcac
 
0df4742
 
 
6f7dcac
0df4742
 
 
 
faadf80
0df4742
faadf80
0df4742
 
 
 
 
6f7dcac
 
faadf80
1619f22
a909fe8
 
 
 
 
 
 
a8ca47b
a909fe8
 
 
 
 
 
1619f22
a909fe8
3900a39
a909fe8
 
 
 
 
 
265e2ec
 
 
 
a909fe8
 
 
 
 
caf29aa
 
 
 
 
 
 
 
 
265e2ec
 
 
 
 
caf29aa
265e2ec
caf29aa
 
 
265e2ec
 
caf29aa
 
 
7467653
caf29aa
 
a909fe8
265e2ec
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a909fe8
 
265e2ec
a909fe8
265e2ec
a909fe8
265e2ec
a909fe8
265e2ec
 
 
 
 
a909fe8
265e2ec
a909fe8
265e2ec
 
 
 
 
a909fe8
 
265e2ec
 
 
cafe337
265e2ec
cafe337
a909fe8
265e2ec
a909fe8
 
265e2ec
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
cafe337
265e2ec
 
 
 
 
cafe337
265e2ec
 
 
 
 
 
 
 
3900a39
265e2ec
a909fe8
 
 
 
 
 
 
 
 
 
3900a39
 
a909fe8
 
 
 
 
 
 
 
 
 
 
3900a39
 
a909fe8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0df4742
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
---
tags:
  - text-classification
  - medical
  - prototypical-networks
  - transformers
library_name: transformers
language: en
license: mit
datasets:
  - your_dataset_name_here
model-index:
  - name: ProtoPatient
    results:
      - task:
          type: multi-label-classification
        dataset:
          name: your_dataset_name_here
          type: text
        metrics:
          - name: Accuracy
            type: accuracy
            value: 0.XX  # Update with real value
          - name: F1-score
            type: f1
            value: 0.XX  # Update with real value
---


# ProtoPatient Model for Multi-Label Classification

## Paper Reference

**van Aken, Betty, Jens-Michalis Papaioannou, Marcel G. Naik, Georgios Eleftheriadis, Wolfgang Nejdl, Felix A. Gers, and Alexander Löser. 2022.**  
*This Patient Looks Like That Patient: Prototypical Networks for Interpretable Diagnosis Prediction from Clinical Text.*  
[arXiv:2210.08500](https://arxiv.org/abs/2210.08500)

ProtoPatient is a transformer-based architecture that uses prototypical networks and label-wise attention to provide multi-label classification on clinical admission notes. Unlike standard black-box models, ProtoPatient offers inherent interpretability by:

- **Highlighting Relevant Tokens:** Shows the most important words for each possible diagnosis.
- **Retrieving Prototypical Patients:** Finds training examples with similar textual patterns to provide intuitive justifications for clinicians—essentially answering, “This patient looks like that patient.”

---

## Model Overview

### Prototype-Based Classification

- The model learns **prototypical vectors** (\(u_c\)) for each diagnosis \(c\).
- A patient’s admission note is encoded via a PubMedBERT encoder and a linear compression layer into a diagnosis-specific representation (\(v_{p,c}\)). This representation is generated using a label-wise attention mechanism.
- Classification scores are computed as the **negative Euclidean distance** between \(v_{p,c}\) and \(u_c\), which directly measures the note’s similarity to the learned prototype.

### Label-Wise Attention

- For each diagnosis, a separate attention vector identifies relevant tokens in the admission note.
- This mechanism provides interpretability by indicating which tokens are most influential in driving each prediction.

### Interpretable Output

- **Token Highlights:** The top attended words (often correlating with symptoms, risk factors, or diagnostic descriptors).
- **Prototypical Patients:** Examples from the training set that are closest to each prototype, representing typical presentations of a diagnosis.

---

## Key Features and Benefits

- **Improved Performance on Rare Diagnoses:**  
  Prototype-based learning has strong few-shot capabilities, which is especially beneficial for diagnoses with very few samples.
  
- **Faithful Interpretations:**  
  Quantitative evaluations (see Section 5 in the paper) indicate that the attention-based highlights are more faithful to the model’s decision process compared to post-hoc methods such as Lime, Occlusion, and gradient-based approaches.
  
- **Clinical Utility:**  
  - Provides label-wise explanations to help clinicians assess whether the predictions align with actual risk factors.  
  - Points to prototypical patients, allowing for comparison of new cases with typical (or atypical) presentations.

---

## Performance Metrics

Evaluated on **MIMIC-III**:
- **Admission Notes:** 48,745  
- **Diagnosis Labels:** 1,266  

Performance (approximate):
- **Macro ROC AUC:** ~87–88%
- **Micro ROC AUC:** ~97%
- **Macro PR AUC:** ~18–21%

The model shows particularly strong gains for rare diagnoses (less than 50 samples) when compared with baselines like PubMedBERT alone or hierarchical attention RNNs (e.g., HAN, HA-GRU).

Additionally, the model achieves high transferability on **i2b2** data (1,118 admission notes) across different clinical environments.

*Refer to Tables 1, 2, and 3 in the paper for detailed results and ablation studies.*

---

## Repository Structure

```plaintext
ProtoPatient/
├── proto_model/
│   ├── proto.py
│   ├── utils.py
│   ├── metrics.py
│   └── __init__.py
├── config.json
├── setup.py
├── model.safetensors
├── tokenizer.json
├── tokenizer_config.json
├── vocab.txt
├── README.md
└── .gitattributes
```

---

## How to Use the Model

### 1. Install Dependencies

```bash
git clone https://huggingface.co/row56/ProtoPatient
cd ProtoPatient
pip install -e . transformers torch safetensors
export TOKENIZERS_PARALLELISM=false
```

### 2. Load the Model via Hugging Face

```python
import os
import warnings

os.environ["TOKENIZERS_PARALLELISM"] = "false"

from transformers import logging as hf_logging
hf_logging.set_verbosity_error() 

warnings.filterwarnings("ignore", category=UserWarning)
import torch
from transformers import AutoTokenizer
from proto_model.configuration_proto import ProtoConfig
from proto_model.modeling_proto    import ProtoForMultiLabelClassification

cfg = ProtoConfig.from_pretrained("row56/ProtoPatient")
cfg.pretrained_model_name_or_path = "bert-base-uncased"
cfg.use_cuda = torch.cuda.is_available()

device = torch.device("cuda" if cfg.use_cuda else "cpu")

tokenizer = AutoTokenizer.from_pretrained(cfg.pretrained_model_name_or_path)
model = ProtoForMultiLabelClassification.from_pretrained(
    "row56/ProtoPatient",
    config=cfg,

)
model.to(device)
model.eval()

def get_proto_logits(texts):
    enc = tokenizer(texts, padding=True, truncation=True, return_tensors="pt")
    batch = {
        "input_ids":       enc["input_ids"],
        "attention_masks": enc["attention_mask"],
        "token_type_ids":  enc.get("token_type_ids", torch.zeros_like(enc["input_ids"])),
        "tokens":          [tokenizer.convert_ids_to_tokens(ids.tolist()) for ids in enc["input_ids"]]
    }
    with torch.no_grad():
        logits, _ = model.proto_module(batch)
    return logits

texts = [
    "Patient shows elevated heart rate and low oxygen saturation.",
    "No significant findings; patient is healthy."
]
logits = get_proto_logits(texts)
print("Logits shape:", logits.shape)
print("Logits:\n", logits)
```

## 3. Training Data & Licenses

This model was trained on the MIMIC-III Clinical Database (v1.4), a large de-identified ICU dataset released under a data use agreement.

To obtain MIMIC-III:

Visit https://physionet.org/content/mimiciii/1.4/
Register for a free PhysioNet account and complete the CITI “Data or Specimens Only Research” training.
Sign the MIMIC-III Data Use Agreement (DUA).
Download the raw notes and run the preprocessing scripts from the paper’s repository.
Note: We do not redistribute MIMIC-III itself; users must obtain it directly under its license.

## 4. Load Precomputed Training Data for Prototype Retrieval

After you have MIMIC-III and have applied the published preprocessing, you should produce:

data/train_embeds.npy — NumPy array of shape (N, d) with per-example, per-class embeddings.
data/train_texts.json — JSON array of length N of the raw admission-note strings.
Place those in data/ and then:

```python
import numpy as np
import json

train_embeds = np.load("data/train_embeds.npy")
with open("data/train_texts.json", "r") as f:
    train_texts = json.load(f)

print(f"Loaded {train_embeds.shape[0]} embeddings of dim {train_embeds.shape[1]}")
```

## 5. Interpreting Outputs & Retrieving Prototypes

```python
from sklearn.neighbors import NearestNeighbors

text = "Patient has chest pain and shortness of breath."
enc  = tokenizer([text], padding=True, truncation=True, return_tensors="pt")
batch = {
    "input_ids":       enc["input_ids"],
    "attention_masks": enc["attention_mask"],
    "token_type_ids":  enc.get("token_type_ids", torch.zeros_like(enc["input_ids"])),
    "tokens":          [tokenizer.convert_ids_to_tokens(ids.tolist()) for ids in enc["input_ids"]]
}

with torch.no_grad():
    logits, metadata = model.proto_module(batch)

attn_scores = metadata["attentions"][0]
for label_id, scores in enumerate(attn_scores):
    topk = sorted(zip(batch["tokens"][0], scores.tolist()),
                  key=lambda x: -x[1])[:5]
    print(f"Label {label_id} top tokens:", topk)

proto_vecs = model.proto_module.prototype_vectors.cpu().numpy()
nn = NearestNeighbors(n_neighbors=1, metric="euclidean").fit(train_embeds)

for label_id, u_c in enumerate(proto_vecs):
    dist, idx = nn.kneighbors(u_c.reshape(1, -1))
    print(f"\nLabel {label_id} prototype (distance={dist[0][0]:.3f}):")
    print(train_texts[idx[0][0]])
```

---

# Intended Use, Limitations & Ethical Considerations

## Intended Use

- **Research & Education:**  
  ProtoPatient is designed primarily for academic research and educational purposes in clinical NLP.
  
- **Interpretability Demonstration:**  
  The model demonstrates how prototype-based methods can provide interpretable multi-label classification on clinical admission notes.

---

## Limitations

- **Generalization:**  
  The model was trained on public ICU datasets (MIMIC-III, i2b2) and may not generalize to other patient populations.
  
- **Prototype Scope:**  
  The current version uses a single prototype per diagnosis, though some diagnoses might have multiple typical presentations—this is an area for future improvement.
  
- **Inter-diagnosis Relationships:**  
  The model does not explicitly model relationships (e.g., conflicts or comorbidities) between different diagnoses.

---

## Ethical & Regulatory Considerations

- **Not for Direct Clinical Use:**  
  This model is not intended for direct clinical decision-making. Always consult healthcare professionals.
  
- **Bias and Fairness:**  
  Users should be aware of potential biases in the training data; rare conditions might still be misclassified.
  
- **Patient Privacy:**  
  When applying the model to real clinical data, patient privacy must be strictly maintained.

---

# Example Interpretability Output

Based on the approach described in the paper (see Section 5 and Table 5):

- **Highlighted Tokens:**  
  Tokens such as “worst headache of her life,” “vomiting,” “fever,” and “infiltrate” strongly indicate specific diagnoses.
  
- **Prototypical Sample:**  
  A snippet from a training patient with similar text segments provides a rationale for the prediction.

*This interpretability output aids clinicians in understanding the model's reasoning – for example: "The system suggests intracerebral hemorrhage because the patient's note closely resembles typical cases with that diagnosis."*

---

# Recommended Citation

If you use ProtoPatient in your research, please cite:

```bibtex
@misc{vanaken2022this,
  title={This Patient Looks Like That Patient: Prototypical Networks for Interpretable Diagnosis Prediction from Clinical Text},
  author={van Aken, Betty and Papaioannou, Jens-Michalis and Naik, Marcel G. and Eleftheriadis, Georgios and Nejdl, Wolfgang and Gers, Felix A. and L{\"o}ser, Alexander},
  year={2022},
  eprint={2210.08500},
  archivePrefix={arXiv},
  primaryClass={cs.CL}
}