JoydeepC commited on
Commit
fb7bf93
·
verified ·
1 Parent(s): 65be902

Create handler.py

Browse files
Files changed (1) hide show
  1. handler.py +38 -0
handler.py ADDED
@@ -0,0 +1,38 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import torch
3
+ from transformers import AutoTokenizer, AutoModelForCausalLM
4
+ from typing import Dict, List, Any
5
+
6
+ # Replace with actual GraniteMoeForCausalLM import if available
7
+ # from granitemoe import GraniteMoeForCausalLM
8
+
9
+ class EndpointHandler:
10
+ def __init__(self, path: str = ""):
11
+ self.tokenizer = AutoTokenizer.from_pretrained(path)
12
+ self.model = AutoModelForCausalLM.from_pretrained(
13
+ path,
14
+ torch_dtype=torch.bfloat16,
15
+ device_map="auto"
16
+ )
17
+ self.model.eval()
18
+
19
+ def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]:
20
+ inputs = data.get("inputs", "")
21
+ parameters = data.get("parameters", {})
22
+ input_ids = self.tokenizer(inputs, return_tensors="pt").input_ids.to(self.model.device)
23
+ max_length = parameters.get("max_length", 100)
24
+ temperature = parameters.get("temperature", 1.0)
25
+ top_p = parameters.get("top_p", 1.0)
26
+ do_sample = parameters.get("do_sample", True)
27
+ with torch.no_grad():
28
+ outputs = self.model.generate(
29
+ input_ids,
30
+ max_length=max_length,
31
+ temperature=temperature,
32
+ top_p=top_p,
33
+ do_sample=do_sample,
34
+ pad_token_id=self.tokenizer.pad_token_id,
35
+ eos_token_id=self.tokenizer.eos_token_id
36
+ )
37
+ generated_text = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
38
+ return {"generated_text": generated_text}