""" HITL-KG Medical Reasoning System - Main Application Human-in-the-Loop Knowledge Graph Visualization for Medical Reasoning Features: - Interactive knowledge graph visualization with Cytoscape - Multilingual support with embedding-based entity extraction - Session persistence with chat history - RLHF feedback collection - Glass-box visualization of reasoning process Refactored for: - Simplified state management - Cleaner callbacks - Embedding-based search (replacing keyword matching) - Configuration-driven setup """ import os import uuid import logging from datetime import datetime from typing import Dict, List, Optional, Any import dash from dash import html, dcc, callback, Input, Output, State, ctx, ALL from dash.exceptions import PreventUpdate import dash_cytoscape as cyto import dash_bootstrap_components as dbc # Configure logging logging.basicConfig( level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s' ) logger = logging.getLogger(__name__) # Load extra layouts cyto.load_extra_layouts() # Import core modules from src.core import ( load_knowledge_graph, create_engine, LLMProvider, GenerationConfig, GraphSynchronizer, NodeType, NODE_TYPE_INFO, get_config, get_session_manager, detect_language, ) # Import styles (assumes these exist - not recreating) try: from src.styles import CYTOSCAPE_STYLESHEET, LAYOUT_CONFIGS, CUSTOM_CSS except ImportError: # Fallback styles if module doesn't exist CYTOSCAPE_STYLESHEET = [ {"selector": "node", "style": { "label": "data(label)", "background-color": "#818cf8", "color": "#fff", "font-size": "10px", "text-wrap": "wrap", "text-max-width": "100px" }}, {"selector": "edge", "style": { "curve-style": "bezier", "target-arrow-shape": "triangle", "line-color": "#64748b", "target-arrow-color": "#64748b" }}, {"selector": ".query", "style": {"background-color": "#38bdf8"}}, {"selector": ".fact", "style": {"background-color": "#4ade80"}}, {"selector": ".reasoning", "style": {"background-color": "#818cf8"}}, {"selector": ".hypothesis", "style": {"background-color": "#fbbf24"}}, {"selector": ".conclusion", "style": {"background-color": "#f472b6"}}, {"selector": ".ghost", "style": {"background-color": "#94a3b8", "opacity": 0.6}}, ] LAYOUT_CONFIGS = { "hierarchical": {"name": "dagre", "rankDir": "TB", "spacingFactor": 1.2}, "force": {"name": "cose", "animate": False}, "radial": {"name": "concentric", "animate": False}, } CUSTOM_CSS = "" # ============================================================================ # CONFIGURATION # ============================================================================ config = get_config() OPENAI_API_KEY = os.environ.get("OPENAI_API_KEY", "") DEFAULT_PROVIDER = "openai" if OPENAI_API_KEY else "local" AVAILABLE_PROVIDERS = [] if OPENAI_API_KEY: AVAILABLE_PROVIDERS.append({"label": "๐Ÿค– OpenAI GPT-4", "value": "openai"}) AVAILABLE_PROVIDERS.append({"label": "๐Ÿ“Š Local Knowledge Graph", "value": "local"}) EXAMPLE_QUERIES = { "en": [ {"text": "Fever and cough for 3 days", "icon": "๐Ÿค’"}, {"text": "Headache with fatigue", "icon": "๐Ÿ˜ซ"}, {"text": "Sore throat and runny nose", "icon": "๐Ÿคง"}, {"text": "Shortness of breath", "icon": "๐Ÿ˜ท"}, ], "uk": [ {"text": "ะขะตะผะฟะตั€ะฐั‚ัƒั€ะฐ ั– ะบะฐัˆะตะปัŒ", "icon": "๐Ÿค’"}, {"text": "ะ“ะพะปะพะฒะฝะธะน ะฑั–ะปัŒ ะท ะฒั‚ะพะผะพัŽ", "icon": "๐Ÿ˜ซ"}, {"text": "ะ‘ั–ะปัŒ ัƒ ะณะพั€ะปั– ั‚ะฐ ะฝะตะถะธั‚ัŒ", "icon": "๐Ÿคง"}, {"text": "ะ—ะฐะดะธัˆะบะฐ", "icon": "๐Ÿ˜ท"}, ], } # ============================================================================ # USER STATE MANAGEMENT (Simplified) # ============================================================================ class UserState: """Per-user state with KG, engine, and session data.""" def __init__(self, session_id: str): self.session_id = session_id self.kg = load_knowledge_graph(use_embeddings=True) self.engine = None self.synchronizer = None self.provider = DEFAULT_PROVIDER self.language = "en" self._init_engine() self._restore_session() def _init_engine(self): """Initialize reasoning engine.""" try: if self.provider == "openai" and OPENAI_API_KEY: from src.core import OpenAIEngine self.engine = OpenAIEngine(self.kg, api_key=OPENAI_API_KEY) else: from src.core import LocalEngine self.engine = LocalEngine(self.kg) self.synchronizer = GraphSynchronizer(self.engine, self.kg) except Exception as e: logger.error(f"Engine init failed: {e}") from src.core import LocalEngine self.engine = LocalEngine(self.kg) self.synchronizer = GraphSynchronizer(self.engine, self.kg) def _restore_session(self): """Restore session from manager.""" sm = get_session_manager() session = sm.get_session(self.session_id) if session and session.graph_state: try: self.kg.restore_state(session.graph_state) self.language = session.language except Exception as e: logger.warning(f"Failed to restore session: {e}") def set_provider(self, provider: str): """Switch LLM provider.""" if provider != self.provider: self.provider = provider self._init_engine() def reset(self): """Reset reasoning state completely.""" # Clear knowledge graph reasoning self.kg.clear_reasoning() # Reset session sm = get_session_manager() session = sm.get_or_create(self.session_id) session.chat_history.clear() session.graph_state = None sm.save_session(self.session_id) # Reset language self.language = "en" logger.info(f"Session {self.session_id} reset. Graph now has {len(self.kg.nodes)} nodes") def save(self): """Save current state.""" sm = get_session_manager() sm.update_graph_state(self.session_id, self.kg.get_state()) def get_chat_history(self) -> List[Dict]: """Get chat history from session.""" sm = get_session_manager() session = sm.get_or_create(self.session_id) return [{"role": m.role, "content": m.content} for m in session.chat_history] def add_message(self, role: str, content: str): """Add message to session.""" sm = get_session_manager() session = sm.get_or_create(self.session_id) session.add_message(role, content, self.language) sm.save_session(self.session_id) # User state storage _user_states: Dict[str, UserState] = {} _user_states_lock = __import__('threading').Lock() def get_user_state(session_id: str) -> UserState: """Get or create user state.""" with _user_states_lock: if session_id not in _user_states: _user_states[session_id] = UserState(session_id) logger.info(f"Created user state: {session_id}") return _user_states[session_id] def cleanup_user_states(): """Cleanup old user states.""" with _user_states_lock: if len(_user_states) > config.max_sessions: # Remove oldest 10% sorted_states = sorted( _user_states.items(), key=lambda x: x[1].kg.version ) for sid, _ in sorted_states[:len(sorted_states) // 10]: del _user_states[sid] # ============================================================================ # DASH APPLICATION # ============================================================================ app = dash.Dash( __name__, external_stylesheets=[ dbc.themes.DARKLY, "https://fonts.googleapis.com/css2?family=DM+Sans:wght@400;500;600&display=swap", ], suppress_callback_exceptions=True, title="HITL-KG Medical Reasoning", ) server = app.server # ============================================================================ # LAYOUT COMPONENTS # ============================================================================ def create_header(): """Application header.""" status = "๐ŸŸข OpenAI" if OPENAI_API_KEY else "๐ŸŸก Local" return dbc.Navbar( dbc.Container([ dbc.Row([ dbc.Col([ html.Span("โš•๏ธ", style={"fontSize": "1.5rem", "marginRight": "10px"}), html.Span("HITL-KG", style={"fontWeight": "700", "fontSize": "1.2rem"}), dbc.Badge("Medical Reasoning", color="info", className="ms-2"), ], className="d-flex align-items-center"), ]), dbc.Row([ dbc.Col([ html.Span(status, className="me-3", style={"fontSize": "0.85rem"}), html.Span(id="language-indicator", children="๐ŸŒ EN"), ]), dbc.Col([ dbc.Button("โ“ Help", id="btn-help", size="sm", outline=True, className="me-2"), dbc.Button("โ†บ Reset", id="btn-reset", size="sm", outline=True), ], width="auto"), ], className="g-0"), ], fluid=True), dark=True, className="mb-3", style={"backgroundColor": "#1e293b"}, ) def create_chat_panel(): """Chat interface panel with tabs for current chat and history.""" return dbc.Card([ dbc.CardHeader([ html.Div([ html.Span("๐Ÿ’ฌ", className="me-2"), html.Span("Symptom Analysis", style={"fontWeight": "600"}), ]), dbc.Button("+ New", id="btn-new-chat", size="sm", color="primary", className="float-end"), ], className="d-flex justify-content-between align-items-center"), dbc.CardBody([ # Tabs for Chat and History dbc.Tabs([ dbc.Tab(label="๐Ÿ’ฌ Chat", tab_id="tab-chat", children=[ html.Div([ # Provider selector html.Div([ html.Label("AI Model", className="small mt-2", style={"color": "#94a3b8"}), dcc.Dropdown( id="provider-select", options=AVAILABLE_PROVIDERS, value=DEFAULT_PROVIDER, clearable=False, className="mb-3", ), ]), # Chat history html.Div( id="chat-history", className="chat-container", style={ "height": "180px", "overflowY": "auto", "backgroundColor": "#0f172a", "borderRadius": "8px", "padding": "10px", "marginBottom": "15px", }, children=[create_welcome_message()] ), # Quick examples html.Div([ html.Label("Quick examples:", className="small mb-2", style={"color": "#94a3b8"}), html.Div(id="example-queries", children=[ html.Button( [html.Span(q["icon"], className="me-1"), q["text"]], id={"type": "example", "index": i}, className="btn btn-outline-secondary btn-sm me-2 mb-2", ) for i, q in enumerate(EXAMPLE_QUERIES["en"]) ], className="d-flex flex-wrap"), ], className="mb-3"), # Input dbc.Textarea( id="chat-input", placeholder="Describe your symptoms...", style={"height": "60px", "resize": "none"}, className="mb-2", ), html.Div([ dbc.Button("๐Ÿ” Analyze", id="btn-send", color="primary", className="me-2"), dbc.Button("๐Ÿ—‘๏ธ Clear", id="btn-clear", color="secondary", outline=True, size="sm"), ]), ]) ]), dbc.Tab(label="๐Ÿ“œ History", tab_id="tab-history", children=[ html.Div([ html.Label("Session History", className="small mt-2 mb-2", style={"color": "#94a3b8"}), html.Div( id="session-history-list", style={ "height": "350px", "overflowY": "auto", "backgroundColor": "#0f172a", "borderRadius": "8px", "padding": "10px", }, children=[ html.P("No saved sessions yet.", className="text-muted small text-center mt-3") ] ), html.Div([ dbc.Button("๐Ÿ’พ Save Session", id="btn-save-session", color="success", size="sm", className="mt-2 me-2"), dbc.Button("๐Ÿ—‘๏ธ Clear History", id="btn-clear-history", color="danger", size="sm", className="mt-2", outline=True), ]), ]) ]), ], id="chat-tabs", active_tab="tab-chat"), ]), ], style={"backgroundColor": "#1e293b"}) def create_graph_panel(): """Graph visualization panel.""" return dbc.Card([ dbc.CardHeader([ html.Div([ html.Span("๐Ÿง ", className="me-2"), html.Span("Reasoning Graph", style={"fontWeight": "600"}), ]), html.Div([ dbc.Button("โˆ’", id="btn-zoom-out", size="sm", outline=True), dbc.Button("โŸฒ", id="btn-zoom-fit", size="sm", outline=True, className="mx-1"), dbc.Button("+", id="btn-zoom-in", size="sm", outline=True), dbc.ButtonGroup([ dbc.Button("โ‡„", id="btn-layout-dag", size="sm", outline=True, active=True), dbc.Button("โ—Ž", id="btn-layout-force", size="sm", outline=True), dbc.Button("โ—‰", id="btn-layout-radial", size="sm", outline=True), ], size="sm", className="ms-2"), ], className="d-flex"), ], className="d-flex justify-content-between align-items-center"), dbc.CardBody([ cyto.Cytoscape( id="reasoning-graph", elements=[], layout=LAYOUT_CONFIGS.get("hierarchical", {"name": "dagre"}), style={"width": "100%", "height": "350px", "backgroundColor": "#0a1020"}, stylesheet=CYTOSCAPE_STYLESHEET, boxSelectionEnabled=True, minZoom=0.2, maxZoom=3.0, ), # Legend html.Div([ html.Div([ html.Span("โ—", style={"color": info["color"], "marginRight": "4px"}), html.Span(info["name"], style={"fontSize": "0.75rem", "marginRight": "10px"}), ], className="d-inline-block") for ntype, info in list(NODE_TYPE_INFO.items())[:6] ], className="mt-2 text-center"), # Stats html.Div( id="stats-display", children="Ready โ€” Enter symptoms to begin", className="mt-2 text-center", style={"color": "#94a3b8", "fontSize": "0.85rem"}, ), ]), ], style={"backgroundColor": "#1e293b"}) def create_control_panel(): """Steering controls panel.""" return dbc.Card([ dbc.CardHeader([ html.Span("๐ŸŽ›๏ธ", className="me-2"), html.Span("Controls", style={"fontWeight": "600"}), ]), dbc.CardBody([ # Selected node info html.Div([ html.Label("Selected Node", className="small", style={"color": "#94a3b8"}), html.Div(id="selected-node-info", children=[ html.P("๐Ÿ‘† Click a node", className="text-muted small") ], style={"minHeight": "80px"}), ], className="mb-3"), html.Hr(style={"borderColor": "#475569"}), # Feedback html.Div([ html.Label("Feedback (RLHF)", className="small", style={"color": "#94a3b8"}), html.Div([ dbc.Button("โœ“ Correct", id="btn-correct", color="success", size="sm", disabled=True, className="me-2"), dbc.Button("โœ— Incorrect", id="btn-incorrect", color="danger", size="sm", disabled=True), ], className="mb-2"), html.Div(id="feedback-status"), ], className="mb-3"), html.Hr(style={"borderColor": "#475569"}), # Display options html.Div([ html.Label("Display", className="small", style={"color": "#94a3b8"}), dbc.Checklist( id="display-options", options=[{"label": " Show pruned paths", "value": "ghosts"}], value=[], switch=True, className="mb-2", ), html.Label("Confidence threshold", className="small", style={"color": "#64748b", "fontSize": "0.75rem"}), dcc.Slider( id="confidence-slider", min=0, max=1, step=0.1, value=0, marks={0: "0%", 0.5: "50%", 1: "100%"}, ), ], className="mb-3"), html.Hr(style={"borderColor": "#475569"}), # Actions html.Div([ html.Label("Actions", className="small", style={"color": "#94a3b8"}), dbc.ButtonGroup([ dbc.Button("โœ‚๏ธ Prune", id="btn-prune", color="danger", size="sm", disabled=True), dbc.Button("๐Ÿ‘ป Resurrect", id="btn-resurrect", color="warning", size="sm", disabled=True), ], className="w-100 mb-2"), dbc.Button("๐Ÿš€ Start reasoning from here", id="btn-branch", color="info", size="sm", disabled=True, className="w-100"), ], className="mb-3"), html.Hr(style={"borderColor": "#475569"}), # Fact injection html.Div([ html.Label("Inject Fact", className="small", style={"color": "#94a3b8"}), dbc.InputGroup([ dbc.Input(id="fact-input", placeholder="e.g., Patient has diabetes", size="sm"), dbc.Button("โž•", id="btn-inject", color="success", size="sm", disabled=True), ], size="sm"), ]), ], style={"maxHeight": "calc(100vh - 200px)", "overflowY": "auto"}), ], style={"backgroundColor": "#1e293b"}) def create_welcome_message(): """Welcome message for chat.""" return html.Div([ html.Div("๐Ÿ‘‹", className="text-center", style={"fontSize": "1.5rem"}), html.P("Welcome to HITL-KG Medical Reasoning", className="text-center mb-1", style={"fontWeight": "600"}), html.P("Describe symptoms to see AI reasoning visualized.", className="text-center small text-muted"), html.P("๐ŸŒ Type in any language!", className="text-center small", style={"fontStyle": "italic", "color": "#64748b"}), ], className="p-3") def create_help_modal(): """Help modal.""" return dbc.Modal([ dbc.ModalHeader(dbc.ModalTitle("๐Ÿ“– How to Use")), dbc.ModalBody([ html.H6("๐ŸŽฏ Overview"), html.P("HITL-KG visualizes AI medical reasoning as an interactive graph."), html.H6("๐Ÿ”ต Node Types", className="mt-3"), html.Ul([ html.Li([html.Strong(info["icon"] + " " + info["name"] + ": "), info["description"]]) for info in list(NODE_TYPE_INFO.values())[:6] ]), html.H6("๐ŸŽฎ Interactions", className="mt-3"), html.Ul([ html.Li("Click nodes to select and view details"), html.Li("Prune to remove incorrect reasoning paths"), html.Li("Resurrect to restore pruned nodes"), html.Li("Inject facts to add medical information"), ]), html.H6("๐ŸŒ Languages", className="mt-3"), html.P("Supports English, Ukrainian, Russian, Spanish, German, French and more."), ]), dbc.ModalFooter(dbc.Button("Got it!", id="btn-close-help", color="primary")), ], id="help-modal", size="lg", is_open=False) # Main layout app.layout = html.Div([ # Stores dcc.Store(id="session-id", storage_type="session"), dcc.Store(id="selected-node-id", data=None), dcc.Store(id="graph-version", data=0), # Header create_header(), # Help modal create_help_modal(), # Main content dbc.Container([ dbc.Row([ dbc.Col(create_chat_panel(), lg=3, md=12, className="mb-3"), dbc.Col(create_graph_panel(), lg=6, md=12, className="mb-3"), dbc.Col(create_control_panel(), lg=3, md=12, className="mb-3"), ], className="g-3"), ], fluid=True), # Loading indicator dcc.Loading(id="loading", type="circle", fullscreen=False, children=html.Div(id="loading-target")), # Cleanup interval dcc.Interval(id="cleanup-interval", interval=300000), ], style={"minHeight": "100vh", "backgroundColor": "#0f172a"}) # ============================================================================ # CALLBACKS # ============================================================================ @callback( Output("session-id", "data"), Input("session-id", "data"), ) def init_session(existing_id): """Initialize session.""" if existing_id: return existing_id return str(uuid.uuid4())[:12] @callback( Output("help-modal", "is_open"), [Input("btn-help", "n_clicks"), Input("btn-close-help", "n_clicks")], State("help-modal", "is_open"), ) def toggle_help(open_clicks, close_clicks, is_open): """Toggle help modal.""" if open_clicks or close_clicks: return not is_open return is_open @callback( Output("loading-target", "children"), Input("provider-select", "value"), State("session-id", "data"), ) def switch_provider(provider, session_id): """Switch LLM provider.""" if provider and session_id: state = get_user_state(session_id) state.set_provider(provider) return "" @callback( Output("chat-input", "value", allow_duplicate=True), Input({"type": "example", "index": ALL}, "n_clicks"), prevent_initial_call=True ) def fill_example(clicks): """Fill input with example query.""" if not any(clicks): raise PreventUpdate triggered = ctx.triggered_id if triggered and isinstance(triggered, dict): idx = triggered.get("index", 0) if idx < len(EXAMPLE_QUERIES["en"]): return EXAMPLE_QUERIES["en"][idx]["text"] raise PreventUpdate @callback( [ Output("reasoning-graph", "elements"), Output("chat-history", "children"), Output("stats-display", "children"), Output("chat-input", "value"), Output("graph-version", "data"), Output("language-indicator", "children"), ], [ Input("btn-send", "n_clicks"), Input("btn-clear", "n_clicks"), Input("btn-reset", "n_clicks"), Input("btn-new-chat", "n_clicks"), ], [ State("chat-input", "value"), State("selected-node-id", "data"), State("display-options", "value"), State("confidence-slider", "value"), State("graph-version", "data"), State("session-id", "data"), ], prevent_initial_call=True ) def handle_main_actions(send_clicks, clear_clicks, reset_clicks, new_clicks, input_text, selected_node, options, conf_threshold, version, session_id): """Handle main user actions.""" if not session_id: raise PreventUpdate state = get_user_state(session_id) triggered = ctx.triggered_id # Reset/Clear/New if triggered in ["btn-clear", "btn-reset", "btn-new-chat"]: state.reset() return ( [], [create_welcome_message()], "Ready โ€” Enter symptoms to begin", "", 0, "๐ŸŒ EN", ) # Send if triggered == "btn-send": if not input_text or not input_text.strip(): raise PreventUpdate # Detect language lang = detect_language(input_text) state.language = lang # Add user message state.add_message("user", input_text.strip()) # Generate reasoning try: context = state.engine.build_context(input_text, selected_node) context.language = lang config = GenerationConfig( model="gpt-4o-mini" if state.provider == "openai" else "local", language=lang ) response_content = "" node_count = 0 for node in state.engine.generate(context, config): node_count += 1 if node.node_type == NodeType.CONCLUSION: response_content = node.content # Debug: log graph connectivity stats stats = state.kg.get_stats() logger.info(f"Generation complete: {node_count} nodes generated, graph has {stats['nodes']} nodes and {stats['edges']} edges") state.add_message( "assistant", response_content or "Analysis complete. See the reasoning graph." ) except Exception as e: logger.error(f"Generation error: {e}") state.add_message("error", f"Analysis failed: {str(e)}") # Save state state.save() # Build response chat_display = build_chat_display(state.get_chat_history()) include_ghosts = "ghosts" in (options or []) elements = state.kg.to_cytoscape_elements( include_ghosts=include_ghosts, confidence_threshold=conf_threshold ) stats = state.kg.get_stats() stats_text = f"๐Ÿ“Š {stats['nodes']} nodes โ€ข {stats['edges']} edges" return ( elements, chat_display, stats_text, "", version + 1, f"๐ŸŒ {lang.upper()}", ) raise PreventUpdate @callback( Output("reasoning-graph", "elements", allow_duplicate=True), [Input("display-options", "value"), Input("confidence-slider", "value")], State("session-id", "data"), prevent_initial_call=True ) def update_display(options, threshold, session_id): """Update graph display options.""" if not session_id: raise PreventUpdate state = get_user_state(session_id) include_ghosts = "ghosts" in (options or []) return state.kg.to_cytoscape_elements( include_ghosts=include_ghosts, confidence_threshold=threshold ) @callback( [ Output("selected-node-info", "children"), Output("selected-node-id", "data"), Output("btn-prune", "disabled"), Output("btn-resurrect", "disabled"), Output("btn-inject", "disabled"), Output("btn-correct", "disabled"), Output("btn-incorrect", "disabled"), Output("btn-branch", "disabled"), ], Input("reasoning-graph", "tapNodeData"), ) def handle_node_click(node_data): """Handle node selection.""" if not node_data: return ( html.P("๐Ÿ‘† Click a node", className="text-muted small"), None, True, True, True, True, True, True ) node_id = node_data.get("id") node_type = node_data.get("type", "unknown") confidence = node_data.get("confidence", 0) content = node_data.get("content", node_data.get("full_label", "")) info = NODE_TYPE_INFO.get(NodeType(node_type), {"icon": "โ—", "name": "Unknown", "color": "#64748b"}) node_info = html.Div([ html.Div([ html.Span(info["icon"], className="me-2"), dbc.Badge(node_type.upper(), style={"backgroundColor": info["color"]}), html.Span(f" {confidence:.0%}", className="ms-2", style={"color": "#34d399" if confidence > 0.7 else "#facc15"}), ], className="mb-2"), html.Div(content[:200], style={"fontSize": "0.85rem", "color": "#e2e8f0"}), ]) is_ghost = node_type == "ghost" can_prune = node_type not in ["query", "ghost"] can_feedback = node_type in ["hypothesis", "conclusion", "reasoning"] can_branch = node_type in ["query", "hypothesis", "reasoning", "fact"] return ( node_info, node_id, not can_prune, not is_ghost, node_id is None, not can_feedback, not can_feedback, not can_branch, ) @callback( [ Output("reasoning-graph", "elements", allow_duplicate=True), Output("stats-display", "children", allow_duplicate=True), Output("feedback-status", "children"), Output("fact-input", "value"), ], [ Input("btn-prune", "n_clicks"), Input("btn-resurrect", "n_clicks"), Input("btn-inject", "n_clicks"), Input("btn-correct", "n_clicks"), Input("btn-incorrect", "n_clicks"), ], [ State("selected-node-id", "data"), State("fact-input", "value"), State("display-options", "value"), State("confidence-slider", "value"), State("session-id", "data"), ], prevent_initial_call=True ) def handle_actions(prune, resurrect, inject, correct, incorrect, selected_node, fact_text, options, threshold, session_id): """Handle steering actions.""" if not session_id: raise PreventUpdate state = get_user_state(session_id) triggered = ctx.triggered_id feedback_status = dash.no_update clear_fact_input = dash.no_update sm = get_session_manager() logger.info(f"Action triggered: {triggered}, selected_node: {selected_node}") if triggered == "btn-prune" and selected_node: result = state.synchronizer.prune_node(selected_node) sm.record_interaction(session_id, 'prune', node_id=selected_node) logger.info(f"Pruned node {selected_node}: {result}") feedback_status = html.Small(f"โœ‚๏ธ Pruned node", style={"color": "#f87171"}) elif triggered == "btn-resurrect" and selected_node: result = state.synchronizer.resurrect_node(selected_node) sm.record_interaction(session_id, 'resurrect', node_id=selected_node) logger.info(f"Resurrected node {selected_node}: {result}") feedback_status = html.Small(f"๐Ÿ‘ป Resurrected", style={"color": "#facc15"}) elif triggered == "btn-inject" and selected_node and fact_text: result = state.synchronizer.inject_fact(selected_node, fact_text) sm.record_interaction(session_id, 'inject', node_id=selected_node, content=fact_text) logger.info(f"Injected fact to {selected_node}: {result}") feedback_status = html.Small(f"โž• Fact injected", style={"color": "#4ade80"}) clear_fact_input = "" # Clear the input elif triggered == "btn-correct" and selected_node: state.synchronizer.record_feedback(selected_node, "correct") sm.add_feedback(session_id, selected_node, "correct") feedback_status = html.Small("โœ“ Marked correct", style={"color": "#4ade80"}) elif triggered == "btn-incorrect" and selected_node: state.synchronizer.record_feedback(selected_node, "incorrect") sm.add_feedback(session_id, selected_node, "incorrect") feedback_status = html.Small("โœ— Marked incorrect", style={"color": "#f87171"}) else: raise PreventUpdate state.save() include_ghosts = "ghosts" in (options or []) elements = state.kg.to_cytoscape_elements( include_ghosts=include_ghosts, confidence_threshold=threshold ) stats = state.kg.get_stats() return ( elements, f"๐Ÿ“Š {stats['nodes']} nodes โ€ข {stats['edges']} edges", feedback_status, clear_fact_input, ) # Store for branch anchor node _branch_anchor_store: Dict[str, str] = {} @callback( [Output("chat-input", "placeholder"), Output("chat-input", "value", allow_duplicate=True)], Input("btn-branch", "n_clicks"), [State("selected-node-id", "data"), State("session-id", "data")], prevent_initial_call=True ) def set_branch_anchor(n_clicks, selected_node, session_id): """Set anchor node for branching reasoning.""" if not n_clicks or not selected_node or not session_id: raise PreventUpdate _branch_anchor_store[session_id] = selected_node logger.info(f"Set branch anchor for session {session_id}: {selected_node}") return "Enter new reasoning to branch from selected node...", "" @callback( Output("reasoning-graph", "elements", allow_duplicate=True), Output("chat-history", "children", allow_duplicate=True), Output("stats-display", "children", allow_duplicate=True), Output("chat-input", "placeholder", allow_duplicate=True), Input("btn-send", "n_clicks"), [State("chat-input", "value"), State("display-options", "value"), State("confidence-slider", "value"), State("session-id", "data")], prevent_initial_call=True ) def handle_branch_send(n_clicks, input_text, options, threshold, session_id): """Handle sending with potential branch anchor.""" if not n_clicks or not input_text or not session_id: raise PreventUpdate # Check if there's a branch anchor set anchor_node = _branch_anchor_store.pop(session_id, None) if not anchor_node: # No anchor, let the main callback handle it raise PreventUpdate state = get_user_state(session_id) lang = detect_language(input_text) state.language = lang # Add user message state.add_message("user", f"[Branching from node] {input_text.strip()}") # Generate reasoning from anchor try: context = state.engine.build_context(input_text, anchor_node) context.language = lang context.is_branching = True config = GenerationConfig( model="gpt-4o-mini" if state.provider == "openai" else "local", language=lang ) response_content = "" for node in state.engine.generate(context, config): if node.node_type == NodeType.CONCLUSION: response_content = node.content state.add_message( "assistant", response_content or "Branch analysis complete. See the reasoning graph." ) except Exception as e: logger.error(f"Branch generation error: {e}") state.add_message("error", f"Branch failed: {str(e)}") state.save() chat_display = build_chat_display(state.get_chat_history()) include_ghosts = "ghosts" in (options or []) elements = state.kg.to_cytoscape_elements( include_ghosts=include_ghosts, confidence_threshold=threshold ) stats = state.kg.get_stats() return ( elements, chat_display, f"๐Ÿ“Š {stats['nodes']} nodes โ€ข {stats['edges']} edges", "Describe your symptoms...", ) @callback( Output("reasoning-graph", "layout"), [Input("btn-layout-dag", "n_clicks"), Input("btn-layout-force", "n_clicks"), Input("btn-layout-radial", "n_clicks")], prevent_initial_call=True ) def change_layout(dag, force, radial): """Change graph layout.""" layouts = { "btn-layout-dag": "hierarchical", "btn-layout-force": "force", "btn-layout-radial": "radial", } return LAYOUT_CONFIGS.get(layouts.get(ctx.triggered_id, "hierarchical")) @callback( Output("reasoning-graph", "zoom"), [Input("btn-zoom-in", "n_clicks"), Input("btn-zoom-out", "n_clicks"), Input("btn-zoom-fit", "n_clicks")], State("reasoning-graph", "zoom"), prevent_initial_call=True ) def handle_zoom(zoom_in, zoom_out, fit, current): """Handle zoom controls.""" current = current or 1.0 triggered = ctx.triggered_id if triggered == "btn-zoom-in": return min(current * 1.3, 3.0) elif triggered == "btn-zoom-out": return max(current * 0.7, 0.2) return 1.0 @callback( Output("cleanup-interval", "disabled"), Input("cleanup-interval", "n_intervals"), ) def periodic_cleanup(n): """Periodic cleanup.""" cleanup_user_states() get_session_manager().cleanup_stale_sessions() return False def build_chat_display(history: List[Dict]) -> List: """Build chat display from history.""" if not history: return [create_welcome_message()] display = [] for msg in history: role = msg.get("role", "user") content = msg.get("content", "") if role == "user": display.append(html.Div([ html.Span("You: ", style={"fontWeight": "600", "color": "#a5b4fc"}), content ], className="mb-2 p-2", style={"backgroundColor": "#1e3a5f", "borderRadius": "8px"})) elif role == "assistant": display.append(html.Div([ html.Span("๐Ÿค– ", className="me-1"), content ], className="mb-2 p-2", style={"backgroundColor": "#1e293b", "borderRadius": "8px"})) else: display.append(html.Div([ html.Span("โš ๏ธ ", className="me-1"), content ], className="mb-2 p-2", style={"backgroundColor": "#450a0a", "borderRadius": "8px"})) return display # Session history storage (simple in-memory for now) _session_history_storage: Dict[str, List[Dict]] = {} @callback( Output("session-history-list", "children"), [Input("btn-save-session", "n_clicks"), Input("btn-clear-history", "n_clicks"), Input({"type": "load-session", "index": ALL}, "n_clicks")], [State("session-id", "data"), State("chat-history", "children")], prevent_initial_call=True ) def handle_session_history(save_clicks, clear_clicks, load_clicks, session_id, chat_children): """Handle session history operations.""" global _session_history_storage triggered = ctx.triggered_id if not session_id: raise PreventUpdate # Initialize storage for this user if session_id not in _session_history_storage: _session_history_storage[session_id] = [] # Clear history if triggered == "btn-clear-history": _session_history_storage[session_id] = [] return [html.P("History cleared.", className="text-muted small text-center mt-3")] # Save current session if triggered == "btn-save-session": state = get_user_state(session_id) history = state.get_chat_history() if history: timestamp = datetime.now().strftime("%Y-%m-%d %H:%M") preview = history[0].get("content", "Empty")[:50] + "..." saved_session = { "timestamp": timestamp, "preview": preview, "messages": history, "graph_state": state.kg.get_state() } _session_history_storage[session_id].insert(0, saved_session) # Keep only last 10 sessions _session_history_storage[session_id] = _session_history_storage[session_id][:10] # Load session (handled by separate callback) if isinstance(triggered, dict) and triggered.get("type") == "load-session": # This is handled by load_session_callback pass # Build history list display sessions = _session_history_storage.get(session_id, []) if not sessions: return [html.P("No saved sessions yet.", className="text-muted small text-center mt-3")] items = [] for i, sess in enumerate(sessions): items.append(html.Div([ html.Div([ html.Small(sess["timestamp"], className="text-muted"), html.Div(sess["preview"], style={"fontSize": "0.85rem"}), ], style={"flex": "1"}), dbc.Button("Load", id={"type": "load-session", "index": i}, size="sm", color="info", outline=True), ], className="d-flex justify-content-between align-items-center p-2 mb-2", style={"backgroundColor": "#1e3a5f", "borderRadius": "6px"})) return items @callback( [Output("reasoning-graph", "elements", allow_duplicate=True), Output("chat-history", "children", allow_duplicate=True), Output("chat-tabs", "active_tab")], Input({"type": "load-session", "index": ALL}, "n_clicks"), State("session-id", "data"), prevent_initial_call=True ) def load_saved_session(clicks, session_id): """Load a saved session.""" if not any(clicks): raise PreventUpdate triggered = ctx.triggered_id if not isinstance(triggered, dict): raise PreventUpdate index = triggered.get("index") sessions = _session_history_storage.get(session_id, []) if index is None or index >= len(sessions): raise PreventUpdate saved = sessions[index] state = get_user_state(session_id) # Restore messages using session manager sm = get_session_manager() session = sm.get_or_create(session_id) session.chat_history.clear() for msg in saved["messages"]: state.add_message(msg["role"], msg["content"]) # Restore graph state if saved.get("graph_state"): state.kg.restore_state(saved["graph_state"]) # Build display chat_display = build_chat_display(saved["messages"]) elements = state.kg.to_cytoscape_elements() return elements, chat_display, "tab-chat" # ============================================================================ # MAIN # ============================================================================ if __name__ == "__main__": print("=" * 60) print(" โš•๏ธ HITL-KG Medical Reasoning System") print("=" * 60) print(f" ๐Ÿ”‘ OpenAI: {'โœ…' if OPENAI_API_KEY else 'โŒ Local mode'}") print(f" ๐ŸŒ Embeddings: Multilingual (50+ languages)") print(f" ๐Ÿ“Š Default provider: {DEFAULT_PROVIDER}") print(f" ๐Ÿš€ Starting at http://localhost:{config.port}") print("=" * 60) app.run( debug=config.debug, host=config.host, port=config.port, )