import os import torch import json from typing import Dict, List, Union, Optional, Any from PIL import Image from transformers import AutoConfig, AutoTokenizer from custom_st import Transformer class InferenceEmbeddings: def __init__(self, model_path: str): """ Initialize the embedding model for inference Args: model_path: Path to the model directory """ self.model_path = model_path self.model = Transformer( model_name_or_path=model_path, model_args={"default_task": "retrieval", "trust_remote_code": True}, trust_remote_code=True ) self.model.eval() self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") self.model.to(self.device) def encode_text(self, texts: List[str], task: str = "retrieval", prompt_name: Optional[str] = None, truncate_dim: Optional[int] = None, return_multivector: bool = False, max_length: Optional[int] = None, batch_size: int = 32) -> torch.Tensor: """ Encode text inputs to embeddings Args: texts: List of text inputs to encode task: Task for which to generate embeddings (retrieval, text-matching, code) prompt_name: Optional prompt type (query, passage) truncate_dim: Optional dimension to truncate embeddings to return_multivector: Whether to return multi-vector embeddings max_length: Maximum token length batch_size: Batch size for encoding Returns: Tensor of embeddings """ if prompt_name: # Add prompt prefix based on prompt_name if prompt_name == "query": texts = [f"Query: {text}" for text in texts] elif prompt_name == "passage": texts = [f"Passage: {text}" for text in texts] embeddings = [] for i in range(0, len(texts), batch_size): batch_texts = texts[i:i+batch_size] features = self.model.tokenize(batch_texts) # Move tensors to device for key, value in features.items(): if isinstance(value, torch.Tensor): features[key] = value.to(self.device) with torch.no_grad(): outputs = self.model.forward(features, task=task, truncate_dim=truncate_dim) batch_embeddings = outputs.get("sentence_embedding", None) if batch_embeddings is not None: embeddings.append(batch_embeddings.cpu()) if embeddings: return torch.cat(embeddings, dim=0) else: raise RuntimeError("Failed to generate embeddings") def encode_image(self, images: List[Union[str, Image.Image]], task: str = "retrieval", truncate_dim: Optional[int] = None, return_multivector: bool = False, max_pixels: Optional[int] = None, batch_size: int = 8) -> torch.Tensor: """ Encode image inputs to embeddings Args: images: List of image inputs (file paths, URLs, or PIL Images) task: Task for which to generate embeddings truncate_dim: Optional dimension to truncate embeddings to return_multivector: Whether to return multi-vector embeddings max_pixels: Maximum number of pixels for image resizing batch_size: Batch size for encoding Returns: Tensor of embeddings """ embeddings = [] for i in range(0, len(images), batch_size): batch_images = images[i:i+batch_size] features = self.model.tokenize(batch_images) # Move tensors to device for key, value in features.items(): if isinstance(value, torch.Tensor): features[key] = value.to(self.device) with torch.no_grad(): outputs = self.model.forward(features, task=task, truncate_dim=truncate_dim) batch_embeddings = outputs.get("sentence_embedding", None) if batch_embeddings is not None: embeddings.append(batch_embeddings.cpu()) if embeddings: return torch.cat(embeddings, dim=0) else: raise RuntimeError("Failed to generate embeddings") def load_model(model_path: str): """ Load the embedding model for inference Args: model_path: Path to the model directory Returns: Loaded model instance """ return InferenceEmbeddings(model_path)