--- license: apache-2.0 language: - en --- # DistilViT2 Image Captioning Model This model performs image captioning using a prefix-conditioning architecture with LoRA adapters. ![Browser demo showing on-device captioning](demo-screenshot.png) ## Architecture ``` Image → SigLIP Vision Encoder → Projection Layer → SmolLM + LoRA → Caption ``` **Components:** - **Vision Encoder**: SigLIP-base-patch16-224 (frozen during training) - **Projection Layer**: Linear layer mapping vision features to text embedding space - **Language Model**: SmolLM-135M with LoRA adapters (rank=16) ## Model Contents The `model.safetensors` file (984 MB) contains **all weights** needed for inference: ### Complete Model Weights (723 tensors total) 1. **Vision Encoder** (~300 MB) - Complete SigLIP-base-patch16-224 weights - Keys: `vision_encoder.*` - Hidden size: 768 - Patches: 14×14 = 196 tokens per image 2. **Projection Layer** (~1 MB) - Linear projection: 768 → 576 - Keys: `projection.*` - Maps vision features to language model embedding space 3. **Language Model Base Weights** (~660 MB) - Complete SmolLM-135M base weights - Keys: `language_model.base_model.model.*` - 30 layers, 576 hidden size, 49152 vocab size 4. **LoRA Adapters** (~2 MB trainable) - Separate low-rank matrices (not merged into base) - Keys: `language_model.*.lora_A.default.weight`, `language_model.*.lora_B.default.weight` - Applied to: q_proj, k_proj, v_proj, o_proj - Rank: 16, Alpha: 16, Dropout: 0.1 ## Training Details - **Trainable Parameters**: 2.2M / 221M total (1%) - **Frozen**: Vision encoder (SigLIP) - **Trainable**: Projection layer + LoRA adapters - **Datasets**: Flickr30k, COCO - **Architecture**: Prefix-conditioning (no cross-attention) ## Usage ### Python CLI (torch vs ONNX) Run the side-by-side comparison script: ```bash python compare_inference.py --model-dir . --onnx-dir onnx --image cat.jpg --prompt "A photo of" --max-new-tokens 15 ``` Key path inside `compare_inference.py`: ```python vision_encoder, projection, language_model, _ = load_models(args.model_dir, device) pixel_values = preprocess(args.image, processor, device) torch_caption = torch_generate( vision_encoder, projection, language_model, tokenizer, pixel_values, args.prompt, args.max_new_tokens ) vision_sess, proj_sess, lm_sess = load_onnx_sessions(args.onnx_dir) onnx_caption = onnx_generate( vision_sess, proj_sess, lm_sess, language_model, tokenizer, pixel_values, args.prompt, args.max_new_tokens ) ``` `--image` defaults to `cat.jpg` in the repo if you do not pass one. The script prints both captions so you can verify parity between torch and ONNX. ### Browser demo (ort.js) Run a static server from `demo/`: ```bash cd demo && python -m http.server 8000 ``` `demo/main.js` drives the full pipeline fully on-device: ```javascript await loadAll(); // downloads tokenizer/processor assets and ONNX models from ./demo/models const pixelData = await preprocessImage(currentImage); const visionHidden = await runVision(pixelData); const projected = await runProjection(visionHidden); const prompt = ''; // not needed const encoded = await tokenizer(prompt); const initFeeds = { prefix_embeddings: projected, input_ids: new ort.Tensor('int64', BigInt64Array.from(encoded.input_ids.data.map(BigInt)), [1, encoded.input_ids.data.length]), }; const initOutputs = await prefixInitSession.run(initFeeds); // then decode step-by-step with cached past: const feeds = buildDecoderInputs([BigInt(nextToken)], attention, position, past); const outputs = await lmSession.run(feeds); ``` Open `http://localhost:8000`, drop an image, and click **Generate caption** to watch the vision → projection → prefix-init → decode flow run in the browser. ## Model Specifications | Component | Size | Parameters | Status | |-----------|------|------------|--------| | Vision Encoder | 768 hidden | ~87M | Frozen | | Projection | 768→576 | ~443K | Trainable | | Language Model | 576 hidden, 30 layers | ~134M | Base frozen | | LoRA Adapters | rank=16 | ~1.8M | Trainable | | **Total** | | **~221M** | **2.2M trainable** | ## Key Features - **Combined weights**: All components are merged into one `model.safetensors`; ONNX runtime uses three files in `onnx/` - **LoRA Preserved**: LoRA weights stored separately (not merged) for flexibility - **Efficient**: Only 1% of parameters trained, 5.7× faster than cross-attention baseline ## Technical Notes - Training scripts: https://github.com/tarekziade/distilvit2 - For a full walkthrough of the architecture and export flow, see the blog post: https://blog.ziade.org/2025/12/16/better-alt-text-part-2/. ## License Model weights inherit licenses from base models: - SigLIP: Apache 2.0 - SmolLM: Apache 2.0 ## Browser Demo (ort.js) An offline, browser-only demo lives in `demo/` and runs the ONNX exports via `ort.js` + `transformers.js` (vision encoder → projection → prefix_init → decoder). The `demo/models` directory points to the ONNX files in `onnx/`, so make sure those exports are present locally. Run it with any static server (e.g., from the repo root): ```bash cd demo && python -m http.server 8000 ``` Then open `http://localhost:8000`, drop an image, and click **Generate caption** to see the model run fully on-device. No remote fetches are required.