Krishna Indukuri
Upload 31 files
22fcf31 verified
from fastapi import FastAPI, Request, Response, HTTPException
from fastapi.responses import JSONResponse
from fastapi.middleware.cors import CORSMiddleware
import uvicorn
import torch
import json
import base64
from io import BytesIO
from PIL import Image
import requests
from typing import List, Dict, Any, Union, Optional
from pydantic import BaseModel, Field
import numpy as np
import os
# Import handler
from handler import ModelHandler
app = FastAPI(title="Embedding Model API")
# Add CORS middleware
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# Initialize model handler
model_handler = ModelHandler()
model_handler.initialize(None) # We'll handle context manually
# Define request/response models
class TextInput(BaseModel):
text: str = Field(..., description="The text to generate embeddings for")
class ImageInput(BaseModel):
image: str = Field(..., description="URL or base64-encoded image to generate embeddings for")
class EmbeddingRequest(BaseModel):
inputs: List[Union[TextInput, ImageInput]] = Field(..., description="List of text or image inputs")
task: str = Field("retrieval", description="Task type: retrieval, text-matching, or code")
class EmbeddingResponse(BaseModel):
embeddings: List[List[float]] = Field(..., description="List of embeddings")
@app.get("/")
async def root():
return {"message": "Embedding Model API is running"}
@app.post("/embeddings", response_model=EmbeddingResponse)
async def create_embeddings(request: EmbeddingRequest):
try:
inputs = []
# Process inputs
for item in request.inputs:
if hasattr(item, "text"):
inputs.append(item.text)
elif hasattr(item, "image"):
image_data = item.image
if image_data.startswith("http"):
# URL
response = requests.get(image_data)
image = Image.open(BytesIO(response.content)).convert("RGB")
elif image_data.startswith("data:image"):
# Base64
image_b64 = image_data.split(",")[1]
image = Image.open(BytesIO(base64.b64decode(image_b64))).convert("RGB")
else:
raise HTTPException(status_code=400, detail="Invalid image format")
inputs.append(image)
# Get embeddings
features = model_handler.model.tokenize(inputs)
outputs = model_handler.model.forward(features, task=request.task)
embeddings = outputs.get("sentence_embedding", None)
if embeddings is None:
raise HTTPException(status_code=500, detail="Failed to generate embeddings")
# Convert to list for JSON serialization
embeddings_list = embeddings.cpu().numpy().tolist()
return {"embeddings": embeddings_list}
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
if __name__ == "__main__":
# Run the API server
port = int(os.environ.get("PORT", 8000))
uvicorn.run(app, host="0.0.0.0", port=port)