Need a guide to run the model

#2
by se-ok - opened

I would love an official guide to run the model.

As of now,

  • vLLM / SGLang and such are not supported yet.
  • with transformers 4.57.3, attn_implementation = 'flash_attention_2' is required from the model file

I was able to run a simple inference using the following code at tp 8, but unable to generate with continuous batching yet, which hinders testing of the model.

"""infer_tp8.py"""
import os

import torch
import torch.distributed as dist
from transformers import AutoModelForCausalLM, AutoTokenizer

MODEL = ...

def main():
    # torchrun sets these
    local_rank = int(os.environ["LOCAL_RANK"])
    torch.cuda.set_device(local_rank)

    # Initialize distributed (TP needs a process group)
    dist.init_process_group(backend="nccl")

    tokenizer = AutoTokenizer.from_pretrained(
        MODEL,
        trust_remote_code=True,
    )

    model = AutoModelForCausalLM.from_pretrained(
        MODEL,
        trust_remote_code=True,
        torch_dtype=torch.bfloat16,
        attn_implementation="flash_attention_2",
        tp_plan="auto",
    ).eval()

    messages = [
        {"role": "user", "content": "who are you"},
    ]

    input_ids = tokenizer.apply_chat_template(
        messages,
        tokenize=True,
        add_generation_prompt=True,
        return_tensors="pt",
    ).to(torch.device("cuda", local_rank))

    with torch.inference_mode():
        out = model.generate(
            input_ids=input_ids,
            max_new_tokens=256,
            temperature=0.6,
            top_p=0.95,
            do_sample=True,
            pad_token_id=tokenizer.eos_token_id,
            eos_token_id=tokenizer.eos_token_id,
        )

    if dist.get_rank() == 0:
        new_tokens = out[0, input_ids.shape[-1] :]
        text = tokenizer.decode(new_tokens, skip_special_tokens=False)
        print(text)

    dist.destroy_process_group()

if __name__ == "__main__":
    main()
  • Command: torchrun --nproc-per-node 8 infer_tp8.py
  • Settings: `H100x8, transformers 4.57.3, flash-attn 2.8.3

Sign up or log in to comment