kasimali commited on
Commit
3de163f
·
verified ·
1 Parent(s): ee405f2

Update indiclid_inference.py

Browse files
Files changed (1) hide show
  1. indiclid_inference.py +134 -0
indiclid_inference.py CHANGED
@@ -0,0 +1,134 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import re
3
+ import pandas as pd
4
+ import fasttext
5
+ import torch
6
+ from torch.utils.data import Dataset, DataLoader
7
+ from transformers import AutoTokenizer
8
+ from huggingface_hub import hf_hub_download
9
+
10
+ # ------------------------------
11
+ # Download required models
12
+ # ------------------------------
13
+ print("Downloading IndicLID models from Hugging Face...")
14
+
15
+ FTN_PATH = hf_hub_download("ai4bharat/IndicLID-FTN", filename="model_baseline_roman.bin")
16
+ FTR_PATH = hf_hub_download("ai4bharat/IndicLID-FTR", filename="model_baseline_roman.bin")
17
+ BERT_PATH = hf_hub_download("ai4bharat/IndicLID-BERT", filename="basline_nn_simple.pt")
18
+
19
+ print("Download complete.")
20
+
21
+ # ------------------------------
22
+ # Data helper for BERT batching
23
+ # ------------------------------
24
+ class IndicBERT_Data(Dataset):
25
+ def __init__(self, indices, X):
26
+ self.x = list(X)
27
+ self.i = list(indices)
28
+ def __len__(self):
29
+ return len(self.x)
30
+ def __getitem__(self, idx):
31
+ return self.i[idx], self.x[idx]
32
+
33
+ # ------------------------------
34
+ # Main IndicLID Class
35
+ # ------------------------------
36
+ class IndicLID:
37
+ def __init__(self, input_threshold=0.5, roman_lid_threshold=0.6):
38
+ self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
39
+ self.IndicLID_FTN = fasttext.load_model(FTN_PATH)
40
+ self.IndicLID_FTR = fasttext.load_model(FTR_PATH)
41
+ self.IndicLID_BERT = torch.load(BERT_PATH, map_location=self.device)
42
+ self.IndicLID_BERT.eval()
43
+ self.IndicLID_BERT_tokenizer = AutoTokenizer.from_pretrained("ai4bharat/IndicBERTv2-MLM-only")
44
+ self.input_threshold = input_threshold
45
+ self.model_threshold = roman_lid_threshold
46
+
47
+ # Official label map (index -> language code)
48
+ self.label_map_reverse = {
49
+ 0: 'asm_Latn', 1: 'ben_Latn', 2: 'brx_Latn', 3: 'guj_Latn',
50
+ 4: 'hin_Latn', 5: 'kan_Latn', 6: 'kas_Latn', 7: 'kok_Latn',
51
+ 8: 'mai_Latn', 9: 'mal_Latn', 10: 'mni_Latn', 11: 'mar_Latn',
52
+ 12: 'nep_Latn', 13: 'ori_Latn', 14: 'pan_Latn', 15: 'san_Latn',
53
+ 16: 'snd_Latn', 17: 'tam_Latn', 18: 'tel_Latn', 19: 'urd_Latn',
54
+ 20: 'eng_Latn', 21: 'other', 22: 'asm_Beng', 23: 'ben_Beng',
55
+ 24: 'brx_Deva', 25: 'doi_Deva', 26: 'guj_Gujr', 27: 'hin_Deva',
56
+ 28: 'kan_Knda', 29: 'kas_Arab', 30: 'kas_Deva', 31: 'kok_Deva',
57
+ 32: 'mai_Deva', 33: 'mal_Mlym', 34: 'mni_Beng', 35: 'mni_Meti',
58
+ 36: 'mar_Deva', 37: 'nep_Deva', 38: 'ori_Orya', 39: 'pan_Guru',
59
+ 40: 'san_Deva', 41: 'sat_Olch', 42: 'snd_Arab', 43: 'tam_Tamil',
60
+ 44: 'tel_Telu', 45: 'urd_Arab'
61
+ }
62
+
63
+ def char_percent_check(self, text):
64
+ total_chars = sum(c.isalpha() for c in text)
65
+ roman_chars = sum(bool(re.match(r"[A-Za-z]", c)) for c in text)
66
+ return roman_chars / total_chars if total_chars else 0
67
+
68
+ def native_inference(self, data, out_dict):
69
+ if not data: return out_dict
70
+ texts = [x[1] for x in data]
71
+ preds = self.IndicLID_FTN.predict(texts)
72
+ for (idx, txt), lbls, scrs in zip(data, preds[0], preds[1]):
73
+ out_dict[idx] = (txt, lbls[0][9:], float(scrs[0]), 'IndicLID-FTN')
74
+ return out_dict
75
+
76
+ def ftr_inference(self, data, out_dict, batch_size):
77
+ if not data: return out_dict
78
+ texts = [x[1] for x in data]
79
+ preds = self.IndicLID_FTR.predict(texts)
80
+ bert_inputs = []
81
+ for (idx, txt), lbls, scrs in zip(data, preds[0], preds[1]):
82
+ if float(scrs[0]) > self.model_threshold:
83
+ out_dict[idx] = (txt, lbls[0][9:], float(scrs[0]), 'IndicLID-FTR')
84
+ else:
85
+ bert_inputs.append((idx, txt))
86
+ return self.bert_inference(bert_inputs, out_dict, batch_size)
87
+
88
+ def bert_inference(self, data, out_dict, batch_size):
89
+ if not data: return out_dict
90
+ ds = IndicBERT_Data([x[0] for x in data], [x[1] for x in data])
91
+ dl = DataLoader(ds, batch_size=batch_size)
92
+ with torch.no_grad():
93
+ for idxs, texts in dl:
94
+ enc = self.IndicLID_BERT_tokenizer(
95
+ list(texts), return_tensors="pt", padding=True,
96
+ truncation=True, max_length=512
97
+ ).to(self.device)
98
+ outputs = self.IndicLID_BERT(**enc)
99
+ preds = torch.argmax(outputs.logits, dim=1)
100
+ probs = torch.softmax(outputs.logits, dim=1)
101
+ for batch_i, p in enumerate(preds):
102
+ i = idxs[batch_i].item()
103
+ label_idx = p.item()
104
+ label = self.label_map_reverse[label_idx]
105
+ score = probs[batch_i, label_idx].item()
106
+ out_dict[i] = (texts[batch_i], label, score, 'IndicLID-BERT')
107
+ return out_dict
108
+
109
+ def batch_predict(self, texts, batch_size=8):
110
+ native, roman = [], []
111
+ for i, t in enumerate(texts):
112
+ if self.char_percent_check(t) > self.input_threshold:
113
+ roman.append((i, t))
114
+ else:
115
+ native.append((i, t))
116
+ out_dict = {}
117
+ out_dict = self.native_inference(native, out_dict)
118
+ out_dict = self.ftr_inference(roman, out_dict, batch_size)
119
+ return [out_dict[i] for i in sorted(out_dict.keys())]
120
+
121
+ # ------------------------------
122
+ # Quick test if run directly
123
+ # ------------------------------
124
+ if __name__ == "__main__":
125
+ detector = IndicLID()
126
+ samples = [
127
+ "यह एक हिंदी वाक्य है।", # Hindi (native)
128
+ "ennai pudikkuma?", # Tamil (romanized)
129
+ "ఇది ఒక తెలుగు వాక్యం", # Telugu (native)
130
+ "Hello, how are you?" # English
131
+ ]
132
+ results = detector.batch_predict(samples)
133
+ for text, label, score, model in results:
134
+ print(f"Text: {text}\nPredicted: {label} | Score: {score:.4f} | Model: {model}\n")