ecg-fm-api / diagnose_model_outputs.py
mystic_CBK
πŸš€ Deploy ECG-FM v2.1.0 - Physiological Parameter Extraction Now Working! - Added comprehensive physiological parameter extraction (HR, QRS, QT, PR, Axis) using ECG-FM features - Implemented statistical pattern recognition algorithms - Added clinical range validation and confidence scoring - Created comprehensive test script for real ECG samples - Updated documentation and status reports - All endpoints now provide actual measurements instead of null values
0d7408c
#!/usr/bin/env python3
"""
Diagnostic Script for ECG-FM Model Outputs
Examines actual model outputs to understand clinical analysis issues
"""
import pandas as pd
import requests
import json
import time
import os
# Configuration
API_URL = "https://mystic-cbk-ecg-fm-api.hf.space"
ECG_FILE = "../ecg_uploads_greenwich/ecg_98408931-6f8e-47cc-954a-ba0c058a0f3d.csv"
def diagnose_model_outputs():
"""Diagnose what the models are actually outputting"""
print("πŸ” DIAGNOSING ECG-FM MODEL OUTPUTS")
print("=" * 60)
print(f"🌐 API URL: {API_URL}")
print(f"πŸ“ ECG File: {ECG_FILE}")
print()
try:
# 1. Load ECG data
print("πŸ“ Loading ECG data...")
if not os.path.exists(ECG_FILE):
print(f"❌ ECG file not found: {ECG_FILE}")
return
df = pd.read_csv(ECG_FILE)
signal = [df[col].tolist() for col in df.columns]
payload = {
"signal": signal,
"fs": 500,
"lead_names": ["I", "II", "III", "aVR", "aVL", "aVF", "V1", "V2", "V3", "V4", "V5", "V6"]
}
print(f"βœ… Loaded ECG: {len(signal)} leads, {len(signal[0])} samples")
# 2. Test feature extraction (pretrained model)
print("\n🧬 Testing Feature Extraction (Pretrained Model)...")
print(" This should show what the pretrained model outputs")
feature_response = requests.post(
f"{API_URL}/extract_features",
json=payload,
timeout=120
)
if feature_response.status_code == 200:
feature_data = feature_response.json()
print("βœ… Feature extraction successful!")
print(f" Features count: {len(feature_data.get('features', []))}")
print(f" Input shape: {feature_data.get('input_shape', 'Unknown')}")
print(f" Model type: {feature_data.get('model_type', 'Unknown')}")
# Show physiological parameters
phys_params = feature_data.get('physiological_parameters', {})
if phys_params:
print(f" Physiological parameters: {len(phys_params)}")
for key, value in phys_params.items():
print(f" {key}: {value}")
else:
print(f"❌ Feature extraction failed: {feature_response.status_code}")
print(f" Response: {feature_response.text}")
return
# 3. Test full analysis (both models)
print("\nπŸ₯ Testing Full Analysis (Both Models)...")
print(" This should show what both models output together")
analysis_response = requests.post(
f"{API_URL}/analyze",
json=payload,
timeout=180
)
if analysis_response.status_code == 200:
analysis_data = analysis_response.json()
print("βœ… Full analysis successful!")
# Examine clinical analysis
clinical = analysis_data.get('clinical_analysis', {})
print(f"\nπŸ“Š Clinical Analysis Details:")
print(f" Rhythm: {clinical.get('rhythm', 'Unknown')}")
print(f" Heart Rate: {clinical.get('heart_rate', 'Unknown')} BPM")
print(f" QRS Duration: {clinical.get('qrs_duration', 'Unknown')} ms")
print(f" QT Interval: {clinical.get('qt_interval', 'Unknown')} ms")
print(f" PR Interval: {clinical.get('pr_interval', 'Unknown')} ms")
print(f" Axis Deviation: {clinical.get('axis_deviation', 'Unknown')}")
print(f" Confidence: {clinical.get('confidence', 'Unknown')}")
print(f" Method: {clinical.get('method', 'Unknown')}")
# Check for probabilities
if 'probabilities' in clinical:
probs = clinical['probabilities']
print(f" Probabilities count: {len(probs)}")
if len(probs) > 0:
print(f" First 5 probabilities: {probs[:5]}")
print(f" Max probability: {max(probs):.4f}")
print(f" Min probability: {min(probs):.4f}")
# Check for label probabilities
if 'label_probabilities' in clinical:
label_probs = clinical['label_probabilities']
print(f" Label probabilities: {len(label_probs)}")
if label_probs:
print(f" Sample labels: {list(label_probs.keys())[:5]}")
# Check for abnormalities
abnormalities = clinical.get('abnormalities', [])
print(f" Abnormalities: {abnormalities}")
# Examine physiological parameters
phys_params = clinical.get('physiological_parameters', {})
if phys_params:
print(f"\nπŸ“Š Physiological Parameters (from clinical analysis):")
for key, value in phys_params.items():
print(f" {key}: {value}")
# Examine features
features = analysis_data.get('features', [])
print(f"\nπŸ“Š Features:")
print(f" Count: {len(features)}")
if len(features) > 0:
print(f" First 5 values: {features[:5]}")
print(f" Last 5 values: {features[-5:]}")
# Examine signal quality
signal_quality = analysis_data.get('signal_quality', 'Unknown')
print(f"\nπŸ“Š Signal Quality: {signal_quality}")
# Examine processing time
processing_time = analysis_data.get('processing_time', 'Unknown')
print(f"⏱️ Processing Time: {processing_time}s")
else:
print(f"❌ Full analysis failed: {analysis_response.status_code}")
print(f" Response: {analysis_response.text}")
return
# 4. Summary and diagnosis
print("\n" + "=" * 60)
print("πŸ” DIAGNOSIS SUMMARY")
print("=" * 60)
if clinical.get('method') == 'clinical_predictions':
print("βœ… Clinical analysis method: clinical_predictions")
print(" This means the finetuned model is working")
else:
print("❌ Clinical analysis method: NOT clinical_predictions")
print(" This means the finetuned model is not producing proper outputs")
if clinical.get('probabilities'):
print("βœ… Probabilities are available")
print(f" Count: {len(clinical['probabilities'])}")
else:
print("❌ No probabilities available")
print(" This explains why clinical analysis is failing")
if clinical.get('rhythm') != 'Unable to determine':
print("βœ… Rhythm detection working")
else:
print("❌ Rhythm detection failing")
print(" This suggests clinical model output issues")
print(f"\n🎯 RECOMMENDED ACTIONS:")
print(f" 1. Check if finetuned model is producing label_logits")
print(f" 2. Verify model output format matches expectations")
print(f" 3. Debug clinical_analysis.py logic")
print(f" 4. Test with simpler ECG data")
except Exception as e:
print(f"❌ Diagnosis failed with error: {e}")
import traceback
traceback.print_exc()
if __name__ == "__main__":
diagnose_model_outputs()