File size: 11,168 Bytes
faadf80 0df4742 6f7dcac 0df4742 6f7dcac 0df4742 6f7dcac 0df4742 6f7dcac 0df4742 6f7dcac 0df4742 6f7dcac 0df4742 6f7dcac 0df4742 6f7dcac 0df4742 6f7dcac 0df4742 faadf80 0df4742 faadf80 0df4742 6f7dcac faadf80 1619f22 a909fe8 a8ca47b a909fe8 1619f22 a909fe8 3900a39 a909fe8 265e2ec a909fe8 caf29aa 265e2ec caf29aa 265e2ec caf29aa 265e2ec caf29aa 7467653 caf29aa a909fe8 265e2ec a909fe8 265e2ec a909fe8 265e2ec a909fe8 265e2ec a909fe8 265e2ec a909fe8 265e2ec a909fe8 265e2ec a909fe8 265e2ec cafe337 265e2ec cafe337 a909fe8 265e2ec a909fe8 265e2ec cafe337 265e2ec cafe337 265e2ec 3900a39 265e2ec a909fe8 3900a39 a909fe8 3900a39 a909fe8 0df4742 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 |
---
tags:
- text-classification
- medical
- prototypical-networks
- transformers
library_name: transformers
language: en
license: mit
datasets:
- your_dataset_name_here
model-index:
- name: ProtoPatient
results:
- task:
type: multi-label-classification
dataset:
name: your_dataset_name_here
type: text
metrics:
- name: Accuracy
type: accuracy
value: 0.XX # Update with real value
- name: F1-score
type: f1
value: 0.XX # Update with real value
---
# ProtoPatient Model for Multi-Label Classification
## Paper Reference
**van Aken, Betty, Jens-Michalis Papaioannou, Marcel G. Naik, Georgios Eleftheriadis, Wolfgang Nejdl, Felix A. Gers, and Alexander Löser. 2022.**
*This Patient Looks Like That Patient: Prototypical Networks for Interpretable Diagnosis Prediction from Clinical Text.*
[arXiv:2210.08500](https://arxiv.org/abs/2210.08500)
ProtoPatient is a transformer-based architecture that uses prototypical networks and label-wise attention to provide multi-label classification on clinical admission notes. Unlike standard black-box models, ProtoPatient offers inherent interpretability by:
- **Highlighting Relevant Tokens:** Shows the most important words for each possible diagnosis.
- **Retrieving Prototypical Patients:** Finds training examples with similar textual patterns to provide intuitive justifications for clinicians—essentially answering, “This patient looks like that patient.”
---
## Model Overview
### Prototype-Based Classification
- The model learns **prototypical vectors** (\(u_c\)) for each diagnosis \(c\).
- A patient’s admission note is encoded via a PubMedBERT encoder and a linear compression layer into a diagnosis-specific representation (\(v_{p,c}\)). This representation is generated using a label-wise attention mechanism.
- Classification scores are computed as the **negative Euclidean distance** between \(v_{p,c}\) and \(u_c\), which directly measures the note’s similarity to the learned prototype.
### Label-Wise Attention
- For each diagnosis, a separate attention vector identifies relevant tokens in the admission note.
- This mechanism provides interpretability by indicating which tokens are most influential in driving each prediction.
### Interpretable Output
- **Token Highlights:** The top attended words (often correlating with symptoms, risk factors, or diagnostic descriptors).
- **Prototypical Patients:** Examples from the training set that are closest to each prototype, representing typical presentations of a diagnosis.
---
## Key Features and Benefits
- **Improved Performance on Rare Diagnoses:**
Prototype-based learning has strong few-shot capabilities, which is especially beneficial for diagnoses with very few samples.
- **Faithful Interpretations:**
Quantitative evaluations (see Section 5 in the paper) indicate that the attention-based highlights are more faithful to the model’s decision process compared to post-hoc methods such as Lime, Occlusion, and gradient-based approaches.
- **Clinical Utility:**
- Provides label-wise explanations to help clinicians assess whether the predictions align with actual risk factors.
- Points to prototypical patients, allowing for comparison of new cases with typical (or atypical) presentations.
---
## Performance Metrics
Evaluated on **MIMIC-III**:
- **Admission Notes:** 48,745
- **Diagnosis Labels:** 1,266
Performance (approximate):
- **Macro ROC AUC:** ~87–88%
- **Micro ROC AUC:** ~97%
- **Macro PR AUC:** ~18–21%
The model shows particularly strong gains for rare diagnoses (less than 50 samples) when compared with baselines like PubMedBERT alone or hierarchical attention RNNs (e.g., HAN, HA-GRU).
Additionally, the model achieves high transferability on **i2b2** data (1,118 admission notes) across different clinical environments.
*Refer to Tables 1, 2, and 3 in the paper for detailed results and ablation studies.*
---
## Repository Structure
```plaintext
ProtoPatient/
├── proto_model/
│ ├── proto.py
│ ├── utils.py
│ ├── metrics.py
│ └── __init__.py
├── config.json
├── setup.py
├── model.safetensors
├── tokenizer.json
├── tokenizer_config.json
├── vocab.txt
├── README.md
└── .gitattributes
```
---
## How to Use the Model
### 1. Install Dependencies
```bash
git clone https://huggingface.co/row56/ProtoPatient
cd ProtoPatient
pip install -e . transformers torch safetensors
export TOKENIZERS_PARALLELISM=false
```
### 2. Load the Model via Hugging Face
```python
import os
import warnings
os.environ["TOKENIZERS_PARALLELISM"] = "false"
from transformers import logging as hf_logging
hf_logging.set_verbosity_error()
warnings.filterwarnings("ignore", category=UserWarning)
import torch
from transformers import AutoTokenizer
from proto_model.configuration_proto import ProtoConfig
from proto_model.modeling_proto import ProtoForMultiLabelClassification
cfg = ProtoConfig.from_pretrained("row56/ProtoPatient")
cfg.pretrained_model_name_or_path = "bert-base-uncased"
cfg.use_cuda = torch.cuda.is_available()
device = torch.device("cuda" if cfg.use_cuda else "cpu")
tokenizer = AutoTokenizer.from_pretrained(cfg.pretrained_model_name_or_path)
model = ProtoForMultiLabelClassification.from_pretrained(
"row56/ProtoPatient",
config=cfg,
)
model.to(device)
model.eval()
def get_proto_logits(texts):
enc = tokenizer(texts, padding=True, truncation=True, return_tensors="pt")
batch = {
"input_ids": enc["input_ids"],
"attention_masks": enc["attention_mask"],
"token_type_ids": enc.get("token_type_ids", torch.zeros_like(enc["input_ids"])),
"tokens": [tokenizer.convert_ids_to_tokens(ids.tolist()) for ids in enc["input_ids"]]
}
with torch.no_grad():
logits, _ = model.proto_module(batch)
return logits
texts = [
"Patient shows elevated heart rate and low oxygen saturation.",
"No significant findings; patient is healthy."
]
logits = get_proto_logits(texts)
print("Logits shape:", logits.shape)
print("Logits:\n", logits)
```
## 3. Training Data & Licenses
This model was trained on the MIMIC-III Clinical Database (v1.4), a large de-identified ICU dataset released under a data use agreement.
To obtain MIMIC-III:
Visit https://physionet.org/content/mimiciii/1.4/
Register for a free PhysioNet account and complete the CITI “Data or Specimens Only Research” training.
Sign the MIMIC-III Data Use Agreement (DUA).
Download the raw notes and run the preprocessing scripts from the paper’s repository.
Note: We do not redistribute MIMIC-III itself; users must obtain it directly under its license.
## 4. Load Precomputed Training Data for Prototype Retrieval
After you have MIMIC-III and have applied the published preprocessing, you should produce:
data/train_embeds.npy — NumPy array of shape (N, d) with per-example, per-class embeddings.
data/train_texts.json — JSON array of length N of the raw admission-note strings.
Place those in data/ and then:
```python
import numpy as np
import json
train_embeds = np.load("data/train_embeds.npy")
with open("data/train_texts.json", "r") as f:
train_texts = json.load(f)
print(f"Loaded {train_embeds.shape[0]} embeddings of dim {train_embeds.shape[1]}")
```
## 5. Interpreting Outputs & Retrieving Prototypes
```python
from sklearn.neighbors import NearestNeighbors
text = "Patient has chest pain and shortness of breath."
enc = tokenizer([text], padding=True, truncation=True, return_tensors="pt")
batch = {
"input_ids": enc["input_ids"],
"attention_masks": enc["attention_mask"],
"token_type_ids": enc.get("token_type_ids", torch.zeros_like(enc["input_ids"])),
"tokens": [tokenizer.convert_ids_to_tokens(ids.tolist()) for ids in enc["input_ids"]]
}
with torch.no_grad():
logits, metadata = model.proto_module(batch)
attn_scores = metadata["attentions"][0]
for label_id, scores in enumerate(attn_scores):
topk = sorted(zip(batch["tokens"][0], scores.tolist()),
key=lambda x: -x[1])[:5]
print(f"Label {label_id} top tokens:", topk)
proto_vecs = model.proto_module.prototype_vectors.cpu().numpy()
nn = NearestNeighbors(n_neighbors=1, metric="euclidean").fit(train_embeds)
for label_id, u_c in enumerate(proto_vecs):
dist, idx = nn.kneighbors(u_c.reshape(1, -1))
print(f"\nLabel {label_id} prototype (distance={dist[0][0]:.3f}):")
print(train_texts[idx[0][0]])
```
---
# Intended Use, Limitations & Ethical Considerations
## Intended Use
- **Research & Education:**
ProtoPatient is designed primarily for academic research and educational purposes in clinical NLP.
- **Interpretability Demonstration:**
The model demonstrates how prototype-based methods can provide interpretable multi-label classification on clinical admission notes.
---
## Limitations
- **Generalization:**
The model was trained on public ICU datasets (MIMIC-III, i2b2) and may not generalize to other patient populations.
- **Prototype Scope:**
The current version uses a single prototype per diagnosis, though some diagnoses might have multiple typical presentations—this is an area for future improvement.
- **Inter-diagnosis Relationships:**
The model does not explicitly model relationships (e.g., conflicts or comorbidities) between different diagnoses.
---
## Ethical & Regulatory Considerations
- **Not for Direct Clinical Use:**
This model is not intended for direct clinical decision-making. Always consult healthcare professionals.
- **Bias and Fairness:**
Users should be aware of potential biases in the training data; rare conditions might still be misclassified.
- **Patient Privacy:**
When applying the model to real clinical data, patient privacy must be strictly maintained.
---
# Example Interpretability Output
Based on the approach described in the paper (see Section 5 and Table 5):
- **Highlighted Tokens:**
Tokens such as “worst headache of her life,” “vomiting,” “fever,” and “infiltrate” strongly indicate specific diagnoses.
- **Prototypical Sample:**
A snippet from a training patient with similar text segments provides a rationale for the prediction.
*This interpretability output aids clinicians in understanding the model's reasoning – for example: "The system suggests intracerebral hemorrhage because the patient's note closely resembles typical cases with that diagnosis."*
---
# Recommended Citation
If you use ProtoPatient in your research, please cite:
```bibtex
@misc{vanaken2022this,
title={This Patient Looks Like That Patient: Prototypical Networks for Interpretable Diagnosis Prediction from Clinical Text},
author={van Aken, Betty and Papaioannou, Jens-Michalis and Naik, Marcel G. and Eleftheriadis, Georgios and Nejdl, Wolfgang and Gers, Felix A. and L{\"o}ser, Alexander},
year={2022},
eprint={2210.08500},
archivePrefix={arXiv},
primaryClass={cs.CL}
}
|