benjamin commited on
Commit
b07d89d
·
verified ·
1 Parent(s): 6b48560

Create README.md

Browse files
Files changed (1) hide show
  1. README.md +82 -0
README.md ADDED
@@ -0,0 +1,82 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ license: mit
3
+ base_model:
4
+ - google/gemma-2-2b-it
5
+ ---
6
+
7
+ # Gemma2-2B-IT-Byte 🔢
8
+
9
+ __[Gemma2-2B](https://huggingface.co/google/gemma-2-2b-it) transferred to byte-level tokenization via [cross-tokenizer distillation](https://arxiv.org/abs/2503.20083).__
10
+
11
+ __🚧This model is intended as a proof-of-concept that we can quickly & effectively transfer pretrained (subword-based) models to the byte-level. It is not optimized for production use (in particular, it is not optimized for speed)!🚧__
12
+
13
+ ## Benchmarks
14
+
15
+ Gemma-2B-IT-Byte performs competitively although it has been trained only on 1.3B bytes (328M subword tokens total).
16
+
17
+ | | MMLU | BoolQ | PiQA | IFEval | ARC-C | Avg. |
18
+ |-----------------------------------|------|-------|-------|--------|-------|------|
19
+ | [EvaByte-6.5B-SFT](https://huggingface.co/EvaByte/EvaByte-SFT) | 49.5 | 79.5* | 74.1* | 60.2 | 64.6* | 65.6 |
20
+ | [Llama3.2-3B-Instruct (original)](https://huggingface.co/meta-llama/Llama-3.2-3B-Instruct) | 62.4 | 78.8 | 76.9 | 76.6 | 43.9 | 67.7 |
21
+ | [Gemma-2B-IT (original)](https://huggingface.co/google/gemma-2-2b-it) | 56.9 | 83.8 | 79.6 | 62.5 | 50.4 | 66.6 |
22
+ | [Llama3-2-3B-IT-Byte](https://huggingface.co/benjamin/Llama3-2-3B-IT-Byte) | 57.0 | 76.6 | 73.6 | 58.8 | 39.8 | 61.2 |
23
+ | __Gemma-2B-IT-Byte (this model)__ | __51.0__ | __80.5__ | __71.5__ | __51.9__ | __38.2__ | __58.6__ |
24
+
25
+ <small>*Numbers from EvaByte-6.5B (Base) since they are not reported for the SFT model.</small>
26
+
27
+ ## Usage
28
+
29
+ ```python
30
+ from transformers import AutoTokenizer, AutoModelForCausalLM
31
+
32
+ tokenizer = AutoTokenizer.from_pretrained("benjamin/Gemma2-2B-IT-Byte")
33
+
34
+ device = "cuda"
35
+ model = AutoModelForCausalLM.from_pretrained("benjamin/Gemma2-2B-IT-Byte", trust_remote_code=True)
36
+ model = model.to(device)
37
+
38
+ tokens = tokenizer.apply_chat_template([{"role": "user", "content": "Hello, how are you doing?"}], return_tensors="pt")
39
+ out = model.generate(tokens.to(model.device), eos_token_id=tokenizer.eos_token_id)
40
+ print(tokenizer.decode(out[0]))
41
+ ```
42
+
43
+ ## Training
44
+
45
+ This model has been trained using [`tokenkit`](https://github.com/bminixhofer/tokenkit) with the following command:
46
+
47
+ ```
48
+ python3 scripts/cross_tokenizer_distill.py \
49
+ --config=configs/cross_tokenizer_distill.yaml \
50
+ --overrides \
51
+ losses=[sft,alm_unconstrained,alm_latents] \
52
+ multitask_aggregation_fn=approx_gradmag_preserve_mag \
53
+ alm_mode=merge_by_space_prob+append_space \
54
+ tokenizer_pair_bias_threshold=0.1 \
55
+ max_student_length=2048 \
56
+ steps=20000 \
57
+ eval_interval=20000 \
58
+ save_interval=20000 \
59
+ optimizer.learning_rate=3.e-5 \
60
+ optimizer.weight_decay=0.0 \
61
+ optimizer.max_grad_norm=null \
62
+ optimizer.grad_acc_steps=1 \
63
+ train_model_mode=full \
64
+ expand_input_ids=true \
65
+ output_embeddings_mode=untie \
66
+ eval.tasks=[arc_easy,arc_challenge,piqa,boolq,arithmetic,mmlu,ifeval,agieval_en,agieval_cn] \
67
+ data.batch_size=32 \
68
+ student.pretrained_model_name_or_path=benjamin/gemma-2-2b-it-flax \
69
+ student.tokenizer_name=google/gemma-2-2b-it:source=Gemma2 \
70
+ target_tokenizer_name=google/gemma-2-2b-it:source=Gemma2:target=Gemma2:conversion=byte \
71
+ n_model_parallel=4 \
72
+ n_data_parallel=4 \
73
+ data.num_workers=16 \
74
+ num_workers=16 \
75
+ name=gemma2_to_byte_20k
76
+ ```
77
+
78
+ ## Future Work
79
+
80
+ The current version of this model is trained for 20k steps with 32*2048 bytes per batch (= 1.3B bytes ≈ 328M subword tokens total). It was unexpected that it performs as well as it does with this very short training procedure. We plan to train a new version for more steps (you can also do so yourself using [`tokenkit`](https://github.com/bminixhofer/tokenkit)).
81
+
82
+ To preserve efficiency, we would have to add (a combination of) [BLT-style hierarchical processing](https://arxiv.org/abs/2412.09871), [attention approximations](https://hkunlp.github.io/blog/2025/evabyte/), and [self-speculative decoding](https://arxiv.org/abs/2309.08168).