|
|
from fastapi import FastAPI, Request, Response, HTTPException
|
|
|
from fastapi.responses import JSONResponse
|
|
|
from fastapi.middleware.cors import CORSMiddleware
|
|
|
import uvicorn
|
|
|
import torch
|
|
|
import json
|
|
|
import base64
|
|
|
from io import BytesIO
|
|
|
from PIL import Image
|
|
|
import requests
|
|
|
from typing import List, Dict, Any, Union, Optional
|
|
|
from pydantic import BaseModel, Field
|
|
|
import numpy as np
|
|
|
import os
|
|
|
|
|
|
|
|
|
from handler import ModelHandler
|
|
|
|
|
|
app = FastAPI(title="Embedding Model API")
|
|
|
|
|
|
|
|
|
app.add_middleware(
|
|
|
CORSMiddleware,
|
|
|
allow_origins=["*"],
|
|
|
allow_credentials=True,
|
|
|
allow_methods=["*"],
|
|
|
allow_headers=["*"],
|
|
|
)
|
|
|
|
|
|
|
|
|
model_handler = ModelHandler()
|
|
|
model_handler.initialize(None)
|
|
|
|
|
|
|
|
|
class TextInput(BaseModel):
|
|
|
text: str = Field(..., description="The text to generate embeddings for")
|
|
|
|
|
|
class ImageInput(BaseModel):
|
|
|
image: str = Field(..., description="URL or base64-encoded image to generate embeddings for")
|
|
|
|
|
|
class EmbeddingRequest(BaseModel):
|
|
|
inputs: List[Union[TextInput, ImageInput]] = Field(..., description="List of text or image inputs")
|
|
|
task: str = Field("retrieval", description="Task type: retrieval, text-matching, or code")
|
|
|
|
|
|
class EmbeddingResponse(BaseModel):
|
|
|
embeddings: List[List[float]] = Field(..., description="List of embeddings")
|
|
|
|
|
|
@app.get("/")
|
|
|
async def root():
|
|
|
return {"message": "Embedding Model API is running"}
|
|
|
|
|
|
@app.post("/embeddings", response_model=EmbeddingResponse)
|
|
|
async def create_embeddings(request: EmbeddingRequest):
|
|
|
try:
|
|
|
inputs = []
|
|
|
|
|
|
|
|
|
for item in request.inputs:
|
|
|
if hasattr(item, "text"):
|
|
|
inputs.append(item.text)
|
|
|
elif hasattr(item, "image"):
|
|
|
image_data = item.image
|
|
|
if image_data.startswith("http"):
|
|
|
|
|
|
response = requests.get(image_data)
|
|
|
image = Image.open(BytesIO(response.content)).convert("RGB")
|
|
|
elif image_data.startswith("data:image"):
|
|
|
|
|
|
image_b64 = image_data.split(",")[1]
|
|
|
image = Image.open(BytesIO(base64.b64decode(image_b64))).convert("RGB")
|
|
|
else:
|
|
|
raise HTTPException(status_code=400, detail="Invalid image format")
|
|
|
inputs.append(image)
|
|
|
|
|
|
|
|
|
features = model_handler.model.tokenize(inputs)
|
|
|
outputs = model_handler.model.forward(features, task=request.task)
|
|
|
embeddings = outputs.get("sentence_embedding", None)
|
|
|
|
|
|
if embeddings is None:
|
|
|
raise HTTPException(status_code=500, detail="Failed to generate embeddings")
|
|
|
|
|
|
|
|
|
embeddings_list = embeddings.cpu().numpy().tolist()
|
|
|
|
|
|
return {"embeddings": embeddings_list}
|
|
|
|
|
|
except Exception as e:
|
|
|
raise HTTPException(status_code=500, detail=str(e))
|
|
|
|
|
|
if __name__ == "__main__":
|
|
|
|
|
|
port = int(os.environ.get("PORT", 8000))
|
|
|
uvicorn.run(app, host="0.0.0.0", port=port)
|
|
|
|