import asyncio import logging import os from contextlib import asynccontextmanager from typing import AsyncGenerator import uuid from codecarbon import EmissionsTracker from ecologits import EcoLogits import torch from dotenv import load_dotenv from fastapi import BackgroundTasks, FastAPI, File, Form, Request, Response, UploadFile from fastapi.responses import HTMLResponse, StreamingResponse from fastapi.staticfiles import StaticFiles from fastapi.templating import Jinja2Templates from opentelemetry import trace from slowapi import Limiter from slowapi.util import get_remote_address from uvicorn.logging import DefaultFormatter from classes.base_models import ( ChatRequest, CommentRequest, DeleteFileRequest, FeedbackRequest, ) from classes.pii_filter import PIIFilter from classes.session_conversation_store import SessionConversationStore from classes.session_document_store import SessionDocumentStore from classes.session_tracker import SessionTracker from constants import ( MAX_ID_LENGTH, STATUS_CODE_EXCEED_SIZE_LIMIT, STATUS_CODE_INTERNAL_SERVER_ERROR, ) from exceptions import ( FILE_EXTRACTION_ERROR_STATUS_CODES, FILE_VALIDATION_ERROR_STATUS_CODES, FileExtractionException, FileValidationException, ) from helpers.dynamodb_helper import log_chat_event, log_environment_event from helpers.file_helper import ( extract_text_from_file, replace_spaces_in_filename, validate_file, ) from helpers.lifespan_helper import cleanup_loop, load_heavy_models, run_cleanup from helpers.llm_helper import call_llm from telemetry import setup_telemetry load_dotenv() logger = logging.getLogger("uvicorn") # -------------------- Config -------------------- DEV = os.getenv("ENV", None) == "dev" # -------------------- Helpers -------------------- # For now, conversations and uploaded documents are stored in RAM. # This is tolerable for a demo, but we will have to switch to # Redis (or another real-time database) at some point. We are # currently storing sessions in what should be a stateless server. session_tracker = SessionTracker() session_document_store = SessionDocumentStore() session_conversation_store = SessionConversationStore() # -------------------- Environmental Impact -------------------- tracker = EmissionsTracker( project_name="test", measure_power_secs=5, save_to_file=False ) tracker.start() logger.info(f"Detected hardware: {tracker.get_detected_hardware()}") logger.info(f"Geographic metadata: {tracker._geo}") def log_environment_infra(): gwp_emissions = tracker.flush() try: infra_data = { "energy_kWh": tracker._total_energy.kWh, "co2eq_kg": gwp_emissions, "water_L": tracker._total_water.litres, } log_environment_event("infrastructure", infra_data) except Exception as e: logger.error(e) async def environment_infra_loop(): """Background task that runs forever while the app is alive.""" while True: await asyncio.sleep(3600) # 1 hour log_environment_infra() # -------------------- FastAPI setup -------------------- @asynccontextmanager async def lifespan(app: FastAPI): # Setup logging logger = logging.getLogger("uvicorn") if logger.handlers: colored_formatter = DefaultFormatter( fmt="%(levelprefix)s %(asctime)s | %(message)s", datefmt="%Y-%m-%d %H:%M:%S" ) logger.handlers[0].setFormatter(colored_formatter) logger.info("Logging configured!") if torch.cuda.is_available(): logger.info("CUDA is available") else: logger.warning("CUDA is NOT available") # Setup heavy models load_heavy_models() # Setup Ecologits EcoLogits.init( providers=["huggingface_hub", "openai", "google_genai"], electricity_mix_zone="USA", ) # Setup CodeCarbon environment_infra_bg_task = asyncio.create_task(environment_infra_loop()) # Setup cleanup loop cleanup_bg_task = asyncio.create_task( cleanup_loop( session_tracker, session_document_store, session_conversation_store ) ) yield cleanup_bg_task.cancel() environment_infra_bg_task.cancel() app = FastAPI(lifespan=lifespan) setup_telemetry(app) app.mount("/static", StaticFiles(directory="static"), name="static") templates = Jinja2Templates(directory="templates") @app.middleware("http") async def cleanup_middleware(request: Request, call_next): run_cleanup(session_tracker, session_document_store, session_conversation_store) response = await call_next(request) return response @app.get("/", response_class=HTMLResponse) async def home(request: Request): return templates.TemplateResponse(name="index.html", request=request) # Time profiler tracer = trace.get_tracer(__name__) # Rate limiter limiter = Limiter(key_func=get_remote_address) @app.post("/chat") @limiter.limit("450/minute") async def chat_endpoint( payload: ChatRequest, background_tasks: BackgroundTasks, request: Request ): session_id = payload.session_id model_type = payload.model_type lang = payload.lang conversation_id = payload.conversation_id human_message = payload.human_message session_tracker.update_session(session_id) # pii_filter = PIIFilter() # with tracer.start_as_current_span("sanitize_document"): # pii_filtered_msg = pii_filter.sanitize(human_message) pii_filtered_msg = human_message conversation = session_conversation_store.add_human_message( session_id, payload.conversation_id, pii_filtered_msg ) document_contents = session_document_store.get_document_contents(session_id) reply = "" reply_id = str(uuid.uuid4()) gwp_kgcoeq = 0.0 triage_meta = {} context = [] n_tokens = 0 try: loop = asyncio.get_running_loop() with tracer.start_as_current_span("call_llm"): result = await loop.run_in_executor( None, call_llm, model_type, lang, conversation, document_contents ) if isinstance(result, AsyncGenerator): async def logging_wrapper(): reply = "" async for token in result: reply += token yield token # Save the messages in DB background_tasks.add_task( log_chat_event, user_id=payload.user_id, session_id=payload.session_id, data={ "model_type": payload.model_type, "consent": payload.consent, "human_message": pii_filtered_msg, "reply": reply, "reply_id": reply_id, "age_group": payload.age_group, "gender": payload.gender, "roles": payload.roles, "participant_id": payload.participant_id, "conversation_id": payload.conversation_id, "lang": payload.lang, "triage_meta": {}, }, ) # Save the messages in session_conversation_store background_tasks.add_task( session_conversation_store.add_assistant_reply, session_id=session_id, conversation_id=conversation_id, reply=reply, ) return StreamingResponse( logging_wrapper(), media_type="text/event-stream", headers={"X-Reply-ID": reply_id}, ) reply, gwp_kgcoeq, triage_meta, context, n_tokens = result except Exception as e: background_tasks.add_task( log_chat_event, user_id=payload.user_id, session_id=payload.session_id, data={ "error": str(e), "model_type": payload.model_type, "consent": payload.consent, "human_message": pii_filtered_msg, "age_group": payload.age_group, "gender": payload.gender, "roles": payload.roles, "participant_id": payload.participant_id, "conversation_id": payload.conversation_id, "lang": payload.lang, }, ) background_tasks.add_task( log_chat_event, user_id=payload.user_id, session_id=payload.session_id, data={ "model_type": payload.model_type, "consent": payload.consent, "human_message": pii_filtered_msg, "reply": reply, "reply_id": reply_id, "context": context, "age_group": payload.age_group, "gender": payload.gender, "roles": payload.roles, "participant_id": payload.participant_id, "conversation_id": payload.conversation_id, "lang": payload.lang, **(triage_meta or {}), }, ) session_conversation_store.add_assistant_reply(session_id, conversation_id, reply) return { "reply": reply, "reply_id": reply_id, "gwp_kgcoeq": gwp_kgcoeq, "n_tokens": n_tokens, } # Endpoint for specific replies/responses @app.post("/feedback") @limiter.limit("450/minute") def feedback_endpoint( payload: FeedbackRequest, background_tasks: BackgroundTasks, request: Request ): background_tasks.add_task( log_chat_event, user_id=payload.user_id, session_id=payload.session_id, data={ "consent": payload.consent, "comment": payload.comment, "age_group": payload.age_group, "gender": payload.gender, "roles": payload.roles, "participant_id": payload.participant_id, "message_index": payload.message_index, "rating": payload.rating, "reply_content": payload.reply_content, "reply_id": str(payload.reply_id), }, ) # Endpoint for specific generic comments @app.post("/comment") @limiter.limit("450/minute") def comment_endpoint( payload: CommentRequest, background_tasks: BackgroundTasks, request: Request ): logger.info("Received comment") background_tasks.add_task( log_chat_event, user_id=payload.user_id, session_id=payload.session_id, data={ "consent": payload.consent, "comment": payload.comment, "age_group": payload.age_group, "gender": payload.gender, "roles": payload.roles, "participant_id": payload.participant_id, }, ) @app.put("/file") @limiter.limit("12/minute") async def upload_file( request: Request, file: UploadFile = File(...), session_id: str = Form( pattern="^[a-zA-Z0-9_-]+$", min_length=1, max_length=MAX_ID_LENGTH ), ): try: validated_file = await validate_file(file) except FileValidationException as e: status_code = FILE_VALIDATION_ERROR_STATUS_CODES[e.error] return Response(status_code=status_code) file_content = validated_file.content file_name = validated_file.filename file_mime = validated_file.mime_type try: file_text = await extract_text_from_file(file_content, file_mime) except FileExtractionException as e: status_code = FILE_EXTRACTION_ERROR_STATUS_CODES[e.error] return Response(status_code=status_code) except Exception: # TODO: Log the unexpected failure return Response(status_code=STATUS_CODE_INTERNAL_SERVER_ERROR) pii_filter = PIIFilter() with tracer.start_as_current_span("sanitize_document"): pii_filtered_file_text = pii_filter.sanitize(file_text) if session_document_store.create_document( session_id, pii_filtered_file_text, file_name ): session_tracker.update_session(session_id) else: return Response(status_code=STATUS_CODE_EXCEED_SIZE_LIMIT) @app.delete("/file") @limiter.limit("20/minute") def delete_file( payload: DeleteFileRequest, request: Request, ): session_id = payload.session_id file_name = payload.file_name file_name = replace_spaces_in_filename(file_name) session_document_store.delete_document(session_id, file_name) @app.post("/flush-environmental-infra-impact") @limiter.limit("2/minute") def flush_environmental_infra_impact(request: Request): log_environment_infra()