move plots to mixin class of JPMCSurvey to simplify file saving

This commit is contained in:
2026-01-28 14:54:36 +01:00
parent 23136b5c2e
commit 365e70b834
3 changed files with 899 additions and 1103 deletions

View File

@@ -12,9 +12,6 @@ def _():
from validation import check_progress, duration_validation
from utils import JPMCSurvey, combine_exclusive_columns, calculate_weighted_ranking_scores
from plots import plot_average_scores_with_counts, plot_top3_ranking_distribution, plot_ranking_distribution, plot_most_ranked_1, plot_weighted_ranking_score, plot_voice_selection_counts, plot_top3_selection_counts
import plots
import utils
from speaking_styles import SPEAKING_STYLES
@@ -27,13 +24,6 @@ def _():
duration_validation,
mo,
pl,
plot_most_ranked_1,
plot_ranking_distribution,
plot_top3_ranking_distribution,
plot_top3_selection_counts,
plot_voice_selection_counts,
plot_weighted_ranking_score,
plots,
utils,
)
@@ -47,10 +37,10 @@ def _():
@app.cell
def _(JPMCSurvey, QSF_FILE, RESULTS_FILE):
survey = JPMCSurvey(RESULTS_FILE, QSF_FILE)
data_all = survey.load_data()
S = JPMCSurvey(RESULTS_FILE, QSF_FILE)
data_all = S.load_data()
data_all.collect()
return data_all, survey
return S, data_all
@app.cell
@@ -108,18 +98,22 @@ def _(mo):
@app.cell(hide_code=True)
def _(data_all, mo):
data_all_collected = data_all.collect()
ages = mo.ui.multiselect(options=data_all_collected["QID1"], value=data_all_collected["QID1"].unique(), label="Select Age Group(s):")
age = mo.ui.multiselect(options=data_all_collected["QID1"], value=data_all_collected["QID1"].unique(), label="Select Age Group(s):")
income = mo.ui.multiselect(data_all_collected["QID15"], value=data_all_collected["QID15"], label="Select Income Group(s):")
gender = mo.ui.multiselect(data_all_collected["QID2"], value=data_all_collected["QID2"], label="Select Gender(s)")
ethnicity = mo.ui.multiselect(data_all_collected["QID3"], value=data_all_collected["QID3"], label="Select Ethnicities:")
consumer = mo.ui.multiselect(data_all_collected["Consumer"], value=data_all_collected["Consumer"], label="Select Consumer Groups:")
return age, consumer, ethnicity, gender, income
@app.cell
def _(age, consumer, ethnicity, gender, income, mo):
mo.md(f"""
# Data Filters
{ages}
{age}
{gender}
@@ -130,12 +124,14 @@ def _(data_all, mo):
{consumer}
""")
return ages, consumer, ethnicity, gender, income
return
@app.cell
def _(ages, consumer, data_all, ethnicity, gender, income, survey):
data = survey.filter_data(data_all, age=ages.value, gender=gender.value, income=income.value, ethnicity=ethnicity.value, consumer=consumer.value)
def _(S, age, consumer, data_all, ethnicity, gender, income):
data = S.filter_data(data_all, age=age.value, gender=gender.value, income=income.value, ethnicity=ethnicity.value, consumer=consumer.value)
data.collect()
return (data,)
@@ -159,49 +155,42 @@ def _(mo):
@app.cell
def _(data, survey):
char_rank = survey.get_character_ranking(data)[0].collect()
def _(S, data):
char_rank = S.get_character_ranking(data)[0]
return (char_rank,)
@app.cell
def _(char_rank, mo, plot_top3_ranking_distribution, survey):
def _(S, char_rank, mo):
mo.md(f"""
### 1. Which character personality is ranked best?
{mo.ui.plotly(plot_top3_ranking_distribution(char_rank, x_label='Character Personality', width=1000, results_dir=survey.fig_save_dir))}
{mo.ui.plotly(S.plot_top3_ranking_distribution(char_rank, x_label='Character Personality', width=1000))}
""")
return
@app.cell
def _(char_rank, mo, plot_most_ranked_1, survey):
def _(S, char_rank, mo):
mo.md(f"""
### 2. Which character personality is ranked 1st the most?
{mo.ui.plotly(plot_most_ranked_1(char_rank, title="Most Popular Character<br>(Number of Times Ranked 1st)", x_label='Character Personality', width=1000, results_dir=survey.fig_save_dir))}
{mo.ui.plotly(S.plot_most_ranked_1(char_rank, title="Most Popular Character<br>(Number of Times Ranked 1st)", x_label='Character Personality', width=1000))}
""")
return
@app.cell
def _(
calculate_weighted_ranking_scores,
char_rank,
mo,
plot_weighted_ranking_score,
survey,
):
def _(S, calculate_weighted_ranking_scores, char_rank, mo):
char_rank_weighted = calculate_weighted_ranking_scores(char_rank)
# plot_weighted_ranking_score(char_rank_weighted, x_label='Voice', width=1000)
mo.md(f"""
### 3. Which character personality most popular based on weighted scores?
{mo.ui.plotly(plot_weighted_ranking_score(char_rank_weighted, title="Most Popular Character - Weighted Popularity Score<br>(1st=3pts, 2nd=2pts, 3rd=1pt)", x_label='Voice', width=1000, results_dir=survey.fig_save_dir))}
{mo.ui.plotly(S.plot_weighted_ranking_score(char_rank_weighted, title="Most Popular Character - Weighted Popularity Score<br>(1st=3pts, 2nd=2pts, 3rd=1pt)", x_label='Voice', width=1000))}
""")
return
@@ -215,73 +204,73 @@ def _(mo):
@app.cell
def _(data, survey):
v_18_8_3 = survey.get_18_8_3(data)[0].collect()
def _(S, data):
v_18_8_3 = S.get_18_8_3(data)[0].collect()
# print(v_18_8_3.head())
return (v_18_8_3,)
@app.cell(hide_code=True)
def _(mo, plot_voice_selection_counts, survey, v_18_8_3):
def _(S, mo, v_18_8_3):
mo.md(f"""
### Which 8 voices are chosen the most out of 18?
{mo.ui.plotly(plot_voice_selection_counts(v_18_8_3, height=500, width=1000, results_dir=survey.fig_save_dir))}
{mo.ui.plotly(S.plot_voice_selection_counts(v_18_8_3, height=500, width=1000))}
""")
return
@app.cell(hide_code=True)
def _(mo, plot_top3_selection_counts, survey, v_18_8_3):
def _(S, mo, v_18_8_3):
mo.md(f"""
### Which 3 voices are chosen the most out of 18?
How many times does each voice end up in the top 3? ( this is based on the survey question where participants need to choose 3 out of the earlier selected 8 voices. So how often each of the 18 stimuli ended up in participants Top 3, after they first selected 8 out of 18.
{mo.ui.plotly(plot_top3_selection_counts(v_18_8_3, height=500, width=1000, results_dir=survey.fig_save_dir))}
{mo.ui.plotly(S.plot_top3_selection_counts(v_18_8_3, height=500, width=1000))}
""")
return
@app.cell(hide_code=True)
def _(calculate_weighted_ranking_scores, data, survey):
top3_voices = survey.get_top_3_voices(data)[0].collect()
def _(S, calculate_weighted_ranking_scores, data):
top3_voices = S.get_top_3_voices(data)[0]
top3_voices_weighted = calculate_weighted_ranking_scores(top3_voices)
return top3_voices, top3_voices_weighted
@app.cell
def _(mo, plot_ranking_distribution, survey, top3_voices):
def _(S, mo, top3_voices):
mo.md(f"""
### Which voice is ranked best in the ranking question for top 3?
(not best 3 out of 8 question)
{mo.ui.plotly(plot_ranking_distribution(top3_voices, x_label='Voice', width=1000, results_dir=survey.fig_save_dir))}
{mo.ui.plotly(S.plot_ranking_distribution(top3_voices, x_label='Voice', width=1000))}
""")
return
@app.cell
def _(mo, plot_weighted_ranking_score, survey, top3_voices_weighted):
def _(S, mo, top3_voices_weighted):
mo.md(f"""
### Most popular **voice** based on weighted scores?
- E.g. 1 point for place 3. 2 points for place 2 and 3 points for place 1. The voice with most points is ranked best.
Distribution of the rankings for each voice:
{mo.ui.plotly(plot_weighted_ranking_score(top3_voices_weighted, title="Most Popular Voice - Weighted Popularity Score<br>(1st = 3pts, 2nd = 2pts, 3rd = 1pt)", height=500, width=1000, results_dir=survey.fig_save_dir))}
{mo.ui.plotly(S.plot_weighted_ranking_score(top3_voices_weighted, title="Most Popular Voice - Weighted Popularity Score<br>(1st = 3pts, 2nd = 2pts, 3rd = 1pt)", height=500, width=1000))}
""")
return
@app.cell
def _(mo, plot_most_ranked_1, survey, top3_voices):
def _(S, mo, top3_voices):
mo.md(f"""
### Which voice is ranked number 1 the most?
(not always the voice with most points)
{mo.ui.plotly(plot_most_ranked_1(top3_voices, title="Most Popular Voice<br>(Number of Times Ranked 1st)", x_label='Voice', width=1000, results_dir=survey.fig_save_dir))}
{mo.ui.plotly(S.plot_most_ranked_1(top3_voices, title="Most Popular Voice<br>(Number of Times Ranked 1st)", x_label='Voice', width=1000))}
""")
return
@@ -297,9 +286,9 @@ def _(mo):
@app.cell
def _(data, survey, utils):
ss_or, choice_map_or = survey.get_ss_orange_red(data)
ss_gb, choice_map_gb = survey.get_ss_green_blue(data)
def _(S, data, utils):
ss_or, choice_map_or = S.get_ss_orange_red(data)
ss_gb, choice_map_gb = S.get_ss_green_blue(data)
# Combine the data
ss_all = ss_or.join(ss_gb, on='_recordId')
@@ -313,7 +302,7 @@ def _(data, survey, utils):
@app.cell
def _(mo, pl, plots, ss_long, survey):
def _(S, mo, pl, ss_long):
content = """### How does each voice score for each “speaking style labeled trait”?"""
for i, trait in enumerate(ss_long.select("Description").unique().to_series().to_list()):
@@ -322,7 +311,7 @@ def _(mo, pl, plots, ss_long, survey):
content += f"""
### {i+1}) {trait.replace(":", "")}
{mo.ui.plotly(plots.plot_speaking_style_trait_scores(trait_d, title=trait.replace(":", ""), height=550, results_dir=survey.fig_save_dir))}
{mo.ui.plotly(S.plot_speaking_style_trait_scores(trait_d, title=trait.replace(":", ""), height=550))}
"""
mo.md(content)
@@ -338,18 +327,18 @@ def _(mo):
@app.cell
def _(data, survey):
vscales = survey.get_voice_scale_1_10(data)[0].collect()
def _(S, data):
vscales = S.get_voice_scale_1_10(data)[0]
# plot_average_scores_with_counts(vscales, x_label='Voice', width=1000)
return (vscales,)
@app.cell
def _(mo, plots, survey, vscales):
def _(S, mo, vscales):
mo.md(f"""
### How does each voice score on a scale from 1-10?
{mo.ui.plotly(plots.plot_average_scores_with_counts(vscales, x_label='Voice', width=1000, results_dir=survey.fig_save_dir))}
{mo.ui.plotly(S.plot_average_scores_with_counts(vscales, x_label='Voice', width=1000))}
""")
return
@@ -394,7 +383,7 @@ def _(mo):
@app.cell
def _(choice_map, ss_all, utils, vscales):
df_style = utils.process_speaking_style_data(ss_all.collect(), choice_map)
df_style = utils.process_speaking_style_data(ss_all, choice_map)
df_voice_long = utils.process_voice_scale_data(vscales)
joined_df = df_style.join(df_voice_long, on=["_recordId", "Voice"], how="inner")
@@ -403,19 +392,18 @@ def _(choice_map, ss_all, utils, vscales):
@app.cell
def _(SPEAKING_STYLES, joined_df, mo, plots, survey):
def _(S, SPEAKING_STYLES, joined_df, mo):
_content = """### Total Results
"""
for style, traits in SPEAKING_STYLES.items():
# print(f"Correlation plot for {style}...")
fig = plots.plot_speaking_style_correlation(
df=joined_df,
fig = S.plot_speaking_style_correlation(
data=joined_df,
style_color=style,
style_traits=traits,
title=f"Correlation: Speaking Style {style} and Voice Scale 1-10",
results_dir=survey.fig_save_dir
title=f"Correlation: Speaking Style {style} and Voice Scale 1-10"
)
_content += f"""
#### Speaking Style **{style}**:
@@ -470,7 +458,7 @@ def _(mo):
@app.cell
def _(SPEAKING_STYLES, df_style, mo, plots, survey, top3_voices, utils):
def _(S, SPEAKING_STYLES, df_style, mo, top3_voices, utils):
df_ranking = utils.process_voice_ranking_data(top3_voices)
joined = df_style.join(df_ranking, on=['_recordId', 'Voice'], how='inner')
@@ -480,7 +468,7 @@ def _(SPEAKING_STYLES, df_style, mo, plots, survey, top3_voices, utils):
"""
for _style, _traits in SPEAKING_STYLES.items():
_fig = plots.plot_speaking_style_ranking_correlation(joined, _style, _traits, results_dir=survey.fig_save_dir)
_fig = S.plot_speaking_style_ranking_correlation(data=joined, style_color=_style, style_traits=_traits)
_content += f"""
#### Speaking Style **{_style}**:

467
plots.py
View File

@@ -8,7 +8,10 @@ import polars as pl
from theme import ColorPalette
def _sanitize_filename(title: str) -> str:
class JPMCPlotsMixin:
"""Mixin class for plotting functions in JPMCSurvey."""
def _sanitize_filename(self, title: str) -> str:
"""Convert plot title to a safe filename."""
# Remove HTML tags
clean = re.sub(r'<[^>]+>', ' ', title)
@@ -21,53 +24,41 @@ def _sanitize_filename(title: str) -> str:
# Lowercase and limit length
return clean.lower()[:100]
def _save_plot(fig: go.Figure, results_dir: str | None, title: str) -> None:
"""Save plot to PNG file if results_dir is provided."""
if results_dir:
path = Path(results_dir)
def _save_plot(self, fig: go.Figure, title: str) -> None:
"""Save plot to PNG file if fig_save_dir is set."""
if hasattr(self, 'fig_save_dir') and self.fig_save_dir:
path = Path(self.fig_save_dir)
if not path.exists():
path.mkdir(parents=True, exist_ok=True)
filename = f"{_sanitize_filename(title)}.png"
filename = f"{self._sanitize_filename(title)}.png"
fig.write_image(path / filename, width=fig.layout.width, height=fig.layout.height)
def _ensure_dataframe(self, data: pl.LazyFrame | pl.DataFrame | None) -> pl.DataFrame:
"""Ensure data is an eager DataFrame, collecting if necessary."""
df = data if data is not None else getattr(self, 'data_filtered', None)
if df is None:
raise ValueError("No data provided and self.data_filtered is None.")
if isinstance(df, pl.LazyFrame):
return df.collect()
return df
def plot_average_scores_with_counts(
df: pl.DataFrame,
def plot_average_scores_with_counts(
self,
data: pl.LazyFrame | pl.DataFrame | None = None,
title: str = "General Impression (1-10)<br>Per Voice with Number of Participants Who Rated It",
x_label: str = "Stimuli",
y_label: str = "Average General Impression Rating (1-10)",
color: str = ColorPalette.PRIMARY,
height: int = 500,
width: int = 1000,
results_dir: str | None = None,
) -> go.Figure:
height: int | None = None,
width: int | None = None,
) -> go.Figure:
"""
Create a bar plot showing average scores and count of non-null values for each column.
Parameters
----------
df : pl.DataFrame
DataFrame containing numeric columns to analyze.
title : str, optional
Plot title.
x_label : str, optional
X-axis label.
y_label : str, optional
Y-axis label.
color : str, optional
Bar color (hex code or named color).
height : int, optional
Plot height in pixels.
width : int, optional
Plot width in pixels.
Returns
-------
go.Figure
Plotly figure object.
"""
# Calculate average and count of non-null values for each column
df = self._ensure_dataframe(data)
# Exclude _recordId column
stats = []
for col in [c for c in df.columns if c != '_recordId']:
@@ -102,8 +93,8 @@ def plot_average_scores_with_counts(
title=title,
xaxis_title=x_label,
yaxis_title=y_label,
height=height,
width=width,
height=height if height else getattr(self, 'plot_height', 500),
width=width if width else getattr(self, 'plot_width', 1000),
plot_bgcolor=ColorPalette.BACKGROUND,
xaxis=dict(
showgrid=True,
@@ -118,45 +109,23 @@ def plot_average_scores_with_counts(
font=dict(size=11)
)
_save_plot(fig, results_dir, title)
self._save_plot(fig, title)
return fig
def plot_top3_ranking_distribution(
df: pl.DataFrame,
def plot_top3_ranking_distribution(
self,
data: pl.LazyFrame | pl.DataFrame | None = None,
title: str = "Top 3 Rankings Distribution<br>Count of 1st, 2nd, and 3rd Place Votes per Voice",
x_label: str = "Voices",
y_label: str = "Number of Mentions in Top 3",
height: int = 500,
width: int = 1000,
results_dir: str | None = None,
) -> go.Figure:
height: int | None = None,
width: int | None = None,
) -> go.Figure:
"""
Create a stacked bar chart showing how often each voice was ranked 1st, 2nd, or 3rd.
The total height of the bar represents the popularity (frequency of being in Top 3),
while the segments show the quality of those rankings.
Parameters
----------
df : pl.DataFrame
DataFrame containing ranking columns (values 1, 2, 3).
title : str, optional
Plot title.
x_label : str, optional
X-axis label.
y_label : str, optional
Y-axis label.
height : int, optional
Plot height in pixels.
width : int, optional
Plot width in pixels.
Returns
-------
go.Figure
Plotly figure object.
"""
df = self._ensure_dataframe(data)
# Exclude _recordId column
stats = []
for col in [c for c in df.columns if c != '_recordId']:
@@ -219,8 +188,8 @@ def plot_top3_ranking_distribution(
title=title,
xaxis_title=x_label,
yaxis_title=y_label,
height=height,
width=width,
height=height if height else getattr(self, 'plot_height', 500),
width=width if width else getattr(self, 'plot_width', 1000),
plot_bgcolor=ColorPalette.BACKGROUND,
xaxis=dict(
showgrid=True,
@@ -242,43 +211,24 @@ def plot_top3_ranking_distribution(
font=dict(size=11)
)
_save_plot(fig, results_dir, title)
self._save_plot(fig, title)
return fig
def plot_ranking_distribution(
df: pl.DataFrame,
def plot_ranking_distribution(
self,
data: pl.LazyFrame | pl.DataFrame | None = None,
title: str = "Rankings Distribution<br>(1st to 4th Place)",
x_label: str = "Item",
y_label: str = "Number of Votes",
height: int = 500,
width: int = 1000,
results_dir: str | None = None,
) -> go.Figure:
height: int | None = None,
width: int | None = None,
) -> go.Figure:
"""
Create a stacked bar chart showing the distribution of rankings (1st to 4th) for characters or voices.
Sorted by the number of Rank 1 votes.
Parameters
----------
df : pl.DataFrame
DataFrame containing ranking columns.
title : str, optional
Plot title.
x_label : str, optional
X-axis label.
y_label : str, optional
Y-axis label.
height : int, optional
Plot height in pixels.
width : int, optional
Plot width in pixels.
Returns
-------
go.Figure
Plotly figure object.
"""
df = self._ensure_dataframe(data)
stats = []
# Identify ranking columns (assume all columns except _recordId)
ranking_cols = [c for c in df.columns if c != '_recordId']
@@ -359,8 +309,8 @@ def plot_ranking_distribution(
title=title,
xaxis_title=x_label,
yaxis_title=y_label,
height=height,
width=width,
height=height if height else getattr(self, 'plot_height', 500),
width=width if width else getattr(self, 'plot_width', 1000),
plot_bgcolor=ColorPalette.BACKGROUND,
xaxis=dict(
showgrid=True,
@@ -382,43 +332,24 @@ def plot_ranking_distribution(
font=dict(size=11)
)
_save_plot(fig, results_dir, title)
self._save_plot(fig, title)
return fig
def plot_most_ranked_1(
df: pl.DataFrame,
def plot_most_ranked_1(
self,
data: pl.LazyFrame | pl.DataFrame | None = None,
title: str = "Most Popular Choice<br>(Number of Times Ranked 1st)",
x_label: str = "Item",
y_label: str = "Count of 1st Place Rankings",
height: int = 500,
width: int = 1000,
results_dir: str | None = None,
) -> go.Figure:
height: int | None = None,
width: int | None = None,
) -> go.Figure:
"""
Create a bar chart showing which item (character/voice) was ranked #1 the most.
Top 3 items are highlighted.
Parameters
----------
df : pl.DataFrame
DataFrame containing ranking columns.
title : str, optional
Plot title.
x_label : str, optional
X-axis label.
y_label : str, optional
Y-axis label.
height : int, optional
Plot height in pixels.
width : int, optional
Plot width in pixels.
Returns
-------
go.Figure
Plotly figure object.
"""
df = self._ensure_dataframe(data)
stats = []
# Identify ranking columns (assume all columns except _recordId)
ranking_cols = [c for c in df.columns if c != '_recordId']
@@ -463,8 +394,8 @@ def plot_most_ranked_1(
title=title,
xaxis_title=x_label,
yaxis_title=y_label,
height=height,
width=width,
height=height if height else getattr(self, 'plot_height', 500),
width=width if width else getattr(self, 'plot_width', 1000),
plot_bgcolor=ColorPalette.BACKGROUND,
xaxis=dict(
showgrid=True,
@@ -478,46 +409,23 @@ def plot_most_ranked_1(
font=dict(size=11)
)
_save_plot(fig, results_dir, title)
self._save_plot(fig, title)
return fig
def plot_weighted_ranking_score(
weighted_df: pl.DataFrame,
def plot_weighted_ranking_score(
self,
data: pl.LazyFrame | pl.DataFrame | None = None,
title: str = "Weighted Popularity Score<br>(1st=3pts, 2nd=2pts, 3rd=1pt)",
x_label: str = "Character Personality",
y_label: str = "Total Weighted Score",
color: str = ColorPalette.PRIMARY,
height: int = 500,
width: int = 1000,
results_dir: str | None = None,
) -> go.Figure:
height: int | None = None,
width: int | None = None,
) -> go.Figure:
"""
Create a bar chart showing the weighted ranking score for each character.
Parameters
----------
df : pl.DataFrame
DataFrame containing ranking columns.
title : str, optional
Plot title.
x_label : str, optional
X-axis label.
y_label : str, optional
Y-axis label.
color : str, optional
Bar color.
height : int, optional
Plot height.
width : int, optional
Plot width.
Returns
-------
go.Figure
Plotly figure object.
"""
weighted_df = self._ensure_dataframe(data)
fig = go.Figure()
@@ -535,8 +443,8 @@ def plot_weighted_ranking_score(
title=title,
xaxis_title=x_label,
yaxis_title=y_label,
height=height,
width=width,
height=height if height else getattr(self, 'plot_height', 500),
width=width if width else getattr(self, 'plot_width', 1000),
plot_bgcolor=ColorPalette.BACKGROUND,
xaxis=dict(
showgrid=True,
@@ -550,48 +458,24 @@ def plot_weighted_ranking_score(
font=dict(size=11)
)
_save_plot(fig, results_dir, title)
self._save_plot(fig, title)
return fig
def plot_voice_selection_counts(
df: pl.DataFrame,
def plot_voice_selection_counts(
self,
data: pl.LazyFrame | pl.DataFrame | None = None,
target_column: str = "8_Combined",
title: str = "Most Frequently Chosen Voices<br>(Top 8 Highlighted)",
x_label: str = "Voice",
y_label: str = "Number of Times Chosen",
height: int = 500,
width: int = 1000,
results_dir: str | None = None,
) -> go.Figure:
height: int | None = None,
width: int | None = None,
) -> go.Figure:
"""
Create a bar plot showing the frequency of voice selections.
Takes a column containing comma-separated values (e.g. "Voice 1, Voice 2..."),
counts occurrences, and highlights the top 8 most frequent voices.
Parameters
----------
df : pl.DataFrame
DataFrame containing the selection column.
target_column : str, optional
Name of the column containing comma-separated voice selections.
Defaults to "8_Combined".
title : str, optional
Plot title.
x_label : str, optional
X-axis label.
y_label : str, optional
Y-axis label.
height : int, optional
Plot height in pixels.
width : int, optional
Plot width in pixels.
Returns
-------
go.Figure
Plotly figure object.
"""
df = self._ensure_dataframe(data)
if target_column not in df.columns:
return go.Figure()
@@ -634,8 +518,8 @@ def plot_voice_selection_counts(
title=title,
xaxis_title=x_label,
yaxis_title=y_label,
height=height,
width=width,
height=height if height else getattr(self, 'plot_height', 500),
width=width if width else getattr(self, 'plot_width', 1000),
plot_bgcolor=ColorPalette.BACKGROUND,
xaxis=dict(
showgrid=True,
@@ -649,51 +533,24 @@ def plot_voice_selection_counts(
font=dict(size=11),
)
_save_plot(fig, results_dir, title)
self._save_plot(fig, title)
return fig
def plot_top3_selection_counts(
df: pl.DataFrame,
def plot_top3_selection_counts(
self,
data: pl.LazyFrame | pl.DataFrame | None = None,
target_column: str = "3_Ranked",
title: str = "Most Frequently Chosen Top 3 Voices<br>(Top 3 Highlighted)",
x_label: str = "Voice",
y_label: str = "Count of Mentions in Top 3",
height: int = 500,
width: int = 1000,
results_dir: str | None = None,
) -> go.Figure:
height: int | None = None,
width: int | None = None,
) -> go.Figure:
"""
Question: Which 3 voices are chosen the most out of 18?
How many times does each voice end up in the top 3?
(this is based on the survey question where participants need to choose 3 out
of the earlier selected 8 voices). So how often each of the 18 stimuli ended
up in participants' Top 3, after they first selected 8 out of 18.
Parameters
----------
df : pl.DataFrame
DataFrame containing the ranking column (comma-separated strings).
target_column : str, optional
Name of the column containing comma-separated Top 3 voice elections.
Defaults to "3_Ranked".
title : str, optional
Plot title.
x_label : str, optional
X-axis label.
y_label : str, optional
Y-axis label.
height : int, optional
Plot height in pixels.
width : int, optional
Plot width in pixels.
Returns
-------
go.Figure
Plotly figure object.
"""
df = self._ensure_dataframe(data)
if target_column not in df.columns:
return go.Figure()
@@ -732,8 +589,8 @@ def plot_top3_selection_counts(
title=title,
xaxis_title=x_label,
yaxis_title=y_label,
height=height,
width=width,
height=height if height else getattr(self, 'plot_height', 500),
width=width if width else getattr(self, 'plot_width', 1000),
plot_bgcolor=ColorPalette.BACKGROUND,
xaxis=dict(
showgrid=True,
@@ -747,53 +604,24 @@ def plot_top3_selection_counts(
font=dict(size=11),
)
_save_plot(fig, results_dir, title)
self._save_plot(fig, title)
return fig
def plot_speaking_style_trait_scores(
df: pl.DataFrame,
def plot_speaking_style_trait_scores(
self,
data: pl.LazyFrame | pl.DataFrame | None = None,
trait_description: str = None,
left_anchor: str = None,
right_anchor: str = None,
title: str = "Speaking Style Trait Analysis",
height: int = 500,
width: int = 1000,
results_dir: str | None = None,
) -> go.Figure:
height: int | None = None,
width: int | None = None,
) -> go.Figure:
"""
Plot scores for a single speaking style trait across multiple voices.
The plot shows the average score per Voice, sorted by score.
It expects the DataFrame to contain 'Voice' and 'score' columns,
typically filtered for a single trait/description.
Parameters
----------
df : pl.DataFrame
DataFrame containing at least 'Voice' and 'score' columns.
Produced by utils.process_speaking_style_data and filtered.
trait_description : str, optional
Description of the trait being analyzed (e.g. "Indifferent : Attentive").
If not provided, it will be constructed from annotations.
left_anchor : str, optional
Label for the lower end of the scale (e.g. "Indifferent").
If not provided, attempts to read 'Left_Anchor' column from df.
right_anchor : str, optional
Label for the upper end of the scale (e.g. "Attentive").
If not provided, attempts to read 'Right_Anchor' column from df.
title : str, optional
Plot title.
height : int, optional
Plot height.
width : int, optional
Plot width.
Returns
-------
go.Figure
Plotly figure object.
"""
df = self._ensure_dataframe(data)
if df.is_empty():
return go.Figure()
@@ -878,8 +706,8 @@ def plot_speaking_style_trait_scores(
),
xaxis_title="Average Score (1-5)",
yaxis_title="Voice",
height=height,
width=width,
height=height if height else getattr(self, 'plot_height', 500),
width=width if width else getattr(self, 'plot_width', 1000),
plot_bgcolor=ColorPalette.BACKGROUND,
xaxis=dict(
range=[1, 5],
@@ -894,34 +722,23 @@ def plot_speaking_style_trait_scores(
annotations=annotations,
font=dict(size=11)
)
_save_plot(fig, results_dir, title)
self._save_plot(fig, title)
return fig
def plot_speaking_style_correlation(
df: pl.DataFrame,
def plot_speaking_style_correlation(
self,
style_color: str,
style_traits: list[str],
title=f"Speaking style and voice scale 1-10 correlations",
results_dir: str | None = None,
) -> go.Figure:
data: pl.LazyFrame | pl.DataFrame | None = None,
title: str | None = None,
) -> go.Figure:
"""
Plots the correlation between Speaking Style Trait Scores (1-5) and Voice Scale (1-10) using a Bar Chart.
Each bar represents one trait.
Parameters
----------
df : pl.DataFrame
Joined dataframe containing 'Right_Anchor', 'score' (Trait Score), and 'Voice_Scale_Score'.
style_color : str
The name of the style (e.g., 'Green', 'Blue') for title and coloring.
style_traits : list[str]
List of trait descriptions (positive side) to include in the plot.
These should match the 'Right_Anchor' column values.
Returns
-------
go.Figure
"""
df = self._ensure_dataframe(data)
if title is None:
title = f"Speaking style and voice scale 1-10 correlations"
trait_correlations = []
@@ -940,13 +757,7 @@ def plot_speaking_style_correlation(
# Calculate Pearson Correlation
corr_val = valid_data.select(pl.corr("score", "Voice_Scale_Score")).item()
# Trait Label for Plot (Use the provided list text, maybe truncated or wrapped later)
trait_label = f"Trait {i+1}: {trait}"
# Or just "Trait {i+1}" and put full text in hover or subtitle?
# User example showed "Trait 1", "Trait 2".
# User request said "Use the traits directly".
# Let's use the trait text as the x-axis label, perhaps wrapped.
# Trait Label for Plot
trait_correlations.append({
"trait_full": trait,
"trait_short": f"Trait {i+1}",
@@ -982,17 +793,6 @@ def plot_speaking_style_correlation(
customdata=plot_df["trait_full"] # Full text on hover
))
# 3. Add Trait Descriptions as Subtitle or Annotation?
# Or put on X-axis? The traits are long strings "Friendly | Conversational ...".
# User's example has "Trait 1", "Trait 2" on axis.
# But user specifically said "Use the traits directly".
# This might mean "Don't map choice 1->Green, choice 2->Blue dynamically, trusting indices. Instead use the text match".
# It might ALSO mean "Show the text on the chart".
# The example image has simple "Trait X" labels.
# I will stick to "Trait X" on axis but add the legend/list in the title or as annotations,
# OR better: Use the full text on X-axis but with <br> wrapping.
# Given the length ("Optimistic | Benevolent | Positive | Appreciative"), wrapping is needed.
# Wrap text at the "|" separator for cleaner line breaks
def wrap_text_at_pipe(text):
parts = [p.strip() for p in text.split("|")]
@@ -1008,43 +808,26 @@ def plot_speaking_style_correlation(
yaxis_title="Correlation",
yaxis=dict(range=[-1, 1], zeroline=True, zerolinecolor="black"),
xaxis=dict(tickangle=0), # Keep flat if possible
height=400,
height=400, # Use fixed default from original
width=1000,
template="plotly_white",
showlegend=False
)
_save_plot(fig, results_dir, title)
self._save_plot(fig, title)
return fig
def plot_speaking_style_ranking_correlation(
df: pl.DataFrame,
def plot_speaking_style_ranking_correlation(
self,
style_color: str,
style_traits: list[str],
title: str = None,
results_dir: str | None = None,
) -> go.Figure:
data: pl.LazyFrame | pl.DataFrame | None = None,
title: str | None = None,
) -> go.Figure:
"""
Plots the correlation between Speaking Style Trait Scores (1-5) and Voice Ranking Points (0-3).
Each bar represents one trait.
Parameters
----------
df : pl.DataFrame
Joined dataframe containing 'Right_Anchor', 'score' (Trait Score), and 'Ranking_Points'.
style_color : str
The name of the style (e.g., 'Green', 'Blue') for title and coloring.
style_traits : list[str]
List of trait descriptions (positive side) to include in the plot.
These should match the 'Right_Anchor' column values.
title : str, optional
Custom title for the plot. If None, uses default.
Returns
-------
go.Figure
"""
df = self._ensure_dataframe(data)
if title is None:
title = f"Speaking style {style_color} and voice ranking points correlations"
@@ -1118,5 +901,5 @@ def plot_speaking_style_ranking_correlation(
showlegend=False
)
_save_plot(fig, results_dir, title)
self._save_plot(fig, title)
return fig

View File

@@ -4,6 +4,9 @@ import pandas as pd
from typing import Union
import json
import re
from plots import JPMCPlotsMixin
import marimo as mo
def extract_voice_label(html_str: str) -> str:
"""
@@ -54,7 +57,7 @@ def combine_exclusive_columns(df: pl.DataFrame, id_col: str = "_recordId", targe
def calculate_weighted_ranking_scores(df: pl.DataFrame) -> pl.DataFrame:
def calculate_weighted_ranking_scores(df: pl.LazyFrame) -> pl.DataFrame:
"""
Calculate weighted scores for character or voice rankings.
Points system: 1st place = 3 pts, 2nd place = 2 pts, 3rd place = 1 pt.
@@ -69,6 +72,9 @@ def calculate_weighted_ranking_scores(df: pl.DataFrame) -> pl.DataFrame:
pl.DataFrame
DataFrame with columns 'Character' and 'Weighted Score', sorted by score.
"""
if isinstance(df, pl.LazyFrame):
df = df.collect()
scores = []
# Identify ranking columns (assume all columns except _recordId)
ranking_cols = [c for c in df.columns if c != '_recordId']
@@ -93,7 +99,7 @@ def calculate_weighted_ranking_scores(df: pl.DataFrame) -> pl.DataFrame:
return pl.DataFrame(scores).sort('Weighted Score', descending=True)
class JPMCSurvey:
class JPMCSurvey(JPMCPlotsMixin):
"""Class to handle JPMorgan Chase survey data."""
def __init__(self, data_path: Union[str, Path], qsf_path: Union[str, Path]):
@@ -113,6 +119,18 @@ class JPMCSurvey:
if not self.fig_save_dir.exists():
self.fig_save_dir.mkdir(parents=True, exist_ok=True)
self.data_filtered = None
self.plot_height = 500
self.plot_width = 1000
# Filter values
self.filter_age:list = None
self.filter_gender:list = None
self.filter_consumer:list = None
self.filter_ethnicity:list = None
self.filter_income:list = None
def _extract_qid_descr_map(self) -> dict:
"""Extract mapping of Qualtrics ImportID to Question Description from results file."""
@@ -217,25 +235,32 @@ class JPMCSurvey:
- ethnicity: list
- income: list
Returns filtered polars LazyFrame.
Also saves the result to self.data_filtered.
"""
# Apply filters
if age is not None:
self.filter_age = age
q = q.filter(pl.col('QID1').is_in(age))
if gender is not None:
self.filter_gender = gender
q = q.filter(pl.col('QID2').is_in(gender))
if consumer is not None:
self.filter_consumer = consumer
q = q.filter(pl.col('Consumer').is_in(consumer))
if ethnicity is not None:
self.filter_ethnicity = ethnicity
q = q.filter(pl.col('QID3').is_in(ethnicity))
if income is not None:
self.filter_income = income
q = q.filter(pl.col('QID15').is_in(income))
return q
self.data_filtered = q
return self.data_filtered
def get_demographics(self, q: pl.LazyFrame) -> Union[pl.LazyFrame, None]:
"""Extract columns containing the demographics.