|
|
""" |
|
|
Dataset Loader Module (Refactored) |
|
|
|
|
|
Generic dataset loading supporting multiple formats: |
|
|
- OBO (Open Biomedical Ontologies) |
|
|
- CSV/TSV |
|
|
- JSON/JSON-LD |
|
|
- Custom adapters |
|
|
|
|
|
Configuration-driven to support any domain, not just medical. |
|
|
""" |
|
|
|
|
|
import os |
|
|
import re |
|
|
import json |
|
|
import csv |
|
|
import logging |
|
|
import hashlib |
|
|
import urllib.request |
|
|
from pathlib import Path |
|
|
from abc import ABC, abstractmethod |
|
|
from datetime import datetime |
|
|
from typing import Dict, List, Optional, Tuple, Any, Type |
|
|
from dataclasses import dataclass, field |
|
|
|
|
|
from .knowledge_graph import Entity, EntityCategory, KnowledgeGraph |
|
|
from .config import DatasetConfig, get_config |
|
|
|
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
|
|
|
@dataclass |
|
|
class OntologyTerm: |
|
|
"""Generic ontology term representation.""" |
|
|
id: str |
|
|
name: str |
|
|
definition: str = "" |
|
|
synonyms: List[str] = field(default_factory=list) |
|
|
xrefs: Dict[str, str] = field(default_factory=dict) |
|
|
is_a: List[str] = field(default_factory=list) |
|
|
relationships: List[Tuple[str, str]] = field(default_factory=list) |
|
|
namespace: str = "" |
|
|
is_obsolete: bool = False |
|
|
|
|
|
def to_entity(self, category: EntityCategory) -> Entity: |
|
|
"""Convert to Entity.""" |
|
|
return Entity( |
|
|
id=self.id, |
|
|
name=self.name, |
|
|
category=category, |
|
|
description=self.definition, |
|
|
synonyms=self.synonyms, |
|
|
xrefs=self.xrefs, |
|
|
properties={"is_a": self.is_a, "namespace": self.namespace} |
|
|
) |
|
|
|
|
|
|
|
|
class DatasetAdapter(ABC): |
|
|
"""Abstract base class for dataset adapters.""" |
|
|
|
|
|
@abstractmethod |
|
|
def parse(self, content: str) -> Dict[str, OntologyTerm]: |
|
|
"""Parse content and return dictionary of terms.""" |
|
|
pass |
|
|
|
|
|
@abstractmethod |
|
|
def can_handle(self, source_type: str) -> bool: |
|
|
"""Check if this adapter can handle the source type.""" |
|
|
pass |
|
|
|
|
|
|
|
|
class OBOAdapter(DatasetAdapter): |
|
|
"""Parser for OBO (Open Biomedical Ontologies) format.""" |
|
|
|
|
|
def can_handle(self, source_type: str) -> bool: |
|
|
return source_type.lower() == "obo" |
|
|
|
|
|
def parse(self, content: str) -> Dict[str, OntologyTerm]: |
|
|
"""Parse OBO format content.""" |
|
|
terms = {} |
|
|
|
|
|
|
|
|
stanzas = re.split(r'\n\[', content) |
|
|
|
|
|
for stanza in stanzas[1:]: |
|
|
if stanza.startswith('Term]'): |
|
|
term = self._parse_term(stanza[5:]) |
|
|
if term and not term.is_obsolete: |
|
|
terms[term.id] = term |
|
|
|
|
|
logger.info(f"Parsed {len(terms)} terms from OBO content") |
|
|
return terms |
|
|
|
|
|
def _parse_term(self, stanza: str) -> Optional[OntologyTerm]: |
|
|
"""Parse a single term stanza.""" |
|
|
data = { |
|
|
"id": "", "name": "", "definition": "", |
|
|
"synonyms": [], "xrefs": {}, "is_a": [], |
|
|
"relationships": [], "namespace": "", "is_obsolete": False |
|
|
} |
|
|
|
|
|
for line in stanza.split('\n'): |
|
|
line = line.strip() |
|
|
if not line or line.startswith('!') or ':' not in line: |
|
|
continue |
|
|
|
|
|
tag, _, value = line.partition(':') |
|
|
tag, value = tag.strip(), value.strip() |
|
|
|
|
|
if tag == 'id': |
|
|
data['id'] = value |
|
|
elif tag == 'name': |
|
|
data['name'] = value |
|
|
elif tag == 'def': |
|
|
match = re.match(r'"([^"]*)"', value) |
|
|
if match: |
|
|
data['definition'] = match.group(1) |
|
|
elif tag == 'synonym': |
|
|
match = re.match(r'"([^"]*)"', value) |
|
|
if match: |
|
|
data['synonyms'].append(match.group(1)) |
|
|
elif tag == 'xref': |
|
|
if ':' in value: |
|
|
xref_ns, _, xref_id = value.partition(':') |
|
|
xref_id = xref_id.split()[0] if ' ' in xref_id else xref_id |
|
|
data['xrefs'][xref_ns.strip()] = xref_id.strip() |
|
|
elif tag == 'is_a': |
|
|
parent_id = value.split('!')[0].strip() |
|
|
data['is_a'].append(parent_id) |
|
|
elif tag == 'relationship': |
|
|
parts = value.split() |
|
|
if len(parts) >= 2: |
|
|
data['relationships'].append((parts[0], parts[1])) |
|
|
elif tag == 'is_obsolete': |
|
|
data['is_obsolete'] = value.lower() == 'true' |
|
|
elif tag == 'namespace': |
|
|
data['namespace'] = value |
|
|
|
|
|
if data['id'] and data['name']: |
|
|
return OntologyTerm(**data) |
|
|
return None |
|
|
|
|
|
|
|
|
class CSVAdapter(DatasetAdapter): |
|
|
"""Parser for CSV/TSV format datasets.""" |
|
|
|
|
|
|
|
|
DEFAULT_MAPPINGS = { |
|
|
"id": ["id", "ID", "identifier", "code"], |
|
|
"name": ["name", "Name", "label", "Label", "title"], |
|
|
"definition": ["definition", "description", "Description", "desc"], |
|
|
"synonyms": ["synonyms", "aliases", "alt_names"], |
|
|
} |
|
|
|
|
|
def __init__(self, column_mappings: Optional[Dict[str, str]] = None): |
|
|
self.column_mappings = column_mappings or {} |
|
|
|
|
|
def can_handle(self, source_type: str) -> bool: |
|
|
return source_type.lower() in ["csv", "tsv"] |
|
|
|
|
|
def parse(self, content: str) -> Dict[str, OntologyTerm]: |
|
|
"""Parse CSV content.""" |
|
|
terms = {} |
|
|
|
|
|
|
|
|
dialect = csv.Sniffer().sniff(content[:1024]) |
|
|
reader = csv.DictReader(content.splitlines(), dialect=dialect) |
|
|
|
|
|
|
|
|
col_map = self._map_columns(reader.fieldnames or []) |
|
|
|
|
|
for row in reader: |
|
|
term = self._row_to_term(row, col_map) |
|
|
if term: |
|
|
terms[term.id] = term |
|
|
|
|
|
logger.info(f"Parsed {len(terms)} terms from CSV content") |
|
|
return terms |
|
|
|
|
|
def _map_columns(self, fieldnames: List[str]) -> Dict[str, str]: |
|
|
"""Map fieldnames to standard term fields.""" |
|
|
col_map = {} |
|
|
|
|
|
for field, possible_names in self.DEFAULT_MAPPINGS.items(): |
|
|
|
|
|
if field in self.column_mappings: |
|
|
col_map[field] = self.column_mappings[field] |
|
|
else: |
|
|
|
|
|
for name in possible_names: |
|
|
if name in fieldnames: |
|
|
col_map[field] = name |
|
|
break |
|
|
|
|
|
return col_map |
|
|
|
|
|
def _row_to_term(self, row: Dict, col_map: Dict[str, str]) -> Optional[OntologyTerm]: |
|
|
"""Convert CSV row to OntologyTerm.""" |
|
|
term_id = row.get(col_map.get("id", ""), "") |
|
|
name = row.get(col_map.get("name", ""), "") |
|
|
|
|
|
if not term_id or not name: |
|
|
return None |
|
|
|
|
|
definition = row.get(col_map.get("definition", ""), "") |
|
|
|
|
|
|
|
|
synonyms_raw = row.get(col_map.get("synonyms", ""), "") |
|
|
if synonyms_raw.startswith("["): |
|
|
try: |
|
|
synonyms = json.loads(synonyms_raw) |
|
|
except json.JSONDecodeError: |
|
|
synonyms = [] |
|
|
else: |
|
|
synonyms = [s.strip() for s in synonyms_raw.split(",") if s.strip()] |
|
|
|
|
|
return OntologyTerm( |
|
|
id=term_id, |
|
|
name=name, |
|
|
definition=definition, |
|
|
synonyms=synonyms |
|
|
) |
|
|
|
|
|
|
|
|
class JSONAdapter(DatasetAdapter): |
|
|
"""Parser for JSON format datasets.""" |
|
|
|
|
|
def __init__(self, terms_path: str = "terms", id_field: str = "id", name_field: str = "name"): |
|
|
self.terms_path = terms_path |
|
|
self.id_field = id_field |
|
|
self.name_field = name_field |
|
|
|
|
|
def can_handle(self, source_type: str) -> bool: |
|
|
return source_type.lower() in ["json", "json-ld"] |
|
|
|
|
|
def parse(self, content: str) -> Dict[str, OntologyTerm]: |
|
|
"""Parse JSON content.""" |
|
|
terms = {} |
|
|
data = json.loads(content) |
|
|
|
|
|
|
|
|
items = data |
|
|
if self.terms_path: |
|
|
for key in self.terms_path.split("."): |
|
|
if isinstance(items, dict): |
|
|
items = items.get(key, []) |
|
|
else: |
|
|
break |
|
|
|
|
|
if not isinstance(items, list): |
|
|
items = [items] if isinstance(items, dict) else [] |
|
|
|
|
|
for item in items: |
|
|
term = self._item_to_term(item) |
|
|
if term: |
|
|
terms[term.id] = term |
|
|
|
|
|
logger.info(f"Parsed {len(terms)} terms from JSON content") |
|
|
return terms |
|
|
|
|
|
def _item_to_term(self, item: Dict) -> Optional[OntologyTerm]: |
|
|
"""Convert JSON item to OntologyTerm.""" |
|
|
term_id = item.get(self.id_field, "") |
|
|
name = item.get(self.name_field, "") |
|
|
|
|
|
if not term_id or not name: |
|
|
return None |
|
|
|
|
|
return OntologyTerm( |
|
|
id=term_id, |
|
|
name=name, |
|
|
definition=item.get("definition", item.get("description", "")), |
|
|
synonyms=item.get("synonyms", item.get("aliases", [])), |
|
|
xrefs=item.get("xrefs", {}), |
|
|
is_a=item.get("is_a", item.get("parents", [])), |
|
|
) |
|
|
|
|
|
|
|
|
class DatasetLoader: |
|
|
""" |
|
|
Main dataset loader supporting multiple formats and sources. |
|
|
|
|
|
Usage: |
|
|
loader = DatasetLoader() |
|
|
loader.load_dataset(config) # Single dataset |
|
|
loader.load_all_datasets() # From config |
|
|
""" |
|
|
|
|
|
def __init__(self, cache_dir: Optional[str] = None): |
|
|
self.cache_dir = Path(cache_dir or get_config().cache_dir) |
|
|
self.cache_dir.mkdir(parents=True, exist_ok=True) |
|
|
|
|
|
|
|
|
self.adapters: List[DatasetAdapter] = [ |
|
|
OBOAdapter(), |
|
|
CSVAdapter(), |
|
|
JSONAdapter(), |
|
|
] |
|
|
|
|
|
|
|
|
self.datasets: Dict[str, Dict[str, OntologyTerm]] = {} |
|
|
|
|
|
def register_adapter(self, adapter: DatasetAdapter): |
|
|
"""Register a custom adapter.""" |
|
|
self.adapters.insert(0, adapter) |
|
|
|
|
|
def get_adapter(self, source_type: str) -> Optional[DatasetAdapter]: |
|
|
"""Get adapter for source type.""" |
|
|
for adapter in self.adapters: |
|
|
if adapter.can_handle(source_type): |
|
|
return adapter |
|
|
return None |
|
|
|
|
|
def load_dataset(self, config: DatasetConfig) -> Dict[str, OntologyTerm]: |
|
|
"""Load a single dataset based on configuration.""" |
|
|
logger.info(f"Loading dataset: {config.name}") |
|
|
|
|
|
|
|
|
if config.cache_enabled: |
|
|
cached = self._load_from_cache(config) |
|
|
if cached: |
|
|
self.datasets[config.name] = cached |
|
|
return cached |
|
|
|
|
|
|
|
|
content = self._get_content(config) |
|
|
if not content: |
|
|
logger.warning(f"No content for dataset: {config.name}") |
|
|
return {} |
|
|
|
|
|
|
|
|
adapter = self.get_adapter(config.source_type) |
|
|
if not adapter: |
|
|
logger.error(f"No adapter for source type: {config.source_type}") |
|
|
return {} |
|
|
|
|
|
terms = adapter.parse(content) |
|
|
|
|
|
|
|
|
if config.cache_enabled: |
|
|
self._save_to_cache(config, terms) |
|
|
|
|
|
self.datasets[config.name] = terms |
|
|
return terms |
|
|
|
|
|
def load_all_datasets(self) -> Dict[str, Dict[str, OntologyTerm]]: |
|
|
"""Load all datasets from configuration.""" |
|
|
config = get_config() |
|
|
for dataset_config in config.datasets: |
|
|
self.load_dataset(dataset_config) |
|
|
return self.datasets |
|
|
|
|
|
def _get_content(self, config: DatasetConfig) -> Optional[str]: |
|
|
"""Get content from URL or file path.""" |
|
|
|
|
|
if config.source_url: |
|
|
try: |
|
|
logger.info(f"Downloading from: {config.source_url}") |
|
|
req = urllib.request.Request( |
|
|
config.source_url, |
|
|
headers={'User-Agent': 'HITL-KG/1.0'} |
|
|
) |
|
|
with urllib.request.urlopen(req, timeout=60) as response: |
|
|
return response.read().decode('utf-8') |
|
|
except Exception as e: |
|
|
logger.warning(f"Download failed: {e}") |
|
|
|
|
|
|
|
|
if config.source_path: |
|
|
path = Path(config.source_path) |
|
|
if path.exists(): |
|
|
return path.read_text(encoding='utf-8') |
|
|
|
|
|
return None |
|
|
|
|
|
def _cache_path(self, config: DatasetConfig) -> Path: |
|
|
"""Get cache file path for a dataset.""" |
|
|
return self.cache_dir / f"{config.name}_cache.json" |
|
|
|
|
|
def _load_from_cache(self, config: DatasetConfig) -> Optional[Dict[str, OntologyTerm]]: |
|
|
"""Load dataset from cache if valid.""" |
|
|
cache_path = self._cache_path(config) |
|
|
|
|
|
if not cache_path.exists(): |
|
|
return None |
|
|
|
|
|
|
|
|
mtime = datetime.fromtimestamp(cache_path.stat().st_mtime) |
|
|
age_days = (datetime.now() - mtime).days |
|
|
if age_days > config.cache_max_age_days: |
|
|
return None |
|
|
|
|
|
try: |
|
|
with open(cache_path) as f: |
|
|
data = json.load(f) |
|
|
|
|
|
terms = {} |
|
|
for term_id, term_data in data.get("terms", {}).items(): |
|
|
terms[term_id] = OntologyTerm( |
|
|
id=term_data["id"], |
|
|
name=term_data["name"], |
|
|
definition=term_data.get("definition", ""), |
|
|
synonyms=term_data.get("synonyms", []), |
|
|
xrefs=term_data.get("xrefs", {}), |
|
|
is_a=term_data.get("is_a", []), |
|
|
relationships=term_data.get("relationships", []), |
|
|
namespace=term_data.get("namespace", ""), |
|
|
) |
|
|
|
|
|
logger.info(f"Loaded {len(terms)} terms from cache: {config.name}") |
|
|
return terms |
|
|
|
|
|
except Exception as e: |
|
|
logger.warning(f"Cache load failed: {e}") |
|
|
return None |
|
|
|
|
|
def _save_to_cache(self, config: DatasetConfig, terms: Dict[str, OntologyTerm]): |
|
|
"""Save dataset to cache.""" |
|
|
try: |
|
|
cache_path = self._cache_path(config) |
|
|
|
|
|
data = { |
|
|
"name": config.name, |
|
|
"source_type": config.source_type, |
|
|
"timestamp": datetime.now().isoformat(), |
|
|
"terms": { |
|
|
tid: { |
|
|
"id": t.id, |
|
|
"name": t.name, |
|
|
"definition": t.definition, |
|
|
"synonyms": t.synonyms, |
|
|
"xrefs": t.xrefs, |
|
|
"is_a": t.is_a, |
|
|
"relationships": t.relationships, |
|
|
"namespace": t.namespace, |
|
|
} |
|
|
for tid, t in terms.items() |
|
|
} |
|
|
} |
|
|
|
|
|
with open(cache_path, 'w') as f: |
|
|
json.dump(data, f) |
|
|
|
|
|
logger.info(f"Cached {len(terms)} terms for: {config.name}") |
|
|
|
|
|
except Exception as e: |
|
|
logger.warning(f"Cache save failed: {e}") |
|
|
|
|
|
|
|
|
def build_knowledge_graph(loader: DatasetLoader) -> KnowledgeGraph: |
|
|
""" |
|
|
Build a KnowledgeGraph from loaded datasets. |
|
|
|
|
|
This function: |
|
|
1. Converts OntologyTerms to Entities |
|
|
2. Creates relationships between entities |
|
|
3. Indexes entities for semantic search |
|
|
""" |
|
|
kg = KnowledgeGraph() |
|
|
config = get_config() |
|
|
|
|
|
|
|
|
category_map = { |
|
|
ds.name: EntityCategory(ds.entity_category) |
|
|
for ds in config.datasets |
|
|
if ds.entity_category in [c.value for c in EntityCategory] |
|
|
} |
|
|
|
|
|
|
|
|
for dataset_name, terms in loader.datasets.items(): |
|
|
category = category_map.get(dataset_name, EntityCategory.FINDING) |
|
|
|
|
|
for term_id, term in terms.items(): |
|
|
entity = term.to_entity(category) |
|
|
kg.add_entity(entity) |
|
|
|
|
|
|
|
|
_build_relationships(kg, loader) |
|
|
|
|
|
logger.info(f"Built KG with {len(kg.entities)} entities") |
|
|
return kg |
|
|
|
|
|
|
|
|
def _build_relationships(kg: KnowledgeGraph, loader: DatasetLoader): |
|
|
"""Build relationships between entities.""" |
|
|
|
|
|
|
|
|
disease_symptom_mappings = _get_disease_symptom_mappings() |
|
|
|
|
|
for disease_id, symptom_mappings in disease_symptom_mappings.items(): |
|
|
if disease_id not in kg.entities: |
|
|
continue |
|
|
|
|
|
for symptom_name, confidence in symptom_mappings: |
|
|
|
|
|
symptom_entity = None |
|
|
for entity in kg.entities.values(): |
|
|
if entity.category == EntityCategory.SYMPTOM: |
|
|
if (entity.name.lower() == symptom_name.lower() or |
|
|
symptom_name.lower() in [s.lower() for s in entity.synonyms]): |
|
|
symptom_entity = entity |
|
|
break |
|
|
|
|
|
if symptom_entity: |
|
|
kg.add_relation(disease_id, symptom_entity.id, "causes", confidence) |
|
|
|
|
|
|
|
|
_add_treatment_entities(kg) |
|
|
|
|
|
|
|
|
def _get_disease_symptom_mappings() -> Dict[str, List[Tuple[str, float]]]: |
|
|
""" |
|
|
Get curated disease-symptom mappings. |
|
|
|
|
|
These are based on medical literature and provide high-quality |
|
|
associations that may not be present in the raw ontologies. |
|
|
""" |
|
|
return { |
|
|
"DOID:8469": [ |
|
|
("fever", 0.95), ("cough", 0.85), ("fatigue", 0.90), |
|
|
("body aches", 0.85), ("headache", 0.80), ("chills", 0.75), |
|
|
], |
|
|
"DOID:0080600": [ |
|
|
("fever", 0.80), ("cough", 0.85), ("fatigue", 0.90), |
|
|
("shortness of breath", 0.70), ("headache", 0.60), |
|
|
("loss of taste", 0.50), ("loss of smell", 0.50), |
|
|
], |
|
|
"DOID:10459": [ |
|
|
("runny nose", 0.95), ("sore throat", 0.80), ("cough", 0.75), |
|
|
("nasal congestion", 0.85), ("sneezing", 0.80), |
|
|
], |
|
|
"DOID:552": [ |
|
|
("fever", 0.90), ("cough", 0.95), ("shortness of breath", 0.85), |
|
|
("chest pain", 0.70), ("fatigue", 0.80), |
|
|
], |
|
|
"DOID:6132": [ |
|
|
("cough", 0.95), ("fatigue", 0.60), |
|
|
("shortness of breath", 0.50), |
|
|
], |
|
|
"DOID:10534": [ |
|
|
("sore throat", 0.98), ("fever", 0.80), ("headache", 0.50), |
|
|
], |
|
|
"DOID:13084": [ |
|
|
("headache", 0.85), ("nasal congestion", 0.90), |
|
|
("runny nose", 0.80), |
|
|
], |
|
|
"DOID:8893": [ |
|
|
("headache", 0.99), ("nausea", 0.70), |
|
|
], |
|
|
} |
|
|
|
|
|
|
|
|
def _add_treatment_entities(kg: KnowledgeGraph): |
|
|
"""Add treatment entities and relationships.""" |
|
|
treatments = [ |
|
|
Entity("tx_rest", "Rest", EntityCategory.TREATMENT, |
|
|
"Physical and mental rest", ["bed rest"]), |
|
|
Entity("tx_fluids", "Fluid Intake", EntityCategory.TREATMENT, |
|
|
"Increased hydration", ["hydration"]), |
|
|
Entity("tx_acetaminophen", "Acetaminophen", EntityCategory.MEDICATION, |
|
|
"Pain and fever reducer", ["paracetamol", "Tylenol"]), |
|
|
Entity("tx_ibuprofen", "Ibuprofen", EntityCategory.MEDICATION, |
|
|
"NSAID for pain and inflammation", ["Advil", "Motrin"]), |
|
|
Entity("tx_antiviral", "Antiviral Medication", EntityCategory.MEDICATION, |
|
|
"Medications for viral infections", ["oseltamivir", "Tamiflu"]), |
|
|
Entity("tx_decongestant", "Decongestants", EntityCategory.MEDICATION, |
|
|
"Nasal congestion relief", ["pseudoephedrine"]), |
|
|
] |
|
|
|
|
|
for tx in treatments: |
|
|
kg.add_entity(tx) |
|
|
|
|
|
|
|
|
treatment_map = { |
|
|
"DOID:8469": ["tx_rest", "tx_fluids", "tx_acetaminophen", "tx_antiviral"], |
|
|
"DOID:0080600": ["tx_rest", "tx_fluids", "tx_acetaminophen"], |
|
|
"DOID:10459": ["tx_rest", "tx_fluids", "tx_decongestant"], |
|
|
"DOID:552": ["tx_rest"], |
|
|
} |
|
|
|
|
|
for disease_id, treatment_ids in treatment_map.items(): |
|
|
if disease_id in kg.entities: |
|
|
for tx_id in treatment_ids: |
|
|
if tx_id in kg.entities: |
|
|
kg.add_relation(tx_id, disease_id, "treats", 0.8) |
|
|
|
|
|
|
|
|
def load_knowledge_graph(use_embeddings: bool = True) -> KnowledgeGraph: |
|
|
""" |
|
|
Main entry point: Load datasets and build knowledge graph. |
|
|
|
|
|
Args: |
|
|
use_embeddings: If True, also index entities for semantic search |
|
|
""" |
|
|
loader = DatasetLoader() |
|
|
loader.load_all_datasets() |
|
|
|
|
|
kg = build_knowledge_graph(loader) |
|
|
|
|
|
if use_embeddings: |
|
|
try: |
|
|
from .embedding_service import get_embedding_service |
|
|
embedding_service = get_embedding_service() |
|
|
embedding_service.index_entities(kg.get_entity_dict_for_embedding()) |
|
|
except Exception as e: |
|
|
logger.warning(f"Failed to initialize embeddings: {e}") |
|
|
|
|
|
return kg |
|
|
|