Files
JPMC-quant/utils.py
2026-02-03 01:32:06 +01:00

1828 lines
71 KiB
Python

import polars as pl
from pathlib import Path
import pandas as pd
from typing import Union
import json
import re
import hashlib
import os
from io import BytesIO
import imagehash
from PIL import Image
from plots import QualtricsPlotsMixin
from pptx import Presentation
from pptx.enum.shapes import MSO_SHAPE_TYPE
def image_alt_text_generator(fpath, include_dataset_dirname=False) -> str:
"""convert image file path to alt text
Args:
fpath (str or Path): path to image file, must start with 'figures/'
include_dataset_dirname (bool): whether to include the dataset directory name in the alt text. Recommended to keep False, so that the images do not get tied to a specific dataset export. (Defeats the purpose of assigning alt text to be able to update images when new datasets are exported.)
"""
if not isinstance(fpath, Path):
fpath = Path(fpath)
fparts = fpath.parts
assert fparts[0] == 'figures', "Image file path must start with 'figures'"
if include_dataset_dirname:
return Path('/'.join(fparts[1:])).as_posix()
else:
return Path('/'.join(fparts[2:])).as_posix()
def _get_shape_alt_text(shape) -> str:
"""
Extract alt text from a PowerPoint shape.
Args:
shape: A python-pptx shape object.
Returns:
str: The alt text (descr attribute) or empty string if not found.
"""
try:
# Check for common property names used by python-pptx elements to store non-visual props
# nvPicPr (Picture), nvSpPr (Shape/Placeholder), nvGrpSpPr (Group),
# nvGraphicFramePr (GraphicFrame), nvCxnSpPr (Connector)
nvPr = None
for attr in ['nvPicPr', 'nvSpPr', 'nvGrpSpPr', 'nvGraphicFramePr', 'nvCxnSpPr']:
if hasattr(shape._element, attr):
nvPr = getattr(shape._element, attr)
break
if nvPr is not None and hasattr(nvPr, 'cNvPr'):
return nvPr.cNvPr.get("descr", "")
except Exception:
pass
return ""
def pptx_replace_images_from_directory(
presentation_path: Union[str, Path],
image_source_dir: Union[str, Path],
save_path: Union[str, Path] = None
) -> dict:
"""
Replace all images in a PowerPoint presentation using images from a directory
where subdirectory/filename paths match the alt_text of each image.
This function scans all images in the presentation, extracts their alt_text,
and looks for a matching image file in the source directory. The alt_text
should be a relative path (e.g., "All_Respondents/chart_name.png") that
corresponds to the directory structure under image_source_dir.
Args:
presentation_path (str/Path): Path to the source .pptx file.
image_source_dir (str/Path): Root directory containing replacement images.
The directory structure should mirror the alt_text paths.
Example: if alt_text is "All_Respondents/voice_scale.png", the
replacement image should be at image_source_dir/All_Respondents/voice_scale.png
save_path (str/Path, optional): Path to save the modified presentation.
If None, overwrites the input file.
Returns:
dict: Summary with keys:
- 'replaced': List of dicts with slide number, shape name, and matched path
- 'not_found': List of dicts with slide number, shape name, and alt_text
- 'no_alt_text': List of dicts with slide number and shape name
- 'total_images': Total number of picture shapes processed
Example:
>>> pptx_replace_images_from_directory(
... "presentation.pptx",
... "figures/2-2-26/",
... "presentation_updated.pptx"
... )
Notes:
- Alt text should be set using update_ppt_alt_text() or image_alt_text_generator()
- Images without alt_text are skipped
- Original image position, size, and aspect ratio are preserved
"""
presentation_path = Path(presentation_path)
image_source_dir = Path(image_source_dir)
if save_path is None:
save_path = presentation_path
else:
save_path = Path(save_path)
if not presentation_path.exists():
raise FileNotFoundError(f"Presentation not found: {presentation_path}")
if not image_source_dir.exists():
raise FileNotFoundError(f"Image source directory not found: {image_source_dir}")
# Build a lookup of all available images in the source directory
available_images = {}
for img_path in image_source_dir.rglob("*"):
if img_path.is_file() and img_path.suffix.lower() in {'.png', '.jpg', '.jpeg', '.gif', '.bmp', '.tiff', '.webp'}:
# Store relative path from image_source_dir as key
rel_path = img_path.relative_to(image_source_dir).as_posix()
available_images[rel_path] = img_path
print(f"Found {len(available_images)} images in source directory")
# Open presentation
prs = Presentation(presentation_path)
# Track results
results = {
'replaced': [],
'not_found': [],
'no_alt_text': [],
'total_images': 0
}
total_slides = len(prs.slides)
print(f"Processing {total_slides} slides...")
for slide_idx, slide in enumerate(prs.slides):
slide_num = slide_idx + 1
# Use recursive iterator to find all pictures including those in groups
picture_shapes = list(_iter_picture_shapes(slide.shapes))
for shape in picture_shapes:
results['total_images'] += 1
shape_name = shape.name or f"Unnamed (ID: {getattr(shape, 'shape_id', 'unknown')})"
# Get alt text
alt_text = _get_shape_alt_text(shape)
if not alt_text:
results['no_alt_text'].append({
'slide': slide_num,
'shape_name': shape_name
})
continue
# Look for matching image in source directory
# Try the alt_text as-is, and also with common extensions if not present
matched_path = None
if alt_text in available_images:
matched_path = available_images[alt_text]
else:
# Try adding common extensions if alt_text doesn't have one
alt_text_path = Path(alt_text)
if not alt_text_path.suffix:
for ext in ['.png', '.jpg', '.jpeg', '.gif']:
test_key = f"{alt_text}{ext}"
if test_key in available_images:
matched_path = available_images[test_key]
break
if matched_path is None:
results['not_found'].append({
'slide': slide_num,
'shape_name': shape_name,
'alt_text': alt_text
})
continue
# Replace the image
try:
# Record coordinates
left, top, width, height = shape.left, shape.top, shape.width, shape.height
# Remove old shape from XML
old_element = shape._element
old_element.getparent().remove(old_element)
# Add new image at the same position/size
new_shape = slide.shapes.add_picture(str(matched_path), left, top, width, height)
# Preserve the alt text on the new shape
new_nvPr = None
for attr in ['nvPicPr', 'nvSpPr', 'nvGrpSpPr', 'nvGraphicFramePr', 'nvCxnSpPr']:
if hasattr(new_shape._element, attr):
new_nvPr = getattr(new_shape._element, attr)
break
if new_nvPr and hasattr(new_nvPr, 'cNvPr'):
new_nvPr.cNvPr.set("descr", alt_text)
results['replaced'].append({
'slide': slide_num,
'shape_name': shape_name,
'matched_path': str(matched_path)
})
print(f"Slide {slide_num}: Replaced '{alt_text}'")
except Exception as e:
results['not_found'].append({
'slide': slide_num,
'shape_name': shape_name,
'alt_text': alt_text,
'error': str(e)
})
# Save presentation
prs.save(save_path)
# Print summary
print("\n" + "=" * 80)
if results['replaced']:
print(f"✓ Saved updated presentation to {save_path} with {len(results['replaced'])} replacements.")
else:
print("No images matched or required updates.")
if results['not_found']:
print(f"\n{len(results['not_found'])} image(s) not found in source directory:")
for item in results['not_found']:
print(f" • Slide {item['slide']}: '{item.get('alt_text', 'N/A')}'")
if results['no_alt_text']:
print(f"\n{len(results['no_alt_text'])} image(s) without alt text (skipped):")
for item in results['no_alt_text']:
print(f" • Slide {item['slide']}: '{item['shape_name']}'")
if not results['not_found'] and not results['no_alt_text']:
print("\n✓ All images replaced successfully!")
print("=" * 80)
return results
def pptx_replace_named_image(presentation_path, target_tag, new_image_path, save_path):
"""
Finds and replaces a specific image in a PowerPoint presentation while
preserving its original position, size, and aspect ratio.
This function performs a 'surgical' replacement: it records the coordinates
of the existing image, removes it from the slide's XML, and inserts a
new image into the exact same bounding box. It identifies the target
image by searching for a specific string within the Shape Name
(Selection Pane) or Alt Text.
Note: For batch replacement of all images using a directory structure,
use pptx_replace_images_from_directory() instead.
Args:
presentation_path (str): The file path to the source .pptx file.
target_tag (str): The unique identifier to look for (e.g., 'HERO_IMAGE').
This is case-sensitive and checks both the shape name and alt text.
new_image_path (str): The file path to the new image (PNG, JPG, etc.).
save_path (str): The file path where the modified presentation will be saved.
Returns:
None: Saves the file directly to the provided save_path.
Raises:
FileNotFoundError: If the source presentation or new image is not found.
PermissionError: If the save_path is currently open or locked.
"""
prs = Presentation(presentation_path)
for i, slide in enumerate(prs.slides):
# Iterate over a list copy of shapes to safely modify the slide during iteration
print(f"Processing Slide {i + 1}...")
print(f"Total Shapes: {len(slide.shapes)} shapes")
for shape in list(slide.shapes):
print(f"Checking shape: {shape.name} of type {shape.shape_type}...")
shape_name = shape.name or ""
alt_text = _get_shape_alt_text(shape)
print(f"Alt Text for shape '{shape_name}': {alt_text}")
if target_tag in shape_name or target_tag in alt_text:
print(f"Found it! Replacing {shape_name}...")
try:
# Record coordinates
left, top, width, height = shape.left, shape.top, shape.width, shape.height
# Remove old shape
old_element = shape._element
old_element.getparent().remove(old_element)
# Add new image at the same spot
slide.shapes.add_picture(str(new_image_path), left, top, width, height)
except AttributeError:
print(f"Could not replace {shape_name} - might be missing dimensions.")
else:
print(f"Skipping shape '{shape_name}' with alt text '{alt_text}'")
prs.save(save_path)
print(f"Successfully saved to {save_path}")
def _calculate_file_sha1(file_path: Union[str, Path]) -> str:
"""Calculate SHA1 hash of a file."""
sha1 = hashlib.sha1()
with open(file_path, 'rb') as f:
while True:
data = f.read(65536)
if not data:
break
sha1.update(data)
return sha1.hexdigest()
def _calculate_perceptual_hash(image_source: Union[str, Path, bytes]) -> str:
"""
Calculate perceptual hash of an image based on visual content.
Uses pHash (perceptual hash) which is robust against:
- Metadata differences
- Minor compression differences
- Small color/contrast variations
Args:
image_source: File path to image or raw image bytes.
Returns:
str: Hexadecimal string representation of the perceptual hash.
"""
if isinstance(image_source, bytes):
img = Image.open(BytesIO(image_source))
else:
img = Image.open(image_source)
# Convert to RGB if necessary (handles RGBA, P mode, etc.)
if img.mode not in ('RGB', 'L'):
img = img.convert('RGB')
# Use pHash (perceptual hash) - robust against minor differences
phash = imagehash.phash(img)
return str(phash)
def _build_image_hash_map(root_dir: Union[str, Path], use_perceptual_hash: bool = True) -> dict:
"""
Recursively walk the directory and build a map of image hashes to file paths.
Only includes common image extensions.
Args:
root_dir: Root directory to scan for images.
use_perceptual_hash: If True, uses perceptual hashing (robust against metadata
differences). If False, uses SHA1 byte hashing (exact match only).
Returns:
dict: Mapping of hash strings to file paths.
"""
hash_map = {}
valid_extensions = {'.png', '.jpg', '.jpeg', '.tiff', '.bmp', '.gif'}
root = Path(root_dir)
hash_type = "perceptual" if use_perceptual_hash else "SHA1"
print(f"Building image hash map from {root} using {hash_type} hashing...")
count = 0
for root_path, dirs, files in os.walk(root):
for file in files:
file_path = Path(root_path) / file
if file_path.suffix.lower() in valid_extensions:
try:
if use_perceptual_hash:
file_hash = _calculate_perceptual_hash(file_path)
else:
file_hash = _calculate_file_sha1(file_path)
# We store the absolute path for reference, but we might just need the path relative to project for alt text
hash_map[file_hash] = file_path
count += 1
except Exception as e:
print(f"Error hashing {file_path}: {e}")
print(f"Indexed {count} images.")
return hash_map
def _iter_picture_shapes(shapes):
"""
Recursively iterate over shapes and yield those that are pictures
(have an 'image' property), diving into groups.
"""
for shape in shapes:
# Check groups recursively
if shape.shape_type == MSO_SHAPE_TYPE.GROUP:
yield from _iter_picture_shapes(shape.shapes)
continue
# Check if shape has image property (Pictures, Placeholders with images)
if hasattr(shape, 'image'):
yield shape
def update_ppt_alt_text(ppt_path: Union[str, Path], image_source_dir: Union[str, Path], output_path: Union[str, Path] = None, use_perceptual_hash: bool = True):
"""
Updates the alt text of images in a PowerPoint presentation by matching
their content with images in a source directory.
Args:
ppt_path (str/Path): Path to the PowerPoint file.
image_source_dir (str/Path): Directory containing source images to match against.
output_path (str/Path, optional): Path to save the updated presentation.
If None, overwrites the input file.
use_perceptual_hash (bool): If True (default), uses perceptual hashing which
matches images based on visual content (robust against metadata differences,
re-compression, etc.). If False, uses SHA1 byte hashing (exact file match only).
"""
if output_path is None:
output_path = ppt_path
# 1. Build lookup map of {hash: file_path} from the source directory
image_hash_map = _build_image_hash_map(image_source_dir, use_perceptual_hash=use_perceptual_hash)
# 2. Open Presentation
try:
prs = Presentation(ppt_path)
except Exception as e:
print(f"Error opening presentation {ppt_path}: {e}")
return
updates_count = 0
unmatched_images = [] # Collect unmatched images to report at the end
slides = list(prs.slides)
total_slides = len(slides)
print(f"Processing {total_slides} slides...")
for i, slide in enumerate(slides):
# Use recursive iterator to find all pictures including those in groups/placeholders
picture_shapes = list(_iter_picture_shapes(slide.shapes))
for shape in picture_shapes:
try:
# Get image hash based on selected method
if use_perceptual_hash:
# Use perceptual hash of the image blob for visual content matching
current_hash = _calculate_perceptual_hash(shape.image.blob)
else:
# Use SHA1 hash from python-pptx (exact byte match)
current_hash = shape.image.sha1
if current_hash in image_hash_map:
original_path = image_hash_map[current_hash]
# Generate Alt Text
try:
# Prepare path for generator.
# Try to relativize to CWD if capable
pass_path = original_path
try:
pass_path = original_path.relative_to(Path.cwd())
except ValueError:
pass
new_alt_text = image_alt_text_generator(pass_path)
# Check existing alt text to avoid redundant updates/log them
# Accessing alt text via cNvPr
# Note: Different shape types might store non-visual props differently
# Picture: nvPicPr.cNvPr
# GraphicFrame: nvGraphicFramePr.cNvPr
# Group: nvGrpSpPr.cNvPr
# Shape/Placeholder: nvSpPr.cNvPr
nvPr = None
for attr in ['nvPicPr', 'nvSpPr', 'nvGrpSpPr', 'nvGraphicFramePr', 'nvCxnSpPr']:
if hasattr(shape._element, attr):
nvPr = getattr(shape._element, attr)
break
if nvPr and hasattr(nvPr, 'cNvPr'):
cNvPr = nvPr.cNvPr
existing_alt_text = cNvPr.get("descr", "")
if existing_alt_text != new_alt_text:
print(f"Slide {i+1}: Updating alt text for image matches '{pass_path}'")
print(f" Old: '{existing_alt_text}' -> New: '{new_alt_text}'")
cNvPr.set("descr", new_alt_text)
updates_count += 1
else:
print(f"Could not find cNvPr for shape on slide {i+1}")
except AssertionError as e:
print(f"Skipping match for {original_path} due to generator error: {e}")
except Exception as e:
print(f"Error updating alt text for {original_path}: {e}")
else:
# Check if image already has alt text set - if so, skip reporting as unmatched
existing_alt = _get_shape_alt_text(shape)
if existing_alt:
# Image already has alt text, no need to report as unmatched
continue
shape_id = getattr(shape, 'shape_id', getattr(shape, 'id', 'Unknown ID'))
shape_name = shape.name if shape.name else f"Unnamed Shape (ID: {shape_id})"
hash_type = "pHash" if use_perceptual_hash else "SHA1"
unmatched_images.append({
'slide': i+1,
'shape_name': shape_name,
'hash_type': hash_type,
'hash': current_hash
})
except AttributeError:
continue
except Exception as e:
print(f"Error processing shape on slide {i+1}: {e}")
# Print summary
print("\n" + "="*80)
if updates_count > 0:
prs.save(output_path)
print(f"✓ Saved updated presentation to {output_path} with {updates_count} updates.")
else:
print("No images matched or required updates.")
# List unmatched images at the end
if unmatched_images:
print(f"\n{len(unmatched_images)} image(s) not found in source directory:")
for img in unmatched_images:
print(f" • Slide {img['slide']}: '{img['shape_name']}' ({img['hash_type']}: {img['hash']})")
else:
print("\n✓ All images matched successfully!")
print("="*80)
def extract_voice_label(html_str: str) -> str:
"""
Extract voice label from HTML string and convert to short format.
Parameters:
html_str (str): HTML string containing voice label in format "Voice N"
Returns:
str: Voice label in format "VN" (e.g., "V14")
Example:
>>> extract_voice_label('<span style="...">Voice 14<br />...')
'V14'
"""
match = re.search(r'Voice (\d+)', html_str)
return f"V{match.group(1)}" if match else None
def extract_qid(val):
"""Extracts the 'ImportId' from a string representation of a dictionary."""
if isinstance(val, str) and val.startswith('{') and val.endswith('}'):
val = eval(val)
return val['ImportId']
def combine_exclusive_columns(df: pl.DataFrame, id_col: str = "_recordId", target_col_name: str = "combined_value") -> pl.DataFrame:
"""
Combines all columns except id_col into a single column.
Raises ValueError if more than one column is populated in a single row.
"""
merge_cols = [c for c in df.columns if c != id_col]
# Validate: count non-nulls horizontally
row_counts = df.select(
pl.sum_horizontal(pl.col(merge_cols).is_not_null())
).to_series()
if (row_counts > 1).any():
raise ValueError("Invalid Data: Multiple columns populated for a single record row.")
# Merge columns using coalesce
return df.select([
pl.col(id_col),
pl.coalesce(merge_cols).alias(target_col_name)
])
def calculate_weighted_ranking_scores(df: pl.LazyFrame) -> pl.DataFrame:
"""
Calculate weighted scores for character or voice rankings.
Points system: 1st place = 3 pts, 2nd place = 2 pts, 3rd place = 1 pt.
Parameters
----------
df : pl.DataFrame
DataFrame containing character/ voice ranking columns.
Returns
-------
pl.DataFrame
DataFrame with columns 'Character' and 'Weighted Score', sorted by score.
"""
if isinstance(df, pl.LazyFrame):
df = df.collect()
scores = []
# Identify ranking columns (assume all columns except _recordId)
ranking_cols = [c for c in df.columns if c != '_recordId']
for col in ranking_cols:
# Calculate score:
# (Count of Rank 1 * 3) + (Count of Rank 2 * 2) + (Count of Rank 3 * 1)
r1_count = df.filter(pl.col(col) == 1).height
r2_count = df.filter(pl.col(col) == 2).height
r3_count = df.filter(pl.col(col) == 3).height
weighted_score = (r1_count * 3) + (r2_count * 2) + (r3_count * 1)
# Clean name
clean_name = col.replace('Character_Ranking_', '').replace('Top_3_Voices_ranking__', '').replace('_', ' ').strip()
scores.append({
'Character': clean_name,
'Weighted Score': weighted_score
})
return pl.DataFrame(scores).sort('Weighted Score', descending=True)
def normalize_row_values(df: pl.DataFrame, target_cols: list[str]) -> pl.DataFrame:
"""
Normalizes values in the specified columns row-wise to 0-10 scale (Min-Max normalization).
Formula: ((x - row_min) / (row_max - row_min)) * 10
Nulls are preserved as nulls. If all non-null values in a row are equal (max == min),
those values become 5.0 (midpoint of the scale).
Parameters
----------
df : pl.DataFrame
Input dataframe.
target_cols : list[str]
List of column names to normalize.
Returns
-------
pl.DataFrame
DataFrame with target columns normalized row-wise.
"""
# Calculate row min and max across target columns (ignoring nulls)
row_min = pl.min_horizontal([pl.col(c).cast(pl.Float64) for c in target_cols])
row_max = pl.max_horizontal([pl.col(c).cast(pl.Float64) for c in target_cols])
row_range = row_max - row_min
# Build normalized column expressions
norm_exprs = []
for col in target_cols:
norm_exprs.append(
pl.when(row_range == 0)
.then(
# If range is 0 (all values equal), return 5.0 for non-null, null for null
pl.when(pl.col(col).is_null()).then(None).otherwise(5.0)
)
.otherwise(
((pl.col(col).cast(pl.Float64) - row_min) / row_range) * 10
)
.alias(col)
)
return df.with_columns(norm_exprs)
def normalize_global_values(df: pl.DataFrame, target_cols: list[str]) -> pl.DataFrame:
"""
Normalizes values in the specified columns globally to 0-10 scale.
Formula: ((x - global_min) / (global_max - global_min)) * 10
Ignores null values (NaNs).
"""
# Ensure eager for scalar extraction
was_lazy = isinstance(df, pl.LazyFrame)
if was_lazy:
df = df.collect()
if len(target_cols) == 0:
return df.lazy() if was_lazy else df
# Calculate global stats efficiently by stacking all columns
# Cast to Float64 to ensure numeric calculations
stats = df.select([pl.col(c).cast(pl.Float64) for c in target_cols]).melt().select([
pl.col("value").min().alias("min"),
pl.col("value").max().alias("max")
])
global_min = stats["min"][0]
global_max = stats["max"][0]
# Handle edge case where all values are same or none exist
if global_min is None or global_max is None or global_max == global_min:
return df.lazy() if was_lazy else df
global_range = global_max - global_min
res = df.with_columns([
(((pl.col(col).cast(pl.Float64) - global_min) / global_range) * 10).alias(col)
for col in target_cols
])
return res.lazy() if was_lazy else res
class QualtricsSurvey(QualtricsPlotsMixin):
"""Class to handle Qualtrics survey data."""
def __init__(self, data_path: Union[str, Path], qsf_path: Union[str, Path]):
if isinstance(data_path, str):
data_path = Path(data_path)
if isinstance(qsf_path, str):
qsf_path = Path(qsf_path)
self.data_filepath = data_path
self.qsf_filepath = qsf_path
self.qid_descr_map = self._extract_qid_descr_map()
self.qsf:dict = self._load_qsf()
# get export directory name for saving figures ie if data_path='data/exports/OneDrive_2026-01-21/...' should be 'figures/OneDrive_2026-01-21'
self.fig_save_dir = Path('figures') / self.data_filepath.parts[2]
if not self.fig_save_dir.exists():
self.fig_save_dir.mkdir(parents=True, exist_ok=True)
self.data_filtered = None
self.plot_height = 500
self.plot_width = 1000
# Filter values
self.filter_age:list = None
self.filter_gender:list = None
self.filter_consumer:list = None
self.filter_ethnicity:list = None
self.filter_income:list = None
def _extract_qid_descr_map(self) -> dict:
"""Extract mapping of Qualtrics ImportID to Question Description from results file."""
if '1_1-16-2026' in self.data_filepath.as_posix():
df_questions = pd.read_csv(self.data_filepath, nrows=1)
df_questions
return df_questions.iloc[0].to_dict()
else:
# First row contains Qualtrics Editor question names (ie 'B_VOICE SEL. 18-8')
# Second row which contains the question content
# Third row contains the Export Metadata (ie '{"ImportId":"startDate","timeZone":"America/Denver"}')
df_questions = pd.read_csv(self.data_filepath, nrows=2)
# transpose df_questions
df_questions = df_questions.T.reset_index()
df_questions.columns = ['QName', 'Description', 'export_metadata']
df_questions['ImportID'] = df_questions['export_metadata'].apply(extract_qid)
df_questions = df_questions[['ImportID', 'QName', 'Description']]
# return dict as {ImportID: [QName, Description]}
return df_questions.set_index('ImportID')[['QName', 'Description']].T.to_dict()
def _load_qsf(self) -> dict:
"""Load QSF file to extract question metadata if needed."""
with open(self.qsf_filepath, 'r', encoding='utf-8') as f:
qsf_data = json.load(f)
return qsf_data
def _get_qsf_question_by_QID(self, QID: str) -> dict:
"""Get question metadata from QSF using the Question ID."""
q_elem = [elem for elem in self.qsf['SurveyElements'] if elem['PrimaryAttribute'] == QID]
if len(q_elem) == 0:
raise ValueError(f"SurveyElement with 'PrimaryAttribute': '{QID}' not found in QSF.")
if len(q_elem) > 1:
raise ValueError(f"Multiple SurveyElements with 'PrimaryAttribute': '{QID}' found in QSF: \n{q_elem}")
return q_elem[0]
def load_data(self) -> pl.LazyFrame:
"""
Load CSV where column headers are in row 3 as dict strings with ImportId.
The 3rd row contains metadata like '{"ImportId":"startDate","timeZone":"America/Denver"}'.
This function extracts the ImportId from each column and uses it as the column name.
Parameters:
file_path (Path): Path to the CSV file to load.
Returns:
pl.LazyFrame: Polars LazyFrame with ImportId as column names.
"""
if '1_1-16-2026' in self.data_filepath.as_posix():
raise NotImplementedError("This method does not support the '1_1-16-2026' export format.")
# Read the 3rd row (index 2) which contains the metadata dictionaries
# Use header=None to get raw values instead of treating them as column names
df_meta = pd.read_csv(self.data_filepath, nrows=1, skiprows=2, header=None)
# Extract ImportIds from each column value in this row
new_columns = [extract_qid(val) for val in df_meta.iloc[0]]
# Now read the actual data starting from row 4 (skip first 3 rows)
df = pl.read_csv(self.data_filepath, skip_rows=3)
# Rename columns with the extracted ImportIds
df.columns = new_columns
# Store unique values for filters (ignoring nulls) to detect "all selected" state
self.options_age = sorted(df['QID1'].drop_nulls().unique().to_list()) if 'QID1' in df.columns else []
self.options_gender = sorted(df['QID2'].drop_nulls().unique().to_list()) if 'QID2' in df.columns else []
self.options_consumer = sorted(df['Consumer'].drop_nulls().unique().to_list()) if 'Consumer' in df.columns else []
self.options_ethnicity = sorted(df['QID3'].drop_nulls().unique().to_list()) if 'QID3' in df.columns else []
self.options_income = sorted(df['QID15'].drop_nulls().unique().to_list()) if 'QID15' in df.columns else []
return df.lazy()
def _get_subset(self, q: pl.LazyFrame, QIDs, rename_cols=True, include_record_id=True) -> pl.LazyFrame:
"""Extract subset of data based on specific questions."""
if include_record_id and '_recordId' not in QIDs:
QIDs = ['_recordId'] + QIDs
if not rename_cols:
return q.select(QIDs)
rename_dict = {qid: self.qid_descr_map[qid]['QName'] for qid in QIDs if qid in self.qid_descr_map and qid != '_recordId'}
return q.select(QIDs).rename(rename_dict)
def filter_data(self, q: pl.LazyFrame, age:list=None, gender:list=None, consumer:list=None, ethnicity:list=None, income:list=None) -> pl.LazyFrame:
"""Filter data based on provided parameters
Possible parameters:
- age: list of age groups to include
- gender: list
- consumer: list
- ethnicity: list
- income: list
Also saves the result to self.data_filtered.
"""
# Apply filters
self.filter_age = age
if age is not None:
q = q.filter(pl.col('QID1').is_in(age))
self.filter_gender = gender
if gender is not None:
q = q.filter(pl.col('QID2').is_in(gender))
self.filter_consumer = consumer
if consumer is not None:
q = q.filter(pl.col('Consumer').is_in(consumer))
self.filter_ethnicity = ethnicity
if ethnicity is not None:
q = q.filter(pl.col('QID3').is_in(ethnicity))
self.filter_income = income
if income is not None:
q = q.filter(pl.col('QID15').is_in(income))
self.data_filtered = q
return self.data_filtered
def get_demographics(self, q: pl.LazyFrame) -> Union[pl.LazyFrame, None]:
"""Extract columns containing the demographics.
Renames columns using qid_descr_map if provided.
"""
QIDs = ['QID1', 'QID2', 'QID3', 'QID4', 'QID7', 'QID13', 'QID14', 'QID15', 'QID16', 'QID17', 'Consumer']
return self._get_subset(q, QIDs), None
def get_top_8_traits(self, q: pl.LazyFrame) -> Union[pl.LazyFrame, None]:
"""Extract columns containing the top 8 characteristics are most important for this Chase virtual assistant to have.
Returns subquery that can be chained with other polars queries.
"""
QIDs = ['QID25']
return self._get_subset(q, QIDs, rename_cols=False).rename({'QID25': 'Top_8_Traits'}), None
def get_top_3_traits(self, q: pl.LazyFrame) -> Union[pl.LazyFrame, None]:
"""Extract columns containing the top 3 characteristics that the Chase virtual assistant should prioritize.
Returns subquery that can be chained with other polars queries.
"""
QIDs = ['QID26_0_GROUP']
return self._get_subset(q, QIDs, rename_cols=False).rename({'QID26_0_GROUP': 'Top_3_Traits'}), None
def get_character_ranking(self, q: pl.LazyFrame) -> Union[pl.LazyFrame, None]:
"""Extract columns containing the ranking of characteristics for the Chase virtual assistant.
Returns subquery that can be chained with other polars queries.
"""
# Requires QSF to map "Character Ranking_2" to the actual character
cfg = self._get_qsf_question_by_QID('QID27')['Payload']
QIDs_map = {f'QID27_{v}': cfg['VariableNaming'][k] for k,v in cfg['RecodeValues'].items()}
QIDs_rename = {qid: f'Character_Ranking_{QIDs_map[qid].replace(" ", "_")}' for qid in QIDs_map}
return self._get_subset(q, list(QIDs_rename.keys()), rename_cols=False).rename(QIDs_rename), None
def get_18_8_3(self, q: pl.LazyFrame) -> Union[pl.LazyFrame, None]:
"""Extract columns containing the 18-8-3 feedback for the Chase virtual assistant.
Returns subquery that can be chained with other polars queries.
"""
QIDs = ['QID29', 'QID101', 'QID36_0_GROUP']
rename_dict = {
'QID29': '18-8_Set-A',
'QID101': '18-8_Set-B',
'QID36_0_GROUP': '3_Ranked'
}
subset = self._get_subset(q, QIDs, rename_cols=False).rename(rename_dict)
# Combine 18-8 Set A and Set B into single column
subset = subset.with_columns(
pl.coalesce(['18-8_Set-A', '18-8_Set-B']).alias('8_Combined')
)
# Change order of columns
subset = subset.select(['_recordId', '18-8_Set-A', '18-8_Set-B', '8_Combined', '3_Ranked'])
return subset, None
def get_voice_scale_1_10(self, q: pl.LazyFrame, drop_cols=['Voice_Scale_1_10__V46']) -> Union[pl.LazyFrame, None]:
"""Extract columns containing the Voice Scale 1-10 ratings for the Chase virtual assistant.
Returns subquery that can be chained with other polars queries.
Drops scores for V46 as it was improperly configured in the survey and thus did not show up for respondents.
"""
QIDs_map = {}
for qid, val in self.qid_descr_map.items():
if 'Scale 1-10_1' in val['QName']:
# Convert "Voice 16 Scale 1-10_1" to "Scale_1_10__Voice_16"
QIDs_map[qid] = f"Voice_Scale_1_10__V{val['QName'].split()[1]}"
for col in drop_cols:
if col in QIDs_map.values():
# remove from QIDs_map
qid_to_remove = [k for k,v in QIDs_map.items() if v == col][0]
del QIDs_map[qid_to_remove]
return self._get_subset(q, list(QIDs_map.keys()), rename_cols=False).rename(QIDs_map), None
def get_ss_green_blue(self, q: pl.LazyFrame) -> Union[pl.LazyFrame, dict]:
"""Extract columns containing the SS Green/Blue ratings for the Chase virtual assistant.
Returns subquery that can be chained with other polars queries.
"""
cfg = self._get_qsf_question_by_QID('QID35')['Payload']
QIDs_map = {}
choices_map = {}
for qid, val in self.qid_descr_map.items():
if 'SS Green-Blue' in val['QName']:
cfg = self._get_qsf_question_by_QID(qid.split('_')[0])['Payload']
# ie: "V14 SS Green-Blue_1"
qname_parts = val['QName'].split()
voice = qname_parts[0]
trait_num = qname_parts[-1].split('_')[-1]
QIDs_map[qid] = f"SS_Green_Blue__{voice}__Choice_{trait_num}"
choices_map[f"SS_Green_Blue__{voice}__Choice_{trait_num}"] = cfg['Choices'][trait_num]['Display']
return self._get_subset(q, list(QIDs_map.keys()), rename_cols=False).rename(QIDs_map), choices_map
def get_top_3_voices(self, q: pl.LazyFrame) -> Union[pl.LazyFrame, None]:
"""Extract columns containing the top 3 voice choices for the Chase virtual assistant.
Returns subquery that can be chained with other polars queries.
"""
QIDs_map = {}
cfg36 = self._get_qsf_question_by_QID('QID36')['Payload']
choice_voice_map = {k: extract_voice_label(v['Display']) for k,v in cfg36['Choices'].items()}
for qid, val in self.qid_descr_map.items():
if 'Rank Top 3 Voices' in val['QName']:
cfg = self._get_qsf_question_by_QID(qid.split('_')[0])['Payload']
voice_num = val['QName'].split('_')[-1]
# Validate that the DynamicChoices Locator is as expected
if cfg['DynamicChoices']['Locator'] != r"q://QID36/ChoiceGroup/SelectedChoicesInGroup/1":
raise ValueError(f"Unexpected DynamicChoices Locator for QID '{qid}': {cfg['DynamicChoices']['Locator']}")
# extract the voice from the QID36 config
voice = choice_voice_map[voice_num]
# Convert "Top 3 Voices_1" to "Top_3_Voices__V14"
QIDs_map[qid] = f"Top_3_Voices_ranking__{voice}"
return self._get_subset(q, list(QIDs_map.keys()), rename_cols=False).rename(QIDs_map), None
def get_ss_orange_red(self, q: pl.LazyFrame) -> Union[pl.LazyFrame, dict]:
"""Extract columns containing the SS Orange/Red ratings for the Chase virtual assistant.
Returns subquery that can be chained with other polars queries.
"""
cfg = self._get_qsf_question_by_QID('QID40')['Payload']
QIDs_map = {}
choices_map = {}
for qid, val in self.qid_descr_map.items():
if 'SS Orange-Red' in val['QName']:
cfg = self._get_qsf_question_by_QID(qid.split('_')[0])['Payload']
# ie: "V14 SS Orange-Red_1"
qname_parts = val['QName'].split()
voice = qname_parts[0]
trait_num = qname_parts[-1].split('_')[-1]
QIDs_map[qid] = f"SS_Orange_Red__{voice}__Choice_{trait_num}"
choices_map[f"SS_Orange_Red__{voice}__Choice_{trait_num}"] = cfg['Choices'][trait_num]['Display']
return self._get_subset(q, list(QIDs_map.keys()), rename_cols=False).rename(QIDs_map), choices_map
def get_character_refine(self, q: pl.LazyFrame) -> Union[pl.LazyFrame, None]:
"""Extract columns containing the character refine feedback for the Chase virtual assistant.
Returns subquery that can be chained with other polars queries.
"""
QIDs = ['QID44', 'QID97', 'QID95', 'QID96']
return self._get_subset(q, QIDs, rename_cols=True), None
def transform_character_trait_frequency(
self,
char_df: pl.LazyFrame | pl.DataFrame,
character_column: str,
) -> tuple[pl.DataFrame, dict | None]:
"""Transform character refine data to trait frequency counts for a single character.
Original use-case: "I need a bar plot that shows the frequency of the times
each trait is chosen per brand character."
This function takes a DataFrame with comma-separated trait selections per
character, explodes traits, and counts their frequency for a single character.
Args:
char_df: Pre-fetched data
Expected columns: '_recordId', '<character_column>' (with comma-separated traits)
character_column: Name of the character column to analyze (e.g., 'Bank Teller')
Returns:
tuple: (DataFrame with columns ['trait', 'count', 'is_original'], None)
- 'trait': individual trait name
- 'count': frequency count
- 'is_original': boolean indicating if trait is in the original definition
"""
from reference import ORIGINAL_CHARACTER_TRAITS
if isinstance(char_df, pl.LazyFrame):
char_df = char_df.collect()
# Map display names to reference keys
character_key_map = {
'Bank Teller': 'the_bank_teller',
'Familiar Friend': 'the_familiar_friend',
'The Coach': 'the_coach',
'Personal Assistant': 'the_personal_assistant',
}
# Get original traits for this character
ref_key = character_key_map.get(character_column)
original_traits = set(ORIGINAL_CHARACTER_TRAITS.get(ref_key, []))
# Filter to rows where this character has a value (not null)
char_data = char_df.filter(pl.col(character_column).is_not_null())
# Split comma-separated traits and explode
exploded = (
char_data
.select(
pl.col(character_column)
.str.split(',')
.alias('traits')
)
.explode('traits')
.with_columns(
pl.col('traits').str.strip_chars().alias('trait')
)
.filter(pl.col('trait') != '')
)
# Count trait frequencies
freq_df = (
exploded
.group_by('trait')
.agg(pl.len().alias('count'))
.sort('count', descending=True)
)
# Add is_original flag
freq_df = freq_df.with_columns(
pl.col('trait').is_in(list(original_traits)).alias('is_original')
)
return freq_df, None
def compute_pairwise_significance(
self,
data: pl.LazyFrame | pl.DataFrame,
test_type: str = "auto",
alpha: float = 0.05,
correction: str = "bonferroni",
) -> tuple[pl.DataFrame, dict]:
"""Compute pairwise statistical significance tests between columns.
Original use-case: "I need to test for statistical significance and present
this in a logical manner. It should be a generalized function to work on
many dataframes."
This function performs pairwise statistical tests between all numeric columns
(excluding '_recordId') to determine which groups differ significantly.
Args:
data: Pre-fetched data with numeric columns to compare.
Expected format: rows are observations, columns are groups/categories.
Example: Voice_Scale_1_10__V14, Voice_Scale_1_10__V04, etc.
test_type: Statistical test to use:
- "auto": Automatically chooses based on data (default)
- "mannwhitney": Mann-Whitney U test (non-parametric, for continuous)
- "ttest": Independent samples t-test (parametric, for continuous)
- "chi2": Chi-square test (for count/frequency data)
alpha: Significance level (default 0.05)
correction: Multiple comparison correction method:
- "bonferroni": Bonferroni correction (conservative)
- "holm": Holm-Bonferroni (less conservative)
- "none": No correction
Returns:
tuple: (pairwise_df, metadata)
- pairwise_df: DataFrame with columns ['group1', 'group2', 'p_value',
'p_adjusted', 'significant', 'effect_size', 'mean1', 'mean2', 'n1', 'n2']
- metadata: dict with 'test_type', 'alpha', 'correction', 'n_comparisons',
'overall_test_stat', 'overall_p_value'
"""
from scipy import stats as scipy_stats
import numpy as np
if isinstance(data, pl.LazyFrame):
df = data.collect()
else:
df = data
# Get numeric columns (exclude _recordId and other non-data columns)
value_cols = [c for c in df.columns if c != '_recordId' and df[c].dtype in [pl.Float64, pl.Float32, pl.Int64, pl.Int32]]
if len(value_cols) < 2:
raise ValueError(f"Need at least 2 numeric columns for comparison, found {len(value_cols)}")
# Auto-detect test type based on data characteristics
if test_type == "auto":
# Check if data looks like counts (integers, small range) vs continuous
sample_col = df[value_cols[0]].drop_nulls()
if len(sample_col) > 0:
is_integer = sample_col.dtype in [pl.Int64, pl.Int32]
unique_ratio = sample_col.n_unique() / len(sample_col)
if is_integer and unique_ratio < 0.1:
test_type = "chi2"
else:
test_type = "mannwhitney" # Default to non-parametric
else:
test_type = "mannwhitney"
# Extract data as lists (dropping nulls for each column)
group_data = {}
for col in value_cols:
group_data[col] = df[col].drop_nulls().to_numpy()
# Compute overall test (Kruskal-Wallis for non-parametric, ANOVA for parametric)
all_groups = [group_data[col] for col in value_cols if len(group_data[col]) > 0]
if test_type in ["mannwhitney", "auto"]:
overall_stat, overall_p = scipy_stats.kruskal(*all_groups)
overall_test_name = "Kruskal-Wallis"
elif test_type == "ttest":
overall_stat, overall_p = scipy_stats.f_oneway(*all_groups)
overall_test_name = "One-way ANOVA"
else:
overall_stat, overall_p = None, None
overall_test_name = "N/A (Chi-square)"
# Compute pairwise tests
results = []
n_comparisons = len(value_cols) * (len(value_cols) - 1) // 2
for i, col1 in enumerate(value_cols):
for col2 in value_cols[i+1:]:
data1 = group_data[col1]
data2 = group_data[col2]
n1, n2 = len(data1), len(data2)
mean1 = float(np.mean(data1)) if n1 > 0 else None
mean2 = float(np.mean(data2)) if n2 > 0 else None
# Skip if either group has no data
if n1 == 0 or n2 == 0:
results.append({
'group1': self._clean_voice_label(col1),
'group2': self._clean_voice_label(col2),
'p_value': None,
'effect_size': None,
'mean1': mean1,
'mean2': mean2,
'n1': n1,
'n2': n2,
})
continue
# Perform the appropriate test
if test_type == "mannwhitney":
stat, p_value = scipy_stats.mannwhitneyu(data1, data2, alternative='two-sided')
# Effect size: rank-biserial correlation
effect_size = 1 - (2 * stat) / (n1 * n2)
elif test_type == "ttest":
stat, p_value = scipy_stats.ttest_ind(data1, data2)
# Effect size: Cohen's d
pooled_std = np.sqrt(((n1-1)*np.std(data1)**2 + (n2-1)*np.std(data2)**2) / (n1+n2-2))
effect_size = (mean1 - mean2) / pooled_std if pooled_std > 0 else 0
elif test_type == "chi2":
# Create contingency table from the two distributions
# Bin the data for chi-square
all_data = np.concatenate([data1, data2])
bins = np.histogram_bin_edges(all_data, bins='auto')
counts1, _ = np.histogram(data1, bins=bins)
counts2, _ = np.histogram(data2, bins=bins)
contingency = np.array([counts1, counts2])
# Remove zero columns
contingency = contingency[:, contingency.sum(axis=0) > 0]
if contingency.shape[1] > 1:
stat, p_value, _, _ = scipy_stats.chi2_contingency(contingency)
effect_size = np.sqrt(stat / (contingency.sum() * (min(contingency.shape) - 1)))
else:
p_value, effect_size = 1.0, 0.0
else:
raise ValueError(f"Unknown test_type: {test_type}")
results.append({
'group1': self._clean_voice_label(col1),
'group2': self._clean_voice_label(col2),
'p_value': float(p_value),
'effect_size': float(effect_size),
'mean1': mean1,
'mean2': mean2,
'n1': n1,
'n2': n2,
})
# Create DataFrame and apply multiple comparison correction
results_df = pl.DataFrame(results)
# Apply correction
p_values = results_df['p_value'].to_numpy()
valid_mask = ~np.isnan(p_values.astype(float))
p_adjusted = np.full_like(p_values, np.nan, dtype=float)
if correction == "bonferroni":
p_adjusted[valid_mask] = np.minimum(p_values[valid_mask] * n_comparisons, 1.0)
elif correction == "holm":
# Holm-Bonferroni step-down procedure
valid_p = p_values[valid_mask]
sorted_idx = np.argsort(valid_p)
sorted_p = valid_p[sorted_idx]
m = len(sorted_p)
adjusted = np.zeros(m)
for j in range(m):
adjusted[j] = sorted_p[j] * (m - j)
# Ensure monotonicity
for j in range(1, m):
adjusted[j] = max(adjusted[j], adjusted[j-1])
adjusted = np.minimum(adjusted, 1.0)
# Restore original order
p_adjusted[valid_mask] = adjusted[np.argsort(sorted_idx)]
elif correction == "none":
p_adjusted = p_values.astype(float)
results_df = results_df.with_columns([
pl.Series('p_adjusted', p_adjusted),
pl.Series('significant', p_adjusted < alpha),
])
metadata = {
'test_type': test_type,
'alpha': alpha,
'correction': correction,
'n_comparisons': n_comparisons,
'overall_test': overall_test_name,
'overall_stat': overall_stat,
'overall_p_value': overall_p,
}
return results_df, metadata
def compute_ranking_significance(
self,
data: pl.LazyFrame | pl.DataFrame,
alpha: float = 0.05,
correction: str = "bonferroni",
) -> tuple[pl.DataFrame, dict]:
"""Compute statistical significance for ranking data (e.g., Top 3 Voices).
Original use-case: "Test whether voices are ranked significantly differently
based on the distribution of 1st, 2nd, 3rd place votes."
This function takes raw ranking data (rows = respondents, columns = voices,
values = rank 1/2/3 or null) and performs:
1. Overall chi-square test on the full contingency table
2. Pairwise proportion tests comparing Rank 1 vote shares
Args:
data: Pre-fetched ranking data from get_top_3_voices() or get_character_ranking().
Expected format: rows are respondents, columns are voices/characters,
values are 1, 2, 3 (rank) or null (not ranked).
alpha: Significance level (default 0.05)
correction: Multiple comparison correction method:
- "bonferroni": Bonferroni correction (conservative)
- "holm": Holm-Bonferroni (less conservative)
- "none": No correction
Returns:
tuple: (pairwise_df, metadata)
- pairwise_df: DataFrame with columns ['group1', 'group2', 'p_value',
'p_adjusted', 'significant', 'rank1_count1', 'rank1_count2',
'rank1_pct1', 'rank1_pct2', 'total1', 'total2']
- metadata: dict with 'alpha', 'correction', 'n_comparisons',
'chi2_stat', 'chi2_p_value', 'contingency_table'
Example:
>>> ranking_data, _ = S.get_top_3_voices(data)
>>> pairwise_df, meta = S.compute_ranking_significance(ranking_data)
>>> # See which voices have significantly different Rank 1 proportions
>>> print(pairwise_df.filter(pl.col('significant') == True))
"""
from scipy import stats as scipy_stats
import numpy as np
if isinstance(data, pl.LazyFrame):
df = data.collect()
else:
df = data
# Get ranking columns (exclude _recordId)
ranking_cols = [c for c in df.columns if c != '_recordId']
if len(ranking_cols) < 2:
raise ValueError(f"Need at least 2 ranking columns, found {len(ranking_cols)}")
# Build contingency table: rows = ranks (1, 2, 3), columns = voices
# Count how many times each voice received each rank
contingency_data = {}
for col in ranking_cols:
label = self._clean_voice_label(col)
r1 = df.filter(pl.col(col) == 1).height
r2 = df.filter(pl.col(col) == 2).height
r3 = df.filter(pl.col(col) == 3).height
contingency_data[label] = [r1, r2, r3]
# Create contingency table as numpy array
labels = list(contingency_data.keys())
contingency_table = np.array([contingency_data[l] for l in labels]).T # 3 x n_voices
# Overall chi-square test on contingency table
# Tests whether rank distribution is independent of voice
chi2_stat, chi2_p, chi2_dof, _ = scipy_stats.chi2_contingency(contingency_table)
# Pairwise proportion tests for Rank 1 votes
# We use a two-proportion z-test to compare rank 1 proportions
results = []
n_comparisons = len(labels) * (len(labels) - 1) // 2
# Total respondents who ranked any voice in top 3
total_respondents = df.height
for i, label1 in enumerate(labels):
for label2 in labels[i+1:]:
r1_count1 = contingency_data[label1][0] # Rank 1 votes for voice 1
r1_count2 = contingency_data[label2][0] # Rank 1 votes for voice 2
# Total times each voice was ranked (1st + 2nd + 3rd)
total1 = sum(contingency_data[label1])
total2 = sum(contingency_data[label2])
# Calculate proportions of Rank 1 out of all rankings for each voice
pct1 = r1_count1 / total1 if total1 > 0 else 0
pct2 = r1_count2 / total2 if total2 > 0 else 0
# Two-proportion z-test
# H0: p1 = p2 (both voices have same proportion of Rank 1)
if total1 > 0 and total2 > 0 and (r1_count1 + r1_count2) > 0:
# Pooled proportion
p_pooled = (r1_count1 + r1_count2) / (total1 + total2)
# Standard error
se = np.sqrt(p_pooled * (1 - p_pooled) * (1/total1 + 1/total2))
if se > 0:
z_stat = (pct1 - pct2) / se
p_value = 2 * (1 - scipy_stats.norm.cdf(abs(z_stat))) # Two-tailed
else:
p_value = 1.0
else:
p_value = 1.0
results.append({
'group1': label1,
'group2': label2,
'p_value': float(p_value),
'rank1_count1': r1_count1,
'rank1_count2': r1_count2,
'rank1_pct1': round(pct1 * 100, 1),
'rank1_pct2': round(pct2 * 100, 1),
'total1': total1,
'total2': total2,
})
# Create DataFrame and apply correction
results_df = pl.DataFrame(results)
p_values = results_df['p_value'].to_numpy()
p_adjusted = np.full_like(p_values, np.nan, dtype=float)
if correction == "bonferroni":
p_adjusted = np.minimum(p_values * n_comparisons, 1.0)
elif correction == "holm":
sorted_idx = np.argsort(p_values)
sorted_p = p_values[sorted_idx]
m = len(sorted_p)
adjusted = np.zeros(m)
for j in range(m):
adjusted[j] = sorted_p[j] * (m - j)
for j in range(1, m):
adjusted[j] = max(adjusted[j], adjusted[j-1])
adjusted = np.minimum(adjusted, 1.0)
p_adjusted = adjusted[np.argsort(sorted_idx)]
elif correction == "none":
p_adjusted = p_values.astype(float)
results_df = results_df.with_columns([
pl.Series('p_adjusted', p_adjusted),
pl.Series('significant', p_adjusted < alpha),
])
# Sort by p_value for easier inspection
results_df = results_df.sort('p_value')
metadata = {
'test_type': 'proportion_z_test',
'alpha': alpha,
'correction': correction,
'n_comparisons': n_comparisons,
'chi2_stat': chi2_stat,
'chi2_p_value': chi2_p,
'chi2_dof': chi2_dof,
'overall_test': 'Chi-square',
'overall_stat': chi2_stat,
'overall_p_value': chi2_p,
'contingency_table': {label: contingency_data[label] for label in labels},
}
return results_df, metadata
def process_speaking_style_data(
df: Union[pl.LazyFrame, pl.DataFrame],
trait_map: dict[str, str]
) -> pl.DataFrame:
"""
Process speaking style columns from wide to long format and map trait descriptions.
Parses columns with format: SS_{StyleGroup}__{Voice}__{ChoiceID}
Example: SS_Orange_Red__V14__Choice_1
Parameters
----------
df : pl.LazyFrame or pl.DataFrame
Input dataframe containing SS_* columns.
trait_map : dict
Dictionary mapping column names to trait descriptions.
Keys should be full column names like "SS_Orange_Red__V14__Choice_1".
Returns
-------
pl.DataFrame
Long-format dataframe with columns:
_recordId, Voice, Style_Group, Choice_ID, Description, Score, Left_Anchor, Right_Anchor
"""
# Normalize input to LazyFrame
lf = df.lazy() if isinstance(df, pl.DataFrame) else df
# 1. Melt SS_ columns
melted = lf.melt(
id_vars=["_recordId"],
value_vars=pl.col("^SS_.*$"),
variable_name="full_col_name",
value_name="score"
)
# 2. Extract components from column name
# Regex captures: Style_Group (e.g. SS_Orange_Red), Voice (e.g. V14), Choice_ID (e.g. Choice_1)
pattern = r"^(?P<Style_Group>SS_.+?)__(?P<Voice>.+?)__(?P<Choice_ID>Choice_\d+)$"
processed = melted.with_columns(
pl.col("full_col_name").str.extract_groups(pattern)
).unnest("full_col_name")
# 3. Create Mapping Lookup from the provided dictionary
# We map (Style_Group, Choice_ID) -> Description
mapping_data = []
seen = set()
for col_name, desc in trait_map.items():
match = re.match(pattern, col_name)
if match:
groups = match.groupdict()
key = (groups["Style_Group"], groups["Choice_ID"])
if key not in seen:
# Parse description into anchors if possible (Left : Right)
parts = desc.split(':')
left_anchor = parts[0].strip() if len(parts) > 0 else ""
right_anchor = parts[1].strip() if len(parts) > 1 else ""
mapping_data.append({
"Style_Group": groups["Style_Group"],
"Choice_ID": groups["Choice_ID"],
"Description": desc,
"Left_Anchor": left_anchor,
"Right_Anchor": right_anchor
})
seen.add(key)
if not mapping_data:
return processed.collect()
mapping_lf = pl.LazyFrame(mapping_data)
# 4. Join Data with Mapping
result = processed.join(
mapping_lf,
on=["Style_Group", "Choice_ID"],
how="left"
)
# 5. Cast score to Int
result = result.with_columns(
pl.col("score").cast(pl.Int64, strict=False)
)
return result.collect()
def process_voice_scale_data(
df: Union[pl.LazyFrame, pl.DataFrame]
) -> pl.DataFrame:
"""
Process Voice Scale columns from wide to long format.
Parses columns with format: Voice_Scale_1_10__V{Voice}
Example: Voice_Scale_1_10__V14
Returns
-------
pl.DataFrame
Long-format dataframe with columns:
_recordId, Voice, Voice_Scale_Score
"""
lf = df.lazy() if isinstance(df, pl.DataFrame) else df
# Melt
melted = lf.melt(
id_vars=["_recordId"],
value_vars=pl.col("^Voice_Scale_1_10__V.*$"),
variable_name="full_col_name",
value_name="Voice_Scale_Score"
)
# Extract Voice
processed = melted.with_columns(
pl.col("full_col_name").str.extract(r"V(\d+)", 1).alias("Voice_Num")
).with_columns(
("V" + pl.col("Voice_Num")).alias("Voice")
)
# Keep Score as Float (original data is f64)
result = processed.select([
"_recordId",
"Voice",
pl.col("Voice_Scale_Score").cast(pl.Float64, strict=False)
])
return result.collect()
def join_voice_and_style_data(
processed_style_data: pl.DataFrame,
processed_voice_data: pl.DataFrame
) -> pl.DataFrame:
"""
Joins processed Speaking Style data with Voice Scale 1-10 data.
Parameters
----------
processed_style_data : pl.DataFrame
Result of process_speaking_style_data
processed_voice_data : pl.DataFrame
Result of process_voice_scale_data
Returns
-------
pl.DataFrame
Merged dataframe with columns from both, joined on _recordId and Voice.
"""
return processed_style_data.join(
processed_voice_data,
on=["_recordId", "Voice"],
how="inner"
)
def transform_speaking_style_color_correlation(
joined_df: pl.LazyFrame | pl.DataFrame,
speaking_styles: dict[str, list[str]],
target_column: str = "Voice_Scale_Score"
) -> tuple[pl.DataFrame, dict | None]:
"""Aggregate speaking style correlation by color (Green, Blue, Orange, Red).
Original use-case: "I want to create high-level correlation plots between
'green, blue, orange, red' speaking styles and the 'voice scale scores'.
I want to go to one plot with one bar for each color."
This function calculates the mean correlation per speaking style color by
averaging the correlations of all traits within each color.
Parameters
----------
joined_df : pl.LazyFrame or pl.DataFrame
Pre-fetched data from joining speaking style data with target data.
Must have columns: Right_Anchor, score, and the target_column
speaking_styles : dict
Dictionary mapping color names to their constituent traits.
Typically imported from speaking_styles.SPEAKING_STYLES
target_column : str
The column to correlate against speaking style scores.
Default: "Voice_Scale_Score" (for voice scale 1-10)
Alternative: "Ranking_Points" (for top 3 voice ranking)
Returns
-------
tuple[pl.DataFrame, dict | None]
(DataFrame with columns [Color, correlation, n_traits], None)
"""
if isinstance(joined_df, pl.LazyFrame):
joined_df = joined_df.collect()
color_correlations = []
for color, traits in speaking_styles.items():
trait_corrs = []
for trait in traits:
# Filter to this specific trait
subset = joined_df.filter(pl.col("Right_Anchor") == trait)
valid_data = subset.select(["score", target_column]).drop_nulls()
if valid_data.height > 1:
corr_val = valid_data.select(pl.corr("score", target_column)).item()
if corr_val is not None:
trait_corrs.append(corr_val)
# Average across all traits for this color
if trait_corrs:
avg_corr = sum(trait_corrs) / len(trait_corrs)
color_correlations.append({
"Color": color,
"correlation": avg_corr,
"n_traits": len(trait_corrs)
})
result_df = pl.DataFrame(color_correlations)
return result_df, None
def process_voice_ranking_data(
df: Union[pl.LazyFrame, pl.DataFrame]
) -> pl.DataFrame:
"""
Process Voice Ranking columns from wide to long format and convert ranks to points.
Parses columns with format: Top_3_Voices_ranking__V{Voice}
Converts ranks to points: 1st place = 3 pts, 2nd place = 2 pts, 3rd place = 1 pt
Returns
-------
pl.DataFrame
Long-format dataframe with columns:
_recordId, Voice, Ranking_Points
"""
lf = df.lazy() if isinstance(df, pl.DataFrame) else df
# Melt
melted = lf.melt(
id_vars=["_recordId"],
value_vars=pl.col("^Top_3_Voices_ranking__V.*$"),
variable_name="full_col_name",
value_name="rank"
)
# Extract Voice
processed = melted.with_columns(
pl.col("full_col_name").str.extract(r"V(\d+)", 1).alias("Voice_Num")
).with_columns(
("V" + pl.col("Voice_Num")).alias("Voice")
)
# Convert rank to points: 1st=3, 2nd=2, 3rd=1, null=0 (not ranked)
# Rank values are 1, 2, 3 for position in top 3
result = processed.with_columns(
pl.when(pl.col("rank") == 1).then(3)
.when(pl.col("rank") == 2).then(2)
.when(pl.col("rank") == 3).then(1)
.otherwise(0)
.alias("Ranking_Points")
).select([
"_recordId",
"Voice",
"Ranking_Points"
])
return result.collect()
def split_consumer_groups(df: Union[pl.LazyFrame, pl.DataFrame], col: str = "Consumer") -> dict[str, pl.DataFrame]:
"""
Split dataframe into groups based on a column.
If col is 'Consumer', it combines A/B subgroups (e.g. Mass_A + Mass_B -> Mass).
For other columns, it splits by unique values as-is.
"""
if isinstance(df, pl.LazyFrame):
df = df.collect()
if col not in df.columns:
raise ValueError(f"Column '{col}' not found in DataFrame")
group_col_alias = f"{col}_Group"
if col == "Consumer":
# Clean Consumer column by removing _A or _B suffix
# Using regex replacement for trailing _A or _B
df_clean = df.with_columns(
pl.col(col)
.str.replace(r"_[AB]$", "")
.alias(group_col_alias)
)
else:
# Use values as is
df_clean = df.with_columns(
pl.col(col).alias(group_col_alias)
)
# Split into dict
groups = {}
unique_groups = df_clean[group_col_alias].drop_nulls().unique().to_list()
for group in unique_groups:
groups[group] = df_clean.filter(pl.col(group_col_alias) == group)
return groups