Spaces:
Running
Running
| #!/usr/bin/env python3 | |
| """ | |
| Discover ECG-FM Model Labels | |
| Inspect the actual labels that the finetuned model outputs | |
| """ | |
| import torch | |
| import numpy as np | |
| import json | |
| from typing import Dict, Any, List | |
| import requests | |
| import time | |
| def test_model_with_sample_ecg(): | |
| """Test the deployed model to see what labels it actually outputs""" | |
| print("π Discovering ECG-FM Model Labels") | |
| print("=" * 50) | |
| # Test with a simple ECG signal | |
| # Create a minimal 12-lead ECG signal (500 samples, 12 leads) | |
| sample_ecg = np.random.normal(0, 0.1, (12, 500)).tolist() | |
| payload = { | |
| "signal": sample_ecg, | |
| "fs": 500, | |
| "lead_names": ["I", "II", "III", "aVR", "aVL", "aVF", "V1", "V2", "V3", "V4", "V5", "V6"], | |
| "recording_duration": 1.0 | |
| } | |
| print("π Testing with sample ECG signal...") | |
| print(f" Signal shape: {len(sample_ecg)} leads x {len(sample_ecg[0])} samples") | |
| # Test the deployed API | |
| api_url = "https://mystic-cbk-ecg-fm-api.hf.space" | |
| try: | |
| print(f"\nπ Testing deployed API: {api_url}") | |
| # Test health first | |
| health_response = requests.get(f"{api_url}/health", timeout=30) | |
| if health_response.status_code == 200: | |
| print("β API is healthy") | |
| else: | |
| print(f"β API health check failed: {health_response.status_code}") | |
| return | |
| # Test full analysis | |
| print("\n㪠Testing full ECG analysis...") | |
| analysis_response = requests.post( | |
| f"{api_url}/analyze", | |
| json=payload, | |
| timeout=180 | |
| ) | |
| if analysis_response.status_code == 200: | |
| result = analysis_response.json() | |
| print("β Analysis successful!") | |
| # Inspect the response structure | |
| print("\nπ Response Structure Analysis:") | |
| print(f" Keys: {list(result.keys())}") | |
| if 'clinical_analysis' in result: | |
| clinical = result['clinical_analysis'] | |
| print(f"\nπ₯ Clinical Analysis Keys: {list(clinical.keys())}") | |
| if 'label_probabilities' in clinical: | |
| label_probs = clinical['label_probabilities'] | |
| print(f"\nπ·οΈ Label Probabilities Found: {len(label_probs)} labels") | |
| print(" Labels and probabilities:") | |
| for label, prob in label_probs.items(): | |
| print(f" {label}: {prob:.3f}") | |
| # Save discovered labels | |
| discovered_labels = list(label_probs.keys()) | |
| save_discovered_labels(discovered_labels) | |
| else: | |
| print("β No label_probabilities found in response") | |
| print(" This suggests the model might not be outputting clinical labels yet") | |
| if 'probabilities' in result: | |
| probs = result['probabilities'] | |
| print(f"\nπ Raw Probabilities Array: {len(probs)} values") | |
| print(f" First 10 values: {probs[:10]}") | |
| # If we have probabilities but no labels, we need to discover the label mapping | |
| if len(probs) > 0 and 'label_probabilities' not in result.get('clinical_analysis', {}): | |
| print("\nβ οΈ Model outputs probabilities but no label names") | |
| print(" This suggests we need to find the label definitions from the model") | |
| else: | |
| print(f"β Analysis failed: {analysis_response.status_code}") | |
| print(f" Response: {analysis_response.text}") | |
| except Exception as e: | |
| print(f"β Error testing API: {e}") | |
| def save_discovered_labels(labels: List[str]): | |
| """Save discovered labels to a file""" | |
| try: | |
| # Create a proper label definition file | |
| label_def_content = [] | |
| for i, label in enumerate(labels): | |
| label_def_content.append(f"{i},{label}") | |
| with open('discovered_labels.csv', 'w') as f: | |
| f.write('\n'.join(label_def_content)) | |
| print(f"\nπΎ Discovered labels saved to: discovered_labels.csv") | |
| print(f" Total labels: {len(labels)}") | |
| # Also create a simple list file | |
| with open('model_labels.txt', 'w') as f: | |
| f.write('\n'.join(labels)) | |
| print(f" Labels list saved to: model_labels.txt") | |
| except Exception as e: | |
| print(f"β Error saving discovered labels: {e}") | |
| def inspect_model_checkpoint(): | |
| """Inspect the model checkpoint to understand its structure""" | |
| print("\nπ Model Checkpoint Inspection") | |
| print("=" * 40) | |
| print("π‘ To properly discover model labels, you should:") | |
| print("1. Load the model checkpoint locally") | |
| print("2. Inspect the model's classification head") | |
| print("3. Check for label mapping in the checkpoint") | |
| print("4. Or test with known ECG data to see output patterns") | |
| print("\nπ Alternative approaches:") | |
| print("1. Check ECG-FM paper/repository for label definitions") | |
| print("2. Contact the model authors for label mapping") | |
| print("3. Use a small labeled dataset to map outputs to known conditions") | |
| def main(): | |
| """Main function to discover model labels""" | |
| print("π§ͺ ECG-FM Model Label Discovery") | |
| print("=" * 50) | |
| print("π― Goal: Discover the actual labels that the finetuned model outputs") | |
| print(" This will help us create the correct label_def.csv") | |
| # Test with deployed API | |
| test_model_with_sample_ecg() | |
| # Provide guidance for further investigation | |
| inspect_model_checkpoint() | |
| print("\nπ‘ Next Steps:") | |
| print("1. Run this script to test the deployed API") | |
| print("2. Check if label_probabilities are returned") | |
| print("3. If yes, use those labels; if no, investigate further") | |
| print("4. Update label_def.csv with the correct labels") | |
| if __name__ == "__main__": | |
| main() | |