File size: 3,323 Bytes
22fcf31
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
from fastapi import FastAPI, Request, Response, HTTPException
from fastapi.responses import JSONResponse
from fastapi.middleware.cors import CORSMiddleware
import uvicorn
import torch
import json
import base64
from io import BytesIO
from PIL import Image
import requests
from typing import List, Dict, Any, Union, Optional
from pydantic import BaseModel, Field
import numpy as np
import os

# Import handler
from handler import ModelHandler

app = FastAPI(title="Embedding Model API")

# Add CORS middleware
app.add_middleware(
    CORSMiddleware,
    allow_origins=["*"],
    allow_credentials=True,
    allow_methods=["*"],
    allow_headers=["*"],
)

# Initialize model handler
model_handler = ModelHandler()
model_handler.initialize(None)  # We'll handle context manually

# Define request/response models
class TextInput(BaseModel):
    text: str = Field(..., description="The text to generate embeddings for")

class ImageInput(BaseModel):
    image: str = Field(..., description="URL or base64-encoded image to generate embeddings for")

class EmbeddingRequest(BaseModel):
    inputs: List[Union[TextInput, ImageInput]] = Field(..., description="List of text or image inputs")
    task: str = Field("retrieval", description="Task type: retrieval, text-matching, or code")

class EmbeddingResponse(BaseModel):
    embeddings: List[List[float]] = Field(..., description="List of embeddings")

@app.get("/")
async def root():
    return {"message": "Embedding Model API is running"}

@app.post("/embeddings", response_model=EmbeddingResponse)
async def create_embeddings(request: EmbeddingRequest):
    try:
        inputs = []
        
        # Process inputs
        for item in request.inputs:
            if hasattr(item, "text"):
                inputs.append(item.text)
            elif hasattr(item, "image"):
                image_data = item.image
                if image_data.startswith("http"):
                    # URL
                    response = requests.get(image_data)
                    image = Image.open(BytesIO(response.content)).convert("RGB")
                elif image_data.startswith("data:image"):
                    # Base64
                    image_b64 = image_data.split(",")[1]
                    image = Image.open(BytesIO(base64.b64decode(image_b64))).convert("RGB")
                else:
                    raise HTTPException(status_code=400, detail="Invalid image format")
                inputs.append(image)
        
        # Get embeddings
        features = model_handler.model.tokenize(inputs)
        outputs = model_handler.model.forward(features, task=request.task)
        embeddings = outputs.get("sentence_embedding", None)
        
        if embeddings is None:
            raise HTTPException(status_code=500, detail="Failed to generate embeddings")
            
        # Convert to list for JSON serialization
        embeddings_list = embeddings.cpu().numpy().tolist()
        
        return {"embeddings": embeddings_list}
    
    except Exception as e:
        raise HTTPException(status_code=500, detail=str(e))

if __name__ == "__main__":
    # Run the API server
    port = int(os.environ.get("PORT", 8000))
    uvicorn.run(app, host="0.0.0.0", port=port)