"""Plotting functions for Voice Branding analysis using Altair.""" import re import math from pathlib import Path import altair as alt import pandas as pd import polars as pl from theme import ColorPalette from reference import VOICE_GENDER_MAPPING import hashlib class QualtricsPlotsMixin: """Mixin class for plotting functions in QualtricsSurvey.""" def _process_title(self, title: str) -> str | list[str]: """Process title to handle
tags for Altair.""" if isinstance(title, str) and '
' in title: return title.split('
') return title def _sanitize_filename(self, title: str) -> str: """Convert plot title to a safe filename.""" # Remove HTML tags clean = re.sub(r'<[^>]+>', ' ', title) # Replace special characters with underscores clean = re.sub(r'[^\w\s-]', '', clean) # Replace whitespace with underscores clean = re.sub(r'\s+', '_', clean.strip()) # Remove consecutive underscores clean = re.sub(r'_+', '_', clean) # Lowercase and limit length return clean.lower()[:100] def _get_filter_slug(self) -> str: """Generate a directory-friendly slug based on active filters.""" parts = [] # Mapping of attribute name to (short_code, value, options_attr) filters = [ ('age', 'Age', getattr(self, 'filter_age', None), 'options_age'), ('gender', 'Gen', getattr(self, 'filter_gender', None), 'options_gender'), ('consumer', 'Cons', getattr(self, 'filter_consumer', None), 'options_consumer'), ('ethnicity', 'Eth', getattr(self, 'filter_ethnicity', None), 'options_ethnicity'), ('income', 'Inc', getattr(self, 'filter_income', None), 'options_income'), ('business_owner', 'BizOwn', getattr(self, 'filter_business_owner', None), 'options_business_owner'), ('employment_status', 'Emp', getattr(self, 'filter_employment_status', None), 'options_employment_status'), ('personal_products', 'Prod', getattr(self, 'filter_personal_products', None), 'options_personal_products'), ('ai_user', 'AI', getattr(self, 'filter_ai_user', None), 'options_ai_user'), ('investable_assets', 'InvAsts', getattr(self, 'filter_investable_assets', None), 'options_investable_assets'), ('industry', 'Ind', getattr(self, 'filter_industry', None), 'options_industry'), ] for _, short_code, value, options_attr in filters: if value is None: continue # Ensure value is a list for uniform handling if not isinstance(value, list): value = [value] if len(value) == 0: continue # Check if all options are selected (equivalent to no filter) # We compare the set of selected values to the set of all available options master_list = getattr(self, options_attr, None) if master_list and set(value) == set(master_list): continue if len(value) > 3: # If more than 3 options selected, create a hash of the sorted values # This ensures uniqueness properly while keeping the slug short sorted_vals = sorted([str(v) for v in value]) vals_str = "".join(sorted_vals) # Create short 6-char hash val_hash = hashlib.md5(vals_str.encode()).hexdigest()[:6] val_str = f"{len(value)}_grps_{val_hash}" else: # Join values with '+' clean_values = [] for v in value: # Simple sanitization: keep alphanum and hyphens/dots, remove others s = str(v) # Remove special chars that might be problematic in dir names s = re.sub(r'[^\w\-\.]', '', s) clean_values.append(s) val_str = "+".join(clean_values) parts.append(f"{short_code}-{val_str}") if not parts: return "All_Respondents" return "_".join(parts) def _get_filter_description(self) -> str: """Generate a human-readable description of active filters.""" parts = [] # Mapping of attribute name to (display_name, value, options_attr) filters = [ ('Age', getattr(self, 'filter_age', None), 'options_age'), ('Gender', getattr(self, 'filter_gender', None), 'options_gender'), ('Consumer', getattr(self, 'filter_consumer', None), 'options_consumer'), ('Ethnicity', getattr(self, 'filter_ethnicity', None), 'options_ethnicity'), ('Income', getattr(self, 'filter_income', None), 'options_income'), ('Business Owner', getattr(self, 'filter_business_owner', None), 'options_business_owner'), ('Employment Status', getattr(self, 'filter_employment_status', None), 'options_employment_status'), ('Personal Products', getattr(self, 'filter_personal_products', None), 'options_personal_products'), ('AI User', getattr(self, 'filter_ai_user', None), 'options_ai_user'), ('Investable Assets', getattr(self, 'filter_investable_assets', None), 'options_investable_assets'), ('Industry', getattr(self, 'filter_industry', None), 'options_industry'), ] for display_name, value, options_attr in filters: if value is None: continue # Ensure value is a list for uniform handling if not isinstance(value, list): value = [value] if len(value) == 0: continue # Check if all options are selected (equivalent to no filter) master_list = getattr(self, options_attr, None) if master_list and set(value) == set(master_list): continue # Use original values for display (full list) clean_values = [str(v) for v in value] val_str = ", ".join(clean_values) # Use UPPERCASE for category name to distinguish from values parts.append(f"{display_name.upper()}: {val_str}") if not parts: return "" # Join with clear separator - double space for visual break return "Filters: " + " — ".join(parts) def _add_filter_footnote(self, chart: alt.Chart) -> alt.Chart: """Add a footnote with active filters to the chart. Uses chart subtitle for filter text to avoid layout issues with vconcat. Returns the modified chart (or original if no filters). """ filter_text = self._get_filter_description() # Skip if no filters active - return original chart if not filter_text: return chart # Wrap text into multiple lines at ~100 chars, but don't break mid-word max_line_length = 100 words = filter_text.split() lines = [] current_line = "" for word in words: test_line = f"{current_line} {word}".strip() if current_line else word if len(test_line) <= max_line_length: current_line = test_line else: if current_line: lines.append(current_line) current_line = word if current_line: lines.append(current_line) # Get existing title from chart spec chart_spec = chart.to_dict() existing_title = chart_spec.get('title', '') # Handle different title formats (string vs dict vs list) if isinstance(existing_title, (str, list)): title_config = { 'text': existing_title, 'subtitle': lines, 'subtitleColor': 'gray', 'subtitleFontSize': 10, 'anchor': 'start', } elif isinstance(existing_title, dict): title_config = existing_title.copy() title_config['subtitle'] = lines title_config['subtitleColor'] = 'gray' title_config['subtitleFontSize'] = 10 title_config['anchor'] = 'start' else: # No existing title, just add filters as subtitle title_config = { 'text': '', 'subtitle': lines, 'subtitleColor': 'gray', 'subtitleFontSize': 10, 'anchor': 'start', } return chart.properties(title=title_config) def _save_plot(self, chart: alt.Chart, title: str) -> alt.Chart: """Save chart to PNG file if fig_save_dir is set. Returns the (potentially modified) chart with filter footnote added. """ # Add filter footnote - returns combined chart if filters active chart = self._add_filter_footnote(chart) if hasattr(self, 'fig_save_dir') and self.fig_save_dir: path = Path(self.fig_save_dir) # Add filter slug subfolder filter_slug = self._get_filter_slug() path = path / filter_slug if not path.exists(): path.mkdir(parents=True, exist_ok=True) filename = f"{self._sanitize_filename(title)}.png" filepath = path / filename # Use vl_convert directly with theme config for consistent rendering import vl_convert as vlc from theme import jpmc_altair_theme # Get chart spec and theme config chart_spec = chart.to_dict() theme_config = jpmc_altair_theme()['config'] png_data = vlc.vegalite_to_png( vl_spec=chart_spec, scale=2.0, ppi=72, config=theme_config ) with open(filepath, 'wb') as f: f.write(png_data) print(f"Saved plot to {filepath}") return chart def _ensure_dataframe(self, data: pl.LazyFrame | pl.DataFrame | None) -> pl.DataFrame: """Ensure data is an eager DataFrame, collecting if necessary.""" df = data if data is not None else getattr(self, 'data_filtered', None) if df is None: raise ValueError("No data provided and self.data_filtered is None.") if isinstance(df, pl.LazyFrame): return df.collect() return df def _clean_voice_label(self, col_name: str) -> str: """Extract and clean voice name from column name for display. Handles patterns like: - 'Voice_Scale__The_Coach' -> 'The Coach' - 'Character_Ranking_The_Coach' -> 'The Coach' - 'Top_3_Voices_ranking__Familiar_Friend' -> 'Familiar Friend' """ # First split by __ if present label = col_name.split('__')[-1] if '__' in col_name else col_name # Remove common prefixes label = label.replace('Character_Ranking_', '') label = label.replace('Top_3_Voices_ranking_', '') # Replace underscores with spaces label = label.replace('_', ' ').strip() return label def _get_voice_gender(self, voice_label: str) -> str: """Get the gender of a voice from its label. Parameters: voice_label: Voice label (e.g., 'V14', 'Voice 14', etc.) Returns: 'Male' or 'Female', defaults to 'Male' if not found """ # Extract voice code (e.g., 'V14' from 'Voice 14' or 'V14') voice_code = None # Try to find VXX pattern match = re.search(r'V(\d+)', voice_label) if match: voice_code = f"V{match.group(1)}" else: # Try to extract number and prepend V match = re.search(r'(\d+)', voice_label) if match: voice_code = f"V{match.group(1)}" if voice_code and voice_code in VOICE_GENDER_MAPPING: return VOICE_GENDER_MAPPING[voice_code] return "Male" # Default to Male if unknown def _get_gender_color(self, gender: str, color_type: str = "primary") -> str: """Get the appropriate color based on gender. Parameters: gender: 'Male' or 'Female' color_type: One of 'primary', 'rank_1', 'rank_2', 'rank_3', 'neutral' Returns: Hex color string """ color_map = { "Male": { "primary": ColorPalette.GENDER_MALE, "rank_1": ColorPalette.GENDER_MALE_RANK_1, "rank_2": ColorPalette.GENDER_MALE_RANK_2, "rank_3": ColorPalette.GENDER_MALE_RANK_3, "neutral": ColorPalette.GENDER_MALE_NEUTRAL, }, "Female": { "primary": ColorPalette.GENDER_FEMALE, "rank_1": ColorPalette.GENDER_FEMALE_RANK_1, "rank_2": ColorPalette.GENDER_FEMALE_RANK_2, "rank_3": ColorPalette.GENDER_FEMALE_RANK_3, "neutral": ColorPalette.GENDER_FEMALE_NEUTRAL, } } return color_map.get(gender, color_map["Male"]).get(color_type, ColorPalette.PRIMARY) def plot_average_scores_with_counts( self, data: pl.LazyFrame | pl.DataFrame | None = None, title: str = "General Impression (1-10)\nPer Voice with Number of Participants Who Rated It", x_label: str = "Stimuli", y_label: str = "Average General Impression Rating (1-10)", color: str = ColorPalette.PRIMARY, height: int | None = None, width: int | str | None = None, domain: list[float] | None = None, color_gender: bool = False, ) -> alt.Chart: """Create a bar plot showing average scores and count of non-null values for each column. Parameters: color_gender: If True, color bars by voice gender (blue=male, pink=female). """ df = self._ensure_dataframe(data) # Calculate stats for each column (exclude _recordId) stats = [] for col in [c for c in df.columns if c != '_recordId']: avg_score = df[col].mean() non_null_count = df[col].drop_nulls().len() label = self._clean_voice_label(col) gender = self._get_voice_gender(label) if color_gender else None stats.append({ 'voice': label, 'average': avg_score, 'count': non_null_count, 'gender': gender }) # Convert to pandas for Altair (sort by average descending) stats_df = pl.DataFrame(stats).sort('average', descending=True).to_pandas() if domain is None: domain = [stats_df['average'].min(), stats_df['average'].max()] # Base bar chart - use y2 to explicitly start bars at domain minimum if color_gender: bars = alt.Chart(stats_df).mark_bar().encode( x=alt.X('voice:N', title=x_label, sort='-y'), y=alt.Y('average:Q', title=y_label, scale=alt.Scale(domain=domain)), y2=alt.datum(domain[0]), # Bars start at domain minimum (bottom edge) color=alt.Color('gender:N', scale=alt.Scale(domain=['Male', 'Female'], range=[ColorPalette.GENDER_MALE, ColorPalette.GENDER_FEMALE]), legend=alt.Legend(orient='top', direction='horizontal', title='Gender')), tooltip=[ alt.Tooltip('voice:N', title='Voice'), alt.Tooltip('average:Q', title='Average', format='.2f'), alt.Tooltip('count:Q', title='Count'), alt.Tooltip('gender:N', title='Gender') ] ) else: bars = alt.Chart(stats_df).mark_bar(color=color).encode( x=alt.X('voice:N', title=x_label, sort='-y'), y=alt.Y('average:Q', title=y_label, scale=alt.Scale(domain=domain)), y2=alt.datum(domain[0]), # Bars start at domain minimum (bottom edge) tooltip=[ alt.Tooltip('voice:N', title='Voice'), alt.Tooltip('average:Q', title='Average', format='.2f'), alt.Tooltip('count:Q', title='Count') ] ) # Text overlay for counts text = alt.Chart(stats_df).mark_text( dy=-5, color='black', fontSize=10 ).encode( x=alt.X('voice:N', sort='-y'), y=alt.Y('average:Q'), text=alt.Text('count:Q') ) # Combine layers chart = (bars + text).properties( title=self._process_title(title), width=width or 800, height=height or getattr(self, 'plot_height', 400) ) chart = self._save_plot(chart, title) return chart def plot_top3_ranking_distribution( self, data: pl.LazyFrame | pl.DataFrame | None = None, title: str = "Top 3 Rankings Distribution\nCount of 1st, 2nd, and 3rd Place Votes per Voice", x_label: str = "Voices", y_label: str = "Number of Mentions in Top 3", height: int | None = None, width: int | str | None = None, ) -> alt.Chart: """Create a stacked bar chart showing how often each voice was ranked 1st, 2nd, or 3rd.""" df = self._ensure_dataframe(data) # Calculate stats per column stats = [] for col in [c for c in df.columns if c != '_recordId']: rank1 = df.filter(pl.col(col) == 1).height rank2 = df.filter(pl.col(col) == 2).height rank3 = df.filter(pl.col(col) == 3).height total = rank1 + rank2 + rank3 if total > 0: label = self._clean_voice_label(col) # Add 3 rows (one per rank) stats.append({'voice': label, 'rank': 'Rank 1 (1st Choice)', 'count': rank1, 'total': total}) stats.append({'voice': label, 'rank': 'Rank 2 (2nd Choice)', 'count': rank2, 'total': total}) stats.append({'voice': label, 'rank': 'Rank 3 (3rd Choice)', 'count': rank3, 'total': total}) # Convert to long format, sort by total stats_df = pl.DataFrame(stats).to_pandas() # Interactive legend selection - click to filter selection = alt.selection_point(fields=['rank'], bind='legend') # Create stacked bar chart with interactive legend chart = alt.Chart(stats_df).mark_bar().encode( x=alt.X('voice:N', title=x_label, sort=alt.EncodingSortField(field='total', op='sum', order='descending')), y=alt.Y('count:Q', title=y_label, stack='zero'), color=alt.Color('rank:N', scale=alt.Scale(domain=['Rank 1 (1st Choice)', 'Rank 2 (2nd Choice)', 'Rank 3 (3rd Choice)'], range=[ColorPalette.RANK_1, ColorPalette.RANK_2, ColorPalette.RANK_3]), legend=alt.Legend(orient='top', direction='horizontal', title=None)), order=alt.Order('rank:N', sort='ascending'), opacity=alt.condition(selection, alt.value(1), alt.value(0.2)), tooltip=[ alt.Tooltip('voice:N', title='Voice'), alt.Tooltip('rank:N', title='Rank'), alt.Tooltip('count:Q', title='Count') ] ).add_params(selection).properties( title=self._process_title(title), width=width or 800, height=height or getattr(self, 'plot_height', 400) ) chart = self._save_plot(chart, title) return chart def plot_ranking_distribution( self, data: pl.LazyFrame | pl.DataFrame | None = None, title: str = "Rankings Distribution\n(1st to 3rd Place)", x_label: str = "Item", y_label: str = "Number of Votes", height: int | None = None, width: int | str | None = None, color_gender: bool = False, ) -> alt.Chart: """Create a stacked bar chart showing the distribution of rankings (1st to 3rd). Parameters: color_gender: If True, color bars by voice gender with rank intensity (blue shades=male, pink shades=female). """ df = self._ensure_dataframe(data) stats = [] ranking_cols = [c for c in df.columns if c != '_recordId'] for col in ranking_cols: r1 = df.filter(pl.col(col) == 1).height r2 = df.filter(pl.col(col) == 2).height r3 = df.filter(pl.col(col) == 3).height # r4 = df.filter(pl.col(col) == 4).height total = r1 + r2 + r3 if total > 0: label = self._clean_voice_label(col) gender = self._get_voice_gender(label) if color_gender else None stats.append({'item': label, 'rank': 'Rank 1 (Best)', 'count': r1, 'total': total, 'gender': gender, 'rank_order': 1}) stats.append({'item': label, 'rank': 'Rank 2', 'count': r2, 'total': total, 'gender': gender, 'rank_order': 2}) stats.append({'item': label, 'rank': 'Rank 3', 'count': r3, 'total': total, 'gender': gender, 'rank_order': 3}) # stats.append({'item': label, 'rank': 'Rank 4 (Worst)', 'count': r4, 'total': total, 'gender': gender, 'rank_order': 4}) if not stats: return alt.Chart(pd.DataFrame({'text': ['No data']})).mark_text().encode(text='text:N') stats_df = pl.DataFrame(stats).to_pandas() # Interactive legend selection - click to filter selection = alt.selection_point(fields=['rank'], bind='legend') if color_gender: # Add gender_rank column for combined color encoding stats_df['gender_rank'] = stats_df['gender'] + ' - ' + stats_df['rank'] # Define combined domain and range for gender + rank domain = [ 'Male - Rank 1 (Best)', 'Male - Rank 2', 'Male - Rank 3', 'Female - Rank 1 (Best)', 'Female - Rank 2', 'Female - Rank 3' ] range_colors = [ ColorPalette.GENDER_MALE_RANK_1, ColorPalette.GENDER_MALE_RANK_2, ColorPalette.GENDER_MALE_RANK_3, ColorPalette.GENDER_FEMALE_RANK_1, ColorPalette.GENDER_FEMALE_RANK_2, ColorPalette.GENDER_FEMALE_RANK_3 ] chart = alt.Chart(stats_df).mark_bar().encode( x=alt.X('item:N', title=x_label, sort=alt.EncodingSortField(field='total', order='descending')), y=alt.Y('count:Q', title=y_label, stack='zero'), color=alt.Color('gender_rank:N', scale=alt.Scale(domain=domain, range=range_colors), legend=alt.Legend(orient='top', direction='horizontal', title=None, columns=3)), order=alt.Order('rank_order:Q', sort='ascending'), opacity=alt.condition(selection, alt.value(1), alt.value(0.2)), tooltip=[ alt.Tooltip('item:N', title='Item'), alt.Tooltip('rank:N', title='Rank'), alt.Tooltip('count:Q', title='Count'), alt.Tooltip('gender:N', title='Gender') ] ).add_params(selection).properties( title=self._process_title(title), width=width or 800, height=height or getattr(self, 'plot_height', 400) ) else: chart = alt.Chart(stats_df).mark_bar().encode( x=alt.X('item:N', title=x_label, sort=alt.EncodingSortField(field='total', order='descending')), y=alt.Y('count:Q', title=y_label, stack='zero'), color=alt.Color('rank:N', scale=alt.Scale(domain=['Rank 1 (Best)', 'Rank 2', 'Rank 3'], range=[ColorPalette.RANK_1, ColorPalette.RANK_2, ColorPalette.RANK_3]), legend=alt.Legend(orient='top', direction='horizontal', title=None)), order=alt.Order('rank_order:Q', sort='ascending'), opacity=alt.condition(selection, alt.value(1), alt.value(0.2)), tooltip=[ alt.Tooltip('item:N', title='Item'), alt.Tooltip('rank:N', title='Rank'), alt.Tooltip('count:Q', title='Count') ] ).add_params(selection).properties( title=self._process_title(title), width=width or 800, height=height or getattr(self, 'plot_height', 400) ) chart = self._save_plot(chart, title) return chart def plot_most_ranked_1( self, data: pl.LazyFrame | pl.DataFrame | None = None, title: str = "Most Popular Choice\n(Number of Times Ranked 1st)", x_label: str = "Item", y_label: str = "Count of 1st Place Rankings", height: int | None = None, width: int | str | None = None, color_gender: bool = False, ) -> alt.Chart: """Create a bar chart showing which item was ranked #1 the most. Top 3 highlighted. Parameters: color_gender: If True, color bars by voice gender with highlight/neutral intensity (blue shades=male, pink shades=female). """ df = self._ensure_dataframe(data) stats = [] ranking_cols = [c for c in df.columns if c != '_recordId'] for col in ranking_cols: count_rank_1 = df.filter(pl.col(col) == 1).height label = self._clean_voice_label(col) gender = self._get_voice_gender(label) if color_gender else None stats.append({'item': label, 'count': count_rank_1, 'gender': gender}) # Convert and sort stats_df = pl.DataFrame(stats).sort('count', descending=True) # Add rank column for coloring (1-3 vs 4+) stats_df = stats_df.with_row_index('rank_index') stats_df = stats_df.with_columns( pl.when(pl.col('rank_index') < 3) .then(pl.lit('Top 3')) .otherwise(pl.lit('Other')) .alias('category') ).to_pandas() if color_gender: # Add gender_category column for combined color encoding stats_df['gender_category'] = stats_df['gender'] + ' - ' + stats_df['category'] # Define combined domain and range for gender + category domain = ['Male - Top 3', 'Male - Other', 'Female - Top 3', 'Female - Other'] range_colors = [ ColorPalette.GENDER_MALE, ColorPalette.GENDER_MALE_NEUTRAL, ColorPalette.GENDER_FEMALE, ColorPalette.GENDER_FEMALE_NEUTRAL ] chart = alt.Chart(stats_df).mark_bar().encode( x=alt.X('item:N', title=x_label, sort='-y'), y=alt.Y('count:Q', title=y_label), color=alt.Color('gender_category:N', scale=alt.Scale(domain=domain, range=range_colors), legend=alt.Legend(orient='top', direction='horizontal', title=None)), tooltip=[ alt.Tooltip('item:N', title='Item'), alt.Tooltip('count:Q', title='1st Place Votes'), alt.Tooltip('gender:N', title='Gender') ] ).properties( title=self._process_title(title), width=width or 800, height=height or getattr(self, 'plot_height', 400) ) else: # Bar chart with conditional color chart = alt.Chart(stats_df).mark_bar().encode( x=alt.X('item:N', title=x_label, sort='-y'), y=alt.Y('count:Q', title=y_label), color=alt.Color('category:N', scale=alt.Scale(domain=['Top 3', 'Other'], range=[ColorPalette.PRIMARY, ColorPalette.NEUTRAL]), legend=None), tooltip=[ alt.Tooltip('item:N', title='Item'), alt.Tooltip('count:Q', title='1st Place Votes') ] ).properties( title=self._process_title(title), width=width or 800, height=height or getattr(self, 'plot_height', 400) ) chart = self._save_plot(chart, title) return chart def plot_weighted_ranking_score( self, data: pl.LazyFrame | pl.DataFrame | None = None, title: str = "Weighted Popularity Score\n(1st=3pts, 2nd=2pts, 3rd=1pt)", x_label: str = "Character Personality", y_label: str = "Total Weighted Score", color: str = ColorPalette.PRIMARY, height: int | None = None, width: int | str | None = None, color_gender: bool = False, ) -> alt.Chart: """Create a bar chart showing the weighted ranking score for each character. Parameters: color_gender: If True, color bars by voice gender (blue=male, pink=female). """ weighted_df = self._ensure_dataframe(data).to_pandas() if color_gender: # Add gender column based on Character name weighted_df['gender'] = weighted_df['Character'].apply(self._get_voice_gender) # Bar chart with gender coloring bars = alt.Chart(weighted_df).mark_bar().encode( x=alt.X('Character:N', title=x_label, sort='-y'), y=alt.Y('Weighted Score:Q', title=y_label), color=alt.Color('gender:N', scale=alt.Scale(domain=['Male', 'Female'], range=[ColorPalette.GENDER_MALE, ColorPalette.GENDER_FEMALE]), legend=alt.Legend(orient='top', direction='horizontal', title='Gender')), tooltip=[ alt.Tooltip('Character:N'), alt.Tooltip('Weighted Score:Q', title='Score'), alt.Tooltip('gender:N', title='Gender') ] ) else: # Bar chart bars = alt.Chart(weighted_df).mark_bar(color=color).encode( x=alt.X('Character:N', title=x_label, sort='-y'), y=alt.Y('Weighted Score:Q', title=y_label), tooltip=[ alt.Tooltip('Character:N'), alt.Tooltip('Weighted Score:Q', title='Score') ] ) # Text overlay text = bars.mark_text( dy=-5, color='white', fontSize=11 ).encode( text='Weighted Score:Q' ) chart = (bars + text).properties( title=self._process_title(title), width=width or 800, height=height or getattr(self, 'plot_height', 400) ) chart = self._save_plot(chart, title) return chart def plot_voice_selection_counts( self, data: pl.LazyFrame | pl.DataFrame | None = None, target_column: str = "8_Combined", title: str = "Most Frequently Chosen Voices\n(Top 8 Highlighted)", x_label: str = "Voice", y_label: str = "Number of Times Chosen", height: int | None = None, width: int | str | None = None, color_gender: bool = False, ) -> alt.Chart: """Create a bar plot showing the frequency of voice selections. Parameters: color_gender: If True, color bars by voice gender with highlight/neutral intensity (blue shades=male, pink shades=female). """ df = self._ensure_dataframe(data) if target_column not in df.columns: return alt.Chart(pd.DataFrame({'text': [f"Column '{target_column}' not found"]})).mark_text().encode(text='text:N') # Process data: split, explode, count stats_df = ( df.select(pl.col(target_column)) .drop_nulls() .with_columns(pl.col(target_column).str.split(",")) .explode(target_column) .with_columns(pl.col(target_column).str.strip_chars()) .filter(pl.col(target_column) != "") .group_by(target_column) .agg(pl.len().alias("count")) .sort("count", descending=True) .with_row_index('rank_index') .with_columns( pl.when(pl.col('rank_index') < 8) .then(pl.lit('Top 8')) .otherwise(pl.lit('Other')) .alias('category') ) .to_pandas() ) if color_gender: # Add gender column based on voice label stats_df['gender'] = stats_df[target_column].apply(self._get_voice_gender) # Add gender_category column for combined color encoding stats_df['gender_category'] = stats_df['gender'] + ' - ' + stats_df['category'] # Define combined domain and range for gender + category domain = ['Male - Top 8', 'Male - Other', 'Female - Top 8', 'Female - Other'] range_colors = [ ColorPalette.GENDER_MALE, ColorPalette.GENDER_MALE_NEUTRAL, ColorPalette.GENDER_FEMALE, ColorPalette.GENDER_FEMALE_NEUTRAL ] chart = alt.Chart(stats_df).mark_bar().encode( x=alt.X(f'{target_column}:N', title=x_label, sort='-y'), y=alt.Y('count:Q', title=y_label), color=alt.Color('gender_category:N', scale=alt.Scale(domain=domain, range=range_colors), legend=alt.Legend(orient='top', direction='horizontal', title=None)), tooltip=[ alt.Tooltip(f'{target_column}:N', title='Voice'), alt.Tooltip('count:Q', title='Selections'), alt.Tooltip('gender:N', title='Gender') ] ).properties( title=self._process_title(title), width=width or 800, height=height or getattr(self, 'plot_height', 400) ) else: chart = alt.Chart(stats_df).mark_bar().encode( x=alt.X(f'{target_column}:N', title=x_label, sort='-y'), y=alt.Y('count:Q', title=y_label), color=alt.Color('category:N', scale=alt.Scale(domain=['Top 8', 'Other'], range=[ColorPalette.PRIMARY, ColorPalette.NEUTRAL]), legend=None), tooltip=[ alt.Tooltip(f'{target_column}:N', title='Voice'), alt.Tooltip('count:Q', title='Selections') ] ).properties( title=self._process_title(title), width=width or 800, height=height or getattr(self, 'plot_height', 400) ) chart = self._save_plot(chart, title) return chart def plot_top3_selection_counts( self, data: pl.LazyFrame | pl.DataFrame | None = None, target_column: str = "3_Ranked", title: str = "Most Frequently Chosen Top 3 Voices\n(Top 3 Highlighted)", x_label: str = "Voice", y_label: str = "Count of Mentions in Top 3", height: int | None = None, width: int | str | None = None, color_gender: bool = False, ) -> alt.Chart: """Question: Which 3 voices are chosen the most out of 18? Parameters: color_gender: If True, color bars by voice gender with highlight/neutral intensity (blue shades=male, pink shades=female). """ df = self._ensure_dataframe(data) if target_column not in df.columns: return alt.Chart(pd.DataFrame({'text': [f"Column '{target_column}' not found"]})).mark_text().encode(text='text:N') stats_df = ( df.select(pl.col(target_column)) .drop_nulls() .with_columns(pl.col(target_column).str.split(",")) .explode(target_column) .with_columns(pl.col(target_column).str.strip_chars()) .filter(pl.col(target_column) != "") .group_by(target_column) .agg(pl.len().alias("count")) .sort("count", descending=True) .with_row_index('rank_index') .with_columns( pl.when(pl.col('rank_index') < 3) .then(pl.lit('Top 3')) .otherwise(pl.lit('Other')) .alias('category') ) .to_pandas() ) if color_gender: # Add gender column based on voice label stats_df['gender'] = stats_df[target_column].apply(self._get_voice_gender) # Add gender_category column for combined color encoding stats_df['gender_category'] = stats_df['gender'] + ' - ' + stats_df['category'] # Define combined domain and range for gender + category domain = ['Male - Top 3', 'Male - Other', 'Female - Top 3', 'Female - Other'] range_colors = [ ColorPalette.GENDER_MALE, ColorPalette.GENDER_MALE_NEUTRAL, ColorPalette.GENDER_FEMALE, ColorPalette.GENDER_FEMALE_NEUTRAL ] chart = alt.Chart(stats_df).mark_bar().encode( x=alt.X(f'{target_column}:N', title=x_label, sort='-y'), y=alt.Y('count:Q', title=y_label), color=alt.Color('gender_category:N', scale=alt.Scale(domain=domain, range=range_colors), legend=alt.Legend(orient='top', direction='horizontal', title=None)), tooltip=[ alt.Tooltip(f'{target_column}:N', title='Voice'), alt.Tooltip('count:Q', title='In Top 3'), alt.Tooltip('gender:N', title='Gender') ] ).properties( title=self._process_title(title), width=width or 800, height=height or getattr(self, 'plot_height', 400) ) else: chart = alt.Chart(stats_df).mark_bar().encode( x=alt.X(f'{target_column}:N', title=x_label, sort='-y'), y=alt.Y('count:Q', title=y_label), color=alt.Color('category:N', scale=alt.Scale(domain=['Top 3', 'Other'], range=[ColorPalette.PRIMARY, ColorPalette.NEUTRAL]), legend=None), tooltip=[ alt.Tooltip(f'{target_column}:N', title='Voice'), alt.Tooltip('count:Q', title='In Top 3') ] ).properties( title=self._process_title(title), width=width or 800, height=height or getattr(self, 'plot_height', 400) ) chart = self._save_plot(chart, title) return chart def plot_speaking_style_trait_scores( self, data: pl.LazyFrame | pl.DataFrame | None = None, trait_description: str = None, left_anchor: str = None, right_anchor: str = None, title: str = "Speaking Style Trait Analysis", height: int | None = None, width: int | str | None = None, ) -> alt.Chart: """Plot scores for a single speaking style trait across multiple voices.""" df = self._ensure_dataframe(data) if df.is_empty(): return alt.Chart(pd.DataFrame({'text': ['No data']})).mark_text().encode(text='text:N') required_cols = ["Voice", "score"] if not all(col in df.columns for col in required_cols): return alt.Chart(pd.DataFrame({'text': ['Missing required columns']})).mark_text().encode(text='text:N') # Calculate stats: Mean, Count stats = ( df.filter(pl.col("score").is_not_null()) .group_by("Voice") .agg([ pl.col("score").mean().alias("mean_score"), pl.col("score").count().alias("count") ]) .sort("mean_score", descending=False) # Ascending for bottom-to-top display .to_pandas() ) # Extract anchors from data if not provided if (left_anchor is None or right_anchor is None) and "Left_Anchor" in df.columns: head = df.filter(pl.col("Left_Anchor").is_not_null()).head(1) if not head.is_empty(): if left_anchor is None: left_anchor = head["Left_Anchor"][0] if right_anchor is None: right_anchor = head["Right_Anchor"][0] if trait_description is None: if left_anchor and right_anchor: trait_description = f"{left_anchor.split('|')[0]} vs. {right_anchor.split('|')[0]}" elif "Description" in df.columns: head = df.filter(pl.col("Description").is_not_null()).head(1) trait_description = head["Description"][0] if not head.is_empty() else "" else: trait_description = "" # Horizontal bar chart - use x2 to explicitly start bars at x=1 bars = alt.Chart(stats).mark_bar(color=ColorPalette.PRIMARY).encode( x=alt.X('mean_score:Q', title='Average Score (1-5)', scale=alt.Scale(domain=[1, 5])), x2=alt.datum(1), # Bars start at x=1 (left edge of domain) y=alt.Y('Voice:N', title='Voice', sort='-x'), tooltip=[ alt.Tooltip('Voice:N'), alt.Tooltip('mean_score:Q', title='Average', format='.2f'), alt.Tooltip('count:Q', title='Count') ] ) # Count text at end of bars (right-aligned inside bar) text = alt.Chart(stats).mark_text( align='right', baseline='middle', color='white', fontSize=12, dx=-5 # Slight padding from bar end ).encode( x='mean_score:Q', y=alt.Y('Voice:N', sort='-x'), text='count:Q' ) # Combine layers chart = (bars + text).properties( title={ "text": self._process_title(title), "subtitle": [trait_description, "(Numbers on bars indicate respondent count)"] }, width=width or 800, height=height or getattr(self, 'plot_height', 400) ) chart = self._save_plot(chart, title) return chart def plot_speaking_style_correlation( self, style_color: str, style_traits: list[str], data: pl.LazyFrame | pl.DataFrame | None = None, title: str | None = None, width: int | str | None = None, height: int | None = None, ) -> alt.Chart: """Plots correlation between Speaking Style Trait Scores (1-5) and Voice Scale (1-10).""" df = self._ensure_dataframe(data) if title is None: title = f"Speaking style and voice scale 1-10 correlations" trait_correlations = [] # Calculate correlations for i, trait in enumerate(style_traits): subset = df.filter(pl.col("Right_Anchor") == trait) valid_data = subset.select(["score", "Voice_Scale_Score"]).drop_nulls() if valid_data.height > 1: corr_val = valid_data.select(pl.corr("score", "Voice_Scale_Score")).item() # Wrap trait text at '|' for display trait_display = trait.replace('|', '\n') trait_correlations.append({ "trait_display": trait_display, "trait_index": f"Trait {i+1}", "correlation": corr_val if corr_val is not None else 0.0 }) if not trait_correlations: return alt.Chart(pd.DataFrame({'text': [f"No data for {style_color} Style"]})).mark_text().encode(text='text:N') plot_df = pl.DataFrame(trait_correlations).to_pandas() # Conditional color based on sign chart = alt.Chart(plot_df).mark_bar().encode( x=alt.X('trait_display:N', title=None, axis=alt.Axis(labelAngle=0)), y=alt.Y('correlation:Q', title='Correlation', scale=alt.Scale(domain=[-1, 1])), color=alt.condition( alt.datum.correlation >= 0, alt.value('green'), alt.value('red') ), tooltip=[ alt.Tooltip('trait_display:N', title='Trait'), alt.Tooltip('correlation:Q', format='.2f') ] ).properties( title=self._process_title(title), width=width or 800, height=height or 350 ) chart = self._save_plot(chart, title) return chart def plot_speaking_style_color_correlation( self, data: pl.LazyFrame | pl.DataFrame | None = None, title: str = "Speaking Style and Voice Scale 1-10 Correlations
(Average by Color)", width: int | str | None = None, height: int | None = None, ) -> alt.Chart: """Plot high-level correlation showing one bar per speaking style color. Original use-case: "I want to create high-level correlation plots between 'green, blue, orange, red' speaking styles and the 'voice scale scores'. I want to go to one plot with one bar for each color." Args: data: DataFrame with columns [Color, correlation, n_traits] from utils.transform_speaking_style_color_correlation title: Chart title (supports
for line breaks) width: Chart width in pixels height: Chart height in pixels Returns: Altair chart with one bar per speaking style color """ df = self._ensure_dataframe(data) # Conditional color based on sign (matches plot_speaking_style_correlation) chart = alt.Chart(df.to_pandas()).mark_bar().encode( x=alt.X('Color:N', title=None, axis=alt.Axis(labelAngle=0), sort=["Green", "Blue", "Orange", "Red"]), y=alt.Y('correlation:Q', title='Average Correlation', scale=alt.Scale(domain=[-1, 1])), color=alt.condition( alt.datum.correlation >= 0, alt.value('green'), alt.value('red') ), tooltip=[ alt.Tooltip('Color:N', title='Speaking Style'), alt.Tooltip('correlation:Q', format='.3f', title='Avg Correlation'), alt.Tooltip('n_traits:Q', title='# Traits') ] ).properties( title=self._process_title(title), width=width or 400, height=height or 350 ) chart = self._save_plot(chart, title) return chart def plot_demographic_distribution( self, column: str, data: pl.LazyFrame | pl.DataFrame | None = None, title: str | None = None, height: int | None = None, width: int | str | None = None, show_counts: bool = True, ) -> alt.Chart: """Create a horizontal bar chart showing the distribution of respondents by a demographic column. Designed to be compact so multiple charts (approx. 6) can fit on one slide. Uses horizontal bars for better readability with many categories. Parameters: column: The column name to analyze (e.g., 'Age', 'Gender', 'Race/Ethnicity'). data: Optional DataFrame. If None, uses self.data_filtered. title: Chart title. If None, auto-generates based on column name. height: Chart height in pixels (default: auto-sized based on categories). width: Chart width in pixels (default: 280 for compact layout). show_counts: If True, display count labels on the bars. Returns: alt.Chart: An Altair horizontal bar chart showing the distribution. """ df = self._ensure_dataframe(data) if column not in df.columns: return alt.Chart(pd.DataFrame({'text': [f"Column '{column}' not found"]})).mark_text().encode(text='text:N') # Count values in the column, including nulls stats_df = ( df.select(pl.col(column)) .with_columns(pl.col(column).fill_null("(No Response)")) .group_by(column) .agg(pl.len().alias("count")) .sort("count", descending=True) .to_pandas() ) if stats_df.empty: return alt.Chart(pd.DataFrame({'text': ['No data']})).mark_text().encode(text='text:N') # Calculate percentages total = stats_df['count'].sum() stats_df['percentage'] = (stats_df['count'] / total * 100).round(1) # Generate title if not provided if title is None: clean_col = column.replace('_', ' ').replace('/', ' / ') title = f"Distribution: {clean_col}" # Calculate appropriate height based on number of categories num_categories = len(stats_df) bar_height = 18 # pixels per bar calculated_height = max(120, num_categories * bar_height + 40) # min 120px, +40 for title/padding # Horizontal bar chart - categories on Y axis, counts on X axis bars = alt.Chart(stats_df).mark_bar(color=ColorPalette.PRIMARY).encode( x=alt.X('count:Q', title='Count', axis=alt.Axis(grid=False)), y=alt.Y(f'{column}:N', title=None, sort='-x', axis=alt.Axis(labelLimit=150)), tooltip=[ alt.Tooltip(f'{column}:N', title=column.replace('_', ' ')), alt.Tooltip('count:Q', title='Count'), alt.Tooltip('percentage:Q', title='Percentage', format='.1f') ] ) # Add count labels at end of bars if show_counts: text = alt.Chart(stats_df).mark_text( align='left', baseline='middle', dx=3, # Offset from bar end fontSize=9, color=ColorPalette.TEXT ).encode( x='count:Q', y=alt.Y(f'{column}:N', sort='-x'), text='count:Q' ) chart = (bars + text) else: chart = bars # Compact dimensions for 6-per-slide layout chart = chart.properties( title=self._process_title(title), width=width or 200, height=height or calculated_height ) chart = self._save_plot(chart, title) return chart def plot_speaking_style_ranking_correlation( self, style_color: str, style_traits: list[str], data: pl.LazyFrame | pl.DataFrame | None = None, title: str | None = None, width: int | str | None = None, height: int | None = None, ) -> alt.Chart: """Plots correlation between Speaking Style Trait Scores (1-5) and Voice Ranking Points (0-3).""" df = self._ensure_dataframe(data) if title is None: title = f"Speaking style {style_color} and voice ranking points correlations" trait_correlations = [] for i, trait in enumerate(style_traits): subset = df.filter(pl.col("Right_Anchor") == trait) valid_data = subset.select(["score", "Ranking_Points"]).drop_nulls() if valid_data.height > 1: corr_val = valid_data.select(pl.corr("score", "Ranking_Points")).item() trait_display = trait.replace('|', '\n') trait_correlations.append({ "trait_display": trait_display, "trait_index": f"Trait {i+1}", "correlation": corr_val if corr_val is not None else 0.0 }) if not trait_correlations: return alt.Chart(pd.DataFrame({'text': [f"No data for {style_color} Style"]})).mark_text().encode(text='text:N') plot_df = pl.DataFrame(trait_correlations).to_pandas() chart = alt.Chart(plot_df).mark_bar().encode( x=alt.X('trait_display:N', title=None, axis=alt.Axis(labelAngle=0)), y=alt.Y('correlation:Q', title='Correlation', scale=alt.Scale(domain=[-1, 1])), color=alt.condition( alt.datum.correlation >= 0, alt.value('green'), alt.value('red') ), tooltip=[ alt.Tooltip('trait_display:N', title='Trait'), alt.Tooltip('correlation:Q', format='.2f') ] ).properties( title=self._process_title(title), width=width or 800, height=height or 350 ) chart = self._save_plot(chart, title) return chart def plot_traits_wordcloud( self, data: pl.LazyFrame | pl.DataFrame | None = None, column: str = 'Top_3_Traits', title: str = "Most Prominent Personality Traits", width: int = 1600, height: int = 800, background_color: str = 'white', random_state: int = 23, ): """Create a word cloud visualization of personality traits from survey data. Args: data: Polars DataFrame or LazyFrame containing trait data column: Name of column containing comma-separated traits title: Title for the word cloud width: Width of the word cloud image in pixels height: Height of the word cloud image in pixels background_color: Background color for the word cloud random_state: Random seed for reproducible word cloud generation (default: 23) Returns: matplotlib.figure.Figure: The word cloud figure for display in notebooks """ import matplotlib.pyplot as plt from wordcloud import WordCloud from collections import Counter import random df = self._ensure_dataframe(data) # Extract and split traits traits_list = [] for row in df[column].drop_nulls(): # Split by comma and clean whitespace traits = [trait.strip() for trait in row.split(',')] traits_list.extend(traits) # Create frequency dictionary trait_freq = Counter(traits_list) # Set random seed for color selection random.seed(random_state) # Color function using JPMC colors def color_func(word, font_size, position, orientation, random_state=None, **kwargs): colors = [ ColorPalette.PRIMARY, ColorPalette.RANK_1, ColorPalette.RANK_2, ColorPalette.RANK_3, ] return random.choice(colors) # Generate word cloud wordcloud = WordCloud( width=width, height=height, background_color=background_color, color_func=color_func, relative_scaling=0.5, min_font_size=10, prefer_horizontal=0.7, collocations=False, # Treat each word independently random_state=random_state # Seed for reproducible layout ).generate_from_frequencies(trait_freq) # Create matplotlib figure fig, ax = plt.subplots(figsize=(width/100, height/100), dpi=100) ax.imshow(wordcloud, interpolation='bilinear') ax.axis('off') ax.set_title(title, fontsize=16, pad=20, color=ColorPalette.TEXT) plt.tight_layout(pad=0) # Save figure if directory specified (using same pattern as other plots) if hasattr(self, 'fig_save_dir') and self.fig_save_dir: save_path = Path(self.fig_save_dir) # Add filter slug subfolder filter_slug = self._get_filter_slug() save_path = save_path / filter_slug if not save_path.exists(): save_path.mkdir(parents=True, exist_ok=True) # Use _sanitize_filename for consistency filename = f"{self._sanitize_filename(title)}.png" filepath = save_path / filename # Save as PNG at high resolution fig.savefig(filepath, dpi=300, bbox_inches='tight', facecolor='white') print(f"Word cloud saved to: {filepath}") return fig def plot_character_trait_frequency( self, data: pl.LazyFrame | pl.DataFrame | None = None, title: str = "Trait Frequency per Brand Character", x_label: str = "Trait", y_label: str = "Frequency (Times Chosen)", height: int | None = None, width: int | str | None = None, ) -> alt.Chart: """Create a grouped bar plot showing how often each trait is chosen per character. Original request: "I need a bar plot that shows the frequency of the times each trait is chosen per brand character" Expects data with columns: Character, Trait, Count (as produced by transform_character_trait_frequency). """ df = self._ensure_dataframe(data) # Ensure we have the expected columns required_cols = {'Character', 'Trait', 'Count'} if not required_cols.issubset(set(df.columns)): return alt.Chart(pd.DataFrame({'text': ['Data must have Character, Trait, Count columns']})).mark_text().encode(text='text:N') # Convert to pandas for Altair plot_df = df.to_pandas() if hasattr(df, 'to_pandas') else df # Calculate total per trait for sorting (traits with highest overall frequency first) trait_totals = plot_df.groupby('Trait')['Count'].sum().sort_values(ascending=False) trait_order = trait_totals.index.tolist() # Get unique characters for color mapping characters = plot_df['Character'].unique().tolist() # Interactive legend selection - click to filter selection = alt.selection_point(fields=['Character'], bind='legend') chart = alt.Chart(plot_df).mark_bar().encode( x=alt.X('Trait:N', title=x_label, sort=trait_order, axis=alt.Axis(labelAngle=-45, labelLimit=200)), y=alt.Y('Count:Q', title=y_label), xOffset='Character:N', color=alt.Color('Character:N', scale=alt.Scale(domain=characters, range=ColorPalette.CATEGORICAL[:len(characters)]), legend=alt.Legend(orient='top', direction='horizontal', title='Character')), opacity=alt.condition(selection, alt.value(1), alt.value(0.2)), tooltip=[ alt.Tooltip('Character:N', title='Character'), alt.Tooltip('Trait:N', title='Trait'), alt.Tooltip('Count:Q', title='Frequency') ] ).add_params(selection).properties( title=self._process_title(title), width=width or 900, height=height or getattr(self, 'plot_height', 400) ) chart = self._save_plot(chart, title) return chart def plot_single_character_trait_frequency( self, data: pl.LazyFrame | pl.DataFrame | None = None, character_name: str = "Character", bar_color: str = ColorPalette.PRIMARY, highlight_color: str = ColorPalette.NEUTRAL, title: str | None = None, x_label: str = "Trait", y_label: str = "Frequency", trait_sort_order: list[str] | None = None, height: int | None = None, width: int | str | None = None, ) -> alt.Chart: """Create a bar plot showing trait frequency for a single character. Original request: "I need a bar plot that shows the frequency of the times each trait is chosen per brand character. The function should be generalized so that it can be used 4 times, once for each character. Each character should use a slightly different color. Original traits should be highlighted." This function creates one plot per character. Call it 4 times (once per character) to generate all plots for a slide. Args: data: DataFrame with columns ['trait', 'count', 'is_original'] as produced by transform_character_trait_frequency() character_name: Name of the character (for title). E.g., "Bank Teller" bar_color: Main bar color for non-original traits. Use ColorPalette constants like ColorPalette.CHARACTER_BANK_TELLER highlight_color: Lighter color for original/expected traits. Use the matching highlight like ColorPalette.CHARACTER_BANK_TELLER_HIGHLIGHT title: Custom title. If None, auto-generates from character_name x_label: X-axis label y_label: Y-axis label trait_sort_order: Optional list of traits for consistent sorting across all character plots. If None, sorts by count descending. height: Chart height width: Chart width Returns: alt.Chart: Altair bar chart """ df = self._ensure_dataframe(data) # Ensure we have the expected columns required_cols = {'trait', 'count', 'is_original'} if not required_cols.issubset(set(df.columns)): return alt.Chart(pd.DataFrame({ 'text': ['Data must have trait, count, is_original columns'] })).mark_text().encode(text='text:N') # Convert to pandas for Altair plot_df = df.to_pandas() if hasattr(df, 'to_pandas') else df # Determine sort order if trait_sort_order is not None: # Use provided order, append any missing traits at the end (sorted by count) known_traits = set(trait_sort_order) extra_traits = plot_df[~plot_df['trait'].isin(known_traits)].sort_values( 'count', ascending=False )['trait'].tolist() sort_order = trait_sort_order + extra_traits else: # Default: sort by count descending sort_order = plot_df.sort_values('count', ascending=False)['trait'].tolist() # Create category column for color encoding plot_df['category'] = plot_df['is_original'].map({ True: 'Original Trait', False: 'Other Trait' }) # Generate title if not provided if title is None: title = f"{character_name}
Trait Selection Frequency" # Build title config with sort order note as subtitle sort_note = "Sorted by total frequency across all characters" if trait_sort_order else "Sorted by frequency (descending)" title_text = self._process_title(title) title_config = { 'text': title_text, 'subtitle': sort_note, 'subtitleColor': 'gray', 'subtitleFontSize': 10, 'anchor': 'start', } # Create HORIZONTAL bar chart with conditional coloring # Reverse sort order for horizontal bars (highest at top) reversed_sort = list(reversed(sort_order)) bars = alt.Chart(plot_df).mark_bar().encode( y=alt.Y('trait:N', title=x_label, sort=reversed_sort, axis=alt.Axis(labelLimit=200)), x=alt.X('count:Q', title=y_label), color=alt.Color('category:N', scale=alt.Scale( domain=['Original Trait', 'Other Trait'], range=[highlight_color, bar_color] ), legend=alt.Legend( orient='top', direction='horizontal', title=None )), tooltip=[ alt.Tooltip('trait:N', title='Trait'), alt.Tooltip('count:Q', title='Frequency'), alt.Tooltip('category:N', title='Type') ] ) # Add count labels on bars (to the right of bars for horizontal) text = alt.Chart(plot_df).mark_text( dx=12, color='black', fontSize=10, align='left' ).encode( y=alt.Y('trait:N', sort=reversed_sort), x=alt.X('count:Q'), text=alt.Text('count:Q') ) chart = (bars + text).properties( title=title_config, width=width or 400, height=height or getattr(self, 'plot_height', 450) ) chart = self._save_plot(chart, title) return chart def plot_significance_heatmap( self, pairwise_df: pl.LazyFrame | pl.DataFrame | None = None, metadata: dict | None = None, title: str = "Pairwise Statistical Significance
(Adjusted p-values)", show_p_values: bool = True, show_effect_size: bool = False, height: int | None = None, width: int | None = None, ) -> alt.Chart: """Create a heatmap showing pairwise statistical significance between groups. Original use-case: "I need to test for statistical significance and present this in a logical manner - as a heatmap or similar visualization." This function visualizes the output of compute_pairwise_significance() as a color-coded heatmap where color intensity indicates significance level. Args: pairwise_df: Output from compute_pairwise_significance(). Expected columns: ['group1', 'group2', 'p_value', 'p_adjusted', 'significant'] metadata: Metadata dict from compute_pairwise_significance() (optional). Used to add test information to the plot subtitle. title: Chart title (supports
for line breaks) show_p_values: Whether to display p-values as text annotations show_effect_size: Whether to display effect sizes instead of p-values height: Chart height (default: auto-sized based on groups) width: Chart width (default: auto-sized based on groups) Returns: alt.Chart: Altair heatmap chart """ df = self._ensure_dataframe(pairwise_df) # Get unique groups all_groups = sorted(set(df['group1'].to_list() + df['group2'].to_list())) n_groups = len(all_groups) # Create symmetric matrix data for heatmap # We need both directions (A,B) and (B,A) for the full matrix heatmap_data = [] for row_group in all_groups: for col_group in all_groups: if row_group == col_group: # Diagonal - self comparison heatmap_data.append({ 'row': row_group, 'col': col_group, 'p_adjusted': None, 'p_value': None, 'significant': None, 'effect_size': None, 'text_label': '—', 'sig_category': 'Self', }) else: # Find the comparison (could be in either order) match = df.filter( ((pl.col('group1') == row_group) & (pl.col('group2') == col_group)) | ((pl.col('group1') == col_group) & (pl.col('group2') == row_group)) ) if match.height > 0: p_adj = match['p_adjusted'][0] p_val = match['p_value'][0] sig = match['significant'][0] eff = match['effect_size'][0] if 'effect_size' in match.columns else None # For ranking data, we can show Rank 1 % difference has_rank_pcts = 'rank1_pct1' in match.columns and 'rank1_pct2' in match.columns if has_rank_pcts: pct_diff = abs(match['rank1_pct1'][0] - match['rank1_pct2'][0]) else: pct_diff = None # Helper to get display text when not showing p-values def get_alt_text(): if eff is not None: return f'{eff:.2f}' elif pct_diff is not None: return f'{pct_diff:.1f}%' else: return '—' # Categorize significance level if p_adj is None: sig_cat = 'N/A' text = 'N/A' elif p_adj < 0.001: sig_cat = 'p < 0.001' text = '<.001' if show_p_values else get_alt_text() elif p_adj < 0.01: sig_cat = 'p < 0.01' text = f'{p_adj:.3f}' if show_p_values else get_alt_text() elif p_adj < 0.05: sig_cat = 'p < 0.05' text = f'{p_adj:.3f}' if show_p_values else get_alt_text() else: sig_cat = 'n.s.' text = f'{p_adj:.2f}' if show_p_values else get_alt_text() if show_effect_size: text = get_alt_text() heatmap_data.append({ 'row': row_group, 'col': col_group, 'p_adjusted': p_adj, 'p_value': p_val, 'significant': sig, 'effect_size': eff, 'text_label': text, 'sig_category': sig_cat, }) else: heatmap_data.append({ 'row': row_group, 'col': col_group, 'p_adjusted': None, 'p_value': None, 'significant': None, 'effect_size': None, 'text_label': 'N/A', 'sig_category': 'N/A', }) heatmap_df = pl.DataFrame(heatmap_data).to_pandas() # Define color scale for significance categories sig_domain = ['p < 0.001', 'p < 0.01', 'p < 0.05', 'n.s.', 'Self', 'N/A'] sig_range = [ ColorPalette.SIG_STRONG, # p < 0.001 ColorPalette.SIG_MODERATE, # p < 0.01 ColorPalette.SIG_WEAK, # p < 0.05 ColorPalette.SIG_NONE, # not significant ColorPalette.SIG_DIAGONAL, # diagonal (self) ColorPalette.NEUTRAL, # N/A ] # Build tooltip fields based on available data tooltip_fields = [ alt.Tooltip('row:N', title='Group 1'), alt.Tooltip('col:N', title='Group 2'), alt.Tooltip('p_value:Q', title='p-value', format='.4f'), alt.Tooltip('p_adjusted:Q', title='Adjusted p', format='.4f'), ] # Only add effect_size if it has non-null values (continuous data) has_effect_size = 'effect_size' in heatmap_df.columns and heatmap_df['effect_size'].notna().any() if has_effect_size: tooltip_fields.append(alt.Tooltip('effect_size:Q', title='Effect Size', format='.3f')) # Add rank info for ranking data has_rank_pcts = 'rank1_pct1' in df.columns if isinstance(df, pl.DataFrame) else False if has_rank_pcts: tooltip_fields.append(alt.Tooltip('text_label:N', title='Rank 1 % Diff')) # Calculate dimensions cell_size = 45 auto_size = n_groups * cell_size + 100 chart_width = width or auto_size chart_height = height or auto_size # Base heatmap heatmap = alt.Chart(heatmap_df).mark_rect(stroke='white', strokeWidth=1).encode( x=alt.X('col:N', title=None, sort=all_groups, axis=alt.Axis(labelAngle=-45, labelLimit=150)), y=alt.Y('row:N', title=None, sort=all_groups, axis=alt.Axis(labelLimit=150)), color=alt.Color('sig_category:N', scale=alt.Scale(domain=sig_domain, range=sig_range), legend=alt.Legend( title='Significance', orient='right', direction='vertical' )), tooltip=tooltip_fields ) # Text annotations if show_p_values or show_effect_size: # Add a column for text color based on significance heatmap_df['text_color'] = heatmap_df['sig_category'].apply( lambda x: 'white' if x in ['p < 0.001', 'p < 0.01'] else 'black' ) text = alt.Chart(heatmap_df).mark_text( fontSize=9, fontWeight='normal' ).encode( x=alt.X('col:N', sort=all_groups), y=alt.Y('row:N', sort=all_groups), text='text_label:N', color=alt.Color('text_color:N', scale=None), ) chart = (heatmap + text) else: chart = heatmap # Build subtitle with test info subtitle_lines = [] if metadata: test_info = f"Test: {metadata.get('test_type', 'N/A')}" if metadata.get('overall_p_value') is not None: test_info += f" | Overall p={metadata['overall_p_value']:.4f}" correction = metadata.get('correction', 'none') if correction != 'none': test_info += f" | Correction: {correction}" subtitle_lines.append(test_info) title_config = { 'text': self._process_title(title), 'subtitle': subtitle_lines if subtitle_lines else None, 'subtitleColor': 'gray', 'subtitleFontSize': 10, 'anchor': 'start', } chart = chart.properties( title=title_config, width=chart_width, height=chart_height, ) chart = self._save_plot(chart, title) return chart def plot_significance_summary( self, pairwise_df: pl.LazyFrame | pl.DataFrame | None = None, metadata: dict | None = None, title: str = "Significant Differences Summary
(Groups with significantly different means)", height: int | None = None, width: int | None = None, ) -> alt.Chart: """Create a summary bar chart showing which groups have significant differences. This shows each group with a count of how many other groups it differs from significantly, plus the mean score or Rank 1 percentage for reference. Args: pairwise_df: Output from compute_pairwise_significance() or compute_ranking_significance(). metadata: Metadata dict from the significance computation (optional). title: Chart title height: Chart height width: Chart width Returns: alt.Chart: Altair bar chart with significance count per group """ df = self._ensure_dataframe(pairwise_df) # Detect data type: continuous (has mean1/mean2) vs ranking (has rank1_pct1/rank1_pct2) has_means = 'mean1' in df.columns has_ranks = 'rank1_pct1' in df.columns # Count significant differences per group sig_df = df.filter(pl.col('significant') == True) # Count for each group (appears as either group1 or group2) group1_counts = sig_df.group_by('group1').agg(pl.len().alias('count')) group2_counts = sig_df.group_by('group2').agg(pl.len().alias('count')) # Combine counts all_groups = sorted(set(df['group1'].to_list() + df['group2'].to_list())) summary_data = [] for group in all_groups: count1 = group1_counts.filter(pl.col('group1') == group)['count'].to_list() count2 = group2_counts.filter(pl.col('group2') == group)['count'].to_list() total_sig = (count1[0] if count1 else 0) + (count2[0] if count2 else 0) # Get score for this group from pairwise data if has_means: # Continuous data - use means scores = df.filter(pl.col('group1') == group)['mean1'].to_list() if not scores: scores = df.filter(pl.col('group2') == group)['mean2'].to_list() score_val = scores[0] if scores else None score_label = 'mean' elif has_ranks: # Ranking data - use Rank 1 percentage scores = df.filter(pl.col('group1') == group)['rank1_pct1'].to_list() if not scores: scores = df.filter(pl.col('group2') == group)['rank1_pct2'].to_list() score_val = scores[0] if scores else None score_label = 'rank1_pct' else: score_val = None score_label = 'score' summary_data.append({ 'group': group, 'sig_count': total_sig, 'score': score_val, }) summary_df = pl.DataFrame(summary_data).sort('score', descending=True, nulls_last=True).to_pandas() # Create layered chart: bars for sig_count, text for score tooltip_title = 'Mean Score' if has_means else 'Rank 1 %' if has_ranks else 'Score' bars = alt.Chart(summary_df).mark_bar(color=ColorPalette.PRIMARY).encode( x=alt.X('group:N', title='Group', sort='-y'), y=alt.Y('sig_count:Q', title='# of Significant Differences'), tooltip=[ alt.Tooltip('group:N', title='Group'), alt.Tooltip('sig_count:Q', title='Sig. Differences'), alt.Tooltip('score:Q', title=tooltip_title, format='.1f'), ] ) # Only add text labels if we have scores if summary_df['score'].notna().any(): text_format = '.1f' if has_means else '.0f' text_suffix = '%' if has_ranks else '' text = alt.Chart(summary_df).mark_text( dy=-8, color='black', fontSize=9 ).encode( x=alt.X('group:N', sort='-y'), y=alt.Y('sig_count:Q'), text=alt.Text('score:Q', format=text_format) ) chart_layers = bars + text else: chart_layers = bars # Build subtitle subtitle = None if metadata: if has_means: subtitle = f"Mean scores shown above bars | α={metadata.get('alpha', 0.05)}" elif has_ranks: subtitle = f"Rank 1 % shown above bars | α={metadata.get('alpha', 0.05)}" else: subtitle = f"α={metadata.get('alpha', 0.05)}" title_config = { 'text': self._process_title(title), 'subtitle': subtitle, 'subtitleColor': 'gray', 'subtitleFontSize': 10, 'anchor': 'start', } chart = chart_layers.properties( title=title_config, width=width or 800, height=height or getattr(self, 'plot_height', 400), ) chart = self._save_plot(chart, title) return chart