{ "nbformat": 4, "nbformat_minor": 0, "metadata": { "colab": { "provenance": [], "gpuType": "T4", "name": "Fine-tune TinyLlama-1.1B for Web Search Tool Calling" }, "kernelspec": { "name": "python3", "display_name": "Python 3" }, "language_info": { "name": "python" }, "accelerator": "GPU" }, "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# Fine-Tune TinyLlama-1.1B-Chat for Web Search Tool Calling\n", "\n", "## What This Notebook Does\n", "\n", "Fine-tunes **TinyLlama/TinyLlama-1.1B-Chat-v1.0** so it can **call a web search tool** when it needs up-to-date information, and **answer directly** when it doesn't.\n", "\n", "| Feature | Detail |\n", "|---|---|\n", "| Web Search | Model outputs a structured tool call; your app calls Google Custom Search API and feeds results back |\n", "| Custom System Prompt | Set any system message / persona at inference time (no retraining needed) |\n", "| Custom Model Name | Configurable in settings |\n", "| No Reasoning | Zero chain-of-thought in training data — model stays tiny and fast |\n", "\n", "## Approach (Research-Backed)\n", "\n", "This notebook implements **two complementary training strategies** based on current research:\n", "\n", "1. **xLAM Function-Calling Format** — Uses the same data schema as [Salesforce xLAM](https://huggingface.co/datasets/Salesforce/xlam-function-calling-60k) (the industry-standard function-calling dataset). Each example includes:\n", " - A list of available tools with their JSON schemas\n", " - User query\n", " - Model's tool call response (JSON with `name` + `arguments`)\n", " - Tool observation (search results)\n", " - Final answer synthesized from results\n", "\n", "2. **LlamaFactory Agent Tuning Format** — Compatible with [LlamaFactory](https://github.com/hiyouga/LlamaFactory) ShareGPT format (`glaive_toolcall_en`), the most popular open-source tool-use dataset.\n", "\n", "### References\n", "- [HuggingFace Cookbook: Fine-tuning LLMs for Function Calling with xLAM](https://huggingface.co/learn/cookbook/en/function_calling_fine_tuning_llms_on_xlam)\n", "- [Microsoft: Fine-Tuning SLMs for Function Calling](https://techcommunity.microsoft.com/blog/azure-ai-foundry-blog/fine-tuning-small-language-models-for-function-calling-a-comprehensive-guide/4362539)\n", "- [TinyAgent: Function Calling at the Edge (Berkeley)](https://bair.berkeley.edu/blog/2024/05/29/tiny-agent)\n", "- [Gorilla: Large Language Model Connected with Massive APIs](https://arxiv.org/abs/2305.15334)\n", "- [LlamaFactory Agent Tuning Docs](https://llamafactory.readthedocs.io/en/latest/getting_started/data_preparation.html)\n", "\n", "### Hardware\n", "- **Free Colab T4** (16 GB VRAM) — works with QLoRA 4-bit quantization\n", "- **Colab Pro A100** — works and trains faster\n", "\n", "### How It Works After Training\n", "```\n", "User: Who won the 2024 Super Bowl?\n", "Model → Action: call web_search(query=\"2024 Super Bowl winner\")\n", " ↓ your app calls Google Custom Search API\n", " ↓ feeds results back to model\n", "Model → Answer: The Kansas City Chiefs won Super Bowl LVIII...\n", "```" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "---\n", "## 1 — Configuration\n", "Edit these before running." ] }, { "cell_type": "code", "metadata": {}, "source": [ "# @title ⚙️ Settings\n", "MODEL_NAME = \"TinyLlama/TinyLlama-1.1B-Chat-v1.0\"\n", "OUTPUT_DIR = \"tinyllama-websearch\"\n", "HF_USERNAME = \"your-username\" # for pushing to HF Hub (leave blank to skip)\n", "HF_REPO = f\"{HF_USERNAME}/tinyllama-websearch\" if HF_USERNAME else None\n", "\n", "# --- Custom model identity ---\n", "CUSTOM_SYSTEM_PROMPT = (\n", " \"You are WebSearchLlama, a helpful assistant. \"\n", " \"When you need up-to-date or real-time information that you cannot confidently answer, \"\n", " \"you MUST call the web_search tool. \"\n", " \"For general knowledge questions, math, or facts you are sure about, answer directly without calling any tool. \"\n", " \"Answer concisely. Do not explain your reasoning.\"\n", ")\n", "CUSTOM_MODEL_NAME = \"WebSearchLlama-1.1B\"\n", "\n", "# --- Training hyper-parameters ---\n", "NUM_EPOCHS = 3\n", "BATCH_SIZE = 4 # per-device micro-batch\n", "GRAD_ACCUM = 4 # effective batch = 4 × 4 = 16\n", "LEARNING_RATE = 2e-4\n", "MAX_SEQ_LEN = 1024 # keeps model small + fast\n", "LORA_R = 16\n", "LORA_ALPHA = 32\n", "LORA_DROPOUT = 0.05\n", "\n", "# --- Google Custom Search API (for inference demo) ---\n", "GOOGLE_API_KEY = \"\" # get from https://console.cloud.google.com/apis/credentials\n", "GOOGLE_CX_ID = \"\" # get from https://programmablesearchengine.google.com\n", "\n", "print(\"✅ Settings loaded\")" ], "execution_count": null, "outputs": [] }, { "cell_type": "markdown", "metadata": {}, "source": [ "---\n", "## 2 — Install Dependencies" ] }, { "cell_type": "code", "metadata": {}, "source": [ "!pip install -q --upgrade transformers datasets accelerate peft trl bitsandbytes sentencepiece protobuf\n", "!pip install -q huggingface_hub evaluate bert_score\n", "print(\"✅ All packages installed\")" ], "execution_count": null, "outputs": [] }, { "cell_type": "markdown", "metadata": {}, "source": [ "---\n", "## 3 — HuggingFace Login (optional)\n", "Skip if you only want to save locally." ] }, { "cell_type": "code", "metadata": {}, "source": [ "# from huggingface_hub import notebook_login\n", "# notebook_login()" ], "execution_count": null, "outputs": [] }, { "cell_type": "markdown", "metadata": {}, "source": [ "---\n", "## 4 — Training Data: Understanding Tool-Calling Formats\n", "\n", "Based on our research, there are **two main formats** for tool-calling training data:\n", "\n", "### Format A: xLAM (Salesforce) — Industry Standard\n", "Used by HuggingFace cookbook, Microsoft, and the Berkeley Function Calling Leaderboard.\n", "\n", "```json\n", "{\n", " \"tools\": [{\"name\": \"web_search\", \"description\": \"...\", \"parameters\": {...}}],\n", " \"query\": \"What is the latest news about AI?\",\n", " \"function_call\": {\"name\": \"web_search\", \"arguments\": {\"query\": \"latest AI news\"}},\n", " \"answer\": \"Based on search results: \"\n", "}\n", "```\n", "\n", "### Format B: LlamaFactory ShareGPT — Most Popular Open Source\n", "Used by LlamaFactory's built-in agent tuning with `glaive_toolcall_en` dataset.\n", "\n", "```json\n", "[\n", " {\"from\": \"system\", \"value\": \"You are a helpful assistant with access to tools...\"},\n", " {\"from\": \"human\", \"value\": \"What is the weather in Tokyo?\"},\n", " {\"from\": \"function_call\", \"value\": \"{\\\"name\\\": \\\"web_search\\\", \\\"arguments\\\": {\\\"query\\\": \\\"Tokyo weather today\\\"}}\"},\n", " {\"from\": \"observation\", \"value\": \"[{\\\"title\\\": ..., \\\"snippet\\\": ...}]\"},\n", " {\"from\": \"gpt\", \"value\": \"Tokyo is currently 18°C with clear skies.\"}\n", "]\n", "```\n", "\n", "**This notebook uses Format A (xLAM)** because:\n", "1. It includes tool schemas in every example → model learns when/how to call tools\n", "2. Data is verified through 3 stages (format check, execution, semantic verification)\n", "3. Compatible with the full xLAM dataset (60k examples) for scaling up training\n", "\n", "We also provide a Format B converter in Section 4D if you prefer LlamaFactory." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "---\n", "### 4A — Tool Definition (the tool the model will learn to call)" ] }, { "cell_type": "code", "metadata": {}, "source": [ "import json\n", "\n", "# ==========================================================\n", "# This is the tool schema the model learns from.\n", "# It matches OpenAI's function calling format — the industry standard.\n", "# Source: https://platform.openai.com/docs/guides/function-calling\n", "# ==========================================================\n", "\n", "WEB_SEARCH_TOOL = {\n", " \"type\": \"function\",\n", " \"function\": {\n", " \"name\": \"web_search\",\n", " \"description\": (\n", " \"Search the web using Google to find up-to-date information. \"\n", " \"Use this tool when you need real-time data, current events, \"\n", " \"recent news, live scores, stock prices, weather, or any information \"\n", " \"that may change frequently. Do NOT use for general knowledge, \"\n", " \"math, definitions, or historical facts you already know.\"\n", " ),\n", " \"parameters\": {\n", " \"type\": \"object\",\n", " \"properties\": {\n", " \"query\": {\n", " \"type\": \"string\",\n", " \"description\": \"The search query to send to Google. Be specific and concise.\"\n", " }\n", " },\n", " \"required\": [\"query\"]\n", " }\n", " }\n", "}\n", "\n", "print(json.dumps(WEB_SEARCH_TOOL, indent=2))" ], "execution_count": null, "outputs": [] }, { "cell_type": "markdown", "metadata": {}, "source": [ "---\n", "### 4B — Build Training Dataset (xLAM Format)\n", "\n", "We create examples following the xLAM schema. Each example has:\n", "- `tools`: the tool schema (so model learns the API shape)\n", "- `query`: the user's question\n", "- `function_call`: the model's tool call output\n", "- `answer`: the final answer (after receiving tool results)\n", "\n", "We also include **no-tool examples** (general knowledge, math) so the model learns when NOT to call." ] }, { "cell_type": "code", "metadata": {}, "source": [ "from datasets import Dataset, DatasetDict\n", "import random\n", "\n", "# ==========================================================\n", "# xLAM-format training examples\n", "# Reference: Salesforce/xlam-function-calling-60k\n", "# ==========================================================\n", "\n", "TOOL_CALL_EXAMPLES = [\n", " # ===== Current Events / News =====\n", " {\n", " \"tools\": [WEB_SEARCH_TOOL],\n", " \"query\": \"Who won the 2024 Super Bowl?\",\n", " \"function_call\": {\"name\": \"web_search\", \"arguments\": {\"query\": \"2024 Super Bowl winner\"}},\n", " \"observations\": [\n", " {\"title\": \"Super Bowl LVIII - Wikipedia\", \"snippet\": \"The Kansas City Chiefs defeated the San Francisco 49ers 25-22 in overtime on February 11, 2024.\"},\n", " {\"title\": \"NFL.com Recap\", \"snippet\": \"Patrick Mahomes led the Chiefs to their third Super Bowl title with a 25-22 overtime victory.\"}\n", " ],\n", " \"answer\": \"The Kansas City Chiefs won Super Bowl LVIII on February 11, 2024, defeating the San Francisco 49ers 25-22 in overtime. Patrick Mahomes led the Chiefs to their third Super Bowl title.\"\n", " },\n", " {\n", " \"tools\": [WEB_SEARCH_TOOL],\n", " \"query\": \"What are the top news headlines today?\",\n", " \"function_call\": {\"name\": \"web_search\", \"arguments\": {\"query\": \"top news headlines today\"}},\n", " \"observations\": [\n", " {\"title\": \"BBC News\", \"snippet\": \"Global leaders meet for climate summit in Geneva. Tech stocks surge on AI breakthrough.\"},\n", " {\"title\": \"Reuters\", \"snippet\": \"UN Security Council votes on new resolution. Major earthquake hits Turkey, no casualties reported.\"}\n", " ],\n", " \"answer\": \"Top headlines today: 1) Global leaders gather in Geneva for a climate summit. 2) Tech stocks surge following an AI breakthrough. 3) UN Security Council votes on a new resolution. 4) Major earthquake hits Turkey with no reported casualties.\"\n", " },\n", "\n", " # ===== Real-Time Data =====\n", " {\n", " \"tools\": [WEB_SEARCH_TOOL],\n", " \"query\": \"What is the price of Bitcoin right now?\",\n", " \"function_call\": {\"name\": \"web_search\", \"arguments\": {\"query\": \"Bitcoin price today USD\"}},\n", " \"observations\": [\n", " {\"title\": \"CoinGecko\", \"snippet\": \"Bitcoin BTC price is $67,234.50 as of today.\"},\n", " {\"title\": \"CoinMarketCap\", \"snippet\": \"BTC is trading at $67,200 with a 2.3% change in 24 hours.\"}\n", " ],\n", " \"answer\": \"Bitcoin is currently priced at approximately $67,234.50 USD, with a 2.3% change in the last 24 hours.\"\n", " },\n", " {\n", " \"tools\": [WEB_SEARCH_TOOL],\n", " \"query\": \"What is Apple's stock price today?\",\n", " \"function_call\": {\"name\": \"web_search\", \"arguments\": {\"query\": \"Apple AAPL stock price today\"}},\n", " \"observations\": [\n", " {\"title\": \"Yahoo Finance AAPL\", \"snippet\": \"Apple Inc. (AAPL) - $189.84, +1.23 (+0.65%).\"},\n", " {\"title\": \"Google Finance\", \"snippet\": \"AAPL stock price $189.84, market cap $2.95T.\"}\n", " ],\n", " \"answer\": \"Apple (AAPL) is trading at $189.84, up $1.23 (+0.65%). Market cap is approximately $2.95 trillion.\"\n", " },\n", "\n", " # ===== Weather =====\n", " {\n", " \"tools\": [WEB_SEARCH_TOOL],\n", " \"query\": \"What is the weather in New York right now?\",\n", " \"function_call\": {\"name\": \"web_search\", \"arguments\": {\"query\": \"New York weather today current conditions\"}},\n", " \"observations\": [\n", " {\"title\": \"Weather.com NYC\", \"snippet\": \"Currently 72°F, Partly Cloudy. High of 78°F, Low of 65°F.\"},\n", " {\"title\": \"AccuWeather NYC\", \"snippet\": \"New York City - Partly cloudy, 71°F, humidity 58%.\"}\n", " ],\n", " \"answer\": \"New York City is currently 72°F (22°C) and partly cloudy, with a high of 78°F and a low of 65°F today. Humidity is around 58%.\"\n", " },\n", " {\n", " \"tools\": [WEB_SEARCH_TOOL],\n", " \"query\": \"Will it rain in London tomorrow?\",\n", " \"function_call\": {\"name\": \"web_search\", \"arguments\": {\"query\": \"London weather forecast tomorrow rain\"}},\n", " \"observations\": [\n", " {\"title\": \"Met Office London\", \"snippet\": \"Tomorrow: Cloudy with periods of rain, 14°C. Rain probability 80%.\"},\n", " {\"title\": \"BBC Weather London\", \"snippet\": \"Tomorrow's forecast: Overcast with heavy rain expected from midday, 13-15°C.\"}\n", " ],\n", " \"answer\": \"Yes, rain is expected in London tomorrow. The forecast shows cloudy skies with periods of rain, temperatures around 13-15°C, and an 80% probability of rain.\"\n", " },\n", "\n", " # ===== Sports / Live Scores =====\n", " {\n", " \"tools\": [WEB_SEARCH_TOOL],\n", " \"query\": \"What were yesterday's Champions League results?\",\n", " \"function_call\": {\"name\": \"web_search\", \"arguments\": {\"query\": \"Champions League results yesterday scores\"}},\n", " \"observations\": [\n", " {\"title\": \"UEFA Results\", \"snippet\": \"Real Madrid 3-1 Bayern Munich. PSG 2-0 Dortmund.\"},\n", " {\"title\": \"ESPN Soccer\", \"snippet\": \"Champions League semi-final: Real Madrid beat Bayern Munich 3-1, PSG defeated Dortmund 2-0.\"}\n", " ],\n", " \"answer\": \"Yesterday's Champions League results: Real Madrid defeated Bayern Munich 3-1, and PSG beat Dortmund 2-0.\"\n", " },\n", "\n", " # ===== Population / Statistics =====\n", " {\n", " \"tools\": [WEB_SEARCH_TOOL],\n", " \"query\": \"What is the current population of Tokyo?\",\n", " \"function_call\": {\"name\": \"web_search\", \"arguments\": {\"query\": \"Tokyo population 2024\"}},\n", " \"observations\": [\n", " {\"title\": \"Tokyo Population 2024\", \"snippet\": \"Tokyo's population is approximately 13.96 million as of 2024.\"},\n", " {\"title\": \"Japan Statistics\", \"snippet\": \"Greater Tokyo Area has 37.4 million residents.\"}\n", " ],\n", " \"answer\": \"Tokyo's population is approximately 13.96 million (2024). The Greater Tokyo Area, the world's most populous metropolitan area, has about 37.4 million residents.\"\n", " },\n", "\n", " # ===== Upcoming Events =====\n", " {\n", " \"tools\": [WEB_SEARCH_TOOL],\n", " \"query\": \"When is the next solar eclipse?\",\n", " \"function_call\": {\"name\": \"web_search\", \"arguments\": {\"query\": \"next solar eclipse date 2025\"}},\n", " \"observations\": [\n", " {\"title\": \"NASA Solar Eclipses\", \"snippet\": \"The next total solar eclipse is March 29, 2025, visible across North Africa and Europe.\"},\n", " {\"title\": \"TimeandDate\", \"snippet\": \"Annular solar eclipse on October 2, 2025, visible in South America.\"}\n", " ],\n", " \"answer\": \"The next total solar eclipse is on March 29, 2025, visible across North Africa and Europe. There will also be an annular solar eclipse on October 2, 2025, visible in South America.\"\n", " },\n", "\n", " # ===== Product Reviews =====\n", " {\n", " \"tools\": [WEB_SEARCH_TOOL],\n", " \"query\": \"Is the iPhone 15 Pro worth buying?\",\n", " \"function_call\": {\"name\": \"web_search\", \"arguments\": {\"query\": \"iPhone 15 Pro review worth buying 2024\"}},\n", " \"observations\": [\n", " {\"title\": \"The Verge Review\", \"snippet\": \"Excellent camera, titanium design, USB-C. Score 9/10. Best iPhone yet.\"},\n", " {\"title\": \"CNET Review\", \"snippet\": \"Significant camera upgrades, lighter titanium build. Recommended.\"}\n", " ],\n", " \"answer\": \"Reviews are very positive. The iPhone 15 Pro scores 9/10 from The Verge and is recommended by CNET. Key highlights: excellent camera system, lighter titanium design, USB-C. Considered the best iPhone yet.\"\n", " },\n", "\n", " # ===== Tech / Software Versions =====\n", " {\n", " \"tools\": [WEB_SEARCH_TOOL],\n", " \"query\": \"What is the latest version of Python?\",\n", " \"function_call\": {\"name\": \"web_search\", \"arguments\": {\"query\": \"latest Python version release 2024\"}},\n", " \"observations\": [\n", " {\"title\": \"Python.org Downloads\", \"snippet\": \"Python 3.12.3 is the latest stable release as of April 2024.\"},\n", " {\"title\": \"Real Python News\", \"snippet\": \"Python 3.13.0 alpha 1 available for testing.\"}\n", " ],\n", " \"answer\": \"The latest stable Python release is 3.12.3 (as of April 2024). Python 3.13.0 is also available as an alpha for early testing.\"\n", " },\n", "\n", " # ===== Health / Medical =====\n", " {\n", " \"tools\": [WEB_SEARCH_TOOL],\n", " \"query\": \"What are the current COVID-19 symptoms?\",\n", " \"function_call\": {\"name\": \"web_search\", \"arguments\": {\"query\": \"COVID-19 symptoms 2024 latest variants\"}},\n", " \"observations\": [\n", " {\"title\": \"CDC Symptoms\", \"snippet\": \"Common symptoms include fever, cough, fatigue, sore throat, and new loss of taste or smell.\"},\n", " {\"title\": \"WHO Update\", \"snippet\": \"JN.1 variant symptoms are similar: fever, cough, fatigue, congestion, headache.\"}\n", " ],\n", " \"answer\": \"Current COVID-19 symptoms (including JN.1 variant) include fever, cough, fatigue, sore throat, congestion, headache, and in some cases loss of taste or smell. These are similar to earlier variants.\"\n", " },\n", "\n", " # ===== Travel / Visas =====\n", " {\n", " \"tools\": [WEB_SEARCH_TOOL],\n", " \"query\": \"Do I need a visa to travel to Japan from the US?\",\n", " \"function_call\": {\"name\": \"web_search\", \"arguments\": {\"query\": \"US citizen visa requirement Japan travel 2024\"}},\n", " \"observations\": [\n", " {\"title\": \"US State Department Japan\", \"snippet\": \"US citizens do not need a visa for tourist stays up to 90 days.\"},\n", " {\"title\": \"Japan Embassy\", \"snippet\": \"Visa-free entry for 68 countries including the US, for stays under 90 days.\"}\n", " ],\n", " \"answer\": \"US citizens do not need a visa to travel to Japan for tourism. Visa-free entry is allowed for stays up to 90 days.\"\n", " },\n", "\n", " # ===== Movies / Entertainment =====\n", " {\n", " \"tools\": [WEB_SEARCH_TOOL],\n", " \"query\": \"What are the top movies in theaters right now?\",\n", " \"function_call\": {\"name\": \"web_search\", \"arguments\": {\"query\": \"top movies in theaters now box office 2024\"}},\n", " \"observations\": [\n", " {\"title\": \"Box Office Mojo\", \"snippet\": \"1. Dune: Part Two $182M 2. Kung Fu Panda 4 $95M 3. Godzilla x Kong $110M.\"},\n", " {\"title\": \"IMDb Box Office\", \"snippet\": \"Current top films: Dune Part Two, Kung Fu Panda 4, Godzilla x Kong: The New Empire.\"}\n", " ],\n", " \"answer\": \"Current top movies: 1) Dune: Part Two ($182M) 2) Godzilla x Kong: The New Empire ($110M) 3) Kung Fu Panda 4 ($95M)\"\n", " },\n", "]\n", "\n", "# ==========================================================\n", "# No-tool examples — model should answer directly\n", "# ==========================================================\n", "NO_TOOL_EXAMPLES = [\n", " {\n", " \"tools\": [WEB_SEARCH_TOOL],\n", " \"query\": \"What is 15% of 200?\",\n", " \"function_call\": None,\n", " \"observations\": None,\n", " \"answer\": \"15% of 200 is 30.\"\n", " },\n", " {\n", " \"tools\": [WEB_SEARCH_TOOL],\n", " \"query\": \"Convert 100 kilometers to miles.\",\n", " \"function_call\": None,\n", " \"observations\": None,\n", " \"answer\": \"100 kilometers is approximately 62.14 miles.\"\n", " },\n", " {\n", " \"tools\": [WEB_SEARCH_TOOL],\n", " \"query\": \"What is the capital of France?\",\n", " \"function_call\": None,\n", " \"observations\": None,\n", " \"answer\": \"The capital of France is Paris.\"\n", " },\n", " {\n", " \"tools\": [WEB_SEARCH_TOOL],\n", " \"query\": \"How many planets are in our solar system?\",\n", " \"function_call\": None,\n", " \"observations\": None,\n", " \"answer\": \"There are 8 planets in our solar system: Mercury, Venus, Earth, Mars, Jupiter, Saturn, Uranus, and Neptune.\"\n", " },\n", " {\n", " \"tools\": [WEB_SEARCH_TOOL],\n", " \"query\": \"What is 2 + 2?\",\n", " \"function_call\": None,\n", " \"observations\": None,\n", " \"answer\": \"2 + 2 = 4\"\n", " },\n", " {\n", " \"tools\": [WEB_SEARCH_TOOL],\n", " \"query\": \"Explain gravity in one sentence.\",\n", " \"function_call\": None,\n", " \"observations\": None,\n", " \"answer\": \"Gravity is a fundamental force that attracts two bodies toward each other, proportional to their masses and inversely proportional to the square of the distance between them.\"\n", " },\n", " {\n", " \"tools\": [WEB_SEARCH_TOOL],\n", " \"query\": \"What color is the sky on a clear day?\",\n", " \"function_call\": None,\n", " \"observations\": None,\n", " \"answer\": \"The sky appears blue on a clear day due to Rayleigh scattering — sunlight is scattered by air molecules, and blue light is scattered more than other colors.\"\n", " },\n", " {\n", " \"tools\": [WEB_SEARCH_TOOL],\n", " \"query\": \"Who wrote Romeo and Juliet?\",\n", " \"function_call\": None,\n", " \"observations\": None,\n", " \"answer\": \"Romeo and Juliet was written by William Shakespeare, likely around 1594-1596.\"\n", " },\n", "]\n", "\n", "ALL_EXAMPLES = TOOL_CALL_EXAMPLES + NO_TOOL_EXAMPLES\n", "print(f\"✅ {len(TOOL_CALL_EXAMPLES)} tool-call + {len(NO_TOOL_EXAMPLES)} no-tool = {len(ALL_EXAMPLES)} total examples\")" ], "execution_count": null, "outputs": [] }, { "cell_type": "markdown", "metadata": {}, "source": [ "---\n", "### 4C — Convert to TinyLlama Chat Format for Training\n", "\n", "Following the approach from the HuggingFace cookbook and Microsoft's guide, we\n", "convert each xLAM example into the model's native chat template with special\n", "markers for tool calls and observations." ] }, { "cell_type": "code", "metadata": {}, "source": [ "def format_example_xlam(ex: dict, system: str) -> str:\n", " \"\"\"Convert one xLAM example into TinyLlama chat format.\n", "\n", " Format (with tool call):\n", " \\n{system}\n", " \\n{tool_schemas}\n", " \\n{query}\n", " \\n {json_function_call} \n", " \\n{observations_json}\n", " \\n{answer}\n", "\n", " Format (no tool call):\n", " \\n{system}\n", " \\n{tool_schemas}\n", " \\n{query}\n", " \\n{answer}\n", "\n", " This follows the Toolformer-style approach (Yao et al. 2023):\n", " Tools are encoded as special tokens so the model learns their structure.\n", " Reference: https://arxiv.org/abs/2302.04761\n", " \"\"\"\n", " # Build tool schema string\n", " tools_str = json.dumps(ex[\"tools\"], indent=2, ensure_ascii=False)\n", "\n", " parts = []\n", " parts.append(f\"<|system|>\\n{system}\")\n", " parts.append(f\"<|tools|>\\n{tools_str}\")\n", " parts.append(f\"<|user|>\\n{ex['query']}\")\n", "\n", " if ex[\"function_call\"] is not None:\n", " # Tool call\n", " fc = ex[\"function_call\"]\n", " action_str = json.dumps(fc, ensure_ascii=False)\n", " parts.append(f\"<|assistant|>\\n {action_str} \")\n", " # Observation\n", " obs_str = json.dumps(ex[\"observations\"], indent=2, ensure_ascii=False)\n", " parts.append(f\"<|observation|>\\n{obs_str}\")\n", " # Final answer\n", " parts.append(f\"<|assistant|>\\n{ex['answer']}\")\n", " else:\n", " # Direct answer, no tool call\n", " parts.append(f\"<|assistant|>\\n{ex['answer']}\")\n", "\n", " return \"\\n\".join(parts)\n", "\n", "\n", "# Build dataset\n", "rows = [{\"text\": format_example_xlam(ex, CUSTOM_SYSTEM_PROMPT)} for ex in ALL_EXAMPLES]\n", "train_ds = Dataset.from_list(rows)\n", "\n", "print(f\"✅ Dataset built — {len(train_ds)} examples\")\n", "print(f\"\\n{'='*60}\")\n", "print(\"--- Sample: tool-call example ---\")\n", "print(f\"{'='*60}\")\n", "print(train_ds[0][\"text\"][:800])\n", "print(f\"\\n{'='*60}\")\n", "print(\"--- Sample: no-tool example ---\")\n", "print(f\"{'='*60}\")\n", "print(train_ds[len(TOOL_CALL_EXAMPLES)][\"text\"])" ], "execution_count": null, "outputs": [] }, { "cell_type": "markdown", "metadata": {}, "source": [ "---\n", "### 4D — Optional: Convert to LlamaFactory ShareGPT Format\n", "\n", "If you prefer to use **LlamaFactory** (the most popular open-source fine-tuning\n", "framework with built-in agent tuning), uncomment and run this cell to export your\n", "data in ShareGPT format. Then use it with:\n", "\n", "```bash\n", "llamafactory-cli train --stage sft --dataset my_websearch_data \\\n", " --model_name_or_path TinyLlama/TinyLlama-1.1B-Chat-v1.0 \\\n", " --template llama3 --finetuning_type lora\n", "```" ] }, { "cell_type": "code", "metadata": {}, "source": [ "# ==========================================================\n", "# Convert to LlamaFactory ShareGPT format (optional export)\n", "# Reference: https://llamafactory.readthedocs.io/en/latest/getting_started/data_preparation.html\n", "# Compatible dataset: glaive_toolcall_en\n", "# ==========================================================\n", "\n", "def to_llamafactory_sharegpt(examples, system_prompt):\n", " \"\"\"Convert examples to LlamaFactory ShareGPT format.\n", " Roles: system, human, function_call, observation, gpt\n", " \"\"\"\n", " sharegpt_data = []\n", " for ex in examples:\n", " conv = [{\"from\": \"system\", \"value\": system_prompt}]\n", " conv.append({\"from\": \"human\", \"value\": ex[\"query\"]})\n", "\n", " if ex[\"function_call\"]:\n", " fc_json = json.dumps(ex[\"function_call\"], ensure_ascii=False)\n", " conv.append({\"from\": \"function_call\", \"value\": fc_json})\n", " obs_json = json.dumps(ex[\"observations\"], ensure_ascii=False)\n", " conv.append({\"from\": \"observation\", \"value\": obs_json})\n", "\n", " conv.append({\"from\": \"gpt\", \"value\": ex[\"answer\"]})\n", " sharegpt_data.append({\"conversations\": conv})\n", "\n", " return sharegpt_data\n", "\n", "# Export\n", "sharegpt_data = to_llamafactory_sharegpt(ALL_EXAMPLES, CUSTOM_SYSTEM_PROMPT)\n", "with open(f\"{OUTPUT_DIR}/dataset_sharegpt.json\", \"w\") as f:\n", " json.dump(sharegpt_data, f, indent=2, ensure_ascii=False)\n", "\n", "print(f\"✅ Exported {len(sharegpt_data)} examples to {OUTPUT_DIR}/dataset_sharegpt.json\")\n", "print(\"\\n--- Sample ShareGPT entry ---\")\n", "print(json.dumps(sharegpt_data[0], indent=2, ensure_ascii=False)[:600])" ], "execution_count": null, "outputs": [] }, { "cell_type": "markdown", "metadata": {}, "source": [ "---\n", "### 4E — (Optional) Load the Full xLAM Dataset for More Training Data\n", "\n", "The examples above are a starting set. For production quality, load the full\n", "[Salesforce xLAM-60k dataset](https://huggingface.co/datasets/Salesforce/xlam-function-calling-60k)\n", "and filter for web-search-like examples." ] }, { "cell_type": "code", "metadata": {}, "source": [ "# ==========================================================\n", "# OPTIONAL: Load the full xLAM dataset (60k verified examples)\n", "# Uncomment to use massive training data instead of synthetic examples.\n", "#\n", "# Reference: https://huggingface.co/datasets/Salesforce/xlam-function-calling-60k\n", "# Reference: https://huggingface.co/learn/cookbook/en/function_calling_fine_tuning_llms_on_xlam\n", "# ==========================================================\n", "\n", "# from datasets import load_dataset\n", "#\n", "# xlam_ds = load_dataset(\"Salesforce/xlam-function-calling-60k\", split=\"train\")\n", "#\n", "# # Filter for examples that use a search-like tool\n", "# def is_search_example(ex):\n", "# tools = json.loads(ex[\"tools\"]) if isinstance(ex[\"tools\"], str) else ex[\"tools\"]\n", "# for t in tools:\n", "# name = t.get(\"name\", \"\") if isinstance(t, dict) else str(t)\n", "# if \"search\" in name.lower():\n", "# return True\n", "# return False\n", "#\n", "# search_examples = xlam_ds.filter(is_search_example)\n", "# print(f\"Found {len(search_examples)} search-related examples from xLAM-60k\")\n", "\n", "print(\"✅ (Skipped) Uncomment above to load full xLAM dataset\")" ], "execution_count": null, "outputs": [] }, { "cell_type": "markdown", "metadata": {}, "source": [ "---\n", "## 5 — Load Model & Tokenizer (QLoRA 4-bit)" ] }, { "cell_type": "code", "metadata": {}, "source": [ "import torch\n", "from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig\n", "from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training, TaskType\n", "\n", "# --- 4-bit quantization (NF4) ---\n", "# Reference: QLoRA paper — Dettmers et al. 2023\n", "# Enables training 7B+ models on a single GPU\n", "bnb_config = BitsAndBytesConfig(\n", " load_in_4bit=True,\n", " bnb_4bit_use_double_quant=True, # nested quantization for extra savings\n", " bnb_4bit_quant_type=\"nf4\", # normal-float 4\n", " bnb_4bit_compute_dtype=torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16,\n", ")\n", "\n", "# --- Tokenizer ---\n", "tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, trust_remote_code=True)\n", "tokenizer.pad_token = tokenizer.eos_token\n", "tokenizer.padding_side = \"right\"\n", "\n", "# Add special tokens for tool calling (Toolformer-style: tools as special tokens)\n", "# Reference: https://arxiv.org/abs/2302.04761 (Toolformer paper)\n", "SPECIAL_TOKENS = [\n", " \"<|tools|>\", # marks tool schema section\n", " \"<|observation|>\", # marks tool output section\n", " \"\", # marks tool call start\n", " \"\", # marks tool call end\n", "]\n", "num_added = tokenizer.add_special_tokens({\"additional_special_tokens\": SPECIAL_TOKENS})\n", "print(f\"Added {num_added} special tokens for tool calling\")\n", "\n", "# --- Model ---\n", "model = AutoModelForCausalLM.from_pretrained(\n", " MODEL_NAME,\n", " quantization_config=bnb_config,\n", " device_map=\"auto\",\n", " trust_remote_code=True,\n", ")\n", "model.resize_token_embeddings(len(tokenizer)) # accommodate new tokens\n", "model.config.pad_token_id = tokenizer.pad_token_id\n", "\n", "# --- LoRA ---\n", "# Following Fireworks.ai best practices for function-calling fine-tuning:\n", "# Apply LoRA to all attention + MLP layers for maximum adaptability\n", "# Reference: https://christianjmills.com/posts/mastering-llms-course-notes/conference-talk-019\n", "lora_config = LoraConfig(\n", " task_type=TaskType.CAUSAL_LM,\n", " r=LORA_R,\n", " lora_alpha=LORA_ALPHA,\n", " lora_dropout=LORA_DROPOUT,\n", " target_modules=[\"q_proj\", \"k_proj\", \"v_proj\", \"o_proj\",\n", " \"gate_proj\", \"up_proj\", \"down_proj\"],\n", " bias=\"none\",\n", ")\n", "\n", "model = prepare_model_for_kbit_training(model)\n", "model = get_peft_model(model, lora_config)\n", "model.print_trainable_parameters()\n", "print(\"✅ Model loaded with QLoRA\")" ], "execution_count": null, "outputs": [] }, { "cell_type": "markdown", "metadata": {}, "source": [ "---\n", "## 6 — Tokenize Dataset" ] }, { "cell_type": "code", "metadata": {}, "source": [ "def tokenize_fn(example):\n", " result = tokenizer(\n", " example[\"text\"],\n", " truncation=True,\n", " max_length=MAX_SEQ_LEN,\n", " padding=False,\n", " )\n", " result[\"labels\"] = result[\"input_ids\"].copy()\n", " return result\n", "\n", "tokenized_ds = train_ds.map(tokenize_fn, batched=False, remove_columns=[\"text\"])\n", "print(f\"✅ Tokenized — {len(tokenized_ds)} rows\")\n", "print(f\"Sample lengths: {[len(x['input_ids']) for x in [tokenized_ds[0], tokenized_ds[5], tokenized_ds[10]]]}\")" ], "execution_count": null, "outputs": [] }, { "cell_type": "markdown", "metadata": {}, "source": [ "---\n", "## 7 — Train with SFTTrainer" ] }, { "cell_type": "code", "metadata": {}, "source": [ "from transformers import TrainingArguments\n", "from trl import SFTTrainer\n", "import os\n", "\n", "os.makedirs(OUTPUT_DIR, exist_ok=True)\n", "\n", "# --- Training arguments ---\n", "# Following Microsoft's SLM function-calling guide recommendations:\n", "# - cosine LR schedule with warmup\n", "# - gradient checkpointing for memory efficiency\n", "# - 8-bit AdamW optimizer\n", "# Reference: https://techcommunity.microsoft.com/blog/azure-ai-foundry-blog/fine-tuning-small-language-models-for-function-calling-a-comprehensive-guide/4362539\n", "training_args = TrainingArguments(\n", " output_dir=OUTPUT_DIR,\n", " num_train_epochs=NUM_EPOCHS,\n", " per_device_train_batch_size=BATCH_SIZE,\n", " gradient_accumulation_steps=GRAD_ACCUM,\n", " learning_rate=LEARNING_RATE,\n", " lr_scheduler_type=\"cosine\",\n", " warmup_ratio=0.05,\n", " fp16=torch.cuda.is_available(),\n", " bf16=torch.cuda.is_bf16_supported(),\n", " logging_steps=2,\n", " save_strategy=\"epoch\",\n", " gradient_checkpointing=True,\n", " optim=\"paged_adamw_8bit\",\n", " report_to=\"none\", # set \"wandb\" for Weights & Biases tracking\n", " ddp_find_unused_parameters=False,\n", " max_grad_norm=1.0,\n", ")\n", "\n", "trainer = SFTTrainer(\n", " model=model,\n", " train_dataset=tokenized_ds,\n", " args=training_args,\n", " dataset_text_field=None, # we pre-tokenized\n", " max_seq_length=MAX_SEQ_LEN,\n", " tokenizer=tokenizer,\n", ")\n", "\n", "print(\"✅ Trainer ready — starting training...\")\n", "trainer.train()\n", "print(\"\\n✅ Training complete!\")" ], "execution_count": null, "outputs": [] }, { "cell_type": "markdown", "metadata": {}, "source": [ "---\n", "## 8 — Save Adapter" ] }, { "cell_type": "code", "metadata": {}, "source": [ "adapter_path = f\"{OUTPUT_DIR}/adapter\"\n", "trainer.model.save_pretrained(adapter_path)\n", "tokenizer.save_pretrained(adapter_path)\n", "print(f\"✅ Adapter saved to {adapter_path}\")\n", "\n", "if HF_REPO:\n", " trainer.model.push_to_hub(HF_REPO)\n", " tokenizer.push_to_hub(HF_REPO)\n", " print(f\"✅ Pushed to https://huggingface.co/{HF_REPO}\")" ], "execution_count": null, "outputs": [] }, { "cell_type": "markdown", "metadata": {}, "source": [ "---\n", "## 9 — Merge LoRA into Full Model (Optional)" ] }, { "cell_type": "code", "metadata": {}, "source": [ "import gc\n", "\n", "merged_path = f\"{OUTPUT_DIR}/merged\"\n", "\n", "del model, trainer\n", "gc.collect()\n", "torch.cuda.empty_cache()\n", "\n", "base_model = AutoModelForCausalLM.from_pretrained(\n", " MODEL_NAME,\n", " torch_dtype=torch.float16,\n", " device_map=\"auto\",\n", " trust_remote_code=True,\n", ")\n", "\n", "from peft import PeftModel\n", "merged_model = PeftModel.from_pretrained(base_model, adapter_path)\n", "merged_model = merged_model.merge_and_unload()\n", "\n", "merged_model.save_pretrained(merged_path)\n", "tokenizer.save_pretrained(merged_path)\n", "print(f\"✅ Merged model saved to {merged_path}\")\n", "\n", "if HF_REPO:\n", " merged_model.push_to_hub(f\"{HF_REPO}-merged\")\n", " tokenizer.push_to_hub(f\"{HF_REPO}-merged\")\n", " print(f\"✅ Pushed to https://huggingface.co/{HF_REPO}-merged\")" ], "execution_count": null, "outputs": [] }, { "cell_type": "markdown", "metadata": {}, "source": [ "---\n", "## 10 — Inference: Test Tool Calling\n", "\n", "Test whether the model correctly:\n", "- **Calls web_search** for questions needing current information\n", "- **Answers directly** for general knowledge / math" ] }, { "cell_type": "code", "metadata": {}, "source": [ "import re\n", "\n", "def generate_reply(prompt: str, system: str = CUSTOM_SYSTEM_PROMPT, max_new: int = 256):\n", " \"\"\"Generate reply. Returns (raw_text, tool_call_or_none).\"\"\"\n", " tools_str = json.dumps([WEB_SEARCH_TOOL], indent=2, ensure_ascii=False)\n", " chat = (\n", " f\"<|system|>\\n{system}\\n\"\n", " f\"<|tools|>\\n{tools_str}\\n\"\n", " f\"<|user|>\\n{prompt}\\n\"\n", " f\"<|assistant|>\\n\"\n", " )\n", " inputs = tokenizer(chat, return_tensors=\"pt\").to(merged_model.device)\n", " with torch.inference_mode():\n", " out = merged_model.generate(\n", " **inputs,\n", " max_new_tokens=max_new,\n", " temperature=0.3,\n", " top_p=0.9,\n", " do_sample=True,\n", " pad_token_id=tokenizer.eos_token_id,\n", " )\n", " decoded = tokenizer.decode(out[0][inputs.input_ids.shape[-1]:], skip_special_tokens=False)\n", " decoded = decoded.replace(\"\", \"\").strip()\n", "\n", " # Parse tool call\n", " tool_call = None\n", " m = re.search(r'\\s*(\\{.*?\\})\\s*', decoded, re.DOTALL)\n", " if m:\n", " try:\n", " tool_call = json.loads(m.group(1))\n", " except json.JSONDecodeError:\n", " pass\n", "\n", " return decoded, tool_call\n", "\n", "\n", "# ====== Test questions ======\n", "test_qs = [\n", " (\"What is the latest news about AI?\", True),\n", " (\"Who is the current president of the US?\", True),\n", " (\"What is the weather in Tokyo today?\", True),\n", " (\"What is 25 * 4?\", False),\n", " (\"What is the capital of Japan?\", False),\n", " (\"Who wrote Harry Potter?\", False),\n", "]\n", "\n", "print(\"=\" * 70)\n", "for q, expect_tool in test_qs:\n", " reply, tc = generate_reply(q)\n", " has_tool = tc is not None\n", " status = \"✅\" if has_tool == expect_tool else \"❌\"\n", " print(f\"\\n{status} Q: {q}\")\n", " print(f\" Expected tool call: {expect_tool} | Got tool call: {has_tool}\")\n", " if tc:\n", " print(f\" 🔍 TOOL CALL → {tc.get('name', '?')}(query={tc.get('arguments', {}).get('query', '')})\")\n", " else:\n", " # Remove special tokens for display\n", " clean = re.sub(r'<[^>]+>', '', reply).strip()\n", " print(f\" 🤖 {clean[:200]}\")\n", "print(\"\\n\" + \"=\" * 70)" ], "execution_count": null, "outputs": [] }, { "cell_type": "markdown", "metadata": {}, "source": [ "---\n", "## 11 — Google Custom Search API Integration\n", "\n", "### How to Set Up Google Custom Search API\n", "\n", "**Step 1**: Create a Google Cloud project\n", "1. Go to [Google Cloud Console](https://console.cloud.google.com)\n", "2. Click **Select a project** → **New Project** → name it\n", "\n", "**Step 2**: Enable the Custom Search API\n", "1. Go to **APIs & Services** → **Library**\n", "2. Search for **\"Custom Search API\"** → click **Enable**\n", "\n", "**Step 3**: Get your API Key\n", "1. Go to **APIs & Services** → **Credentials**\n", "2. Click **Create Credentials** → **API Key**\n", "3. Copy the key → paste it in `GOOGLE_API_KEY` above\n", "\n", "**Step 4**: Create a Programmable Search Engine\n", "1. Go to [Programmable Search Engine](https://programmablesearchengine.google.com)\n", "2. Click **Add** → name your search engine\n", "3. Under \"Sites to search\", select **\"Search the entire web\"**\n", "4. Click **Create** → copy the **Search engine ID** → paste in `GOOGLE_CX_ID` above\n", "\n", "References:\n", "- [Google Custom Search JSON API Docs](https://developers.google.com/custom-search/v1/introduction)\n", "- [Google CSE Python Tutorial](https://thepythoncode.com/article/use-google-custom-search-engine-api-in-python)\n", "\n", "**Free tier**: 100 queries/day. For more, enable billing ($5 per 1,000 queries, first 1,000 free/month)." ] }, { "cell_type": "code", "metadata": {}, "source": [ "def google_search(query: str, api_key: str = GOOGLE_API_KEY, cx: str = GOOGLE_CX_ID,\n", " num_results: int = 5) -> list[dict]:\n", " \"\"\"Call Google Custom Search JSON API.\n", "\n", " Returns a list of dicts: [{\"title\": ..., \"snippet\": ..., \"url\": ...}]\n", "\n", " Setup: See Section 11 above for step-by-step instructions.\n", " Docs: https://developers.google.com/custom-search/v1/overview\n", "\n", " Free tier: 100 queries/day.\n", " \"\"\"\n", " import requests\n", "\n", " if not api_key or not cx:\n", " print(\"⚠️ Google API key or CX ID not set. Using mock results.\")\n", " print(\" Set GOOGLE_API_KEY and GOOGLE_CX_ID in Section 1 to use real search.\")\n", " return [\n", " {\"title\": f\"Mock result for: {query}\", \"snippet\": f\"Simulated search result for: {query}\", \"url\": \"#\"},\n", " {\"title\": f\"More results for: {query}\", \"snippet\": f\"Additional mock search info about: {query}\", \"url\": \"#\"},\n", " ]\n", "\n", " url = \"https://www.googleapis.com/customsearch/v1\"\n", " params = {\n", " \"key\": api_key,\n", " \"cx\": cx,\n", " \"q\": query,\n", " \"num\": num_results,\n", " }\n", "\n", " try:\n", " resp = requests.get(url, params=params, timeout=10)\n", " resp.raise_for_status()\n", " data = resp.json()\n", " results = []\n", " for item in data.get(\"items\", []):\n", " results.append({\n", " \"title\": item.get(\"title\", \"\"),\n", " \"snippet\": item.get(\"snippet\", \"\"),\n", " \"url\": item.get(\"link\", \"\"),\n", " })\n", " return results\n", " except requests.exceptions.RequestException as e:\n", " print(f\"❌ Search error: {e}\")\n", " return []\n", "\n", "# Quick test (will use mock if keys not set)\n", "test_results = google_search(\"test query\")\n", "print(f\"Search returned {len(test_results)} results:\")\n", "for r in test_results:\n", " print(f\" - {r['title']}\")" ], "execution_count": null, "outputs": [] }, { "cell_type": "markdown", "metadata": {}, "source": [ "---\n", "## 12 — Full Pipeline: Chat with Web Search\n", "\n", "Complete loop: user asks → model decides → Google search → model answers." ] }, { "cell_type": "code", "metadata": {}, "source": [ "def chat_with_search(user_msg: str, system: str = CUSTOM_SYSTEM_PROMPT):\n", " \"\"\"Full pipeline: generate → tool call → Google search → final answer.\n", "\n", " This implements the standard function-calling loop:\n", " 1. User sends message with tool schema in context\n", " 2. Model either calls a tool or answers directly\n", " 3. If tool called: execute tool, feed observation back to model\n", " 4. Model generates final answer\n", "\n", " Reference: OpenAI Function Calling flow\n", " \"\"\"\n", " tools_str = json.dumps([WEB_SEARCH_TOOL], indent=2, ensure_ascii=False)\n", "\n", " # --- Step 1: Initial generation ---\n", " chat = (\n", " f\"<|system|>\\n{system}\\n\"\n", " f\"<|tools|>\\n{tools_str}\\n\"\n", " f\"<|user|>\\n{user_msg}\\n\"\n", " f\"<|assistant|>\\n\"\n", " )\n", " inputs = tokenizer(chat, return_tensors=\"pt\").to(merged_model.device)\n", " with torch.inference_mode():\n", " out = merged_model.generate(\n", " **inputs, max_new_tokens=200, temperature=0.3,\n", " top_p=0.9, do_sample=True, pad_token_id=tokenizer.eos_token_id,\n", " )\n", " decoded = tokenizer.decode(out[0][inputs.input_ids.shape[-1]:], skip_special_tokens=False)\n", " decoded = decoded.replace(\"\", \"\").strip()\n", "\n", " # --- Step 2: Check for tool call ---\n", " m = re.search(r'\\s*(\\{.*?\\})\\s*', decoded, re.DOTALL)\n", " if not m:\n", " # No tool call — return direct answer\n", " clean = re.sub(r'<[^>]+>', '', decoded).strip()\n", " return clean\n", "\n", " # --- Step 3: Execute tool ---\n", " try:\n", " tool_call = json.loads(m.group(1))\n", " except json.JSONDecodeError:\n", " return decoded\n", "\n", " tool_name = tool_call.get(\"name\", \"\")\n", " tool_args = tool_call.get(\"arguments\", {})\n", " search_query = tool_args.get(\"query\", \"\") if isinstance(tool_args, dict) else str(tool_args)\n", "\n", " print(f\"🔍 Searching Google for: {search_query}\")\n", " observations = google_search(search_query)\n", " print(f\"📋 Got {len(observations)} results\")\n", "\n", " if not observations:\n", " return \"I searched but couldn't find relevant results. Please try rephrasing your question.\"\n", "\n", " # --- Step 4: Feed observation and get final answer ---\n", " obs_str = json.dumps(observations, indent=2, ensure_ascii=False)\n", " follow_up = (\n", " f\"<|system|>\\n{system}\\n\"\n", " f\"<|tools|>\\n{tools_str}\\n\"\n", " f\"<|user|>\\n{user_msg}\\n\"\n", " f\"<|assistant|>\\n {json.dumps(tool_call, ensure_ascii=False)} \\n\"\n", " f\"<|observation|>\\n{obs_str}\\n\"\n", " f\"<|assistant|>\\n\"\n", " )\n", " inputs2 = tokenizer(follow_up, return_tensors=\"pt\").to(merged_model.device)\n", " with torch.inference_mode():\n", " out2 = merged_model.generate(\n", " **inputs2, max_new_tokens=300, temperature=0.3,\n", " top_p=0.9, do_sample=True, pad_token_id=tokenizer.eos_token_id,\n", " )\n", " final = tokenizer.decode(out2[0][inputs2.input_ids.shape[-1]:], skip_special_tokens=True)\n", " return final.strip()\n", "\n", "\n", "# ====== Quick demo ======\n", "demo_questions = [\n", " \"What is 7 * 8?\", # no search\n", " \"What is the current price of Ethereum?\", # needs search\n", "]\n", "\n", "for q in demo_questions:\n", " print(f\"\\n🧑 {q}\")\n", " answer = chat_with_search(q)\n", " print(f\"🤖 {CUSTOM_MODEL_NAME}: {answer}\")\n", " print(\"-\" * 50)" ], "execution_count": null, "outputs": [] }, { "cell_type": "markdown", "metadata": {}, "source": [ "---\n", "## 13 — Interactive Chat (Run this cell to chat!)" ] }, { "cell_type": "code", "metadata": {}, "source": [ "print(\"=\" * 60)\n", "print(f\"🤖 {CUSTOM_MODEL_NAME} — Web Search Chat\")\n", "print(\"Type 'quit' to exit.\")\n", "print(\"=\" * 60)\n", "\n", "while True:\n", " try:\n", " q = input(\"\\n🧑 You: \").strip()\n", " except (EOFError, KeyboardInterrupt):\n", " break\n", " if q.lower() in (\"quit\", \"exit\", \"q\"):\n", " break\n", " if not q:\n", " continue\n", " answer = chat_with_search(q)\n", " print(f\"🤖 {CUSTOM_MODEL_NAME}: {answer}\")" ], "execution_count": null, "outputs": [] }, { "cell_type": "markdown", "metadata": {}, "source": [ "---\n", "## 14 — Custom System Prompt & Model Name (No Retraining Needed)" ] }, { "cell_type": "code", "metadata": {}, "source": [ "# --- Example: swap to a pirate coding assistant persona ---\n", "PIRATE_SYSTEM = (\n", " \"You are Captain Code, a pirate-themed coding assistant. \"\n", " \"Ye speak like a pirate and help with programming questions. \"\n", " \"When ye need up-to-date information, call the web_search tool. \"\n", " \"Answer concisely. No reasoning.\" \n", ")\n", "\n", "question = \"What is the latest version of React?\"\n", "print(f\"🧑 {question}\")\n", "print(f\"🎭 System: {PIRATE_SYSTEM}\\n\")\n", "\n", "answer = chat_with_search(question, system=PIRATE_SYSTEM)\n", "print(f\"🏴‍☠️ Captain Code: {answer}\")" ], "execution_count": null, "outputs": [] }, { "cell_type": "markdown", "metadata": {}, "source": [ "---\n", "## 15 — Export for Production\n", "\n", "### Option A: Ollama (GGUF)\n", "```bash\n", "# Install llama.cpp\n", "!git clone https://github.com/ggerganov/llama.cpp && cd llama.cpp && make\n", "\n", "# Convert to GGUF\n", "!python convert_hf_to_gguf.py {merged_path} --outfile tinyllama-websearch.gguf --outtype f16\n", "\n", "# Create Ollama Modelfile\n", "# FROM ./tinyllama-websearch.gguf\n", "# PARAMETER temperature 0.3\n", "# SYSTEM \"You are WebSearchLlama, a helpful assistant with web search...\"\n", "\n", "# ollama create websearchllama -f Modelfile\n", "# ollama run websearchllama\n", "```\n", "\n", "### Option B: vLLM (Fast Serving)\n", "```bash\n", "!pip install vllm\n", "from vllm import LLM\n", "llm = LLM(model=merged_path)\n", "output = llm.generate([\"Hello\"])\n", "```\n", "\n", "### Option C: llama.cpp HTTP Server\n", "```bash\n", "./server -m tinyllama-websearch.gguf -c 2048 --host 0.0.0.0 --port 8080\n", "```\n", "\n", "### Option D: LlamaFactory (Alternative Training)\n", "If you want to retrain with more data using LlamaFactory:\n", "```bash\n", "# 1. Copy the ShareGPT dataset\n", "cp {OUTPUT_DIR}/dataset_sharegpt.json /path/to/LLaMA-Factory/data/\n", "\n", "# 2. Add to dataset_info.json\n", "# \"websearch_toolcall\": {\n", "# \"file_name\": \"dataset_sharegpt.json\"\n", "# }\n", "\n", "# 3. Train\n", "llamafactory-cli train \\\n", " --stage sft \\\n", " --model_name_or_path TinyLlama/TinyLlama-1.1B-Chat-v1.0 \\\n", " --dataset websearch_toolcall \\\n", " --template default \\\n", " --finetuning_type lora \\\n", " --output_dir websearch_output\n", "```" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "---\n", "## 16 — How It Works: Architecture Summary\n", "\n", "```\n", "┌─────────────────────────────────────────────────────────────┐\n", "│ TRAINING PHASE │\n", "│ │\n", "│ xLAM-Format Examples │\n", "│ ┌──────────┐ ┌──────────┐ ┌───────────┐ ┌───────────┐ │\n", "│ │ Tools │ │ Query │ │ Tool Call │ │ Answer │ │\n", "│ │ Schema │ │ │ │ │ │ │ │\n", "│ └──────────┘ └──────────┘ └─────┬─────┘ └───────────┘ │\n", "│ │ │\n", "│ No-Tool Examples │\n", "│ (direct answers) │\n", "│ │\n", "│ ────────────── TinyLlama + QLoRA ──────────────▶ Model │\n", "└─────────────────────────────────────────────────────────────┘\n", "\n", "┌─────────────────────────────────────────────────────────────┐\n", "│ INFERENCE PHASE │\n", "│ │\n", "│ User: \"What is the weather in Tokyo?\" │\n", "│ │ │\n", "│ ▼ │\n", "│ Model ──▶ {\"name\": \"web_search\", ...} │\n", "│ │ │\n", "│ ▼ Your app parses the Action JSON │\n", "│ google_search(query=\"Tokyo weather today\") │\n", "│ │ │\n", "│ ▼ Feed observations back to model │\n", "│ Model ──▶ \"Tokyo is currently 18°C with clear skies.\" │\n", "│ │\n", "└─────────────────────────────────────────────────────────────┘\n", "```\n", "\n", "### Key Design Decisions (Research-Backed)\n", "\n", "| Decision | Rationale | Source |\n", "|---|---|---|\n", "| xLAM format for training data | Industry standard, 3-stage verified, used by HuggingFace cookbook | [Salesforce xLAM](https://huggingface.co/datasets/Salesforce/xlam-function-calling-60k) |\n", "| Tool schemas in every prompt | Model learns API shape; mirrors OpenAI function calling | [OpenAI Docs](https://platform.openai.com/docs/guides/function-calling) |\n", "| Special tokens for tools (``, `<|observation|>`) | Toolformer approach — encode tools as special tokens | [Toolformer paper](https://arxiv.org/abs/2302.04761) |\n", "| No-tool examples included | Model learns when NOT to call (critical for avoiding over-calling) | [Microsoft SLM Guide](https://techcommunity.microsoft.com/blog/azure-ai-foundry-blog/fine-tuning-small-language-models-for-function-calling-a-comprehensive-guide/4362539) |\n", "| No reasoning / CoT | Keeps model small, fast, and focused on direct answers | [TinyAgent (Berkeley)](https://bair.berkeley.edu/blog/2024/05/29/tiny-agent) |\n", "| LoRA on all attention + MLP layers | Maximum parameter efficiency for tool-calling adaptation | [Fireworks.ai best practices](https://christianjmills.com/posts/mastering-llms-course-notes/conference-talk-019) |\n", "| LlamaFactory ShareGPT export | Interoperability with most popular open-source fine-tuning framework | [LlamaFactory](https://github.com/hiyouga/LlamaFactory) |\n", "| Google Custom Search API | Industry-standard web search, 100 free queries/day | [Google CSE Docs](https://developers.google.com/custom-search/v1/introduction) |" ] } ] }