drop voice46 from scales 1-10. fix plots breakline in title
This commit is contained in:
@@ -1,7 +1,7 @@
|
|||||||
import marimo
|
import marimo
|
||||||
|
|
||||||
__generated_with = "0.19.2"
|
__generated_with = "0.19.2"
|
||||||
app = marimo.App(width="medium")
|
app = marimo.App(width="full")
|
||||||
|
|
||||||
|
|
||||||
@app.cell
|
@app.cell
|
||||||
@@ -167,17 +167,19 @@ def _(S, mo):
|
|||||||
''')
|
''')
|
||||||
|
|
||||||
|
|
||||||
return (filter_form,)
|
return
|
||||||
|
|
||||||
|
|
||||||
@app.cell
|
@app.cell
|
||||||
def _(S, data_validated, filter_form, mo):
|
def _(data_validated):
|
||||||
mo.stop(filter_form.value is None, mo.md("**Please submit filter above to proceed**"))
|
# mo.stop(filter_form.value is None, mo.md("**Please submit filter above to proceed**"))
|
||||||
_d = S.filter_data(data_validated, age=filter_form.value['age'], gender=filter_form.value['gender'], income=filter_form.value['income'], ethnicity=filter_form.value['ethnicity'], consumer=filter_form.value['consumer'])
|
# _d = S.filter_data(data_validated, age=filter_form.value['age'], gender=filter_form.value['gender'], income=filter_form.value['income'], ethnicity=filter_form.value['ethnicity'], consumer=filter_form.value['consumer'])
|
||||||
|
|
||||||
# Stop execution and prevent other cells from running if no data is selected
|
# # Stop execution and prevent other cells from running if no data is selected
|
||||||
mo.stop(len(_d.collect()) == 0, mo.md("**No Data available for current filter combination**"))
|
# mo.stop(len(_d.collect()) == 0, mo.md("**No Data available for current filter combination**"))
|
||||||
data = _d
|
# data = _d
|
||||||
|
|
||||||
|
data = data_validated
|
||||||
|
|
||||||
data.collect()
|
data.collect()
|
||||||
return (data,)
|
return (data,)
|
||||||
@@ -391,28 +393,25 @@ def _(S, mo, vscales):
|
|||||||
mo.md(f"""
|
mo.md(f"""
|
||||||
### How does each voice score on a scale from 1-10?
|
### How does each voice score on a scale from 1-10?
|
||||||
|
|
||||||
{mo.ui.altair_chart(S.plot_average_scores_with_counts(vscales, x_label='Voice', width=1000))}
|
{mo.ui.altair_chart(S.plot_average_scores_with_counts(vscales, x_label='Voice', width=1000, domain=[1,10]))}
|
||||||
""")
|
""")
|
||||||
return
|
return
|
||||||
|
|
||||||
|
|
||||||
@app.cell
|
@app.cell
|
||||||
def _(vscales):
|
def _(utils, vscales):
|
||||||
target_cols=[c for c in vscales.columns if c not in ['_recordId']]
|
_target_cols=[c for c in vscales.collect().columns if c not in ['_recordId']]
|
||||||
target_cols
|
vscales_row_norm = utils.normalize_row_values(vscales.collect(), target_cols=_target_cols)
|
||||||
return (target_cols,)
|
vscales_row_norm
|
||||||
|
return (vscales_row_norm,)
|
||||||
|
|
||||||
|
|
||||||
@app.cell
|
@app.cell
|
||||||
def _(target_cols, utils, vscales):
|
def _(S, mo, vscales_row_norm):
|
||||||
vscales_row_norm = utils.normalize_row_values(vscales.collect(), target_cols=target_cols)
|
mo.md(f"""
|
||||||
return
|
### Voice scale 1-10 normalized per respondent?
|
||||||
|
|
||||||
|
|
||||||
@app.cell
|
|
||||||
def _(mo):
|
|
||||||
mo.md(r"""
|
|
||||||
|
|
||||||
|
{mo.ui.altair_chart(S.plot_average_scores_with_counts(vscales_row_norm, x_label='Voice', width=1000))}
|
||||||
""")
|
""")
|
||||||
return
|
return
|
||||||
|
|
||||||
|
|||||||
36
plots.py
36
plots.py
@@ -13,6 +13,12 @@ import hashlib
|
|||||||
class JPMCPlotsMixin:
|
class JPMCPlotsMixin:
|
||||||
"""Mixin class for plotting functions in JPMCSurvey."""
|
"""Mixin class for plotting functions in JPMCSurvey."""
|
||||||
|
|
||||||
|
def _process_title(self, title: str) -> str | list[str]:
|
||||||
|
"""Process title to handle <br> tags for Altair."""
|
||||||
|
if isinstance(title, str) and '<br>' in title:
|
||||||
|
return title.split('<br>')
|
||||||
|
return title
|
||||||
|
|
||||||
def _sanitize_filename(self, title: str) -> str:
|
def _sanitize_filename(self, title: str) -> str:
|
||||||
"""Convert plot title to a safe filename."""
|
"""Convert plot title to a safe filename."""
|
||||||
# Remove HTML tags
|
# Remove HTML tags
|
||||||
@@ -156,8 +162,8 @@ class JPMCPlotsMixin:
|
|||||||
chart_spec = chart.to_dict()
|
chart_spec = chart.to_dict()
|
||||||
existing_title = chart_spec.get('title', '')
|
existing_title = chart_spec.get('title', '')
|
||||||
|
|
||||||
# Handle different title formats (string vs dict)
|
# Handle different title formats (string vs dict vs list)
|
||||||
if isinstance(existing_title, str):
|
if isinstance(existing_title, (str, list)):
|
||||||
title_config = {
|
title_config = {
|
||||||
'text': existing_title,
|
'text': existing_title,
|
||||||
'subtitle': lines,
|
'subtitle': lines,
|
||||||
@@ -260,6 +266,7 @@ class JPMCPlotsMixin:
|
|||||||
color: str = ColorPalette.PRIMARY,
|
color: str = ColorPalette.PRIMARY,
|
||||||
height: int | None = None,
|
height: int | None = None,
|
||||||
width: int | str | None = None,
|
width: int | str | None = None,
|
||||||
|
domain: list[float] | None = None,
|
||||||
) -> alt.Chart:
|
) -> alt.Chart:
|
||||||
"""Create a bar plot showing average scores and count of non-null values for each column."""
|
"""Create a bar plot showing average scores and count of non-null values for each column."""
|
||||||
df = self._ensure_dataframe(data)
|
df = self._ensure_dataframe(data)
|
||||||
@@ -279,10 +286,13 @@ class JPMCPlotsMixin:
|
|||||||
# Convert to pandas for Altair (sort by average descending)
|
# Convert to pandas for Altair (sort by average descending)
|
||||||
stats_df = pl.DataFrame(stats).sort('average', descending=True).to_pandas()
|
stats_df = pl.DataFrame(stats).sort('average', descending=True).to_pandas()
|
||||||
|
|
||||||
|
if domain is None:
|
||||||
|
domain = [stats_df['average'].min(), stats_df['average'].max()]
|
||||||
|
|
||||||
# Base bar chart
|
# Base bar chart
|
||||||
bars = alt.Chart(stats_df).mark_bar(color=color).encode(
|
bars = alt.Chart(stats_df).mark_bar(color=color).encode(
|
||||||
x=alt.X('voice:N', title=x_label, sort='-y'),
|
x=alt.X('voice:N', title=x_label, sort='-y'),
|
||||||
y=alt.Y('average:Q', title=y_label, scale=alt.Scale(domain=[0, 10])),
|
y=alt.Y('average:Q', title=y_label, scale=alt.Scale(domain=domain)),
|
||||||
tooltip=[
|
tooltip=[
|
||||||
alt.Tooltip('voice:N', title='Voice'),
|
alt.Tooltip('voice:N', title='Voice'),
|
||||||
alt.Tooltip('average:Q', title='Average', format='.2f'),
|
alt.Tooltip('average:Q', title='Average', format='.2f'),
|
||||||
@@ -303,7 +313,7 @@ class JPMCPlotsMixin:
|
|||||||
|
|
||||||
# Combine layers
|
# Combine layers
|
||||||
chart = (bars + text).properties(
|
chart = (bars + text).properties(
|
||||||
title=title,
|
title=self._process_title(title),
|
||||||
width=width or 800,
|
width=width or 800,
|
||||||
height=height or getattr(self, 'plot_height', 400)
|
height=height or getattr(self, 'plot_height', 400)
|
||||||
)
|
)
|
||||||
@@ -360,7 +370,7 @@ class JPMCPlotsMixin:
|
|||||||
alt.Tooltip('count:Q', title='Count')
|
alt.Tooltip('count:Q', title='Count')
|
||||||
]
|
]
|
||||||
).add_params(selection).properties(
|
).add_params(selection).properties(
|
||||||
title=title,
|
title=self._process_title(title),
|
||||||
width=width or 800,
|
width=width or 800,
|
||||||
height=height or getattr(self, 'plot_height', 400)
|
height=height or getattr(self, 'plot_height', 400)
|
||||||
)
|
)
|
||||||
@@ -420,7 +430,7 @@ class JPMCPlotsMixin:
|
|||||||
alt.Tooltip('count:Q', title='Count')
|
alt.Tooltip('count:Q', title='Count')
|
||||||
]
|
]
|
||||||
).add_params(selection).properties(
|
).add_params(selection).properties(
|
||||||
title=title,
|
title=self._process_title(title),
|
||||||
width=width or 800,
|
width=width or 800,
|
||||||
height=height or getattr(self, 'plot_height', 400)
|
height=height or getattr(self, 'plot_height', 400)
|
||||||
)
|
)
|
||||||
@@ -473,7 +483,7 @@ class JPMCPlotsMixin:
|
|||||||
alt.Tooltip('count:Q', title='1st Place Votes')
|
alt.Tooltip('count:Q', title='1st Place Votes')
|
||||||
]
|
]
|
||||||
).properties(
|
).properties(
|
||||||
title=title,
|
title=self._process_title(title),
|
||||||
width=width or 800,
|
width=width or 800,
|
||||||
height=height or getattr(self, 'plot_height', 400)
|
height=height or getattr(self, 'plot_height', 400)
|
||||||
)
|
)
|
||||||
@@ -514,7 +524,7 @@ class JPMCPlotsMixin:
|
|||||||
)
|
)
|
||||||
|
|
||||||
chart = (bars + text).properties(
|
chart = (bars + text).properties(
|
||||||
title=title,
|
title=self._process_title(title),
|
||||||
width=width or 800,
|
width=width or 800,
|
||||||
height=height or getattr(self, 'plot_height', 400)
|
height=height or getattr(self, 'plot_height', 400)
|
||||||
)
|
)
|
||||||
@@ -571,7 +581,7 @@ class JPMCPlotsMixin:
|
|||||||
alt.Tooltip('count:Q', title='Selections')
|
alt.Tooltip('count:Q', title='Selections')
|
||||||
]
|
]
|
||||||
).properties(
|
).properties(
|
||||||
title=title,
|
title=self._process_title(title),
|
||||||
width=width or 800,
|
width=width or 800,
|
||||||
height=height or getattr(self, 'plot_height', 400)
|
height=height or getattr(self, 'plot_height', 400)
|
||||||
)
|
)
|
||||||
@@ -627,7 +637,7 @@ class JPMCPlotsMixin:
|
|||||||
alt.Tooltip('count:Q', title='In Top 3')
|
alt.Tooltip('count:Q', title='In Top 3')
|
||||||
]
|
]
|
||||||
).properties(
|
).properties(
|
||||||
title=title,
|
title=self._process_title(title),
|
||||||
width=width or 800,
|
width=width or 800,
|
||||||
height=height or getattr(self, 'plot_height', 400)
|
height=height or getattr(self, 'plot_height', 400)
|
||||||
)
|
)
|
||||||
@@ -713,7 +723,7 @@ class JPMCPlotsMixin:
|
|||||||
# Combine layers
|
# Combine layers
|
||||||
chart = (bars + text).properties(
|
chart = (bars + text).properties(
|
||||||
title={
|
title={
|
||||||
"text": title,
|
"text": self._process_title(title),
|
||||||
"subtitle": [trait_description, "(Numbers on bars indicate respondent count)"]
|
"subtitle": [trait_description, "(Numbers on bars indicate respondent count)"]
|
||||||
},
|
},
|
||||||
width=width or 800,
|
width=width or 800,
|
||||||
@@ -776,7 +786,7 @@ class JPMCPlotsMixin:
|
|||||||
alt.Tooltip('correlation:Q', format='.2f')
|
alt.Tooltip('correlation:Q', format='.2f')
|
||||||
]
|
]
|
||||||
).properties(
|
).properties(
|
||||||
title=title,
|
title=self._process_title(title),
|
||||||
width=width or 800,
|
width=width or 800,
|
||||||
height=height or 350
|
height=height or 350
|
||||||
)
|
)
|
||||||
@@ -832,7 +842,7 @@ class JPMCPlotsMixin:
|
|||||||
alt.Tooltip('correlation:Q', format='.2f')
|
alt.Tooltip('correlation:Q', format='.2f')
|
||||||
]
|
]
|
||||||
).properties(
|
).properties(
|
||||||
title=title,
|
title=self._process_title(title),
|
||||||
width=width or 800,
|
width=width or 800,
|
||||||
height=height or 350
|
height=height or 350
|
||||||
)
|
)
|
||||||
|
|||||||
48
utils.py
48
utils.py
@@ -351,18 +351,22 @@ def calculate_weighted_ranking_scores(df: pl.LazyFrame) -> pl.DataFrame:
|
|||||||
|
|
||||||
def normalize_row_values(df: pl.DataFrame, target_cols: list[str]) -> pl.DataFrame:
|
def normalize_row_values(df: pl.DataFrame, target_cols: list[str]) -> pl.DataFrame:
|
||||||
"""
|
"""
|
||||||
Normalizes values in the specified columns row-wise (Standardization: (x - mean) / std).
|
Normalizes values in the specified columns row-wise to 0-10 scale (Min-Max normalization).
|
||||||
Ignores null values (NaNs). Only applied if there are at least 2 non-null values in the row.
|
Formula: ((x - min) / (max - min)) * 10
|
||||||
|
Ignores null values (NaNs).
|
||||||
"""
|
"""
|
||||||
|
|
||||||
# Using list evaluation for row-wise stats
|
# Using list evaluation for row-wise stats
|
||||||
# We create a temporary list column containing values from all target columns
|
# We create a temporary list column containing values from all target columns
|
||||||
|
# Ensure columns are cast to Float64 to avoid type errors with mixed/string data
|
||||||
df_norm = df.with_columns(
|
df_norm = df.with_columns(
|
||||||
pl.concat_list(target_cols)
|
pl.concat_list([pl.col(c).cast(pl.Float64) for c in target_cols])
|
||||||
.list.eval(
|
.list.eval(
|
||||||
# Apply standardization: (x - mean) / std
|
# Apply Min-Max scaling to 0-10
|
||||||
# std(ddof=1) is the sample standard deviation
|
(
|
||||||
(pl.element() - pl.element().mean()) / pl.element().std(ddof=1)
|
(pl.element() - pl.element().min()) /
|
||||||
|
(pl.element().max() - pl.element().min())
|
||||||
|
) * 10
|
||||||
)
|
)
|
||||||
.alias("_normalized_values")
|
.alias("_normalized_values")
|
||||||
)
|
)
|
||||||
@@ -377,8 +381,8 @@ def normalize_row_values(df: pl.DataFrame, target_cols: list[str]) -> pl.DataFra
|
|||||||
|
|
||||||
def normalize_global_values(df: pl.DataFrame, target_cols: list[str]) -> pl.DataFrame:
|
def normalize_global_values(df: pl.DataFrame, target_cols: list[str]) -> pl.DataFrame:
|
||||||
"""
|
"""
|
||||||
Normalizes values in the specified columns globally (Standardization: (x - global_mean) / global_std).
|
Normalizes values in the specified columns globally to 0-10 scale.
|
||||||
Computes a single mean and standard deviation across ALL values in the target_cols and applies it.
|
Formula: ((x - global_min) / (global_max - global_min)) * 10
|
||||||
Ignores null values (NaNs).
|
Ignores null values (NaNs).
|
||||||
"""
|
"""
|
||||||
# Ensure eager for scalar extraction
|
# Ensure eager for scalar extraction
|
||||||
@@ -390,19 +394,23 @@ def normalize_global_values(df: pl.DataFrame, target_cols: list[str]) -> pl.Data
|
|||||||
return df.lazy() if was_lazy else df
|
return df.lazy() if was_lazy else df
|
||||||
|
|
||||||
# Calculate global stats efficiently by stacking all columns
|
# Calculate global stats efficiently by stacking all columns
|
||||||
stats = df.select(target_cols).melt().select([
|
# Cast to Float64 to ensure numeric calculations
|
||||||
pl.col("value").mean().alias("mean"),
|
stats = df.select([pl.col(c).cast(pl.Float64) for c in target_cols]).melt().select([
|
||||||
pl.col("value").std().alias("std")
|
pl.col("value").min().alias("min"),
|
||||||
|
pl.col("value").max().alias("max")
|
||||||
])
|
])
|
||||||
|
|
||||||
global_mean = stats["mean"][0]
|
global_min = stats["min"][0]
|
||||||
global_std = stats["std"][0]
|
global_max = stats["max"][0]
|
||||||
|
|
||||||
if global_std is None or global_std == 0:
|
# Handle edge case where all values are same or none exist
|
||||||
|
if global_min is None or global_max is None or global_max == global_min:
|
||||||
return df.lazy() if was_lazy else df
|
return df.lazy() if was_lazy else df
|
||||||
|
|
||||||
|
global_range = global_max - global_min
|
||||||
|
|
||||||
res = df.with_columns([
|
res = df.with_columns([
|
||||||
((pl.col(col) - global_mean) / global_std).alias(col)
|
(((pl.col(col).cast(pl.Float64) - global_min) / global_range) * 10).alias(col)
|
||||||
for col in target_cols
|
for col in target_cols
|
||||||
])
|
])
|
||||||
|
|
||||||
@@ -649,10 +657,12 @@ class JPMCSurvey(JPMCPlotsMixin):
|
|||||||
return subset, None
|
return subset, None
|
||||||
|
|
||||||
|
|
||||||
def get_voice_scale_1_10(self, q: pl.LazyFrame) -> Union[pl.LazyFrame, None]:
|
def get_voice_scale_1_10(self, q: pl.LazyFrame, drop_cols=['Voice_Scale_1_10__V46']) -> Union[pl.LazyFrame, None]:
|
||||||
"""Extract columns containing the Voice Scale 1-10 ratings for the Chase virtual assistant.
|
"""Extract columns containing the Voice Scale 1-10 ratings for the Chase virtual assistant.
|
||||||
|
|
||||||
Returns subquery that can be chained with other polars queries.
|
Returns subquery that can be chained with other polars queries.
|
||||||
|
|
||||||
|
Drops scores for V46 as it was improperly configured in the survey and thus did not show up for respondents.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
QIDs_map = {}
|
QIDs_map = {}
|
||||||
@@ -662,6 +672,12 @@ class JPMCSurvey(JPMCPlotsMixin):
|
|||||||
# Convert "Voice 16 Scale 1-10_1" to "Scale_1_10__Voice_16"
|
# Convert "Voice 16 Scale 1-10_1" to "Scale_1_10__Voice_16"
|
||||||
QIDs_map[qid] = f"Voice_Scale_1_10__V{val['QName'].split()[1]}"
|
QIDs_map[qid] = f"Voice_Scale_1_10__V{val['QName'].split()[1]}"
|
||||||
|
|
||||||
|
for col in drop_cols:
|
||||||
|
if col in QIDs_map.values():
|
||||||
|
# remove from QIDs_map
|
||||||
|
qid_to_remove = [k for k,v in QIDs_map.items() if v == col][0]
|
||||||
|
del QIDs_map[qid_to_remove]
|
||||||
|
|
||||||
return self._get_subset(q, list(QIDs_map.keys()), rename_cols=False).rename(QIDs_map), None
|
return self._get_subset(q, list(QIDs_map.keys()), rename_cols=False).rename(QIDs_map), None
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user