Spaces:
Build error
Build error
Upload 5 files
Browse files- app.py +262 -0
- requirements.txt +7 -3
- stress_calculation_guide.ipynb +159 -0
- stress_detection_fre_ck+.ipynb +0 -0
- wesad.ipynb +337 -0
app.py
ADDED
|
@@ -0,0 +1,262 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import streamlit as st
|
| 2 |
+
import tensorflow as tf
|
| 3 |
+
import numpy as np
|
| 4 |
+
import cv2
|
| 5 |
+
from PIL import Image
|
| 6 |
+
import math
|
| 7 |
+
|
| 8 |
+
# Define the stress functions (copy these from your notebook)
|
| 9 |
+
# Source: https://www.kaggle.com/code/ahmedmagaji/stress-detection/notebook
|
| 10 |
+
# Formula: t = a*p + b; stress = c*ln(t)
|
| 11 |
+
# where p is probability (0-1), t is intermediate value, and stress is the output
|
| 12 |
+
def anger(p):
|
| 13 |
+
t = 0.343 * p + 1.003
|
| 14 |
+
return 2.332 * math.log(t)
|
| 15 |
+
|
| 16 |
+
def fear(p):
|
| 17 |
+
t = 1.356 * p + 1
|
| 18 |
+
return 1.763 * math.log(t)
|
| 19 |
+
|
| 20 |
+
def contempt(p):
|
| 21 |
+
t = 0.01229 * p + 1.036
|
| 22 |
+
return 5.03 * math.log(t)
|
| 23 |
+
|
| 24 |
+
def disgust(p):
|
| 25 |
+
t = 0.0123 * p + 1.019
|
| 26 |
+
return 7.351 * math.log(t)
|
| 27 |
+
|
| 28 |
+
def happy(p):
|
| 29 |
+
t = 5.221e-5 * p + 0.9997
|
| 30 |
+
return 532.2 * math.log(t)
|
| 31 |
+
|
| 32 |
+
def sad(p):
|
| 33 |
+
t = 0.1328 * p + 1.009
|
| 34 |
+
return 2.851 * math.log(t)
|
| 35 |
+
|
| 36 |
+
def surprise(p):
|
| 37 |
+
t = 0.2825 * p + 1.003
|
| 38 |
+
return 2.478 * math.log(t)
|
| 39 |
+
|
| 40 |
+
# Mapping from emotion index to label and stress function
|
| 41 |
+
emotion_map = {0: 'Angry', 1: 'Disgust', 2: 'Fear', 3: 'Happy', 4: 'Sad', 5: 'Surprise', 6: 'Neutral'}
|
| 42 |
+
aligned_func = [anger, disgust, fear, happy, sad, surprise, contempt] # Ensure correct order matching emotion_map keys
|
| 43 |
+
|
| 44 |
+
# Load Haar Cascade for face detection
|
| 45 |
+
@st.cache_resource
|
| 46 |
+
def load_face_cascade():
|
| 47 |
+
cascade_path = cv2.data.haarcascades + 'haarcascade_frontalface_default.xml'
|
| 48 |
+
return cv2.CascadeClassifier(cascade_path)
|
| 49 |
+
|
| 50 |
+
# Load the trained model
|
| 51 |
+
@st.cache_resource # Cache the model to avoid reloading on each interaction
|
| 52 |
+
def load_model():
|
| 53 |
+
model_path = "models/model_fer_bloss.h5"
|
| 54 |
+
try:
|
| 55 |
+
model = tf.keras.models.load_model(model_path)
|
| 56 |
+
return model
|
| 57 |
+
except FileNotFoundError:
|
| 58 |
+
st.error(f"Model file not found at {model_path}. Please ensure the .h5 file is in the 'models/' folder.")
|
| 59 |
+
return None
|
| 60 |
+
except Exception as e:
|
| 61 |
+
st.error(f"Error loading model: {e}")
|
| 62 |
+
return None
|
| 63 |
+
|
| 64 |
+
face_cascade = load_face_cascade()
|
| 65 |
+
model = load_model()
|
| 66 |
+
|
| 67 |
+
st.title("Facial Expression Stress Detector")
|
| 68 |
+
st.write("Upload an image or use your camera to detect faces, process them, and estimate stress level.")
|
| 69 |
+
st.markdown("---")
|
| 70 |
+
|
| 71 |
+
# Tab selection for input method
|
| 72 |
+
tab1, tab2 = st.tabs(["Upload Image", "Live Camera"])
|
| 73 |
+
|
| 74 |
+
with tab1:
|
| 75 |
+
uploaded_file = st.file_uploader("Choose an image...", type=["jpg", "jpeg", "png"])
|
| 76 |
+
image_source = uploaded_file
|
| 77 |
+
|
| 78 |
+
with tab2:
|
| 79 |
+
camera_input = st.camera_input("Take a picture")
|
| 80 |
+
image_source = camera_input
|
| 81 |
+
|
| 82 |
+
if image_source is not None:
|
| 83 |
+
# Load and display original image
|
| 84 |
+
original_image = Image.open(image_source).convert("RGB")
|
| 85 |
+
|
| 86 |
+
col1, col2 = st.columns(2)
|
| 87 |
+
with col1:
|
| 88 |
+
st.subheader("1οΈβ£ Original Image")
|
| 89 |
+
st.image(original_image, use_container_width=True)
|
| 90 |
+
|
| 91 |
+
# Convert to OpenCV format for processing
|
| 92 |
+
img_cv = cv2.cvtColor(np.array(original_image), cv2.COLOR_RGB2BGR)
|
| 93 |
+
gray_img = cv2.cvtColor(img_cv, cv2.COLOR_BGR2GRAY)
|
| 94 |
+
|
| 95 |
+
# Apply histogram equalization to improve contrast
|
| 96 |
+
gray_img_eq = cv2.equalizeHist(gray_img)
|
| 97 |
+
|
| 98 |
+
# Detect faces with multiple parameter sets (more lenient)
|
| 99 |
+
faces = face_cascade.detectMultiScale(
|
| 100 |
+
gray_img_eq,
|
| 101 |
+
scaleFactor=1.1, # Reduced from 1.3 (slower but finds smaller faces)
|
| 102 |
+
minNeighbors=3, # Reduced from 5 (more lenient)
|
| 103 |
+
minSize=(20, 20), # Reduced from (30, 30)
|
| 104 |
+
maxSize=(400, 400)
|
| 105 |
+
)
|
| 106 |
+
|
| 107 |
+
# If no faces found, try less strict parameters
|
| 108 |
+
if len(faces) == 0:
|
| 109 |
+
faces = face_cascade.detectMultiScale(
|
| 110 |
+
gray_img_eq,
|
| 111 |
+
scaleFactor=1.05,
|
| 112 |
+
minNeighbors=2,
|
| 113 |
+
minSize=(15, 15)
|
| 114 |
+
)
|
| 115 |
+
|
| 116 |
+
# If still no faces, try original grayscale without equalization
|
| 117 |
+
if len(faces) == 0:
|
| 118 |
+
faces = face_cascade.detectMultiScale(
|
| 119 |
+
gray_img,
|
| 120 |
+
scaleFactor=1.1,
|
| 121 |
+
minNeighbors=3,
|
| 122 |
+
minSize=(20, 20)
|
| 123 |
+
)
|
| 124 |
+
|
| 125 |
+
with col2:
|
| 126 |
+
st.subheader("π Detection Info")
|
| 127 |
+
st.write(f"**Faces Found:** {len(faces)}")
|
| 128 |
+
st.write(f"**Image Size:** {gray_img.shape}")
|
| 129 |
+
|
| 130 |
+
if len(faces) == 0:
|
| 131 |
+
st.warning("β οΈ No face detected. Try:")
|
| 132 |
+
st.write("- Ensure your face is clearly visible")
|
| 133 |
+
st.write("- Make sure lighting is adequate")
|
| 134 |
+
st.write("- Face should be front-facing")
|
| 135 |
+
st.write("- Try uploading a cropped image with just the face")
|
| 136 |
+
|
| 137 |
+
# Option to skip face detection
|
| 138 |
+
if st.checkbox("Skip face detection and process entire image?"):
|
| 139 |
+
st.info("Processing entire image as 48x48...")
|
| 140 |
+
|
| 141 |
+
# Resize entire image to 48x48
|
| 142 |
+
face_resized = cv2.resize(gray_img, (48, 48))
|
| 143 |
+
face_normalized = face_resized / 255.0
|
| 144 |
+
input_img = np.reshape(face_normalized, (1, 48, 48, 1))
|
| 145 |
+
|
| 146 |
+
if model:
|
| 147 |
+
predictions = model.predict(input_img, verbose=0)
|
| 148 |
+
predicted_emotion_index = np.argmax(predictions)
|
| 149 |
+
predicted_emotion_label = emotion_map[predicted_emotion_index]
|
| 150 |
+
confidence = np.max(predictions) * 100
|
| 151 |
+
probability = np.max(predictions)
|
| 152 |
+
|
| 153 |
+
col_pred1, col_pred2 = st.columns(2)
|
| 154 |
+
with col_pred1:
|
| 155 |
+
st.write(f"**Predicted Emotion:** {predicted_emotion_label}")
|
| 156 |
+
st.write(f"**Confidence:** {confidence:.2f}%")
|
| 157 |
+
|
| 158 |
+
if predicted_emotion_index < len(aligned_func):
|
| 159 |
+
try:
|
| 160 |
+
stress_score = aligned_func[predicted_emotion_index](probability)
|
| 161 |
+
normalized_stress = (stress_score / 9) * 100
|
| 162 |
+
|
| 163 |
+
with col_pred2:
|
| 164 |
+
if normalized_stress < 33:
|
| 165 |
+
st.success(f"**Stress:** {normalized_stress:.2f}% π’ Low")
|
| 166 |
+
elif normalized_stress < 66:
|
| 167 |
+
st.warning(f"**Stress:** {normalized_stress:.2f}% π‘ Medium")
|
| 168 |
+
else:
|
| 169 |
+
st.error(f"**Stress:** {normalized_stress:.2f}% π΄ High")
|
| 170 |
+
except Exception as e:
|
| 171 |
+
st.error(f"Error: {e}")
|
| 172 |
+
else:
|
| 173 |
+
st.success(f"β
Detected {len(faces)} face(s)")
|
| 174 |
+
|
| 175 |
+
# Process each detected face
|
| 176 |
+
for idx, (x, y, w, h) in enumerate(faces):
|
| 177 |
+
st.markdown(f"### Face {idx + 1}")
|
| 178 |
+
|
| 179 |
+
col1, col2, col3, col4 = st.columns(4)
|
| 180 |
+
|
| 181 |
+
# 1. Face Region (from original image)
|
| 182 |
+
face_region = gray_img[y:y+h, x:x+w]
|
| 183 |
+
with col1:
|
| 184 |
+
st.subheader("2οΈβ£ Detected Face")
|
| 185 |
+
face_pil = Image.fromarray(face_region)
|
| 186 |
+
st.image(face_pil, use_container_width=True)
|
| 187 |
+
|
| 188 |
+
# 2. Resized to 48x48
|
| 189 |
+
face_resized = cv2.resize(face_region, (48, 48))
|
| 190 |
+
with col2:
|
| 191 |
+
st.subheader("3οΈβ£ Resized (48Γ48)")
|
| 192 |
+
resized_pil = Image.fromarray(face_resized)
|
| 193 |
+
st.image(resized_pil, use_container_width=True)
|
| 194 |
+
|
| 195 |
+
# 3. Normalized (0-1 range)
|
| 196 |
+
face_normalized = face_resized / 255.0
|
| 197 |
+
face_normalized_display = (face_normalized * 255).astype(np.uint8)
|
| 198 |
+
with col3:
|
| 199 |
+
st.subheader("4οΈβ£ Normalized")
|
| 200 |
+
normalized_pil = Image.fromarray(face_normalized_display)
|
| 201 |
+
st.image(normalized_pil, use_container_width=True)
|
| 202 |
+
|
| 203 |
+
# 4. Model Input Shape (1, 48, 48, 1)
|
| 204 |
+
input_img = np.reshape(face_normalized, (1, 48, 48, 1))
|
| 205 |
+
with col4:
|
| 206 |
+
st.subheader("5οΈβ£ Model Input")
|
| 207 |
+
st.write(f"**Shape:** {input_img.shape}")
|
| 208 |
+
st.write(f"**Min:** {input_img.min():.3f}")
|
| 209 |
+
st.write(f"**Max:** {input_img.max():.3f}")
|
| 210 |
+
st.write(f"**Mean:** {input_img.mean():.3f}")
|
| 211 |
+
|
| 212 |
+
# Run model prediction
|
| 213 |
+
if model:
|
| 214 |
+
st.markdown("---")
|
| 215 |
+
st.subheader("π Prediction Results")
|
| 216 |
+
|
| 217 |
+
predictions = model.predict(input_img, verbose=0)
|
| 218 |
+
predicted_emotion_index = np.argmax(predictions)
|
| 219 |
+
predicted_emotion_label = emotion_map[predicted_emotion_index]
|
| 220 |
+
confidence = np.max(predictions) * 100
|
| 221 |
+
probability = np.max(predictions)
|
| 222 |
+
|
| 223 |
+
col_pred1, col_pred2 = st.columns(2)
|
| 224 |
+
with col_pred1:
|
| 225 |
+
st.write(f"**Predicted Emotion:** {predicted_emotion_label}")
|
| 226 |
+
st.write(f"**Confidence:** {confidence:.2f}%")
|
| 227 |
+
|
| 228 |
+
# Calculate and display stress level
|
| 229 |
+
if predicted_emotion_index < len(aligned_func):
|
| 230 |
+
try:
|
| 231 |
+
stress_score = aligned_func[predicted_emotion_index](probability)
|
| 232 |
+
normalized_stress = (stress_score / 9) * 100
|
| 233 |
+
|
| 234 |
+
with col_pred2:
|
| 235 |
+
# Color-coded stress level
|
| 236 |
+
if normalized_stress < 33:
|
| 237 |
+
st.success(f"**Stress Level:** {normalized_stress:.2f}% π’ Low")
|
| 238 |
+
elif normalized_stress < 66:
|
| 239 |
+
st.warning(f"**Stress Level:** {normalized_stress:.2f}% π‘ Medium")
|
| 240 |
+
else:
|
| 241 |
+
st.error(f"**Stress Level:** {normalized_stress:.2f}% π΄ High")
|
| 242 |
+
except Exception as e:
|
| 243 |
+
st.error(f"Could not calculate stress: {e}")
|
| 244 |
+
|
| 245 |
+
# Display all emotion probabilities as bar chart
|
| 246 |
+
st.subheader("π All Emotion Probabilities")
|
| 247 |
+
emotion_probs = {emotion_map[i]: float(predictions[0][i]) * 100 for i in range(7)}
|
| 248 |
+
emotion_probs_sorted = dict(sorted(emotion_probs.items(), key=lambda x: x[1], reverse=True))
|
| 249 |
+
|
| 250 |
+
col_chart1, col_chart2 = st.columns([2, 1])
|
| 251 |
+
with col_chart1:
|
| 252 |
+
st.bar_chart(emotion_probs_sorted)
|
| 253 |
+
|
| 254 |
+
with col_chart2:
|
| 255 |
+
st.write("**Emotion Breakdown:**")
|
| 256 |
+
for emotion, prob in emotion_probs_sorted.items():
|
| 257 |
+
st.write(f"{emotion}: **{prob:.2f}%**")
|
| 258 |
+
|
| 259 |
+
st.markdown("---")
|
| 260 |
+
else:
|
| 261 |
+
st.error("β Model not loaded. Please check the model path.")
|
| 262 |
+
|
requirements.txt
CHANGED
|
@@ -1,3 +1,7 @@
|
|
| 1 |
-
|
| 2 |
-
|
| 3 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
streamlit
|
| 2 |
+
tensorflow
|
| 3 |
+
numpy
|
| 4 |
+
opencv-python
|
| 5 |
+
Pillow
|
| 6 |
+
scipy
|
| 7 |
+
matplotlib
|
stress_calculation_guide.ipynb
ADDED
|
@@ -0,0 +1,159 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"cells": [
|
| 3 |
+
{
|
| 4 |
+
"cell_type": "markdown",
|
| 5 |
+
"id": "9236892e",
|
| 6 |
+
"metadata": {},
|
| 7 |
+
"source": [
|
| 8 |
+
"# Stress Calculation Mathematics\n",
|
| 9 |
+
"\n",
|
| 10 |
+
"This notebook documents the stress calculation formulas used in the Facial Expression Stress Detector.\n",
|
| 11 |
+
"\n",
|
| 12 |
+
"Source: https://www.kaggle.com/code/ahmedmagaji/stress-detection/notebook\n",
|
| 13 |
+
"\n",
|
| 14 |
+
"## Formula Structure\n",
|
| 15 |
+
"\n",
|
| 16 |
+
"Each emotion has a stress function of the form:\n",
|
| 17 |
+
"- `t = a * p + b` (linear transformation of probability)\n",
|
| 18 |
+
"- `stress = c * ln(t)` (logarithmic scaling)\n",
|
| 19 |
+
"\n",
|
| 20 |
+
"Where `p` is the emotion probability (0-1 range)"
|
| 21 |
+
]
|
| 22 |
+
},
|
| 23 |
+
{
|
| 24 |
+
"cell_type": "code",
|
| 25 |
+
"execution_count": null,
|
| 26 |
+
"id": "7f7215a9",
|
| 27 |
+
"metadata": {},
|
| 28 |
+
"outputs": [],
|
| 29 |
+
"source": [
|
| 30 |
+
"import math\n",
|
| 31 |
+
"\n",
|
| 32 |
+
"# Stress calculation formulas\n",
|
| 33 |
+
"stress_formulas = {\n",
|
| 34 |
+
" 'Angry': {\n",
|
| 35 |
+
" 'formula': 't = 0.343 * p + 1.003; stress = 2.332 * ln(t)',\n",
|
| 36 |
+
" 'a': 0.343, 'b': 1.003, 'c': 2.332\n",
|
| 37 |
+
" },\n",
|
| 38 |
+
" 'Disgust': {\n",
|
| 39 |
+
" 'formula': 't = 0.0123 * p + 1.019; stress = 7.351 * ln(t)',\n",
|
| 40 |
+
" 'a': 0.0123, 'b': 1.019, 'c': 7.351\n",
|
| 41 |
+
" },\n",
|
| 42 |
+
" 'Fear': {\n",
|
| 43 |
+
" 'formula': 't = 1.356 * p + 1; stress = 1.763 * ln(t)',\n",
|
| 44 |
+
" 'a': 1.356, 'b': 1.0, 'c': 1.763\n",
|
| 45 |
+
" },\n",
|
| 46 |
+
" 'Happy': {\n",
|
| 47 |
+
" 'formula': 't = 5.221e-5 * p + 0.9997; stress = 532.2 * ln(t)',\n",
|
| 48 |
+
" 'a': 5.221e-5, 'b': 0.9997, 'c': 532.2\n",
|
| 49 |
+
" },\n",
|
| 50 |
+
" 'Sad': {\n",
|
| 51 |
+
" 'formula': 't = 0.1328 * p + 1.009; stress = 2.851 * ln(t)',\n",
|
| 52 |
+
" 'a': 0.1328, 'b': 1.009, 'c': 2.851\n",
|
| 53 |
+
" },\n",
|
| 54 |
+
" 'Surprise': {\n",
|
| 55 |
+
" 'formula': 't = 0.2825 * p + 1.003; stress = 2.478 * ln(t)',\n",
|
| 56 |
+
" 'a': 0.2825, 'b': 1.003, 'c': 2.478\n",
|
| 57 |
+
" },\n",
|
| 58 |
+
" 'Contempt': {\n",
|
| 59 |
+
" 'formula': 't = 0.01229 * p + 1.036; stress = 5.03 * ln(t)',\n",
|
| 60 |
+
" 'a': 0.01229, 'b': 1.036, 'c': 5.03\n",
|
| 61 |
+
" }\n",
|
| 62 |
+
"}\n",
|
| 63 |
+
"\n",
|
| 64 |
+
"# Display formulas\n",
|
| 65 |
+
"for emotion, details in stress_formulas.items():\n",
|
| 66 |
+
" print(f\"{emotion}: {details['formula']}\")"
|
| 67 |
+
]
|
| 68 |
+
},
|
| 69 |
+
{
|
| 70 |
+
"cell_type": "code",
|
| 71 |
+
"execution_count": null,
|
| 72 |
+
"id": "d13c61d5",
|
| 73 |
+
"metadata": {},
|
| 74 |
+
"outputs": [],
|
| 75 |
+
"source": [
|
| 76 |
+
"# Example calculations with probability = 0.5 (50% confidence)\n",
|
| 77 |
+
"def anger(p):\n",
|
| 78 |
+
" t = 0.343 * p + 1.003\n",
|
| 79 |
+
" return 2.332 * math.log(t)\n",
|
| 80 |
+
"\n",
|
| 81 |
+
"def fear(p):\n",
|
| 82 |
+
" t = 1.356 * p + 1\n",
|
| 83 |
+
" return 1.763 * math.log(t)\n",
|
| 84 |
+
"\n",
|
| 85 |
+
"def contempt(p):\n",
|
| 86 |
+
" t = 0.01229 * p + 1.036\n",
|
| 87 |
+
" return 5.03 * math.log(t)\n",
|
| 88 |
+
"\n",
|
| 89 |
+
"def disgust(p):\n",
|
| 90 |
+
" t = 0.0123 * p + 1.019\n",
|
| 91 |
+
" return 7.351 * math.log(t)\n",
|
| 92 |
+
"\n",
|
| 93 |
+
"def happy(p):\n",
|
| 94 |
+
" t = 5.221e-5 * p + 0.9997\n",
|
| 95 |
+
" return 532.2 * math.log(t)\n",
|
| 96 |
+
"\n",
|
| 97 |
+
"def sad(p):\n",
|
| 98 |
+
" t = 0.1328 * p + 1.009\n",
|
| 99 |
+
" return 2.851 * math.log(t)\n",
|
| 100 |
+
"\n",
|
| 101 |
+
"def surprise(p):\n",
|
| 102 |
+
" t = 0.2825 * p + 1.003\n",
|
| 103 |
+
" return 2.478 * math.log(t)\n",
|
| 104 |
+
"\n",
|
| 105 |
+
"# Test with p = 0.5\n",
|
| 106 |
+
"test_probability = 0.5\n",
|
| 107 |
+
"print(f\"\\nStress values for probability p = {test_probability}:\")\n",
|
| 108 |
+
"print(f\"Angry: {anger(test_probability):.4f}\")\n",
|
| 109 |
+
"print(f\"Disgust: {disgust(test_probability):.4f}\")\n",
|
| 110 |
+
"print(f\"Fear: {fear(test_probability):.4f}\")\n",
|
| 111 |
+
"print(f\"Happy: {happy(test_probability):.4f}\")\n",
|
| 112 |
+
"print(f\"Sad: {sad(test_probability):.4f}\")\n",
|
| 113 |
+
"print(f\"Surprise: {surprise(test_probability):.4f}\")\n",
|
| 114 |
+
"print(f\"Contempt: {contempt(test_probability):.4f}\")\n",
|
| 115 |
+
"\n",
|
| 116 |
+
"# Normalize to 0-100 scale (divide by 9 as mentioned in your code)\n",
|
| 117 |
+
"print(f\"\\nNormalized Stress Levels (0-100):\")\n",
|
| 118 |
+
"print(f\"Angry: {(anger(test_probability) / 9 * 100):.2f}\")\n",
|
| 119 |
+
"print(f\"Disgust: {(disgust(test_probability) / 9 * 100):.2f}\")\n",
|
| 120 |
+
"print(f\"Fear: {(fear(test_probability) / 9 * 100):.2f}\")\n",
|
| 121 |
+
"print(f\"Happy: {(happy(test_probability) / 9 * 100):.2f}\")\n",
|
| 122 |
+
"print(f\"Sad: {(sad(test_probability) / 9 * 100):.2f}\")\n",
|
| 123 |
+
"print(f\"Surprise: {(surprise(test_probability) / 9 * 100):.2f}\")\n",
|
| 124 |
+
"print(f\"Contempt: {(contempt(test_probability) / 9 * 100):.2f}\")"
|
| 125 |
+
]
|
| 126 |
+
},
|
| 127 |
+
{
|
| 128 |
+
"cell_type": "markdown",
|
| 129 |
+
"id": "17bde9e1",
|
| 130 |
+
"metadata": {},
|
| 131 |
+
"source": [
|
| 132 |
+
"## Key Points\n",
|
| 133 |
+
"\n",
|
| 134 |
+
"1. **Input Range**: All stress functions expect `p` in range [0, 1] (probability, not percentage)\n",
|
| 135 |
+
"2. **Output Range**: Raw stress scores are typically 0-9\n",
|
| 136 |
+
"3. **Normalization**: Divide by 9 and multiply by 100 to get 0-100 scale\n",
|
| 137 |
+
"4. **Interpretation**:\n",
|
| 138 |
+
" - Low stress (0-33): Calm, relaxed expressions\n",
|
| 139 |
+
" - Medium stress (33-66): Moderate emotional intensity\n",
|
| 140 |
+
" - High stress (66-100): High emotional intensity or negative emotions\n",
|
| 141 |
+
"\n",
|
| 142 |
+
"## Validation\n",
|
| 143 |
+
"\n",
|
| 144 |
+
"Your Streamlit app correctly:\n",
|
| 145 |
+
"- β
Extracts probability from model predictions (0-1 range)\n",
|
| 146 |
+
"- β
Applies the stress functions\n",
|
| 147 |
+
"- β
Normalizes to 0-100 scale\n",
|
| 148 |
+
"- β
Color-codes results by stress level"
|
| 149 |
+
]
|
| 150 |
+
}
|
| 151 |
+
],
|
| 152 |
+
"metadata": {
|
| 153 |
+
"language_info": {
|
| 154 |
+
"name": "python"
|
| 155 |
+
}
|
| 156 |
+
},
|
| 157 |
+
"nbformat": 4,
|
| 158 |
+
"nbformat_minor": 5
|
| 159 |
+
}
|
stress_detection_fre_ck+.ipynb
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
wesad.ipynb
ADDED
|
@@ -0,0 +1,337 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"cells": [
|
| 3 |
+
{
|
| 4 |
+
"cell_type": "code",
|
| 5 |
+
"execution_count": null,
|
| 6 |
+
"id": "6d1cc4dd",
|
| 7 |
+
"metadata": {
|
| 8 |
+
"vscode": {
|
| 9 |
+
"languageId": "plaintext"
|
| 10 |
+
}
|
| 11 |
+
},
|
| 12 |
+
"outputs": [],
|
| 13 |
+
"source": [
|
| 14 |
+
"# ============================================================================\n",
|
| 15 |
+
"# WESAD Stress Detection - Google Colab with T4 GPU Support\n",
|
| 16 |
+
"# ============================================================================\n",
|
| 17 |
+
"# Enable GPU: Runtime > Change runtime type > Hardware accelerator > T4 GPU\n",
|
| 18 |
+
"# ============================================================================\n",
|
| 19 |
+
"\n",
|
| 20 |
+
"# === 1. Install Required Packages ===\n",
|
| 21 |
+
"print(\"Installing dependencies...\")\n",
|
| 22 |
+
"!pip install -q kagglehub scipy scikit-learn pandas numpy joblib\n",
|
| 23 |
+
"\n",
|
| 24 |
+
"# Install RAPIDS cuML for GPU acceleration (T4 compatible)\n",
|
| 25 |
+
"try:\n",
|
| 26 |
+
" import cuml\n",
|
| 27 |
+
" print(\"β cuML already installed\")\n",
|
| 28 |
+
"except ImportError:\n",
|
| 29 |
+
" print(\"Installing cuML for GPU acceleration...\")\n",
|
| 30 |
+
" !pip install -q cuml-cu11 --extra-index-url=https://pypi.nvidia.com\n",
|
| 31 |
+
" print(\"β cuML installed\")\n",
|
| 32 |
+
"\n",
|
| 33 |
+
"# === 2. Download WESAD Dataset ===\n",
|
| 34 |
+
"print(\"\\nDownloading WESAD dataset...\")\n",
|
| 35 |
+
"import kagglehub\n",
|
| 36 |
+
"import os\n",
|
| 37 |
+
"import glob\n",
|
| 38 |
+
"\n",
|
| 39 |
+
"# Download the dataset\n",
|
| 40 |
+
"path = kagglehub.dataset_download(\"orvile/wesad-wearable-stress-affect-detection-dataset\")\n",
|
| 41 |
+
"print(f\"Dataset downloaded to: {path}\")\n",
|
| 42 |
+
"\n",
|
| 43 |
+
"# Find all .pkl files\n",
|
| 44 |
+
"pkl_files = glob.glob(os.path.join(path, \"**/*.pkl\"), recursive=True)\n",
|
| 45 |
+
"print(f\"Found {len(pkl_files)} .pkl files\")\n",
|
| 46 |
+
"\n",
|
| 47 |
+
"# === 3. Import Libraries ===\n",
|
| 48 |
+
"import pickle\n",
|
| 49 |
+
"import numpy as np\n",
|
| 50 |
+
"import pandas as pd\n",
|
| 51 |
+
"from scipy import signal, stats\n",
|
| 52 |
+
"from sklearn.metrics import classification_report, f1_score, confusion_matrix\n",
|
| 53 |
+
"from sklearn.model_selection import LeaveOneGroupOut\n",
|
| 54 |
+
"import joblib\n",
|
| 55 |
+
"\n",
|
| 56 |
+
"# Try GPU-accelerated RandomForest\n",
|
| 57 |
+
"try:\n",
|
| 58 |
+
" from cuml.ensemble import RandomForestClassifier as cuRFC\n",
|
| 59 |
+
" import cupy as cp\n",
|
| 60 |
+
" USE_GPU = True\n",
|
| 61 |
+
" print(\"β Using GPU-accelerated cuML RandomForest\")\n",
|
| 62 |
+
" # Check GPU availability\n",
|
| 63 |
+
" !nvidia-smi --query-gpu=name,memory.total --format=csv\n",
|
| 64 |
+
"except ImportError:\n",
|
| 65 |
+
" from sklearn.ensemble import RandomForestClassifier as cuRFC\n",
|
| 66 |
+
" USE_GPU = False\n",
|
| 67 |
+
" print(\"β cuML not available, using CPU sklearn RandomForest\")\n",
|
| 68 |
+
"\n",
|
| 69 |
+
"# === 4. Load Subject Data ===\n",
|
| 70 |
+
"print(\"\\n\" + \"=\"*60)\n",
|
| 71 |
+
"print(\"LOADING SUBJECT DATA\")\n",
|
| 72 |
+
"print(\"=\"*60)\n",
|
| 73 |
+
"\n",
|
| 74 |
+
"subjects = []\n",
|
| 75 |
+
"for p in pkl_files:\n",
|
| 76 |
+
" with open(p, 'rb') as f:\n",
|
| 77 |
+
" data = pickle.load(f, encoding='latin1')\n",
|
| 78 |
+
" \n",
|
| 79 |
+
" subj = data['subject']\n",
|
| 80 |
+
" sig = data['signal']\n",
|
| 81 |
+
" \n",
|
| 82 |
+
" chest = sig.get('chest', {})\n",
|
| 83 |
+
" wrist = sig.get('wrist', {})\n",
|
| 84 |
+
" \n",
|
| 85 |
+
" subj_dict = {\n",
|
| 86 |
+
" 'id': subj,\n",
|
| 87 |
+
" 'ecg': np.asarray(chest.get('ECG', []), dtype=np.float64),\n",
|
| 88 |
+
" 'eda': np.asarray(chest.get('EDA', []) if 'EDA' in chest else wrist.get('EDA', []), dtype=np.float64),\n",
|
| 89 |
+
" 'temp': np.asarray(chest.get('TEMP', []) if 'TEMP' in chest else wrist.get('TEMP', []), dtype=np.float64),\n",
|
| 90 |
+
" 'label': np.asarray(data['label'], dtype=np.int32)\n",
|
| 91 |
+
" }\n",
|
| 92 |
+
" \n",
|
| 93 |
+
" # Skip subjects with missing channels\n",
|
| 94 |
+
" if subj_dict['ecg'].size==0 or subj_dict['eda'].size==0 or subj_dict['temp'].size==0:\n",
|
| 95 |
+
" print(f\" β Skipping {subj} - missing channel\")\n",
|
| 96 |
+
" continue\n",
|
| 97 |
+
" \n",
|
| 98 |
+
" subjects.append(subj_dict)\n",
|
| 99 |
+
" print(f\" β Loaded {subj}: ECG={len(subj_dict['ecg'])}, EDA={len(subj_dict['eda'])}, TEMP={len(subj_dict['temp'])}\")\n",
|
| 100 |
+
"\n",
|
| 101 |
+
"print(f\"\\nβ Successfully loaded {len(subjects)} subjects: {[s['id'] for s in subjects]}\")\n",
|
| 102 |
+
"\n",
|
| 103 |
+
"# === 5. Signal Processing Functions ===\n",
|
| 104 |
+
"def to_numpy(x):\n",
|
| 105 |
+
" \"\"\"Convert any array type to CPU numpy array\"\"\"\n",
|
| 106 |
+
" if hasattr(x, 'get'): # CuPy array\n",
|
| 107 |
+
" return np.asarray(x.get(), dtype=np.float64).flatten()\n",
|
| 108 |
+
" return np.asarray(x, dtype=np.float64).flatten()\n",
|
| 109 |
+
"\n",
|
| 110 |
+
"def downsample(x, orig_fs=700, target_fs=4):\n",
|
| 111 |
+
" \"\"\"Downsample signal from 700 Hz to 4 Hz\"\"\"\n",
|
| 112 |
+
" x = to_numpy(x)\n",
|
| 113 |
+
" if x.size==0: \n",
|
| 114 |
+
" return x\n",
|
| 115 |
+
" factor = int(orig_fs/target_fs)\n",
|
| 116 |
+
" if factor <= 1: \n",
|
| 117 |
+
" return x\n",
|
| 118 |
+
" return signal.decimate(x, factor, ftype='fir', zero_phase=True)\n",
|
| 119 |
+
"\n",
|
| 120 |
+
"def extract_features(ecg, eda, temp):\n",
|
| 121 |
+
" \"\"\"Extract time-domain and signal-specific features\"\"\"\n",
|
| 122 |
+
" # Convert all inputs to CPU numpy 1-D arrays\n",
|
| 123 |
+
" ecg = to_numpy(ecg)\n",
|
| 124 |
+
" eda = to_numpy(eda)\n",
|
| 125 |
+
" temp = to_numpy(temp)\n",
|
| 126 |
+
" \n",
|
| 127 |
+
" feats = {}\n",
|
| 128 |
+
" \n",
|
| 129 |
+
" # Basic statistical features for all signals\n",
|
| 130 |
+
" for name, arr in [('ecg', ecg), ('eda', eda), ('temp', temp)]:\n",
|
| 131 |
+
" feats[f'{name}_mean'] = float(np.mean(arr))\n",
|
| 132 |
+
" feats[f'{name}_std'] = float(np.std(arr, ddof=0))\n",
|
| 133 |
+
" feats[f'{name}_min'] = float(np.min(arr))\n",
|
| 134 |
+
" feats[f'{name}_max'] = float(np.max(arr))\n",
|
| 135 |
+
" feats[f'{name}_skew'] = float(stats.skew(arr))\n",
|
| 136 |
+
" feats[f'{name}_kurt'] = float(stats.kurtosis(arr))\n",
|
| 137 |
+
" feats[f'{name}_rms'] = float(np.sqrt(np.mean(arr**2)))\n",
|
| 138 |
+
" \n",
|
| 139 |
+
" # EDA-specific: peak count (skin conductance responses)\n",
|
| 140 |
+
" try:\n",
|
| 141 |
+
" peaks, _ = signal.find_peaks(eda, distance=4, height=None)\n",
|
| 142 |
+
" feats['eda_peaks_count'] = len(peaks)\n",
|
| 143 |
+
" except Exception as e:\n",
|
| 144 |
+
" feats['eda_peaks_count'] = 0\n",
|
| 145 |
+
" \n",
|
| 146 |
+
" # ECG-derived: Heart rate and HRV features\n",
|
| 147 |
+
" try:\n",
|
| 148 |
+
" ecg_peaks, _ = signal.find_peaks(ecg, distance=2, height=None)\n",
|
| 149 |
+
" if len(ecg_peaks) >= 2:\n",
|
| 150 |
+
" ibi = np.diff(ecg_peaks) / 4.0\n",
|
| 151 |
+
" feats['hr_mean'] = float(60.0 / np.mean(ibi))\n",
|
| 152 |
+
" feats['hr_std'] = float(np.std(60.0/ibi))\n",
|
| 153 |
+
" feats['ibi_mean'] = float(np.mean(ibi))\n",
|
| 154 |
+
" feats['ibi_sdnn'] = float(np.std(ibi, ddof=0))\n",
|
| 155 |
+
" else:\n",
|
| 156 |
+
" feats['hr_mean']=feats['hr_std']=feats['ibi_mean']=feats['ibi_sdnn']=0.0\n",
|
| 157 |
+
" except Exception as e:\n",
|
| 158 |
+
" feats['hr_mean']=feats['hr_std']=feats['ibi_mean']=feats['ibi_sdnn']=0.0\n",
|
| 159 |
+
" \n",
|
| 160 |
+
" return feats\n",
|
| 161 |
+
"\n",
|
| 162 |
+
"# === 6. Feature Extraction with Sliding Windows ===\n",
|
| 163 |
+
"print(\"\\n\" + \"=\"*60)\n",
|
| 164 |
+
"print(\"FEATURE EXTRACTION\")\n",
|
| 165 |
+
"print(\"=\"*60)\n",
|
| 166 |
+
"\n",
|
| 167 |
+
"window_sec = 60 # 60-second windows\n",
|
| 168 |
+
"fs = 4 # Target sampling frequency\n",
|
| 169 |
+
"win = window_sec * fs # 240 samples per window\n",
|
| 170 |
+
"step = int(win * 0.5) # 50% overlap (30-second step)\n",
|
| 171 |
+
"\n",
|
| 172 |
+
"X_rows, y_rows, groups = [], [], []\n",
|
| 173 |
+
"labels_map = {1:'baseline', 2:'stress', 3:'amusement', 4:'meditation'}\n",
|
| 174 |
+
"\n",
|
| 175 |
+
"for subj in subjects:\n",
|
| 176 |
+
" print(f\"Processing {subj['id']}...\", end=' ')\n",
|
| 177 |
+
" \n",
|
| 178 |
+
" # Downsample all signals to 4 Hz\n",
|
| 179 |
+
" ecg_ds = downsample(subj['ecg'])\n",
|
| 180 |
+
" eda_ds = downsample(subj['eda'])\n",
|
| 181 |
+
" temp_ds = downsample(subj['temp'])\n",
|
| 182 |
+
" labels_ds = signal.decimate(subj['label'].astype(float), int(700/4), ftype='fir', zero_phase=True)\n",
|
| 183 |
+
" \n",
|
| 184 |
+
" # Ensure all signals have same length\n",
|
| 185 |
+
" L = min(len(ecg_ds), len(eda_ds), len(temp_ds), len(labels_ds))\n",
|
| 186 |
+
" ecg_ds = ecg_ds[:L]\n",
|
| 187 |
+
" eda_ds = eda_ds[:L]\n",
|
| 188 |
+
" temp_ds = temp_ds[:L]\n",
|
| 189 |
+
" labels_ds = labels_ds[:L]\n",
|
| 190 |
+
" \n",
|
| 191 |
+
" window_count = 0\n",
|
| 192 |
+
" # Sliding window extraction\n",
|
| 193 |
+
" for start in range(0, L - win + 1, step):\n",
|
| 194 |
+
" end = start + win\n",
|
| 195 |
+
" lab_window = labels_ds[start:end]\n",
|
| 196 |
+
" \n",
|
| 197 |
+
" # Get majority label in window\n",
|
| 198 |
+
" mode_result = stats.mode(lab_window, nan_policy='omit', keepdims=True)\n",
|
| 199 |
+
" lab = int(mode_result.mode[0])\n",
|
| 200 |
+
" \n",
|
| 201 |
+
" # Keep only valid stress conditions\n",
|
| 202 |
+
" if lab not in [1, 2, 3, 4]:\n",
|
| 203 |
+
" continue\n",
|
| 204 |
+
" \n",
|
| 205 |
+
" # Extract features from window\n",
|
| 206 |
+
" try:\n",
|
| 207 |
+
" feats = extract_features(ecg_ds[start:end], eda_ds[start:end], temp_ds[start:end])\n",
|
| 208 |
+
" X_rows.append(feats)\n",
|
| 209 |
+
" y_rows.append(lab)\n",
|
| 210 |
+
" groups.append(subj['id'])\n",
|
| 211 |
+
" window_count += 1\n",
|
| 212 |
+
" except Exception as e:\n",
|
| 213 |
+
" print(f\"\\n β Error processing window: {e}\")\n",
|
| 214 |
+
" continue\n",
|
| 215 |
+
" \n",
|
| 216 |
+
" print(f\"{window_count} windows\")\n",
|
| 217 |
+
"\n",
|
| 218 |
+
"# Convert to DataFrame and arrays\n",
|
| 219 |
+
"X = pd.DataFrame(X_rows).fillna(0)\n",
|
| 220 |
+
"y = np.array(y_rows)\n",
|
| 221 |
+
"groups = np.array(groups)\n",
|
| 222 |
+
"\n",
|
| 223 |
+
"print(f\"\\nβ Feature matrix: {X.shape} (samples Γ features)\")\n",
|
| 224 |
+
"print(f\"β Features: {list(X.columns)}\")\n",
|
| 225 |
+
"print(f\"β Class distribution:\")\n",
|
| 226 |
+
"for label_id in sorted(labels_map.keys()):\n",
|
| 227 |
+
" count = np.sum(y == label_id)\n",
|
| 228 |
+
" print(f\" {labels_map[label_id]}: {count} windows\")\n",
|
| 229 |
+
"\n",
|
| 230 |
+
"# === 7. Leave-One-Subject-Out Cross-Validation ===\n",
|
| 231 |
+
"print(\"\\n\" + \"=\"*60)\n",
|
| 232 |
+
"print(\"LEAVE-ONE-SUBJECT-OUT CROSS-VALIDATION\")\n",
|
| 233 |
+
"print(\"=\"*60)\n",
|
| 234 |
+
"\n",
|
| 235 |
+
"logo = LeaveOneGroupOut()\n",
|
| 236 |
+
"f1s = []\n",
|
| 237 |
+
"all_predictions = []\n",
|
| 238 |
+
"all_true_labels = []\n",
|
| 239 |
+
"\n",
|
| 240 |
+
"for fold, (train_idx, test_idx) in enumerate(logo.split(X, y, groups), 1):\n",
|
| 241 |
+
" test_subject = groups[test_idx][0]\n",
|
| 242 |
+
" \n",
|
| 243 |
+
" # Initialize model\n",
|
| 244 |
+
" if USE_GPU:\n",
|
| 245 |
+
" clf = cuRFC(n_estimators=200, random_state=42, max_depth=16)\n",
|
| 246 |
+
" else:\n",
|
| 247 |
+
" clf = cuRFC(n_estimators=200, random_state=42, n_jobs=-1)\n",
|
| 248 |
+
" \n",
|
| 249 |
+
" # Train\n",
|
| 250 |
+
" clf.fit(X.iloc[train_idx], y[train_idx])\n",
|
| 251 |
+
" \n",
|
| 252 |
+
" # Predict\n",
|
| 253 |
+
" y_pred = clf.predict(X.iloc[test_idx])\n",
|
| 254 |
+
" \n",
|
| 255 |
+
" all_predictions.extend(y_pred)\n",
|
| 256 |
+
" all_true_labels.extend(y[test_idx])\n",
|
| 257 |
+
" \n",
|
| 258 |
+
" f1_macro = f1_score(y[test_idx], y_pred, average='macro')\n",
|
| 259 |
+
" f1_weighted = f1_score(y[test_idx], y_pred, average='weighted')\n",
|
| 260 |
+
" f1s.append(f1_macro)\n",
|
| 261 |
+
" \n",
|
| 262 |
+
" print(f\"Fold {fold:2d} | {test_subject} | F1-Macro: {f1_macro:.4f} | F1-Weighted: {f1_weighted:.4f}\")\n",
|
| 263 |
+
"\n",
|
| 264 |
+
"print(\"=\"*60)\n",
|
| 265 |
+
"print(f\"β Mean F1-Macro: {np.mean(f1s):.4f} Β± {np.std(f1s):.4f}\")\n",
|
| 266 |
+
"print(f\"β Median F1-Macro: {np.median(f1s):.4f}\")\n",
|
| 267 |
+
"print(\"=\"*60)\n",
|
| 268 |
+
"\n",
|
| 269 |
+
"# Confusion Matrix\n",
|
| 270 |
+
"print(\"\\nConfusion Matrix:\")\n",
|
| 271 |
+
"cm = confusion_matrix(all_true_labels, all_predictions)\n",
|
| 272 |
+
"cm_df = pd.DataFrame(cm, \n",
|
| 273 |
+
" index=[labels_map[i] for i in sorted(labels_map.keys())],\n",
|
| 274 |
+
" columns=[labels_map[i] for i in sorted(labels_map.keys())])\n",
|
| 275 |
+
"print(cm_df)\n",
|
| 276 |
+
"\n",
|
| 277 |
+
"# === 8. Train Final Model on All Data ===\n",
|
| 278 |
+
"print(\"\\n\" + \"=\"*60)\n",
|
| 279 |
+
"print(\"TRAINING FINAL MODEL\")\n",
|
| 280 |
+
"print(\"=\"*60)\n",
|
| 281 |
+
"\n",
|
| 282 |
+
"if USE_GPU:\n",
|
| 283 |
+
" clf_final = cuRFC(n_estimators=300, random_state=42, max_depth=16)\n",
|
| 284 |
+
"else:\n",
|
| 285 |
+
" clf_final = cuRFC(n_estimators=300, random_state=42, n_jobs=-1)\n",
|
| 286 |
+
"\n",
|
| 287 |
+
"print(\"Training on all data...\")\n",
|
| 288 |
+
"clf_final.fit(X, y)\n",
|
| 289 |
+
"\n",
|
| 290 |
+
"y_pred_final = clf_final.predict(X)\n",
|
| 291 |
+
"print(\"\\nFinal Model Performance:\")\n",
|
| 292 |
+
"print(classification_report(y, y_pred_final, \n",
|
| 293 |
+
" target_names=[labels_map[i] for i in sorted(labels_map.keys())]))\n",
|
| 294 |
+
"\n",
|
| 295 |
+
"# === 9. Save Model and Artifacts ===\n",
|
| 296 |
+
"print(\"\\n\" + \"=\"*60)\n",
|
| 297 |
+
"print(\"SAVING MODEL\")\n",
|
| 298 |
+
"print(\"=\"*60)\n",
|
| 299 |
+
"\n",
|
| 300 |
+
"# Save to /content/ (Colab's default working directory)\n",
|
| 301 |
+
"model_path = '/content/wesad_rfc_model.joblib'\n",
|
| 302 |
+
"features_path = '/content/wesad_feature_names.joblib'\n",
|
| 303 |
+
"scaler_info_path = '/content/wesad_preprocessing_info.joblib'\n",
|
| 304 |
+
"\n",
|
| 305 |
+
"joblib.dump(clf_final, model_path)\n",
|
| 306 |
+
"joblib.dump(X.columns.tolist(), features_path)\n",
|
| 307 |
+
"joblib.dump({\n",
|
| 308 |
+
" 'labels_map': labels_map,\n",
|
| 309 |
+
" 'window_sec': window_sec,\n",
|
| 310 |
+
" 'fs': fs,\n",
|
| 311 |
+
" 'feature_count': X.shape[1]\n",
|
| 312 |
+
"}, scaler_info_path)\n",
|
| 313 |
+
"\n",
|
| 314 |
+
"print(f\"β Model saved\")\n",
|
| 315 |
+
"print(f\"β Feature names saved\")\n",
|
| 316 |
+
"print(f\"β Preprocessing info saved\")\n",
|
| 317 |
+
"print(\"\\nβ Training complete!\")\n",
|
| 318 |
+
"\n",
|
| 319 |
+
"# === 10. Download Model Files ===\n",
|
| 320 |
+
"from google.colab import files\n",
|
| 321 |
+
"\n",
|
| 322 |
+
"print(\"\\nDownloading model files...\")\n",
|
| 323 |
+
"files.download('/content/wesad_rfc_model.joblib')\n",
|
| 324 |
+
"files.download('/content/wesad_feature_names.joblib')\n",
|
| 325 |
+
"files.download('/content/wesad_preprocessing_info.joblib')\n",
|
| 326 |
+
"print(\"β Downloads complete\")"
|
| 327 |
+
]
|
| 328 |
+
}
|
| 329 |
+
],
|
| 330 |
+
"metadata": {
|
| 331 |
+
"language_info": {
|
| 332 |
+
"name": "python"
|
| 333 |
+
}
|
| 334 |
+
},
|
| 335 |
+
"nbformat": 4,
|
| 336 |
+
"nbformat_minor": 5
|
| 337 |
+
}
|