drop voice46 from scales 1-10. fix plots breakline in title

This commit is contained in:
2026-01-29 21:10:56 +01:00
parent 8aee09f968
commit becc435d3c
3 changed files with 75 additions and 50 deletions

View File

@@ -1,7 +1,7 @@
import marimo import marimo
__generated_with = "0.19.2" __generated_with = "0.19.2"
app = marimo.App(width="medium") app = marimo.App(width="full")
@app.cell @app.cell
@@ -167,17 +167,19 @@ def _(S, mo):
''') ''')
return (filter_form,) return
@app.cell @app.cell
def _(S, data_validated, filter_form, mo): def _(data_validated):
mo.stop(filter_form.value is None, mo.md("**Please submit filter above to proceed**")) # mo.stop(filter_form.value is None, mo.md("**Please submit filter above to proceed**"))
_d = S.filter_data(data_validated, age=filter_form.value['age'], gender=filter_form.value['gender'], income=filter_form.value['income'], ethnicity=filter_form.value['ethnicity'], consumer=filter_form.value['consumer']) # _d = S.filter_data(data_validated, age=filter_form.value['age'], gender=filter_form.value['gender'], income=filter_form.value['income'], ethnicity=filter_form.value['ethnicity'], consumer=filter_form.value['consumer'])
# Stop execution and prevent other cells from running if no data is selected # # Stop execution and prevent other cells from running if no data is selected
mo.stop(len(_d.collect()) == 0, mo.md("**No Data available for current filter combination**")) # mo.stop(len(_d.collect()) == 0, mo.md("**No Data available for current filter combination**"))
data = _d # data = _d
data = data_validated
data.collect() data.collect()
return (data,) return (data,)
@@ -391,28 +393,25 @@ def _(S, mo, vscales):
mo.md(f""" mo.md(f"""
### How does each voice score on a scale from 1-10? ### How does each voice score on a scale from 1-10?
{mo.ui.altair_chart(S.plot_average_scores_with_counts(vscales, x_label='Voice', width=1000))} {mo.ui.altair_chart(S.plot_average_scores_with_counts(vscales, x_label='Voice', width=1000, domain=[1,10]))}
""") """)
return return
@app.cell @app.cell
def _(vscales): def _(utils, vscales):
target_cols=[c for c in vscales.columns if c not in ['_recordId']] _target_cols=[c for c in vscales.collect().columns if c not in ['_recordId']]
target_cols vscales_row_norm = utils.normalize_row_values(vscales.collect(), target_cols=_target_cols)
return (target_cols,) vscales_row_norm
return (vscales_row_norm,)
@app.cell @app.cell
def _(target_cols, utils, vscales): def _(S, mo, vscales_row_norm):
vscales_row_norm = utils.normalize_row_values(vscales.collect(), target_cols=target_cols) mo.md(f"""
return ### Voice scale 1-10 normalized per respondent?
@app.cell
def _(mo):
mo.md(r"""
{mo.ui.altair_chart(S.plot_average_scores_with_counts(vscales_row_norm, x_label='Voice', width=1000))}
""") """)
return return

View File

