kasimali commited on
Commit
ee405f2
·
verified ·
1 Parent(s): 875e456

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +120 -182
app.py CHANGED
@@ -1,190 +1,128 @@
1
  import os
2
  import re
3
- import gradio as gr
4
  import fasttext
5
  import torch
 
6
  from transformers import AutoTokenizer
7
  from huggingface_hub import hf_hub_download
8
 
9
- # ---------------------------------------------------------------------------------
10
- # Configuration (thresholds and tokenizer)
11
- # ---------------------------------------------------------------------------------
12
- ROMAN_SPLIT_THRESHOLD = float(os.getenv("INDICLID_INPUT_ROMAN_THRESHOLD", "0.5")) # >50% A–Z => romanized
13
- FTR_CONF_THRESHOLD = float(os.getenv("INDICLID_ROMAN_CONF_THRESHOLD", "0.6")) # FTR prob threshold for BERT fallback
14
- BERT_TOKENIZER_ID = os.getenv("INDICLID_BERT_TOKENIZER", "ai4bharat/IndicBERTv2-MLM-only")
15
-
16
- # Persist Hugging Face cache if Space has persistent storage enabled
17
- os.environ["HF_HOME"] = os.getenv("HF_HOME", "/data/.huggingface")
18
-
19
- # Local filenames (no models/ folder)
20
- FTN_LOCAL = "indiclid_ftn.bin"
21
- FTR_LOCAL = "indiclid_ftr.bin"
22
- BERT_LOCAL = "indiclid_bert.pt"
23
-
24
- # Repos and filenames confirmed from upstream
25
- # - FTN fastText (native): model_baseline_roman.bin (as used in official paths in repo)
26
- # - FTR fastText (roman): model_baseline_roman.bin
27
- # - BERT fallback: basline_nn_simple.pt
28
- FTN_REPO = "ai4bharat/IndicLID-FTN" # file exists in repo; official code references this filename
29
- FTN_FILENAME = "model_baseline_roman.bin" # per upstream repo path usage[9][13]
30
-
31
- FTR_REPO = "ai4bharat/IndicLID-FTR"
32
- FTR_FILENAME = "model_baseline_roman.bin" # per HF commit/files[12]
33
-
34
- BERT_REPO = "ai4bharat/IndicLID-BERT"
35
- BERT_FILENAME = "basline_nn_simple.pt" # per HF file listing[7][10]
36
-
37
- # ---------------------------------------------------------------------------------
38
- # Utilities
39
- # ---------------------------------------------------------------------------------
40
- def ensure_artifact(local_path: str, repo_id: str, filename: str):
41
- if os.path.exists(local_path):
42
- return local_path
43
- downloaded = hf_hub_download(repo_id=repo_id, filename=filename)
44
- if downloaded != local_path:
45
- try:
46
- os.rename(downloaded, local_path)
47
- except Exception:
48
- import shutil
49
- shutil.copyfile(downloaded, local_path)
50
- return local_path
51
-
52
- # ---------------------------------------------------------------------------------
53
- # Download and load models
54
- # ---------------------------------------------------------------------------------
55
- FTN_PATH = ensure_artifact(FTN_LOCAL, FTN_REPO, FTN_FILENAME) # native-script fastText[9][13]
56
- FTR_PATH = ensure_artifact(FTR_LOCAL, FTR_REPO, FTR_FILENAME) # roman fastText[12]
57
- BERT_PATH = ensure_artifact(BERT_LOCAL, BERT_REPO, BERT_FILENAME) # BERT fallback[7][10]
58
-
59
- DEVICE = "cuda:0" if torch.cuda.is_available() else "cpu"
60
-
61
- ftn_model = fasttext.load_model(FTN_PATH)
62
- ftr_model = fasttext.load_model(FTR_PATH)
63
-
64
- # Note: basline_nn_simple.pt is a ready-to-call module in the official pipeline that returns logits
65
- bert_model = torch.load(BERT_PATH, map_location=DEVICE)
66
- if hasattr(bert_model, "to"):
67
- bert_model = bert_model.to(DEVICE)
68
- bert_model.eval()
69
- tokenizer = AutoTokenizer.from_pretrained(BERT_TOKENIZER_ID)
70
-
71
- # If the checkpoint exposes label_map_reverse, prefer it; else use a safe placeholder.
72
- LABEL_MAP_REVERSE = getattr(bert_model, "label_map_reverse", None)
73
- if LABEL_MAP_REVERSE is None:
74
- # Replace with official mapping from the IndicLID inference file for exact codes.
75
- LABEL_MAP_REVERSE = {i: f"label_{i}" for i in range(60)}
76
-
77
- # ---------------------------------------------------------------------------------
78
- # Inference helpers
79
- # ---------------------------------------------------------------------------------
80
- def roman_char_ratio(text: str) -> float:
81
- if not text:
82
- return 0.0
83
- roman = len(re.findall(r"[A-Za-z]", text))
84
- return roman / max(len(text), 1)
85
-
86
- def predict_ftn(texts):
87
- labels, scores = ftn_model.predict(texts)
88
- out = []
89
- for t, ls, sc in zip(texts, labels, scores):
90
- out.append({
91
- "text": t,
92
- "label": ls[0].replace("__label__", ""),
93
- "score": float(sc[0]),
94
- "model": "IndicLID-FTN"
95
- })
96
- return out
97
-
98
- def ftr_predict_or_route(texts):
99
- labels, scores = ftr_model.predict(texts)
100
- kept, route = [], []
101
- for idx, (t, ls, sc) in enumerate(zip(texts, labels, scores)):
102
- conf = float(sc[0])
103
- lbl = ls[0].replace("__label__", "")
104
- if conf > FTR_CONF_THRESHOLD:
105
- kept.append({"index": idx, "text": t, "label": lbl, "score": conf, "model": "IndicLID-FTR"})
106
- else:
107
- route.append((idx, t))
108
- return kept, route
109
-
110
- @torch.no_grad()
111
- def bert_predict(indexed_inputs):
112
- if not indexed_inputs:
113
- return []
114
- idxs = [i for i, _ in indexed_inputs]
115
- texts = [t for _, t in indexed_inputs]
116
- enc = tokenizer(texts, return_tensors="pt", padding=True, truncation=True, max_length=512)
117
- for k in enc:
118
- enc[k] = enc[k].to(DEVICE)
119
- outputs = bert_model(
120
- enc["input_ids"],
121
- token_type_ids=enc.get("token_type_ids"),
122
- attention_mask=enc.get("attention_mask"),
123
- )
124
- logits = outputs.logits if hasattr(outputs, "logits") else outputs[0]
125
- probs = torch.softmax(logits, dim=1)
126
- preds = torch.argmax(probs, dim=1)
127
- scores = probs.gather(1, preds.unsqueeze(1)).squeeze(1)
128
- results = []
129
- for i, t, p, s in zip(idxs, texts, preds, scores):
130
- label_idx = int(p.item())
131
- label = LABEL_MAP_REVERSE.get(label_idx, str(label_idx))
132
- results.append({"index": i, "text": t, "label": label, "score": float(s.item()), "model": "IndicLID-BERT"})
133
- results.sort(key=lambda x: x["index"])
134
- return results
135
-
136
- def ensemble_predict(texts):
137
- roman_inputs, native_inputs = [], []
138
- for i, t in enumerate(texts):
139
- if roman_char_ratio(t) > ROMAN_SPLIT_THRESHOLD:
140
- roman_inputs.append((i, t))
141
- else:
142
- native_inputs.append((i, t))
143
-
144
- outputs = {}
145
-
146
- if native_inputs:
147
- nat_texts = [t for _, t in native_inputs]
148
- nat_out = predict_ftn(nat_texts)
149
- for (i, _), r in zip(native_inputs, nat_out):
150
- outputs[i] = r
151
-
152
- if roman_inputs:
153
- rom_texts = [t for _, t in roman_inputs]
154
- ftr_kept, bert_inputs = ftr_predict_or_route(rom_texts)
155
- for kept in ftr_kept:
156
- i_orig = roman_inputs[kept["index"]][0]
157
- outputs[i_orig] = {
158
- "text": kept["text"], "label": kept["label"], "score": kept["score"], "model": kept["model"]
159
- }
160
- if bert_inputs:
161
- bert_out = bert_predict(bert_inputs)
162
- for r in bert_out:
163
- i_orig = roman_inputs[r["index"]][0]
164
- outputs[i_orig] = {
165
- "text": r["text"], "label": r["label"], "score": r["score"], "model": r["model"]
166
- }
167
-
168
- return [outputs[i] for i in sorted(outputs.keys())]
169
-
170
- # ---------------------------------------------------------------------------------
171
- # Gradio UI
172
- # ---------------------------------------------------------------------------------
173
- def detect(texts_str: str):
174
- if not texts_str or not texts_str.strip():
175
- return []
176
- lines = [t.strip() for t in texts_str.split("\n") if t.strip()]
177
- return ensemble_predict(lines)
178
-
179
- with gr.Blocks(title="IndicLID Ensemble (AI4Bharat) — Gradio Space") as demo:
180
- gr.Markdown(
181
- "## IndicLID Ensemble (AI4Bharat)\n"
182
- "Two-stage LID for 22 Indian languages (47 classes), with native fastText (FTN), roman fastText (FTR), "
183
- "and IndicBERT fallback for low-confidence romanized inputs."
184
- )
185
- inp = gr.Textbox(lines=8, label="Enter text(s) — one per line")
186
- out = gr.JSON(label="Predictions")
187
- gr.Button("Detect").click(fn=detect, inputs=inp, outputs=out)
188
-
189
  if __name__ == "__main__":
