IRMSEmbeddingsV4 / inference.py
Krishna Indukuri
Upload 31 files
22fcf31 verified
import os
import torch
import json
from typing import Dict, List, Union, Optional, Any
from PIL import Image
from transformers import AutoConfig, AutoTokenizer
from custom_st import Transformer
class InferenceEmbeddings:
def __init__(self, model_path: str):
"""
Initialize the embedding model for inference
Args:
model_path: Path to the model directory
"""
self.model_path = model_path
self.model = Transformer(
model_name_or_path=model_path,
model_args={"default_task": "retrieval", "trust_remote_code": True},
trust_remote_code=True
)
self.model.eval()
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
self.model.to(self.device)
def encode_text(self,
texts: List[str],
task: str = "retrieval",
prompt_name: Optional[str] = None,
truncate_dim: Optional[int] = None,
return_multivector: bool = False,
max_length: Optional[int] = None,
batch_size: int = 32) -> torch.Tensor:
"""
Encode text inputs to embeddings
Args:
texts: List of text inputs to encode
task: Task for which to generate embeddings (retrieval, text-matching, code)
prompt_name: Optional prompt type (query, passage)
truncate_dim: Optional dimension to truncate embeddings to
return_multivector: Whether to return multi-vector embeddings
max_length: Maximum token length
batch_size: Batch size for encoding
Returns:
Tensor of embeddings
"""
if prompt_name:
# Add prompt prefix based on prompt_name
if prompt_name == "query":
texts = [f"Query: {text}" for text in texts]
elif prompt_name == "passage":
texts = [f"Passage: {text}" for text in texts]
embeddings = []
for i in range(0, len(texts), batch_size):
batch_texts = texts[i:i+batch_size]
features = self.model.tokenize(batch_texts)
# Move tensors to device
for key, value in features.items():
if isinstance(value, torch.Tensor):
features[key] = value.to(self.device)
with torch.no_grad():
outputs = self.model.forward(features, task=task, truncate_dim=truncate_dim)
batch_embeddings = outputs.get("sentence_embedding", None)
if batch_embeddings is not None:
embeddings.append(batch_embeddings.cpu())
if embeddings:
return torch.cat(embeddings, dim=0)
else:
raise RuntimeError("Failed to generate embeddings")
def encode_image(self,
images: List[Union[str, Image.Image]],
task: str = "retrieval",
truncate_dim: Optional[int] = None,
return_multivector: bool = False,
max_pixels: Optional[int] = None,
batch_size: int = 8) -> torch.Tensor:
"""
Encode image inputs to embeddings
Args:
images: List of image inputs (file paths, URLs, or PIL Images)
task: Task for which to generate embeddings
truncate_dim: Optional dimension to truncate embeddings to
return_multivector: Whether to return multi-vector embeddings
max_pixels: Maximum number of pixels for image resizing
batch_size: Batch size for encoding
Returns:
Tensor of embeddings
"""
embeddings = []
for i in range(0, len(images), batch_size):
batch_images = images[i:i+batch_size]
features = self.model.tokenize(batch_images)
# Move tensors to device
for key, value in features.items():
if isinstance(value, torch.Tensor):
features[key] = value.to(self.device)
with torch.no_grad():
outputs = self.model.forward(features, task=task, truncate_dim=truncate_dim)
batch_embeddings = outputs.get("sentence_embedding", None)
if batch_embeddings is not None:
embeddings.append(batch_embeddings.cpu())
if embeddings:
return torch.cat(embeddings, dim=0)
else:
raise RuntimeError("Failed to generate embeddings")
def load_model(model_path: str):
"""
Load the embedding model for inference
Args:
model_path: Path to the model directory
Returns:
Loaded model instance
"""
return InferenceEmbeddings(model_path)