mahmoudsaber0 commited on
Commit
080d131
Β·
verified Β·
1 Parent(s): ab32be2

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +93 -142
app.py CHANGED
@@ -1,145 +1,96 @@
1
  import os
2
- import re
 
 
3
  import torch
4
- from fastapi import FastAPI
5
- from pydantic import BaseModel
6
- from transformers import AutoTokenizer, AutoModelForSequenceClassification
7
- from tokenizers import normalizers
8
- from tokenizers.normalizers import Sequence, Replace, Strip
9
- from tokenizers import Regex
10
-
11
- # βœ… Environment cache setup for safe deployment
12
- os.environ["HF_HOME"] = "/tmp"
13
- os.environ["TRANSFORMERS_CACHE"] = "/tmp"
14
- os.environ["HF_DATASETS_CACHE"] = "/tmp"
15
- os.environ["HF_HUB_CACHE"] = "/tmp"
16
-
17
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
18
-
19
- # βœ… Model & Tokenizer Setup
20
- model1_path = "modernbert.bin"
21
- model2_path = "https://huggingface.co/mihalykiss/modernbert_2/resolve/main/Model_groups_3class_seed12"
22
- model3_path = "https://huggingface.co/mihalykiss/modernbert_2/resolve/main/Model_groups_3class_seed22"
23
-
24
- tokenizer = AutoTokenizer.from_pretrained("answerdotai/ModernBERT-base")
25
-
26
- # βœ… Load 3 models
27
- model_1 = AutoModelForSequenceClassification.from_pretrained("answerdotai/ModernBERT-base", num_labels=41)
28
- model_1.load_state_dict(torch.load(model1_path, map_location=device))
29
- model_1.to(device).eval()
30
-
31
- model_2 = AutoModelForSequenceClassification.from_pretrained("answerdotai/ModernBERT-base", num_labels=41)
32
- model_2.load_state_dict(torch.hub.load_state_dict_from_url(model2_path, map_location=device))
33
- model_2.to(device).eval()
34
-
35
- model_3 = AutoModelForSequenceClassification.from_pretrained("answerdotai/ModernBERT-base", num_labels=41)
36
- model_3.load_state_dict(torch.hub.load_state_dict_from_url(model3_path, map_location=device))
37
- model_3.to(device).eval()
38
-
39
- # βœ… Label mapping
40
- label_mapping = {
41
- 0: '13B', 1: '30B', 2: '65B', 3: '7B', 4: 'GLM130B', 5: 'bloom_7b',
42
- 6: 'bloomz', 7: 'cohere', 8: 'davinci', 9: 'dolly', 10: 'dolly-v2-12b',
43
- 11: 'flan_t5_base', 12: 'flan_t5_large', 13: 'flan_t5_small',
44
- 14: 'flan_t5_xl', 15: 'flan_t5_xxl', 16: 'gemma-7b-it', 17: 'gemma2-9b-it',
45
- 18: 'gpt-3.5-turbo', 19: 'gpt-35', 20: 'gpt4', 21: 'gpt4o',
46
- 22: 'gpt_j', 23: 'gpt_neox', 24: 'human', 25: 'llama3-70b', 26: 'llama3-8b',
47
- 27: 'mixtral-8x7b', 28: 'opt_1.3b', 29: 'opt_125m', 30: 'opt_13b',
48
- 31: 'opt_2.7b', 32: 'opt_30b', 33: 'opt_350m', 34: 'opt_6.7b',
49
- 35: 'opt_iml_30b', 36: 'opt_iml_max_1.3b', 37: 't0_11b', 38: 't0_3b',
50
- 39: 'text-davinci-002', 40: 'text-davinci-003'
51
- }
52
-
53
- # βœ… Text cleaning and normalization
54
- def clean_text(text: str) -> str:
55
- text = re.sub(r'\s{2,}', ' ', text)
56
- text = re.sub(r'\s+([,.;:?!])', r'\1', text)
57
- return text
58
-
59
- newline_to_space = Replace(Regex(r'\s*\n\s*'), " ")
60
- join_hyphen_break = Replace(Regex(r'(\w+)[--]\s*\n\s*(\w+)'), r"\1\2")
61
-
62
- tokenizer.backend_tokenizer.normalizer = Sequence([
63
- tokenizer.backend_tokenizer.normalizer,
64
- join_hyphen_break,
65
- newline_to_space,
66
- Strip()
67
- ])
68
-
69
- # βœ… FastAPI app
70
- app = FastAPI(title="ModernBERT AI Text Detector")
71
-
72
- class InputText(BaseModel):
73
- text: str
74
-
75
- def classify_text_ensemble(text: str):
76
- """Run ensemble classification and return percentages + identified model"""
77
- cleaned_text = clean_text(text)
78
- if not cleaned_text.strip():
79
- return None
80
-
81
- inputs = tokenizer(cleaned_text, return_tensors="pt", truncation=True, padding=True).to(device)
82
-
83
- with torch.no_grad():
84
- logits_1 = model_1(**inputs).logits
85
- logits_2 = model_2(**inputs).logits
86
- logits_3 = model_3(**inputs).logits
87
-
88
- softmax_1 = torch.softmax(logits_1, dim=1)
89
- softmax_2 = torch.softmax(logits_2, dim=1)
90
- softmax_3 = torch.softmax(logits_3, dim=1)
91
-
92
- averaged_probabilities = (softmax_1 + softmax_2 + softmax_3) / 3
93
- probabilities = averaged_probabilities[0]
94
-
95
- human_prob = probabilities[24].item()
96
- ai_probs_clone = probabilities.clone()
97
- ai_probs_clone[24] = 0
98
- ai_total_prob = ai_probs_clone.sum().item()
99
-
100
- total = human_prob + ai_total_prob
101
- human_percentage = (human_prob / total) * 100
102
- ai_percentage = (ai_total_prob / total) * 100
103
-
104
- ai_argmax_index = torch.argmax(ai_probs_clone).item()
105
- ai_model_name = label_mapping[ai_argmax_index]
106
-
107
- return {
108
- "ai_percentage": round(ai_percentage, 2),
109
- "human_percentage": round(human_percentage, 2),
110
- "identified_model": ai_model_name,
111
- "is_ai": ai_percentage > human_percentage
112
- }
113
-
114
- @app.get("/")
115
- def root():
116
- return {"message": "ModernBERT AI Text Detector API is running. Use POST /analyze"}
117
-
118
  @app.post("/analyze")
