2198 lines
85 KiB
Python
2198 lines
85 KiB
Python
import polars as pl
|
||
from pathlib import Path
|
||
import pandas as pd
|
||
from typing import Union
|
||
import json
|
||
import re
|
||
import hashlib
|
||
import os
|
||
from io import BytesIO
|
||
|
||
import imagehash
|
||
from PIL import Image
|
||
|
||
from plots import QualtricsPlotsMixin
|
||
|
||
|
||
from pptx import Presentation
|
||
from pptx.enum.shapes import MSO_SHAPE_TYPE
|
||
|
||
|
||
def image_alt_text_generator(fpath, include_dataset_dirname=False) -> str:
|
||
"""convert image file path to alt text
|
||
|
||
Args:
|
||
fpath (str or Path): path to image file, must start with 'figures/'
|
||
include_dataset_dirname (bool): whether to include the dataset directory name in the alt text. Recommended to keep False, so that the images do not get tied to a specific dataset export. (Defeats the purpose of assigning alt text to be able to update images when new datasets are exported.)
|
||
"""
|
||
|
||
if not isinstance(fpath, Path):
|
||
fpath = Path(fpath)
|
||
|
||
fparts = fpath.parts
|
||
assert fparts[0] == 'figures', "Image file path must start with 'figures'"
|
||
|
||
if include_dataset_dirname:
|
||
return Path('/'.join(fparts[1:])).as_posix()
|
||
else:
|
||
return Path('/'.join(fparts[2:])).as_posix()
|
||
|
||
def _get_shape_alt_text(shape) -> str:
|
||
"""
|
||
Extract alt text from a PowerPoint shape.
|
||
|
||
Args:
|
||
shape: A python-pptx shape object.
|
||
|
||
Returns:
|
||
str: The alt text (descr attribute) or empty string if not found.
|
||
"""
|
||
try:
|
||
# Check for common property names used by python-pptx elements to store non-visual props
|
||
# nvPicPr (Picture), nvSpPr (Shape/Placeholder), nvGrpSpPr (Group),
|
||
# nvGraphicFramePr (GraphicFrame), nvCxnSpPr (Connector)
|
||
nvPr = None
|
||
for attr in ['nvPicPr', 'nvSpPr', 'nvGrpSpPr', 'nvGraphicFramePr', 'nvCxnSpPr']:
|
||
if hasattr(shape._element, attr):
|
||
nvPr = getattr(shape._element, attr)
|
||
break
|
||
|
||
if nvPr is not None and hasattr(nvPr, 'cNvPr'):
|
||
return nvPr.cNvPr.get("descr", "")
|
||
except Exception:
|
||
pass
|
||
return ""
|
||
|
||
|
||
def pptx_replace_images_from_directory(
|
||
presentation_path: Union[str, Path],
|
||
image_source_dir: Union[str, Path],
|
||
save_path: Union[str, Path] = None
|
||
) -> dict:
|
||
"""
|
||
Replace all images in a PowerPoint presentation using images from a directory
|
||
where subdirectory/filename paths match the alt_text of each image.
|
||
|
||
This function scans all images in the presentation, extracts their alt_text,
|
||
and looks for a matching image file in the source directory. The alt_text
|
||
should be a relative path (e.g., "All_Respondents/chart_name.png") that
|
||
corresponds to the directory structure under image_source_dir.
|
||
|
||
Args:
|
||
presentation_path (str/Path): Path to the source .pptx file.
|
||
image_source_dir (str/Path): Root directory containing replacement images.
|
||
The directory structure should mirror the alt_text paths.
|
||
Example: if alt_text is "All_Respondents/voice_scale.png", the
|
||
replacement image should be at image_source_dir/All_Respondents/voice_scale.png
|
||
save_path (str/Path, optional): Path to save the modified presentation.
|
||
If None, overwrites the input file.
|
||
|
||
Returns:
|
||
dict: Summary with keys:
|
||
- 'replaced': List of dicts with slide number, shape name, and matched path
|
||
- 'not_found': List of dicts with slide number, shape name, and alt_text
|
||
- 'no_alt_text': List of dicts with slide number and shape name
|
||
- 'total_images': Total number of picture shapes processed
|
||
|
||
Example:
|
||
>>> pptx_replace_images_from_directory(
|
||
... "presentation.pptx",
|
||
... "figures/2-2-26/",
|
||
... "presentation_updated.pptx"
|
||
... )
|
||
|
||
Notes:
|
||
- Alt text should be set using update_ppt_alt_text() or image_alt_text_generator()
|
||
- Images without alt_text are skipped
|
||
- Original image position, size, and aspect ratio are preserved
|
||
"""
|
||
presentation_path = Path(presentation_path)
|
||
image_source_dir = Path(image_source_dir)
|
||
|
||
if save_path is None:
|
||
save_path = presentation_path
|
||
else:
|
||
save_path = Path(save_path)
|
||
|
||
if not presentation_path.exists():
|
||
raise FileNotFoundError(f"Presentation not found: {presentation_path}")
|
||
if not image_source_dir.exists():
|
||
raise FileNotFoundError(f"Image source directory not found: {image_source_dir}")
|
||
|
||
# Build a lookup of all available images in the source directory
|
||
available_images = {}
|
||
for img_path in image_source_dir.rglob("*"):
|
||
if img_path.is_file() and img_path.suffix.lower() in {'.png', '.jpg', '.jpeg', '.gif', '.bmp', '.tiff', '.webp'}:
|
||
# Store relative path from image_source_dir as key
|
||
rel_path = img_path.relative_to(image_source_dir).as_posix()
|
||
available_images[rel_path] = img_path
|
||
|
||
print(f"Found {len(available_images)} images in source directory")
|
||
|
||
# Open presentation
|
||
prs = Presentation(presentation_path)
|
||
|
||
# Track results
|
||
results = {
|
||
'replaced': [],
|
||
'not_found': [],
|
||
'no_alt_text': [],
|
||
'total_images': 0
|
||
}
|
||
|
||
total_slides = len(prs.slides)
|
||
print(f"Processing {total_slides} slides...")
|
||
|
||
for slide_idx, slide in enumerate(prs.slides):
|
||
slide_num = slide_idx + 1
|
||
|
||
# Use recursive iterator to find all pictures including those in groups
|
||
picture_shapes = list(_iter_picture_shapes(slide.shapes))
|
||
|
||
for shape in picture_shapes:
|
||
results['total_images'] += 1
|
||
shape_name = shape.name or f"Unnamed (ID: {getattr(shape, 'shape_id', 'unknown')})"
|
||
|
||
# Get alt text
|
||
alt_text = _get_shape_alt_text(shape)
|
||
|
||
if not alt_text:
|
||
results['no_alt_text'].append({
|
||
'slide': slide_num,
|
||
'shape_name': shape_name
|
||
})
|
||
continue
|
||
|
||
# Look for matching image in source directory
|
||
# Try the alt_text as-is, and also with common extensions if not present
|
||
matched_path = None
|
||
|
||
if alt_text in available_images:
|
||
matched_path = available_images[alt_text]
|
||
else:
|
||
# Try adding common extensions if alt_text doesn't have one
|
||
alt_text_path = Path(alt_text)
|
||
if not alt_text_path.suffix:
|
||
for ext in ['.png', '.jpg', '.jpeg', '.gif']:
|
||
test_key = f"{alt_text}{ext}"
|
||
if test_key in available_images:
|
||
matched_path = available_images[test_key]
|
||
break
|
||
|
||
if matched_path is None:
|
||
results['not_found'].append({
|
||
'slide': slide_num,
|
||
'shape_name': shape_name,
|
||
'alt_text': alt_text
|
||
})
|
||
continue
|
||
|
||
# Replace the image
|
||
try:
|
||
# Record coordinates
|
||
left, top, width, height = shape.left, shape.top, shape.width, shape.height
|
||
|
||
# Remove old shape from XML
|
||
old_element = shape._element
|
||
old_element.getparent().remove(old_element)
|
||
|
||
# Add new image at the same position/size
|
||
new_shape = slide.shapes.add_picture(str(matched_path), left, top, width, height)
|
||
|
||
# Preserve the alt text on the new shape
|
||
new_nvPr = None
|
||
for attr in ['nvPicPr', 'nvSpPr', 'nvGrpSpPr', 'nvGraphicFramePr', 'nvCxnSpPr']:
|
||
if hasattr(new_shape._element, attr):
|
||
new_nvPr = getattr(new_shape._element, attr)
|
||
break
|
||
if new_nvPr and hasattr(new_nvPr, 'cNvPr'):
|
||
new_nvPr.cNvPr.set("descr", alt_text)
|
||
|
||
results['replaced'].append({
|
||
'slide': slide_num,
|
||
'shape_name': shape_name,
|
||
'matched_path': str(matched_path)
|
||
})
|
||
print(f"Slide {slide_num}: Replaced '{alt_text}'")
|
||
|
||
except Exception as e:
|
||
results['not_found'].append({
|
||
'slide': slide_num,
|
||
'shape_name': shape_name,
|
||
'alt_text': alt_text,
|
||
'error': str(e)
|
||
})
|
||
|
||
# Save presentation
|
||
prs.save(save_path)
|
||
|
||
# Print summary
|
||
print("\n" + "=" * 80)
|
||
if results['replaced']:
|
||
print(f"✓ Saved updated presentation to {save_path} with {len(results['replaced'])} replacements.")
|
||
else:
|
||
print("No images matched or required updates.")
|
||
|
||
if results['not_found']:
|
||
print(f"\n⚠ {len(results['not_found'])} image(s) not found in source directory:")
|
||
for item in results['not_found']:
|
||
print(f" • Slide {item['slide']}: '{item.get('alt_text', 'N/A')}'")
|
||
|
||
if results['no_alt_text']:
|
||
print(f"\n⚠ {len(results['no_alt_text'])} image(s) without alt text (skipped):")
|
||
for item in results['no_alt_text']:
|
||
print(f" • Slide {item['slide']}: '{item['shape_name']}'")
|
||
|
||
if not results['not_found'] and not results['no_alt_text']:
|
||
print("\n✓ All images replaced successfully!")
|
||
print("=" * 80)
|
||
|
||
return results
|
||
|
||
|
||
def pptx_replace_named_image(presentation_path, target_tag, new_image_path, save_path):
|
||
"""
|
||
Finds and replaces a specific image in a PowerPoint presentation while
|
||
preserving its original position, size, and aspect ratio.
|
||
|
||
This function performs a 'surgical' replacement: it records the coordinates
|
||
of the existing image, removes it from the slide's XML, and inserts a
|
||
new image into the exact same bounding box. It identifies the target
|
||
image by searching for a specific string within the Shape Name
|
||
(Selection Pane) or Alt Text.
|
||
|
||
Note: For batch replacement of all images using a directory structure,
|
||
use pptx_replace_images_from_directory() instead.
|
||
|
||
Args:
|
||
presentation_path (str): The file path to the source .pptx file.
|
||
target_tag (str): The unique identifier to look for (e.g., 'HERO_IMAGE').
|
||
This is case-sensitive and checks both the shape name and alt text.
|
||
new_image_path (str): The file path to the new image (PNG, JPG, etc.).
|
||
save_path (str): The file path where the modified presentation will be saved.
|
||
|
||
Returns:
|
||
None: Saves the file directly to the provided save_path.
|
||
|
||
Raises:
|
||
FileNotFoundError: If the source presentation or new image is not found.
|
||
PermissionError: If the save_path is currently open or locked.
|
||
"""
|
||
prs = Presentation(presentation_path)
|
||
|
||
for i, slide in enumerate(prs.slides):
|
||
# Iterate over a list copy of shapes to safely modify the slide during iteration
|
||
print(f"Processing Slide {i + 1}...")
|
||
print(f"Total Shapes: {len(slide.shapes)} shapes")
|
||
|
||
for shape in list(slide.shapes):
|
||
print(f"Checking shape: {shape.name} of type {shape.shape_type}...")
|
||
|
||
shape_name = shape.name or ""
|
||
alt_text = _get_shape_alt_text(shape)
|
||
|
||
print(f"Alt Text for shape '{shape_name}': {alt_text}")
|
||
|
||
if target_tag in shape_name or target_tag in alt_text:
|
||
print(f"Found it! Replacing {shape_name}...")
|
||
|
||
try:
|
||
# Record coordinates
|
||
left, top, width, height = shape.left, shape.top, shape.width, shape.height
|
||
|
||
# Remove old shape
|
||
old_element = shape._element
|
||
old_element.getparent().remove(old_element)
|
||
|
||
# Add new image at the same spot
|
||
slide.shapes.add_picture(str(new_image_path), left, top, width, height)
|
||
except AttributeError:
|
||
print(f"Could not replace {shape_name} - might be missing dimensions.")
|
||
|
||
else:
|
||
print(f"Skipping shape '{shape_name}' with alt text '{alt_text}'")
|
||
|
||
prs.save(save_path)
|
||
print(f"Successfully saved to {save_path}")
|
||
|
||
|
||
def _calculate_file_sha1(file_path: Union[str, Path]) -> str:
|
||
"""Calculate SHA1 hash of a file."""
|
||
sha1 = hashlib.sha1()
|
||
with open(file_path, 'rb') as f:
|
||
while True:
|
||
data = f.read(65536)
|
||
if not data:
|
||
break
|
||
sha1.update(data)
|
||
return sha1.hexdigest()
|
||
|
||
|
||
def _calculate_perceptual_hash(image_source: Union[str, Path, bytes]) -> str:
|
||
"""
|
||
Calculate perceptual hash of an image based on visual content.
|
||
|
||
Uses pHash (perceptual hash) which is robust against:
|
||
- Metadata differences
|
||
- Minor compression differences
|
||
- Small color/contrast variations
|
||
|
||
Args:
|
||
image_source: File path to image or raw image bytes.
|
||
|
||
Returns:
|
||
str: Hexadecimal string representation of the perceptual hash.
|
||
"""
|
||
if isinstance(image_source, bytes):
|
||
img = Image.open(BytesIO(image_source))
|
||
else:
|
||
img = Image.open(image_source)
|
||
|
||
# Convert to RGB if necessary (handles RGBA, P mode, etc.)
|
||
if img.mode not in ('RGB', 'L'):
|
||
img = img.convert('RGB')
|
||
|
||
# Use pHash (perceptual hash) - robust against minor differences
|
||
phash = imagehash.phash(img)
|
||
return str(phash)
|
||
|
||
|
||
def _build_image_hash_map(root_dir: Union[str, Path], use_perceptual_hash: bool = True) -> dict:
|
||
"""
|
||
Recursively walk the directory and build a map of image hashes to file paths.
|
||
Only includes common image extensions.
|
||
|
||
Args:
|
||
root_dir: Root directory to scan for images.
|
||
use_perceptual_hash: If True, uses perceptual hashing (robust against metadata
|
||
differences). If False, uses SHA1 byte hashing (exact match only).
|
||
|
||
Returns:
|
||
dict: Mapping of hash strings to file paths.
|
||
"""
|
||
hash_map = {}
|
||
valid_extensions = {'.png', '.jpg', '.jpeg', '.tiff', '.bmp', '.gif'}
|
||
|
||
root = Path(root_dir)
|
||
hash_type = "perceptual" if use_perceptual_hash else "SHA1"
|
||
print(f"Building image hash map from {root} using {hash_type} hashing...")
|
||
|
||
count = 0
|
||
for root_path, dirs, files in os.walk(root):
|
||
for file in files:
|
||
file_path = Path(root_path) / file
|
||
if file_path.suffix.lower() in valid_extensions:
|
||
try:
|
||
if use_perceptual_hash:
|
||
file_hash = _calculate_perceptual_hash(file_path)
|
||
else:
|
||
file_hash = _calculate_file_sha1(file_path)
|
||
# We store the absolute path for reference, but we might just need the path relative to project for alt text
|
||
hash_map[file_hash] = file_path
|
||
count += 1
|
||
except Exception as e:
|
||
print(f"Error hashing {file_path}: {e}")
|
||
|
||
print(f"Indexed {count} images.")
|
||
return hash_map
|
||
|
||
|
||
def _iter_picture_shapes(shapes):
|
||
"""
|
||
Recursively iterate over shapes and yield those that are pictures
|
||
(have an 'image' property), diving into groups.
|
||
"""
|
||
for shape in shapes:
|
||
# Check groups recursively
|
||
if shape.shape_type == MSO_SHAPE_TYPE.GROUP:
|
||
yield from _iter_picture_shapes(shape.shapes)
|
||
continue
|
||
|
||
# Check if shape has image property (Pictures, Placeholders with images)
|
||
if hasattr(shape, 'image'):
|
||
yield shape
|
||
|
||
|
||
def _set_shape_alt_text(shape, alt_text: str):
|
||
"""
|
||
Set alt text (descr attribute) for a PowerPoint shape.
|
||
"""
|
||
nvPr = None
|
||
# Check for common property names used by python-pptx elements
|
||
for attr in ['nvPicPr', 'nvSpPr', 'nvGrpSpPr', 'nvGraphicFramePr', 'nvCxnSpPr']:
|
||
if hasattr(shape._element, attr):
|
||
nvPr = getattr(shape._element, attr)
|
||
break
|
||
|
||
if nvPr and hasattr(nvPr, 'cNvPr'):
|
||
nvPr.cNvPr.set("descr", alt_text)
|
||
|
||
|
||
def update_ppt_alt_text(ppt_path: Union[str, Path], image_source_dir: Union[str, Path], output_path: Union[str, Path] = None, use_perceptual_hash: bool = True):
|
||
"""
|
||
Updates the alt text of images in a PowerPoint presentation.
|
||
|
||
1. First pass: Validates existing alt-text format (<filter>/<filename>).
|
||
- Fixes full paths by keeping only the last two parts.
|
||
- Clears invalid alt-text.
|
||
2. Second pass: If images are missing alt-text, matches them against source directory
|
||
using perceptual hash or SHA1.
|
||
|
||
Args:
|
||
ppt_path (str/Path): Path to the PowerPoint file.
|
||
image_source_dir (str/Path): Directory containing source images to match against.
|
||
output_path (str/Path, optional): Path to save the updated presentation.
|
||
If None, overwrites the input file.
|
||
use_perceptual_hash (bool): If True (default), uses perceptual hashing which
|
||
matches images based on visual content (robust against metadata differences,
|
||
re-compression, etc.). If False, uses SHA1 byte hashing (exact file match only).
|
||
"""
|
||
if output_path is None:
|
||
output_path = ppt_path
|
||
|
||
# Open Presentation
|
||
try:
|
||
prs = Presentation(ppt_path)
|
||
except Exception as e:
|
||
print(f"Error opening presentation {ppt_path}: {e}")
|
||
return
|
||
|
||
updates_count = 0
|
||
images_needing_match = []
|
||
|
||
slides = list(prs.slides)
|
||
total_slides = len(slides)
|
||
|
||
print(f"Scanning {total_slides} slides for existing alt-text...")
|
||
|
||
# Pass 1: Scan and clean existing alt-text
|
||
for i, slide in enumerate(slides):
|
||
picture_shapes = list(_iter_picture_shapes(slide.shapes))
|
||
|
||
for shape in picture_shapes:
|
||
alt_text = _get_shape_alt_text(shape)
|
||
has_valid_alt = False
|
||
|
||
if alt_text:
|
||
# Handle potential path separators and whitespace
|
||
clean_alt = alt_text.strip().replace('\\', '/')
|
||
parts = clean_alt.split('/')
|
||
|
||
# Check if it looks like a path/file reference (at least 2 parts like dir/file)
|
||
if len(parts) >= 2:
|
||
# Enforce format: keep last 2 parts (e.g. filter/image.png)
|
||
new_alt = '/'.join(parts[-2:])
|
||
|
||
if new_alt != alt_text:
|
||
print(f"Slide {i+1}: Fixing alt-text format: '{alt_text}' -> '{new_alt}'")
|
||
_set_shape_alt_text(shape, new_alt)
|
||
updates_count += 1
|
||
|
||
has_valid_alt = True
|
||
else:
|
||
# User requested deleting other cases that do not meet format
|
||
# If it's single word or doesn't look like our path format
|
||
pass # logic below handles this
|
||
|
||
if not has_valid_alt:
|
||
if alt_text:
|
||
print(f"Slide {i+1}: Invalid/Legacy alt-text '{alt_text}'. Clearing for re-matching.")
|
||
_set_shape_alt_text(shape, "")
|
||
updates_count += 1
|
||
|
||
# Queue for hash matching
|
||
shape_id = getattr(shape, 'shape_id', getattr(shape, 'id', 'Unknown ID'))
|
||
shape_name = shape.name if shape.name else f"Unnamed Shape (ID: {shape_id})"
|
||
images_needing_match.append({
|
||
'slide_idx': i, # 0-based
|
||
'slide_num': i+1,
|
||
'shape': shape,
|
||
'shape_name': shape_name
|
||
})
|
||
|
||
if not images_needing_match:
|
||
print("\nAll images have valid alt-text format. No hash matching needed.")
|
||
if updates_count > 0:
|
||
prs.save(output_path)
|
||
print(f"✓ Saved updated presentation to {output_path} with {updates_count} updates.")
|
||
else:
|
||
print("Presentation is up to date.")
|
||
return
|
||
|
||
# Pass 2: Hash Matching
|
||
print(f"\n{len(images_needing_match)} images missing proper alt-text. Proceeding with hash matching...")
|
||
|
||
# Build lookup map of {hash: file_path} only if needed
|
||
image_hash_map = _build_image_hash_map(image_source_dir, use_perceptual_hash=use_perceptual_hash)
|
||
|
||
unmatched_images = []
|
||
|
||
for item in images_needing_match:
|
||
shape = item['shape']
|
||
slide_num = item['slide_num']
|
||
|
||
try:
|
||
# Get image hash
|
||
if use_perceptual_hash:
|
||
current_hash = _calculate_perceptual_hash(shape.image.blob)
|
||
else:
|
||
current_hash = shape.image.sha1
|
||
|
||
if current_hash in image_hash_map:
|
||
original_path = image_hash_map[current_hash]
|
||
|
||
# Generate Alt Text
|
||
try:
|
||
# Try to relativize to CWD if capable
|
||
pass_path = original_path
|
||
try:
|
||
pass_path = original_path.relative_to(Path.cwd())
|
||
except ValueError:
|
||
pass
|
||
|
||
new_alt_text = image_alt_text_generator(pass_path)
|
||
|
||
print(f"Slide {slide_num}: Match found! Assigning alt-text '{new_alt_text}'")
|
||
_set_shape_alt_text(shape, new_alt_text)
|
||
updates_count += 1
|
||
|
||
except Exception as e:
|
||
print(f"Error generating alt text for {original_path}: {e}")
|
||
else:
|
||
hash_type = "pHash" if use_perceptual_hash else "SHA1"
|
||
unmatched_images.append({
|
||
'slide': slide_num,
|
||
'shape_name': item['shape_name'],
|
||
'hash_type': hash_type,
|
||
'hash': current_hash
|
||
})
|
||
|
||
except Exception as e:
|
||
print(f"Error processing shape on slide {slide_num}: {e}")
|
||
|
||
# Save and Print Summary
|
||
print("\n" + "="*80)
|
||
if updates_count > 0:
|
||
prs.save(output_path)
|
||
print(f"✓ Saved updated presentation to {output_path} with {updates_count} updates.")
|
||
else:
|
||
print("No matches found for missing images.")
|
||
|
||
if unmatched_images:
|
||
print(f"\n⚠ {len(unmatched_images)} image(s) could not be matched:")
|
||
for img in unmatched_images:
|
||
print(f" • Slide {img['slide']}: '{img['shape_name']}' ({img['hash_type']}: {img['hash']})")
|
||
else:
|
||
print("\n✓ All images processed successfully!")
|
||
print("="*80)
|
||
|
||
|
||
def extract_voice_label(html_str: str) -> str:
|
||
"""
|
||
Extract voice label from HTML string and convert to short format.
|
||
|
||
Parameters:
|
||
html_str (str): HTML string containing voice label in format "Voice N"
|
||
|
||
Returns:
|
||
str: Voice label in format "VN" (e.g., "V14")
|
||
|
||
Example:
|
||
>>> extract_voice_label('<span style="...">Voice 14<br />...')
|
||
'V14'
|
||
"""
|
||
match = re.search(r'Voice (\d+)', html_str)
|
||
return f"V{match.group(1)}" if match else None
|
||
|
||
|
||
def extract_qid(val):
|
||
"""Extracts the 'ImportId' from a string representation of a dictionary."""
|
||
|
||
if isinstance(val, str) and val.startswith('{') and val.endswith('}'):
|
||
val = eval(val)
|
||
return val['ImportId']
|
||
|
||
|
||
def combine_exclusive_columns(df: pl.DataFrame, id_col: str = "_recordId", target_col_name: str = "combined_value") -> pl.DataFrame:
|
||
"""
|
||
Combines all columns except id_col into a single column.
|
||
Raises ValueError if more than one column is populated in a single row.
|
||
"""
|
||
merge_cols = [c for c in df.columns if c != id_col]
|
||
|
||
# Validate: count non-nulls horizontally
|
||
row_counts = df.select(
|
||
pl.sum_horizontal(pl.col(merge_cols).is_not_null())
|
||
).to_series()
|
||
|
||
if (row_counts > 1).any():
|
||
raise ValueError("Invalid Data: Multiple columns populated for a single record row.")
|
||
|
||
# Merge columns using coalesce
|
||
return df.select([
|
||
pl.col(id_col),
|
||
pl.coalesce(merge_cols).alias(target_col_name)
|
||
])
|
||
|
||
|
||
|
||
def calculate_weighted_ranking_scores(df: pl.LazyFrame) -> pl.DataFrame:
|
||
"""
|
||
Calculate weighted scores for character or voice rankings.
|
||
Points system: 1st place = 3 pts, 2nd place = 2 pts, 3rd place = 1 pt.
|
||
|
||
Parameters
|
||
----------
|
||
df : pl.DataFrame
|
||
DataFrame containing character/ voice ranking columns.
|
||
|
||
Returns
|
||
-------
|
||
pl.DataFrame
|
||
DataFrame with columns 'Character' and 'Weighted Score', sorted by score.
|
||
"""
|
||
if isinstance(df, pl.LazyFrame):
|
||
df = df.collect()
|
||
|
||
scores = []
|
||
# Identify ranking columns (assume all columns except _recordId)
|
||
ranking_cols = [c for c in df.columns if c != '_recordId']
|
||
|
||
for col in ranking_cols:
|
||
# Calculate score:
|
||
# (Count of Rank 1 * 3) + (Count of Rank 2 * 2) + (Count of Rank 3 * 1)
|
||
r1_count = df.filter(pl.col(col) == 1).height
|
||
r2_count = df.filter(pl.col(col) == 2).height
|
||
r3_count = df.filter(pl.col(col) == 3).height
|
||
|
||
weighted_score = (r1_count * 3) + (r2_count * 2) + (r3_count * 1)
|
||
|
||
# Clean name
|
||
clean_name = col.replace('Character_Ranking_', '').replace('Top_3_Voices_ranking__', '').replace('_', ' ').strip()
|
||
|
||
scores.append({
|
||
'Character': clean_name,
|
||
'Weighted Score': weighted_score
|
||
})
|
||
|
||
return pl.DataFrame(scores).sort('Weighted Score', descending=True)
|
||
|
||
|
||
def normalize_row_values(df: pl.DataFrame, target_cols: list[str]) -> pl.DataFrame:
|
||
"""
|
||
Normalizes values in the specified columns row-wise to 0-10 scale (Min-Max normalization).
|
||
Formula: ((x - row_min) / (row_max - row_min)) * 10
|
||
|
||
Nulls are preserved as nulls. If all non-null values in a row are equal (max == min),
|
||
those values become 5.0 (midpoint of the scale).
|
||
|
||
Parameters
|
||
----------
|
||
df : pl.DataFrame
|
||
Input dataframe.
|
||
target_cols : list[str]
|
||
List of column names to normalize.
|
||
|
||
Returns
|
||
-------
|
||
pl.DataFrame
|
||
DataFrame with target columns normalized row-wise.
|
||
"""
|
||
# Calculate row min and max across target columns (ignoring nulls)
|
||
row_min = pl.min_horizontal([pl.col(c).cast(pl.Float64) for c in target_cols])
|
||
row_max = pl.max_horizontal([pl.col(c).cast(pl.Float64) for c in target_cols])
|
||
row_range = row_max - row_min
|
||
|
||
# Build normalized column expressions
|
||
norm_exprs = []
|
||
for col in target_cols:
|
||
norm_exprs.append(
|
||
pl.when(row_range == 0)
|
||
.then(
|
||
# If range is 0 (all values equal), return 5.0 for non-null, null for null
|
||
pl.when(pl.col(col).is_null()).then(None).otherwise(5.0)
|
||
)
|
||
.otherwise(
|
||
((pl.col(col).cast(pl.Float64) - row_min) / row_range) * 10
|
||
)
|
||
.alias(col)
|
||
)
|
||
|
||
return df.with_columns(norm_exprs)
|
||
|
||
|
||
def normalize_global_values(df: pl.DataFrame, target_cols: list[str]) -> pl.DataFrame:
|
||
"""
|
||
Normalizes values in the specified columns globally to 0-10 scale.
|
||
Formula: ((x - global_min) / (global_max - global_min)) * 10
|
||
Ignores null values (NaNs).
|
||
"""
|
||
# Ensure eager for scalar extraction
|
||
was_lazy = isinstance(df, pl.LazyFrame)
|
||
if was_lazy:
|
||
df = df.collect()
|
||
|
||
if len(target_cols) == 0:
|
||
return df.lazy() if was_lazy else df
|
||
|
||
# Calculate global stats efficiently by stacking all columns
|
||
# Cast to Float64 to ensure numeric calculations
|
||
stats = df.select([pl.col(c).cast(pl.Float64) for c in target_cols]).melt().select([
|
||
pl.col("value").min().alias("min"),
|
||
pl.col("value").max().alias("max")
|
||
])
|
||
|
||
global_min = stats["min"][0]
|
||
global_max = stats["max"][0]
|
||
|
||
# Handle edge case where all values are same or none exist
|
||
if global_min is None or global_max is None or global_max == global_min:
|
||
return df.lazy() if was_lazy else df
|
||
|
||
global_range = global_max - global_min
|
||
|
||
res = df.with_columns([
|
||
(((pl.col(col).cast(pl.Float64) - global_min) / global_range) * 10).alias(col)
|
||
for col in target_cols
|
||
])
|
||
|
||
return res.lazy() if was_lazy else res
|
||
|
||
|
||
class QualtricsSurvey(QualtricsPlotsMixin):
|
||
"""Class to handle Qualtrics survey data."""
|
||
|
||
def __init__(self, data_path: Union[str, Path], qsf_path: Union[str, Path], figures_dir: Union[str, Path] = None):
|
||
if isinstance(data_path, str):
|
||
data_path = Path(data_path)
|
||
|
||
if isinstance(qsf_path, str):
|
||
qsf_path = Path(qsf_path)
|
||
|
||
self.data_filepath = data_path
|
||
self.qsf_filepath = qsf_path
|
||
self.qid_descr_map = self._extract_qid_descr_map()
|
||
self.qsf:dict = self._load_qsf()
|
||
|
||
if figures_dir:
|
||
self.fig_save_dir = Path(figures_dir)
|
||
else:
|
||
# get export directory name for saving figures ie if data_path='data/exports/OneDrive_2026-01-21/...' should be 'figures/OneDrive_2026-01-21'
|
||
self.fig_save_dir = Path('figures') / self.data_filepath.parts[2]
|
||
|
||
if not self.fig_save_dir.exists():
|
||
self.fig_save_dir.mkdir(parents=True, exist_ok=True)
|
||
|
||
self.data_filtered = None
|
||
self.plot_height = 500
|
||
self.plot_width = 1000
|
||
|
||
# Filter values
|
||
self.filter_age:list = None
|
||
self.filter_gender:list = None
|
||
self.filter_consumer:list = None
|
||
self.filter_ethnicity:list = None
|
||
self.filter_income:list = None
|
||
self.filter_business_owner:list = None # QID4
|
||
self.filter_ai_user:list = None # QID22
|
||
self.filter_investable_assets:list = None # QID16
|
||
self.filter_industry:list = None # QID17
|
||
|
||
|
||
|
||
def _extract_qid_descr_map(self) -> dict:
|
||
"""Extract mapping of Qualtrics ImportID to Question Description from results file."""
|
||
|
||
if '1_1-16-2026' in self.data_filepath.as_posix():
|
||
df_questions = pd.read_csv(self.data_filepath, nrows=1)
|
||
df_questions
|
||
|
||
return df_questions.iloc[0].to_dict()
|
||
|
||
|
||
else:
|
||
# First row contains Qualtrics Editor question names (ie 'B_VOICE SEL. 18-8')
|
||
|
||
# Second row which contains the question content
|
||
# Third row contains the Export Metadata (ie '{"ImportId":"startDate","timeZone":"America/Denver"}')
|
||
df_questions = pd.read_csv(self.data_filepath, nrows=2)
|
||
|
||
|
||
|
||
# transpose df_questions
|
||
df_questions = df_questions.T.reset_index()
|
||
df_questions.columns = ['QName', 'Description', 'export_metadata']
|
||
df_questions['ImportID'] = df_questions['export_metadata'].apply(extract_qid)
|
||
|
||
df_questions = df_questions[['ImportID', 'QName', 'Description']]
|
||
|
||
# return dict as {ImportID: [QName, Description]}
|
||
return df_questions.set_index('ImportID')[['QName', 'Description']].T.to_dict()
|
||
|
||
def _load_qsf(self) -> dict:
|
||
"""Load QSF file to extract question metadata if needed."""
|
||
|
||
with open(self.qsf_filepath, 'r', encoding='utf-8') as f:
|
||
qsf_data = json.load(f)
|
||
return qsf_data
|
||
|
||
def _get_qsf_question_by_QID(self, QID: str) -> dict:
|
||
"""Get question metadata from QSF using the Question ID."""
|
||
|
||
q_elem = [elem for elem in self.qsf['SurveyElements'] if elem['PrimaryAttribute'] == QID]
|
||
|
||
if len(q_elem) == 0:
|
||
raise ValueError(f"SurveyElement with 'PrimaryAttribute': '{QID}' not found in QSF.")
|
||
if len(q_elem) > 1:
|
||
raise ValueError(f"Multiple SurveyElements with 'PrimaryAttribute': '{QID}' found in QSF: \n{q_elem}")
|
||
|
||
return q_elem[0]
|
||
|
||
|
||
def load_data(self) -> pl.LazyFrame:
|
||
"""
|
||
Load CSV where column headers are in row 3 as dict strings with ImportId.
|
||
|
||
The 3rd row contains metadata like '{"ImportId":"startDate","timeZone":"America/Denver"}'.
|
||
This function extracts the ImportId from each column and uses it as the column name.
|
||
|
||
Parameters:
|
||
file_path (Path): Path to the CSV file to load.
|
||
|
||
Returns:
|
||
pl.LazyFrame: Polars LazyFrame with ImportId as column names.
|
||
"""
|
||
if '1_1-16-2026' in self.data_filepath.as_posix():
|
||
raise NotImplementedError("This method does not support the '1_1-16-2026' export format.")
|
||
|
||
# Read the 3rd row (index 2) which contains the metadata dictionaries
|
||
# Use header=None to get raw values instead of treating them as column names
|
||
df_meta = pd.read_csv(self.data_filepath, nrows=1, skiprows=2, header=None)
|
||
|
||
# Extract ImportIds from each column value in this row
|
||
new_columns = [extract_qid(val) for val in df_meta.iloc[0]]
|
||
|
||
# Now read the actual data starting from row 4 (skip first 3 rows)
|
||
df = pl.read_csv(self.data_filepath, skip_rows=3)
|
||
|
||
# Rename columns with the extracted ImportIds
|
||
df.columns = new_columns
|
||
|
||
# Store unique values for filters (ignoring nulls) to detect "all selected" state
|
||
self.options_age = sorted(df['QID1'].drop_nulls().unique().to_list()) if 'QID1' in df.columns else []
|
||
self.options_gender = sorted(df['QID2'].drop_nulls().unique().to_list()) if 'QID2' in df.columns else []
|
||
self.options_consumer = sorted(df['Consumer'].drop_nulls().unique().to_list()) if 'Consumer' in df.columns else []
|
||
self.options_ethnicity = sorted(df['QID3'].drop_nulls().unique().to_list()) if 'QID3' in df.columns else []
|
||
self.options_income = sorted(df['QID15'].drop_nulls().unique().to_list()) if 'QID15' in df.columns else []
|
||
self.options_business_owner = sorted(df['QID4'].drop_nulls().unique().to_list()) if 'QID4' in df.columns else []
|
||
self.options_ai_user = sorted(df['QID22'].drop_nulls().unique().to_list()) if 'QID22' in df.columns else []
|
||
self.options_investable_assets = sorted(df['QID16'].drop_nulls().unique().to_list()) if 'QID16' in df.columns else []
|
||
self.options_industry = sorted(df['QID17'].drop_nulls().unique().to_list()) if 'QID17' in df.columns else []
|
||
|
||
return df.lazy()
|
||
|
||
def _get_subset(self, q: pl.LazyFrame, QIDs, rename_cols=True, include_record_id=True) -> pl.LazyFrame:
|
||
"""Extract subset of data based on specific questions."""
|
||
|
||
if include_record_id and '_recordId' not in QIDs:
|
||
QIDs = ['_recordId'] + QIDs
|
||
|
||
if not rename_cols:
|
||
return q.select(QIDs)
|
||
|
||
rename_dict = {qid: self.qid_descr_map[qid]['QName'] for qid in QIDs if qid in self.qid_descr_map and qid != '_recordId'}
|
||
|
||
return q.select(QIDs).rename(rename_dict)
|
||
|
||
def filter_data(self, q: pl.LazyFrame, age:list=None, gender:list=None, consumer:list=None, ethnicity:list=None, income:list=None, business_owner:list=None, ai_user:list=None, investable_assets:list=None, industry:list=None) -> pl.LazyFrame:
|
||
"""Filter data based on provided parameters
|
||
|
||
Possible parameters:
|
||
- age: list of age groups to include (QID1)
|
||
- gender: list (QID2)
|
||
- consumer: list (Consumer)
|
||
- ethnicity: list (QID3)
|
||
- income: list (QID15)
|
||
- business_owner: list (QID4)
|
||
- ai_user: list (QID22)
|
||
- investable_assets: list (QID16)
|
||
- industry: list (QID17)
|
||
|
||
Also saves the result to self.data_filtered.
|
||
"""
|
||
|
||
# Apply filters - skip if empty list (columns with all NULLs produce empty options)
|
||
# OR if all options are selected (to avoid dropping NULLs)
|
||
|
||
self.filter_age = age
|
||
if age is not None and len(age) > 0 and set(age) != set(self.options_age):
|
||
q = q.filter(pl.col('QID1').is_in(age))
|
||
|
||
self.filter_gender = gender
|
||
if gender is not None and len(gender) > 0 and set(gender) != set(self.options_gender):
|
||
q = q.filter(pl.col('QID2').is_in(gender))
|
||
|
||
self.filter_consumer = consumer
|
||
if consumer is not None and len(consumer) > 0 and set(consumer) != set(self.options_consumer):
|
||
q = q.filter(pl.col('Consumer').is_in(consumer))
|
||
|
||
self.filter_ethnicity = ethnicity
|
||
if ethnicity is not None and len(ethnicity) > 0 and set(ethnicity) != set(self.options_ethnicity):
|
||
q = q.filter(pl.col('QID3').is_in(ethnicity))
|
||
|
||
self.filter_income = income
|
||
if income is not None and len(income) > 0 and set(income) != set(self.options_income):
|
||
q = q.filter(pl.col('QID15').is_in(income))
|
||
|
||
self.filter_business_owner = business_owner
|
||
if business_owner is not None and len(business_owner) > 0 and set(business_owner) != set(self.options_business_owner):
|
||
q = q.filter(pl.col('QID4').is_in(business_owner))
|
||
|
||
self.filter_ai_user = ai_user
|
||
if ai_user is not None and len(ai_user) > 0 and set(ai_user) != set(self.options_ai_user):
|
||
q = q.filter(pl.col('QID22').is_in(ai_user))
|
||
|
||
self.filter_investable_assets = investable_assets
|
||
if investable_assets is not None and len(investable_assets) > 0 and set(investable_assets) != set(self.options_investable_assets):
|
||
q = q.filter(pl.col('QID16').is_in(investable_assets))
|
||
|
||
self.filter_industry = industry
|
||
if industry is not None and len(industry) > 0 and set(industry) != set(self.options_industry):
|
||
q = q.filter(pl.col('QID17').is_in(industry))
|
||
|
||
self.data_filtered = q
|
||
return self.data_filtered
|
||
|
||
def get_demographics(self, q: pl.LazyFrame) -> Union[pl.LazyFrame, None]:
|
||
"""Extract columns containing the demographics.
|
||
|
||
Renames columns using qid_descr_map if provided.
|
||
"""
|
||
QIDs = ['QID1', 'QID2', 'QID3', 'QID4', 'QID7', 'QID13', 'QID14', 'QID15', 'QID16', 'QID17', 'Consumer']
|
||
return self._get_subset(q, QIDs), None
|
||
|
||
|
||
def get_top_8_traits(self, q: pl.LazyFrame) -> Union[pl.LazyFrame, None]:
|
||
"""Extract columns containing the top 8 characteristics are most important for this Chase virtual assistant to have.
|
||
|
||
Returns subquery that can be chained with other polars queries.
|
||
"""
|
||
QIDs = ['QID25']
|
||
return self._get_subset(q, QIDs, rename_cols=False).rename({'QID25': 'Top_8_Traits'}), None
|
||
|
||
|
||
|
||
def get_top_3_traits(self, q: pl.LazyFrame) -> Union[pl.LazyFrame, None]:
|
||
"""Extract columns containing the top 3 characteristics that the Chase virtual assistant should prioritize.
|
||
|
||
Returns subquery that can be chained with other polars queries.
|
||
"""
|
||
QIDs = ['QID26_0_GROUP']
|
||
return self._get_subset(q, QIDs, rename_cols=False).rename({'QID26_0_GROUP': 'Top_3_Traits'}), None
|
||
|
||
|
||
def get_character_ranking(self, q: pl.LazyFrame) -> Union[pl.LazyFrame, None]:
|
||
"""Extract columns containing the ranking of characteristics for the Chase virtual assistant.
|
||
|
||
Returns subquery that can be chained with other polars queries.
|
||
"""
|
||
|
||
|
||
# Requires QSF to map "Character Ranking_2" to the actual character
|
||
cfg = self._get_qsf_question_by_QID('QID27')['Payload']
|
||
|
||
|
||
QIDs_map = {f'QID27_{v}': cfg['VariableNaming'][k] for k,v in cfg['RecodeValues'].items()}
|
||
QIDs_rename = {qid: f'Character_Ranking_{QIDs_map[qid].replace(" ", "_")}' for qid in QIDs_map}
|
||
|
||
return self._get_subset(q, list(QIDs_rename.keys()), rename_cols=False).rename(QIDs_rename), None
|
||
|
||
|
||
def get_18_8_3(self, q: pl.LazyFrame) -> Union[pl.LazyFrame, None]:
|
||
"""Extract columns containing the 18-8-3 feedback for the Chase virtual assistant.
|
||
|
||
Returns subquery that can be chained with other polars queries.
|
||
"""
|
||
QIDs = ['QID29', 'QID101', 'QID36_0_GROUP']
|
||
|
||
rename_dict = {
|
||
'QID29': '18-8_Set-A',
|
||
'QID101': '18-8_Set-B',
|
||
'QID36_0_GROUP': '3_Ranked'
|
||
}
|
||
|
||
subset = self._get_subset(q, QIDs, rename_cols=False).rename(rename_dict)
|
||
|
||
# Combine 18-8 Set A and Set B into single column
|
||
subset = subset.with_columns(
|
||
pl.coalesce(['18-8_Set-A', '18-8_Set-B']).alias('8_Combined')
|
||
)
|
||
# Change order of columns
|
||
subset = subset.select(['_recordId', '18-8_Set-A', '18-8_Set-B', '8_Combined', '3_Ranked'])
|
||
|
||
return subset, None
|
||
|
||
|
||
def get_voice_scale_1_10(self, q: pl.LazyFrame, drop_cols=['Voice_Scale_1_10__V46']) -> Union[pl.LazyFrame, None]:
|
||
"""Extract columns containing the Voice Scale 1-10 ratings for the Chase virtual assistant.
|
||
|
||
Returns subquery that can be chained with other polars queries.
|
||
|
||
Drops scores for V46 as it was improperly configured in the survey and thus did not show up for respondents.
|
||
"""
|
||
|
||
QIDs_map = {}
|
||
|
||
for qid, val in self.qid_descr_map.items():
|
||
if 'Scale 1-10_1' in val['QName']:
|
||
# Convert "Voice 16 Scale 1-10_1" to "Scale_1_10__Voice_16"
|
||
QIDs_map[qid] = f"Voice_Scale_1_10__V{val['QName'].split()[1]}"
|
||
|
||
for col in drop_cols:
|
||
if col in QIDs_map.values():
|
||
# remove from QIDs_map
|
||
qid_to_remove = [k for k,v in QIDs_map.items() if v == col][0]
|
||
del QIDs_map[qid_to_remove]
|
||
|
||
return self._get_subset(q, list(QIDs_map.keys()), rename_cols=False).rename(QIDs_map), None
|
||
|
||
|
||
|
||
def get_ss_green_blue(self, q: pl.LazyFrame) -> Union[pl.LazyFrame, dict]:
|
||
"""Extract columns containing the SS Green/Blue ratings for the Chase virtual assistant.
|
||
|
||
Returns subquery that can be chained with other polars queries.
|
||
"""
|
||
|
||
cfg = self._get_qsf_question_by_QID('QID35')['Payload']
|
||
|
||
QIDs_map = {}
|
||
choices_map = {}
|
||
for qid, val in self.qid_descr_map.items():
|
||
if 'SS Green-Blue' in val['QName']:
|
||
|
||
cfg = self._get_qsf_question_by_QID(qid.split('_')[0])['Payload']
|
||
|
||
# ie: "V14 SS Green-Blue_1"
|
||
qname_parts = val['QName'].split()
|
||
voice = qname_parts[0]
|
||
trait_num = qname_parts[-1].split('_')[-1]
|
||
|
||
QIDs_map[qid] = f"SS_Green_Blue__{voice}__Choice_{trait_num}"
|
||
|
||
choices_map[f"SS_Green_Blue__{voice}__Choice_{trait_num}"] = cfg['Choices'][trait_num]['Display']
|
||
|
||
return self._get_subset(q, list(QIDs_map.keys()), rename_cols=False).rename(QIDs_map), choices_map
|
||
|
||
|
||
def get_top_3_voices(self, q: pl.LazyFrame) -> Union[pl.LazyFrame, None]:
|
||
"""Extract columns containing the top 3 voice choices for the Chase virtual assistant.
|
||
|
||
Returns subquery that can be chained with other polars queries.
|
||
"""
|
||
|
||
QIDs_map = {}
|
||
|
||
cfg36 = self._get_qsf_question_by_QID('QID36')['Payload']
|
||
choice_voice_map = {k: extract_voice_label(v['Display']) for k,v in cfg36['Choices'].items()}
|
||
|
||
|
||
for qid, val in self.qid_descr_map.items():
|
||
if 'Rank Top 3 Voices' in val['QName']:
|
||
|
||
cfg = self._get_qsf_question_by_QID(qid.split('_')[0])['Payload']
|
||
voice_num = val['QName'].split('_')[-1]
|
||
|
||
# Validate that the DynamicChoices Locator is as expected
|
||
if cfg['DynamicChoices']['Locator'] != r"q://QID36/ChoiceGroup/SelectedChoicesInGroup/1":
|
||
raise ValueError(f"Unexpected DynamicChoices Locator for QID '{qid}': {cfg['DynamicChoices']['Locator']}")
|
||
|
||
# extract the voice from the QID36 config
|
||
voice = choice_voice_map[voice_num]
|
||
|
||
# Convert "Top 3 Voices_1" to "Top_3_Voices__V14"
|
||
QIDs_map[qid] = f"Top_3_Voices_ranking__{voice}"
|
||
|
||
return self._get_subset(q, list(QIDs_map.keys()), rename_cols=False).rename(QIDs_map), None
|
||
|
||
def get_top_3_voices_missing_ranking(
|
||
self, q: pl.LazyFrame
|
||
) -> pl.DataFrame:
|
||
"""Identify respondents who completed the top-3 voice selection (QID36)
|
||
but are missing the explicit ranking question (QID98).
|
||
|
||
These respondents picked 3 voices in the selection step and have
|
||
selection-order data in ``QID36_G0_*_RANK``, but all 18 ``QID98_*``
|
||
ranking columns are null. This means ``get_top_3_voices()`` will
|
||
return all-null rows for them, causing plots like
|
||
``plot_most_ranked_1`` to undercount.
|
||
|
||
Parameters:
|
||
q: The (optionally filtered) LazyFrame from ``load_data()``.
|
||
|
||
Returns:
|
||
A collected ``pl.DataFrame`` with columns:
|
||
|
||
- ``_recordId`` – the respondent identifier
|
||
- ``3_Ranked`` – comma-separated text of the 3 voices they selected
|
||
- ``qid36_rank_cols`` – dict-like column with their QID36 selection-
|
||
order values (for reference; these are *not* preference ranks)
|
||
"""
|
||
# Get the top-3 ranking data (QID98-based)
|
||
top3, _ = self.get_top_3_voices(q)
|
||
top3_df = top3.collect()
|
||
|
||
ranking_cols = [c for c in top3_df.columns if c != '_recordId']
|
||
|
||
# Respondents where every QID98 ranking column is null
|
||
all_null_expr = pl.lit(True)
|
||
for col in ranking_cols:
|
||
all_null_expr = all_null_expr & pl.col(col).is_null()
|
||
|
||
missing_ids = top3_df.filter(all_null_expr).select('_recordId')
|
||
|
||
if missing_ids.height == 0:
|
||
return pl.DataFrame(schema={
|
||
'_recordId': pl.Utf8,
|
||
'3_Ranked': pl.Utf8,
|
||
})
|
||
|
||
# Enrich with the 3_Ranked text from the 18→8→3 question
|
||
v_18_8_3, _ = self.get_18_8_3(q)
|
||
v_df = v_18_8_3.collect()
|
||
|
||
result = missing_ids.join(
|
||
v_df.select(['_recordId', '3_Ranked']),
|
||
on='_recordId',
|
||
how='left',
|
||
)
|
||
|
||
return result
|
||
|
||
|
||
def get_ss_orange_red(self, q: pl.LazyFrame) -> Union[pl.LazyFrame, dict]:
|
||
"""Extract columns containing the SS Orange/Red ratings for the Chase virtual assistant.
|
||
|
||
Returns subquery that can be chained with other polars queries.
|
||
"""
|
||
|
||
cfg = self._get_qsf_question_by_QID('QID40')['Payload']
|
||
|
||
QIDs_map = {}
|
||
choices_map = {}
|
||
for qid, val in self.qid_descr_map.items():
|
||
if 'SS Orange-Red' in val['QName']:
|
||
|
||
cfg = self._get_qsf_question_by_QID(qid.split('_')[0])['Payload']
|
||
|
||
# ie: "V14 SS Orange-Red_1"
|
||
qname_parts = val['QName'].split()
|
||
voice = qname_parts[0]
|
||
trait_num = qname_parts[-1].split('_')[-1]
|
||
|
||
QIDs_map[qid] = f"SS_Orange_Red__{voice}__Choice_{trait_num}"
|
||
|
||
choices_map[f"SS_Orange_Red__{voice}__Choice_{trait_num}"] = cfg['Choices'][trait_num]['Display']
|
||
|
||
return self._get_subset(q, list(QIDs_map.keys()), rename_cols=False).rename(QIDs_map), choices_map
|
||
|
||
|
||
def get_character_refine(self, q: pl.LazyFrame) -> Union[pl.LazyFrame, None]:
|
||
"""Extract columns containing the character refine feedback for the Chase virtual assistant.
|
||
|
||
Returns subquery that can be chained with other polars queries.
|
||
"""
|
||
QIDs = ['QID44', 'QID97', 'QID95', 'QID96']
|
||
|
||
return self._get_subset(q, QIDs, rename_cols=True), None
|
||
|
||
def transform_character_trait_frequency(
|
||
self,
|
||
char_df: pl.LazyFrame | pl.DataFrame,
|
||
character_column: str,
|
||
) -> tuple[pl.DataFrame, dict | None]:
|
||
"""Transform character refine data to trait frequency counts for a single character.
|
||
|
||
Original use-case: "I need a bar plot that shows the frequency of the times
|
||
each trait is chosen per brand character."
|
||
|
||
This function takes a DataFrame with comma-separated trait selections per
|
||
character, explodes traits, and counts their frequency for a single character.
|
||
|
||
Args:
|
||
char_df: Pre-fetched data
|
||
Expected columns: '_recordId', '<character_column>' (with comma-separated traits)
|
||
character_column: Name of the character column to analyze (e.g., 'Bank Teller')
|
||
|
||
Returns:
|
||
tuple: (DataFrame with columns ['trait', 'count', 'is_original'], None)
|
||
- 'trait': individual trait name
|
||
- 'count': frequency count
|
||
- 'is_original': boolean indicating if trait is in the original definition
|
||
"""
|
||
from reference import ORIGINAL_CHARACTER_TRAITS
|
||
|
||
if isinstance(char_df, pl.LazyFrame):
|
||
char_df = char_df.collect()
|
||
|
||
# Map display names to reference keys
|
||
character_key_map = {
|
||
'Bank Teller': 'the_bank_teller',
|
||
'Familiar Friend': 'the_familiar_friend',
|
||
'The Coach': 'the_coach',
|
||
'Personal Assistant': 'the_personal_assistant',
|
||
}
|
||
|
||
# Get original traits for this character
|
||
ref_key = character_key_map.get(character_column)
|
||
original_traits = set(ORIGINAL_CHARACTER_TRAITS.get(ref_key, []))
|
||
|
||
# Filter to rows where this character has a value (not null)
|
||
char_data = char_df.filter(pl.col(character_column).is_not_null())
|
||
|
||
# Split comma-separated traits and explode
|
||
exploded = (
|
||
char_data
|
||
.select(
|
||
pl.col(character_column)
|
||
.str.split(',')
|
||
.alias('traits')
|
||
)
|
||
.explode('traits')
|
||
.with_columns(
|
||
pl.col('traits').str.strip_chars().alias('trait')
|
||
)
|
||
.filter(pl.col('trait') != '')
|
||
)
|
||
|
||
# Count trait frequencies
|
||
freq_df = (
|
||
exploded
|
||
.group_by('trait')
|
||
.agg(pl.len().alias('count'))
|
||
.sort('count', descending=True)
|
||
)
|
||
|
||
# Add is_original flag
|
||
freq_df = freq_df.with_columns(
|
||
pl.col('trait').is_in(list(original_traits)).alias('is_original')
|
||
)
|
||
|
||
return freq_df, None
|
||
|
||
def compute_pairwise_significance(
|
||
self,
|
||
data: pl.LazyFrame | pl.DataFrame,
|
||
test_type: str = "auto",
|
||
alpha: float = 0.05,
|
||
correction: str = "bonferroni",
|
||
) -> tuple[pl.DataFrame, dict]:
|
||
"""Compute pairwise statistical significance tests between columns.
|
||
|
||
Original use-case: "I need to test for statistical significance and present
|
||
this in a logical manner. It should be a generalized function to work on
|
||
many dataframes."
|
||
|
||
This function performs pairwise statistical tests between all numeric columns
|
||
(excluding '_recordId') to determine which groups differ significantly.
|
||
|
||
Args:
|
||
data: Pre-fetched data with numeric columns to compare.
|
||
Expected format: rows are observations, columns are groups/categories.
|
||
Example: Voice_Scale_1_10__V14, Voice_Scale_1_10__V04, etc.
|
||
test_type: Statistical test to use:
|
||
- "auto": Automatically chooses based on data (default)
|
||
- "mannwhitney": Mann-Whitney U test (non-parametric, for continuous)
|
||
- "ttest": Independent samples t-test (parametric, for continuous)
|
||
- "chi2": Chi-square test (for count/frequency data)
|
||
alpha: Significance level (default 0.05)
|
||
correction: Multiple comparison correction method:
|
||
- "bonferroni": Bonferroni correction (conservative)
|
||
- "holm": Holm-Bonferroni (less conservative)
|
||
- "none": No correction
|
||
|
||
Returns:
|
||
tuple: (pairwise_df, metadata)
|
||
- pairwise_df: DataFrame with columns ['group1', 'group2', 'p_value',
|
||
'p_adjusted', 'significant', 'effect_size', 'mean1', 'mean2', 'n1', 'n2']
|
||
- metadata: dict with 'test_type', 'alpha', 'correction', 'n_comparisons',
|
||
'overall_test_stat', 'overall_p_value'
|
||
"""
|
||
from scipy import stats as scipy_stats
|
||
import numpy as np
|
||
|
||
if isinstance(data, pl.LazyFrame):
|
||
df = data.collect()
|
||
else:
|
||
df = data
|
||
|
||
# Get numeric columns (exclude _recordId and other non-data columns)
|
||
value_cols = [c for c in df.columns if c != '_recordId' and df[c].dtype in [pl.Float64, pl.Float32, pl.Int64, pl.Int32]]
|
||
|
||
if len(value_cols) < 2:
|
||
raise ValueError(f"Need at least 2 numeric columns for comparison, found {len(value_cols)}")
|
||
|
||
# Auto-detect test type based on data characteristics
|
||
if test_type == "auto":
|
||
# Check if data looks like counts (integers, small range) vs continuous
|
||
sample_col = df[value_cols[0]].drop_nulls()
|
||
if len(sample_col) > 0:
|
||
is_integer = sample_col.dtype in [pl.Int64, pl.Int32]
|
||
unique_ratio = sample_col.n_unique() / len(sample_col)
|
||
if is_integer and unique_ratio < 0.1:
|
||
test_type = "chi2"
|
||
else:
|
||
test_type = "mannwhitney" # Default to non-parametric
|
||
else:
|
||
test_type = "mannwhitney"
|
||
|
||
# Extract data as lists (dropping nulls for each column)
|
||
group_data = {}
|
||
for col in value_cols:
|
||
group_data[col] = df[col].drop_nulls().to_numpy()
|
||
|
||
# Compute overall test (Kruskal-Wallis for non-parametric, ANOVA for parametric)
|
||
all_groups = [group_data[col] for col in value_cols if len(group_data[col]) > 0]
|
||
if test_type in ["mannwhitney", "auto"]:
|
||
overall_stat, overall_p = scipy_stats.kruskal(*all_groups)
|
||
overall_test_name = "Kruskal-Wallis"
|
||
elif test_type == "ttest":
|
||
overall_stat, overall_p = scipy_stats.f_oneway(*all_groups)
|
||
overall_test_name = "One-way ANOVA"
|
||
else:
|
||
overall_stat, overall_p = None, None
|
||
overall_test_name = "N/A (Chi-square)"
|
||
|
||
# Compute pairwise tests
|
||
results = []
|
||
n_comparisons = len(value_cols) * (len(value_cols) - 1) // 2
|
||
|
||
for i, col1 in enumerate(value_cols):
|
||
for col2 in value_cols[i+1:]:
|
||
data1 = group_data[col1]
|
||
data2 = group_data[col2]
|
||
|
||
n1, n2 = len(data1), len(data2)
|
||
mean1 = float(np.mean(data1)) if n1 > 0 else None
|
||
mean2 = float(np.mean(data2)) if n2 > 0 else None
|
||
|
||
# Skip if either group has no data
|
||
if n1 == 0 or n2 == 0:
|
||
results.append({
|
||
'group1': self._clean_voice_label(col1),
|
||
'group2': self._clean_voice_label(col2),
|
||
'p_value': None,
|
||
'effect_size': None,
|
||
'mean1': mean1,
|
||
'mean2': mean2,
|
||
'n1': n1,
|
||
'n2': n2,
|
||
})
|
||
continue
|
||
|
||
# Perform the appropriate test
|
||
if test_type == "mannwhitney":
|
||
stat, p_value = scipy_stats.mannwhitneyu(data1, data2, alternative='two-sided')
|
||
# Effect size: rank-biserial correlation
|
||
effect_size = 1 - (2 * stat) / (n1 * n2)
|
||
elif test_type == "ttest":
|
||
stat, p_value = scipy_stats.ttest_ind(data1, data2)
|
||
# Effect size: Cohen's d
|
||
pooled_std = np.sqrt(((n1-1)*np.std(data1)**2 + (n2-1)*np.std(data2)**2) / (n1+n2-2))
|
||
effect_size = (mean1 - mean2) / pooled_std if pooled_std > 0 else 0
|
||
elif test_type == "chi2":
|
||
# Create contingency table from the two distributions
|
||
# Bin the data for chi-square
|
||
all_data = np.concatenate([data1, data2])
|
||
bins = np.histogram_bin_edges(all_data, bins='auto')
|
||
counts1, _ = np.histogram(data1, bins=bins)
|
||
counts2, _ = np.histogram(data2, bins=bins)
|
||
contingency = np.array([counts1, counts2])
|
||
# Remove zero columns
|
||
contingency = contingency[:, contingency.sum(axis=0) > 0]
|
||
if contingency.shape[1] > 1:
|
||
stat, p_value, _, _ = scipy_stats.chi2_contingency(contingency)
|
||
effect_size = np.sqrt(stat / (contingency.sum() * (min(contingency.shape) - 1)))
|
||
else:
|
||
p_value, effect_size = 1.0, 0.0
|
||
else:
|
||
raise ValueError(f"Unknown test_type: {test_type}")
|
||
|
||
results.append({
|
||
'group1': self._clean_voice_label(col1),
|
||
'group2': self._clean_voice_label(col2),
|
||
'p_value': float(p_value),
|
||
'effect_size': float(effect_size),
|
||
'mean1': mean1,
|
||
'mean2': mean2,
|
||
'n1': n1,
|
||
'n2': n2,
|
||
})
|
||
|
||
# Create DataFrame and apply multiple comparison correction
|
||
results_df = pl.DataFrame(results)
|
||
|
||
# Apply correction
|
||
p_values = results_df['p_value'].to_numpy()
|
||
valid_mask = ~np.isnan(p_values.astype(float))
|
||
p_adjusted = np.full_like(p_values, np.nan, dtype=float)
|
||
|
||
if correction == "bonferroni":
|
||
p_adjusted[valid_mask] = np.minimum(p_values[valid_mask] * n_comparisons, 1.0)
|
||
elif correction == "holm":
|
||
# Holm-Bonferroni step-down procedure
|
||
valid_p = p_values[valid_mask]
|
||
sorted_idx = np.argsort(valid_p)
|
||
sorted_p = valid_p[sorted_idx]
|
||
m = len(sorted_p)
|
||
adjusted = np.zeros(m)
|
||
for j in range(m):
|
||
adjusted[j] = sorted_p[j] * (m - j)
|
||
# Ensure monotonicity
|
||
for j in range(1, m):
|
||
adjusted[j] = max(adjusted[j], adjusted[j-1])
|
||
adjusted = np.minimum(adjusted, 1.0)
|
||
# Restore original order
|
||
p_adjusted[valid_mask] = adjusted[np.argsort(sorted_idx)]
|
||
elif correction == "none":
|
||
p_adjusted = p_values.astype(float)
|
||
|
||
results_df = results_df.with_columns([
|
||
pl.Series('p_adjusted', p_adjusted),
|
||
pl.Series('significant', p_adjusted < alpha),
|
||
])
|
||
|
||
metadata = {
|
||
'test_type': test_type,
|
||
'alpha': alpha,
|
||
'correction': correction,
|
||
'n_comparisons': n_comparisons,
|
||
'overall_test': overall_test_name,
|
||
'overall_stat': overall_stat,
|
||
'overall_p_value': overall_p,
|
||
}
|
||
|
||
return results_df, metadata
|
||
|
||
def compute_ranking_significance(
|
||
self,
|
||
data: pl.LazyFrame | pl.DataFrame,
|
||
alpha: float = 0.05,
|
||
correction: str = "bonferroni",
|
||
) -> tuple[pl.DataFrame, dict]:
|
||
"""Compute statistical significance for ranking data (e.g., Top 3 Voices).
|
||
|
||
Original use-case: "Test whether voices are ranked significantly differently
|
||
based on the distribution of 1st, 2nd, 3rd place votes."
|
||
|
||
This function takes raw ranking data (rows = respondents, columns = voices,
|
||
values = rank 1/2/3 or null) and performs:
|
||
1. Overall chi-square test on the full contingency table
|
||
2. Pairwise proportion tests comparing Rank 1 vote shares
|
||
|
||
Args:
|
||
data: Pre-fetched ranking data from get_top_3_voices() or get_character_ranking().
|
||
Expected format: rows are respondents, columns are voices/characters,
|
||
values are 1, 2, 3 (rank) or null (not ranked).
|
||
alpha: Significance level (default 0.05)
|
||
correction: Multiple comparison correction method:
|
||
- "bonferroni": Bonferroni correction (conservative)
|
||
- "holm": Holm-Bonferroni (less conservative)
|
||
- "none": No correction
|
||
|
||
Returns:
|
||
tuple: (pairwise_df, metadata)
|
||
- pairwise_df: DataFrame with columns ['group1', 'group2', 'p_value',
|
||
'p_adjusted', 'significant', 'rank1_count1', 'rank1_count2',
|
||
'rank1_pct1', 'rank1_pct2', 'total1', 'total2']
|
||
- metadata: dict with 'alpha', 'correction', 'n_comparisons',
|
||
'chi2_stat', 'chi2_p_value', 'contingency_table'
|
||
|
||
Example:
|
||
>>> ranking_data, _ = S.get_top_3_voices(data)
|
||
>>> pairwise_df, meta = S.compute_ranking_significance(ranking_data)
|
||
>>> # See which voices have significantly different Rank 1 proportions
|
||
>>> print(pairwise_df.filter(pl.col('significant') == True))
|
||
"""
|
||
from scipy import stats as scipy_stats
|
||
import numpy as np
|
||
|
||
if isinstance(data, pl.LazyFrame):
|
||
df = data.collect()
|
||
else:
|
||
df = data
|
||
|
||
# Get ranking columns (exclude _recordId)
|
||
ranking_cols = [c for c in df.columns if c != '_recordId']
|
||
|
||
if len(ranking_cols) < 2:
|
||
raise ValueError(f"Need at least 2 ranking columns, found {len(ranking_cols)}")
|
||
|
||
# Build contingency table: rows = ranks (1, 2, 3), columns = voices
|
||
# Count how many times each voice received each rank
|
||
contingency_data = {}
|
||
for col in ranking_cols:
|
||
label = self._clean_voice_label(col)
|
||
r1 = df.filter(pl.col(col) == 1).height
|
||
r2 = df.filter(pl.col(col) == 2).height
|
||
r3 = df.filter(pl.col(col) == 3).height
|
||
contingency_data[label] = [r1, r2, r3]
|
||
|
||
# Create contingency table as numpy array
|
||
labels = list(contingency_data.keys())
|
||
contingency_table = np.array([contingency_data[l] for l in labels]).T # 3 x n_voices
|
||
|
||
# Overall chi-square test on contingency table
|
||
# Tests whether rank distribution is independent of voice
|
||
chi2_stat, chi2_p, chi2_dof, _ = scipy_stats.chi2_contingency(contingency_table)
|
||
|
||
# Pairwise proportion tests for Rank 1 votes
|
||
# We use a two-proportion z-test to compare rank 1 proportions
|
||
results = []
|
||
n_comparisons = len(labels) * (len(labels) - 1) // 2
|
||
|
||
# Total respondents who ranked any voice in top 3
|
||
total_respondents = df.height
|
||
|
||
for i, label1 in enumerate(labels):
|
||
for label2 in labels[i+1:]:
|
||
r1_count1 = contingency_data[label1][0] # Rank 1 votes for voice 1
|
||
r1_count2 = contingency_data[label2][0] # Rank 1 votes for voice 2
|
||
|
||
# Total times each voice was ranked (1st + 2nd + 3rd)
|
||
total1 = sum(contingency_data[label1])
|
||
total2 = sum(contingency_data[label2])
|
||
|
||
# Calculate proportions of Rank 1 out of all rankings for each voice
|
||
pct1 = r1_count1 / total1 if total1 > 0 else 0
|
||
pct2 = r1_count2 / total2 if total2 > 0 else 0
|
||
|
||
# Two-proportion z-test
|
||
# H0: p1 = p2 (both voices have same proportion of Rank 1)
|
||
if total1 > 0 and total2 > 0 and (r1_count1 + r1_count2) > 0:
|
||
# Pooled proportion
|
||
p_pooled = (r1_count1 + r1_count2) / (total1 + total2)
|
||
|
||
# Standard error
|
||
se = np.sqrt(p_pooled * (1 - p_pooled) * (1/total1 + 1/total2))
|
||
|
||
if se > 0:
|
||
z_stat = (pct1 - pct2) / se
|
||
p_value = 2 * (1 - scipy_stats.norm.cdf(abs(z_stat))) # Two-tailed
|
||
else:
|
||
p_value = 1.0
|
||
else:
|
||
p_value = 1.0
|
||
|
||
results.append({
|
||
'group1': label1,
|
||
'group2': label2,
|
||
'p_value': float(p_value),
|
||
'rank1_count1': r1_count1,
|
||
'rank1_count2': r1_count2,
|
||
'rank1_pct1': round(pct1 * 100, 1),
|
||
'rank1_pct2': round(pct2 * 100, 1),
|
||
'total1': total1,
|
||
'total2': total2,
|
||
})
|
||
|
||
# Create DataFrame and apply correction
|
||
results_df = pl.DataFrame(results)
|
||
|
||
p_values = results_df['p_value'].to_numpy()
|
||
p_adjusted = np.full_like(p_values, np.nan, dtype=float)
|
||
|
||
if correction == "bonferroni":
|
||
p_adjusted = np.minimum(p_values * n_comparisons, 1.0)
|
||
elif correction == "holm":
|
||
sorted_idx = np.argsort(p_values)
|
||
sorted_p = p_values[sorted_idx]
|
||
m = len(sorted_p)
|
||
adjusted = np.zeros(m)
|
||
for j in range(m):
|
||
adjusted[j] = sorted_p[j] * (m - j)
|
||
for j in range(1, m):
|
||
adjusted[j] = max(adjusted[j], adjusted[j-1])
|
||
adjusted = np.minimum(adjusted, 1.0)
|
||
p_adjusted = adjusted[np.argsort(sorted_idx)]
|
||
elif correction == "none":
|
||
p_adjusted = p_values.astype(float)
|
||
|
||
results_df = results_df.with_columns([
|
||
pl.Series('p_adjusted', p_adjusted),
|
||
pl.Series('significant', p_adjusted < alpha),
|
||
])
|
||
|
||
# Sort by p_value for easier inspection
|
||
results_df = results_df.sort('p_value')
|
||
|
||
metadata = {
|
||
'test_type': 'proportion_z_test',
|
||
'alpha': alpha,
|
||
'correction': correction,
|
||
'n_comparisons': n_comparisons,
|
||
'chi2_stat': chi2_stat,
|
||
'chi2_p_value': chi2_p,
|
||
'chi2_dof': chi2_dof,
|
||
'overall_test': 'Chi-square',
|
||
'overall_stat': chi2_stat,
|
||
'overall_p_value': chi2_p,
|
||
'contingency_table': {label: contingency_data[label] for label in labels},
|
||
}
|
||
|
||
return results_df, metadata
|
||
|
||
def compute_mentions_significance(
|
||
self,
|
||
data: pl.LazyFrame | pl.DataFrame,
|
||
alpha: float = 0.05,
|
||
correction: str = "bonferroni",
|
||
) -> tuple[pl.DataFrame, dict]:
|
||
"""Compute statistical significance for Total Mentions (Rank 1+2+3).
|
||
|
||
Tests whether the proportion of respondents who included a voice in their Top 3
|
||
is significantly different between voices.
|
||
|
||
Args:
|
||
data: Ranking data (rows=respondents, cols=voices, values=rank).
|
||
alpha: Significance level.
|
||
correction: Multiple comparison correction method.
|
||
|
||
Returns:
|
||
tuple: (pairwise_df, metadata)
|
||
"""
|
||
from scipy import stats as scipy_stats
|
||
import numpy as np
|
||
|
||
if isinstance(data, pl.LazyFrame):
|
||
df = data.collect()
|
||
else:
|
||
df = data
|
||
|
||
ranking_cols = [c for c in df.columns if c != '_recordId']
|
||
if len(ranking_cols) < 2:
|
||
raise ValueError("Need at least 2 ranking columns")
|
||
|
||
total_respondents = df.height
|
||
mentions_data = {}
|
||
|
||
# Count mentions (any rank) for each voice
|
||
for col in ranking_cols:
|
||
label = self._clean_voice_label(col)
|
||
count = df.filter(pl.col(col).is_not_null()).height
|
||
mentions_data[label] = count
|
||
|
||
labels = sorted(list(mentions_data.keys()))
|
||
results = []
|
||
n_comparisons = len(labels) * (len(labels) - 1) // 2
|
||
|
||
for i, label1 in enumerate(labels):
|
||
for label2 in labels[i+1:]:
|
||
count1 = mentions_data[label1]
|
||
count2 = mentions_data[label2]
|
||
|
||
pct1 = count1 / total_respondents
|
||
pct2 = count2 / total_respondents
|
||
|
||
# Z-test for two proportions
|
||
n1 = total_respondents
|
||
n2 = total_respondents
|
||
|
||
p_pooled = (count1 + count2) / (n1 + n2)
|
||
se = np.sqrt(p_pooled * (1 - p_pooled) * (1/n1 + 1/n2))
|
||
|
||
if se > 0:
|
||
z_stat = (pct1 - pct2) / se
|
||
p_value = 2 * (1 - scipy_stats.norm.cdf(abs(z_stat)))
|
||
else:
|
||
p_value = 1.0
|
||
|
||
results.append({
|
||
'group1': label1,
|
||
'group2': label2,
|
||
'p_value': float(p_value),
|
||
'rank1_count1': count1, # Reusing column names for compatibility with heatmap plotting
|
||
'rank1_count2': count2,
|
||
'rank1_pct1': round(pct1 * 100, 1),
|
||
'rank1_pct2': round(pct2 * 100, 1),
|
||
'total1': n1,
|
||
'total2': n2,
|
||
'effect_size': pct1 - pct2 # Difference in proportions
|
||
})
|
||
|
||
results_df = pl.DataFrame(results)
|
||
|
||
p_values = results_df['p_value'].to_numpy()
|
||
p_adjusted = np.full_like(p_values, np.nan, dtype=float)
|
||
|
||
if correction == "bonferroni":
|
||
p_adjusted = np.minimum(p_values * n_comparisons, 1.0)
|
||
elif correction == "holm":
|
||
sorted_idx = np.argsort(p_values)
|
||
sorted_p = p_values[sorted_idx]
|
||
m = len(sorted_p)
|
||
adjusted = np.zeros(m)
|
||
for j in range(m):
|
||
adjusted[j] = sorted_p[j] * (m - j)
|
||
for j in range(1, m):
|
||
adjusted[j] = max(adjusted[j], adjusted[j-1])
|
||
adjusted = np.minimum(adjusted, 1.0)
|
||
p_adjusted = adjusted[np.argsort(sorted_idx)]
|
||
elif correction == "none":
|
||
p_adjusted = p_values.astype(float) # pyright: ignore
|
||
|
||
results_df = results_df.with_columns([
|
||
pl.Series('p_adjusted', p_adjusted),
|
||
pl.Series('significant', p_adjusted < alpha),
|
||
]).sort('p_value')
|
||
|
||
metadata = {
|
||
'test_type': 'proportion_z_test_mentions',
|
||
'alpha': alpha,
|
||
'correction': correction,
|
||
'n_comparisons': n_comparisons,
|
||
}
|
||
|
||
return results_df, metadata
|
||
|
||
def compute_rank1_significance(
|
||
self,
|
||
data: pl.LazyFrame | pl.DataFrame,
|
||
alpha: float = 0.05,
|
||
correction: str = "bonferroni",
|
||
) -> tuple[pl.DataFrame, dict]:
|
||
"""Compute statistical significance for Rank 1 selections only.
|
||
|
||
Like compute_mentions_significance but counts only how many times each
|
||
voice/character was ranked **1st**, using total respondents as the
|
||
denominator. This tests whether first-choice preference differs
|
||
significantly between voices.
|
||
|
||
Args:
|
||
data: Ranking data (rows=respondents, cols=voices, values=rank).
|
||
alpha: Significance level.
|
||
correction: Multiple comparison correction method.
|
||
|
||
Returns:
|
||
tuple: (pairwise_df, metadata)
|
||
"""
|
||
from scipy import stats as scipy_stats
|
||
import numpy as np
|
||
|
||
if isinstance(data, pl.LazyFrame):
|
||
df = data.collect()
|
||
else:
|
||
df = data
|
||
|
||
ranking_cols = [c for c in df.columns if c != '_recordId']
|
||
if len(ranking_cols) < 2:
|
||
raise ValueError("Need at least 2 ranking columns")
|
||
|
||
total_respondents = df.height
|
||
rank1_data: dict[str, int] = {}
|
||
|
||
# Count rank-1 selections for each voice
|
||
for col in ranking_cols:
|
||
label = self._clean_voice_label(col)
|
||
count = df.filter(pl.col(col) == 1).height
|
||
rank1_data[label] = count
|
||
|
||
labels = sorted(list(rank1_data.keys()))
|
||
results = []
|
||
n_comparisons = len(labels) * (len(labels) - 1) // 2
|
||
|
||
for i, label1 in enumerate(labels):
|
||
for label2 in labels[i+1:]:
|
||
count1 = rank1_data[label1]
|
||
count2 = rank1_data[label2]
|
||
|
||
pct1 = count1 / total_respondents
|
||
pct2 = count2 / total_respondents
|
||
|
||
# Z-test for two proportions (same denominator for both)
|
||
n1 = total_respondents
|
||
n2 = total_respondents
|
||
|
||
p_pooled = (count1 + count2) / (n1 + n2)
|
||
se = np.sqrt(p_pooled * (1 - p_pooled) * (1/n1 + 1/n2))
|
||
|
||
if se > 0:
|
||
z_stat = (pct1 - pct2) / se
|
||
p_value = 2 * (1 - scipy_stats.norm.cdf(abs(z_stat)))
|
||
else:
|
||
p_value = 1.0
|
||
|
||
results.append({
|
||
'group1': label1,
|
||
'group2': label2,
|
||
'p_value': float(p_value),
|
||
'rank1_count1': count1,
|
||
'rank1_count2': count2,
|
||
'rank1_pct1': round(pct1 * 100, 1),
|
||
'rank1_pct2': round(pct2 * 100, 1),
|
||
'total1': n1,
|
||
'total2': n2,
|
||
'effect_size': pct1 - pct2,
|
||
})
|
||
|
||
results_df = pl.DataFrame(results)
|
||
|
||
p_values = results_df['p_value'].to_numpy()
|
||
p_adjusted = np.full_like(p_values, np.nan, dtype=float)
|
||
|
||
if correction == "bonferroni":
|
||
p_adjusted = np.minimum(p_values * n_comparisons, 1.0)
|
||
elif correction == "holm":
|
||
sorted_idx = np.argsort(p_values)
|
||
sorted_p = p_values[sorted_idx]
|
||
m = len(sorted_p)
|
||
adjusted = np.zeros(m)
|
||
for j in range(m):
|
||
adjusted[j] = sorted_p[j] * (m - j)
|
||
for j in range(1, m):
|
||
adjusted[j] = max(adjusted[j], adjusted[j-1])
|
||
adjusted = np.minimum(adjusted, 1.0)
|
||
p_adjusted = adjusted[np.argsort(sorted_idx)]
|
||
elif correction == "none":
|
||
p_adjusted = p_values.astype(float) # pyright: ignore
|
||
|
||
results_df = results_df.with_columns([
|
||
pl.Series('p_adjusted', p_adjusted),
|
||
pl.Series('significant', p_adjusted < alpha),
|
||
]).sort('p_value')
|
||
|
||
metadata = {
|
||
'test_type': 'proportion_z_test_rank1',
|
||
'alpha': alpha,
|
||
'correction': correction,
|
||
'n_comparisons': n_comparisons,
|
||
}
|
||
|
||
return results_df, metadata
|
||
|
||
|
||
|
||
def process_speaking_style_data(
|
||
df: Union[pl.LazyFrame, pl.DataFrame],
|
||
trait_map: dict[str, str]
|
||
) -> pl.DataFrame:
|
||
"""
|
||
Process speaking style columns from wide to long format and map trait descriptions.
|
||
|
||
Parses columns with format: SS_{StyleGroup}__{Voice}__{ChoiceID}
|
||
Example: SS_Orange_Red__V14__Choice_1
|
||
|
||
Parameters
|
||
----------
|
||
df : pl.LazyFrame or pl.DataFrame
|
||
Input dataframe containing SS_* columns.
|
||
trait_map : dict
|
||
Dictionary mapping column names to trait descriptions.
|
||
Keys should be full column names like "SS_Orange_Red__V14__Choice_1".
|
||
|
||
Returns
|
||
-------
|
||
pl.DataFrame
|
||
Long-format dataframe with columns:
|
||
_recordId, Voice, Style_Group, Choice_ID, Description, Score, Left_Anchor, Right_Anchor
|
||
"""
|
||
# Normalize input to LazyFrame
|
||
lf = df.lazy() if isinstance(df, pl.DataFrame) else df
|
||
|
||
# 1. Melt SS_ columns
|
||
melted = lf.melt(
|
||
id_vars=["_recordId"],
|
||
value_vars=pl.col("^SS_.*$"),
|
||
variable_name="full_col_name",
|
||
value_name="score"
|
||
)
|
||
|
||
# 2. Extract components from column name
|
||
# Regex captures: Style_Group (e.g. SS_Orange_Red), Voice (e.g. V14), Choice_ID (e.g. Choice_1)
|
||
pattern = r"^(?P<Style_Group>SS_.+?)__(?P<Voice>.+?)__(?P<Choice_ID>Choice_\d+)$"
|
||
|
||
processed = melted.with_columns(
|
||
pl.col("full_col_name").str.extract_groups(pattern)
|
||
).unnest("full_col_name")
|
||
|
||
# 3. Create Mapping Lookup from the provided dictionary
|
||
# We map (Style_Group, Choice_ID) -> Description
|
||
mapping_data = []
|
||
seen = set()
|
||
|
||
for col_name, desc in trait_map.items():
|
||
match = re.match(pattern, col_name)
|
||
if match:
|
||
groups = match.groupdict()
|
||
key = (groups["Style_Group"], groups["Choice_ID"])
|
||
|
||
if key not in seen:
|
||
# Parse description into anchors if possible (Left : Right)
|
||
parts = desc.split(':')
|
||
left_anchor = parts[0].strip() if len(parts) > 0 else ""
|
||
right_anchor = parts[1].strip() if len(parts) > 1 else ""
|
||
|
||
mapping_data.append({
|
||
"Style_Group": groups["Style_Group"],
|
||
"Choice_ID": groups["Choice_ID"],
|
||
"Description": desc,
|
||
"Left_Anchor": left_anchor,
|
||
"Right_Anchor": right_anchor
|
||
})
|
||
seen.add(key)
|
||
|
||
if not mapping_data:
|
||
return processed.collect()
|
||
|
||
mapping_lf = pl.LazyFrame(mapping_data)
|
||
|
||
# 4. Join Data with Mapping
|
||
result = processed.join(
|
||
mapping_lf,
|
||
on=["Style_Group", "Choice_ID"],
|
||
how="left"
|
||
)
|
||
|
||
# 5. Cast score to Int
|
||
result = result.with_columns(
|
||
pl.col("score").cast(pl.Int64, strict=False)
|
||
)
|
||
|
||
return result.collect()
|
||
|
||
|
||
|
||
|
||
|
||
|
||
def process_voice_scale_data(
|
||
df: Union[pl.LazyFrame, pl.DataFrame]
|
||
) -> pl.DataFrame:
|
||
"""
|
||
Process Voice Scale columns from wide to long format.
|
||
|
||
Parses columns with format: Voice_Scale_1_10__V{Voice}
|
||
Example: Voice_Scale_1_10__V14
|
||
|
||
Returns
|
||
-------
|
||
pl.DataFrame
|
||
Long-format dataframe with columns:
|
||
_recordId, Voice, Voice_Scale_Score
|
||
"""
|
||
lf = df.lazy() if isinstance(df, pl.DataFrame) else df
|
||
|
||
# Melt
|
||
melted = lf.melt(
|
||
id_vars=["_recordId"],
|
||
value_vars=pl.col("^Voice_Scale_1_10__V.*$"),
|
||
variable_name="full_col_name",
|
||
value_name="Voice_Scale_Score"
|
||
)
|
||
|
||
# Extract Voice
|
||
processed = melted.with_columns(
|
||
pl.col("full_col_name").str.extract(r"V(\d+)", 1).alias("Voice_Num")
|
||
).with_columns(
|
||
("V" + pl.col("Voice_Num")).alias("Voice")
|
||
)
|
||
|
||
# Keep Score as Float (original data is f64)
|
||
result = processed.select([
|
||
"_recordId",
|
||
"Voice",
|
||
pl.col("Voice_Scale_Score").cast(pl.Float64, strict=False)
|
||
])
|
||
|
||
return result.collect()
|
||
|
||
def join_voice_and_style_data(
|
||
processed_style_data: pl.DataFrame,
|
||
processed_voice_data: pl.DataFrame
|
||
) -> pl.DataFrame:
|
||
"""
|
||
Joins processed Speaking Style data with Voice Scale 1-10 data.
|
||
|
||
Parameters
|
||
----------
|
||
processed_style_data : pl.DataFrame
|
||
Result of process_speaking_style_data
|
||
processed_voice_data : pl.DataFrame
|
||
Result of process_voice_scale_data
|
||
|
||
Returns
|
||
-------
|
||
pl.DataFrame
|
||
Merged dataframe with columns from both, joined on _recordId and Voice.
|
||
"""
|
||
|
||
return processed_style_data.join(
|
||
processed_voice_data,
|
||
on=["_recordId", "Voice"],
|
||
how="inner"
|
||
)
|
||
|
||
|
||
def transform_speaking_style_color_correlation(
|
||
joined_df: pl.LazyFrame | pl.DataFrame,
|
||
speaking_styles: dict[str, list[str]],
|
||
target_column: str = "Voice_Scale_Score"
|
||
) -> tuple[pl.DataFrame, dict | None]:
|
||
"""Aggregate speaking style correlation by color (Green, Blue, Orange, Red).
|
||
|
||
Original use-case: "I want to create high-level correlation plots between
|
||
'green, blue, orange, red' speaking styles and the 'voice scale scores'.
|
||
I want to go to one plot with one bar for each color."
|
||
|
||
This function calculates the mean correlation per speaking style color by
|
||
averaging the correlations of all traits within each color.
|
||
|
||
Parameters
|
||
----------
|
||
joined_df : pl.LazyFrame or pl.DataFrame
|
||
Pre-fetched data from joining speaking style data with target data.
|
||
Must have columns: Right_Anchor, score, and the target_column
|
||
speaking_styles : dict
|
||
Dictionary mapping color names to their constituent traits.
|
||
Typically imported from speaking_styles.SPEAKING_STYLES
|
||
target_column : str
|
||
The column to correlate against speaking style scores.
|
||
Default: "Voice_Scale_Score" (for voice scale 1-10)
|
||
Alternative: "Ranking_Points" (for top 3 voice ranking)
|
||
|
||
Returns
|
||
-------
|
||
tuple[pl.DataFrame, dict | None]
|
||
(DataFrame with columns [Color, correlation, n_traits], None)
|
||
"""
|
||
if isinstance(joined_df, pl.LazyFrame):
|
||
joined_df = joined_df.collect()
|
||
|
||
color_correlations = []
|
||
|
||
for color, traits in speaking_styles.items():
|
||
trait_corrs = []
|
||
for trait in traits:
|
||
# Filter to this specific trait
|
||
subset = joined_df.filter(pl.col("Right_Anchor") == trait)
|
||
valid_data = subset.select(["score", target_column]).drop_nulls()
|
||
|
||
if valid_data.height > 1:
|
||
corr_val = valid_data.select(pl.corr("score", target_column)).item()
|
||
if corr_val is not None:
|
||
trait_corrs.append(corr_val)
|
||
|
||
# Average across all traits for this color
|
||
if trait_corrs:
|
||
avg_corr = sum(trait_corrs) / len(trait_corrs)
|
||
color_correlations.append({
|
||
"Color": color,
|
||
"correlation": avg_corr,
|
||
"n_traits": len(trait_corrs)
|
||
})
|
||
|
||
result_df = pl.DataFrame(color_correlations)
|
||
return result_df, None
|
||
|
||
|
||
def process_voice_ranking_data(
|
||
df: Union[pl.LazyFrame, pl.DataFrame]
|
||
) -> pl.DataFrame:
|
||
"""
|
||
Process Voice Ranking columns from wide to long format and convert ranks to points.
|
||
|
||
Parses columns with format: Top_3_Voices_ranking__V{Voice}
|
||
Converts ranks to points: 1st place = 3 pts, 2nd place = 2 pts, 3rd place = 1 pt
|
||
|
||
Returns
|
||
-------
|
||
pl.DataFrame
|
||
Long-format dataframe with columns:
|
||
_recordId, Voice, Ranking_Points
|
||
"""
|
||
lf = df.lazy() if isinstance(df, pl.DataFrame) else df
|
||
|
||
# Melt
|
||
melted = lf.melt(
|
||
id_vars=["_recordId"],
|
||
value_vars=pl.col("^Top_3_Voices_ranking__V.*$"),
|
||
variable_name="full_col_name",
|
||
value_name="rank"
|
||
)
|
||
|
||
# Extract Voice
|
||
processed = melted.with_columns(
|
||
pl.col("full_col_name").str.extract(r"V(\d+)", 1).alias("Voice_Num")
|
||
).with_columns(
|
||
("V" + pl.col("Voice_Num")).alias("Voice")
|
||
)
|
||
|
||
# Convert rank to points: 1st=3, 2nd=2, 3rd=1, null=0 (not ranked)
|
||
# Rank values are 1, 2, 3 for position in top 3
|
||
result = processed.with_columns(
|
||
pl.when(pl.col("rank") == 1).then(3)
|
||
.when(pl.col("rank") == 2).then(2)
|
||
.when(pl.col("rank") == 3).then(1)
|
||
.otherwise(0)
|
||
.alias("Ranking_Points")
|
||
).select([
|
||
"_recordId",
|
||
"Voice",
|
||
"Ranking_Points"
|
||
])
|
||
|
||
return result.collect()
|
||
|
||
|
||
def split_consumer_groups(df: Union[pl.LazyFrame, pl.DataFrame], col: str = "Consumer") -> dict[str, pl.DataFrame]:
|
||
"""
|
||
Split dataframe into groups based on a column.
|
||
|
||
If col is 'Consumer', it combines A/B subgroups (e.g. Mass_A + Mass_B -> Mass).
|
||
For other columns, it splits by unique values as-is.
|
||
"""
|
||
if isinstance(df, pl.LazyFrame):
|
||
df = df.collect()
|
||
|
||
if col not in df.columns:
|
||
raise ValueError(f"Column '{col}' not found in DataFrame")
|
||
|
||
group_col_alias = f"{col}_Group"
|
||
|
||
if col == "Consumer":
|
||
# Clean Consumer column by removing _A or _B suffix
|
||
# Using regex replacement for trailing _A or _B
|
||
df_clean = df.with_columns(
|
||
pl.col(col)
|
||
.str.replace(r"_[AB]$", "")
|
||
.alias(group_col_alias)
|
||
)
|
||
else:
|
||
# Use values as is
|
||
df_clean = df.with_columns(
|
||
pl.col(col).alias(group_col_alias)
|
||
)
|
||
|
||
# Split into dict
|
||
groups = {}
|
||
unique_groups = df_clean[group_col_alias].drop_nulls().unique().to_list()
|
||
|
||
for group in unique_groups:
|
||
groups[group] = df_clean.filter(pl.col(group_col_alias) == group)
|
||
|
||
return groups
|
||
|
||
|
||
|
||
# Filter SPEAKING_STYLES to only include traits containing any keyword
|
||
def filter_speaking_styles(speaking_styles: dict, keywords: list[str]) -> dict:
|
||
"""Filter speaking styles to only include traits matching any keyword."""
|
||
filtered = {}
|
||
for color, traits in speaking_styles.items():
|
||
matching_traits = [
|
||
trait for trait in traits
|
||
if any(kw.lower() in trait.lower() for kw in keywords)
|
||
]
|
||
if matching_traits:
|
||
filtered[color] = matching_traits
|
||
return filtered |