correlation matrix speech characteristics vs score
This commit is contained in:
116
plots.py
116
plots.py
@@ -2861,3 +2861,119 @@ class QualtricsPlotsMixin:
|
|||||||
|
|
||||||
chart = self._save_plot(chart, title)
|
chart = self._save_plot(chart, title)
|
||||||
return chart
|
return chart
|
||||||
|
|
||||||
|
def plot_speech_attribute_correlation(
|
||||||
|
self,
|
||||||
|
corr_df: pl.DataFrame | pd.DataFrame,
|
||||||
|
title: str = "Speech Attributes vs Survey Metrics<br>Pearson Correlation",
|
||||||
|
filename: str | None = None,
|
||||||
|
height: int | None = None,
|
||||||
|
width: int | None = None,
|
||||||
|
show_values: bool = True,
|
||||||
|
color_scheme: str | None = None,
|
||||||
|
) -> alt.Chart:
|
||||||
|
"""Plot a correlation heatmap between speech attributes and survey metrics.
|
||||||
|
|
||||||
|
Expects a long-form DataFrame with columns:
|
||||||
|
- metric: row label (e.g. "Weighted Rank", "Avg Voice Score")
|
||||||
|
- attribute: column label (speech characteristic name)
|
||||||
|
- correlation: Pearson r value
|
||||||
|
|
||||||
|
Args:
|
||||||
|
corr_df: Long-form correlation DataFrame.
|
||||||
|
title: Chart title (supports <br> for line breaks).
|
||||||
|
filename: Optional explicit filename (without extension).
|
||||||
|
height: Chart height in pixels.
|
||||||
|
width: Chart width in pixels.
|
||||||
|
show_values: Whether to display correlation values as text.
|
||||||
|
color_scheme: Optional Altair diverging color scheme name.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
alt.Chart: Altair heatmap chart.
|
||||||
|
"""
|
||||||
|
if isinstance(corr_df, pl.DataFrame):
|
||||||
|
plot_df = corr_df.to_pandas()
|
||||||
|
else:
|
||||||
|
plot_df = corr_df
|
||||||
|
|
||||||
|
attributes = plot_df["attribute"].unique().tolist()
|
||||||
|
metrics = plot_df["metric"].unique().tolist()
|
||||||
|
|
||||||
|
n_attrs = len(attributes)
|
||||||
|
chart_width = width or max(600, n_attrs * 55)
|
||||||
|
chart_height = height or max(120, len(metrics) * 50 + 60)
|
||||||
|
|
||||||
|
heatmap = (
|
||||||
|
alt.Chart(plot_df)
|
||||||
|
.mark_rect(stroke="white", strokeWidth=1)
|
||||||
|
.encode(
|
||||||
|
x=alt.X(
|
||||||
|
"attribute:N",
|
||||||
|
title=None,
|
||||||
|
sort=attributes,
|
||||||
|
axis=alt.Axis(labelAngle=-45, labelLimit=180, grid=False),
|
||||||
|
),
|
||||||
|
y=alt.Y(
|
||||||
|
"metric:N",
|
||||||
|
title=None,
|
||||||
|
sort=metrics,
|
||||||
|
axis=alt.Axis(labelLimit=200, grid=False),
|
||||||
|
),
|
||||||
|
color=alt.Color(
|
||||||
|
"correlation:Q",
|
||||||
|
scale=alt.Scale(
|
||||||
|
domain=[-1, 1],
|
||||||
|
scheme=color_scheme or "redblue",
|
||||||
|
),
|
||||||
|
legend=alt.Legend(title="Pearson r"),
|
||||||
|
),
|
||||||
|
tooltip=[
|
||||||
|
alt.Tooltip("metric:N", title="Metric"),
|
||||||
|
alt.Tooltip("attribute:N", title="Attribute"),
|
||||||
|
alt.Tooltip("correlation:Q", title="r", format=".3f"),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
if show_values:
|
||||||
|
# Split into two text layers with fixed mark colors to avoid
|
||||||
|
# conflicting color encodings that break vl_convert PNG export.
|
||||||
|
dark_rows = plot_df[plot_df["correlation"].abs() <= 0.45]
|
||||||
|
light_rows = plot_df[plot_df["correlation"].abs() > 0.45]
|
||||||
|
|
||||||
|
text_layers = []
|
||||||
|
if not dark_rows.empty:
|
||||||
|
text_layers.append(
|
||||||
|
alt.Chart(dark_rows)
|
||||||
|
.mark_text(fontSize=11, fontWeight="normal", color="black")
|
||||||
|
.encode(
|
||||||
|
x=alt.X("attribute:N", sort=attributes),
|
||||||
|
y=alt.Y("metric:N", sort=metrics),
|
||||||
|
text=alt.Text("correlation:Q", format=".2f"),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
if not light_rows.empty:
|
||||||
|
text_layers.append(
|
||||||
|
alt.Chart(light_rows)
|
||||||
|
.mark_text(fontSize=11, fontWeight="normal", color="white")
|
||||||
|
.encode(
|
||||||
|
x=alt.X("attribute:N", sort=attributes),
|
||||||
|
y=alt.Y("metric:N", sort=metrics),
|
||||||
|
text=alt.Text("correlation:Q", format=".2f"),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
chart = heatmap
|
||||||
|
for tl in text_layers:
|
||||||
|
chart = chart + tl
|
||||||
|
else:
|
||||||
|
chart = heatmap
|
||||||
|
|
||||||
|
chart = chart.properties(
|
||||||
|
title=self._process_title(title),
|
||||||
|
width=chart_width,
|
||||||
|
height=chart_height,
|
||||||
|
)
|
||||||
|
|
||||||
|
chart = self._save_plot(chart, title, filename=filename)
|
||||||
|
return chart
|
||||||
File diff suppressed because one or more lines are too long
Reference in New Issue
Block a user