diff --git a/02_quant_analysis.py b/02_quant_analysis.py index d0f7b1e..055bca7 100644 --- a/02_quant_analysis.py +++ b/02_quant_analysis.py @@ -39,7 +39,6 @@ def _(): def _(JPMCSurvey, QSF_FILE, RESULTS_FILE): S = JPMCSurvey(RESULTS_FILE, QSF_FILE) data_all = S.load_data() - data_all.collect() return S, data_all @@ -96,22 +95,9 @@ def _(mo): @app.cell(hide_code=True) -def _(data_all, mo): - data_all_collected = data_all.collect() - age = mo.ui.multiselect(options=data_all_collected["QID1"], value=data_all_collected["QID1"].unique(), label="Select Age Group(s):") - income = mo.ui.multiselect(data_all_collected["QID15"], value=data_all_collected["QID15"], label="Select Income Group(s):") - gender = mo.ui.multiselect(data_all_collected["QID2"], value=data_all_collected["QID2"], label="Select Gender(s)") - ethnicity = mo.ui.multiselect(data_all_collected["QID3"], value=data_all_collected["QID3"], label="Select Ethnicities:") - consumer = mo.ui.multiselect(data_all_collected["Consumer"], value=data_all_collected["Consumer"], label="Select Consumer Groups:") - return age, consumer, ethnicity, gender, income - - -@app.cell -def _(age, consumer, ethnicity, gender, income, mo): - - mo.md(f""" - # Data Filters - +def _(S, mo): + filter_form = mo.md(''' + # Data Filter {age} @@ -122,16 +108,26 @@ def _(age, consumer, ethnicity, gender, income, mo): {income} {consumer} - - """) - - - return + ''' + ).batch( + age=mo.ui.multiselect(options=S.options_age, value=S.options_age, label="Select Age Group(s):"), + gender=mo.ui.multiselect(options=S.options_gender, value=S.options_gender, label="Select Gender(s):"), + ethnicity=mo.ui.multiselect(options=S.options_ethnicity, value=S.options_ethnicity, label="Select Ethnicities:"), + income=mo.ui.multiselect(options=S.options_income, value=S.options_income, label="Select Income Group(s):"), + consumer=mo.ui.multiselect(options=S.options_consumer, value=S.options_consumer, label="Select Consumer Groups:") + ).form() + filter_form + return (filter_form,) @app.cell -def _(S, age, consumer, data_all, ethnicity, gender, income): - data = S.filter_data(data_all, age=age.value, gender=gender.value, income=income.value, ethnicity=ethnicity.value, consumer=consumer.value) +def _(S, data_all, filter_form, mo): + _d = S.filter_data(data_all, age=filter_form.value['age'], gender=filter_form.value['gender'], income=filter_form.value['income'], ethnicity=filter_form.value['ethnicity'], consumer=filter_form.value['consumer']) + + # Stop execution and prevent other cells from running if no data is selected + mo.stop(len(_d.collect()) == 0, mo.md("**No Data available for current filter combination**")) + data = _d + data.collect() return (data,) diff --git a/plots.py b/plots.py index 8c12b6c..c4da523 100644 --- a/plots.py +++ b/plots.py @@ -24,10 +24,66 @@ class JPMCPlotsMixin: # Lowercase and limit length return clean.lower()[:100] + def _get_filter_slug(self) -> str: + """Generate a directory-friendly slug based on active filters.""" + parts = [] + + # Mapping of attribute name to (short_code, value, options_attr) + filters = [ + ('age', 'Age', getattr(self, 'filter_age', None), 'options_age'), + ('gender', 'Gen', getattr(self, 'filter_gender', None), 'options_gender'), + ('consumer', 'Cons', getattr(self, 'filter_consumer', None), 'options_consumer'), + ('ethnicity', 'Eth', getattr(self, 'filter_ethnicity', None), 'options_ethnicity'), + ('income', 'Inc', getattr(self, 'filter_income', None), 'options_income'), + ] + + for _, short_code, value, options_attr in filters: + if value is None: + continue + + # Ensure value is a list for uniform handling + if not isinstance(value, list): + value = [value] + + if len(value) == 0: + continue + + # Check if all options are selected (equivalent to no filter) + # We compare the set of selected values to the set of all available options + master_list = getattr(self, options_attr, None) + if master_list and set(value) == set(master_list): + continue + + if len(value) > 3: + # If more than 3 options selected, use count to keep slug short + val_str = f"{len(value)}_grps" + else: + # Join values with '+' + clean_values = [] + for v in value: + # Simple sanitization: keep alphanum and hyphens/dots, remove others + s = str(v) + # Remove special chars that might be problematic in dir names + s = re.sub(r'[^\w\-\.]', '', s) + clean_values.append(s) + val_str = "+".join(clean_values) + + parts.append(f"{short_code}-{val_str}") + + if not parts: + return "All_Respondents" + + return "_".join(parts) + def _save_plot(self, fig: go.Figure, title: str) -> None: """Save plot to PNG file if fig_save_dir is set.""" if hasattr(self, 'fig_save_dir') and self.fig_save_dir: path = Path(self.fig_save_dir) + + # Add filter slug subfolder + filter_slug = self._get_filter_slug() + path = path / filter_slug + if not path.exists(): path.mkdir(parents=True, exist_ok=True) diff --git a/utils.py b/utils.py index 440a7d6..757082d 100644 --- a/utils.py +++ b/utils.py @@ -210,6 +210,13 @@ class JPMCSurvey(JPMCPlotsMixin): # Rename columns with the extracted ImportIds df.columns = new_columns + # Store unique values for filters (ignoring nulls) to detect "all selected" state + self.options_age = sorted(df['QID1'].drop_nulls().unique().to_list()) if 'QID1' in df.columns else [] + self.options_gender = sorted(df['QID2'].drop_nulls().unique().to_list()) if 'QID2' in df.columns else [] + self.options_consumer = sorted(df['Consumer'].drop_nulls().unique().to_list()) if 'Consumer' in df.columns else [] + self.options_ethnicity = sorted(df['QID3'].drop_nulls().unique().to_list()) if 'QID3' in df.columns else [] + self.options_income = sorted(df['QID15'].drop_nulls().unique().to_list()) if 'QID15' in df.columns else [] + return df.lazy() def _get_subset(self, q: pl.LazyFrame, QIDs, rename_cols=True, include_record_id=True) -> pl.LazyFrame: @@ -239,24 +246,24 @@ class JPMCSurvey(JPMCPlotsMixin): """ # Apply filters + self.filter_age = age if age is not None: - self.filter_age = age q = q.filter(pl.col('QID1').is_in(age)) + self.filter_gender = gender if gender is not None: - self.filter_gender = gender q = q.filter(pl.col('QID2').is_in(gender)) + self.filter_consumer = consumer if consumer is not None: - self.filter_consumer = consumer q = q.filter(pl.col('Consumer').is_in(consumer)) + self.filter_ethnicity = ethnicity if ethnicity is not None: - self.filter_ethnicity = ethnicity q = q.filter(pl.col('QID3').is_in(ethnicity)) + self.filter_income = income if income is not None: - self.filter_income = income q = q.filter(pl.col('QID15').is_in(income)) self.data_filtered = q