Eric Xu commited on
Commit
ff413e9
·
unverified ·
1 Parent(s): 6dcf8e5

Fix XSS, VJP contract, binding, session security, name collisions, and error handling

Browse files

1. XSS: Add esc() helper to sanitize all user/LLM data interpolated into
innerHTML (logStep, gradient table details, eval log, bias audit table).
2. VJP contract: Backend analyze_gradient() now returns structured ranked
data alongside text; frontend renders from backend rankings instead of
recomputing unweighted averages.
3. Bind uvicorn to 127.0.0.1 instead of 0.0.0.0.
4. Session-bound tickets: store sid in counterfactual ticket and verify on
stream pickup.
5. Name collision: use composite key (name_user_id) for cohort_map in both
web app and scripts to avoid persona lookup collisions; propagate
user_id through _evaluator dict.
6. Counterfactual error handling: add try-catch around fut.result() matching
the evaluate stream pattern.

scripts/counterfactual.py CHANGED
@@ -89,10 +89,20 @@ def build_changes_block(changes):
89
  return "\n".join(lines)
90
 
91
 
 
 
 
 
 
 
 
 
 
 
92
  def probe_one(client, model, eval_result, cohort_map, all_changes):
93
  ev = eval_result.get("_evaluator", {})
94
  name = ev.get("name", "")
95
- persona_text = cohort_map.get(name, {}).get("persona", "")
96
 
