move plots to mixin class of JPMCSurvey to simplify file saving
This commit is contained in:
33
utils.py
33
utils.py
@@ -4,6 +4,9 @@ import pandas as pd
|
||||
from typing import Union
|
||||
import json
|
||||
import re
|
||||
from plots import JPMCPlotsMixin
|
||||
|
||||
import marimo as mo
|
||||
|
||||
def extract_voice_label(html_str: str) -> str:
|
||||
"""
|
||||
@@ -54,7 +57,7 @@ def combine_exclusive_columns(df: pl.DataFrame, id_col: str = "_recordId", targe
|
||||
|
||||
|
||||
|
||||
def calculate_weighted_ranking_scores(df: pl.DataFrame) -> pl.DataFrame:
|
||||
def calculate_weighted_ranking_scores(df: pl.LazyFrame) -> pl.DataFrame:
|
||||
"""
|
||||
Calculate weighted scores for character or voice rankings.
|
||||
Points system: 1st place = 3 pts, 2nd place = 2 pts, 3rd place = 1 pt.
|
||||
@@ -69,6 +72,9 @@ def calculate_weighted_ranking_scores(df: pl.DataFrame) -> pl.DataFrame:
|
||||
pl.DataFrame
|
||||
DataFrame with columns 'Character' and 'Weighted Score', sorted by score.
|
||||
"""
|
||||
if isinstance(df, pl.LazyFrame):
|
||||
df = df.collect()
|
||||
|
||||
scores = []
|
||||
# Identify ranking columns (assume all columns except _recordId)
|
||||
ranking_cols = [c for c in df.columns if c != '_recordId']
|
||||
@@ -93,7 +99,7 @@ def calculate_weighted_ranking_scores(df: pl.DataFrame) -> pl.DataFrame:
|
||||
return pl.DataFrame(scores).sort('Weighted Score', descending=True)
|
||||
|
||||
|
||||
class JPMCSurvey:
|
||||
class JPMCSurvey(JPMCPlotsMixin):
|
||||
"""Class to handle JPMorgan Chase survey data."""
|
||||
|
||||
def __init__(self, data_path: Union[str, Path], qsf_path: Union[str, Path]):
|
||||
@@ -112,6 +118,18 @@ class JPMCSurvey:
|
||||
self.fig_save_dir = Path('figures') / self.data_filepath.parts[2]
|
||||
if not self.fig_save_dir.exists():
|
||||
self.fig_save_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
self.data_filtered = None
|
||||
self.plot_height = 500
|
||||
self.plot_width = 1000
|
||||
|
||||
# Filter values
|
||||
self.filter_age:list = None
|
||||
self.filter_gender:list = None
|
||||
self.filter_consumer:list = None
|
||||
self.filter_ethnicity:list = None
|
||||
self.filter_income:list = None
|
||||
|
||||
|
||||
|
||||
def _extract_qid_descr_map(self) -> dict:
|
||||
@@ -217,25 +235,32 @@ class JPMCSurvey:
|
||||
- ethnicity: list
|
||||
- income: list
|
||||
|
||||
Returns filtered polars LazyFrame.
|
||||
Also saves the result to self.data_filtered.
|
||||
"""
|
||||
|
||||
# Apply filters
|
||||
if age is not None:
|
||||
self.filter_age = age
|
||||
q = q.filter(pl.col('QID1').is_in(age))
|
||||
|
||||
if gender is not None:
|
||||
self.filter_gender = gender
|
||||
q = q.filter(pl.col('QID2').is_in(gender))
|
||||
|
||||
if consumer is not None:
|
||||
self.filter_consumer = consumer
|
||||
q = q.filter(pl.col('Consumer').is_in(consumer))
|
||||
|
||||
if ethnicity is not None:
|
||||
self.filter_ethnicity = ethnicity
|
||||
q = q.filter(pl.col('QID3').is_in(ethnicity))
|
||||
|
||||
if income is not None:
|
||||
self.filter_income = income
|
||||
q = q.filter(pl.col('QID15').is_in(income))
|
||||
|
||||
return q
|
||||
self.data_filtered = q
|
||||
return self.data_filtered
|
||||
|
||||
def get_demographics(self, q: pl.LazyFrame) -> Union[pl.LazyFrame, None]:
|
||||
"""Extract columns containing the demographics.
|
||||
|
||||
Reference in New Issue
Block a user