Files
JPMC-quant/plots.py
2026-01-22 20:51:57 +01:00

214 lines
6.0 KiB
Python

"""Plotting functions for Voice Branding analysis."""
import plotly.graph_objects as go
import polars as pl
from theme import ColorPalette
def plot_average_scores_with_counts(
df: pl.DataFrame,
title: str = "General Impression (1-10)<br>Per Voice with Number of Participants Who Rated It",
x_label: str = "Stimuli",
y_label: str = "Average General Impression Rating (1-10)",
color: str = ColorPalette.PRIMARY,
height: int = 500,
width: int = 1000,
) -> go.Figure:
"""
Create a bar plot showing average scores and count of non-null values for each column.
Parameters
----------
df : pl.DataFrame
DataFrame containing numeric columns to analyze.
title : str, optional
Plot title.
x_label : str, optional
X-axis label.
y_label : str, optional
Y-axis label.
color : str, optional
Bar color (hex code or named color).
height : int, optional
Plot height in pixels.
width : int, optional
Plot width in pixels.
Returns
-------
go.Figure
Plotly figure object.
"""
# Calculate average and count of non-null values for each column
stats = []
for col in df.columns:
avg_score = df[col].mean()
non_null_count = df[col].drop_nulls().len()
stats.append({
'column': col,
'average': avg_score,
'count': non_null_count
})
# Sort by average score in descending order
stats_df = pl.DataFrame(stats).sort('average', descending=True)
# Extract voice identifiers from column names (e.g., "V14" from "Voice_Scale_1_10__V14")
labels = [col.split('__')[-1] if '__' in col else col for col in stats_df['column']]
# Create the plot
fig = go.Figure()
fig.add_trace(go.Bar(
x=labels,
y=stats_df['average'],
text=stats_df['count'],
textposition='inside',
textfont=dict(size=10, color='black'),
marker_color=color,
hovertemplate='<b>%{x}</b><br>Average: %{y:.2f}<br>Count: %{text}<extra></extra>'
))
fig.update_layout(
title=title,
xaxis_title=x_label,
yaxis_title=y_label,
height=height,
width=width,
plot_bgcolor=ColorPalette.BACKGROUND,
xaxis=dict(
showgrid=True,
gridcolor=ColorPalette.GRID,
tickangle=-45
),
yaxis=dict(
range=[0, 10],
showgrid=True,
gridcolor=ColorPalette.GRID
),
font=dict(size=11)
)
return fig
def plot_top3_ranking_distribution(
df: pl.DataFrame,
title: str = "Top 3 Rankings Distribution<br>Count of 1st, 2nd, and 3rd Place Votes per Voice",
x_label: str = "Voices",
y_label: str = "Number of Mentions in Top 3",
height: int = 500,
width: int = 1000,
) -> go.Figure:
"""
Create a stacked bar chart showing how often each voice was ranked 1st, 2nd, or 3rd.
The total height of the bar represents the popularity (frequency of being in Top 3),
while the segments show the quality of those rankings.
Parameters
----------
df : pl.DataFrame
DataFrame containing ranking columns (values 1, 2, 3).
title : str, optional
Plot title.
x_label : str, optional
X-axis label.
y_label : str, optional
Y-axis label.
height : int, optional
Plot height in pixels.
width : int, optional
Plot width in pixels.
Returns
-------
go.Figure
Plotly figure object.
"""
stats = []
for col in df.columns:
# Count occurrences of each rank (1, 2, 3)
# We ensure we're just counting the specific integer values
rank1 = df.filter(pl.col(col) == 1).height
rank2 = df.filter(pl.col(col) == 2).height
rank3 = df.filter(pl.col(col) == 3).height
total = rank1 + rank2 + rank3
# Only include if it received at least one vote (optional, but keeps chart clean)
if total > 0:
stats.append({
'column': col,
'Rank 1': rank1,
'Rank 2': rank2,
'Rank 3': rank3,
'Total': total
})
# Sort by Total count descending (Most popular overall)
# Tie-break with Rank 1 count
stats_df = pl.DataFrame(stats).sort(['Total', 'Rank 1'], descending=[True, True])
# Extract voice identifiers from column names
labels = [col.split('__')[-1] if '__' in col else col for col in stats_df['column']]
fig = go.Figure()
# Add traces for Rank 1, 2, and 3.
# Stack order: Rank 1 at bottom (Base) -> Rank 2 -> Rank 3
# This makes it easy to compare the "First Choice" volume across bars.
fig.add_trace(go.Bar(
name='Rank 1 (1st Choice)',
x=labels,
y=stats_df['Rank 1'],
marker_color=ColorPalette.RANK_1,
hovertemplate='<b>%{x}</b><br>Rank 1: %{y}<extra></extra>'
))
fig.add_trace(go.Bar(
name='Rank 2 (2nd Choice)',
x=labels,
y=stats_df['Rank 2'],
marker_color=ColorPalette.RANK_2,
hovertemplate='<b>%{x}</b><br>Rank 2: %{y}<extra></extra>'
))
fig.add_trace(go.Bar(
name='Rank 3 (3rd Choice)',
x=labels,
y=stats_df['Rank 3'],
marker_color=ColorPalette.RANK_3,
hovertemplate='<b>%{x}</b><br>Rank 3: %{y}<extra></extra>'
))
fig.update_layout(
barmode='stack',
title=title,
xaxis_title=x_label,
yaxis_title=y_label,
height=height,
width=width,
plot_bgcolor=ColorPalette.BACKGROUND,
xaxis=dict(
showgrid=True,
gridcolor=ColorPalette.GRID,
tickangle=-45
),
yaxis=dict(
showgrid=True,
gridcolor=ColorPalette.GRID
),
legend=dict(
orientation="h",
yanchor="bottom",
y=1.02,
xanchor="right",
x=1,
traceorder="normal"
),
font=dict(size=11)
)
return fig