Spaces:
Sleeping
Sleeping
mystic_CBK
Fix JSON serialization error: Convert all numpy types to Python native types for clinical analysis
58e60a2
| #!/usr/bin/env python3 | |
| """ | |
| Clinical Analysis Module for ECG-FM | |
| Handles real clinical predictions from finetuned model | |
| """ | |
| import numpy as np | |
| import torch | |
| from typing import Dict, Any, List | |
| def analyze_ecg_features(model_output: Dict[str, Any]) -> Dict[str, Any]: | |
| """Extract clinical predictions from finetuned ECG-FM model output""" | |
| try: | |
| # DEBUG: Log what we're receiving | |
| print(f"π DEBUG: analyze_ecg_features received: {type(model_output)}") | |
| if isinstance(model_output, dict): | |
| print(f"π DEBUG: Keys: {list(model_output.keys())}") | |
| for key, value in model_output.items(): | |
| if isinstance(value, torch.Tensor): | |
| print(f"π DEBUG: {key} shape: {value.shape}, dtype: {value.dtype}") | |
| else: | |
| print(f"π DEBUG: {key}: {type(value)} - {value}") | |
| # Check if we have clinical predictions from the finetuned model | |
| if 'label_logits' in model_output: | |
| print("β Found label_logits - using finetuned model output") | |
| # FINETUNED MODEL - Extract real clinical predictions | |
| logits = model_output['label_logits'] | |
| if isinstance(logits, torch.Tensor): | |
| probs = torch.sigmoid(logits).detach().cpu().numpy().ravel() | |
| else: | |
| probs = 1 / (1 + np.exp(-np.array(logits).ravel())) | |
| # Extract clinical parameters from probabilities | |
| clinical_result = extract_clinical_from_probabilities(probs) | |
| return clinical_result | |
| # NEW: Check for 'out' key (actual finetuned model output) | |
| elif 'out' in model_output: | |
| print("β Found 'out' key - using finetuned model output") | |
| # FINETUNED MODEL - Extract real clinical predictions | |
| logits = model_output['out'] | |
| if isinstance(logits, torch.Tensor): | |
| # Remove batch dimension if present | |
| if logits.dim() == 2: # [batch, num_labels] | |
| logits = logits.squeeze(0) # Remove batch dimension | |
| probs = torch.sigmoid(logits).detach().cpu().numpy().ravel() | |
| else: | |
| probs = 1 / (1 + np.exp(-np.array(logits).ravel())) | |
| # Extract clinical parameters from probabilities | |
| clinical_result = extract_clinical_from_probabilities(probs) | |
| return clinical_result | |
| # NEW: Check if the model output IS the logits tensor directly (classifier model) | |
| elif isinstance(model_output, torch.Tensor): | |
| print("β Found direct logits tensor - using classifier model output") | |
| # The model output is the logits directly | |
| logits = model_output | |
| if logits.dim() == 2: # [batch, num_labels] | |
| logits = logits.squeeze(0) # Remove batch dimension | |
| probs = torch.sigmoid(logits).detach().cpu().numpy().ravel() | |
| clinical_result = extract_clinical_from_probabilities(probs) | |
| return clinical_result | |
| # NEW: Check if model output is a tuple (common in some frameworks) | |
| elif isinstance(model_output, tuple): | |
| print("β Found tuple output - checking for logits") | |
| # Look for logits in the tuple | |
| for item in model_output: | |
| if isinstance(item, torch.Tensor) and item.dim() == 2 and item.shape[1] == 17: | |
| print("β Found logits in tuple - using classifier model output") | |
| logits = item.squeeze(0) # Remove batch dimension | |
| probs = torch.sigmoid(logits).detach().cpu().numpy().ravel() | |
| clinical_result = extract_clinical_from_probabilities(probs) | |
| return clinical_result | |
| print("β No suitable logits found in tuple") | |
| return create_fallback_response("Tuple output but no logits found") | |
| elif 'features' in model_output: | |
| # PRETRAINED MODEL - Fallback to feature analysis | |
| features = model_output.get('features', []) | |
| if isinstance(features, torch.Tensor): | |
| features = features.detach().cpu().numpy() | |
| if len(features) > 0: | |
| # Basic clinical estimation from features (fallback) | |
| clinical_result = estimate_clinical_from_features(features) | |
| return clinical_result | |
| else: | |
| return create_fallback_response("Insufficient features") | |
| else: | |
| return create_fallback_response("No clinical data available") | |
| except Exception as e: | |
| print(f"β Error in clinical analysis: {e}") | |
| return create_fallback_response("Analysis error") | |
| def extract_clinical_from_probabilities(probs: np.ndarray) -> Dict[str, Any]: | |
| """Extract clinical findings from probability array using official ECG-FM labels""" | |
| try: | |
| # Load official labels and thresholds | |
| labels = load_label_definitions() | |
| thresholds = load_clinical_thresholds() | |
| if len(probs) != len(labels): | |
| print(f"β οΈ Warning: Probability array length ({len(probs)}) doesn't match label count ({len(labels)})") | |
| # Truncate or pad as needed | |
| if len(probs) > len(labels): | |
| probs = probs[:len(labels)] | |
| else: | |
| probs = np.pad(probs, (0, len(labels) - len(probs)), 'constant', constant_values=0.0) | |
| # Find abnormalities above threshold | |
| abnormalities = [] | |
| for i, (label, prob) in enumerate(zip(labels, probs)): | |
| threshold = thresholds.get(label, 0.7) | |
| if prob >= threshold: | |
| abnormalities.append(label) | |
| # Determine rhythm | |
| rhythm = determine_rhythm_from_abnormalities(abnormalities) | |
| # Calculate confidence metrics | |
| confidence_metrics = calculate_confidence_metrics(probs, thresholds) | |
| # Ensure all numpy types are converted to Python native types for JSON serialization | |
| probabilities_list = [float(p) for p in probs] | |
| label_probs_dict = {str(label): float(prob) for label, prob in zip(labels, probs)} | |
| return { | |
| "rhythm": str(rhythm), | |
| "heart_rate": None, # Will be calculated from features if available | |
| "qrs_duration": None, # Will be calculated from features if available | |
| "qt_interval": None, # Will be calculated from features if available | |
| "pr_interval": None, # Will be calculated from features if available | |
| "axis_deviation": "Normal", # Will be calculated from features if available | |
| "abnormalities": [str(abnormality) for abnormality in abnormalities], | |
| "confidence": float(confidence_metrics["overall_confidence"]), | |
| "confidence_level": str(confidence_metrics["confidence_level"]), | |
| "review_required": bool(confidence_metrics["review_required"]), | |
| "probabilities": probabilities_list, | |
| "label_probabilities": label_probs_dict, | |
| "method": "clinical_predictions", | |
| "warning": None, | |
| "labels_used": [str(label) for label in labels], | |
| "thresholds_used": {str(k): float(v) for k, v in thresholds.items()} | |
| } | |
| except Exception as e: | |
| print(f"β Error in clinical probability extraction: {e}") | |
| return create_fallback_response(f"Clinical analysis failed: {str(e)}") | |
| def estimate_clinical_from_features(features: np.ndarray) -> Dict[str, Any]: | |
| """Estimate clinical parameters from ECG features (fallback method)""" | |
| try: | |
| if len(features) == 0: | |
| return create_fallback_response("No features available for estimation") | |
| # ECG-FM features require proper validation and analysis | |
| # We cannot provide reliable clinical estimates without validated algorithms | |
| print("β οΈ Clinical estimation from features requires validated ECG-FM algorithms") | |
| print(" Returning fallback response to prevent incorrect clinical information") | |
| return create_fallback_response("Clinical estimation from features not yet validated") | |
| except Exception as e: | |
| print(f"β Error in clinical feature estimation: {e}") | |
| return create_fallback_response(f"Feature estimation error: {str(e)}") | |
| def create_fallback_response(reason: str) -> Dict[str, Any]: | |
| """Create fallback response when clinical analysis fails""" | |
| return { | |
| "rhythm": "Analysis Failed", | |
| "heart_rate": None, | |
| "qrs_duration": None, | |
| "qt_interval": None, | |
| "pr_interval": None, | |
| "axis_deviation": "Unknown", | |
| "abnormalities": [], | |
| "confidence": 0.0, | |
| "confidence_level": "None", | |
| "review_required": True, | |
| "probabilities": [], | |
| "label_probabilities": {}, | |
| "method": "fallback", | |
| "warning": reason, | |
| "labels_used": [], | |
| "thresholds_used": {} | |
| } | |
| # New helper functions for enhanced clinical analysis | |
| def load_label_definitions() -> List[str]: | |
| """Load official ECG-FM label definitions from CSV file""" | |
| try: | |
| import pandas as pd | |
| df = pd.read_csv('label_def.csv', header=None) | |
| label_names = [] | |
| for _, row in df.iterrows(): | |
| if len(row) >= 2: | |
| label_names.append(row[1]) # Second column contains label names | |
| # Validate that we have the expected 17 labels | |
| if len(label_names) != 17: | |
| print(f"β οΈ Warning: Expected 17 labels, got {len(label_names)}") | |
| print(f" Labels: {label_names}") | |
| print(f"β Loaded {len(label_names)} official ECG-FM labels") | |
| return label_names | |
| except Exception as e: | |
| print(f"β CRITICAL ERROR: Could not load label_def.csv: {e}") | |
| print(" ECG-FM clinical analysis cannot proceed without proper labels") | |
| raise RuntimeError(f"Failed to load ECG-FM label definitions: {e}") | |
| def load_clinical_thresholds() -> Dict[str, float]: | |
| """Load clinical thresholds from JSON file""" | |
| try: | |
| import json | |
| with open('thresholds.json', 'r') as f: | |
| config = json.load(f) | |
| thresholds = config.get('clinical_thresholds', {}) | |
| # Validate that thresholds match our labels | |
| expected_labels = load_label_definitions() | |
| missing_labels = [label for label in expected_labels if label not in thresholds] | |
| if missing_labels: | |
| print(f"β οΈ Warning: Missing thresholds for labels: {missing_labels}") | |
| # Use default threshold for missing labels | |
| for label in missing_labels: | |
| thresholds[label] = 0.7 | |
| print(f"β Loaded thresholds for {len(thresholds)} clinical labels") | |
| return thresholds | |
| except Exception as e: | |
| print(f"β CRITICAL ERROR: Could not load thresholds.json: {e}") | |
| print(" Using default threshold of 0.7 for all labels") | |
| # Load labels first to create default thresholds | |
| try: | |
| labels = load_label_definitions() | |
| default_thresholds = {label: 0.7 for label in labels} | |
| print(f"β Created default thresholds for {len(default_thresholds)} labels") | |
| return default_thresholds | |
| except Exception as label_error: | |
| print(f"β CRITICAL ERROR: Cannot create default thresholds: {label_error}") | |
| raise RuntimeError(f"Failed to load clinical thresholds: {e}") | |
| def determine_rhythm_from_abnormalities(abnormalities: List[str]) -> str: | |
| """Determine heart rhythm based on detected abnormalities using official ECG-FM labels""" | |
| if not abnormalities: | |
| return "Normal Sinus Rhythm" | |
| # Use official ECG-FM labels for rhythm determination | |
| # Priority-based rhythm determination | |
| if "Atrial fibrillation" in abnormalities: | |
| return "Atrial Fibrillation" | |
| elif "Atrial flutter" in abnormalities: | |
| return "Atrial Flutter" | |
| elif "Ventricular tachycardia" in abnormalities: | |
| return "Ventricular Tachycardia" | |
| elif "Supraventricular tachycardia with aberrancy" in abnormalities: | |
| return "Supraventricular Tachycardia with Aberrancy" | |
| elif "Bradycardia" in abnormalities: | |
| return "Bradycardia" | |
| elif "Tachycardia" in abnormalities: | |
| return "Tachycardia" | |
| elif "Premature ventricular contraction" in abnormalities: | |
| return "Premature Ventricular Contractions" | |
| elif "1st degree atrioventricular block" in abnormalities: | |
| return "1st Degree AV Block" | |
| elif "Atrioventricular block" in abnormalities: | |
| return "AV Block" | |
| elif "Right bundle branch block" in abnormalities: | |
| return "Right Bundle Branch Block" | |
| elif "Left bundle branch block" in abnormalities: | |
| return "Left Bundle Branch Block" | |
| elif "Bifascicular block" in abnormalities: | |
| return "Bifascicular Block" | |
| elif "Accessory pathway conduction" in abnormalities: | |
| return "Accessory Pathway Conduction" | |
| elif "Infarction" in abnormalities: | |
| return "Myocardial Infarction" | |
| elif "Electronic pacemaker" in abnormalities: | |
| return "Electronic Pacemaker" | |
| elif "Poor data quality" in abnormalities: | |
| return "Poor Data Quality - Rhythm Unclear" | |
| else: | |
| return "Abnormal Rhythm" | |
| def calculate_confidence_metrics(probs: np.ndarray, thresholds: Dict[str, float]) -> Dict[str, Any]: | |
| """Calculate confidence metrics and review flags""" | |
| max_prob = np.max(probs) | |
| mean_prob = np.mean(probs) | |
| # Determine confidence level | |
| if max_prob >= 0.8: | |
| confidence_level = "High" | |
| elif max_prob >= 0.6: | |
| confidence_level = "Medium" | |
| else: | |
| confidence_level = "Low" | |
| # Calculate overall confidence | |
| overall_confidence = float(max_prob) | |
| # Determine if review is required | |
| review_required = max_prob < 0.6 or mean_prob < 0.4 | |
| return { | |
| "overall_confidence": float(overall_confidence), | |
| "confidence_level": str(confidence_level), | |
| "review_required": bool(review_required), | |
| "mean_probability": float(mean_prob), | |
| "max_probability": float(max_prob) | |
| } | |