Files
JPMC-quant/plots.py

857 lines
24 KiB
Python

"""Plotting functions for Voice Branding analysis."""
import plotly.graph_objects as go
import polars as pl
from theme import ColorPalette
def plot_average_scores_with_counts(
df: pl.DataFrame,
title: str = "General Impression (1-10)<br>Per Voice with Number of Participants Who Rated It",
x_label: str = "Stimuli",
y_label: str = "Average General Impression Rating (1-10)",
color: str = ColorPalette.PRIMARY,
height: int = 500,
width: int = 1000,
) -> go.Figure:
"""
Create a bar plot showing average scores and count of non-null values for each column.
Parameters
----------
df : pl.DataFrame
DataFrame containing numeric columns to analyze.
title : str, optional
Plot title.
x_label : str, optional
X-axis label.
y_label : str, optional
Y-axis label.
color : str, optional
Bar color (hex code or named color).
height : int, optional
Plot height in pixels.
width : int, optional
Plot width in pixels.
Returns
-------
go.Figure
Plotly figure object.
"""
# Calculate average and count of non-null values for each column
# Exclude _recordId column
stats = []
for col in [c for c in df.columns if c != '_recordId']:
avg_score = df[col].mean()
non_null_count = df[col].drop_nulls().len()
stats.append({
'column': col,
'average': avg_score,
'count': non_null_count
})
# Sort by average score in descending order
stats_df = pl.DataFrame(stats).sort('average', descending=True)
# Extract voice identifiers from column names (e.g., "V14" from "Voice_Scale_1_10__V14")
labels = [col.split('__')[-1] if '__' in col else col for col in stats_df['column']]
# Create the plot
fig = go.Figure()
fig.add_trace(go.Bar(
x=labels,
y=stats_df['average'],
text=stats_df['count'],
textposition='inside',
textfont=dict(size=10, color='black'),
marker_color=color,
hovertemplate='<b>%{x}</b><br>Average: %{y:.2f}<br>Count: %{text}<extra></extra>'
))
fig.update_layout(
title=title,
xaxis_title=x_label,
yaxis_title=y_label,
height=height,
width=width,
plot_bgcolor=ColorPalette.BACKGROUND,
xaxis=dict(
showgrid=True,
gridcolor=ColorPalette.GRID,
tickangle=-45
),
yaxis=dict(
range=[0, 10],
showgrid=True,
gridcolor=ColorPalette.GRID
),
font=dict(size=11)
)
return fig
def plot_top3_ranking_distribution(
df: pl.DataFrame,
title: str = "Top 3 Rankings Distribution<br>Count of 1st, 2nd, and 3rd Place Votes per Voice",
x_label: str = "Voices",
y_label: str = "Number of Mentions in Top 3",
height: int = 500,
width: int = 1000,
) -> go.Figure:
"""
Create a stacked bar chart showing how often each voice was ranked 1st, 2nd, or 3rd.
The total height of the bar represents the popularity (frequency of being in Top 3),
while the segments show the quality of those rankings.
Parameters
----------
df : pl.DataFrame
DataFrame containing ranking columns (values 1, 2, 3).
title : str, optional
Plot title.
x_label : str, optional
X-axis label.
y_label : str, optional
Y-axis label.
height : int, optional
Plot height in pixels.
width : int, optional
Plot width in pixels.
Returns
-------
go.Figure
Plotly figure object.
"""
# Exclude _recordId column
stats = []
for col in [c for c in df.columns if c != '_recordId']:
# Count occurrences of each rank (1, 2, 3)
# We ensure we're just counting the specific integer values
rank1 = df.filter(pl.col(col) == 1).height
rank2 = df.filter(pl.col(col) == 2).height
rank3 = df.filter(pl.col(col) == 3).height
total = rank1 + rank2 + rank3
# Only include if it received at least one vote (optional, but keeps chart clean)
if total > 0:
stats.append({
'column': col,
'Rank 1': rank1,
'Rank 2': rank2,
'Rank 3': rank3,
'Total': total
})
# Sort by Total count descending (Most popular overall)
# Tie-break with Rank 1 count
stats_df = pl.DataFrame(stats).sort(['Total', 'Rank 1'], descending=[True, True])
# Extract voice identifiers from column names
labels = [col.split('__')[-1] if '__' in col else col for col in stats_df['column']]
fig = go.Figure()
# Add traces for Rank 1, 2, and 3.
# Stack order: Rank 1 at bottom (Base) -> Rank 2 -> Rank 3
# This makes it easy to compare the "First Choice" volume across bars.
fig.add_trace(go.Bar(
name='Rank 1 (1st Choice)',
x=labels,
y=stats_df['Rank 1'],
marker_color=ColorPalette.RANK_1,
hovertemplate='<b>%{x}</b><br>Rank 1: %{y}<extra></extra>'
))
fig.add_trace(go.Bar(
name='Rank 2 (2nd Choice)',
x=labels,
y=stats_df['Rank 2'],
marker_color=ColorPalette.RANK_2,
hovertemplate='<b>%{x}</b><br>Rank 2: %{y}<extra></extra>'
))
fig.add_trace(go.Bar(
name='Rank 3 (3rd Choice)',
x=labels,
y=stats_df['Rank 3'],
marker_color=ColorPalette.RANK_3,
hovertemplate='<b>%{x}</b><br>Rank 3: %{y}<extra></extra>'
))
fig.update_layout(
barmode='stack',
title=title,
xaxis_title=x_label,
yaxis_title=y_label,
height=height,
width=width,
plot_bgcolor=ColorPalette.BACKGROUND,
xaxis=dict(
showgrid=True,
gridcolor=ColorPalette.GRID,
tickangle=-45
),
yaxis=dict(
showgrid=True,
gridcolor=ColorPalette.GRID
),
legend=dict(
orientation="h",
yanchor="bottom",
y=1.02,
xanchor="right",
x=1,
traceorder="normal"
),
font=dict(size=11)
)
return fig
def plot_ranking_distribution(
df: pl.DataFrame,
title: str = "Rankings Distribution<br>(1st to 4th Place)",
x_label: str = "Item",
y_label: str = "Number of Votes",
height: int = 500,
width: int = 1000,
) -> go.Figure:
"""
Create a stacked bar chart showing the distribution of rankings (1st to 4th) for characters or voices.
Sorted by the number of Rank 1 votes.
Parameters
----------
df : pl.DataFrame
DataFrame containing ranking columns.
title : str, optional
Plot title.
x_label : str, optional
X-axis label.
y_label : str, optional
Y-axis label.
height : int, optional
Plot height in pixels.
width : int, optional
Plot width in pixels.
Returns
-------
go.Figure
Plotly figure object.
"""
stats = []
# Identify ranking columns (assume all columns except _recordId)
ranking_cols = [c for c in df.columns if c != '_recordId']
for col in ranking_cols:
# Count occurrences of each rank (1, 2, 3, 4)
# Using height/len to count rows in the filtered frame
r1 = df.filter(pl.col(col) == 1).height
r2 = df.filter(pl.col(col) == 2).height
r3 = df.filter(pl.col(col) == 3).height
r4 = df.filter(pl.col(col) == 4).height
total = r1 + r2 + r3 + r4
if total > 0:
stats.append({
'column': col,
'Rank 1': r1,
'Rank 2': r2,
'Rank 3': r3,
'Rank 4': r4
})
if not stats:
return go.Figure()
# Sort by Rank 1 (Most "Best" votes) descending to show the winner first
# Secondary sort by Rank 2
stats_df = pl.DataFrame(stats).sort(['Rank 1', 'Rank 2'], descending=[True, True])
# Clean up labels: Remove prefix and underscores
# e.g. "Character_Ranking_The_Coach" -> "The Coach"
labels = [
col.replace('Character_Ranking_', '').replace('Top_3_Voices_ranking__', '').replace('_', ' ').strip()
for col in stats_df['column']
]
fig = go.Figure()
# Rank 1 (Best)
fig.add_trace(go.Bar(
name='Rank 1 (Best)',
x=labels,
y=stats_df['Rank 1'],
marker_color=ColorPalette.RANK_1,
hovertemplate='<b>%{x}</b><br>Rank 1: %{y}<extra></extra>'
))
# Rank 2
fig.add_trace(go.Bar(
name='Rank 2',
x=labels,
y=stats_df['Rank 2'],
marker_color=ColorPalette.RANK_2,
hovertemplate='<b>%{x}</b><br>Rank 2: %{y}<extra></extra>'
))
# Rank 3
fig.add_trace(go.Bar(
name='Rank 3',
x=labels,
y=stats_df['Rank 3'],
marker_color=ColorPalette.RANK_3,
hovertemplate='<b>%{x}</b><br>Rank 3: %{y}<extra></extra>'
))
# Rank 4 (Worst)
# Using a neutral grey as a fallback for the lowest rank to keep focus on top ranks
fig.add_trace(go.Bar(
name='Rank 4 (Worst)',
x=labels,
y=stats_df['Rank 4'],
marker_color=ColorPalette.RANK_4,
hovertemplate='<b>%{x}</b><br>Rank 4: %{y}<extra></extra>'
))
fig.update_layout(
barmode='stack',
title=title,
xaxis_title=x_label,
yaxis_title=y_label,
height=height,
width=width,
plot_bgcolor=ColorPalette.BACKGROUND,
xaxis=dict(
showgrid=True,
gridcolor=ColorPalette.GRID,
tickangle=-45
),
yaxis=dict(
showgrid=True,
gridcolor=ColorPalette.GRID
),
legend=dict(
orientation="h",
yanchor="bottom",
y=1.02,
xanchor="right",
x=1,
traceorder="normal"
),
font=dict(size=11)
)
return fig
def plot_most_ranked_1(
df: pl.DataFrame,
title: str = "Most Popular Choice<br>(Number of Times Ranked 1st)",
x_label: str = "Item",
y_label: str = "Count of 1st Place Rankings",
height: int = 500,
width: int = 1000,
) -> go.Figure:
"""
Create a bar chart showing which item (character/voice) was ranked #1 the most.
Top 3 items are highlighted.
Parameters
----------
df : pl.DataFrame
DataFrame containing ranking columns.
title : str, optional
Plot title.
x_label : str, optional
X-axis label.
y_label : str, optional
Y-axis label.
height : int, optional
Plot height in pixels.
width : int, optional
Plot width in pixels.
Returns
-------
go.Figure
Plotly figure object.
"""
stats = []
# Identify ranking columns (assume all columns except _recordId)
ranking_cols = [c for c in df.columns if c != '_recordId']
for col in ranking_cols:
# Count occurrences of rank 1
count_rank_1 = df.filter(pl.col(col) == 1).height
stats.append({
'column': col,
'count': count_rank_1
})
# Sort by count descending
stats_df = pl.DataFrame(stats).sort('count', descending=True)
# Clean up labels
labels = [
col.replace('Character_Ranking_', '').replace('Top_3_Voices_ranking__', '').replace('_', ' ').strip()
for col in stats_df['column']
]
# Assign colors: Top 3 get PRIMARY (Blue), others get NEUTRAL (Grey)
colors = [
ColorPalette.PRIMARY if i < 3 else ColorPalette.NEUTRAL
for i in range(len(stats_df))
]
fig = go.Figure()
fig.add_trace(go.Bar(
x=labels,
y=stats_df['count'],
text=stats_df['count'],
textposition='inside',
textfont=dict(size=10, color='white'),
marker_color=colors,
hovertemplate='<b>%{x}</b><br>1st Place Votes: %{y}<extra></extra>'
))
fig.update_layout(
title=title,
xaxis_title=x_label,
yaxis_title=y_label,
height=height,
width=width,
plot_bgcolor=ColorPalette.BACKGROUND,
xaxis=dict(
showgrid=True,
gridcolor=ColorPalette.GRID,
tickangle=-45
),
yaxis=dict(
showgrid=True,
gridcolor=ColorPalette.GRID
),
font=dict(size=11)
)
return fig
def plot_weighted_ranking_score(
weighted_df: pl.DataFrame,
title: str = "Weighted Popularity Score<br>(1st=3pts, 2nd=2pts, 3rd=1pt)",
x_label: str = "Character Personality",
y_label: str = "Total Weighted Score",
color: str = ColorPalette.PRIMARY,
height: int = 500,
width: int = 1000,
) -> go.Figure:
"""
Create a bar chart showing the weighted ranking score for each character.
Parameters
----------
df : pl.DataFrame
DataFrame containing ranking columns.
title : str, optional
Plot title.
x_label : str, optional
X-axis label.
y_label : str, optional
Y-axis label.
color : str, optional
Bar color.
height : int, optional
Plot height.
width : int, optional
Plot width.
Returns
-------
go.Figure
Plotly figure object.
"""
fig = go.Figure()
fig.add_trace(go.Bar(
x=weighted_df['Character'],
y=weighted_df['Weighted Score'],
text=weighted_df['Weighted Score'],
textposition='inside',
textfont=dict(size=11, color='white'),
marker_color=color,
hovertemplate='<b>%{x}</b><br>Score: %{y}<extra></extra>'
))
fig.update_layout(
title=title,
xaxis_title=x_label,
yaxis_title=y_label,
height=height,
width=width,
plot_bgcolor=ColorPalette.BACKGROUND,
xaxis=dict(
showgrid=True,
gridcolor=ColorPalette.GRID,
tickangle=-45
),
yaxis=dict(
showgrid=True,
gridcolor=ColorPalette.GRID
),
font=dict(size=11)
)
return fig
def plot_voice_selection_counts(
df: pl.DataFrame,
target_column: str = "8_Combined",
title: str = "Most Frequently Chosen Voices<br>(Top 8 Highlighted)",
x_label: str = "Voice",
y_label: str = "Number of Times Chosen",
height: int = 500,
width: int = 1000,
) -> go.Figure:
"""
Create a bar plot showing the frequency of voice selections.
Takes a column containing comma-separated values (e.g. "Voice 1, Voice 2..."),
counts occurrences, and highlights the top 8 most frequent voices.
Parameters
----------
df : pl.DataFrame
DataFrame containing the selection column.
target_column : str, optional
Name of the column containing comma-separated voice selections.
Defaults to "8_Combined".
title : str, optional
Plot title.
x_label : str, optional
X-axis label.
y_label : str, optional
Y-axis label.
height : int, optional
Plot height in pixels.
width : int, optional
Plot width in pixels.
Returns
-------
go.Figure
Plotly figure object.
"""
if target_column not in df.columns:
return go.Figure()
# Process the data:
# 1. Select the relevant column and remove nulls
# 2. Split the comma-separated string into a list
# 3. Explode the list so each voice gets its own row
# 4. Strip whitespace ensuring "Voice 1" and " Voice 1" match
# 5. Count occurrences
stats_df = (
df.select(pl.col(target_column))
.drop_nulls()
.with_columns(pl.col(target_column).str.split(","))
.explode(target_column)
.with_columns(pl.col(target_column).str.strip_chars())
.filter(pl.col(target_column) != "")
.group_by(target_column)
.agg(pl.len().alias("count"))
.sort("count", descending=True)
)
# Define colors: Top 8 get PRIMARY, rest get NEUTRAL
colors = [
ColorPalette.PRIMARY if i < 8 else ColorPalette.NEUTRAL
for i in range(len(stats_df))
]
fig = go.Figure()
fig.add_trace(go.Bar(
x=stats_df[target_column],
y=stats_df['count'],
text=stats_df['count'],
textposition='outside',
marker_color=colors,
hovertemplate='<b>%{x}</b><br>Selections: %{y}<extra></extra>'
))
fig.update_layout(
title=title,
xaxis_title=x_label,
yaxis_title=y_label,
height=height,
width=width,
plot_bgcolor=ColorPalette.BACKGROUND,
xaxis=dict(
showgrid=True,
gridcolor=ColorPalette.GRID,
tickangle=-45
),
yaxis=dict(
showgrid=True,
gridcolor=ColorPalette.GRID
),
font=dict(size=11),
)
return fig
def plot_top3_selection_counts(
df: pl.DataFrame,
target_column: str = "3_Ranked",
title: str = "Most Frequently Chosen Top 3 Voices<br>(Top 3 Highlighted)",
x_label: str = "Voice",
y_label: str = "Count of Mentions in Top 3",
height: int = 500,
width: int = 1000,
) -> go.Figure:
"""
Question: Which 3 voices are chosen the most out of 18?
How many times does each voice end up in the top 3?
(this is based on the survey question where participants need to choose 3 out
of the earlier selected 8 voices). So how often each of the 18 stimuli ended
up in participants' Top 3, after they first selected 8 out of 18.
Parameters
----------
df : pl.DataFrame
DataFrame containing the ranking column (comma-separated strings).
target_column : str, optional
Name of the column containing comma-separated Top 3 voice elections.
Defaults to "3_Ranked".
title : str, optional
Plot title.
x_label : str, optional
X-axis label.
y_label : str, optional
Y-axis label.
height : int, optional
Plot height in pixels.
width : int, optional
Plot width in pixels.
Returns
-------
go.Figure
Plotly figure object.
"""
if target_column not in df.columns:
return go.Figure()
# Process the data:
# Same logic as plot_voice_selection_counts: explode comma-separated string
stats_df = (
df.select(pl.col(target_column))
.drop_nulls()
.with_columns(pl.col(target_column).str.split(","))
.explode(target_column)
.with_columns(pl.col(target_column).str.strip_chars())
.filter(pl.col(target_column) != "")
.group_by(target_column)
.agg(pl.len().alias("count"))
.sort("count", descending=True)
)
# Define colors: Top 3 get PRIMARY, rest get NEUTRAL
colors = [
ColorPalette.PRIMARY if i < 3 else ColorPalette.NEUTRAL
for i in range(len(stats_df))
]
fig = go.Figure()
fig.add_trace(go.Bar(
x=stats_df[target_column],
y=stats_df['count'],
text=stats_df['count'],
textposition='outside',
marker_color=colors,
hovertemplate='<b>%{x}</b><br>In Top 3: %{y} times<extra></extra>'
))
fig.update_layout(
title=title,
xaxis_title=x_label,
yaxis_title=y_label,
height=height,
width=width,
plot_bgcolor=ColorPalette.BACKGROUND,
xaxis=dict(
showgrid=True,
gridcolor=ColorPalette.GRID,
tickangle=-45
),
yaxis=dict(
showgrid=True,
gridcolor=ColorPalette.GRID
),
font=dict(size=11),
)
return fig
def plot_speaking_style_trait_scores(
df: pl.DataFrame,
trait_description: str = None,
left_anchor: str = None,
right_anchor: str = None,
title: str = "Speaking Style Trait Analysis",
height: int = 500,
width: int = 1000,
) -> go.Figure:
"""
Plot scores for a single speaking style trait across multiple voices.
The plot shows the average score per Voice, sorted by score.
It expects the DataFrame to contain 'Voice' and 'score' columns,
typically filtered for a single trait/description.
Parameters
----------
df : pl.DataFrame
DataFrame containing at least 'Voice' and 'score' columns.
Produced by utils.process_speaking_style_data and filtered.
trait_description : str, optional
Description of the trait being analyzed (e.g. "Indifferent : Attentive").
If not provided, it will be constructed from annotations.
left_anchor : str, optional
Label for the lower end of the scale (e.g. "Indifferent").
If not provided, attempts to read 'Left_Anchor' column from df.
right_anchor : str, optional
Label for the upper end of the scale (e.g. "Attentive").
If not provided, attempts to read 'Right_Anchor' column from df.
title : str, optional
Plot title.
height : int, optional
Plot height.
width : int, optional
Plot width.
Returns
-------
go.Figure
Plotly figure object.
"""
if df.is_empty():
return go.Figure()
required_cols = ["Voice", "score"]
if not all(col in df.columns for col in required_cols):
return go.Figure()
# Calculate stats: Mean, Count
stats = (
df.filter(pl.col("score").is_not_null())
.group_by("Voice")
.agg([
pl.col("score").mean().alias("mean_score"),
pl.col("score").count().alias("count")
])
.sort("mean_score", descending=False) # Ascending for display bottom-to-top
)
# Attempt to extract anchors from DF if not provided
if (left_anchor is None or right_anchor is None) and "Left_Anchor" in df.columns:
head = df.filter(pl.col("Left_Anchor").is_not_null()).head(1)
if not head.is_empty():
if left_anchor is None: left_anchor = head["Left_Anchor"][0]
if right_anchor is None: right_anchor = head["Right_Anchor"][0]
if trait_description is None:
if left_anchor and right_anchor:
trait_description = f"{left_anchor.split('|')[0]} vs. {right_anchor.split('|')[0]}"
else:
# Try getting from Description column
if "Description" in df.columns:
head = df.filter(pl.col("Description").is_not_null()).head(1)
if not head.is_empty():
trait_description = head["Description"][0]
else:
trait_description = ""
else:
trait_description = ""
fig = go.Figure()
fig.add_trace(go.Bar(
y=stats["Voice"], # Y is Voice
x=stats["mean_score"], # X is Score
orientation='h',
text=stats["count"],
textposition='inside',
textangle=0,
textfont=dict(size=16, color='white'),
texttemplate='%{text}', # Count on bar
marker_color=ColorPalette.PRIMARY,
hovertemplate='<b>%{y}</b><br>Average: %{x:.2f}<br>Count: %{text}<extra></extra>'
))
# Add annotations for anchors
annotations = []
# Place anchors at the bottom
if left_anchor:
annotations.append(dict(
xref='x', yref='paper',
x=1, y=-0.2, # Below axis
xanchor='left', yanchor='top',
text=f"<b>1: {left_anchor.split('|')[0]}</b>",
showarrow=False,
font=dict(size=10, color='gray')
))
if right_anchor:
annotations.append(dict(
xref='x', yref='paper',
x=5, y=-0.2, # Below axis
xanchor='right', yanchor='top',
text=f"<b>5: {right_anchor.split('|')[0]}</b>",
showarrow=False,
font=dict(size=10, color='gray')
))
fig.update_layout(
title=dict(
text=f"{title}<br><sub>{trait_description}</sub><br><sub>(Numbers on bars indicate respondent count)</sub>",
y=0.92
),
xaxis_title="Average Score (1-5)",
yaxis_title="Voice",
height=height,
width=width,
plot_bgcolor=ColorPalette.BACKGROUND,
xaxis=dict(
range=[1, 5],
showgrid=True,
gridcolor=ColorPalette.GRID,
zeroline=False
),
yaxis=dict(
showgrid=False
),
margin=dict(b=120),
annotations=annotations,
font=dict(size=11)
)
return fig