Upload 18 files
Browse files- .gitattributes +2 -0
- gradio_demo_c_api.py +127 -0
- gradio_demo_python_api.py +129 -0
- llm.py +457 -0
- main_api_axcl_aarch64 +3 -0
- main_axcl_aarch64 +3 -0
- requirements.txt +5 -0
- run_internvl_3_2b_448_api_ax650.sh +1 -1
- run_internvl_3_2b_448_api_axcl_aarch64.sh +11 -0
- run_internvl_3_2b_448_api_axcl_x86.sh +1 -1
- run_internvl_3_2b_448_ax650.sh +1 -1
- run_internvl_3_2b_448_axcl_aarch64.sh +12 -0
- run_internvl_3_2b_448_axcl_x86.sh +1 -1
- webgui.png +3 -0
.gitattributes
CHANGED
|
@@ -42,3 +42,5 @@ main_api_ax650 filter=lfs diff=lfs merge=lfs -text
|
|
| 42 |
main_api_axcl_x86 filter=lfs diff=lfs merge=lfs -text
|
| 43 |
main_ax650 filter=lfs diff=lfs merge=lfs -text
|
| 44 |
main_axcl_x86 filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
|
|
|
| 42 |
main_api_axcl_x86 filter=lfs diff=lfs merge=lfs -text
|
| 43 |
main_ax650 filter=lfs diff=lfs merge=lfs -text
|
| 44 |
main_axcl_x86 filter=lfs diff=lfs merge=lfs -text
|
| 45 |
+
main_api_axcl_aarch64 filter=lfs diff=lfs merge=lfs -text
|
| 46 |
+
main_axcl_aarch64 filter=lfs diff=lfs merge=lfs -text
|
gradio_demo_c_api.py
ADDED
|
@@ -0,0 +1,127 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import mimetypes
|
| 2 |
+
import os
|
| 3 |
+
import gradio as gr
|
| 4 |
+
import requests
|
| 5 |
+
import json, time
|
| 6 |
+
|
| 7 |
+
base_url = "http://10.126.33.235:8000"
|
| 8 |
+
|
| 9 |
+
def upload_image(file_path):
|
| 10 |
+
if file_path is None:
|
| 11 |
+
return None
|
| 12 |
+
# Gradio File component returns a tempfile-like object
|
| 13 |
+
# file_path = image.name
|
| 14 |
+
filename = os.path.basename(file_path)
|
| 15 |
+
# Guess MIME type
|
| 16 |
+
mime_type, _ = mimetypes.guess_type(filename)
|
| 17 |
+
mime_type = mime_type or 'application/octet-stream'
|
| 18 |
+
# Open file in binary mode for upload
|
| 19 |
+
with open(file_path, 'rb') as f:
|
| 20 |
+
file_bytes = f.read()
|
| 21 |
+
# Prepare multipart form data
|
| 22 |
+
files = {
|
| 23 |
+
'image': (filename, file_bytes, mime_type)
|
| 24 |
+
}
|
| 25 |
+
# Send to upload endpoint
|
| 26 |
+
resp = requests.post(
|
| 27 |
+
f'{base_url}/api/upload',
|
| 28 |
+
files=files
|
| 29 |
+
)
|
| 30 |
+
resp.raise_for_status()
|
| 31 |
+
data = resp.json()
|
| 32 |
+
return data.get('file_path')
|
| 33 |
+
|
| 34 |
+
def stop_generation():
|
| 35 |
+
try:
|
| 36 |
+
requests.get(f'{base_url}/api/stop')
|
| 37 |
+
except:
|
| 38 |
+
pass
|
| 39 |
+
|
| 40 |
+
def respond(prompt, image:gr.Image, temp, rep_penalty, tp, tk, history=None):
|
| 41 |
+
if history is None:
|
| 42 |
+
history = []
|
| 43 |
+
if not prompt.strip():
|
| 44 |
+
return history
|
| 45 |
+
# append empty response to history
|
| 46 |
+
if image is None:
|
| 47 |
+
file_path = None
|
| 48 |
+
else:
|
| 49 |
+
file_path = upload_image(image)
|
| 50 |
+
history.append((f'', None))
|
| 51 |
+
relative_path = os.path.relpath(file_path)
|
| 52 |
+
# html = f"<img src='{relative_path}' style='max-width:300px;'/>"
|
| 53 |
+
# history.append((html, None))
|
| 54 |
+
# print(relative_path)
|
| 55 |
+
|
| 56 |
+
|
| 57 |
+
history.append((prompt, ""))
|
| 58 |
+
yield history
|
| 59 |
+
# stream updates
|
| 60 |
+
|
| 61 |
+
|
| 62 |
+
payload = {
|
| 63 |
+
"prompt": prompt,
|
| 64 |
+
"temperature": temp,
|
| 65 |
+
"repetition_penalty": rep_penalty,
|
| 66 |
+
"top-p": tp,
|
| 67 |
+
"top-k": tk
|
| 68 |
+
}
|
| 69 |
+
if file_path:
|
| 70 |
+
payload["file_path"] = file_path
|
| 71 |
+
|
| 72 |
+
response = requests.post(
|
| 73 |
+
f'{base_url}/api/generate',
|
| 74 |
+
json=payload
|
| 75 |
+
)
|
| 76 |
+
response.raise_for_status()
|
| 77 |
+
|
| 78 |
+
|
| 79 |
+
while True:
|
| 80 |
+
time.sleep(0.01)
|
| 81 |
+
response = requests.get(
|
| 82 |
+
f'{base_url}/api/generate_provider'
|
| 83 |
+
)
|
| 84 |
+
data = response.json()
|
| 85 |
+
chunk:str = data.get("response", "")
|
| 86 |
+
done = data.get("done", False)
|
| 87 |
+
if done:
|
| 88 |
+
break
|
| 89 |
+
if chunk.strip() == "":
|
| 90 |
+
continue
|
| 91 |
+
history[-1] = (prompt, history[-1][1] + chunk)
|
| 92 |
+
yield history
|
| 93 |
+
|
| 94 |
+
print("end")
|
| 95 |
+
|
| 96 |
+
|
| 97 |
+
|
| 98 |
+
|
| 99 |
+
def chat_interface():
|
| 100 |
+
with gr.Blocks(theme=gr.themes.Soft(font="Consolas"), fill_width=True) as demo:
|
| 101 |
+
gr.Markdown("## Chat with LLM\nUpload an image and chat with the model!")
|
| 102 |
+
with gr.Row():
|
| 103 |
+
image = gr.Image(label="Upload Image", type="filepath")
|
| 104 |
+
with gr.Column(scale=3):
|
| 105 |
+
chatbot = gr.Chatbot(height=600)
|
| 106 |
+
prompt = gr.Textbox(placeholder="Type your message...", label="Prompt", value="描述一下这张图片")
|
| 107 |
+
with gr.Row():
|
| 108 |
+
btn_chat = gr.Button("Chat", variant="primary")
|
| 109 |
+
btn_stop = gr.Button("Stop", variant="stop")
|
| 110 |
+
|
| 111 |
+
with gr.Column(scale=1):
|
| 112 |
+
temperature = gr.Slider(minimum=0.0, maximum=1.0, step=0.01, value=0.7, label="Temperature")
|
| 113 |
+
repetition_penalty = gr.Slider(minimum=1.0, maximum=2.0, step=0.01, value=1.0, label="Repetition Penalty")
|
| 114 |
+
top_p = gr.Slider(minimum=0.0, maximum=1.0, step=0.01, value=0.9, label="Top-p Sampling")
|
| 115 |
+
top_k = gr.Slider(minimum=0, maximum=100, step=1, value=40, label="Top-k Sampling")
|
| 116 |
+
|
| 117 |
+
btn_stop.click(fn=stop_generation, inputs=None, outputs=None)
|
| 118 |
+
btn_chat.click(
|
| 119 |
+
fn=respond,
|
| 120 |
+
inputs=[prompt, image, temperature, repetition_penalty, top_p, top_k, chatbot],
|
| 121 |
+
outputs=chatbot
|
| 122 |
+
)
|
| 123 |
+
|
| 124 |
+
demo.launch(server_name="0.0.0.0", server_port=7860)
|
| 125 |
+
|
| 126 |
+
if __name__ == "__main__":
|
| 127 |
+
chat_interface()
|
gradio_demo_python_api.py
ADDED
|
@@ -0,0 +1,129 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import cv2
|
| 2 |
+
import gradio as gr
|
| 3 |
+
from llm import LLM
|
| 4 |
+
import llm as llm
|
| 5 |
+
import argparse
|
| 6 |
+
import socket
|
| 7 |
+
|
| 8 |
+
parser = argparse.ArgumentParser(description="Model configuration parameters")
|
| 9 |
+
parser.add_argument("--hf_model", type=str, default="./InternVL3-2B",
|
| 10 |
+
help="Path to HuggingFace model")
|
| 11 |
+
parser.add_argument("--axmodel_path", type=str, default="./InternVL3-2B_axmodel",
|
| 12 |
+
help="Path to save compiled axmodel of llama model")
|
| 13 |
+
parser.add_argument("--vit_model", type=str, default=None,
|
| 14 |
+
help="Path to save compiled axmodel of llama model")
|
| 15 |
+
args = parser.parse_args()
|
| 16 |
+
|
| 17 |
+
hf_model_path = args.hf_model
|
| 18 |
+
axmodel_path = args.axmodel_path
|
| 19 |
+
vit_axmodel_path = args.vit_model
|
| 20 |
+
|
| 21 |
+
gllm = LLM(hf_model_path, axmodel_path, vit_axmodel_path)
|
| 22 |
+
|
| 23 |
+
# 获取本地 IP 地址
|
| 24 |
+
def get_local_ip():
|
| 25 |
+
try:
|
| 26 |
+
s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
|
| 27 |
+
s.connect(("8.8.8.8", 80))
|
| 28 |
+
ip = s.getsockname()[0]
|
| 29 |
+
s.close()
|
| 30 |
+
return ip
|
| 31 |
+
except Exception:
|
| 32 |
+
return "127.0.0.1"
|
| 33 |
+
|
| 34 |
+
def stop_generation():
|
| 35 |
+
gllm.stop_generate()
|
| 36 |
+
|
| 37 |
+
def respond(prompt, video, image, is_image, video_segments, image_segments_cols, image_segments_rows, history=None):
|
| 38 |
+
if history is None:
|
| 39 |
+
history = []
|
| 40 |
+
if not prompt.strip():
|
| 41 |
+
return history
|
| 42 |
+
# append empty response to history
|
| 43 |
+
|
| 44 |
+
gllm.tag = "video" if not is_image else "image"
|
| 45 |
+
|
| 46 |
+
history.append((prompt, ""))
|
| 47 |
+
yield history
|
| 48 |
+
|
| 49 |
+
print(video)
|
| 50 |
+
print(image)
|
| 51 |
+
|
| 52 |
+
if is_image:
|
| 53 |
+
img = cv2.imread(image)
|
| 54 |
+
images_list = []
|
| 55 |
+
if image_segments_cols == 1 and image_segments_rows == 1:
|
| 56 |
+
images_list.append(img)
|
| 57 |
+
elif image_segments_cols * image_segments_rows > 8:
|
| 58 |
+
# gr.Error("image segments cols * image segments rows > 8")
|
| 59 |
+
history[-1] = (prompt, history[-1][1] + "image segments cols * image segments rows > 8")
|
| 60 |
+
yield history
|
| 61 |
+
return
|
| 62 |
+
else:
|
| 63 |
+
height, width, _ = img.shape
|
| 64 |
+
segment_width = width // image_segments_cols
|
| 65 |
+
segment_height = height // image_segments_rows
|
| 66 |
+
for i in range(image_segments_rows):
|
| 67 |
+
for j in range(image_segments_cols):
|
| 68 |
+
x1 = j * segment_width
|
| 69 |
+
y1 = i * segment_height
|
| 70 |
+
x2 = (j + 1) * segment_width
|
| 71 |
+
y2 = (i + 1) * segment_height
|
| 72 |
+
segment = img[y1:y2, x1:x2]
|
| 73 |
+
images_list.append(segment)
|
| 74 |
+
else:
|
| 75 |
+
images_list = llm.load_video_opencv(video, num_segments = video_segments)
|
| 76 |
+
|
| 77 |
+
for msg in gllm.generate(images_list, prompt):
|
| 78 |
+
print(msg, end="", flush=True)
|
| 79 |
+
history[-1] = (prompt, history[-1][1] + msg)
|
| 80 |
+
yield history
|
| 81 |
+
print("\n\n\n")
|
| 82 |
+
|
| 83 |
+
def chat_interface():
|
| 84 |
+
with gr.Blocks() as demo:
|
| 85 |
+
gr.Markdown("## Raspberry Pi 5 InternVL3 Chat DEMO using AXCL\nUpload an image or video and chat with the InternVL3.")
|
| 86 |
+
with gr.Row():
|
| 87 |
+
with gr.Column(scale=1):
|
| 88 |
+
video = gr.Video(label="Upload Video", format="mp4")
|
| 89 |
+
video_segments = gr.Slider(minimum=2, maximum=8, step=1, value=4, label="video segments")
|
| 90 |
+
image = gr.Image(label="Upload Image", type="filepath")
|
| 91 |
+
image_segments_cols = gr.Slider(minimum=1, maximum=4, step=1, value=1, label="image cols segments")
|
| 92 |
+
image_segments_rows = gr.Slider(minimum=1, maximum=4, step=1, value=1, label="image rows segments")
|
| 93 |
+
checkbox = gr.Checkbox(label="Use Image")
|
| 94 |
+
with gr.Column(scale=3):
|
| 95 |
+
chatbot = gr.Chatbot(height=650)
|
| 96 |
+
prompt = gr.Textbox(placeholder="Type your message...", label="Prompt", value="描述一下这组图片")
|
| 97 |
+
with gr.Row():
|
| 98 |
+
btn_chat = gr.Button("Chat", variant="primary")
|
| 99 |
+
btn_stop = gr.Button("Stop", variant="stop")
|
| 100 |
+
|
| 101 |
+
btn_stop.click(fn=stop_generation, inputs=None, outputs=None)
|
| 102 |
+
|
| 103 |
+
btn_chat.click(
|
| 104 |
+
fn=respond,
|
| 105 |
+
inputs=[prompt, video, image, checkbox, video_segments, image_segments_cols, image_segments_rows, chatbot],
|
| 106 |
+
outputs=chatbot
|
| 107 |
+
)
|
| 108 |
+
|
| 109 |
+
def on_video_uploaded(video):
|
| 110 |
+
if video is not None:
|
| 111 |
+
return gr.update(value=False)
|
| 112 |
+
return gr.update()
|
| 113 |
+
|
| 114 |
+
def on_image_uploaded(image):
|
| 115 |
+
if image is not None:
|
| 116 |
+
return gr.update(value=True)
|
| 117 |
+
return gr.update()
|
| 118 |
+
|
| 119 |
+
video.change(fn=on_video_uploaded, inputs=video, outputs=checkbox)
|
| 120 |
+
image.change(fn=on_image_uploaded, inputs=image, outputs=checkbox)
|
| 121 |
+
|
| 122 |
+
|
| 123 |
+
local_ip = get_local_ip()
|
| 124 |
+
server_port = 7860
|
| 125 |
+
print(f"HTTP 服务地址: http://{local_ip}:{server_port}")
|
| 126 |
+
demo.launch(server_name=local_ip, server_port=server_port)
|
| 127 |
+
|
| 128 |
+
if __name__ == "__main__":
|
| 129 |
+
chat_interface()
|
llm.py
ADDED
|
@@ -0,0 +1,457 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import cv2
|
| 2 |
+
from transformers import AutoTokenizer, AutoConfig
|
| 3 |
+
import numpy as np
|
| 4 |
+
from ml_dtypes import bfloat16
|
| 5 |
+
from axengine import InferenceSession
|
| 6 |
+
from tqdm import tqdm
|
| 7 |
+
# from decord import VideoReader
|
| 8 |
+
|
| 9 |
+
def img_preprocess(img, input_size):
|
| 10 |
+
IMAGENET_MEAN = np.array([0.485, 0.456, 0.406], dtype=np.float32)
|
| 11 |
+
IMAGENET_STD = np.array((0.229, 0.224, 0.225), dtype=np.float32)
|
| 12 |
+
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
|
| 13 |
+
img = cv2.resize(img, (input_size, input_size))
|
| 14 |
+
img = img.astype(np.float32) / 255.0
|
| 15 |
+
img = (img - IMAGENET_MEAN) / IMAGENET_STD
|
| 16 |
+
img = img.transpose(2, 0, 1).reshape(1, 3, input_size, input_size)
|
| 17 |
+
return img
|
| 18 |
+
|
| 19 |
+
def find_closest_aspect_ratio(aspect_ratio, target_ratios, width, height, image_size):
|
| 20 |
+
best_ratio_diff = float('inf')
|
| 21 |
+
best_ratio = (1, 1)
|
| 22 |
+
area = width * height
|
| 23 |
+
for ratio in target_ratios:
|
| 24 |
+
target_aspect_ratio = ratio[0] / ratio[1]
|
| 25 |
+
ratio_diff = abs(aspect_ratio - target_aspect_ratio)
|
| 26 |
+
if ratio_diff < best_ratio_diff:
|
| 27 |
+
best_ratio_diff = ratio_diff
|
| 28 |
+
best_ratio = ratio
|
| 29 |
+
elif ratio_diff == best_ratio_diff:
|
| 30 |
+
if area > 0.5 * image_size * image_size * ratio[0] * ratio[1]:
|
| 31 |
+
best_ratio = ratio
|
| 32 |
+
return best_ratio
|
| 33 |
+
|
| 34 |
+
def dynamic_preprocess(image:np.array, min_num=1, max_num=12, image_size=448, use_thumbnail=False):
|
| 35 |
+
orig_height, orig_width, = image.shape[:2]
|
| 36 |
+
aspect_ratio = orig_width / orig_height
|
| 37 |
+
|
| 38 |
+
# calculate the existing image aspect ratio
|
| 39 |
+
target_ratios = set(
|
| 40 |
+
(i, j) for n in range(min_num, max_num + 1) for i in range(1, n + 1) for j in range(1, n + 1) if
|
| 41 |
+
i * j <= max_num and i * j >= min_num)
|
| 42 |
+
target_ratios = sorted(target_ratios, key=lambda x: x[0] * x[1])
|
| 43 |
+
|
| 44 |
+
# find the closest aspect ratio to the target
|
| 45 |
+
target_aspect_ratio = find_closest_aspect_ratio(
|
| 46 |
+
aspect_ratio, target_ratios, orig_width, orig_height, image_size)
|
| 47 |
+
|
| 48 |
+
# calculate the target width and height
|
| 49 |
+
target_width = image_size * target_aspect_ratio[0]
|
| 50 |
+
target_height = image_size * target_aspect_ratio[1]
|
| 51 |
+
blocks = target_aspect_ratio[0] * target_aspect_ratio[1]
|
| 52 |
+
|
| 53 |
+
# resize the image
|
| 54 |
+
# resized_img = image.resize((target_width, target_height))
|
| 55 |
+
resized_img = cv2.resize(image, (target_width, target_height))
|
| 56 |
+
processed_images = []
|
| 57 |
+
for i in range(blocks):
|
| 58 |
+
box = (
|
| 59 |
+
(i % (target_width // image_size)) * image_size,
|
| 60 |
+
(i // (target_width // image_size)) * image_size,
|
| 61 |
+
((i % (target_width // image_size)) + 1) * image_size,
|
| 62 |
+
((i // (target_width // image_size)) + 1) * image_size
|
| 63 |
+
)
|
| 64 |
+
# split the image
|
| 65 |
+
# split_img = resized_img.crop(box)
|
| 66 |
+
split_img = resized_img[box[1]:box[3], box[0]:box[2]]
|
| 67 |
+
processed_images.append(split_img)
|
| 68 |
+
assert len(processed_images) == blocks
|
| 69 |
+
if use_thumbnail and len(processed_images) != 1:
|
| 70 |
+
# thumbnail_img = image.resize((image_size, image_size))
|
| 71 |
+
thumbnail_img = cv2.resize(image, (image_size, image_size))
|
| 72 |
+
processed_images.append(thumbnail_img)
|
| 73 |
+
return processed_images
|
| 74 |
+
|
| 75 |
+
def pre_process(image, input_size=448, max_num=12):
|
| 76 |
+
images = dynamic_preprocess(image, image_size=input_size, use_thumbnail=True, max_num=max_num)
|
| 77 |
+
pixel_values = [img_preprocess(image, input_size) for image in images]
|
| 78 |
+
pixel_values = np.concatenate(pixel_values, axis=0)
|
| 79 |
+
return pixel_values
|
| 80 |
+
|
| 81 |
+
def get_index(bound, fps, max_frame, first_idx=0, num_segments=32):
|
| 82 |
+
if bound:
|
| 83 |
+
start, end = bound[0], bound[1]
|
| 84 |
+
else:
|
| 85 |
+
start, end = -100000, 100000
|
| 86 |
+
start_idx = max(first_idx, round(start * fps))
|
| 87 |
+
end_idx = min(round(end * fps), max_frame)
|
| 88 |
+
seg_size = float(end_idx - start_idx) / num_segments
|
| 89 |
+
frame_indices = np.array([
|
| 90 |
+
int(start_idx + (seg_size / 2) + np.round(seg_size * idx))
|
| 91 |
+
for idx in range(num_segments)
|
| 92 |
+
])
|
| 93 |
+
return frame_indices
|
| 94 |
+
|
| 95 |
+
|
| 96 |
+
def load_video_opencv(video_path, bound=None, num_segments=32):
|
| 97 |
+
cap = cv2.VideoCapture(video_path)
|
| 98 |
+
if not cap.isOpened():
|
| 99 |
+
raise IOError(f"Cannot open video: {video_path}")
|
| 100 |
+
|
| 101 |
+
fps = cap.get(cv2.CAP_PROP_FPS)
|
| 102 |
+
max_frame = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) - 1
|
| 103 |
+
|
| 104 |
+
frame_indices = get_index(bound, fps, max_frame, first_idx=0, num_segments=num_segments)
|
| 105 |
+
|
| 106 |
+
images_list = []
|
| 107 |
+
for frame_index in frame_indices:
|
| 108 |
+
cap.set(cv2.CAP_PROP_POS_FRAMES, frame_index)
|
| 109 |
+
ret, frame = cap.read()
|
| 110 |
+
if not ret:
|
| 111 |
+
print(f"⚠ Failed to read frame {frame_index}")
|
| 112 |
+
continue
|
| 113 |
+
images_list.append(frame)
|
| 114 |
+
|
| 115 |
+
cap.release()
|
| 116 |
+
return images_list
|
| 117 |
+
|
| 118 |
+
|
| 119 |
+
def is_video_file(path):
|
| 120 |
+
return str(path).lower().endswith((".mp4", ".avi", ".mov", ".mkv", ".webm"))
|
| 121 |
+
|
| 122 |
+
def is_image_file(path):
|
| 123 |
+
return str(path).lower().endswith((".jpg", ".png", ".jpeg", ".webp"))
|
| 124 |
+
|
| 125 |
+
def load_image(path):
|
| 126 |
+
image = cv2.imread(str(path))
|
| 127 |
+
if image is None:
|
| 128 |
+
raise ValueError(f"Image {path} not found or cannot be read.")
|
| 129 |
+
return image
|
| 130 |
+
|
| 131 |
+
def post_process(data, topk=1, topp=0.9, temperature=0.6):
|
| 132 |
+
def top_p(l: np.ndarray, p: float) -> np.ndarray:
|
| 133 |
+
index = np.argsort(l)
|
| 134 |
+
res = l.copy()
|
| 135 |
+
sum_p = 0
|
| 136 |
+
for i in index[::-1]:
|
| 137 |
+
if sum_p >= p:
|
| 138 |
+
res[i] = 0
|
| 139 |
+
sum_p += res[i]
|
| 140 |
+
return res / sum_p
|
| 141 |
+
|
| 142 |
+
def softmax(l: np.ndarray) -> np.ndarray:
|
| 143 |
+
l_max = l - l.max()
|
| 144 |
+
l_exp = np.exp(l_max)
|
| 145 |
+
res = l_exp / np.sum(l_exp)
|
| 146 |
+
return res.astype(np.float64)
|
| 147 |
+
|
| 148 |
+
r = data.astype(np.float32)
|
| 149 |
+
r = r.flatten()
|
| 150 |
+
# topk
|
| 151 |
+
candidate_index = np.argpartition(r, -topk)[-topk:]
|
| 152 |
+
candidate_value = r[candidate_index]
|
| 153 |
+
# temperature
|
| 154 |
+
candidate_value /= temperature
|
| 155 |
+
# softmax
|
| 156 |
+
candidate_soft = softmax(candidate_value)
|
| 157 |
+
# topp
|
| 158 |
+
candidate_soft = top_p(candidate_soft, topp)
|
| 159 |
+
candidate_soft = candidate_soft.astype(np.float64) / candidate_soft.sum()
|
| 160 |
+
pos = np.random.multinomial(1, candidate_soft).argmax()
|
| 161 |
+
next_token = candidate_index[pos]
|
| 162 |
+
return next_token, candidate_index, candidate_soft
|
| 163 |
+
|
| 164 |
+
|
| 165 |
+
class LLM:
|
| 166 |
+
|
| 167 |
+
def __init__(self, hf_model_path, axmodel_path, vit_axmodel_path ):
|
| 168 |
+
self.hf_model_path = hf_model_path
|
| 169 |
+
self.tag = "image"
|
| 170 |
+
|
| 171 |
+
config = AutoConfig.from_pretrained(hf_model_path, trust_remote_code=True)
|
| 172 |
+
self.tokenizer = AutoTokenizer.from_pretrained(hf_model_path, trust_remote_code=True, use_fast=False)
|
| 173 |
+
self.cfg = config.llm_config
|
| 174 |
+
|
| 175 |
+
self.prefill_slice_len=128
|
| 176 |
+
self.kv_cache_len=2559
|
| 177 |
+
|
| 178 |
+
self.prefill_decoder_sessins = []
|
| 179 |
+
for i in tqdm(range(self.cfg.num_hidden_layers), desc="Init InferenceSession"):
|
| 180 |
+
session = InferenceSession(
|
| 181 |
+
f"{axmodel_path}/qwen2_p128_l{i}_together.axmodel"
|
| 182 |
+
)
|
| 183 |
+
self.prefill_decoder_sessins.append(session)
|
| 184 |
+
|
| 185 |
+
self.post_process_session = InferenceSession(
|
| 186 |
+
f"{axmodel_path}/qwen2_post.axmodel"
|
| 187 |
+
)
|
| 188 |
+
print("model load done!")
|
| 189 |
+
|
| 190 |
+
self.kv_dim = self.cfg.hidden_size // self.cfg.num_attention_heads * self.cfg.num_key_value_heads
|
| 191 |
+
|
| 192 |
+
|
| 193 |
+
self.vit_session = InferenceSession(vit_axmodel_path)
|
| 194 |
+
|
| 195 |
+
self.embeds = np.load(f"{axmodel_path}/model.embed_tokens.weight.npy")
|
| 196 |
+
|
| 197 |
+
self.stop = False
|
| 198 |
+
|
| 199 |
+
def stop_generate(self):
|
| 200 |
+
self.stop = True
|
| 201 |
+
|
| 202 |
+
def image_encode(self, images_list):
|
| 203 |
+
pixel_values_list = []
|
| 204 |
+
vit_output_list = []
|
| 205 |
+
if images_list is not None:
|
| 206 |
+
for img in images_list:
|
| 207 |
+
pixel_values = pre_process(img, input_size=448, max_num=1)
|
| 208 |
+
pixel_values_list.append(pixel_values)
|
| 209 |
+
print(f"输入图像数: {len(pixel_values_list)}")
|
| 210 |
+
print("preprocess image done!")
|
| 211 |
+
|
| 212 |
+
# extract img feature by vit
|
| 213 |
+
|
| 214 |
+
|
| 215 |
+
for idx, pixel_values in enumerate(pixel_values_list):
|
| 216 |
+
vit_output = self.vit_session.run(None, {"image": pixel_values})[0]
|
| 217 |
+
vit_output_list.append(vit_output.copy()) # 避免 vit 输出结果使用同一块内存
|
| 218 |
+
|
| 219 |
+
print(f"vit_output.shape is {vit_output_list[0].shape}, vit feature extract done!")
|
| 220 |
+
|
| 221 |
+
return vit_output_list
|
| 222 |
+
|
| 223 |
+
def prompt_encode(self, question, num_of_images) -> list:
|
| 224 |
+
prompt = "<|im_start|>system\n你是书生·万象, 英文名是InternVL, 是由上海人工智能实验室、清华大学及多家合作单位联合开发的多模态大语言模型.<|im_end|>\n"
|
| 225 |
+
# question = args.question
|
| 226 |
+
|
| 227 |
+
if num_of_images > 0:
|
| 228 |
+
for idx in range(num_of_images):
|
| 229 |
+
if self.tag == "video":
|
| 230 |
+
prompt += "<|im_start|>user"
|
| 231 |
+
prompt += f"\nFrame{idx+1}: <img>" + "<IMG_CONTEXT>" * 256 + "</img>\n"
|
| 232 |
+
prompt += f"\n{question}<|im_end|>\n<|im_start|>assistant\n"
|
| 233 |
+
else:
|
| 234 |
+
prompt += "<|im_start|>user\n" + question
|
| 235 |
+
prompt += "\n<img>" + "<IMG_CONTEXT>" * 256 + "</img>\n"
|
| 236 |
+
prompt += "<|im_end|>\n<|im_start|>assistant\n"
|
| 237 |
+
|
| 238 |
+
token_ids = self.tokenizer.encode(prompt)
|
| 239 |
+
print(f"prompt is {prompt}, \ntoken_len is {len(token_ids)}")
|
| 240 |
+
return token_ids
|
| 241 |
+
|
| 242 |
+
|
| 243 |
+
def generate(self, sources, prompt, video_segments=8):
|
| 244 |
+
self.stop = False
|
| 245 |
+
images_list = []
|
| 246 |
+
|
| 247 |
+
# 1. Handle single video path string
|
| 248 |
+
if isinstance(sources, str) and is_video_file(sources):
|
| 249 |
+
images_list = load_video_opencv(sources, num_segments=video_segments)
|
| 250 |
+
|
| 251 |
+
# 2. Handle [video_path] list
|
| 252 |
+
elif isinstance(sources, list) and len(sources) == 1 and isinstance(sources[0], str) and is_video_file(sources[0]):
|
| 253 |
+
images_list = load_video_opencv(sources[0], num_segments=video_segments)
|
| 254 |
+
|
| 255 |
+
# 3. Handle single image path
|
| 256 |
+
elif isinstance(sources, str) and is_image_file(sources):
|
| 257 |
+
images_list = [load_image(sources)]
|
| 258 |
+
|
| 259 |
+
# 4. Handle single image as np.ndarray
|
| 260 |
+
elif isinstance(sources, np.ndarray):
|
| 261 |
+
images_list = [sources]
|
| 262 |
+
|
| 263 |
+
# 5. Handle list of images or paths
|
| 264 |
+
elif isinstance(sources, list):
|
| 265 |
+
for img in sources:
|
| 266 |
+
if isinstance(img, str):
|
| 267 |
+
images_list.append(load_image(img))
|
| 268 |
+
elif isinstance(img, np.ndarray):
|
| 269 |
+
images_list.append(img)
|
| 270 |
+
else:
|
| 271 |
+
raise ValueError(f"Unsupported image type: {type(img)}")
|
| 272 |
+
else:
|
| 273 |
+
raise ValueError("Unsupported input format for 'sources'.")
|
| 274 |
+
|
| 275 |
+
vit_output_list = self.image_encode(images_list)
|
| 276 |
+
|
| 277 |
+
token_ids = self.prompt_encode(prompt, len(vit_output_list))
|
| 278 |
+
|
| 279 |
+
k_caches = [
|
| 280 |
+
np.zeros((1, self.kv_cache_len, self.kv_dim), dtype=bfloat16)
|
| 281 |
+
for _ in range(self.cfg.num_hidden_layers)
|
| 282 |
+
]
|
| 283 |
+
v_caches = [
|
| 284 |
+
np.zeros((1, self.kv_cache_len, self.kv_dim), dtype=bfloat16)
|
| 285 |
+
for _ in range(self.cfg.num_hidden_layers)
|
| 286 |
+
]
|
| 287 |
+
|
| 288 |
+
# 图像理解
|
| 289 |
+
image_start_indices = np.where(np.array(token_ids) == 151665)[0].tolist() # <img> tag
|
| 290 |
+
|
| 291 |
+
prefill_data = np.take(self.embeds, token_ids, axis=0)
|
| 292 |
+
prefill_data = prefill_data.astype(bfloat16)
|
| 293 |
+
token_len = len(token_ids)
|
| 294 |
+
|
| 295 |
+
assert token_len < 2048 + 128, f"输入 prompt({token_len}) 超过最大限度!"
|
| 296 |
+
for idx, image_start_index in enumerate(image_start_indices):
|
| 297 |
+
image_insert_index = image_start_index + 1
|
| 298 |
+
prefill_data[image_insert_index : image_insert_index + 256] = vit_output_list[idx][0, :, :]
|
| 299 |
+
##################################
|
| 300 |
+
print("prefill token_len: ", token_len)
|
| 301 |
+
|
| 302 |
+
"""
|
| 303 |
+
prefill
|
| 304 |
+
"""
|
| 305 |
+
prefill_slice_len = self.prefill_slice_len
|
| 306 |
+
# slice_indexs = [0, 1, 2, 3, 4, 5, 6, 7, 8]
|
| 307 |
+
slice_indexs = [
|
| 308 |
+
e for e in range(token_len // prefill_slice_len + 1)
|
| 309 |
+
]
|
| 310 |
+
# print(f"slice_indexs is {slice_indexs}")
|
| 311 |
+
prefill_len = prefill_slice_len * slice_indexs[-1] if slice_indexs[-1] != 0 else prefill_slice_len # 这里的 128 就是 prefill_slice_len
|
| 312 |
+
|
| 313 |
+
if prefill_len > 0:
|
| 314 |
+
for slice_index in tqdm(slice_indexs, desc="prefill"):
|
| 315 |
+
indices = np.array(
|
| 316 |
+
list(
|
| 317 |
+
range(
|
| 318 |
+
slice_index * prefill_slice_len,
|
| 319 |
+
(slice_index + 1) * prefill_slice_len,
|
| 320 |
+
)
|
| 321 |
+
),
|
| 322 |
+
np.uint32,
|
| 323 |
+
).reshape((1, prefill_slice_len))
|
| 324 |
+
|
| 325 |
+
mask = (
|
| 326 |
+
np.zeros((1, prefill_slice_len, prefill_slice_len * (slice_index + 1)))
|
| 327 |
+
- 65536
|
| 328 |
+
)
|
| 329 |
+
data = np.zeros((1, prefill_slice_len, self.cfg.hidden_size)).astype(bfloat16)
|
| 330 |
+
for i, t in enumerate(
|
| 331 |
+
range(
|
| 332 |
+
slice_index * prefill_slice_len,
|
| 333 |
+
(slice_index + 1) * prefill_slice_len,
|
| 334 |
+
)
|
| 335 |
+
):
|
| 336 |
+
if t < len(token_ids):
|
| 337 |
+
mask[:, i, : slice_index * prefill_slice_len + i + 1] = 0
|
| 338 |
+
data[:, i : i + 1, :] = (
|
| 339 |
+
prefill_data[t]
|
| 340 |
+
.reshape((1, 1, self.cfg.hidden_size))
|
| 341 |
+
.astype(bfloat16)
|
| 342 |
+
)
|
| 343 |
+
|
| 344 |
+
if slice_index == slice_indexs[-1]:
|
| 345 |
+
remain_len = token_len - slice_index * prefill_slice_len
|
| 346 |
+
else:
|
| 347 |
+
remain_len = prefill_slice_len
|
| 348 |
+
mask = mask.astype(bfloat16)
|
| 349 |
+
for i in range(self.cfg.num_hidden_layers):
|
| 350 |
+
input_feed = {
|
| 351 |
+
"K_cache": (
|
| 352 |
+
k_caches[i][:, 0 : prefill_slice_len * slice_index, :]
|
| 353 |
+
if slice_index
|
| 354 |
+
else np.zeros((1, 1, self.cfg.hidden_size), dtype=bfloat16)
|
| 355 |
+
),
|
| 356 |
+
"V_cache": (
|
| 357 |
+
v_caches[i][:, 0 : prefill_slice_len * slice_index, :]
|
| 358 |
+
if slice_index
|
| 359 |
+
else np.zeros((1, 1, self.cfg.hidden_size), dtype=bfloat16)
|
| 360 |
+
),
|
| 361 |
+
"indices": indices,
|
| 362 |
+
"input": data,
|
| 363 |
+
"mask": mask,
|
| 364 |
+
}
|
| 365 |
+
outputs = self.prefill_decoder_sessins[i].run(None, input_feed, shape_group=slice_index + 1)
|
| 366 |
+
k_caches[i][
|
| 367 |
+
:,
|
| 368 |
+
slice_index
|
| 369 |
+
* prefill_slice_len : slice_index
|
| 370 |
+
* prefill_slice_len + remain_len,
|
| 371 |
+
:,
|
| 372 |
+
] = outputs[0][:, :remain_len, :]
|
| 373 |
+
v_caches[i][
|
| 374 |
+
:,
|
| 375 |
+
slice_index
|
| 376 |
+
* prefill_slice_len : slice_index
|
| 377 |
+
* prefill_slice_len + remain_len,
|
| 378 |
+
:,
|
| 379 |
+
] = outputs[1][:, :remain_len, :]
|
| 380 |
+
data = outputs[2]
|
| 381 |
+
|
| 382 |
+
if self.stop:
|
| 383 |
+
return
|
| 384 |
+
|
| 385 |
+
# print("slice prefill done", slice_index)
|
| 386 |
+
post_out = self.post_process_session.run(
|
| 387 |
+
None,
|
| 388 |
+
{
|
| 389 |
+
"input": data[
|
| 390 |
+
:, token_len - (len(slice_indexs) - 1) * prefill_slice_len - 1, None, :
|
| 391 |
+
]
|
| 392 |
+
}
|
| 393 |
+
)[0]
|
| 394 |
+
next_token, posssible_tokens, possible_soft = post_process(post_out)
|
| 395 |
+
posibles = [self.tokenizer.decode([t]) for t in posssible_tokens]
|
| 396 |
+
posible_soft = [str((t, s)) for t, s in zip(posibles, possible_soft)]
|
| 397 |
+
token_ids.append(next_token)
|
| 398 |
+
|
| 399 |
+
# set to decoder
|
| 400 |
+
token_ids_cached = []
|
| 401 |
+
token_ids_cached.append(next_token)
|
| 402 |
+
|
| 403 |
+
mask = np.zeros((1, 1, self.kv_cache_len + 1), dtype=np.float32).astype(bfloat16)
|
| 404 |
+
mask[:, :, :self.kv_cache_len] -= 65536
|
| 405 |
+
if prefill_len > 0:
|
| 406 |
+
mask[:, :, :token_len] = 0
|
| 407 |
+
|
| 408 |
+
for start_indice in range(self.kv_cache_len):
|
| 409 |
+
if prefill_len > 0 and start_indice < token_len:
|
| 410 |
+
continue
|
| 411 |
+
|
| 412 |
+
next_token = token_ids[start_indice]
|
| 413 |
+
indices = np.array([start_indice], np.uint32).reshape((1, 1))
|
| 414 |
+
data = self.embeds[next_token, :].reshape((1, 1, self.cfg.hidden_size)).astype(bfloat16)
|
| 415 |
+
for i in range(self.cfg.num_hidden_layers):
|
| 416 |
+
input_feed = {
|
| 417 |
+
"K_cache": k_caches[i],
|
| 418 |
+
"V_cache": v_caches[i],
|
| 419 |
+
"indices": indices,
|
| 420 |
+
"input": data,
|
| 421 |
+
"mask": mask,
|
| 422 |
+
}
|
| 423 |
+
outputs = self.prefill_decoder_sessins[i].run(None, input_feed, shape_group=0)
|
| 424 |
+
k_caches[i][:, start_indice, :] = outputs[0][:, :, :]
|
| 425 |
+
v_caches[i][:, start_indice, :] = outputs[1][:, :, :]
|
| 426 |
+
data = outputs[2]
|
| 427 |
+
mask[..., start_indice] = 0
|
| 428 |
+
if start_indice < token_len - 1:
|
| 429 |
+
pass
|
| 430 |
+
else:
|
| 431 |
+
post_out = self.post_process_session.run(None, {"input": data})[0]
|
| 432 |
+
next_token, posssible_tokens, possible_soft = post_process(post_out)
|
| 433 |
+
token_ids.append(next_token)
|
| 434 |
+
|
| 435 |
+
if next_token == self.tokenizer.eos_token_id and next_token > token_len:
|
| 436 |
+
if len(token_ids_cached) > 0:
|
| 437 |
+
msg = self.tokenizer.decode(token_ids_cached)
|
| 438 |
+
token_ids_cached.clear()
|
| 439 |
+
if "\ufffd" in msg:
|
| 440 |
+
msg = msg.replace("\ufffd", "")
|
| 441 |
+
# print(msg, end="", flush=True)
|
| 442 |
+
yield msg
|
| 443 |
+
break
|
| 444 |
+
|
| 445 |
+
token_ids_cached.append(next_token)
|
| 446 |
+
|
| 447 |
+
if len(token_ids_cached) >= 3:
|
| 448 |
+
msg = self.tokenizer.decode(token_ids_cached)
|
| 449 |
+
token_ids_cached.clear()
|
| 450 |
+
if "\ufffd" in msg:
|
| 451 |
+
msg = msg.replace("\ufffd", "")
|
| 452 |
+
# print(msg, end="", flush=True)
|
| 453 |
+
yield msg
|
| 454 |
+
|
| 455 |
+
|
| 456 |
+
if self.stop:
|
| 457 |
+
return
|
main_api_axcl_aarch64
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:fe28789b7b719911f38e18fbcd38b149ac77f4d75129e8629d7fce1c7b3cddf8
|
| 3 |
+
size 1870304
|
main_axcl_aarch64
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:03ff6647009917de1687a84e7601d457e00587af490b54bf1da41b4f0208f691
|
| 3 |
+
size 1786544
|
requirements.txt
ADDED
|
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
opencv-python
|
| 2 |
+
transformers
|
| 3 |
+
numpy
|
| 4 |
+
ml_dtypes
|
| 5 |
+
tqdm
|
run_internvl_3_2b_448_api_ax650.sh
CHANGED
|
@@ -5,6 +5,6 @@
|
|
| 5 |
--use_mmap_load_embed 1 \
|
| 6 |
--filename_tokenizer_model "http://0.0.0.0:12345" \
|
| 7 |
--filename_post_axmodel "./internvl3_2b_axmodel/qwen2_post.axmodel" \
|
| 8 |
-
--filename_tokens_embed "./internvl3_2b_axmodel/model.embed_tokens.weight.
|
| 9 |
--tokens_embed_num 151674 \
|
| 10 |
--tokens_embed_size 1536
|
|
|
|
| 5 |
--use_mmap_load_embed 1 \
|
| 6 |
--filename_tokenizer_model "http://0.0.0.0:12345" \
|
| 7 |
--filename_post_axmodel "./internvl3_2b_axmodel/qwen2_post.axmodel" \
|
| 8 |
+
--filename_tokens_embed "./internvl3_2b_axmodel/model.embed_tokens.weight.bf16.bin" \
|
| 9 |
--tokens_embed_num 151674 \
|
| 10 |
--tokens_embed_size 1536
|
run_internvl_3_2b_448_api_axcl_aarch64.sh
ADDED
|
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
./main_api_axcl_aarch64 \
|
| 2 |
+
--template_filename_axmodel "./internvl3_2b_ax650/qwen2_p128_l%d_together.axmodel" \
|
| 3 |
+
--axmodel_num 28 \
|
| 4 |
+
--filename_image_encoder_axmodedl "./internvl3_2b_ax650/internvl3_2b_vit.axmodel" \
|
| 5 |
+
--use_mmap_load_embed 1 \
|
| 6 |
+
--filename_tokenizer_model "http://0.0.0.0:12345" \
|
| 7 |
+
--filename_post_axmodel "./internvl3_2b_ax650/qwen2_post.axmodel" \
|
| 8 |
+
--filename_tokens_embed "./internvl3_2b_ax650/model.embed_tokens.weight.bfloat16.bin" \
|
| 9 |
+
--tokens_embed_num 151674 \
|
| 10 |
+
--tokens_embed_size 1536 \
|
| 11 |
+
--devices 0 \
|
run_internvl_3_2b_448_api_axcl_x86.sh
CHANGED
|
@@ -5,7 +5,7 @@
|
|
| 5 |
--use_mmap_load_embed 1 \
|
| 6 |
--filename_tokenizer_model "http://0.0.0.0:12345" \
|
| 7 |
--filename_post_axmodel "./internvl3_2b_axmodel/qwen2_post.axmodel" \
|
| 8 |
-
--filename_tokens_embed "./internvl3_2b_axmodel/model.embed_tokens.weight.
|
| 9 |
--tokens_embed_num 151674 \
|
| 10 |
--tokens_embed_size 1536 \
|
| 11 |
--devices 0,2,4 \
|
|
|
|
| 5 |
--use_mmap_load_embed 1 \
|
| 6 |
--filename_tokenizer_model "http://0.0.0.0:12345" \
|
| 7 |
--filename_post_axmodel "./internvl3_2b_axmodel/qwen2_post.axmodel" \
|
| 8 |
+
--filename_tokens_embed "./internvl3_2b_axmodel/model.embed_tokens.weight.bf16.bin" \
|
| 9 |
--tokens_embed_num 151674 \
|
| 10 |
--tokens_embed_size 1536 \
|
| 11 |
--devices 0,2,4 \
|
run_internvl_3_2b_448_ax650.sh
CHANGED
|
@@ -5,7 +5,7 @@
|
|
| 5 |
--use_mmap_load_embed 1 \
|
| 6 |
--filename_tokenizer_model "http://0.0.0.0:12345" \
|
| 7 |
--filename_post_axmodel "./internvl3_2b_axmodel/qwen2_post.axmodel" \
|
| 8 |
-
--filename_tokens_embed "./internvl3_2b_axmodel/model.embed_tokens.weight.
|
| 9 |
--tokens_embed_num 151674 \
|
| 10 |
--tokens_embed_size 1536 \
|
| 11 |
--live_print 1
|
|
|
|
| 5 |
--use_mmap_load_embed 1 \
|
| 6 |
--filename_tokenizer_model "http://0.0.0.0:12345" \
|
| 7 |
--filename_post_axmodel "./internvl3_2b_axmodel/qwen2_post.axmodel" \
|
| 8 |
+
--filename_tokens_embed "./internvl3_2b_axmodel/model.embed_tokens.weight.bf16.bin" \
|
| 9 |
--tokens_embed_num 151674 \
|
| 10 |
--tokens_embed_size 1536 \
|
| 11 |
--live_print 1
|
run_internvl_3_2b_448_axcl_aarch64.sh
ADDED
|
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
./main_axcl_aarch64 \
|
| 2 |
+
--template_filename_axmodel "./internvl3_2b_ax650/qwen2_p128_l%d_together.axmodel" \
|
| 3 |
+
--axmodel_num 28 \
|
| 4 |
+
--filename_image_encoder_axmodedl "./internvl3_2b_ax650/internvl3_2b_vit.axmodel" \
|
| 5 |
+
--use_mmap_load_embed 1 \
|
| 6 |
+
--filename_tokenizer_model "http://0.0.0.0:12345" \
|
| 7 |
+
--filename_post_axmodel "./internvl3_2b_ax650/qwen2_post.axmodel" \
|
| 8 |
+
--filename_tokens_embed "./internvl3_2b_ax650/model.embed_tokens.weight.bfloat16.bin" \
|
| 9 |
+
--tokens_embed_num 151674 \
|
| 10 |
+
--tokens_embed_size 1536 \
|
| 11 |
+
--devices 0 \
|
| 12 |
+
--live_print 1
|
run_internvl_3_2b_448_axcl_x86.sh
CHANGED
|
@@ -5,7 +5,7 @@
|
|
| 5 |
--use_mmap_load_embed 1 \
|
| 6 |
--filename_tokenizer_model "http://0.0.0.0:12345" \
|
| 7 |
--filename_post_axmodel "./internvl3_2b_axmodel/qwen2_post.axmodel" \
|
| 8 |
-
--filename_tokens_embed "./internvl3_2b_axmodel/model.embed_tokens.weight.
|
| 9 |
--tokens_embed_num 151674 \
|
| 10 |
--tokens_embed_size 1536 \
|
| 11 |
--devices 0,2,4 \
|
|
|
|
| 5 |
--use_mmap_load_embed 1 \
|
| 6 |
--filename_tokenizer_model "http://0.0.0.0:12345" \
|
| 7 |
--filename_post_axmodel "./internvl3_2b_axmodel/qwen2_post.axmodel" \
|
| 8 |
+
--filename_tokens_embed "./internvl3_2b_axmodel/model.embed_tokens.weight.bf16.bin" \
|
| 9 |
--tokens_embed_num 151674 \
|
| 10 |
--tokens_embed_size 1536 \
|
| 11 |
--devices 0,2,4 \
|
webgui.png
ADDED
|
Git LFS Details
|