champ-chatbot / main.py
qyle's picture
no filter
ebab2ac verified
import asyncio
import logging
import os
from contextlib import asynccontextmanager
from typing import AsyncGenerator
import uuid
from codecarbon import EmissionsTracker
from ecologits import EcoLogits
import torch
from dotenv import load_dotenv
from fastapi import BackgroundTasks, FastAPI, File, Form, Request, Response, UploadFile
from fastapi.responses import HTMLResponse, StreamingResponse
from fastapi.staticfiles import StaticFiles
from fastapi.templating import Jinja2Templates
from opentelemetry import trace
from slowapi import Limiter
from slowapi.util import get_remote_address
from uvicorn.logging import DefaultFormatter
from classes.base_models import (
ChatRequest,
CommentRequest,
DeleteFileRequest,
FeedbackRequest,
)
from classes.pii_filter import PIIFilter
from classes.session_conversation_store import SessionConversationStore
from classes.session_document_store import SessionDocumentStore
from classes.session_tracker import SessionTracker
from constants import (
MAX_ID_LENGTH,
STATUS_CODE_EXCEED_SIZE_LIMIT,
STATUS_CODE_INTERNAL_SERVER_ERROR,
)
from exceptions import (
FILE_EXTRACTION_ERROR_STATUS_CODES,
FILE_VALIDATION_ERROR_STATUS_CODES,
FileExtractionException,
FileValidationException,
)
from helpers.dynamodb_helper import log_chat_event, log_environment_event
from helpers.file_helper import (
extract_text_from_file,
replace_spaces_in_filename,
validate_file,
)
from helpers.lifespan_helper import cleanup_loop, load_heavy_models, run_cleanup
from helpers.llm_helper import call_llm
from telemetry import setup_telemetry
load_dotenv()
logger = logging.getLogger("uvicorn")
# -------------------- Config --------------------
DEV = os.getenv("ENV", None) == "dev"
# -------------------- Helpers --------------------
# For now, conversations and uploaded documents are stored in RAM.
# This is tolerable for a demo, but we will have to switch to
# Redis (or another real-time database) at some point. We are
# currently storing sessions in what should be a stateless server.
session_tracker = SessionTracker()
session_document_store = SessionDocumentStore()
session_conversation_store = SessionConversationStore()
# -------------------- Environmental Impact --------------------
tracker = EmissionsTracker(
project_name="test", measure_power_secs=5, save_to_file=False
)
tracker.start()
logger.info(f"Detected hardware: {tracker.get_detected_hardware()}")
logger.info(f"Geographic metadata: {tracker._geo}")
def log_environment_infra():
gwp_emissions = tracker.flush()
try:
infra_data = {
"energy_kWh": tracker._total_energy.kWh,
"co2eq_kg": gwp_emissions,
"water_L": tracker._total_water.litres,
}
log_environment_event("infrastructure", infra_data)
except Exception as e:
logger.error(e)
async def environment_infra_loop():
"""Background task that runs forever while the app is alive."""
while True:
await asyncio.sleep(3600) # 1 hour
log_environment_infra()
# -------------------- FastAPI setup --------------------
@asynccontextmanager
async def lifespan(app: FastAPI):
# Setup logging
logger = logging.getLogger("uvicorn")
if logger.handlers:
colored_formatter = DefaultFormatter(
fmt="%(levelprefix)s %(asctime)s | %(message)s", datefmt="%Y-%m-%d %H:%M:%S"
)
logger.handlers[0].setFormatter(colored_formatter)
logger.info("Logging configured!")
if torch.cuda.is_available():
logger.info("CUDA is available")
else:
logger.warning("CUDA is NOT available")
# Setup heavy models
load_heavy_models()
# Setup Ecologits
EcoLogits.init(
providers=["huggingface_hub", "openai", "google_genai"],
electricity_mix_zone="USA",
)
# Setup CodeCarbon
environment_infra_bg_task = asyncio.create_task(environment_infra_loop())
# Setup cleanup loop
cleanup_bg_task = asyncio.create_task(
cleanup_loop(
session_tracker, session_document_store, session_conversation_store
)
)
yield
cleanup_bg_task.cancel()
environment_infra_bg_task.cancel()
app = FastAPI(lifespan=lifespan)
setup_telemetry(app)
app.mount("/static", StaticFiles(directory="static"), name="static")
templates = Jinja2Templates(directory="templates")
@app.middleware("http")
async def cleanup_middleware(request: Request, call_next):
run_cleanup(session_tracker, session_document_store, session_conversation_store)
response = await call_next(request)
return response
@app.get("/", response_class=HTMLResponse)
async def home(request: Request):
return templates.TemplateResponse(name="index.html", request=request)
# Time profiler
tracer = trace.get_tracer(__name__)
# Rate limiter
limiter = Limiter(key_func=get_remote_address)
@app.post("/chat")
@limiter.limit("450/minute")
async def chat_endpoint(
payload: ChatRequest, background_tasks: BackgroundTasks, request: Request
):
session_id = payload.session_id
model_type = payload.model_type
lang = payload.lang
conversation_id = payload.conversation_id
human_message = payload.human_message
session_tracker.update_session(session_id)
# pii_filter = PIIFilter()
# with tracer.start_as_current_span("sanitize_document"):
# pii_filtered_msg = pii_filter.sanitize(human_message)
pii_filtered_msg = human_message
conversation = session_conversation_store.add_human_message(
session_id, payload.conversation_id, pii_filtered_msg
)
document_contents = session_document_store.get_document_contents(session_id)
reply = ""
reply_id = str(uuid.uuid4())
gwp_kgcoeq = 0.0
triage_meta = {}
context = []
n_tokens = 0
try:
loop = asyncio.get_running_loop()
with tracer.start_as_current_span("call_llm"):
result = await loop.run_in_executor(
None, call_llm, model_type, lang, conversation, document_contents
)
if isinstance(result, AsyncGenerator):
async def logging_wrapper():
reply = ""
async for token in result:
reply += token
yield token
# Save the messages in DB
background_tasks.add_task(
log_chat_event,
user_id=payload.user_id,
session_id=payload.session_id,
data={
"model_type": payload.model_type,
"consent": payload.consent,
"human_message": pii_filtered_msg,
"reply": reply,
"reply_id": reply_id,
"age_group": payload.age_group,
"gender": payload.gender,
"roles": payload.roles,
"participant_id": payload.participant_id,
"conversation_id": payload.conversation_id,
"lang": payload.lang,
"triage_meta": {},
},
)
# Save the messages in session_conversation_store
background_tasks.add_task(
session_conversation_store.add_assistant_reply,
session_id=session_id,
conversation_id=conversation_id,
reply=reply,
)
return StreamingResponse(
logging_wrapper(),
media_type="text/event-stream",
headers={"X-Reply-ID": reply_id},
)
reply, gwp_kgcoeq, triage_meta, context, n_tokens = result
except Exception as e:
background_tasks.add_task(
log_chat_event,
user_id=payload.user_id,
session_id=payload.session_id,
data={
"error": str(e),
"model_type": payload.model_type,
"consent": payload.consent,
"human_message": pii_filtered_msg,
"age_group": payload.age_group,
"gender": payload.gender,
"roles": payload.roles,
"participant_id": payload.participant_id,
"conversation_id": payload.conversation_id,
"lang": payload.lang,
},
)
background_tasks.add_task(
log_chat_event,
user_id=payload.user_id,
session_id=payload.session_id,
data={
"model_type": payload.model_type,
"consent": payload.consent,
"human_message": pii_filtered_msg,
"reply": reply,
"reply_id": reply_id,
"context": context,
"age_group": payload.age_group,
"gender": payload.gender,
"roles": payload.roles,
"participant_id": payload.participant_id,
"conversation_id": payload.conversation_id,
"lang": payload.lang,
**(triage_meta or {}),
},
)
session_conversation_store.add_assistant_reply(session_id, conversation_id, reply)
return {
"reply": reply,
"reply_id": reply_id,
"gwp_kgcoeq": gwp_kgcoeq,
"n_tokens": n_tokens,
}
# Endpoint for specific replies/responses
@app.post("/feedback")
@limiter.limit("450/minute")
def feedback_endpoint(
payload: FeedbackRequest, background_tasks: BackgroundTasks, request: Request
):
background_tasks.add_task(
log_chat_event,
user_id=payload.user_id,
session_id=payload.session_id,
data={
"consent": payload.consent,
"comment": payload.comment,
"age_group": payload.age_group,
"gender": payload.gender,
"roles": payload.roles,
"participant_id": payload.participant_id,
"message_index": payload.message_index,
"rating": payload.rating,
"reply_content": payload.reply_content,
"reply_id": str(payload.reply_id),
},
)
# Endpoint for specific generic comments
@app.post("/comment")
@limiter.limit("450/minute")
def comment_endpoint(
payload: CommentRequest, background_tasks: BackgroundTasks, request: Request
):
logger.info("Received comment")
background_tasks.add_task(
log_chat_event,
user_id=payload.user_id,
session_id=payload.session_id,
data={
"consent": payload.consent,
"comment": payload.comment,
"age_group": payload.age_group,
"gender": payload.gender,
"roles": payload.roles,
"participant_id": payload.participant_id,
},
)
@app.put("/file")
@limiter.limit("12/minute")
async def upload_file(
request: Request,
file: UploadFile = File(...),
session_id: str = Form(
pattern="^[a-zA-Z0-9_-]+$", min_length=1, max_length=MAX_ID_LENGTH
),
):
try:
validated_file = await validate_file(file)
except FileValidationException as e:
status_code = FILE_VALIDATION_ERROR_STATUS_CODES[e.error]
return Response(status_code=status_code)
file_content = validated_file.content
file_name = validated_file.filename
file_mime = validated_file.mime_type
try:
file_text = await extract_text_from_file(file_content, file_mime)
except FileExtractionException as e:
status_code = FILE_EXTRACTION_ERROR_STATUS_CODES[e.error]
return Response(status_code=status_code)
except Exception:
# TODO: Log the unexpected failure
return Response(status_code=STATUS_CODE_INTERNAL_SERVER_ERROR)
pii_filter = PIIFilter()
with tracer.start_as_current_span("sanitize_document"):
pii_filtered_file_text = pii_filter.sanitize(file_text)
if session_document_store.create_document(
session_id, pii_filtered_file_text, file_name
):
session_tracker.update_session(session_id)
else:
return Response(status_code=STATUS_CODE_EXCEED_SIZE_LIMIT)
@app.delete("/file")
@limiter.limit("20/minute")
def delete_file(
payload: DeleteFileRequest,
request: Request,
):
session_id = payload.session_id
file_name = payload.file_name
file_name = replace_spaces_in_filename(file_name)
session_document_store.delete_document(session_id, file_name)
@app.post("/flush-environmental-infra-impact")
@limiter.limit("2/minute")
def flush_environmental_infra_impact(request: Request):
log_environment_infra()