Spaces:
Runtime error
Runtime error
| import streamlit as st | |
| import pandas as pd | |
| import numpy as np | |
| from datasets import load_dataset, concatenate_datasets | |
| from sklearn.feature_extraction.text import TfidfVectorizer | |
| from sklearn.metrics.pairwise import cosine_similarity | |
| import torch | |
| from transformers import AutoTokenizer, AutoModelForQuestionAnswering | |
| import spacy | |
| import nltk | |
| from nltk.corpus import stopwords | |
| from nltk.tokenize import word_tokenize | |
| import re | |
| from bs4 import BeautifulSoup | |
| # === Загрузка и подготовка данных === | |
| def load_data(): | |
| # Загрузка датасета | |
| data = load_dataset('Romyx/ru_QA_school_history', split='train') | |
| df = pd.DataFrame(data) | |
| df['Pt_question'] = df['question'].apply(preprocess_text) | |
| df['Pt_answer'] = df['answer'].apply(preprocess_text) | |
| return df | |
| def load_model_and_tokenizer(): | |
| # Загрузка предобученной модели вопрос-ответа (например, SberQuad) | |
| model_name = "AlexKay/xlm-roberta-large-qa-multilingual-finedtuned-ru" # замените на нужную модель, например, "bert-base-uncased" | |
| tokenizer = AutoTokenizer.from_pretrained(model_name) | |
| model = AutoModelForQuestionAnswering.from_pretrained(model_name) | |
| return tokenizer, model | |
| def build_vectorizer(_df): | |
| combined_texts = _df['Pt_question'].tolist() + _df['Pt_answer'].tolist() | |
| vectorizer = TfidfVectorizer() | |
| tfidf_matrix = vectorizer.fit_transform(combined_texts) | |
| return vectorizer, tfidf_matrix | |
| # === Предобработка текста === | |
| # Загрузка Spacy модели | |
| nlp = spacy.load('ru_core_news_lg') | |
| stop_words = set(stopwords.words('russian')) | |
| cache_dict = {} | |
| def get_norm_form(word): | |
| if word in cache_dict: | |
| return cache_dict[word] | |
| norm_form = nlp(word)[0].lemma_ | |
| cache_dict[word] = norm_form | |
| return norm_form | |
| def remove_html_tags(text): | |
| soup = BeautifulSoup(text, 'html.parser') | |
| return soup.text | |
| def preprocess_text(text): | |
| if pd.isna(text) or text is None: | |
| return "" | |
| text = remove_html_tags(text) | |
| text = text.lower() | |
| # Обработка знаков препинания | |
| text = re.sub(r'([^\w\s-]|_)', r' \1 ', text) | |
| text = re.sub(r'\s+', ' ', text) | |
| text = re.sub(r'(\w+)-(\w+)', r'\1 \2', text) | |
| text = re.sub(r'(\d+)(г|кг|см|м|мм|л|мл)', r'\1 \2', text) | |
| # Удаление всего, кроме букв, цифр и пробелов | |
| text = re.sub(r'[^\w\s]', '', text) | |
| tokens = word_tokenize(text) | |
| tokens = [token for token in tokens if token not in stop_words] | |
| tokens = [get_norm_form(token) for token in tokens] | |
| words_to_remove = {"ответ", "new"} | |
| tokens = [token for token in tokens if token not in words_to_remove] | |
| return ' '.join(tokens) | |
| # === Основная функция получения ответа === | |
| def get_answer_from_qa_model(user_question, df, vectorizer, tfidf_matrix, model, tokenizer): | |
| processed = preprocess_text(user_question) | |
| user_vec = vectorizer.transform([processed]) | |
| similarities = cosine_similarity(user_vec, tfidf_matrix).flatten() | |
| # Проверка, что similarities не пустой | |
| if len(similarities) == 0: | |
| return "Тема не входит в программу этих классов." | |
| best_match_idx = similarities.argmax() | |
| best_score = similarities[best_match_idx] | |
| if best_score > 0.1: | |
| # Проверка, что индекс не выходит за границы | |
| if best_match_idx >= len(df): | |
| return "Тема не входит в программу этих классов." | |
| context = df.iloc[best_match_idx]['answer'] | |
| question = user_question | |
| inputs = tokenizer(question, context, return_tensors="pt", truncation=True, padding=True) | |
| with torch.no_grad(): | |
| outputs = model(**inputs) | |
| start_scores = outputs.start_logits | |
| end_scores = outputs.end_logits | |
| # Проверка на корректность размера логитов | |
| if len(start_scores.shape) == 2: | |
| start_idx = torch.argmax(start_scores, dim=1)[0].item() | |
| end_idx = torch.argmax(end_scores, dim=1)[0].item() | |
| else: | |
| start_idx = torch.argmax(start_scores).item() | |
| end_idx = torch.argmax(end_scores).item() | |
| # Проверка, что индексы не выходят за пределы | |
| seq_len = inputs['input_ids'].shape[1] | |
| if start_idx >= seq_len or end_idx >= seq_len or start_idx > end_idx: | |
| return "Ответ не найден." | |
| answer = tokenizer.decode(inputs['input_ids'][0][start_idx:end_idx+1], skip_special_tokens=True) | |
| else: | |
| answer = "Извините, я не понимаю вопрос." | |
| return answer | |
| # === Интерфейс Streamlit === | |
| st.title("🤖 ИИ-ассистент по истории (на основе вопрос-ответа)") | |
| st.write("Задайте вопрос, и я постараюсь найти на него ответ из базы.") | |
| # Загрузка данных и модели | |
| df = load_data() | |
| tokenizer, model = load_model_and_tokenizer() | |
| vectorizer, tfidf_matrix = build_vectorizer(df) | |
| # Поле ввода вопроса | |
| user_input = st.text_input("Введите ваш вопрос:") | |
| if st.button("Получить ответ"): | |
| if user_input.strip(): | |
| with st.spinner("Ищем ответ..."): | |
| response = get_answer_from_qa_model( | |
| user_input, df, vectorizer, tfidf_matrix, model, tokenizer | |
| ) | |
| st.success("Ответ:") | |
| st.write(response) | |
| else: | |
| st.warning("Пожалуйста, введите вопрос.") | |