diff --git a/02_quant_analysis.py b/02_quant_analysis.py
index 187f463..d0f7b1e 100644
--- a/02_quant_analysis.py
+++ b/02_quant_analysis.py
@@ -12,9 +12,6 @@ def _():
from validation import check_progress, duration_validation
from utils import JPMCSurvey, combine_exclusive_columns, calculate_weighted_ranking_scores
- from plots import plot_average_scores_with_counts, plot_top3_ranking_distribution, plot_ranking_distribution, plot_most_ranked_1, plot_weighted_ranking_score, plot_voice_selection_counts, plot_top3_selection_counts
-
- import plots
import utils
from speaking_styles import SPEAKING_STYLES
@@ -27,13 +24,6 @@ def _():
duration_validation,
mo,
pl,
- plot_most_ranked_1,
- plot_ranking_distribution,
- plot_top3_ranking_distribution,
- plot_top3_selection_counts,
- plot_voice_selection_counts,
- plot_weighted_ranking_score,
- plots,
utils,
)
@@ -47,10 +37,10 @@ def _():
@app.cell
def _(JPMCSurvey, QSF_FILE, RESULTS_FILE):
- survey = JPMCSurvey(RESULTS_FILE, QSF_FILE)
- data_all = survey.load_data()
+ S = JPMCSurvey(RESULTS_FILE, QSF_FILE)
+ data_all = S.load_data()
data_all.collect()
- return data_all, survey
+ return S, data_all
@app.cell
@@ -108,18 +98,22 @@ def _(mo):
@app.cell(hide_code=True)
def _(data_all, mo):
data_all_collected = data_all.collect()
- ages = mo.ui.multiselect(options=data_all_collected["QID1"], value=data_all_collected["QID1"].unique(), label="Select Age Group(s):")
+ age = mo.ui.multiselect(options=data_all_collected["QID1"], value=data_all_collected["QID1"].unique(), label="Select Age Group(s):")
income = mo.ui.multiselect(data_all_collected["QID15"], value=data_all_collected["QID15"], label="Select Income Group(s):")
gender = mo.ui.multiselect(data_all_collected["QID2"], value=data_all_collected["QID2"], label="Select Gender(s)")
ethnicity = mo.ui.multiselect(data_all_collected["QID3"], value=data_all_collected["QID3"], label="Select Ethnicities:")
consumer = mo.ui.multiselect(data_all_collected["Consumer"], value=data_all_collected["Consumer"], label="Select Consumer Groups:")
+ return age, consumer, ethnicity, gender, income
+@app.cell
+def _(age, consumer, ethnicity, gender, income, mo):
+
mo.md(f"""
# Data Filters
- {ages}
+ {age}
{gender}
@@ -130,12 +124,14 @@ def _(data_all, mo):
{consumer}
""")
- return ages, consumer, ethnicity, gender, income
+
+
+ return
@app.cell
-def _(ages, consumer, data_all, ethnicity, gender, income, survey):
- data = survey.filter_data(data_all, age=ages.value, gender=gender.value, income=income.value, ethnicity=ethnicity.value, consumer=consumer.value)
+def _(S, age, consumer, data_all, ethnicity, gender, income):
+ data = S.filter_data(data_all, age=age.value, gender=gender.value, income=income.value, ethnicity=ethnicity.value, consumer=consumer.value)
data.collect()
return (data,)
@@ -159,49 +155,42 @@ def _(mo):
@app.cell
-def _(data, survey):
- char_rank = survey.get_character_ranking(data)[0].collect()
+def _(S, data):
+ char_rank = S.get_character_ranking(data)[0]
return (char_rank,)
@app.cell
-def _(char_rank, mo, plot_top3_ranking_distribution, survey):
+def _(S, char_rank, mo):
mo.md(f"""
### 1. Which character personality is ranked best?
- {mo.ui.plotly(plot_top3_ranking_distribution(char_rank, x_label='Character Personality', width=1000, results_dir=survey.fig_save_dir))}
+ {mo.ui.plotly(S.plot_top3_ranking_distribution(char_rank, x_label='Character Personality', width=1000))}
""")
return
@app.cell
-def _(char_rank, mo, plot_most_ranked_1, survey):
+def _(S, char_rank, mo):
mo.md(f"""
### 2. Which character personality is ranked 1st the most?
- {mo.ui.plotly(plot_most_ranked_1(char_rank, title="Most Popular Character
(Number of Times Ranked 1st)", x_label='Character Personality', width=1000, results_dir=survey.fig_save_dir))}
+ {mo.ui.plotly(S.plot_most_ranked_1(char_rank, title="Most Popular Character
(Number of Times Ranked 1st)", x_label='Character Personality', width=1000))}
""")
return
@app.cell
-def _(
- calculate_weighted_ranking_scores,
- char_rank,
- mo,
- plot_weighted_ranking_score,
- survey,
-):
+def _(S, calculate_weighted_ranking_scores, char_rank, mo):
char_rank_weighted = calculate_weighted_ranking_scores(char_rank)
- # plot_weighted_ranking_score(char_rank_weighted, x_label='Voice', width=1000)
mo.md(f"""
### 3. Which character personality most popular based on weighted scores?
- {mo.ui.plotly(plot_weighted_ranking_score(char_rank_weighted, title="Most Popular Character - Weighted Popularity Score
(1st=3pts, 2nd=2pts, 3rd=1pt)", x_label='Voice', width=1000, results_dir=survey.fig_save_dir))}
+ {mo.ui.plotly(S.plot_weighted_ranking_score(char_rank_weighted, title="Most Popular Character - Weighted Popularity Score
(1st=3pts, 2nd=2pts, 3rd=1pt)", x_label='Voice', width=1000))}
""")
return
@@ -215,73 +204,73 @@ def _(mo):
@app.cell
-def _(data, survey):
- v_18_8_3 = survey.get_18_8_3(data)[0].collect()
+def _(S, data):
+ v_18_8_3 = S.get_18_8_3(data)[0].collect()
# print(v_18_8_3.head())
return (v_18_8_3,)
@app.cell(hide_code=True)
-def _(mo, plot_voice_selection_counts, survey, v_18_8_3):
+def _(S, mo, v_18_8_3):
mo.md(f"""
### Which 8 voices are chosen the most out of 18?
- {mo.ui.plotly(plot_voice_selection_counts(v_18_8_3, height=500, width=1000, results_dir=survey.fig_save_dir))}
+ {mo.ui.plotly(S.plot_voice_selection_counts(v_18_8_3, height=500, width=1000))}
""")
return
@app.cell(hide_code=True)
-def _(mo, plot_top3_selection_counts, survey, v_18_8_3):
+def _(S, mo, v_18_8_3):
mo.md(f"""
### Which 3 voices are chosen the most out of 18?
How many times does each voice end up in the top 3? ( this is based on the survey question where participants need to choose 3 out of the earlier selected 8 voices. So how often each of the 18 stimuli ended up in participants’ Top 3, after they first selected 8 out of 18.
- {mo.ui.plotly(plot_top3_selection_counts(v_18_8_3, height=500, width=1000, results_dir=survey.fig_save_dir))}
+ {mo.ui.plotly(S.plot_top3_selection_counts(v_18_8_3, height=500, width=1000))}
""")
return
@app.cell(hide_code=True)
-def _(calculate_weighted_ranking_scores, data, survey):
- top3_voices = survey.get_top_3_voices(data)[0].collect()
+def _(S, calculate_weighted_ranking_scores, data):
+ top3_voices = S.get_top_3_voices(data)[0]
top3_voices_weighted = calculate_weighted_ranking_scores(top3_voices)
return top3_voices, top3_voices_weighted
@app.cell
-def _(mo, plot_ranking_distribution, survey, top3_voices):
+def _(S, mo, top3_voices):
mo.md(f"""
### Which voice is ranked best in the ranking question for top 3?
(not best 3 out of 8 question)
- {mo.ui.plotly(plot_ranking_distribution(top3_voices, x_label='Voice', width=1000, results_dir=survey.fig_save_dir))}
+ {mo.ui.plotly(S.plot_ranking_distribution(top3_voices, x_label='Voice', width=1000))}
""")
return
@app.cell
-def _(mo, plot_weighted_ranking_score, survey, top3_voices_weighted):
+def _(S, mo, top3_voices_weighted):
mo.md(f"""
### Most popular **voice** based on weighted scores?
- E.g. 1 point for place 3. 2 points for place 2 and 3 points for place 1. The voice with most points is ranked best.
Distribution of the rankings for each voice:
- {mo.ui.plotly(plot_weighted_ranking_score(top3_voices_weighted, title="Most Popular Voice - Weighted Popularity Score
(1st = 3pts, 2nd = 2pts, 3rd = 1pt)", height=500, width=1000, results_dir=survey.fig_save_dir))}
+ {mo.ui.plotly(S.plot_weighted_ranking_score(top3_voices_weighted, title="Most Popular Voice - Weighted Popularity Score
(1st = 3pts, 2nd = 2pts, 3rd = 1pt)", height=500, width=1000))}
""")
return
@app.cell
-def _(mo, plot_most_ranked_1, survey, top3_voices):
+def _(S, mo, top3_voices):
mo.md(f"""
### Which voice is ranked number 1 the most?
(not always the voice with most points)
- {mo.ui.plotly(plot_most_ranked_1(top3_voices, title="Most Popular Voice
(Number of Times Ranked 1st)", x_label='Voice', width=1000, results_dir=survey.fig_save_dir))}
+ {mo.ui.plotly(S.plot_most_ranked_1(top3_voices, title="Most Popular Voice
(Number of Times Ranked 1st)", x_label='Voice', width=1000))}
""")
return
@@ -297,9 +286,9 @@ def _(mo):
@app.cell
-def _(data, survey, utils):
- ss_or, choice_map_or = survey.get_ss_orange_red(data)
- ss_gb, choice_map_gb = survey.get_ss_green_blue(data)
+def _(S, data, utils):
+ ss_or, choice_map_or = S.get_ss_orange_red(data)
+ ss_gb, choice_map_gb = S.get_ss_green_blue(data)
# Combine the data
ss_all = ss_or.join(ss_gb, on='_recordId')
@@ -313,7 +302,7 @@ def _(data, survey, utils):
@app.cell
-def _(mo, pl, plots, ss_long, survey):
+def _(S, mo, pl, ss_long):
content = """### How does each voice score for each “speaking style labeled trait”?"""
for i, trait in enumerate(ss_long.select("Description").unique().to_series().to_list()):
@@ -322,7 +311,7 @@ def _(mo, pl, plots, ss_long, survey):
content += f"""
### {i+1}) {trait.replace(":", " ↔ ")}
- {mo.ui.plotly(plots.plot_speaking_style_trait_scores(trait_d, title=trait.replace(":", " ↔ "), height=550, results_dir=survey.fig_save_dir))}
+ {mo.ui.plotly(S.plot_speaking_style_trait_scores(trait_d, title=trait.replace(":", " ↔ "), height=550))}
"""
mo.md(content)
@@ -338,18 +327,18 @@ def _(mo):
@app.cell
-def _(data, survey):
- vscales = survey.get_voice_scale_1_10(data)[0].collect()
+def _(S, data):
+ vscales = S.get_voice_scale_1_10(data)[0]
# plot_average_scores_with_counts(vscales, x_label='Voice', width=1000)
return (vscales,)
@app.cell
-def _(mo, plots, survey, vscales):
+def _(S, mo, vscales):
mo.md(f"""
### How does each voice score on a scale from 1-10?
- {mo.ui.plotly(plots.plot_average_scores_with_counts(vscales, x_label='Voice', width=1000, results_dir=survey.fig_save_dir))}
+ {mo.ui.plotly(S.plot_average_scores_with_counts(vscales, x_label='Voice', width=1000))}
""")
return
@@ -394,7 +383,7 @@ def _(mo):
@app.cell
def _(choice_map, ss_all, utils, vscales):
- df_style = utils.process_speaking_style_data(ss_all.collect(), choice_map)
+ df_style = utils.process_speaking_style_data(ss_all, choice_map)
df_voice_long = utils.process_voice_scale_data(vscales)
joined_df = df_style.join(df_voice_long, on=["_recordId", "Voice"], how="inner")
@@ -403,19 +392,18 @@ def _(choice_map, ss_all, utils, vscales):
@app.cell
-def _(SPEAKING_STYLES, joined_df, mo, plots, survey):
+def _(S, SPEAKING_STYLES, joined_df, mo):
_content = """### Total Results
"""
for style, traits in SPEAKING_STYLES.items():
# print(f"Correlation plot for {style}...")
- fig = plots.plot_speaking_style_correlation(
- df=joined_df,
+ fig = S.plot_speaking_style_correlation(
+ data=joined_df,
style_color=style,
style_traits=traits,
- title=f"Correlation: Speaking Style {style} and Voice Scale 1-10",
- results_dir=survey.fig_save_dir
+ title=f"Correlation: Speaking Style {style} and Voice Scale 1-10"
)
_content += f"""
#### Speaking Style **{style}**:
@@ -470,7 +458,7 @@ def _(mo):
@app.cell
-def _(SPEAKING_STYLES, df_style, mo, plots, survey, top3_voices, utils):
+def _(S, SPEAKING_STYLES, df_style, mo, top3_voices, utils):
df_ranking = utils.process_voice_ranking_data(top3_voices)
joined = df_style.join(df_ranking, on=['_recordId', 'Voice'], how='inner')
@@ -480,7 +468,7 @@ def _(SPEAKING_STYLES, df_style, mo, plots, survey, top3_voices, utils):
"""
for _style, _traits in SPEAKING_STYLES.items():
- _fig = plots.plot_speaking_style_ranking_correlation(joined, _style, _traits, results_dir=survey.fig_save_dir)
+ _fig = S.plot_speaking_style_ranking_correlation(data=joined, style_color=_style, style_traits=_traits)
_content += f"""
#### Speaking Style **{_style}**:
diff --git a/plots.py b/plots.py
index 43b7fc5..8c12b6c 100644
--- a/plots.py
+++ b/plots.py
@@ -8,1115 +8,898 @@ import polars as pl
from theme import ColorPalette
-def _sanitize_filename(title: str) -> str:
- """Convert plot title to a safe filename."""
- # Remove HTML tags
- clean = re.sub(r'<[^>]+>', ' ', title)
- # Replace special characters with underscores
- clean = re.sub(r'[^\w\s-]', '', clean)
- # Replace whitespace with underscores
- clean = re.sub(r'\s+', '_', clean.strip())
- # Remove consecutive underscores
- clean = re.sub(r'_+', '_', clean)
- # Lowercase and limit length
- return clean.lower()[:100]
+class JPMCPlotsMixin:
+ """Mixin class for plotting functions in JPMCSurvey."""
+ def _sanitize_filename(self, title: str) -> str:
+ """Convert plot title to a safe filename."""
+ # Remove HTML tags
+ clean = re.sub(r'<[^>]+>', ' ', title)
+ # Replace special characters with underscores
+ clean = re.sub(r'[^\w\s-]', '', clean)
+ # Replace whitespace with underscores
+ clean = re.sub(r'\s+', '_', clean.strip())
+ # Remove consecutive underscores
+ clean = re.sub(r'_+', '_', clean)
+ # Lowercase and limit length
+ return clean.lower()[:100]
-def _save_plot(fig: go.Figure, results_dir: str | None, title: str) -> None:
- """Save plot to PNG file if results_dir is provided."""
- if results_dir:
- path = Path(results_dir)
- path.mkdir(parents=True, exist_ok=True)
- filename = f"{_sanitize_filename(title)}.png"
- fig.write_image(path / filename, width=fig.layout.width, height=fig.layout.height)
+ def _save_plot(self, fig: go.Figure, title: str) -> None:
+ """Save plot to PNG file if fig_save_dir is set."""
+ if hasattr(self, 'fig_save_dir') and self.fig_save_dir:
+ path = Path(self.fig_save_dir)
+ if not path.exists():
+ path.mkdir(parents=True, exist_ok=True)
+
+ filename = f"{self._sanitize_filename(title)}.png"
+ fig.write_image(path / filename, width=fig.layout.width, height=fig.layout.height)
+ def _ensure_dataframe(self, data: pl.LazyFrame | pl.DataFrame | None) -> pl.DataFrame:
+ """Ensure data is an eager DataFrame, collecting if necessary."""
+ df = data if data is not None else getattr(self, 'data_filtered', None)
+ if df is None:
+ raise ValueError("No data provided and self.data_filtered is None.")
+
+ if isinstance(df, pl.LazyFrame):
+ return df.collect()
+ return df
+ def plot_average_scores_with_counts(
+ self,
+ data: pl.LazyFrame | pl.DataFrame | None = None,
+ title: str = "General Impression (1-10)
Per Voice with Number of Participants Who Rated It",
+ x_label: str = "Stimuli",
+ y_label: str = "Average General Impression Rating (1-10)",
+ color: str = ColorPalette.PRIMARY,
+ height: int | None = None,
+ width: int | None = None,
+ ) -> go.Figure:
+ """
+ Create a bar plot showing average scores and count of non-null values for each column.
+ """
+ df = self._ensure_dataframe(data)
-def plot_average_scores_with_counts(
- df: pl.DataFrame,
- title: str = "General Impression (1-10)
Per Voice with Number of Participants Who Rated It",
- x_label: str = "Stimuli",
- y_label: str = "Average General Impression Rating (1-10)",
- color: str = ColorPalette.PRIMARY,
- height: int = 500,
- width: int = 1000,
- results_dir: str | None = None,
-) -> go.Figure:
- """
- Create a bar plot showing average scores and count of non-null values for each column.
-
- Parameters
- ----------
- df : pl.DataFrame
- DataFrame containing numeric columns to analyze.
- title : str, optional
- Plot title.
- x_label : str, optional
- X-axis label.
- y_label : str, optional
- Y-axis label.
- color : str, optional
- Bar color (hex code or named color).
- height : int, optional
- Plot height in pixels.
- width : int, optional
- Plot width in pixels.
-
- Returns
- -------
- go.Figure
- Plotly figure object.
- """
- # Calculate average and count of non-null values for each column
- # Exclude _recordId column
- stats = []
- for col in [c for c in df.columns if c != '_recordId']:
- avg_score = df[col].mean()
- non_null_count = df[col].drop_nulls().len()
- stats.append({
- 'column': col,
- 'average': avg_score,
- 'count': non_null_count
- })
-
- # Sort by average score in descending order
- stats_df = pl.DataFrame(stats).sort('average', descending=True)
-
- # Extract voice identifiers from column names (e.g., "V14" from "Voice_Scale_1_10__V14")
- labels = [col.split('__')[-1] if '__' in col else col for col in stats_df['column']]
-
- # Create the plot
- fig = go.Figure()
-
- fig.add_trace(go.Bar(
- x=labels,
- y=stats_df['average'],
- text=stats_df['count'],
- textposition='inside',
- textfont=dict(size=10, color='black'),
- marker_color=color,
- hovertemplate='%{x}
Average: %{y:.2f}
Count: %{text}'
- ))
-
- fig.update_layout(
- title=title,
- xaxis_title=x_label,
- yaxis_title=y_label,
- height=height,
- width=width,
- plot_bgcolor=ColorPalette.BACKGROUND,
- xaxis=dict(
- showgrid=True,
- gridcolor=ColorPalette.GRID,
- tickangle=-45
- ),
- yaxis=dict(
- range=[0, 10],
- showgrid=True,
- gridcolor=ColorPalette.GRID
- ),
- font=dict(size=11)
- )
-
- _save_plot(fig, results_dir, title)
- return fig
-
-
-def plot_top3_ranking_distribution(
- df: pl.DataFrame,
- title: str = "Top 3 Rankings Distribution
Count of 1st, 2nd, and 3rd Place Votes per Voice",
- x_label: str = "Voices",
- y_label: str = "Number of Mentions in Top 3",
- height: int = 500,
- width: int = 1000,
- results_dir: str | None = None,
-) -> go.Figure:
- """
- Create a stacked bar chart showing how often each voice was ranked 1st, 2nd, or 3rd.
-
- The total height of the bar represents the popularity (frequency of being in Top 3),
- while the segments show the quality of those rankings.
-
- Parameters
- ----------
- df : pl.DataFrame
- DataFrame containing ranking columns (values 1, 2, 3).
- title : str, optional
- Plot title.
- x_label : str, optional
- X-axis label.
- y_label : str, optional
- Y-axis label.
- height : int, optional
- Plot height in pixels.
- width : int, optional
- Plot width in pixels.
-
- Returns
- -------
- go.Figure
- Plotly figure object.
- """
- # Exclude _recordId column
- stats = []
- for col in [c for c in df.columns if c != '_recordId']:
- # Count occurrences of each rank (1, 2, 3)
- # We ensure we're just counting the specific integer values
- rank1 = df.filter(pl.col(col) == 1).height
- rank2 = df.filter(pl.col(col) == 2).height
- rank3 = df.filter(pl.col(col) == 3).height
- total = rank1 + rank2 + rank3
-
- # Only include if it received at least one vote (optional, but keeps chart clean)
- if total > 0:
+ # Exclude _recordId column
+ stats = []
+ for col in [c for c in df.columns if c != '_recordId']:
+ avg_score = df[col].mean()
+ non_null_count = df[col].drop_nulls().len()
stats.append({
'column': col,
- 'Rank 1': rank1,
- 'Rank 2': rank2,
- 'Rank 3': rank3,
- 'Total': total
+ 'average': avg_score,
+ 'count': non_null_count
})
- # Sort by Total count descending (Most popular overall)
- # Tie-break with Rank 1 count
- stats_df = pl.DataFrame(stats).sort(['Total', 'Rank 1'], descending=[True, True])
+ # Sort by average score in descending order
+ stats_df = pl.DataFrame(stats).sort('average', descending=True)
- # Extract voice identifiers from column names
- labels = [col.split('__')[-1] if '__' in col else col for col in stats_df['column']]
+ # Extract voice identifiers from column names (e.g., "V14" from "Voice_Scale_1_10__V14")
+ labels = [col.split('__')[-1] if '__' in col else col for col in stats_df['column']]
- fig = go.Figure()
+ # Create the plot
+ fig = go.Figure()
- # Add traces for Rank 1, 2, and 3.
- # Stack order: Rank 1 at bottom (Base) -> Rank 2 -> Rank 3
- # This makes it easy to compare the "First Choice" volume across bars.
+ fig.add_trace(go.Bar(
+ x=labels,
+ y=stats_df['average'],
+ text=stats_df['count'],
+ textposition='inside',
+ textfont=dict(size=10, color='black'),
+ marker_color=color,
+ hovertemplate='%{x}
Average: %{y:.2f}
Count: %{text}'
+ ))
- fig.add_trace(go.Bar(
- name='Rank 1 (1st Choice)',
- x=labels,
- y=stats_df['Rank 1'],
- marker_color=ColorPalette.RANK_1,
- hovertemplate='%{x}
Rank 1: %{y}'
- ))
+ fig.update_layout(
+ title=title,
+ xaxis_title=x_label,
+ yaxis_title=y_label,
+ height=height if height else getattr(self, 'plot_height', 500),
+ width=width if width else getattr(self, 'plot_width', 1000),
+ plot_bgcolor=ColorPalette.BACKGROUND,
+ xaxis=dict(
+ showgrid=True,
+ gridcolor=ColorPalette.GRID,
+ tickangle=-45
+ ),
+ yaxis=dict(
+ range=[0, 10],
+ showgrid=True,
+ gridcolor=ColorPalette.GRID
+ ),
+ font=dict(size=11)
+ )
- fig.add_trace(go.Bar(
- name='Rank 2 (2nd Choice)',
- x=labels,
- y=stats_df['Rank 2'],
- marker_color=ColorPalette.RANK_2,
- hovertemplate='%{x}
Rank 2: %{y}'
- ))
+ self._save_plot(fig, title)
+ return fig
- fig.add_trace(go.Bar(
- name='Rank 3 (3rd Choice)',
- x=labels,
- y=stats_df['Rank 3'],
- marker_color=ColorPalette.RANK_3,
- hovertemplate='%{x}
Rank 3: %{y}'
- ))
+ def plot_top3_ranking_distribution(
+ self,
+ data: pl.LazyFrame | pl.DataFrame | None = None,
+ title: str = "Top 3 Rankings Distribution
Count of 1st, 2nd, and 3rd Place Votes per Voice",
+ x_label: str = "Voices",
+ y_label: str = "Number of Mentions in Top 3",
+ height: int | None = None,
+ width: int | None = None,
+ ) -> go.Figure:
+ """
+ Create a stacked bar chart showing how often each voice was ranked 1st, 2nd, or 3rd.
+ """
+ df = self._ensure_dataframe(data)
- fig.update_layout(
- barmode='stack',
- title=title,
- xaxis_title=x_label,
- yaxis_title=y_label,
- height=height,
- width=width,
- plot_bgcolor=ColorPalette.BACKGROUND,
- xaxis=dict(
- showgrid=True,
- gridcolor=ColorPalette.GRID,
- tickangle=-45
- ),
- yaxis=dict(
- showgrid=True,
- gridcolor=ColorPalette.GRID
- ),
- legend=dict(
- orientation="h",
- yanchor="bottom",
- y=1.02,
- xanchor="right",
- x=1,
- traceorder="normal"
- ),
- font=dict(size=11)
- )
+ # Exclude _recordId column
+ stats = []
+ for col in [c for c in df.columns if c != '_recordId']:
+ # Count occurrences of each rank (1, 2, 3)
+ # We ensure we're just counting the specific integer values
+ rank1 = df.filter(pl.col(col) == 1).height
+ rank2 = df.filter(pl.col(col) == 2).height
+ rank3 = df.filter(pl.col(col) == 3).height
+ total = rank1 + rank2 + rank3
- _save_plot(fig, results_dir, title)
- return fig
+ # Only include if it received at least one vote (optional, but keeps chart clean)
+ if total > 0:
+ stats.append({
+ 'column': col,
+ 'Rank 1': rank1,
+ 'Rank 2': rank2,
+ 'Rank 3': rank3,
+ 'Total': total
+ })
+ # Sort by Total count descending (Most popular overall)
+ # Tie-break with Rank 1 count
+ stats_df = pl.DataFrame(stats).sort(['Total', 'Rank 1'], descending=[True, True])
-def plot_ranking_distribution(
- df: pl.DataFrame,
- title: str = "Rankings Distribution
(1st to 4th Place)",
- x_label: str = "Item",
- y_label: str = "Number of Votes",
- height: int = 500,
- width: int = 1000,
- results_dir: str | None = None,
-) -> go.Figure:
- """
- Create a stacked bar chart showing the distribution of rankings (1st to 4th) for characters or voices.
- Sorted by the number of Rank 1 votes.
+ # Extract voice identifiers from column names
+ labels = [col.split('__')[-1] if '__' in col else col for col in stats_df['column']]
- Parameters
- ----------
- df : pl.DataFrame
- DataFrame containing ranking columns.
- title : str, optional
- Plot title.
- x_label : str, optional
- X-axis label.
- y_label : str, optional
- Y-axis label.
- height : int, optional
- Plot height in pixels.
- width : int, optional
- Plot width in pixels.
+ fig = go.Figure()
- Returns
- -------
- go.Figure
- Plotly figure object.
- """
- stats = []
- # Identify ranking columns (assume all columns except _recordId)
- ranking_cols = [c for c in df.columns if c != '_recordId']
+ # Add traces for Rank 1, 2, and 3.
+ # Stack order: Rank 1 at bottom (Base) -> Rank 2 -> Rank 3
+ # This makes it easy to compare the "First Choice" volume across bars.
- for col in ranking_cols:
- # Count occurrences of each rank (1, 2, 3, 4)
- # Using height/len to count rows in the filtered frame
- r1 = df.filter(pl.col(col) == 1).height
- r2 = df.filter(pl.col(col) == 2).height
- r3 = df.filter(pl.col(col) == 3).height
- r4 = df.filter(pl.col(col) == 4).height
- total = r1 + r2 + r3 + r4
+ fig.add_trace(go.Bar(
+ name='Rank 1 (1st Choice)',
+ x=labels,
+ y=stats_df['Rank 1'],
+ marker_color=ColorPalette.RANK_1,
+ hovertemplate='%{x}
Rank 1: %{y}'
+ ))
- if total > 0:
+ fig.add_trace(go.Bar(
+ name='Rank 2 (2nd Choice)',
+ x=labels,
+ y=stats_df['Rank 2'],
+ marker_color=ColorPalette.RANK_2,
+ hovertemplate='%{x}
Rank 2: %{y}'
+ ))
+
+ fig.add_trace(go.Bar(
+ name='Rank 3 (3rd Choice)',
+ x=labels,
+ y=stats_df['Rank 3'],
+ marker_color=ColorPalette.RANK_3,
+ hovertemplate='%{x}
Rank 3: %{y}'
+ ))
+
+ fig.update_layout(
+ barmode='stack',
+ title=title,
+ xaxis_title=x_label,
+ yaxis_title=y_label,
+ height=height if height else getattr(self, 'plot_height', 500),
+ width=width if width else getattr(self, 'plot_width', 1000),
+ plot_bgcolor=ColorPalette.BACKGROUND,
+ xaxis=dict(
+ showgrid=True,
+ gridcolor=ColorPalette.GRID,
+ tickangle=-45
+ ),
+ yaxis=dict(
+ showgrid=True,
+ gridcolor=ColorPalette.GRID
+ ),
+ legend=dict(
+ orientation="h",
+ yanchor="bottom",
+ y=1.02,
+ xanchor="right",
+ x=1,
+ traceorder="normal"
+ ),
+ font=dict(size=11)
+ )
+
+ self._save_plot(fig, title)
+ return fig
+
+ def plot_ranking_distribution(
+ self,
+ data: pl.LazyFrame | pl.DataFrame | None = None,
+ title: str = "Rankings Distribution
(1st to 4th Place)",
+ x_label: str = "Item",
+ y_label: str = "Number of Votes",
+ height: int | None = None,
+ width: int | None = None,
+ ) -> go.Figure:
+ """
+ Create a stacked bar chart showing the distribution of rankings (1st to 4th) for characters or voices.
+ Sorted by the number of Rank 1 votes.
+ """
+ df = self._ensure_dataframe(data)
+
+ stats = []
+ # Identify ranking columns (assume all columns except _recordId)
+ ranking_cols = [c for c in df.columns if c != '_recordId']
+
+ for col in ranking_cols:
+ # Count occurrences of each rank (1, 2, 3, 4)
+ # Using height/len to count rows in the filtered frame
+ r1 = df.filter(pl.col(col) == 1).height
+ r2 = df.filter(pl.col(col) == 2).height
+ r3 = df.filter(pl.col(col) == 3).height
+ r4 = df.filter(pl.col(col) == 4).height
+ total = r1 + r2 + r3 + r4
+
+ if total > 0:
+ stats.append({
+ 'column': col,
+ 'Rank 1': r1,
+ 'Rank 2': r2,
+ 'Rank 3': r3,
+ 'Rank 4': r4
+ })
+
+ if not stats:
+ return go.Figure()
+
+ # Sort by Rank 1 (Most "Best" votes) descending to show the winner first
+ # Secondary sort by Rank 2
+ stats_df = pl.DataFrame(stats).sort(['Rank 1', 'Rank 2'], descending=[True, True])
+
+ # Clean up labels: Remove prefix and underscores
+ # e.g. "Character_Ranking_The_Coach" -> "The Coach"
+ labels = [
+ col.replace('Character_Ranking_', '').replace('Top_3_Voices_ranking__', '').replace('_', ' ').strip()
+ for col in stats_df['column']
+ ]
+
+ fig = go.Figure()
+
+ # Rank 1 (Best)
+ fig.add_trace(go.Bar(
+ name='Rank 1 (Best)',
+ x=labels,
+ y=stats_df['Rank 1'],
+ marker_color=ColorPalette.RANK_1,
+ hovertemplate='%{x}
Rank 1: %{y}'
+ ))
+
+ # Rank 2
+ fig.add_trace(go.Bar(
+ name='Rank 2',
+ x=labels,
+ y=stats_df['Rank 2'],
+ marker_color=ColorPalette.RANK_2,
+ hovertemplate='%{x}
Rank 2: %{y}'
+ ))
+
+ # Rank 3
+ fig.add_trace(go.Bar(
+ name='Rank 3',
+ x=labels,
+ y=stats_df['Rank 3'],
+ marker_color=ColorPalette.RANK_3,
+ hovertemplate='%{x}
Rank 3: %{y}'
+ ))
+
+ # Rank 4 (Worst)
+ # Using a neutral grey as a fallback for the lowest rank to keep focus on top ranks
+ fig.add_trace(go.Bar(
+ name='Rank 4 (Worst)',
+ x=labels,
+ y=stats_df['Rank 4'],
+ marker_color=ColorPalette.RANK_4,
+ hovertemplate='%{x}
Rank 4: %{y}'
+ ))
+
+ fig.update_layout(
+ barmode='stack',
+ title=title,
+ xaxis_title=x_label,
+ yaxis_title=y_label,
+ height=height if height else getattr(self, 'plot_height', 500),
+ width=width if width else getattr(self, 'plot_width', 1000),
+ plot_bgcolor=ColorPalette.BACKGROUND,
+ xaxis=dict(
+ showgrid=True,
+ gridcolor=ColorPalette.GRID,
+ tickangle=-45
+ ),
+ yaxis=dict(
+ showgrid=True,
+ gridcolor=ColorPalette.GRID
+ ),
+ legend=dict(
+ orientation="h",
+ yanchor="bottom",
+ y=1.02,
+ xanchor="right",
+ x=1,
+ traceorder="normal"
+ ),
+ font=dict(size=11)
+ )
+
+ self._save_plot(fig, title)
+ return fig
+
+ def plot_most_ranked_1(
+ self,
+ data: pl.LazyFrame | pl.DataFrame | None = None,
+ title: str = "Most Popular Choice
(Number of Times Ranked 1st)",
+ x_label: str = "Item",
+ y_label: str = "Count of 1st Place Rankings",
+ height: int | None = None,
+ width: int | None = None,
+ ) -> go.Figure:
+ """
+ Create a bar chart showing which item (character/voice) was ranked #1 the most.
+ Top 3 items are highlighted.
+ """
+ df = self._ensure_dataframe(data)
+
+ stats = []
+ # Identify ranking columns (assume all columns except _recordId)
+ ranking_cols = [c for c in df.columns if c != '_recordId']
+
+ for col in ranking_cols:
+ # Count occurrences of rank 1
+ count_rank_1 = df.filter(pl.col(col) == 1).height
+
stats.append({
'column': col,
- 'Rank 1': r1,
- 'Rank 2': r2,
- 'Rank 3': r3,
- 'Rank 4': r4
+ 'count': count_rank_1
})
- if not stats:
- return go.Figure()
+ # Sort by count descending
+ stats_df = pl.DataFrame(stats).sort('count', descending=True)
- # Sort by Rank 1 (Most "Best" votes) descending to show the winner first
- # Secondary sort by Rank 2
- stats_df = pl.DataFrame(stats).sort(['Rank 1', 'Rank 2'], descending=[True, True])
+ # Clean up labels
+ labels = [
+ col.replace('Character_Ranking_', '').replace('Top_3_Voices_ranking__', '').replace('_', ' ').strip()
+ for col in stats_df['column']
+ ]
- # Clean up labels: Remove prefix and underscores
- # e.g. "Character_Ranking_The_Coach" -> "The Coach"
- labels = [
- col.replace('Character_Ranking_', '').replace('Top_3_Voices_ranking__', '').replace('_', ' ').strip()
- for col in stats_df['column']
- ]
+ # Assign colors: Top 3 get PRIMARY (Blue), others get NEUTRAL (Grey)
+ colors = [
+ ColorPalette.PRIMARY if i < 3 else ColorPalette.NEUTRAL
+ for i in range(len(stats_df))
+ ]
- fig = go.Figure()
-
- # Rank 1 (Best)
- fig.add_trace(go.Bar(
- name='Rank 1 (Best)',
- x=labels,
- y=stats_df['Rank 1'],
- marker_color=ColorPalette.RANK_1,
- hovertemplate='%{x}
Rank 1: %{y}'
- ))
-
- # Rank 2
- fig.add_trace(go.Bar(
- name='Rank 2',
- x=labels,
- y=stats_df['Rank 2'],
- marker_color=ColorPalette.RANK_2,
- hovertemplate='%{x}
Rank 2: %{y}'
- ))
-
- # Rank 3
- fig.add_trace(go.Bar(
- name='Rank 3',
- x=labels,
- y=stats_df['Rank 3'],
- marker_color=ColorPalette.RANK_3,
- hovertemplate='%{x}
Rank 3: %{y}'
- ))
-
- # Rank 4 (Worst)
- # Using a neutral grey as a fallback for the lowest rank to keep focus on top ranks
- fig.add_trace(go.Bar(
- name='Rank 4 (Worst)',
- x=labels,
- y=stats_df['Rank 4'],
- marker_color=ColorPalette.RANK_4,
- hovertemplate='%{x}
Rank 4: %{y}'
- ))
-
- fig.update_layout(
- barmode='stack',
- title=title,
- xaxis_title=x_label,
- yaxis_title=y_label,
- height=height,
- width=width,
- plot_bgcolor=ColorPalette.BACKGROUND,
- xaxis=dict(
- showgrid=True,
- gridcolor=ColorPalette.GRID,
- tickangle=-45
- ),
- yaxis=dict(
- showgrid=True,
- gridcolor=ColorPalette.GRID
- ),
- legend=dict(
- orientation="h",
- yanchor="bottom",
- y=1.02,
- xanchor="right",
- x=1,
- traceorder="normal"
- ),
- font=dict(size=11)
- )
-
- _save_plot(fig, results_dir, title)
- return fig
-
-
-def plot_most_ranked_1(
- df: pl.DataFrame,
- title: str = "Most Popular Choice
(Number of Times Ranked 1st)",
- x_label: str = "Item",
- y_label: str = "Count of 1st Place Rankings",
- height: int = 500,
- width: int = 1000,
- results_dir: str | None = None,
-) -> go.Figure:
- """
- Create a bar chart showing which item (character/voice) was ranked #1 the most.
- Top 3 items are highlighted.
-
- Parameters
- ----------
- df : pl.DataFrame
- DataFrame containing ranking columns.
- title : str, optional
- Plot title.
- x_label : str, optional
- X-axis label.
- y_label : str, optional
- Y-axis label.
- height : int, optional
- Plot height in pixels.
- width : int, optional
- Plot width in pixels.
-
- Returns
- -------
- go.Figure
- Plotly figure object.
- """
- stats = []
- # Identify ranking columns (assume all columns except _recordId)
- ranking_cols = [c for c in df.columns if c != '_recordId']
-
- for col in ranking_cols:
- # Count occurrences of rank 1
- count_rank_1 = df.filter(pl.col(col) == 1).height
+ fig = go.Figure()
- stats.append({
- 'column': col,
- 'count': count_rank_1
- })
+ fig.add_trace(go.Bar(
+ x=labels,
+ y=stats_df['count'],
+ text=stats_df['count'],
+ textposition='inside',
+ textfont=dict(size=10, color='white'),
+ marker_color=colors,
+ hovertemplate='%{x}
1st Place Votes: %{y}'
+ ))
- # Sort by count descending
- stats_df = pl.DataFrame(stats).sort('count', descending=True)
+ fig.update_layout(
+ title=title,
+ xaxis_title=x_label,
+ yaxis_title=y_label,
+ height=height if height else getattr(self, 'plot_height', 500),
+ width=width if width else getattr(self, 'plot_width', 1000),
+ plot_bgcolor=ColorPalette.BACKGROUND,
+ xaxis=dict(
+ showgrid=True,
+ gridcolor=ColorPalette.GRID,
+ tickangle=-45
+ ),
+ yaxis=dict(
+ showgrid=True,
+ gridcolor=ColorPalette.GRID
+ ),
+ font=dict(size=11)
+ )
- # Clean up labels
- labels = [
- col.replace('Character_Ranking_', '').replace('Top_3_Voices_ranking__', '').replace('_', ' ').strip()
- for col in stats_df['column']
- ]
+ self._save_plot(fig, title)
+ return fig
- # Assign colors: Top 3 get PRIMARY (Blue), others get NEUTRAL (Grey)
- colors = [
- ColorPalette.PRIMARY if i < 3 else ColorPalette.NEUTRAL
- for i in range(len(stats_df))
- ]
+ def plot_weighted_ranking_score(
+ self,
+ data: pl.LazyFrame | pl.DataFrame | None = None,
+ title: str = "Weighted Popularity Score
(1st=3pts, 2nd=2pts, 3rd=1pt)",
+ x_label: str = "Character Personality",
+ y_label: str = "Total Weighted Score",
+ color: str = ColorPalette.PRIMARY,
+ height: int | None = None,
+ width: int | None = None,
+ ) -> go.Figure:
+ """
+ Create a bar chart showing the weighted ranking score for each character.
+ """
+ weighted_df = self._ensure_dataframe(data)
- fig = go.Figure()
-
- fig.add_trace(go.Bar(
- x=labels,
- y=stats_df['count'],
- text=stats_df['count'],
- textposition='inside',
- textfont=dict(size=10, color='white'),
- marker_color=colors,
- hovertemplate='%{x}
1st Place Votes: %{y}'
- ))
+ fig = go.Figure()
- fig.update_layout(
- title=title,
- xaxis_title=x_label,
- yaxis_title=y_label,
- height=height,
- width=width,
- plot_bgcolor=ColorPalette.BACKGROUND,
- xaxis=dict(
- showgrid=True,
- gridcolor=ColorPalette.GRID,
- tickangle=-45
- ),
- yaxis=dict(
- showgrid=True,
- gridcolor=ColorPalette.GRID
- ),
- font=dict(size=11)
- )
+ fig.add_trace(go.Bar(
+ x=weighted_df['Character'],
+ y=weighted_df['Weighted Score'],
+ text=weighted_df['Weighted Score'],
+ textposition='inside',
+ textfont=dict(size=11, color='white'),
+ marker_color=color,
+ hovertemplate='%{x}
Score: %{y}'
+ ))
- _save_plot(fig, results_dir, title)
- return fig
+ fig.update_layout(
+ title=title,
+ xaxis_title=x_label,
+ yaxis_title=y_label,
+ height=height if height else getattr(self, 'plot_height', 500),
+ width=width if width else getattr(self, 'plot_width', 1000),
+ plot_bgcolor=ColorPalette.BACKGROUND,
+ xaxis=dict(
+ showgrid=True,
+ gridcolor=ColorPalette.GRID,
+ tickangle=-45
+ ),
+ yaxis=dict(
+ showgrid=True,
+ gridcolor=ColorPalette.GRID
+ ),
+ font=dict(size=11)
+ )
+ self._save_plot(fig, title)
+ return fig
+ def plot_voice_selection_counts(
+ self,
+ data: pl.LazyFrame | pl.DataFrame | None = None,
+ target_column: str = "8_Combined",
+ title: str = "Most Frequently Chosen Voices
(Top 8 Highlighted)",
+ x_label: str = "Voice",
+ y_label: str = "Number of Times Chosen",
+ height: int | None = None,
+ width: int | None = None,
+ ) -> go.Figure:
+ """
+ Create a bar plot showing the frequency of voice selections.
+ """
+ df = self._ensure_dataframe(data)
-def plot_weighted_ranking_score(
- weighted_df: pl.DataFrame,
- title: str = "Weighted Popularity Score
(1st=3pts, 2nd=2pts, 3rd=1pt)",
- x_label: str = "Character Personality",
- y_label: str = "Total Weighted Score",
- color: str = ColorPalette.PRIMARY,
- height: int = 500,
- width: int = 1000,
- results_dir: str | None = None,
-) -> go.Figure:
- """
- Create a bar chart showing the weighted ranking score for each character.
+ if target_column not in df.columns:
+ return go.Figure()
- Parameters
- ----------
- df : pl.DataFrame
- DataFrame containing ranking columns.
- title : str, optional
- Plot title.
- x_label : str, optional
- X-axis label.
- y_label : str, optional
- Y-axis label.
- color : str, optional
- Bar color.
- height : int, optional
- Plot height.
- width : int, optional
- Plot width.
+ # Process the data:
+ # 1. Select the relevant column and remove nulls
+ # 2. Split the comma-separated string into a list
+ # 3. Explode the list so each voice gets its own row
+ # 4. Strip whitespace ensuring "Voice 1" and " Voice 1" match
+ # 5. Count occurrences
+ stats_df = (
+ df.select(pl.col(target_column))
+ .drop_nulls()
+ .with_columns(pl.col(target_column).str.split(","))
+ .explode(target_column)
+ .with_columns(pl.col(target_column).str.strip_chars())
+ .filter(pl.col(target_column) != "")
+ .group_by(target_column)
+ .agg(pl.len().alias("count"))
+ .sort("count", descending=True)
+ )
- Returns
- -------
- go.Figure
- Plotly figure object.
- """
+ # Define colors: Top 8 get PRIMARY, rest get NEUTRAL
+ colors = [
+ ColorPalette.PRIMARY if i < 8 else ColorPalette.NEUTRAL
+ for i in range(len(stats_df))
+ ]
- fig = go.Figure()
+ fig = go.Figure()
- fig.add_trace(go.Bar(
- x=weighted_df['Character'],
- y=weighted_df['Weighted Score'],
- text=weighted_df['Weighted Score'],
- textposition='inside',
- textfont=dict(size=11, color='white'),
- marker_color=color,
- hovertemplate='%{x}
Score: %{y}'
- ))
+ fig.add_trace(go.Bar(
+ x=stats_df[target_column],
+ y=stats_df['count'],
+ text=stats_df['count'],
+ textposition='outside',
+ marker_color=colors,
+ hovertemplate='%{x}
Selections: %{y}'
+ ))
- fig.update_layout(
- title=title,
- xaxis_title=x_label,
- yaxis_title=y_label,
- height=height,
- width=width,
- plot_bgcolor=ColorPalette.BACKGROUND,
- xaxis=dict(
- showgrid=True,
- gridcolor=ColorPalette.GRID,
- tickangle=-45
- ),
- yaxis=dict(
- showgrid=True,
- gridcolor=ColorPalette.GRID
- ),
- font=dict(size=11)
- )
+ fig.update_layout(
+ title=title,
+ xaxis_title=x_label,
+ yaxis_title=y_label,
+ height=height if height else getattr(self, 'plot_height', 500),
+ width=width if width else getattr(self, 'plot_width', 1000),
+ plot_bgcolor=ColorPalette.BACKGROUND,
+ xaxis=dict(
+ showgrid=True,
+ gridcolor=ColorPalette.GRID,
+ tickangle=-45
+ ),
+ yaxis=dict(
+ showgrid=True,
+ gridcolor=ColorPalette.GRID
+ ),
+ font=dict(size=11),
+ )
- _save_plot(fig, results_dir, title)
- return fig
+ self._save_plot(fig, title)
+ return fig
+ def plot_top3_selection_counts(
+ self,
+ data: pl.LazyFrame | pl.DataFrame | None = None,
+ target_column: str = "3_Ranked",
+ title: str = "Most Frequently Chosen Top 3 Voices
(Top 3 Highlighted)",
+ x_label: str = "Voice",
+ y_label: str = "Count of Mentions in Top 3",
+ height: int | None = None,
+ width: int | None = None,
+ ) -> go.Figure:
+ """
+ Question: Which 3 voices are chosen the most out of 18?
+ """
+ df = self._ensure_dataframe(data)
-def plot_voice_selection_counts(
- df: pl.DataFrame,
- target_column: str = "8_Combined",
- title: str = "Most Frequently Chosen Voices
(Top 8 Highlighted)",
- x_label: str = "Voice",
- y_label: str = "Number of Times Chosen",
- height: int = 500,
- width: int = 1000,
- results_dir: str | None = None,
-) -> go.Figure:
- """
- Create a bar plot showing the frequency of voice selections.
- Takes a column containing comma-separated values (e.g. "Voice 1, Voice 2..."),
- counts occurrences, and highlights the top 8 most frequent voices.
+ if target_column not in df.columns:
+ return go.Figure()
- Parameters
- ----------
- df : pl.DataFrame
- DataFrame containing the selection column.
- target_column : str, optional
- Name of the column containing comma-separated voice selections.
- Defaults to "8_Combined".
- title : str, optional
- Plot title.
- x_label : str, optional
- X-axis label.
- y_label : str, optional
- Y-axis label.
- height : int, optional
- Plot height in pixels.
- width : int, optional
- Plot width in pixels.
+ # Process the data:
+ # Same logic as plot_voice_selection_counts: explode comma-separated string
+ stats_df = (
+ df.select(pl.col(target_column))
+ .drop_nulls()
+ .with_columns(pl.col(target_column).str.split(","))
+ .explode(target_column)
+ .with_columns(pl.col(target_column).str.strip_chars())
+ .filter(pl.col(target_column) != "")
+ .group_by(target_column)
+ .agg(pl.len().alias("count"))
+ .sort("count", descending=True)
+ )
- Returns
- -------
- go.Figure
- Plotly figure object.
- """
- if target_column not in df.columns:
- return go.Figure()
+ # Define colors: Top 3 get PRIMARY, rest get NEUTRAL
+ colors = [
+ ColorPalette.PRIMARY if i < 3 else ColorPalette.NEUTRAL
+ for i in range(len(stats_df))
+ ]
- # Process the data:
- # 1. Select the relevant column and remove nulls
- # 2. Split the comma-separated string into a list
- # 3. Explode the list so each voice gets its own row
- # 4. Strip whitespace ensuring "Voice 1" and " Voice 1" match
- # 5. Count occurrences
- stats_df = (
- df.select(pl.col(target_column))
- .drop_nulls()
- .with_columns(pl.col(target_column).str.split(","))
- .explode(target_column)
- .with_columns(pl.col(target_column).str.strip_chars())
- .filter(pl.col(target_column) != "")
- .group_by(target_column)
- .agg(pl.len().alias("count"))
- .sort("count", descending=True)
- )
+ fig = go.Figure()
- # Define colors: Top 8 get PRIMARY, rest get NEUTRAL
- colors = [
- ColorPalette.PRIMARY if i < 8 else ColorPalette.NEUTRAL
- for i in range(len(stats_df))
- ]
+ fig.add_trace(go.Bar(
+ x=stats_df[target_column],
+ y=stats_df['count'],
+ text=stats_df['count'],
+ textposition='outside',
+ marker_color=colors,
+ hovertemplate='%{x}
In Top 3: %{y} times'
+ ))
- fig = go.Figure()
+ fig.update_layout(
+ title=title,
+ xaxis_title=x_label,
+ yaxis_title=y_label,
+ height=height if height else getattr(self, 'plot_height', 500),
+ width=width if width else getattr(self, 'plot_width', 1000),
+ plot_bgcolor=ColorPalette.BACKGROUND,
+ xaxis=dict(
+ showgrid=True,
+ gridcolor=ColorPalette.GRID,
+ tickangle=-45
+ ),
+ yaxis=dict(
+ showgrid=True,
+ gridcolor=ColorPalette.GRID
+ ),
+ font=dict(size=11),
+ )
- fig.add_trace(go.Bar(
- x=stats_df[target_column],
- y=stats_df['count'],
- text=stats_df['count'],
- textposition='outside',
- marker_color=colors,
- hovertemplate='%{x}
Selections: %{y}'
- ))
+ self._save_plot(fig, title)
+ return fig
- fig.update_layout(
- title=title,
- xaxis_title=x_label,
- yaxis_title=y_label,
- height=height,
- width=width,
- plot_bgcolor=ColorPalette.BACKGROUND,
- xaxis=dict(
- showgrid=True,
- gridcolor=ColorPalette.GRID,
- tickangle=-45
- ),
- yaxis=dict(
- showgrid=True,
- gridcolor=ColorPalette.GRID
- ),
- font=dict(size=11),
- )
+ def plot_speaking_style_trait_scores(
+ self,
+ data: pl.LazyFrame | pl.DataFrame | None = None,
+ trait_description: str = None,
+ left_anchor: str = None,
+ right_anchor: str = None,
+ title: str = "Speaking Style Trait Analysis",
+ height: int | None = None,
+ width: int | None = None,
+ ) -> go.Figure:
+ """
+ Plot scores for a single speaking style trait across multiple voices.
+ """
+ df = self._ensure_dataframe(data)
- _save_plot(fig, results_dir, title)
- return fig
+ if df.is_empty():
+ return go.Figure()
+
+ required_cols = ["Voice", "score"]
+ if not all(col in df.columns for col in required_cols):
+ return go.Figure()
+ # Calculate stats: Mean, Count
+ stats = (
+ df.filter(pl.col("score").is_not_null())
+ .group_by("Voice")
+ .agg([
+ pl.col("score").mean().alias("mean_score"),
+ pl.col("score").count().alias("count")
+ ])
+ .sort("mean_score", descending=False) # Ascending for display bottom-to-top
+ )
-def plot_top3_selection_counts(
- df: pl.DataFrame,
- target_column: str = "3_Ranked",
- title: str = "Most Frequently Chosen Top 3 Voices
(Top 3 Highlighted)",
- x_label: str = "Voice",
- y_label: str = "Count of Mentions in Top 3",
- height: int = 500,
- width: int = 1000,
- results_dir: str | None = None,
-) -> go.Figure:
- """
- Question: Which 3 voices are chosen the most out of 18?
-
- How many times does each voice end up in the top 3?
- (this is based on the survey question where participants need to choose 3 out
- of the earlier selected 8 voices). So how often each of the 18 stimuli ended
- up in participants' Top 3, after they first selected 8 out of 18.
+ # Attempt to extract anchors from DF if not provided
+ if (left_anchor is None or right_anchor is None) and "Left_Anchor" in df.columns:
+ head = df.filter(pl.col("Left_Anchor").is_not_null()).head(1)
+ if not head.is_empty():
+ if left_anchor is None: left_anchor = head["Left_Anchor"][0]
+ if right_anchor is None: right_anchor = head["Right_Anchor"][0]
- Parameters
- ----------
- df : pl.DataFrame
- DataFrame containing the ranking column (comma-separated strings).
- target_column : str, optional
- Name of the column containing comma-separated Top 3 voice elections.
- Defaults to "3_Ranked".
- title : str, optional
- Plot title.
- x_label : str, optional
- X-axis label.
- y_label : str, optional
- Y-axis label.
- height : int, optional
- Plot height in pixels.
- width : int, optional
- Plot width in pixels.
-
- Returns
- -------
- go.Figure
- Plotly figure object.
- """
- if target_column not in df.columns:
- return go.Figure()
-
- # Process the data:
- # Same logic as plot_voice_selection_counts: explode comma-separated string
- stats_df = (
- df.select(pl.col(target_column))
- .drop_nulls()
- .with_columns(pl.col(target_column).str.split(","))
- .explode(target_column)
- .with_columns(pl.col(target_column).str.strip_chars())
- .filter(pl.col(target_column) != "")
- .group_by(target_column)
- .agg(pl.len().alias("count"))
- .sort("count", descending=True)
- )
-
- # Define colors: Top 3 get PRIMARY, rest get NEUTRAL
- colors = [
- ColorPalette.PRIMARY if i < 3 else ColorPalette.NEUTRAL
- for i in range(len(stats_df))
- ]
-
- fig = go.Figure()
-
- fig.add_trace(go.Bar(
- x=stats_df[target_column],
- y=stats_df['count'],
- text=stats_df['count'],
- textposition='outside',
- marker_color=colors,
- hovertemplate='%{x}
In Top 3: %{y} times'
- ))
-
- fig.update_layout(
- title=title,
- xaxis_title=x_label,
- yaxis_title=y_label,
- height=height,
- width=width,
- plot_bgcolor=ColorPalette.BACKGROUND,
- xaxis=dict(
- showgrid=True,
- gridcolor=ColorPalette.GRID,
- tickangle=-45
- ),
- yaxis=dict(
- showgrid=True,
- gridcolor=ColorPalette.GRID
- ),
- font=dict(size=11),
- )
-
- _save_plot(fig, results_dir, title)
- return fig
-
-
-def plot_speaking_style_trait_scores(
- df: pl.DataFrame,
- trait_description: str = None,
- left_anchor: str = None,
- right_anchor: str = None,
- title: str = "Speaking Style Trait Analysis",
- height: int = 500,
- width: int = 1000,
- results_dir: str | None = None,
-) -> go.Figure:
- """
- Plot scores for a single speaking style trait across multiple voices.
-
- The plot shows the average score per Voice, sorted by score.
- It expects the DataFrame to contain 'Voice' and 'score' columns,
- typically filtered for a single trait/description.
-
- Parameters
- ----------
- df : pl.DataFrame
- DataFrame containing at least 'Voice' and 'score' columns.
- Produced by utils.process_speaking_style_data and filtered.
- trait_description : str, optional
- Description of the trait being analyzed (e.g. "Indifferent : Attentive").
- If not provided, it will be constructed from annotations.
- left_anchor : str, optional
- Label for the lower end of the scale (e.g. "Indifferent").
- If not provided, attempts to read 'Left_Anchor' column from df.
- right_anchor : str, optional
- Label for the upper end of the scale (e.g. "Attentive").
- If not provided, attempts to read 'Right_Anchor' column from df.
- title : str, optional
- Plot title.
- height : int, optional
- Plot height.
- width : int, optional
- Plot width.
-
- Returns
- -------
- go.Figure
- Plotly figure object.
- """
- if df.is_empty():
- return go.Figure()
-
- required_cols = ["Voice", "score"]
- if not all(col in df.columns for col in required_cols):
- return go.Figure()
-
- # Calculate stats: Mean, Count
- stats = (
- df.filter(pl.col("score").is_not_null())
- .group_by("Voice")
- .agg([
- pl.col("score").mean().alias("mean_score"),
- pl.col("score").count().alias("count")
- ])
- .sort("mean_score", descending=False) # Ascending for display bottom-to-top
- )
-
- # Attempt to extract anchors from DF if not provided
- if (left_anchor is None or right_anchor is None) and "Left_Anchor" in df.columns:
- head = df.filter(pl.col("Left_Anchor").is_not_null()).head(1)
- if not head.is_empty():
- if left_anchor is None: left_anchor = head["Left_Anchor"][0]
- if right_anchor is None: right_anchor = head["Right_Anchor"][0]
-
- if trait_description is None:
- if left_anchor and right_anchor:
- trait_description = f"{left_anchor.split('|')[0]} vs. {right_anchor.split('|')[0]}"
- else:
- # Try getting from Description column
- if "Description" in df.columns:
- head = df.filter(pl.col("Description").is_not_null()).head(1)
- if not head.is_empty():
- trait_description = head["Description"][0]
+ if trait_description is None:
+ if left_anchor and right_anchor:
+ trait_description = f"{left_anchor.split('|')[0]} vs. {right_anchor.split('|')[0]}"
+ else:
+ # Try getting from Description column
+ if "Description" in df.columns:
+ head = df.filter(pl.col("Description").is_not_null()).head(1)
+ if not head.is_empty():
+ trait_description = head["Description"][0]
+ else:
+ trait_description = ""
else:
- trait_description = ""
- else:
- trait_description = ""
+ trait_description = ""
- fig = go.Figure()
+ fig = go.Figure()
- fig.add_trace(go.Bar(
- y=stats["Voice"], # Y is Voice
- x=stats["mean_score"], # X is Score
- orientation='h',
- text=stats["count"],
- textposition='inside',
- textangle=0,
- textfont=dict(size=16, color='white'),
- texttemplate='%{text}', # Count on bar
- marker_color=ColorPalette.PRIMARY,
- hovertemplate='%{y}
Average: %{x:.2f}
Count: %{text}'
- ))
-
- # Add annotations for anchors
- annotations = []
-
- # Place anchors at the bottom
- if left_anchor:
- annotations.append(dict(
- xref='x', yref='paper',
- x=1, y=-0.2, # Below axis
- xanchor='left', yanchor='top',
- text=f"1: {left_anchor.split('|')[0]}",
- showarrow=False,
- font=dict(size=10, color='gray')
- ))
- if right_anchor:
- annotations.append(dict(
- xref='x', yref='paper',
- x=5, y=-0.2, # Below axis
- xanchor='right', yanchor='top',
- text=f"5: {right_anchor.split('|')[0]}",
- showarrow=False,
- font=dict(size=10, color='gray')
+ fig.add_trace(go.Bar(
+ y=stats["Voice"], # Y is Voice
+ x=stats["mean_score"], # X is Score
+ orientation='h',
+ text=stats["count"],
+ textposition='inside',
+ textangle=0,
+ textfont=dict(size=16, color='white'),
+ texttemplate='%{text}', # Count on bar
+ marker_color=ColorPalette.PRIMARY,
+ hovertemplate='%{y}
Average: %{x:.2f}
Count: %{text}'
))
- fig.update_layout(
- title=dict(
- text=f"{title}
{trait_description}
(Numbers on bars indicate respondent count)",
- y=0.92
- ),
- xaxis_title="Average Score (1-5)",
- yaxis_title="Voice",
- height=height,
- width=width,
- plot_bgcolor=ColorPalette.BACKGROUND,
- xaxis=dict(
- range=[1, 5],
- showgrid=True,
- gridcolor=ColorPalette.GRID,
- zeroline=False
- ),
- yaxis=dict(
- showgrid=False
- ),
- margin=dict(b=120),
- annotations=annotations,
- font=dict(size=11)
- )
- _save_plot(fig, results_dir, title)
- return fig
-
-def plot_speaking_style_correlation(
- df: pl.DataFrame,
- style_color: str,
- style_traits: list[str],
- title=f"Speaking style and voice scale 1-10 correlations",
- results_dir: str | None = None,
-) -> go.Figure:
- """
- Plots the correlation between Speaking Style Trait Scores (1-5) and Voice Scale (1-10) using a Bar Chart.
- Each bar represents one trait.
-
- Parameters
- ----------
- df : pl.DataFrame
- Joined dataframe containing 'Right_Anchor', 'score' (Trait Score), and 'Voice_Scale_Score'.
- style_color : str
- The name of the style (e.g., 'Green', 'Blue') for title and coloring.
- style_traits : list[str]
- List of trait descriptions (positive side) to include in the plot.
- These should match the 'Right_Anchor' column values.
+ # Add annotations for anchors
+ annotations = []
- Returns
- -------
- go.Figure
- """
-
- trait_correlations = []
-
- # 1. Calculate Correlations
- for i, trait in enumerate(style_traits):
- # Match against Right_Anchor which contains the positive trait description
- # Use exact match for reliability
- subset = df.filter(
- pl.col("Right_Anchor") == trait
+ # Place anchors at the bottom
+ if left_anchor:
+ annotations.append(dict(
+ xref='x', yref='paper',
+ x=1, y=-0.2, # Below axis
+ xanchor='left', yanchor='top',
+ text=f"1: {left_anchor.split('|')[0]}",
+ showarrow=False,
+ font=dict(size=10, color='gray')
+ ))
+ if right_anchor:
+ annotations.append(dict(
+ xref='x', yref='paper',
+ x=5, y=-0.2, # Below axis
+ xanchor='right', yanchor='top',
+ text=f"5: {right_anchor.split('|')[0]}",
+ showarrow=False,
+ font=dict(size=10, color='gray')
+ ))
+
+ fig.update_layout(
+ title=dict(
+ text=f"{title}
{trait_description}
(Numbers on bars indicate respondent count)",
+ y=0.92
+ ),
+ xaxis_title="Average Score (1-5)",
+ yaxis_title="Voice",
+ height=height if height else getattr(self, 'plot_height', 500),
+ width=width if width else getattr(self, 'plot_width', 1000),
+ plot_bgcolor=ColorPalette.BACKGROUND,
+ xaxis=dict(
+ range=[1, 5],
+ showgrid=True,
+ gridcolor=ColorPalette.GRID,
+ zeroline=False
+ ),
+ yaxis=dict(
+ showgrid=False
+ ),
+ margin=dict(b=120),
+ annotations=annotations,
+ font=dict(size=11)
+ )
+ self._save_plot(fig, title)
+ return fig
+
+ def plot_speaking_style_correlation(
+ self,
+ style_color: str,
+ style_traits: list[str],
+ data: pl.LazyFrame | pl.DataFrame | None = None,
+ title: str | None = None,
+ ) -> go.Figure:
+ """
+ Plots the correlation between Speaking Style Trait Scores (1-5) and Voice Scale (1-10) using a Bar Chart.
+ """
+ df = self._ensure_dataframe(data)
+
+ if title is None:
+ title = f"Speaking style and voice scale 1-10 correlations"
+
+ trait_correlations = []
+
+ # 1. Calculate Correlations
+ for i, trait in enumerate(style_traits):
+ # Match against Right_Anchor which contains the positive trait description
+ # Use exact match for reliability
+ subset = df.filter(
+ pl.col("Right_Anchor") == trait
+ )
+
+ # Drop Nulls for correlation calculation
+ valid_data = subset.select(["score", "Voice_Scale_Score"]).drop_nulls()
+
+ if valid_data.height > 1:
+ # Calculate Pearson Correlation
+ corr_val = valid_data.select(pl.corr("score", "Voice_Scale_Score")).item()
+
+ # Trait Label for Plot
+ trait_correlations.append({
+ "trait_full": trait,
+ "trait_short": f"Trait {i+1}",
+ "correlation": corr_val if corr_val is not None else 0.0
+ })
+
+ # 2. Build Plot Data
+ if not trait_correlations:
+ # Return empty fig with title
+ fig = go.Figure()
+ fig.update_layout(title=f"No data for {style_color} Style")
+ return fig
+
+ plot_df = pl.DataFrame(trait_correlations)
+
+ # Determine colors based on correlation sign
+ colors = []
+ for val in plot_df["correlation"]:
+ if val >= 0:
+ colors.append("green") # Positive
+ else:
+ colors.append("red") # Negative
+
+ fig = go.Figure()
+
+ fig.add_trace(go.Bar(
+ x=[f"Trait {i+1}" for i in range(len(plot_df))], # Simple Labels on Axis
+ y=plot_df["correlation"],
+ text=[f"{val:.2f}" for val in plot_df["correlation"]],
+ textposition='outside', # Or auto
+ marker_color=colors,
+ hovertemplate="%{customdata}
Correlation: %{y:.2f}",
+ customdata=plot_df["trait_full"] # Full text on hover
+ ))
+
+ # Wrap text at the "|" separator for cleaner line breaks
+ def wrap_text_at_pipe(text):
+ parts = [p.strip() for p in text.split("|")]
+ return "
".join(parts)
+
+ x_labels = [wrap_text_at_pipe(t) for t in plot_df["trait_full"]]
+
+ # Update trace to use full labels
+ fig.data[0].x = x_labels
+
+ fig.update_layout(
+ title=title,
+ yaxis_title="Correlation",
+ yaxis=dict(range=[-1, 1], zeroline=True, zerolinecolor="black"),
+ xaxis=dict(tickangle=0), # Keep flat if possible
+ height=400, # Use fixed default from original
+ width=1000,
+ template="plotly_white",
+ showlegend=False
)
- # Drop Nulls for correlation calculation
- valid_data = subset.select(["score", "Voice_Scale_Score"]).drop_nulls()
-
- if valid_data.height > 1:
- # Calculate Pearson Correlation
- corr_val = valid_data.select(pl.corr("score", "Voice_Scale_Score")).item()
-
- # Trait Label for Plot (Use the provided list text, maybe truncated or wrapped later)
- trait_label = f"Trait {i+1}: {trait}"
- # Or just "Trait {i+1}" and put full text in hover or subtitle?
- # User example showed "Trait 1", "Trait 2".
- # User request said "Use the traits directly".
- # Let's use the trait text as the x-axis label, perhaps wrapped.
-
- trait_correlations.append({
- "trait_full": trait,
- "trait_short": f"Trait {i+1}",
- "correlation": corr_val if corr_val is not None else 0.0
- })
-
- # 2. Build Plot Data
- if not trait_correlations:
- # Return empty fig with title
- fig = go.Figure()
- fig.update_layout(title=f"No data for {style_color} Style")
+ self._save_plot(fig, title)
return fig
-
- plot_df = pl.DataFrame(trait_correlations)
-
- # Determine colors based on correlation sign
- colors = []
- for val in plot_df["correlation"]:
- if val >= 0:
- colors.append("green") # Positive
- else:
- colors.append("red") # Negative
-
- fig = go.Figure()
-
- fig.add_trace(go.Bar(
- x=[f"Trait {i+1}" for i in range(len(plot_df))], # Simple Labels on Axis
- y=plot_df["correlation"],
- text=[f"{val:.2f}" for val in plot_df["correlation"]],
- textposition='outside', # Or auto
- marker_color=colors,
- hovertemplate="%{customdata}
Correlation: %{y:.2f}",
- customdata=plot_df["trait_full"] # Full text on hover
- ))
-
- # 3. Add Trait Descriptions as Subtitle or Annotation?
- # Or put on X-axis? The traits are long strings "Friendly | Conversational ...".
- # User's example has "Trait 1", "Trait 2" on axis.
- # But user specifically said "Use the traits directly".
- # This might mean "Don't map choice 1->Green, choice 2->Blue dynamically, trusting indices. Instead use the text match".
- # It might ALSO mean "Show the text on the chart".
- # The example image has simple "Trait X" labels.
- # I will stick to "Trait X" on axis but add the legend/list in the title or as annotations,
- # OR better: Use the full text on X-axis but with
wrapping.
- # Given the length ("Optimistic | Benevolent | Positive | Appreciative"), wrapping is needed.
-
- # Wrap text at the "|" separator for cleaner line breaks
- def wrap_text_at_pipe(text):
- parts = [p.strip() for p in text.split("|")]
- return "
".join(parts)
-
- x_labels = [wrap_text_at_pipe(t) for t in plot_df["trait_full"]]
-
- # Update trace to use full labels
- fig.data[0].x = x_labels
-
- fig.update_layout(
- title=title,
- yaxis_title="Correlation",
- yaxis=dict(range=[-1, 1], zeroline=True, zerolinecolor="black"),
- xaxis=dict(tickangle=0), # Keep flat if possible
- height=400,
- width=1000,
- template="plotly_white",
- showlegend=False
- )
-
- _save_plot(fig, results_dir, title)
- return fig
+ def plot_speaking_style_ranking_correlation(
+ self,
+ style_color: str,
+ style_traits: list[str],
+ data: pl.LazyFrame | pl.DataFrame | None = None,
+ title: str | None = None,
+ ) -> go.Figure:
+ """
+ Plots the correlation between Speaking Style Trait Scores (1-5) and Voice Ranking Points (0-3).
+ """
+ df = self._ensure_dataframe(data)
-def plot_speaking_style_ranking_correlation(
- df: pl.DataFrame,
- style_color: str,
- style_traits: list[str],
- title: str = None,
- results_dir: str | None = None,
-) -> go.Figure:
- """
- Plots the correlation between Speaking Style Trait Scores (1-5) and Voice Ranking Points (0-3).
- Each bar represents one trait.
-
- Parameters
- ----------
- df : pl.DataFrame
- Joined dataframe containing 'Right_Anchor', 'score' (Trait Score), and 'Ranking_Points'.
- style_color : str
- The name of the style (e.g., 'Green', 'Blue') for title and coloring.
- style_traits : list[str]
- List of trait descriptions (positive side) to include in the plot.
- These should match the 'Right_Anchor' column values.
- title : str, optional
- Custom title for the plot. If None, uses default.
+ if title is None:
+ title = f"Speaking style {style_color} and voice ranking points correlations"
- Returns
- -------
- go.Figure
- """
-
- if title is None:
- title = f"Speaking style {style_color} and voice ranking points correlations"
-
- trait_correlations = []
-
- # 1. Calculate Correlations
- for i, trait in enumerate(style_traits):
- # Match against Right_Anchor which contains the positive trait description
- subset = df.filter(pl.col("Right_Anchor") == trait)
+ trait_correlations = []
- # Drop Nulls for correlation calculation
- valid_data = subset.select(["score", "Ranking_Points"]).drop_nulls()
-
- if valid_data.height > 1:
- # Calculate Pearson Correlation
- corr_val = valid_data.select(pl.corr("score", "Ranking_Points")).item()
+ # 1. Calculate Correlations
+ for i, trait in enumerate(style_traits):
+ # Match against Right_Anchor which contains the positive trait description
+ subset = df.filter(pl.col("Right_Anchor") == trait)
- trait_correlations.append({
- "trait_full": trait,
- "trait_short": f"Trait {i+1}",
- "correlation": corr_val if corr_val is not None else 0.0
- })
-
- # 2. Build Plot Data
- if not trait_correlations:
+ # Drop Nulls for correlation calculation
+ valid_data = subset.select(["score", "Ranking_Points"]).drop_nulls()
+
+ if valid_data.height > 1:
+ # Calculate Pearson Correlation
+ corr_val = valid_data.select(pl.corr("score", "Ranking_Points")).item()
+
+ trait_correlations.append({
+ "trait_full": trait,
+ "trait_short": f"Trait {i+1}",
+ "correlation": corr_val if corr_val is not None else 0.0
+ })
+
+ # 2. Build Plot Data
+ if not trait_correlations:
+ fig = go.Figure()
+ fig.update_layout(title=f"No data for {style_color} Style")
+ return fig
+
+ plot_df = pl.DataFrame(trait_correlations)
+
+ # Determine colors based on correlation sign
+ colors = []
+ for val in plot_df["correlation"]:
+ if val >= 0:
+ colors.append("green")
+ else:
+ colors.append("red")
+
fig = go.Figure()
- fig.update_layout(title=f"No data for {style_color} Style")
- return fig
- plot_df = pl.DataFrame(trait_correlations)
-
- # Determine colors based on correlation sign
- colors = []
- for val in plot_df["correlation"]:
- if val >= 0:
- colors.append("green")
- else:
- colors.append("red")
+ fig.add_trace(go.Bar(
+ x=[f"Trait {i+1}" for i in range(len(plot_df))],
+ y=plot_df["correlation"],
+ text=[f"{val:.2f}" for val in plot_df["correlation"]],
+ textposition='outside',
+ marker_color=colors,
+ hovertemplate="%{customdata}
Correlation: %{y:.2f}",
+ customdata=plot_df["trait_full"]
+ ))
+
+ # Wrap text at the "|" separator for cleaner line breaks
+ def wrap_text_at_pipe(text):
+ parts = [p.strip() for p in text.split("|")]
+ return "
".join(parts)
- fig = go.Figure()
-
- fig.add_trace(go.Bar(
- x=[f"Trait {i+1}" for i in range(len(plot_df))],
- y=plot_df["correlation"],
- text=[f"{val:.2f}" for val in plot_df["correlation"]],
- textposition='outside',
- marker_color=colors,
- hovertemplate="%{customdata}
Correlation: %{y:.2f}",
- customdata=plot_df["trait_full"]
- ))
-
- # Wrap text at the "|" separator for cleaner line breaks
- def wrap_text_at_pipe(text):
- parts = [p.strip() for p in text.split("|")]
- return "
".join(parts)
+ x_labels = [wrap_text_at_pipe(t) for t in plot_df["trait_full"]]
- x_labels = [wrap_text_at_pipe(t) for t in plot_df["trait_full"]]
-
- # Update trace to use full labels
- fig.data[0].x = x_labels
-
- fig.update_layout(
- title=title,
- yaxis_title="Correlation",
- yaxis=dict(range=[-1, 1], zeroline=True, zerolinecolor="black"),
- xaxis=dict(tickangle=0),
- height=400,
- width=1000,
- template="plotly_white",
- showlegend=False
- )
-
- _save_plot(fig, results_dir, title)
- return fig
+ # Update trace to use full labels
+ fig.data[0].x = x_labels
+
+ fig.update_layout(
+ title=title,
+ yaxis_title="Correlation",
+ yaxis=dict(range=[-1, 1], zeroline=True, zerolinecolor="black"),
+ xaxis=dict(tickangle=0),
+ height=400,
+ width=1000,
+ template="plotly_white",
+ showlegend=False
+ )
+
+ self._save_plot(fig, title)
+ return fig
diff --git a/utils.py b/utils.py
index 15161f0..440a7d6 100644
--- a/utils.py
+++ b/utils.py
@@ -4,6 +4,9 @@ import pandas as pd
from typing import Union
import json
import re
+from plots import JPMCPlotsMixin
+
+import marimo as mo
def extract_voice_label(html_str: str) -> str:
"""
@@ -54,7 +57,7 @@ def combine_exclusive_columns(df: pl.DataFrame, id_col: str = "_recordId", targe
-def calculate_weighted_ranking_scores(df: pl.DataFrame) -> pl.DataFrame:
+def calculate_weighted_ranking_scores(df: pl.LazyFrame) -> pl.DataFrame:
"""
Calculate weighted scores for character or voice rankings.
Points system: 1st place = 3 pts, 2nd place = 2 pts, 3rd place = 1 pt.
@@ -69,6 +72,9 @@ def calculate_weighted_ranking_scores(df: pl.DataFrame) -> pl.DataFrame:
pl.DataFrame
DataFrame with columns 'Character' and 'Weighted Score', sorted by score.
"""
+ if isinstance(df, pl.LazyFrame):
+ df = df.collect()
+
scores = []
# Identify ranking columns (assume all columns except _recordId)
ranking_cols = [c for c in df.columns if c != '_recordId']
@@ -93,7 +99,7 @@ def calculate_weighted_ranking_scores(df: pl.DataFrame) -> pl.DataFrame:
return pl.DataFrame(scores).sort('Weighted Score', descending=True)
-class JPMCSurvey:
+class JPMCSurvey(JPMCPlotsMixin):
"""Class to handle JPMorgan Chase survey data."""
def __init__(self, data_path: Union[str, Path], qsf_path: Union[str, Path]):
@@ -112,6 +118,18 @@ class JPMCSurvey:
self.fig_save_dir = Path('figures') / self.data_filepath.parts[2]
if not self.fig_save_dir.exists():
self.fig_save_dir.mkdir(parents=True, exist_ok=True)
+
+ self.data_filtered = None
+ self.plot_height = 500
+ self.plot_width = 1000
+
+ # Filter values
+ self.filter_age:list = None
+ self.filter_gender:list = None
+ self.filter_consumer:list = None
+ self.filter_ethnicity:list = None
+ self.filter_income:list = None
+
def _extract_qid_descr_map(self) -> dict:
@@ -217,25 +235,32 @@ class JPMCSurvey:
- ethnicity: list
- income: list
- Returns filtered polars LazyFrame.
+ Also saves the result to self.data_filtered.
"""
+ # Apply filters
if age is not None:
+ self.filter_age = age
q = q.filter(pl.col('QID1').is_in(age))
if gender is not None:
+ self.filter_gender = gender
q = q.filter(pl.col('QID2').is_in(gender))
if consumer is not None:
+ self.filter_consumer = consumer
q = q.filter(pl.col('Consumer').is_in(consumer))
if ethnicity is not None:
+ self.filter_ethnicity = ethnicity
q = q.filter(pl.col('QID3').is_in(ethnicity))
if income is not None:
+ self.filter_income = income
q = q.filter(pl.col('QID15').is_in(income))
- return q
+ self.data_filtered = q
+ return self.data_filtered
def get_demographics(self, q: pl.LazyFrame) -> Union[pl.LazyFrame, None]:
"""Extract columns containing the demographics.