|
|
import os
|
|
|
import torch
|
|
|
import json
|
|
|
from typing import Dict, List, Union, Optional, Any
|
|
|
from PIL import Image
|
|
|
from transformers import AutoConfig, AutoTokenizer
|
|
|
from custom_st import Transformer
|
|
|
|
|
|
class InferenceEmbeddings:
|
|
|
def __init__(self, model_path: str):
|
|
|
"""
|
|
|
Initialize the embedding model for inference
|
|
|
|
|
|
Args:
|
|
|
model_path: Path to the model directory
|
|
|
"""
|
|
|
self.model_path = model_path
|
|
|
self.model = Transformer(
|
|
|
model_name_or_path=model_path,
|
|
|
model_args={"default_task": "retrieval", "trust_remote_code": True},
|
|
|
trust_remote_code=True
|
|
|
)
|
|
|
self.model.eval()
|
|
|
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
|
|
self.model.to(self.device)
|
|
|
|
|
|
def encode_text(self,
|
|
|
texts: List[str],
|
|
|
task: str = "retrieval",
|
|
|
prompt_name: Optional[str] = None,
|
|
|
truncate_dim: Optional[int] = None,
|
|
|
return_multivector: bool = False,
|
|
|
max_length: Optional[int] = None,
|
|
|
batch_size: int = 32) -> torch.Tensor:
|
|
|
"""
|
|
|
Encode text inputs to embeddings
|
|
|
|
|
|
Args:
|
|
|
texts: List of text inputs to encode
|
|
|
task: Task for which to generate embeddings (retrieval, text-matching, code)
|
|
|
prompt_name: Optional prompt type (query, passage)
|
|
|
truncate_dim: Optional dimension to truncate embeddings to
|
|
|
return_multivector: Whether to return multi-vector embeddings
|
|
|
max_length: Maximum token length
|
|
|
batch_size: Batch size for encoding
|
|
|
|
|
|
Returns:
|
|
|
Tensor of embeddings
|
|
|
"""
|
|
|
if prompt_name:
|
|
|
|
|
|
if prompt_name == "query":
|
|
|
texts = [f"Query: {text}" for text in texts]
|
|
|
elif prompt_name == "passage":
|
|
|
texts = [f"Passage: {text}" for text in texts]
|
|
|
|
|
|
embeddings = []
|
|
|
for i in range(0, len(texts), batch_size):
|
|
|
batch_texts = texts[i:i+batch_size]
|
|
|
features = self.model.tokenize(batch_texts)
|
|
|
|
|
|
|
|
|
for key, value in features.items():
|
|
|
if isinstance(value, torch.Tensor):
|
|
|
features[key] = value.to(self.device)
|
|
|
|
|
|
with torch.no_grad():
|
|
|
outputs = self.model.forward(features, task=task, truncate_dim=truncate_dim)
|
|
|
batch_embeddings = outputs.get("sentence_embedding", None)
|
|
|
|
|
|
if batch_embeddings is not None:
|
|
|
embeddings.append(batch_embeddings.cpu())
|
|
|
|
|
|
if embeddings:
|
|
|
return torch.cat(embeddings, dim=0)
|
|
|
else:
|
|
|
raise RuntimeError("Failed to generate embeddings")
|
|
|
|
|
|
def encode_image(self,
|
|
|
images: List[Union[str, Image.Image]],
|
|
|
task: str = "retrieval",
|
|
|
truncate_dim: Optional[int] = None,
|
|
|
return_multivector: bool = False,
|
|
|
max_pixels: Optional[int] = None,
|
|
|
batch_size: int = 8) -> torch.Tensor:
|
|
|
"""
|
|
|
Encode image inputs to embeddings
|
|
|
|
|
|
Args:
|
|
|
images: List of image inputs (file paths, URLs, or PIL Images)
|
|
|
task: Task for which to generate embeddings
|
|
|
truncate_dim: Optional dimension to truncate embeddings to
|
|
|
return_multivector: Whether to return multi-vector embeddings
|
|
|
max_pixels: Maximum number of pixels for image resizing
|
|
|
batch_size: Batch size for encoding
|
|
|
|
|
|
Returns:
|
|
|
Tensor of embeddings
|
|
|
"""
|
|
|
embeddings = []
|
|
|
for i in range(0, len(images), batch_size):
|
|
|
batch_images = images[i:i+batch_size]
|
|
|
features = self.model.tokenize(batch_images)
|
|
|
|
|
|
|
|
|
for key, value in features.items():
|
|
|
if isinstance(value, torch.Tensor):
|
|
|
features[key] = value.to(self.device)
|
|
|
|
|
|
with torch.no_grad():
|
|
|
outputs = self.model.forward(features, task=task, truncate_dim=truncate_dim)
|
|
|
batch_embeddings = outputs.get("sentence_embedding", None)
|
|
|
|
|
|
if batch_embeddings is not None:
|
|
|
embeddings.append(batch_embeddings.cpu())
|
|
|
|
|
|
if embeddings:
|
|
|
return torch.cat(embeddings, dim=0)
|
|
|
else:
|
|
|
raise RuntimeError("Failed to generate embeddings")
|
|
|
|
|
|
def load_model(model_path: str):
|
|
|
"""
|
|
|
Load the embedding model for inference
|
|
|
|
|
|
Args:
|
|
|
model_path: Path to the model directory
|
|
|
|
|
|
Returns:
|
|
|
Loaded model instance
|
|
|
"""
|
|
|
return InferenceEmbeddings(model_path)
|
|
|
|