Update app.py
Browse files
app.py
CHANGED
|
@@ -23,12 +23,12 @@ from sentence_transformers import SentenceTransformer
|
|
| 23 |
# Global variables for pipelines and settings.
|
| 24 |
TEXT_PIPELINE = None
|
| 25 |
COMPARISON_PIPELINE = None
|
| 26 |
-
NUM_EXAMPLES =
|
| 27 |
|
| 28 |
@spaces.GPU(duration=300)
|
| 29 |
def finetune_small_subset():
|
| 30 |
"""
|
| 31 |
-
Fine-tunes the custom
|
| 32 |
Steps:
|
| 33 |
1) Loads the model from "wuhp/myr1" (using files from the "myr1" subfolder via trust_remote_code).
|
| 34 |
2) Applies 4-bit quantization and prepares for QLoRA training.
|
|
@@ -163,7 +163,7 @@ def ensure_pipeline():
|
|
| 163 |
|
| 164 |
def ensure_comparison_pipeline():
|
| 165 |
"""
|
| 166 |
-
Loads
|
| 167 |
"""
|
| 168 |
global COMPARISON_PIPELINE
|
| 169 |
if COMPARISON_PIPELINE is None:
|
|
@@ -180,7 +180,7 @@ def ensure_comparison_pipeline():
|
|
| 180 |
@spaces.GPU(duration=120)
|
| 181 |
def predict(prompt, temperature, top_p, min_new_tokens, max_new_tokens):
|
| 182 |
"""
|
| 183 |
-
Direct generation without retrieval.
|
| 184 |
"""
|
| 185 |
pipe = ensure_pipeline()
|
| 186 |
out = pipe(
|
|
@@ -196,7 +196,7 @@ def predict(prompt, temperature, top_p, min_new_tokens, max_new_tokens):
|
|
| 196 |
@spaces.GPU(duration=120)
|
| 197 |
def compare_models(prompt, temperature, top_p, min_new_tokens, max_new_tokens):
|
| 198 |
"""
|
| 199 |
-
Compare outputs between your custom model and
|
| 200 |
"""
|
| 201 |
local_pipe = ensure_pipeline()
|
| 202 |
comp_pipe = ensure_comparison_pipeline()
|
|
@@ -299,34 +299,34 @@ def chat_rag(user_input, history, temperature, top_p, min_new_tokens, max_new_to
|
|
| 299 |
|
| 300 |
# Build the Gradio interface.
|
| 301 |
with gr.Blocks() as demo:
|
| 302 |
-
gr.Markdown("# QLoRA Fine-tuning & RAG-based Chat Demo using Custom
|
| 303 |
|
| 304 |
finetune_btn = gr.Button("Finetune 4-bit (QLoRA) on ServiceNow-AI/R1-Distill-SFT subset (up to 5 min)")
|
| 305 |
status_box = gr.Textbox(label="Finetune Status")
|
| 306 |
finetune_btn.click(fn=finetune_small_subset, outputs=status_box)
|
| 307 |
|
| 308 |
-
gr.Markdown("## Direct Generation (No Retrieval)")
|
| 309 |
prompt_in = gr.Textbox(lines=3, label="Prompt")
|
| 310 |
temperature = gr.Slider(0.0, 1.5, step=0.1, value=0.7, label="Temperature")
|
| 311 |
top_p = gr.Slider(0.0, 1.0, step=0.05, value=0.9, label="Top-p")
|
| 312 |
min_tokens = gr.Slider(1, 2500, value=50, step=10, label="Min New Tokens")
|
| 313 |
max_tokens = gr.Slider(1, 2500, value=200, step=50, label="Max New Tokens")
|
| 314 |
-
output_box = gr.Textbox(label="
|
| 315 |
-
gen_btn = gr.Button("Generate with
|
| 316 |
gen_btn.click(
|
| 317 |
fn=predict,
|
| 318 |
inputs=[prompt_in, temperature, top_p, min_tokens, max_tokens],
|
| 319 |
outputs=output_box
|
| 320 |
)
|
| 321 |
|
| 322 |
-
gr.Markdown("## Compare
|
| 323 |
compare_btn = gr.Button("Compare")
|
| 324 |
-
|
| 325 |
-
|
| 326 |
compare_btn.click(
|
| 327 |
fn=compare_models,
|
| 328 |
inputs=[prompt_in, temperature, top_p, min_tokens, max_tokens],
|
| 329 |
-
outputs=[
|
| 330 |
)
|
| 331 |
|
| 332 |
gr.Markdown("## Chat with Retrieval-Augmented Memory")
|
|
|
|
| 23 |
# Global variables for pipelines and settings.
|
| 24 |
TEXT_PIPELINE = None
|
| 25 |
COMPARISON_PIPELINE = None
|
| 26 |
+
NUM_EXAMPLES = 100
|
| 27 |
|
| 28 |
@spaces.GPU(duration=300)
|
| 29 |
def finetune_small_subset():
|
| 30 |
"""
|
| 31 |
+
Fine-tunes the custom R1 model on a small subset of the ServiceNow-AI/R1-Distill-SFT dataset.
|
| 32 |
Steps:
|
| 33 |
1) Loads the model from "wuhp/myr1" (using files from the "myr1" subfolder via trust_remote_code).
|
| 34 |
2) Applies 4-bit quantization and prepares for QLoRA training.
|
|
|
|
| 163 |
|
| 164 |
def ensure_comparison_pipeline():
|
| 165 |
"""
|
| 166 |
+
Loads the official R1 model pipeline if not already loaded.
|
| 167 |
"""
|
| 168 |
global COMPARISON_PIPELINE
|
| 169 |
if COMPARISON_PIPELINE is None:
|
|
|
|
| 180 |
@spaces.GPU(duration=120)
|
| 181 |
def predict(prompt, temperature, top_p, min_new_tokens, max_new_tokens):
|
| 182 |
"""
|
| 183 |
+
Direct generation without retrieval using the custom R1 model.
|
| 184 |
"""
|
| 185 |
pipe = ensure_pipeline()
|
| 186 |
out = pipe(
|
|
|
|
| 196 |
@spaces.GPU(duration=120)
|
| 197 |
def compare_models(prompt, temperature, top_p, min_new_tokens, max_new_tokens):
|
| 198 |
"""
|
| 199 |
+
Compare outputs between your custom R1 model and the official R1 model.
|
| 200 |
"""
|
| 201 |
local_pipe = ensure_pipeline()
|
| 202 |
comp_pipe = ensure_comparison_pipeline()
|
|
|
|
| 299 |
|
| 300 |
# Build the Gradio interface.
|
| 301 |
with gr.Blocks() as demo:
|
| 302 |
+
gr.Markdown("# QLoRA Fine-tuning & RAG-based Chat Demo using Custom R1 Model")
|
| 303 |
|
| 304 |
finetune_btn = gr.Button("Finetune 4-bit (QLoRA) on ServiceNow-AI/R1-Distill-SFT subset (up to 5 min)")
|
| 305 |
status_box = gr.Textbox(label="Finetune Status")
|
| 306 |
finetune_btn.click(fn=finetune_small_subset, outputs=status_box)
|
| 307 |
|
| 308 |
+
gr.Markdown("## Direct Generation (No Retrieval) using Custom R1")
|
| 309 |
prompt_in = gr.Textbox(lines=3, label="Prompt")
|
| 310 |
temperature = gr.Slider(0.0, 1.5, step=0.1, value=0.7, label="Temperature")
|
| 311 |
top_p = gr.Slider(0.0, 1.0, step=0.05, value=0.9, label="Top-p")
|
| 312 |
min_tokens = gr.Slider(1, 2500, value=50, step=10, label="Min New Tokens")
|
| 313 |
max_tokens = gr.Slider(1, 2500, value=200, step=50, label="Max New Tokens")
|
| 314 |
+
output_box = gr.Textbox(label="Custom R1 Output", lines=8)
|
| 315 |
+
gen_btn = gr.Button("Generate with Custom R1")
|
| 316 |
gen_btn.click(
|
| 317 |
fn=predict,
|
| 318 |
inputs=[prompt_in, temperature, top_p, min_tokens, max_tokens],
|
| 319 |
outputs=output_box
|
| 320 |
)
|
| 321 |
|
| 322 |
+
gr.Markdown("## Compare Custom R1 vs Official R1")
|
| 323 |
compare_btn = gr.Button("Compare")
|
| 324 |
+
out_custom = gr.Textbox(label="Custom R1 Output", lines=6)
|
| 325 |
+
out_official = gr.Textbox(label="Official R1 Output", lines=6)
|
| 326 |
compare_btn.click(
|
| 327 |
fn=compare_models,
|
| 328 |
inputs=[prompt_in, temperature, top_p, min_tokens, max_tokens],
|
| 329 |
+
outputs=[out_custom, out_official]
|
| 330 |
)
|
| 331 |
|
| 332 |
gr.Markdown("## Chat with Retrieval-Augmented Memory")
|