File size: 5,194 Bytes
22fcf31
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
import os
import torch
import json
from typing import Dict, List, Union, Optional, Any
from PIL import Image
from transformers import AutoConfig, AutoTokenizer
from custom_st import Transformer

class InferenceEmbeddings:
    def __init__(self, model_path: str):
        """

        Initialize the embedding model for inference

        

        Args:

            model_path: Path to the model directory

        """
        self.model_path = model_path
        self.model = Transformer(
            model_name_or_path=model_path,
            model_args={"default_task": "retrieval", "trust_remote_code": True},
            trust_remote_code=True
        )
        self.model.eval()
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.model.to(self.device)
        
    def encode_text(self, 

                   texts: List[str], 

                   task: str = "retrieval", 

                   prompt_name: Optional[str] = None,

                   truncate_dim: Optional[int] = None,

                   return_multivector: bool = False,

                   max_length: Optional[int] = None,

                   batch_size: int = 32) -> torch.Tensor:
        """

        Encode text inputs to embeddings

        

        Args:

            texts: List of text inputs to encode

            task: Task for which to generate embeddings (retrieval, text-matching, code)

            prompt_name: Optional prompt type (query, passage)

            truncate_dim: Optional dimension to truncate embeddings to

            return_multivector: Whether to return multi-vector embeddings

            max_length: Maximum token length

            batch_size: Batch size for encoding

            

        Returns:

            Tensor of embeddings

        """
        if prompt_name:
            # Add prompt prefix based on prompt_name
            if prompt_name == "query":
                texts = [f"Query: {text}" for text in texts]
            elif prompt_name == "passage":
                texts = [f"Passage: {text}" for text in texts]
                
        embeddings = []
        for i in range(0, len(texts), batch_size):
            batch_texts = texts[i:i+batch_size]
            features = self.model.tokenize(batch_texts)
            
            # Move tensors to device
            for key, value in features.items():
                if isinstance(value, torch.Tensor):
                    features[key] = value.to(self.device)
                    
            with torch.no_grad():
                outputs = self.model.forward(features, task=task, truncate_dim=truncate_dim)
                batch_embeddings = outputs.get("sentence_embedding", None)
                
                if batch_embeddings is not None:
                    embeddings.append(batch_embeddings.cpu())
                    
        if embeddings:
            return torch.cat(embeddings, dim=0)
        else:
            raise RuntimeError("Failed to generate embeddings")
            
    def encode_image(self,

                    images: List[Union[str, Image.Image]],

                    task: str = "retrieval",

                    truncate_dim: Optional[int] = None,

                    return_multivector: bool = False,

                    max_pixels: Optional[int] = None,

                    batch_size: int = 8) -> torch.Tensor:
        """

        Encode image inputs to embeddings

        

        Args:

            images: List of image inputs (file paths, URLs, or PIL Images)

            task: Task for which to generate embeddings

            truncate_dim: Optional dimension to truncate embeddings to

            return_multivector: Whether to return multi-vector embeddings

            max_pixels: Maximum number of pixels for image resizing

            batch_size: Batch size for encoding

            

        Returns:

            Tensor of embeddings

        """
        embeddings = []
        for i in range(0, len(images), batch_size):
            batch_images = images[i:i+batch_size]
            features = self.model.tokenize(batch_images)
            
            # Move tensors to device
            for key, value in features.items():
                if isinstance(value, torch.Tensor):
                    features[key] = value.to(self.device)
                    
            with torch.no_grad():
                outputs = self.model.forward(features, task=task, truncate_dim=truncate_dim)
                batch_embeddings = outputs.get("sentence_embedding", None)
                
                if batch_embeddings is not None:
                    embeddings.append(batch_embeddings.cpu())
                    
        if embeddings:
            return torch.cat(embeddings, dim=0)
        else:
            raise RuntimeError("Failed to generate embeddings")

def load_model(model_path: str):
    """

    Load the embedding model for inference

    

    Args:

        model_path: Path to the model directory

        

    Returns:

        Loaded model instance

    """
    return InferenceEmbeddings(model_path)