move plots to mixin class of JPMCSurvey to simplify file saving

This commit is contained in:
2026-01-28 14:54:36 +01:00
parent 23136b5c2e
commit 365e70b834
3 changed files with 899 additions and 1103 deletions

View File

@@ -4,6 +4,9 @@ import pandas as pd
from typing import Union
import json
import re
from plots import JPMCPlotsMixin
import marimo as mo
def extract_voice_label(html_str: str) -> str:
"""
@@ -54,7 +57,7 @@ def combine_exclusive_columns(df: pl.DataFrame, id_col: str = "_recordId", targe
def calculate_weighted_ranking_scores(df: pl.DataFrame) -> pl.DataFrame:
def calculate_weighted_ranking_scores(df: pl.LazyFrame) -> pl.DataFrame:
"""
Calculate weighted scores for character or voice rankings.
Points system: 1st place = 3 pts, 2nd place = 2 pts, 3rd place = 1 pt.
@@ -69,6 +72,9 @@ def calculate_weighted_ranking_scores(df: pl.DataFrame) -> pl.DataFrame:
pl.DataFrame
DataFrame with columns 'Character' and 'Weighted Score', sorted by score.
"""
if isinstance(df, pl.LazyFrame):
df = df.collect()
scores = []
# Identify ranking columns (assume all columns except _recordId)
ranking_cols = [c for c in df.columns if c != '_recordId']
@@ -93,7 +99,7 @@ def calculate_weighted_ranking_scores(df: pl.DataFrame) -> pl.DataFrame:
return pl.DataFrame(scores).sort('Weighted Score', descending=True)
class JPMCSurvey:
class JPMCSurvey(JPMCPlotsMixin):
"""Class to handle JPMorgan Chase survey data."""
def __init__(self, data_path: Union[str, Path], qsf_path: Union[str, Path]):
@@ -112,6 +118,18 @@ class JPMCSurvey:
self.fig_save_dir = Path('figures') / self.data_filepath.parts[2]
if not self.fig_save_dir.exists():
self.fig_save_dir.mkdir(parents=True, exist_ok=True)
self.data_filtered = None
self.plot_height = 500
self.plot_width = 1000
# Filter values
self.filter_age:list = None
self.filter_gender:list = None
self.filter_consumer:list = None
self.filter_ethnicity:list = None
self.filter_income:list = None
def _extract_qid_descr_map(self) -> dict:
@@ -217,25 +235,32 @@ class JPMCSurvey:
- ethnicity: list
- income: list
Returns filtered polars LazyFrame.
Also saves the result to self.data_filtered.
"""
# Apply filters
if age is not None:
self.filter_age = age
q = q.filter(pl.col('QID1').is_in(age))
if gender is not None:
self.filter_gender = gender
q = q.filter(pl.col('QID2').is_in(gender))
if consumer is not None:
self.filter_consumer = consumer
q = q.filter(pl.col('Consumer').is_in(consumer))
if ethnicity is not None:
self.filter_ethnicity = ethnicity
q = q.filter(pl.col('QID3').is_in(ethnicity))
if income is not None:
self.filter_income = income
q = q.filter(pl.col('QID15').is_in(income))
return q
self.data_filtered = q
return self.data_filtered
def get_demographics(self, q: pl.LazyFrame) -> Union[pl.LazyFrame, None]:
"""Extract columns containing the demographics.