Holy-fox commited on
Commit
f507e9e
·
verified ·
1 Parent(s): ea0a107

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +42 -146
README.md CHANGED
@@ -25,199 +25,95 @@ Gemma 3 ファミリーと同様に、テキストと画像のマルチモーダ
25
  まず、必要なライブラリをインストールします。Gemma 3は `transformers` 4.50.0 以降が必要です。
26
 
27
  ```sh
28
- pip install -U transformers accelerate Pillow vllm
29
  # CPUのみで使用する場合や特定の環境ではvllmのインストールが異なる場合があります。
30
  # vLLMの公式ドキュメントを参照してください: https://docs.vllm.ai/en/latest/getting_started/installation.html
31
  ```
32
 
33
- ### vLLMでの推論
34
-
35
- [vLLM](https://github.com/vllm-project/vllm) を使用して高速な推論を行うサンプルコードです。
36
 
37
  ```python
38
- from vllm import LLM, SamplingParams
39
- from transformers import AutoTokenizer
 
 
40
 
41
  model_id = "DataPilot/ArrowMint-Gemma3-4B-ChocoMint-jp"
42
 
43
- # Gemma 3のチャットテンプレートを使用するためにTokenizerをロード
44
- tokenizer = AutoTokenizer.from_pretrained(model_id)
45
-
46
- # プロンプトの準備 (チャット形式)
47
- messages = [
48
- {"role": "system", "content": "あなたは親切なAIアシスタントです。"},
49
- {"role": "user", "content": "日本の首都とその見どころを教えてください。"}
50
- ]
51
-
52
- # チャットテンプレートを適用
53
- # vLLMは直接チャットテンプレートを適用できないため、tokenizerで文字列に変換します
54
- # 注意: vLLMのバージョンや設定によっては、より効率的な方法がある可能性があります
55
- prompt = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
56
-
57
- # LLMの初期化
58
- # 必要に応じて tensor_parallel_size を調整してください
59
- llm = LLM(model=model_id, trust_remote_code=True) # Gemma 3 モデルによっては trust_remote_code が必要
60
-
61
- # サンプリングパラメータの設定
62
- sampling_params = SamplingParams(temperature=0.7, top_p=0.9, max_tokens=512)
63
-
64
- # 推論の実行
65
- outputs = llm.generate(prompt, sampling_params)
66
-
67
- # 結果の表示
68
- for output in outputs:
69
- generated_text = output.outputs[0].text
70
- print(f"Generated text: {generated_text!r}")
71
-
72
- # >>> Generated text: '東京は日本の首都であり、多くの魅力的な観光スポットがあります。\n\n* **東京タワー:** 市街を一望できる象徴的なランドマークです。\n* **浅草寺:** 歴史ある寺院で、仲見世通りでの買い物も楽しめます。\n* **渋谷スクランブル交差点:** 世界的に有名な活気あふれる交差点です。\n* **新宿御苑:** 都心にある広大な庭園で、四季折々の自然を楽しめます。\n* **築地場外市場:** 新鮮な海産物やグルメを堪能できます。\n\nこれらの他にも、美術館、博物館、ショッピングエリアなど、見どころは尽きません。'
73
- ```
74
-
75
- ### Transformersでのテキスト推論
76
-
77
- `transformers` ライブラリを使用して、テキストのみ(システムプロンプトとユーザープロンプト)で推論を行うサンプルコードです。
78
 
79
- ```python
80
- # pip install accelerate が必要になる場合があります
81
- from transformers import AutoTokenizer, AutoModelForCausalLM # Gemma 3はConditionalGenerationですが、テキストのみならこちらでもロードできる場合があります
82
- # もし上記でエラーが出る場合や、公式に合わせる場合は以下を使用
83
- # from transformers import AutoTokenizer, Gemma3ForConditionalGeneration
84
- import torch
85
 
86
- model_id = "DataPilot/ArrowMint-Gemma3-4B-ChocoMint-jp"
87
- device = "cuda" # GPUが利用可能な場合
88
-
89
- # トークナイザーとモデルのロード
90
- tokenizer = AutoTokenizer.from_pretrained(model_id)
91
- # テキストのみの場合でも Gemma3ForConditionalGeneration を使用するのが確実です
92
- model = AutoModelForCausalLM.from_pretrained( # または Gemma3ForConditionalGeneration.from_pretrained
93
- model_id,
94
- torch_dtype=torch.bfloat16, # bfloat16を推奨
95
- device_map="auto", # 自動的にGPUに配置
96
- )
97
- # model = Gemma3ForConditionalGeneration.from_pretrained(
98
- # model_id,
99
- # torch_dtype=torch.bfloat16,
100
- # device_map="auto",
101
- # )
102
- model.eval()
103
-
104
- # チャット形式のプロンプト
105
  messages = [
106
- {"role": "system", "content": "あなたは知識豊富な歴史解説家です。簡潔に説明してください。"},
107
- {"role": "user", "content": "戦国時代の三英傑について教えてください。"}
 
 
 
 
 
 
 
 
 
108
  ]
109
 
110
- # プロンプトをトークナイズ (チャットテンプレートを適用)
111
- # Gemma 3 instruction-tuned モデルでは add_generation_prompt=True が重要です
112
- inputs = tokenizer.apply_chat_template(
113
- messages,
114
- add_generation_prompt=True,
115
- tokenize=True,
116
- return_tensors="pt"
117
- ).to(model.device)
118
 
119
- input_len = inputs.shape[-1]
120
 
121
- # 推論の実行
122
  with torch.inference_mode():
123
- generation = model.generate(
124
- inputs,
125
- max_new_tokens=200,
126
- do_sample=True, # サンプリングを行う場合
127
- temperature=0.2,
128
- top_p=0.9
129
- )
130
- # 入力部分を除いた生成されたトークンのみを取得
131
- generated_ids = generation[0][input_len:]
132
-
133
- # 結果をデコード
134
- decoded = tokenizer.decode(generated_ids, skip_special_tokens=True)
135
- print(decoded)
136
 
137
- # >>> 戦国時代の三英傑とは、織田信長、豊臣秀吉、徳川家康の3人を指します。
138
- # >>>
139
- # >>> * **織田信長:** 尾張の小大名から身を起こし、天下統一を目前にしながら本能寺の変で倒れました。革新的な政策や戦術で知られます。
140
- # >>> * **豊臣秀吉:** 信長の後を継ぎ、天下統一を成し遂げました。農民出身から最高権力者に上り詰めた人物です。
141
- # >>> * **徳川家康:** 秀吉の死後、関ヶ原の戦いで勝利し、江戸幕府を開いて約260年続く泰平の世を築きました。
142
  ```
143
-
144
- ### Transformersでの画像とテキスト推論
145
-
146
- `transformers` ライブラリを使用して、画像とテキストを入力として推論を行うサンプルコードです。
147
 
148
  ```python
149
- # pip install accelerate が必要になる場合があります
150
  from transformers import AutoProcessor, Gemma3ForConditionalGeneration
151
- from PIL import Image
152
- import requests
153
  import torch
154
 
155
  model_id = "DataPilot/ArrowMint-Gemma3-4B-ChocoMint-jp"
156
- device = "cuda" # GPUが利用可能な場合
157
 
158
- # プロセッサーとモデルのロード
159
- processor = AutoProcessor.from_pretrained(model_id)
160
  model = Gemma3ForConditionalGeneration.from_pretrained(
161
- model_id,
162
- torch_dtype=torch.bfloat16, # bfloat16を推奨
163
- device_map="auto", # 自動的にGPUに配置
164
  ).eval()
165
 
166
- # チャット形式のプロンプト (画像とテキストを含む)
167
- # 画像のURLやローカルパスを指定できます
168
- image_url = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/bee.jpg"
169
- # ローカルファイルの場合: image = Image.open("path/to/your/image.jpg")
170
- image = Image.open(requests.get(image_url, stream=True).raw)
171
 
172
  messages = [
173
  {
174
  "role": "system",
175
- "content": [{"type": "text", "text": "あなたは画像について説明するAIアシスタントです。"}]
176
  },
177
  {
178
  "role": "user",
179
  "content": [
180
- {"type": "image", "image": image}, # PILイメージオブジェクトを渡す
181
- # URLを直接渡すことも可能な場合があります (ライブラリのバージョンによる)
182
- # {"type": "image", "url": image_url},
183
- {"type": "text", "text": "この画像に写っている花と昆虫について説明してください。"}
184
  ]
185
  }
186
  ]
187
 
188
- # プロンプトを処理してトークナイズ
189
- # Gemma 3 instruction-tuned モデルでは add_generation_prompt=True が重要です
190
  inputs = processor.apply_chat_template(
191
- messages,
192
- add_generation_prompt=True,
193
- tokenize=True,
194
- return_dict=True, # return_tensors="pt" と合わせて辞書形式で受け取る
195
- return_tensors="pt"
196
- ).to(model.device) # processorがtorch_dtypeを適切に扱わない場合があるため、ここで .to(dtype=torch.bfloat16) を追加する必要があるかもしれません
197
 
198
  input_len = inputs["input_ids"].shape[-1]
199
 
200
- # 推論の実行
201
  with torch.inference_mode():
202
- generation = model.generate(
203
- **inputs,
204
- max_new_tokens=150,
205
- do_sample=False # 決定的な出力を得る場合
206
- )
207
- # 入力部分を除いた生成されたトークンのみを取得
208
- generated_ids = generation[0][input_len:]
209
-
210
- # 結果をデコード
211
- # processor.decode は text/image トークンを適切に扱います
212
- decoded = processor.decode(generated_ids, skip_special_tokens=True)
213
- print(decoded)
214
 
215
- # >>> 画像には、ピンク色のコスモスのような花にミツバチ(またはマルハナバチ)が止まっている様子が写っています。
216
- # >>>
217
- # >>> * **花:** ピンク色の花びらを持つキク科の植物で、おそらくコスモスでしょう。中央には黄色い花粉が見えます。
218
- # >>> * **昆虫:** 体に黄色と黒の縞模様があり、毛深い外見からマルハナバチ(Bumblebee)である可能性が高いです。花の中心部で蜜や花粉を集めているようです。
219
- # >>>
220
- # >>> 背景は緑色で、自然光の下で撮影されたような、柔らかい雰囲気の写真です。
221
  ```
222
 
223
  ## License
 
25
  まず、必要なライブラリをインストールします。Gemma 3は `transformers` 4.50.0 以降が必要です。
26
 
27
  ```sh
28
+ pip install -U transformers accelerate Pillow
29
  # CPUのみで使用する場合や特定の環境ではvllmのインストールが異なる場合があります。
30
  # vLLMの公式ドキュメントを参照してください: https://docs.vllm.ai/en/latest/getting_started/installation.html
31
  ```
32
 
33
+ ### 画像付き推論
 
 
34
 
35
  ```python
36
+ from transformers import AutoProcessor, Gemma3ForConditionalGeneration
37
+ from PIL import Image
38
+ import requests
39
+ import torch
40
 
41
  model_id = "DataPilot/ArrowMint-Gemma3-4B-ChocoMint-jp"
42
 
43
+ model = Gemma3ForConditionalGeneration.from_pretrained(
44
+ model_id, device_map="auto"
45
+ ).eval()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
46
 
47
+ processor = AutoProcessor.from_pretrained(model_id)
 
 
 
 
 
48
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
49
  messages = [
50
+ {
51
+ "role": "system",
52
+ "content": [{"type": "text", "text": "あなたは素晴らしい日本語アシスタントです。"}]
53
+ },
54
+ {
55
+ "role": "user",
56
+ "content": [
57
+ {"type": "image", "image": "https://cs.stanford.edu/people/rak248/VG_100K_2/2399540.jpg"},
58
+ {"type": "text", "text": "この画像を説明してください。"}
59
+ ]
60
+ }
61
  ]
62
 
63
+ inputs = processor.apply_chat_template(
64
+ messages, add_generation_prompt=True, tokenize=True,
65
+ return_dict=True, return_tensors="pt"
66
+ ).to(model.device, dtype=torch.bfloat16)
 
 
 
 
67
 
68
+ input_len = inputs["input_ids"].shape[-1]
69
 
 
70
  with torch.inference_mode():
71
+ generation = model.generate(**inputs, max_new_tokens=100, do_sample=False)
72
+ generation = generation[0][input_len:]
 
 
 
 
 
 
 
 
 
 
 
73
 
74
+ decoded = processor.decode(generation, skip_special_tokens=True)
75
+ print(decoded)
 
 
 
76
  ```
77
+ ### 画像無し推論
 
 
 
78
 
79
  ```python
 
80
  from transformers import AutoProcessor, Gemma3ForConditionalGeneration
 
 
81
  import torch
82
 
83
  model_id = "DataPilot/ArrowMint-Gemma3-4B-ChocoMint-jp"
 
84
 
 
 
85
  model = Gemma3ForConditionalGeneration.from_pretrained(
86
+ model_id, device_map="auto"
 
 
87
  ).eval()
88
 
89
+ processor = AutoProcessor.from_pretrained(model_id)
 
 
 
 
90
 
91
  messages = [
92
  {
93
  "role": "system",
94
+ "content": [{"type": "text", "text": "あなたは素晴らしい日本語アシスタントです。"}]
95
  },
96
  {
97
  "role": "user",
98
  "content": [
99
+ {"type": "text", "text": "福岡に一人で遊びに行くのですがお勧めスポットはありますか?"}
 
 
 
100
  ]
101
  }
102
  ]
103
 
 
 
104
  inputs = processor.apply_chat_template(
105
+ messages, add_generation_prompt=True, tokenize=True,
106
+ return_dict=True, return_tensors="pt"
107
+ ).to(model.device, dtype=torch.bfloat16)
 
 
 
108
 
109
  input_len = inputs["input_ids"].shape[-1]
110
 
 
111
  with torch.inference_mode():
112
+ generation = model.generate(**inputs, max_new_tokens=100, do_sample=False)
113
+ generation = generation[0][input_len:]
 
 
 
 
 
 
 
 
 
 
114
 
115
+ decoded = processor.decode(generation, skip_special_tokens=True)
116
+ print(decoded)
 
 
 
 
117
  ```
118
 
119
  ## License