Spaces:
Running
Running
mystic_CBK
commited on
Commit
·
10b881c
1
Parent(s):
d205f72
Fix clinical analysis to use 'out' key from finetuned model (actual working output)
Browse files- clinical_analysis.py +17 -0
clinical_analysis.py
CHANGED
|
@@ -35,6 +35,23 @@ def analyze_ecg_features(model_output: Dict[str, Any]) -> Dict[str, Any]:
|
|
| 35 |
clinical_result = extract_clinical_from_probabilities(probs)
|
| 36 |
return clinical_result
|
| 37 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 38 |
# NEW: Check if the model output IS the logits tensor directly (classifier model)
|
| 39 |
elif isinstance(model_output, torch.Tensor):
|
| 40 |
print("✅ Found direct logits tensor - using classifier model output")
|
|
|
|
| 35 |
clinical_result = extract_clinical_from_probabilities(probs)
|
| 36 |
return clinical_result
|
| 37 |
|
| 38 |
+
# NEW: Check for 'out' key (actual finetuned model output)
|
| 39 |
+
elif 'out' in model_output:
|
| 40 |
+
print("✅ Found 'out' key - using finetuned model output")
|
| 41 |
+
# FINETUNED MODEL - Extract real clinical predictions
|
| 42 |
+
logits = model_output['out']
|
| 43 |
+
if isinstance(logits, torch.Tensor):
|
| 44 |
+
# Remove batch dimension if present
|
| 45 |
+
if logits.dim() == 2: # [batch, num_labels]
|
| 46 |
+
logits = logits.squeeze(0) # Remove batch dimension
|
| 47 |
+
probs = torch.sigmoid(logits).detach().cpu().numpy().ravel()
|
| 48 |
+
else:
|
| 49 |
+
probs = 1 / (1 + np.exp(-np.array(logits).ravel()))
|
| 50 |
+
|
| 51 |
+
# Extract clinical parameters from probabilities
|
| 52 |
+
clinical_result = extract_clinical_from_probabilities(probs)
|
| 53 |
+
return clinical_result
|
| 54 |
+
|
| 55 |
# NEW: Check if the model output IS the logits tensor directly (classifier model)
|
| 56 |
elif isinstance(model_output, torch.Tensor):
|
| 57 |
print("✅ Found direct logits tensor - using classifier model output")
|