Spaces:
Sleeping
Sleeping
| """ | |
| Dataset Visualization Module for Jigsaw Toxic Comment Classification | |
| """ | |
| import pandas as pd | |
| import numpy as np | |
| import matplotlib.pyplot as plt | |
| import seaborn as sns | |
| from wordcloud import WordCloud | |
| from typing import Tuple | |
| import re | |
| import streamlit as st | |
| # Set style for better-looking plots | |
| sns.set_style("whitegrid") | |
| plt.rcParams['figure.figsize'] = (12, 6) | |
| def load_dataset(file_path: str) -> pd.DataFrame: | |
| """Load the train.csv dataset""" | |
| try: | |
| df = pd.read_csv(file_path) | |
| return df | |
| except Exception as e: | |
| st.error(f"Error loading dataset: {str(e)}") | |
| return None | |
| def prepare_data(df: pd.DataFrame) -> Tuple[pd.Series, pd.Series, dict]: | |
| """ | |
| Prepare data for visualization | |
| Returns: toxic_texts, non_toxic_texts, label_counts | |
| """ | |
| # Get label columns | |
| label_columns = ['toxic', 'severe_toxic', 'obscene', 'threat', 'insult', 'identity_hate'] | |
| # Calculate label frequencies | |
| label_counts = df[label_columns].sum().to_dict() | |
| # Create binary column for any toxicity | |
| df['any_toxic'] = df[label_columns].max(axis=1) | |
| # Separate toxic and non-toxic texts | |
| toxic_df = df[df['any_toxic'] == 1] | |
| non_toxic_df = df[df['any_toxic'] == 0] | |
| # Sample for word clouds if dataset is too large | |
| max_samples = 5000 | |
| if len(toxic_df) > max_samples: | |
| toxic_df = toxic_df.sample(n=max_samples, random_state=42) | |
| if len(non_toxic_df) > max_samples: | |
| non_toxic_df = non_toxic_df.sample(n=max_samples, random_state=42) | |
| # Combine text | |
| toxic_texts = ' '.join(toxic_df['comment_text'].astype(str)) | |
| non_toxic_texts = ' '.join(non_toxic_df['comment_text'].astype(str)) | |
| return toxic_texts, non_toxic_texts, label_counts | |
| def clean_text_for_wordcloud(text: str) -> str: | |
| """Clean text for word cloud generation""" | |
| # Remove URLs | |
| text = re.sub(r'http\S+|www\S+|https\S+', '', text, flags=re.MULTILINE) | |
| # Remove special characters but keep spaces | |
| text = re.sub(r'[^\w\s]', '', text) | |
| # Convert to lowercase | |
| text = text.lower() | |
| return text | |
| def create_label_frequency_chart(label_counts: dict): | |
| """Create a bar chart showing label frequencies""" | |
| labels = list(label_counts.keys()) | |
| counts = list(label_counts.values()) | |
| plt.figure(figsize=(10, 6)) | |
| bars = plt.bar(labels, counts, color=['#ff6b6b', '#4ecdc4', '#45b7d1', '#f9ca24', '#f0932b', '#eb4d4b']) | |
| plt.xlabel('Toxicity Type', fontsize=12, fontweight='bold') | |
| plt.ylabel('Count', fontsize=12, fontweight='bold') | |
| plt.title('📊 Label Distribution in Training Dataset', fontsize=14, fontweight='bold', pad=20) | |
| plt.xticks(rotation=45, ha='right') | |
| # Add value labels on bars | |
| for bar in bars: | |
| height = bar.get_height() | |
| plt.text(bar.get_x() + bar.get_width()/2., height, | |
| f'{int(height):,}', ha='center', va='bottom', fontsize=10) | |
| plt.tight_layout() | |
| return plt.gcf() | |
| def create_wordcloud(text: str, title: str, colors: str, width: int = 800, height: int = 400): | |
| """Create a word cloud from text""" | |
| # Clean text | |
| cleaned_text = clean_text_for_wordcloud(text) | |
| # Create word cloud | |
| wordcloud = WordCloud( | |
| width=width, | |
| height=height, | |
| background_color='white', | |
| colormap=colors, | |
| max_words=100, | |
| prefer_horizontal=0.7, | |
| relative_scaling=0.5, | |
| min_font_size=10 | |
| ).generate(cleaned_text) | |
| # Plot | |
| plt.figure(figsize=(12, 6)) | |
| plt.imshow(wordcloud, interpolation='bilinear') | |
| plt.axis('off') | |
| plt.title(title, fontsize=14, fontweight='bold', pad=20) | |
| plt.tight_layout() | |
| return plt.gcf() | |
| def create_toxicity_comparison_chart(df: pd.DataFrame): | |
| """Create a pie chart showing toxic vs non-toxic distribution""" | |
| label_columns = ['toxic', 'severe_toxic', 'obscene', 'threat', 'insult', 'identity_hate'] | |
| df['any_toxic'] = df[label_columns].max(axis=1) | |
| toxic_count = df['any_toxic'].sum() | |
| non_toxic_count = len(df) - toxic_count | |
| plt.figure(figsize=(8, 8)) | |
| colors = ['#95e1d3', '#f38181'] | |
| explode = (0.05, 0.05) | |
| plt.pie( | |
| [non_toxic_count, toxic_count], | |
| labels=['Non-Toxic', 'Toxic'], | |
| autopct='%1.1f%%', | |
| startangle=90, | |
| colors=colors, | |
| explode=explode, | |
| shadow=True, | |
| textprops={'fontsize': 14, 'fontweight': 'bold'} | |
| ) | |
| plt.title('🧩 Toxic vs Non-Toxic Comments', fontsize=16, fontweight='bold', pad=20) | |
| plt.tight_layout() | |
| return plt.gcf() | |
| def create_overlap_heatmap(df: pd.DataFrame): | |
| """Create a heatmap showing label overlaps""" | |
| label_columns = ['toxic', 'severe_toxic', 'obscene', 'threat', 'insult', 'identity_hate'] | |
| # Calculate pairwise overlaps | |
| overlap_matrix = np.zeros((len(label_columns), len(label_columns))) | |
| for i, label1 in enumerate(label_columns): | |
| for j, label2 in enumerate(label_columns): | |
| overlap = ((df[label1] == 1) & (df[label2] == 1)).sum() | |
| overlap_matrix[i, j] = overlap | |
| # Create heatmap | |
| plt.figure(figsize=(10, 8)) | |
| mask = np.triu(np.ones_like(overlap_matrix, dtype=bool), k=1) | |
| sns.heatmap( | |
| overlap_matrix, | |
| annot=True, | |
| fmt='.0f', | |
| cmap='YlOrRd', | |
| xticklabels=label_columns, | |
| yticklabels=label_columns, | |
| square=True, | |
| cbar_kws={"shrink": 0.8}, | |
| mask=mask, | |
| linewidths=0.5 | |
| ) | |
| plt.title('🔥 Label Co-occurrence Heatmap', fontsize=14, fontweight='bold', pad=20) | |
| plt.tight_layout() | |
| return plt.gcf() | |
| def main_visualization(file_path: str = 'train.csv'): | |
| """Main function to generate all visualizations""" | |
| # Load data | |
| df = load_dataset(file_path) | |
| if df is None: | |
| return None, None, None, None, None | |
| # Prepare data | |
| toxic_texts, non_toxic_texts, label_counts = prepare_data(df) | |
| # Create visualizations | |
| fig1 = create_label_frequency_chart(label_counts) | |
| # Create word clouds | |
| fig2 = create_wordcloud(toxic_texts, "🔴 Most Common Words in Toxic Comments", 'Reds') | |
| fig3 = create_wordcloud(non_toxic_texts, "🟢 Most Common Words in Non-Toxic Comments", 'Greens') | |
| # Create pie chart | |
| fig4 = create_toxicity_comparison_chart(df) | |
| # Create heatmap | |
| fig5 = create_overlap_heatmap(df) | |
| return fig1, fig2, fig3, fig4, fig5 | |
| # Streamlit-specific functions | |
| def load_data_cached(file_path: str): | |
| """Cached version of load_dataset for Streamlit""" | |
| return load_dataset(file_path) | |
| def generate_wordcloud_cached(text: str, colors: str, width: int = 800, height: int = 400): | |
| """Cached wordcloud generation""" | |
| cleaned_text = clean_text_for_wordcloud(text) | |
| wordcloud = WordCloud( | |
| width=width, | |
| height=height, | |
| background_color='white', | |
| colormap=colors, | |
| max_words=100, | |
| prefer_horizontal=0.7, | |
| relative_scaling=0.5, | |
| min_font_size=10 | |
| ).generate(cleaned_text) | |
| return wordcloud | |