sapheneia-timesfm / src /visualization.py
rkovashikawa's picture
Initial Hugging Face Spaces deployment
48abd32
"""
Professional Visualization Module for TimesFM Forecasting
This module provides comprehensive visualization capabilities for TimesFM forecasting,
including professional-grade plots with prediction intervals, covariates displays,
and publication-ready styling.
Key Features:
- Professional forecast visualizations with seamless connections
- Prediction intervals with customizable confidence levels
- Covariates subplots integration
- Sapheneia-style professional formatting
- Interactive and static plot options
- Export capabilities for presentations and publications
"""
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import matplotlib.dates as mdates
import seaborn as sns
from datetime import datetime
from typing import List, Dict, Optional, Union
import logging
logger = logging.getLogger(__name__)
# Set professional style
plt.style.use('seaborn-v0_8')
sns.set_palette("husl")
class Visualizer:
"""
Professional visualization class for TimesFM forecasting results.
This class provides methods to create publication-quality visualizations
of forecasting results, including prediction intervals, covariates analysis,
and comprehensive time series plots.
Example:
>>> viz = Visualizer()
>>> fig = viz.plot_forecast_with_intervals(
... historical_data=historical,
... forecast=point_forecast,
... intervals=prediction_intervals,
... title="Bitcoin Price Forecast"
... )
"""
def __init__(self, style: str = "professional"):
"""
Initialize the Visualizer with specified styling.
Args:
style: Visualization style ("professional", "minimal", "presentation")
"""
self.style = style
self._setup_style()
logger.info(f"Visualizer initialized with '{style}' style")
def _setup_style(self) -> None:
"""Set up the visualization style and parameters."""
if self.style == "professional":
# Sapheneia professional style
self.colors = {
'historical': '#1f77b4',
'forecast': '#d62728',
'actual': '#2ca02c',
'interval_80': '#ffb366',
'interval_50': '#ff7f0e',
'grid': '#e0e0e0',
'background': '#fafafa'
}
self.figsize = (16, 12)
elif self.style == "minimal":
# Clean minimal style
self.colors = {
'historical': '#2E86AB',
'forecast': '#A23B72',
'actual': '#F18F01',
'interval_80': '#C73E1D',
'interval_50': '#F18F01',
'grid': '#f0f0f0',
'background': 'white'
}
self.figsize = (14, 10)
else: # presentation
# High contrast for presentations
self.colors = {
'historical': '#003f5c',
'forecast': '#ff6361',
'actual': '#58508d',
'interval_80': '#ffa600',
'interval_50': '#ff6361',
'grid': '#e8e8e8',
'background': 'white'
}
self.figsize = (18, 14)
def plot_forecast_with_intervals(
self,
historical_data: Union[List[float], np.ndarray],
forecast: Union[List[float], np.ndarray],
intervals: Optional[Dict[str, np.ndarray]] = None,
actual_future: Optional[Union[List[float], np.ndarray]] = None,
dates_historical: Optional[List[Union[str, datetime]]] = None,
dates_future: Optional[List[Union[str, datetime]]] = None,
title: str = "TimesFM Forecast with Prediction Intervals",
target_name: str = "Value",
save_path: Optional[str] = None
) -> plt.Figure:
"""
Create a professional forecast visualization with prediction intervals.
Args:
historical_data: Historical time series data
forecast: Point forecast values
intervals: Dictionary containing prediction intervals
actual_future: Optional actual future values for comparison
dates_historical: Optional dates for historical data
dates_future: Optional dates for forecast period
title: Plot title
target_name: Name of the target variable
save_path: Optional path to save the plot
Returns:
Matplotlib Figure object
"""
logger.info(f"Creating forecast visualization: {title}")
# Convert to numpy arrays
historical_data = np.array(historical_data)
forecast = np.array(forecast)
if actual_future is not None:
actual_future = np.array(actual_future)
# Create figure
fig, ax = plt.subplots(figsize=self.figsize)
ax.set_facecolor(self.colors['background'])
# Setup time axis
if dates_historical is None:
historical_x = np.arange(len(historical_data))
else:
historical_x = pd.to_datetime(dates_historical)
if dates_future is None:
future_x = np.arange(len(historical_data), len(historical_data) + len(forecast))
else:
future_x = pd.to_datetime(dates_future)
# Plot historical data
ax.plot(historical_x, historical_data,
color=self.colors['historical'], linewidth=2.5,
label='Historical Data', zorder=5)
# Create seamless connection for forecast
if dates_historical is None:
connection_x = [len(historical_data) - 1] + list(future_x)
else:
connection_x = [historical_x[-1]] + list(future_x)
connection_forecast = [historical_data[-1]] + list(forecast)
# Plot quantile intervals if available
if intervals:
# Handle different types of intervals
if 'lower_80' in intervals and 'upper_80' in intervals:
# Traditional confidence intervals
interval_lower = [historical_data[-1]] + list(intervals['lower_80'])
interval_upper = [historical_data[-1]] + list(intervals['upper_80'])
ax.fill_between(connection_x, interval_lower, interval_upper,
alpha=0.3, color=self.colors['interval_80'],
label='80% Quantile Interval', zorder=1)
# Add 50% interval if available
if 'lower_50' in intervals and 'upper_50' in intervals:
interval_lower_50 = [historical_data[-1]] + list(intervals['lower_50'])
interval_upper_50 = [historical_data[-1]] + list(intervals['upper_50'])
ax.fill_between(connection_x, interval_lower_50, interval_upper_50,
alpha=0.5, color=self.colors['interval_50'],
label='50% Quantile Interval', zorder=2)
else:
# Check for generic confidence levels
conf_levels = []
for key in intervals.keys():
if key.startswith('lower_'):
conf_level = key.split('_')[1]
if f'upper_{conf_level}' in intervals:
conf_levels.append(int(conf_level))
conf_levels.sort(reverse=True) # Largest first for layering
for conf_level in conf_levels:
lower_key = f'lower_{conf_level}'
upper_key = f'upper_{conf_level}'
if lower_key in intervals and upper_key in intervals:
# Create seamless intervals
interval_lower = [historical_data[-1]] + list(intervals[lower_key])
interval_upper = [historical_data[-1]] + list(intervals[upper_key])
alpha = 0.3 if conf_level == max(conf_levels) else 0.5
color = self.colors['interval_80'] if conf_level >= 80 else self.colors['interval_50']
ax.fill_between(connection_x, interval_lower, interval_upper,
alpha=alpha, color=color,
label=f'{conf_level}% Quantile Interval', zorder=1)
# Handle quantile bands (new format)
quantile_bands = {}
for key in intervals.keys():
if key.startswith('quantile_band_') and key.endswith('_lower'):
band_name = key.replace('quantile_band_', '').replace('_lower', '')
upper_key = f'quantile_band_{band_name}_upper'
if upper_key in intervals:
quantile_bands[band_name] = {
'lower': intervals[key],
'upper': intervals[upper_key]
}
if quantile_bands:
# Define colors for different bands
band_colors = ['#ff9999', '#99ccff', '#99ff99', '#ffcc99', '#cc99ff', '#ffff99']
logger.info(f"Processing {len(quantile_bands)} quantile bands")
logger.info(f"Connection_x length: {len(connection_x)}, Forecast length: {len(forecast)}")
for i, (band_name, band_data) in enumerate(sorted(quantile_bands.items())):
color = band_colors[i % len(band_colors)]
alpha = 0.3 + (0.2 * (1 - i / max(1, len(quantile_bands) - 1))) # Vary alpha
# Ensure quantile band data matches forecast length
lower_values = band_data['lower']
upper_values = band_data['upper']
logger.info(f"Band {band_name}: lower length={len(lower_values)}, upper length={len(upper_values)}")
# Truncate or pad to match forecast length
if len(lower_values) > len(forecast):
lower_values = lower_values[:len(forecast)]
upper_values = upper_values[:len(forecast)]
logger.info(f"Truncated band {band_name} to forecast length")
elif len(lower_values) < len(forecast):
# Pad with last value if too short
last_lower = lower_values[-1] if lower_values else 0
last_upper = upper_values[-1] if upper_values else 0
lower_values = list(lower_values) + [last_lower] * (len(forecast) - len(lower_values))
upper_values = list(upper_values) + [last_upper] * (len(forecast) - len(upper_values))
logger.info(f"Padded band {band_name} to forecast length")
interval_lower = [historical_data[-1]] + list(lower_values)
interval_upper = [historical_data[-1]] + list(upper_values)
logger.info(f"Final interval lengths: lower={len(interval_lower)}, upper={len(interval_upper)}, connection_x={len(connection_x)}")
label_key = f'quantile_band_{band_name}_label'
label_text = intervals.get(label_key, f'Quantile Band {int(band_name)+1}')
ax.fill_between(connection_x, interval_lower, interval_upper,
alpha=alpha, color=color,
label=label_text, zorder=1)
# Plot forecast line
ax.plot(connection_x, connection_forecast,
color=self.colors['forecast'], linestyle='--', linewidth=2.5,
label='Point Forecast', zorder=4)
# Plot actual future data if available
if actual_future is not None:
actual_connection = [historical_data[-1]] + list(actual_future)
ax.plot(connection_x, actual_connection,
color=self.colors['actual'], linewidth=3,
marker='o', markersize=6, markeredgecolor='white',
markeredgewidth=1, label='Actual Future', zorder=6)
# Add forecast start line
forecast_start = historical_x[-1] if dates_historical else len(historical_data) - 1
ax.axvline(x=forecast_start, color='gray', linestyle=':',
alpha=0.7, linewidth=1.5, label='Forecast Start')
# Styling
ax.set_title(title, fontsize=18, fontweight='bold', pad=20)
ax.set_ylabel(target_name, fontsize=14, fontweight='bold')
ax.set_xlabel('Time', fontsize=14, fontweight='bold')
# Grid
ax.grid(True, alpha=0.3, linestyle='-', linewidth=0.5, color=self.colors['grid'])
# Legend
legend = ax.legend(loc='upper left', fontsize=12, frameon=True,
fancybox=True, shadow=True, framealpha=0.95)
legend.get_frame().set_facecolor('white')
# Format axes
ax.tick_params(labelsize=12)
# Format dates if using datetime
if dates_historical is not None:
ax.xaxis.set_major_formatter(mdates.DateFormatter('%Y-%m-%d'))
ax.xaxis.set_major_locator(mdates.MonthLocator(interval=1))
plt.setp(ax.xaxis.get_majorticklabels(), rotation=45)
# Add timestamp
timestamp = datetime.now().strftime("%Y-%m-%d %H:%M")
fig.text(0.99, 0.01, f'Generated: {timestamp}', ha='right', va='bottom',
fontsize=10, alpha=0.7)
plt.tight_layout()
# Save if requested
if save_path:
plt.savefig(save_path, dpi=300, bbox_inches='tight', facecolor='white')
logger.info(f"Plot saved to: {save_path}")
logger.info("βœ… Forecast visualization completed")
return fig
def plot_forecast_with_covariates(
self,
historical_data: Union[List[float], np.ndarray],
forecast: Union[List[float], np.ndarray],
covariates_data: Dict[str, Dict[str, Union[List[float], float, str]]],
intervals: Optional[Dict[str, np.ndarray]] = None,
actual_future: Optional[Union[List[float], np.ndarray]] = None,
dates_historical: Optional[List[Union[str, datetime]]] = None,
dates_future: Optional[List[Union[str, datetime]]] = None,
title: str = "TimesFM Forecast with Covariates Analysis",
target_name: str = "Target Value",
save_path: Optional[str] = None
) -> plt.Figure:
"""
Create a comprehensive visualization with main forecast and covariates subplots.
Args:
historical_data: Historical time series data
forecast: Point forecast values
covariates_data: Dictionary containing covariates information
intervals: Optional prediction intervals
actual_future: Optional actual future values
dates_historical: Optional historical dates
dates_future: Optional future dates
title: Main plot title
target_name: Name of target variable
save_path: Optional save path
Returns:
Matplotlib Figure object
"""
logger.info(f"Creating comprehensive forecast with covariates: {title}")
# Count covariates for subplot layout
num_covariates = len([k for k, v in covariates_data.items()
if isinstance(v, dict) and 'historical' in v])
# Create subplot layout
if num_covariates == 0:
return self.plot_forecast_with_intervals(
historical_data, forecast, intervals, actual_future,
dates_historical, dates_future, title, target_name, save_path
)
# Determine grid layout
if num_covariates <= 2:
rows, cols = 2, 2
height_ratios = [3, 1]
elif num_covariates <= 4:
rows, cols = 3, 2
height_ratios = [3, 1, 1]
else:
rows, cols = 4, 2
height_ratios = [3, 1, 1, 1]
fig = plt.figure(figsize=(18, 14))
gs = fig.add_gridspec(rows, cols, height_ratios=height_ratios,
hspace=0.35, wspace=0.25)
# Main forecast plot (top row, full width)
ax_main = fig.add_subplot(gs[0, :])
# Convert data
historical_data = np.array(historical_data)
forecast = np.array(forecast)
# Setup time axes
if dates_historical is None:
historical_x = np.arange(len(historical_data))
future_x = np.arange(len(historical_data), len(historical_data) + len(forecast))
else:
historical_x = pd.to_datetime(dates_historical)
future_x = pd.to_datetime(dates_future) if dates_future else None
# Plot main forecast (similar to single plot method)
ax_main.set_facecolor(self.colors['background'])
ax_main.plot(historical_x, historical_data,
color=self.colors['historical'], linewidth=2.5,
label='Historical Data', zorder=5)
# Forecast with seamless connection
if dates_historical is None:
connection_x = [len(historical_data) - 1] + list(future_x)
else:
connection_x = [historical_x[-1]] + list(future_x)
connection_forecast = [historical_data[-1]] + list(forecast)
# Plot intervals if available
if intervals:
for key in intervals.keys():
if key.startswith('lower_'):
conf_level = key.split('_')[1]
upper_key = f'upper_{conf_level}'
if upper_key in intervals:
interval_lower = [historical_data[-1]] + list(intervals[key])
interval_upper = [historical_data[-1]] + list(intervals[upper_key])
alpha = 0.3 if int(conf_level) >= 80 else 0.5
color = self.colors['interval_80'] if int(conf_level) >= 80 else self.colors['interval_50']
ax_main.fill_between(connection_x, interval_lower, interval_upper,
alpha=alpha, color=color,
label=f'{conf_level}% Prediction Interval')
ax_main.plot(connection_x, connection_forecast,
color=self.colors['forecast'], linestyle='--', linewidth=2.5,
label='Point Forecast', zorder=4)
# Plot actual future if available
if actual_future is not None:
actual_future = np.array(actual_future)
actual_connection = [historical_data[-1]] + list(actual_future)
ax_main.plot(connection_x, actual_connection,
color=self.colors['actual'], linewidth=3,
marker='o', markersize=6, markeredgecolor='white',
markeredgewidth=1, label='Actual Future', zorder=6)
# Forecast start line
forecast_start = historical_x[-1] if dates_historical else len(historical_data) - 1
ax_main.axvline(x=forecast_start, color='gray', linestyle=':',
alpha=0.7, linewidth=1.5, label='Forecast Start')
# Main plot styling
ax_main.set_title(title, fontsize=18, fontweight='bold', pad=20)
ax_main.set_ylabel(target_name, fontsize=14, fontweight='bold')
ax_main.grid(True, alpha=0.3, color=self.colors['grid'])
ax_main.tick_params(labelsize=12)
legend = ax_main.legend(loc='upper left', fontsize=12, frameon=True)
legend.get_frame().set_facecolor('white')
# Create covariate subplots
covariate_colors = ['#9467bd', '#e377c2', '#7f7f7f', '#bcbd22', '#17becf', '#d62728']
plot_idx = 0
for cov_name, cov_data in covariates_data.items():
if not isinstance(cov_data, dict) or 'historical' not in cov_data:
continue
if plot_idx >= (rows - 1) * cols: # Don't exceed subplot capacity
break
# Calculate subplot position
row = 1 + plot_idx // cols
col = plot_idx % cols
ax_cov = fig.add_subplot(gs[row, col])
color = covariate_colors[plot_idx % len(covariate_colors)]
# Plot historical covariate data
ax_cov.plot(historical_x, cov_data['historical'],
color=color, linewidth=2.5, alpha=0.8, label='Historical')
# Plot future covariate data if available
if 'future' in cov_data and future_x is not None:
combined_data = list(cov_data['historical']) + list(cov_data['future'])
if dates_historical is None:
combined_x = np.arange(len(combined_data))
else:
combined_x = list(historical_x) + list(future_x)
future_start_idx = len(cov_data['historical']) - 1
ax_cov.plot(combined_x[future_start_idx:], combined_data[future_start_idx:],
color=color, linewidth=2.5, linestyle='--', alpha=0.9,
marker='s', markersize=4, label='Future')
# Forecast start line
ax_cov.axvline(x=forecast_start, color='gray', linestyle=':', alpha=0.5)
# Styling
ax_cov.set_title(f'{cov_name.replace("_", " ").title()}',
fontsize=12, fontweight='bold')
ax_cov.set_ylabel('Value', fontsize=10)
ax_cov.grid(True, alpha=0.3, color=self.colors['grid'])
ax_cov.tick_params(labelsize=9)
ax_cov.legend(fontsize=8, loc='upper left')
ax_cov.set_facecolor(self.colors['background'])
plot_idx += 1
# Format x-axis for dates
if dates_historical is not None:
for ax in fig.get_axes():
ax.xaxis.set_major_formatter(mdates.DateFormatter('%m/%d'))
plt.setp(ax.xaxis.get_majorticklabels(), rotation=45)
# Overall title and timestamp
fig.suptitle('TimesFM Comprehensive Forecasting Analysis',
fontsize=20, fontweight='bold', y=0.98)
timestamp = datetime.now().strftime("%Y-%m-%d %H:%M")
fig.text(0.99, 0.01, f'Generated: {timestamp}', ha='right', va='bottom',
fontsize=10, alpha=0.7)
plt.tight_layout()
if save_path:
plt.savefig(save_path, dpi=300, bbox_inches='tight', facecolor='white')
logger.info(f"Comprehensive plot saved to: {save_path}")
logger.info("βœ… Comprehensive forecast visualization completed")
return fig
def plot_forecast_comparison(
self,
forecasts_dict: Dict[str, np.ndarray],
historical_data: Union[List[float], np.ndarray],
actual_future: Optional[Union[List[float], np.ndarray]] = None,
title: str = "Forecast Methods Comparison",
save_path: Optional[str] = None
) -> plt.Figure:
"""
Compare multiple forecasting methods in a single plot.
Args:
forecasts_dict: Dictionary of {method_name: forecast_array}
historical_data: Historical data for context
actual_future: Optional actual future values
title: Plot title
save_path: Optional save path
Returns:
Matplotlib Figure object
"""
logger.info(f"Creating forecast comparison plot: {title}")
fig, ax = plt.subplots(figsize=self.figsize)
ax.set_facecolor(self.colors['background'])
historical_data = np.array(historical_data)
historical_x = np.arange(len(historical_data))
# Plot historical data
ax.plot(historical_x, historical_data,
color=self.colors['historical'], linewidth=2.5,
label='Historical Data', zorder=5)
# Plot different forecasts
forecast_colors = ['#d62728', '#ff7f0e', '#2ca02c', '#9467bd', '#8c564b']
for i, (method, forecast) in enumerate(forecasts_dict.items()):
forecast = np.array(forecast)
future_x = np.arange(len(historical_data), len(historical_data) + len(forecast))
# Seamless connection
connection_x = [len(historical_data) - 1] + list(future_x)
connection_forecast = [historical_data[-1]] + list(forecast)
color = forecast_colors[i % len(forecast_colors)]
linestyle = '--' if i == 0 else '-.'
ax.plot(connection_x, connection_forecast,
color=color, linestyle=linestyle, linewidth=2.5,
label=f'{method} Forecast', zorder=3)
# Plot actual future if available
if actual_future is not None:
actual_future = np.array(actual_future)
future_x = np.arange(len(historical_data), len(historical_data) + len(actual_future))
connection_x = [len(historical_data) - 1] + list(future_x)
actual_connection = [historical_data[-1]] + list(actual_future)
ax.plot(connection_x, actual_connection,
color=self.colors['actual'], linewidth=3,
marker='o', markersize=6, markeredgecolor='white',
markeredgewidth=1, label='Actual Future', zorder=6)
# Forecast start line
ax.axvline(x=len(historical_data) - 1, color='gray', linestyle=':',
alpha=0.7, linewidth=1.5, label='Forecast Start')
# Styling
ax.set_title(title, fontsize=18, fontweight='bold', pad=20)
ax.set_ylabel('Value', fontsize=14, fontweight='bold')
ax.set_xlabel('Time', fontsize=14, fontweight='bold')
ax.grid(True, alpha=0.3, color=self.colors['grid'])
ax.tick_params(labelsize=12)
# Legend
legend = ax.legend(loc='upper left', fontsize=12, frameon=True)
legend.get_frame().set_facecolor('white')
# Timestamp
timestamp = datetime.now().strftime("%Y-%m-%d %H:%M")
fig.text(0.99, 0.01, f'Generated: {timestamp}', ha='right', va='bottom',
fontsize=10, alpha=0.7)
plt.tight_layout()
if save_path:
plt.savefig(save_path, dpi=300, bbox_inches='tight', facecolor='white')
logger.info(f"Comparison plot saved to: {save_path}")
logger.info("βœ… Forecast comparison visualization completed")
return fig