qqc1989 commited on
Commit
6bced81
·
verified ·
1 Parent(s): c7b9cf2

Upload 18 files

Browse files
.gitattributes CHANGED
@@ -42,3 +42,5 @@ main_api_ax650 filter=lfs diff=lfs merge=lfs -text
42
  main_api_axcl_x86 filter=lfs diff=lfs merge=lfs -text
43
  main_ax650 filter=lfs diff=lfs merge=lfs -text
44
  main_axcl_x86 filter=lfs diff=lfs merge=lfs -text
 
 
 
42
  main_api_axcl_x86 filter=lfs diff=lfs merge=lfs -text
43
  main_ax650 filter=lfs diff=lfs merge=lfs -text
44
  main_axcl_x86 filter=lfs diff=lfs merge=lfs -text
45
+ main_api_axcl_aarch64 filter=lfs diff=lfs merge=lfs -text
46
+ main_axcl_aarch64 filter=lfs diff=lfs merge=lfs -text
gradio_demo_c_api.py ADDED
@@ -0,0 +1,127 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import mimetypes
2
+ import os
3
+ import gradio as gr
4
+ import requests
5
+ import json, time
6
+
7
+ base_url = "http://10.126.33.235:8000"
8
+
9
+ def upload_image(file_path):
10
+ if file_path is None:
11
+ return None
12
+ # Gradio File component returns a tempfile-like object
13
+ # file_path = image.name
14
+ filename = os.path.basename(file_path)
15
+ # Guess MIME type
16
+ mime_type, _ = mimetypes.guess_type(filename)
17
+ mime_type = mime_type or 'application/octet-stream'
18
+ # Open file in binary mode for upload
19
+ with open(file_path, 'rb') as f:
20
+ file_bytes = f.read()
21
+ # Prepare multipart form data
22
+ files = {
23
+ 'image': (filename, file_bytes, mime_type)
24
+ }
25
+ # Send to upload endpoint
26
+ resp = requests.post(
27
+ f'{base_url}/api/upload',
28
+ files=files
29
+ )
30
+ resp.raise_for_status()
31
+ data = resp.json()
32
+ return data.get('file_path')
33
+
34
+ def stop_generation():
35
+ try:
36
+ requests.get(f'{base_url}/api/stop')
37
+ except:
38
+ pass
39
+
40
+ def respond(prompt, image:gr.Image, temp, rep_penalty, tp, tk, history=None):
41
+ if history is None:
42
+ history = []
43
+ if not prompt.strip():
44
+ return history
45
+ # append empty response to history
46
+ if image is None:
47
+ file_path = None
48
+ else:
49
+ file_path = upload_image(image)
50
+ history.append((f'![]({file_path})', None))
51
+ relative_path = os.path.relpath(file_path)
52
+ # html = f"<img src='{relative_path}' style='max-width:300px;'/>"
53
+ # history.append((html, None))
54
+ # print(relative_path)
55
+
56
+
57
+ history.append((prompt, ""))
58
+ yield history
59
+ # stream updates
60
+
61
+
62
+ payload = {
63
+ "prompt": prompt,
64
+ "temperature": temp,
65
+ "repetition_penalty": rep_penalty,
66
+ "top-p": tp,
67
+ "top-k": tk
68
+ }
69
+ if file_path:
70
+ payload["file_path"] = file_path
71
+
72
+ response = requests.post(
73
+ f'{base_url}/api/generate',
74
+ json=payload
75
+ )
76
+ response.raise_for_status()
77
+
78
+
79
+ while True:
80
+ time.sleep(0.01)
81
+ response = requests.get(
82
+ f'{base_url}/api/generate_provider'
83
+ )
84
+ data = response.json()
85
+ chunk:str = data.get("response", "")
86
+ done = data.get("done", False)
87
+ if done:
88
+ break
89
+ if chunk.strip() == "":
90
+ continue
91
+ history[-1] = (prompt, history[-1][1] + chunk)
92
+ yield history
93
+
94
+ print("end")
95
+
96
+
97
+
98
+
99
+ def chat_interface():
100
+ with gr.Blocks(theme=gr.themes.Soft(font="Consolas"), fill_width=True) as demo:
101
+ gr.Markdown("## Chat with LLM\nUpload an image and chat with the model!")
102
+ with gr.Row():
103
+ image = gr.Image(label="Upload Image", type="filepath")
104
+ with gr.Column(scale=3):
105
+ chatbot = gr.Chatbot(height=600)
106
+ prompt = gr.Textbox(placeholder="Type your message...", label="Prompt", value="描述一下这张图片")
107
+ with gr.Row():
108
+ btn_chat = gr.Button("Chat", variant="primary")
109
+ btn_stop = gr.Button("Stop", variant="stop")
110
+
111
+ with gr.Column(scale=1):
112
+ temperature = gr.Slider(minimum=0.0, maximum=1.0, step=0.01, value=0.7, label="Temperature")
113
+ repetition_penalty = gr.Slider(minimum=1.0, maximum=2.0, step=0.01, value=1.0, label="Repetition Penalty")
114
+ top_p = gr.Slider(minimum=0.0, maximum=1.0, step=0.01, value=0.9, label="Top-p Sampling")
115
+ top_k = gr.Slider(minimum=0, maximum=100, step=1, value=40, label="Top-k Sampling")
116
+
117
+ btn_stop.click(fn=stop_generation, inputs=None, outputs=None)
118
+ btn_chat.click(
119
+ fn=respond,
120
+ inputs=[prompt, image, temperature, repetition_penalty, top_p, top_k, chatbot],
121
+ outputs=chatbot
122
+ )
123
+
124
+ demo.launch(server_name="0.0.0.0", server_port=7860)
125
+
126
+ if __name__ == "__main__":
127
+ chat_interface()
gradio_demo_python_api.py ADDED
@@ -0,0 +1,129 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import cv2
2
+ import gradio as gr
3
+ from llm import LLM
4
+ import llm as llm
5
+ import argparse
6
+ import socket
7
+
8
+ parser = argparse.ArgumentParser(description="Model configuration parameters")
9
+ parser.add_argument("--hf_model", type=str, default="./InternVL3-2B",
10
+ help="Path to HuggingFace model")
11
+ parser.add_argument("--axmodel_path", type=str, default="./InternVL3-2B_axmodel",
12
+ help="Path to save compiled axmodel of llama model")
13
+ parser.add_argument("--vit_model", type=str, default=None,
14
+ help="Path to save compiled axmodel of llama model")
15
+ args = parser.parse_args()
16
+
17
+ hf_model_path = args.hf_model
18
+ axmodel_path = args.axmodel_path
19
+ vit_axmodel_path = args.vit_model
20
+
21
+ gllm = LLM(hf_model_path, axmodel_path, vit_axmodel_path)
22
+
23
+ # 获取本地 IP 地址
24
+ def get_local_ip():
25
+ try:
26
+ s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
27
+ s.connect(("8.8.8.8", 80))
28
+ ip = s.getsockname()[0]
29
+ s.close()
30
+ return ip
31
+ except Exception:
32
+ return "127.0.0.1"
33
+
34
+ def stop_generation():
35
+ gllm.stop_generate()
36
+
37
+ def respond(prompt, video, image, is_image, video_segments, image_segments_cols, image_segments_rows, history=None):
38
+ if history is None:
39
+ history = []
40
+ if not prompt.strip():
41
+ return history
42
+ # append empty response to history
43
+
44
+ gllm.tag = "video" if not is_image else "image"
45
+
46
+ history.append((prompt, ""))
47
+ yield history
48
+
49
+ print(video)
50
+ print(image)
51
+
52
+ if is_image:
53
+ img = cv2.imread(image)
54
+ images_list = []
55
+ if image_segments_cols == 1 and image_segments_rows == 1:
56
+ images_list.append(img)
57
+ elif image_segments_cols * image_segments_rows > 8:
58
+ # gr.Error("image segments cols * image segments rows > 8")
59
+ history[-1] = (prompt, history[-1][1] + "image segments cols * image segments rows > 8")
60
+ yield history
61
+ return
62
+ else:
63
+ height, width, _ = img.shape
64
+ segment_width = width // image_segments_cols
65
+ segment_height = height // image_segments_rows
66
+ for i in range(image_segments_rows):
67
+ for j in range(image_segments_cols):
68
+ x1 = j * segment_width
69
+ y1 = i * segment_height
70
+ x2 = (j + 1) * segment_width
71
+ y2 = (i + 1) * segment_height
72
+ segment = img[y1:y2, x1:x2]
73
+ images_list.append(segment)
74
+ else:
75
+ images_list = llm.load_video_opencv(video, num_segments = video_segments)
76
+
77
+ for msg in gllm.generate(images_list, prompt):
78
+ print(msg, end="", flush=True)
79
+ history[-1] = (prompt, history[-1][1] + msg)
80
+ yield history
81
+ print("\n\n\n")
82
+
83
+ def chat_interface():
84
+ with gr.Blocks() as demo:
85
+ gr.Markdown("## Raspberry Pi 5 InternVL3 Chat DEMO using AXCL\nUpload an image or video and chat with the InternVL3.")
86
+ with gr.Row():
87
+ with gr.Column(scale=1):
88
+ video = gr.Video(label="Upload Video", format="mp4")
89
+ video_segments = gr.Slider(minimum=2, maximum=8, step=1, value=4, label="video segments")
90
+ image = gr.Image(label="Upload Image", type="filepath")
91
+ image_segments_cols = gr.Slider(minimum=1, maximum=4, step=1, value=1, label="image cols segments")
92
+ image_segments_rows = gr.Slider(minimum=1, maximum=4, step=1, value=1, label="image rows segments")
93
+ checkbox = gr.Checkbox(label="Use Image")
94
+ with gr.Column(scale=3):
95
+ chatbot = gr.Chatbot(height=650)
96
+ prompt = gr.Textbox(placeholder="Type your message...", label="Prompt", value="描述一下这组图片")
97
+ with gr.Row():
98
+ btn_chat = gr.Button("Chat", variant="primary")
99
+ btn_stop = gr.Button("Stop", variant="stop")
100
+
101
+ btn_stop.click(fn=stop_generation, inputs=None, outputs=None)
102
+
103
+ btn_chat.click(
104
+ fn=respond,
105
+ inputs=[prompt, video, image, checkbox, video_segments, image_segments_cols, image_segments_rows, chatbot],
106
+ outputs=chatbot
107
+ )
108
+
109
+ def on_video_uploaded(video):
110
+ if video is not None:
111
+ return gr.update(value=False)
112
+ return gr.update()
113
+
114
+ def on_image_uploaded(image):
115
+ if image is not None:
116
+ return gr.update(value=True)
117
+ return gr.update()
118
+
119
+ video.change(fn=on_video_uploaded, inputs=video, outputs=checkbox)
120
+ image.change(fn=on_image_uploaded, inputs=image, outputs=checkbox)
121
+
122
+
123
+ local_ip = get_local_ip()
124
+ server_port = 7860
125
+ print(f"HTTP 服务地址: http://{local_ip}:{server_port}")
126
+ demo.launch(server_name=local_ip, server_port=server_port)
127
+
128
+ if __name__ == "__main__":
129
+ chat_interface()
llm.py ADDED
@@ -0,0 +1,457 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import cv2
2
+ from transformers import AutoTokenizer, AutoConfig
3
+ import numpy as np
4
+ from ml_dtypes import bfloat16
5
+ from axengine import InferenceSession
6
+ from tqdm import tqdm
7
+ # from decord import VideoReader
8
+
9
+ def img_preprocess(img, input_size):
10
+ IMAGENET_MEAN = np.array([0.485, 0.456, 0.406], dtype=np.float32)
11
+ IMAGENET_STD = np.array((0.229, 0.224, 0.225), dtype=np.float32)
12
+ img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
13
+ img = cv2.resize(img, (input_size, input_size))
14
+ img = img.astype(np.float32) / 255.0
15
+ img = (img - IMAGENET_MEAN) / IMAGENET_STD
16
+ img = img.transpose(2, 0, 1).reshape(1, 3, input_size, input_size)
17
+ return img
18
+
19
+ def find_closest_aspect_ratio(aspect_ratio, target_ratios, width, height, image_size):
20
+ best_ratio_diff = float('inf')
21
+ best_ratio = (1, 1)
22
+ area = width * height
23
+ for ratio in target_ratios:
24
+ target_aspect_ratio = ratio[0] / ratio[1]
25
+ ratio_diff = abs(aspect_ratio - target_aspect_ratio)
26
+ if ratio_diff < best_ratio_diff:
27
+ best_ratio_diff = ratio_diff
28
+ best_ratio = ratio
29
+ elif ratio_diff == best_ratio_diff:
30
+ if area > 0.5 * image_size * image_size * ratio[0] * ratio[1]:
31
+ best_ratio = ratio
32
+ return best_ratio
33
+
34
+ def dynamic_preprocess(image:np.array, min_num=1, max_num=12, image_size=448, use_thumbnail=False):
35
+ orig_height, orig_width, = image.shape[:2]
36
+ aspect_ratio = orig_width / orig_height
37
+
38
+ # calculate the existing image aspect ratio
39
+ target_ratios = set(
40
+ (i, j) for n in range(min_num, max_num + 1) for i in range(1, n + 1) for j in range(1, n + 1) if
41
+ i * j <= max_num and i * j >= min_num)
42
+ target_ratios = sorted(target_ratios, key=lambda x: x[0] * x[1])
43
+
44
+ # find the closest aspect ratio to the target
45
+ target_aspect_ratio = find_closest_aspect_ratio(
46
+ aspect_ratio, target_ratios, orig_width, orig_height, image_size)
47
+
48
+ # calculate the target width and height
49
+ target_width = image_size * target_aspect_ratio[0]
50
+ target_height = image_size * target_aspect_ratio[1]
51
+ blocks = target_aspect_ratio[0] * target_aspect_ratio[1]
52
+
53
+ # resize the image
54
+ # resized_img = image.resize((target_width, target_height))
55
+ resized_img = cv2.resize(image, (target_width, target_height))
56
+ processed_images = []
57
+ for i in range(blocks):
58
+ box = (
59
+ (i % (target_width // image_size)) * image_size,
60
+ (i // (target_width // image_size)) * image_size,
61
+ ((i % (target_width // image_size)) + 1) * image_size,
62
+ ((i // (target_width // image_size)) + 1) * image_size
63
+ )
64
+ # split the image
65
+ # split_img = resized_img.crop(box)
66
+ split_img = resized_img[box[1]:box[3], box[0]:box[2]]
67
+ processed_images.append(split_img)
68
+ assert len(processed_images) == blocks
69
+ if use_thumbnail and len(processed_images) != 1:
70
+ # thumbnail_img = image.resize((image_size, image_size))
71
+ thumbnail_img = cv2.resize(image, (image_size, image_size))
72
+ processed_images.append(thumbnail_img)
73
+ return processed_images
74
+
75
+ def pre_process(image, input_size=448, max_num=12):
76
+ images = dynamic_preprocess(image, image_size=input_size, use_thumbnail=True, max_num=max_num)
77
+ pixel_values = [img_preprocess(image, input_size) for image in images]
78
+ pixel_values = np.concatenate(pixel_values, axis=0)
79
+ return pixel_values
80
+
81
+ def get_index(bound, fps, max_frame, first_idx=0, num_segments=32):
82
+ if bound:
83
+ start, end = bound[0], bound[1]
84
+ else:
85
+ start, end = -100000, 100000
86
+ start_idx = max(first_idx, round(start * fps))
87
+ end_idx = min(round(end * fps), max_frame)
88
+ seg_size = float(end_idx - start_idx) / num_segments
89
+ frame_indices = np.array([
90
+ int(start_idx + (seg_size / 2) + np.round(seg_size * idx))
91
+ for idx in range(num_segments)
92
+ ])
93
+ return frame_indices
94
+
95
+
96
+ def load_video_opencv(video_path, bound=None, num_segments=32):
97
+ cap = cv2.VideoCapture(video_path)
98
+ if not cap.isOpened():
99
+ raise IOError(f"Cannot open video: {video_path}")
100
+
101
+ fps = cap.get(cv2.CAP_PROP_FPS)
102
+ max_frame = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) - 1
103
+
104
+ frame_indices = get_index(bound, fps, max_frame, first_idx=0, num_segments=num_segments)
105
+
106
+ images_list = []
107
+ for frame_index in frame_indices:
108
+ cap.set(cv2.CAP_PROP_POS_FRAMES, frame_index)
109
+ ret, frame = cap.read()
110
+ if not ret:
111
+ print(f"⚠ Failed to read frame {frame_index}")
112
+ continue
113
+ images_list.append(frame)
114
+
115
+ cap.release()
116
+ return images_list
117
+
118
+
119
+ def is_video_file(path):
120
+ return str(path).lower().endswith((".mp4", ".avi", ".mov", ".mkv", ".webm"))
121
+
122
+ def is_image_file(path):
123
+ return str(path).lower().endswith((".jpg", ".png", ".jpeg", ".webp"))
124
+
125
+ def load_image(path):
126
+ image = cv2.imread(str(path))
127
+ if image is None:
128
+ raise ValueError(f"Image {path} not found or cannot be read.")
129
+ return image
130
+
131
+ def post_process(data, topk=1, topp=0.9, temperature=0.6):
132
+ def top_p(l: np.ndarray, p: float) -> np.ndarray:
133
+ index = np.argsort(l)
134
+ res = l.copy()
135
+ sum_p = 0
136
+ for i in index[::-1]:
137
+ if sum_p >= p:
138
+ res[i] = 0
139
+ sum_p += res[i]
140
+ return res / sum_p
141
+
142
+ def softmax(l: np.ndarray) -> np.ndarray:
143
+ l_max = l - l.max()
144
+ l_exp = np.exp(l_max)
145
+ res = l_exp / np.sum(l_exp)
146
+ return res.astype(np.float64)
147
+
148
+ r = data.astype(np.float32)
149
+ r = r.flatten()
150
+ # topk
151
+ candidate_index = np.argpartition(r, -topk)[-topk:]
152
+ candidate_value = r[candidate_index]
153
+ # temperature
154
+ candidate_value /= temperature
155
+ # softmax
156
+ candidate_soft = softmax(candidate_value)
157
+ # topp
158
+ candidate_soft = top_p(candidate_soft, topp)
159
+ candidate_soft = candidate_soft.astype(np.float64) / candidate_soft.sum()
160
+ pos = np.random.multinomial(1, candidate_soft).argmax()
161
+ next_token = candidate_index[pos]
162
+ return next_token, candidate_index, candidate_soft
163
+
164
+
165
+ class LLM:
166
+
167
+ def __init__(self, hf_model_path, axmodel_path, vit_axmodel_path ):
168
+ self.hf_model_path = hf_model_path
169
+ self.tag = "image"
170
+
171
+ config = AutoConfig.from_pretrained(hf_model_path, trust_remote_code=True)
172
+ self.tokenizer = AutoTokenizer.from_pretrained(hf_model_path, trust_remote_code=True, use_fast=False)
173
+ self.cfg = config.llm_config
174
+
175
+ self.prefill_slice_len=128
176
+ self.kv_cache_len=2559
177
+
178
+ self.prefill_decoder_sessins = []
179
+ for i in tqdm(range(self.cfg.num_hidden_layers), desc="Init InferenceSession"):
180
+ session = InferenceSession(
181
+ f"{axmodel_path}/qwen2_p128_l{i}_together.axmodel"
182
+ )
183
+ self.prefill_decoder_sessins.append(session)
184
+
185
+ self.post_process_session = InferenceSession(
186
+ f"{axmodel_path}/qwen2_post.axmodel"
187
+ )
188
+ print("model load done!")
189
+
190
+ self.kv_dim = self.cfg.hidden_size // self.cfg.num_attention_heads * self.cfg.num_key_value_heads
191
+
192
+
193
+ self.vit_session = InferenceSession(vit_axmodel_path)
194
+
195
+ self.embeds = np.load(f"{axmodel_path}/model.embed_tokens.weight.npy")
196
+
197
+ self.stop = False
198
+
199
+ def stop_generate(self):
200
+ self.stop = True
201
+
202
+ def image_encode(self, images_list):
203
+ pixel_values_list = []
204
+ vit_output_list = []
205
+ if images_list is not None:
206
+ for img in images_list:
207
+ pixel_values = pre_process(img, input_size=448, max_num=1)
208
+ pixel_values_list.append(pixel_values)
209
+ print(f"输入图像数: {len(pixel_values_list)}")
210
+ print("preprocess image done!")
211
+
212
+ # extract img feature by vit
213
+
214
+
215
+ for idx, pixel_values in enumerate(pixel_values_list):
216
+ vit_output = self.vit_session.run(None, {"image": pixel_values})[0]
217
+ vit_output_list.append(vit_output.copy()) # 避免 vit 输出结果使用同一块内存
218
+
219
+ print(f"vit_output.shape is {vit_output_list[0].shape}, vit feature extract done!")
220
+
221
+ return vit_output_list
222
+
223
+ def prompt_encode(self, question, num_of_images) -> list:
224
+ prompt = "<|im_start|>system\n你是书生·万象, 英文名是InternVL, 是由上海人工智能实验室、清华大学及多家合作单位联合开发的多模态大语言模型.<|im_end|>\n"
225
+ # question = args.question
226
+
227
+ if num_of_images > 0:
228
+ for idx in range(num_of_images):
229
+ if self.tag == "video":
230
+ prompt += "<|im_start|>user"
231
+ prompt += f"\nFrame{idx+1}: <img>" + "<IMG_CONTEXT>" * 256 + "</img>\n"
232
+ prompt += f"\n{question}<|im_end|>\n<|im_start|>assistant\n"
233
+ else:
234
+ prompt += "<|im_start|>user\n" + question
235
+ prompt += "\n<img>" + "<IMG_CONTEXT>" * 256 + "</img>\n"
236
+ prompt += "<|im_end|>\n<|im_start|>assistant\n"
237
+
238
+ token_ids = self.tokenizer.encode(prompt)
239
+ print(f"prompt is {prompt}, \ntoken_len is {len(token_ids)}")
240
+ return token_ids
241
+
242
+
243
+ def generate(self, sources, prompt, video_segments=8):
244
+ self.stop = False
245
+ images_list = []
246
+
247
+ # 1. Handle single video path string
248
+ if isinstance(sources, str) and is_video_file(sources):
249
+ images_list = load_video_opencv(sources, num_segments=video_segments)
250
+
251
+ # 2. Handle [video_path] list
252
+ elif isinstance(sources, list) and len(sources) == 1 and isinstance(sources[0], str) and is_video_file(sources[0]):
253
+ images_list = load_video_opencv(sources[0], num_segments=video_segments)
254
+
255
+ # 3. Handle single image path
256
+ elif isinstance(sources, str) and is_image_file(sources):
257
+ images_list = [load_image(sources)]
258
+
259
+ # 4. Handle single image as np.ndarray
260
+ elif isinstance(sources, np.ndarray):
261
+ images_list = [sources]
262
+
263
+ # 5. Handle list of images or paths
264
+ elif isinstance(sources, list):
265
+ for img in sources:
266
+ if isinstance(img, str):
267
+ images_list.append(load_image(img))
268
+ elif isinstance(img, np.ndarray):
269
+ images_list.append(img)
270
+ else:
271
+ raise ValueError(f"Unsupported image type: {type(img)}")
272
+ else:
273
+ raise ValueError("Unsupported input format for 'sources'.")
274
+
275
+ vit_output_list = self.image_encode(images_list)
276
+
277
+ token_ids = self.prompt_encode(prompt, len(vit_output_list))
278
+
279
+ k_caches = [
280
+ np.zeros((1, self.kv_cache_len, self.kv_dim), dtype=bfloat16)
281
+ for _ in range(self.cfg.num_hidden_layers)
282
+ ]
283
+ v_caches = [
284
+ np.zeros((1, self.kv_cache_len, self.kv_dim), dtype=bfloat16)
285
+ for _ in range(self.cfg.num_hidden_layers)
286
+ ]
287
+
288
+ # 图像理解
289
+ image_start_indices = np.where(np.array(token_ids) == 151665)[0].tolist() # <img> tag
290
+
291
+ prefill_data = np.take(self.embeds, token_ids, axis=0)
292
+ prefill_data = prefill_data.astype(bfloat16)
293
+ token_len = len(token_ids)
294
+
295
+ assert token_len < 2048 + 128, f"输入 prompt({token_len}) 超过最大限度!"
296
+ for idx, image_start_index in enumerate(image_start_indices):
297
+ image_insert_index = image_start_index + 1
298
+ prefill_data[image_insert_index : image_insert_index + 256] = vit_output_list[idx][0, :, :]
299
+ ##################################
300
+ print("prefill token_len: ", token_len)
301
+
302
+ """
303
+ prefill
304
+ """
305
+ prefill_slice_len = self.prefill_slice_len
306
+ # slice_indexs = [0, 1, 2, 3, 4, 5, 6, 7, 8]
307
+ slice_indexs = [
308
+ e for e in range(token_len // prefill_slice_len + 1)
309
+ ]
310
+ # print(f"slice_indexs is {slice_indexs}")
311
+ prefill_len = prefill_slice_len * slice_indexs[-1] if slice_indexs[-1] != 0 else prefill_slice_len # 这里的 128 就是 prefill_slice_len
312
+
313
+ if prefill_len > 0:
314
+ for slice_index in tqdm(slice_indexs, desc="prefill"):
315
+ indices = np.array(
316
+ list(
317
+ range(
318
+ slice_index * prefill_slice_len,
319
+ (slice_index + 1) * prefill_slice_len,
320
+ )
321
+ ),
322
+ np.uint32,
323
+ ).reshape((1, prefill_slice_len))
324
+
325
+ mask = (
326
+ np.zeros((1, prefill_slice_len, prefill_slice_len * (slice_index + 1)))
327
+ - 65536
328
+ )
329
+ data = np.zeros((1, prefill_slice_len, self.cfg.hidden_size)).astype(bfloat16)
330
+ for i, t in enumerate(
331
+ range(
332
+ slice_index * prefill_slice_len,
333
+ (slice_index + 1) * prefill_slice_len,
334
+ )
335
+ ):
336
+ if t < len(token_ids):
337
+ mask[:, i, : slice_index * prefill_slice_len + i + 1] = 0
338
+ data[:, i : i + 1, :] = (
339
+ prefill_data[t]
340
+ .reshape((1, 1, self.cfg.hidden_size))
341
+ .astype(bfloat16)
342
+ )
343
+
344
+ if slice_index == slice_indexs[-1]:
345
+ remain_len = token_len - slice_index * prefill_slice_len
346
+ else:
347
+ remain_len = prefill_slice_len
348
+ mask = mask.astype(bfloat16)
349
+ for i in range(self.cfg.num_hidden_layers):
350
+ input_feed = {
351
+ "K_cache": (
352
+ k_caches[i][:, 0 : prefill_slice_len * slice_index, :]
353
+ if slice_index
354
+ else np.zeros((1, 1, self.cfg.hidden_size), dtype=bfloat16)
355
+ ),
356
+ "V_cache": (
357
+ v_caches[i][:, 0 : prefill_slice_len * slice_index, :]
358
+ if slice_index
359
+ else np.zeros((1, 1, self.cfg.hidden_size), dtype=bfloat16)
360
+ ),
361
+ "indices": indices,
362
+ "input": data,
363
+ "mask": mask,
364
+ }
365
+ outputs = self.prefill_decoder_sessins[i].run(None, input_feed, shape_group=slice_index + 1)
366
+ k_caches[i][
367
+ :,
368
+ slice_index
369
+ * prefill_slice_len : slice_index
370
+ * prefill_slice_len + remain_len,
371
+ :,
372
+ ] = outputs[0][:, :remain_len, :]
373
+ v_caches[i][
374
+ :,
375
+ slice_index
376
+ * prefill_slice_len : slice_index
377
+ * prefill_slice_len + remain_len,
378
+ :,
379
+ ] = outputs[1][:, :remain_len, :]
380
+ data = outputs[2]
381
+
382
+ if self.stop:
383
+ return
384
+
385
+ # print("slice prefill done", slice_index)
386
+ post_out = self.post_process_session.run(
387
+ None,
388
+ {
389
+ "input": data[
390
+ :, token_len - (len(slice_indexs) - 1) * prefill_slice_len - 1, None, :
391
+ ]
392
+ }
393
+ )[0]
394
+ next_token, posssible_tokens, possible_soft = post_process(post_out)
395
+ posibles = [self.tokenizer.decode([t]) for t in posssible_tokens]
396
+ posible_soft = [str((t, s)) for t, s in zip(posibles, possible_soft)]
397
+ token_ids.append(next_token)
398
+
399
+ # set to decoder
400
+ token_ids_cached = []
401
+ token_ids_cached.append(next_token)
402
+
403
+ mask = np.zeros((1, 1, self.kv_cache_len + 1), dtype=np.float32).astype(bfloat16)
404
+ mask[:, :, :self.kv_cache_len] -= 65536
405
+ if prefill_len > 0:
406
+ mask[:, :, :token_len] = 0
407
+
408
+ for start_indice in range(self.kv_cache_len):
409
+ if prefill_len > 0 and start_indice < token_len:
410
+ continue
411
+
412
+ next_token = token_ids[start_indice]
413
+ indices = np.array([start_indice], np.uint32).reshape((1, 1))
414
+ data = self.embeds[next_token, :].reshape((1, 1, self.cfg.hidden_size)).astype(bfloat16)
415
+ for i in range(self.cfg.num_hidden_layers):
416
+ input_feed = {
417
+ "K_cache": k_caches[i],
418
+ "V_cache": v_caches[i],
419
+ "indices": indices,
420
+ "input": data,
421
+ "mask": mask,
422
+ }
423
+ outputs = self.prefill_decoder_sessins[i].run(None, input_feed, shape_group=0)
424
+ k_caches[i][:, start_indice, :] = outputs[0][:, :, :]
425
+ v_caches[i][:, start_indice, :] = outputs[1][:, :, :]
426
+ data = outputs[2]
427
+ mask[..., start_indice] = 0
428
+ if start_indice < token_len - 1:
429
+ pass
430
+ else:
431
+ post_out = self.post_process_session.run(None, {"input": data})[0]
432
+ next_token, posssible_tokens, possible_soft = post_process(post_out)
433
+ token_ids.append(next_token)
434
+
435
+ if next_token == self.tokenizer.eos_token_id and next_token > token_len:
436
+ if len(token_ids_cached) > 0:
437
+ msg = self.tokenizer.decode(token_ids_cached)
438
+ token_ids_cached.clear()
439
+ if "\ufffd" in msg:
440
+ msg = msg.replace("\ufffd", "")
441
+ # print(msg, end="", flush=True)
442
+ yield msg
443
+ break
444
+
445
+ token_ids_cached.append(next_token)
446
+
447
+ if len(token_ids_cached) >= 3:
448
+ msg = self.tokenizer.decode(token_ids_cached)
449
+ token_ids_cached.clear()
450
+ if "\ufffd" in msg:
451
+ msg = msg.replace("\ufffd", "")
452
+ # print(msg, end="", flush=True)
453
+ yield msg
454
+
455
+
456
+ if self.stop:
457
+ return
main_api_axcl_aarch64 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:fe28789b7b719911f38e18fbcd38b149ac77f4d75129e8629d7fce1c7b3cddf8
3
+ size 1870304
main_axcl_aarch64 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:03ff6647009917de1687a84e7601d457e00587af490b54bf1da41b4f0208f691
3
+ size 1786544
requirements.txt ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ opencv-python
2
+ transformers
3
+ numpy
4
+ ml_dtypes
5
+ tqdm
run_internvl_3_2b_448_api_ax650.sh CHANGED
@@ -5,6 +5,6 @@
5
  --use_mmap_load_embed 1 \
6
  --filename_tokenizer_model "http://0.0.0.0:12345" \
7
  --filename_post_axmodel "./internvl3_2b_axmodel/qwen2_post.axmodel" \
8
- --filename_tokens_embed "./internvl3_2b_axmodel/model.embed_tokens.weight.bfloat16.bin" \
9
  --tokens_embed_num 151674 \
10
  --tokens_embed_size 1536
 
5
  --use_mmap_load_embed 1 \
6
  --filename_tokenizer_model "http://0.0.0.0:12345" \
7
  --filename_post_axmodel "./internvl3_2b_axmodel/qwen2_post.axmodel" \
8
+ --filename_tokens_embed "./internvl3_2b_axmodel/model.embed_tokens.weight.bf16.bin" \
9
  --tokens_embed_num 151674 \
10
  --tokens_embed_size 1536
run_internvl_3_2b_448_api_axcl_aarch64.sh ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ./main_api_axcl_aarch64 \
2
+ --template_filename_axmodel "./internvl3_2b_ax650/qwen2_p128_l%d_together.axmodel" \
3
+ --axmodel_num 28 \
4
+ --filename_image_encoder_axmodedl "./internvl3_2b_ax650/internvl3_2b_vit.axmodel" \
5
+ --use_mmap_load_embed 1 \
6
+ --filename_tokenizer_model "http://0.0.0.0:12345" \
7
+ --filename_post_axmodel "./internvl3_2b_ax650/qwen2_post.axmodel" \
8
+ --filename_tokens_embed "./internvl3_2b_ax650/model.embed_tokens.weight.bfloat16.bin" \
9
+ --tokens_embed_num 151674 \
10
+ --tokens_embed_size 1536 \
11
+ --devices 0 \
run_internvl_3_2b_448_api_axcl_x86.sh CHANGED
@@ -5,7 +5,7 @@
5
  --use_mmap_load_embed 1 \
6
  --filename_tokenizer_model "http://0.0.0.0:12345" \
7
  --filename_post_axmodel "./internvl3_2b_axmodel/qwen2_post.axmodel" \
8
- --filename_tokens_embed "./internvl3_2b_axmodel/model.embed_tokens.weight.bfloat16.bin" \
9
  --tokens_embed_num 151674 \
10
  --tokens_embed_size 1536 \
11
  --devices 0,2,4 \
 
5
  --use_mmap_load_embed 1 \
6
  --filename_tokenizer_model "http://0.0.0.0:12345" \
7
  --filename_post_axmodel "./internvl3_2b_axmodel/qwen2_post.axmodel" \
8
+ --filename_tokens_embed "./internvl3_2b_axmodel/model.embed_tokens.weight.bf16.bin" \
9
  --tokens_embed_num 151674 \
10
  --tokens_embed_size 1536 \
11
  --devices 0,2,4 \
run_internvl_3_2b_448_ax650.sh CHANGED
@@ -5,7 +5,7 @@
5
  --use_mmap_load_embed 1 \
6
  --filename_tokenizer_model "http://0.0.0.0:12345" \
7
  --filename_post_axmodel "./internvl3_2b_axmodel/qwen2_post.axmodel" \
8
- --filename_tokens_embed "./internvl3_2b_axmodel/model.embed_tokens.weight.bfloat16.bin" \
9
  --tokens_embed_num 151674 \
10
  --tokens_embed_size 1536 \
11
  --live_print 1
 
5
  --use_mmap_load_embed 1 \
6
  --filename_tokenizer_model "http://0.0.0.0:12345" \
7
  --filename_post_axmodel "./internvl3_2b_axmodel/qwen2_post.axmodel" \
8
+ --filename_tokens_embed "./internvl3_2b_axmodel/model.embed_tokens.weight.bf16.bin" \
9
  --tokens_embed_num 151674 \
10
  --tokens_embed_size 1536 \
11
  --live_print 1
run_internvl_3_2b_448_axcl_aarch64.sh ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ./main_axcl_aarch64 \
2
+ --template_filename_axmodel "./internvl3_2b_ax650/qwen2_p128_l%d_together.axmodel" \
3
+ --axmodel_num 28 \
4
+ --filename_image_encoder_axmodedl "./internvl3_2b_ax650/internvl3_2b_vit.axmodel" \
5
+ --use_mmap_load_embed 1 \
6
+ --filename_tokenizer_model "http://0.0.0.0:12345" \
7
+ --filename_post_axmodel "./internvl3_2b_ax650/qwen2_post.axmodel" \
8
+ --filename_tokens_embed "./internvl3_2b_ax650/model.embed_tokens.weight.bfloat16.bin" \
9
+ --tokens_embed_num 151674 \
10
+ --tokens_embed_size 1536 \
11
+ --devices 0 \
12
+ --live_print 1
run_internvl_3_2b_448_axcl_x86.sh CHANGED
@@ -5,7 +5,7 @@
5
  --use_mmap_load_embed 1 \
6
  --filename_tokenizer_model "http://0.0.0.0:12345" \
7
  --filename_post_axmodel "./internvl3_2b_axmodel/qwen2_post.axmodel" \
8
- --filename_tokens_embed "./internvl3_2b_axmodel/model.embed_tokens.weight.bfloat16.bin" \
9
  --tokens_embed_num 151674 \
10
  --tokens_embed_size 1536 \
11
  --devices 0,2,4 \
 
5
  --use_mmap_load_embed 1 \
6
  --filename_tokenizer_model "http://0.0.0.0:12345" \
7
  --filename_post_axmodel "./internvl3_2b_axmodel/qwen2_post.axmodel" \
8
+ --filename_tokens_embed "./internvl3_2b_axmodel/model.embed_tokens.weight.bf16.bin" \
9
  --tokens_embed_num 151674 \
10
  --tokens_embed_size 1536 \
11
  --devices 0,2,4 \
webgui.png ADDED

Git LFS Details

  • SHA256: 4017e49ec821141cd3356aea983b4cda5997583c6b90dbc666f4ca7b838e4ed3
  • Pointer size: 131 Bytes
  • Size of remote file: 329 kB