diff --git a/02_quant_analysis.py b/02_quant_analysis.py
index 2c33ce8..edc787e 100644
--- a/02_quant_analysis.py
+++ b/02_quant_analysis.py
@@ -1,7 +1,7 @@
import marimo
__generated_with = "0.19.2"
-app = marimo.App(width="medium")
+app = marimo.App(width="full")
@app.cell
@@ -167,17 +167,19 @@ def _(S, mo):
''')
- return (filter_form,)
+ return
@app.cell
-def _(S, data_validated, filter_form, mo):
- mo.stop(filter_form.value is None, mo.md("**Please submit filter above to proceed**"))
- _d = S.filter_data(data_validated, age=filter_form.value['age'], gender=filter_form.value['gender'], income=filter_form.value['income'], ethnicity=filter_form.value['ethnicity'], consumer=filter_form.value['consumer'])
+def _(data_validated):
+ # mo.stop(filter_form.value is None, mo.md("**Please submit filter above to proceed**"))
+ # _d = S.filter_data(data_validated, age=filter_form.value['age'], gender=filter_form.value['gender'], income=filter_form.value['income'], ethnicity=filter_form.value['ethnicity'], consumer=filter_form.value['consumer'])
- # Stop execution and prevent other cells from running if no data is selected
- mo.stop(len(_d.collect()) == 0, mo.md("**No Data available for current filter combination**"))
- data = _d
+ # # Stop execution and prevent other cells from running if no data is selected
+ # mo.stop(len(_d.collect()) == 0, mo.md("**No Data available for current filter combination**"))
+ # data = _d
+
+ data = data_validated
data.collect()
return (data,)
@@ -391,28 +393,25 @@ def _(S, mo, vscales):
mo.md(f"""
### How does each voice score on a scale from 1-10?
- {mo.ui.altair_chart(S.plot_average_scores_with_counts(vscales, x_label='Voice', width=1000))}
+ {mo.ui.altair_chart(S.plot_average_scores_with_counts(vscales, x_label='Voice', width=1000, domain=[1,10]))}
""")
return
@app.cell
-def _(vscales):
- target_cols=[c for c in vscales.columns if c not in ['_recordId']]
- target_cols
- return (target_cols,)
+def _(utils, vscales):
+ _target_cols=[c for c in vscales.collect().columns if c not in ['_recordId']]
+ vscales_row_norm = utils.normalize_row_values(vscales.collect(), target_cols=_target_cols)
+ vscales_row_norm
+ return (vscales_row_norm,)
@app.cell
-def _(target_cols, utils, vscales):
- vscales_row_norm = utils.normalize_row_values(vscales.collect(), target_cols=target_cols)
- return
+def _(S, mo, vscales_row_norm):
+ mo.md(f"""
+ ### Voice scale 1-10 normalized per respondent?
-
-@app.cell
-def _(mo):
- mo.md(r"""
-
+ {mo.ui.altair_chart(S.plot_average_scores_with_counts(vscales_row_norm, x_label='Voice', width=1000))}
""")
return
diff --git a/plots.py b/plots.py
index 8c68bd6..897b07d 100644
--- a/plots.py
+++ b/plots.py
@@ -13,6 +13,12 @@ import hashlib
class JPMCPlotsMixin:
"""Mixin class for plotting functions in JPMCSurvey."""
+ def _process_title(self, title: str) -> str | list[str]:
+ """Process title to handle
tags for Altair."""
+ if isinstance(title, str) and '
' in title:
+ return title.split('
')
+ return title
+
def _sanitize_filename(self, title: str) -> str:
"""Convert plot title to a safe filename."""
# Remove HTML tags
@@ -156,8 +162,8 @@ class JPMCPlotsMixin:
chart_spec = chart.to_dict()
existing_title = chart_spec.get('title', '')
- # Handle different title formats (string vs dict)
- if isinstance(existing_title, str):
+ # Handle different title formats (string vs dict vs list)
+ if isinstance(existing_title, (str, list)):
title_config = {
'text': existing_title,
'subtitle': lines,
@@ -260,6 +266,7 @@ class JPMCPlotsMixin:
color: str = ColorPalette.PRIMARY,
height: int | None = None,
width: int | str | None = None,
+ domain: list[float] | None = None,
) -> alt.Chart:
"""Create a bar plot showing average scores and count of non-null values for each column."""
df = self._ensure_dataframe(data)
@@ -278,11 +285,14 @@ class JPMCPlotsMixin:
# Convert to pandas for Altair (sort by average descending)
stats_df = pl.DataFrame(stats).sort('average', descending=True).to_pandas()
+
+ if domain is None:
+ domain = [stats_df['average'].min(), stats_df['average'].max()]
# Base bar chart
bars = alt.Chart(stats_df).mark_bar(color=color).encode(
x=alt.X('voice:N', title=x_label, sort='-y'),
- y=alt.Y('average:Q', title=y_label, scale=alt.Scale(domain=[0, 10])),
+ y=alt.Y('average:Q', title=y_label, scale=alt.Scale(domain=domain)),
tooltip=[
alt.Tooltip('voice:N', title='Voice'),
alt.Tooltip('average:Q', title='Average', format='.2f'),
@@ -303,7 +313,7 @@ class JPMCPlotsMixin:
# Combine layers
chart = (bars + text).properties(
- title=title,
+ title=self._process_title(title),
width=width or 800,
height=height or getattr(self, 'plot_height', 400)
)
@@ -360,7 +370,7 @@ class JPMCPlotsMixin:
alt.Tooltip('count:Q', title='Count')
]
).add_params(selection).properties(
- title=title,
+ title=self._process_title(title),
width=width or 800,
height=height or getattr(self, 'plot_height', 400)
)
@@ -420,7 +430,7 @@ class JPMCPlotsMixin:
alt.Tooltip('count:Q', title='Count')
]
).add_params(selection).properties(
- title=title,
+ title=self._process_title(title),
width=width or 800,
height=height or getattr(self, 'plot_height', 400)
)
@@ -473,7 +483,7 @@ class JPMCPlotsMixin:
alt.Tooltip('count:Q', title='1st Place Votes')
]
).properties(
- title=title,
+ title=self._process_title(title),
width=width or 800,
height=height or getattr(self, 'plot_height', 400)
)
@@ -514,7 +524,7 @@ class JPMCPlotsMixin:
)
chart = (bars + text).properties(
- title=title,
+ title=self._process_title(title),
width=width or 800,
height=height or getattr(self, 'plot_height', 400)
)
@@ -571,7 +581,7 @@ class JPMCPlotsMixin:
alt.Tooltip('count:Q', title='Selections')
]
).properties(
- title=title,
+ title=self._process_title(title),
width=width or 800,
height=height or getattr(self, 'plot_height', 400)
)
@@ -627,7 +637,7 @@ class JPMCPlotsMixin:
alt.Tooltip('count:Q', title='In Top 3')
]
).properties(
- title=title,
+ title=self._process_title(title),
width=width or 800,
height=height or getattr(self, 'plot_height', 400)
)
@@ -713,7 +723,7 @@ class JPMCPlotsMixin:
# Combine layers
chart = (bars + text).properties(
title={
- "text": title,
+ "text": self._process_title(title),
"subtitle": [trait_description, "(Numbers on bars indicate respondent count)"]
},
width=width or 800,
@@ -776,7 +786,7 @@ class JPMCPlotsMixin:
alt.Tooltip('correlation:Q', format='.2f')
]
).properties(
- title=title,
+ title=self._process_title(title),
width=width or 800,
height=height or 350
)
@@ -832,7 +842,7 @@ class JPMCPlotsMixin:
alt.Tooltip('correlation:Q', format='.2f')
]
).properties(
- title=title,
+ title=self._process_title(title),
width=width or 800,
height=height or 350
)
diff --git a/utils.py b/utils.py
index 28d3a2c..3b6fc0a 100644
--- a/utils.py
+++ b/utils.py
@@ -351,18 +351,22 @@ def calculate_weighted_ranking_scores(df: pl.LazyFrame) -> pl.DataFrame:
def normalize_row_values(df: pl.DataFrame, target_cols: list[str]) -> pl.DataFrame:
"""
- Normalizes values in the specified columns row-wise (Standardization: (x - mean) / std).
- Ignores null values (NaNs). Only applied if there are at least 2 non-null values in the row.
+ Normalizes values in the specified columns row-wise to 0-10 scale (Min-Max normalization).
+ Formula: ((x - min) / (max - min)) * 10
+ Ignores null values (NaNs).
"""
# Using list evaluation for row-wise stats
# We create a temporary list column containing values from all target columns
+ # Ensure columns are cast to Float64 to avoid type errors with mixed/string data
df_norm = df.with_columns(
- pl.concat_list(target_cols)
+ pl.concat_list([pl.col(c).cast(pl.Float64) for c in target_cols])
.list.eval(
- # Apply standardization: (x - mean) / std
- # std(ddof=1) is the sample standard deviation
- (pl.element() - pl.element().mean()) / pl.element().std(ddof=1)
+ # Apply Min-Max scaling to 0-10
+ (
+ (pl.element() - pl.element().min()) /
+ (pl.element().max() - pl.element().min())
+ ) * 10
)
.alias("_normalized_values")
)
@@ -377,8 +381,8 @@ def normalize_row_values(df: pl.DataFrame, target_cols: list[str]) -> pl.DataFra
def normalize_global_values(df: pl.DataFrame, target_cols: list[str]) -> pl.DataFrame:
"""
- Normalizes values in the specified columns globally (Standardization: (x - global_mean) / global_std).
- Computes a single mean and standard deviation across ALL values in the target_cols and applies it.
+ Normalizes values in the specified columns globally to 0-10 scale.
+ Formula: ((x - global_min) / (global_max - global_min)) * 10
Ignores null values (NaNs).
"""
# Ensure eager for scalar extraction
@@ -390,19 +394,23 @@ def normalize_global_values(df: pl.DataFrame, target_cols: list[str]) -> pl.Data
return df.lazy() if was_lazy else df
# Calculate global stats efficiently by stacking all columns
- stats = df.select(target_cols).melt().select([
- pl.col("value").mean().alias("mean"),
- pl.col("value").std().alias("std")
+ # Cast to Float64 to ensure numeric calculations
+ stats = df.select([pl.col(c).cast(pl.Float64) for c in target_cols]).melt().select([
+ pl.col("value").min().alias("min"),
+ pl.col("value").max().alias("max")
])
- global_mean = stats["mean"][0]
- global_std = stats["std"][0]
+ global_min = stats["min"][0]
+ global_max = stats["max"][0]
- if global_std is None or global_std == 0:
+ # Handle edge case where all values are same or none exist
+ if global_min is None or global_max is None or global_max == global_min:
return df.lazy() if was_lazy else df
+ global_range = global_max - global_min
+
res = df.with_columns([
- ((pl.col(col) - global_mean) / global_std).alias(col)
+ (((pl.col(col).cast(pl.Float64) - global_min) / global_range) * 10).alias(col)
for col in target_cols
])
@@ -649,10 +657,12 @@ class JPMCSurvey(JPMCPlotsMixin):
return subset, None
- def get_voice_scale_1_10(self, q: pl.LazyFrame) -> Union[pl.LazyFrame, None]:
+ def get_voice_scale_1_10(self, q: pl.LazyFrame, drop_cols=['Voice_Scale_1_10__V46']) -> Union[pl.LazyFrame, None]:
"""Extract columns containing the Voice Scale 1-10 ratings for the Chase virtual assistant.
Returns subquery that can be chained with other polars queries.
+
+ Drops scores for V46 as it was improperly configured in the survey and thus did not show up for respondents.
"""
QIDs_map = {}
@@ -662,6 +672,12 @@ class JPMCSurvey(JPMCPlotsMixin):
# Convert "Voice 16 Scale 1-10_1" to "Scale_1_10__Voice_16"
QIDs_map[qid] = f"Voice_Scale_1_10__V{val['QName'].split()[1]}"
+ for col in drop_cols:
+ if col in QIDs_map.values():
+ # remove from QIDs_map
+ qid_to_remove = [k for k,v in QIDs_map.items() if v == col][0]
+ del QIDs_map[qid_to_remove]
+
return self._get_subset(q, list(QIDs_map.keys()), rename_cols=False).rename(QIDs_map), None