| import gradio as gr | |
| from transformers import AutoModelForSequenceClassification, AutoTokenizer | |
| from evaluation import input_classification | |
| from explainer import CustomExplainer | |
| import numpy as np | |
| checkpoint = "Detsutut/medbit-assertion-negation" | |
| model = AutoModelForSequenceClassification.from_pretrained(checkpoint, num_labels=3, local_files_only=True) | |
| tokenizer = AutoTokenizer.from_pretrained(checkpoint, local_files_only=True) | |
| cls_explainer = CustomExplainer(model, tokenizer) | |
| sentence = "Il paziente non mostra alcun segno di [entità]." | |
| def compute(text): | |
| output = input_classification(model=model, tokenizer=tokenizer, x=text, all_classes=True) | |
| exp = cls_explainer(text) | |
| entities = [] | |
| start = 0 | |
| words_and_exp = cls_explainer.merge_attributions(exp) | |
| low_threshold = np.percentile([float(abs(w)) for _, w in words_and_exp], 25) | |
| high_threshold = np.percentile([float(abs(w)) for _, w in words_and_exp], 75) | |
| for i, entity in enumerate(words_and_exp): | |
| if entity[1] < 0 and entity[1] < low_threshold: | |
| polarity = "-" | |
| entities.append({"entity": polarity, "start": start, "end": start + len(entity[0])}) | |
| elif entity[1] > 0 and entity[1] > high_threshold: | |
| polarity = "+" | |
| entities.append({"entity": polarity, "start": start, "end": start + len(entity[0])}) | |
| start = start + len(entity[0]) + 1 | |
| return output, gr.HighlightedText(label="Explanation", visible=True, color_map={"+": "green", "-": "red"}, | |
| value={"text": " ".join([e[0] for e in words_and_exp]), "entities": entities}, | |
| combine_adjacent=True, adjacent_separator=" ") | |
| with gr.Blocks(title="Inference GUI") as gui: | |
| text = gr.Textbox(label="Input Text", value=sentence) | |
| explanation = gr.HighlightedText(label="Explanation", visible=False) | |
| output = gr.Label(label="Predicted Label", num_top_classes=3) | |
| compute_btn = gr.Button("Predict") | |
| compute_btn.click(fn=compute, inputs=text, outputs=[output, explanation], api_name="compute") | |
| gui.launch() |