jjkim
commited on
Commit
·
2128ba2
1
Parent(s):
f35f0d4
refactor & fix order bug & add early stop option
Browse files- code_eval.py +71 -58
- requirements.txt +2 -1
code_eval.py
CHANGED
|
@@ -20,11 +20,14 @@ import itertools
|
|
| 20 |
import os
|
| 21 |
from collections import Counter, defaultdict
|
| 22 |
from concurrent.futures import CancelledError, ThreadPoolExecutor, as_completed
|
|
|
|
|
|
|
| 23 |
|
| 24 |
import datasets
|
| 25 |
import evaluate
|
| 26 |
import numpy as np
|
| 27 |
from tqdm import tqdm
|
|
|
|
| 28 |
|
| 29 |
from .execute import check_correctness
|
| 30 |
|
|
@@ -155,9 +158,11 @@ class CodeEval(evaluate.Metric):
|
|
| 155 |
self,
|
| 156 |
predictions,
|
| 157 |
references,
|
|
|
|
| 158 |
k=[1, 10, 100],
|
| 159 |
num_workers=4,
|
| 160 |
timeout=3.0,
|
|
|
|
| 161 |
):
|
| 162 |
"""Returns the scores"""
|
| 163 |
|
|
@@ -169,69 +174,43 @@ class CodeEval(evaluate.Metric):
|
|
| 169 |
"This metric is currently not supported on Windows."
|
| 170 |
)
|
| 171 |
|
|
|
|
|
|
|
| 172 |
with ThreadPoolExecutor(max_workers=num_workers) as executor:
|
| 173 |
-
|
| 174 |
-
|
| 175 |
-
|
| 176 |
-
|
| 177 |
-
|
| 178 |
-
|
| 179 |
-
|
| 180 |
-
|
| 181 |
-
|
| 182 |
-
for _test_case in test_case:
|
| 183 |
-
assert isinstance(_test_case, str)
|
| 184 |
-
test_program = candidate + "\n" + _test_case
|
| 185 |
-
args = (
|
| 186 |
-
test_program,
|
| 187 |
-
timeout,
|
| 188 |
-
task_id,
|
| 189 |
-
completion_id[task_id],
|
| 190 |
-
)
|
| 191 |
future = executor.submit(check_correctness, *args)
|
| 192 |
-
|
| 193 |
-
|
| 194 |
-
|
| 195 |
-
|
| 196 |
-
|
| 197 |
-
for
|
| 198 |
-
|
| 199 |
-
|
| 200 |
-
|
| 201 |
-
|
| 202 |
-
|
| 203 |
-
|
| 204 |
-
|
| 205 |
-
pbar.update(
|
| 206 |
-
|
| 207 |
-
|
| 208 |
-
|
| 209 |
-
|
| 210 |
-
|
| 211 |
-
|
| 212 |
-
|
| 213 |
-
for key, result in results.items():
|
| 214 |
-
new_result = []
|
| 215 |
-
result.sort(key=lambda x: x[0])
|
| 216 |
-
for completion_id, group in itertools.groupby(result, key=lambda x: x[0]):
|
| 217 |
-
group = list(group)
|
| 218 |
-
new_result.append(
|
| 219 |
-
(
|
| 220 |
-
completion_id,
|
| 221 |
-
dict(
|
| 222 |
-
task_id=key,
|
| 223 |
-
passed=all(r[1]["passed"] for r in group),
|
| 224 |
-
result=[r[1]["result"] for r in group],
|
| 225 |
-
completion_id=completion_id,
|
| 226 |
-
),
|
| 227 |
-
)
|
| 228 |
-
)
|
| 229 |
-
new_results[key] = new_result
|
| 230 |
-
results = new_results
|
| 231 |
|
| 232 |
total, correct = [], []
|
| 233 |
for result in results.values():
|
| 234 |
-
result.sort(key=lambda x: x[0])
|
| 235 |
passed = [r[1]["passed"] for r in result]
|
| 236 |
total.append(len(passed))
|
| 237 |
correct.append(sum(passed))
|
|
@@ -266,3 +245,37 @@ def estimate_pass_at_k(num_samples, num_correct, k):
|
|
| 266 |
return np.array(
|
| 267 |
[estimator(int(n), int(c), k) for n, c in zip(num_samples_it, num_correct)]
|
| 268 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 20 |
import os
|
| 21 |
from collections import Counter, defaultdict
|
| 22 |
from concurrent.futures import CancelledError, ThreadPoolExecutor, as_completed
|
| 23 |
+
from typing import List, Optional
|
| 24 |
+
import time
|
| 25 |
|
| 26 |
import datasets
|
| 27 |
import evaluate
|
| 28 |
import numpy as np
|
| 29 |
from tqdm import tqdm
|
| 30 |
+
from pydantic import BaseModel
|
| 31 |
|
| 32 |
from .execute import check_correctness
|
| 33 |
|
|
|
|
| 158 |
self,
|
| 159 |
predictions,
|
| 160 |
references,
|
| 161 |
+
task_ids=None,
|
| 162 |
k=[1, 10, 100],
|
| 163 |
num_workers=4,
|
| 164 |
timeout=3.0,
|
| 165 |
+
early_stop=False,
|
| 166 |
):
|
| 167 |
"""Returns the scores"""
|
| 168 |
|
|
|
|
| 174 |
"This metric is currently not supported on Windows."
|
| 175 |
)
|
| 176 |
|
| 177 |
+
task_ids = task_ids or list(range(len(predictions)))
|
| 178 |
+
|
| 179 |
with ThreadPoolExecutor(max_workers=num_workers) as executor:
|
| 180 |
+
results = {}
|
| 181 |
+
for tid, pred, ref in zip(task_ids, predictions, references):
|
| 182 |
+
results[tid] = []
|
| 183 |
+
for candidate in pred:
|
| 184 |
+
result = Result(task_id=tid, completion_id=len(results))
|
| 185 |
+
for test_case in ref:
|
| 186 |
+
assert isinstance(test_case, str)
|
| 187 |
+
test_program = candidate + "\n" + test_case
|
| 188 |
+
args = (test_program, timeout, tid)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 189 |
future = executor.submit(check_correctness, *args)
|
| 190 |
+
result.add(future)
|
| 191 |
+
results[tid].append(result)
|
| 192 |
+
|
| 193 |
+
pbar = tqdm(total=len(results))
|
| 194 |
+
prev_done_count = 0
|
| 195 |
+
while not all(r.done() for r in results.values()):
|
| 196 |
+
cur_done_count = 0
|
| 197 |
+
for result in results.values():
|
| 198 |
+
for r in result:
|
| 199 |
+
if not r.done():
|
| 200 |
+
r.refresh(early_stop)
|
| 201 |
+
else:
|
| 202 |
+
cur_done_count += 1
|
| 203 |
+
pbar.update(cur_done_count - prev_done_count)
|
| 204 |
+
prev_done_count = cur_done_count
|
| 205 |
+
time.sleep(1)
|
| 206 |
+
|
| 207 |
+
results = {
|
| 208 |
+
task_id: [(r.completion_id, r.dict(exclude={"futures"})) for r in result]
|
| 209 |
+
for task_id, result in results.items()
|
| 210 |
+
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 211 |
|
| 212 |
total, correct = [], []
|
| 213 |
for result in results.values():
|
|
|
|
| 214 |
passed = [r[1]["passed"] for r in result]
|
| 215 |
total.append(len(passed))
|
| 216 |
correct.append(sum(passed))
|
|
|
|
| 245 |
return np.array(
|
| 246 |
[estimator(int(n), int(c), k) for n, c in zip(num_samples_it, num_correct)]
|
| 247 |
)
|
| 248 |
+
|
| 249 |
+
|
| 250 |
+
class Result(BaseModel):
|
| 251 |
+
task_id: int
|
| 252 |
+
completion_id: int
|
| 253 |
+
|
| 254 |
+
passed: Optional[bool] = None
|
| 255 |
+
result: List[str] = []
|
| 256 |
+
futures: List[object] = []
|
| 257 |
+
|
| 258 |
+
def add(self, future):
|
| 259 |
+
self.futures.append(future)
|
| 260 |
+
self.result.append(None)
|
| 261 |
+
|
| 262 |
+
def refresh(self, early_stop=False):
|
| 263 |
+
for i, future in enumerate(self.futures):
|
| 264 |
+
if self.result[i] is None and future.done():
|
| 265 |
+
try:
|
| 266 |
+
self.result[i] = future.result()
|
| 267 |
+
except CancelledError:
|
| 268 |
+
self.result[i] = "Early Stopped"
|
| 269 |
+
except Exception as e:
|
| 270 |
+
self.result[i] = str(e)
|
| 271 |
+
|
| 272 |
+
if early_stop:
|
| 273 |
+
# cancel all other futures
|
| 274 |
+
for future in self.futures[i + 1 :]:
|
| 275 |
+
future.cancel()
|
| 276 |
+
|
| 277 |
+
if all(r is not None for r in self.result):
|
| 278 |
+
self.passed = all(r["passed"] for r in self.result)
|
| 279 |
+
|
| 280 |
+
def done(self):
|
| 281 |
+
return self.passed is not None
|
requirements.txt
CHANGED
|
@@ -1 +1,2 @@
|
|
| 1 |
-
|
|
|
|
|
|
| 1 |
+
pydantic
|
| 2 |
+
numpy
|