move plots to mixin class of JPMCSurvey to simplify file saving

This commit is contained in:
2026-01-28 14:54:36 +01:00
parent 23136b5c2e
commit 365e70b834
3 changed files with 899 additions and 1103 deletions

View File

@@ -12,9 +12,6 @@ def _():
from validation import check_progress, duration_validation from validation import check_progress, duration_validation
from utils import JPMCSurvey, combine_exclusive_columns, calculate_weighted_ranking_scores from utils import JPMCSurvey, combine_exclusive_columns, calculate_weighted_ranking_scores
from plots import plot_average_scores_with_counts, plot_top3_ranking_distribution, plot_ranking_distribution, plot_most_ranked_1, plot_weighted_ranking_score, plot_voice_selection_counts, plot_top3_selection_counts
import plots
import utils import utils
from speaking_styles import SPEAKING_STYLES from speaking_styles import SPEAKING_STYLES
@@ -27,13 +24,6 @@ def _():
duration_validation, duration_validation,
mo, mo,
pl, pl,
plot_most_ranked_1,
plot_ranking_distribution,
plot_top3_ranking_distribution,
plot_top3_selection_counts,
plot_voice_selection_counts,
plot_weighted_ranking_score,
plots,
utils, utils,
) )
@@ -47,10 +37,10 @@ def _():
@app.cell @app.cell
def _(JPMCSurvey, QSF_FILE, RESULTS_FILE): def _(JPMCSurvey, QSF_FILE, RESULTS_FILE):
survey = JPMCSurvey(RESULTS_FILE, QSF_FILE) S = JPMCSurvey(RESULTS_FILE, QSF_FILE)
data_all = survey.load_data() data_all = S.load_data()
data_all.collect() data_all.collect()
return data_all, survey return S, data_all
@app.cell @app.cell
@@ -108,18 +98,22 @@ def _(mo):
@app.cell(hide_code=True) @app.cell(hide_code=True)
def _(data_all, mo): def _(data_all, mo):
data_all_collected = data_all.collect() data_all_collected = data_all.collect()
ages = mo.ui.multiselect(options=data_all_collected["QID1"], value=data_all_collected["QID1"].unique(), label="Select Age Group(s):") age = mo.ui.multiselect(options=data_all_collected["QID1"], value=data_all_collected["QID1"].unique(), label="Select Age Group(s):")
income = mo.ui.multiselect(data_all_collected["QID15"], value=data_all_collected["QID15"], label="Select Income Group(s):") income = mo.ui.multiselect(data_all_collected["QID15"], value=data_all_collected["QID15"], label="Select Income Group(s):")
gender = mo.ui.multiselect(data_all_collected["QID2"], value=data_all_collected["QID2"], label="Select Gender(s)") gender = mo.ui.multiselect(data_all_collected["QID2"], value=data_all_collected["QID2"], label="Select Gender(s)")
ethnicity = mo.ui.multiselect(data_all_collected["QID3"], value=data_all_collected["QID3"], label="Select Ethnicities:") ethnicity = mo.ui.multiselect(data_all_collected["QID3"], value=data_all_collected["QID3"], label="Select Ethnicities:")
consumer = mo.ui.multiselect(data_all_collected["Consumer"], value=data_all_collected["Consumer"], label="Select Consumer Groups:") consumer = mo.ui.multiselect(data_all_collected["Consumer"], value=data_all_collected["Consumer"], label="Select Consumer Groups:")
return age, consumer, ethnicity, gender, income
@app.cell
def _(age, consumer, ethnicity, gender, income, mo):
mo.md(f""" mo.md(f"""
# Data Filters # Data Filters
{ages} {age}
{gender} {gender}
@@ -130,12 +124,14 @@ def _(data_all, mo):
{consumer} {consumer}
""") """)
return ages, consumer, ethnicity, gender, income
return
@app.cell @app.cell
def _(ages, consumer, data_all, ethnicity, gender, income, survey): def _(S, age, consumer, data_all, ethnicity, gender, income):
data = survey.filter_data(data_all, age=ages.value, gender=gender.value, income=income.value, ethnicity=ethnicity.value, consumer=consumer.value) data = S.filter_data(data_all, age=age.value, gender=gender.value, income=income.value, ethnicity=ethnicity.value, consumer=consumer.value)
data.collect() data.collect()
return (data,) return (data,)
@@ -159,49 +155,42 @@ def _(mo):
@app.cell @app.cell
def _(data, survey): def _(S, data):
char_rank = survey.get_character_ranking(data)[0].collect() char_rank = S.get_character_ranking(data)[0]
return (char_rank,) return (char_rank,)
@app.cell @app.cell
def _(char_rank, mo, plot_top3_ranking_distribution, survey): def _(S, char_rank, mo):
mo.md(f""" mo.md(f"""
### 1. Which character personality is ranked best? ### 1. Which character personality is ranked best?
{mo.ui.plotly(plot_top3_ranking_distribution(char_rank, x_label='Character Personality', width=1000, results_dir=survey.fig_save_dir))} {mo.ui.plotly(S.plot_top3_ranking_distribution(char_rank, x_label='Character Personality', width=1000))}
""") """)
return return
@app.cell @app.cell
def _(char_rank, mo, plot_most_ranked_1, survey): def _(S, char_rank, mo):
mo.md(f""" mo.md(f"""
### 2. Which character personality is ranked 1st the most? ### 2. Which character personality is ranked 1st the most?
{mo.ui.plotly(plot_most_ranked_1(char_rank, title="Most Popular Character<br>(Number of Times Ranked 1st)", x_label='Character Personality', width=1000, results_dir=survey.fig_save_dir))} {mo.ui.plotly(S.plot_most_ranked_1(char_rank, title="Most Popular Character<br>(Number of Times Ranked 1st)", x_label='Character Personality', width=1000))}
""") """)
return return
@app.cell @app.cell
def _( def _(S, calculate_weighted_ranking_scores, char_rank, mo):
calculate_weighted_ranking_scores,
char_rank,
mo,
plot_weighted_ranking_score,
survey,
):
char_rank_weighted = calculate_weighted_ranking_scores(char_rank) char_rank_weighted = calculate_weighted_ranking_scores(char_rank)
# plot_weighted_ranking_score(char_rank_weighted, x_label='Voice', width=1000)
mo.md(f""" mo.md(f"""
### 3. Which character personality most popular based on weighted scores? ### 3. Which character personality most popular based on weighted scores?
{mo.ui.plotly(plot_weighted_ranking_score(char_rank_weighted, title="Most Popular Character - Weighted Popularity Score<br>(1st=3pts, 2nd=2pts, 3rd=1pt)", x_label='Voice', width=1000, results_dir=survey.fig_save_dir))} {mo.ui.plotly(S.plot_weighted_ranking_score(char_rank_weighted, title="Most Popular Character - Weighted Popularity Score<br>(1st=3pts, 2nd=2pts, 3rd=1pt)", x_label='Voice', width=1000))}
""") """)
return return
@@ -215,73 +204,73 @@ def _(mo):
@app.cell @app.cell
def _(data, survey): def _(S, data):
v_18_8_3 = survey.get_18_8_3(data)[0].collect() v_18_8_3 = S.get_18_8_3(data)[0].collect()
# print(v_18_8_3.head()) # print(v_18_8_3.head())
return (v_18_8_3,) return (v_18_8_3,)
@app.cell(hide_code=True) @app.cell(hide_code=True)
def _(mo, plot_voice_selection_counts, survey, v_18_8_3): def _(S, mo, v_18_8_3):
mo.md(f""" mo.md(f"""
### Which 8 voices are chosen the most out of 18? ### Which 8 voices are chosen the most out of 18?
{mo.ui.plotly(plot_voice_selection_counts(v_18_8_3, height=500, width=1000, results_dir=survey.fig_save_dir))} {mo.ui.plotly(S.plot_voice_selection_counts(v_18_8_3, height=500, width=1000))}
""") """)
return return
@app.cell(hide_code=True) @app.cell(hide_code=True)
def _(mo, plot_top3_selection_counts, survey, v_18_8_3): def _(S, mo, v_18_8_3):
mo.md(f""" mo.md(f"""
### Which 3 voices are chosen the most out of 18? ### Which 3 voices are chosen the most out of 18?
How many times does each voice end up in the top 3? ( this is based on the survey question where participants need to choose 3 out of the earlier selected 8 voices. So how often each of the 18 stimuli ended up in participants Top 3, after they first selected 8 out of 18. How many times does each voice end up in the top 3? ( this is based on the survey question where participants need to choose 3 out of the earlier selected 8 voices. So how often each of the 18 stimuli ended up in participants Top 3, after they first selected 8 out of 18.
{mo.ui.plotly(plot_top3_selection_counts(v_18_8_3, height=500, width=1000, results_dir=survey.fig_save_dir))} {mo.ui.plotly(S.plot_top3_selection_counts(v_18_8_3, height=500, width=1000))}
""") """)
return return
@app.cell(hide_code=True) @app.cell(hide_code=True)
def _(calculate_weighted_ranking_scores, data, survey): def _(S, calculate_weighted_ranking_scores, data):
top3_voices = survey.get_top_3_voices(data)[0].collect() top3_voices = S.get_top_3_voices(data)[0]
top3_voices_weighted = calculate_weighted_ranking_scores(top3_voices) top3_voices_weighted = calculate_weighted_ranking_scores(top3_voices)
return top3_voices, top3_voices_weighted return top3_voices, top3_voices_weighted
@app.cell @app.cell
def _(mo, plot_ranking_distribution, survey, top3_voices): def _(S, mo, top3_voices):
mo.md(f""" mo.md(f"""
### Which voice is ranked best in the ranking question for top 3? ### Which voice is ranked best in the ranking question for top 3?
(not best 3 out of 8 question) (not best 3 out of 8 question)
{mo.ui.plotly(plot_ranking_distribution(top3_voices, x_label='Voice', width=1000, results_dir=survey.fig_save_dir))} {mo.ui.plotly(S.plot_ranking_distribution(top3_voices, x_label='Voice', width=1000))}
""") """)
return return
@app.cell @app.cell
def _(mo, plot_weighted_ranking_score, survey, top3_voices_weighted): def _(S, mo, top3_voices_weighted):
mo.md(f""" mo.md(f"""
### Most popular **voice** based on weighted scores? ### Most popular **voice** based on weighted scores?
- E.g. 1 point for place 3. 2 points for place 2 and 3 points for place 1. The voice with most points is ranked best. - E.g. 1 point for place 3. 2 points for place 2 and 3 points for place 1. The voice with most points is ranked best.
Distribution of the rankings for each voice: Distribution of the rankings for each voice:
{mo.ui.plotly(plot_weighted_ranking_score(top3_voices_weighted, title="Most Popular Voice - Weighted Popularity Score<br>(1st = 3pts, 2nd = 2pts, 3rd = 1pt)", height=500, width=1000, results_dir=survey.fig_save_dir))} {mo.ui.plotly(S.plot_weighted_ranking_score(top3_voices_weighted, title="Most Popular Voice - Weighted Popularity Score<br>(1st = 3pts, 2nd = 2pts, 3rd = 1pt)", height=500, width=1000))}
""") """)
return return
@app.cell @app.cell
def _(mo, plot_most_ranked_1, survey, top3_voices): def _(S, mo, top3_voices):
mo.md(f""" mo.md(f"""
### Which voice is ranked number 1 the most? ### Which voice is ranked number 1 the most?
(not always the voice with most points) (not always the voice with most points)
{mo.ui.plotly(plot_most_ranked_1(top3_voices, title="Most Popular Voice<br>(Number of Times Ranked 1st)", x_label='Voice', width=1000, results_dir=survey.fig_save_dir))} {mo.ui.plotly(S.plot_most_ranked_1(top3_voices, title="Most Popular Voice<br>(Number of Times Ranked 1st)", x_label='Voice', width=1000))}
""") """)
return return
@@ -297,9 +286,9 @@ def _(mo):
@app.cell @app.cell
def _(data, survey, utils): def _(S, data, utils):
ss_or, choice_map_or = survey.get_ss_orange_red(data) ss_or, choice_map_or = S.get_ss_orange_red(data)
ss_gb, choice_map_gb = survey.get_ss_green_blue(data) ss_gb, choice_map_gb = S.get_ss_green_blue(data)
# Combine the data # Combine the data
ss_all = ss_or.join(ss_gb, on='_recordId') ss_all = ss_or.join(ss_gb, on='_recordId')
@@ -313,7 +302,7 @@ def _(data, survey, utils):
@app.cell @app.cell
def _(mo, pl, plots, ss_long, survey): def _(S, mo, pl, ss_long):
content = """### How does each voice score for each “speaking style labeled trait”?""" content = """### How does each voice score for each “speaking style labeled trait”?"""
for i, trait in enumerate(ss_long.select("Description").unique().to_series().to_list()): for i, trait in enumerate(ss_long.select("Description").unique().to_series().to_list()):
@@ -322,7 +311,7 @@ def _(mo, pl, plots, ss_long, survey):
content += f""" content += f"""
### {i+1}) {trait.replace(":", "")} ### {i+1}) {trait.replace(":", "")}
{mo.ui.plotly(plots.plot_speaking_style_trait_scores(trait_d, title=trait.replace(":", ""), height=550, results_dir=survey.fig_save_dir))} {mo.ui.plotly(S.plot_speaking_style_trait_scores(trait_d, title=trait.replace(":", ""), height=550))}
""" """
mo.md(content) mo.md(content)
@@ -338,18 +327,18 @@ def _(mo):
@app.cell @app.cell
def _(data, survey): def _(S, data):
vscales = survey.get_voice_scale_1_10(data)[0].collect() vscales = S.get_voice_scale_1_10(data)[0]
# plot_average_scores_with_counts(vscales, x_label='Voice', width=1000) # plot_average_scores_with_counts(vscales, x_label='Voice', width=1000)
return (vscales,) return (vscales,)
@app.cell @app.cell
def _(mo, plots, survey, vscales): def _(S, mo, vscales):
mo.md(f""" mo.md(f"""
### How does each voice score on a scale from 1-10? ### How does each voice score on a scale from 1-10?
{mo.ui.plotly(plots.plot_average_scores_with_counts(vscales, x_label='Voice', width=1000, results_dir=survey.fig_save_dir))} {mo.ui.plotly(S.plot_average_scores_with_counts(vscales, x_label='Voice', width=1000))}
""") """)
return return
@@ -394,7 +383,7 @@ def _(mo):
@app.cell @app.cell
def _(choice_map, ss_all, utils, vscales): def _(choice_map, ss_all, utils, vscales):
df_style = utils.process_speaking_style_data(ss_all.collect(), choice_map) df_style = utils.process_speaking_style_data(ss_all, choice_map)
df_voice_long = utils.process_voice_scale_data(vscales) df_voice_long = utils.process_voice_scale_data(vscales)
joined_df = df_style.join(df_voice_long, on=["_recordId", "Voice"], how="inner") joined_df = df_style.join(df_voice_long, on=["_recordId", "Voice"], how="inner")
@@ -403,19 +392,18 @@ def _(choice_map, ss_all, utils, vscales):
@app.cell @app.cell
def _(SPEAKING_STYLES, joined_df, mo, plots, survey): def _(S, SPEAKING_STYLES, joined_df, mo):
_content = """### Total Results _content = """### Total Results
""" """
for style, traits in SPEAKING_STYLES.items(): for style, traits in SPEAKING_STYLES.items():
# print(f"Correlation plot for {style}...") # print(f"Correlation plot for {style}...")
fig = plots.plot_speaking_style_correlation( fig = S.plot_speaking_style_correlation(
df=joined_df, data=joined_df,
style_color=style, style_color=style,
style_traits=traits, style_traits=traits,
title=f"Correlation: Speaking Style {style} and Voice Scale 1-10", title=f"Correlation: Speaking Style {style} and Voice Scale 1-10"
results_dir=survey.fig_save_dir
) )
_content += f""" _content += f"""
#### Speaking Style **{style}**: #### Speaking Style **{style}**:
@@ -470,7 +458,7 @@ def _(mo):
@app.cell @app.cell
def _(SPEAKING_STYLES, df_style, mo, plots, survey, top3_voices, utils): def _(S, SPEAKING_STYLES, df_style, mo, top3_voices, utils):
df_ranking = utils.process_voice_ranking_data(top3_voices) df_ranking = utils.process_voice_ranking_data(top3_voices)
joined = df_style.join(df_ranking, on=['_recordId', 'Voice'], how='inner') joined = df_style.join(df_ranking, on=['_recordId', 'Voice'], how='inner')
@@ -480,7 +468,7 @@ def _(SPEAKING_STYLES, df_style, mo, plots, survey, top3_voices, utils):
""" """
for _style, _traits in SPEAKING_STYLES.items(): for _style, _traits in SPEAKING_STYLES.items():
_fig = plots.plot_speaking_style_ranking_correlation(joined, _style, _traits, results_dir=survey.fig_save_dir) _fig = S.plot_speaking_style_ranking_correlation(data=joined, style_color=_style, style_traits=_traits)
_content += f""" _content += f"""
#### Speaking Style **{_style}**: #### Speaking Style **{_style}**:

1941
plots.py

File diff suppressed because it is too large Load Diff

View File

@@ -4,6 +4,9 @@ import pandas as pd
from typing import Union from typing import Union
import json import json
import re import re
from plots import JPMCPlotsMixin
import marimo as mo
def extract_voice_label(html_str: str) -> str: def extract_voice_label(html_str: str) -> str:
""" """
@@ -54,7 +57,7 @@ def combine_exclusive_columns(df: pl.DataFrame, id_col: str = "_recordId", targe
def calculate_weighted_ranking_scores(df: pl.DataFrame) -> pl.DataFrame: def calculate_weighted_ranking_scores(df: pl.LazyFrame) -> pl.DataFrame:
""" """
Calculate weighted scores for character or voice rankings. Calculate weighted scores for character or voice rankings.
Points system: 1st place = 3 pts, 2nd place = 2 pts, 3rd place = 1 pt. Points system: 1st place = 3 pts, 2nd place = 2 pts, 3rd place = 1 pt.
@@ -69,6 +72,9 @@ def calculate_weighted_ranking_scores(df: pl.DataFrame) -> pl.DataFrame:
pl.DataFrame pl.DataFrame
DataFrame with columns 'Character' and 'Weighted Score', sorted by score. DataFrame with columns 'Character' and 'Weighted Score', sorted by score.
""" """
if isinstance(df, pl.LazyFrame):
df = df.collect()
scores = [] scores = []
# Identify ranking columns (assume all columns except _recordId) # Identify ranking columns (assume all columns except _recordId)
ranking_cols = [c for c in df.columns if c != '_recordId'] ranking_cols = [c for c in df.columns if c != '_recordId']
@@ -93,7 +99,7 @@ def calculate_weighted_ranking_scores(df: pl.DataFrame) -> pl.DataFrame:
return pl.DataFrame(scores).sort('Weighted Score', descending=True) return pl.DataFrame(scores).sort('Weighted Score', descending=True)
class JPMCSurvey: class JPMCSurvey(JPMCPlotsMixin):
"""Class to handle JPMorgan Chase survey data.""" """Class to handle JPMorgan Chase survey data."""
def __init__(self, data_path: Union[str, Path], qsf_path: Union[str, Path]): def __init__(self, data_path: Union[str, Path], qsf_path: Union[str, Path]):
@@ -113,6 +119,18 @@ class JPMCSurvey:
if not self.fig_save_dir.exists(): if not self.fig_save_dir.exists():
self.fig_save_dir.mkdir(parents=True, exist_ok=True) self.fig_save_dir.mkdir(parents=True, exist_ok=True)
self.data_filtered = None
self.plot_height = 500
self.plot_width = 1000
# Filter values
self.filter_age:list = None
self.filter_gender:list = None
self.filter_consumer:list = None
self.filter_ethnicity:list = None
self.filter_income:list = None
def _extract_qid_descr_map(self) -> dict: def _extract_qid_descr_map(self) -> dict:
"""Extract mapping of Qualtrics ImportID to Question Description from results file.""" """Extract mapping of Qualtrics ImportID to Question Description from results file."""
@@ -217,25 +235,32 @@ class JPMCSurvey:
- ethnicity: list - ethnicity: list
- income: list - income: list
Returns filtered polars LazyFrame. Also saves the result to self.data_filtered.
""" """
# Apply filters
if age is not None: if age is not None:
self.filter_age = age
q = q.filter(pl.col('QID1').is_in(age)) q = q.filter(pl.col('QID1').is_in(age))
if gender is not None: if gender is not None:
self.filter_gender = gender
q = q.filter(pl.col('QID2').is_in(gender)) q = q.filter(pl.col('QID2').is_in(gender))
if consumer is not None: if consumer is not None:
self.filter_consumer = consumer
q = q.filter(pl.col('Consumer').is_in(consumer)) q = q.filter(pl.col('Consumer').is_in(consumer))
if ethnicity is not None: if ethnicity is not None:
self.filter_ethnicity = ethnicity
q = q.filter(pl.col('QID3').is_in(ethnicity)) q = q.filter(pl.col('QID3').is_in(ethnicity))
if income is not None: if income is not None:
self.filter_income = income
q = q.filter(pl.col('QID15').is_in(income)) q = q.filter(pl.col('QID15').is_in(income))
return q self.data_filtered = q
return self.data_filtered
def get_demographics(self, q: pl.LazyFrame) -> Union[pl.LazyFrame, None]: def get_demographics(self, q: pl.LazyFrame) -> Union[pl.LazyFrame, None]:
"""Extract columns containing the demographics. """Extract columns containing the demographics.