| |
| """ |
| Fix batch5 by correctly converting environment role to observation. |
| """ |
|
|
| import json |
| from datasets import load_dataset |
| from tqdm import tqdm |
|
|
| def convert_to_llamafactory_format(sample): |
| """ |
| Convert from Dolci format to LlamaFactory format. |
| |
| Dolci format (messages): |
| - role: system/user/assistant/environment |
| - content: text content |
| - function_calls: function call string (in assistant messages) |
| - functions: available functions JSON string (in system message) |
| |
| LlamaFactory format (conversations): |
| - from: human/gpt/function_call/observation/system |
| - value: text or JSON |
| """ |
| messages = sample.get('messages', []) |
| conversations = [] |
| tools = None |
| system_prompt = None |
|
|
| for i, msg in enumerate(messages): |
| role = msg.get('role', '') |
| content = msg.get('content', '') |
| function_calls = msg.get('function_calls') |
| functions = msg.get('functions') |
|
|
| |
| if role == 'system': |
| if functions and not tools: |
| tools = functions |
| if content: |
| system_prompt = content |
| continue |
|
|
| |
| if role == 'user': |
| conversations.append({ |
| 'from': 'human', |
| 'value': content |
| }) |
| elif role == 'assistant': |
| |
| if function_calls: |
| |
| conversations.append({ |
| 'from': 'function_call', |
| 'value': function_calls |
| }) |
| elif content: |
| |
| conversations.append({ |
| 'from': 'gpt', |
| 'value': content |
| }) |
| elif role == 'environment': |
| |
| conversations.append({ |
| 'from': 'observation', |
| 'value': content |
| }) |
|
|
| result = {'conversations': conversations} |
|
|
| if system_prompt: |
| result['system'] = system_prompt |
| if tools: |
| result['tools'] = tools |
|
|
| return result |
|
|
| def get_sample_hash(sample): |
| """Create a hash for a sample to identify duplicates.""" |
| messages = sample.get('messages', []) |
| for msg in messages: |
| if msg.get('role') == 'user': |
| return hash(msg.get('content', '')) |
| return None |
|
|
| def has_tool_calling(messages): |
| """Check if messages contain function_call.""" |
| for msg in messages: |
| if msg.get('function_calls'): |
| return True |
| return False |
|
|
| def main(): |
| print("Loading allenai/Dolci-Instruct-SFT-Tool-Use dataset...") |
| dataset = load_dataset("allenai/Dolci-Instruct-SFT-Tool-Use", split="train") |
|
|
| total_samples = len(dataset) |
| print(f"Total samples in dataset: {total_samples}") |
|
|
| |
| print("\nLoading existing batches to avoid duplicates...") |
| existing_hashes = set() |
| for batch_num in range(1, 5): |
| batch_file = f"data/dolci_10k_with_tool_call_batch{batch_num}.json" |
| try: |
| with open(batch_file, 'r', encoding='utf-8') as f: |
| batch_data = json.load(f) |
| for sample in batch_data: |
| conversations = sample.get('conversations', []) |
| for conv in conversations: |
| if conv.get('from') == 'human': |
| sample_hash = hash(conv.get('value', '')) |
| existing_hashes.add(sample_hash) |
| break |
| print(f" Loaded batch{batch_num}: {len(batch_data)} samples") |
| except FileNotFoundError: |
| print(f" Warning: {batch_file} not found, skipping...") |
|
|
| print(f"Total existing samples to avoid: {len(existing_hashes)}") |
|
|
| |
| start_idx = max(0, total_samples - 20000) |
| last_20k = dataset.select(range(start_idx, total_samples)) |
| print(f"\nProcessing last 20k samples (from index {start_idx} to {total_samples})") |
|
|
| |
| tool_calling_samples = [] |
| for idx, sample in enumerate(tqdm(last_20k, desc="Filtering tool calling samples")): |
| messages = sample.get('messages', []) |
| if has_tool_calling(messages): |
| sample_hash = get_sample_hash(sample) |
|
|
| |
| if sample_hash not in existing_hashes: |
| converted = convert_to_llamafactory_format(sample) |
|
|
| |
| conversations = converted.get('conversations', []) |
| roles = [c['from'] for c in conversations] |
|
|
| |
| if 'function_call' in roles and 'observation' in roles: |
| tool_calling_samples.append(converted) |
|
|
| print(f"\nFound {len(tool_calling_samples)} NEW tool calling samples with proper format") |
|
|
| |
| if len(tool_calling_samples) > 10000: |
| selected_samples = tool_calling_samples[:10000] |
| print(f"Selected first 10,000 samples for batch5") |
| else: |
| selected_samples = tool_calling_samples |
| print(f"Using all {len(selected_samples)} samples for batch5") |
|
|
| if not selected_samples: |
| print("\n❌ No new tool calling samples found!") |
| return |
|
|
| |
| output_file = "data/dolci_10k_with_tool_call_batch5.json" |
| print(f"\nSaving to {output_file}...") |
| with open(output_file, 'w', encoding='utf-8') as f: |
| json.dump(selected_samples, f, ensure_ascii=False, indent=2) |
|
|
| print(f"✓ Successfully created batch5 with {len(selected_samples)} samples") |
|
|
| |
| print("\n=== Verifying format ===") |
| role_patterns = {} |
| for sample in selected_samples[:100]: |
| roles = [c['from'] for c in sample['conversations']] |
| pattern = ' -> '.join(roles) |
| role_patterns[pattern] = role_patterns.get(pattern, 0) + 1 |
|
|
| print("Top patterns in first 100 samples:") |
| for pattern, count in sorted(role_patterns.items(), key=lambda x: -x[1])[:5]: |
| print(f" [{count:3d}] {pattern}") |
|
|
| |
| if selected_samples: |
| print("\nSample entry:") |
| print(json.dumps(selected_samples[0], ensure_ascii=False, indent=2)[:1000] + "...") |
|
|
| if __name__ == "__main__": |
| main() |
|
|