Spaces:
Running
Running
File size: 14,459 Bytes
31b6ae7 660f0f8 31b6ae7 660f0f8 31b6ae7 d205f72 10b881c d205f72 31b6ae7 0d7408c 31b6ae7 0d7408c 31b6ae7 0d7408c 31b6ae7 0d7408c 31b6ae7 0d7408c 31b6ae7 0d7408c 31b6ae7 58e60a2 31b6ae7 58e60a2 0d7408c 58e60a2 31b6ae7 0d7408c 58e60a2 31b6ae7 0d7408c 31b6ae7 0d7408c 31b6ae7 0d7408c 31b6ae7 0d7408c 31b6ae7 0d7408c 31b6ae7 0d7408c 31b6ae7 0d7408c 31b6ae7 0d7408c 31b6ae7 0d7408c 31b6ae7 0d7408c 31b6ae7 0d7408c 31b6ae7 0d7408c 31b6ae7 0d7408c 31b6ae7 0d7408c 31b6ae7 0d7408c 31b6ae7 0d7408c 31b6ae7 0d7408c 31b6ae7 0d7408c 31b6ae7 0d7408c 31b6ae7 58e60a2 31b6ae7 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 |
#!/usr/bin/env python3
"""
Clinical Analysis Module for ECG-FM
Handles real clinical predictions from finetuned model
"""
import numpy as np
import torch
from typing import Dict, Any, List
def analyze_ecg_features(model_output: Dict[str, Any]) -> Dict[str, Any]:
"""Extract clinical predictions from finetuned ECG-FM model output"""
try:
# DEBUG: Log what we're receiving
print(f"π DEBUG: analyze_ecg_features received: {type(model_output)}")
if isinstance(model_output, dict):
print(f"π DEBUG: Keys: {list(model_output.keys())}")
for key, value in model_output.items():
if isinstance(value, torch.Tensor):
print(f"π DEBUG: {key} shape: {value.shape}, dtype: {value.dtype}")
else:
print(f"π DEBUG: {key}: {type(value)} - {value}")
# Check if we have clinical predictions from the finetuned model
if 'label_logits' in model_output:
print("β
Found label_logits - using finetuned model output")
# FINETUNED MODEL - Extract real clinical predictions
logits = model_output['label_logits']
if isinstance(logits, torch.Tensor):
probs = torch.sigmoid(logits).detach().cpu().numpy().ravel()
else:
probs = 1 / (1 + np.exp(-np.array(logits).ravel()))
# Extract clinical parameters from probabilities
clinical_result = extract_clinical_from_probabilities(probs)
return clinical_result
# NEW: Check for 'out' key (actual finetuned model output)
elif 'out' in model_output:
print("β
Found 'out' key - using finetuned model output")
# FINETUNED MODEL - Extract real clinical predictions
logits = model_output['out']
if isinstance(logits, torch.Tensor):
# Remove batch dimension if present
if logits.dim() == 2: # [batch, num_labels]
logits = logits.squeeze(0) # Remove batch dimension
probs = torch.sigmoid(logits).detach().cpu().numpy().ravel()
else:
probs = 1 / (1 + np.exp(-np.array(logits).ravel()))
# Extract clinical parameters from probabilities
clinical_result = extract_clinical_from_probabilities(probs)
return clinical_result
# NEW: Check if the model output IS the logits tensor directly (classifier model)
elif isinstance(model_output, torch.Tensor):
print("β
Found direct logits tensor - using classifier model output")
# The model output is the logits directly
logits = model_output
if logits.dim() == 2: # [batch, num_labels]
logits = logits.squeeze(0) # Remove batch dimension
probs = torch.sigmoid(logits).detach().cpu().numpy().ravel()
clinical_result = extract_clinical_from_probabilities(probs)
return clinical_result
# NEW: Check if model output is a tuple (common in some frameworks)
elif isinstance(model_output, tuple):
print("β
Found tuple output - checking for logits")
# Look for logits in the tuple
for item in model_output:
if isinstance(item, torch.Tensor) and item.dim() == 2 and item.shape[1] == 17:
print("β
Found logits in tuple - using classifier model output")
logits = item.squeeze(0) # Remove batch dimension
probs = torch.sigmoid(logits).detach().cpu().numpy().ravel()
clinical_result = extract_clinical_from_probabilities(probs)
return clinical_result
print("β No suitable logits found in tuple")
return create_fallback_response("Tuple output but no logits found")
elif 'features' in model_output:
# PRETRAINED MODEL - Fallback to feature analysis
features = model_output.get('features', [])
if isinstance(features, torch.Tensor):
features = features.detach().cpu().numpy()
if len(features) > 0:
# Basic clinical estimation from features (fallback)
clinical_result = estimate_clinical_from_features(features)
return clinical_result
else:
return create_fallback_response("Insufficient features")
else:
return create_fallback_response("No clinical data available")
except Exception as e:
print(f"β Error in clinical analysis: {e}")
return create_fallback_response("Analysis error")
def extract_clinical_from_probabilities(probs: np.ndarray) -> Dict[str, Any]:
"""Extract clinical findings from probability array using official ECG-FM labels"""
try:
# Load official labels and thresholds
labels = load_label_definitions()
thresholds = load_clinical_thresholds()
if len(probs) != len(labels):
print(f"β οΈ Warning: Probability array length ({len(probs)}) doesn't match label count ({len(labels)})")
# Truncate or pad as needed
if len(probs) > len(labels):
probs = probs[:len(labels)]
else:
probs = np.pad(probs, (0, len(labels) - len(probs)), 'constant', constant_values=0.0)
# Find abnormalities above threshold
abnormalities = []
for i, (label, prob) in enumerate(zip(labels, probs)):
threshold = thresholds.get(label, 0.7)
if prob >= threshold:
abnormalities.append(label)
# Determine rhythm
rhythm = determine_rhythm_from_abnormalities(abnormalities)
# Calculate confidence metrics
confidence_metrics = calculate_confidence_metrics(probs, thresholds)
# Ensure all numpy types are converted to Python native types for JSON serialization
probabilities_list = [float(p) for p in probs]
label_probs_dict = {str(label): float(prob) for label, prob in zip(labels, probs)}
return {
"rhythm": str(rhythm),
"heart_rate": None, # Will be calculated from features if available
"qrs_duration": None, # Will be calculated from features if available
"qt_interval": None, # Will be calculated from features if available
"pr_interval": None, # Will be calculated from features if available
"axis_deviation": "Normal", # Will be calculated from features if available
"abnormalities": [str(abnormality) for abnormality in abnormalities],
"confidence": float(confidence_metrics["overall_confidence"]),
"confidence_level": str(confidence_metrics["confidence_level"]),
"review_required": bool(confidence_metrics["review_required"]),
"probabilities": probabilities_list,
"label_probabilities": label_probs_dict,
"method": "clinical_predictions",
"warning": None,
"labels_used": [str(label) for label in labels],
"thresholds_used": {str(k): float(v) for k, v in thresholds.items()}
}
except Exception as e:
print(f"β Error in clinical probability extraction: {e}")
return create_fallback_response(f"Clinical analysis failed: {str(e)}")
def estimate_clinical_from_features(features: np.ndarray) -> Dict[str, Any]:
"""Estimate clinical parameters from ECG features (fallback method)"""
try:
if len(features) == 0:
return create_fallback_response("No features available for estimation")
# ECG-FM features require proper validation and analysis
# We cannot provide reliable clinical estimates without validated algorithms
print("β οΈ Clinical estimation from features requires validated ECG-FM algorithms")
print(" Returning fallback response to prevent incorrect clinical information")
return create_fallback_response("Clinical estimation from features not yet validated")
except Exception as e:
print(f"β Error in clinical feature estimation: {e}")
return create_fallback_response(f"Feature estimation error: {str(e)}")
def create_fallback_response(reason: str) -> Dict[str, Any]:
"""Create fallback response when clinical analysis fails"""
return {
"rhythm": "Analysis Failed",
"heart_rate": None,
"qrs_duration": None,
"qt_interval": None,
"pr_interval": None,
"axis_deviation": "Unknown",
"abnormalities": [],
"confidence": 0.0,
"confidence_level": "None",
"review_required": True,
"probabilities": [],
"label_probabilities": {},
"method": "fallback",
"warning": reason,
"labels_used": [],
"thresholds_used": {}
}
# New helper functions for enhanced clinical analysis
def load_label_definitions() -> List[str]:
"""Load official ECG-FM label definitions from CSV file"""
try:
import pandas as pd
df = pd.read_csv('label_def.csv', header=None)
label_names = []
for _, row in df.iterrows():
if len(row) >= 2:
label_names.append(row[1]) # Second column contains label names
# Validate that we have the expected 17 labels
if len(label_names) != 17:
print(f"β οΈ Warning: Expected 17 labels, got {len(label_names)}")
print(f" Labels: {label_names}")
print(f"β
Loaded {len(label_names)} official ECG-FM labels")
return label_names
except Exception as e:
print(f"β CRITICAL ERROR: Could not load label_def.csv: {e}")
print(" ECG-FM clinical analysis cannot proceed without proper labels")
raise RuntimeError(f"Failed to load ECG-FM label definitions: {e}")
def load_clinical_thresholds() -> Dict[str, float]:
"""Load clinical thresholds from JSON file"""
try:
import json
with open('thresholds.json', 'r') as f:
config = json.load(f)
thresholds = config.get('clinical_thresholds', {})
# Validate that thresholds match our labels
expected_labels = load_label_definitions()
missing_labels = [label for label in expected_labels if label not in thresholds]
if missing_labels:
print(f"β οΈ Warning: Missing thresholds for labels: {missing_labels}")
# Use default threshold for missing labels
for label in missing_labels:
thresholds[label] = 0.7
print(f"β
Loaded thresholds for {len(thresholds)} clinical labels")
return thresholds
except Exception as e:
print(f"β CRITICAL ERROR: Could not load thresholds.json: {e}")
print(" Using default threshold of 0.7 for all labels")
# Load labels first to create default thresholds
try:
labels = load_label_definitions()
default_thresholds = {label: 0.7 for label in labels}
print(f"β
Created default thresholds for {len(default_thresholds)} labels")
return default_thresholds
except Exception as label_error:
print(f"β CRITICAL ERROR: Cannot create default thresholds: {label_error}")
raise RuntimeError(f"Failed to load clinical thresholds: {e}")
def determine_rhythm_from_abnormalities(abnormalities: List[str]) -> str:
"""Determine heart rhythm based on detected abnormalities using official ECG-FM labels"""
if not abnormalities:
return "Normal Sinus Rhythm"
# Use official ECG-FM labels for rhythm determination
# Priority-based rhythm determination
if "Atrial fibrillation" in abnormalities:
return "Atrial Fibrillation"
elif "Atrial flutter" in abnormalities:
return "Atrial Flutter"
elif "Ventricular tachycardia" in abnormalities:
return "Ventricular Tachycardia"
elif "Supraventricular tachycardia with aberrancy" in abnormalities:
return "Supraventricular Tachycardia with Aberrancy"
elif "Bradycardia" in abnormalities:
return "Bradycardia"
elif "Tachycardia" in abnormalities:
return "Tachycardia"
elif "Premature ventricular contraction" in abnormalities:
return "Premature Ventricular Contractions"
elif "1st degree atrioventricular block" in abnormalities:
return "1st Degree AV Block"
elif "Atrioventricular block" in abnormalities:
return "AV Block"
elif "Right bundle branch block" in abnormalities:
return "Right Bundle Branch Block"
elif "Left bundle branch block" in abnormalities:
return "Left Bundle Branch Block"
elif "Bifascicular block" in abnormalities:
return "Bifascicular Block"
elif "Accessory pathway conduction" in abnormalities:
return "Accessory Pathway Conduction"
elif "Infarction" in abnormalities:
return "Myocardial Infarction"
elif "Electronic pacemaker" in abnormalities:
return "Electronic Pacemaker"
elif "Poor data quality" in abnormalities:
return "Poor Data Quality - Rhythm Unclear"
else:
return "Abnormal Rhythm"
def calculate_confidence_metrics(probs: np.ndarray, thresholds: Dict[str, float]) -> Dict[str, Any]:
"""Calculate confidence metrics and review flags"""
max_prob = np.max(probs)
mean_prob = np.mean(probs)
# Determine confidence level
if max_prob >= 0.8:
confidence_level = "High"
elif max_prob >= 0.6:
confidence_level = "Medium"
else:
confidence_level = "Low"
# Calculate overall confidence
overall_confidence = float(max_prob)
# Determine if review is required
review_required = max_prob < 0.6 or mean_prob < 0.4
return {
"overall_confidence": float(overall_confidence),
"confidence_level": str(confidence_level),
"review_required": bool(review_required),
"mean_probability": float(mean_prob),
"max_probability": float(max_prob)
}
|