frankenstallm / data /filter_sft_v2.py
pathcosmos's picture
feat: Add data pipeline scripts + phase reports (Tier 3 - reproducibility)
b3d361d verified
#!/usr/bin/env python3
"""
filter_sft_v2.py โ€” SFT ๋ฐ์ดํ„ฐ ํ’ˆ์งˆ ํ•„ํ„ฐ (JSONL messages ํฌ๋งท)
ํ•„ํ„ฐ ๊ทœ์น™:
1. </s> ๋ฆฌํ„ฐ๋Ÿด ์ œ๊ฑฐ (assistant ๋ฉ”์‹œ์ง€์—์„œ </s> ํƒœ๊ทธ strip)
2. Q:, A:, ์งˆ๋ฌธ:, ๋‹ต๋ณ€: ๋“ฑ Q/A ๋งˆ์ปค ์ œ๊ฑฐ (content ์‹œ์ž‘ ๋ถ€๋ถ„)
3. 50์ž ๋ฏธ๋งŒ ๊ทน๋‹จ ๋‹จ๋ฌธ ์ œ๊ฑฐ (assistant ์‘๋‹ต ๊ธฐ์ค€)
4. 4-gram ๋ฐ˜๋ณต๋ฅ  >30% ์ œ๊ฑฐ (assistant ์‘๋‹ต ๊ธฐ์ค€)
Usage:
python data/filter_sft_v2.py \\
--input data/sft_combined/train.jsonl \\
--output data/sft_combined/train_filtered.jsonl
"""
import argparse
import json
import re
import sys
from collections import Counter
from pathlib import Path
# ---------------------------------------------------------------------------
# ํ•„ํ„ฐ 1: </s> ๋ฆฌํ„ฐ๋Ÿด ์ œ๊ฑฐ
# ---------------------------------------------------------------------------
_EOS_PATTERN = re.compile(r"</s>", re.IGNORECASE)
def strip_eos_tag(text: str) -> str:
"""</s> ํƒœ๊ทธ๋ฅผ ์ œ๊ฑฐํ•˜๊ณ  ์•ž๋’ค ๊ณต๋ฐฑ์„ ์ •๋ฆฌํ•œ๋‹ค."""
return _EOS_PATTERN.sub("", text).strip()
# ---------------------------------------------------------------------------
# ํ•„ํ„ฐ 2: Q/A ๋งˆ์ปค ์ œ๊ฑฐ
# ---------------------------------------------------------------------------
# content ์‹œ์ž‘ ๋ถ€๋ถ„์˜ ๋งˆ์ปค ํŒจํ„ด (ํ•œ๊ตญ์–ดยท์˜์–ด ๋ชจ๋‘ ์ฒ˜๋ฆฌ)
_QA_MARKER_PATTERN = re.compile(
r"^\s*(?:"
r"์งˆ๋ฌธ\s*[:๏ผš]\s*"
r"|๋‹ต๋ณ€\s*[:๏ผš]\s*"
r"|Q\s*[:๏ผš]\s*"
r"|A\s*[:๏ผš]\s*"
r"|Answer\s*[:๏ผš]\s*"
r"|Question\s*[:๏ผš]\s*"
r")+",
re.IGNORECASE,
)
def strip_qa_markers(text: str) -> str:
"""content ์‹œ์ž‘ ๋ถ€๋ถ„์˜ Q/A ๋งˆ์ปค๋ฅผ ์ œ๊ฑฐํ•œ๋‹ค."""
return _QA_MARKER_PATTERN.sub("", text).strip()
# ---------------------------------------------------------------------------
# ํ•„ํ„ฐ 3: ๊ทน๋‹จ ๋‹จ๋ฌธ ํŒ๋‹จ
# ---------------------------------------------------------------------------
MIN_ASSISTANT_LEN = 50 # ๊ธ€์ž ์ˆ˜ ๊ธฐ์ค€
def is_too_short(text: str) -> bool:
return len(text) < MIN_ASSISTANT_LEN
# ---------------------------------------------------------------------------
# ํ•„ํ„ฐ 4: 4-gram ๋ฐ˜๋ณต๋ฅ 
# ---------------------------------------------------------------------------
NGRAM_SIZE = 4
MAX_REPEAT_RATIO = 0.30 # 30% ์ดˆ๊ณผ ์‹œ ์ œ๊ฑฐ
def _tokenize_ngrams(text: str, n: int):
"""๊ณต๋ฐฑ ๋‹จ์œ„ ํ† ํฌ๋‚˜์ด์ฆˆ ํ›„ n-gram ๋ฆฌ์ŠคํŠธ ๋ฐ˜ํ™˜. ํ•œ๊ตญ์–ด fallback ํฌํ•จ."""
tokens = text.split()
# ํ•œ๊ตญ์–ด fallback: ๊ณต๋ฐฑ ํ† ํฐ์ด ๋ถ€์กฑํ•˜๋ฉด ๋ฌธ์ž ๋ ˆ๋ฒจ n-gram ์‚ฌ์šฉ
if len(tokens) < n * 3:
# ๊ณต๋ฐฑ/๊ตฌ๋‘์  ์ œ๊ฑฐ ํ›„ ๋ฌธ์ž ๋‹จ์œ„
chars = [c for c in text if not c.isspace()]
if len(chars) < n:
return []
return [tuple(chars[i : i + n]) for i in range(len(chars) - n + 1)]
if len(tokens) < n:
return []
return [tuple(tokens[i : i + n]) for i in range(len(tokens) - n + 1)]
def ngram_repeat_ratio(text: str, n: int = NGRAM_SIZE) -> float:
"""
(์ค‘๋ณต n-gram ์ˆ˜) / (์ „์ฒด n-gram ์ˆ˜) ๋น„์œจ์„ ๋ฐ˜ํ™˜ํ•œ๋‹ค.
์ „์ฒด n-gram์ด ์—†์œผ๋ฉด 0.0 ๋ฐ˜ํ™˜.
"""
ngrams = _tokenize_ngrams(text, n)
total = len(ngrams)
if total == 0:
return 0.0
counts = Counter(ngrams)
# 1ํšŒ ์ดˆ๊ณผ ๋“ฑ์žฅํ•œ n-gram ๊ฐœ์ˆ˜(์ค‘๋ณต๋ถ„)
duplicated = sum(c - 1 for c in counts.values() if c > 1)
return duplicated / total
def is_repetitive(text: str) -> bool:
return ngram_repeat_ratio(text) > MAX_REPEAT_RATIO
# ---------------------------------------------------------------------------
# ํ•„ํ„ฐ 5: ์ดˆ์žฅ๋ฌธ ์‘๋‹ต ํ•„ํ„ฐ
# ---------------------------------------------------------------------------
MAX_CHAR_LEN = 20000 # 20K ๊ธ€์ž ์ดˆ๊ณผ ์‹œ ์ œ๊ฑฐ
def is_too_long(text: str) -> bool:
return len(text) > MAX_CHAR_LEN
# ---------------------------------------------------------------------------
# ๋ฉ”์‹œ์ง€ ์ •์ œ / ์ƒ˜ํ”Œ ์ˆ˜์ค€ ํ•„ํ„ฐ๋ง
# ---------------------------------------------------------------------------
def clean_message_content(content: str, role: str) -> str:
"""๋‹จ์ผ ๋ฉ”์‹œ์ง€์˜ content๋ฅผ ์ •์ œํ•œ๋‹ค."""
# ํ•„ํ„ฐ 1: </s> ํƒœ๊ทธ ์ œ๊ฑฐ (assistant ํ•œ์ •)
if role == "assistant":
content = strip_eos_tag(content)
# ํ•„ํ„ฐ 2: Q/A ๋งˆ์ปค ์ œ๊ฑฐ (๋ชจ๋“  role)
content = strip_qa_markers(content)
return content
def filter_sample(sample: dict) -> tuple[dict | None, str]:
"""
ํ•˜๋‚˜์˜ ์ƒ˜ํ”Œ์„ ๊ฒ€์‚ฌยท์ •์ œํ•œ๋‹ค.
๋ฐ˜ํ™˜: (์ •์ œ๋œ ์ƒ˜ํ”Œ ๋˜๋Š” None, ์ œ๊ฑฐ ์ด์œ  ๋˜๋Š” "")
"""
messages = sample.get("messages")
if not messages or not isinstance(messages, list):
return None, "no_messages"
cleaned_messages = []
for msg in messages:
role = msg.get("role", "")
content = msg.get("content", "")
if not isinstance(content, str):
content = str(content)
content = clean_message_content(content, role)
cleaned_messages.append({**msg, "content": content})
# assistant ์‘๋‹ต ๊ธฐ์ค€ ํ•„ํ„ฐ ์ ์šฉ
assistant_contents = [
m["content"] for m in cleaned_messages if m.get("role") == "assistant"
]
if not assistant_contents:
return None, "no_assistant_turn"
for ac in assistant_contents:
# ํ•„ํ„ฐ 3: ๊ทน๋‹จ ๋‹จ๋ฌธ
if is_too_short(ac):
return None, "too_short"
# ํ•„ํ„ฐ 5: ์ดˆ์žฅ๋ฌธ
if is_too_long(ac):
return None, "too_long"
# ํ•„ํ„ฐ 4: 4-gram ๋ฐ˜๋ณต
if is_repetitive(ac):
return None, "repetitive"
result = {**sample, "messages": cleaned_messages}
return result, ""
# ---------------------------------------------------------------------------
# ๋ฉ”์ธ
# ---------------------------------------------------------------------------
def parse_args():
parser = argparse.ArgumentParser(
description="SFT ๋ฐ์ดํ„ฐ ํ’ˆ์งˆ ํ•„ํ„ฐ (JSONL messages ํฌ๋งท)"
)
parser.add_argument("--input", required=True, help="์ž…๋ ฅ JSONL ํŒŒ์ผ ๊ฒฝ๋กœ")
parser.add_argument("--output", required=True, help="์ถœ๋ ฅ JSONL ํŒŒ์ผ ๊ฒฝ๋กœ")
return parser.parse_args()
def main():
args = parse_args()
in_path = Path(args.input)
out_path = Path(args.output)
if not in_path.exists():
print(f"ERROR: ์ž…๋ ฅ ํŒŒ์ผ์„ ์ฐพ์„ ์ˆ˜ ์—†์Šต๋‹ˆ๋‹ค: {in_path}", file=sys.stderr)
sys.exit(1)
out_path.parent.mkdir(parents=True, exist_ok=True)
# ํ†ต๊ณ„ ์นด์šดํ„ฐ
stats: dict[str, int] = {
"total": 0,
"no_messages": 0,
"no_assistant_turn": 0,
"too_short": 0,
"too_long": 0,
"repetitive": 0,
"json_error": 0,
"passed": 0,
}
with in_path.open("r", errors="replace") as fin, out_path.open("w") as fout:
for lineno, raw in enumerate(fin, 1):
raw = raw.strip()
if not raw:
continue
stats["total"] += 1
try:
sample = json.loads(raw)
except json.JSONDecodeError as e:
print(f"[WARN] ๋ผ์ธ {lineno} JSON ํŒŒ์‹ฑ ์‹คํŒจ: {e}", file=sys.stderr)
stats["json_error"] += 1
continue
cleaned, reason = filter_sample(sample)
if cleaned is None:
stats[reason] = stats.get(reason, 0) + 1
else:
stats["passed"] += 1
fout.write(json.dumps(cleaned, ensure_ascii=False) + "\n")
# ํ†ต๊ณ„ ์ถœ๋ ฅ
total = stats["total"]
removed = total - stats["passed"]
print("=" * 60)
print(f" ์ž…๋ ฅ ํŒŒ์ผ : {in_path}")
print(f" ์ถœ๋ ฅ ํŒŒ์ผ : {out_path}")
print("=" * 60)
print(f" ์ด ์ž…๋ ฅ : {total:>10,}")
print(f" [์ œ๊ฑฐ] no_messages : {stats['no_messages']:>10,}")
print(f" [์ œ๊ฑฐ] no_assistant_turn: {stats['no_assistant_turn']:>10,}")
print(f" [์ œ๊ฑฐ] too_short (<50์ž): {stats['too_short']:>10,}")
print(f" [์ œ๊ฑฐ] too_long (>{MAX_CHAR_LEN}์ž): {stats['too_long']:>10,}")
print(f" [์ œ๊ฑฐ] json_error : {stats['json_error']:>10,}")
print(f" [์ œ๊ฑฐ] repetitive (4-gram >30%): {stats['repetitive']:>10,}")
print(f" ์ด ์ œ๊ฑฐ : {removed:>10,} ({removed/total*100:.1f}%)")
print(f" ์ตœ์ข… ์ž”์กด : {stats['passed']:>10,} ({stats['passed']/total*100:.1f}%)")
print("=" * 60)
if __name__ == "__main__":
main()