ecg-fm-api / discover_model_labels.py
mystic_CBK
Deploy ECG-FM Dual Model API v2.0.0
31b6ae7
raw
history blame
6.15 kB
#!/usr/bin/env python3
"""
Discover ECG-FM Model Labels
Inspect the actual labels that the finetuned model outputs
"""
import torch
import numpy as np
import json
from typing import Dict, Any, List
import requests
import time
def test_model_with_sample_ecg():
"""Test the deployed model to see what labels it actually outputs"""
print("πŸ” Discovering ECG-FM Model Labels")
print("=" * 50)
# Test with a simple ECG signal
# Create a minimal 12-lead ECG signal (500 samples, 12 leads)
sample_ecg = np.random.normal(0, 0.1, (12, 500)).tolist()
payload = {
"signal": sample_ecg,
"fs": 500,
"lead_names": ["I", "II", "III", "aVR", "aVL", "aVF", "V1", "V2", "V3", "V4", "V5", "V6"],
"recording_duration": 1.0
}
print("πŸ“Š Testing with sample ECG signal...")
print(f" Signal shape: {len(sample_ecg)} leads x {len(sample_ecg[0])} samples")
# Test the deployed API
api_url = "https://mystic-cbk-ecg-fm-api.hf.space"
try:
print(f"\n🌐 Testing deployed API: {api_url}")
# Test health first
health_response = requests.get(f"{api_url}/health", timeout=30)
if health_response.status_code == 200:
print("βœ… API is healthy")
else:
print(f"❌ API health check failed: {health_response.status_code}")
return
# Test full analysis
print("\nπŸ”¬ Testing full ECG analysis...")
analysis_response = requests.post(
f"{api_url}/analyze",
json=payload,
timeout=180
)
if analysis_response.status_code == 200:
result = analysis_response.json()
print("βœ… Analysis successful!")
# Inspect the response structure
print("\nπŸ“‹ Response Structure Analysis:")
print(f" Keys: {list(result.keys())}")
if 'clinical_analysis' in result:
clinical = result['clinical_analysis']
print(f"\nπŸ₯ Clinical Analysis Keys: {list(clinical.keys())}")
if 'label_probabilities' in clinical:
label_probs = clinical['label_probabilities']
print(f"\n🏷️ Label Probabilities Found: {len(label_probs)} labels")
print(" Labels and probabilities:")
for label, prob in label_probs.items():
print(f" {label}: {prob:.3f}")
# Save discovered labels
discovered_labels = list(label_probs.keys())
save_discovered_labels(discovered_labels)
else:
print("❌ No label_probabilities found in response")
print(" This suggests the model might not be outputting clinical labels yet")
if 'probabilities' in result:
probs = result['probabilities']
print(f"\nπŸ“Š Raw Probabilities Array: {len(probs)} values")
print(f" First 10 values: {probs[:10]}")
# If we have probabilities but no labels, we need to discover the label mapping
if len(probs) > 0 and 'label_probabilities' not in result.get('clinical_analysis', {}):
print("\n⚠️ Model outputs probabilities but no label names")
print(" This suggests we need to find the label definitions from the model")
else:
print(f"❌ Analysis failed: {analysis_response.status_code}")
print(f" Response: {analysis_response.text}")
except Exception as e:
print(f"❌ Error testing API: {e}")
def save_discovered_labels(labels: List[str]):
"""Save discovered labels to a file"""
try:
# Create a proper label definition file
label_def_content = []
for i, label in enumerate(labels):
label_def_content.append(f"{i},{label}")
with open('discovered_labels.csv', 'w') as f:
f.write('\n'.join(label_def_content))
print(f"\nπŸ’Ύ Discovered labels saved to: discovered_labels.csv")
print(f" Total labels: {len(labels)}")
# Also create a simple list file
with open('model_labels.txt', 'w') as f:
f.write('\n'.join(labels))
print(f" Labels list saved to: model_labels.txt")
except Exception as e:
print(f"❌ Error saving discovered labels: {e}")
def inspect_model_checkpoint():
"""Inspect the model checkpoint to understand its structure"""
print("\nπŸ” Model Checkpoint Inspection")
print("=" * 40)
print("πŸ’‘ To properly discover model labels, you should:")
print("1. Load the model checkpoint locally")
print("2. Inspect the model's classification head")
print("3. Check for label mapping in the checkpoint")
print("4. Or test with known ECG data to see output patterns")
print("\nπŸ“š Alternative approaches:")
print("1. Check ECG-FM paper/repository for label definitions")
print("2. Contact the model authors for label mapping")
print("3. Use a small labeled dataset to map outputs to known conditions")
def main():
"""Main function to discover model labels"""
print("πŸ§ͺ ECG-FM Model Label Discovery")
print("=" * 50)
print("🎯 Goal: Discover the actual labels that the finetuned model outputs")
print(" This will help us create the correct label_def.csv")
# Test with deployed API
test_model_with_sample_ecg()
# Provide guidance for further investigation
inspect_model_checkpoint()
print("\nπŸ’‘ Next Steps:")
print("1. Run this script to test the deployed API")
print("2. Check if label_probabilities are returned")
print("3. If yes, use those labels; if no, investigate further")
print("4. Update label_def.csv with the correct labels")
if __name__ == "__main__":
main()