ToxcityDetector / visualization.py
khushi-18's picture
Upload 13 files
3a4a5df verified
"""
Dataset Visualization Module for Jigsaw Toxic Comment Classification
"""
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from wordcloud import WordCloud
from typing import Tuple
import re
import streamlit as st
# Set style for better-looking plots
sns.set_style("whitegrid")
plt.rcParams['figure.figsize'] = (12, 6)
def load_dataset(file_path: str) -> pd.DataFrame:
"""Load the train.csv dataset"""
try:
df = pd.read_csv(file_path)
return df
except Exception as e:
st.error(f"Error loading dataset: {str(e)}")
return None
def prepare_data(df: pd.DataFrame) -> Tuple[pd.Series, pd.Series, dict]:
"""
Prepare data for visualization
Returns: toxic_texts, non_toxic_texts, label_counts
"""
# Get label columns
label_columns = ['toxic', 'severe_toxic', 'obscene', 'threat', 'insult', 'identity_hate']
# Calculate label frequencies
label_counts = df[label_columns].sum().to_dict()
# Create binary column for any toxicity
df['any_toxic'] = df[label_columns].max(axis=1)
# Separate toxic and non-toxic texts
toxic_df = df[df['any_toxic'] == 1]
non_toxic_df = df[df['any_toxic'] == 0]
# Sample for word clouds if dataset is too large
max_samples = 5000
if len(toxic_df) > max_samples:
toxic_df = toxic_df.sample(n=max_samples, random_state=42)
if len(non_toxic_df) > max_samples:
non_toxic_df = non_toxic_df.sample(n=max_samples, random_state=42)
# Combine text
toxic_texts = ' '.join(toxic_df['comment_text'].astype(str))
non_toxic_texts = ' '.join(non_toxic_df['comment_text'].astype(str))
return toxic_texts, non_toxic_texts, label_counts
def clean_text_for_wordcloud(text: str) -> str:
"""Clean text for word cloud generation"""
# Remove URLs
text = re.sub(r'http\S+|www\S+|https\S+', '', text, flags=re.MULTILINE)
# Remove special characters but keep spaces
text = re.sub(r'[^\w\s]', '', text)
# Convert to lowercase
text = text.lower()
return text
def create_label_frequency_chart(label_counts: dict):
"""Create a bar chart showing label frequencies"""
labels = list(label_counts.keys())
counts = list(label_counts.values())
plt.figure(figsize=(10, 6))
bars = plt.bar(labels, counts, color=['#ff6b6b', '#4ecdc4', '#45b7d1', '#f9ca24', '#f0932b', '#eb4d4b'])
plt.xlabel('Toxicity Type', fontsize=12, fontweight='bold')
plt.ylabel('Count', fontsize=12, fontweight='bold')
plt.title('📊 Label Distribution in Training Dataset', fontsize=14, fontweight='bold', pad=20)
plt.xticks(rotation=45, ha='right')
# Add value labels on bars
for bar in bars:
height = bar.get_height()
plt.text(bar.get_x() + bar.get_width()/2., height,
f'{int(height):,}', ha='center', va='bottom', fontsize=10)
plt.tight_layout()
return plt.gcf()
def create_wordcloud(text: str, title: str, colors: str, width: int = 800, height: int = 400):
"""Create a word cloud from text"""
# Clean text
cleaned_text = clean_text_for_wordcloud(text)
# Create word cloud
wordcloud = WordCloud(
width=width,
height=height,
background_color='white',
colormap=colors,
max_words=100,
prefer_horizontal=0.7,
relative_scaling=0.5,
min_font_size=10
).generate(cleaned_text)
# Plot
plt.figure(figsize=(12, 6))
plt.imshow(wordcloud, interpolation='bilinear')
plt.axis('off')
plt.title(title, fontsize=14, fontweight='bold', pad=20)
plt.tight_layout()
return plt.gcf()
def create_toxicity_comparison_chart(df: pd.DataFrame):
"""Create a pie chart showing toxic vs non-toxic distribution"""
label_columns = ['toxic', 'severe_toxic', 'obscene', 'threat', 'insult', 'identity_hate']
df['any_toxic'] = df[label_columns].max(axis=1)
toxic_count = df['any_toxic'].sum()
non_toxic_count = len(df) - toxic_count
plt.figure(figsize=(8, 8))
colors = ['#95e1d3', '#f38181']
explode = (0.05, 0.05)
plt.pie(
[non_toxic_count, toxic_count],
labels=['Non-Toxic', 'Toxic'],
autopct='%1.1f%%',
startangle=90,
colors=colors,
explode=explode,
shadow=True,
textprops={'fontsize': 14, 'fontweight': 'bold'}
)
plt.title('🧩 Toxic vs Non-Toxic Comments', fontsize=16, fontweight='bold', pad=20)
plt.tight_layout()
return plt.gcf()
def create_overlap_heatmap(df: pd.DataFrame):
"""Create a heatmap showing label overlaps"""
label_columns = ['toxic', 'severe_toxic', 'obscene', 'threat', 'insult', 'identity_hate']
# Calculate pairwise overlaps
overlap_matrix = np.zeros((len(label_columns), len(label_columns)))
for i, label1 in enumerate(label_columns):
for j, label2 in enumerate(label_columns):
overlap = ((df[label1] == 1) & (df[label2] == 1)).sum()
overlap_matrix[i, j] = overlap
# Create heatmap
plt.figure(figsize=(10, 8))
mask = np.triu(np.ones_like(overlap_matrix, dtype=bool), k=1)
sns.heatmap(
overlap_matrix,
annot=True,
fmt='.0f',
cmap='YlOrRd',
xticklabels=label_columns,
yticklabels=label_columns,
square=True,
cbar_kws={"shrink": 0.8},
mask=mask,
linewidths=0.5
)
plt.title('🔥 Label Co-occurrence Heatmap', fontsize=14, fontweight='bold', pad=20)
plt.tight_layout()
return plt.gcf()
def main_visualization(file_path: str = 'train.csv'):
"""Main function to generate all visualizations"""
# Load data
df = load_dataset(file_path)
if df is None:
return None, None, None, None, None
# Prepare data
toxic_texts, non_toxic_texts, label_counts = prepare_data(df)
# Create visualizations
fig1 = create_label_frequency_chart(label_counts)
# Create word clouds
fig2 = create_wordcloud(toxic_texts, "🔴 Most Common Words in Toxic Comments", 'Reds')
fig3 = create_wordcloud(non_toxic_texts, "🟢 Most Common Words in Non-Toxic Comments", 'Greens')
# Create pie chart
fig4 = create_toxicity_comparison_chart(df)
# Create heatmap
fig5 = create_overlap_heatmap(df)
return fig1, fig2, fig3, fig4, fig5
# Streamlit-specific functions
@st.cache_data
def load_data_cached(file_path: str):
"""Cached version of load_dataset for Streamlit"""
return load_dataset(file_path)
@st.cache_data
def generate_wordcloud_cached(text: str, colors: str, width: int = 800, height: int = 400):
"""Cached wordcloud generation"""
cleaned_text = clean_text_for_wordcloud(text)
wordcloud = WordCloud(
width=width,
height=height,
background_color='white',
colormap=colors,
max_words=100,
prefer_horizontal=0.7,
relative_scaling=0.5,
min_font_size=10
).generate(cleaned_text)
return wordcloud