Need a guide to run the model
#2
by
se-ok
- opened
I would love an official guide to run the model.
As of now,
- vLLM / SGLang and such are not supported yet.
- with transformers 4.57.3, attn_implementation = 'flash_attention_2' is required from the model file
I was able to run a simple inference using the following code at tp 8, but unable to generate with continuous batching yet, which hinders testing of the model.
"""infer_tp8.py"""
import os
import torch
import torch.distributed as dist
from transformers import AutoModelForCausalLM, AutoTokenizer
MODEL = ...
def main():
# torchrun sets these
local_rank = int(os.environ["LOCAL_RANK"])
torch.cuda.set_device(local_rank)
# Initialize distributed (TP needs a process group)
dist.init_process_group(backend="nccl")
tokenizer = AutoTokenizer.from_pretrained(
MODEL,
trust_remote_code=True,
)
model = AutoModelForCausalLM.from_pretrained(
MODEL,
trust_remote_code=True,
torch_dtype=torch.bfloat16,
attn_implementation="flash_attention_2",
tp_plan="auto",
).eval()
messages = [
{"role": "user", "content": "who are you"},
]
input_ids = tokenizer.apply_chat_template(
messages,
tokenize=True,
add_generation_prompt=True,
return_tensors="pt",
).to(torch.device("cuda", local_rank))
with torch.inference_mode():
out = model.generate(
input_ids=input_ids,
max_new_tokens=256,
temperature=0.6,
top_p=0.95,
do_sample=True,
pad_token_id=tokenizer.eos_token_id,
eos_token_id=tokenizer.eos_token_id,
)
if dist.get_rank() == 0:
new_tokens = out[0, input_ids.shape[-1] :]
text = tokenizer.decode(new_tokens, skip_special_tokens=False)
print(text)
dist.destroy_process_group()
if __name__ == "__main__":
main()
- Command:
torchrun --nproc-per-node 8 infer_tp8.py - Settings: `H100x8, transformers 4.57.3, flash-attn 2.8.3