|
|
""" |
|
|
LLM Engine Module (Refactored) |
|
|
|
|
|
Simplified reasoning engine with: |
|
|
1. Embedding-based entity extraction (replaces keyword matching) |
|
|
2. Clean separation between OpenAI and local modes |
|
|
3. Proper context building with language support |
|
|
4. ReasoningChainCache for Graph-of-Thoughts structure |
|
|
|
|
|
References: |
|
|
- Chain-of-Thought prompting (Wei et al., 2022) |
|
|
- Tree of Thoughts (Yao et al., 2023) |
|
|
- Graph of Thoughts (Besta et al., 2023) |
|
|
""" |
|
|
|
|
|
import os |
|
|
import json |
|
|
import logging |
|
|
from abc import ABC, abstractmethod |
|
|
from dataclasses import dataclass, field |
|
|
from datetime import datetime |
|
|
from enum import Enum |
|
|
from typing import Dict, List, Optional, Any, Generator, Tuple |
|
|
|
|
|
from .knowledge_graph import ( |
|
|
KnowledgeGraph, ReasoningNode, ReasoningEdge, |
|
|
NodeType, EdgeType, EntityCategory, Entity, create_node_id |
|
|
) |
|
|
from .embedding_service import get_embedding_service, SearchResult |
|
|
|
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@dataclass |
|
|
class ChainNode: |
|
|
"""A node in the reasoning chain with parent tracking.""" |
|
|
node_id: str |
|
|
node_type: NodeType |
|
|
parents: List[str] = field(default_factory=list) |
|
|
children: List[str] = field(default_factory=list) |
|
|
depth: int = 0 |
|
|
|
|
|
|
|
|
class ReasoningChainCache: |
|
|
""" |
|
|
Manages the structure of reasoning chains for Graph-of-Thoughts. |
|
|
|
|
|
Tracks: |
|
|
- Parent-child relationships between reasoning steps |
|
|
- Multiple converging/diverging paths |
|
|
- Proper depth tracking for hierarchy |
|
|
|
|
|
Reference: Graph of Thoughts (Besta et al., 2023) |
|
|
""" |
|
|
|
|
|
def __init__(self): |
|
|
self.chains: Dict[str, ChainNode] = {} |
|
|
self.root_nodes: List[str] = [] |
|
|
self.current_branch: List[str] = [] |
|
|
|
|
|
def add_node( |
|
|
self, |
|
|
node_id: str, |
|
|
node_type: NodeType, |
|
|
parent_ids: Optional[List[str]] = None |
|
|
) -> ChainNode: |
|
|
"""Add a node to the reasoning chain.""" |
|
|
parent_ids = parent_ids or [] |
|
|
|
|
|
|
|
|
depth = 0 |
|
|
if parent_ids: |
|
|
max_parent_depth = max( |
|
|
self.chains[pid].depth for pid in parent_ids if pid in self.chains |
|
|
) |
|
|
depth = max_parent_depth + 1 |
|
|
|
|
|
chain_node = ChainNode( |
|
|
node_id=node_id, |
|
|
node_type=node_type, |
|
|
parents=parent_ids, |
|
|
depth=depth |
|
|
) |
|
|
|
|
|
self.chains[node_id] = chain_node |
|
|
|
|
|
|
|
|
for pid in parent_ids: |
|
|
if pid in self.chains: |
|
|
self.chains[pid].children.append(node_id) |
|
|
|
|
|
|
|
|
if not parent_ids: |
|
|
self.root_nodes.append(node_id) |
|
|
|
|
|
|
|
|
self.current_branch.append(node_id) |
|
|
|
|
|
node_type_str = node_type.value if node_type else "unknown" |
|
|
logger.debug(f"Chain: Added {node_type_str} node {node_id[:8]} at depth {depth}") |
|
|
return chain_node |
|
|
|
|
|
def get_active_nodes(self) -> List[str]: |
|
|
"""Get nodes that can be extended (leaf nodes).""" |
|
|
return [ |
|
|
nid for nid, node in self.chains.items() |
|
|
if not node.children |
|
|
] |
|
|
|
|
|
def get_ancestors(self, node_id: str) -> List[str]: |
|
|
"""Get all ancestor node IDs.""" |
|
|
ancestors = [] |
|
|
to_visit = [node_id] |
|
|
visited = set() |
|
|
|
|
|
while to_visit: |
|
|
current = to_visit.pop(0) |
|
|
if current in visited: |
|
|
continue |
|
|
visited.add(current) |
|
|
|
|
|
if current in self.chains: |
|
|
for parent in self.chains[current].parents: |
|
|
ancestors.append(parent) |
|
|
to_visit.append(parent) |
|
|
|
|
|
return ancestors |
|
|
|
|
|
def create_branch(self, from_node_id: str) -> None: |
|
|
"""Start a new branch from the specified node.""" |
|
|
if from_node_id in self.chains: |
|
|
self.current_branch = [from_node_id] |
|
|
logger.info(f"Started new branch from node {from_node_id[:8]}") |
|
|
|
|
|
def get_context_nodes(self, max_nodes: int = 10) -> List[str]: |
|
|
"""Get recent nodes for context building.""" |
|
|
|
|
|
return self.current_branch[-max_nodes:] |
|
|
|
|
|
def clear(self): |
|
|
"""Clear all chain data.""" |
|
|
self.chains.clear() |
|
|
self.root_nodes.clear() |
|
|
self.current_branch.clear() |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class LLMProvider(str, Enum): |
|
|
"""Supported LLM providers.""" |
|
|
OPENAI = "openai" |
|
|
LOCAL = "local" |
|
|
|
|
|
|
|
|
@dataclass |
|
|
class GenerationConfig: |
|
|
"""Configuration for reasoning generation.""" |
|
|
model: str = "gpt-4o-mini" |
|
|
temperature: float = 0.7 |
|
|
max_tokens: int = 2048 |
|
|
max_reasoning_steps: int = 10 |
|
|
include_alternatives: bool = True |
|
|
language: str = "en" |
|
|
|
|
|
|
|
|
@dataclass |
|
|
class ReasoningContext: |
|
|
"""Context for reasoning generation.""" |
|
|
query: str |
|
|
language: str = "en" |
|
|
matched_entities: List[SearchResult] = field(default_factory=list) |
|
|
previous_reasoning: List[Dict] = field(default_factory=list) |
|
|
anchor_node_id: Optional[str] = None |
|
|
is_branching: bool = False |
|
|
|
|
|
|
|
|
|
|
|
SYSTEM_PROMPTS = { |
|
|
"en": """You are a medical reasoning assistant using Graph-of-Thoughts methodology. |
|
|
Analyze symptoms and provide structured diagnostic analysis with BRANCHING reasoning paths. |
|
|
|
|
|
CRITICAL: You MUST always return a JSON object with a non-empty "steps" array. |
|
|
|
|
|
IMPORTANT - CREATE NON-LINEAR REASONING: |
|
|
- Generate multiple parallel reasoning branches, not just sequential steps |
|
|
- Use "supports" array to indicate which prior steps support each new step |
|
|
- A step can be supported by MULTIPLE prior steps (converging evidence) |
|
|
- Create at least 2-3 alternative diagnostic pathways |
|
|
|
|
|
OUTPUT FORMAT (JSON) - ALWAYS INCLUDE STEPS: |
|
|
{ |
|
|
"steps": [ |
|
|
{"type": "fact", "content": "Patient reports headache", "confidence": 0.95, "supports": [0]}, |
|
|
{"type": "fact", "content": "Patient has fever 38.5°C", "confidence": 0.95, "supports": [0]}, |
|
|
{"type": "reasoning", "content": "Symptoms suggest infection", "confidence": 0.8, "supports": [1, 2]}, |
|
|
{"type": "reasoning", "content": "Could indicate tension headache", "confidence": 0.6, "supports": [1]}, |
|
|
{"type": "hypothesis", "content": "Primary: Viral infection", "confidence": 0.75, "supports": [3]}, |
|
|
{"type": "hypothesis", "content": "Alternative: Bacterial infection", "confidence": 0.5, "supports": [3]}, |
|
|
{"type": "conclusion", "content": "Recommend tests and monitoring", "confidence": 0.85, "supports": [5, 6]} |
|
|
], |
|
|
"alternatives": [ |
|
|
{"content": "Migraine if symptoms persist without fever", "confidence": 0.4, "reason": "Headache pattern"} |
|
|
] |
|
|
} |
|
|
|
|
|
Step indices: 0 = user query, 1+ = your generated steps. |
|
|
GUIDELINES: |
|
|
1. ALWAYS generate 5-8 reasoning steps - NEVER return empty steps array |
|
|
2. Multiple facts can support the same reasoning step (supports: [1, 2, 3]) |
|
|
3. Create divergent then convergent reasoning paths |
|
|
4. Include at least 2 alternative hypotheses |
|
|
5. Respond in the SAME LANGUAGE as the user query |
|
|
|
|
|
DISCLAIMER: Educational purposes only. Consult healthcare professionals.""", |
|
|
|
|
|
"uk": """Ви — медичний асистент, що використовує методологію Graph-of-Thoughts. |
|
|
Аналізуйте симптоми та створюйте РОЗГАЛУЖЕНІ шляхи міркування. |
|
|
|
|
|
КРИТИЧНО: Ви ПОВИННІ завжди повертати JSON об'єкт з непорожнім масивом "steps". |
|
|
|
|
|
ФОРМАТ ВИВОДУ (JSON) - ЗАВЖДИ ВКЛЮЧАЙТЕ STEPS: |
|
|
{ |
|
|
"steps": [ |
|
|
{"type": "fact", "content": "Пацієнт скаржиться на головний біль", "confidence": 0.95, "supports": [0]}, |
|
|
{"type": "fact", "content": "У пацієнта температура 38.5°C", "confidence": 0.95, "supports": [0]}, |
|
|
{"type": "reasoning", "content": "Симптоми вказують на інфекцію", "confidence": 0.8, "supports": [1, 2]}, |
|
|
{"type": "hypothesis", "content": "Первинна: Вірусна інфекція", "confidence": 0.75, "supports": [3]}, |
|
|
{"type": "hypothesis", "content": "Альтернатива: Застуда", "confidence": 0.5, "supports": [3]}, |
|
|
{"type": "conclusion", "content": "Рекомендовано обстеження", "confidence": 0.85, "supports": [4, 5]} |
|
|
], |
|
|
"alternatives": [ |
|
|
{"content": "Мігрень, якщо симптоми без температури", "confidence": 0.4, "reason": "Характер болю"} |
|
|
] |
|
|
} |
|
|
|
|
|
ВАЖЛИВО: |
|
|
- ЗАВЖДИ генеруйте 5-8 кроків міркування - НІКОЛИ не повертайте порожній масив steps |
|
|
- Використовуйте масив "supports" для зв'язку кроків |
|
|
- Відповідайте УКРАЇНСЬКОЮ МОВОЮ |
|
|
|
|
|
ВІДМОВА: Лише в освітніх цілях. Зверніться до лікаря.""", |
|
|
|
|
|
"ru": """Вы — медицинский ассистент, использующий методологию Graph-of-Thoughts. |
|
|
Анализируйте симптомы и создавайте РАЗВЕТВЛЁННЫЕ пути рассуждений. |
|
|
|
|
|
КРИТИЧНО: Вы ДОЛЖНЫ всегда возвращать JSON объект с непустым массивом "steps". |
|
|
|
|
|
ФОРМАТ ВЫВОДА (JSON) - ВСЕГДА ВКЛЮЧАЙТЕ STEPS: |
|
|
{ |
|
|
"steps": [ |
|
|
{"type": "fact", "content": "Пациент жалуется на головную боль", "confidence": 0.95, "supports": [0]}, |
|
|
{"type": "fact", "content": "У пациента температура 38.5°C", "confidence": 0.95, "supports": [0]}, |
|
|
{"type": "reasoning", "content": "Симптомы указывают на инфекцию", "confidence": 0.8, "supports": [1, 2]}, |
|
|
{"type": "hypothesis", "content": "Первичная: Вирусная инфекция", "confidence": 0.75, "supports": [3]}, |
|
|
{"type": "hypothesis", "content": "Альтернатива: Простуда", "confidence": 0.5, "supports": [3]}, |
|
|
{"type": "conclusion", "content": "Рекомендовано обследование", "confidence": 0.85, "supports": [4, 5]} |
|
|
], |
|
|
"alternatives": [ |
|
|
{"content": "Мигрень, если симптомы без температуры", "confidence": 0.4, "reason": "Характер боли"} |
|
|
] |
|
|
} |
|
|
|
|
|
ВАЖНО: |
|
|
- ВСЕГДА генерируйте 5-8 шагов рассуждений - НИКОГДА не возвращайте пустой массив steps |
|
|
- Используйте массив "supports" для связи шагов |
|
|
- Отвечайте НА РУССКОМ ЯЗЫКЕ |
|
|
|
|
|
ОТКАЗ: Только в образовательных целях. Обратитесь к врачу.""", |
|
|
} |
|
|
|
|
|
LANGUAGE_NAMES = { |
|
|
"en": "English", "uk": "Ukrainian", "ru": "Russian", |
|
|
"es": "Spanish", "de": "German", "fr": "French", |
|
|
} |
|
|
|
|
|
|
|
|
def detect_language(text: str) -> str: |
|
|
""" |
|
|
Detect language of text using simple heuristics. |
|
|
For production, use langdetect or similar library. |
|
|
""" |
|
|
text_lower = text.lower() |
|
|
|
|
|
|
|
|
cyrillic_chars = sum(1 for c in text if '\u0400' <= c <= '\u04FF') |
|
|
if cyrillic_chars > len(text) * 0.3: |
|
|
|
|
|
ukrainian_markers = ['і', 'ї', 'є', 'ґ'] |
|
|
if any(m in text_lower for m in ukrainian_markers): |
|
|
return "uk" |
|
|
return "ru" |
|
|
|
|
|
|
|
|
spanish_markers = ['¿', '¡', 'ñ', 'ción', 'mente'] |
|
|
german_markers = ['ß', 'ü', 'ö', 'ä', 'ich', 'und', 'der', 'die'] |
|
|
french_markers = ['ç', 'œ', 'être', 'avoir', 'très'] |
|
|
|
|
|
if any(m in text_lower for m in spanish_markers): |
|
|
return "es" |
|
|
if any(m in text_lower for m in german_markers): |
|
|
return "de" |
|
|
if any(m in text_lower for m in french_markers): |
|
|
return "fr" |
|
|
|
|
|
return "en" |
|
|
|
|
|
|
|
|
class ReasoningEngine(ABC): |
|
|
"""Abstract base class for reasoning engines with Graph-of-Thoughts support.""" |
|
|
|
|
|
def __init__(self, kg: KnowledgeGraph): |
|
|
self.kg = kg |
|
|
self.chain_cache = ReasoningChainCache() |
|
|
|
|
|
@abstractmethod |
|
|
def generate( |
|
|
self, |
|
|
context: ReasoningContext, |
|
|
config: GenerationConfig |
|
|
) -> Generator[ReasoningNode, None, None]: |
|
|
"""Generate reasoning steps.""" |
|
|
pass |
|
|
|
|
|
def reset_chain(self): |
|
|
"""Reset the reasoning chain cache.""" |
|
|
self.chain_cache.clear() |
|
|
|
|
|
def build_context( |
|
|
self, |
|
|
query: str, |
|
|
anchor_node_id: Optional[str] = None |
|
|
) -> ReasoningContext: |
|
|
"""Build reasoning context from query using embedding-based search.""" |
|
|
language = detect_language(query) |
|
|
|
|
|
context = ReasoningContext(query=query, language=language) |
|
|
|
|
|
|
|
|
try: |
|
|
embedding_service = get_embedding_service() |
|
|
|
|
|
|
|
|
symptom_results = embedding_service.extract_entities_from_text( |
|
|
text=query, |
|
|
category="symptom", |
|
|
top_k=5, |
|
|
threshold=0.35 |
|
|
) |
|
|
context.matched_entities = symptom_results |
|
|
|
|
|
if symptom_results: |
|
|
logger.info( |
|
|
f"Extracted {len(symptom_results)} entities: " |
|
|
f"{[r.entity_data.get('name') for r in symptom_results]}" |
|
|
) |
|
|
except Exception as e: |
|
|
logger.warning(f"Embedding search failed: {e}") |
|
|
|
|
|
|
|
|
recent_nodes = sorted( |
|
|
self.kg.nodes.values(), |
|
|
key=lambda x: x.timestamp |
|
|
)[-10:] |
|
|
|
|
|
context.previous_reasoning = [ |
|
|
{ |
|
|
"role": "assistant" if n.node_type != NodeType.QUERY else "user", |
|
|
"content": f"[{n.node_type.value}]: {n.content}", |
|
|
"id": n.id, |
|
|
"type": n.node_type.value |
|
|
} |
|
|
for n in recent_nodes |
|
|
] |
|
|
|
|
|
|
|
|
if anchor_node_id: |
|
|
context.anchor_node_id = anchor_node_id |
|
|
last_node = self.kg.get_last_active_node() |
|
|
if last_node and anchor_node_id != last_node.id: |
|
|
context.is_branching = True |
|
|
|
|
|
return context |
|
|
|
|
|
def _create_query_node( |
|
|
self, |
|
|
context: ReasoningContext |
|
|
) -> Tuple[ReasoningNode, Optional[ReasoningNode]]: |
|
|
"""Create query node and connect to parent.""" |
|
|
|
|
|
parent_node = None |
|
|
edge_type = EdgeType.LEADS_TO |
|
|
|
|
|
if context.anchor_node_id: |
|
|
parent_node = self.kg.nodes.get(context.anchor_node_id) |
|
|
if context.is_branching: |
|
|
edge_type = EdgeType.ALTERNATIVE |
|
|
|
|
|
if not parent_node: |
|
|
parent_node = self.kg.get_last_active_node() |
|
|
if parent_node: |
|
|
edge_type = EdgeType.FOLLOW_UP |
|
|
|
|
|
|
|
|
query_node = ReasoningNode( |
|
|
id=create_node_id(), |
|
|
label=context.query[:60], |
|
|
node_type=NodeType.QUERY, |
|
|
content=context.query, |
|
|
confidence=1.0, |
|
|
language=context.language |
|
|
) |
|
|
self.kg.add_node(query_node) |
|
|
|
|
|
|
|
|
if parent_node: |
|
|
edge = ReasoningEdge( |
|
|
source=parent_node.id, |
|
|
target=query_node.id, |
|
|
edge_type=edge_type |
|
|
) |
|
|
self.kg.add_edge(edge) |
|
|
|
|
|
return query_node, parent_node |
|
|
|
|
|
def get_system_prompt(self, language: str) -> str: |
|
|
"""Get system prompt for language.""" |
|
|
return SYSTEM_PROMPTS.get(language, SYSTEM_PROMPTS["en"]) |
|
|
|
|
|
|
|
|
class OpenAIEngine(ReasoningEngine): |
|
|
"""OpenAI-based reasoning engine.""" |
|
|
|
|
|
def __init__(self, kg: KnowledgeGraph, api_key: Optional[str] = None): |
|
|
super().__init__(kg) |
|
|
self.api_key = api_key or os.environ.get("OPENAI_API_KEY") |
|
|
self._client = None |
|
|
|
|
|
@property |
|
|
def client(self): |
|
|
"""Lazy-load OpenAI client.""" |
|
|
if self._client is None: |
|
|
try: |
|
|
from openai import OpenAI |
|
|
self._client = OpenAI(api_key=self.api_key) |
|
|
except ImportError: |
|
|
raise ImportError("Install openai: pip install openai") |
|
|
return self._client |
|
|
|
|
|
def generate( |
|
|
self, |
|
|
context: ReasoningContext, |
|
|
config: GenerationConfig |
|
|
) -> Generator[ReasoningNode, None, None]: |
|
|
"""Generate reasoning using OpenAI.""" |
|
|
query_node, _ = self._create_query_node(context) |
|
|
yield query_node |
|
|
|
|
|
|
|
|
user_prompt = self._build_prompt(context) |
|
|
system_prompt = self.get_system_prompt(context.language) |
|
|
|
|
|
try: |
|
|
response = self.client.chat.completions.create( |
|
|
model=config.model, |
|
|
messages=[ |
|
|
{"role": "system", "content": system_prompt}, |
|
|
{"role": "user", "content": user_prompt} |
|
|
], |
|
|
temperature=config.temperature, |
|
|
max_tokens=config.max_tokens, |
|
|
response_format={"type": "json_object"} |
|
|
) |
|
|
|
|
|
response_text = response.choices[0].message.content |
|
|
logger.debug(f"OpenAI response: {response_text[:500]}...") |
|
|
|
|
|
data = json.loads(response_text) |
|
|
steps = data.get("steps", []) |
|
|
|
|
|
|
|
|
if not steps: |
|
|
logger.warning(f"OpenAI returned empty steps for query: {context.query[:50]}...") |
|
|
logger.warning("Falling back to local reasoning") |
|
|
|
|
|
|
|
|
steps = [ |
|
|
{"type": "fact", "content": f"Query received: {context.query}", "confidence": 0.95, "supports": [0]}, |
|
|
{"type": "reasoning", "content": "Analyzing the provided information", "confidence": 0.8, "supports": [1]}, |
|
|
{"type": "hypothesis", "content": "Based on the query, further analysis needed", "confidence": 0.6, "supports": [2]}, |
|
|
{"type": "conclusion", "content": "Please provide more specific symptoms for accurate analysis. Consult a healthcare professional.", "confidence": 0.5, "supports": [3]} |
|
|
] |
|
|
|
|
|
|
|
|
step_nodes = {0: query_node} |
|
|
previous_node = query_node |
|
|
|
|
|
for i, step in enumerate(steps, 1): |
|
|
node = self._create_step_node(step, context.language) |
|
|
self.kg.add_node(node) |
|
|
step_nodes[i] = node |
|
|
|
|
|
|
|
|
supports = step.get("supports", []) |
|
|
connected = False |
|
|
|
|
|
if supports: |
|
|
for sup_idx in supports: |
|
|
if sup_idx in step_nodes: |
|
|
edge = ReasoningEdge( |
|
|
source=step_nodes[sup_idx].id, |
|
|
target=node.id, |
|
|
edge_type=EdgeType.SUPPORTS, |
|
|
weight=node.confidence |
|
|
) |
|
|
if self.kg.add_edge(edge): |
|
|
connected = True |
|
|
logger.debug(f"Connected step {i} to step {sup_idx} via SUPPORTS") |
|
|
|
|
|
|
|
|
if not connected: |
|
|
edge = ReasoningEdge( |
|
|
source=previous_node.id, |
|
|
target=node.id, |
|
|
edge_type=EdgeType.LEADS_TO, |
|
|
weight=node.confidence |
|
|
) |
|
|
edge_id = self.kg.add_edge(edge) |
|
|
if edge_id: |
|
|
logger.debug(f"Connected step {i} to previous via LEADS_TO (fallback)") |
|
|
else: |
|
|
logger.error(f"Failed to connect step {i} - node may be isolated!") |
|
|
|
|
|
previous_node = node |
|
|
yield node |
|
|
|
|
|
|
|
|
if config.include_alternatives: |
|
|
last_reasoning_node = previous_node |
|
|
for alt in data.get("alternatives", []): |
|
|
ghost = self._create_alternative_node(alt, context.language) |
|
|
self.kg.add_node(ghost) |
|
|
|
|
|
|
|
|
edge = ReasoningEdge( |
|
|
source=last_reasoning_node.id, |
|
|
target=ghost.id, |
|
|
edge_type=EdgeType.ALTERNATIVE, |
|
|
weight=ghost.confidence |
|
|
) |
|
|
self.kg.add_edge(edge) |
|
|
|
|
|
yield ghost |
|
|
|
|
|
except json.JSONDecodeError as e: |
|
|
logger.error(f"Failed to parse response: {e}") |
|
|
yield self._create_error_node(query_node, "Could not parse response") |
|
|
except Exception as e: |
|
|
logger.error(f"OpenAI API error: {e}") |
|
|
yield self._create_error_node(query_node, str(e)) |
|
|
|
|
|
def _build_prompt(self, context: ReasoningContext) -> str: |
|
|
"""Build user prompt with context.""" |
|
|
parts = [] |
|
|
|
|
|
|
|
|
lang_name = LANGUAGE_NAMES.get(context.language, "English") |
|
|
parts.append(f"RESPOND IN {lang_name.upper()}.\n") |
|
|
|
|
|
|
|
|
if context.previous_reasoning: |
|
|
parts.append("PREVIOUS CONTEXT:") |
|
|
for item in context.previous_reasoning[-5:]: |
|
|
parts.append(f"- [{item.get('type')}]: {item['content'][:150]}") |
|
|
parts.append("\n--- NEW QUERY ---\n") |
|
|
|
|
|
parts.append(f"Query: {context.query}\n") |
|
|
|
|
|
if context.is_branching: |
|
|
parts.append("NOTE: Exploring alternative reasoning path.\n") |
|
|
|
|
|
|
|
|
if context.matched_entities: |
|
|
parts.append("\nMATCHED MEDICAL ENTITIES:") |
|
|
for result in context.matched_entities: |
|
|
entity = result.entity_data |
|
|
parts.append( |
|
|
f"- {entity.get('name')} ({entity.get('category')}) " |
|
|
f"[confidence: {result.score:.0%}]: {entity.get('description', '')[:80]}" |
|
|
) |
|
|
|
|
|
|
|
|
symptom_ids = [r.entity_id for r in context.matched_entities] |
|
|
diseases = self.kg.get_diseases_for_symptoms(symptom_ids) |
|
|
|
|
|
if diseases: |
|
|
parts.append("\nPOSSIBLE CONDITIONS:") |
|
|
for disease, score in diseases[:5]: |
|
|
parts.append(f"- {disease.name}: {score:.0%} match") |
|
|
|
|
|
parts.append("\nProvide structured reasoning as JSON.") |
|
|
return "\n".join(parts) |
|
|
|
|
|
def _create_step_node(self, step: Dict, language: str) -> ReasoningNode: |
|
|
"""Create reasoning node from step data.""" |
|
|
type_map = { |
|
|
"fact": NodeType.FACT, |
|
|
"reasoning": NodeType.REASONING, |
|
|
"hypothesis": NodeType.HYPOTHESIS, |
|
|
"conclusion": NodeType.CONCLUSION, |
|
|
"evidence": NodeType.EVIDENCE, |
|
|
} |
|
|
|
|
|
node_type = type_map.get(step.get("type", "reasoning"), NodeType.REASONING) |
|
|
|
|
|
return ReasoningNode( |
|
|
id=create_node_id(), |
|
|
label=step.get("content", "")[:60], |
|
|
node_type=node_type, |
|
|
content=step.get("content", ""), |
|
|
confidence=float(step.get("confidence", 0.8)), |
|
|
kg_entity_id=step.get("kg_entity_id"), |
|
|
language=language |
|
|
) |
|
|
|
|
|
def _create_alternative_node(self, alt: Dict, language: str) -> ReasoningNode: |
|
|
"""Create ghost node for alternative.""" |
|
|
return ReasoningNode( |
|
|
id=create_node_id(), |
|
|
label=f"Alt: {alt.get('content', '')[:50]}", |
|
|
node_type=NodeType.GHOST, |
|
|
content=alt.get("content", ""), |
|
|
confidence=float(alt.get("confidence", 0.3)), |
|
|
metadata={"reason": alt.get("reason", ""), "original_type": "hypothesis"}, |
|
|
language=language |
|
|
) |
|
|
|
|
|
def _create_error_node(self, query_node: ReasoningNode, error: str) -> ReasoningNode: |
|
|
"""Create error node.""" |
|
|
node = ReasoningNode( |
|
|
id=create_node_id(), |
|
|
label="Error", |
|
|
node_type=NodeType.REASONING, |
|
|
content=f"Analysis failed: {error}", |
|
|
confidence=0.0 |
|
|
) |
|
|
self.kg.add_node(node) |
|
|
self.kg.add_edge(ReasoningEdge( |
|
|
source=query_node.id, |
|
|
target=node.id, |
|
|
edge_type=EdgeType.LEADS_TO |
|
|
)) |
|
|
return node |
|
|
|
|
|
|
|
|
class LocalEngine(ReasoningEngine): |
|
|
""" |
|
|
Local knowledge-graph-based reasoning engine. |
|
|
Uses embeddings for entity matching, no LLM required. |
|
|
""" |
|
|
|
|
|
def __init__(self, kg: KnowledgeGraph): |
|
|
super().__init__(kg) |
|
|
self._last_conclusion_id: Optional[str] = None |
|
|
|
|
|
def generate( |
|
|
self, |
|
|
context: ReasoningContext, |
|
|
config: GenerationConfig |
|
|
) -> Generator[ReasoningNode, None, None]: |
|
|
"""Generate reasoning using knowledge graph only.""" |
|
|
query_node, _ = self._create_query_node(context) |
|
|
yield query_node |
|
|
|
|
|
|
|
|
if context.matched_entities: |
|
|
yield from self._entity_based_reasoning( |
|
|
query_node, context.matched_entities, context.language |
|
|
) |
|
|
else: |
|
|
yield from self._generic_reasoning(query_node, context.language) |
|
|
|
|
|
def _entity_based_reasoning( |
|
|
self, |
|
|
query_node: ReasoningNode, |
|
|
matched_entities: List[SearchResult], |
|
|
language: str |
|
|
) -> Generator[ReasoningNode, None, None]: |
|
|
"""Generate reasoning based on matched entities with proper Graph-of-Thoughts structure.""" |
|
|
messages = self._get_messages(language) |
|
|
|
|
|
|
|
|
symptom_nodes = [] |
|
|
symptom_ids = [] |
|
|
|
|
|
for result in matched_entities[:5]: |
|
|
symptom_name = result.entity_data.get("name", "Unknown") |
|
|
symptom_id = result.entity_id |
|
|
symptom_ids.append(symptom_id) |
|
|
|
|
|
fact_node = ReasoningNode( |
|
|
id=create_node_id(), |
|
|
label=f"Symptom: {symptom_name[:30]}", |
|
|
node_type=NodeType.FACT, |
|
|
content=f"{messages['identified']}: {symptom_name}", |
|
|
confidence=result.score, |
|
|
kg_entity_id=symptom_id, |
|
|
language=language |
|
|
) |
|
|
self.kg.add_node(fact_node) |
|
|
|
|
|
self.kg.add_edge(ReasoningEdge( |
|
|
source=query_node.id, target=fact_node.id, edge_type=EdgeType.LEADS_TO |
|
|
)) |
|
|
symptom_nodes.append(fact_node) |
|
|
yield fact_node |
|
|
|
|
|
|
|
|
reasoning_node = ReasoningNode( |
|
|
id=create_node_id(), |
|
|
label=messages['searching'][:50], |
|
|
node_type=NodeType.REASONING, |
|
|
content=messages['consulting'], |
|
|
confidence=0.9, |
|
|
language=language |
|
|
) |
|
|
self.kg.add_node(reasoning_node) |
|
|
|
|
|
|
|
|
for symptom_node in symptom_nodes: |
|
|
self.kg.add_edge(ReasoningEdge( |
|
|
source=symptom_node.id, target=reasoning_node.id, |
|
|
edge_type=EdgeType.SUPPORTS, weight=symptom_node.confidence |
|
|
)) |
|
|
yield reasoning_node |
|
|
|
|
|
|
|
|
possible_diseases = self.kg.get_diseases_for_symptoms(symptom_ids) |
|
|
|
|
|
|
|
|
hypothesis_nodes = [] |
|
|
primary_hypothesis = None |
|
|
|
|
|
for i, (disease, score) in enumerate(possible_diseases[:3]): |
|
|
is_primary = (i == 0) |
|
|
|
|
|
matching_symptoms = self.kg.get_symptoms_for_disease(disease.id) |
|
|
matching_names = [ |
|
|
s.name for s in matching_symptoms |
|
|
if s.id in symptom_ids |
|
|
] |
|
|
|
|
|
hypothesis = ReasoningNode( |
|
|
id=create_node_id(), |
|
|
label=f"{'Primary' if is_primary else 'Alt'}: {disease.name}", |
|
|
node_type=NodeType.HYPOTHESIS if is_primary else NodeType.GHOST, |
|
|
content=( |
|
|
f"{disease.name} ({score:.0%} {messages['match']})\n" |
|
|
f"{messages['description']}: {disease.description}\n" |
|
|
f"{messages['matching']}: {', '.join(matching_names)}" |
|
|
), |
|
|
confidence=score, |
|
|
kg_entity_id=disease.id, |
|
|
metadata={} if is_primary else {"original_type": "hypothesis"}, |
|
|
language=language |
|
|
) |
|
|
self.kg.add_node(hypothesis) |
|
|
|
|
|
|
|
|
edge_type = EdgeType.SUPPORTS if is_primary else EdgeType.ALTERNATIVE |
|
|
self.kg.add_edge(ReasoningEdge( |
|
|
source=reasoning_node.id, target=hypothesis.id, |
|
|
edge_type=edge_type, weight=score |
|
|
)) |
|
|
|
|
|
hypothesis_nodes.append(hypothesis) |
|
|
if is_primary: |
|
|
primary_hypothesis = hypothesis |
|
|
|
|
|
yield hypothesis |
|
|
|
|
|
|
|
|
if primary_hypothesis and possible_diseases: |
|
|
top_disease = possible_diseases[0][0] |
|
|
treatments = self.kg.get_treatments_for_disease(top_disease.id) |
|
|
|
|
|
treatment_text = "\n".join([ |
|
|
f"- {tx.name}: {tx.description}" |
|
|
for tx in treatments[:5] |
|
|
]) |
|
|
|
|
|
conclusion = ReasoningNode( |
|
|
id=create_node_id(), |
|
|
label=messages['recommendation'][:50], |
|
|
node_type=NodeType.CONCLUSION, |
|
|
content=( |
|
|
f"{messages['based_on']} {top_disease.name} {messages['most_likely']}.\n\n" |
|
|
f"{messages['treatments']}:\n{treatment_text}\n\n" |
|
|
f"⚠️ {messages['disclaimer']}" |
|
|
), |
|
|
confidence=possible_diseases[0][1] * 0.9, |
|
|
language=language |
|
|
) |
|
|
self.kg.add_node(conclusion) |
|
|
|
|
|
|
|
|
for hyp_node in hypothesis_nodes: |
|
|
edge_type = EdgeType.SUPPORTS if hyp_node == primary_hypothesis else EdgeType.ALTERNATIVE |
|
|
self.kg.add_edge(ReasoningEdge( |
|
|
source=hyp_node.id, target=conclusion.id, |
|
|
edge_type=edge_type, weight=hyp_node.confidence |
|
|
)) |
|
|
|
|
|
self._last_conclusion_id = conclusion.id |
|
|
yield conclusion |
|
|
else: |
|
|
yield from self._no_match_conclusion(reasoning_node, language) |
|
|
|
|
|
def _generic_reasoning( |
|
|
self, |
|
|
query_node: ReasoningNode, |
|
|
language: str |
|
|
) -> Generator[ReasoningNode, None, None]: |
|
|
"""Generic reasoning when no entities matched - creates branching structure.""" |
|
|
messages = self._get_messages(language) |
|
|
|
|
|
|
|
|
query_text = query_node.content.lower() |
|
|
symptom_keywords = [] |
|
|
|
|
|
|
|
|
|
|
|
common_symptoms = { |
|
|
'en': [ |
|
|
|
|
|
'headache', 'head ache', 'head pain', 'migraine', |
|
|
|
|
|
'fever', 'temperature', 'chills', 'sweating', 'hot', |
|
|
|
|
|
'cough', 'coughing', 'cold', 'flu', 'runny nose', 'congestion', |
|
|
'shortness of breath', 'breathing', 'sore throat', 'throat', |
|
|
|
|
|
'pain', 'ache', 'aching', 'hurts', 'hurt', 'sore', 'burning', |
|
|
|
|
|
'nausea', 'vomiting', 'diarrhea', 'stomach', 'belly', 'abdomen', |
|
|
'constipation', 'bloating', |
|
|
|
|
|
'fatigue', 'tired', 'weakness', 'weak', 'exhausted', 'dizzy', |
|
|
'dizziness', 'lightheaded', 'faint', |
|
|
|
|
|
'rash', 'itching', 'swelling', 'swollen', |
|
|
|
|
|
'insomnia', 'anxiety', 'stress', 'depression', |
|
|
], |
|
|
'uk': [ |
|
|
|
|
|
'головний біль', 'болить голова', 'біль голови', 'мігрень', |
|
|
|
|
|
'температура', 'гарячка', 'лихоманка', 'озноб', 'жар', |
|
|
|
|
|
'кашель', 'кашляю', 'застуда', 'грип', 'нежить', 'закладений ніс', |
|
|
'задишка', 'важко дихати', 'біль в горлі', 'горло болить', |
|
|
|
|
|
'біль', 'болить', 'боляче', 'ниє', 'печіння', |
|
|
|
|
|
'нудота', 'нудить', 'блювота', 'пронос', 'діарея', 'живіт', |
|
|
'шлунок', 'запор', 'здуття', |
|
|
|
|
|
'втома', 'слабкість', 'знесилення', 'запаморочення', |
|
|
'паморочиться', 'млість', |
|
|
|
|
|
'висип', 'свербіж', 'набряк', 'опух', |
|
|
|
|
|
'безсоння', 'тривога', 'стрес', 'депресія', |
|
|
], |
|
|
'ru': [ |
|
|
|
|
|
'головная боль', 'болит голова', 'боль в голове', 'мигрень', |
|
|
|
|
|
'температура', 'жар', 'лихорадка', 'озноб', 'потливость', |
|
|
|
|
|
'кашель', 'кашляю', 'простуда', 'грипп', 'насморк', 'заложенность', |
|
|
'одышка', 'тяжело дышать', 'боль в горле', 'горло болит', |
|
|
|
|
|
'боль', 'болит', 'больно', 'ноет', 'жжение', |
|
|
|
|
|
'тошнота', 'тошнит', 'рвота', 'понос', 'диарея', 'живот', |
|
|
'желудок', 'запор', 'вздутие', |
|
|
|
|
|
'усталость', 'слабость', 'утомление', 'головокружение', |
|
|
'кружится голова', 'обморок', |
|
|
|
|
|
'сыпь', 'зуд', 'отёк', 'опухло', 'опухоль', |
|
|
|
|
|
'бессонница', 'тревога', 'стресс', 'депрессия', |
|
|
], |
|
|
} |
|
|
|
|
|
|
|
|
all_symptoms = common_symptoms.get(language, []) + common_symptoms.get('en', []) |
|
|
if language not in ['en']: |
|
|
all_symptoms += common_symptoms.get('uk', []) + common_symptoms.get('ru', []) |
|
|
|
|
|
for symptom in all_symptoms: |
|
|
if symptom in query_text: |
|
|
|
|
|
if symptom not in symptom_keywords: |
|
|
symptom_keywords.append(symptom) |
|
|
|
|
|
logger.debug(f"Detected symptoms in '{query_text[:50]}...': {symptom_keywords}") |
|
|
|
|
|
|
|
|
symptom_nodes = [] |
|
|
|
|
|
if len(symptom_keywords) > 1: |
|
|
|
|
|
for symptom in symptom_keywords[:4]: |
|
|
fact = ReasoningNode( |
|
|
id=create_node_id(), |
|
|
label=f"Symptom: {symptom.title()[:25]}", |
|
|
node_type=NodeType.FACT, |
|
|
content=f"{messages['identified']}: {symptom}", |
|
|
confidence=0.85, |
|
|
language=language |
|
|
) |
|
|
self.kg.add_node(fact) |
|
|
self.kg.add_edge(ReasoningEdge( |
|
|
source=query_node.id, target=fact.id, edge_type=EdgeType.LEADS_TO |
|
|
)) |
|
|
symptom_nodes.append(fact) |
|
|
yield fact |
|
|
|
|
|
|
|
|
reasoning = ReasoningNode( |
|
|
id=create_node_id(), |
|
|
label=messages['analyzing'][:50], |
|
|
node_type=NodeType.REASONING, |
|
|
content=f"{messages['consulting']} - analyzing {len(symptom_keywords)} symptoms", |
|
|
confidence=0.9, |
|
|
language=language |
|
|
) |
|
|
self.kg.add_node(reasoning) |
|
|
|
|
|
|
|
|
for sym_node in symptom_nodes: |
|
|
self.kg.add_edge(ReasoningEdge( |
|
|
source=sym_node.id, target=reasoning.id, |
|
|
edge_type=EdgeType.SUPPORTS, weight=sym_node.confidence |
|
|
)) |
|
|
yield reasoning |
|
|
|
|
|
|
|
|
hyp1 = ReasoningNode( |
|
|
id=create_node_id(), |
|
|
label="Possible: Common condition", |
|
|
node_type=NodeType.HYPOTHESIS, |
|
|
content="Common condition matching these symptoms", |
|
|
confidence=0.6, |
|
|
language=language |
|
|
) |
|
|
self.kg.add_node(hyp1) |
|
|
self.kg.add_edge(ReasoningEdge( |
|
|
source=reasoning.id, target=hyp1.id, edge_type=EdgeType.SUPPORTS |
|
|
)) |
|
|
yield hyp1 |
|
|
|
|
|
hyp2 = ReasoningNode( |
|
|
id=create_node_id(), |
|
|
label="Alternative: Secondary condition", |
|
|
node_type=NodeType.GHOST, |
|
|
content="Alternative diagnosis to consider", |
|
|
confidence=0.4, |
|
|
metadata={"original_type": "hypothesis"}, |
|
|
language=language |
|
|
) |
|
|
self.kg.add_node(hyp2) |
|
|
self.kg.add_edge(ReasoningEdge( |
|
|
source=reasoning.id, target=hyp2.id, edge_type=EdgeType.ALTERNATIVE |
|
|
)) |
|
|
yield hyp2 |
|
|
|
|
|
|
|
|
conclusion = ReasoningNode( |
|
|
id=create_node_id(), |
|
|
label=messages['recommendation'][:50], |
|
|
node_type=NodeType.CONCLUSION, |
|
|
content=f"{messages['provide_more']}\n\n⚠️ {messages['disclaimer']}", |
|
|
confidence=0.5, |
|
|
language=language |
|
|
) |
|
|
self.kg.add_node(conclusion) |
|
|
self.kg.add_edge(ReasoningEdge( |
|
|
source=hyp1.id, target=conclusion.id, edge_type=EdgeType.SUPPORTS |
|
|
)) |
|
|
self.kg.add_edge(ReasoningEdge( |
|
|
source=hyp2.id, target=conclusion.id, edge_type=EdgeType.ALTERNATIVE |
|
|
)) |
|
|
self._last_conclusion_id = conclusion.id |
|
|
yield conclusion |
|
|
else: |
|
|
|
|
|
step1 = ReasoningNode( |
|
|
id=create_node_id(), |
|
|
label=messages['analyzing'][:50], |
|
|
node_type=NodeType.REASONING, |
|
|
content=messages['analyzing'], |
|
|
confidence=0.9, |
|
|
language=language |
|
|
) |
|
|
self.kg.add_node(step1) |
|
|
self.kg.add_edge(ReasoningEdge( |
|
|
source=query_node.id, target=step1.id, edge_type=EdgeType.LEADS_TO |
|
|
)) |
|
|
yield step1 |
|
|
|
|
|
yield from self._no_match_conclusion(step1, language) |
|
|
|
|
|
def _no_match_conclusion( |
|
|
self, |
|
|
parent_node: ReasoningNode, |
|
|
language: str |
|
|
) -> Generator[ReasoningNode, None, None]: |
|
|
"""Conclusion when no matches found.""" |
|
|
messages = self._get_messages(language) |
|
|
|
|
|
conclusion = ReasoningNode( |
|
|
id=create_node_id(), |
|
|
label=messages['recommendation'][:50], |
|
|
node_type=NodeType.CONCLUSION, |
|
|
content=f"{messages['provide_more']}\n\n⚠️ {messages['disclaimer']}", |
|
|
confidence=0.5, |
|
|
language=language |
|
|
) |
|
|
self.kg.add_node(conclusion) |
|
|
self.kg.add_edge(ReasoningEdge( |
|
|
source=parent_node.id, target=conclusion.id, edge_type=EdgeType.LEADS_TO |
|
|
)) |
|
|
self._last_conclusion_id = conclusion.id |
|
|
yield conclusion |
|
|
|
|
|
def _get_messages(self, language: str) -> Dict[str, str]: |
|
|
"""Get localized messages.""" |
|
|
messages = { |
|
|
"en": { |
|
|
"symptoms": "Symptoms", |
|
|
"identified": "Identified symptoms", |
|
|
"searching": "Searching knowledge base...", |
|
|
"consulting": "Consulting medical knowledge graph for conditions", |
|
|
"match": "match", |
|
|
"description": "Description", |
|
|
"matching": "Matching symptoms", |
|
|
"recommendation": "Recommendation", |
|
|
"based_on": "Based on analysis,", |
|
|
"most_likely": "is the most likely condition", |
|
|
"treatments": "Recommended treatments", |
|
|
"analyzing": "Analyzing query for medical terms", |
|
|
"provide_more": "Could not identify specific symptoms. Please provide more details.", |
|
|
"disclaimer": "DISCLAIMER: Educational purposes only. Consult healthcare professionals." |
|
|
}, |
|
|
"uk": { |
|
|
"symptoms": "Симптоми", |
|
|
"identified": "Визначені симптоми", |
|
|
"searching": "Пошук у базі знань...", |
|
|
"consulting": "Консультація медичного графу знань", |
|
|
"match": "збіг", |
|
|
"description": "Опис", |
|
|
"matching": "Симптоми, що збігаються", |
|
|
"recommendation": "Рекомендація", |
|
|
"based_on": "На основі аналізу,", |
|
|
"most_likely": "є найбільш ймовірним станом", |
|
|
"treatments": "Рекомендоване лікування", |
|
|
"analyzing": "Аналіз запиту на медичні терміни", |
|
|
"provide_more": "Не вдалося визначити симптоми. Надайте більше деталей.", |
|
|
"disclaimer": "ВІДМОВА: Лише в освітніх цілях. Зверніться до лікаря." |
|
|
}, |
|
|
"ru": { |
|
|
"symptoms": "Симптомы", |
|
|
"identified": "Определённые симптомы", |
|
|
"searching": "Поиск в базе знаний...", |
|
|
"consulting": "Консультация медицинского графа знаний", |
|
|
"match": "совпадение", |
|
|
"description": "Описание", |
|
|
"matching": "Совпадающие симптомы", |
|
|
"recommendation": "Рекомендация", |
|
|
"based_on": "На основе анализа,", |
|
|
"most_likely": "является наиболее вероятным состоянием", |
|
|
"treatments": "Рекомендуемое лечение", |
|
|
"analyzing": "Анализ запроса на медицинские термины", |
|
|
"provide_more": "Не удалось определить симптомы. Предоставьте больше деталей.", |
|
|
"disclaimer": "ОТКАЗ: Только в образовательных целях. Обратитесь к врачу." |
|
|
} |
|
|
} |
|
|
return messages.get(language, messages["en"]) |
|
|
|
|
|
|
|
|
class GraphSynchronizer: |
|
|
""" |
|
|
Handles graph operations triggered by UI interactions. |
|
|
Simplified from original - removed complex state management. |
|
|
""" |
|
|
|
|
|
def __init__(self, engine: ReasoningEngine, kg: KnowledgeGraph): |
|
|
self.engine = engine |
|
|
self.kg = kg |
|
|
self.edit_history: List[Dict] = [] |
|
|
|
|
|
def prune_node(self, node_id: str) -> Dict: |
|
|
"""Prune a node and its descendants.""" |
|
|
result = self.kg.prune_branch(node_id) |
|
|
self._log_edit("prune", node_id, result) |
|
|
return {"success": True, "pruned": result} |
|
|
|
|
|
def resurrect_node(self, node_id: str) -> Dict: |
|
|
"""Resurrect a ghost node.""" |
|
|
success = self.kg.resurrect_node(node_id) |
|
|
self._log_edit("resurrect", node_id) |
|
|
return {"success": success} |
|
|
|
|
|
def inject_fact( |
|
|
self, |
|
|
parent_node_id: str, |
|
|
fact_content: str, |
|
|
entity_id: Optional[str] = None |
|
|
) -> Dict: |
|
|
"""Inject a new fact into the reasoning chain.""" |
|
|
node = ReasoningNode( |
|
|
id=create_node_id(), |
|
|
label=fact_content[:60], |
|
|
node_type=NodeType.FACT, |
|
|
content=fact_content, |
|
|
confidence=1.0, |
|
|
kg_entity_id=entity_id, |
|
|
metadata={"user_injected": True} |
|
|
) |
|
|
|
|
|
self.kg.add_node(node) |
|
|
self.kg.add_edge(ReasoningEdge( |
|
|
source=parent_node_id, |
|
|
target=node.id, |
|
|
edge_type=EdgeType.REQUIRES, |
|
|
metadata={"user_injected": True} |
|
|
)) |
|
|
|
|
|
self._log_edit("inject", parent_node_id, {"new_node_id": node.id}) |
|
|
return {"success": True, "new_node_id": node.id} |
|
|
|
|
|
def record_feedback( |
|
|
self, |
|
|
node_id: str, |
|
|
feedback_type: str, |
|
|
context: str = "" |
|
|
) -> Dict: |
|
|
"""Record user feedback on a node (for RLHF).""" |
|
|
node = self.kg.nodes.get(node_id) |
|
|
if not node: |
|
|
return {"success": False, "error": "Node not found"} |
|
|
|
|
|
node.metadata["feedback"] = feedback_type |
|
|
node.metadata["feedback_context"] = context |
|
|
node.metadata["feedback_timestamp"] = datetime.now().isoformat() |
|
|
|
|
|
|
|
|
if feedback_type == "correct": |
|
|
node.confidence = min(node.confidence * 1.2, 1.0) |
|
|
elif feedback_type == "incorrect": |
|
|
node.confidence = max(node.confidence * 0.5, 0.1) |
|
|
|
|
|
self.kg.update_node(node_id, confidence=node.confidence, metadata=node.metadata) |
|
|
self._log_edit("feedback", node_id, {"type": feedback_type}) |
|
|
|
|
|
return {"success": True, "new_confidence": node.confidence} |
|
|
|
|
|
def _log_edit(self, op_type: str, node_id: str, data: Any = None): |
|
|
"""Log edit for history.""" |
|
|
self.edit_history.append({ |
|
|
"type": op_type, |
|
|
"node_id": node_id, |
|
|
"data": data, |
|
|
"timestamp": datetime.now().isoformat() |
|
|
}) |
|
|
|
|
|
def export_history(self) -> List[Dict]: |
|
|
"""Export edit history for RLHF training.""" |
|
|
return self.edit_history.copy() |
|
|
|
|
|
|
|
|
def create_engine( |
|
|
provider: LLMProvider, |
|
|
kg: KnowledgeGraph, |
|
|
api_key: Optional[str] = None |
|
|
) -> ReasoningEngine: |
|
|
"""Factory function to create reasoning engine.""" |
|
|
if provider == LLMProvider.OPENAI: |
|
|
key = api_key or os.environ.get("OPENAI_API_KEY") |
|
|
if not key: |
|
|
raise ValueError("OpenAI API key required") |
|
|
return OpenAIEngine(kg, api_key=key) |
|
|
else: |
|
|
return LocalEngine(kg) |
|
|
|