celarium / main.py
user9200's picture
Upload 4 files
6de83f6 verified
import uvicorn
import os
import uuid
import re
import random
import json
from datetime import datetime
from typing import Union, List, Dict, Any
from fastapi import FastAPI, HTTPException, Security
from fastapi.security import APIKeyHeader
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel
from faker import Faker
from gliner import GLiNER
app = FastAPI(title="Celarium AI")
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
api_key_header = APIKeyHeader(name="X-API-Key", auto_error=False)
VALID_API_KEYS = {"sk_test_celarium_founder_001", "sk_test_celarium_beta_001"}
SESSIONS = {}
fake = Faker()
# Load Model
print("Loading GLiNER...")
model = GLiNER.from_pretrained("urchade/gliner_small-v2.1")
print("Loaded.")
# Regex & Labels
REGEX_PATTERNS = {
"EMAIL_ADDRESS": r'\b[A-Za-z0-9._%+-]+@[A-Za-z0-9.-]+\.[A-Z|a-z]{2,}\b',
"PHONE_NUMBER": r'(?:\+?1[-.\s]?)?\(?\d{3}\)?[-.\s]?\d{3}[-.\s]?\d{4}',
"MRN": r'\bMRN[-_]\w+\b',
"SSN": r'\b\d{3}-\d{2}-\d{4}\b',
"INSURANCE_GROUP": r'\bG\d{5,}\b',
"INSURANCE_POLICY": r'\b(POL|POLICY)[-_]?\d+\b',
"FULL_ADDRESS": r'\d+\s+[A-Za-z0-9\s\.]+,\s+[A-Za-z\s\.]+,\s+[A-Z]{2}\s+\d{5}(?:-\d{4})?'
}
AI_LABELS = ["person", "physical address", "organization", "date of birth"]
# Generators
def generate_clean_name():
return f"{fake.first_name()} {fake.last_name()}"
def generate_matching_email(fake_name: str):
if not fake_name: return f"user{random.randint(1000, 9999)}@example.com"
parts = fake_name.lower().split()
base = f"{parts[0]}{parts[1]}" if len(parts) >= 2 else parts[0]
return f"{base}{random.randint(100, 9999)}@example.com"
# --- UPDATED GENERATORS ---
def generate_clean_phone():
"""Matches the requested format: +1-XXX-XXX-XXXX"""
return f"+1-{random.randint(200, 999)}-{random.randint(200, 999)}-{random.randint(1000, 9999)}"
def generate_medical_org():
"""Generates realistic Healthcare/Clinical names"""
suffixes = [
"Medical Center", "Regional Health", "General Hospital",
"Health Group", "Family Clinic", "Community Care",
"Medical Associates", "Health System", "Diagnostics Lab"
]
# 50% chance of City-based name (e.g. "Austin Regional Health")
# 50% chance of Name-based name (e.g. "Rivera Medical Group")
prefix = fake.city() if random.random() > 0.5 else fake.last_name()
return f"{prefix} {random.choice(suffixes)}"
def get_fake_value(label: str, context: dict) -> str:
label = label.upper()
if "PERSON" in label:
val = generate_clean_name()
context["last_person"] = val
return val
if "EMAIL" in label:
return generate_matching_email(context.get("last_person", ""))
if "PHONE" in label:
return generate_clean_phone() # <--- Uses new format
if "ADDRESS" in label or "LOCATION" in label:
# Fixes address leak by generating full block
return f"{fake.street_address()}, {fake.city()}, {fake.state_abbr()} {fake.zipcode()}"
if "MRN" in label:
return f"MRN-{fake.random_number(digits=8, fix_len=True)}"
if "SSN" in label:
return fake.ssn()
if "DATE" in label:
return str(fake.date_of_birth(minimum_age=18, maximum_age=90))
if "POLICY" in label:
return f"POL-{fake.random_number(digits=9, fix_len=True)}"
if "GROUP" in label:
return f"G{fake.random_number(digits=5, fix_len=True)}"
if "ORGANIZATION" in label:
return generate_medical_org() # <--- Uses new medical generator
return f"REDACTED_{uuid.uuid4().hex[:6]}"
def analyze_and_replace(text: str) -> (str, dict):
"""Core logic to anonymize a single string block"""
findings = []
# Regex
for label, pattern in REGEX_PATTERNS.items():
for match in re.finditer(pattern, text):
findings.append({"start": match.start(), "end": match.end(), "label": label, "score": 1.0})
# AI
try:
ai_preds = model.predict_entities(text, AI_LABELS, threshold=0.35)
for p in ai_preds:
findings.append({"start": p["start"], "end": p["end"], "label": p["label"], "score": p["score"]})
except:
pass
# Merge
findings.sort(key=lambda x: x["start"])
merged = []
for f in findings:
if not merged:
merged.append(f)
continue
last = merged[-1]
if f["start"] < last["end"]:
if f["score"] > last["score"] or (f["end"] - f["start"]) > (last["end"] - last["start"]):
merged[-1] = f
else:
merged.append(f)
# Generate Fakes
mapping = {}
replacements = []
context = {"last_person": ""}
used_fakes = set()
for ent in merged:
original = text[ent["start"]:ent["end"]]
# Skip JSON Keys
if original.lower() in ["person_name", "date_of_birth", "ssn", "mrn", "email", "phone", "address"]:
continue
fake_val = get_fake_value(ent["label"], context)
if fake_val in used_fakes:
fake_val = f"{fake_val}_{random.randint(1, 99)}"
used_fakes.add(fake_val)
mapping[fake_val] = original
replacements.append({"start": ent["start"], "end": ent["end"], "fake": fake_val})
# Replace
replacements.sort(key=lambda x: x["start"], reverse=True)
text_chars = list(text)
for r in replacements:
text_chars[r["start"]:r["end"]] = list(r["fake"])
return "".join(text_chars), mapping
# --- ENDPOINTS ---
async def get_api_key(api_key: str = Security(api_key_header)):
if not api_key or api_key not in VALID_API_KEYS:
raise HTTPException(401, "Invalid API Key")
return api_key
class AnonymizeRequest(BaseModel):
text: Union[str, List[Any], Dict[str, Any]]
class RestoreRequest(BaseModel):
session_id: str
text: str
@app.post("/v1/anonymize")
async def anonymize(req: AnonymizeRequest, api_key: str = Security(get_api_key)):
input_data = req.text
global_mapping = {}
final_output_str = ""
# LOGIC: Handle List vs Single String
if isinstance(input_data, list):
# Process each item individually to avoid Token Limit
anonymized_list = []
for item in input_data:
item_str = json.dumps(item)
anon_str, item_map = analyze_and_replace(item_str)
anonymized_list.append(json.loads(anon_str)) # Convert back to dict
global_mapping.update(item_map)
# Return as formatted JSON string
final_output_str = json.dumps(anonymized_list, indent=2)
else:
# Single object or string
text_to_process = json.dumps(input_data) if isinstance(input_data, dict) else str(input_data)
final_output_str, global_mapping = analyze_and_replace(text_to_process)
session_id = str(uuid.uuid4())
SESSIONS[session_id] = {"mapping": global_mapping, "created": datetime.now(), "api_key": api_key}
return {
"anonymized_text": final_output_str,
"session_id": session_id,
"entities_found": len(global_mapping)
}
@app.post("/v1/restore")
async def restore(req: RestoreRequest, api_key: str = Security(get_api_key)):
session = SESSIONS.get(req.session_id)
if not session or session["api_key"] != api_key:
raise HTTPException(404, "Session not found")
restored = req.text
for fake_v, real_v in session["mapping"].items():
restored = restored.replace(fake_v, real_v)
return {"restored_text": restored}
if __name__ == "__main__":
port = int(os.getenv("PORT", 8000))
uvicorn.run(app, host="0.0.0.0", port=port)