ecg-fm-api / clinical_analysis.py
mystic_CBK
Fix JSON serialization error: Convert all numpy types to Python native types for clinical analysis
58e60a2
#!/usr/bin/env python3
"""
Clinical Analysis Module for ECG-FM
Handles real clinical predictions from finetuned model
"""
import numpy as np
import torch
from typing import Dict, Any, List
def analyze_ecg_features(model_output: Dict[str, Any]) -> Dict[str, Any]:
"""Extract clinical predictions from finetuned ECG-FM model output"""
try:
# DEBUG: Log what we're receiving
print(f"πŸ” DEBUG: analyze_ecg_features received: {type(model_output)}")
if isinstance(model_output, dict):
print(f"πŸ” DEBUG: Keys: {list(model_output.keys())}")
for key, value in model_output.items():
if isinstance(value, torch.Tensor):
print(f"πŸ” DEBUG: {key} shape: {value.shape}, dtype: {value.dtype}")
else:
print(f"πŸ” DEBUG: {key}: {type(value)} - {value}")
# Check if we have clinical predictions from the finetuned model
if 'label_logits' in model_output:
print("βœ… Found label_logits - using finetuned model output")
# FINETUNED MODEL - Extract real clinical predictions
logits = model_output['label_logits']
if isinstance(logits, torch.Tensor):
probs = torch.sigmoid(logits).detach().cpu().numpy().ravel()
else:
probs = 1 / (1 + np.exp(-np.array(logits).ravel()))
# Extract clinical parameters from probabilities
clinical_result = extract_clinical_from_probabilities(probs)
return clinical_result
# NEW: Check for 'out' key (actual finetuned model output)
elif 'out' in model_output:
print("βœ… Found 'out' key - using finetuned model output")
# FINETUNED MODEL - Extract real clinical predictions
logits = model_output['out']
if isinstance(logits, torch.Tensor):
# Remove batch dimension if present
if logits.dim() == 2: # [batch, num_labels]
logits = logits.squeeze(0) # Remove batch dimension
probs = torch.sigmoid(logits).detach().cpu().numpy().ravel()
else:
probs = 1 / (1 + np.exp(-np.array(logits).ravel()))
# Extract clinical parameters from probabilities
clinical_result = extract_clinical_from_probabilities(probs)
return clinical_result
# NEW: Check if the model output IS the logits tensor directly (classifier model)
elif isinstance(model_output, torch.Tensor):
print("βœ… Found direct logits tensor - using classifier model output")
# The model output is the logits directly
logits = model_output
if logits.dim() == 2: # [batch, num_labels]
logits = logits.squeeze(0) # Remove batch dimension
probs = torch.sigmoid(logits).detach().cpu().numpy().ravel()
clinical_result = extract_clinical_from_probabilities(probs)
return clinical_result
# NEW: Check if model output is a tuple (common in some frameworks)
elif isinstance(model_output, tuple):
print("βœ… Found tuple output - checking for logits")
# Look for logits in the tuple
for item in model_output:
if isinstance(item, torch.Tensor) and item.dim() == 2 and item.shape[1] == 17:
print("βœ… Found logits in tuple - using classifier model output")
logits = item.squeeze(0) # Remove batch dimension
probs = torch.sigmoid(logits).detach().cpu().numpy().ravel()
clinical_result = extract_clinical_from_probabilities(probs)
return clinical_result
print("❌ No suitable logits found in tuple")
return create_fallback_response("Tuple output but no logits found")
elif 'features' in model_output:
# PRETRAINED MODEL - Fallback to feature analysis
features = model_output.get('features', [])
if isinstance(features, torch.Tensor):
features = features.detach().cpu().numpy()
if len(features) > 0:
# Basic clinical estimation from features (fallback)
clinical_result = estimate_clinical_from_features(features)
return clinical_result
else:
return create_fallback_response("Insufficient features")
else:
return create_fallback_response("No clinical data available")
except Exception as e:
print(f"❌ Error in clinical analysis: {e}")
return create_fallback_response("Analysis error")
def extract_clinical_from_probabilities(probs: np.ndarray) -> Dict[str, Any]:
"""Extract clinical findings from probability array using official ECG-FM labels"""
try:
# Load official labels and thresholds
labels = load_label_definitions()
thresholds = load_clinical_thresholds()
if len(probs) != len(labels):
print(f"⚠️ Warning: Probability array length ({len(probs)}) doesn't match label count ({len(labels)})")
# Truncate or pad as needed
if len(probs) > len(labels):
probs = probs[:len(labels)]
else:
probs = np.pad(probs, (0, len(labels) - len(probs)), 'constant', constant_values=0.0)
# Find abnormalities above threshold
abnormalities = []
for i, (label, prob) in enumerate(zip(labels, probs)):
threshold = thresholds.get(label, 0.7)
if prob >= threshold:
abnormalities.append(label)
# Determine rhythm
rhythm = determine_rhythm_from_abnormalities(abnormalities)
# Calculate confidence metrics
confidence_metrics = calculate_confidence_metrics(probs, thresholds)
# Ensure all numpy types are converted to Python native types for JSON serialization
probabilities_list = [float(p) for p in probs]
label_probs_dict = {str(label): float(prob) for label, prob in zip(labels, probs)}
return {
"rhythm": str(rhythm),
"heart_rate": None, # Will be calculated from features if available
"qrs_duration": None, # Will be calculated from features if available
"qt_interval": None, # Will be calculated from features if available
"pr_interval": None, # Will be calculated from features if available
"axis_deviation": "Normal", # Will be calculated from features if available
"abnormalities": [str(abnormality) for abnormality in abnormalities],
"confidence": float(confidence_metrics["overall_confidence"]),
"confidence_level": str(confidence_metrics["confidence_level"]),
"review_required": bool(confidence_metrics["review_required"]),
"probabilities": probabilities_list,
"label_probabilities": label_probs_dict,
"method": "clinical_predictions",
"warning": None,
"labels_used": [str(label) for label in labels],
"thresholds_used": {str(k): float(v) for k, v in thresholds.items()}
}
except Exception as e:
print(f"❌ Error in clinical probability extraction: {e}")
return create_fallback_response(f"Clinical analysis failed: {str(e)}")
def estimate_clinical_from_features(features: np.ndarray) -> Dict[str, Any]:
"""Estimate clinical parameters from ECG features (fallback method)"""
try:
if len(features) == 0:
return create_fallback_response("No features available for estimation")
# ECG-FM features require proper validation and analysis
# We cannot provide reliable clinical estimates without validated algorithms
print("⚠️ Clinical estimation from features requires validated ECG-FM algorithms")
print(" Returning fallback response to prevent incorrect clinical information")
return create_fallback_response("Clinical estimation from features not yet validated")
except Exception as e:
print(f"❌ Error in clinical feature estimation: {e}")
return create_fallback_response(f"Feature estimation error: {str(e)}")
def create_fallback_response(reason: str) -> Dict[str, Any]:
"""Create fallback response when clinical analysis fails"""
return {
"rhythm": "Analysis Failed",
"heart_rate": None,
"qrs_duration": None,
"qt_interval": None,
"pr_interval": None,
"axis_deviation": "Unknown",
"abnormalities": [],
"confidence": 0.0,
"confidence_level": "None",
"review_required": True,
"probabilities": [],
"label_probabilities": {},
"method": "fallback",
"warning": reason,
"labels_used": [],
"thresholds_used": {}
}
# New helper functions for enhanced clinical analysis
def load_label_definitions() -> List[str]:
"""Load official ECG-FM label definitions from CSV file"""
try:
import pandas as pd
df = pd.read_csv('label_def.csv', header=None)
label_names = []
for _, row in df.iterrows():
if len(row) >= 2:
label_names.append(row[1]) # Second column contains label names
# Validate that we have the expected 17 labels
if len(label_names) != 17:
print(f"⚠️ Warning: Expected 17 labels, got {len(label_names)}")
print(f" Labels: {label_names}")
print(f"βœ… Loaded {len(label_names)} official ECG-FM labels")
return label_names
except Exception as e:
print(f"❌ CRITICAL ERROR: Could not load label_def.csv: {e}")
print(" ECG-FM clinical analysis cannot proceed without proper labels")
raise RuntimeError(f"Failed to load ECG-FM label definitions: {e}")
def load_clinical_thresholds() -> Dict[str, float]:
"""Load clinical thresholds from JSON file"""
try:
import json
with open('thresholds.json', 'r') as f:
config = json.load(f)
thresholds = config.get('clinical_thresholds', {})
# Validate that thresholds match our labels
expected_labels = load_label_definitions()
missing_labels = [label for label in expected_labels if label not in thresholds]
if missing_labels:
print(f"⚠️ Warning: Missing thresholds for labels: {missing_labels}")
# Use default threshold for missing labels
for label in missing_labels:
thresholds[label] = 0.7
print(f"βœ… Loaded thresholds for {len(thresholds)} clinical labels")
return thresholds
except Exception as e:
print(f"❌ CRITICAL ERROR: Could not load thresholds.json: {e}")
print(" Using default threshold of 0.7 for all labels")
# Load labels first to create default thresholds
try:
labels = load_label_definitions()
default_thresholds = {label: 0.7 for label in labels}
print(f"βœ… Created default thresholds for {len(default_thresholds)} labels")
return default_thresholds
except Exception as label_error:
print(f"❌ CRITICAL ERROR: Cannot create default thresholds: {label_error}")
raise RuntimeError(f"Failed to load clinical thresholds: {e}")
def determine_rhythm_from_abnormalities(abnormalities: List[str]) -> str:
"""Determine heart rhythm based on detected abnormalities using official ECG-FM labels"""
if not abnormalities:
return "Normal Sinus Rhythm"
# Use official ECG-FM labels for rhythm determination
# Priority-based rhythm determination
if "Atrial fibrillation" in abnormalities:
return "Atrial Fibrillation"
elif "Atrial flutter" in abnormalities:
return "Atrial Flutter"
elif "Ventricular tachycardia" in abnormalities:
return "Ventricular Tachycardia"
elif "Supraventricular tachycardia with aberrancy" in abnormalities:
return "Supraventricular Tachycardia with Aberrancy"
elif "Bradycardia" in abnormalities:
return "Bradycardia"
elif "Tachycardia" in abnormalities:
return "Tachycardia"
elif "Premature ventricular contraction" in abnormalities:
return "Premature Ventricular Contractions"
elif "1st degree atrioventricular block" in abnormalities:
return "1st Degree AV Block"
elif "Atrioventricular block" in abnormalities:
return "AV Block"
elif "Right bundle branch block" in abnormalities:
return "Right Bundle Branch Block"
elif "Left bundle branch block" in abnormalities:
return "Left Bundle Branch Block"
elif "Bifascicular block" in abnormalities:
return "Bifascicular Block"
elif "Accessory pathway conduction" in abnormalities:
return "Accessory Pathway Conduction"
elif "Infarction" in abnormalities:
return "Myocardial Infarction"
elif "Electronic pacemaker" in abnormalities:
return "Electronic Pacemaker"
elif "Poor data quality" in abnormalities:
return "Poor Data Quality - Rhythm Unclear"
else:
return "Abnormal Rhythm"
def calculate_confidence_metrics(probs: np.ndarray, thresholds: Dict[str, float]) -> Dict[str, Any]:
"""Calculate confidence metrics and review flags"""
max_prob = np.max(probs)
mean_prob = np.mean(probs)
# Determine confidence level
if max_prob >= 0.8:
confidence_level = "High"
elif max_prob >= 0.6:
confidence_level = "Medium"
else:
confidence_level = "Low"
# Calculate overall confidence
overall_confidence = float(max_prob)
# Determine if review is required
review_required = max_prob < 0.6 or mean_prob < 0.4
return {
"overall_confidence": float(overall_confidence),
"confidence_level": str(confidence_level),
"review_required": bool(review_required),
"mean_probability": float(mean_prob),
"max_probability": float(max_prob)
}