@@ -13,6 +13,12 @@ import hashlib
class JPMCPlotsMixin: class JPMCPlotsMixin:
"""Mixin class for plotting functions in JPMCSurvey.""" """Mixin class for plotting functions in JPMCSurvey."""
def _process_title(self, title: str) -> str | list[str]:
"""Process title to handle <br> tags for Altair."""
if isinstance(title, str) and '<br>' in title:
return title.split('<br>')
return title
def _sanitize_filename(self, title: str) -> str: def _sanitize_filename(self, title: str) -> str:
"""Convert plot title to a safe filename.""" """Convert plot title to a safe filename."""
# Remove HTML tags # Remove HTML tags
@@ -156,8 +162,8 @@ class JPMCPlotsMixin:
chart_spec = chart.to_dict() chart_spec = chart.to_dict()
existing_title = chart_spec.get('title', '') existing_title = chart_spec.get('title', '')
# Handle different title formats (string vs dict) # Handle different title formats (string vs dict vs list)
if isinstance(existing_title, str): if isinstance(existing_title, (str, list)):
title_config = { title_config = {
'text': existing_title, 'text': existing_title,
'subtitle': lines, 'subtitle': lines,
@@ -260,6 +266,7 @@ class JPMCPlotsMixin:
color: str = ColorPalette.PRIMARY, color: str = ColorPalette.PRIMARY,
height: int | None = None, height: int | None = None,
width: int | str | None = None, width: int | str | None = None,
domain: list[float] | None = None,
) -> alt.Chart: ) -> alt.Chart:
"""Create a bar plot showing average scores and count of non-null values for each column.""" """Create a bar plot showing average scores and count of non-null values for each column."""
df = self._ensure_dataframe(data) df = self._ensure_dataframe(data)
@@ -279,10 +286,13 @@ class JPMCPlotsMixin:
# Convert to pandas for Altair (sort by average descending) # Convert to pandas for Altair (sort by average descending)
stats_df = pl.DataFrame(stats).sort('average', descending=True).to_pandas() stats_df = pl.DataFrame(stats).sort('average', descending=True).to_pandas()
if domain is None:
domain = [stats_df['average'].min(), stats_df['average'].max()]
# Base bar chart # Base bar chart
bars = alt.Chart(stats_df).mark_bar(color=color).encode( bars = alt.Chart(stats_df).mark_bar(color=color).encode(
x=alt.X('voice:N', title=x_label, sort='-y'), x=alt.X('voice:N', title=x_label, sort='-y'),
y=alt.Y('average:Q', title=y_label, scale=alt.Scale(domain=[0, 10])), y=alt.Y('average:Q', title=y_label, scale=alt.Scale(domain=domain)),
tooltip=[ tooltip=[
alt.Tooltip('voice:N', title='Voice'), alt.Tooltip('voice:N', title='Voice'),
alt.Tooltip('average:Q', title='Average', format='.2f'), alt.Tooltip('average:Q', title='Average', format='.2f'),
@@ -303,7 +313,7 @@ class JPMCPlotsMixin:
# Combine layers # Combine layers
chart = (bars + text).properties( chart = (bars + text).properties(
title=title, title=self._process_title(title),
width=width or 800, width=width or 800,
height=height or getattr(self, 'plot_height', 400) height=height or getattr(self, 'plot_height', 400)
) )
@@ -360,7 +370,7 @@ class JPMCPlotsMixin:
alt.Tooltip('count:Q', title='Count') alt.Tooltip('count:Q', title='Count')
] ]
).add_params(selection).properties( ).add_params(selection).properties(
title=title, title=self._process_title(title),
width=width or 800, width=width or 800,
height=height or getattr(self, 'plot_height', 400) height=height or getattr(self, 'plot_height', 400)
) )
@@ -420,7 +430,7 @@ class JPMCPlotsMixin:
alt.Tooltip('count:Q', title='Count') alt.Tooltip('count:Q', title='Count')
] ]
).add_params(selection).properties( ).add_params(selection).properties(
title=title, title=self._process_title(title),
width=width or 800, width=width or 800,
height=height or getattr(self, 'plot_height', 400) height=height or getattr(self, 'plot_height', 400)
) )
@@ -473,7 +483,7 @@ class JPMCPlotsMixin:
alt.Tooltip('count:Q', title='1st Place Votes') alt.Tooltip('count:Q', title='1st Place Votes')
] ]
).properties( ).properties(
title=title, title=self._process_title(title),
width=width or 800, width=width or 800,
height=height or getattr(self, 'plot_height', 400) height=height or getattr(self, 'plot_height', 400)
) )
@@ -514,7 +524,7 @@ class JPMCPlotsMixin:
) )
chart = (bars + text).properties( chart = (bars + text).properties(
title=title, title=self._process_title(title),
width=width or 800, width=width or 800,
height=height or getattr(self, 'plot_height', 400) height=height or getattr(self, 'plot_height', 400)
) )
@@ -571,7 +581,7 @@ class JPMCPlotsMixin:
alt.Tooltip('count:Q', title='Selections') alt.Tooltip('count:Q', title='Selections')
] ]
).properties( ).properties(
title=title, title=self._process_title(title),
width=width or 800, width=width or 800,
height=height or getattr(self, 'plot_height', 400) height=height or getattr(self, 'plot_height', 400)
) )
@@ -627,7 +637,7 @@ class JPMCPlotsMixin:
alt.Tooltip('count:Q', title='In Top 3') alt.Tooltip('count:Q', title='In Top 3')
] ]
).properties( ).properties(
title=title, title=self._process_title(title),
width=width or 800, width=width or 800,
height=height or getattr(self, 'plot_height', 400) height=height or getattr(self, 'plot_height', 400)
) )
@@ -713,7 +723,7 @@ class JPMCPlotsMixin:
# Combine layers # Combine layers
chart = (bars + text).properties( chart = (bars + text).properties(
title={ title={
"text": title, "text": self._process_title(title),
"subtitle": [trait_description, "(Numbers on bars indicate respondent count)"] "subtitle": [trait_description, "(Numbers on bars indicate respondent count)"]
}, },
width=width or 800, width=width or 800,
@@ -776,7 +786,7 @@ class JPMCPlotsMixin:
alt.Tooltip('correlation:Q', format='.2f') alt.Tooltip('correlation:Q', format='.2f')
] ]
).properties( ).properties(
title=title, title=self._process_title(title),
width=width or 800, width=width or 800,
height=height or 350 height=height or 350
) )
@@ -832,7 +842,7 @@ class JPMCPlotsMixin:
alt.Tooltip('correlation:Q', format='.2f') alt.Tooltip('correlation:Q', format='.2f')
] ]
).properties( ).properties(
title=title, title=self._process_title(title),
width=width or 800, width=width or 800,
height=height or 350 height=height or 350
) )

View File

