1901 lines
79 KiB
Python
1901 lines
79 KiB
Python
"""Plotting functions for Voice Branding analysis using Altair."""
|
||
|
||
import re
|
||
import math
|
||
from pathlib import Path
|
||
|
||
import altair as alt
|
||
import pandas as pd
|
||
import polars as pl
|
||
from theme import ColorPalette
|
||
from reference import VOICE_GENDER_MAPPING
|
||
|
||
import hashlib
|
||
|
||
class QualtricsPlotsMixin:
|
||
"""Mixin class for plotting functions in QualtricsSurvey."""
|
||
|
||
def _process_title(self, title: str) -> str | list[str]:
|
||
"""Process title to handle <br> tags for Altair."""
|
||
if isinstance(title, str) and '<br>' in title:
|
||
return title.split('<br>')
|
||
return title
|
||
|
||
def _sanitize_filename(self, title: str) -> str:
|
||
"""Convert plot title to a safe filename."""
|
||
# Remove HTML tags
|
||
clean = re.sub(r'<[^>]+>', ' ', title)
|
||
# Replace special characters with underscores
|
||
clean = re.sub(r'[^\w\s-]', '', clean)
|
||
# Replace whitespace with underscores
|
||
clean = re.sub(r'\s+', '_', clean.strip())
|
||
# Remove consecutive underscores
|
||
clean = re.sub(r'_+', '_', clean)
|
||
# Lowercase and limit length
|
||
return clean.lower()[:100]
|
||
|
||
def _get_filter_slug(self) -> str:
|
||
"""Generate a directory-friendly slug based on active filters."""
|
||
parts = []
|
||
|
||
# Mapping of attribute name to (short_code, value, options_attr)
|
||
filters = [
|
||
('age', 'Age', getattr(self, 'filter_age', None), 'options_age'),
|
||
('gender', 'Gen', getattr(self, 'filter_gender', None), 'options_gender'),
|
||
('consumer', 'Cons', getattr(self, 'filter_consumer', None), 'options_consumer'),
|
||
('ethnicity', 'Eth', getattr(self, 'filter_ethnicity', None), 'options_ethnicity'),
|
||
('income', 'Inc', getattr(self, 'filter_income', None), 'options_income'),
|
||
]
|
||
|
||
for _, short_code, value, options_attr in filters:
|
||
if value is None:
|
||
continue
|
||
|
||
# Ensure value is a list for uniform handling
|
||
if not isinstance(value, list):
|
||
value = [value]
|
||
|
||
if len(value) == 0:
|
||
continue
|
||
|
||
# Check if all options are selected (equivalent to no filter)
|
||
# We compare the set of selected values to the set of all available options
|
||
master_list = getattr(self, options_attr, None)
|
||
if master_list and set(value) == set(master_list):
|
||
continue
|
||
|
||
if len(value) > 3:
|
||
# If more than 3 options selected, create a hash of the sorted values
|
||
# This ensures uniqueness properly while keeping the slug short
|
||
sorted_vals = sorted([str(v) for v in value])
|
||
vals_str = "".join(sorted_vals)
|
||
# Create short 6-char hash
|
||
val_hash = hashlib.md5(vals_str.encode()).hexdigest()[:6]
|
||
val_str = f"{len(value)}_grps_{val_hash}"
|
||
else:
|
||
# Join values with '+'
|
||
clean_values = []
|
||
for v in value:
|
||
# Simple sanitization: keep alphanum and hyphens/dots, remove others
|
||
s = str(v)
|
||
# Remove special chars that might be problematic in dir names
|
||
s = re.sub(r'[^\w\-\.]', '', s)
|
||
clean_values.append(s)
|
||
val_str = "+".join(clean_values)
|
||
|
||
parts.append(f"{short_code}-{val_str}")
|
||
|
||
if not parts:
|
||
return "All_Respondents"
|
||
|
||
return "_".join(parts)
|
||
|
||
def _get_filter_description(self) -> str:
|
||
"""Generate a human-readable description of active filters."""
|
||
parts = []
|
||
|
||
# Mapping of attribute name to (display_name, value, options_attr)
|
||
filters = [
|
||
('Age', getattr(self, 'filter_age', None), 'options_age'),
|
||
('Gender', getattr(self, 'filter_gender', None), 'options_gender'),
|
||
('Consumer', getattr(self, 'filter_consumer', None), 'options_consumer'),
|
||
('Ethnicity', getattr(self, 'filter_ethnicity', None), 'options_ethnicity'),
|
||
('Income', getattr(self, 'filter_income', None), 'options_income'),
|
||
]
|
||
|
||
for display_name, value, options_attr in filters:
|
||
if value is None:
|
||
continue
|
||
|
||
# Ensure value is a list for uniform handling
|
||
if not isinstance(value, list):
|
||
value = [value]
|
||
|
||
if len(value) == 0:
|
||
continue
|
||
|
||
# Check if all options are selected (equivalent to no filter)
|
||
master_list = getattr(self, options_attr, None)
|
||
if master_list and set(value) == set(master_list):
|
||
continue
|
||
|
||
# Use original values for display (full list)
|
||
clean_values = [str(v) for v in value]
|
||
val_str = ", ".join(clean_values)
|
||
# Use UPPERCASE for category name to distinguish from values
|
||
parts.append(f"{display_name.upper()}: {val_str}")
|
||
|
||
if not parts:
|
||
return ""
|
||
|
||
# Join with clear separator - double space for visual break
|
||
return "Filters: " + " — ".join(parts)
|
||
|
||
def _add_filter_footnote(self, chart: alt.Chart) -> alt.Chart:
|
||
"""Add a footnote with active filters to the chart.
|
||
|
||
Uses chart subtitle for filter text to avoid layout issues with vconcat.
|
||
Returns the modified chart (or original if no filters).
|
||
"""
|
||
filter_text = self._get_filter_description()
|
||
|
||
# Skip if no filters active - return original chart
|
||
if not filter_text:
|
||
return chart
|
||
|
||
# Wrap text into multiple lines at ~100 chars, but don't break mid-word
|
||
max_line_length = 100
|
||
words = filter_text.split()
|
||
lines = []
|
||
current_line = ""
|
||
|
||
for word in words:
|
||
test_line = f"{current_line} {word}".strip() if current_line else word
|
||
if len(test_line) <= max_line_length:
|
||
current_line = test_line
|
||
else:
|
||
if current_line:
|
||
lines.append(current_line)
|
||
current_line = word
|
||
if current_line:
|
||
lines.append(current_line)
|
||
|
||
# Get existing title from chart spec
|
||
chart_spec = chart.to_dict()
|
||
existing_title = chart_spec.get('title', '')
|
||
|
||
# Handle different title formats (string vs dict vs list)
|
||
if isinstance(existing_title, (str, list)):
|
||
title_config = {
|
||
'text': existing_title,
|
||
'subtitle': lines,
|
||
'subtitleColor': 'gray',
|
||
'subtitleFontSize': 10,
|
||
'anchor': 'start',
|
||
}
|
||
elif isinstance(existing_title, dict):
|
||
title_config = existing_title.copy()
|
||
title_config['subtitle'] = lines
|
||
title_config['subtitleColor'] = 'gray'
|
||
title_config['subtitleFontSize'] = 10
|
||
title_config['anchor'] = 'start'
|
||
else:
|
||
# No existing title, just add filters as subtitle
|
||
title_config = {
|
||
'text': '',
|
||
'subtitle': lines,
|
||
'subtitleColor': 'gray',
|
||
'subtitleFontSize': 10,
|
||
'anchor': 'start',
|
||
}
|
||
|
||
return chart.properties(title=title_config)
|
||
|
||
def _save_plot(self, chart: alt.Chart, title: str) -> alt.Chart:
|
||
"""Save chart to PNG file if fig_save_dir is set.
|
||
|
||
Returns the (potentially modified) chart with filter footnote added.
|
||
"""
|
||
# Add filter footnote - returns combined chart if filters active
|
||
chart = self._add_filter_footnote(chart)
|
||
|
||
if hasattr(self, 'fig_save_dir') and self.fig_save_dir:
|
||
path = Path(self.fig_save_dir)
|
||
|
||
# Add filter slug subfolder
|
||
filter_slug = self._get_filter_slug()
|
||
path = path / filter_slug
|
||
|
||
if not path.exists():
|
||
path.mkdir(parents=True, exist_ok=True)
|
||
|
||
filename = f"{self._sanitize_filename(title)}.png"
|
||
filepath = path / filename
|
||
|
||
# Use vl_convert directly with theme config for consistent rendering
|
||
import vl_convert as vlc
|
||
from theme import jpmc_altair_theme
|
||
|
||
# Get chart spec and theme config
|
||
chart_spec = chart.to_dict()
|
||
theme_config = jpmc_altair_theme()['config']
|
||
|
||
png_data = vlc.vegalite_to_png(
|
||
vl_spec=chart_spec,
|
||
scale=2.0,
|
||
ppi=72,
|
||
config=theme_config
|
||
)
|
||
|
||
with open(filepath, 'wb') as f:
|
||
f.write(png_data)
|
||
|
||
print(f"Saved plot to {filepath}")
|
||
|
||
return chart
|
||
|
||
def _ensure_dataframe(self, data: pl.LazyFrame | pl.DataFrame | None) -> pl.DataFrame:
|
||
"""Ensure data is an eager DataFrame, collecting if necessary."""
|
||
df = data if data is not None else getattr(self, 'data_filtered', None)
|
||
if df is None:
|
||
raise ValueError("No data provided and self.data_filtered is None.")
|
||
|
||
if isinstance(df, pl.LazyFrame):
|
||
return df.collect()
|
||
return df
|
||
|
||
def _clean_voice_label(self, col_name: str) -> str:
|
||
"""Extract and clean voice name from column name for display.
|
||
|
||
Handles patterns like:
|
||
- 'Voice_Scale__The_Coach' -> 'The Coach'
|
||
- 'Character_Ranking_The_Coach' -> 'The Coach'
|
||
- 'Top_3_Voices_ranking__Familiar_Friend' -> 'Familiar Friend'
|
||
"""
|
||
# First split by __ if present
|
||
label = col_name.split('__')[-1] if '__' in col_name else col_name
|
||
# Remove common prefixes
|
||
label = label.replace('Character_Ranking_', '')
|
||
label = label.replace('Top_3_Voices_ranking_', '')
|
||
# Replace underscores with spaces
|
||
label = label.replace('_', ' ').strip()
|
||
return label
|
||
|
||
def _get_voice_gender(self, voice_label: str) -> str:
|
||
"""Get the gender of a voice from its label.
|
||
|
||
Parameters:
|
||
voice_label: Voice label (e.g., 'V14', 'Voice 14', etc.)
|
||
|
||
Returns:
|
||
'Male' or 'Female', defaults to 'Male' if not found
|
||
"""
|
||
# Extract voice code (e.g., 'V14' from 'Voice 14' or 'V14')
|
||
voice_code = None
|
||
|
||
# Try to find VXX pattern
|
||
match = re.search(r'V(\d+)', voice_label)
|
||
if match:
|
||
voice_code = f"V{match.group(1)}"
|
||
else:
|
||
# Try to extract number and prepend V
|
||
match = re.search(r'(\d+)', voice_label)
|
||
if match:
|
||
voice_code = f"V{match.group(1)}"
|
||
|
||
if voice_code and voice_code in VOICE_GENDER_MAPPING:
|
||
return VOICE_GENDER_MAPPING[voice_code]
|
||
|
||
return "Male" # Default to Male if unknown
|
||
|
||
def _get_gender_color(self, gender: str, color_type: str = "primary") -> str:
|
||
"""Get the appropriate color based on gender.
|
||
|
||
Parameters:
|
||
gender: 'Male' or 'Female'
|
||
color_type: One of 'primary', 'rank_1', 'rank_2', 'rank_3', 'neutral'
|
||
|
||
Returns:
|
||
Hex color string
|
||
"""
|
||
color_map = {
|
||
"Male": {
|
||
"primary": ColorPalette.GENDER_MALE,
|
||
"rank_1": ColorPalette.GENDER_MALE_RANK_1,
|
||
"rank_2": ColorPalette.GENDER_MALE_RANK_2,
|
||
"rank_3": ColorPalette.GENDER_MALE_RANK_3,
|
||
"neutral": ColorPalette.GENDER_MALE_NEUTRAL,
|
||
},
|
||
"Female": {
|
||
"primary": ColorPalette.GENDER_FEMALE,
|
||
"rank_1": ColorPalette.GENDER_FEMALE_RANK_1,
|
||
"rank_2": ColorPalette.GENDER_FEMALE_RANK_2,
|
||
"rank_3": ColorPalette.GENDER_FEMALE_RANK_3,
|
||
"neutral": ColorPalette.GENDER_FEMALE_NEUTRAL,
|
||
}
|
||
}
|
||
return color_map.get(gender, color_map["Male"]).get(color_type, ColorPalette.PRIMARY)
|
||
|
||
def plot_average_scores_with_counts(
|
||
self,
|
||
data: pl.LazyFrame | pl.DataFrame | None = None,
|
||
title: str = "General Impression (1-10)\nPer Voice with Number of Participants Who Rated It",
|
||
x_label: str = "Stimuli",
|
||
y_label: str = "Average General Impression Rating (1-10)",
|
||
color: str = ColorPalette.PRIMARY,
|
||
height: int | None = None,
|
||
width: int | str | None = None,
|
||
domain: list[float] | None = None,
|
||
color_gender: bool = False,
|
||
) -> alt.Chart:
|
||
"""Create a bar plot showing average scores and count of non-null values for each column.
|
||
|
||
Parameters:
|
||
color_gender: If True, color bars by voice gender (blue=male, pink=female).
|
||
"""
|
||
df = self._ensure_dataframe(data)
|
||
|
||
# Calculate stats for each column (exclude _recordId)
|
||
stats = []
|
||
for col in [c for c in df.columns if c != '_recordId']:
|
||
avg_score = df[col].mean()
|
||
non_null_count = df[col].drop_nulls().len()
|
||
label = self._clean_voice_label(col)
|
||
gender = self._get_voice_gender(label) if color_gender else None
|
||
stats.append({
|
||
'voice': label,
|
||
'average': avg_score,
|
||
'count': non_null_count,
|
||
'gender': gender
|
||
})
|
||
|
||
# Convert to pandas for Altair (sort by average descending)
|
||
stats_df = pl.DataFrame(stats).sort('average', descending=True).to_pandas()
|
||
|
||
if domain is None:
|
||
domain = [stats_df['average'].min(), stats_df['average'].max()]
|
||
|
||
# Base bar chart - use y2 to explicitly start bars at domain minimum
|
||
if color_gender:
|
||
bars = alt.Chart(stats_df).mark_bar().encode(
|
||
x=alt.X('voice:N', title=x_label, sort='-y'),
|
||
y=alt.Y('average:Q', title=y_label, scale=alt.Scale(domain=domain)),
|
||
y2=alt.datum(domain[0]), # Bars start at domain minimum (bottom edge)
|
||
color=alt.Color('gender:N',
|
||
scale=alt.Scale(domain=['Male', 'Female'],
|
||
range=[ColorPalette.GENDER_MALE, ColorPalette.GENDER_FEMALE]),
|
||
legend=alt.Legend(orient='top', direction='horizontal', title='Gender')),
|
||
tooltip=[
|
||
alt.Tooltip('voice:N', title='Voice'),
|
||
alt.Tooltip('average:Q', title='Average', format='.2f'),
|
||
alt.Tooltip('count:Q', title='Count'),
|
||
alt.Tooltip('gender:N', title='Gender')
|
||
]
|
||
)
|
||
else:
|
||
bars = alt.Chart(stats_df).mark_bar(color=color).encode(
|
||
x=alt.X('voice:N', title=x_label, sort='-y'),
|
||
y=alt.Y('average:Q', title=y_label, scale=alt.Scale(domain=domain)),
|
||
y2=alt.datum(domain[0]), # Bars start at domain minimum (bottom edge)
|
||
tooltip=[
|
||
alt.Tooltip('voice:N', title='Voice'),
|
||
alt.Tooltip('average:Q', title='Average', format='.2f'),
|
||
alt.Tooltip('count:Q', title='Count')
|
||
]
|
||
)
|
||
|
||
# Text overlay for counts
|
||
text = alt.Chart(stats_df).mark_text(
|
||
dy=-5,
|
||
color='black',
|
||
fontSize=10
|
||
).encode(
|
||
x=alt.X('voice:N', sort='-y'),
|
||
y=alt.Y('average:Q'),
|
||
text=alt.Text('count:Q')
|
||
)
|
||
|
||
# Combine layers
|
||
chart = (bars + text).properties(
|
||
title=self._process_title(title),
|
||
width=width or 800,
|
||
height=height or getattr(self, 'plot_height', 400)
|
||
)
|
||
|
||
chart = self._save_plot(chart, title)
|
||
return chart
|
||
|
||
def plot_top3_ranking_distribution(
|
||
self,
|
||
data: pl.LazyFrame | pl.DataFrame | None = None,
|
||
title: str = "Top 3 Rankings Distribution\nCount of 1st, 2nd, and 3rd Place Votes per Voice",
|
||
x_label: str = "Voices",
|
||
y_label: str = "Number of Mentions in Top 3",
|
||
height: int | None = None,
|
||
width: int | str | None = None,
|
||
) -> alt.Chart:
|
||
"""Create a stacked bar chart showing how often each voice was ranked 1st, 2nd, or 3rd."""
|
||
df = self._ensure_dataframe(data)
|
||
|
||
# Calculate stats per column
|
||
stats = []
|
||
for col in [c for c in df.columns if c != '_recordId']:
|
||
rank1 = df.filter(pl.col(col) == 1).height
|
||
rank2 = df.filter(pl.col(col) == 2).height
|
||
rank3 = df.filter(pl.col(col) == 3).height
|
||
total = rank1 + rank2 + rank3
|
||
|
||
if total > 0:
|
||
label = self._clean_voice_label(col)
|
||
# Add 3 rows (one per rank)
|
||
stats.append({'voice': label, 'rank': 'Rank 1 (1st Choice)', 'count': rank1, 'total': total})
|
||
stats.append({'voice': label, 'rank': 'Rank 2 (2nd Choice)', 'count': rank2, 'total': total})
|
||
stats.append({'voice': label, 'rank': 'Rank 3 (3rd Choice)', 'count': rank3, 'total': total})
|
||
|
||
# Convert to long format, sort by total
|
||
stats_df = pl.DataFrame(stats).to_pandas()
|
||
|
||
# Interactive legend selection - click to filter
|
||
selection = alt.selection_point(fields=['rank'], bind='legend')
|
||
|
||
# Create stacked bar chart with interactive legend
|
||
chart = alt.Chart(stats_df).mark_bar().encode(
|
||
x=alt.X('voice:N', title=x_label, sort=alt.EncodingSortField(field='total', op='sum', order='descending')),
|
||
y=alt.Y('count:Q', title=y_label, stack='zero'),
|
||
color=alt.Color('rank:N',
|
||
scale=alt.Scale(domain=['Rank 1 (1st Choice)', 'Rank 2 (2nd Choice)', 'Rank 3 (3rd Choice)'],
|
||
range=[ColorPalette.RANK_1, ColorPalette.RANK_2, ColorPalette.RANK_3]),
|
||
legend=alt.Legend(orient='top', direction='horizontal', title=None)),
|
||
order=alt.Order('rank:N', sort='ascending'),
|
||
opacity=alt.condition(selection, alt.value(1), alt.value(0.2)),
|
||
tooltip=[
|
||
alt.Tooltip('voice:N', title='Voice'),
|
||
alt.Tooltip('rank:N', title='Rank'),
|
||
alt.Tooltip('count:Q', title='Count')
|
||
]
|
||
).add_params(selection).properties(
|
||
title=self._process_title(title),
|
||
width=width or 800,
|
||
height=height or getattr(self, 'plot_height', 400)
|
||
)
|
||
|
||
chart = self._save_plot(chart, title)
|
||
return chart
|
||
|
||
def plot_ranking_distribution(
|
||
self,
|
||
data: pl.LazyFrame | pl.DataFrame | None = None,
|
||
title: str = "Rankings Distribution\n(1st to 3rd Place)",
|
||
x_label: str = "Item",
|
||
y_label: str = "Number of Votes",
|
||
height: int | None = None,
|
||
width: int | str | None = None,
|
||
color_gender: bool = False,
|
||
) -> alt.Chart:
|
||
"""Create a stacked bar chart showing the distribution of rankings (1st to 3rd).
|
||
|
||
Parameters:
|
||
color_gender: If True, color bars by voice gender with rank intensity
|
||
(blue shades=male, pink shades=female).
|
||
"""
|
||
df = self._ensure_dataframe(data)
|
||
|
||
stats = []
|
||
ranking_cols = [c for c in df.columns if c != '_recordId']
|
||
|
||
for col in ranking_cols:
|
||
r1 = df.filter(pl.col(col) == 1).height
|
||
r2 = df.filter(pl.col(col) == 2).height
|
||
r3 = df.filter(pl.col(col) == 3).height
|
||
# r4 = df.filter(pl.col(col) == 4).height
|
||
total = r1 + r2 + r3
|
||
|
||
if total > 0:
|
||
label = self._clean_voice_label(col)
|
||
gender = self._get_voice_gender(label) if color_gender else None
|
||
stats.append({'item': label, 'rank': 'Rank 1 (Best)', 'count': r1, 'total': total, 'gender': gender, 'rank_order': 1})
|
||
stats.append({'item': label, 'rank': 'Rank 2', 'count': r2, 'total': total, 'gender': gender, 'rank_order': 2})
|
||
stats.append({'item': label, 'rank': 'Rank 3', 'count': r3, 'total': total, 'gender': gender, 'rank_order': 3})
|
||
# stats.append({'item': label, 'rank': 'Rank 4 (Worst)', 'count': r4, 'total': total, 'gender': gender, 'rank_order': 4})
|
||
|
||
if not stats:
|
||
return alt.Chart(pd.DataFrame({'text': ['No data']})).mark_text().encode(text='text:N')
|
||
|
||
stats_df = pl.DataFrame(stats).to_pandas()
|
||
|
||
# Interactive legend selection - click to filter
|
||
selection = alt.selection_point(fields=['rank'], bind='legend')
|
||
|
||
if color_gender:
|
||
# Add gender_rank column for combined color encoding
|
||
stats_df['gender_rank'] = stats_df['gender'] + ' - ' + stats_df['rank']
|
||
|
||
# Define combined domain and range for gender + rank
|
||
domain = [
|
||
'Male - Rank 1 (Best)', 'Male - Rank 2', 'Male - Rank 3',
|
||
'Female - Rank 1 (Best)', 'Female - Rank 2', 'Female - Rank 3'
|
||
]
|
||
range_colors = [
|
||
ColorPalette.GENDER_MALE_RANK_1, ColorPalette.GENDER_MALE_RANK_2, ColorPalette.GENDER_MALE_RANK_3,
|
||
ColorPalette.GENDER_FEMALE_RANK_1, ColorPalette.GENDER_FEMALE_RANK_2, ColorPalette.GENDER_FEMALE_RANK_3
|
||
]
|
||
|
||
chart = alt.Chart(stats_df).mark_bar().encode(
|
||
x=alt.X('item:N', title=x_label, sort=alt.EncodingSortField(field='total', order='descending')),
|
||
y=alt.Y('count:Q', title=y_label, stack='zero'),
|
||
color=alt.Color('gender_rank:N',
|
||
scale=alt.Scale(domain=domain, range=range_colors),
|
||
legend=alt.Legend(orient='top', direction='horizontal', title=None, columns=3)),
|
||
order=alt.Order('rank_order:Q', sort='ascending'),
|
||
opacity=alt.condition(selection, alt.value(1), alt.value(0.2)),
|
||
tooltip=[
|
||
alt.Tooltip('item:N', title='Item'),
|
||
alt.Tooltip('rank:N', title='Rank'),
|
||
alt.Tooltip('count:Q', title='Count'),
|
||
alt.Tooltip('gender:N', title='Gender')
|
||
]
|
||
).add_params(selection).properties(
|
||
title=self._process_title(title),
|
||
width=width or 800,
|
||
height=height or getattr(self, 'plot_height', 400)
|
||
)
|
||
else:
|
||
chart = alt.Chart(stats_df).mark_bar().encode(
|
||
x=alt.X('item:N', title=x_label, sort=alt.EncodingSortField(field='total', order='descending')),
|
||
y=alt.Y('count:Q', title=y_label, stack='zero'),
|
||
color=alt.Color('rank:N',
|
||
scale=alt.Scale(domain=['Rank 1 (Best)', 'Rank 2', 'Rank 3'],
|
||
range=[ColorPalette.RANK_1, ColorPalette.RANK_2, ColorPalette.RANK_3]),
|
||
legend=alt.Legend(orient='top', direction='horizontal', title=None)),
|
||
order=alt.Order('rank_order:Q', sort='ascending'),
|
||
opacity=alt.condition(selection, alt.value(1), alt.value(0.2)),
|
||
tooltip=[
|
||
alt.Tooltip('item:N', title='Item'),
|
||
alt.Tooltip('rank:N', title='Rank'),
|
||
alt.Tooltip('count:Q', title='Count')
|
||
]
|
||
).add_params(selection).properties(
|
||
title=self._process_title(title),
|
||
width=width or 800,
|
||
height=height or getattr(self, 'plot_height', 400)
|
||
)
|
||
|
||
chart = self._save_plot(chart, title)
|
||
return chart
|
||
|
||
def plot_most_ranked_1(
|
||
self,
|
||
data: pl.LazyFrame | pl.DataFrame | None = None,
|
||
title: str = "Most Popular Choice\n(Number of Times Ranked 1st)",
|
||
x_label: str = "Item",
|
||
y_label: str = "Count of 1st Place Rankings",
|
||
height: int | None = None,
|
||
width: int | str | None = None,
|
||
color_gender: bool = False,
|
||
) -> alt.Chart:
|
||
"""Create a bar chart showing which item was ranked #1 the most. Top 3 highlighted.
|
||
|
||
Parameters:
|
||
color_gender: If True, color bars by voice gender with highlight/neutral intensity
|
||
(blue shades=male, pink shades=female).
|
||
"""
|
||
df = self._ensure_dataframe(data)
|
||
|
||
stats = []
|
||
ranking_cols = [c for c in df.columns if c != '_recordId']
|
||
|
||
for col in ranking_cols:
|
||
count_rank_1 = df.filter(pl.col(col) == 1).height
|
||
label = self._clean_voice_label(col)
|
||
gender = self._get_voice_gender(label) if color_gender else None
|
||
stats.append({'item': label, 'count': count_rank_1, 'gender': gender})
|
||
|
||
# Convert and sort
|
||
stats_df = pl.DataFrame(stats).sort('count', descending=True)
|
||
|
||
# Add rank column for coloring (1-3 vs 4+)
|
||
stats_df = stats_df.with_row_index('rank_index')
|
||
stats_df = stats_df.with_columns(
|
||
pl.when(pl.col('rank_index') < 3)
|
||
.then(pl.lit('Top 3'))
|
||
.otherwise(pl.lit('Other'))
|
||
.alias('category')
|
||
).to_pandas()
|
||
|
||
if color_gender:
|
||
# Add gender_category column for combined color encoding
|
||
stats_df['gender_category'] = stats_df['gender'] + ' - ' + stats_df['category']
|
||
|
||
# Define combined domain and range for gender + category
|
||
domain = ['Male - Top 3', 'Male - Other', 'Female - Top 3', 'Female - Other']
|
||
range_colors = [
|
||
ColorPalette.GENDER_MALE, ColorPalette.GENDER_MALE_NEUTRAL,
|
||
ColorPalette.GENDER_FEMALE, ColorPalette.GENDER_FEMALE_NEUTRAL
|
||
]
|
||
|
||
chart = alt.Chart(stats_df).mark_bar().encode(
|
||
x=alt.X('item:N', title=x_label, sort='-y'),
|
||
y=alt.Y('count:Q', title=y_label),
|
||
color=alt.Color('gender_category:N',
|
||
scale=alt.Scale(domain=domain, range=range_colors),
|
||
legend=alt.Legend(orient='top', direction='horizontal', title=None)),
|
||
tooltip=[
|
||
alt.Tooltip('item:N', title='Item'),
|
||
alt.Tooltip('count:Q', title='1st Place Votes'),
|
||
alt.Tooltip('gender:N', title='Gender')
|
||
]
|
||
).properties(
|
||
title=self._process_title(title),
|
||
width=width or 800,
|
||
height=height or getattr(self, 'plot_height', 400)
|
||
)
|
||
else:
|
||
# Bar chart with conditional color
|
||
chart = alt.Chart(stats_df).mark_bar().encode(
|
||
x=alt.X('item:N', title=x_label, sort='-y'),
|
||
y=alt.Y('count:Q', title=y_label),
|
||
color=alt.Color('category:N',
|
||
scale=alt.Scale(domain=['Top 3', 'Other'],
|
||
range=[ColorPalette.PRIMARY, ColorPalette.NEUTRAL]),
|
||
legend=None),
|
||
tooltip=[
|
||
alt.Tooltip('item:N', title='Item'),
|
||
alt.Tooltip('count:Q', title='1st Place Votes')
|
||
]
|
||
).properties(
|
||
title=self._process_title(title),
|
||
width=width or 800,
|
||
height=height or getattr(self, 'plot_height', 400)
|
||
)
|
||
|
||
chart = self._save_plot(chart, title)
|
||
return chart
|
||
|
||
def plot_weighted_ranking_score(
|
||
self,
|
||
data: pl.LazyFrame | pl.DataFrame | None = None,
|
||
title: str = "Weighted Popularity Score\n(1st=3pts, 2nd=2pts, 3rd=1pt)",
|
||
x_label: str = "Character Personality",
|
||
y_label: str = "Total Weighted Score",
|
||
color: str = ColorPalette.PRIMARY,
|
||
height: int | None = None,
|
||
width: int | str | None = None,
|
||
color_gender: bool = False,
|
||
) -> alt.Chart:
|
||
"""Create a bar chart showing the weighted ranking score for each character.
|
||
|
||
Parameters:
|
||
color_gender: If True, color bars by voice gender (blue=male, pink=female).
|
||
"""
|
||
weighted_df = self._ensure_dataframe(data).to_pandas()
|
||
|
||
if color_gender:
|
||
# Add gender column based on Character name
|
||
weighted_df['gender'] = weighted_df['Character'].apply(self._get_voice_gender)
|
||
|
||
# Bar chart with gender coloring
|
||
bars = alt.Chart(weighted_df).mark_bar().encode(
|
||
x=alt.X('Character:N', title=x_label, sort='-y'),
|
||
y=alt.Y('Weighted Score:Q', title=y_label),
|
||
color=alt.Color('gender:N',
|
||
scale=alt.Scale(domain=['Male', 'Female'],
|
||
range=[ColorPalette.GENDER_MALE, ColorPalette.GENDER_FEMALE]),
|
||
legend=alt.Legend(orient='top', direction='horizontal', title='Gender')),
|
||
tooltip=[
|
||
alt.Tooltip('Character:N'),
|
||
alt.Tooltip('Weighted Score:Q', title='Score'),
|
||
alt.Tooltip('gender:N', title='Gender')
|
||
]
|
||
)
|
||
else:
|
||
# Bar chart
|
||
bars = alt.Chart(weighted_df).mark_bar(color=color).encode(
|
||
x=alt.X('Character:N', title=x_label, sort='-y'),
|
||
y=alt.Y('Weighted Score:Q', title=y_label),
|
||
tooltip=[
|
||
alt.Tooltip('Character:N'),
|
||
alt.Tooltip('Weighted Score:Q', title='Score')
|
||
]
|
||
)
|
||
|
||
# Text overlay
|
||
text = bars.mark_text(
|
||
dy=-5,
|
||
color='white',
|
||
fontSize=11
|
||
).encode(
|
||
text='Weighted Score:Q'
|
||
)
|
||
|
||
chart = (bars + text).properties(
|
||
title=self._process_title(title),
|
||
width=width or 800,
|
||
height=height or getattr(self, 'plot_height', 400)
|
||
)
|
||
|
||
chart = self._save_plot(chart, title)
|
||
return chart
|
||
|
||
def plot_voice_selection_counts(
|
||
self,
|
||
data: pl.LazyFrame | pl.DataFrame | None = None,
|
||
target_column: str = "8_Combined",
|
||
title: str = "Most Frequently Chosen Voices\n(Top 8 Highlighted)",
|
||
x_label: str = "Voice",
|
||
y_label: str = "Number of Times Chosen",
|
||
height: int | None = None,
|
||
width: int | str | None = None,
|
||
color_gender: bool = False,
|
||
) -> alt.Chart:
|
||
"""Create a bar plot showing the frequency of voice selections.
|
||
|
||
Parameters:
|
||
color_gender: If True, color bars by voice gender with highlight/neutral intensity
|
||
(blue shades=male, pink shades=female).
|
||
"""
|
||
df = self._ensure_dataframe(data)
|
||
|
||
if target_column not in df.columns:
|
||
return alt.Chart(pd.DataFrame({'text': [f"Column '{target_column}' not found"]})).mark_text().encode(text='text:N')
|
||
|
||
# Process data: split, explode, count
|
||
stats_df = (
|
||
df.select(pl.col(target_column))
|
||
.drop_nulls()
|
||
.with_columns(pl.col(target_column).str.split(","))
|
||
.explode(target_column)
|
||
.with_columns(pl.col(target_column).str.strip_chars())
|
||
.filter(pl.col(target_column) != "")
|
||
.group_by(target_column)
|
||
.agg(pl.len().alias("count"))
|
||
.sort("count", descending=True)
|
||
.with_row_index('rank_index')
|
||
.with_columns(
|
||
pl.when(pl.col('rank_index') < 8)
|
||
.then(pl.lit('Top 8'))
|
||
.otherwise(pl.lit('Other'))
|
||
.alias('category')
|
||
)
|
||
.to_pandas()
|
||
)
|
||
|
||
if color_gender:
|
||
# Add gender column based on voice label
|
||
stats_df['gender'] = stats_df[target_column].apply(self._get_voice_gender)
|
||
# Add gender_category column for combined color encoding
|
||
stats_df['gender_category'] = stats_df['gender'] + ' - ' + stats_df['category']
|
||
|
||
# Define combined domain and range for gender + category
|
||
domain = ['Male - Top 8', 'Male - Other', 'Female - Top 8', 'Female - Other']
|
||
range_colors = [
|
||
ColorPalette.GENDER_MALE, ColorPalette.GENDER_MALE_NEUTRAL,
|
||
ColorPalette.GENDER_FEMALE, ColorPalette.GENDER_FEMALE_NEUTRAL
|
||
]
|
||
|
||
chart = alt.Chart(stats_df).mark_bar().encode(
|
||
x=alt.X(f'{target_column}:N', title=x_label, sort='-y'),
|
||
y=alt.Y('count:Q', title=y_label),
|
||
color=alt.Color('gender_category:N',
|
||
scale=alt.Scale(domain=domain, range=range_colors),
|
||
legend=alt.Legend(orient='top', direction='horizontal', title=None)),
|
||
tooltip=[
|
||
alt.Tooltip(f'{target_column}:N', title='Voice'),
|
||
alt.Tooltip('count:Q', title='Selections'),
|
||
alt.Tooltip('gender:N', title='Gender')
|
||
]
|
||
).properties(
|
||
title=self._process_title(title),
|
||
width=width or 800,
|
||
height=height or getattr(self, 'plot_height', 400)
|
||
)
|
||
else:
|
||
chart = alt.Chart(stats_df).mark_bar().encode(
|
||
x=alt.X(f'{target_column}:N', title=x_label, sort='-y'),
|
||
y=alt.Y('count:Q', title=y_label),
|
||
color=alt.Color('category:N',
|
||
scale=alt.Scale(domain=['Top 8', 'Other'],
|
||
range=[ColorPalette.PRIMARY, ColorPalette.NEUTRAL]),
|
||
legend=None),
|
||
tooltip=[
|
||
alt.Tooltip(f'{target_column}:N', title='Voice'),
|
||
alt.Tooltip('count:Q', title='Selections')
|
||
]
|
||
).properties(
|
||
title=self._process_title(title),
|
||
width=width or 800,
|
||
height=height or getattr(self, 'plot_height', 400)
|
||
)
|
||
|
||
chart = self._save_plot(chart, title)
|
||
return chart
|
||
|
||
def plot_top3_selection_counts(
|
||
self,
|
||
data: pl.LazyFrame | pl.DataFrame | None = None,
|
||
target_column: str = "3_Ranked",
|
||
title: str = "Most Frequently Chosen Top 3 Voices\n(Top 3 Highlighted)",
|
||
x_label: str = "Voice",
|
||
y_label: str = "Count of Mentions in Top 3",
|
||
height: int | None = None,
|
||
width: int | str | None = None,
|
||
color_gender: bool = False,
|
||
) -> alt.Chart:
|
||
"""Question: Which 3 voices are chosen the most out of 18?
|
||
|
||
Parameters:
|
||
color_gender: If True, color bars by voice gender with highlight/neutral intensity
|
||
(blue shades=male, pink shades=female).
|
||
"""
|
||
df = self._ensure_dataframe(data)
|
||
|
||
if target_column not in df.columns:
|
||
return alt.Chart(pd.DataFrame({'text': [f"Column '{target_column}' not found"]})).mark_text().encode(text='text:N')
|
||
|
||
stats_df = (
|
||
df.select(pl.col(target_column))
|
||
.drop_nulls()
|
||
.with_columns(pl.col(target_column).str.split(","))
|
||
.explode(target_column)
|
||
.with_columns(pl.col(target_column).str.strip_chars())
|
||
.filter(pl.col(target_column) != "")
|
||
.group_by(target_column)
|
||
.agg(pl.len().alias("count"))
|
||
.sort("count", descending=True)
|
||
.with_row_index('rank_index')
|
||
.with_columns(
|
||
pl.when(pl.col('rank_index') < 3)
|
||
.then(pl.lit('Top 3'))
|
||
.otherwise(pl.lit('Other'))
|
||
.alias('category')
|
||
)
|
||
.to_pandas()
|
||
)
|
||
|
||
if color_gender:
|
||
# Add gender column based on voice label
|
||
stats_df['gender'] = stats_df[target_column].apply(self._get_voice_gender)
|
||
# Add gender_category column for combined color encoding
|
||
stats_df['gender_category'] = stats_df['gender'] + ' - ' + stats_df['category']
|
||
|
||
# Define combined domain and range for gender + category
|
||
domain = ['Male - Top 3', 'Male - Other', 'Female - Top 3', 'Female - Other']
|
||
range_colors = [
|
||
ColorPalette.GENDER_MALE, ColorPalette.GENDER_MALE_NEUTRAL,
|
||
ColorPalette.GENDER_FEMALE, ColorPalette.GENDER_FEMALE_NEUTRAL
|
||
]
|
||
|
||
chart = alt.Chart(stats_df).mark_bar().encode(
|
||
x=alt.X(f'{target_column}:N', title=x_label, sort='-y'),
|
||
y=alt.Y('count:Q', title=y_label),
|
||
color=alt.Color('gender_category:N',
|
||
scale=alt.Scale(domain=domain, range=range_colors),
|
||
legend=alt.Legend(orient='top', direction='horizontal', title=None)),
|
||
tooltip=[
|
||
alt.Tooltip(f'{target_column}:N', title='Voice'),
|
||
alt.Tooltip('count:Q', title='In Top 3'),
|
||
alt.Tooltip('gender:N', title='Gender')
|
||
]
|
||
).properties(
|
||
title=self._process_title(title),
|
||
width=width or 800,
|
||
height=height or getattr(self, 'plot_height', 400)
|
||
)
|
||
else:
|
||
chart = alt.Chart(stats_df).mark_bar().encode(
|
||
x=alt.X(f'{target_column}:N', title=x_label, sort='-y'),
|
||
y=alt.Y('count:Q', title=y_label),
|
||
color=alt.Color('category:N',
|
||
scale=alt.Scale(domain=['Top 3', 'Other'],
|
||
range=[ColorPalette.PRIMARY, ColorPalette.NEUTRAL]),
|
||
legend=None),
|
||
tooltip=[
|
||
alt.Tooltip(f'{target_column}:N', title='Voice'),
|
||
alt.Tooltip('count:Q', title='In Top 3')
|
||
]
|
||
).properties(
|
||
title=self._process_title(title),
|
||
width=width or 800,
|
||
height=height or getattr(self, 'plot_height', 400)
|
||
)
|
||
|
||
chart = self._save_plot(chart, title)
|
||
return chart
|
||
|
||
def plot_speaking_style_trait_scores(
|
||
self,
|
||
data: pl.LazyFrame | pl.DataFrame | None = None,
|
||
trait_description: str = None,
|
||
left_anchor: str = None,
|
||
right_anchor: str = None,
|
||
title: str = "Speaking Style Trait Analysis",
|
||
height: int | None = None,
|
||
width: int | str | None = None,
|
||
) -> alt.Chart:
|
||
"""Plot scores for a single speaking style trait across multiple voices."""
|
||
df = self._ensure_dataframe(data)
|
||
|
||
if df.is_empty():
|
||
return alt.Chart(pd.DataFrame({'text': ['No data']})).mark_text().encode(text='text:N')
|
||
|
||
required_cols = ["Voice", "score"]
|
||
if not all(col in df.columns for col in required_cols):
|
||
return alt.Chart(pd.DataFrame({'text': ['Missing required columns']})).mark_text().encode(text='text:N')
|
||
|
||
# Calculate stats: Mean, Count
|
||
stats = (
|
||
df.filter(pl.col("score").is_not_null())
|
||
.group_by("Voice")
|
||
.agg([
|
||
pl.col("score").mean().alias("mean_score"),
|
||
pl.col("score").count().alias("count")
|
||
])
|
||
.sort("mean_score", descending=False) # Ascending for bottom-to-top display
|
||
.to_pandas()
|
||
)
|
||
|
||
# Extract anchors from data if not provided
|
||
if (left_anchor is None or right_anchor is None) and "Left_Anchor" in df.columns:
|
||
head = df.filter(pl.col("Left_Anchor").is_not_null()).head(1)
|
||
if not head.is_empty():
|
||
if left_anchor is None:
|
||
left_anchor = head["Left_Anchor"][0]
|
||
if right_anchor is None:
|
||
right_anchor = head["Right_Anchor"][0]
|
||
|
||
if trait_description is None:
|
||
if left_anchor and right_anchor:
|
||
trait_description = f"{left_anchor.split('|')[0]} vs. {right_anchor.split('|')[0]}"
|
||
elif "Description" in df.columns:
|
||
head = df.filter(pl.col("Description").is_not_null()).head(1)
|
||
trait_description = head["Description"][0] if not head.is_empty() else ""
|
||
else:
|
||
trait_description = ""
|
||
|
||
# Horizontal bar chart - use x2 to explicitly start bars at x=1
|
||
bars = alt.Chart(stats).mark_bar(color=ColorPalette.PRIMARY).encode(
|
||
x=alt.X('mean_score:Q', title='Average Score (1-5)', scale=alt.Scale(domain=[1, 5])),
|
||
x2=alt.datum(1), # Bars start at x=1 (left edge of domain)
|
||
y=alt.Y('Voice:N', title='Voice', sort='-x'),
|
||
tooltip=[
|
||
alt.Tooltip('Voice:N'),
|
||
alt.Tooltip('mean_score:Q', title='Average', format='.2f'),
|
||
alt.Tooltip('count:Q', title='Count')
|
||
]
|
||
)
|
||
|
||
# Count text at end of bars (right-aligned inside bar)
|
||
text = alt.Chart(stats).mark_text(
|
||
align='right',
|
||
baseline='middle',
|
||
color='white',
|
||
fontSize=12,
|
||
dx=-5 # Slight padding from bar end
|
||
).encode(
|
||
x='mean_score:Q',
|
||
y=alt.Y('Voice:N', sort='-x'),
|
||
text='count:Q'
|
||
)
|
||
|
||
# Combine layers
|
||
chart = (bars + text).properties(
|
||
title={
|
||
"text": self._process_title(title),
|
||
"subtitle": [trait_description, "(Numbers on bars indicate respondent count)"]
|
||
},
|
||
width=width or 800,
|
||
height=height or getattr(self, 'plot_height', 400)
|
||
)
|
||
|
||
chart = self._save_plot(chart, title)
|
||
return chart
|
||
|
||
def plot_speaking_style_correlation(
|
||
self,
|
||
style_color: str,
|
||
style_traits: list[str],
|
||
data: pl.LazyFrame | pl.DataFrame | None = None,
|
||
title: str | None = None,
|
||
width: int | str | None = None,
|
||
height: int | None = None,
|
||
) -> alt.Chart:
|
||
"""Plots correlation between Speaking Style Trait Scores (1-5) and Voice Scale (1-10)."""
|
||
df = self._ensure_dataframe(data)
|
||
|
||
if title is None:
|
||
title = f"Speaking style and voice scale 1-10 correlations"
|
||
|
||
trait_correlations = []
|
||
|
||
# Calculate correlations
|
||
for i, trait in enumerate(style_traits):
|
||
subset = df.filter(pl.col("Right_Anchor") == trait)
|
||
valid_data = subset.select(["score", "Voice_Scale_Score"]).drop_nulls()
|
||
|
||
if valid_data.height > 1:
|
||
corr_val = valid_data.select(pl.corr("score", "Voice_Scale_Score")).item()
|
||
# Wrap trait text at '|' for display
|
||
trait_display = trait.replace('|', '\n')
|
||
trait_correlations.append({
|
||
"trait_display": trait_display,
|
||
"trait_index": f"Trait {i+1}",
|
||
"correlation": corr_val if corr_val is not None else 0.0
|
||
})
|
||
|
||
if not trait_correlations:
|
||
return alt.Chart(pd.DataFrame({'text': [f"No data for {style_color} Style"]})).mark_text().encode(text='text:N')
|
||
|
||
plot_df = pl.DataFrame(trait_correlations).to_pandas()
|
||
|
||
# Conditional color based on sign
|
||
chart = alt.Chart(plot_df).mark_bar().encode(
|
||
x=alt.X('trait_display:N', title=None, axis=alt.Axis(labelAngle=0)),
|
||
y=alt.Y('correlation:Q', title='Correlation', scale=alt.Scale(domain=[-1, 1])),
|
||
color=alt.condition(
|
||
alt.datum.correlation >= 0,
|
||
alt.value('green'),
|
||
alt.value('red')
|
||
),
|
||
tooltip=[
|
||
alt.Tooltip('trait_display:N', title='Trait'),
|
||
alt.Tooltip('correlation:Q', format='.2f')
|
||
]
|
||
).properties(
|
||
title=self._process_title(title),
|
||
width=width or 800,
|
||
height=height or 350
|
||
)
|
||
|
||
chart = self._save_plot(chart, title)
|
||
return chart
|
||
|
||
def plot_speaking_style_color_correlation(
|
||
self,
|
||
data: pl.LazyFrame | pl.DataFrame | None = None,
|
||
title: str = "Speaking Style and Voice Scale 1-10 Correlations<br>(Average by Color)",
|
||
width: int | str | None = None,
|
||
height: int | None = None,
|
||
) -> alt.Chart:
|
||
"""Plot high-level correlation showing one bar per speaking style color.
|
||
|
||
Original use-case: "I want to create high-level correlation plots between
|
||
'green, blue, orange, red' speaking styles and the 'voice scale scores'.
|
||
I want to go to one plot with one bar for each color."
|
||
|
||
Args:
|
||
data: DataFrame with columns [Color, correlation, n_traits] from
|
||
utils.transform_speaking_style_color_correlation
|
||
title: Chart title (supports <br> for line breaks)
|
||
width: Chart width in pixels
|
||
height: Chart height in pixels
|
||
|
||
Returns:
|
||
Altair chart with one bar per speaking style color
|
||
"""
|
||
df = self._ensure_dataframe(data)
|
||
|
||
# Conditional color based on sign (matches plot_speaking_style_correlation)
|
||
chart = alt.Chart(df.to_pandas()).mark_bar().encode(
|
||
x=alt.X('Color:N',
|
||
title=None,
|
||
axis=alt.Axis(labelAngle=0),
|
||
sort=["Green", "Blue", "Orange", "Red"]),
|
||
y=alt.Y('correlation:Q',
|
||
title='Average Correlation',
|
||
scale=alt.Scale(domain=[-1, 1])),
|
||
color=alt.condition(
|
||
alt.datum.correlation >= 0,
|
||
alt.value('green'),
|
||
alt.value('red')
|
||
),
|
||
tooltip=[
|
||
alt.Tooltip('Color:N', title='Speaking Style'),
|
||
alt.Tooltip('correlation:Q', format='.3f', title='Avg Correlation'),
|
||
alt.Tooltip('n_traits:Q', title='# Traits')
|
||
]
|
||
).properties(
|
||
title=self._process_title(title),
|
||
width=width or 400,
|
||
height=height or 350
|
||
)
|
||
|
||
chart = self._save_plot(chart, title)
|
||
return chart
|
||
|
||
def plot_demographic_distribution(
|
||
self,
|
||
column: str,
|
||
data: pl.LazyFrame | pl.DataFrame | None = None,
|
||
title: str | None = None,
|
||
height: int | None = None,
|
||
width: int | str | None = None,
|
||
show_counts: bool = True,
|
||
) -> alt.Chart:
|
||
"""Create a horizontal bar chart showing the distribution of respondents by a demographic column.
|
||
|
||
Designed to be compact so multiple charts (approx. 6) can fit on one slide.
|
||
Uses horizontal bars for better readability with many categories.
|
||
|
||
Parameters:
|
||
column: The column name to analyze (e.g., 'Age', 'Gender', 'Race/Ethnicity').
|
||
data: Optional DataFrame. If None, uses self.data_filtered.
|
||
title: Chart title. If None, auto-generates based on column name.
|
||
height: Chart height in pixels (default: auto-sized based on categories).
|
||
width: Chart width in pixels (default: 280 for compact layout).
|
||
show_counts: If True, display count labels on the bars.
|
||
|
||
Returns:
|
||
alt.Chart: An Altair horizontal bar chart showing the distribution.
|
||
"""
|
||
df = self._ensure_dataframe(data)
|
||
|
||
if column not in df.columns:
|
||
return alt.Chart(pd.DataFrame({'text': [f"Column '{column}' not found"]})).mark_text().encode(text='text:N')
|
||
|
||
# Count values in the column, including nulls
|
||
stats_df = (
|
||
df.select(pl.col(column))
|
||
.with_columns(pl.col(column).fill_null("(No Response)"))
|
||
.group_by(column)
|
||
.agg(pl.len().alias("count"))
|
||
.sort("count", descending=True)
|
||
.to_pandas()
|
||
)
|
||
|
||
if stats_df.empty:
|
||
return alt.Chart(pd.DataFrame({'text': ['No data']})).mark_text().encode(text='text:N')
|
||
|
||
# Calculate percentages
|
||
total = stats_df['count'].sum()
|
||
stats_df['percentage'] = (stats_df['count'] / total * 100).round(1)
|
||
|
||
# Generate title if not provided
|
||
if title is None:
|
||
clean_col = column.replace('_', ' ').replace('/', ' / ')
|
||
title = f"Distribution: {clean_col}"
|
||
|
||
# Calculate appropriate height based on number of categories
|
||
num_categories = len(stats_df)
|
||
bar_height = 18 # pixels per bar
|
||
calculated_height = max(120, num_categories * bar_height + 40) # min 120px, +40 for title/padding
|
||
|
||
# Horizontal bar chart - categories on Y axis, counts on X axis
|
||
bars = alt.Chart(stats_df).mark_bar(color=ColorPalette.PRIMARY).encode(
|
||
x=alt.X('count:Q', title='Count', axis=alt.Axis(grid=False)),
|
||
y=alt.Y(f'{column}:N', title=None, sort='-x', axis=alt.Axis(labelLimit=150)),
|
||
tooltip=[
|
||
alt.Tooltip(f'{column}:N', title=column.replace('_', ' ')),
|
||
alt.Tooltip('count:Q', title='Count'),
|
||
alt.Tooltip('percentage:Q', title='Percentage', format='.1f')
|
||
]
|
||
)
|
||
|
||
# Add count labels at end of bars
|
||
if show_counts:
|
||
text = alt.Chart(stats_df).mark_text(
|
||
align='left',
|
||
baseline='middle',
|
||
dx=3, # Offset from bar end
|
||
fontSize=9,
|
||
color=ColorPalette.TEXT
|
||
).encode(
|
||
x='count:Q',
|
||
y=alt.Y(f'{column}:N', sort='-x'),
|
||
text='count:Q'
|
||
)
|
||
chart = (bars + text)
|
||
else:
|
||
chart = bars
|
||
|
||
# Compact dimensions for 6-per-slide layout
|
||
chart = chart.properties(
|
||
title=self._process_title(title),
|
||
width=width or 200,
|
||
height=height or calculated_height
|
||
)
|
||
|
||
chart = self._save_plot(chart, title)
|
||
return chart
|
||
|
||
def plot_speaking_style_ranking_correlation(
|
||
self,
|
||
style_color: str,
|
||
style_traits: list[str],
|
||
data: pl.LazyFrame | pl.DataFrame | None = None,
|
||
title: str | None = None,
|
||
width: int | str | None = None,
|
||
height: int | None = None,
|
||
) -> alt.Chart:
|
||
"""Plots correlation between Speaking Style Trait Scores (1-5) and Voice Ranking Points (0-3)."""
|
||
df = self._ensure_dataframe(data)
|
||
|
||
if title is None:
|
||
title = f"Speaking style {style_color} and voice ranking points correlations"
|
||
|
||
trait_correlations = []
|
||
|
||
for i, trait in enumerate(style_traits):
|
||
subset = df.filter(pl.col("Right_Anchor") == trait)
|
||
valid_data = subset.select(["score", "Ranking_Points"]).drop_nulls()
|
||
|
||
if valid_data.height > 1:
|
||
corr_val = valid_data.select(pl.corr("score", "Ranking_Points")).item()
|
||
trait_display = trait.replace('|', '\n')
|
||
trait_correlations.append({
|
||
"trait_display": trait_display,
|
||
"trait_index": f"Trait {i+1}",
|
||
"correlation": corr_val if corr_val is not None else 0.0
|
||
})
|
||
|
||
if not trait_correlations:
|
||
return alt.Chart(pd.DataFrame({'text': [f"No data for {style_color} Style"]})).mark_text().encode(text='text:N')
|
||
|
||
plot_df = pl.DataFrame(trait_correlations).to_pandas()
|
||
|
||
chart = alt.Chart(plot_df).mark_bar().encode(
|
||
x=alt.X('trait_display:N', title=None, axis=alt.Axis(labelAngle=0)),
|
||
y=alt.Y('correlation:Q', title='Correlation', scale=alt.Scale(domain=[-1, 1])),
|
||
color=alt.condition(
|
||
alt.datum.correlation >= 0,
|
||
alt.value('green'),
|
||
alt.value('red')
|
||
),
|
||
tooltip=[
|
||
alt.Tooltip('trait_display:N', title='Trait'),
|
||
alt.Tooltip('correlation:Q', format='.2f')
|
||
]
|
||
).properties(
|
||
title=self._process_title(title),
|
||
width=width or 800,
|
||
height=height or 350
|
||
)
|
||
|
||
chart = self._save_plot(chart, title)
|
||
return chart
|
||
|
||
def plot_traits_wordcloud(
|
||
self,
|
||
data: pl.LazyFrame | pl.DataFrame | None = None,
|
||
column: str = 'Top_3_Traits',
|
||
title: str = "Most Prominent Personality Traits",
|
||
width: int = 1600,
|
||
height: int = 800,
|
||
background_color: str = 'white',
|
||
random_state: int = 23,
|
||
):
|
||
"""Create a word cloud visualization of personality traits from survey data.
|
||
|
||
Args:
|
||
data: Polars DataFrame or LazyFrame containing trait data
|
||
column: Name of column containing comma-separated traits
|
||
title: Title for the word cloud
|
||
width: Width of the word cloud image in pixels
|
||
height: Height of the word cloud image in pixels
|
||
background_color: Background color for the word cloud
|
||
random_state: Random seed for reproducible word cloud generation (default: 23)
|
||
|
||
Returns:
|
||
matplotlib.figure.Figure: The word cloud figure for display in notebooks
|
||
"""
|
||
import matplotlib.pyplot as plt
|
||
from wordcloud import WordCloud
|
||
from collections import Counter
|
||
import random
|
||
|
||
df = self._ensure_dataframe(data)
|
||
|
||
# Extract and split traits
|
||
traits_list = []
|
||
for row in df[column].drop_nulls():
|
||
# Split by comma and clean whitespace
|
||
traits = [trait.strip() for trait in row.split(',')]
|
||
traits_list.extend(traits)
|
||
|
||
# Create frequency dictionary
|
||
trait_freq = Counter(traits_list)
|
||
|
||
# Set random seed for color selection
|
||
random.seed(random_state)
|
||
|
||
# Color function using JPMC colors
|
||
def color_func(word, font_size, position, orientation, random_state=None, **kwargs):
|
||
colors = [
|
||
ColorPalette.PRIMARY,
|
||
ColorPalette.RANK_1,
|
||
ColorPalette.RANK_2,
|
||
ColorPalette.RANK_3,
|
||
]
|
||
return random.choice(colors)
|
||
|
||
# Generate word cloud
|
||
wordcloud = WordCloud(
|
||
width=width,
|
||
height=height,
|
||
background_color=background_color,
|
||
color_func=color_func,
|
||
relative_scaling=0.5,
|
||
min_font_size=10,
|
||
prefer_horizontal=0.7,
|
||
collocations=False, # Treat each word independently
|
||
random_state=random_state # Seed for reproducible layout
|
||
).generate_from_frequencies(trait_freq)
|
||
|
||
# Create matplotlib figure
|
||
fig, ax = plt.subplots(figsize=(width/100, height/100), dpi=100)
|
||
ax.imshow(wordcloud, interpolation='bilinear')
|
||
ax.axis('off')
|
||
ax.set_title(title, fontsize=16, pad=20, color=ColorPalette.TEXT)
|
||
|
||
plt.tight_layout(pad=0)
|
||
|
||
# Save figure if directory specified (using same pattern as other plots)
|
||
if hasattr(self, 'fig_save_dir') and self.fig_save_dir:
|
||
save_path = Path(self.fig_save_dir)
|
||
|
||
# Add filter slug subfolder
|
||
filter_slug = self._get_filter_slug()
|
||
save_path = save_path / filter_slug
|
||
|
||
if not save_path.exists():
|
||
save_path.mkdir(parents=True, exist_ok=True)
|
||
|
||
# Use _sanitize_filename for consistency
|
||
filename = f"{self._sanitize_filename(title)}.png"
|
||
filepath = save_path / filename
|
||
|
||
# Save as PNG at high resolution
|
||
fig.savefig(filepath, dpi=300, bbox_inches='tight', facecolor='white')
|
||
print(f"Word cloud saved to: {filepath}")
|
||
|
||
return fig
|
||
|
||
|
||
def plot_character_trait_frequency(
|
||
self,
|
||
data: pl.LazyFrame | pl.DataFrame | None = None,
|
||
title: str = "Trait Frequency per Brand Character",
|
||
x_label: str = "Trait",
|
||
y_label: str = "Frequency (Times Chosen)",
|
||
height: int | None = None,
|
||
width: int | str | None = None,
|
||
) -> alt.Chart:
|
||
"""Create a grouped bar plot showing how often each trait is chosen per character.
|
||
|
||
Original request: "I need a bar plot that shows the frequency of the times
|
||
each trait is chosen per brand character"
|
||
|
||
Expects data with columns: Character, Trait, Count (as produced by
|
||
transform_character_trait_frequency).
|
||
"""
|
||
df = self._ensure_dataframe(data)
|
||
|
||
# Ensure we have the expected columns
|
||
required_cols = {'Character', 'Trait', 'Count'}
|
||
if not required_cols.issubset(set(df.columns)):
|
||
return alt.Chart(pd.DataFrame({'text': ['Data must have Character, Trait, Count columns']})).mark_text().encode(text='text:N')
|
||
|
||
# Convert to pandas for Altair
|
||
plot_df = df.to_pandas() if hasattr(df, 'to_pandas') else df
|
||
|
||
# Calculate total per trait for sorting (traits with highest overall frequency first)
|
||
trait_totals = plot_df.groupby('Trait')['Count'].sum().sort_values(ascending=False)
|
||
trait_order = trait_totals.index.tolist()
|
||
|
||
# Get unique characters for color mapping
|
||
characters = plot_df['Character'].unique().tolist()
|
||
|
||
# Interactive legend selection - click to filter
|
||
selection = alt.selection_point(fields=['Character'], bind='legend')
|
||
|
||
chart = alt.Chart(plot_df).mark_bar().encode(
|
||
x=alt.X('Trait:N',
|
||
title=x_label,
|
||
sort=trait_order,
|
||
axis=alt.Axis(labelAngle=-45, labelLimit=200)),
|
||
y=alt.Y('Count:Q', title=y_label),
|
||
xOffset='Character:N',
|
||
color=alt.Color('Character:N',
|
||
scale=alt.Scale(domain=characters,
|
||
range=ColorPalette.CATEGORICAL[:len(characters)]),
|
||
legend=alt.Legend(orient='top', direction='horizontal', title='Character')),
|
||
opacity=alt.condition(selection, alt.value(1), alt.value(0.2)),
|
||
tooltip=[
|
||
alt.Tooltip('Character:N', title='Character'),
|
||
alt.Tooltip('Trait:N', title='Trait'),
|
||
alt.Tooltip('Count:Q', title='Frequency')
|
||
]
|
||
).add_params(selection).properties(
|
||
title=self._process_title(title),
|
||
width=width or 900,
|
||
height=height or getattr(self, 'plot_height', 400)
|
||
)
|
||
|
||
chart = self._save_plot(chart, title)
|
||
return chart
|
||
|
||
def plot_single_character_trait_frequency(
|
||
self,
|
||
data: pl.LazyFrame | pl.DataFrame | None = None,
|
||
character_name: str = "Character",
|
||
bar_color: str = ColorPalette.PRIMARY,
|
||
highlight_color: str = ColorPalette.NEUTRAL,
|
||
title: str | None = None,
|
||
x_label: str = "Trait",
|
||
y_label: str = "Frequency",
|
||
trait_sort_order: list[str] | None = None,
|
||
height: int | None = None,
|
||
width: int | str | None = None,
|
||
) -> alt.Chart:
|
||
"""Create a bar plot showing trait frequency for a single character.
|
||
|
||
Original request: "I need a bar plot that shows the frequency of the times
|
||
each trait is chosen per brand character. The function should be generalized
|
||
so that it can be used 4 times, once for each character. Each character should
|
||
use a slightly different color. Original traits should be highlighted."
|
||
|
||
This function creates one plot per character. Call it 4 times (once per
|
||
character) to generate all plots for a slide.
|
||
|
||
Args:
|
||
data: DataFrame with columns ['trait', 'count', 'is_original']
|
||
as produced by transform_character_trait_frequency()
|
||
character_name: Name of the character (for title). E.g., "Bank Teller"
|
||
bar_color: Main bar color for non-original traits. Use ColorPalette
|
||
constants like ColorPalette.CHARACTER_BANK_TELLER
|
||
highlight_color: Lighter color for original/expected traits. Use the
|
||
matching highlight like ColorPalette.CHARACTER_BANK_TELLER_HIGHLIGHT
|
||
title: Custom title. If None, auto-generates from character_name
|
||
x_label: X-axis label
|
||
y_label: Y-axis label
|
||
trait_sort_order: Optional list of traits for consistent sorting across
|
||
all character plots. If None, sorts by count descending.
|
||
height: Chart height
|
||
width: Chart width
|
||
|
||
Returns:
|
||
alt.Chart: Altair bar chart
|
||
"""
|
||
df = self._ensure_dataframe(data)
|
||
|
||
# Ensure we have the expected columns
|
||
required_cols = {'trait', 'count', 'is_original'}
|
||
if not required_cols.issubset(set(df.columns)):
|
||
return alt.Chart(pd.DataFrame({
|
||
'text': ['Data must have trait, count, is_original columns']
|
||
})).mark_text().encode(text='text:N')
|
||
|
||
# Convert to pandas for Altair
|
||
plot_df = df.to_pandas() if hasattr(df, 'to_pandas') else df
|
||
|
||
# Determine sort order
|
||
if trait_sort_order is not None:
|
||
# Use provided order, append any missing traits at the end (sorted by count)
|
||
known_traits = set(trait_sort_order)
|
||
extra_traits = plot_df[~plot_df['trait'].isin(known_traits)].sort_values(
|
||
'count', ascending=False
|
||
)['trait'].tolist()
|
||
sort_order = trait_sort_order + extra_traits
|
||
else:
|
||
# Default: sort by count descending
|
||
sort_order = plot_df.sort_values('count', ascending=False)['trait'].tolist()
|
||
|
||
# Create category column for color encoding
|
||
plot_df['category'] = plot_df['is_original'].map({
|
||
True: 'Original Trait',
|
||
False: 'Other Trait'
|
||
})
|
||
|
||
# Generate title if not provided
|
||
if title is None:
|
||
title = f"{character_name}<br>Trait Selection Frequency"
|
||
|
||
# Build title config with sort order note as subtitle
|
||
sort_note = "Sorted by total frequency across all characters" if trait_sort_order else "Sorted by frequency (descending)"
|
||
title_text = self._process_title(title)
|
||
title_config = {
|
||
'text': title_text,
|
||
'subtitle': sort_note,
|
||
'subtitleColor': 'gray',
|
||
'subtitleFontSize': 10,
|
||
'anchor': 'start',
|
||
}
|
||
|
||
# Create HORIZONTAL bar chart with conditional coloring
|
||
# Reverse sort order for horizontal bars (highest at top)
|
||
reversed_sort = list(reversed(sort_order))
|
||
|
||
bars = alt.Chart(plot_df).mark_bar().encode(
|
||
y=alt.Y('trait:N',
|
||
title=x_label,
|
||
sort=reversed_sort,
|
||
axis=alt.Axis(labelLimit=200)),
|
||
x=alt.X('count:Q', title=y_label),
|
||
color=alt.Color('category:N',
|
||
scale=alt.Scale(
|
||
domain=['Original Trait', 'Other Trait'],
|
||
range=[highlight_color, bar_color]
|
||
),
|
||
legend=alt.Legend(
|
||
orient='top',
|
||
direction='horizontal',
|
||
title=None
|
||
)),
|
||
tooltip=[
|
||
alt.Tooltip('trait:N', title='Trait'),
|
||
alt.Tooltip('count:Q', title='Frequency'),
|
||
alt.Tooltip('category:N', title='Type')
|
||
]
|
||
)
|
||
|
||
# Add count labels on bars (to the right of bars for horizontal)
|
||
text = alt.Chart(plot_df).mark_text(
|
||
dx=12,
|
||
color='black',
|
||
fontSize=10,
|
||
align='left'
|
||
).encode(
|
||
y=alt.Y('trait:N', sort=reversed_sort),
|
||
x=alt.X('count:Q'),
|
||
text=alt.Text('count:Q')
|
||
)
|
||
|
||
chart = (bars + text).properties(
|
||
title=title_config,
|
||
width=width or 400,
|
||
height=height or getattr(self, 'plot_height', 450)
|
||
)
|
||
|
||
chart = self._save_plot(chart, title)
|
||
return chart
|
||
|
||
def plot_significance_heatmap(
|
||
self,
|
||
pairwise_df: pl.LazyFrame | pl.DataFrame | None = None,
|
||
metadata: dict | None = None,
|
||
title: str = "Pairwise Statistical Significance<br>(Adjusted p-values)",
|
||
show_p_values: bool = True,
|
||
show_effect_size: bool = False,
|
||
height: int | None = None,
|
||
width: int | None = None,
|
||
) -> alt.Chart:
|
||
"""Create a heatmap showing pairwise statistical significance between groups.
|
||
|
||
Original use-case: "I need to test for statistical significance and present
|
||
this in a logical manner - as a heatmap or similar visualization."
|
||
|
||
This function visualizes the output of compute_pairwise_significance() as
|
||
a color-coded heatmap where color intensity indicates significance level.
|
||
|
||
Args:
|
||
pairwise_df: Output from compute_pairwise_significance().
|
||
Expected columns: ['group1', 'group2', 'p_value', 'p_adjusted', 'significant']
|
||
metadata: Metadata dict from compute_pairwise_significance() (optional).
|
||
Used to add test information to the plot subtitle.
|
||
title: Chart title (supports <br> for line breaks)
|
||
show_p_values: Whether to display p-values as text annotations
|
||
show_effect_size: Whether to display effect sizes instead of p-values
|
||
height: Chart height (default: auto-sized based on groups)
|
||
width: Chart width (default: auto-sized based on groups)
|
||
|
||
Returns:
|
||
alt.Chart: Altair heatmap chart
|
||
"""
|
||
df = self._ensure_dataframe(pairwise_df)
|
||
|
||
# Get unique groups
|
||
all_groups = sorted(set(df['group1'].to_list() + df['group2'].to_list()))
|
||
n_groups = len(all_groups)
|
||
|
||
# Create symmetric matrix data for heatmap
|
||
# We need both directions (A,B) and (B,A) for the full matrix
|
||
heatmap_data = []
|
||
for row_group in all_groups:
|
||
for col_group in all_groups:
|
||
if row_group == col_group:
|
||
# Diagonal - self comparison
|
||
heatmap_data.append({
|
||
'row': row_group,
|
||
'col': col_group,
|
||
'p_adjusted': None,
|
||
'p_value': None,
|
||
'significant': None,
|
||
'effect_size': None,
|
||
'text_label': '—',
|
||
'sig_category': 'Self',
|
||
})
|
||
else:
|
||
# Find the comparison (could be in either order)
|
||
match = df.filter(
|
||
((pl.col('group1') == row_group) & (pl.col('group2') == col_group)) |
|
||
((pl.col('group1') == col_group) & (pl.col('group2') == row_group))
|
||
)
|
||
if match.height > 0:
|
||
p_adj = match['p_adjusted'][0]
|
||
p_val = match['p_value'][0]
|
||
sig = match['significant'][0]
|
||
eff = match['effect_size'][0] if 'effect_size' in match.columns else None
|
||
|
||
# For ranking data, we can show Rank 1 % difference
|
||
has_rank_pcts = 'rank1_pct1' in match.columns and 'rank1_pct2' in match.columns
|
||
if has_rank_pcts:
|
||
pct_diff = abs(match['rank1_pct1'][0] - match['rank1_pct2'][0])
|
||
else:
|
||
pct_diff = None
|
||
|
||
# Helper to get display text when not showing p-values
|
||
def get_alt_text():
|
||
if eff is not None:
|
||
return f'{eff:.2f}'
|
||
elif pct_diff is not None:
|
||
return f'{pct_diff:.1f}%'
|
||
else:
|
||
return '—'
|
||
|
||
# Categorize significance level
|
||
if p_adj is None:
|
||
sig_cat = 'N/A'
|
||
text = 'N/A'
|
||
elif p_adj < 0.001:
|
||
sig_cat = 'p < 0.001'
|
||
text = '<.001' if show_p_values else get_alt_text()
|
||
elif p_adj < 0.01:
|
||
sig_cat = 'p < 0.01'
|
||
text = f'{p_adj:.3f}' if show_p_values else get_alt_text()
|
||
elif p_adj < 0.05:
|
||
sig_cat = 'p < 0.05'
|
||
text = f'{p_adj:.3f}' if show_p_values else get_alt_text()
|
||
else:
|
||
sig_cat = 'n.s.'
|
||
text = f'{p_adj:.2f}' if show_p_values else get_alt_text()
|
||
|
||
if show_effect_size:
|
||
text = get_alt_text()
|
||
|
||
heatmap_data.append({
|
||
'row': row_group,
|
||
'col': col_group,
|
||
'p_adjusted': p_adj,
|
||
'p_value': p_val,
|
||
'significant': sig,
|
||
'effect_size': eff,
|
||
'text_label': text,
|
||
'sig_category': sig_cat,
|
||
})
|
||
else:
|
||
heatmap_data.append({
|
||
'row': row_group,
|
||
'col': col_group,
|
||
'p_adjusted': None,
|
||
'p_value': None,
|
||
'significant': None,
|
||
'effect_size': None,
|
||
'text_label': 'N/A',
|
||
'sig_category': 'N/A',
|
||
})
|
||
|
||
heatmap_df = pl.DataFrame(heatmap_data).to_pandas()
|
||
|
||
# Define color scale for significance categories
|
||
sig_domain = ['p < 0.001', 'p < 0.01', 'p < 0.05', 'n.s.', 'Self', 'N/A']
|
||
sig_range = [
|
||
ColorPalette.SIG_STRONG, # p < 0.001
|
||
ColorPalette.SIG_MODERATE, # p < 0.01
|
||
ColorPalette.SIG_WEAK, # p < 0.05
|
||
ColorPalette.SIG_NONE, # not significant
|
||
ColorPalette.SIG_DIAGONAL, # diagonal (self)
|
||
ColorPalette.NEUTRAL, # N/A
|
||
]
|
||
|
||
# Build tooltip fields based on available data
|
||
tooltip_fields = [
|
||
alt.Tooltip('row:N', title='Group 1'),
|
||
alt.Tooltip('col:N', title='Group 2'),
|
||
alt.Tooltip('p_value:Q', title='p-value', format='.4f'),
|
||
alt.Tooltip('p_adjusted:Q', title='Adjusted p', format='.4f'),
|
||
]
|
||
# Only add effect_size if it has non-null values (continuous data)
|
||
has_effect_size = 'effect_size' in heatmap_df.columns and heatmap_df['effect_size'].notna().any()
|
||
if has_effect_size:
|
||
tooltip_fields.append(alt.Tooltip('effect_size:Q', title='Effect Size', format='.3f'))
|
||
# Add rank info for ranking data
|
||
has_rank_pcts = 'rank1_pct1' in df.columns if isinstance(df, pl.DataFrame) else False
|
||
if has_rank_pcts:
|
||
tooltip_fields.append(alt.Tooltip('text_label:N', title='Rank 1 % Diff'))
|
||
|
||
# Calculate dimensions
|
||
cell_size = 45
|
||
auto_size = n_groups * cell_size + 100
|
||
chart_width = width or auto_size
|
||
chart_height = height or auto_size
|
||
|
||
# Base heatmap
|
||
heatmap = alt.Chart(heatmap_df).mark_rect(stroke='white', strokeWidth=1).encode(
|
||
x=alt.X('col:N', title=None, sort=all_groups,
|
||
axis=alt.Axis(labelAngle=-45, labelLimit=150)),
|
||
y=alt.Y('row:N', title=None, sort=all_groups,
|
||
axis=alt.Axis(labelLimit=150)),
|
||
color=alt.Color('sig_category:N',
|
||
scale=alt.Scale(domain=sig_domain, range=sig_range),
|
||
legend=alt.Legend(
|
||
title='Significance',
|
||
orient='right',
|
||
direction='vertical'
|
||
)),
|
||
tooltip=tooltip_fields
|
||
)
|
||
|
||
# Text annotations
|
||
if show_p_values or show_effect_size:
|
||
# Add a column for text color based on significance
|
||
heatmap_df['text_color'] = heatmap_df['sig_category'].apply(
|
||
lambda x: 'white' if x in ['p < 0.001', 'p < 0.01'] else 'black'
|
||
)
|
||
|
||
text = alt.Chart(heatmap_df).mark_text(
|
||
fontSize=9,
|
||
fontWeight='normal'
|
||
).encode(
|
||
x=alt.X('col:N', sort=all_groups),
|
||
y=alt.Y('row:N', sort=all_groups),
|
||
text='text_label:N',
|
||
color=alt.Color('text_color:N', scale=None),
|
||
)
|
||
chart = (heatmap + text)
|
||
else:
|
||
chart = heatmap
|
||
|
||
# Build subtitle with test info
|
||
subtitle_lines = []
|
||
if metadata:
|
||
test_info = f"Test: {metadata.get('test_type', 'N/A')}"
|
||
if metadata.get('overall_p_value') is not None:
|
||
test_info += f" | Overall p={metadata['overall_p_value']:.4f}"
|
||
correction = metadata.get('correction', 'none')
|
||
if correction != 'none':
|
||
test_info += f" | Correction: {correction}"
|
||
subtitle_lines.append(test_info)
|
||
|
||
title_config = {
|
||
'text': self._process_title(title),
|
||
'subtitle': subtitle_lines if subtitle_lines else None,
|
||
'subtitleColor': 'gray',
|
||
'subtitleFontSize': 10,
|
||
'anchor': 'start',
|
||
}
|
||
|
||
chart = chart.properties(
|
||
title=title_config,
|
||
width=chart_width,
|
||
height=chart_height,
|
||
)
|
||
|
||
chart = self._save_plot(chart, title)
|
||
return chart
|
||
|
||
def plot_significance_summary(
|
||
self,
|
||
pairwise_df: pl.LazyFrame | pl.DataFrame | None = None,
|
||
metadata: dict | None = None,
|
||
title: str = "Significant Differences Summary<br>(Groups with significantly different means)",
|
||
height: int | None = None,
|
||
width: int | None = None,
|
||
) -> alt.Chart:
|
||
"""Create a summary bar chart showing which groups have significant differences.
|
||
|
||
This shows each group with a count of how many other groups it differs from
|
||
significantly, plus the mean score or Rank 1 percentage for reference.
|
||
|
||
Args:
|
||
pairwise_df: Output from compute_pairwise_significance() or compute_ranking_significance().
|
||
metadata: Metadata dict from the significance computation (optional).
|
||
title: Chart title
|
||
height: Chart height
|
||
width: Chart width
|
||
|
||
Returns:
|
||
alt.Chart: Altair bar chart with significance count per group
|
||
"""
|
||
df = self._ensure_dataframe(pairwise_df)
|
||
|
||
# Detect data type: continuous (has mean1/mean2) vs ranking (has rank1_pct1/rank1_pct2)
|
||
has_means = 'mean1' in df.columns
|
||
has_ranks = 'rank1_pct1' in df.columns
|
||
|
||
# Count significant differences per group
|
||
sig_df = df.filter(pl.col('significant') == True)
|
||
|
||
# Count for each group (appears as either group1 or group2)
|
||
group1_counts = sig_df.group_by('group1').agg(pl.len().alias('count'))
|
||
group2_counts = sig_df.group_by('group2').agg(pl.len().alias('count'))
|
||
|
||
# Combine counts
|
||
all_groups = sorted(set(df['group1'].to_list() + df['group2'].to_list()))
|
||
summary_data = []
|
||
|
||
for group in all_groups:
|
||
count1 = group1_counts.filter(pl.col('group1') == group)['count'].to_list()
|
||
count2 = group2_counts.filter(pl.col('group2') == group)['count'].to_list()
|
||
total_sig = (count1[0] if count1 else 0) + (count2[0] if count2 else 0)
|
||
|
||
# Get score for this group from pairwise data
|
||
if has_means:
|
||
# Continuous data - use means
|
||
scores = df.filter(pl.col('group1') == group)['mean1'].to_list()
|
||
if not scores:
|
||
scores = df.filter(pl.col('group2') == group)['mean2'].to_list()
|
||
score_val = scores[0] if scores else None
|
||
score_label = 'mean'
|
||
elif has_ranks:
|
||
# Ranking data - use Rank 1 percentage
|
||
scores = df.filter(pl.col('group1') == group)['rank1_pct1'].to_list()
|
||
if not scores:
|
||
scores = df.filter(pl.col('group2') == group)['rank1_pct2'].to_list()
|
||
score_val = scores[0] if scores else None
|
||
score_label = 'rank1_pct'
|
||
else:
|
||
score_val = None
|
||
score_label = 'score'
|
||
|
||
summary_data.append({
|
||
'group': group,
|
||
'sig_count': total_sig,
|
||
'score': score_val,
|
||
})
|
||
|
||
summary_df = pl.DataFrame(summary_data).sort('score', descending=True, nulls_last=True).to_pandas()
|
||
|
||
# Create layered chart: bars for sig_count, text for score
|
||
tooltip_title = 'Mean Score' if has_means else 'Rank 1 %' if has_ranks else 'Score'
|
||
|
||
bars = alt.Chart(summary_df).mark_bar(color=ColorPalette.PRIMARY).encode(
|
||
x=alt.X('group:N', title='Group', sort='-y'),
|
||
y=alt.Y('sig_count:Q', title='# of Significant Differences'),
|
||
tooltip=[
|
||
alt.Tooltip('group:N', title='Group'),
|
||
alt.Tooltip('sig_count:Q', title='Sig. Differences'),
|
||
alt.Tooltip('score:Q', title=tooltip_title, format='.1f'),
|
||
]
|
||
)
|
||
|
||
# Only add text labels if we have scores
|
||
if summary_df['score'].notna().any():
|
||
text_format = '.1f' if has_means else '.0f'
|
||
text_suffix = '%' if has_ranks else ''
|
||
text = alt.Chart(summary_df).mark_text(
|
||
dy=-8,
|
||
color='black',
|
||
fontSize=9
|
||
).encode(
|
||
x=alt.X('group:N', sort='-y'),
|
||
y=alt.Y('sig_count:Q'),
|
||
text=alt.Text('score:Q', format=text_format)
|
||
)
|
||
chart_layers = bars + text
|
||
else:
|
||
chart_layers = bars
|
||
|
||
# Build subtitle
|
||
subtitle = None
|
||
if metadata:
|
||
if has_means:
|
||
subtitle = f"Mean scores shown above bars | α={metadata.get('alpha', 0.05)}"
|
||
elif has_ranks:
|
||
subtitle = f"Rank 1 % shown above bars | α={metadata.get('alpha', 0.05)}"
|
||
else:
|
||
subtitle = f"α={metadata.get('alpha', 0.05)}"
|
||
|
||
title_config = {
|
||
'text': self._process_title(title),
|
||
'subtitle': subtitle,
|
||
'subtitleColor': 'gray',
|
||
'subtitleFontSize': 10,
|
||
'anchor': 'start',
|
||
}
|
||
|
||
chart = chart_layers.properties(
|
||
title=title_config,
|
||
width=width or 800,
|
||
height=height or getattr(self, 'plot_height', 400),
|
||
)
|
||
|
||
chart = self._save_plot(chart, title)
|
||
return chart |