File size: 5,194 Bytes
22fcf31 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 |
import os
import torch
import json
from typing import Dict, List, Union, Optional, Any
from PIL import Image
from transformers import AutoConfig, AutoTokenizer
from custom_st import Transformer
class InferenceEmbeddings:
def __init__(self, model_path: str):
"""
Initialize the embedding model for inference
Args:
model_path: Path to the model directory
"""
self.model_path = model_path
self.model = Transformer(
model_name_or_path=model_path,
model_args={"default_task": "retrieval", "trust_remote_code": True},
trust_remote_code=True
)
self.model.eval()
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
self.model.to(self.device)
def encode_text(self,
texts: List[str],
task: str = "retrieval",
prompt_name: Optional[str] = None,
truncate_dim: Optional[int] = None,
return_multivector: bool = False,
max_length: Optional[int] = None,
batch_size: int = 32) -> torch.Tensor:
"""
Encode text inputs to embeddings
Args:
texts: List of text inputs to encode
task: Task for which to generate embeddings (retrieval, text-matching, code)
prompt_name: Optional prompt type (query, passage)
truncate_dim: Optional dimension to truncate embeddings to
return_multivector: Whether to return multi-vector embeddings
max_length: Maximum token length
batch_size: Batch size for encoding
Returns:
Tensor of embeddings
"""
if prompt_name:
# Add prompt prefix based on prompt_name
if prompt_name == "query":
texts = [f"Query: {text}" for text in texts]
elif prompt_name == "passage":
texts = [f"Passage: {text}" for text in texts]
embeddings = []
for i in range(0, len(texts), batch_size):
batch_texts = texts[i:i+batch_size]
features = self.model.tokenize(batch_texts)
# Move tensors to device
for key, value in features.items():
if isinstance(value, torch.Tensor):
features[key] = value.to(self.device)
with torch.no_grad():
outputs = self.model.forward(features, task=task, truncate_dim=truncate_dim)
batch_embeddings = outputs.get("sentence_embedding", None)
if batch_embeddings is not None:
embeddings.append(batch_embeddings.cpu())
if embeddings:
return torch.cat(embeddings, dim=0)
else:
raise RuntimeError("Failed to generate embeddings")
def encode_image(self,
images: List[Union[str, Image.Image]],
task: str = "retrieval",
truncate_dim: Optional[int] = None,
return_multivector: bool = False,
max_pixels: Optional[int] = None,
batch_size: int = 8) -> torch.Tensor:
"""
Encode image inputs to embeddings
Args:
images: List of image inputs (file paths, URLs, or PIL Images)
task: Task for which to generate embeddings
truncate_dim: Optional dimension to truncate embeddings to
return_multivector: Whether to return multi-vector embeddings
max_pixels: Maximum number of pixels for image resizing
batch_size: Batch size for encoding
Returns:
Tensor of embeddings
"""
embeddings = []
for i in range(0, len(images), batch_size):
batch_images = images[i:i+batch_size]
features = self.model.tokenize(batch_images)
# Move tensors to device
for key, value in features.items():
if isinstance(value, torch.Tensor):
features[key] = value.to(self.device)
with torch.no_grad():
outputs = self.model.forward(features, task=task, truncate_dim=truncate_dim)
batch_embeddings = outputs.get("sentence_embedding", None)
if batch_embeddings is not None:
embeddings.append(batch_embeddings.cpu())
if embeddings:
return torch.cat(embeddings, dim=0)
else:
raise RuntimeError("Failed to generate embeddings")
def load_model(model_path: str):
"""
Load the embedding model for inference
Args:
model_path: Path to the model directory
Returns:
Loaded model instance
"""
return InferenceEmbeddings(model_path)
|