base correlations
This commit is contained in:
63
utils.py
63
utils.py
@@ -1676,6 +1676,69 @@ def join_voice_and_style_data(
|
||||
how="inner"
|
||||
)
|
||||
|
||||
|
||||
def transform_speaking_style_color_correlation(
|
||||
joined_df: pl.LazyFrame | pl.DataFrame,
|
||||
speaking_styles: dict[str, list[str]],
|
||||
target_column: str = "Voice_Scale_Score"
|
||||
) -> tuple[pl.DataFrame, dict | None]:
|
||||
"""Aggregate speaking style correlation by color (Green, Blue, Orange, Red).
|
||||
|
||||
Original use-case: "I want to create high-level correlation plots between
|
||||
'green, blue, orange, red' speaking styles and the 'voice scale scores'.
|
||||
I want to go to one plot with one bar for each color."
|
||||
|
||||
This function calculates the mean correlation per speaking style color by
|
||||
averaging the correlations of all traits within each color.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
joined_df : pl.LazyFrame or pl.DataFrame
|
||||
Pre-fetched data from joining speaking style data with target data.
|
||||
Must have columns: Right_Anchor, score, and the target_column
|
||||
speaking_styles : dict
|
||||
Dictionary mapping color names to their constituent traits.
|
||||
Typically imported from speaking_styles.SPEAKING_STYLES
|
||||
target_column : str
|
||||
The column to correlate against speaking style scores.
|
||||
Default: "Voice_Scale_Score" (for voice scale 1-10)
|
||||
Alternative: "Ranking_Points" (for top 3 voice ranking)
|
||||
|
||||
Returns
|
||||
-------
|
||||
tuple[pl.DataFrame, dict | None]
|
||||
(DataFrame with columns [Color, correlation, n_traits], None)
|
||||
"""
|
||||
if isinstance(joined_df, pl.LazyFrame):
|
||||
joined_df = joined_df.collect()
|
||||
|
||||
color_correlations = []
|
||||
|
||||
for color, traits in speaking_styles.items():
|
||||
trait_corrs = []
|
||||
for trait in traits:
|
||||
# Filter to this specific trait
|
||||
subset = joined_df.filter(pl.col("Right_Anchor") == trait)
|
||||
valid_data = subset.select(["score", target_column]).drop_nulls()
|
||||
|
||||
if valid_data.height > 1:
|
||||
corr_val = valid_data.select(pl.corr("score", target_column)).item()
|
||||
if corr_val is not None:
|
||||
trait_corrs.append(corr_val)
|
||||
|
||||
# Average across all traits for this color
|
||||
if trait_corrs:
|
||||
avg_corr = sum(trait_corrs) / len(trait_corrs)
|
||||
color_correlations.append({
|
||||
"Color": color,
|
||||
"correlation": avg_corr,
|
||||
"n_traits": len(trait_corrs)
|
||||
})
|
||||
|
||||
result_df = pl.DataFrame(color_correlations)
|
||||
return result_df, None
|
||||
|
||||
|
||||
def process_voice_ranking_data(
|
||||
df: Union[pl.LazyFrame, pl.DataFrame]
|
||||
) -> pl.DataFrame:
|
||||
|
||||
Reference in New Issue
Block a user