File size: 6,154 Bytes
31b6ae7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
#!/usr/bin/env python3
"""
Discover ECG-FM Model Labels
Inspect the actual labels that the finetuned model outputs
"""

import torch
import numpy as np
import json
from typing import Dict, Any, List
import requests
import time

def test_model_with_sample_ecg():
    """Test the deployed model to see what labels it actually outputs"""
    
    print("πŸ” Discovering ECG-FM Model Labels")
    print("=" * 50)
    
    # Test with a simple ECG signal
    # Create a minimal 12-lead ECG signal (500 samples, 12 leads)
    sample_ecg = np.random.normal(0, 0.1, (12, 500)).tolist()
    
    payload = {
        "signal": sample_ecg,
        "fs": 500,
        "lead_names": ["I", "II", "III", "aVR", "aVL", "aVF", "V1", "V2", "V3", "V4", "V5", "V6"],
        "recording_duration": 1.0
    }
    
    print("πŸ“Š Testing with sample ECG signal...")
    print(f"   Signal shape: {len(sample_ecg)} leads x {len(sample_ecg[0])} samples")
    
    # Test the deployed API
    api_url = "https://mystic-cbk-ecg-fm-api.hf.space"
    
    try:
        print(f"\n🌐 Testing deployed API: {api_url}")
        
        # Test health first
        health_response = requests.get(f"{api_url}/health", timeout=30)
        if health_response.status_code == 200:
            print("βœ… API is healthy")
        else:
            print(f"❌ API health check failed: {health_response.status_code}")
            return
        
        # Test full analysis
        print("\nπŸ”¬ Testing full ECG analysis...")
        analysis_response = requests.post(
            f"{api_url}/analyze",
            json=payload,
            timeout=180
        )
        
        if analysis_response.status_code == 200:
            result = analysis_response.json()
            print("βœ… Analysis successful!")
            
            # Inspect the response structure
            print("\nπŸ“‹ Response Structure Analysis:")
            print(f"   Keys: {list(result.keys())}")
            
            if 'clinical_analysis' in result:
                clinical = result['clinical_analysis']
                print(f"\nπŸ₯ Clinical Analysis Keys: {list(clinical.keys())}")
                
                if 'label_probabilities' in clinical:
                    label_probs = clinical['label_probabilities']
                    print(f"\n🏷️  Label Probabilities Found: {len(label_probs)} labels")
                    print("   Labels and probabilities:")
                    for label, prob in label_probs.items():
                        print(f"      {label}: {prob:.3f}")
                    
                    # Save discovered labels
                    discovered_labels = list(label_probs.keys())
                    save_discovered_labels(discovered_labels)
                    
                else:
                    print("❌ No label_probabilities found in response")
                    print("   This suggests the model might not be outputting clinical labels yet")
            
            if 'probabilities' in result:
                probs = result['probabilities']
                print(f"\nπŸ“Š Raw Probabilities Array: {len(probs)} values")
                print(f"   First 10 values: {probs[:10]}")
                
                # If we have probabilities but no labels, we need to discover the label mapping
                if len(probs) > 0 and 'label_probabilities' not in result.get('clinical_analysis', {}):
                    print("\n⚠️  Model outputs probabilities but no label names")
                    print("   This suggests we need to find the label definitions from the model")
            
        else:
            print(f"❌ Analysis failed: {analysis_response.status_code}")
            print(f"   Response: {analysis_response.text}")
            
    except Exception as e:
        print(f"❌ Error testing API: {e}")

def save_discovered_labels(labels: List[str]):
    """Save discovered labels to a file"""
    try:
        # Create a proper label definition file
        label_def_content = []
        for i, label in enumerate(labels):
            label_def_content.append(f"{i},{label}")
        
        with open('discovered_labels.csv', 'w') as f:
            f.write('\n'.join(label_def_content))
        
        print(f"\nπŸ’Ύ Discovered labels saved to: discovered_labels.csv")
        print(f"   Total labels: {len(labels)}")
        
        # Also create a simple list file
        with open('model_labels.txt', 'w') as f:
            f.write('\n'.join(labels))
        
        print(f"   Labels list saved to: model_labels.txt")
        
    except Exception as e:
        print(f"❌ Error saving discovered labels: {e}")

def inspect_model_checkpoint():
    """Inspect the model checkpoint to understand its structure"""
    print("\nπŸ” Model Checkpoint Inspection")
    print("=" * 40)
    
    print("πŸ’‘ To properly discover model labels, you should:")
    print("1. Load the model checkpoint locally")
    print("2. Inspect the model's classification head")
    print("3. Check for label mapping in the checkpoint")
    print("4. Or test with known ECG data to see output patterns")
    
    print("\nπŸ“š Alternative approaches:")
    print("1. Check ECG-FM paper/repository for label definitions")
    print("2. Contact the model authors for label mapping")
    print("3. Use a small labeled dataset to map outputs to known conditions")

def main():
    """Main function to discover model labels"""
    print("πŸ§ͺ ECG-FM Model Label Discovery")
    print("=" * 50)
    
    print("🎯 Goal: Discover the actual labels that the finetuned model outputs")
    print("   This will help us create the correct label_def.csv")
    
    # Test with deployed API
    test_model_with_sample_ecg()
    
    # Provide guidance for further investigation
    inspect_model_checkpoint()
    
    print("\nπŸ’‘ Next Steps:")
    print("1. Run this script to test the deployed API")
    print("2. Check if label_probabilities are returned")
    print("3. If yes, use those labels; if no, investigate further")
    print("4. Update label_def.csv with the correct labels")

if __name__ == "__main__":
    main()