@@ -351,18 +351,22 @@ def calculate_weighted_ranking_scores(df: pl.LazyFrame) -> pl.DataFrame:
def normalize_row_values(df: pl.DataFrame, target_cols: list[str]) -> pl.DataFrame: def normalize_row_values(df: pl.DataFrame, target_cols: list[str]) -> pl.DataFrame:
""" """
Normalizes values in the specified columns row-wise (Standardization: (x - mean) / std). Normalizes values in the specified columns row-wise to 0-10 scale (Min-Max normalization).
Ignores null values (NaNs). Only applied if there are at least 2 non-null values in the row. Formula: ((x - min) / (max - min)) * 10
Ignores null values (NaNs).
""" """
# Using list evaluation for row-wise stats # Using list evaluation for row-wise stats
# We create a temporary list column containing values from all target columns # We create a temporary list column containing values from all target columns
# Ensure columns are cast to Float64 to avoid type errors with mixed/string data
df_norm = df.with_columns( df_norm = df.with_columns(
pl.concat_list(target_cols) pl.concat_list([pl.col(c).cast(pl.Float64) for c in target_cols])
.list.eval( .list.eval(
# Apply standardization: (x - mean) / std # Apply Min-Max scaling to 0-10
# std(ddof=1) is the sample standard deviation (
(pl.element() - pl.element().mean()) / pl.element().std(ddof=1) (pl.element() - pl.element().min()) /
(pl.element().max() - pl.element().min())
) * 10
) )
.alias("_normalized_values") .alias("_normalized_values")
) )
@@ -377,8 +381,8 @@ def normalize_row_values(df: pl.DataFrame, target_cols: list[str]) -> pl.DataFra
def normalize_global_values(df: pl.DataFrame, target_cols: list[str]) -> pl.DataFrame: def normalize_global_values(df: pl.DataFrame, target_cols: list[str]) -> pl.DataFrame:
""" """
Normalizes values in the specified columns globally (Standardization: (x - global_mean) / global_std). Normalizes values in the specified columns globally to 0-10 scale.
Computes a single mean and standard deviation across ALL values in the target_cols and applies it. Formula: ((x - global_min) / (global_max - global_min)) * 10
Ignores null values (NaNs). Ignores null values (NaNs).
""" """
# Ensure eager for scalar extraction # Ensure eager for scalar extraction
@@ -390,19 +394,23 @@ def normalize_global_values(df: pl.DataFrame, target_cols: list[str]) -> pl.Data
return df.lazy() if was_lazy else df return df.lazy() if was_lazy else df
# Calculate global stats efficiently by stacking all columns # Calculate global stats efficiently by stacking all columns
stats = df.select(target_cols).melt().select([ # Cast to Float64 to ensure numeric calculations
pl.col("value").mean().alias("mean"), stats = df.select([pl.col(c).cast(pl.Float64) for c in target_cols]).melt().select([
pl.col("value").std().alias("std") pl.col("value").min().alias("min"),
pl.col("value").max().alias("max")
]) ])
global_mean = stats["mean"][0] global_min = stats["min"][0]
global_std = stats["std"][0] global_max = stats["max"][0]
if global_std is None or global_std == 0: # Handle edge case where all values are same or none exist
if global_min is None or global_max is None or global_max == global_min:
return df.lazy() if was_lazy else df return df.lazy() if was_lazy else df
global_range = global_max - global_min
res = df.with_columns([ res = df.with_columns([
((pl.col(col) - global_mean) / global_std).alias(col) (((pl.col(col).cast(pl.Float64) - global_min) / global_range) * 10).alias(col)
for col in target_cols for col in target_cols
]) ])
@@ -649,10 +657,12 @@ class JPMCSurvey(JPMCPlotsMixin):
return subset, None return subset, None
def get_voice_scale_1_10(self, q: pl.LazyFrame) -> Union[pl.LazyFrame, None]: def get_voice_scale_1_10(self, q: pl.LazyFrame, drop_cols=['Voice_Scale_1_10__V46']) -> Union[pl.LazyFrame, None]:
"""Extract columns containing the Voice Scale 1-10 ratings for the Chase virtual assistant. """Extract columns containing the Voice Scale 1-10 ratings for the Chase virtual assistant.
Returns subquery that can be chained with other polars queries. Returns subquery that can be chained with other polars queries.
Drops scores for V46 as it was improperly configured in the survey and thus did not show up for respondents.
""" """
QIDs_map = {} QIDs_map = {}
@@ -662,6 +672,12 @@ class JPMCSurvey(JPMCPlotsMixin):
# Convert "Voice 16 Scale 1-10_1" to "Scale_1_10__Voice_16" # Convert "Voice 16 Scale 1-10_1" to "Scale_1_10__Voice_16"
QIDs_map[qid] = f"Voice_Scale_1_10__V{val['QName'].split()[1]}" QIDs_map[qid] = f"Voice_Scale_1_10__V{val['QName'].split()[1]}"
for col in drop_cols:
if col in QIDs_map.values():
# remove from QIDs_map
qid_to_remove = [k for k,v in QIDs_map.items() if v == col][0]
del QIDs_map[qid_to_remove]
return self._get_subset(q, list(QIDs_map.keys()), rename_cols=False).rename(QIDs_map), None return self._get_subset(q, list(QIDs_map.keys()), rename_cols=False).rename(QIDs_map), None