119
- async def analyze(data: InputText):
120
- text = data.text.strip()
121
- if not text:
122
- return {"success": False, "code": 400, "message": "Empty text"}
123
-
124
- result = classify_text_ensemble(text)
125
- if not result:
126
- return {"success": False, "code": 400, "message": "Text too short or invalid"}
127
-
128
- feedback = (
129
- f"The text is {result['human_percentage']}% likely human-written."
130
- if not result["is_ai"]
131
- else f"The text is {result['ai_percentage']}% likely AI-generated. Identified LLM: {result['identified_model']}."
132
- )
133
-
134
- return {
135
- "success": True,
136
- "code": 200,
137
- "data": {
138
- "input_text": text,
139
- "ai_percentage": result["ai_percentage"],
140
- "human_percentage": result["human_percentage"],
141
- "identified_model": result["identified_model"],
142
- "feedback": feedback,
143
- "is_ai": result["is_ai"]
144
- }
145
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import os
2
+ from fastapi import FastAPI, WebSocket, UploadFile, File
3
+ from fastapi.middleware.cors import CORSMiddleware
4
+ from fastapi.responses import JSONResponse
5
  import torch
6
+ from transformers import AutoTokenizer, AutoModelForSequenceClassification, pipeline
7
+ import asyncio
8
+
9
+ # =====================================================
10
+ # βœ… Fix Hugging Face Cache Permission Errors
11
+ # =====================================================
12
+ CACHE_DIR = "/tmp/hf_cache"
13
+ os.environ["HF_HOME"] = CACHE_DIR
14
+ os.environ["TRANSFORMERS_CACHE"] = CACHE_DIR
15
+ os.environ["HF_DATASETS_CACHE"] = CACHE_DIR
16
+ os.environ["HF_HUB_CACHE"] = CACHE_DIR
17
+ os.makedirs(CACHE_DIR, exist_ok=True)
18
+
19
+ # =====================================================
20
+ # βœ… Initialize Model and Tokenizer
21
+ # =====================================================
22
+ MODEL_NAME = "answerdotai/ModernBERT-base"
23
+
24
+ print("Loading model and tokenizer...")
25
+ tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
26
+ model = AutoModelForSequenceClassification.from_pretrained(MODEL_NAME)
27
+
28
+ classifier = pipeline(
29
+ "text-classification",
30
+ model=model,
31
+ tokenizer=tokenizer,
32
+ device=0 if torch.cuda.is_available() else -1
33
+ )
34
+
35
+ # =====================================================
36
+ # βœ… FastAPI App Setup
37
+ # =====================================================
38
+ app = FastAPI(title="ModernBERT FastAPI Server")
39
+
40
+ # Allow all origins (for testing)
41
+ app.add_middleware(
42
+ CORSMiddleware,
43
+ allow_origins=["*"],
44
+ allow_credentials=True,
45
+ allow_methods=["*"],
46
+ allow_headers=["*"],
47
+ )
48
+
49
+ # =====================================================
50
+ # βœ… REST Endpoint Example
51
+ # =====================================================
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
52
  @app.post("/analyze")
53
+ async def analyze_text(data: dict):
54
+ try:
55
+ text = data.get("text", "")
56
+ if not text.strip():
57
+ return JSONResponse({"error": "Empty text provided"}, status_code=400)
58
+
59
+ result = classifier(text)
60
+ return {"result": result}
61
+ except Exception as e:
62
+ return JSONResponse({"error": str(e)}, status_code=500)
63
+
64
+ # =====================================================
65
+ # βœ… WebSocket Endpoint (real-time classification)
66
+ # =====================================================
67
+ @app.websocket("/ws")
68
+ async def websocket_endpoint(ws: WebSocket):
69
+ await ws.accept()
70
+ idle_timeout = 60 # seconds
71
+
72
+ async def close_if_idle():
73
+ while True:
74
+ await asyncio.sleep(idle_timeout)
75
+ await ws.close(code=1000)
76
+ break
77
+
78
+ asyncio.create_task(close_if_idle())
79
+
80
+ try:
81
+ while True:
82
+ message = await ws.receive_text()
83
+ if message.lower() in ["exit", "quit"]:
84
+ await ws.close(code=1000)
85
+ break
86
+ result = classifier(message)
87
+ await ws.send_json(result)
88
+ except Exception:
89
+ await ws.close()
90
+
91
+ # =====================================================
92
+ # βœ… Root Endpoint
93
+ # =====================================================
94
+ @app.get("/")
95
+ def home():
96
+ return {"status": "ok", "model": MODEL_NAME, "device": "cuda" if torch.cuda.is_available() else "cpu"}