190
- demo.launch()
 
 
 
 
 
 
 
 
 
 
1
  import os
2
  import re
3
+ import pandas as pd
4
  import fasttext
5
  import torch
6
+ from torch.utils.data import Dataset, DataLoader
7
  from transformers import AutoTokenizer
8
  from huggingface_hub import hf_hub_download
9
 
10
+ # ------------------------------
11
+ # Download models automatically
12
+ # ------------------------------
13
+ print("Downloading IndicLID models from Hugging Face...")
14
+
15
+ FTN_PATH = hf_hub_download("ai4bharat/IndicLID-FTN", filename="model_baseline_roman.bin")
16
+ FTR_PATH = hf_hub_download("ai4bharat/IndicLID-FTR", filename="model_baseline_roman.bin")
17
+ BERT_PATH = hf_hub_download("ai4bharat/IndicLID-BERT", filename="basline_nn_simple.pt")
18
+
19
+ print("Download complete.")
20
+
21
+ # ------------------------------
22
+ # Dataset class for BERT batching
23
+ # ------------------------------
24
+ class IndicBERT_Data(Dataset):
25
+ def __init__(self, indices, X):
26
+ self.x = list(X)
27
+ self.i = list(indices)
28
+ def __len__(self):
29
+ return len(self.x)
30
+ def __getitem__(self, idx):
31
+ return self.i[idx], self.x[idx]
32
+
33
+ # ------------------------------
34
+ # Full IndicLID Ensemble
35
+ # ------------------------------
36
+ class IndicLID:
37
+ def __init__(self, input_threshold=0.5, roman_lid_threshold=0.6):
38
+ self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
39
+ self.IndicLID_FTN = fasttext.load_model(FTN_PATH)
40
+ self.IndicLID_FTR = fasttext.load_model(FTR_PATH)
41
+ self.IndicLID_BERT = torch.load(BERT_PATH, map_location=self.device)
42
+ self.IndicLID_BERT.eval()
43
+ self.IndicLID_BERT_tokenizer = AutoTokenizer.from_pretrained("ai4bharat/IndicBERTv2-MLM-only")
44
+ self.input_threshold = input_threshold
45
+ self.model_threshold = roman_lid_threshold
46
+
47
+ # Official label mapping from AI4Bharat
48
+ self.label_map_reverse = {
49
+ 0: 'asm_Latn', 1: 'ben_Latn', 2: 'brx_Latn', 3: 'guj_Latn', 4: 'hin_Latn',
50
+ 5: 'kan_Latn', 6: 'kas_Latn', 7: 'kok_Latn', 8: 'mai_Latn', 9: 'mal_Latn',
51
+ 10: 'mni_Latn', 11: 'mar_Latn', 12: 'nep_Latn', 13: 'ori_Latn', 14: 'pan_Latn',
52
+ 15: 'san_Latn', 16: 'snd_Latn', 17: 'tam_Latn', 18: 'tel_Latn', 19: 'urd_Latn',
53
+ 20: 'eng_Latn', 21: 'other', 22: 'asm_Beng', 23: 'ben_Beng', 24: 'brx_Deva',
54
+ 25: 'doi_Deva', 26: 'guj_Gujr', 27: 'hin_Deva', 28: 'kan_Knda', 29: 'kas_Arab',
55
+ 30: 'kas_Deva', 31: 'kok_Deva', 32: 'mai_Deva', 33: 'mal_Mlym', 34: 'mni_Beng',
56
+ 35: 'mni_Meti', 36: 'mar_Deva', 37: 'nep_Deva', 38: 'ori_Orya', 39: 'pan_Guru',
57
+ 40: 'san_Deva', 41: 'sat_Olch', 42: 'snd_Arab', 43: 'tam_Tamil', 44: 'tel_Telu',
58
+ 45: 'urd_Arab'
59
+ }
60
+
61
+ def char_percent_check(self, text):
62
+ total_chars = sum(c.isalpha() for c in text)
63
+ roman_chars = sum(bool(re.match(r"[A-Za-z]", c)) for c in text)
64
+ return roman_chars / total_chars if total_chars else 0
65
+
66
+ def native_inference(self, data, out_dict):
67
+ if not data: return out_dict
68
+ texts = [x[1] for x in data]
69
+ preds = self.IndicLID_FTN.predict(texts)
70
+ for (idx, txt), lbls, scrs in zip(data, preds[0], preds[1]):
71
+ out_dict[idx] = (txt, lbls[0][9:], float(scrs[0]), 'IndicLID-FTN')
72
+ return out_dict
73
+
74
+ def ftr_inference(self, data, out_dict, batch_size):
75
+ if not data: return out_dict
76
+ texts = [x[1] for x in data]
77
+ preds = self.IndicLID_FTR.predict(texts)
78
+ bert_inputs = []
79
+ for (idx, txt), lbls, scrs in zip(data, preds[0], preds[1]):
80
+ if float(scrs[0]) > self.model_threshold:
81
+ out_dict[idx] = (txt, lbls[0][9:], float(scrs[0]), 'IndicLID-FTR')
82
+ else:
83
+ bert_inputs.append((idx, txt))
84
+ return self.bert_inference(bert_inputs, out_dict, batch_size)
85
+
86
+ def bert_inference(self, data, out_dict, batch_size):
87
+ if not data: return out_dict
88
+ ds = IndicBERT_Data([x[0] for x in data], [x[1] for x in data])
89
+ dl = DataLoader(ds, batch_size=batch_size)
90
+ with torch.no_grad():
91
+ for idxs, texts in dl:
92
+ enc = self.IndicLID_BERT_tokenizer(list(texts), return_tensors="pt", padding=True,
93
+ truncation=True, max_length=512).to(self.device)
94
+ outputs = self.IndicLID_BERT(**enc)
95
+ preds = torch.argmax(outputs.logits, dim=1)
96
+ probs = torch.softmax(outputs.logits, dim=1)
97
+ for i, t, p in zip(idxs, texts, preds):
98
+ label = self.label_map_reverse[p.item()]
99
+ score = probs[i, p].item()
100
+ out_dict[i.item()] = (t, label, score, 'IndicLID-BERT')
101
+ return out_dict
102
+
103
+ def batch_predict(self, texts, batch_size=8):
104
+ native, roman = [], []
105
+ for i, t in enumerate(texts):
106
+ if self.char_percent_check(t) > self.input_threshold:
107
+ roman.append((i, t))
108
+ else:
109
+ native.append((i, t))
110
+ out_dict = {}
111
+ out_dict = self.native_inference(native, out_dict)
112
+ out_dict = self.ftr_inference(roman, out_dict, batch_size)
113
+ return [out_dict[i] for i in sorted(out_dict.keys())]
114
+
115
+ # ------------------------------
116
+ # Run a quick test
117
+ # ------------------------------
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
118
  if __name__ == "__main__":
119
+ detector = IndicLID()
120
+ samples = [
121
+ "यह एक हिंदी वाक्य है।",
122
+ "ennai pudikkuma?",
123
+ "ఇది ఒక తెలుగు వాక్యం",
124
+ "Hello, how are you?"
125
+ ]
126
+ results = detector.batch_predict(samples)
127
+ for text, label, score, model in results:
128
+ print(f"Text: {text}\nPredicted: {label} | Score: {score:.4f} | Model: {model}\n")