| |
| """ |
| filter_sft_v2.py โ SFT ๋ฐ์ดํฐ ํ์ง ํํฐ (JSONL messages ํฌ๋งท) |
| |
| ํํฐ ๊ท์น: |
| 1. </s> ๋ฆฌํฐ๋ด ์ ๊ฑฐ (assistant ๋ฉ์์ง์์ </s> ํ๊ทธ strip) |
| 2. Q:, A:, ์ง๋ฌธ:, ๋ต๋ณ: ๋ฑ Q/A ๋ง์ปค ์ ๊ฑฐ (content ์์ ๋ถ๋ถ) |
| 3. 50์ ๋ฏธ๋ง ๊ทน๋จ ๋จ๋ฌธ ์ ๊ฑฐ (assistant ์๋ต ๊ธฐ์ค) |
| 4. 4-gram ๋ฐ๋ณต๋ฅ >30% ์ ๊ฑฐ (assistant ์๋ต ๊ธฐ์ค) |
| |
| Usage: |
| python data/filter_sft_v2.py \\ |
| --input data/sft_combined/train.jsonl \\ |
| --output data/sft_combined/train_filtered.jsonl |
| """ |
|
|
| import argparse |
| import json |
| import re |
| import sys |
| from collections import Counter |
| from pathlib import Path |
|
|
|
|
| |
| |
| |
| _EOS_PATTERN = re.compile(r"</s>", re.IGNORECASE) |
|
|
|
|
| def strip_eos_tag(text: str) -> str: |
| """</s> ํ๊ทธ๋ฅผ ์ ๊ฑฐํ๊ณ ์๋ค ๊ณต๋ฐฑ์ ์ ๋ฆฌํ๋ค.""" |
| return _EOS_PATTERN.sub("", text).strip() |
|
|
|
|
| |
| |
| |
| |
| _QA_MARKER_PATTERN = re.compile( |
| r"^\s*(?:" |
| r"์ง๋ฌธ\s*[:๏ผ]\s*" |
| r"|๋ต๋ณ\s*[:๏ผ]\s*" |
| r"|Q\s*[:๏ผ]\s*" |
| r"|A\s*[:๏ผ]\s*" |
| r"|Answer\s*[:๏ผ]\s*" |
| r"|Question\s*[:๏ผ]\s*" |
| r")+", |
| re.IGNORECASE, |
| ) |
|
|
|
|
| def strip_qa_markers(text: str) -> str: |
| """content ์์ ๋ถ๋ถ์ Q/A ๋ง์ปค๋ฅผ ์ ๊ฑฐํ๋ค.""" |
| return _QA_MARKER_PATTERN.sub("", text).strip() |
|
|
|
|
| |
| |
| |
| MIN_ASSISTANT_LEN = 50 |
|
|
|
|
| def is_too_short(text: str) -> bool: |
| return len(text) < MIN_ASSISTANT_LEN |
|
|
|
|
| |
| |
| |
| NGRAM_SIZE = 4 |
| MAX_REPEAT_RATIO = 0.30 |
|
|
|
|
| def _tokenize_ngrams(text: str, n: int): |
| """๊ณต๋ฐฑ ๋จ์ ํ ํฌ๋์ด์ฆ ํ n-gram ๋ฆฌ์คํธ ๋ฐํ. ํ๊ตญ์ด fallback ํฌํจ.""" |
| tokens = text.split() |
| |
| if len(tokens) < n * 3: |
| |
| chars = [c for c in text if not c.isspace()] |
| if len(chars) < n: |
| return [] |
| return [tuple(chars[i : i + n]) for i in range(len(chars) - n + 1)] |
| if len(tokens) < n: |
| return [] |
| return [tuple(tokens[i : i + n]) for i in range(len(tokens) - n + 1)] |
|
|
|
|
| def ngram_repeat_ratio(text: str, n: int = NGRAM_SIZE) -> float: |
| """ |
| (์ค๋ณต n-gram ์) / (์ ์ฒด n-gram ์) ๋น์จ์ ๋ฐํํ๋ค. |
| ์ ์ฒด n-gram์ด ์์ผ๋ฉด 0.0 ๋ฐํ. |
| """ |
| ngrams = _tokenize_ngrams(text, n) |
| total = len(ngrams) |
| if total == 0: |
| return 0.0 |
| counts = Counter(ngrams) |
| |
| duplicated = sum(c - 1 for c in counts.values() if c > 1) |
| return duplicated / total |
|
|
|
|
| def is_repetitive(text: str) -> bool: |
| return ngram_repeat_ratio(text) > MAX_REPEAT_RATIO |
|
|
|
|
| |
| |
| |
| MAX_CHAR_LEN = 20000 |
|
|
|
|
| def is_too_long(text: str) -> bool: |
| return len(text) > MAX_CHAR_LEN |
|
|
|
|
| |
| |
| |
|
|
| def clean_message_content(content: str, role: str) -> str: |
| """๋จ์ผ ๋ฉ์์ง์ content๋ฅผ ์ ์ ํ๋ค.""" |
| |
| if role == "assistant": |
| content = strip_eos_tag(content) |
| |
| content = strip_qa_markers(content) |
| return content |
|
|
|
|
| def filter_sample(sample: dict) -> tuple[dict | None, str]: |
| """ |
| ํ๋์ ์ํ์ ๊ฒ์ฌยท์ ์ ํ๋ค. |
| ๋ฐํ: (์ ์ ๋ ์ํ ๋๋ None, ์ ๊ฑฐ ์ด์ ๋๋ "") |
| """ |
| messages = sample.get("messages") |
| if not messages or not isinstance(messages, list): |
| return None, "no_messages" |
|
|
| cleaned_messages = [] |
| for msg in messages: |
| role = msg.get("role", "") |
| content = msg.get("content", "") |
| if not isinstance(content, str): |
| content = str(content) |
| content = clean_message_content(content, role) |
| cleaned_messages.append({**msg, "content": content}) |
|
|
| |
| assistant_contents = [ |
| m["content"] for m in cleaned_messages if m.get("role") == "assistant" |
| ] |
|
|
| if not assistant_contents: |
| return None, "no_assistant_turn" |
|
|
| for ac in assistant_contents: |
| |
| if is_too_short(ac): |
| return None, "too_short" |
| |
| if is_too_long(ac): |
| return None, "too_long" |
| |
| if is_repetitive(ac): |
| return None, "repetitive" |
|
|
| result = {**sample, "messages": cleaned_messages} |
| return result, "" |
|
|
|
|
| |
| |
| |
|
|
| def parse_args(): |
| parser = argparse.ArgumentParser( |
| description="SFT ๋ฐ์ดํฐ ํ์ง ํํฐ (JSONL messages ํฌ๋งท)" |
| ) |
| parser.add_argument("--input", required=True, help="์
๋ ฅ JSONL ํ์ผ ๊ฒฝ๋ก") |
| parser.add_argument("--output", required=True, help="์ถ๋ ฅ JSONL ํ์ผ ๊ฒฝ๋ก") |
| return parser.parse_args() |
|
|
|
|
| def main(): |
| args = parse_args() |
|
|
| in_path = Path(args.input) |
| out_path = Path(args.output) |
|
|
| if not in_path.exists(): |
| print(f"ERROR: ์
๋ ฅ ํ์ผ์ ์ฐพ์ ์ ์์ต๋๋ค: {in_path}", file=sys.stderr) |
| sys.exit(1) |
|
|
| out_path.parent.mkdir(parents=True, exist_ok=True) |
|
|
| |
| stats: dict[str, int] = { |
| "total": 0, |
| "no_messages": 0, |
| "no_assistant_turn": 0, |
| "too_short": 0, |
| "too_long": 0, |
| "repetitive": 0, |
| "json_error": 0, |
| "passed": 0, |
| } |
|
|
| with in_path.open("r", errors="replace") as fin, out_path.open("w") as fout: |
| for lineno, raw in enumerate(fin, 1): |
| raw = raw.strip() |
| if not raw: |
| continue |
| stats["total"] += 1 |
|
|
| try: |
| sample = json.loads(raw) |
| except json.JSONDecodeError as e: |
| print(f"[WARN] ๋ผ์ธ {lineno} JSON ํ์ฑ ์คํจ: {e}", file=sys.stderr) |
| stats["json_error"] += 1 |
| continue |
|
|
| cleaned, reason = filter_sample(sample) |
| if cleaned is None: |
| stats[reason] = stats.get(reason, 0) + 1 |
| else: |
| stats["passed"] += 1 |
| fout.write(json.dumps(cleaned, ensure_ascii=False) + "\n") |
|
|
| |
| total = stats["total"] |
| removed = total - stats["passed"] |
| print("=" * 60) |
| print(f" ์
๋ ฅ ํ์ผ : {in_path}") |
| print(f" ์ถ๋ ฅ ํ์ผ : {out_path}") |
| print("=" * 60) |
| print(f" ์ด ์
๋ ฅ : {total:>10,}") |
| print(f" [์ ๊ฑฐ] no_messages : {stats['no_messages']:>10,}") |
| print(f" [์ ๊ฑฐ] no_assistant_turn: {stats['no_assistant_turn']:>10,}") |
| print(f" [์ ๊ฑฐ] too_short (<50์): {stats['too_short']:>10,}") |
| print(f" [์ ๊ฑฐ] too_long (>{MAX_CHAR_LEN}์): {stats['too_long']:>10,}") |
| print(f" [์ ๊ฑฐ] json_error : {stats['json_error']:>10,}") |
| print(f" [์ ๊ฑฐ] repetitive (4-gram >30%): {stats['repetitive']:>10,}") |
| print(f" ์ด ์ ๊ฑฐ : {removed:>10,} ({removed/total*100:.1f}%)") |
| print(f" ์ต์ข
์์กด : {stats['passed']:>10,} ({stats['passed']/total*100:.1f}%)") |
| print("=" * 60) |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|