drop voice46 from scales 1-10. fix plots breakline in title

This commit is contained in:
2026-01-29 21:10:56 +01:00
parent 8aee09f968
commit becc435d3c
3 changed files with 75 additions and 50 deletions

View File

@@ -1,7 +1,7 @@
import marimo
__generated_with = "0.19.2"
app = marimo.App(width="medium")
app = marimo.App(width="full")
@app.cell
@@ -167,17 +167,19 @@ def _(S, mo):
''')
return (filter_form,)
return
@app.cell
def _(S, data_validated, filter_form, mo):
mo.stop(filter_form.value is None, mo.md("**Please submit filter above to proceed**"))
_d = S.filter_data(data_validated, age=filter_form.value['age'], gender=filter_form.value['gender'], income=filter_form.value['income'], ethnicity=filter_form.value['ethnicity'], consumer=filter_form.value['consumer'])
def _(data_validated):
# mo.stop(filter_form.value is None, mo.md("**Please submit filter above to proceed**"))
# _d = S.filter_data(data_validated, age=filter_form.value['age'], gender=filter_form.value['gender'], income=filter_form.value['income'], ethnicity=filter_form.value['ethnicity'], consumer=filter_form.value['consumer'])
# Stop execution and prevent other cells from running if no data is selected
mo.stop(len(_d.collect()) == 0, mo.md("**No Data available for current filter combination**"))
data = _d
# # Stop execution and prevent other cells from running if no data is selected
# mo.stop(len(_d.collect()) == 0, mo.md("**No Data available for current filter combination**"))
# data = _d
data = data_validated
data.collect()
return (data,)
@@ -391,28 +393,25 @@ def _(S, mo, vscales):
mo.md(f"""
### How does each voice score on a scale from 1-10?
{mo.ui.altair_chart(S.plot_average_scores_with_counts(vscales, x_label='Voice', width=1000))}
{mo.ui.altair_chart(S.plot_average_scores_with_counts(vscales, x_label='Voice', width=1000, domain=[1,10]))}
""")
return
@app.cell
def _(vscales):
target_cols=[c for c in vscales.columns if c not in ['_recordId']]
target_cols
return (target_cols,)
def _(utils, vscales):
_target_cols=[c for c in vscales.collect().columns if c not in ['_recordId']]
vscales_row_norm = utils.normalize_row_values(vscales.collect(), target_cols=_target_cols)
vscales_row_norm
return (vscales_row_norm,)
@app.cell
def _(target_cols, utils, vscales):
vscales_row_norm = utils.normalize_row_values(vscales.collect(), target_cols=target_cols)
return
@app.cell
def _(mo):
mo.md(r"""
def _(S, mo, vscales_row_norm):
mo.md(f"""
### Voice scale 1-10 normalized per respondent?
{mo.ui.altair_chart(S.plot_average_scores_with_counts(vscales_row_norm, x_label='Voice', width=1000))}
""")
return

View File

@@ -13,6 +13,12 @@ import hashlib
class JPMCPlotsMixin:
"""Mixin class for plotting functions in JPMCSurvey."""
def _process_title(self, title: str) -> str | list[str]:
"""Process title to handle <br> tags for Altair."""
if isinstance(title, str) and '<br>' in title:
return title.split('<br>')
return title
def _sanitize_filename(self, title: str) -> str:
"""Convert plot title to a safe filename."""
# Remove HTML tags
@@ -156,8 +162,8 @@ class JPMCPlotsMixin:
chart_spec = chart.to_dict()
existing_title = chart_spec.get('title', '')
# Handle different title formats (string vs dict)
if isinstance(existing_title, str):
# Handle different title formats (string vs dict vs list)
if isinstance(existing_title, (str, list)):
title_config = {
'text': existing_title,
'subtitle': lines,
@@ -260,6 +266,7 @@ class JPMCPlotsMixin:
color: str = ColorPalette.PRIMARY,
height: int | None = None,
width: int | str | None = None,
domain: list[float] | None = None,
) -> alt.Chart:
"""Create a bar plot showing average scores and count of non-null values for each column."""
df = self._ensure_dataframe(data)
@@ -279,10 +286,13 @@ class JPMCPlotsMixin:
# Convert to pandas for Altair (sort by average descending)
stats_df = pl.DataFrame(stats).sort('average', descending=True).to_pandas()
if domain is None:
domain = [stats_df['average'].min(), stats_df['average'].max()]
# Base bar chart
bars = alt.Chart(stats_df).mark_bar(color=color).encode(
x=alt.X('voice:N', title=x_label, sort='-y'),
y=alt.Y('average:Q', title=y_label, scale=alt.Scale(domain=[0, 10])),
y=alt.Y('average:Q', title=y_label, scale=alt.Scale(domain=domain)),
tooltip=[
alt.Tooltip('voice:N', title='Voice'),
alt.Tooltip('average:Q', title='Average', format='.2f'),
@@ -303,7 +313,7 @@ class JPMCPlotsMixin:
# Combine layers
chart = (bars + text).properties(
title=title,
title=self._process_title(title),
width=width or 800,
height=height or getattr(self, 'plot_height', 400)
)
@@ -360,7 +370,7 @@ class JPMCPlotsMixin:
alt.Tooltip('count:Q', title='Count')
]
).add_params(selection).properties(
title=title,
title=self._process_title(title),
width=width or 800,
height=height or getattr(self, 'plot_height', 400)
)
@@ -420,7 +430,7 @@ class JPMCPlotsMixin:
alt.Tooltip('count:Q', title='Count')
]
).add_params(selection).properties(
title=title,
title=self._process_title(title),
width=width or 800,
height=height or getattr(self, 'plot_height', 400)
)
@@ -473,7 +483,7 @@ class JPMCPlotsMixin:
alt.Tooltip('count:Q', title='1st Place Votes')
]
).properties(
title=title,
title=self._process_title(title),
width=width or 800,
height=height or getattr(self, 'plot_height', 400)
)
@@ -514,7 +524,7 @@ class JPMCPlotsMixin:
)
chart = (bars + text).properties(
title=title,
title=self._process_title(title),
width=width or 800,
height=height or getattr(self, 'plot_height', 400)
)
@@ -571,7 +581,7 @@ class JPMCPlotsMixin:
alt.Tooltip('count:Q', title='Selections')
]
).properties(
title=title,
title=self._process_title(title),
width=width or 800,
height=height or getattr(self, 'plot_height', 400)
)
@@ -627,7 +637,7 @@ class JPMCPlotsMixin:
alt.Tooltip('count:Q', title='In Top 3')
]
).properties(
title=title,
title=self._process_title(title),
width=width or 800,
height=height or getattr(self, 'plot_height', 400)
)
@@ -713,7 +723,7 @@ class JPMCPlotsMixin:
# Combine layers
chart = (bars + text).properties(
title={
"text": title,
"text": self._process_title(title),
"subtitle": [trait_description, "(Numbers on bars indicate respondent count)"]
},
width=width or 800,
@@ -776,7 +786,7 @@ class JPMCPlotsMixin:
alt.Tooltip('correlation:Q', format='.2f')
]
).properties(
title=title,
title=self._process_title(title),
width=width or 800,
height=height or 350
)
@@ -832,7 +842,7 @@ class JPMCPlotsMixin:
alt.Tooltip('correlation:Q', format='.2f')
]
).properties(
title=title,
title=self._process_title(title),
width=width or 800,
height=height or 350
)

View File

@@ -351,18 +351,22 @@ def calculate_weighted_ranking_scores(df: pl.LazyFrame) -> pl.DataFrame:
def normalize_row_values(df: pl.DataFrame, target_cols: list[str]) -> pl.DataFrame:
"""
Normalizes values in the specified columns row-wise (Standardization: (x - mean) / std).
Ignores null values (NaNs). Only applied if there are at least 2 non-null values in the row.
Normalizes values in the specified columns row-wise to 0-10 scale (Min-Max normalization).
Formula: ((x - min) / (max - min)) * 10
Ignores null values (NaNs).
"""
# Using list evaluation for row-wise stats
# We create a temporary list column containing values from all target columns
# Ensure columns are cast to Float64 to avoid type errors with mixed/string data
df_norm = df.with_columns(
pl.concat_list(target_cols)
pl.concat_list([pl.col(c).cast(pl.Float64) for c in target_cols])
.list.eval(
# Apply standardization: (x - mean) / std
# std(ddof=1) is the sample standard deviation
(pl.element() - pl.element().mean()) / pl.element().std(ddof=1)
# Apply Min-Max scaling to 0-10
(
(pl.element() - pl.element().min()) /
(pl.element().max() - pl.element().min())
) * 10
)
.alias("_normalized_values")
)
@@ -377,8 +381,8 @@ def normalize_row_values(df: pl.DataFrame, target_cols: list[str]) -> pl.DataFra
def normalize_global_values(df: pl.DataFrame, target_cols: list[str]) -> pl.DataFrame:
"""
Normalizes values in the specified columns globally (Standardization: (x - global_mean) / global_std).
Computes a single mean and standard deviation across ALL values in the target_cols and applies it.
Normalizes values in the specified columns globally to 0-10 scale.
Formula: ((x - global_min) / (global_max - global_min)) * 10
Ignores null values (NaNs).
"""
# Ensure eager for scalar extraction
@@ -390,19 +394,23 @@ def normalize_global_values(df: pl.DataFrame, target_cols: list[str]) -> pl.Data
return df.lazy() if was_lazy else df
# Calculate global stats efficiently by stacking all columns
stats = df.select(target_cols).melt().select([
pl.col("value").mean().alias("mean"),
pl.col("value").std().alias("std")
# Cast to Float64 to ensure numeric calculations
stats = df.select([pl.col(c).cast(pl.Float64) for c in target_cols]).melt().select([
pl.col("value").min().alias("min"),
pl.col("value").max().alias("max")
])
global_mean = stats["mean"][0]
global_std = stats["std"][0]
global_min = stats["min"][0]
global_max = stats["max"][0]
if global_std is None or global_std == 0:
# Handle edge case where all values are same or none exist
if global_min is None or global_max is None or global_max == global_min:
return df.lazy() if was_lazy else df
global_range = global_max - global_min
res = df.with_columns([
((pl.col(col) - global_mean) / global_std).alias(col)
(((pl.col(col).cast(pl.Float64) - global_min) / global_range) * 10).alias(col)
for col in target_cols
])
@@ -649,10 +657,12 @@ class JPMCSurvey(JPMCPlotsMixin):
return subset, None
def get_voice_scale_1_10(self, q: pl.LazyFrame) -> Union[pl.LazyFrame, None]:
def get_voice_scale_1_10(self, q: pl.LazyFrame, drop_cols=['Voice_Scale_1_10__V46']) -> Union[pl.LazyFrame, None]:
"""Extract columns containing the Voice Scale 1-10 ratings for the Chase virtual assistant.
Returns subquery that can be chained with other polars queries.
Drops scores for V46 as it was improperly configured in the survey and thus did not show up for respondents.
"""
QIDs_map = {}
@@ -662,6 +672,12 @@ class JPMCSurvey(JPMCPlotsMixin):
# Convert "Voice 16 Scale 1-10_1" to "Scale_1_10__Voice_16"
QIDs_map[qid] = f"Voice_Scale_1_10__V{val['QName'].split()[1]}"
for col in drop_cols:
if col in QIDs_map.values():
# remove from QIDs_map
qid_to_remove = [k for k,v in QIDs_map.items() if v == col][0]
del QIDs_map[qid_to_remove]
return self._get_subset(q, list(QIDs_map.keys()), rename_cols=False).rename(QIDs_map), None