Files
JPMC-quant/plots.py

962 lines
33 KiB
Python

"""Plotting functions for Voice Branding analysis."""
import re
from pathlib import Path
import plotly.graph_objects as go
import polars as pl
from theme import ColorPalette
class JPMCPlotsMixin:
"""Mixin class for plotting functions in JPMCSurvey."""
def _sanitize_filename(self, title: str) -> str:
"""Convert plot title to a safe filename."""
# Remove HTML tags
clean = re.sub(r'<[^>]+>', ' ', title)
# Replace special characters with underscores
clean = re.sub(r'[^\w\s-]', '', clean)
# Replace whitespace with underscores
clean = re.sub(r'\s+', '_', clean.strip())
# Remove consecutive underscores
clean = re.sub(r'_+', '_', clean)
# Lowercase and limit length
return clean.lower()[:100]
def _get_filter_slug(self) -> str:
"""Generate a directory-friendly slug based on active filters."""
parts = []
# Mapping of attribute name to (short_code, value, options_attr)
filters = [
('age', 'Age', getattr(self, 'filter_age', None), 'options_age'),
('gender', 'Gen', getattr(self, 'filter_gender', None), 'options_gender'),
('consumer', 'Cons', getattr(self, 'filter_consumer', None), 'options_consumer'),
('ethnicity', 'Eth', getattr(self, 'filter_ethnicity', None), 'options_ethnicity'),
('income', 'Inc', getattr(self, 'filter_income', None), 'options_income'),
]
for _, short_code, value, options_attr in filters:
if value is None:
continue
# Ensure value is a list for uniform handling
if not isinstance(value, list):
value = [value]
if len(value) == 0:
continue
# Check if all options are selected (equivalent to no filter)
# We compare the set of selected values to the set of all available options
master_list = getattr(self, options_attr, None)
if master_list and set(value) == set(master_list):
continue
if len(value) > 3:
# If more than 3 options selected, use count to keep slug short
val_str = f"{len(value)}_grps"
else:
# Join values with '+'
clean_values = []
for v in value:
# Simple sanitization: keep alphanum and hyphens/dots, remove others
s = str(v)
# Remove special chars that might be problematic in dir names
s = re.sub(r'[^\w\-\.]', '', s)
clean_values.append(s)
val_str = "+".join(clean_values)
parts.append(f"{short_code}-{val_str}")
if not parts:
return "All_Respondents"
return "_".join(parts)
def _save_plot(self, fig: go.Figure, title: str) -> None:
"""Save plot to PNG file if fig_save_dir is set."""
if hasattr(self, 'fig_save_dir') and self.fig_save_dir:
path = Path(self.fig_save_dir)
# Add filter slug subfolder
filter_slug = self._get_filter_slug()
path = path / filter_slug
if not path.exists():
path.mkdir(parents=True, exist_ok=True)
filename = f"{self._sanitize_filename(title)}.png"
fig.write_image(path / filename, width=fig.layout.width, height=fig.layout.height)
def _ensure_dataframe(self, data: pl.LazyFrame | pl.DataFrame | None) -> pl.DataFrame:
"""Ensure data is an eager DataFrame, collecting if necessary."""
df = data if data is not None else getattr(self, 'data_filtered', None)
if df is None:
raise ValueError("No data provided and self.data_filtered is None.")
if isinstance(df, pl.LazyFrame):
return df.collect()
return df
def plot_average_scores_with_counts(
self,
data: pl.LazyFrame | pl.DataFrame | None = None,
title: str = "General Impression (1-10)<br>Per Voice with Number of Participants Who Rated It",
x_label: str = "Stimuli",
y_label: str = "Average General Impression Rating (1-10)",
color: str = ColorPalette.PRIMARY,
height: int | None = None,
width: int | None = None,
) -> go.Figure:
"""
Create a bar plot showing average scores and count of non-null values for each column.
"""
df = self._ensure_dataframe(data)
# Exclude _recordId column
stats = []
for col in [c for c in df.columns if c != '_recordId']:
avg_score = df[col].mean()
non_null_count = df[col].drop_nulls().len()
stats.append({
'column': col,
'average': avg_score,
'count': non_null_count
})
# Sort by average score in descending order
stats_df = pl.DataFrame(stats).sort('average', descending=True)
# Extract voice identifiers from column names (e.g., "V14" from "Voice_Scale_1_10__V14")
labels = [col.split('__')[-1] if '__' in col else col for col in stats_df['column']]
# Create the plot
fig = go.Figure()
fig.add_trace(go.Bar(
x=labels,
y=stats_df['average'],
text=stats_df['count'],
textposition='inside',
textfont=dict(size=10, color='black'),
marker_color=color,
hovertemplate='<b>%{x}</b><br>Average: %{y:.2f}<br>Count: %{text}<extra></extra>'
))
fig.update_layout(
title=title,
xaxis_title=x_label,
yaxis_title=y_label,
height=height if height else getattr(self, 'plot_height', 500),
width=width if width else getattr(self, 'plot_width', 1000),
plot_bgcolor=ColorPalette.BACKGROUND,
xaxis=dict(
showgrid=True,
gridcolor=ColorPalette.GRID,
tickangle=-45
),
yaxis=dict(
range=[0, 10],
showgrid=True,
gridcolor=ColorPalette.GRID
),
font=dict(size=11)
)
self._save_plot(fig, title)
return fig
def plot_top3_ranking_distribution(
self,
data: pl.LazyFrame | pl.DataFrame | None = None,
title: str = "Top 3 Rankings Distribution<br>Count of 1st, 2nd, and 3rd Place Votes per Voice",
x_label: str = "Voices",
y_label: str = "Number of Mentions in Top 3",
height: int | None = None,
width: int | None = None,
) -> go.Figure:
"""
Create a stacked bar chart showing how often each voice was ranked 1st, 2nd, or 3rd.
"""
df = self._ensure_dataframe(data)
# Exclude _recordId column
stats = []
for col in [c for c in df.columns if c != '_recordId']:
# Count occurrences of each rank (1, 2, 3)
# We ensure we're just counting the specific integer values
rank1 = df.filter(pl.col(col) == 1).height
rank2 = df.filter(pl.col(col) == 2).height
rank3 = df.filter(pl.col(col) == 3).height
total = rank1 + rank2 + rank3
# Only include if it received at least one vote (optional, but keeps chart clean)
if total > 0:
stats.append({
'column': col,
'Rank 1': rank1,
'Rank 2': rank2,
'Rank 3': rank3,
'Total': total
})
# Sort by Total count descending (Most popular overall)
# Tie-break with Rank 1 count
stats_df = pl.DataFrame(stats).sort(['Total', 'Rank 1'], descending=[True, True])
# Extract voice identifiers from column names
labels = [col.split('__')[-1] if '__' in col else col for col in stats_df['column']]
fig = go.Figure()
# Add traces for Rank 1, 2, and 3.
# Stack order: Rank 1 at bottom (Base) -> Rank 2 -> Rank 3
# This makes it easy to compare the "First Choice" volume across bars.
fig.add_trace(go.Bar(
name='Rank 1 (1st Choice)',
x=labels,
y=stats_df['Rank 1'],
marker_color=ColorPalette.RANK_1,
hovertemplate='<b>%{x}</b><br>Rank 1: %{y}<extra></extra>'
))
fig.add_trace(go.Bar(
name='Rank 2 (2nd Choice)',
x=labels,
y=stats_df['Rank 2'],
marker_color=ColorPalette.RANK_2,
hovertemplate='<b>%{x}</b><br>Rank 2: %{y}<extra></extra>'
))
fig.add_trace(go.Bar(
name='Rank 3 (3rd Choice)',
x=labels,
y=stats_df['Rank 3'],
marker_color=ColorPalette.RANK_3,
hovertemplate='<b>%{x}</b><br>Rank 3: %{y}<extra></extra>'
))
fig.update_layout(
barmode='stack',
title=title,
xaxis_title=x_label,
yaxis_title=y_label,
height=height if height else getattr(self, 'plot_height', 500),
width=width if width else getattr(self, 'plot_width', 1000),
plot_bgcolor=ColorPalette.BACKGROUND,
xaxis=dict(
showgrid=True,
gridcolor=ColorPalette.GRID,
tickangle=-45
),
yaxis=dict(
showgrid=True,
gridcolor=ColorPalette.GRID
),
legend=dict(
orientation="h",
yanchor="bottom",
y=1.02,
xanchor="right",
x=1,
traceorder="normal"
),
font=dict(size=11)
)
self._save_plot(fig, title)
return fig
def plot_ranking_distribution(
self,
data: pl.LazyFrame | pl.DataFrame | None = None,
title: str = "Rankings Distribution<br>(1st to 4th Place)",
x_label: str = "Item",
y_label: str = "Number of Votes",
height: int | None = None,
width: int | None = None,
) -> go.Figure:
"""
Create a stacked bar chart showing the distribution of rankings (1st to 4th) for characters or voices.
Sorted by the number of Rank 1 votes.
"""
df = self._ensure_dataframe(data)
stats = []
# Identify ranking columns (assume all columns except _recordId)
ranking_cols = [c for c in df.columns if c != '_recordId']
for col in ranking_cols:
# Count occurrences of each rank (1, 2, 3, 4)
# Using height/len to count rows in the filtered frame
r1 = df.filter(pl.col(col) == 1).height
r2 = df.filter(pl.col(col) == 2).height
r3 = df.filter(pl.col(col) == 3).height
r4 = df.filter(pl.col(col) == 4).height
total = r1 + r2 + r3 + r4
if total > 0:
stats.append({
'column': col,
'Rank 1': r1,
'Rank 2': r2,
'Rank 3': r3,
'Rank 4': r4
})
if not stats:
return go.Figure()
# Sort by Rank 1 (Most "Best" votes) descending to show the winner first
# Secondary sort by Rank 2
stats_df = pl.DataFrame(stats).sort(['Rank 1', 'Rank 2'], descending=[True, True])
# Clean up labels: Remove prefix and underscores
# e.g. "Character_Ranking_The_Coach" -> "The Coach"
labels = [
col.replace('Character_Ranking_', '').replace('Top_3_Voices_ranking__', '').replace('_', ' ').strip()
for col in stats_df['column']
]
fig = go.Figure()
# Rank 1 (Best)
fig.add_trace(go.Bar(
name='Rank 1 (Best)',
x=labels,
y=stats_df['Rank 1'],
marker_color=ColorPalette.RANK_1,
hovertemplate='<b>%{x}</b><br>Rank 1: %{y}<extra></extra>'
))
# Rank 2
fig.add_trace(go.Bar(
name='Rank 2',
x=labels,
y=stats_df['Rank 2'],
marker_color=ColorPalette.RANK_2,
hovertemplate='<b>%{x}</b><br>Rank 2: %{y}<extra></extra>'
))
# Rank 3
fig.add_trace(go.Bar(
name='Rank 3',
x=labels,
y=stats_df['Rank 3'],
marker_color=ColorPalette.RANK_3,
hovertemplate='<b>%{x}</b><br>Rank 3: %{y}<extra></extra>'
))
# Rank 4 (Worst)
# Using a neutral grey as a fallback for the lowest rank to keep focus on top ranks
fig.add_trace(go.Bar(
name='Rank 4 (Worst)',
x=labels,
y=stats_df['Rank 4'],
marker_color=ColorPalette.RANK_4,
hovertemplate='<b>%{x}</b><br>Rank 4: %{y}<extra></extra>'
))
fig.update_layout(
barmode='stack',
title=title,
xaxis_title=x_label,
yaxis_title=y_label,
height=height if height else getattr(self, 'plot_height', 500),
width=width if width else getattr(self, 'plot_width', 1000),
plot_bgcolor=ColorPalette.BACKGROUND,
xaxis=dict(
showgrid=True,
gridcolor=ColorPalette.GRID,
tickangle=-45
),
yaxis=dict(
showgrid=True,
gridcolor=ColorPalette.GRID
),
legend=dict(
orientation="h",
yanchor="bottom",
y=1.02,
xanchor="right",
x=1,
traceorder="normal"
),
font=dict(size=11)
)
self._save_plot(fig, title)
return fig
def plot_most_ranked_1(
self,
data: pl.LazyFrame | pl.DataFrame | None = None,
title: str = "Most Popular Choice<br>(Number of Times Ranked 1st)",
x_label: str = "Item",
y_label: str = "Count of 1st Place Rankings",
height: int | None = None,
width: int | None = None,
) -> go.Figure:
"""
Create a bar chart showing which item (character/voice) was ranked #1 the most.
Top 3 items are highlighted.
"""
df = self._ensure_dataframe(data)
stats = []
# Identify ranking columns (assume all columns except _recordId)
ranking_cols = [c for c in df.columns if c != '_recordId']
for col in ranking_cols:
# Count occurrences of rank 1
count_rank_1 = df.filter(pl.col(col) == 1).height
stats.append({
'column': col,
'count': count_rank_1
})
# Sort by count descending
stats_df = pl.DataFrame(stats).sort('count', descending=True)
# Clean up labels
labels = [
col.replace('Character_Ranking_', '').replace('Top_3_Voices_ranking__', '').replace('_', ' ').strip()
for col in stats_df['column']
]
# Assign colors: Top 3 get PRIMARY (Blue), others get NEUTRAL (Grey)
colors = [
ColorPalette.PRIMARY if i < 3 else ColorPalette.NEUTRAL
for i in range(len(stats_df))
]
fig = go.Figure()
fig.add_trace(go.Bar(
x=labels,
y=stats_df['count'],
text=stats_df['count'],
textposition='inside',
textfont=dict(size=10, color='white'),
marker_color=colors,
hovertemplate='<b>%{x}</b><br>1st Place Votes: %{y}<extra></extra>'
))
fig.update_layout(
title=title,
xaxis_title=x_label,
yaxis_title=y_label,
height=height if height else getattr(self, 'plot_height', 500),
width=width if width else getattr(self, 'plot_width', 1000),
plot_bgcolor=ColorPalette.BACKGROUND,
xaxis=dict(
showgrid=True,
gridcolor=ColorPalette.GRID,
tickangle=-45
),
yaxis=dict(
showgrid=True,
gridcolor=ColorPalette.GRID
),
font=dict(size=11)
)
self._save_plot(fig, title)
return fig
def plot_weighted_ranking_score(
self,
data: pl.LazyFrame | pl.DataFrame | None = None,
title: str = "Weighted Popularity Score<br>(1st=3pts, 2nd=2pts, 3rd=1pt)",
x_label: str = "Character Personality",
y_label: str = "Total Weighted Score",
color: str = ColorPalette.PRIMARY,
height: int | None = None,
width: int | None = None,
) -> go.Figure:
"""
Create a bar chart showing the weighted ranking score for each character.
"""
weighted_df = self._ensure_dataframe(data)
fig = go.Figure()
fig.add_trace(go.Bar(
x=weighted_df['Character'],
y=weighted_df['Weighted Score'],
text=weighted_df['Weighted Score'],
textposition='inside',
textfont=dict(size=11, color='white'),
marker_color=color,
hovertemplate='<b>%{x}</b><br>Score: %{y}<extra></extra>'
))
fig.update_layout(
title=title,
xaxis_title=x_label,
yaxis_title=y_label,
height=height if height else getattr(self, 'plot_height', 500),
width=width if width else getattr(self, 'plot_width', 1000),
plot_bgcolor=ColorPalette.BACKGROUND,
xaxis=dict(
showgrid=True,
gridcolor=ColorPalette.GRID,
tickangle=-45
),
yaxis=dict(
showgrid=True,
gridcolor=ColorPalette.GRID
),
font=dict(size=11)
)
self._save_plot(fig, title)
return fig
def plot_voice_selection_counts(
self,
data: pl.LazyFrame | pl.DataFrame | None = None,
target_column: str = "8_Combined",
title: str = "Most Frequently Chosen Voices<br>(Top 8 Highlighted)",
x_label: str = "Voice",
y_label: str = "Number of Times Chosen",
height: int | None = None,
width: int | None = None,
) -> go.Figure:
"""
Create a bar plot showing the frequency of voice selections.
"""
df = self._ensure_dataframe(data)
if target_column not in df.columns:
return go.Figure()
# Process the data:
# 1. Select the relevant column and remove nulls
# 2. Split the comma-separated string into a list
# 3. Explode the list so each voice gets its own row
# 4. Strip whitespace ensuring "Voice 1" and " Voice 1" match
# 5. Count occurrences
stats_df = (
df.select(pl.col(target_column))
.drop_nulls()
.with_columns(pl.col(target_column).str.split(","))
.explode(target_column)
.with_columns(pl.col(target_column).str.strip_chars())
.filter(pl.col(target_column) != "")
.group_by(target_column)
.agg(pl.len().alias("count"))
.sort("count", descending=True)
)
# Define colors: Top 8 get PRIMARY, rest get NEUTRAL
colors = [
ColorPalette.PRIMARY if i < 8 else ColorPalette.NEUTRAL
for i in range(len(stats_df))
]
fig = go.Figure()
fig.add_trace(go.Bar(
x=stats_df[target_column],
y=stats_df['count'],
text=stats_df['count'],
textposition='outside',
marker_color=colors,
hovertemplate='<b>%{x}</b><br>Selections: %{y}<extra></extra>'
))
fig.update_layout(
title=title,
xaxis_title=x_label,
yaxis_title=y_label,
height=height if height else getattr(self, 'plot_height', 500),
width=width if width else getattr(self, 'plot_width', 1000),
plot_bgcolor=ColorPalette.BACKGROUND,
xaxis=dict(
showgrid=True,
gridcolor=ColorPalette.GRID,
tickangle=-45
),
yaxis=dict(
showgrid=True,
gridcolor=ColorPalette.GRID
),
font=dict(size=11),
)
self._save_plot(fig, title)
return fig
def plot_top3_selection_counts(
self,
data: pl.LazyFrame | pl.DataFrame | None = None,
target_column: str = "3_Ranked",
title: str = "Most Frequently Chosen Top 3 Voices<br>(Top 3 Highlighted)",
x_label: str = "Voice",
y_label: str = "Count of Mentions in Top 3",
height: int | None = None,
width: int | None = None,
) -> go.Figure:
"""
Question: Which 3 voices are chosen the most out of 18?
"""
df = self._ensure_dataframe(data)
if target_column not in df.columns:
return go.Figure()
# Process the data:
# Same logic as plot_voice_selection_counts: explode comma-separated string
stats_df = (
df.select(pl.col(target_column))
.drop_nulls()
.with_columns(pl.col(target_column).str.split(","))
.explode(target_column)
.with_columns(pl.col(target_column).str.strip_chars())
.filter(pl.col(target_column) != "")
.group_by(target_column)
.agg(pl.len().alias("count"))
.sort("count", descending=True)
)
# Define colors: Top 3 get PRIMARY, rest get NEUTRAL
colors = [
ColorPalette.PRIMARY if i < 3 else ColorPalette.NEUTRAL
for i in range(len(stats_df))
]
fig = go.Figure()
fig.add_trace(go.Bar(
x=stats_df[target_column],
y=stats_df['count'],
text=stats_df['count'],
textposition='outside',
marker_color=colors,
hovertemplate='<b>%{x}</b><br>In Top 3: %{y} times<extra></extra>'
))
fig.update_layout(
title=title,
xaxis_title=x_label,
yaxis_title=y_label,
height=height if height else getattr(self, 'plot_height', 500),
width=width if width else getattr(self, 'plot_width', 1000),
plot_bgcolor=ColorPalette.BACKGROUND,
xaxis=dict(
showgrid=True,
gridcolor=ColorPalette.GRID,
tickangle=-45
),
yaxis=dict(
showgrid=True,
gridcolor=ColorPalette.GRID
),
font=dict(size=11),
)
self._save_plot(fig, title)
return fig
def plot_speaking_style_trait_scores(
self,
data: pl.LazyFrame | pl.DataFrame | None = None,
trait_description: str = None,
left_anchor: str = None,
right_anchor: str = None,
title: str = "Speaking Style Trait Analysis",
height: int | None = None,
width: int | None = None,
) -> go.Figure:
"""
Plot scores for a single speaking style trait across multiple voices.
"""
df = self._ensure_dataframe(data)
if df.is_empty():
return go.Figure()
required_cols = ["Voice", "score"]
if not all(col in df.columns for col in required_cols):
return go.Figure()
# Calculate stats: Mean, Count
stats = (
df.filter(pl.col("score").is_not_null())
.group_by("Voice")
.agg([
pl.col("score").mean().alias("mean_score"),
pl.col("score").count().alias("count")
])
.sort("mean_score", descending=False) # Ascending for display bottom-to-top
)
# Attempt to extract anchors from DF if not provided
if (left_anchor is None or right_anchor is None) and "Left_Anchor" in df.columns:
head = df.filter(pl.col("Left_Anchor").is_not_null()).head(1)
if not head.is_empty():
if left_anchor is None: left_anchor = head["Left_Anchor"][0]
if right_anchor is None: right_anchor = head["Right_Anchor"][0]
if trait_description is None:
if left_anchor and right_anchor:
trait_description = f"{left_anchor.split('|')[0]} vs. {right_anchor.split('|')[0]}"
else:
# Try getting from Description column
if "Description" in df.columns:
head = df.filter(pl.col("Description").is_not_null()).head(1)
if not head.is_empty():
trait_description = head["Description"][0]
else:
trait_description = ""
else:
trait_description = ""
fig = go.Figure()
fig.add_trace(go.Bar(
y=stats["Voice"], # Y is Voice
x=stats["mean_score"], # X is Score
orientation='h',
text=stats["count"],
textposition='inside',
textangle=0,
textfont=dict(size=16, color='white'),
texttemplate='%{text}', # Count on bar
marker_color=ColorPalette.PRIMARY,
hovertemplate='<b>%{y}</b><br>Average: %{x:.2f}<br>Count: %{text}<extra></extra>'
))
# Add annotations for anchors
annotations = []
# Place anchors at the bottom
if left_anchor:
annotations.append(dict(
xref='x', yref='paper',
x=1, y=-0.2, # Below axis
xanchor='left', yanchor='top',
text=f"<b>1: {left_anchor.split('|')[0]}</b>",
showarrow=False,
font=dict(size=10, color='gray')
))
if right_anchor:
annotations.append(dict(
xref='x', yref='paper',
x=5, y=-0.2, # Below axis
xanchor='right', yanchor='top',
text=f"<b>5: {right_anchor.split('|')[0]}</b>",
showarrow=False,
font=dict(size=10, color='gray')
))
fig.update_layout(
title=dict(
text=f"{title}<br><sub>{trait_description}</sub><br><sub>(Numbers on bars indicate respondent count)</sub>",
y=0.92
),
xaxis_title="Average Score (1-5)",
yaxis_title="Voice",
height=height if height else getattr(self, 'plot_height', 500),
width=width if width else getattr(self, 'plot_width', 1000),
plot_bgcolor=ColorPalette.BACKGROUND,
xaxis=dict(
range=[1, 5],
showgrid=True,
gridcolor=ColorPalette.GRID,
zeroline=False
),
yaxis=dict(
showgrid=False
),
margin=dict(b=120),
annotations=annotations,
font=dict(size=11)
)
self._save_plot(fig, title)
return fig
def plot_speaking_style_correlation(
self,
style_color: str,
style_traits: list[str],
data: pl.LazyFrame | pl.DataFrame | None = None,
title: str | None = None,
) -> go.Figure:
"""
Plots the correlation between Speaking Style Trait Scores (1-5) and Voice Scale (1-10) using a Bar Chart.
"""
df = self._ensure_dataframe(data)
if title is None:
title = f"Speaking style and voice scale 1-10 correlations"
trait_correlations = []
# 1. Calculate Correlations
for i, trait in enumerate(style_traits):
# Match against Right_Anchor which contains the positive trait description
# Use exact match for reliability
subset = df.filter(
pl.col("Right_Anchor") == trait
)
# Drop Nulls for correlation calculation
valid_data = subset.select(["score", "Voice_Scale_Score"]).drop_nulls()
if valid_data.height > 1:
# Calculate Pearson Correlation
corr_val = valid_data.select(pl.corr("score", "Voice_Scale_Score")).item()
# Trait Label for Plot
trait_correlations.append({
"trait_full": trait,
"trait_short": f"Trait {i+1}",
"correlation": corr_val if corr_val is not None else 0.0
})
# 2. Build Plot Data
if not trait_correlations:
# Return empty fig with title
fig = go.Figure()
fig.update_layout(title=f"No data for {style_color} Style")
return fig
plot_df = pl.DataFrame(trait_correlations)
# Determine colors based on correlation sign
colors = []
for val in plot_df["correlation"]:
if val >= 0:
colors.append("green") # Positive
else:
colors.append("red") # Negative
fig = go.Figure()
fig.add_trace(go.Bar(
x=[f"Trait {i+1}" for i in range(len(plot_df))], # Simple Labels on Axis
y=plot_df["correlation"],
text=[f"{val:.2f}" for val in plot_df["correlation"]],
textposition='outside', # Or auto
marker_color=colors,
hovertemplate="<b>%{customdata}</b><br>Correlation: %{y:.2f}<extra></extra>",
customdata=plot_df["trait_full"] # Full text on hover
))
# Wrap text at the "|" separator for cleaner line breaks
def wrap_text_at_pipe(text):
parts = [p.strip() for p in text.split("|")]
return "<br>".join(parts)
x_labels = [wrap_text_at_pipe(t) for t in plot_df["trait_full"]]
# Update trace to use full labels
fig.data[0].x = x_labels
fig.update_layout(
title=title,
yaxis_title="Correlation",
yaxis=dict(range=[-1, 1], zeroline=True, zerolinecolor="black"),
xaxis=dict(tickangle=0), # Keep flat if possible
height=400, # Use fixed default from original
width=1000,
template="plotly_white",
showlegend=False
)
self._save_plot(fig, title)
return fig
def plot_speaking_style_ranking_correlation(
self,
style_color: str,
style_traits: list[str],
data: pl.LazyFrame | pl.DataFrame | None = None,
title: str | None = None,
) -> go.Figure:
"""
Plots the correlation between Speaking Style Trait Scores (1-5) and Voice Ranking Points (0-3).
"""
df = self._ensure_dataframe(data)
if title is None:
title = f"Speaking style {style_color} and voice ranking points correlations"
trait_correlations = []
# 1. Calculate Correlations
for i, trait in enumerate(style_traits):
# Match against Right_Anchor which contains the positive trait description
subset = df.filter(pl.col("Right_Anchor") == trait)
# Drop Nulls for correlation calculation
valid_data = subset.select(["score", "Ranking_Points"]).drop_nulls()
if valid_data.height > 1:
# Calculate Pearson Correlation
corr_val = valid_data.select(pl.corr("score", "Ranking_Points")).item()
trait_correlations.append({
"trait_full": trait,
"trait_short": f"Trait {i+1}",
"correlation": corr_val if corr_val is not None else 0.0
})
# 2. Build Plot Data
if not trait_correlations:
fig = go.Figure()
fig.update_layout(title=f"No data for {style_color} Style")
return fig
plot_df = pl.DataFrame(trait_correlations)
# Determine colors based on correlation sign
colors = []
for val in plot_df["correlation"]:
if val >= 0:
colors.append("green")
else:
colors.append("red")
fig = go.Figure()
fig.add_trace(go.Bar(
x=[f"Trait {i+1}" for i in range(len(plot_df))],
y=plot_df["correlation"],
text=[f"{val:.2f}" for val in plot_df["correlation"]],
textposition='outside',
marker_color=colors,
hovertemplate="<b>%{customdata}</b><br>Correlation: %{y:.2f}<extra></extra>",
customdata=plot_df["trait_full"]
))
# Wrap text at the "|" separator for cleaner line breaks
def wrap_text_at_pipe(text):
parts = [p.strip() for p in text.split("|")]
return "<br>".join(parts)
x_labels = [wrap_text_at_pipe(t) for t in plot_df["trait_full"]]
# Update trace to use full labels
fig.data[0].x = x_labels
fig.update_layout(
title=title,
yaxis_title="Correlation",
yaxis=dict(range=[-1, 1], zeroline=True, zerolinecolor="black"),
xaxis=dict(tickangle=0),
height=400,
width=1000,
template="plotly_white",
showlegend=False
)
self._save_plot(fig, title)
return fig