diff --git a/02_quant_analysis.py b/02_quant_analysis.py index 2c33ce8..edc787e 100644 --- a/02_quant_analysis.py +++ b/02_quant_analysis.py @@ -1,7 +1,7 @@ import marimo __generated_with = "0.19.2" -app = marimo.App(width="medium") +app = marimo.App(width="full") @app.cell @@ -167,17 +167,19 @@ def _(S, mo): ''') - return (filter_form,) + return @app.cell -def _(S, data_validated, filter_form, mo): - mo.stop(filter_form.value is None, mo.md("**Please submit filter above to proceed**")) - _d = S.filter_data(data_validated, age=filter_form.value['age'], gender=filter_form.value['gender'], income=filter_form.value['income'], ethnicity=filter_form.value['ethnicity'], consumer=filter_form.value['consumer']) +def _(data_validated): + # mo.stop(filter_form.value is None, mo.md("**Please submit filter above to proceed**")) + # _d = S.filter_data(data_validated, age=filter_form.value['age'], gender=filter_form.value['gender'], income=filter_form.value['income'], ethnicity=filter_form.value['ethnicity'], consumer=filter_form.value['consumer']) - # Stop execution and prevent other cells from running if no data is selected - mo.stop(len(_d.collect()) == 0, mo.md("**No Data available for current filter combination**")) - data = _d + # # Stop execution and prevent other cells from running if no data is selected + # mo.stop(len(_d.collect()) == 0, mo.md("**No Data available for current filter combination**")) + # data = _d + + data = data_validated data.collect() return (data,) @@ -391,28 +393,25 @@ def _(S, mo, vscales): mo.md(f""" ### How does each voice score on a scale from 1-10? - {mo.ui.altair_chart(S.plot_average_scores_with_counts(vscales, x_label='Voice', width=1000))} + {mo.ui.altair_chart(S.plot_average_scores_with_counts(vscales, x_label='Voice', width=1000, domain=[1,10]))} """) return @app.cell -def _(vscales): - target_cols=[c for c in vscales.columns if c not in ['_recordId']] - target_cols - return (target_cols,) +def _(utils, vscales): + _target_cols=[c for c in vscales.collect().columns if c not in ['_recordId']] + vscales_row_norm = utils.normalize_row_values(vscales.collect(), target_cols=_target_cols) + vscales_row_norm + return (vscales_row_norm,) @app.cell -def _(target_cols, utils, vscales): - vscales_row_norm = utils.normalize_row_values(vscales.collect(), target_cols=target_cols) - return +def _(S, mo, vscales_row_norm): + mo.md(f""" + ### Voice scale 1-10 normalized per respondent? - -@app.cell -def _(mo): - mo.md(r""" - + {mo.ui.altair_chart(S.plot_average_scores_with_counts(vscales_row_norm, x_label='Voice', width=1000))} """) return diff --git a/plots.py b/plots.py index 8c68bd6..897b07d 100644 --- a/plots.py +++ b/plots.py @@ -13,6 +13,12 @@ import hashlib class JPMCPlotsMixin: """Mixin class for plotting functions in JPMCSurvey.""" + def _process_title(self, title: str) -> str | list[str]: + """Process title to handle
tags for Altair.""" + if isinstance(title, str) and '
' in title: + return title.split('
') + return title + def _sanitize_filename(self, title: str) -> str: """Convert plot title to a safe filename.""" # Remove HTML tags @@ -156,8 +162,8 @@ class JPMCPlotsMixin: chart_spec = chart.to_dict() existing_title = chart_spec.get('title', '') - # Handle different title formats (string vs dict) - if isinstance(existing_title, str): + # Handle different title formats (string vs dict vs list) + if isinstance(existing_title, (str, list)): title_config = { 'text': existing_title, 'subtitle': lines, @@ -260,6 +266,7 @@ class JPMCPlotsMixin: color: str = ColorPalette.PRIMARY, height: int | None = None, width: int | str | None = None, + domain: list[float] | None = None, ) -> alt.Chart: """Create a bar plot showing average scores and count of non-null values for each column.""" df = self._ensure_dataframe(data) @@ -278,11 +285,14 @@ class JPMCPlotsMixin: # Convert to pandas for Altair (sort by average descending) stats_df = pl.DataFrame(stats).sort('average', descending=True).to_pandas() + + if domain is None: + domain = [stats_df['average'].min(), stats_df['average'].max()] # Base bar chart bars = alt.Chart(stats_df).mark_bar(color=color).encode( x=alt.X('voice:N', title=x_label, sort='-y'), - y=alt.Y('average:Q', title=y_label, scale=alt.Scale(domain=[0, 10])), + y=alt.Y('average:Q', title=y_label, scale=alt.Scale(domain=domain)), tooltip=[ alt.Tooltip('voice:N', title='Voice'), alt.Tooltip('average:Q', title='Average', format='.2f'), @@ -303,7 +313,7 @@ class JPMCPlotsMixin: # Combine layers chart = (bars + text).properties( - title=title, + title=self._process_title(title), width=width or 800, height=height or getattr(self, 'plot_height', 400) ) @@ -360,7 +370,7 @@ class JPMCPlotsMixin: alt.Tooltip('count:Q', title='Count') ] ).add_params(selection).properties( - title=title, + title=self._process_title(title), width=width or 800, height=height or getattr(self, 'plot_height', 400) ) @@ -420,7 +430,7 @@ class JPMCPlotsMixin: alt.Tooltip('count:Q', title='Count') ] ).add_params(selection).properties( - title=title, + title=self._process_title(title), width=width or 800, height=height or getattr(self, 'plot_height', 400) ) @@ -473,7 +483,7 @@ class JPMCPlotsMixin: alt.Tooltip('count:Q', title='1st Place Votes') ] ).properties( - title=title, + title=self._process_title(title), width=width or 800, height=height or getattr(self, 'plot_height', 400) ) @@ -514,7 +524,7 @@ class JPMCPlotsMixin: ) chart = (bars + text).properties( - title=title, + title=self._process_title(title), width=width or 800, height=height or getattr(self, 'plot_height', 400) ) @@ -571,7 +581,7 @@ class JPMCPlotsMixin: alt.Tooltip('count:Q', title='Selections') ] ).properties( - title=title, + title=self._process_title(title), width=width or 800, height=height or getattr(self, 'plot_height', 400) ) @@ -627,7 +637,7 @@ class JPMCPlotsMixin: alt.Tooltip('count:Q', title='In Top 3') ] ).properties( - title=title, + title=self._process_title(title), width=width or 800, height=height or getattr(self, 'plot_height', 400) ) @@ -713,7 +723,7 @@ class JPMCPlotsMixin: # Combine layers chart = (bars + text).properties( title={ - "text": title, + "text": self._process_title(title), "subtitle": [trait_description, "(Numbers on bars indicate respondent count)"] }, width=width or 800, @@ -776,7 +786,7 @@ class JPMCPlotsMixin: alt.Tooltip('correlation:Q', format='.2f') ] ).properties( - title=title, + title=self._process_title(title), width=width or 800, height=height or 350 ) @@ -832,7 +842,7 @@ class JPMCPlotsMixin: alt.Tooltip('correlation:Q', format='.2f') ] ).properties( - title=title, + title=self._process_title(title), width=width or 800, height=height or 350 ) diff --git a/utils.py b/utils.py index 28d3a2c..3b6fc0a 100644 --- a/utils.py +++ b/utils.py @@ -351,18 +351,22 @@ def calculate_weighted_ranking_scores(df: pl.LazyFrame) -> pl.DataFrame: def normalize_row_values(df: pl.DataFrame, target_cols: list[str]) -> pl.DataFrame: """ - Normalizes values in the specified columns row-wise (Standardization: (x - mean) / std). - Ignores null values (NaNs). Only applied if there are at least 2 non-null values in the row. + Normalizes values in the specified columns row-wise to 0-10 scale (Min-Max normalization). + Formula: ((x - min) / (max - min)) * 10 + Ignores null values (NaNs). """ # Using list evaluation for row-wise stats # We create a temporary list column containing values from all target columns + # Ensure columns are cast to Float64 to avoid type errors with mixed/string data df_norm = df.with_columns( - pl.concat_list(target_cols) + pl.concat_list([pl.col(c).cast(pl.Float64) for c in target_cols]) .list.eval( - # Apply standardization: (x - mean) / std - # std(ddof=1) is the sample standard deviation - (pl.element() - pl.element().mean()) / pl.element().std(ddof=1) + # Apply Min-Max scaling to 0-10 + ( + (pl.element() - pl.element().min()) / + (pl.element().max() - pl.element().min()) + ) * 10 ) .alias("_normalized_values") ) @@ -377,8 +381,8 @@ def normalize_row_values(df: pl.DataFrame, target_cols: list[str]) -> pl.DataFra def normalize_global_values(df: pl.DataFrame, target_cols: list[str]) -> pl.DataFrame: """ - Normalizes values in the specified columns globally (Standardization: (x - global_mean) / global_std). - Computes a single mean and standard deviation across ALL values in the target_cols and applies it. + Normalizes values in the specified columns globally to 0-10 scale. + Formula: ((x - global_min) / (global_max - global_min)) * 10 Ignores null values (NaNs). """ # Ensure eager for scalar extraction @@ -390,19 +394,23 @@ def normalize_global_values(df: pl.DataFrame, target_cols: list[str]) -> pl.Data return df.lazy() if was_lazy else df # Calculate global stats efficiently by stacking all columns - stats = df.select(target_cols).melt().select([ - pl.col("value").mean().alias("mean"), - pl.col("value").std().alias("std") + # Cast to Float64 to ensure numeric calculations + stats = df.select([pl.col(c).cast(pl.Float64) for c in target_cols]).melt().select([ + pl.col("value").min().alias("min"), + pl.col("value").max().alias("max") ]) - global_mean = stats["mean"][0] - global_std = stats["std"][0] + global_min = stats["min"][0] + global_max = stats["max"][0] - if global_std is None or global_std == 0: + # Handle edge case where all values are same or none exist + if global_min is None or global_max is None or global_max == global_min: return df.lazy() if was_lazy else df + global_range = global_max - global_min + res = df.with_columns([ - ((pl.col(col) - global_mean) / global_std).alias(col) + (((pl.col(col).cast(pl.Float64) - global_min) / global_range) * 10).alias(col) for col in target_cols ]) @@ -649,10 +657,12 @@ class JPMCSurvey(JPMCPlotsMixin): return subset, None - def get_voice_scale_1_10(self, q: pl.LazyFrame) -> Union[pl.LazyFrame, None]: + def get_voice_scale_1_10(self, q: pl.LazyFrame, drop_cols=['Voice_Scale_1_10__V46']) -> Union[pl.LazyFrame, None]: """Extract columns containing the Voice Scale 1-10 ratings for the Chase virtual assistant. Returns subquery that can be chained with other polars queries. + + Drops scores for V46 as it was improperly configured in the survey and thus did not show up for respondents. """ QIDs_map = {} @@ -662,6 +672,12 @@ class JPMCSurvey(JPMCPlotsMixin): # Convert "Voice 16 Scale 1-10_1" to "Scale_1_10__Voice_16" QIDs_map[qid] = f"Voice_Scale_1_10__V{val['QName'].split()[1]}" + for col in drop_cols: + if col in QIDs_map.values(): + # remove from QIDs_map + qid_to_remove = [k for k,v in QIDs_map.items() if v == col][0] + del QIDs_map[qid_to_remove] + return self._get_subset(q, list(QIDs_map.keys()), rename_cols=False).rename(QIDs_map), None