97
  prompt = PROBE_PROMPT.format(
98
  name=name, age=ev.get("age", ""),
@@ -150,10 +160,17 @@ def compute_goal_weights(client, model, eval_results, cohort_map, goal, parallel
150
  """Score each evaluator's relevance to the goal. Returns {name: weight}."""
151
  weights = {}
152
 
 
 
 
 
 
 
153
  def score_one(r):
154
  ev = r.get("_evaluator", {})
155
  name = ev.get("name", "")
156
- persona = cohort_map.get(name, {})
 
157
  prompt = GOAL_RELEVANCE_PROMPT.format(
158
  goal=goal, name=name, age=ev.get("age", ""),
159
  occupation=ev.get("occupation", ""),
@@ -170,15 +187,15 @@ def compute_goal_weights(client, model, eval_results, cohort_map, goal, parallel
170
  content = resp.choices[0].message.content
171
  content = re.sub(r'<think>[\s\S]*?</think>', '', content).strip()
172
  data = json.loads(content)
173
- return name, float(data.get("relevance", 0.5)), data.get("reasoning", "")
174
  except Exception:
175
- return name, 0.5, "default"
176
 
177
  with concurrent.futures.ThreadPoolExecutor(max_workers=parallel) as pool:
178
  futs = [pool.submit(score_one, r) for r in eval_results]
179
  for fut in concurrent.futures.as_completed(futs):
180
- name, weight, reasoning = fut.result()
181
- weights[name] = {"weight": weight, "reasoning": reasoning}
182
 
183
  return weights
184
 
@@ -186,22 +203,26 @@ def compute_goal_weights(client, model, eval_results, cohort_map, goal, parallel
186
  def analyze_gradient(results, all_changes, goal_weights=None):
187
  valid = [r for r in results if "counterfactuals" in r]
188
  if not valid:
189
- return "No valid results."
190
 
191
  has_goal = goal_weights is not None
192
  labels = {c["id"]: c["label"] for c in all_changes}
193
  jacobian = defaultdict(list)
194
 
195
  for r in valid:
196
- name = r["_evaluator"].get("name", "")
197
- w = goal_weights.get(name, {}).get("weight", 1.0) if has_goal else 1.0
 
 
 
198
  for cf in r.get("counterfactuals", []):
199
  jacobian[cf.get("change_id", "")].append({
200
  "delta": cf.get("delta", 0),
201
  "weighted_delta": cf.get("delta", 0) * w,
202
  "weight": w,
203
  "name": name,
204
- "age": r["_evaluator"].get("age", ""),
 
205
  "reasoning": cf.get("reasoning", ""),
206
  })
207
 
@@ -268,7 +289,22 @@ def analyze_gradient(results, all_changes, goal_weights=None):
268
  lines.append(f" {d['delta']} {d['name']} ({d['age']}){w_label}: {d['reasoning']}")
269
  lines.append("")
270
 
271
- return "\n".join(lines)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
272
 
273
 
274
  def main():
@@ -300,7 +336,7 @@ def main():
300
  for cat in changes_data.values():
301
  all_changes.extend(cat if isinstance(cat, list) else cat.get("changes", []))
302
 
303
- cohort_map = {p["name"]: p for p in cohort}
304
 
305
  movable = [r for r in eval_results
306
  if "score" in r and args.min_score <= r["score"] <= args.max_score]
@@ -356,7 +392,7 @@ def main():
356
  relevant = sum(1 for v in goal_weights.values() if v["weight"] >= 0.5)
357
  print(f" {relevant}/{len(goal_weights)} evaluators relevant to goal\n")
358
 
359
- gradient = analyze_gradient(results, all_changes, goal_weights=goal_weights)
360
  with open(out_dir / "gradient.md", "w") as f:
361
  f.write(gradient)
362
 
 
89
  return "\n".join(lines)
90
 
91
 
92
+ def _cohort_lookup(cohort_map, ev):
93
+ """Look up persona by composite key (name_user_id), falling back to name."""
94
+ name = ev.get("name", "")
95
+ uid = ev.get("user_id", "")
96
+ key = f"{name}_{uid}"
97
+ if key in cohort_map:
98
+ return cohort_map[key]
99
+ return cohort_map.get(name, {})
100
+
101
+
102
  def probe_one(client, model, eval_result, cohort_map, all_changes):
103
  ev = eval_result.get("_evaluator", {})
104
  name = ev.get("name", "")
105
+ persona_text = _cohort_lookup(cohort_map, ev).get("persona", "")
106
 
107
  prompt = PROBE_PROMPT.format(
108
  name=name, age=ev.get("age", ""),
 
160
  """Score each evaluator's relevance to the goal. Returns {name: weight}."""
161
  weights = {}
162
 
163
+ def _eval_key(ev):
164
+ """Composite key matching cohort_map keys to avoid name collisions."""
165
+ name = ev.get("name", "")
166
+ uid = ev.get("user_id", "")
167
+ return f"{name}_{uid}" if uid else name
168
+
169
  def score_one(r):
170
  ev = r.get("_evaluator", {})
171
  name = ev.get("name", "")
172
+ key = _eval_key(ev)
173
+ persona = _cohort_lookup(cohort_map, ev)
174
  prompt = GOAL_RELEVANCE_PROMPT.format(
175
  goal=goal, name=name, age=ev.get("age", ""),
176
  occupation=ev.get("occupation", ""),
 
187
  content = resp.choices[0].message.content
188
  content = re.sub(r'<think>[\s\S]*?</think>', '', content).strip()
189
  data = json.loads(content)
190
+ return key, float(data.get("relevance", 0.5)), data.get("reasoning", "")
191
  except Exception:
192
+ return key, 0.5, "default"
193
 
194
  with concurrent.futures.ThreadPoolExecutor(max_workers=parallel) as pool:
195
  futs = [pool.submit(score_one, r) for r in eval_results]
196
  for fut in concurrent.futures.as_completed(futs):
197
+ key, weight, reasoning = fut.result()
198
+ weights[key] = {"weight": weight, "reasoning": reasoning}
199
 
200
  return weights
201
 
 
203
  def analyze_gradient(results, all_changes, goal_weights=None):
204
  valid = [r for r in results if "counterfactuals" in r]
205
  if not valid:
206
+ return "No valid results.", []
207
 
208
  has_goal = goal_weights is not None
209
  labels = {c["id"]: c["label"] for c in all_changes}
210
  jacobian = defaultdict(list)
211
 
212
  for r in valid:
213
+ ev = r["_evaluator"]
214
+ name = ev.get("name", "")
215
+ uid = ev.get("user_id", "")
216
+ key = f"{name}_{uid}" if uid else name
217
+ w = goal_weights.get(key, {}).get("weight", 1.0) if has_goal else 1.0
218
  for cf in r.get("counterfactuals", []):
219
  jacobian[cf.get("change_id", "")].append({
220
  "delta": cf.get("delta", 0),
221
  "weighted_delta": cf.get("delta", 0) * w,
222
  "weight": w,
223
  "name": name,
224
+ "age": ev.get("age", ""),
225
+ "occupation": ev.get("occupation", ""),
226
  "reasoning": cf.get("reasoning", ""),
227
  })
228
 
 
289
  lines.append(f" {d['delta']} {d['name']} ({d['age']}){w_label}: {d['reasoning']}")
290
  lines.append("")
291
 
292
+ # Return ranked data alongside text for structured consumers (web UI)
293
+ ranked_data = [{
294
+ "id": r["id"], "label": r["label"],
295
+ "avg_delta": round(r["avg_delta"], 2),
296
+ "raw_avg_delta": round(r["raw_avg_delta"], 2),
297
+ "max_delta": r["max_delta"], "min_delta": r["min_delta"],
298
+ "positive": r["positive"], "negative": r["negative"],
299
+ "n": r["n"],
300
+ "details": sorted([{
301
+ "name": d["name"], "age": d.get("age", ""),
302
+ "occupation": d.get("occupation", ""),
303
+ "delta": d["delta"], "reasoning": d.get("reasoning", ""),
304
+ } for d in r["details"]], key=lambda x: x["delta"], reverse=True),
305
+ } for r in ranked]
306
+
307
+ return "\n".join(lines), ranked_data
308
 
309
 
310
  def main():
 
336
  for cat in changes_data.values():
337
  all_changes.extend(cat if isinstance(cat, list) else cat.get("changes", []))
338
 
339
+ cohort_map = {f"{p.get('name','')}_{p.get('user_id','')}": p for p in cohort}
340
 
341
  movable = [r for r in eval_results
342
  if "score" in r and args.min_score <= r["score"] <= args.max_score]
 
392
  relevant = sum(1 for v in goal_weights.values() if v["weight"] >= 0.5)
393
  print(f" {relevant}/{len(goal_weights)} evaluators relevant to goal\n")
394
 
395
+ gradient, _ranked = analyze_gradient(results, all_changes, goal_weights=goal_weights)
396
  with open(out_dir / "gradient.md", "w") as f:
397
  f.write(gradient)
398
 
scripts/evaluate.py CHANGED
@@ -125,6 +125,7 @@ def evaluate_one(client, model, evaluator, entity_text, system_prompt=None):
125
  result = json.loads(content)
126
  result["_evaluator"] = {
127
  "name": evaluator["name"],
 
128
  "age": evaluator.get("age"),
129
  "city": evaluator.get("city"),
130
  "state": evaluator.get("state"),
 
125
  result = json.loads(content)
126
  result["_evaluator"] = {
127
  "name": evaluator["name"],
128
+ "user_id": evaluator.get("user_id"),
129
  "age": evaluator.get("age"),
130
  "city": evaluator.get("city"),
131
  "state": evaluator.get("state"),
web/app.py CHANGED
@@ -583,7 +583,7 @@ async def prepare_counterfactual(sid: str, req: CounterfactualRequest):
583
  expired = [k for k, v in _cf_pending.items() if now - v.get("ts", 0) > 600]
584
  for k in expired:
585
  del _cf_pending[k]
586
- _cf_pending[ticket] = {"req": req, "ts": now}
587
  return {"ticket": ticket}
588
 
589
 
@@ -598,6 +598,8 @@ async def counterfactual_stream(sid: str, ticket: str):
598
  entry = _cf_pending.pop(ticket, None)
599
  if not entry:
600
  raise HTTPException(400, "Invalid or expired ticket")
 
 
601
  req = entry["req"]
602
 
603
  all_changes = req.changes
@@ -611,7 +613,7 @@ async def counterfactual_stream(sid: str, ticket: str):
611
  model = get_model()
612
  cohort = session["cohort"]
613
  eval_results = session["eval_results"]
614
- cohort_map = {p["name"]: p for p in cohort}
615
 
616
  movable = [r for r in eval_results
617
  if "score" in r and min_score <= r["score"] <= max_score]
@@ -659,7 +661,10 @@ async def counterfactual_stream(sid: str, ticket: str):
659
  }
660
  for fut in concurrent.futures.as_completed(futs):
661
  idx = futs[fut]
662
- result = fut.result()
 
 
 
663
  results[idx] = result
664
  done += 1
665
 
@@ -678,13 +683,14 @@ async def counterfactual_stream(sid: str, ticket: str):
678
  yield {"event": "progress", "data": json.dumps(progress)}
679
 
680
  elapsed = time.time() - t0
681
- gradient_text = analyze_gradient(results, all_changes,
682
- goal_weights=goal_weights)
683
  session["gradient"] = gradient_text
684
 
685
  yield {"event": "complete", "data": json.dumps({
686
  "elapsed": round(elapsed, 1),
687
  "gradient": gradient_text,
 
688
  "results": results,
689
  "goal": goal if has_goal else None,
690
  })}
@@ -830,4 +836,4 @@ if __name__ == "__main__":
830
  import uvicorn
831
  print(f"\n SGO Web Interface")
832
  print(f" http://localhost:8000\n")
833
- uvicorn.run(app, host="0.0.0.0", port=8000)
 
583
  expired = [k for k, v in _cf_pending.items() if now - v.get("ts", 0) > 600]
584
  for k in expired:
585
  del _cf_pending[k]
586
+ _cf_pending[ticket] = {"req": req, "ts": now, "sid": sid}
587
  return {"ticket": ticket}
588
 
589
 
 
598
  entry = _cf_pending.pop(ticket, None)
599
  if not entry:
600
  raise HTTPException(400, "Invalid or expired ticket")
601
+ if entry.get("sid") != sid:
602
+ raise HTTPException(403, "Ticket does not belong to this session")
603
  req = entry["req"]
604
 
605
  all_changes = req.changes
 
613
  model = get_model()
614
  cohort = session["cohort"]
615
  eval_results = session["eval_results"]
616
+ cohort_map = {f"{p.get('name','')}_{p.get('user_id','')}": p for p in cohort}
617
 
618
  movable = [r for r in eval_results
619
  if "score" in r and min_score <= r["score"] <= max_score]
 
661
  }
662
  for fut in concurrent.futures.as_completed(futs):
663
  idx = futs[fut]
664
+ try:
665
+ result = fut.result()
666
+ except Exception as e:
667
+ result = {"error": str(e), "_evaluator": {"name": "?"}}
668
  results[idx] = result
669
  done += 1
670
 
 
683
  yield {"event": "progress", "data": json.dumps(progress)}
684
 
685
  elapsed = time.time() - t0
686
+ gradient_text, ranked_data = analyze_gradient(results, all_changes,
687
+ goal_weights=goal_weights)
688
  session["gradient"] = gradient_text
689
 
690
  yield {"event": "complete", "data": json.dumps({
691
  "elapsed": round(elapsed, 1),
692
  "gradient": gradient_text,
693
+ "ranked": ranked_data,
694
  "results": results,
695
  "goal": goal if has_goal else None,
696
  })}
 
836
  import uvicorn
837
  print(f"\n SGO Web Interface")
838
  print(f" http://localhost:8000\n")
839
+ uvicorn.run(app, host="127.0.0.1", port=8000)
web/static/index.html CHANGED
@@ -558,6 +558,17 @@ const TEMPLATES = {
558
  let sessionId = null;
559
  let evalResultsData = null;
560
 
 
 
 
 
 
 
 
 
 
 
 
561
  // ── Init ──
562
 
563
  async function init() {
@@ -652,7 +663,7 @@ function goToStep(n) {
652
 
653
  function logStep(msg, cls = '') {
654
  const log = document.getElementById('evalLog');
655
- log.innerHTML += `<div class="${cls}">${msg}</div>`;
656
  log.scrollTop = log.scrollHeight;
657
  }
658
 
@@ -875,7 +886,7 @@ async function runDirections() {
875
 
876
  const log = document.getElementById('cfLog');
877
  concerns.slice(0, 8).forEach(c => {
878
- log.innerHTML += `<div style="color:var(--text2)">Concern: ${c}</div>`;
879
  });
880
  log.innerHTML += `<div>${concerns.length} unique concerns from ${persuadable.length} persuadable evaluators</div>`;
881
  document.getElementById('cfProgressBar').style.width = '15%';
@@ -891,7 +902,7 @@ async function runDirections() {
891
  suggestedChanges = suggestData.changes || [];
892
 
893
  suggestedChanges.forEach(c => {
894
- log.innerHTML += `<div class="pos">Change: ${c.label} — ${c.description}</div>`;
895
  });
896
  log.scrollTop = log.scrollHeight;
897
  document.getElementById('cfProgressBar').style.width = '25%';
@@ -926,7 +937,7 @@ async function runDirections() {
926
  es.addEventListener('goal_weights', (e) => {
927
  const d = JSON.parse(e.data);
928
  document.getElementById('cfProgressText').textContent = d.message;
929
- log.innerHTML += `<div>${d.message}</div>`;
930
  log.scrollTop = log.scrollHeight;
931
  });
932
 
@@ -938,7 +949,7 @@ async function runDirections() {
938
 
939
  const delta = d.best_delta > 0 ? `+${d.best_delta}` : d.best_delta;
940
  const changeName = (suggestedChanges.find(c => c.id === d.best_change) || {}).label || d.best_change;
941
- log.innerHTML += `<div>${d.name} (orig ${d.original_score}): best ${delta} from "${changeName}"</div>`;
942
  log.scrollTop = log.scrollHeight;
943
  });
944
 
@@ -955,7 +966,7 @@ async function runDirections() {
955
  return;
956
  }
957
 
958
- renderGradientTable(d.results, suggestedChanges);
959
  document.getElementById('gradientText').textContent = d.gradient;
960
  document.getElementById('changesTested').textContent =
961
  suggestedChanges.map(c => `${c.label}: ${c.description}`).join('\n');
@@ -971,48 +982,55 @@ async function runDirections() {
971
  }
972
  }
973
 
974
- function renderGradientTable(results, changes) {
975
- const valid = results.filter(r => r && r.counterfactuals);
976
- const labels = {};
977
- const descs = {};
978
- changes.forEach(c => { labels[c.id] = c.label; descs[c.id] = c.description; });
979
-
980
- // Aggregate with per-evaluator details
981
- const byChange = {};
982
- valid.forEach(r => {
983
- const ev = r._evaluator || {};
984
- (r.counterfactuals || []).forEach(cf => {
985
- const cid = cf.change_id;
986
- if (!byChange[cid]) byChange[cid] = {deltas: [], pos: 0, neg: 0, details: []};
987
- const delta = cf.delta || 0;
988
- byChange[cid].deltas.push(delta);
989
- if (delta > 0) byChange[cid].pos++;
990
- if (delta < 0) byChange[cid].neg++;
991
- byChange[cid].details.push({
992
- name: ev.name || '?',
993
- age: ev.age || '',
994
- occupation: ev.occupation || '',
995
- delta: delta,
996
- reasoning: cf.reasoning || '',
997
  });
998
  });
999
- });
1000
-
1001
- const ranked = Object.entries(byChange).map(([cid, d]) => {
1002
- const avg = d.deltas.reduce((a, b) => a + b, 0) / d.deltas.length;
1003
- const min = Math.min(...d.deltas);
1004
- const max = Math.max(...d.deltas);
1005
- d.details.sort((a, b) => b.delta - a.delta);
1006
- return {id: cid, label: labels[cid] || cid, desc: descs[cid] || '', avg, min, max, pos: d.pos, neg: d.neg, details: d.details};
1007
- });
1008
- ranked.sort((a, b) => b.avg - a.avg);
 
 
 
 
 
 
1009
 
1010
  const tbody = document.querySelector('#gradientTable tbody');
1011
  tbody.innerHTML = '';
1012
  ranked.forEach((r, i) => {
1013
- const cls = r.avg >= 0 ? 'delta-pos' : 'delta-neg';
1014
- const barWidth = Math.min(Math.abs(r.avg) * 30, 120);
1015
- const barColor = r.avg >= 0 ? 'var(--green)' : 'var(--red)';
 
1016
  const rowId = `gradient-detail-${i}`;
1017
 
1018
  // Summary row (clickable)
@@ -1020,35 +1038,36 @@ function renderGradientTable(results, changes) {
1020
  <tr onclick="document.getElementById('${rowId}').classList.toggle('hidden')" style="cursor:pointer">
1021
  <td>${i + 1}</td>
1022
  <td>
1023
- <div style="font-weight:600">${r.label}</div>
1024
- <div style="font-size:0.75rem;color:var(--text2);margin-top:2px">${r.desc}</div>
1025
  </td>
1026
  <td class="${cls}">
1027
- ${r.avg >= 0 ? '+' : ''}${r.avg.toFixed(1)}
1028
  <span class="delta-bar" style="width:${barWidth}px;background:${barColor};margin-left:8px"></span>
1029
  </td>
1030
- <td style="color:var(--text2)">${r.min >= 0 ? '+' : ''}${r.min} to +${r.max}</td>
1031
- <td style="color:var(--green)">${r.pos}</td>
1032
- <td style="color:var(--red)">${r.neg}</td>
1033
  </tr>
1034
  `;
1035
 
1036
  // Detail row (hidden by default)
1037
- const helped = r.details.filter(d => d.delta > 0).slice(0, 5);
1038
- const hurt = r.details.filter(d => d.delta < 0).slice(0, 3);
1039
- const neutral = r.details.filter(d => d.delta === 0).length;
 
1040
 
1041
  let detailHtml = '<div style="padding:12px 16px;font-size:0.8rem;line-height:1.6">';
1042
  if (helped.length) {
1043
  detailHtml += '<div style="color:var(--green);font-weight:600;margin-bottom:4px">Helps:</div>';
1044
  helped.forEach(d => {
1045
- detailHtml += `<div style="margin-left:12px;margin-bottom:4px">+${d.delta} <strong>${d.name}</strong> (${d.age}, ${d.occupation}): ${d.reasoning}</div>`;
1046
  });
1047
  }
1048
  if (hurt.length) {
1049
  detailHtml += '<div style="color:var(--red);font-weight:600;margin-top:8px;margin-bottom:4px">Hurts:</div>';
1050
  hurt.forEach(d => {
1051
- detailHtml += `<div style="margin-left:12px;margin-bottom:4px">${d.delta} <strong>${d.name}</strong> (${d.age}, ${d.occupation}): ${d.reasoning}</div>`;
1052
  });
1053
  }
1054
  if (neutral) {
@@ -1121,7 +1140,7 @@ function runBiasAudit() {
1121
 
1122
  d.analyses.forEach(a => {
1123
  if (a.error) {
1124
- tbody.innerHTML += `<tr><td>${a.probe}</td><td colspan="4">Error: ${a.error}</td></tr>`;
1125
  return;
1126
  }
1127
  const expected = baselines[a.probe];
@@ -1137,7 +1156,7 @@ function runBiasAudit() {
1137
 
1138
  tbody.innerHTML += `
1139
  <tr>
1140
- <td style="font-weight:600">${a.probe}</td>
1141
  <td>${a.shifted_pct.toFixed(1)}%</td>
1142
  <td>${a.avg_abs_delta.toFixed(2)}</td>
1143
  <td style="color:var(--text2)">${expected !== undefined ? expected + '%' : '—'}</td>
 
558
  let sessionId = null;
559
  let evalResultsData = null;
560
 
561
+ // XSS sanitization helper
562
+ function esc(str) {
563
+ if (str == null) return '';
564
+ return String(str)
565
+ .replace(/&/g, '&amp;')
566
+ .replace(/</g, '&lt;')
567
+ .replace(/>/g, '&gt;')
568
+ .replace(/"/g, '&quot;')
569
+ .replace(/'/g, '&#039;');
570
+ }
571
+
572
  // ── Init ──
573
 
574
  async function init() {
 
663
 
664
  function logStep(msg, cls = '') {
665
  const log = document.getElementById('evalLog');
666
+ log.innerHTML += `<div class="${esc(cls)}">${esc(msg)}</div>`;
667
  log.scrollTop = log.scrollHeight;
668
  }
669
 
 
886
 
887
  const log = document.getElementById('cfLog');
888
  concerns.slice(0, 8).forEach(c => {
889
+ log.innerHTML += `<div style="color:var(--text2)">Concern: ${esc(c)}</div>`;
890
  });
891
  log.innerHTML += `<div>${concerns.length} unique concerns from ${persuadable.length} persuadable evaluators</div>`;
892
  document.getElementById('cfProgressBar').style.width = '15%';
 
902
  suggestedChanges = suggestData.changes || [];
903
 
904
  suggestedChanges.forEach(c => {
905
+ log.innerHTML += `<div class="pos">Change: ${esc(c.label)} — ${esc(c.description)}</div>`;
906
  });
907
  log.scrollTop = log.scrollHeight;
908
  document.getElementById('cfProgressBar').style.width = '25%';
 
937
  es.addEventListener('goal_weights', (e) => {
938
  const d = JSON.parse(e.data);
939
  document.getElementById('cfProgressText').textContent = d.message;
940
+ log.innerHTML += `<div>${esc(d.message)}</div>`;
941
  log.scrollTop = log.scrollHeight;
942
  });
943
 
 
949
 
950
  const delta = d.best_delta > 0 ? `+${d.best_delta}` : d.best_delta;
951
  const changeName = (suggestedChanges.find(c => c.id === d.best_change) || {}).label || d.best_change;
952
+ log.innerHTML += `<div>${esc(d.name)} (orig ${d.original_score}): best ${delta} from "${esc(changeName)}"</div>`;
953
  log.scrollTop = log.scrollHeight;
954
  });
955
 
 
966
  return;
967
  }
968
 
969
+ renderGradientTable(d.results, suggestedChanges, d.ranked);
970
  document.getElementById('gradientText').textContent = d.gradient;
971
  document.getElementById('changesTested').textContent =
972
  suggestedChanges.map(c => `${c.label}: ${c.description}`).join('\n');
 
982
  }
983
  }
984
 
985
+ function renderGradientTable(results, changes, ranked) {
986
+ // Use backend-provided ranked data (respects goal weights / VJP) when available,
987
+ // falling back to client-side aggregation only for legacy responses.
988
+ if (!ranked || !ranked.length) {
989
+ // Legacy fallback: recompute from raw results (unweighted)
990
+ const valid = results.filter(r => r && r.counterfactuals);
991
+ const labels = {};
992
+ const descs = {};
993
+ changes.forEach(c => { labels[c.id] = c.label; descs[c.id] = c.description; });
994
+ const byChange = {};
995
+ valid.forEach(r => {
996
+ const ev = r._evaluator || {};
997
+ (r.counterfactuals || []).forEach(cf => {
998
+ const cid = cf.change_id;
999
+ if (!byChange[cid]) byChange[cid] = {deltas: [], pos: 0, neg: 0, details: []};
1000
+ const delta = cf.delta || 0;
1001
+ byChange[cid].deltas.push(delta);
1002
+ if (delta > 0) byChange[cid].pos++;
1003
+ if (delta < 0) byChange[cid].neg++;
1004
+ byChange[cid].details.push({
1005
+ name: ev.name || '?', age: ev.age || '',
1006
+ occupation: ev.occupation || '', delta, reasoning: cf.reasoning || '',
1007
+ });
1008
  });
1009
  });
1010
+ ranked = Object.entries(byChange).map(([cid, d]) => {
1011
+ const avg = d.deltas.reduce((a, b) => a + b, 0) / d.deltas.length;
1012
+ d.details.sort((a, b) => b.delta - a.delta);
1013
+ return {
1014
+ id: cid, label: labels[cid] || cid, desc: descs[cid] || '',
1015
+ avg_delta: avg, min_delta: Math.min(...d.deltas), max_delta: Math.max(...d.deltas),
1016
+ positive: d.pos, negative: d.neg, details: d.details,
1017
+ };
1018
+ });
1019
+ ranked.sort((a, b) => b.avg_delta - a.avg_delta);
1020
+ } else {
1021
+ // Attach descriptions from changes list
1022
+ const descs = {};
1023
+ changes.forEach(c => { descs[c.id] = c.description; });
1024
+ ranked.forEach(r => { if (!r.desc) r.desc = descs[r.id] || ''; });
1025
+ }
1026
 
1027
  const tbody = document.querySelector('#gradientTable tbody');
1028
  tbody.innerHTML = '';
1029
  ranked.forEach((r, i) => {
1030
+ const avg = r.avg_delta;
1031
+ const cls = avg >= 0 ? 'delta-pos' : 'delta-neg';
1032
+ const barWidth = Math.min(Math.abs(avg) * 30, 120);
1033
+ const barColor = avg >= 0 ? 'var(--green)' : 'var(--red)';
1034
  const rowId = `gradient-detail-${i}`;
1035
 
1036
  // Summary row (clickable)
 
1038
  <tr onclick="document.getElementById('${rowId}').classList.toggle('hidden')" style="cursor:pointer">
1039
  <td>${i + 1}</td>
1040
  <td>
1041
+ <div style="font-weight:600">${esc(r.label)}</div>
1042
+ <div style="font-size:0.75rem;color:var(--text2);margin-top:2px">${esc(r.desc)}</div>
1043
  </td>
1044
  <td class="${cls}">
1045
+ ${avg >= 0 ? '+' : ''}${avg.toFixed(1)}
1046
  <span class="delta-bar" style="width:${barWidth}px;background:${barColor};margin-left:8px"></span>
1047
  </td>
1048
+ <td style="color:var(--text2)">${r.min_delta >= 0 ? '+' : ''}${r.min_delta} to +${r.max_delta}</td>
1049
+ <td style="color:var(--green)">${r.positive}</td>
1050
+ <td style="color:var(--red)">${r.negative}</td>
1051
  </tr>
1052
  `;
1053
 
1054
  // Detail row (hidden by default)
1055
+ const details = r.details || [];
1056
+ const helped = details.filter(d => d.delta > 0).slice(0, 5);
1057
+ const hurt = details.filter(d => d.delta < 0).slice(0, 3);
1058
+ const neutral = details.filter(d => d.delta === 0).length;
1059
 
1060
  let detailHtml = '<div style="padding:12px 16px;font-size:0.8rem;line-height:1.6">';
1061
  if (helped.length) {
1062
  detailHtml += '<div style="color:var(--green);font-weight:600;margin-bottom:4px">Helps:</div>';
1063
  helped.forEach(d => {
1064
+ detailHtml += `<div style="margin-left:12px;margin-bottom:4px">+${d.delta} <strong>${esc(d.name)}</strong> (${esc(d.age)}, ${esc(d.occupation)}): ${esc(d.reasoning)}</div>`;
1065
  });
1066
  }
1067
  if (hurt.length) {
1068
  detailHtml += '<div style="color:var(--red);font-weight:600;margin-top:8px;margin-bottom:4px">Hurts:</div>';
1069
  hurt.forEach(d => {
1070
+ detailHtml += `<div style="margin-left:12px;margin-bottom:4px">${d.delta} <strong>${esc(d.name)}</strong> (${esc(d.age)}, ${esc(d.occupation)}): ${esc(d.reasoning)}</div>`;
1071
  });
1072
  }
1073
  if (neutral) {
 
1140
 
1141
  d.analyses.forEach(a => {
1142
  if (a.error) {
1143
+ tbody.innerHTML += `<tr><td>${esc(a.probe)}</td><td colspan="4">Error: ${esc(a.error)}</td></tr>`;
1144
  return;
1145
  }
1146
  const expected = baselines[a.probe];
 
1156
 
1157
  tbody.innerHTML += `
1158
  <tr>
1159
+ <td style="font-weight:600">${esc(a.probe)}</td>
1160
  <td>${a.shifted_pct.toFixed(1)}%</td>
1161
  <td>${a.avg_abs_delta.toFixed(2)}</td>
1162
  <td style="color:var(--text2)">${expected !== undefined ? expected + '%' : '—'}</td>