File size: 7,600 Bytes
0d7408c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
#!/usr/bin/env python3
"""
Diagnostic Script for ECG-FM Model Outputs
Examines actual model outputs to understand clinical analysis issues
"""

import pandas as pd
import requests
import json
import time
import os

# Configuration
API_URL = "https://mystic-cbk-ecg-fm-api.hf.space"
ECG_FILE = "../ecg_uploads_greenwich/ecg_98408931-6f8e-47cc-954a-ba0c058a0f3d.csv"

def diagnose_model_outputs():
    """Diagnose what the models are actually outputting"""
    print("πŸ” DIAGNOSING ECG-FM MODEL OUTPUTS")
    print("=" * 60)
    print(f"🌐 API URL: {API_URL}")
    print(f"πŸ“ ECG File: {ECG_FILE}")
    print()
    
    try:
        # 1. Load ECG data
        print("πŸ“ Loading ECG data...")
        if not os.path.exists(ECG_FILE):
            print(f"❌ ECG file not found: {ECG_FILE}")
            return
        
        df = pd.read_csv(ECG_FILE)
        signal = [df[col].tolist() for col in df.columns]
        
        payload = {
            "signal": signal,
            "fs": 500,
            "lead_names": ["I", "II", "III", "aVR", "aVL", "aVF", "V1", "V2", "V3", "V4", "V5", "V6"]
        }
        
        print(f"βœ… Loaded ECG: {len(signal)} leads, {len(signal[0])} samples")
        
        # 2. Test feature extraction (pretrained model)
        print("\n🧬 Testing Feature Extraction (Pretrained Model)...")
        print("   This should show what the pretrained model outputs")
        
        feature_response = requests.post(
            f"{API_URL}/extract_features",
            json=payload,
            timeout=120
        )
        
        if feature_response.status_code == 200:
            feature_data = feature_response.json()
            print("βœ… Feature extraction successful!")
            print(f"   Features count: {len(feature_data.get('features', []))}")
            print(f"   Input shape: {feature_data.get('input_shape', 'Unknown')}")
            print(f"   Model type: {feature_data.get('model_type', 'Unknown')}")
            
            # Show physiological parameters
            phys_params = feature_data.get('physiological_parameters', {})
            if phys_params:
                print(f"   Physiological parameters: {len(phys_params)}")
                for key, value in phys_params.items():
                    print(f"     {key}: {value}")
        else:
            print(f"❌ Feature extraction failed: {feature_response.status_code}")
            print(f"   Response: {feature_response.text}")
            return
        
        # 3. Test full analysis (both models)
        print("\nπŸ₯ Testing Full Analysis (Both Models)...")
        print("   This should show what both models output together")
        
        analysis_response = requests.post(
            f"{API_URL}/analyze",
            json=payload,
            timeout=180
        )
        
        if analysis_response.status_code == 200:
            analysis_data = analysis_response.json()
            print("βœ… Full analysis successful!")
            
            # Examine clinical analysis
            clinical = analysis_data.get('clinical_analysis', {})
            print(f"\nπŸ“Š Clinical Analysis Details:")
            print(f"   Rhythm: {clinical.get('rhythm', 'Unknown')}")
            print(f"   Heart Rate: {clinical.get('heart_rate', 'Unknown')} BPM")
            print(f"   QRS Duration: {clinical.get('qrs_duration', 'Unknown')} ms")
            print(f"   QT Interval: {clinical.get('qt_interval', 'Unknown')} ms")
            print(f"   PR Interval: {clinical.get('pr_interval', 'Unknown')} ms")
            print(f"   Axis Deviation: {clinical.get('axis_deviation', 'Unknown')}")
            print(f"   Confidence: {clinical.get('confidence', 'Unknown')}")
            print(f"   Method: {clinical.get('method', 'Unknown')}")
            
            # Check for probabilities
            if 'probabilities' in clinical:
                probs = clinical['probabilities']
                print(f"   Probabilities count: {len(probs)}")
                if len(probs) > 0:
                    print(f"   First 5 probabilities: {probs[:5]}")
                    print(f"   Max probability: {max(probs):.4f}")
                    print(f"   Min probability: {min(probs):.4f}")
            
            # Check for label probabilities
            if 'label_probabilities' in clinical:
                label_probs = clinical['label_probabilities']
                print(f"   Label probabilities: {len(label_probs)}")
                if label_probs:
                    print(f"   Sample labels: {list(label_probs.keys())[:5]}")
            
            # Check for abnormalities
            abnormalities = clinical.get('abnormalities', [])
            print(f"   Abnormalities: {abnormalities}")
            
            # Examine physiological parameters
            phys_params = clinical.get('physiological_parameters', {})
            if phys_params:
                print(f"\nπŸ“Š Physiological Parameters (from clinical analysis):")
                for key, value in phys_params.items():
                    print(f"   {key}: {value}")
            
            # Examine features
            features = analysis_data.get('features', [])
            print(f"\nπŸ“Š Features:")
            print(f"   Count: {len(features)}")
            if len(features) > 0:
                print(f"   First 5 values: {features[:5]}")
                print(f"   Last 5 values: {features[-5:]}")
            
            # Examine signal quality
            signal_quality = analysis_data.get('signal_quality', 'Unknown')
            print(f"\nπŸ“Š Signal Quality: {signal_quality}")
            
            # Examine processing time
            processing_time = analysis_data.get('processing_time', 'Unknown')
            print(f"⏱️  Processing Time: {processing_time}s")
            
        else:
            print(f"❌ Full analysis failed: {analysis_response.status_code}")
            print(f"   Response: {analysis_response.text}")
            return
        
        # 4. Summary and diagnosis
        print("\n" + "=" * 60)
        print("πŸ” DIAGNOSIS SUMMARY")
        print("=" * 60)
        
        if clinical.get('method') == 'clinical_predictions':
            print("βœ… Clinical analysis method: clinical_predictions")
            print("   This means the finetuned model is working")
        else:
            print("❌ Clinical analysis method: NOT clinical_predictions")
            print("   This means the finetuned model is not producing proper outputs")
        
        if clinical.get('probabilities'):
            print("βœ… Probabilities are available")
            print(f"   Count: {len(clinical['probabilities'])}")
        else:
            print("❌ No probabilities available")
            print("   This explains why clinical analysis is failing")
        
        if clinical.get('rhythm') != 'Unable to determine':
            print("βœ… Rhythm detection working")
        else:
            print("❌ Rhythm detection failing")
            print("   This suggests clinical model output issues")
        
        print(f"\n🎯 RECOMMENDED ACTIONS:")
        print(f"   1. Check if finetuned model is producing label_logits")
        print(f"   2. Verify model output format matches expectations")
        print(f"   3. Debug clinical_analysis.py logic")
        print(f"   4. Test with simpler ECG data")
        
    except Exception as e:
        print(f"❌ Diagnosis failed with error: {e}")
        import traceback
        traceback.print_exc()

if __name__ == "__main__":
    diagnose_model_outputs()