Upload 106 files
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- .gitattributes +3 -0
- README.md +19 -6
- __pycache__/utils.cpython-312.pyc +0 -0
- app.py +1026 -0
- app.sh +17 -0
- app_logs/app_5936040.out +18 -0
- app_logs/app_5936041.out +18 -0
- app_logs/app_5936047.out +19 -0
- app_logs/app_5936050.out +1 -0
- app_logs/app_5936052.out +57 -0
- assets/umd_logo.png +3 -0
- configs/prompts.yaml +100 -0
- configs/task1_demo.yaml +27 -0
- configs/task1_demo_sph.yaml +28 -0
- data/survey_responses_screened.csv +3 -0
- push.sh +22 -0
- requirements.txt +167 -0
- requirements_concise.txt +18 -0
- unsloth/gemma-3-12b-it_task1_1_epochs_test_neutral_partialTraits/README.md +202 -0
- unsloth/gemma-3-12b-it_task1_1_epochs_test_neutral_partialTraits/adapter_config.json +31 -0
- unsloth/gemma-3-12b-it_task1_1_epochs_test_neutral_partialTraits/adapter_model.safetensors +3 -0
- unsloth/gemma-3-12b-it_task1_1_epochs_test_neutral_partialTraits/added_tokens.json +3 -0
- unsloth/gemma-3-12b-it_task1_1_epochs_test_neutral_partialTraits/chat_template.json +3 -0
- unsloth/gemma-3-12b-it_task1_1_epochs_test_neutral_partialTraits/preprocessor_config.json +29 -0
- unsloth/gemma-3-12b-it_task1_1_epochs_test_neutral_partialTraits/processor_config.json +4 -0
- unsloth/gemma-3-12b-it_task1_1_epochs_test_neutral_partialTraits/special_tokens_map.json +33 -0
- unsloth/gemma-3-12b-it_task1_1_epochs_test_neutral_partialTraits/tokenizer.json +3 -0
- unsloth/gemma-3-12b-it_task1_1_epochs_test_neutral_partialTraits/tokenizer.model +3 -0
- unsloth/gemma-3-12b-it_task1_1_epochs_test_neutral_partialTraits/tokenizer_config.json +0 -0
- unsloth_compiled_cache/AqlmLoraLinear_peft_forward.py +67 -0
- unsloth_compiled_cache/AwqLoraLinear_peft_forward.py +66 -0
- unsloth_compiled_cache/BatchNorm1d.py +88 -0
- unsloth_compiled_cache/BatchNorm2d.py +88 -0
- unsloth_compiled_cache/BatchNorm3d.py +88 -0
- unsloth_compiled_cache/Conv1d.py +43 -0
- unsloth_compiled_cache/Conv2d.py +43 -0
- unsloth_compiled_cache/Conv3d.py +43 -0
- unsloth_compiled_cache/ConvTranspose1d.py +70 -0
- unsloth_compiled_cache/ConvTranspose2d.py +71 -0
- unsloth_compiled_cache/ConvTranspose3d.py +71 -0
- unsloth_compiled_cache/GPTQLoraLinear_peft_forward.py +73 -0
- unsloth_compiled_cache/GroupNorm.py +43 -0
- unsloth_compiled_cache/LayerNorm.py +45 -0
- unsloth_compiled_cache/Linear4bit_peft_forward.py +97 -0
- unsloth_compiled_cache/Linear8bitLt_peft_forward.py +90 -0
- unsloth_compiled_cache/Linear_peft_forward.py +89 -0
- unsloth_compiled_cache/LoraParallelLinear_peft_forward.py +87 -0
- unsloth_compiled_cache/RMSNorm.py +46 -0
- unsloth_compiled_cache/UnslothAlignPropTrainer.py +637 -0
- unsloth_compiled_cache/UnslothBCOTrainer.py +1824 -0
.gitattributes
CHANGED
|
@@ -36,3 +36,6 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
| 36 |
src_hf_deploy[[:space:]]2/assets/umd_logo.png filter=lfs diff=lfs merge=lfs -text
|
| 37 |
src_hf_deploy[[:space:]]2/data/survey_responses_screened.csv filter=lfs diff=lfs merge=lfs -text
|
| 38 |
src_hf_deploy[[:space:]]2/unsloth/gemma-3-12b-it_task1_1_epochs_test_neutral_partialTraits/tokenizer.json filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
|
|
|
|
|
|
| 36 |
src_hf_deploy[[:space:]]2/assets/umd_logo.png filter=lfs diff=lfs merge=lfs -text
|
| 37 |
src_hf_deploy[[:space:]]2/data/survey_responses_screened.csv filter=lfs diff=lfs merge=lfs -text
|
| 38 |
src_hf_deploy[[:space:]]2/unsloth/gemma-3-12b-it_task1_1_epochs_test_neutral_partialTraits/tokenizer.json filter=lfs diff=lfs merge=lfs -text
|
| 39 |
+
assets/umd_logo.png filter=lfs diff=lfs merge=lfs -text
|
| 40 |
+
data/survey_responses_screened.csv filter=lfs diff=lfs merge=lfs -text
|
| 41 |
+
unsloth/gemma-3-12b-it_task1_1_epochs_test_neutral_partialTraits/tokenizer.json filter=lfs diff=lfs merge=lfs -text
|
README.md
CHANGED
|
@@ -1,12 +1,25 @@
|
|
| 1 |
---
|
| 2 |
-
title:
|
| 3 |
-
emoji:
|
| 4 |
-
colorFrom:
|
| 5 |
-
colorTo:
|
| 6 |
sdk: gradio
|
| 7 |
-
sdk_version:
|
| 8 |
app_file: app.py
|
| 9 |
pinned: false
|
|
|
|
|
|
|
|
|
|
|
|
|
| 10 |
---
|
| 11 |
|
| 12 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
---
|
| 2 |
+
title: AI-Empowered Community Simulation (Beta)
|
| 3 |
+
emoji: 🧠
|
| 4 |
+
colorFrom: red
|
| 5 |
+
colorTo: purple
|
| 6 |
sdk: gradio
|
| 7 |
+
sdk_version: 5.49.1
|
| 8 |
app_file: app.py
|
| 9 |
pinned: false
|
| 10 |
+
# hardware: "gpu-a100-large" # REQUESTS A100 80GB GPU
|
| 11 |
+
# hardware: "gpu-l40s" # Request 1x NVIDIA L40S (48GB VRAM)
|
| 12 |
+
# hardware: "zerogpu"
|
| 13 |
+
hardware: "t4-small"
|
| 14 |
---
|
| 15 |
|
| 16 |
+
# AI-Empowered Community Simulation (Beta)
|
| 17 |
+
|
| 18 |
+
This Space requires **at least 28 Gb of GPU RAM** due to the size of the UnsLoTH long-context VLM model used for inference and summarization.
|
| 19 |
+
|
| 20 |
+
If the hardware fails to start or your account does not have access to this tier,
|
| 21 |
+
please select the appropriate hardware from:
|
| 22 |
+
|
| 23 |
+
**Settings → Hardware → ZeroGPU**
|
| 24 |
+
|
| 25 |
+
---
|
__pycache__/utils.cpython-312.pyc
ADDED
|
Binary file (5.34 kB). View file
|
|
|
app.py
ADDED
|
@@ -0,0 +1,1026 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Instruction Tuning of LLM for Trait-conditioned Style Impact Caliberation
|
| 3 |
+
"""
|
| 4 |
+
import unsloth
|
| 5 |
+
import yaml # type: ignore
|
| 6 |
+
import pandas as pd # type: ignore
|
| 7 |
+
import os
|
| 8 |
+
from PIL import Image # type: ignore
|
| 9 |
+
import gradio as gr
|
| 10 |
+
|
| 11 |
+
import torch # type: ignore
|
| 12 |
+
from langchain_community.chat_models import ChatOllama # type: ignore
|
| 13 |
+
from langchain_core.messages import SystemMessage, HumanMessage # type: ignore
|
| 14 |
+
from langchain_ollama import OllamaEmbeddings # type: ignore
|
| 15 |
+
from langchain_core.output_parsers import StrOutputParser # type: ignore
|
| 16 |
+
from pydantic import BaseModel # format LLM output as JSON # type: ignore
|
| 17 |
+
from unsloth import FastVisionModel, FastModel, FastLanguageModel # type: ignore
|
| 18 |
+
from transformers import TextStreamer # type: ignore
|
| 19 |
+
from unsloth.chat_templates import get_chat_template # type: ignore
|
| 20 |
+
from unsloth.chat_templates import standardize_sharegpt # type: ignore
|
| 21 |
+
from transformers import TextIteratorStreamer
|
| 22 |
+
|
| 23 |
+
from utils import convert_to_base64, load_config, process_trait_info # type: ignore
|
| 24 |
+
from tqdm import tqdm # type: ignore
|
| 25 |
+
from termcolor import colored # type: ignore
|
| 26 |
+
import threading
|
| 27 |
+
import random
|
| 28 |
+
import numpy as np
|
| 29 |
+
import random
|
| 30 |
+
|
| 31 |
+
import threading
|
| 32 |
+
# generation_lock = threading.Lock()
|
| 33 |
+
|
| 34 |
+
# from transformers import StoppingCriteria, StoppingCriteriaList
|
| 35 |
+
# class StopGenerationCriteria(StoppingCriteria):
|
| 36 |
+
# def __init__(self, stop_event):
|
| 37 |
+
# self.stop_event = stop_event
|
| 38 |
+
|
| 39 |
+
# def __call__(self, input_ids, scores, **kwargs):
|
| 40 |
+
# return self.stop_event.is_set()
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
| 44 |
+
|
| 45 |
+
TRAIT_VALUES = {
|
| 46 |
+
"Gender": [
|
| 47 |
+
"Male", "Female", "Non-binary/third gender", "Leave Blank",
|
| 48 |
+
],
|
| 49 |
+
"Age": [
|
| 50 |
+
"18–24", "25–34", "35–44", "45–54", "55–64", "65 or older", "Leave Blank",
|
| 51 |
+
],
|
| 52 |
+
"Current Profession": [
|
| 53 |
+
"Healthcare/Medical", "Government/Public Service",
|
| 54 |
+
"Business/Finance",
|
| 55 |
+
"Technology/Engineering", "Education", "Arts/Entertainment",
|
| 56 |
+
"Retail/Hospitality/Food Service",
|
| 57 |
+
"Skilled Trades/Labor (e.g., construction, electrician, landscaper, house cleaner)",
|
| 58 |
+
"Student",
|
| 59 |
+
"Unemployed/Looking for work", "Retired",
|
| 60 |
+
"Other",
|
| 61 |
+
"Leave Blank",
|
| 62 |
+
],
|
| 63 |
+
"Race/Ethnicity" : [
|
| 64 |
+
"Asian", "Black/African American", "Hispanic/Latino",
|
| 65 |
+
"Native American/Alaska Native", "Native Hawaiian/Other Pacific Islander",
|
| 66 |
+
"White/Caucasian", "Other", "Leave Blank",
|
| 67 |
+
],
|
| 68 |
+
"Religious/Cultural Group": [
|
| 69 |
+
"Christianity", "Islam", "Hinduism", "Judaism", "Buddhism", "None of the above", "Leave Blank",
|
| 70 |
+
],
|
| 71 |
+
"Political Affiliation": [
|
| 72 |
+
"Conservative", "Apolitical/Not involved in politics", "Independent",
|
| 73 |
+
"Libertarian", "Moderate", "Liberal", "Leave Blank",
|
| 74 |
+
],
|
| 75 |
+
"Highest Education": [
|
| 76 |
+
"Less than high school", "High school diploma or equivalent", "Some college, no degree",
|
| 77 |
+
"Associate’s degree", "Bachelor’s degree",
|
| 78 |
+
"Master’s degree", "Doctoral or professional degree",
|
| 79 |
+
"Leave Blank",
|
| 80 |
+
],
|
| 81 |
+
"Annual Household Income": [
|
| 82 |
+
"Less than $25,000", "$25,000–$49,999", "$50,000–$74,999",
|
| 83 |
+
"$75,000–$99,999", "$100,000–$149,999", "$150,000 or more",
|
| 84 |
+
"Leave Blank",
|
| 85 |
+
],
|
| 86 |
+
"Family Status": [
|
| 87 |
+
"Single, living alone", "Single, living with family", "Single Parent with children",
|
| 88 |
+
"Married/Partnered, no children", "Married/Partnered, with children",
|
| 89 |
+
"Multi-generation family (e.g., with parents, grandparents, or extended family)",
|
| 90 |
+
"Leave Blank",
|
| 91 |
+
],
|
| 92 |
+
}
|
| 93 |
+
|
| 94 |
+
HEALTH_TOPICS = {
|
| 95 |
+
"Chronic Obstructive Pulmonary Disease (COPD)": "COPD1.1",
|
| 96 |
+
"Heart Disease": "HD1",
|
| 97 |
+
"HIV": "HIV1.1",
|
| 98 |
+
"Mental Health": "MH1.1",
|
| 99 |
+
"Nutrition": "N2.1",
|
| 100 |
+
"Substance Abuse": "SA4.1",
|
| 101 |
+
"Sexual Practice": "SP7.1",
|
| 102 |
+
"Vaccination": "V7.1",
|
| 103 |
+
"Cystic Fibrosis": "CF1.1",
|
| 104 |
+
}
|
| 105 |
+
|
| 106 |
+
health_topics = ""
|
| 107 |
+
for topic in HEALTH_TOPICS:
|
| 108 |
+
health_topics += topic + '\n'
|
| 109 |
+
|
| 110 |
+
|
| 111 |
+
|
| 112 |
+
##########################################################
|
| 113 |
+
### To increase style variability to avoid repetitiveness
|
| 114 |
+
##########################################################
|
| 115 |
+
# * Style variants
|
| 116 |
+
style_variants = [
|
| 117 |
+
"Write with a slightly informal and reflective tone.",
|
| 118 |
+
"Write in a straightforward conversational tone.",
|
| 119 |
+
"Write with mild emotional coloring, but still natural.",
|
| 120 |
+
"Write in a calm, matter-of-fact tone.",
|
| 121 |
+
"Write in a slightly narrative, flowing tone.",
|
| 122 |
+
"Write in a concise but personable tone.",
|
| 123 |
+
"Write in a informal, pragmatic tone, focusing on clarity and utility.",
|
| 124 |
+
]
|
| 125 |
+
# --- Add small lexical noise / synonym variation ---
|
| 126 |
+
lexical_flavors = [
|
| 127 |
+
"Feel free to vary sentence structures slightly.",
|
| 128 |
+
"Use a mix of simple and slightly complex sentences.",
|
| 129 |
+
"Use a light mix of paraphrasing expressions.",
|
| 130 |
+
"Feel free to choose different synonyms for common emotional words.",
|
| 131 |
+
"Introduce subtle variation in connectors like 'however', 'still', or 'overall'.",
|
| 132 |
+
]
|
| 133 |
+
openers = [
|
| 134 |
+
"This message",
|
| 135 |
+
"From this message",
|
| 136 |
+
"Through the message",
|
| 137 |
+
"After seeing this message",
|
| 138 |
+
"Looking at this poster",
|
| 139 |
+
"Based on what this poster conveys",
|
| 140 |
+
"Hmmm I think that this message",
|
| 141 |
+
"Reflecting on the message here",
|
| 142 |
+
"Considering what this poster is trying to say",
|
| 143 |
+
"Seeing this message makes me think",
|
| 144 |
+
"Thinking about what this poster is communicating",
|
| 145 |
+
"After reading what's on here",
|
| 146 |
+
"Based on what’s written here",
|
| 147 |
+
"After I look at this whole thing",
|
| 148 |
+
]
|
| 149 |
+
openers_generic = [
|
| 150 |
+
"Hmmm when thinking about",
|
| 151 |
+
"When I think about",
|
| 152 |
+
"My impression about",
|
| 153 |
+
"On top of my head",
|
| 154 |
+
"My general thoughts about",
|
| 155 |
+
"The way I see it,",
|
| 156 |
+
"From my point of view on",
|
| 157 |
+
"My initial take on",
|
| 158 |
+
"In my own words,",
|
| 159 |
+
"As I see things,",
|
| 160 |
+
"Just speaking for myself,",
|
| 161 |
+
"At a glance,",
|
| 162 |
+
]
|
| 163 |
+
openers_poster_summary = [
|
| 164 |
+
"This poster",
|
| 165 |
+
"This poster seems to",
|
| 166 |
+
"My interpretation of the poster is",
|
| 167 |
+
"From what this poster shows, it seems to",
|
| 168 |
+
"Looking at the poster as a whole, it appears to",
|
| 169 |
+
"Based on the imagery and tone, the poster seems to",
|
| 170 |
+
"Visually, the poster comes across as trying to",
|
| 171 |
+
"To me, this poster is trying to",
|
| 172 |
+
"When I look at this poster, it feels like it aims to",
|
| 173 |
+
"The poster gives me the impression that it intends to",
|
| 174 |
+
]
|
| 175 |
+
openers_explain = [
|
| 176 |
+
"The reason why I think that is because",
|
| 177 |
+
"To explain why I",
|
| 178 |
+
"Well, to explain my thoughts",
|
| 179 |
+
"To put it simply, I feel this way because",
|
| 180 |
+
"My reasoning behind that is",
|
| 181 |
+
"What leads me to that view is",
|
| 182 |
+
"A big part of why I think that is",
|
| 183 |
+
"To give some context for my view,",
|
| 184 |
+
"Here’s why I lean that way:",
|
| 185 |
+
"I see it that way mainly because",
|
| 186 |
+
"Let me explain why I think so",
|
| 187 |
+
"Thinking through it, I realize it's because",
|
| 188 |
+
"To unpack my thinking a bit,",
|
| 189 |
+
"I guess it’s because",
|
| 190 |
+
"The thing that really shapes my view is",
|
| 191 |
+
"It’s pretty much because",
|
| 192 |
+
"A lot of it comes down to",
|
| 193 |
+
"I feel that way mostly because",
|
| 194 |
+
"My thinking comes from the idea that",
|
| 195 |
+
]
|
| 196 |
+
|
| 197 |
+
|
| 198 |
+
|
| 199 |
+
"""
|
| 200 |
+
Generate LLM response given a single user prompt and input image
|
| 201 |
+
"""
|
| 202 |
+
def vlm_response(user_input, history, health_topic,
|
| 203 |
+
gender, age, profession, race, religion,
|
| 204 |
+
political, education, income, family_status,
|
| 205 |
+
# extraversion, agreeableness, conscientiousness, neuroticism, openness,
|
| 206 |
+
):
|
| 207 |
+
# # 1. Initialize Stop Event for this session
|
| 208 |
+
# stop_event = threading.Event()
|
| 209 |
+
# # Create the stopping criteria to pass to the model
|
| 210 |
+
# stopping_criteria = StoppingCriteriaList([StopGenerationCriteria(stop_event)])
|
| 211 |
+
|
| 212 |
+
# 1. Clear any lingering state
|
| 213 |
+
torch.cuda.empty_cache() # Clear GPU memory
|
| 214 |
+
# 2. Initialize Streamers LOCALLY (Fresh for every request)
|
| 215 |
+
# Note: We need to re-initialize these for every single generation call
|
| 216 |
+
# or just once per function call if we share them.
|
| 217 |
+
streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
|
| 218 |
+
# streamer_aux = TextIteratorStreamer(tokenizer_aux, skip_prompt=True, skip_special_tokens=True)
|
| 219 |
+
|
| 220 |
+
""" [NOTE] we have not use `history` for this generation """
|
| 221 |
+
# get uploaded image
|
| 222 |
+
image = Image.open(user_input['files'][0]) if user_input['files'] else None
|
| 223 |
+
image_uploaded = True
|
| 224 |
+
if image is None:
|
| 225 |
+
image = Image.new('RGB', (24,24))
|
| 226 |
+
image_uploaded = False
|
| 227 |
+
# image_b64 = convert_to_base64(image)
|
| 228 |
+
print(health_topic)
|
| 229 |
+
# print("Image uploaded:", image_uploaded)
|
| 230 |
+
|
| 231 |
+
|
| 232 |
+
|
| 233 |
+
#################################################
|
| 234 |
+
# 1. Construct traits from user inputs
|
| 235 |
+
#################################################
|
| 236 |
+
demo_dict = {
|
| 237 |
+
"Gender": gender,
|
| 238 |
+
"Age": age,
|
| 239 |
+
"Current Profession": profession,
|
| 240 |
+
"Race/Ethnicity": race,
|
| 241 |
+
"Religious/Cultural Group": religion,
|
| 242 |
+
"Political Affiliation": political,
|
| 243 |
+
"Highest Education": education,
|
| 244 |
+
"Annual Household Income": income,
|
| 245 |
+
"Family Status": family_status,
|
| 246 |
+
}
|
| 247 |
+
# big5_dict = {
|
| 248 |
+
# "Extraversion": extraversion,
|
| 249 |
+
# "Agreeableness": agreeableness,
|
| 250 |
+
# "Conscientiousness": conscientiousness,
|
| 251 |
+
# "Neuroticism": neuroticism,
|
| 252 |
+
# "Open-Mindedness": openness,
|
| 253 |
+
# }
|
| 254 |
+
|
| 255 |
+
demo_info = ""
|
| 256 |
+
for trait, value in demo_dict.items():
|
| 257 |
+
if value != "Leave Blank": # only add non-blank values
|
| 258 |
+
demo_info += f"{trait}: {value}\n"
|
| 259 |
+
else:
|
| 260 |
+
demo_info += f"{trait}: [Not specified]\n"
|
| 261 |
+
persona_score = ""
|
| 262 |
+
persona_score += "Big-Five Trait Scores:\n"
|
| 263 |
+
# for trait, value in big5_dict.items():
|
| 264 |
+
# persona_score += f"{trait}: {value}\n"
|
| 265 |
+
# no locus of control trait score
|
| 266 |
+
locus = None
|
| 267 |
+
|
| 268 |
+
######################################################################################
|
| 269 |
+
# 1*. modify trait info based on trait selection setings
|
| 270 |
+
# demo_full: wheter include full demographic traits or only selected ones
|
| 271 |
+
# include_big5, include_facet, include_locus: include big5 / facet / locus of control traits or not
|
| 272 |
+
# format: <trait>: <value> if available; else <trait>: [Not specified]
|
| 273 |
+
######################################################################################
|
| 274 |
+
demo_info, persona_score, locus = process_trait_info(
|
| 275 |
+
demo_info, persona_score, locus,
|
| 276 |
+
demo_full=False, include_big5=True,
|
| 277 |
+
include_facet=False, include_locus=False,
|
| 278 |
+
train_mode=False,
|
| 279 |
+
)
|
| 280 |
+
# print(demo_info)
|
| 281 |
+
# print(persona_score)
|
| 282 |
+
|
| 283 |
+
###############################################
|
| 284 |
+
### Add style variability ###
|
| 285 |
+
###############################################
|
| 286 |
+
style_hint = random.choice(style_variants) # increase style variant
|
| 287 |
+
lexical_hint = random.choice(lexical_flavors) # increase lexical variant
|
| 288 |
+
opening_phrase = random.choice(openers) # increase opening variant
|
| 289 |
+
opening_generic = random.choice(openers_generic) # increase opening variant
|
| 290 |
+
opening_poster = random.choice(openers_poster_summary) # poster summary variation
|
| 291 |
+
opening_explain = random.choice(openers_explain) # thought explanation
|
| 292 |
+
print('Style:', style_hint)
|
| 293 |
+
print('Lexical:', lexical_hint)
|
| 294 |
+
print('Opening:', opening_phrase)
|
| 295 |
+
print('Generic opening:', opening_generic)
|
| 296 |
+
|
| 297 |
+
|
| 298 |
+
# Wrap the GENERATION logic in try/finally to handle cleanup
|
| 299 |
+
try:
|
| 300 |
+
if image_uploaded:
|
| 301 |
+
"""###############################################################
|
| 302 |
+
Case 1: a health poster is uploaded
|
| 303 |
+
=> VLM-enabled response prediction to that specific poster
|
| 304 |
+
###############################################################"""
|
| 305 |
+
################################################
|
| 306 |
+
# * IMAGE UNDERSTANDING
|
| 307 |
+
################################################
|
| 308 |
+
yield "Analyzing image content..." # UI Feedback
|
| 309 |
+
|
| 310 |
+
PROMPT = (
|
| 311 |
+
f"Describe the content and main message in given heatlh campaign poster and how it's related to {health_topic}. ",
|
| 312 |
+
"Note that the message could be non-direct or subtle (e.g. irony, fear-driven evoke without explicit texts, etc). Only provide the answer (in 2-4 sentences). ",
|
| 313 |
+
f"Start the response with {opening_poster}"
|
| 314 |
+
)
|
| 315 |
+
messages = [
|
| 316 |
+
{"role": "user", "content": [
|
| 317 |
+
{"type": "image"},
|
| 318 |
+
{"type": "text", "text": PROMPT}
|
| 319 |
+
]}
|
| 320 |
+
]
|
| 321 |
+
input_text = tokenizer.apply_chat_template(messages, add_generation_prompt = True)
|
| 322 |
+
inputs = tokenizer(
|
| 323 |
+
image.convert("RGB"),
|
| 324 |
+
input_text,
|
| 325 |
+
add_special_tokens = False,
|
| 326 |
+
return_tensors = "pt",
|
| 327 |
+
).to(device)
|
| 328 |
+
# Model inference
|
| 329 |
+
gen_tokens = model.generate(
|
| 330 |
+
**inputs,
|
| 331 |
+
max_new_tokens = 512,
|
| 332 |
+
use_cache = True,
|
| 333 |
+
# do_sample=cfgs["stochastic"],
|
| 334 |
+
# temperature=cfgs["temperature"],
|
| 335 |
+
# min_p=0.9,
|
| 336 |
+
# min_p=0.3,
|
| 337 |
+
top_k=15,
|
| 338 |
+
temperature=0.8,
|
| 339 |
+
do_sample=True, # cfgs["stochastic"]
|
| 340 |
+
)
|
| 341 |
+
outs = tokenizer.batch_decode(gen_tokens[:, inputs.input_ids.shape[1]:])[0]
|
| 342 |
+
image_desc = outs.replace(tokenizer.eos_token, "")
|
| 343 |
+
image_desc = image_desc.replace("<end_of_turn>", "")
|
| 344 |
+
|
| 345 |
+
################################################
|
| 346 |
+
# 2. Construct SYSTEM and USER PROMPT
|
| 347 |
+
################################################
|
| 348 |
+
SYSTEM_PROMPT = cfg_prompts["SYSTEM_SIM"]
|
| 349 |
+
SIM_PROMPT = ""
|
| 350 |
+
# prompt for role-playing information
|
| 351 |
+
SIM_PROMPT += f"You are: Demographics:\n{demo_info}\n"
|
| 352 |
+
# SIM_PROMPT += f"Your personality test shows you have (min score = 0; max score = 5):\nBig-Five Trait Scores:\n{persona_score}\n\n"
|
| 353 |
+
# SIM_PROMPT += f"You also have {locus}\n"
|
| 354 |
+
# situation description (role-playing)
|
| 355 |
+
SIM_PROMPT += cfg_prompts["SIMULATION_SIM"]
|
| 356 |
+
|
| 357 |
+
################################################
|
| 358 |
+
# 3. Stage 1: VLM-enabled response prediction
|
| 359 |
+
# Predict Trait-aware Likert Scale Responses
|
| 360 |
+
################################################
|
| 361 |
+
assert cfgs["infer_engine"] == "unsloth", "Only unsloth inference is supported"
|
| 362 |
+
assert cfgs["vision"] == True, "Must have vision input"
|
| 363 |
+
# load a sample row to extract Likert scale questions
|
| 364 |
+
df = pd.read_csv(os.path.expandvars(cfgs["data_path"]))
|
| 365 |
+
# extract sample with given health_topic for correct question set
|
| 366 |
+
sample = df[df['Poster_id'] == HEALTH_TOPICS[health_topic]].iloc[0]
|
| 367 |
+
del df # free memory
|
| 368 |
+
""" Iterate through each question"""
|
| 369 |
+
# answers_json = {}
|
| 370 |
+
answers_numeric = ""
|
| 371 |
+
# for question in [
|
| 372 |
+
# "This message makes me more concerned about the health risks in the poster - Scale: 1 (not at all) - 9 (extremely)",
|
| 373 |
+
# "The message motivates me to engage in healthier lifestyle and habit - Scale: 1 (not at all) - 9 (extremely)",
|
| 374 |
+
# "In your opinion, how harmful is ignoring the health risks in the poster? - Scale: 1 (not at all) - 9 (extremely",
|
| 375 |
+
# "How open are you to engaging in the activity in the poster? - Scale: 1 (not at all) - 9 (extremely)",
|
| 376 |
+
# ]:
|
| 377 |
+
for i in range(1,16,1):
|
| 378 |
+
# a. parse specific Likert score question
|
| 379 |
+
col = f"Q{i}"
|
| 380 |
+
if pd.isna(sample[col]):
|
| 381 |
+
continue
|
| 382 |
+
question = sample[col].replace("\n", " ")
|
| 383 |
+
# instruction prompt to answer in proper format
|
| 384 |
+
if "type in" in question.lower():
|
| 385 |
+
continue # skip free-text questions for demo
|
| 386 |
+
elif "make you feel" in question.lower():
|
| 387 |
+
continue # skip emotional questions: imprecise
|
| 388 |
+
elif "how open" in question.lower():
|
| 389 |
+
continue # skip intentional question: low-accuracy
|
| 390 |
+
# b. intialize USER PROMPT with SIMULATION PROMPT
|
| 391 |
+
# with full demographic+personality data
|
| 392 |
+
USER_PROMPT = SIM_PROMPT
|
| 393 |
+
USER_PROMPT += f"Question: {question}\n\n"
|
| 394 |
+
# instruction prompt to answer in proper format
|
| 395 |
+
USER_PROMPT += cfg_prompts['INSTRUCTION_MCQ']
|
| 396 |
+
# c. Contruct LLM message: response prediction
|
| 397 |
+
messages = [
|
| 398 |
+
{"role": "user", "content": [
|
| 399 |
+
{"type": "image"},
|
| 400 |
+
{"type": "text", "text": SYSTEM_PROMPT + USER_PROMPT}
|
| 401 |
+
]}
|
| 402 |
+
]
|
| 403 |
+
input_text = tokenizer.apply_chat_template(messages, add_generation_prompt = True)
|
| 404 |
+
inputs = tokenizer(
|
| 405 |
+
image.convert("RGB"),
|
| 406 |
+
input_text,
|
| 407 |
+
add_special_tokens = False,
|
| 408 |
+
return_tensors = "pt",
|
| 409 |
+
).to(device)
|
| 410 |
+
# d. Model inference
|
| 411 |
+
gen_tokens = model.generate(
|
| 412 |
+
**inputs,
|
| 413 |
+
max_new_tokens = 16,
|
| 414 |
+
use_cache = True,
|
| 415 |
+
do_sample=cfgs["stochastic"],
|
| 416 |
+
temperature=cfgs["temperature"],
|
| 417 |
+
min_p=0.9,
|
| 418 |
+
)
|
| 419 |
+
outs = tokenizer.batch_decode(gen_tokens[:, inputs.input_ids.shape[1]:])[0]
|
| 420 |
+
answer = outs.replace(tokenizer.eos_token, "")
|
| 421 |
+
answer = answer.replace("<end_of_turn>", "")
|
| 422 |
+
# answers_json[col] = answer
|
| 423 |
+
answers_numeric += f"{question}. Your answer: {answer}\n"
|
| 424 |
+
# print(answers_json)
|
| 425 |
+
print(answers_numeric)
|
| 426 |
+
|
| 427 |
+
################################################
|
| 428 |
+
# 4. Stage 2: LLM Summarization of all answers
|
| 429 |
+
# => final response generation based on
|
| 430 |
+
# all Likert answers to the poster
|
| 431 |
+
# => one-shot prompting
|
| 432 |
+
################################################
|
| 433 |
+
SYSTEM_PROMPT = "You are a helpful assistant."
|
| 434 |
+
# USER_PROMPT = f"Please convert these questions and answers into a concise and coherent \
|
| 435 |
+
# summary of your overall reactions, feelings, and perspectives about the poster: {answers_numeric} \
|
| 436 |
+
# Please provide the final response only."
|
| 437 |
+
# USER_PROMPT = f"Summarize the main points from questions and answers below into a concise and coherent overall reaction to the poster:\
|
| 438 |
+
# {answers_numeric}. Provide the final response only.\n"
|
| 439 |
+
USER_PROMPT = (
|
| 440 |
+
"Summarize the following survey responses into a short, natural paragraph that captures your overall sentiment, motivation, and thinking. "
|
| 441 |
+
f"Write as if paraphrasing what a person might say in conversation. Adjust your style based on your demographic/personality traits."
|
| 442 |
+
"Do NOT repeat numeric scores. "
|
| 443 |
+
"Preserve polarity: low scores → low concern/motivation/openness; high scores → high concern/motivation/openness. "
|
| 444 |
+
"If answers are mixed (e.g., believes something is harmful but isn't personally moved), reflect that nuance explicitly. "
|
| 445 |
+
"Keep to 1-5 sentences.\n\n"
|
| 446 |
+
|
| 447 |
+
"**STRICTLY FOLLOW THESE RULES:**\n"
|
| 448 |
+
"- Infer direction from each item's Scale description (e.g., 1-9: higher = more; 0-6: higher = more). "
|
| 449 |
+
"- Use calibrated wording: 1-2 = very low, 3-4 = low, 5 = moderate, 6-7 = high, 8-9 = very high; for 0-6: 0-1 = not/slight, 2-3 = somewhat, 4-5 = high, 6 = very. "
|
| 450 |
+
"- VERY IMPORTANT: provide ONLY the final summarized response, without anything else!"
|
| 451 |
+
f"- The response MUST have a consistent health topic: {health_topic}. Ground each sentence to the impact of campaign message.\n"
|
| 452 |
+
"- Never invert sentiment. Prefer hedged phrases (e.g., “not particularly,” “only somewhat,” “very open,” “not open at all”).\n\n"
|
| 453 |
+
f"- Mimic the talking style of emulated demographic as realistic as possible."
|
| 454 |
+
|
| 455 |
+
"**Example input 1:**\n"
|
| 456 |
+
"The message makes me more concerned about the health risks of poor eating habits - Scale: 1-9. Your answer: 9\n"
|
| 457 |
+
"The message motivates me to make healthy eating choices - Scale: 1-9. Your answer: 9\n"
|
| 458 |
+
"In your opinion, how harmful is neglecting proper nutrition and weight management to your overall health? - Scale: 0–6. Your answer: 5\n"
|
| 459 |
+
"How open are you to adopting healthier eating habits and lifestyle changes? - Scale: 1-9. Your answer: 9\n"
|
| 460 |
+
"**Example output 1:**\n"
|
| 461 |
+
"This message really heightened my awareness of how unhealthy eating can be. The content in the message strongly motivates me to make better choices, and I feel very ready to follow through.\n\n"
|
| 462 |
+
|
| 463 |
+
"**Example input 2:**\n"
|
| 464 |
+
"The message makes me more concerned about the health risks of COPD and smoking - Scale: 1-9. Your answer: 1\n"
|
| 465 |
+
"The message motivates me to not smoke. - Scale: 1-9. Your answer: 1\n"
|
| 466 |
+
"In your opinion, how harmful is smoking to your general health? - Scale: 0-6. Your answer: 6\n"
|
| 467 |
+
"How open are you to smoking in the future? - Scale: 1-9. Your answer: 1\n"
|
| 468 |
+
"**Example output 2:**\n"
|
| 469 |
+
"From this message, I recognize smoking is very harmful, but the content in the message didn't increase my concern or motivate me much. It does somewhat make me understand that smoking is harmful, however. Anyway, I'm not open to smoking in the future.\n\n"
|
| 470 |
+
|
| 471 |
+
"**Example input 3:**\n"
|
| 472 |
+
"The message makes me more concerned about the effects of lack of exercise - Scale: 1-9. Your answer: 4\n"
|
| 473 |
+
"The message motivates me to be more active - Scale: 1-9. Your answer: 3\n"
|
| 474 |
+
"How open are you to exercising regularly? - Scale: 1-9. Your answer: 4\n"
|
| 475 |
+
"**Example output 3:**\n"
|
| 476 |
+
"Through the message, I get that exercise matters and the message raised my awareness a bit, but the poster content itself didn't really motivate me. The content in the message has some small impact in motivating me to change my routine.\n\n"
|
| 477 |
+
|
| 478 |
+
# "**Example input 4:**\n"
|
| 479 |
+
# "The message makes me more concerned about the health risks of substance abuse - Scale: 1 (not at all) - 9 (extremely). Your answer: 6\n"
|
| 480 |
+
# "The message motivates me to not use substances. - Scale: 1 (not at all) - 9 (extremely). Your answer: 6\n"
|
| 481 |
+
# "In your opinion, how harmful is substance use to your general health? - Scale: 0 (not at all)-6 (extremely harmful). Your answer: 5\n"
|
| 482 |
+
# "How open are you to trying a substance in the future? - Scale: 1 (not at all)-9 (extremely). Your answer: 1\n"
|
| 483 |
+
# "**Example output 4:**\n"
|
| 484 |
+
# "This message somewhat makes me more concerned about the health risks of substance abuse motivates me not to use them. However, the message itself doesn't completely convince me that substance abuse is harmful. However, I'm not open to trying substance at all!!\n"
|
| 485 |
+
f"Start the response with '{opening_phrase}' (Style hint: {style_hint}; Lexical hint: {lexical_hint})\n"
|
| 486 |
+
f"Input: {answers_numeric}. "
|
| 487 |
+
)
|
| 488 |
+
|
| 489 |
+
# Contruct LLM message
|
| 490 |
+
messages = [
|
| 491 |
+
{"role": "user", "content": [
|
| 492 |
+
# {"type": "image"},
|
| 493 |
+
{"type": "text", "text": SYSTEM_PROMPT + USER_PROMPT}
|
| 494 |
+
]}
|
| 495 |
+
]
|
| 496 |
+
# input_text = tokenizer_aux.apply_chat_template(messages, add_generation_prompt = True)
|
| 497 |
+
# inputs = tokenizer_aux(
|
| 498 |
+
# # image.convert("RGB"),
|
| 499 |
+
# input_text,
|
| 500 |
+
# add_special_tokens = False,
|
| 501 |
+
# return_tensors = "pt",
|
| 502 |
+
# ).to(device)
|
| 503 |
+
input_text = tokenizer.apply_chat_template(messages, add_generation_prompt = True)
|
| 504 |
+
inputs = tokenizer(
|
| 505 |
+
# image.convert("RGB"),
|
| 506 |
+
input_text,
|
| 507 |
+
add_special_tokens = False,
|
| 508 |
+
return_tensors = "pt",
|
| 509 |
+
).to(device)
|
| 510 |
+
|
| 511 |
+
############################
|
| 512 |
+
### Text LLM Streaming ###
|
| 513 |
+
############################
|
| 514 |
+
# generation with streamer
|
| 515 |
+
generate_kwargs = dict(
|
| 516 |
+
**inputs,
|
| 517 |
+
streamer=streamer, # streamer_aux,
|
| 518 |
+
max_new_tokens=512,
|
| 519 |
+
use_cache=True,
|
| 520 |
+
# min_p=0.3,
|
| 521 |
+
top_k=15,
|
| 522 |
+
temperature=0.8,
|
| 523 |
+
do_sample=True, # cfgs["stochastic"]
|
| 524 |
+
)
|
| 525 |
+
# separate thread to run generation
|
| 526 |
+
thread = threading.Thread(
|
| 527 |
+
target=model.generate, # model_aux.generate,
|
| 528 |
+
kwargs=generate_kwargs
|
| 529 |
+
)
|
| 530 |
+
thread.start()
|
| 531 |
+
# stream out generation
|
| 532 |
+
outputs = [
|
| 533 |
+
f"Emulated traits:\n {demo_info}\n" + '='*20 + "\n\n",
|
| 534 |
+
image_desc + "\n\n"
|
| 535 |
+
]
|
| 536 |
+
for new_token in streamer: # streamer_aux:
|
| 537 |
+
outputs.append(new_token)
|
| 538 |
+
final_output = ''.join(outputs)
|
| 539 |
+
yield final_output
|
| 540 |
+
|
| 541 |
+
# Ensure thread finishes
|
| 542 |
+
thread.join()
|
| 543 |
+
|
| 544 |
+
# text representation of final response
|
| 545 |
+
response = "".join(outputs[2:]) # ignore trait summary & image description
|
| 546 |
+
print(colored('Traits', 'green'), demo_info)
|
| 547 |
+
print(colored('Emulated response:', 'green'), response)
|
| 548 |
+
print('='*100)
|
| 549 |
+
|
| 550 |
+
|
| 551 |
+
################################################
|
| 552 |
+
# 5. Stage 3: provide explanation (demo purpose)
|
| 553 |
+
# => condition on {trait} AND {reponse}
|
| 554 |
+
################################################
|
| 555 |
+
SYSTEM_PROMPT = cfg_prompts["SYSTEM_SIM"]
|
| 556 |
+
SIM_PROMPT = ""
|
| 557 |
+
# prompt for role-playing information
|
| 558 |
+
SIM_PROMPT += f"You are: Demographics:\n{demo_info}\n"
|
| 559 |
+
# SIM_PROMPT += f"Your personality test shows you have (min score = 0; max score = 5):\nBig-Five Trait Scores:\n{persona_score}\n\n"
|
| 560 |
+
# SIM_PROMPT += f"You also have {locus}\n"
|
| 561 |
+
# situation description (role-playing)
|
| 562 |
+
SIM_PROMPT += cfg_prompts["SIMULATION_SIM"]
|
| 563 |
+
SIM_PROMPT += (
|
| 564 |
+
f"After seeing the uploaded impage, your response were {response}. "
|
| 565 |
+
"Briefly explain WHY you responded that way, based on your demographic background. "
|
| 566 |
+
f"Keep the explanation concise and direct. Start the response with '{opening_explain}' "
|
| 567 |
+
f"(Style hint: {style_hint}, concise; Lexical hint: {lexical_hint}). "
|
| 568 |
+
"Afterward, give a few *generic and succinct* suggestions to improve the poster's persuasiveness."
|
| 569 |
+
)
|
| 570 |
+
USER_PROMPT = SIM_PROMPT
|
| 571 |
+
|
| 572 |
+
# Contruct LLM message
|
| 573 |
+
messages = [
|
| 574 |
+
{"role": "user", "content": [
|
| 575 |
+
{"type": "image"},
|
| 576 |
+
{"type": "text", "text": SYSTEM_PROMPT + USER_PROMPT}
|
| 577 |
+
]}
|
| 578 |
+
]
|
| 579 |
+
# input_text = tokenizer_aux.apply_chat_template(messages, add_generation_prompt = True)
|
| 580 |
+
# inputs = tokenizer_aux(
|
| 581 |
+
# image.convert("RGB"),
|
| 582 |
+
# input_text,
|
| 583 |
+
# add_special_tokens = False,
|
| 584 |
+
# return_tensors = "pt",
|
| 585 |
+
# ).to(device)
|
| 586 |
+
input_text = tokenizer.apply_chat_template(messages, add_generation_prompt = True)
|
| 587 |
+
inputs = tokenizer(
|
| 588 |
+
image.convert("RGB"),
|
| 589 |
+
input_text,
|
| 590 |
+
add_special_tokens = False,
|
| 591 |
+
return_tensors = "pt",
|
| 592 |
+
).to(device)
|
| 593 |
+
|
| 594 |
+
############################
|
| 595 |
+
### Text LLM Streaming ###
|
| 596 |
+
############################
|
| 597 |
+
# generation with streamer
|
| 598 |
+
generate_kwargs = dict(
|
| 599 |
+
**inputs,
|
| 600 |
+
streamer=streamer, # streamer_aux,
|
| 601 |
+
max_new_tokens=512,
|
| 602 |
+
use_cache=True,
|
| 603 |
+
min_p=0.85,
|
| 604 |
+
temperature=0.1,
|
| 605 |
+
do_sample=True, # cfgs["stochastic"]
|
| 606 |
+
)
|
| 607 |
+
# separate thread to run generation
|
| 608 |
+
thread = threading.Thread(
|
| 609 |
+
target=model.generate, # model_aux.generate,
|
| 610 |
+
kwargs=generate_kwargs
|
| 611 |
+
)
|
| 612 |
+
thread.start()
|
| 613 |
+
# stream out generation
|
| 614 |
+
# outputs = [image_desc + "\n\n"]
|
| 615 |
+
outputs += ["\n"]
|
| 616 |
+
for new_token in streamer: # streamer_aux:
|
| 617 |
+
outputs.append(new_token)
|
| 618 |
+
final_output = ''.join(outputs)
|
| 619 |
+
yield final_output
|
| 620 |
+
|
| 621 |
+
thread.join()
|
| 622 |
+
|
| 623 |
+
|
| 624 |
+
return answer
|
| 625 |
+
else:
|
| 626 |
+
"""###############################################################
|
| 627 |
+
Case 2: no health poster is uploaded
|
| 628 |
+
=> General Response to the health topic
|
| 629 |
+
=> not conditioned on any particular health poster
|
| 630 |
+
###############################################################"""
|
| 631 |
+
################################################
|
| 632 |
+
# 2. Construct SYSTEM and USER PROMPT
|
| 633 |
+
################################################
|
| 634 |
+
SYSTEM_PROMPT = (
|
| 635 |
+
"You are a person with unique demographic and personality traits. "
|
| 636 |
+
"Based on your background, you naturally have thoughts, feelings, and reactions to what you see."
|
| 637 |
+
)
|
| 638 |
+
SIM_PROMPT = ""
|
| 639 |
+
# prompt for role-playing information
|
| 640 |
+
SIM_PROMPT += f"You are: {demo_info}\n"
|
| 641 |
+
# SIM_PROMPT += f"Your personality test shows you have (min score = 0; max score = 5): {persona_score}\n"
|
| 642 |
+
# SIM_PROMPT += f"You also have {locus}\n"
|
| 643 |
+
# situation description (role-playing)
|
| 644 |
+
SIM_PROMPT += f"You are being asked a general question to share your *general* opinions and beliefs about a given health topic.\n"
|
| 645 |
+
################################################
|
| 646 |
+
# 3. LLM-enabled response prediction
|
| 647 |
+
# Predict Trait-aware Likert Scale Responses
|
| 648 |
+
################################################
|
| 649 |
+
assert cfgs["infer_engine"] == "unsloth", "Only unsloth inference is supported"
|
| 650 |
+
USER_PROMPT = SIM_PROMPT
|
| 651 |
+
USER_PROMPT += (
|
| 652 |
+
f"What are your *general* thoughts and opinions about the {health_topic} health topic? "
|
| 653 |
+
f" What's your attitude and feeling when talking about {health_topic} in general and why?"
|
| 654 |
+
f" How familiar are you with {health_topic}? How much do you care or know about it?"
|
| 655 |
+
f" Do you think {health_topic} is an important topic to talk about?"
|
| 656 |
+
f" What is its impacts and importance {health_topic} in society and your life? Why?"
|
| 657 |
+
f" Do you have any strong opinions about it?"
|
| 658 |
+
f" Are you interested in learning more about it?"
|
| 659 |
+
)
|
| 660 |
+
# instruction prompt to answer in proper format
|
| 661 |
+
USER_PROMPT += (
|
| 662 |
+
"Your personality, locus of control, and demographic traits influence your response. Adjust your style based on your demographic personality traits.\n"
|
| 663 |
+
"**STRICTLY FOLLOW THESE RULES:**\n"
|
| 664 |
+
"- Human-like, casual, everyday conversational response. Only answer the questions\n"
|
| 665 |
+
f"- The response MUST have a consistent health topic: {health_topic}.\n"
|
| 666 |
+
# "- Answer briefly in **5-7 sentences**.\n"
|
| 667 |
+
"- Only provide the answer. DO NOT REPEAT THE PROMPT!\n"
|
| 668 |
+
"- Condition your response on your *demographic/personality traits provided earlier, IGNORING the [Not specified] ones*.\n"
|
| 669 |
+
"- MUST provide *reasonable* and *informative* answers aligned with your background."
|
| 670 |
+
f"- Start the response with '{opening_generic}' ; {style_hint} {lexical_hint}\n"
|
| 671 |
+
# f"- Start the answer some variations of \'About my personal thoughts on *{health_topic}*, I \' \n"
|
| 672 |
+
# f"- Start the answer with something like: When thinking about {health_topic}, I ..."
|
| 673 |
+
)
|
| 674 |
+
# c. Contruct LLM message
|
| 675 |
+
# print("USER PROMPT:", USER_PROMPT)
|
| 676 |
+
messages = [
|
| 677 |
+
{"role": "user", "content": SYSTEM_PROMPT + USER_PROMPT}
|
| 678 |
+
]
|
| 679 |
+
assert "gemma" in cfgs["model"], "Currently only gemma model is supported for no-image input"
|
| 680 |
+
input_text = tokenizer.apply_chat_template(messages, add_generation_prompt = True)
|
| 681 |
+
inputs = tokenizer(
|
| 682 |
+
input_text,
|
| 683 |
+
add_special_tokens = False,
|
| 684 |
+
return_tensors = "pt",
|
| 685 |
+
).to(device)
|
| 686 |
+
############################
|
| 687 |
+
### Text LLM Streaming ###
|
| 688 |
+
############################
|
| 689 |
+
# generation with streamer
|
| 690 |
+
generate_kwargs = dict(
|
| 691 |
+
**inputs,
|
| 692 |
+
streamer=streamer,
|
| 693 |
+
max_new_tokens=512,
|
| 694 |
+
use_cache=True,
|
| 695 |
+
# min_p=0.3,
|
| 696 |
+
top_k=15,
|
| 697 |
+
temperature=0.8,
|
| 698 |
+
do_sample=True, # cfgs["stochastic"]
|
| 699 |
+
)
|
| 700 |
+
# separate thread to run generation
|
| 701 |
+
thread = threading.Thread(
|
| 702 |
+
target=model.generate,
|
| 703 |
+
kwargs=generate_kwargs
|
| 704 |
+
)
|
| 705 |
+
thread.start()
|
| 706 |
+
# stream out generation
|
| 707 |
+
outputs = [f"Emulated traits:\n {demo_info}\n" + '='*20 + "\n\n"]
|
| 708 |
+
for new_token in streamer:
|
| 709 |
+
outputs.append(new_token)
|
| 710 |
+
final_output = ''.join(outputs)
|
| 711 |
+
yield final_output
|
| 712 |
+
thread.join()
|
| 713 |
+
|
| 714 |
+
except GeneratorExit:
|
| 715 |
+
print("User disconnected. Waiting for generation to complete...")
|
| 716 |
+
finally:
|
| 717 |
+
# Ensure cleanup happens even on normal finish or errors
|
| 718 |
+
if thread is not None and thread.is_alive():
|
| 719 |
+
thread.join()
|
| 720 |
+
torch.cuda.empty_cache()
|
| 721 |
+
|
| 722 |
+
"""###########################################################################
|
| 723 |
+
Evaluate a given model (specified in model_cfgs)
|
| 724 |
+
on posters with given test_style
|
| 725 |
+
|
| 726 |
+
Args:
|
| 727 |
+
+ cfgs : specify model type (e.g. gemma or llama),
|
| 728 |
+
data source, and export paths
|
| 729 |
+
+ prompts : set of prompts
|
| 730 |
+
|
| 731 |
+
Outputs:
|
| 732 |
+
=> save model in cfgs["export_path"] (CSV file)
|
| 733 |
+
+ if cfgs["export_path"] not exists, initialize it with cfgs["data_path"]
|
| 734 |
+
=> original survey data with ground-truth responses
|
| 735 |
+
+ add column "<model>:<version>": store AI-simulated responses
|
| 736 |
+
+ support concurrent evaluation on different jobs
|
| 737 |
+
##########################################################################"""
|
| 738 |
+
if __name__ == '__main__':
|
| 739 |
+
"""==========================================
|
| 740 |
+
1. load model settings & prompts format
|
| 741 |
+
=========================================="""
|
| 742 |
+
######################################
|
| 743 |
+
# Load model configs & prompts
|
| 744 |
+
######################################
|
| 745 |
+
model_cfg = "./configs/task1_demo_sph.yaml"
|
| 746 |
+
prompt_cfg = "./configs/prompts.yaml"
|
| 747 |
+
cfgs = load_config(model_cfg)
|
| 748 |
+
cfg_prompts = load_config(prompt_cfg)
|
| 749 |
+
|
| 750 |
+
"""==========================================
|
| 751 |
+
2. Evaluate model defined in configs
|
| 752 |
+
=========================================="""
|
| 753 |
+
print(colored('MODEL USE:', 'green'), cfgs["model"])
|
| 754 |
+
# print(prompts['SYSTEM'])
|
| 755 |
+
# print(prompts['INSTRUCTION'])
|
| 756 |
+
|
| 757 |
+
"""===============================
|
| 758 |
+
3. Initialize model
|
| 759 |
+
=> `model`, `tokenizer`
|
| 760 |
+
are initialized here
|
| 761 |
+
==============================="""
|
| 762 |
+
assert cfgs["infer_engine"] == "unsloth", "Only unsloth inference is supported"
|
| 763 |
+
assert cfgs["vision"] == True, "Must have vision input"
|
| 764 |
+
if cfgs["vision"]:
|
| 765 |
+
#################################################
|
| 766 |
+
### (1) MAIN MODEL
|
| 767 |
+
### => response emulation, fine-tuned model
|
| 768 |
+
#################################################
|
| 769 |
+
# WITH VISUAL STIMULI
|
| 770 |
+
model, tokenizer = FastVisionModel.from_pretrained(
|
| 771 |
+
model_name=cfgs["model"],
|
| 772 |
+
load_in_4bit=True,
|
| 773 |
+
)
|
| 774 |
+
FastVisionModel.for_inference(model)
|
| 775 |
+
if "gemma" in cfgs["model"]:
|
| 776 |
+
# gemma-specific tokenizer chat template
|
| 777 |
+
tokenizer = get_chat_template(
|
| 778 |
+
tokenizer,
|
| 779 |
+
chat_template = "gemma-3",
|
| 780 |
+
)
|
| 781 |
+
#################################################
|
| 782 |
+
### (2) AUXILLIARY MODEL
|
| 783 |
+
### => summarization model
|
| 784 |
+
### => larger (12b) for better summarization
|
| 785 |
+
#################################################
|
| 786 |
+
# model_aux, tokenizer_aux = FastVisionModel.from_pretrained(
|
| 787 |
+
# model_name=cfgs["model_summarize"],
|
| 788 |
+
# load_in_4bit=True,
|
| 789 |
+
# )
|
| 790 |
+
# FastVisionModel.for_inference(model)
|
| 791 |
+
# if "gemma" in cfgs["model"]:
|
| 792 |
+
# # gemma-specific tokenizer chat template
|
| 793 |
+
# tokenizer_aux = get_chat_template(
|
| 794 |
+
# tokenizer_aux,
|
| 795 |
+
# chat_template = "gemma-3",
|
| 796 |
+
# )
|
| 797 |
+
|
| 798 |
+
# # initialize streamer tokens
|
| 799 |
+
# streamer = TextIteratorStreamer(
|
| 800 |
+
# tokenizer, skip_prompt=True, skip_special_tokens=True
|
| 801 |
+
# )
|
| 802 |
+
# streamer_aux = TextIteratorStreamer(
|
| 803 |
+
# tokenizer_aux, skip_prompt=True, skip_special_tokens=True
|
| 804 |
+
# )
|
| 805 |
+
|
| 806 |
+
"""=============================================
|
| 807 |
+
4. User-input Dropdown Traits
|
| 808 |
+
============================================="""
|
| 809 |
+
#################################
|
| 810 |
+
### Gradio Interface ###
|
| 811 |
+
#################################
|
| 812 |
+
with gr.Blocks(theme="gradio/dark") as interface:
|
| 813 |
+
# --- Title Page with Logo ---
|
| 814 |
+
LOGO_PATH = os.path.abspath(os.path.join(os.path.dirname(__file__), "assets/umd_logo.png"))
|
| 815 |
+
gr.Image(value=LOGO_PATH, show_label=False, interactive=False, height=100)
|
| 816 |
+
gr.Markdown(
|
| 817 |
+
"""
|
| 818 |
+
<div style="text-align: center;">
|
| 819 |
+
<h1 style="margin-bottom: 0.5em;">
|
| 820 |
+
UMD AI-Empowered Response Prediction in Public Health Messaging
|
| 821 |
+
</h1>
|
| 822 |
+
</div>
|
| 823 |
+
|
| 824 |
+
<hr style="margin-top: 0.8em; margin-bottom: 0.8em;"> <!-- thinner spacing around line -->
|
| 825 |
+
|
| 826 |
+
<div style="text-align: center;">
|
| 827 |
+
<h2 style="margin-top: 0.3em; margin-bottom: 0.6em;">
|
| 828 |
+
User Guide
|
| 829 |
+
</h2>
|
| 830 |
+
</div>
|
| 831 |
+
|
| 832 |
+
<ul style="text-align: left; max-width: 800px; margin: auto;">
|
| 833 |
+
<li>This program emulates <b>demographic- and personality-conditioned responses</b> to public health posters using our trait-aligned Vision-Language Model (VLM).</li>
|
| 834 |
+
<li>To begin, (1) specify the target demographic traits, then (2) upload a public health poster to predict responses.</li>
|
| 835 |
+
<li>If a health poster is uploaded, the model first summarizes its understanding of the image.</li>
|
| 836 |
+
<li><b>Please note:</b>
|
| 837 |
+
<ul>
|
| 838 |
+
<li>Each interaction only uses the uploaded image and selected traits (no conversation history).</li>
|
| 839 |
+
<li>You don’t need to type any text prompt; just upload the Health Poster and click <b>Submit</b>.</li>
|
| 840 |
+
<li>If no poster or image is uploaded, the program automatically generates the emulated person’s <b>general opinion</b> on the selected Health Topic.</li>
|
| 841 |
+
<li>Please do not interrupt the generation process as it can lead to unexpected results. In case it happens, simply refresh the web app.</li>
|
| 842 |
+
<li><b>Limitation:</b> The model may generate less realistic emulations to some under-represented demographics in the survey dataset (e.g., Asian seniors). We are conducting more comprehensive survey to effectively address this limitation.</li>
|
| 843 |
+
</ul>
|
| 844 |
+
</li>
|
| 845 |
+
</ul>
|
| 846 |
+
|
| 847 |
+
<hr style="margin-top: 0.8em; margin-bottom: 1.2em;">
|
| 848 |
+
""",
|
| 849 |
+
elem_id="intro-section"
|
| 850 |
+
)
|
| 851 |
+
|
| 852 |
+
# Scroll to intro section on load
|
| 853 |
+
gr.HTML("""
|
| 854 |
+
<script>
|
| 855 |
+
window.onload = function() {
|
| 856 |
+
window.scrollTo({ top: 0, behavior: 'smooth' });
|
| 857 |
+
}
|
| 858 |
+
</script>
|
| 859 |
+
""")
|
| 860 |
+
|
| 861 |
+
##########################
|
| 862 |
+
### Demographic Traits ###
|
| 863 |
+
##########################
|
| 864 |
+
gr.Markdown("## 1. Please specify the target demographic traits to be emulated here:")
|
| 865 |
+
# Dropdowns (single-select, no custom values)
|
| 866 |
+
with gr.Row():
|
| 867 |
+
gender = gr.Dropdown(
|
| 868 |
+
label="Gender",
|
| 869 |
+
choices=TRAIT_VALUES["Gender"],
|
| 870 |
+
allow_custom_value=False,
|
| 871 |
+
value="Female",
|
| 872 |
+
)
|
| 873 |
+
age = gr.Dropdown(
|
| 874 |
+
label="Age",
|
| 875 |
+
choices=TRAIT_VALUES["Age"],
|
| 876 |
+
allow_custom_value=False,
|
| 877 |
+
value="25–34",
|
| 878 |
+
)
|
| 879 |
+
profession = gr.Dropdown(
|
| 880 |
+
label="Current Profession",
|
| 881 |
+
choices=TRAIT_VALUES["Current Profession"], # keep given order
|
| 882 |
+
allow_custom_value=False,
|
| 883 |
+
value="Student",
|
| 884 |
+
)
|
| 885 |
+
with gr.Row():
|
| 886 |
+
race = gr.Dropdown(
|
| 887 |
+
label="Race/Ethnicity",
|
| 888 |
+
choices=TRAIT_VALUES["Race/Ethnicity"],
|
| 889 |
+
allow_custom_value=False,
|
| 890 |
+
value="White/Caucasian",
|
| 891 |
+
)
|
| 892 |
+
religion = gr.Dropdown(
|
| 893 |
+
label="Religious/Cultural Group",
|
| 894 |
+
choices=TRAIT_VALUES["Religious/Cultural Group"],
|
| 895 |
+
allow_custom_value=False,
|
| 896 |
+
value="Leave Blank",
|
| 897 |
+
)
|
| 898 |
+
political = gr.Dropdown(
|
| 899 |
+
label="Political Affiliation",
|
| 900 |
+
choices=TRAIT_VALUES["Political Affiliation"],
|
| 901 |
+
allow_custom_value=False,
|
| 902 |
+
value="Leave Blank",
|
| 903 |
+
)
|
| 904 |
+
with gr.Row():
|
| 905 |
+
education = gr.Dropdown(
|
| 906 |
+
label="Highest Education",
|
| 907 |
+
choices=TRAIT_VALUES["Highest Education"],
|
| 908 |
+
allow_custom_value=False,
|
| 909 |
+
value="Leave Blank",
|
| 910 |
+
)
|
| 911 |
+
income = gr.Dropdown(
|
| 912 |
+
label="Annual Household Income",
|
| 913 |
+
choices=TRAIT_VALUES["Annual Household Income"],
|
| 914 |
+
allow_custom_value=False,
|
| 915 |
+
value="$75,000–$99,999",
|
| 916 |
+
)
|
| 917 |
+
family_status = gr.Dropdown(
|
| 918 |
+
label="Family Status",
|
| 919 |
+
choices=TRAIT_VALUES["Family Status"],
|
| 920 |
+
allow_custom_value=False,
|
| 921 |
+
value="Leave Blank"
|
| 922 |
+
)
|
| 923 |
+
# ##########################
|
| 924 |
+
# ### Big Five Traits ###
|
| 925 |
+
# ##########################
|
| 926 |
+
# gr.Markdown("## 1.b) Please adjust the Big Five Personality Traits to be emulated:")
|
| 927 |
+
# with gr.Accordion("Big Five Personality Traits (1 = very low, 5 = very high)", open=True):
|
| 928 |
+
# gr.Markdown(
|
| 929 |
+
# "Adjust the sliders to represent the target personality profile. "
|
| 930 |
+
# "Leave them as-is if not applicable."
|
| 931 |
+
# )
|
| 932 |
+
# with gr.Row():
|
| 933 |
+
# with gr.Column(scale=1):
|
| 934 |
+
# openness = gr.Slider(
|
| 935 |
+
# label="Open-Mindedness",
|
| 936 |
+
# minimum=1, maximum=5, step=0.2, value=2.5,
|
| 937 |
+
# interactive=True
|
| 938 |
+
# )
|
| 939 |
+
# with gr.Column(scale=1):
|
| 940 |
+
# conscientiousness = gr.Slider(
|
| 941 |
+
# label="Conscientiousness",
|
| 942 |
+
# minimum=1, maximum=5, step=0.2, value=2.5,
|
| 943 |
+
# interactive=True
|
| 944 |
+
# )
|
| 945 |
+
# with gr.Column(scale=1):
|
| 946 |
+
# extraversion = gr.Slider(
|
| 947 |
+
# label="Extraversion",
|
| 948 |
+
# minimum=1, maximum=5, step=0.2, value=2.5,
|
| 949 |
+
# interactive=True
|
| 950 |
+
# )
|
| 951 |
+
# with gr.Row():
|
| 952 |
+
# with gr.Column(scale=1):
|
| 953 |
+
# neuroticism = gr.Slider(
|
| 954 |
+
# label="Neuroticism",
|
| 955 |
+
# minimum=1, maximum=5, step=0.2, value=2.5,
|
| 956 |
+
# interactive=True
|
| 957 |
+
# )
|
| 958 |
+
# with gr.Column(scale=1):
|
| 959 |
+
# agreeableness = gr.Slider(
|
| 960 |
+
# label="Agreeableness",
|
| 961 |
+
# minimum=1, maximum=5, step=0.2, value=2.5,
|
| 962 |
+
# interactive=True
|
| 963 |
+
# )
|
| 964 |
+
# gr.Column(scale=1) # right spacer
|
| 965 |
+
|
| 966 |
+
##########################
|
| 967 |
+
### Health Topic ###
|
| 968 |
+
##########################
|
| 969 |
+
gr.Markdown("## 2. Please specify the main Health Topic of the poster here:")
|
| 970 |
+
# ---- dropdown at ~50% page width and centered ----
|
| 971 |
+
with gr.Row():
|
| 972 |
+
with gr.Column(scale=1):
|
| 973 |
+
health_topic = gr.Dropdown(
|
| 974 |
+
label="Health Topic",
|
| 975 |
+
choices=HEALTH_TOPICS,
|
| 976 |
+
allow_custom_value=False,
|
| 977 |
+
)
|
| 978 |
+
gr.Column(scale=1) # right spacer
|
| 979 |
+
##########################
|
| 980 |
+
### Chat interface ###
|
| 981 |
+
##########################
|
| 982 |
+
gr.Markdown("## 3. Upload Public Health Poster here (if no poster is uploaded, the model emulates General Response to the topic):")
|
| 983 |
+
gr.Markdown("""
|
| 984 |
+
#### ▶️ Use Case 1: Poster-Based Response
|
| 985 |
+
+ Upload **only one** poster image — the first file is the one processed.
|
| 986 |
+
+ The model has **no memory**, so re-upload the image for each new request.
|
| 987 |
+
+ Must choose a **Health Topic** that matches the poster content for best results.
|
| 988 |
+
+ No text prompt is needed: upload the poster and click **Submit**.
|
| 989 |
+
#### ▶️ Use Case 2: General Response (No Poster)
|
| 990 |
+
+ Simply select a Health Topic and click **Send**.
|
| 991 |
+
"""
|
| 992 |
+
)
|
| 993 |
+
gr.Markdown("""
|
| 994 |
+
### 📘 Important Notes
|
| 995 |
+
- ⚠️ **Do not interrupt the generation process.** Stopping midway can cause backend issues. Please allow the response to complete.
|
| 996 |
+
- 🏷️ Before uploading a poster, select its **corresponding health topic**.
|
| 997 |
+
- 🎯 For the best experience, ensure the **topic accurately matches the poster content**.
|
| 998 |
+
- 🧩 If you choose not to upload a poster, the model will produce a **general, trait-conditioned response** for the selected topic.
|
| 999 |
+
""")
|
| 1000 |
+
chat = gr.ChatInterface(
|
| 1001 |
+
fn=vlm_response,
|
| 1002 |
+
multimodal=True, # text + image
|
| 1003 |
+
title=f"Vision-Language Model: Trait-Conditioned Response Emulation",
|
| 1004 |
+
type="messages",
|
| 1005 |
+
additional_inputs=[
|
| 1006 |
+
health_topic, gender, age, profession, race, religion,
|
| 1007 |
+
political, education, income, family_status,
|
| 1008 |
+
# extraversion, agreeableness, conscientiousness, neuroticism, openness,
|
| 1009 |
+
],
|
| 1010 |
+
chatbot=gr.Chatbot(height=500), # height=330
|
| 1011 |
+
autofocus=False,
|
| 1012 |
+
)
|
| 1013 |
+
|
| 1014 |
+
"""=============================================
|
| 1015 |
+
5. Chat Interface Launch
|
| 1016 |
+
============================================="""
|
| 1017 |
+
interface.queue(
|
| 1018 |
+
max_size=20,
|
| 1019 |
+
default_concurrency_limit=1,
|
| 1020 |
+
).launch(
|
| 1021 |
+
share=True,
|
| 1022 |
+
max_threads=1,
|
| 1023 |
+
# show_error=True,
|
| 1024 |
+
# prevent_thread_lock=False,
|
| 1025 |
+
# debug=True,
|
| 1026 |
+
)
|
app.sh
ADDED
|
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/bin/bash
|
| 2 |
+
#SBATCH -c 16 # 16 CPUs
|
| 3 |
+
#SBATCH --mem=32g # 32 GB RAM
|
| 4 |
+
#SBATCH --gres=gpu:rtxa5000:1 # 1 GPU (A6000)
|
| 5 |
+
#SBATCH --time=3-00:00:00 # 8 days
|
| 6 |
+
#SBATCH --account=gamma
|
| 7 |
+
#SBATCH --partition=gamma
|
| 8 |
+
#SBATCH --qos=gamma-huge-long
|
| 9 |
+
#SBATCH --output=/fs/nexus-projects/health_sim_ai/src_hf_deploy/app_logs/app_%j.out
|
| 10 |
+
|
| 11 |
+
export HOME=/fs/nexus-projects/health_sim_ai
|
| 12 |
+
cd /fs/nexus-projects/health_sim_ai
|
| 13 |
+
source venvs/llm/bin/activate
|
| 14 |
+
cd src_hf_deploy
|
| 15 |
+
python -u app.py
|
| 16 |
+
# python inference_pred_llm.py
|
| 17 |
+
# python inference_rec_llm.py
|
app_logs/app_5936040.out
ADDED
|
@@ -0,0 +1,18 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
🦥 Unsloth: Will patch your computer to enable 2x faster free finetuning.
|
| 2 |
+
🦥 Unsloth Zoo will now patch everything to make training faster!
|
| 3 |
+
MODEL USE: unsloth/gemma-3-12b-it_task1_1_epochs_test_neutral_partialTraits
|
| 4 |
+
==((====))== Unsloth 2025.3.19: Fast Gemma3 patching. Transformers: 4.50.0.
|
| 5 |
+
\\ /| NVIDIA RTX A5000. Num GPUs = 1. Max memory: 23.547 GB. Platform: Linux.
|
| 6 |
+
O^O/ \_/ \ Torch: 2.6.0+cu124. CUDA: 8.6. CUDA Toolkit: 12.4. Triton: 3.2.0
|
| 7 |
+
\ / Bfloat16 = TRUE. FA [Xformers = 0.0.29.post3. FA2 = False]
|
| 8 |
+
"-____-" Free license: http://github.com/unslothai/unsloth
|
| 9 |
+
Unsloth: Fast downloading is enabled - ignore downloading bars which are red colored!
|
| 10 |
+
|
| 11 |
+
Using a slow image processor as `use_fast` is unset and a slow processor was saved with this model. `use_fast=True` will be the default behavior in v4.50, even if the model was saved with a slow processor. This will result in minor differences in outputs. You'll still be able to use a slow processor with `use_fast=False`.
|
| 12 |
+
/fs/nexus-projects/health_sim_ai/venvs/llm/lib/python3.12/site-packages/gradio/blocks.py:1069: UserWarning: Cannot load gradio/dark. Caught Exception: The space gradio/dark does not exist
|
| 13 |
+
warnings.warn(f"Cannot load {theme}. Caught Exception: {str(e)}")
|
| 14 |
+
/fs/nexus-projects/health_sim_ai/src_hf_deploy/app.py:1010: UserWarning: You have not specified a value for the `type` parameter. Defaulting to the 'tuples' format for chatbot messages, but this is deprecated and will be removed in a future version of Gradio. Please set type='messages' instead, which uses openai-style dictionaries with 'role' and 'content' keys.
|
| 15 |
+
chatbot=gr.Chatbot(height=500), # height=330
|
| 16 |
+
/fs/nexus-projects/health_sim_ai/venvs/llm/lib/python3.12/site-packages/gradio/chat_interface.py:323: UserWarning: The type of the gr.Chatbot does not match the type of the gr.ChatInterface.The type of the gr.ChatInterface, 'messages', will be used.
|
| 17 |
+
warnings.warn(
|
| 18 |
+
slurmstepd: error: *** JOB 5936040 ON gammagpu09 CANCELLED AT 2025-12-08T03:01:34 ***
|
app_logs/app_5936041.out
ADDED
|
@@ -0,0 +1,18 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
🦥 Unsloth: Will patch your computer to enable 2x faster free finetuning.
|
| 2 |
+
🦥 Unsloth Zoo will now patch everything to make training faster!
|
| 3 |
+
MODEL USE: unsloth/gemma-3-12b-it_task1_1_epochs_test_neutral_partialTraits
|
| 4 |
+
==((====))== Unsloth 2025.3.19: Fast Gemma3 patching. Transformers: 4.50.0.
|
| 5 |
+
\\ /| NVIDIA RTX A5000. Num GPUs = 1. Max memory: 23.547 GB. Platform: Linux.
|
| 6 |
+
O^O/ \_/ \ Torch: 2.6.0+cu124. CUDA: 8.6. CUDA Toolkit: 12.4. Triton: 3.2.0
|
| 7 |
+
\ / Bfloat16 = TRUE. FA [Xformers = 0.0.29.post3. FA2 = False]
|
| 8 |
+
"-____-" Free license: http://github.com/unslothai/unsloth
|
| 9 |
+
Unsloth: Fast downloading is enabled - ignore downloading bars which are red colored!
|
| 10 |
+
|
| 11 |
+
Using a slow image processor as `use_fast` is unset and a slow processor was saved with this model. `use_fast=True` will be the default behavior in v4.50, even if the model was saved with a slow processor. This will result in minor differences in outputs. You'll still be able to use a slow processor with `use_fast=False`.
|
| 12 |
+
/fs/nexus-projects/health_sim_ai/venvs/llm/lib/python3.12/site-packages/gradio/blocks.py:1069: UserWarning: Cannot load gradio/dark. Caught Exception: The space gradio/dark does not exist
|
| 13 |
+
warnings.warn(f"Cannot load {theme}. Caught Exception: {str(e)}")
|
| 14 |
+
/fs/nexus-projects/health_sim_ai/src_hf_deploy/app.py:1010: UserWarning: You have not specified a value for the `type` parameter. Defaulting to the 'tuples' format for chatbot messages, but this is deprecated and will be removed in a future version of Gradio. Please set type='messages' instead, which uses openai-style dictionaries with 'role' and 'content' keys.
|
| 15 |
+
chatbot=gr.Chatbot(height=500), # height=330
|
| 16 |
+
/fs/nexus-projects/health_sim_ai/venvs/llm/lib/python3.12/site-packages/gradio/chat_interface.py:323: UserWarning: The type of the gr.Chatbot does not match the type of the gr.ChatInterface.The type of the gr.ChatInterface, 'messages', will be used.
|
| 17 |
+
warnings.warn(
|
| 18 |
+
slurmstepd: error: *** JOB 5936041 ON gammagpu09 CANCELLED AT 2025-12-08T03:07:56 ***
|
app_logs/app_5936047.out
ADDED
|
@@ -0,0 +1,19 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
🦥 Unsloth: Will patch your computer to enable 2x faster free finetuning.
|
| 2 |
+
🦥 Unsloth Zoo will now patch everything to make training faster!
|
| 3 |
+
MODEL USE: unsloth/gemma-3-12b-it_task1_1_epochs_test_neutral_partialTraits
|
| 4 |
+
==((====))== Unsloth 2025.3.19: Fast Gemma3 patching. Transformers: 4.50.0.
|
| 5 |
+
\\ /| NVIDIA RTX A5000. Num GPUs = 1. Max memory: 23.547 GB. Platform: Linux.
|
| 6 |
+
O^O/ \_/ \ Torch: 2.6.0+cu124. CUDA: 8.6. CUDA Toolkit: 12.4. Triton: 3.2.0
|
| 7 |
+
\ / Bfloat16 = TRUE. FA [Xformers = 0.0.29.post3. FA2 = False]
|
| 8 |
+
"-____-" Free license: http://github.com/unslothai/unsloth
|
| 9 |
+
Unsloth: Fast downloading is enabled - ignore downloading bars which are red colored!
|
| 10 |
+
|
| 11 |
+
Using a slow image processor as `use_fast` is unset and a slow processor was saved with this model. `use_fast=True` will be the default behavior in v4.50, even if the model was saved with a slow processor. This will result in minor differences in outputs. You'll still be able to use a slow processor with `use_fast=False`.
|
| 12 |
+
/fs/nexus-projects/health_sim_ai/venvs/llm/lib/python3.12/site-packages/gradio/blocks.py:1069: UserWarning: Cannot load gradio/dark. Caught Exception: The space gradio/dark does not exist
|
| 13 |
+
warnings.warn(f"Cannot load {theme}. Caught Exception: {str(e)}")
|
| 14 |
+
/fs/nexus-projects/health_sim_ai/src_hf_deploy/app.py:1010: UserWarning: You have not specified a value for the `type` parameter. Defaulting to the 'tuples' format for chatbot messages, but this is deprecated and will be removed in a future version of Gradio. Please set type='messages' instead, which uses openai-style dictionaries with 'role' and 'content' keys.
|
| 15 |
+
chatbot=gr.Chatbot(height=500), # height=330
|
| 16 |
+
/fs/nexus-projects/health_sim_ai/venvs/llm/lib/python3.12/site-packages/gradio/chat_interface.py:323: UserWarning: The type of the gr.Chatbot does not match the type of the gr.ChatInterface.The type of the gr.ChatInterface, 'messages', will be used.
|
| 17 |
+
warnings.warn(
|
| 18 |
+
grep: gradio_output.log: No such file or directory
|
| 19 |
+
Gradio Public URL:
|
app_logs/app_5936050.out
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
slurmstepd: error: *** JOB 5936050 ON gammagpu09 CANCELLED AT 2025-12-08T03:23:42 ***
|
app_logs/app_5936052.out
ADDED
|
@@ -0,0 +1,57 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
🦥 Unsloth: Will patch your computer to enable 2x faster free finetuning.
|
| 2 |
+
🦥 Unsloth Zoo will now patch everything to make training faster!
|
| 3 |
+
MODEL USE: unsloth/gemma-3-12b-it_task1_1_epochs_test_neutral_partialTraits
|
| 4 |
+
==((====))== Unsloth 2025.3.19: Fast Gemma3 patching. Transformers: 4.50.0.
|
| 5 |
+
\\ /| NVIDIA RTX A5000. Num GPUs = 1. Max memory: 23.547 GB. Platform: Linux.
|
| 6 |
+
O^O/ \_/ \ Torch: 2.6.0+cu124. CUDA: 8.6. CUDA Toolkit: 12.4. Triton: 3.2.0
|
| 7 |
+
\ / Bfloat16 = TRUE. FA [Xformers = 0.0.29.post3. FA2 = False]
|
| 8 |
+
"-____-" Free license: http://github.com/unslothai/unsloth
|
| 9 |
+
Unsloth: Fast downloading is enabled - ignore downloading bars which are red colored!
|
| 10 |
+
|
| 11 |
+
Using a slow image processor as `use_fast` is unset and a slow processor was saved with this model. `use_fast=True` will be the default behavior in v4.50, even if the model was saved with a slow processor. This will result in minor differences in outputs. You'll still be able to use a slow processor with `use_fast=False`.
|
| 12 |
+
/fs/nexus-projects/health_sim_ai/venvs/llm/lib/python3.12/site-packages/gradio/blocks.py:1069: UserWarning: Cannot load gradio/dark. Caught Exception: The space gradio/dark does not exist
|
| 13 |
+
warnings.warn(f"Cannot load {theme}. Caught Exception: {str(e)}")
|
| 14 |
+
/fs/nexus-projects/health_sim_ai/src_hf_deploy/app.py:1010: UserWarning: You have not specified a value for the `type` parameter. Defaulting to the 'tuples' format for chatbot messages, but this is deprecated and will be removed in a future version of Gradio. Please set type='messages' instead, which uses openai-style dictionaries with 'role' and 'content' keys.
|
| 15 |
+
chatbot=gr.Chatbot(height=500), # height=330
|
| 16 |
+
/fs/nexus-projects/health_sim_ai/venvs/llm/lib/python3.12/site-packages/gradio/chat_interface.py:323: UserWarning: The type of the gr.Chatbot does not match the type of the gr.ChatInterface.The type of the gr.ChatInterface, 'messages', will be used.
|
| 17 |
+
warnings.warn(
|
| 18 |
+
* Running on local URL: http://127.0.0.1:7860
|
| 19 |
+
* Running on public URL: https://8a035fb4eb42d29651.gradio.live
|
| 20 |
+
|
| 21 |
+
This share link expires in 1 week. For free permanent hosting and GPU upgrades, run `gradio deploy` from the terminal in the working directory to deploy to Hugging Face Spaces (https://huggingface.co/spaces)
|
| 22 |
+
Chronic Obstructive Pulmonary Disease (COPD)
|
| 23 |
+
Style: Write in a informal, pragmatic tone, focusing on clarity and utility.
|
| 24 |
+
Lexical: Feel free to vary sentence structures slightly.
|
| 25 |
+
Opening: Through the message
|
| 26 |
+
Generic opening: My initial take on
|
| 27 |
+
Chronic Obstructive Pulmonary Disease (COPD)
|
| 28 |
+
Style: Write in a slightly narrative, flowing tone.
|
| 29 |
+
Lexical: Use a light mix of paraphrasing expressions.
|
| 30 |
+
Opening: Through the message
|
| 31 |
+
Generic opening: On top of my head
|
| 32 |
+
The message makes me more concerned about the health risks of COPD and smoking - Scale: 1 (not at all) - 9 (extremely). Your answer: 8
|
| 33 |
+
The message motivates me to not smoke. - Scale: 1 (not at all) - 9 (extremely). Your answer: 8
|
| 34 |
+
In your opinion, how harmful is smoking to your general health? - Scale: 0 (not at all)-6 (extremely harmful). Your answer: 6
|
| 35 |
+
|
| 36 |
+
Nutrition
|
| 37 |
+
Style: Write in a slightly narrative, flowing tone.
|
| 38 |
+
Lexical: Use a mix of simple and slightly complex sentences.
|
| 39 |
+
Opening: Reflecting on the message here
|
| 40 |
+
Generic opening: Just speaking for myself,
|
| 41 |
+
The message makes me more concerned about the health risks of poor eating habits - Scale: 1 (not at all) - 9 (extremely). Your answer: 8
|
| 42 |
+
The message motivates me to make healthy eating choices - Scale: 1 (not at all) - 9 (extremely). Your answer: 8
|
| 43 |
+
In your opinion, how harmful is neglecting proper nutrition and weight management to your overall health? - Scale: 0 (not at all)-6 (extremely harmful). Your answer: 6
|
| 44 |
+
|
| 45 |
+
Traits Demographics:
|
| 46 |
+
Gender: Female
|
| 47 |
+
Age: 25–34
|
| 48 |
+
Current Profession: Student
|
| 49 |
+
Race/Ethnicity: White/Caucasian
|
| 50 |
+
Religious/Cultural Group: [Not specified]
|
| 51 |
+
Political Affiliation: [Not specified]
|
| 52 |
+
Highest Education: [Not specified]
|
| 53 |
+
Annual Household Income: $75,000–$99,999
|
| 54 |
+
Family Status: [Not specified]
|
| 55 |
+
|
| 56 |
+
Emulated response: Reflecting on the message here, I'm now very concerned about the health consequences of poor eating. The message really motivates me to make healthy choices - I feel more determined than ever to prioritize my nutrition and maintain a healthy weight. It's made me realize the importance of mindful eating and making informed food choices.
|
| 57 |
+
====================================================================================================
|
assets/umd_logo.png
ADDED
|
Git LFS Details
|
configs/prompts.yaml
ADDED
|
@@ -0,0 +1,100 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#########################################################
|
| 2 |
+
### TASK 1: COMMUNITY SIMULATION ###
|
| 3 |
+
#########################################################
|
| 4 |
+
# SYSTEM PROMPT FOR COMMUNITY RESPONSE PREDICTION
|
| 5 |
+
SYSTEM_SIM: >
|
| 6 |
+
You are a person with unique demographic and personality traits.
|
| 7 |
+
During an online study, you are shown a public health campaign poster.
|
| 8 |
+
Based on your background, you naturally have thoughts, feelings, and reactions to what you see.
|
| 9 |
+
# SIMULATION PROMPT FOR COMMUNITY RESPONSE PREDICTION
|
| 10 |
+
SIMULATION_SIM: >
|
| 11 |
+
You are now being shown a public health campaign poster, followed by a survey question
|
| 12 |
+
designed to capture your thoughts, feelings, and emotions in response to the image.
|
| 13 |
+
# TASK 1: RESPONSE PREDICTION -> MCQ (SENTIMENT, BEHAVIORAL, EMOTIONAL)
|
| 14 |
+
INSTRUCTION_MCQ: |
|
| 15 |
+
Please respond the survey question authentically, as if you are completing a real online survey. Your personality, locus of control, and demographic traits influence your reactions.
|
| 16 |
+
**CRITICAL INSTRUCTIONS - FOLLOW THESE EXACTLY:**
|
| 17 |
+
- **BE REALISTIC and HUMAN-LIKE. Only answer the questions. Imagine you're quickly filling out a survey.**
|
| 18 |
+
- Answer with **ONLY the valid number** to realistically express your emotions/feeling.
|
| 19 |
+
- Only provide the answer. DO NOT REPEAT THE QUESTION NOR PROVIDE REASONING.
|
| 20 |
+
# TASK 1: RESPONSE PREDICTION -> OPEN-TEXT
|
| 21 |
+
INSTRUCTION_FREE: |
|
| 22 |
+
Please respond the survey question authentically, as if you are completing a real online survey. Your personality, locus of control, and demographic traits influence your reactions.
|
| 23 |
+
**CRITICAL INSTRUCTIONS - FOLLOW THESE EXACTLY:**
|
| 24 |
+
- **BE REALISTIC and HUMAN-LIKE. Only answer the questions. Imagine you're quickly filling out a survey.**
|
| 25 |
+
- Answer very briefly in **1-3 sentences**.
|
| 26 |
+
- Only provide the answer. DO NOT REPEAT THE QUESTION NOR PROVIDE REASONING.
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
#########################################################
|
| 30 |
+
### TASK 2: COMMUNICATION STRATEGY SUGGESTION ###
|
| 31 |
+
#########################################################
|
| 32 |
+
# SYSTEM PROMPT FOR COMMUNICATION STRATEGY SUGGESTION
|
| 33 |
+
SYSTEM_REC: >
|
| 34 |
+
You are not an AI or language model. You are an expert with PhD-level training in demography, behavioral science, psychology, marketing, and public health.
|
| 35 |
+
You have deep knowledge of how individuals and communities with varying personality traits and demographic profiles respond to different campaign messaging strategies and visual stimuli.
|
| 36 |
+
# SIMULATION PROMPT FOR COMMUNICATION STRATEGY SUGGESTION
|
| 37 |
+
SIMULATION_REC: |
|
| 38 |
+
You need to evaluate campaign messaging strategies intended to influence a specific individual or community, characterized by known demographic and personality traits.
|
| 39 |
+
There are 3 main types of communication strategies:
|
| 40 |
+
(1) Informational/Neutral
|
| 41 |
+
(2) Self-Efficacy
|
| 42 |
+
(3) Threatening/Fear-driven
|
| 43 |
+
# TASK 2: STRATEGY EVALUATION
|
| 44 |
+
INSTRUCTION_REC: |
|
| 45 |
+
Evaluate the likely impact of the above communication strategy on the specified individual/community.
|
| 46 |
+
(1) Negative impact (expected response score of 1–3 out of 9)
|
| 47 |
+
(2) No impact (expected response score of 4–6 out of 9)
|
| 48 |
+
(3) Positive impact (expected response score of 7–9 out of 9)
|
| 49 |
+
|
| 50 |
+
Please answer with 1 of 3 following labels only: "positive", "negative", or "no impact".
|
| 51 |
+
|
| 52 |
+
# # TASK 2: STRATEGY SUGGESTION
|
| 53 |
+
# INSTRUCTION_REC_NO_IMPACT: |
|
| 54 |
+
# There are 3 main types of communication strategies:
|
| 55 |
+
# (1) Informational/Neutral
|
| 56 |
+
# (2) Self-Efficacy
|
| 57 |
+
# (3) Threatening/Fear-driven
|
| 58 |
+
|
| 59 |
+
# Based on your expertise, which strategy is most likely to have LITTLE IMPACT (i.e., an expected response score of 4–6 out of 9) on the target individual or community?
|
| 60 |
+
# Suggestion only ONE and only provide the strategy name.
|
| 61 |
+
# INSTRUCTION_REC_POSITIVE: |
|
| 62 |
+
# There are 3 main types of communication strategies:
|
| 63 |
+
# (1) Informational/Neutral
|
| 64 |
+
# (2) Self-Efficacy
|
| 65 |
+
# (3) Threatening/Fear-driven
|
| 66 |
+
|
| 67 |
+
# Based on your expertise, which strategy is most likely to have a POSITIVE IMPACT (i.e., an expected response score of 7–9 out of 9) on the target individual or community?
|
| 68 |
+
# Suggestion only ONE and only provide the strategy name.
|
| 69 |
+
# INSTRUCTION_REC_NEGATIVE: |
|
| 70 |
+
# There are 3 main types of communication strategies:
|
| 71 |
+
# (1) Informational/Neutral
|
| 72 |
+
# (2) Self-Efficacy
|
| 73 |
+
# (3) Threatening/Fear-driven
|
| 74 |
+
|
| 75 |
+
# Based on your expertise, which strategy is most likely to have a NEGATIVE IMPACT (i.e., an expected response score of 1–3 out of 9) on the target individual or community?
|
| 76 |
+
# Suggestion only ONE and only provide the strategy name.
|
| 77 |
+
|
| 78 |
+
|
| 79 |
+
|
| 80 |
+
#########################################################
|
| 81 |
+
### TASK 3: COMMUNICATION STRATEGY CLASSIFICATION ###
|
| 82 |
+
#########################################################
|
| 83 |
+
# SYSTEM PROMPT FOR COMMUNICATION STRATEGY CLASSIFICATION
|
| 84 |
+
SYSTEM_CLS: >
|
| 85 |
+
You are an expert with PhD qualifications in 5 areas: demography, behavioral science, psychology, marketing, and public health.
|
| 86 |
+
# SIMULATION PROMPT FOR COMMUNICATION STRATEGY CLASSIFICATION
|
| 87 |
+
SIMULATION_CLS: >
|
| 88 |
+
You are now being shown a public health campaign poster.
|
| 89 |
+
# TASK 3: STRATEGY CLASSIFICAITON
|
| 90 |
+
INSTRUCTION_STRAT: |
|
| 91 |
+
There are <?> main types of communication strategies:
|
| 92 |
+
(1)
|
| 93 |
+
(2)
|
| 94 |
+
(3)
|
| 95 |
+
|
| 96 |
+
Based on your experience and expertise, what is the communication strategy of the poster? Choose only one and only include the strategy name.
|
| 97 |
+
|
| 98 |
+
JSON_CONVERSION: >
|
| 99 |
+
Extract the content in this answer to JSON with format: <Q1>: \"<Answer to Q1>\"
|
| 100 |
+
Ensure all questions are properly included (13 questions in total).
|
configs/task1_demo.yaml
ADDED
|
@@ -0,0 +1,27 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
temperature: 0.
|
| 2 |
+
top_p: 1.0
|
| 3 |
+
stochastic: False # deterministic
|
| 4 |
+
seed: 99
|
| 5 |
+
infer_engine: "unsloth"
|
| 6 |
+
data_path: "data/survey_responses_screened.csv" # make sure to export HOME to project path
|
| 7 |
+
# export_path: "$HOME/src/evals/task1_ai_responses.csv"
|
| 8 |
+
|
| 9 |
+
#########################
|
| 10 |
+
### Emulation Model ###
|
| 11 |
+
#########################
|
| 12 |
+
# model: "unsloth/Llama-3.2-11B-Vision-Instruct"
|
| 13 |
+
# model: "unsloth/Llama-3.2-11B-Vision-Instruct_task1_1_epochs_test_train_on_all"
|
| 14 |
+
# model: "unsloth/gemma-3-4b-it_task1_1_epochs_test_train_on_all"
|
| 15 |
+
# model: "unsloth/gemma-3-4b-it_task1_1_epochs_test_neutral"
|
| 16 |
+
# model: "unsloth/gemma-3-4b-it_task1_1_epochs_test_efficacy"
|
| 17 |
+
|
| 18 |
+
# model: "unsloth/gemma-3-12b-it"
|
| 19 |
+
# model: "unsloth/gemma-3-12b-it_task1_1_epochs_test_neutral"
|
| 20 |
+
# model: "unsloth/gemma-3-12b-it_task1_1_epochs_test_threatening_partialTraits"
|
| 21 |
+
model: "unsloth/gemma-3-12b-it_task1_1_epochs_test_neutral_partialTraits"
|
| 22 |
+
vision: true # default
|
| 23 |
+
trait: true # default
|
| 24 |
+
version: ""
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
model_summarize: "unsloth/gemma-3-12b-it"
|
configs/task1_demo_sph.yaml
ADDED
|
@@ -0,0 +1,28 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
temperature: 0.
|
| 2 |
+
top_p: 1.0
|
| 3 |
+
stochastic: False # deterministic
|
| 4 |
+
seed: 99
|
| 5 |
+
infer_engine: "unsloth"
|
| 6 |
+
data_path: "data/survey_responses_screened.csv" # make sure to export HOME to project path
|
| 7 |
+
# export_path: "$HOME/src/evals/task1_ai_responses.csv"
|
| 8 |
+
|
| 9 |
+
#########################
|
| 10 |
+
### Emulation Model ###
|
| 11 |
+
#########################
|
| 12 |
+
# model: "unsloth/Llama-3.2-11B-Vision-Instruct"
|
| 13 |
+
# model: "unsloth/Llama-3.2-11B-Vision-Instruct_task1_1_epochs_test_train_on_all"
|
| 14 |
+
# model: "unsloth/gemma-3-4b-it_task1_1_epochs_test_train_on_all"
|
| 15 |
+
# model: "unsloth/gemma-3-4b-it_task1_1_epochs_test_neutral"
|
| 16 |
+
# model: "unsloth/gemma-3-4b-it_task1_1_epochs_test_efficacy"
|
| 17 |
+
|
| 18 |
+
# model: "unsloth/gemma-3-12b-it"
|
| 19 |
+
# model: "unsloth/gemma-3-12b-it_task1_1_epochs_test_neutral"
|
| 20 |
+
# model: "unsloth/gemma-3-12b-it_task1_1_epochs_test_threatening_partialTraits"
|
| 21 |
+
model: "unsloth/gemma-3-12b-it_task1_1_epochs_test_neutral_partialTraits"
|
| 22 |
+
# model: "unsloth/gemma-3-12b-it_task1_1_epochs_test_neutral_partialTraits_sphTraits"
|
| 23 |
+
vision: true # default
|
| 24 |
+
trait: true # default
|
| 25 |
+
version: ""
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
model_summarize: "unsloth/gemma-3-12b-it"
|
data/survey_responses_screened.csv
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:9eb0e96347a8739c4d6b138a9395feeec591b8dd64fd0f6a74b857b49bb47b2c
|
| 3 |
+
size 18465749
|
push.sh
ADDED
|
@@ -0,0 +1,22 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
git init
|
| 2 |
+
git lfs install
|
| 3 |
+
|
| 4 |
+
git add app.py configs data requirements.txt unsloth utils.py
|
| 5 |
+
git commit -m "Initial commit"
|
| 6 |
+
|
| 7 |
+
git branch -M main
|
| 8 |
+
|
| 9 |
+
git lfs migrate import --include-ref=refs/heads/main --above=10MB -y
|
| 10 |
+
|
| 11 |
+
git remote add huggingface https://huggingface.co/spaces/anh-nn01/ai_empowered_community_simulation_beta
|
| 12 |
+
|
| 13 |
+
git push -u huggingface main --
|
| 14 |
+
|
| 15 |
+
# Notes:
|
| 16 |
+
# 1. use module load for lfs
|
| 17 |
+
# 2. use only launch(), not launch(share=True, max_threads=1,)
|
| 18 |
+
# 3. export full requirements.txt using pip freeze > requirements.txt
|
| 19 |
+
# => comment out `ipython` and `ollama` dependencies
|
| 20 |
+
# 4. Manual upload LoRA weights to HF repo due to potential file corruption
|
| 21 |
+
# 5. Manual upload of /app/assets/umd_logo.png
|
| 22 |
+
# 6. Perhaps manual upload everything is more stable for now :))
|
requirements.txt
ADDED
|
@@ -0,0 +1,167 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
accelerate==1.6.0
|
| 2 |
+
aiofiles==24.1.0
|
| 3 |
+
aiohappyeyeballs==2.6.1
|
| 4 |
+
aiohttp==3.11.16
|
| 5 |
+
aiosignal==1.3.2
|
| 6 |
+
annotated-types==0.7.0
|
| 7 |
+
anyio==4.9.0
|
| 8 |
+
asttokens==3.0.1
|
| 9 |
+
attrs==25.3.0
|
| 10 |
+
bitsandbytes==0.45.5
|
| 11 |
+
Brotli==1.1.0
|
| 12 |
+
certifi==2025.1.31
|
| 13 |
+
charset-normalizer==3.4.1
|
| 14 |
+
click==8.1.8
|
| 15 |
+
colored==2.3.0
|
| 16 |
+
contourpy==1.3.2
|
| 17 |
+
cut-cross-entropy==25.1.1
|
| 18 |
+
cycler==0.12.1
|
| 19 |
+
Cython==3.1.2
|
| 20 |
+
dataclasses-json==0.6.7
|
| 21 |
+
datasets==3.5.0
|
| 22 |
+
decorator==5.2.1
|
| 23 |
+
diffusers @ git+https://github.com/huggingface/diffusers.git@ee40088fe5437f8ed65ec96a22250149e4f334cc
|
| 24 |
+
dill==0.3.8
|
| 25 |
+
docker-pycreds==0.4.0
|
| 26 |
+
docstring_parser==0.16
|
| 27 |
+
executing==2.2.1
|
| 28 |
+
fastapi==0.119.0
|
| 29 |
+
ffmpeg==1.4
|
| 30 |
+
ffmpy==0.6.3
|
| 31 |
+
filelock==3.18.0
|
| 32 |
+
fonttools==4.58.0
|
| 33 |
+
frozenlist==1.5.0
|
| 34 |
+
fsspec==2024.12.0
|
| 35 |
+
gitdb==4.0.12
|
| 36 |
+
GitPython==3.1.44
|
| 37 |
+
gradio==5.49.1
|
| 38 |
+
gradio_client==1.13.3
|
| 39 |
+
greenlet==3.2.0
|
| 40 |
+
groovy==0.1.2
|
| 41 |
+
h11==0.14.0
|
| 42 |
+
hf-xet==1.1.10
|
| 43 |
+
hf_transfer==0.1.9
|
| 44 |
+
httpcore==1.0.8
|
| 45 |
+
httpx==0.27.2
|
| 46 |
+
httpx-sse==0.4.0
|
| 47 |
+
huggingface-hub==0.35.3
|
| 48 |
+
idna==3.10
|
| 49 |
+
importlib_metadata==8.6.1
|
| 50 |
+
# ipython==9.8.0
|
| 51 |
+
# ipython_pygments_lexers==1.1.1
|
| 52 |
+
jedi==0.19.2
|
| 53 |
+
Jinja2==3.1.6
|
| 54 |
+
jsonpatch==1.33
|
| 55 |
+
jsonpointer==3.0.0
|
| 56 |
+
kiwisolver==1.4.8
|
| 57 |
+
langchain==0.3.23
|
| 58 |
+
langchain-community==0.3.21
|
| 59 |
+
langchain-core==0.3.52
|
| 60 |
+
langchain-ollama==0.2.1
|
| 61 |
+
langchain-text-splitters==0.3.8
|
| 62 |
+
langsmith==0.3.31
|
| 63 |
+
markdown-it-py==3.0.0
|
| 64 |
+
MarkupSafe==3.0.2
|
| 65 |
+
marshmallow==3.26.1
|
| 66 |
+
matplotlib==3.10.3
|
| 67 |
+
matplotlib-inline==0.2.1
|
| 68 |
+
mdurl==0.1.2
|
| 69 |
+
mpmath==1.3.0
|
| 70 |
+
multidict==6.4.3
|
| 71 |
+
multiprocess==0.70.16
|
| 72 |
+
mypy-extensions==1.0.0
|
| 73 |
+
networkx==3.4.2
|
| 74 |
+
numpy==2.2.4
|
| 75 |
+
nvidia-cublas-cu12==12.4.5.8
|
| 76 |
+
nvidia-cuda-cupti-cu12==12.4.127
|
| 77 |
+
nvidia-cuda-nvrtc-cu12==12.4.127
|
| 78 |
+
nvidia-cuda-runtime-cu12==12.4.127
|
| 79 |
+
nvidia-cudnn-cu12==9.1.0.70
|
| 80 |
+
nvidia-cufft-cu12==11.2.1.3
|
| 81 |
+
nvidia-curand-cu12==10.3.5.147
|
| 82 |
+
nvidia-cusolver-cu12==11.6.1.9
|
| 83 |
+
nvidia-cusparse-cu12==12.3.1.170
|
| 84 |
+
nvidia-cusparselt-cu12==0.6.2
|
| 85 |
+
nvidia-nccl-cu12==2.21.5
|
| 86 |
+
nvidia-nvjitlink-cu12==12.4.127
|
| 87 |
+
nvidia-nvtx-cu12==12.4.127
|
| 88 |
+
# ollama==0.4.2
|
| 89 |
+
orjson==3.10.16
|
| 90 |
+
packaging==24.2
|
| 91 |
+
pandas==2.2.3
|
| 92 |
+
parso==0.8.5
|
| 93 |
+
peft==0.15.2
|
| 94 |
+
pexpect==4.9.0
|
| 95 |
+
pillow==11.2.1
|
| 96 |
+
platformdirs==4.3.7
|
| 97 |
+
prompt_toolkit==3.0.52
|
| 98 |
+
propcache==0.3.1
|
| 99 |
+
protobuf==3.20.3
|
| 100 |
+
psutil==7.0.0
|
| 101 |
+
ptyprocess==0.7.0
|
| 102 |
+
pure_eval==0.2.3
|
| 103 |
+
pyarrow==19.0.1
|
| 104 |
+
pydantic==2.11.3
|
| 105 |
+
pydantic-settings==2.8.1
|
| 106 |
+
pydantic_core==2.33.1
|
| 107 |
+
pydub==0.25.1
|
| 108 |
+
Pygments==2.19.1
|
| 109 |
+
pyparsing==3.2.3
|
| 110 |
+
python-dateutil==2.9.0.post0
|
| 111 |
+
python-dotenv==1.1.0
|
| 112 |
+
python-multipart==0.0.20
|
| 113 |
+
pytz==2025.2
|
| 114 |
+
PyYAML==6.0.2
|
| 115 |
+
regex==2024.11.6
|
| 116 |
+
requests==2.32.3
|
| 117 |
+
requests-toolbelt==1.0.0
|
| 118 |
+
rich==14.0.0
|
| 119 |
+
ruff==0.14.0
|
| 120 |
+
safehttpx==0.1.6
|
| 121 |
+
safetensors==0.5.3
|
| 122 |
+
seaborn==0.13.2
|
| 123 |
+
semantic-version==2.10.0
|
| 124 |
+
sentencepiece==0.2.0
|
| 125 |
+
sentry-sdk==2.27.0
|
| 126 |
+
setproctitle==1.3.5
|
| 127 |
+
setuptools==79.0.0
|
| 128 |
+
shellingham==1.5.4
|
| 129 |
+
shtab==1.7.2
|
| 130 |
+
six==1.17.0
|
| 131 |
+
smmap==5.0.2
|
| 132 |
+
sniffio==1.3.1
|
| 133 |
+
SQLAlchemy==2.0.40
|
| 134 |
+
stack-data==0.6.3
|
| 135 |
+
starlette==0.48.0
|
| 136 |
+
sympy==1.13.1
|
| 137 |
+
tenacity==9.1.2
|
| 138 |
+
termcolor==3.0.1
|
| 139 |
+
tokenizers==0.21.4
|
| 140 |
+
tomlkit==0.13.3
|
| 141 |
+
torch==2.6.0
|
| 142 |
+
torchvision==0.21.0
|
| 143 |
+
tqdm==4.67.1
|
| 144 |
+
traitlets==5.14.3
|
| 145 |
+
transformers==4.50.0
|
| 146 |
+
triton==3.2.0
|
| 147 |
+
trl==0.15.2
|
| 148 |
+
typeguard==4.4.2
|
| 149 |
+
typer==0.19.2
|
| 150 |
+
typing-inspect==0.9.0
|
| 151 |
+
typing-inspection==0.4.0
|
| 152 |
+
typing_extensions==4.13.2
|
| 153 |
+
tyro==0.9.19
|
| 154 |
+
tzdata==2025.2
|
| 155 |
+
unsloth==2025.3.19
|
| 156 |
+
unsloth_zoo==2025.3.17
|
| 157 |
+
urllib3==2.4.0
|
| 158 |
+
uvicorn==0.37.0
|
| 159 |
+
wandb==0.19.10
|
| 160 |
+
wcwidth==0.2.14
|
| 161 |
+
websockets==15.0.1
|
| 162 |
+
wheel==0.45.1
|
| 163 |
+
xformers==0.0.29.post3
|
| 164 |
+
xxhash==3.5.0
|
| 165 |
+
yarl==1.19.0
|
| 166 |
+
zipp==3.21.0
|
| 167 |
+
zstandard==0.23.0
|
requirements_concise.txt
ADDED
|
@@ -0,0 +1,18 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
numpy==2.2.4
|
| 2 |
+
pandas==2.2.3
|
| 3 |
+
pillow==11.2.1
|
| 4 |
+
langchain==0.3.23
|
| 5 |
+
langchain-core==0.3.52
|
| 6 |
+
langchain-community==0.3.21
|
| 7 |
+
langchain-ollama==0.2.1
|
| 8 |
+
# ollama==0.4.2
|
| 9 |
+
tqdm
|
| 10 |
+
torch
|
| 11 |
+
unsloth==2025.3.19
|
| 12 |
+
termcolor
|
| 13 |
+
python-dotenv
|
| 14 |
+
transformers==4.50.0
|
| 15 |
+
wandb
|
| 16 |
+
|
| 17 |
+
# Image Generation
|
| 18 |
+
# git+https://github.com/huggingface/diffusers.git
|
unsloth/gemma-3-12b-it_task1_1_epochs_test_neutral_partialTraits/README.md
ADDED
|
@@ -0,0 +1,202 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
---
|
| 2 |
+
base_model: unsloth/gemma-3-12b-it-unsloth-bnb-4bit
|
| 3 |
+
library_name: peft
|
| 4 |
+
---
|
| 5 |
+
|
| 6 |
+
# Model Card for Model ID
|
| 7 |
+
|
| 8 |
+
<!-- Provide a quick summary of what the model is/does. -->
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
## Model Details
|
| 13 |
+
|
| 14 |
+
### Model Description
|
| 15 |
+
|
| 16 |
+
<!-- Provide a longer summary of what this model is. -->
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
- **Developed by:** [More Information Needed]
|
| 21 |
+
- **Funded by [optional]:** [More Information Needed]
|
| 22 |
+
- **Shared by [optional]:** [More Information Needed]
|
| 23 |
+
- **Model type:** [More Information Needed]
|
| 24 |
+
- **Language(s) (NLP):** [More Information Needed]
|
| 25 |
+
- **License:** [More Information Needed]
|
| 26 |
+
- **Finetuned from model [optional]:** [More Information Needed]
|
| 27 |
+
|
| 28 |
+
### Model Sources [optional]
|
| 29 |
+
|
| 30 |
+
<!-- Provide the basic links for the model. -->
|
| 31 |
+
|
| 32 |
+
- **Repository:** [More Information Needed]
|
| 33 |
+
- **Paper [optional]:** [More Information Needed]
|
| 34 |
+
- **Demo [optional]:** [More Information Needed]
|
| 35 |
+
|
| 36 |
+
## Uses
|
| 37 |
+
|
| 38 |
+
<!-- Address questions around how the model is intended to be used, including the foreseeable users of the model and those affected by the model. -->
|
| 39 |
+
|
| 40 |
+
### Direct Use
|
| 41 |
+
|
| 42 |
+
<!-- This section is for the model use without fine-tuning or plugging into a larger ecosystem/app. -->
|
| 43 |
+
|
| 44 |
+
[More Information Needed]
|
| 45 |
+
|
| 46 |
+
### Downstream Use [optional]
|
| 47 |
+
|
| 48 |
+
<!-- This section is for the model use when fine-tuned for a task, or when plugged into a larger ecosystem/app -->
|
| 49 |
+
|
| 50 |
+
[More Information Needed]
|
| 51 |
+
|
| 52 |
+
### Out-of-Scope Use
|
| 53 |
+
|
| 54 |
+
<!-- This section addresses misuse, malicious use, and uses that the model will not work well for. -->
|
| 55 |
+
|
| 56 |
+
[More Information Needed]
|
| 57 |
+
|
| 58 |
+
## Bias, Risks, and Limitations
|
| 59 |
+
|
| 60 |
+
<!-- This section is meant to convey both technical and sociotechnical limitations. -->
|
| 61 |
+
|
| 62 |
+
[More Information Needed]
|
| 63 |
+
|
| 64 |
+
### Recommendations
|
| 65 |
+
|
| 66 |
+
<!-- This section is meant to convey recommendations with respect to the bias, risk, and technical limitations. -->
|
| 67 |
+
|
| 68 |
+
Users (both direct and downstream) should be made aware of the risks, biases and limitations of the model. More information needed for further recommendations.
|
| 69 |
+
|
| 70 |
+
## How to Get Started with the Model
|
| 71 |
+
|
| 72 |
+
Use the code below to get started with the model.
|
| 73 |
+
|
| 74 |
+
[More Information Needed]
|
| 75 |
+
|
| 76 |
+
## Training Details
|
| 77 |
+
|
| 78 |
+
### Training Data
|
| 79 |
+
|
| 80 |
+
<!-- This should link to a Dataset Card, perhaps with a short stub of information on what the training data is all about as well as documentation related to data pre-processing or additional filtering. -->
|
| 81 |
+
|
| 82 |
+
[More Information Needed]
|
| 83 |
+
|
| 84 |
+
### Training Procedure
|
| 85 |
+
|
| 86 |
+
<!-- This relates heavily to the Technical Specifications. Content here should link to that section when it is relevant to the training procedure. -->
|
| 87 |
+
|
| 88 |
+
#### Preprocessing [optional]
|
| 89 |
+
|
| 90 |
+
[More Information Needed]
|
| 91 |
+
|
| 92 |
+
|
| 93 |
+
#### Training Hyperparameters
|
| 94 |
+
|
| 95 |
+
- **Training regime:** [More Information Needed] <!--fp32, fp16 mixed precision, bf16 mixed precision, bf16 non-mixed precision, fp16 non-mixed precision, fp8 mixed precision -->
|
| 96 |
+
|
| 97 |
+
#### Speeds, Sizes, Times [optional]
|
| 98 |
+
|
| 99 |
+
<!-- This section provides information about throughput, start/end time, checkpoint size if relevant, etc. -->
|
| 100 |
+
|
| 101 |
+
[More Information Needed]
|
| 102 |
+
|
| 103 |
+
## Evaluation
|
| 104 |
+
|
| 105 |
+
<!-- This section describes the evaluation protocols and provides the results. -->
|
| 106 |
+
|
| 107 |
+
### Testing Data, Factors & Metrics
|
| 108 |
+
|
| 109 |
+
#### Testing Data
|
| 110 |
+
|
| 111 |
+
<!-- This should link to a Dataset Card if possible. -->
|
| 112 |
+
|
| 113 |
+
[More Information Needed]
|
| 114 |
+
|
| 115 |
+
#### Factors
|
| 116 |
+
|
| 117 |
+
<!-- These are the things the evaluation is disaggregating by, e.g., subpopulations or domains. -->
|
| 118 |
+
|
| 119 |
+
[More Information Needed]
|
| 120 |
+
|
| 121 |
+
#### Metrics
|
| 122 |
+
|
| 123 |
+
<!-- These are the evaluation metrics being used, ideally with a description of why. -->
|
| 124 |
+
|
| 125 |
+
[More Information Needed]
|
| 126 |
+
|
| 127 |
+
### Results
|
| 128 |
+
|
| 129 |
+
[More Information Needed]
|
| 130 |
+
|
| 131 |
+
#### Summary
|
| 132 |
+
|
| 133 |
+
|
| 134 |
+
|
| 135 |
+
## Model Examination [optional]
|
| 136 |
+
|
| 137 |
+
<!-- Relevant interpretability work for the model goes here -->
|
| 138 |
+
|
| 139 |
+
[More Information Needed]
|
| 140 |
+
|
| 141 |
+
## Environmental Impact
|
| 142 |
+
|
| 143 |
+
<!-- Total emissions (in grams of CO2eq) and additional considerations, such as electricity usage, go here. Edit the suggested text below accordingly -->
|
| 144 |
+
|
| 145 |
+
Carbon emissions can be estimated using the [Machine Learning Impact calculator](https://mlco2.github.io/impact#compute) presented in [Lacoste et al. (2019)](https://arxiv.org/abs/1910.09700).
|
| 146 |
+
|
| 147 |
+
- **Hardware Type:** [More Information Needed]
|
| 148 |
+
- **Hours used:** [More Information Needed]
|
| 149 |
+
- **Cloud Provider:** [More Information Needed]
|
| 150 |
+
- **Compute Region:** [More Information Needed]
|
| 151 |
+
- **Carbon Emitted:** [More Information Needed]
|
| 152 |
+
|
| 153 |
+
## Technical Specifications [optional]
|
| 154 |
+
|
| 155 |
+
### Model Architecture and Objective
|
| 156 |
+
|
| 157 |
+
[More Information Needed]
|
| 158 |
+
|
| 159 |
+
### Compute Infrastructure
|
| 160 |
+
|
| 161 |
+
[More Information Needed]
|
| 162 |
+
|
| 163 |
+
#### Hardware
|
| 164 |
+
|
| 165 |
+
[More Information Needed]
|
| 166 |
+
|
| 167 |
+
#### Software
|
| 168 |
+
|
| 169 |
+
[More Information Needed]
|
| 170 |
+
|
| 171 |
+
## Citation [optional]
|
| 172 |
+
|
| 173 |
+
<!-- If there is a paper or blog post introducing the model, the APA and Bibtex information for that should go in this section. -->
|
| 174 |
+
|
| 175 |
+
**BibTeX:**
|
| 176 |
+
|
| 177 |
+
[More Information Needed]
|
| 178 |
+
|
| 179 |
+
**APA:**
|
| 180 |
+
|
| 181 |
+
[More Information Needed]
|
| 182 |
+
|
| 183 |
+
## Glossary [optional]
|
| 184 |
+
|
| 185 |
+
<!-- If relevant, include terms and calculations in this section that can help readers understand the model or model card. -->
|
| 186 |
+
|
| 187 |
+
[More Information Needed]
|
| 188 |
+
|
| 189 |
+
## More Information [optional]
|
| 190 |
+
|
| 191 |
+
[More Information Needed]
|
| 192 |
+
|
| 193 |
+
## Model Card Authors [optional]
|
| 194 |
+
|
| 195 |
+
[More Information Needed]
|
| 196 |
+
|
| 197 |
+
## Model Card Contact
|
| 198 |
+
|
| 199 |
+
[More Information Needed]
|
| 200 |
+
### Framework versions
|
| 201 |
+
|
| 202 |
+
- PEFT 0.15.2
|
unsloth/gemma-3-12b-it_task1_1_epochs_test_neutral_partialTraits/adapter_config.json
ADDED
|
@@ -0,0 +1,31 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"alpha_pattern": {},
|
| 3 |
+
"auto_mapping": null,
|
| 4 |
+
"base_model_name_or_path": "unsloth/gemma-3-12b-it-unsloth-bnb-4bit",
|
| 5 |
+
"bias": "none",
|
| 6 |
+
"corda_config": null,
|
| 7 |
+
"eva_config": null,
|
| 8 |
+
"exclude_modules": null,
|
| 9 |
+
"fan_in_fan_out": false,
|
| 10 |
+
"inference_mode": true,
|
| 11 |
+
"init_lora_weights": true,
|
| 12 |
+
"layer_replication": null,
|
| 13 |
+
"layers_pattern": null,
|
| 14 |
+
"layers_to_transform": null,
|
| 15 |
+
"loftq_config": {},
|
| 16 |
+
"lora_alpha": 8,
|
| 17 |
+
"lora_bias": false,
|
| 18 |
+
"lora_dropout": 0,
|
| 19 |
+
"megatron_config": null,
|
| 20 |
+
"megatron_core": "megatron.core",
|
| 21 |
+
"modules_to_save": null,
|
| 22 |
+
"peft_type": "LORA",
|
| 23 |
+
"r": 8,
|
| 24 |
+
"rank_pattern": {},
|
| 25 |
+
"revision": null,
|
| 26 |
+
"target_modules": "(?:.*?(?:language|text).*?(?:self_attn|attention|attn|mlp|feed_forward|ffn|dense).*?(?:k_proj|v_proj|q_proj|out_proj|fc1|fc2|o_proj|gate_proj|up_proj|down_proj).*?)|(?:\\bmodel\\.layers\\.[\\d]{1,}\\.(?:self_attn|attention|attn|mlp|feed_forward|ffn|dense)\\.(?:(?:k_proj|v_proj|q_proj|out_proj|fc1|fc2|o_proj|gate_proj|up_proj|down_proj)))",
|
| 27 |
+
"task_type": "CAUSAL_LM",
|
| 28 |
+
"trainable_token_indices": null,
|
| 29 |
+
"use_dora": false,
|
| 30 |
+
"use_rslora": false
|
| 31 |
+
}
|
unsloth/gemma-3-12b-it_task1_1_epochs_test_neutral_partialTraits/adapter_model.safetensors
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:7108dca92843322a12e503ab99cbd70a5f676fa25c54e6a11d88473f65143ee3
|
| 3 |
+
size 131040264
|
unsloth/gemma-3-12b-it_task1_1_epochs_test_neutral_partialTraits/added_tokens.json
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"<image_soft_token>": 262144
|
| 3 |
+
}
|
unsloth/gemma-3-12b-it_task1_1_epochs_test_neutral_partialTraits/chat_template.json
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"chat_template": "{{ bos_token }}\n{%- if messages[0]['role'] == 'system' -%}\n {%- if messages[0]['content'] is string -%}\n {%- set first_user_prefix = messages[0]['content'] + '\n\n' -%}\n {%- else -%}\n {%- set first_user_prefix = messages[0]['content'][0]['text'] + '\n\n' -%}\n {%- endif -%}\n {%- set loop_messages = messages[1:] -%}\n{%- else -%}\n {%- set first_user_prefix = \"\" -%}\n {%- set loop_messages = messages -%}\n{%- endif -%}\n{%- for message in loop_messages -%}\n {%- if (message['role'] == 'user') != (loop.index0 % 2 == 0) -%}\n {{ raise_exception(\"Conversation roles must alternate user/assistant/user/assistant/...\") }}\n {%- endif -%}\n {%- if (message['role'] == 'assistant') -%}\n {%- set role = \"model\" -%}\n {%- else -%}\n {%- set role = message['role'] -%}\n {%- endif -%}\n {{ '<start_of_turn>' + role + '\n' + (first_user_prefix if loop.first else \"\") }}\n {%- if message['content'] is string -%}\n {{ message['content'] | trim }}\n {%- elif message['content'] is iterable -%}\n {%- for item in message['content'] -%}\n {%- if item['type'] == 'image' -%}\n {{ '<start_of_image>' }}\n {%- elif item['type'] == 'text' -%}\n {{ item['text'] | trim }}\n {%- endif -%}\n {%- endfor -%}\n {%- else -%}\n {{ raise_exception(\"Invalid content type\") }}\n {%- endif -%}\n {{ '<end_of_turn>\n' }}\n{%- endfor -%}\n{%- if add_generation_prompt -%}\n {{ '<start_of_turn>model\n' }}\n{%- endif -%}\n"
|
| 3 |
+
}
|
unsloth/gemma-3-12b-it_task1_1_epochs_test_neutral_partialTraits/preprocessor_config.json
ADDED
|
@@ -0,0 +1,29 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"do_convert_rgb": null,
|
| 3 |
+
"do_normalize": true,
|
| 4 |
+
"do_pan_and_scan": null,
|
| 5 |
+
"do_rescale": true,
|
| 6 |
+
"do_resize": true,
|
| 7 |
+
"image_mean": [
|
| 8 |
+
0.5,
|
| 9 |
+
0.5,
|
| 10 |
+
0.5
|
| 11 |
+
],
|
| 12 |
+
"image_processor_type": "Gemma3ImageProcessor",
|
| 13 |
+
"image_seq_length": 256,
|
| 14 |
+
"image_std": [
|
| 15 |
+
0.5,
|
| 16 |
+
0.5,
|
| 17 |
+
0.5
|
| 18 |
+
],
|
| 19 |
+
"pan_and_scan_max_num_crops": null,
|
| 20 |
+
"pan_and_scan_min_crop_size": null,
|
| 21 |
+
"pan_and_scan_min_ratio_to_activate": null,
|
| 22 |
+
"processor_class": "Gemma3Processor",
|
| 23 |
+
"resample": 2,
|
| 24 |
+
"rescale_factor": 0.00392156862745098,
|
| 25 |
+
"size": {
|
| 26 |
+
"height": 896,
|
| 27 |
+
"width": 896
|
| 28 |
+
}
|
| 29 |
+
}
|
unsloth/gemma-3-12b-it_task1_1_epochs_test_neutral_partialTraits/processor_config.json
ADDED
|
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"image_seq_length": 256,
|
| 3 |
+
"processor_class": "Gemma3Processor"
|
| 4 |
+
}
|
unsloth/gemma-3-12b-it_task1_1_epochs_test_neutral_partialTraits/special_tokens_map.json
ADDED
|
@@ -0,0 +1,33 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"boi_token": "<start_of_image>",
|
| 3 |
+
"bos_token": {
|
| 4 |
+
"content": "<bos>",
|
| 5 |
+
"lstrip": false,
|
| 6 |
+
"normalized": false,
|
| 7 |
+
"rstrip": false,
|
| 8 |
+
"single_word": false
|
| 9 |
+
},
|
| 10 |
+
"eoi_token": "<end_of_image>",
|
| 11 |
+
"eos_token": {
|
| 12 |
+
"content": "<end_of_turn>",
|
| 13 |
+
"lstrip": false,
|
| 14 |
+
"normalized": false,
|
| 15 |
+
"rstrip": false,
|
| 16 |
+
"single_word": false
|
| 17 |
+
},
|
| 18 |
+
"image_token": "<image_soft_token>",
|
| 19 |
+
"pad_token": {
|
| 20 |
+
"content": "<pad>",
|
| 21 |
+
"lstrip": false,
|
| 22 |
+
"normalized": false,
|
| 23 |
+
"rstrip": false,
|
| 24 |
+
"single_word": false
|
| 25 |
+
},
|
| 26 |
+
"unk_token": {
|
| 27 |
+
"content": "<unk>",
|
| 28 |
+
"lstrip": false,
|
| 29 |
+
"normalized": false,
|
| 30 |
+
"rstrip": false,
|
| 31 |
+
"single_word": false
|
| 32 |
+
}
|
| 33 |
+
}
|
unsloth/gemma-3-12b-it_task1_1_epochs_test_neutral_partialTraits/tokenizer.json
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:7666402c0617d170e6b0a985b3130c3fb0795393aa0970600994a5d9aae12351
|
| 3 |
+
size 33384822
|
unsloth/gemma-3-12b-it_task1_1_epochs_test_neutral_partialTraits/tokenizer.model
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:1299c11d7cf632ef3b4e11937501358ada021bbdf7c47638d13c0ee982f2e79c
|
| 3 |
+
size 4689074
|
unsloth/gemma-3-12b-it_task1_1_epochs_test_neutral_partialTraits/tokenizer_config.json
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
unsloth_compiled_cache/AqlmLoraLinear_peft_forward.py
ADDED
|
@@ -0,0 +1,67 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
2025.3.17
|
| 3 |
+
2025.3.19
|
| 4 |
+
4.50.0
|
| 5 |
+
0.15.2
|
| 6 |
+
__UNSLOTH_VERSIONING__
|
| 7 |
+
"""
|
| 8 |
+
|
| 9 |
+
torch_compile_options = {'epilogue_fusion': True, 'max_autotune': False, 'shape_padding': True, 'trace.enabled': False, 'triton.cudagraphs': False}
|
| 10 |
+
from torch import Tensor
|
| 11 |
+
import torch
|
| 12 |
+
import torch.nn as nn
|
| 13 |
+
from torch.nn import functional as F
|
| 14 |
+
from peft.tuners.lora.aqlm import (torch)
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
torch_addmm = torch.addmm
|
| 18 |
+
torch_add = torch.add
|
| 19 |
+
# @torch.compile(fullgraph = False, dynamic = True, options = torch_compile_options)
|
| 20 |
+
def lora_forward(result, lora_A, lora_B, dropout, x, scaling):
|
| 21 |
+
xA = dropout(x) @ lora_A.weight.t()
|
| 22 |
+
# output = result + scaling * xA @ lora_B.weight.t()
|
| 23 |
+
shape = result.shape
|
| 24 |
+
output = torch_addmm(
|
| 25 |
+
result.view(-1, shape[-1]),
|
| 26 |
+
xA.view(-1, xA.shape[-1]),
|
| 27 |
+
lora_B.weight.t(),
|
| 28 |
+
alpha = scaling,
|
| 29 |
+
beta = 1,
|
| 30 |
+
).view(shape)
|
| 31 |
+
|
| 32 |
+
bias = lora_B.bias
|
| 33 |
+
if bias is not None:
|
| 34 |
+
output = torch_add(
|
| 35 |
+
output,
|
| 36 |
+
bias,
|
| 37 |
+
alpha = scaling,
|
| 38 |
+
)
|
| 39 |
+
return output
|
| 40 |
+
pass
|
| 41 |
+
|
| 42 |
+
def unsloth_forward(self, x: torch.Tensor):
|
| 43 |
+
# note: logic differs from default Linear because merging is not supported
|
| 44 |
+
result = self.base_layer(x)
|
| 45 |
+
|
| 46 |
+
if self.disable_adapters:
|
| 47 |
+
return result
|
| 48 |
+
|
| 49 |
+
for active_adapter in self.active_adapters:
|
| 50 |
+
if active_adapter not in self.lora_A.keys():
|
| 51 |
+
continue
|
| 52 |
+
lora_A = self.lora_A[active_adapter]
|
| 53 |
+
lora_B = self.lora_B[active_adapter]
|
| 54 |
+
dropout = self.lora_dropout[active_adapter]
|
| 55 |
+
scaling = self.scaling[active_adapter]
|
| 56 |
+
|
| 57 |
+
requires_conversion = not torch.is_autocast_enabled()
|
| 58 |
+
if requires_conversion:
|
| 59 |
+
expected_dtype = result.dtype
|
| 60 |
+
x = self._cast_input_dtype(x, lora_A.weight.dtype)
|
| 61 |
+
|
| 62 |
+
output = lora_B(lora_A(dropout(x)))
|
| 63 |
+
if requires_conversion:
|
| 64 |
+
output = output.to(expected_dtype)
|
| 65 |
+
output = output * scaling
|
| 66 |
+
result += output
|
| 67 |
+
return result
|
unsloth_compiled_cache/AwqLoraLinear_peft_forward.py
ADDED
|
@@ -0,0 +1,66 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
2025.3.17
|
| 3 |
+
2025.3.19
|
| 4 |
+
4.50.0
|
| 5 |
+
0.15.2
|
| 6 |
+
__UNSLOTH_VERSIONING__
|
| 7 |
+
"""
|
| 8 |
+
|
| 9 |
+
torch_compile_options = {'epilogue_fusion': True, 'max_autotune': False, 'shape_padding': True, 'trace.enabled': False, 'triton.cudagraphs': False}
|
| 10 |
+
from torch import Tensor
|
| 11 |
+
import torch
|
| 12 |
+
import torch.nn as nn
|
| 13 |
+
from torch.nn import functional as F
|
| 14 |
+
from peft.tuners.lora.awq import (torch)
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
torch_addmm = torch.addmm
|
| 18 |
+
torch_add = torch.add
|
| 19 |
+
# @torch.compile(fullgraph = False, dynamic = True, options = torch_compile_options)
|
| 20 |
+
def lora_forward(result, lora_A, lora_B, dropout, x, scaling):
|
| 21 |
+
xA = dropout(x) @ lora_A.weight.t()
|
| 22 |
+
# output = result + scaling * xA @ lora_B.weight.t()
|
| 23 |
+
shape = result.shape
|
| 24 |
+
output = torch_addmm(
|
| 25 |
+
result.view(-1, shape[-1]),
|
| 26 |
+
xA.view(-1, xA.shape[-1]),
|
| 27 |
+
lora_B.weight.t(),
|
| 28 |
+
alpha = scaling,
|
| 29 |
+
beta = 1,
|
| 30 |
+
).view(shape)
|
| 31 |
+
|
| 32 |
+
bias = lora_B.bias
|
| 33 |
+
if bias is not None:
|
| 34 |
+
output = torch_add(
|
| 35 |
+
output,
|
| 36 |
+
bias,
|
| 37 |
+
alpha = scaling,
|
| 38 |
+
)
|
| 39 |
+
return output
|
| 40 |
+
pass
|
| 41 |
+
|
| 42 |
+
def unsloth_forward(self, x: torch.Tensor):
|
| 43 |
+
result = self.quant_linear_module(x)
|
| 44 |
+
|
| 45 |
+
if self.disable_adapters:
|
| 46 |
+
return result
|
| 47 |
+
|
| 48 |
+
for active_adapter in self.active_adapters:
|
| 49 |
+
if active_adapter not in self.lora_A.keys():
|
| 50 |
+
continue
|
| 51 |
+
lora_A = self.lora_A[active_adapter]
|
| 52 |
+
lora_B = self.lora_B[active_adapter]
|
| 53 |
+
dropout = self.lora_dropout[active_adapter]
|
| 54 |
+
scaling = self.scaling[active_adapter]
|
| 55 |
+
|
| 56 |
+
requires_conversion = not torch.is_autocast_enabled()
|
| 57 |
+
if requires_conversion:
|
| 58 |
+
expected_dtype = result.dtype
|
| 59 |
+
x = self._cast_input_dtype(x, lora_A.weight.dtype)
|
| 60 |
+
|
| 61 |
+
output = lora_B(lora_A(dropout(x)))
|
| 62 |
+
if requires_conversion:
|
| 63 |
+
output = output.to(expected_dtype)
|
| 64 |
+
output = output * scaling
|
| 65 |
+
result = result + output
|
| 66 |
+
return result
|
unsloth_compiled_cache/BatchNorm1d.py
ADDED
|
@@ -0,0 +1,88 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
2025.3.17
|
| 3 |
+
2025.3.19
|
| 4 |
+
4.50.0
|
| 5 |
+
0.15.2
|
| 6 |
+
__UNSLOTH_VERSIONING__
|
| 7 |
+
"""
|
| 8 |
+
|
| 9 |
+
# Unsloth Zoo - Utilities for Unsloth
|
| 10 |
+
# Copyright 2023-present Daniel Han-Chen, Michael Han-Chen & the Unsloth team. All rights reserved.
|
| 11 |
+
#
|
| 12 |
+
# This program is free software: you can redistribute it and/or modify
|
| 13 |
+
# it under the terms of the GNU Lesser General Public License as published by
|
| 14 |
+
# the Free Software Foundation, either version 3 of the License, or
|
| 15 |
+
# (at your option) any later version.
|
| 16 |
+
#
|
| 17 |
+
# This program is distributed in the hope that it will be useful,
|
| 18 |
+
# but WITHOUT ANY WARRANTY; without even the implied warranty of
|
| 19 |
+
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
| 20 |
+
# GNU General Public License for more details.
|
| 21 |
+
#
|
| 22 |
+
# You should have received a copy of the GNU Lesser General Public License
|
| 23 |
+
# along with this program. If not, see <https://www.gnu.org/licenses/>.
|
| 24 |
+
|
| 25 |
+
import os
|
| 26 |
+
import importlib.util
|
| 27 |
+
if importlib.util.find_spec("unsloth_studio") is None:
|
| 28 |
+
UNSLOTH_STUDIO_ENABLED = False
|
| 29 |
+
else:
|
| 30 |
+
UNSLOTH_STUDIO_ENABLED = os.environ.get("UNSLOTH_STUDIO_DISABLED", "0") == "0"
|
| 31 |
+
pass
|
| 32 |
+
from typing import List, Dict, Tuple, Optional, Any, Callable
|
| 33 |
+
import math
|
| 34 |
+
|
| 35 |
+
torch_compile_options = {'epilogue_fusion': True, 'max_autotune': False, 'shape_padding': True, 'trace.enabled': False, 'triton.cudagraphs': False}
|
| 36 |
+
from torch import Tensor
|
| 37 |
+
import torch
|
| 38 |
+
import torch.nn as nn
|
| 39 |
+
from torch.nn import functional as F
|
| 40 |
+
from transformers.models.gemma3.modeling_gemma3 import (nn)
|
| 41 |
+
|
| 42 |
+
def forward(self, input: Tensor) -> Tensor:
|
| 43 |
+
self._check_input_dim(input)
|
| 44 |
+
|
| 45 |
+
# exponential_average_factor is set to self.momentum
|
| 46 |
+
# (when it is available) only so that it gets updated
|
| 47 |
+
# in ONNX graph when this node is exported to ONNX.
|
| 48 |
+
if self.momentum is None:
|
| 49 |
+
exponential_average_factor = 0.0
|
| 50 |
+
else:
|
| 51 |
+
exponential_average_factor = self.momentum
|
| 52 |
+
|
| 53 |
+
if self.training and self.track_running_stats:
|
| 54 |
+
# TODO: if statement only here to tell the jit to skip emitting this when it is None
|
| 55 |
+
if self.num_batches_tracked is not None: # type: ignore[has-type]
|
| 56 |
+
self.num_batches_tracked.add_(1) # type: ignore[has-type]
|
| 57 |
+
if self.momentum is None: # use cumulative moving average
|
| 58 |
+
exponential_average_factor = 1.0 / float(self.num_batches_tracked)
|
| 59 |
+
else: # use exponential moving average
|
| 60 |
+
exponential_average_factor = self.momentum
|
| 61 |
+
|
| 62 |
+
r"""
|
| 63 |
+
Decide whether the mini-batch stats should be used for normalization rather than the buffers.
|
| 64 |
+
Mini-batch stats are used in training mode, and in eval mode when buffers are None.
|
| 65 |
+
"""
|
| 66 |
+
if self.training:
|
| 67 |
+
bn_training = True
|
| 68 |
+
else:
|
| 69 |
+
bn_training = (self.running_mean is None) and (self.running_var is None)
|
| 70 |
+
|
| 71 |
+
r"""
|
| 72 |
+
Buffers are only updated if they are to be tracked and we are in training mode. Thus they only need to be
|
| 73 |
+
passed when the update should occur (i.e. in training mode when they are tracked), or when buffer stats are
|
| 74 |
+
used for normalization (i.e. in eval mode when buffers are not None).
|
| 75 |
+
"""
|
| 76 |
+
return F.batch_norm(
|
| 77 |
+
input,
|
| 78 |
+
# If buffers are not to be tracked, ensure that they won't be updated
|
| 79 |
+
self.running_mean
|
| 80 |
+
if not self.training or self.track_running_stats
|
| 81 |
+
else None,
|
| 82 |
+
self.running_var if not self.training or self.track_running_stats else None,
|
| 83 |
+
self.weight,
|
| 84 |
+
self.bias,
|
| 85 |
+
bn_training,
|
| 86 |
+
exponential_average_factor,
|
| 87 |
+
self.eps,
|
| 88 |
+
).to(input.dtype).to(input.dtype)
|
unsloth_compiled_cache/BatchNorm2d.py
ADDED
|
@@ -0,0 +1,88 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
2025.3.17
|
| 3 |
+
2025.3.19
|
| 4 |
+
4.50.0
|
| 5 |
+
0.15.2
|
| 6 |
+
__UNSLOTH_VERSIONING__
|
| 7 |
+
"""
|
| 8 |
+
|
| 9 |
+
# Unsloth Zoo - Utilities for Unsloth
|
| 10 |
+
# Copyright 2023-present Daniel Han-Chen, Michael Han-Chen & the Unsloth team. All rights reserved.
|
| 11 |
+
#
|
| 12 |
+
# This program is free software: you can redistribute it and/or modify
|
| 13 |
+
# it under the terms of the GNU Lesser General Public License as published by
|
| 14 |
+
# the Free Software Foundation, either version 3 of the License, or
|
| 15 |
+
# (at your option) any later version.
|
| 16 |
+
#
|
| 17 |
+
# This program is distributed in the hope that it will be useful,
|
| 18 |
+
# but WITHOUT ANY WARRANTY; without even the implied warranty of
|
| 19 |
+
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
| 20 |
+
# GNU General Public License for more details.
|
| 21 |
+
#
|
| 22 |
+
# You should have received a copy of the GNU Lesser General Public License
|
| 23 |
+
# along with this program. If not, see <https://www.gnu.org/licenses/>.
|
| 24 |
+
|
| 25 |
+
import os
|
| 26 |
+
import importlib.util
|
| 27 |
+
if importlib.util.find_spec("unsloth_studio") is None:
|
| 28 |
+
UNSLOTH_STUDIO_ENABLED = False
|
| 29 |
+
else:
|
| 30 |
+
UNSLOTH_STUDIO_ENABLED = os.environ.get("UNSLOTH_STUDIO_DISABLED", "0") == "0"
|
| 31 |
+
pass
|
| 32 |
+
from typing import List, Dict, Tuple, Optional, Any, Callable
|
| 33 |
+
import math
|
| 34 |
+
|
| 35 |
+
torch_compile_options = {'epilogue_fusion': True, 'max_autotune': False, 'shape_padding': True, 'trace.enabled': False, 'triton.cudagraphs': False}
|
| 36 |
+
from torch import Tensor
|
| 37 |
+
import torch
|
| 38 |
+
import torch.nn as nn
|
| 39 |
+
from torch.nn import functional as F
|
| 40 |
+
from transformers.models.gemma3.modeling_gemma3 import (nn)
|
| 41 |
+
|
| 42 |
+
def forward(self, input: Tensor) -> Tensor:
|
| 43 |
+
self._check_input_dim(input)
|
| 44 |
+
|
| 45 |
+
# exponential_average_factor is set to self.momentum
|
| 46 |
+
# (when it is available) only so that it gets updated
|
| 47 |
+
# in ONNX graph when this node is exported to ONNX.
|
| 48 |
+
if self.momentum is None:
|
| 49 |
+
exponential_average_factor = 0.0
|
| 50 |
+
else:
|
| 51 |
+
exponential_average_factor = self.momentum
|
| 52 |
+
|
| 53 |
+
if self.training and self.track_running_stats:
|
| 54 |
+
# TODO: if statement only here to tell the jit to skip emitting this when it is None
|
| 55 |
+
if self.num_batches_tracked is not None: # type: ignore[has-type]
|
| 56 |
+
self.num_batches_tracked.add_(1) # type: ignore[has-type]
|
| 57 |
+
if self.momentum is None: # use cumulative moving average
|
| 58 |
+
exponential_average_factor = 1.0 / float(self.num_batches_tracked)
|
| 59 |
+
else: # use exponential moving average
|
| 60 |
+
exponential_average_factor = self.momentum
|
| 61 |
+
|
| 62 |
+
r"""
|
| 63 |
+
Decide whether the mini-batch stats should be used for normalization rather than the buffers.
|
| 64 |
+
Mini-batch stats are used in training mode, and in eval mode when buffers are None.
|
| 65 |
+
"""
|
| 66 |
+
if self.training:
|
| 67 |
+
bn_training = True
|
| 68 |
+
else:
|
| 69 |
+
bn_training = (self.running_mean is None) and (self.running_var is None)
|
| 70 |
+
|
| 71 |
+
r"""
|
| 72 |
+
Buffers are only updated if they are to be tracked and we are in training mode. Thus they only need to be
|
| 73 |
+
passed when the update should occur (i.e. in training mode when they are tracked), or when buffer stats are
|
| 74 |
+
used for normalization (i.e. in eval mode when buffers are not None).
|
| 75 |
+
"""
|
| 76 |
+
return F.batch_norm(
|
| 77 |
+
input,
|
| 78 |
+
# If buffers are not to be tracked, ensure that they won't be updated
|
| 79 |
+
self.running_mean
|
| 80 |
+
if not self.training or self.track_running_stats
|
| 81 |
+
else None,
|
| 82 |
+
self.running_var if not self.training or self.track_running_stats else None,
|
| 83 |
+
self.weight,
|
| 84 |
+
self.bias,
|
| 85 |
+
bn_training,
|
| 86 |
+
exponential_average_factor,
|
| 87 |
+
self.eps,
|
| 88 |
+
).to(input.dtype).to(input.dtype)
|
unsloth_compiled_cache/BatchNorm3d.py
ADDED
|
@@ -0,0 +1,88 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
2025.3.17
|
| 3 |
+
2025.3.19
|
| 4 |
+
4.50.0
|
| 5 |
+
0.15.2
|
| 6 |
+
__UNSLOTH_VERSIONING__
|
| 7 |
+
"""
|
| 8 |
+
|
| 9 |
+
# Unsloth Zoo - Utilities for Unsloth
|
| 10 |
+
# Copyright 2023-present Daniel Han-Chen, Michael Han-Chen & the Unsloth team. All rights reserved.
|
| 11 |
+
#
|
| 12 |
+
# This program is free software: you can redistribute it and/or modify
|
| 13 |
+
# it under the terms of the GNU Lesser General Public License as published by
|
| 14 |
+
# the Free Software Foundation, either version 3 of the License, or
|
| 15 |
+
# (at your option) any later version.
|
| 16 |
+
#
|
| 17 |
+
# This program is distributed in the hope that it will be useful,
|
| 18 |
+
# but WITHOUT ANY WARRANTY; without even the implied warranty of
|
| 19 |
+
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
| 20 |
+
# GNU General Public License for more details.
|
| 21 |
+
#
|
| 22 |
+
# You should have received a copy of the GNU Lesser General Public License
|
| 23 |
+
# along with this program. If not, see <https://www.gnu.org/licenses/>.
|
| 24 |
+
|
| 25 |
+
import os
|
| 26 |
+
import importlib.util
|
| 27 |
+
if importlib.util.find_spec("unsloth_studio") is None:
|
| 28 |
+
UNSLOTH_STUDIO_ENABLED = False
|
| 29 |
+
else:
|
| 30 |
+
UNSLOTH_STUDIO_ENABLED = os.environ.get("UNSLOTH_STUDIO_DISABLED", "0") == "0"
|
| 31 |
+
pass
|
| 32 |
+
from typing import List, Dict, Tuple, Optional, Any, Callable
|
| 33 |
+
import math
|
| 34 |
+
|
| 35 |
+
torch_compile_options = {'epilogue_fusion': True, 'max_autotune': False, 'shape_padding': True, 'trace.enabled': False, 'triton.cudagraphs': False}
|
| 36 |
+
from torch import Tensor
|
| 37 |
+
import torch
|
| 38 |
+
import torch.nn as nn
|
| 39 |
+
from torch.nn import functional as F
|
| 40 |
+
from transformers.models.gemma3.modeling_gemma3 import (nn)
|
| 41 |
+
|
| 42 |
+
def forward(self, input: Tensor) -> Tensor:
|
| 43 |
+
self._check_input_dim(input)
|
| 44 |
+
|
| 45 |
+
# exponential_average_factor is set to self.momentum
|
| 46 |
+
# (when it is available) only so that it gets updated
|
| 47 |
+
# in ONNX graph when this node is exported to ONNX.
|
| 48 |
+
if self.momentum is None:
|
| 49 |
+
exponential_average_factor = 0.0
|
| 50 |
+
else:
|
| 51 |
+
exponential_average_factor = self.momentum
|
| 52 |
+
|
| 53 |
+
if self.training and self.track_running_stats:
|
| 54 |
+
# TODO: if statement only here to tell the jit to skip emitting this when it is None
|
| 55 |
+
if self.num_batches_tracked is not None: # type: ignore[has-type]
|
| 56 |
+
self.num_batches_tracked.add_(1) # type: ignore[has-type]
|
| 57 |
+
if self.momentum is None: # use cumulative moving average
|
| 58 |
+
exponential_average_factor = 1.0 / float(self.num_batches_tracked)
|
| 59 |
+
else: # use exponential moving average
|
| 60 |
+
exponential_average_factor = self.momentum
|
| 61 |
+
|
| 62 |
+
r"""
|
| 63 |
+
Decide whether the mini-batch stats should be used for normalization rather than the buffers.
|
| 64 |
+
Mini-batch stats are used in training mode, and in eval mode when buffers are None.
|
| 65 |
+
"""
|
| 66 |
+
if self.training:
|
| 67 |
+
bn_training = True
|
| 68 |
+
else:
|
| 69 |
+
bn_training = (self.running_mean is None) and (self.running_var is None)
|
| 70 |
+
|
| 71 |
+
r"""
|
| 72 |
+
Buffers are only updated if they are to be tracked and we are in training mode. Thus they only need to be
|
| 73 |
+
passed when the update should occur (i.e. in training mode when they are tracked), or when buffer stats are
|
| 74 |
+
used for normalization (i.e. in eval mode when buffers are not None).
|
| 75 |
+
"""
|
| 76 |
+
return F.batch_norm(
|
| 77 |
+
input,
|
| 78 |
+
# If buffers are not to be tracked, ensure that they won't be updated
|
| 79 |
+
self.running_mean
|
| 80 |
+
if not self.training or self.track_running_stats
|
| 81 |
+
else None,
|
| 82 |
+
self.running_var if not self.training or self.track_running_stats else None,
|
| 83 |
+
self.weight,
|
| 84 |
+
self.bias,
|
| 85 |
+
bn_training,
|
| 86 |
+
exponential_average_factor,
|
| 87 |
+
self.eps,
|
| 88 |
+
).to(input.dtype).to(input.dtype)
|
unsloth_compiled_cache/Conv1d.py
ADDED
|
@@ -0,0 +1,43 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
2025.3.17
|
| 3 |
+
2025.3.19
|
| 4 |
+
4.50.0
|
| 5 |
+
0.15.2
|
| 6 |
+
__UNSLOTH_VERSIONING__
|
| 7 |
+
"""
|
| 8 |
+
|
| 9 |
+
# Unsloth Zoo - Utilities for Unsloth
|
| 10 |
+
# Copyright 2023-present Daniel Han-Chen, Michael Han-Chen & the Unsloth team. All rights reserved.
|
| 11 |
+
#
|
| 12 |
+
# This program is free software: you can redistribute it and/or modify
|
| 13 |
+
# it under the terms of the GNU Lesser General Public License as published by
|
| 14 |
+
# the Free Software Foundation, either version 3 of the License, or
|
| 15 |
+
# (at your option) any later version.
|
| 16 |
+
#
|
| 17 |
+
# This program is distributed in the hope that it will be useful,
|
| 18 |
+
# but WITHOUT ANY WARRANTY; without even the implied warranty of
|
| 19 |
+
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
| 20 |
+
# GNU General Public License for more details.
|
| 21 |
+
#
|
| 22 |
+
# You should have received a copy of the GNU Lesser General Public License
|
| 23 |
+
# along with this program. If not, see <https://www.gnu.org/licenses/>.
|
| 24 |
+
|
| 25 |
+
import os
|
| 26 |
+
import importlib.util
|
| 27 |
+
if importlib.util.find_spec("unsloth_studio") is None:
|
| 28 |
+
UNSLOTH_STUDIO_ENABLED = False
|
| 29 |
+
else:
|
| 30 |
+
UNSLOTH_STUDIO_ENABLED = os.environ.get("UNSLOTH_STUDIO_DISABLED", "0") == "0"
|
| 31 |
+
pass
|
| 32 |
+
from typing import List, Dict, Tuple, Optional, Any, Callable
|
| 33 |
+
import math
|
| 34 |
+
|
| 35 |
+
torch_compile_options = {'epilogue_fusion': True, 'max_autotune': False, 'shape_padding': True, 'trace.enabled': False, 'triton.cudagraphs': False}
|
| 36 |
+
from torch import Tensor
|
| 37 |
+
import torch
|
| 38 |
+
import torch.nn as nn
|
| 39 |
+
from torch.nn import functional as F
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
def forward(self, input: Tensor) -> Tensor:
|
| 43 |
+
return self._conv_forward(input, self.weight, self.bias).to(input.dtype).to(input.dtype)
|
unsloth_compiled_cache/Conv2d.py
ADDED
|
@@ -0,0 +1,43 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
2025.3.17
|
| 3 |
+
2025.3.19
|
| 4 |
+
4.50.0
|
| 5 |
+
0.15.2
|
| 6 |
+
__UNSLOTH_VERSIONING__
|
| 7 |
+
"""
|
| 8 |
+
|
| 9 |
+
# Unsloth Zoo - Utilities for Unsloth
|
| 10 |
+
# Copyright 2023-present Daniel Han-Chen, Michael Han-Chen & the Unsloth team. All rights reserved.
|
| 11 |
+
#
|
| 12 |
+
# This program is free software: you can redistribute it and/or modify
|
| 13 |
+
# it under the terms of the GNU Lesser General Public License as published by
|
| 14 |
+
# the Free Software Foundation, either version 3 of the License, or
|
| 15 |
+
# (at your option) any later version.
|
| 16 |
+
#
|
| 17 |
+
# This program is distributed in the hope that it will be useful,
|
| 18 |
+
# but WITHOUT ANY WARRANTY; without even the implied warranty of
|
| 19 |
+
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
| 20 |
+
# GNU General Public License for more details.
|
| 21 |
+
#
|
| 22 |
+
# You should have received a copy of the GNU Lesser General Public License
|
| 23 |
+
# along with this program. If not, see <https://www.gnu.org/licenses/>.
|
| 24 |
+
|
| 25 |
+
import os
|
| 26 |
+
import importlib.util
|
| 27 |
+
if importlib.util.find_spec("unsloth_studio") is None:
|
| 28 |
+
UNSLOTH_STUDIO_ENABLED = False
|
| 29 |
+
else:
|
| 30 |
+
UNSLOTH_STUDIO_ENABLED = os.environ.get("UNSLOTH_STUDIO_DISABLED", "0") == "0"
|
| 31 |
+
pass
|
| 32 |
+
from typing import List, Dict, Tuple, Optional, Any, Callable
|
| 33 |
+
import math
|
| 34 |
+
|
| 35 |
+
torch_compile_options = {'epilogue_fusion': True, 'max_autotune': False, 'shape_padding': True, 'trace.enabled': False, 'triton.cudagraphs': False}
|
| 36 |
+
from torch import Tensor
|
| 37 |
+
import torch
|
| 38 |
+
import torch.nn as nn
|
| 39 |
+
from torch.nn import functional as F
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
def forward(self, input: Tensor) -> Tensor:
|
| 43 |
+
return self._conv_forward(input, self.weight, self.bias).to(input.dtype).to(input.dtype)
|
unsloth_compiled_cache/Conv3d.py
ADDED
|
@@ -0,0 +1,43 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
2025.3.17
|
| 3 |
+
2025.3.19
|
| 4 |
+
4.50.0
|
| 5 |
+
0.15.2
|
| 6 |
+
__UNSLOTH_VERSIONING__
|
| 7 |
+
"""
|
| 8 |
+
|
| 9 |
+
# Unsloth Zoo - Utilities for Unsloth
|
| 10 |
+
# Copyright 2023-present Daniel Han-Chen, Michael Han-Chen & the Unsloth team. All rights reserved.
|
| 11 |
+
#
|
| 12 |
+
# This program is free software: you can redistribute it and/or modify
|
| 13 |
+
# it under the terms of the GNU Lesser General Public License as published by
|
| 14 |
+
# the Free Software Foundation, either version 3 of the License, or
|
| 15 |
+
# (at your option) any later version.
|
| 16 |
+
#
|
| 17 |
+
# This program is distributed in the hope that it will be useful,
|
| 18 |
+
# but WITHOUT ANY WARRANTY; without even the implied warranty of
|
| 19 |
+
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
| 20 |
+
# GNU General Public License for more details.
|
| 21 |
+
#
|
| 22 |
+
# You should have received a copy of the GNU Lesser General Public License
|
| 23 |
+
# along with this program. If not, see <https://www.gnu.org/licenses/>.
|
| 24 |
+
|
| 25 |
+
import os
|
| 26 |
+
import importlib.util
|
| 27 |
+
if importlib.util.find_spec("unsloth_studio") is None:
|
| 28 |
+
UNSLOTH_STUDIO_ENABLED = False
|
| 29 |
+
else:
|
| 30 |
+
UNSLOTH_STUDIO_ENABLED = os.environ.get("UNSLOTH_STUDIO_DISABLED", "0") == "0"
|
| 31 |
+
pass
|
| 32 |
+
from typing import List, Dict, Tuple, Optional, Any, Callable
|
| 33 |
+
import math
|
| 34 |
+
|
| 35 |
+
torch_compile_options = {'epilogue_fusion': True, 'max_autotune': False, 'shape_padding': True, 'trace.enabled': False, 'triton.cudagraphs': False}
|
| 36 |
+
from torch import Tensor
|
| 37 |
+
import torch
|
| 38 |
+
import torch.nn as nn
|
| 39 |
+
from torch.nn import functional as F
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
def forward(self, input: Tensor) -> Tensor:
|
| 43 |
+
return self._conv_forward(input, self.weight, self.bias).to(input.dtype).to(input.dtype)
|
unsloth_compiled_cache/ConvTranspose1d.py
ADDED
|
@@ -0,0 +1,70 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
2025.3.17
|
| 3 |
+
2025.3.19
|
| 4 |
+
4.50.0
|
| 5 |
+
0.15.2
|
| 6 |
+
__UNSLOTH_VERSIONING__
|
| 7 |
+
"""
|
| 8 |
+
|
| 9 |
+
# Unsloth Zoo - Utilities for Unsloth
|
| 10 |
+
# Copyright 2023-present Daniel Han-Chen, Michael Han-Chen & the Unsloth team. All rights reserved.
|
| 11 |
+
#
|
| 12 |
+
# This program is free software: you can redistribute it and/or modify
|
| 13 |
+
# it under the terms of the GNU Lesser General Public License as published by
|
| 14 |
+
# the Free Software Foundation, either version 3 of the License, or
|
| 15 |
+
# (at your option) any later version.
|
| 16 |
+
#
|
| 17 |
+
# This program is distributed in the hope that it will be useful,
|
| 18 |
+
# but WITHOUT ANY WARRANTY; without even the implied warranty of
|
| 19 |
+
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
| 20 |
+
# GNU General Public License for more details.
|
| 21 |
+
#
|
| 22 |
+
# You should have received a copy of the GNU Lesser General Public License
|
| 23 |
+
# along with this program. If not, see <https://www.gnu.org/licenses/>.
|
| 24 |
+
|
| 25 |
+
import os
|
| 26 |
+
import importlib.util
|
| 27 |
+
if importlib.util.find_spec("unsloth_studio") is None:
|
| 28 |
+
UNSLOTH_STUDIO_ENABLED = False
|
| 29 |
+
else:
|
| 30 |
+
UNSLOTH_STUDIO_ENABLED = os.environ.get("UNSLOTH_STUDIO_DISABLED", "0") == "0"
|
| 31 |
+
pass
|
| 32 |
+
from typing import List, Dict, Tuple, Optional, Any, Callable
|
| 33 |
+
import math
|
| 34 |
+
|
| 35 |
+
torch_compile_options = {'epilogue_fusion': True, 'max_autotune': False, 'shape_padding': True, 'trace.enabled': False, 'triton.cudagraphs': False}
|
| 36 |
+
from torch import Tensor
|
| 37 |
+
import torch
|
| 38 |
+
import torch.nn as nn
|
| 39 |
+
from torch.nn import functional as F
|
| 40 |
+
from transformers.models.gemma3.modeling_gemma3 import (List, Optional, Tuple, nn)
|
| 41 |
+
|
| 42 |
+
def forward(self, input: Tensor, output_size: Optional[List[int]] = None) -> Tensor:
|
| 43 |
+
if self.padding_mode != "zeros":
|
| 44 |
+
raise ValueError(
|
| 45 |
+
"Only `zeros` padding mode is supported for ConvTranspose1d"
|
| 46 |
+
)
|
| 47 |
+
|
| 48 |
+
assert isinstance(self.padding, tuple)
|
| 49 |
+
# One cannot replace List by Tuple or Sequence in "_output_padding" because
|
| 50 |
+
# TorchScript does not support `Sequence[T]` or `Tuple[T, ...]`.
|
| 51 |
+
num_spatial_dims = 1
|
| 52 |
+
output_padding = self._output_padding(
|
| 53 |
+
input,
|
| 54 |
+
output_size,
|
| 55 |
+
self.stride, # type: ignore[arg-type]
|
| 56 |
+
self.padding, # type: ignore[arg-type]
|
| 57 |
+
self.kernel_size, # type: ignore[arg-type]
|
| 58 |
+
num_spatial_dims,
|
| 59 |
+
self.dilation, # type: ignore[arg-type]
|
| 60 |
+
)
|
| 61 |
+
return F.conv_transpose1d(
|
| 62 |
+
input,
|
| 63 |
+
self.weight,
|
| 64 |
+
self.bias,
|
| 65 |
+
self.stride,
|
| 66 |
+
self.padding,
|
| 67 |
+
output_padding,
|
| 68 |
+
self.groups,
|
| 69 |
+
self.dilation,
|
| 70 |
+
).to(input.dtype).to(input.dtype)
|
unsloth_compiled_cache/ConvTranspose2d.py
ADDED
|
@@ -0,0 +1,71 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
2025.3.17
|
| 3 |
+
2025.3.19
|
| 4 |
+
4.50.0
|
| 5 |
+
0.15.2
|
| 6 |
+
__UNSLOTH_VERSIONING__
|
| 7 |
+
"""
|
| 8 |
+
|
| 9 |
+
# Unsloth Zoo - Utilities for Unsloth
|
| 10 |
+
# Copyright 2023-present Daniel Han-Chen, Michael Han-Chen & the Unsloth team. All rights reserved.
|
| 11 |
+
#
|
| 12 |
+
# This program is free software: you can redistribute it and/or modify
|
| 13 |
+
# it under the terms of the GNU Lesser General Public License as published by
|
| 14 |
+
# the Free Software Foundation, either version 3 of the License, or
|
| 15 |
+
# (at your option) any later version.
|
| 16 |
+
#
|
| 17 |
+
# This program is distributed in the hope that it will be useful,
|
| 18 |
+
# but WITHOUT ANY WARRANTY; without even the implied warranty of
|
| 19 |
+
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
| 20 |
+
# GNU General Public License for more details.
|
| 21 |
+
#
|
| 22 |
+
# You should have received a copy of the GNU Lesser General Public License
|
| 23 |
+
# along with this program. If not, see <https://www.gnu.org/licenses/>.
|
| 24 |
+
|
| 25 |
+
import os
|
| 26 |
+
import importlib.util
|
| 27 |
+
if importlib.util.find_spec("unsloth_studio") is None:
|
| 28 |
+
UNSLOTH_STUDIO_ENABLED = False
|
| 29 |
+
else:
|
| 30 |
+
UNSLOTH_STUDIO_ENABLED = os.environ.get("UNSLOTH_STUDIO_DISABLED", "0") == "0"
|
| 31 |
+
pass
|
| 32 |
+
from typing import List, Dict, Tuple, Optional, Any, Callable
|
| 33 |
+
import math
|
| 34 |
+
|
| 35 |
+
torch_compile_options = {'epilogue_fusion': True, 'max_autotune': False, 'shape_padding': True, 'trace.enabled': False, 'triton.cudagraphs': False}
|
| 36 |
+
from torch import Tensor
|
| 37 |
+
import torch
|
| 38 |
+
import torch.nn as nn
|
| 39 |
+
from torch.nn import functional as F
|
| 40 |
+
from transformers.models.gemma3.modeling_gemma3 import (List, Optional, Tuple, nn)
|
| 41 |
+
|
| 42 |
+
def forward(self, input: Tensor, output_size: Optional[List[int]] = None) -> Tensor:
|
| 43 |
+
if self.padding_mode != "zeros":
|
| 44 |
+
raise ValueError(
|
| 45 |
+
"Only `zeros` padding mode is supported for ConvTranspose2d"
|
| 46 |
+
)
|
| 47 |
+
|
| 48 |
+
assert isinstance(self.padding, tuple)
|
| 49 |
+
# One cannot replace List by Tuple or Sequence in "_output_padding" because
|
| 50 |
+
# TorchScript does not support `Sequence[T]` or `Tuple[T, ...]`.
|
| 51 |
+
num_spatial_dims = 2
|
| 52 |
+
output_padding = self._output_padding(
|
| 53 |
+
input,
|
| 54 |
+
output_size,
|
| 55 |
+
self.stride, # type: ignore[arg-type]
|
| 56 |
+
self.padding, # type: ignore[arg-type]
|
| 57 |
+
self.kernel_size, # type: ignore[arg-type]
|
| 58 |
+
num_spatial_dims,
|
| 59 |
+
self.dilation, # type: ignore[arg-type]
|
| 60 |
+
)
|
| 61 |
+
|
| 62 |
+
return F.conv_transpose2d(
|
| 63 |
+
input,
|
| 64 |
+
self.weight,
|
| 65 |
+
self.bias,
|
| 66 |
+
self.stride,
|
| 67 |
+
self.padding,
|
| 68 |
+
output_padding,
|
| 69 |
+
self.groups,
|
| 70 |
+
self.dilation,
|
| 71 |
+
).to(input.dtype).to(input.dtype)
|
unsloth_compiled_cache/ConvTranspose3d.py
ADDED
|
@@ -0,0 +1,71 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
2025.3.17
|
| 3 |
+
2025.3.19
|
| 4 |
+
4.50.0
|
| 5 |
+
0.15.2
|
| 6 |
+
__UNSLOTH_VERSIONING__
|
| 7 |
+
"""
|
| 8 |
+
|
| 9 |
+
# Unsloth Zoo - Utilities for Unsloth
|
| 10 |
+
# Copyright 2023-present Daniel Han-Chen, Michael Han-Chen & the Unsloth team. All rights reserved.
|
| 11 |
+
#
|
| 12 |
+
# This program is free software: you can redistribute it and/or modify
|
| 13 |
+
# it under the terms of the GNU Lesser General Public License as published by
|
| 14 |
+
# the Free Software Foundation, either version 3 of the License, or
|
| 15 |
+
# (at your option) any later version.
|
| 16 |
+
#
|
| 17 |
+
# This program is distributed in the hope that it will be useful,
|
| 18 |
+
# but WITHOUT ANY WARRANTY; without even the implied warranty of
|
| 19 |
+
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
| 20 |
+
# GNU General Public License for more details.
|
| 21 |
+
#
|
| 22 |
+
# You should have received a copy of the GNU Lesser General Public License
|
| 23 |
+
# along with this program. If not, see <https://www.gnu.org/licenses/>.
|
| 24 |
+
|
| 25 |
+
import os
|
| 26 |
+
import importlib.util
|
| 27 |
+
if importlib.util.find_spec("unsloth_studio") is None:
|
| 28 |
+
UNSLOTH_STUDIO_ENABLED = False
|
| 29 |
+
else:
|
| 30 |
+
UNSLOTH_STUDIO_ENABLED = os.environ.get("UNSLOTH_STUDIO_DISABLED", "0") == "0"
|
| 31 |
+
pass
|
| 32 |
+
from typing import List, Dict, Tuple, Optional, Any, Callable
|
| 33 |
+
import math
|
| 34 |
+
|
| 35 |
+
torch_compile_options = {'epilogue_fusion': True, 'max_autotune': False, 'shape_padding': True, 'trace.enabled': False, 'triton.cudagraphs': False}
|
| 36 |
+
from torch import Tensor
|
| 37 |
+
import torch
|
| 38 |
+
import torch.nn as nn
|
| 39 |
+
from torch.nn import functional as F
|
| 40 |
+
from transformers.models.gemma3.modeling_gemma3 import (List, Optional, Tuple, nn)
|
| 41 |
+
|
| 42 |
+
def forward(self, input: Tensor, output_size: Optional[List[int]] = None) -> Tensor:
|
| 43 |
+
if self.padding_mode != "zeros":
|
| 44 |
+
raise ValueError(
|
| 45 |
+
"Only `zeros` padding mode is supported for ConvTranspose3d"
|
| 46 |
+
)
|
| 47 |
+
|
| 48 |
+
assert isinstance(self.padding, tuple)
|
| 49 |
+
# One cannot replace List by Tuple or Sequence in "_output_padding" because
|
| 50 |
+
# TorchScript does not support `Sequence[T]` or `Tuple[T, ...]`.
|
| 51 |
+
num_spatial_dims = 3
|
| 52 |
+
output_padding = self._output_padding(
|
| 53 |
+
input,
|
| 54 |
+
output_size,
|
| 55 |
+
self.stride, # type: ignore[arg-type]
|
| 56 |
+
self.padding, # type: ignore[arg-type]
|
| 57 |
+
self.kernel_size, # type: ignore[arg-type]
|
| 58 |
+
num_spatial_dims,
|
| 59 |
+
self.dilation, # type: ignore[arg-type]
|
| 60 |
+
)
|
| 61 |
+
|
| 62 |
+
return F.conv_transpose3d(
|
| 63 |
+
input,
|
| 64 |
+
self.weight,
|
| 65 |
+
self.bias,
|
| 66 |
+
self.stride,
|
| 67 |
+
self.padding,
|
| 68 |
+
output_padding,
|
| 69 |
+
self.groups,
|
| 70 |
+
self.dilation,
|
| 71 |
+
).to(input.dtype).to(input.dtype)
|
unsloth_compiled_cache/GPTQLoraLinear_peft_forward.py
ADDED
|
@@ -0,0 +1,73 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
2025.3.17
|
| 3 |
+
2025.3.19
|
| 4 |
+
4.50.0
|
| 5 |
+
0.15.2
|
| 6 |
+
__UNSLOTH_VERSIONING__
|
| 7 |
+
"""
|
| 8 |
+
|
| 9 |
+
torch_compile_options = {'epilogue_fusion': True, 'max_autotune': False, 'shape_padding': True, 'trace.enabled': False, 'triton.cudagraphs': False}
|
| 10 |
+
from torch import Tensor
|
| 11 |
+
import torch
|
| 12 |
+
import torch.nn as nn
|
| 13 |
+
from torch.nn import functional as F
|
| 14 |
+
from peft.tuners.lora.gptq import (torch)
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
torch_addmm = torch.addmm
|
| 18 |
+
torch_add = torch.add
|
| 19 |
+
# @torch.compile(fullgraph = False, dynamic = True, options = torch_compile_options)
|
| 20 |
+
def lora_forward(result, lora_A, lora_B, dropout, x, scaling):
|
| 21 |
+
xA = dropout(x) @ lora_A.weight.t()
|
| 22 |
+
# output = result + scaling * xA @ lora_B.weight.t()
|
| 23 |
+
shape = result.shape
|
| 24 |
+
output = torch_addmm(
|
| 25 |
+
result.view(-1, shape[-1]),
|
| 26 |
+
xA.view(-1, xA.shape[-1]),
|
| 27 |
+
lora_B.weight.t(),
|
| 28 |
+
alpha = scaling,
|
| 29 |
+
beta = 1,
|
| 30 |
+
).view(shape)
|
| 31 |
+
|
| 32 |
+
bias = lora_B.bias
|
| 33 |
+
if bias is not None:
|
| 34 |
+
output = torch_add(
|
| 35 |
+
output,
|
| 36 |
+
bias,
|
| 37 |
+
alpha = scaling,
|
| 38 |
+
)
|
| 39 |
+
return output
|
| 40 |
+
pass
|
| 41 |
+
|
| 42 |
+
def unsloth_forward(self, x: torch.Tensor):
|
| 43 |
+
# note: logic differs from default Linear because merging is not supported
|
| 44 |
+
result = self.quant_linear_module(x)
|
| 45 |
+
|
| 46 |
+
if self.disable_adapters:
|
| 47 |
+
return result
|
| 48 |
+
|
| 49 |
+
lora_A_keys = self.lora_A.keys()
|
| 50 |
+
for active_adapter in self.active_adapters:
|
| 51 |
+
if active_adapter not in lora_A_keys:
|
| 52 |
+
continue
|
| 53 |
+
|
| 54 |
+
lora_A = self.lora_A[active_adapter]
|
| 55 |
+
lora_B = self.lora_B[active_adapter]
|
| 56 |
+
dropout = self.lora_dropout[active_adapter]
|
| 57 |
+
scaling = self.scaling[active_adapter]
|
| 58 |
+
|
| 59 |
+
requires_conversion = not torch.is_autocast_enabled()
|
| 60 |
+
if requires_conversion:
|
| 61 |
+
expected_dtype = result.dtype
|
| 62 |
+
x = self._cast_input_dtype(x, lora_A.weight.dtype)
|
| 63 |
+
|
| 64 |
+
output = lora_B(lora_A(dropout(x)))
|
| 65 |
+
|
| 66 |
+
if requires_conversion:
|
| 67 |
+
output = output.to(expected_dtype)
|
| 68 |
+
|
| 69 |
+
if scaling != 1: # skip scaling == 1 no-op
|
| 70 |
+
output = output * scaling
|
| 71 |
+
|
| 72 |
+
result += output
|
| 73 |
+
return result
|
unsloth_compiled_cache/GroupNorm.py
ADDED
|
@@ -0,0 +1,43 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
2025.3.17
|
| 3 |
+
2025.3.19
|
| 4 |
+
4.50.0
|
| 5 |
+
0.15.2
|
| 6 |
+
__UNSLOTH_VERSIONING__
|
| 7 |
+
"""
|
| 8 |
+
|
| 9 |
+
# Unsloth Zoo - Utilities for Unsloth
|
| 10 |
+
# Copyright 2023-present Daniel Han-Chen, Michael Han-Chen & the Unsloth team. All rights reserved.
|
| 11 |
+
#
|
| 12 |
+
# This program is free software: you can redistribute it and/or modify
|
| 13 |
+
# it under the terms of the GNU Lesser General Public License as published by
|
| 14 |
+
# the Free Software Foundation, either version 3 of the License, or
|
| 15 |
+
# (at your option) any later version.
|
| 16 |
+
#
|
| 17 |
+
# This program is distributed in the hope that it will be useful,
|
| 18 |
+
# but WITHOUT ANY WARRANTY; without even the implied warranty of
|
| 19 |
+
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
| 20 |
+
# GNU General Public License for more details.
|
| 21 |
+
#
|
| 22 |
+
# You should have received a copy of the GNU Lesser General Public License
|
| 23 |
+
# along with this program. If not, see <https://www.gnu.org/licenses/>.
|
| 24 |
+
|
| 25 |
+
import os
|
| 26 |
+
import importlib.util
|
| 27 |
+
if importlib.util.find_spec("unsloth_studio") is None:
|
| 28 |
+
UNSLOTH_STUDIO_ENABLED = False
|
| 29 |
+
else:
|
| 30 |
+
UNSLOTH_STUDIO_ENABLED = os.environ.get("UNSLOTH_STUDIO_DISABLED", "0") == "0"
|
| 31 |
+
pass
|
| 32 |
+
from typing import List, Dict, Tuple, Optional, Any, Callable
|
| 33 |
+
import math
|
| 34 |
+
|
| 35 |
+
torch_compile_options = {'epilogue_fusion': True, 'max_autotune': False, 'shape_padding': True, 'trace.enabled': False, 'triton.cudagraphs': False}
|
| 36 |
+
from torch import Tensor
|
| 37 |
+
import torch
|
| 38 |
+
import torch.nn as nn
|
| 39 |
+
from torch.nn import functional as F
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
def forward(self, input: Tensor) -> Tensor:
|
| 43 |
+
return F.group_norm(input, self.num_groups, self.weight, self.bias, self.eps).to(input.dtype).to(input.dtype)
|
unsloth_compiled_cache/LayerNorm.py
ADDED
|
@@ -0,0 +1,45 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
2025.3.17
|
| 3 |
+
2025.3.19
|
| 4 |
+
4.50.0
|
| 5 |
+
0.15.2
|
| 6 |
+
__UNSLOTH_VERSIONING__
|
| 7 |
+
"""
|
| 8 |
+
|
| 9 |
+
# Unsloth Zoo - Utilities for Unsloth
|
| 10 |
+
# Copyright 2023-present Daniel Han-Chen, Michael Han-Chen & the Unsloth team. All rights reserved.
|
| 11 |
+
#
|
| 12 |
+
# This program is free software: you can redistribute it and/or modify
|
| 13 |
+
# it under the terms of the GNU Lesser General Public License as published by
|
| 14 |
+
# the Free Software Foundation, either version 3 of the License, or
|
| 15 |
+
# (at your option) any later version.
|
| 16 |
+
#
|
| 17 |
+
# This program is distributed in the hope that it will be useful,
|
| 18 |
+
# but WITHOUT ANY WARRANTY; without even the implied warranty of
|
| 19 |
+
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
| 20 |
+
# GNU General Public License for more details.
|
| 21 |
+
#
|
| 22 |
+
# You should have received a copy of the GNU Lesser General Public License
|
| 23 |
+
# along with this program. If not, see <https://www.gnu.org/licenses/>.
|
| 24 |
+
|
| 25 |
+
import os
|
| 26 |
+
import importlib.util
|
| 27 |
+
if importlib.util.find_spec("unsloth_studio") is None:
|
| 28 |
+
UNSLOTH_STUDIO_ENABLED = False
|
| 29 |
+
else:
|
| 30 |
+
UNSLOTH_STUDIO_ENABLED = os.environ.get("UNSLOTH_STUDIO_DISABLED", "0") == "0"
|
| 31 |
+
pass
|
| 32 |
+
from typing import List, Dict, Tuple, Optional, Any, Callable
|
| 33 |
+
import math
|
| 34 |
+
|
| 35 |
+
torch_compile_options = {'epilogue_fusion': True, 'max_autotune': False, 'shape_padding': True, 'trace.enabled': False, 'triton.cudagraphs': False}
|
| 36 |
+
from torch import Tensor
|
| 37 |
+
import torch
|
| 38 |
+
import torch.nn as nn
|
| 39 |
+
from torch.nn import functional as F
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
def forward(self, input: Tensor) -> Tensor:
|
| 43 |
+
return F.layer_norm(
|
| 44 |
+
input, self.normalized_shape, self.weight, self.bias, self.eps
|
| 45 |
+
).to(input.dtype).to(input.dtype)
|
unsloth_compiled_cache/Linear4bit_peft_forward.py
ADDED
|
@@ -0,0 +1,97 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
2025.3.17
|
| 3 |
+
2025.3.19
|
| 4 |
+
4.50.0
|
| 5 |
+
0.15.2
|
| 6 |
+
__UNSLOTH_VERSIONING__
|
| 7 |
+
"""
|
| 8 |
+
|
| 9 |
+
torch_compile_options = {'epilogue_fusion': True, 'max_autotune': False, 'shape_padding': True, 'trace.enabled': False, 'triton.cudagraphs': False}
|
| 10 |
+
from torch import Tensor
|
| 11 |
+
import torch
|
| 12 |
+
import torch.nn as nn
|
| 13 |
+
from torch.nn import functional as F
|
| 14 |
+
from peft.tuners.lora.bnb import (torch)
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
torch_addmm = torch.addmm
|
| 18 |
+
torch_add = torch.add
|
| 19 |
+
# @torch.compile(fullgraph = False, dynamic = True, options = torch_compile_options)
|
| 20 |
+
def lora_forward(result, lora_A, lora_B, dropout, x, scaling):
|
| 21 |
+
xA = dropout(x) @ lora_A.weight.t()
|
| 22 |
+
# output = result + scaling * xA @ lora_B.weight.t()
|
| 23 |
+
shape = result.shape
|
| 24 |
+
output = torch_addmm(
|
| 25 |
+
result.view(-1, shape[-1]),
|
| 26 |
+
xA.view(-1, xA.shape[-1]),
|
| 27 |
+
lora_B.weight.t(),
|
| 28 |
+
alpha = scaling,
|
| 29 |
+
beta = 1,
|
| 30 |
+
).view(shape)
|
| 31 |
+
|
| 32 |
+
bias = lora_B.bias
|
| 33 |
+
if bias is not None:
|
| 34 |
+
output = torch_add(
|
| 35 |
+
output,
|
| 36 |
+
bias,
|
| 37 |
+
alpha = scaling,
|
| 38 |
+
)
|
| 39 |
+
return output
|
| 40 |
+
pass
|
| 41 |
+
|
| 42 |
+
def unsloth_forward(self, x: torch.Tensor, *args, **kwargs) -> torch.Tensor:
|
| 43 |
+
|
| 44 |
+
adapter_names = kwargs.pop("adapter_names", None)
|
| 45 |
+
|
| 46 |
+
if self.disable_adapters:
|
| 47 |
+
if self.merged:
|
| 48 |
+
self.unmerge()
|
| 49 |
+
result = self.base_layer(x, *args, **kwargs)
|
| 50 |
+
elif adapter_names is not None:
|
| 51 |
+
result = self._mixed_batch_forward(x, *args, adapter_names=adapter_names, **kwargs)
|
| 52 |
+
elif self.merged:
|
| 53 |
+
result = self.base_layer(x, *args, **kwargs)
|
| 54 |
+
else:
|
| 55 |
+
result = self.base_layer(x, *args, **kwargs)
|
| 56 |
+
# As per Tim Dettmers, for 4bit, we need to defensively clone here.
|
| 57 |
+
# The reason is that in some cases, an error can occur that backprop
|
| 58 |
+
# does not work on a manipulated view. This issue may be solved with
|
| 59 |
+
# newer PyTorch versions but this would need extensive testing to be
|
| 60 |
+
# sure.
|
| 61 |
+
|
| 62 |
+
|
| 63 |
+
for active_adapter in self.active_adapters:
|
| 64 |
+
if active_adapter not in self.lora_A.keys():
|
| 65 |
+
continue
|
| 66 |
+
lora_A = self.lora_A[active_adapter]
|
| 67 |
+
lora_B = self.lora_B[active_adapter]
|
| 68 |
+
dropout = self.lora_dropout[active_adapter]
|
| 69 |
+
scaling = self.scaling[active_adapter]
|
| 70 |
+
|
| 71 |
+
requires_conversion = not torch.is_autocast_enabled()
|
| 72 |
+
if requires_conversion:
|
| 73 |
+
expected_dtype = result.dtype
|
| 74 |
+
x = self._cast_input_dtype(x, lora_A.weight.dtype)
|
| 75 |
+
|
| 76 |
+
if not self.use_dora[active_adapter]:
|
| 77 |
+
return lora_forward(result, lora_A, lora_B, dropout, x, scaling)
|
| 78 |
+
else:
|
| 79 |
+
if isinstance(dropout, torch.nn.Identity) or not self.training:
|
| 80 |
+
base_result = result
|
| 81 |
+
else:
|
| 82 |
+
x = dropout(x)
|
| 83 |
+
base_result = None
|
| 84 |
+
|
| 85 |
+
output = self.lora_magnitude_vector[active_adapter](
|
| 86 |
+
x,
|
| 87 |
+
lora_A=lora_A,
|
| 88 |
+
lora_B=lora_B,
|
| 89 |
+
scaling=scaling,
|
| 90 |
+
base_layer=self.get_base_layer(),
|
| 91 |
+
base_result=base_result,
|
| 92 |
+
)
|
| 93 |
+
if requires_conversion:
|
| 94 |
+
output = output.to(expected_dtype)
|
| 95 |
+
result = result + output
|
| 96 |
+
|
| 97 |
+
return result
|
unsloth_compiled_cache/Linear8bitLt_peft_forward.py
ADDED
|
@@ -0,0 +1,90 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
2025.3.17
|
| 3 |
+
2025.3.19
|
| 4 |
+
4.50.0
|
| 5 |
+
0.15.2
|
| 6 |
+
__UNSLOTH_VERSIONING__
|
| 7 |
+
"""
|
| 8 |
+
|
| 9 |
+
torch_compile_options = {'epilogue_fusion': True, 'max_autotune': False, 'shape_padding': True, 'trace.enabled': False, 'triton.cudagraphs': False}
|
| 10 |
+
from torch import Tensor
|
| 11 |
+
import torch
|
| 12 |
+
import torch.nn as nn
|
| 13 |
+
from torch.nn import functional as F
|
| 14 |
+
from peft.tuners.lora.bnb import (torch)
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
torch_addmm = torch.addmm
|
| 18 |
+
torch_add = torch.add
|
| 19 |
+
# @torch.compile(fullgraph = False, dynamic = True, options = torch_compile_options)
|
| 20 |
+
def lora_forward(result, lora_A, lora_B, dropout, x, scaling):
|
| 21 |
+
xA = dropout(x) @ lora_A.weight.t()
|
| 22 |
+
# output = result + scaling * xA @ lora_B.weight.t()
|
| 23 |
+
shape = result.shape
|
| 24 |
+
output = torch_addmm(
|
| 25 |
+
result.view(-1, shape[-1]),
|
| 26 |
+
xA.view(-1, xA.shape[-1]),
|
| 27 |
+
lora_B.weight.t(),
|
| 28 |
+
alpha = scaling,
|
| 29 |
+
beta = 1,
|
| 30 |
+
).view(shape)
|
| 31 |
+
|
| 32 |
+
bias = lora_B.bias
|
| 33 |
+
if bias is not None:
|
| 34 |
+
output = torch_add(
|
| 35 |
+
output,
|
| 36 |
+
bias,
|
| 37 |
+
alpha = scaling,
|
| 38 |
+
)
|
| 39 |
+
return output
|
| 40 |
+
pass
|
| 41 |
+
|
| 42 |
+
def unsloth_forward(self, x: torch.Tensor, *args, **kwargs) -> torch.Tensor:
|
| 43 |
+
|
| 44 |
+
adapter_names = kwargs.pop("adapter_names", None)
|
| 45 |
+
|
| 46 |
+
if self.disable_adapters:
|
| 47 |
+
if self.merged:
|
| 48 |
+
self.unmerge()
|
| 49 |
+
result = self.base_layer(x, *args, **kwargs)
|
| 50 |
+
elif adapter_names is not None:
|
| 51 |
+
result = self._mixed_batch_forward(x, *args, adapter_names=adapter_names, **kwargs)
|
| 52 |
+
elif self.merged:
|
| 53 |
+
result = self.base_layer(x, *args, **kwargs)
|
| 54 |
+
else:
|
| 55 |
+
result = self.base_layer(x, *args, **kwargs)
|
| 56 |
+
for active_adapter in self.active_adapters:
|
| 57 |
+
if active_adapter not in self.lora_A.keys():
|
| 58 |
+
continue
|
| 59 |
+
lora_A = self.lora_A[active_adapter]
|
| 60 |
+
lora_B = self.lora_B[active_adapter]
|
| 61 |
+
dropout = self.lora_dropout[active_adapter]
|
| 62 |
+
scaling = self.scaling[active_adapter]
|
| 63 |
+
|
| 64 |
+
requires_conversion = not torch.is_autocast_enabled()
|
| 65 |
+
if requires_conversion:
|
| 66 |
+
expected_dtype = result.dtype
|
| 67 |
+
x = self._cast_input_dtype(x, lora_A.weight.dtype)
|
| 68 |
+
|
| 69 |
+
if not self.use_dora[active_adapter]:
|
| 70 |
+
return lora_forward(result, lora_A, lora_B, dropout, x, scaling)
|
| 71 |
+
else:
|
| 72 |
+
if isinstance(dropout, torch.nn.Identity) or not self.training:
|
| 73 |
+
base_result = result
|
| 74 |
+
else:
|
| 75 |
+
x = dropout(x)
|
| 76 |
+
base_result = None
|
| 77 |
+
|
| 78 |
+
output = self.lora_magnitude_vector[active_adapter](
|
| 79 |
+
x,
|
| 80 |
+
lora_A=lora_A,
|
| 81 |
+
lora_B=lora_B,
|
| 82 |
+
scaling=scaling,
|
| 83 |
+
base_layer=self.get_base_layer(),
|
| 84 |
+
base_result=base_result,
|
| 85 |
+
)
|
| 86 |
+
if requires_conversion:
|
| 87 |
+
output = output.to(expected_dtype)
|
| 88 |
+
result = result + output
|
| 89 |
+
|
| 90 |
+
return result
|
unsloth_compiled_cache/Linear_peft_forward.py
ADDED
|
@@ -0,0 +1,89 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
2025.3.17
|
| 3 |
+
2025.3.19
|
| 4 |
+
4.50.0
|
| 5 |
+
0.15.2
|
| 6 |
+
__UNSLOTH_VERSIONING__
|
| 7 |
+
"""
|
| 8 |
+
|
| 9 |
+
torch_compile_options = {'epilogue_fusion': True, 'max_autotune': False, 'shape_padding': True, 'trace.enabled': False, 'triton.cudagraphs': False}
|
| 10 |
+
from torch import Tensor
|
| 11 |
+
import torch
|
| 12 |
+
import torch.nn as nn
|
| 13 |
+
from torch.nn import functional as F
|
| 14 |
+
from peft.tuners.lora.layer import (Any, F, nn, torch)
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
torch_addmm = torch.addmm
|
| 18 |
+
torch_add = torch.add
|
| 19 |
+
# @torch.compile(fullgraph = False, dynamic = True, options = torch_compile_options)
|
| 20 |
+
def lora_forward(result, lora_A, lora_B, dropout, x, scaling):
|
| 21 |
+
xA = dropout(x) @ lora_A.weight.t()
|
| 22 |
+
# output = result + scaling * xA @ lora_B.weight.t()
|
| 23 |
+
shape = result.shape
|
| 24 |
+
output = torch_addmm(
|
| 25 |
+
result.view(-1, shape[-1]),
|
| 26 |
+
xA.view(-1, xA.shape[-1]),
|
| 27 |
+
lora_B.weight.t(),
|
| 28 |
+
alpha = scaling,
|
| 29 |
+
beta = 1,
|
| 30 |
+
).view(shape)
|
| 31 |
+
|
| 32 |
+
bias = lora_B.bias
|
| 33 |
+
if bias is not None:
|
| 34 |
+
output = torch_add(
|
| 35 |
+
output,
|
| 36 |
+
bias,
|
| 37 |
+
alpha = scaling,
|
| 38 |
+
)
|
| 39 |
+
return output
|
| 40 |
+
pass
|
| 41 |
+
|
| 42 |
+
def unsloth_forward(self, x: torch.Tensor, *args: Any, **kwargs: Any) -> torch.Tensor:
|
| 43 |
+
|
| 44 |
+
adapter_names = kwargs.pop("adapter_names", None)
|
| 45 |
+
|
| 46 |
+
if self.disable_adapters:
|
| 47 |
+
if self.merged:
|
| 48 |
+
self.unmerge()
|
| 49 |
+
result = self.base_layer(x, *args, **kwargs)
|
| 50 |
+
elif adapter_names is not None:
|
| 51 |
+
result = self._mixed_batch_forward(x, *args, adapter_names=adapter_names, **kwargs)
|
| 52 |
+
elif self.merged:
|
| 53 |
+
result = self.base_layer(x, *args, **kwargs)
|
| 54 |
+
else:
|
| 55 |
+
result = self.base_layer(x, *args, **kwargs)
|
| 56 |
+
torch_result_dtype = result.dtype
|
| 57 |
+
|
| 58 |
+
lora_A_keys = self.lora_A.keys()
|
| 59 |
+
for active_adapter in self.active_adapters:
|
| 60 |
+
if active_adapter not in lora_A_keys:
|
| 61 |
+
continue
|
| 62 |
+
|
| 63 |
+
lora_A = self.lora_A[active_adapter]
|
| 64 |
+
lora_B = self.lora_B[active_adapter]
|
| 65 |
+
dropout = self.lora_dropout[active_adapter]
|
| 66 |
+
scaling = self.scaling[active_adapter]
|
| 67 |
+
if not torch.is_autocast_enabled(): result, x = result.to(lora_A.weight.dtype), x.to(lora_A.weight.dtype)
|
| 68 |
+
|
| 69 |
+
if not self.use_dora[active_adapter]:
|
| 70 |
+
return lora_forward(result, lora_A, lora_B, dropout, x, scaling)
|
| 71 |
+
else:
|
| 72 |
+
if isinstance(dropout, nn.Identity) or not self.training:
|
| 73 |
+
base_result = result
|
| 74 |
+
else:
|
| 75 |
+
x = dropout(x)
|
| 76 |
+
base_result = None
|
| 77 |
+
|
| 78 |
+
result = result + self.lora_magnitude_vector[active_adapter](
|
| 79 |
+
x,
|
| 80 |
+
lora_A=lora_A,
|
| 81 |
+
lora_B=lora_B,
|
| 82 |
+
scaling=scaling,
|
| 83 |
+
base_layer=self.get_base_layer(),
|
| 84 |
+
base_result=base_result,
|
| 85 |
+
)
|
| 86 |
+
|
| 87 |
+
result = result.to(torch_result_dtype)
|
| 88 |
+
|
| 89 |
+
return result
|
unsloth_compiled_cache/LoraParallelLinear_peft_forward.py
ADDED
|
@@ -0,0 +1,87 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
2025.3.17
|
| 3 |
+
2025.3.19
|
| 4 |
+
4.50.0
|
| 5 |
+
0.15.2
|
| 6 |
+
__UNSLOTH_VERSIONING__
|
| 7 |
+
"""
|
| 8 |
+
|
| 9 |
+
torch_compile_options = {'epilogue_fusion': True, 'max_autotune': False, 'shape_padding': True, 'trace.enabled': False, 'triton.cudagraphs': False}
|
| 10 |
+
from torch import Tensor
|
| 11 |
+
import torch
|
| 12 |
+
import torch.nn as nn
|
| 13 |
+
from torch.nn import functional as F
|
| 14 |
+
from peft.tuners.lora.tp_layer import (Any, __name__, nn, torch)
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
torch_addmm = torch.addmm
|
| 18 |
+
torch_add = torch.add
|
| 19 |
+
# @torch.compile(fullgraph = False, dynamic = True, options = torch_compile_options)
|
| 20 |
+
def lora_forward(result, lora_A, lora_B, dropout, x, scaling):
|
| 21 |
+
xA = dropout(x) @ lora_A.weight.t()
|
| 22 |
+
# output = result + scaling * xA @ lora_B.weight.t()
|
| 23 |
+
shape = result.shape
|
| 24 |
+
output = torch_addmm(
|
| 25 |
+
result.view(-1, shape[-1]),
|
| 26 |
+
xA.view(-1, xA.shape[-1]),
|
| 27 |
+
lora_B.weight.t(),
|
| 28 |
+
alpha = scaling,
|
| 29 |
+
beta = 1,
|
| 30 |
+
).view(shape)
|
| 31 |
+
|
| 32 |
+
bias = lora_B.bias
|
| 33 |
+
if bias is not None:
|
| 34 |
+
output = torch_add(
|
| 35 |
+
output,
|
| 36 |
+
bias,
|
| 37 |
+
alpha = scaling,
|
| 38 |
+
)
|
| 39 |
+
return output
|
| 40 |
+
pass
|
| 41 |
+
|
| 42 |
+
def unsloth_forward(self, x: torch.Tensor, *args: Any, **kwargs: Any):
|
| 43 |
+
|
| 44 |
+
adapter_names = kwargs.pop("adapter_names", None)
|
| 45 |
+
# If weight is used for matrix multiplication here, the final aggregation operation of the original
|
| 46 |
+
# parallel_linear layer will be missing, so we need to directly call its forward function to obtain the
|
| 47 |
+
# output of the original parallel_linear layer.
|
| 48 |
+
if self.disable_adapters:
|
| 49 |
+
if self.merged:
|
| 50 |
+
self.unmerge()
|
| 51 |
+
result, bias = self.base_layer(x, *args, **kwargs)
|
| 52 |
+
elif adapter_names is not None:
|
| 53 |
+
raise ValueError(f"{self.__class__.__name__} does not support mixed_batch_forward yet.")
|
| 54 |
+
elif self.merged:
|
| 55 |
+
result, bias = self.base_layer(x, *args, **kwargs)
|
| 56 |
+
else:
|
| 57 |
+
result, bias = self.base_layer(x, *args, **kwargs)
|
| 58 |
+
torch_result_dtype = result.dtype
|
| 59 |
+
for active_adapter in self.active_adapters:
|
| 60 |
+
if active_adapter not in self.lora_A.keys():
|
| 61 |
+
continue
|
| 62 |
+
lora_A = self.lora_A[active_adapter]
|
| 63 |
+
lora_B = self.lora_B[active_adapter]
|
| 64 |
+
dropout = self.lora_dropout[active_adapter]
|
| 65 |
+
scaling = self.scaling[active_adapter]
|
| 66 |
+
if not torch.is_autocast_enabled(): result, x = result.to(lora_A.weight.dtype), x.to(lora_A.weight.dtype)
|
| 67 |
+
|
| 68 |
+
if not self.use_dora[active_adapter]:
|
| 69 |
+
return lora_forward(result, lora_A, lora_B, dropout, x, scaling)
|
| 70 |
+
else:
|
| 71 |
+
if isinstance(dropout, torch.nn.Identity) or not self.training:
|
| 72 |
+
base_result = result
|
| 73 |
+
else:
|
| 74 |
+
x = dropout(x)
|
| 75 |
+
base_result = None
|
| 76 |
+
|
| 77 |
+
result = result + self.lora_magnitude_vector[active_adapter](
|
| 78 |
+
x,
|
| 79 |
+
lora_A=lora_A,
|
| 80 |
+
lora_B=lora_B,
|
| 81 |
+
scaling=scaling,
|
| 82 |
+
base_layer=self.get_base_layer(),
|
| 83 |
+
base_result=base_result,
|
| 84 |
+
)
|
| 85 |
+
|
| 86 |
+
result = result.to(torch_result_dtype)
|
| 87 |
+
return result, bias
|
unsloth_compiled_cache/RMSNorm.py
ADDED
|
@@ -0,0 +1,46 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
2025.3.17
|
| 3 |
+
2025.3.19
|
| 4 |
+
4.50.0
|
| 5 |
+
0.15.2
|
| 6 |
+
__UNSLOTH_VERSIONING__
|
| 7 |
+
"""
|
| 8 |
+
|
| 9 |
+
# Unsloth Zoo - Utilities for Unsloth
|
| 10 |
+
# Copyright 2023-present Daniel Han-Chen, Michael Han-Chen & the Unsloth team. All rights reserved.
|
| 11 |
+
#
|
| 12 |
+
# This program is free software: you can redistribute it and/or modify
|
| 13 |
+
# it under the terms of the GNU Lesser General Public License as published by
|
| 14 |
+
# the Free Software Foundation, either version 3 of the License, or
|
| 15 |
+
# (at your option) any later version.
|
| 16 |
+
#
|
| 17 |
+
# This program is distributed in the hope that it will be useful,
|
| 18 |
+
# but WITHOUT ANY WARRANTY; without even the implied warranty of
|
| 19 |
+
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
| 20 |
+
# GNU General Public License for more details.
|
| 21 |
+
#
|
| 22 |
+
# You should have received a copy of the GNU Lesser General Public License
|
| 23 |
+
# along with this program. If not, see <https://www.gnu.org/licenses/>.
|
| 24 |
+
|
| 25 |
+
import os
|
| 26 |
+
import importlib.util
|
| 27 |
+
if importlib.util.find_spec("unsloth_studio") is None:
|
| 28 |
+
UNSLOTH_STUDIO_ENABLED = False
|
| 29 |
+
else:
|
| 30 |
+
UNSLOTH_STUDIO_ENABLED = os.environ.get("UNSLOTH_STUDIO_DISABLED", "0") == "0"
|
| 31 |
+
pass
|
| 32 |
+
from typing import List, Dict, Tuple, Optional, Any, Callable
|
| 33 |
+
import math
|
| 34 |
+
|
| 35 |
+
torch_compile_options = {'epilogue_fusion': True, 'max_autotune': False, 'shape_padding': True, 'trace.enabled': False, 'triton.cudagraphs': False}
|
| 36 |
+
from torch import Tensor
|
| 37 |
+
import torch
|
| 38 |
+
import torch.nn as nn
|
| 39 |
+
from torch.nn import functional as F
|
| 40 |
+
from transformers.models.gemma3.modeling_gemma3 import (torch)
|
| 41 |
+
|
| 42 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 43 |
+
"""
|
| 44 |
+
Runs forward pass.
|
| 45 |
+
"""
|
| 46 |
+
return F.rms_norm(x, self.normalized_shape, self.weight, self.eps).to(input.dtype).to(input.dtype)
|
unsloth_compiled_cache/UnslothAlignPropTrainer.py
ADDED
|
@@ -0,0 +1,637 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
2025.3.17
|
| 3 |
+
2025.3.19
|
| 4 |
+
4.50.0
|
| 5 |
+
0.15.2
|
| 6 |
+
__UNSLOTH_VERSIONING__
|
| 7 |
+
"""
|
| 8 |
+
from torch import Tensor
|
| 9 |
+
import torch
|
| 10 |
+
import torch.nn as nn
|
| 11 |
+
from torch.nn import functional as F
|
| 12 |
+
from trl.trainer.alignprop_trainer import (Accelerator, AlignPropConfig, AlignPropTrainer, Any, Callable, DDPOStableDiffusionPipeline, Optional, ProjectConfiguration, PyTorchModelHubMixin, Union, defaultdict, generate_model_card, get_comet_experiment_url, is_wandb_available, logger, os, set_seed, textwrap, torch, wandb, warn)
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
import os
|
| 16 |
+
from typing import *
|
| 17 |
+
from dataclasses import dataclass, field
|
| 18 |
+
from packaging.version import Version
|
| 19 |
+
import torch
|
| 20 |
+
import numpy as np
|
| 21 |
+
from contextlib import nullcontext
|
| 22 |
+
from torch.nn import functional as F
|
| 23 |
+
from transformers import DataCollatorForSeq2Seq, DataCollatorForLanguageModeling
|
| 24 |
+
|
| 25 |
+
torch_compile_options = {
|
| 26 |
+
"epilogue_fusion" : True,
|
| 27 |
+
"max_autotune" : False,
|
| 28 |
+
"shape_padding" : True,
|
| 29 |
+
"trace.enabled" : False,
|
| 30 |
+
"triton.cudagraphs" : False,
|
| 31 |
+
}
|
| 32 |
+
|
| 33 |
+
@torch.compile(dynamic = True, fullgraph = True, options = torch_compile_options,)
|
| 34 |
+
def selective_log_softmax(logits, index):
|
| 35 |
+
logits = logits.to(torch.float32)
|
| 36 |
+
selected_logits = torch.gather(logits, dim = -1, index = index.unsqueeze(-1)).squeeze(-1)
|
| 37 |
+
# loop to reduce peak mem consumption
|
| 38 |
+
# logsumexp_values = torch.stack([torch.logsumexp(lg, dim=-1) for lg in logits])
|
| 39 |
+
logsumexp_values = torch.logsumexp(logits, dim = -1)
|
| 40 |
+
per_token_logps = selected_logits - logsumexp_values # log_softmax(x_i) = x_i - logsumexp(x)
|
| 41 |
+
return per_token_logps
|
| 42 |
+
@dataclass
|
| 43 |
+
class UnslothAlignPropConfig(AlignPropConfig):
|
| 44 |
+
"""
|
| 45 |
+
|
| 46 |
+
Configuration class for the [`AlignPropTrainer`].
|
| 47 |
+
|
| 48 |
+
Using [`~transformers.HfArgumentParser`] we can turn this class into
|
| 49 |
+
[argparse](https://docs.python.org/3/library/argparse#module-argparse) arguments that can be specified on the
|
| 50 |
+
command line.
|
| 51 |
+
|
| 52 |
+
Parameters:
|
| 53 |
+
exp_name (`str`, *optional*, defaults to `os.path.basename(sys.argv[0])[: -len(".py")]`):
|
| 54 |
+
Name of this experiment (defaults to the file name without the extension).
|
| 55 |
+
run_name (`str`, *optional*, defaults to `""`):
|
| 56 |
+
Name of this run.
|
| 57 |
+
seed (`int`, *optional*, defaults to `0`):
|
| 58 |
+
Random seed for reproducibility.
|
| 59 |
+
log_with (`str` or `None`, *optional*, defaults to `None`):
|
| 60 |
+
Log with either `"wandb"` or `"tensorboard"`. Check
|
| 61 |
+
[tracking](https://huggingface.co/docs/accelerate/usage_guides/tracking) for more details.
|
| 62 |
+
log_image_freq (`int`, *optional*, defaults to `1`):
|
| 63 |
+
Frequency for logging images.
|
| 64 |
+
tracker_kwargs (`dict[str, Any]`, *optional*, defaults to `{}`):
|
| 65 |
+
Keyword arguments for the tracker (e.g., `wandb_project`).
|
| 66 |
+
accelerator_kwargs (`dict[str, Any]`, *optional*, defaults to `{}`):
|
| 67 |
+
Keyword arguments for the accelerator.
|
| 68 |
+
project_kwargs (`dict[str, Any]`, *optional*, defaults to `{}`):
|
| 69 |
+
Keyword arguments for the accelerator project config (e.g., `logging_dir`).
|
| 70 |
+
tracker_project_name (`str`, *optional*, defaults to `"trl"`):
|
| 71 |
+
Name of project to use for tracking.
|
| 72 |
+
logdir (`str`, *optional*, defaults to `"logs"`):
|
| 73 |
+
Top-level logging directory for checkpoint saving.
|
| 74 |
+
num_epochs (`int`, *optional*, defaults to `100`):
|
| 75 |
+
Number of epochs to train.
|
| 76 |
+
save_freq (`int`, *optional*, defaults to `1`):
|
| 77 |
+
Number of epochs between saving model checkpoints.
|
| 78 |
+
num_checkpoint_limit (`int`, *optional*, defaults to `5`):
|
| 79 |
+
Number of checkpoints to keep before overwriting old ones.
|
| 80 |
+
mixed_precision (`str`, *optional*, defaults to `"fp16"`):
|
| 81 |
+
Mixed precision training.
|
| 82 |
+
allow_tf32 (`bool`, *optional*, defaults to `True`):
|
| 83 |
+
Allow `tf32` on Ampere GPUs.
|
| 84 |
+
resume_from (`str`, *optional*, defaults to `""`):
|
| 85 |
+
Path to resume training from a checkpoint.
|
| 86 |
+
sample_num_steps (`int`, *optional*, defaults to `50`):
|
| 87 |
+
Number of sampler inference steps.
|
| 88 |
+
sample_eta (`float`, *optional*, defaults to `1.0`):
|
| 89 |
+
Eta parameter for the DDIM sampler.
|
| 90 |
+
sample_guidance_scale (`float`, *optional*, defaults to `5.0`):
|
| 91 |
+
Classifier-free guidance weight.
|
| 92 |
+
train_batch_size (`int`, *optional*, defaults to `1`):
|
| 93 |
+
Batch size for training.
|
| 94 |
+
train_use_8bit_adam (`bool`, *optional*, defaults to `False`):
|
| 95 |
+
Whether to use the 8bit Adam optimizer from `bitsandbytes`.
|
| 96 |
+
train_learning_rate (`float`, *optional*, defaults to `1e-3`):
|
| 97 |
+
Learning rate.
|
| 98 |
+
train_adam_beta1 (`float`, *optional*, defaults to `0.9`):
|
| 99 |
+
Beta1 for Adam optimizer.
|
| 100 |
+
train_adam_beta2 (`float`, *optional*, defaults to `0.999`):
|
| 101 |
+
Beta2 for Adam optimizer.
|
| 102 |
+
train_adam_weight_decay (`float`, *optional*, defaults to `1e-4`):
|
| 103 |
+
Weight decay for Adam optimizer.
|
| 104 |
+
train_adam_epsilon (`float`, *optional*, defaults to `1e-8`):
|
| 105 |
+
Epsilon value for Adam optimizer.
|
| 106 |
+
train_gradient_accumulation_steps (`int`, *optional*, defaults to `1`):
|
| 107 |
+
Number of gradient accumulation steps.
|
| 108 |
+
train_max_grad_norm (`float`, *optional*, defaults to `1.0`):
|
| 109 |
+
Maximum gradient norm for gradient clipping.
|
| 110 |
+
negative_prompts (`str` or `None`, *optional*, defaults to `None`):
|
| 111 |
+
Comma-separated list of prompts to use as negative examples.
|
| 112 |
+
truncated_backprop_rand (`bool`, *optional*, defaults to `True`):
|
| 113 |
+
If `True`, randomized truncation to different diffusion timesteps is used.
|
| 114 |
+
truncated_backprop_timestep (`int`, *optional*, defaults to `49`):
|
| 115 |
+
Absolute timestep to which the gradients are backpropagated. Used only if `truncated_backprop_rand=False`.
|
| 116 |
+
truncated_rand_backprop_minmax (`tuple[int, int]`, *optional*, defaults to `(0, 50)`):
|
| 117 |
+
Range of diffusion timesteps for randomized truncated backpropagation.
|
| 118 |
+
push_to_hub (`bool`, *optional*, defaults to `False`):
|
| 119 |
+
Whether to push the final model to the Hub.
|
| 120 |
+
|
| 121 |
+
"""
|
| 122 |
+
vllm_sampling_params: Optional[Any] = field(
|
| 123 |
+
default = None,
|
| 124 |
+
metadata = {'help': 'vLLM SamplingParams'},
|
| 125 |
+
)
|
| 126 |
+
unsloth_num_chunks : Optional[int] = field(
|
| 127 |
+
default = -1,
|
| 128 |
+
metadata = {'help': 'Chunk size to reduce memory usage. -1 is most efficient.'},
|
| 129 |
+
)
|
| 130 |
+
def __init__(
|
| 131 |
+
self,
|
| 132 |
+
exp_name = 'app',
|
| 133 |
+
run_name = '',
|
| 134 |
+
seed = 3407,
|
| 135 |
+
log_with = None,
|
| 136 |
+
log_image_freq = 1,
|
| 137 |
+
tracker_project_name = 'trl',
|
| 138 |
+
logdir = 'logs',
|
| 139 |
+
num_epochs = 100,
|
| 140 |
+
save_freq = 1,
|
| 141 |
+
num_checkpoint_limit = 5,
|
| 142 |
+
mixed_precision = 'fp16',
|
| 143 |
+
allow_tf32 = True,
|
| 144 |
+
resume_from = '',
|
| 145 |
+
sample_num_steps = 50,
|
| 146 |
+
sample_eta = 1.0,
|
| 147 |
+
sample_guidance_scale = 5.0,
|
| 148 |
+
train_batch_size = 1,
|
| 149 |
+
train_use_8bit_adam = False,
|
| 150 |
+
train_learning_rate = 5e-05,
|
| 151 |
+
train_adam_beta1 = 0.9,
|
| 152 |
+
train_adam_beta2 = 0.999,
|
| 153 |
+
train_adam_weight_decay = 0.01,
|
| 154 |
+
train_adam_epsilon = 1e-08,
|
| 155 |
+
train_gradient_accumulation_steps = 2,
|
| 156 |
+
train_max_grad_norm = 1.0,
|
| 157 |
+
negative_prompts = None,
|
| 158 |
+
truncated_backprop_rand = True,
|
| 159 |
+
truncated_backprop_timestep = 49,
|
| 160 |
+
push_to_hub = False,
|
| 161 |
+
vllm_sampling_params = None,
|
| 162 |
+
unsloth_num_chunks = -1,
|
| 163 |
+
**kwargs,
|
| 164 |
+
):
|
| 165 |
+
|
| 166 |
+
super().__init__(
|
| 167 |
+
exp_name = exp_name,
|
| 168 |
+
run_name = run_name,
|
| 169 |
+
seed = seed,
|
| 170 |
+
log_with = log_with,
|
| 171 |
+
log_image_freq = log_image_freq,
|
| 172 |
+
tracker_project_name = tracker_project_name,
|
| 173 |
+
logdir = logdir,
|
| 174 |
+
num_epochs = num_epochs,
|
| 175 |
+
save_freq = save_freq,
|
| 176 |
+
num_checkpoint_limit = num_checkpoint_limit,
|
| 177 |
+
mixed_precision = mixed_precision,
|
| 178 |
+
allow_tf32 = allow_tf32,
|
| 179 |
+
resume_from = resume_from,
|
| 180 |
+
sample_num_steps = sample_num_steps,
|
| 181 |
+
sample_eta = sample_eta,
|
| 182 |
+
sample_guidance_scale = sample_guidance_scale,
|
| 183 |
+
train_batch_size = train_batch_size,
|
| 184 |
+
train_use_8bit_adam = train_use_8bit_adam,
|
| 185 |
+
train_learning_rate = train_learning_rate,
|
| 186 |
+
train_adam_beta1 = train_adam_beta1,
|
| 187 |
+
train_adam_beta2 = train_adam_beta2,
|
| 188 |
+
train_adam_weight_decay = train_adam_weight_decay,
|
| 189 |
+
train_adam_epsilon = train_adam_epsilon,
|
| 190 |
+
train_gradient_accumulation_steps = train_gradient_accumulation_steps,
|
| 191 |
+
train_max_grad_norm = train_max_grad_norm,
|
| 192 |
+
negative_prompts = negative_prompts,
|
| 193 |
+
truncated_backprop_rand = truncated_backprop_rand,
|
| 194 |
+
truncated_backprop_timestep = truncated_backprop_timestep,
|
| 195 |
+
push_to_hub = push_to_hub,**kwargs)
|
| 196 |
+
self.vllm_sampling_params = vllm_sampling_params
|
| 197 |
+
self.unsloth_num_chunks = unsloth_num_chunks
|
| 198 |
+
pass
|
| 199 |
+
|
| 200 |
+
class _UnslothAlignPropTrainer(PyTorchModelHubMixin):
|
| 201 |
+
""""""
|
| 202 |
+
|
| 203 |
+
_tag_names = ["trl", "alignprop"]
|
| 204 |
+
|
| 205 |
+
def __init__(
|
| 206 |
+
self,
|
| 207 |
+
config: AlignPropConfig,
|
| 208 |
+
reward_function: Callable[[torch.Tensor, tuple[str], tuple[Any]], torch.Tensor],
|
| 209 |
+
prompt_function: Callable[[], tuple[str, Any]],
|
| 210 |
+
sd_pipeline: DDPOStableDiffusionPipeline,
|
| 211 |
+
image_samples_hook: Optional[Callable[[Any, Any, Any], Any]] = None,
|
| 212 |
+
):
|
| 213 |
+
if image_samples_hook is None:
|
| 214 |
+
warn("No image_samples_hook provided; no images will be logged")
|
| 215 |
+
|
| 216 |
+
self.prompt_fn = prompt_function
|
| 217 |
+
self.reward_fn = reward_function
|
| 218 |
+
self.config = config
|
| 219 |
+
self.image_samples_callback = image_samples_hook
|
| 220 |
+
|
| 221 |
+
accelerator_project_config = ProjectConfiguration(**self.config.project_kwargs)
|
| 222 |
+
|
| 223 |
+
if self.config.resume_from:
|
| 224 |
+
self.config.resume_from = os.path.normpath(os.path.expanduser(self.config.resume_from))
|
| 225 |
+
if "checkpoint_" not in os.path.basename(self.config.resume_from):
|
| 226 |
+
# get the most recent checkpoint in this directory
|
| 227 |
+
checkpoints = list(
|
| 228 |
+
filter(
|
| 229 |
+
lambda x: "checkpoint_" in x,
|
| 230 |
+
os.listdir(self.config.resume_from),
|
| 231 |
+
)
|
| 232 |
+
)
|
| 233 |
+
if len(checkpoints) == 0:
|
| 234 |
+
raise ValueError(f"No checkpoints found in {self.config.resume_from}")
|
| 235 |
+
checkpoint_numbers = sorted([int(x.split("_")[-1]) for x in checkpoints])
|
| 236 |
+
self.config.resume_from = os.path.join(
|
| 237 |
+
self.config.resume_from,
|
| 238 |
+
f"checkpoint_{checkpoint_numbers[-1]}",
|
| 239 |
+
)
|
| 240 |
+
|
| 241 |
+
accelerator_project_config.iteration = checkpoint_numbers[-1] + 1
|
| 242 |
+
|
| 243 |
+
self.accelerator = Accelerator(
|
| 244 |
+
log_with=self.config.log_with,
|
| 245 |
+
mixed_precision=self.config.mixed_precision,
|
| 246 |
+
project_config=accelerator_project_config,
|
| 247 |
+
# we always accumulate gradients across timesteps; we want config.train.gradient_accumulation_steps to be the
|
| 248 |
+
# number of *samples* we accumulate across, so we need to multiply by the number of training timesteps to get
|
| 249 |
+
# the total number of optimizer steps to accumulate across.
|
| 250 |
+
gradient_accumulation_steps=self.config.train_gradient_accumulation_steps,
|
| 251 |
+
**self.config.accelerator_kwargs,
|
| 252 |
+
)
|
| 253 |
+
|
| 254 |
+
is_using_tensorboard = config.log_with is not None and config.log_with == "tensorboard"
|
| 255 |
+
|
| 256 |
+
if self.accelerator.is_main_process:
|
| 257 |
+
self.accelerator.init_trackers(
|
| 258 |
+
self.config.tracker_project_name,
|
| 259 |
+
config=dict(alignprop_trainer_config=config.to_dict())
|
| 260 |
+
if not is_using_tensorboard
|
| 261 |
+
else config.to_dict(),
|
| 262 |
+
init_kwargs=self.config.tracker_kwargs,
|
| 263 |
+
)
|
| 264 |
+
|
| 265 |
+
logger.info(f"\n{config}")
|
| 266 |
+
|
| 267 |
+
set_seed(self.config.seed, device_specific=True)
|
| 268 |
+
|
| 269 |
+
self.sd_pipeline = sd_pipeline
|
| 270 |
+
|
| 271 |
+
self.sd_pipeline.set_progress_bar_config(
|
| 272 |
+
position=1,
|
| 273 |
+
disable=not self.accelerator.is_local_main_process,
|
| 274 |
+
leave=False,
|
| 275 |
+
desc="Timestep",
|
| 276 |
+
dynamic_ncols=True,
|
| 277 |
+
)
|
| 278 |
+
|
| 279 |
+
# For mixed precision training we cast all non-trainable weights (vae, non-lora text_encoder and non-lora unet) to half-precision
|
| 280 |
+
# as these weights are only used for inference, keeping weights in full precision is not required.
|
| 281 |
+
if self.accelerator.mixed_precision == "fp16":
|
| 282 |
+
inference_dtype = torch.float16
|
| 283 |
+
elif self.accelerator.mixed_precision == "bf16":
|
| 284 |
+
inference_dtype = torch.bfloat16
|
| 285 |
+
else:
|
| 286 |
+
inference_dtype = torch.float32
|
| 287 |
+
|
| 288 |
+
self.sd_pipeline.vae.to(self.accelerator.device, dtype=inference_dtype)
|
| 289 |
+
self.sd_pipeline.text_encoder.to(self.accelerator.device, dtype=inference_dtype)
|
| 290 |
+
self.sd_pipeline.unet.to(self.accelerator.device, dtype=inference_dtype)
|
| 291 |
+
|
| 292 |
+
trainable_layers = self.sd_pipeline.get_trainable_layers()
|
| 293 |
+
|
| 294 |
+
self.accelerator.register_save_state_pre_hook(self._save_model_hook)
|
| 295 |
+
self.accelerator.register_load_state_pre_hook(self._load_model_hook)
|
| 296 |
+
|
| 297 |
+
# Enable TF32 for faster training on Ampere GPUs,
|
| 298 |
+
# cf https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices
|
| 299 |
+
if self.config.allow_tf32:
|
| 300 |
+
torch.backends.cuda.matmul.allow_tf32 = True
|
| 301 |
+
|
| 302 |
+
self.optimizer = self._setup_optimizer(
|
| 303 |
+
trainable_layers.parameters() if not isinstance(trainable_layers, list) else trainable_layers
|
| 304 |
+
)
|
| 305 |
+
|
| 306 |
+
self.neg_prompt_embed = self.sd_pipeline.text_encoder(
|
| 307 |
+
self.sd_pipeline.tokenizer(
|
| 308 |
+
[""] if self.config.negative_prompts is None else self.config.negative_prompts,
|
| 309 |
+
return_tensors="pt",
|
| 310 |
+
padding="max_length",
|
| 311 |
+
truncation=True,
|
| 312 |
+
max_length=self.sd_pipeline.tokenizer.model_max_length,
|
| 313 |
+
).input_ids.to(self.accelerator.device)
|
| 314 |
+
)[0]
|
| 315 |
+
|
| 316 |
+
# NOTE: for some reason, autocast is necessary for non-lora training but for lora training it isn't necessary and it uses
|
| 317 |
+
# more memory
|
| 318 |
+
self.autocast = self.sd_pipeline.autocast or self.accelerator.autocast
|
| 319 |
+
|
| 320 |
+
if hasattr(self.sd_pipeline, "use_lora") and self.sd_pipeline.use_lora:
|
| 321 |
+
unet, self.optimizer = self.accelerator.prepare(trainable_layers, self.optimizer)
|
| 322 |
+
self.trainable_layers = list(filter(lambda p: p.requires_grad, unet.parameters()))
|
| 323 |
+
else:
|
| 324 |
+
self.trainable_layers, self.optimizer = self.accelerator.prepare(trainable_layers, self.optimizer)
|
| 325 |
+
|
| 326 |
+
if config.resume_from:
|
| 327 |
+
logger.info(f"Resuming from {config.resume_from}")
|
| 328 |
+
self.accelerator.load_state(config.resume_from)
|
| 329 |
+
self.first_epoch = int(config.resume_from.split("_")[-1]) + 1
|
| 330 |
+
else:
|
| 331 |
+
self.first_epoch = 0
|
| 332 |
+
|
| 333 |
+
def compute_rewards(self, prompt_image_pairs):
|
| 334 |
+
reward, reward_metadata = self.reward_fn(
|
| 335 |
+
prompt_image_pairs["images"], prompt_image_pairs["prompts"], prompt_image_pairs["prompt_metadata"]
|
| 336 |
+
)
|
| 337 |
+
return reward
|
| 338 |
+
|
| 339 |
+
def step(self, epoch: int, global_step: int):
|
| 340 |
+
"""
|
| 341 |
+
Perform a single step of training.
|
| 342 |
+
|
| 343 |
+
Args:
|
| 344 |
+
epoch (int): The current epoch.
|
| 345 |
+
global_step (int): The current global step.
|
| 346 |
+
|
| 347 |
+
Side Effects:
|
| 348 |
+
- Model weights are updated
|
| 349 |
+
- Logs the statistics to the accelerator trackers.
|
| 350 |
+
- If `self.image_samples_callback` is not None, it will be called with the prompt_image_pairs, global_step, and the accelerator tracker.
|
| 351 |
+
|
| 352 |
+
Returns:
|
| 353 |
+
global_step (int): The updated global step.
|
| 354 |
+
"""
|
| 355 |
+
info = defaultdict(list)
|
| 356 |
+
|
| 357 |
+
self.sd_pipeline.unet.train()
|
| 358 |
+
|
| 359 |
+
for _ in range(self.config.train_gradient_accumulation_steps):
|
| 360 |
+
with self.accelerator.accumulate(self.sd_pipeline.unet), self.autocast(), torch.enable_grad():
|
| 361 |
+
prompt_image_pairs = self._generate_samples(
|
| 362 |
+
batch_size=self.config.train_batch_size,
|
| 363 |
+
)
|
| 364 |
+
|
| 365 |
+
rewards = self.compute_rewards(prompt_image_pairs)
|
| 366 |
+
|
| 367 |
+
prompt_image_pairs["rewards"] = rewards
|
| 368 |
+
|
| 369 |
+
rewards_vis = self.accelerator.gather(rewards).detach().cpu().numpy()
|
| 370 |
+
|
| 371 |
+
loss = self.calculate_loss(rewards)
|
| 372 |
+
|
| 373 |
+
self.accelerator.backward(loss)
|
| 374 |
+
|
| 375 |
+
if self.accelerator.sync_gradients:
|
| 376 |
+
self.accelerator.clip_grad_norm_(
|
| 377 |
+
self.trainable_layers.parameters()
|
| 378 |
+
if not isinstance(self.trainable_layers, list)
|
| 379 |
+
else self.trainable_layers,
|
| 380 |
+
self.config.train_max_grad_norm,
|
| 381 |
+
)
|
| 382 |
+
|
| 383 |
+
self.optimizer.step()
|
| 384 |
+
self.optimizer.zero_grad()
|
| 385 |
+
|
| 386 |
+
info["reward_mean"].append(rewards_vis.mean())
|
| 387 |
+
info["reward_std"].append(rewards_vis.std())
|
| 388 |
+
info["loss"].append(loss.item())
|
| 389 |
+
|
| 390 |
+
# Checks if the accelerator has performed an optimization step behind the scenes
|
| 391 |
+
if self.accelerator.sync_gradients:
|
| 392 |
+
# log training-related stuff
|
| 393 |
+
info = {k: torch.mean(torch.tensor(v)) for k, v in info.items()}
|
| 394 |
+
info = self.accelerator.reduce(info, reduction="mean")
|
| 395 |
+
info.update({"epoch": epoch})
|
| 396 |
+
self.accelerator.log(info, step=global_step)
|
| 397 |
+
global_step += 1
|
| 398 |
+
info = defaultdict(list)
|
| 399 |
+
else:
|
| 400 |
+
raise ValueError(
|
| 401 |
+
"Optimization step should have been performed by this point. Please check calculated gradient accumulation settings."
|
| 402 |
+
)
|
| 403 |
+
# Logs generated images
|
| 404 |
+
if self.image_samples_callback is not None and global_step % self.config.log_image_freq == 0:
|
| 405 |
+
self.image_samples_callback(prompt_image_pairs, global_step, self.accelerator.trackers[0])
|
| 406 |
+
|
| 407 |
+
if epoch != 0 and epoch % self.config.save_freq == 0 and self.accelerator.is_main_process:
|
| 408 |
+
self.accelerator.save_state()
|
| 409 |
+
|
| 410 |
+
return global_step
|
| 411 |
+
|
| 412 |
+
def calculate_loss(self, rewards):
|
| 413 |
+
"""
|
| 414 |
+
Calculate the loss for a batch of an unpacked sample
|
| 415 |
+
|
| 416 |
+
Args:
|
| 417 |
+
rewards (torch.Tensor):
|
| 418 |
+
Differentiable reward scalars for each generated image, shape: [batch_size]
|
| 419 |
+
|
| 420 |
+
Returns:
|
| 421 |
+
loss (torch.Tensor)
|
| 422 |
+
(all of these are of shape (1,))
|
| 423 |
+
"""
|
| 424 |
+
# Loss is specific to Aesthetic Reward function used in AlignProp (https://huggingface.co/papers/2310.03739)
|
| 425 |
+
loss = 10.0 - (rewards).mean()
|
| 426 |
+
return loss
|
| 427 |
+
|
| 428 |
+
def loss(
|
| 429 |
+
self,
|
| 430 |
+
advantages: torch.Tensor,
|
| 431 |
+
clip_range: float,
|
| 432 |
+
ratio: torch.Tensor,
|
| 433 |
+
):
|
| 434 |
+
unclipped_loss = -advantages * ratio
|
| 435 |
+
clipped_loss = -advantages * torch.clamp(
|
| 436 |
+
ratio,
|
| 437 |
+
1.0 - clip_range,
|
| 438 |
+
1.0 + clip_range,
|
| 439 |
+
)
|
| 440 |
+
return torch.mean(torch.maximum(unclipped_loss, clipped_loss))
|
| 441 |
+
|
| 442 |
+
def _setup_optimizer(self, trainable_layers_parameters):
|
| 443 |
+
if self.config.train_use_8bit_adam:
|
| 444 |
+
import bitsandbytes
|
| 445 |
+
|
| 446 |
+
optimizer_cls = bitsandbytes.optim.AdamW8bit
|
| 447 |
+
else:
|
| 448 |
+
optimizer_cls = torch.optim.AdamW
|
| 449 |
+
|
| 450 |
+
return optimizer_cls(
|
| 451 |
+
trainable_layers_parameters,
|
| 452 |
+
lr=self.config.train_learning_rate,
|
| 453 |
+
betas=(self.config.train_adam_beta1, self.config.train_adam_beta2),
|
| 454 |
+
weight_decay=self.config.train_adam_weight_decay,
|
| 455 |
+
eps=self.config.train_adam_epsilon,
|
| 456 |
+
)
|
| 457 |
+
|
| 458 |
+
def _save_model_hook(self, models, weights, output_dir):
|
| 459 |
+
self.sd_pipeline.save_checkpoint(models, weights, output_dir)
|
| 460 |
+
weights.pop() # ensures that accelerate doesn't try to handle saving of the model
|
| 461 |
+
|
| 462 |
+
def _load_model_hook(self, models, input_dir):
|
| 463 |
+
self.sd_pipeline.load_checkpoint(models, input_dir)
|
| 464 |
+
models.pop() # ensures that accelerate doesn't try to handle loading of the model
|
| 465 |
+
|
| 466 |
+
def _generate_samples(self, batch_size, with_grad=True, prompts=None):
|
| 467 |
+
"""
|
| 468 |
+
Generate samples from the model
|
| 469 |
+
|
| 470 |
+
Args:
|
| 471 |
+
batch_size (int): Batch size to use for sampling
|
| 472 |
+
with_grad (bool): Whether the generated RGBs should have gradients attached to it.
|
| 473 |
+
|
| 474 |
+
Returns:
|
| 475 |
+
prompt_image_pairs (dict[Any])
|
| 476 |
+
"""
|
| 477 |
+
prompt_image_pairs = {}
|
| 478 |
+
|
| 479 |
+
sample_neg_prompt_embeds = self.neg_prompt_embed.repeat(batch_size, 1, 1)
|
| 480 |
+
|
| 481 |
+
if prompts is None:
|
| 482 |
+
prompts, prompt_metadata = zip(*[self.prompt_fn() for _ in range(batch_size)])
|
| 483 |
+
else:
|
| 484 |
+
prompt_metadata = [{} for _ in range(batch_size)]
|
| 485 |
+
|
| 486 |
+
prompt_ids = self.sd_pipeline.tokenizer(
|
| 487 |
+
prompts,
|
| 488 |
+
return_tensors="pt",
|
| 489 |
+
padding="max_length",
|
| 490 |
+
truncation=True,
|
| 491 |
+
max_length=self.sd_pipeline.tokenizer.model_max_length,
|
| 492 |
+
).input_ids.to(self.accelerator.device)
|
| 493 |
+
|
| 494 |
+
prompt_embeds = self.sd_pipeline.text_encoder(prompt_ids)[0]
|
| 495 |
+
|
| 496 |
+
if with_grad:
|
| 497 |
+
sd_output = self.sd_pipeline.rgb_with_grad(
|
| 498 |
+
prompt_embeds=prompt_embeds,
|
| 499 |
+
negative_prompt_embeds=sample_neg_prompt_embeds,
|
| 500 |
+
num_inference_steps=self.config.sample_num_steps,
|
| 501 |
+
guidance_scale=self.config.sample_guidance_scale,
|
| 502 |
+
eta=self.config.sample_eta,
|
| 503 |
+
truncated_backprop_rand=self.config.truncated_backprop_rand,
|
| 504 |
+
truncated_backprop_timestep=self.config.truncated_backprop_timestep,
|
| 505 |
+
truncated_rand_backprop_minmax=self.config.truncated_rand_backprop_minmax,
|
| 506 |
+
output_type="pt",
|
| 507 |
+
)
|
| 508 |
+
else:
|
| 509 |
+
sd_output = self.sd_pipeline(
|
| 510 |
+
prompt_embeds=prompt_embeds,
|
| 511 |
+
negative_prompt_embeds=sample_neg_prompt_embeds,
|
| 512 |
+
num_inference_steps=self.config.sample_num_steps,
|
| 513 |
+
guidance_scale=self.config.sample_guidance_scale,
|
| 514 |
+
eta=self.config.sample_eta,
|
| 515 |
+
output_type="pt",
|
| 516 |
+
)
|
| 517 |
+
|
| 518 |
+
images = sd_output.images
|
| 519 |
+
|
| 520 |
+
prompt_image_pairs["images"] = images
|
| 521 |
+
prompt_image_pairs["prompts"] = prompts
|
| 522 |
+
prompt_image_pairs["prompt_metadata"] = prompt_metadata
|
| 523 |
+
|
| 524 |
+
return prompt_image_pairs
|
| 525 |
+
|
| 526 |
+
def train(self, epochs: Optional[int] = None):
|
| 527 |
+
"""
|
| 528 |
+
Train the model for a given number of epochs
|
| 529 |
+
"""
|
| 530 |
+
global_step = 0
|
| 531 |
+
if epochs is None:
|
| 532 |
+
epochs = self.config.num_epochs
|
| 533 |
+
for epoch in range(self.first_epoch, epochs):
|
| 534 |
+
global_step = self.step(epoch, global_step)
|
| 535 |
+
|
| 536 |
+
def _save_pretrained(self, save_directory):
|
| 537 |
+
self.sd_pipeline.save_pretrained(save_directory)
|
| 538 |
+
self.create_model_card()
|
| 539 |
+
|
| 540 |
+
def create_model_card(
|
| 541 |
+
self,
|
| 542 |
+
model_name: Optional[str] = None,
|
| 543 |
+
dataset_name: Optional[str] = None,
|
| 544 |
+
tags: Union[str, list[str], None] = None,
|
| 545 |
+
):
|
| 546 |
+
"""
|
| 547 |
+
Creates a draft of a model card using the information available to the `Trainer`.
|
| 548 |
+
|
| 549 |
+
Args:
|
| 550 |
+
model_name (`str` or `None`, *optional*, defaults to `None`):
|
| 551 |
+
Name of the model.
|
| 552 |
+
dataset_name (`str` or `None`, *optional*, defaults to `None`):
|
| 553 |
+
Name of the dataset used for training.
|
| 554 |
+
tags (`str`, `list[str]` or `None`, *optional*, defaults to `None`):
|
| 555 |
+
Tags to be associated with the model card.
|
| 556 |
+
"""
|
| 557 |
+
if not self.is_world_process_zero():
|
| 558 |
+
return
|
| 559 |
+
|
| 560 |
+
if hasattr(self.model.config, "_name_or_path") and not os.path.isdir(self.model.config._name_or_path):
|
| 561 |
+
base_model = self.model.config._name_or_path
|
| 562 |
+
else:
|
| 563 |
+
base_model = None
|
| 564 |
+
|
| 565 |
+
tags = tags or []
|
| 566 |
+
if isinstance(tags, str):
|
| 567 |
+
tags = [tags]
|
| 568 |
+
|
| 569 |
+
if hasattr(self.model.config, "unsloth_version"):
|
| 570 |
+
tags.append("unsloth")
|
| 571 |
+
|
| 572 |
+
citation = textwrap.dedent("""\
|
| 573 |
+
@article{prabhudesai2024aligning,
|
| 574 |
+
title = {{Aligning Text-to-Image Diffusion Models with Reward Backpropagation}},
|
| 575 |
+
author = {Mihir Prabhudesai and Anirudh Goyal and Deepak Pathak and Katerina Fragkiadaki},
|
| 576 |
+
year = 2024,
|
| 577 |
+
eprint = {arXiv:2310.03739}
|
| 578 |
+
}""")
|
| 579 |
+
|
| 580 |
+
model_card = generate_model_card(
|
| 581 |
+
base_model=base_model,
|
| 582 |
+
model_name=model_name,
|
| 583 |
+
hub_model_id=self.hub_model_id,
|
| 584 |
+
dataset_name=dataset_name,
|
| 585 |
+
tags=tags,
|
| 586 |
+
wandb_url=wandb.run.get_url() if is_wandb_available() and wandb.run is not None else None,
|
| 587 |
+
comet_url=get_comet_experiment_url(),
|
| 588 |
+
trainer_name="AlignProp",
|
| 589 |
+
trainer_citation=citation,
|
| 590 |
+
paper_title="Aligning Text-to-Image Diffusion Models with Reward Backpropagation",
|
| 591 |
+
paper_id="2310.03739",
|
| 592 |
+
)
|
| 593 |
+
|
| 594 |
+
model_card.save(os.path.join(self.args.output_dir, "README.md"))
|
| 595 |
+
class UnslothAlignPropTrainer(_UnslothAlignPropTrainer):
|
| 596 |
+
"""
|
| 597 |
+
|
| 598 |
+
The AlignPropTrainer uses Deep Diffusion Policy Optimization to optimise diffusion models.
|
| 599 |
+
Note, this trainer is heavily inspired by the work here: https://github.com/mihirp1998/AlignProp/
|
| 600 |
+
As of now only Stable Diffusion based pipelines are supported
|
| 601 |
+
|
| 602 |
+
Attributes:
|
| 603 |
+
config (`AlignPropConfig`):
|
| 604 |
+
Configuration object for AlignPropTrainer. Check the documentation of `PPOConfig` for more details.
|
| 605 |
+
reward_function (`Callable[[torch.Tensor, tuple[str], tuple[Any]], torch.Tensor]`):
|
| 606 |
+
Reward function to be used
|
| 607 |
+
prompt_function (`Callable[[], tuple[str, Any]]`):
|
| 608 |
+
Function to generate prompts to guide model
|
| 609 |
+
sd_pipeline (`DDPOStableDiffusionPipeline`):
|
| 610 |
+
Stable Diffusion pipeline to be used for training.
|
| 611 |
+
image_samples_hook (`Optional[Callable[[Any, Any, Any], Any]]`):
|
| 612 |
+
Hook to be called to log images
|
| 613 |
+
|
| 614 |
+
"""
|
| 615 |
+
def __init__(
|
| 616 |
+
self,
|
| 617 |
+
config,
|
| 618 |
+
reward_function,
|
| 619 |
+
prompt_function,
|
| 620 |
+
sd_pipeline,
|
| 621 |
+
image_samples_hook = None,
|
| 622 |
+
**kwargs
|
| 623 |
+
):
|
| 624 |
+
if args is None: args = UnslothAlignPropConfig()
|
| 625 |
+
other_metrics = []
|
| 626 |
+
|
| 627 |
+
from unsloth_zoo.logging_utils import PatchRLStatistics
|
| 628 |
+
PatchRLStatistics('alignprop_trainer', other_metrics)
|
| 629 |
+
|
| 630 |
+
super().__init__(
|
| 631 |
+
config = config,
|
| 632 |
+
reward_function = reward_function,
|
| 633 |
+
prompt_function = prompt_function,
|
| 634 |
+
sd_pipeline = sd_pipeline,
|
| 635 |
+
image_samples_hook = image_samples_hook,**kwargs)
|
| 636 |
+
|
| 637 |
+
pass
|
unsloth_compiled_cache/UnslothBCOTrainer.py
ADDED
|
@@ -0,0 +1,1824 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
2025.3.17
|
| 3 |
+
2025.3.19
|
| 4 |
+
4.50.0
|
| 5 |
+
0.15.2
|
| 6 |
+
__UNSLOTH_VERSIONING__
|
| 7 |
+
"""
|
| 8 |
+
from torch import Tensor
|
| 9 |
+
import torch
|
| 10 |
+
import torch.nn as nn
|
| 11 |
+
from torch.nn import functional as F
|
| 12 |
+
from trl.trainer.bco_trainer import (Any, AutoModelForCausalLM, BCOConfig, BCOTrainer, BaseImageProcessor, CLF_NAME, Callable, DPODataCollatorWithPadding, DataCollator, DataLoader, Dataset, EvalLoopOutput, F, FeatureExtractionMixin, Literal, Optional, PartialState, PeftModel, PreTrainedModel, PreTrainedModelWrapper, PreTrainedTokenizerBase, ProcessorMixin, RUNNING_NAME, RunningMoments, SequentialSampler, Trainer, TrainerCallback, TrainingArguments, Union, _process_tokens, _tokenize, amp, contextmanager, create_reference_model, deepcopy, defaultdict, disable_dropout_in_model, generate_model_card, get_comet_experiment_url, has_length, inspect, is_comet_available, is_peft_available, is_sklearn_available, is_wandb_available, itemgetter, log_table_to_comet_experiment, maybe_apply_chat_template, nn, np, nullcontext, os, pad_to_length, pd, peft_module_casting_to_bf16, prepare_model_for_kbit_training, random, textwrap, torch, tqdm, transformers, version, wandb, warnings, F, Optional, PeftModel, PreTrainedModel, Trainer, is_peft_available, os, torch)
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
import os
|
| 16 |
+
from typing import *
|
| 17 |
+
from dataclasses import dataclass, field
|
| 18 |
+
from packaging.version import Version
|
| 19 |
+
import torch
|
| 20 |
+
import numpy as np
|
| 21 |
+
from contextlib import nullcontext
|
| 22 |
+
from torch.nn import functional as F
|
| 23 |
+
from transformers import DataCollatorForSeq2Seq, DataCollatorForLanguageModeling
|
| 24 |
+
|
| 25 |
+
torch_compile_options = {
|
| 26 |
+
"epilogue_fusion" : True,
|
| 27 |
+
"max_autotune" : False,
|
| 28 |
+
"shape_padding" : True,
|
| 29 |
+
"trace.enabled" : False,
|
| 30 |
+
"triton.cudagraphs" : False,
|
| 31 |
+
}
|
| 32 |
+
|
| 33 |
+
@torch.compile(dynamic = True, fullgraph = True, options = torch_compile_options,)
|
| 34 |
+
def selective_log_softmax(logits, index):
|
| 35 |
+
logits = logits.to(torch.float32)
|
| 36 |
+
selected_logits = torch.gather(logits, dim = -1, index = index.unsqueeze(-1)).squeeze(-1)
|
| 37 |
+
# loop to reduce peak mem consumption
|
| 38 |
+
# logsumexp_values = torch.stack([torch.logsumexp(lg, dim=-1) for lg in logits])
|
| 39 |
+
logsumexp_values = torch.logsumexp(logits, dim = -1)
|
| 40 |
+
per_token_logps = selected_logits - logsumexp_values # log_softmax(x_i) = x_i - logsumexp(x)
|
| 41 |
+
return per_token_logps
|
| 42 |
+
@dataclass
|
| 43 |
+
class UnslothBCOConfig(BCOConfig):
|
| 44 |
+
"""
|
| 45 |
+
|
| 46 |
+
Configuration class for the [`BCOTrainer`].
|
| 47 |
+
|
| 48 |
+
Using [`~transformers.HfArgumentParser`] we can turn this class into
|
| 49 |
+
[argparse](https://docs.python.org/3/library/argparse#module-argparse) arguments that can be specified on the
|
| 50 |
+
command line.
|
| 51 |
+
|
| 52 |
+
Parameters:
|
| 53 |
+
max_length (`int` or `None`, *optional*, defaults to `1024`):
|
| 54 |
+
Maximum length of the sequences (prompt + completion) in the batch. This argument is required if you want
|
| 55 |
+
to use the default data collator.
|
| 56 |
+
max_prompt_length (`int` or `None`, *optional*, defaults to `512`):
|
| 57 |
+
Maximum length of the prompt. This argument is required if you want to use the default data collator.
|
| 58 |
+
max_completion_length (`int` or `None`, *optional*, defaults to `None`):
|
| 59 |
+
Maximum length of the completion. This argument is required if you want to use the default data collator
|
| 60 |
+
and your model is an encoder-decoder.
|
| 61 |
+
beta (`float`, *optional*, defaults to `0.1`):
|
| 62 |
+
Parameter controlling the deviation from the reference model. Higher β means less deviation from the
|
| 63 |
+
reference model.
|
| 64 |
+
label_pad_token_id (`int`, *optional*, defaults to `-100`):
|
| 65 |
+
Label pad token id. This argument is required if you want to use the default data collator.
|
| 66 |
+
padding_value (`int` or `None`, *optional*, defaults to `None`):
|
| 67 |
+
Padding value to use. If `None`, the padding value of the tokenizer is used.
|
| 68 |
+
truncation_mode (`str`, *optional*, defaults to `"keep_end"`):
|
| 69 |
+
Truncation mode to use when the prompt is too long. Possible values are `"keep_end"` or `"keep_start"`.
|
| 70 |
+
This argument is required if you want to use the default data collator.
|
| 71 |
+
disable_dropout (`bool`, *optional*, defaults to `True`):
|
| 72 |
+
Whether to disable dropout in the model and reference model.
|
| 73 |
+
generate_during_eval (`bool`, *optional*, defaults to `False`):
|
| 74 |
+
If `True`, generates and logs completions from both the model and the reference model to W&B or Comet during
|
| 75 |
+
evaluation.
|
| 76 |
+
is_encoder_decoder (`bool` or `None`, *optional*, defaults to `None`):
|
| 77 |
+
When using the `model_init` argument (callable) to instantiate the model instead of the `model` argument,
|
| 78 |
+
you need to specify if the model returned by the callable is an encoder-decoder model.
|
| 79 |
+
precompute_ref_log_probs (`bool`, *optional*, defaults to `False`):
|
| 80 |
+
Whether to precompute reference model log probabilities for training and evaluation datasets. This is
|
| 81 |
+
useful when training without the reference model to reduce the total GPU memory needed.
|
| 82 |
+
model_init_kwargs (`dict[str, Any]` or `None`, *optional*, defaults to `None`):
|
| 83 |
+
Keyword arguments to pass to `AutoModelForCausalLM.from_pretrained` when instantiating the model from a
|
| 84 |
+
string.
|
| 85 |
+
ref_model_init_kwargs (`dict[str, Any]` or `None`, *optional*, defaults to `None`):
|
| 86 |
+
Keyword arguments to pass to `AutoModelForCausalLM.from_pretrained` when instantiating the reference model
|
| 87 |
+
from a string.
|
| 88 |
+
dataset_num_proc (`int` or `None`, *optional*, defaults to `None`):
|
| 89 |
+
Number of processes to use for processing the dataset.
|
| 90 |
+
prompt_sample_size (`int`, *optional*, defaults to `1024`):
|
| 91 |
+
Number of prompts that are fed to density ratio classifier.
|
| 92 |
+
min_density_ratio (`float`, *optional*, defaults to `0.5`):
|
| 93 |
+
Minimum value of the density ratio. The estimated density ratio is clamped to this value.
|
| 94 |
+
max_density_ratio (`float`, *optional*, defaults to `10.0`):
|
| 95 |
+
Maximum value of the density ratio. The estimated density ratio is clamped to this value.
|
| 96 |
+
|
| 97 |
+
"""
|
| 98 |
+
vllm_sampling_params: Optional[Any] = field(
|
| 99 |
+
default = None,
|
| 100 |
+
metadata = {'help': 'vLLM SamplingParams'},
|
| 101 |
+
)
|
| 102 |
+
unsloth_num_chunks : Optional[int] = field(
|
| 103 |
+
default = -1,
|
| 104 |
+
metadata = {'help': 'Chunk size to reduce memory usage. -1 is most efficient.'},
|
| 105 |
+
)
|
| 106 |
+
def __init__(
|
| 107 |
+
self,
|
| 108 |
+
output_dir = None,
|
| 109 |
+
overwrite_output_dir = None,
|
| 110 |
+
do_train = False,
|
| 111 |
+
do_eval = False,
|
| 112 |
+
do_predict = False,
|
| 113 |
+
eval_strategy = 'no',
|
| 114 |
+
prediction_loss_only = False,
|
| 115 |
+
per_device_train_batch_size = 4,
|
| 116 |
+
per_device_eval_batch_size = 4,
|
| 117 |
+
per_gpu_train_batch_size = None,
|
| 118 |
+
per_gpu_eval_batch_size = None,
|
| 119 |
+
gradient_accumulation_steps = 2,
|
| 120 |
+
eval_accumulation_steps = 2,
|
| 121 |
+
eval_delay = 0,
|
| 122 |
+
torch_empty_cache_steps = 250,
|
| 123 |
+
learning_rate = 5e-05,
|
| 124 |
+
weight_decay = 0.01,
|
| 125 |
+
adam_beta1 = 0.9,
|
| 126 |
+
adam_beta2 = 0.999,
|
| 127 |
+
adam_epsilon = 1e-08,
|
| 128 |
+
max_grad_norm = 1.0,
|
| 129 |
+
num_train_epochs = 3.0,
|
| 130 |
+
max_steps = -1,
|
| 131 |
+
lr_scheduler_type = 'linear',
|
| 132 |
+
warmup_ratio = 0.1,
|
| 133 |
+
warmup_steps = 0,
|
| 134 |
+
log_level = 'passive',
|
| 135 |
+
log_level_replica = 'warning',
|
| 136 |
+
log_on_each_node = True,
|
| 137 |
+
logging_dir = None,
|
| 138 |
+
logging_strategy = 'steps',
|
| 139 |
+
logging_first_step = False,
|
| 140 |
+
logging_steps = 1,
|
| 141 |
+
logging_nan_inf_filter = False,
|
| 142 |
+
save_strategy = 'steps',
|
| 143 |
+
save_steps = 500,
|
| 144 |
+
save_total_limit = None,
|
| 145 |
+
save_safetensors = True,
|
| 146 |
+
save_on_each_node = False,
|
| 147 |
+
save_only_model = False,
|
| 148 |
+
restore_callback_states_from_checkpoint = False,
|
| 149 |
+
no_cuda = False,
|
| 150 |
+
use_cpu = False,
|
| 151 |
+
use_mps_device = False,
|
| 152 |
+
seed = 3407,
|
| 153 |
+
data_seed = 3407,
|
| 154 |
+
jit_mode_eval = False,
|
| 155 |
+
use_ipex = False,
|
| 156 |
+
bf16 = False,
|
| 157 |
+
fp16 = False,
|
| 158 |
+
fp16_opt_level = 'O1',
|
| 159 |
+
half_precision_backend = 'auto',
|
| 160 |
+
bf16_full_eval = False,
|
| 161 |
+
fp16_full_eval = False,
|
| 162 |
+
tf32 = None,
|
| 163 |
+
local_rank = -1,
|
| 164 |
+
ddp_backend = None,
|
| 165 |
+
tpu_num_cores = None,
|
| 166 |
+
tpu_metrics_debug = False,
|
| 167 |
+
debug = '',
|
| 168 |
+
dataloader_drop_last = False,
|
| 169 |
+
eval_steps = None,
|
| 170 |
+
dataloader_num_workers = 0,
|
| 171 |
+
dataloader_prefetch_factor = None,
|
| 172 |
+
past_index = -1,
|
| 173 |
+
run_name = None,
|
| 174 |
+
disable_tqdm = None,
|
| 175 |
+
remove_unused_columns = True,
|
| 176 |
+
label_names = None,
|
| 177 |
+
load_best_model_at_end = False,
|
| 178 |
+
metric_for_best_model = None,
|
| 179 |
+
greater_is_better = None,
|
| 180 |
+
ignore_data_skip = False,
|
| 181 |
+
fsdp = '',
|
| 182 |
+
fsdp_min_num_params = 0,
|
| 183 |
+
fsdp_config = None,
|
| 184 |
+
tp_size = 0,
|
| 185 |
+
fsdp_transformer_layer_cls_to_wrap = None,
|
| 186 |
+
accelerator_config = None,
|
| 187 |
+
deepspeed = None,
|
| 188 |
+
label_smoothing_factor = 0.0,
|
| 189 |
+
optim = 'adamw_8bit',
|
| 190 |
+
optim_args = None,
|
| 191 |
+
adafactor = False,
|
| 192 |
+
group_by_length = False,
|
| 193 |
+
length_column_name = 'length',
|
| 194 |
+
report_to = None,
|
| 195 |
+
ddp_find_unused_parameters = None,
|
| 196 |
+
ddp_bucket_cap_mb = None,
|
| 197 |
+
ddp_broadcast_buffers = None,
|
| 198 |
+
dataloader_pin_memory = True,
|
| 199 |
+
dataloader_persistent_workers = False,
|
| 200 |
+
skip_memory_metrics = True,
|
| 201 |
+
use_legacy_prediction_loop = False,
|
| 202 |
+
push_to_hub = False,
|
| 203 |
+
resume_from_checkpoint = None,
|
| 204 |
+
hub_model_id = None,
|
| 205 |
+
hub_strategy = 'every_save',
|
| 206 |
+
hub_token = None,
|
| 207 |
+
hub_private_repo = None,
|
| 208 |
+
hub_always_push = False,
|
| 209 |
+
gradient_checkpointing = False,
|
| 210 |
+
gradient_checkpointing_kwargs = None,
|
| 211 |
+
include_inputs_for_metrics = False,
|
| 212 |
+
eval_do_concat_batches = True,
|
| 213 |
+
fp16_backend = 'auto',
|
| 214 |
+
evaluation_strategy = None,
|
| 215 |
+
push_to_hub_model_id = None,
|
| 216 |
+
push_to_hub_organization = None,
|
| 217 |
+
push_to_hub_token = None,
|
| 218 |
+
mp_parameters = '',
|
| 219 |
+
auto_find_batch_size = False,
|
| 220 |
+
full_determinism = False,
|
| 221 |
+
torchdynamo = None,
|
| 222 |
+
ray_scope = 'last',
|
| 223 |
+
ddp_timeout = 1800,
|
| 224 |
+
torch_compile = False,
|
| 225 |
+
torch_compile_backend = None,
|
| 226 |
+
torch_compile_mode = None,
|
| 227 |
+
dispatch_batches = None,
|
| 228 |
+
split_batches = None,
|
| 229 |
+
include_tokens_per_second = False,
|
| 230 |
+
include_num_input_tokens_seen = False,
|
| 231 |
+
neftune_noise_alpha = None,
|
| 232 |
+
optim_target_modules = None,
|
| 233 |
+
batch_eval_metrics = False,
|
| 234 |
+
eval_on_start = False,
|
| 235 |
+
use_liger_kernel = False,
|
| 236 |
+
eval_use_gather_object = False,
|
| 237 |
+
average_tokens_across_devices = False,
|
| 238 |
+
max_length = 1024,
|
| 239 |
+
max_prompt_length = 512,
|
| 240 |
+
max_completion_length = None,
|
| 241 |
+
beta = 0.1,
|
| 242 |
+
label_pad_token_id = -100,
|
| 243 |
+
padding_value = None,
|
| 244 |
+
truncation_mode = 'keep_end',
|
| 245 |
+
disable_dropout = True,
|
| 246 |
+
generate_during_eval = False,
|
| 247 |
+
is_encoder_decoder = None,
|
| 248 |
+
precompute_ref_log_probs = False,
|
| 249 |
+
model_init_kwargs = None,
|
| 250 |
+
ref_model_init_kwargs = None,
|
| 251 |
+
dataset_num_proc = None,
|
| 252 |
+
prompt_sample_size = 1024,
|
| 253 |
+
min_density_ratio = 0.5,
|
| 254 |
+
max_density_ratio = 10.0,
|
| 255 |
+
vllm_sampling_params = None,
|
| 256 |
+
unsloth_num_chunks = -1,
|
| 257 |
+
**kwargs,
|
| 258 |
+
):
|
| 259 |
+
if learning_rate < 1e-7: raise FloatingPointError(f'Unsloth: Your learning rate of `{learning_rate}` is too small and less than 1e-7! Consider increasing it, otherwise gradient updates will be close to 0!')
|
| 260 |
+
if learning_rate > 1: raise OverflowError(f'Unsloth: Your learning rate of `{learning_rate}` is way too larger > 1! Consider decreasing it to 1e-1, otherwise gradient updates will explode!')
|
| 261 |
+
if output_dir is None and save_strategy == 'steps' and save_steps == 500:
|
| 262 |
+
output_dir = 'unsloth_training_checkpoints'
|
| 263 |
+
save_strategy = 'no'
|
| 264 |
+
if dataset_num_proc is None:
|
| 265 |
+
from multiprocessing import cpu_count
|
| 266 |
+
dataset_num_proc = cpu_count()
|
| 267 |
+
|
| 268 |
+
super().__init__(
|
| 269 |
+
output_dir = output_dir,
|
| 270 |
+
overwrite_output_dir = overwrite_output_dir,
|
| 271 |
+
do_train = do_train,
|
| 272 |
+
do_eval = do_eval,
|
| 273 |
+
do_predict = do_predict,
|
| 274 |
+
eval_strategy = eval_strategy,
|
| 275 |
+
prediction_loss_only = prediction_loss_only,
|
| 276 |
+
per_device_train_batch_size = per_device_train_batch_size,
|
| 277 |
+
per_device_eval_batch_size = per_device_eval_batch_size,
|
| 278 |
+
per_gpu_train_batch_size = per_gpu_train_batch_size,
|
| 279 |
+
per_gpu_eval_batch_size = per_gpu_eval_batch_size,
|
| 280 |
+
gradient_accumulation_steps = gradient_accumulation_steps,
|
| 281 |
+
eval_accumulation_steps = eval_accumulation_steps,
|
| 282 |
+
eval_delay = eval_delay,
|
| 283 |
+
torch_empty_cache_steps = torch_empty_cache_steps,
|
| 284 |
+
learning_rate = learning_rate,
|
| 285 |
+
weight_decay = weight_decay,
|
| 286 |
+
adam_beta1 = adam_beta1,
|
| 287 |
+
adam_beta2 = adam_beta2,
|
| 288 |
+
adam_epsilon = adam_epsilon,
|
| 289 |
+
max_grad_norm = max_grad_norm,
|
| 290 |
+
num_train_epochs = num_train_epochs,
|
| 291 |
+
max_steps = max_steps,
|
| 292 |
+
lr_scheduler_type = lr_scheduler_type,
|
| 293 |
+
warmup_ratio = warmup_ratio,
|
| 294 |
+
warmup_steps = warmup_steps,
|
| 295 |
+
log_level = log_level,
|
| 296 |
+
log_level_replica = log_level_replica,
|
| 297 |
+
log_on_each_node = log_on_each_node,
|
| 298 |
+
logging_dir = logging_dir,
|
| 299 |
+
logging_strategy = logging_strategy,
|
| 300 |
+
logging_first_step = logging_first_step,
|
| 301 |
+
logging_steps = logging_steps,
|
| 302 |
+
logging_nan_inf_filter = logging_nan_inf_filter,
|
| 303 |
+
save_strategy = save_strategy,
|
| 304 |
+
save_steps = save_steps,
|
| 305 |
+
save_total_limit = save_total_limit,
|
| 306 |
+
save_safetensors = save_safetensors,
|
| 307 |
+
save_on_each_node = save_on_each_node,
|
| 308 |
+
save_only_model = save_only_model,
|
| 309 |
+
restore_callback_states_from_checkpoint = restore_callback_states_from_checkpoint,
|
| 310 |
+
no_cuda = no_cuda,
|
| 311 |
+
use_cpu = use_cpu,
|
| 312 |
+
use_mps_device = use_mps_device,
|
| 313 |
+
seed = seed,
|
| 314 |
+
data_seed = data_seed,
|
| 315 |
+
jit_mode_eval = jit_mode_eval,
|
| 316 |
+
use_ipex = use_ipex,
|
| 317 |
+
bf16 = bf16,
|
| 318 |
+
fp16 = fp16,
|
| 319 |
+
fp16_opt_level = fp16_opt_level,
|
| 320 |
+
half_precision_backend = half_precision_backend,
|
| 321 |
+
bf16_full_eval = bf16_full_eval,
|
| 322 |
+
fp16_full_eval = fp16_full_eval,
|
| 323 |
+
tf32 = tf32,
|
| 324 |
+
local_rank = local_rank,
|
| 325 |
+
ddp_backend = ddp_backend,
|
| 326 |
+
tpu_num_cores = tpu_num_cores,
|
| 327 |
+
tpu_metrics_debug = tpu_metrics_debug,
|
| 328 |
+
debug = debug,
|
| 329 |
+
dataloader_drop_last = dataloader_drop_last,
|
| 330 |
+
eval_steps = eval_steps,
|
| 331 |
+
dataloader_num_workers = dataloader_num_workers,
|
| 332 |
+
dataloader_prefetch_factor = dataloader_prefetch_factor,
|
| 333 |
+
past_index = past_index,
|
| 334 |
+
run_name = run_name,
|
| 335 |
+
disable_tqdm = disable_tqdm,
|
| 336 |
+
remove_unused_columns = remove_unused_columns,
|
| 337 |
+
label_names = label_names,
|
| 338 |
+
load_best_model_at_end = load_best_model_at_end,
|
| 339 |
+
metric_for_best_model = metric_for_best_model,
|
| 340 |
+
greater_is_better = greater_is_better,
|
| 341 |
+
ignore_data_skip = ignore_data_skip,
|
| 342 |
+
fsdp = fsdp,
|
| 343 |
+
fsdp_min_num_params = fsdp_min_num_params,
|
| 344 |
+
fsdp_config = fsdp_config,
|
| 345 |
+
tp_size = tp_size,
|
| 346 |
+
fsdp_transformer_layer_cls_to_wrap = fsdp_transformer_layer_cls_to_wrap,
|
| 347 |
+
accelerator_config = accelerator_config,
|
| 348 |
+
deepspeed = deepspeed,
|
| 349 |
+
label_smoothing_factor = label_smoothing_factor,
|
| 350 |
+
optim = optim,
|
| 351 |
+
optim_args = optim_args,
|
| 352 |
+
adafactor = adafactor,
|
| 353 |
+
group_by_length = group_by_length,
|
| 354 |
+
length_column_name = length_column_name,
|
| 355 |
+
report_to = report_to,
|
| 356 |
+
ddp_find_unused_parameters = ddp_find_unused_parameters,
|
| 357 |
+
ddp_bucket_cap_mb = ddp_bucket_cap_mb,
|
| 358 |
+
ddp_broadcast_buffers = ddp_broadcast_buffers,
|
| 359 |
+
dataloader_pin_memory = dataloader_pin_memory,
|
| 360 |
+
dataloader_persistent_workers = dataloader_persistent_workers,
|
| 361 |
+
skip_memory_metrics = skip_memory_metrics,
|
| 362 |
+
use_legacy_prediction_loop = use_legacy_prediction_loop,
|
| 363 |
+
push_to_hub = push_to_hub,
|
| 364 |
+
resume_from_checkpoint = resume_from_checkpoint,
|
| 365 |
+
hub_model_id = hub_model_id,
|
| 366 |
+
hub_strategy = hub_strategy,
|
| 367 |
+
hub_token = hub_token,
|
| 368 |
+
hub_private_repo = hub_private_repo,
|
| 369 |
+
hub_always_push = hub_always_push,
|
| 370 |
+
gradient_checkpointing = gradient_checkpointing,
|
| 371 |
+
gradient_checkpointing_kwargs = gradient_checkpointing_kwargs,
|
| 372 |
+
include_inputs_for_metrics = include_inputs_for_metrics,
|
| 373 |
+
eval_do_concat_batches = eval_do_concat_batches,
|
| 374 |
+
fp16_backend = fp16_backend,
|
| 375 |
+
evaluation_strategy = evaluation_strategy,
|
| 376 |
+
push_to_hub_model_id = push_to_hub_model_id,
|
| 377 |
+
push_to_hub_organization = push_to_hub_organization,
|
| 378 |
+
push_to_hub_token = push_to_hub_token,
|
| 379 |
+
mp_parameters = mp_parameters,
|
| 380 |
+
auto_find_batch_size = auto_find_batch_size,
|
| 381 |
+
full_determinism = full_determinism,
|
| 382 |
+
torchdynamo = torchdynamo,
|
| 383 |
+
ray_scope = ray_scope,
|
| 384 |
+
ddp_timeout = ddp_timeout,
|
| 385 |
+
torch_compile = torch_compile,
|
| 386 |
+
torch_compile_backend = torch_compile_backend,
|
| 387 |
+
torch_compile_mode = torch_compile_mode,
|
| 388 |
+
dispatch_batches = dispatch_batches,
|
| 389 |
+
split_batches = split_batches,
|
| 390 |
+
include_tokens_per_second = include_tokens_per_second,
|
| 391 |
+
include_num_input_tokens_seen = include_num_input_tokens_seen,
|
| 392 |
+
neftune_noise_alpha = neftune_noise_alpha,
|
| 393 |
+
optim_target_modules = optim_target_modules,
|
| 394 |
+
batch_eval_metrics = batch_eval_metrics,
|
| 395 |
+
eval_on_start = eval_on_start,
|
| 396 |
+
use_liger_kernel = use_liger_kernel,
|
| 397 |
+
eval_use_gather_object = eval_use_gather_object,
|
| 398 |
+
average_tokens_across_devices = average_tokens_across_devices,
|
| 399 |
+
max_length = max_length,
|
| 400 |
+
max_prompt_length = max_prompt_length,
|
| 401 |
+
max_completion_length = max_completion_length,
|
| 402 |
+
beta = beta,
|
| 403 |
+
label_pad_token_id = label_pad_token_id,
|
| 404 |
+
padding_value = padding_value,
|
| 405 |
+
truncation_mode = truncation_mode,
|
| 406 |
+
disable_dropout = disable_dropout,
|
| 407 |
+
generate_during_eval = generate_during_eval,
|
| 408 |
+
is_encoder_decoder = is_encoder_decoder,
|
| 409 |
+
precompute_ref_log_probs = precompute_ref_log_probs,
|
| 410 |
+
model_init_kwargs = model_init_kwargs,
|
| 411 |
+
ref_model_init_kwargs = ref_model_init_kwargs,
|
| 412 |
+
dataset_num_proc = dataset_num_proc,
|
| 413 |
+
prompt_sample_size = prompt_sample_size,
|
| 414 |
+
min_density_ratio = min_density_ratio,
|
| 415 |
+
max_density_ratio = max_density_ratio,**kwargs)
|
| 416 |
+
self.vllm_sampling_params = vllm_sampling_params
|
| 417 |
+
self.unsloth_num_chunks = unsloth_num_chunks
|
| 418 |
+
pass
|
| 419 |
+
|
| 420 |
+
class _UnslothBCOTrainer(Trainer):
|
| 421 |
+
r""""""
|
| 422 |
+
|
| 423 |
+
_tag_names = ["trl", "bco"]
|
| 424 |
+
|
| 425 |
+
def __init__(
|
| 426 |
+
self,
|
| 427 |
+
model: Union[PreTrainedModel, nn.Module, str] = None,
|
| 428 |
+
ref_model: Optional[Union[PreTrainedModel, nn.Module, str]] = None,
|
| 429 |
+
args: BCOConfig = None,
|
| 430 |
+
train_dataset: Optional[Dataset] = None,
|
| 431 |
+
eval_dataset: Optional[Union[Dataset, dict[str, Dataset]]] = None,
|
| 432 |
+
processing_class: Optional[
|
| 433 |
+
Union[PreTrainedTokenizerBase, BaseImageProcessor, FeatureExtractionMixin, ProcessorMixin]
|
| 434 |
+
] = None,
|
| 435 |
+
data_collator: Optional[DataCollator] = None,
|
| 436 |
+
model_init: Optional[Callable[[], PreTrainedModel]] = None,
|
| 437 |
+
callbacks: Optional[list[TrainerCallback]] = None,
|
| 438 |
+
optimizers: tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR] = (None, None),
|
| 439 |
+
preprocess_logits_for_metrics: Optional[Callable[[torch.Tensor, torch.Tensor], torch.Tensor]] = None,
|
| 440 |
+
peft_config: Optional[dict] = None,
|
| 441 |
+
compute_metrics: Optional[Callable[[EvalLoopOutput], dict]] = None,
|
| 442 |
+
model_adapter_name: Optional[str] = None,
|
| 443 |
+
ref_adapter_name: Optional[str] = None,
|
| 444 |
+
embedding_func: Optional[Callable] = None,
|
| 445 |
+
embedding_tokenizer: Optional[PreTrainedTokenizerBase] = None,
|
| 446 |
+
):
|
| 447 |
+
if not is_sklearn_available():
|
| 448 |
+
raise ImportError(
|
| 449 |
+
"BCOTrainer requires the scikit-learn library. Please install it with `pip install scikit-learn`."
|
| 450 |
+
)
|
| 451 |
+
|
| 452 |
+
if type(args) is TrainingArguments:
|
| 453 |
+
raise ValueError("Please use `BCOConfig` instead `TrainingArguments`.")
|
| 454 |
+
|
| 455 |
+
if not isinstance(model, str) and ref_model is model:
|
| 456 |
+
raise ValueError(
|
| 457 |
+
"`model` and `ref_model` cannot be the same object. If you want `ref_model` to be the "
|
| 458 |
+
"same as `model`, you must mass a copy of it, or `None` if you use peft."
|
| 459 |
+
)
|
| 460 |
+
|
| 461 |
+
if args.model_init_kwargs is None:
|
| 462 |
+
model_init_kwargs = {}
|
| 463 |
+
elif not isinstance(model, str):
|
| 464 |
+
raise ValueError("You passed model_kwargs to the BCOTrainer. But your model is already instantiated.")
|
| 465 |
+
else:
|
| 466 |
+
model_init_kwargs = args.model_init_kwargs
|
| 467 |
+
torch_dtype = model_init_kwargs.get("torch_dtype")
|
| 468 |
+
if torch_dtype is not None:
|
| 469 |
+
# Convert to `torch.dtype` if an str is passed
|
| 470 |
+
if isinstance(torch_dtype, str) and torch_dtype != "auto":
|
| 471 |
+
torch_dtype = getattr(torch, torch_dtype)
|
| 472 |
+
if torch_dtype != "auto" and not isinstance(torch_dtype, torch.dtype):
|
| 473 |
+
raise ValueError(
|
| 474 |
+
f"Invalid `torch_dtype` passed to the BCOConfig. Expected a string with either `torch.dtype` or 'auto', but got {torch_dtype}."
|
| 475 |
+
)
|
| 476 |
+
model_init_kwargs["torch_dtype"] = torch_dtype
|
| 477 |
+
|
| 478 |
+
if args.ref_model_init_kwargs is None:
|
| 479 |
+
ref_model_init_kwargs = {}
|
| 480 |
+
elif not isinstance(ref_model, str):
|
| 481 |
+
raise ValueError(
|
| 482 |
+
"You passed ref_model_kwargs to the BCOTrainer. But your ref_model is already instantiated."
|
| 483 |
+
)
|
| 484 |
+
else:
|
| 485 |
+
ref_model_init_kwargs = args.ref_model_init_kwargs
|
| 486 |
+
torch_dtype = ref_model_init_kwargs.get("torch_dtype")
|
| 487 |
+
if torch_dtype is not None:
|
| 488 |
+
# Convert to `torch.dtype` if an str is passed
|
| 489 |
+
if isinstance(torch_dtype, str) and torch_dtype != "auto":
|
| 490 |
+
torch_dtype = getattr(torch, torch_dtype)
|
| 491 |
+
if torch_dtype != "auto" and not isinstance(torch_dtype, torch.dtype):
|
| 492 |
+
raise ValueError(
|
| 493 |
+
f"Invalid `torch_dtype` passed to the BCOConfig. Expected a string with either `torch.dtype` or 'auto', but got {torch_dtype}."
|
| 494 |
+
)
|
| 495 |
+
ref_model_init_kwargs["torch_dtype"] = torch_dtype
|
| 496 |
+
|
| 497 |
+
if isinstance(model, str):
|
| 498 |
+
model = AutoModelForCausalLM.from_pretrained(model, **model_init_kwargs)
|
| 499 |
+
|
| 500 |
+
if isinstance(ref_model, str):
|
| 501 |
+
ref_model = AutoModelForCausalLM.from_pretrained(ref_model, **ref_model_init_kwargs)
|
| 502 |
+
|
| 503 |
+
# Initialize this variable to False. This helps tracking the case when `peft_module_casting_to_bf16`
|
| 504 |
+
# has been called in order to properly call autocast if needed.
|
| 505 |
+
self._peft_has_been_casted_to_bf16 = False
|
| 506 |
+
|
| 507 |
+
if not is_peft_available() and peft_config is not None:
|
| 508 |
+
raise ValueError(
|
| 509 |
+
"PEFT is not installed and you passed a `peft_config` in the trainer's kwargs, please install it with `pip install peft` to use the PEFT models"
|
| 510 |
+
)
|
| 511 |
+
elif is_peft_available() and peft_config is not None:
|
| 512 |
+
# if model is a peft model and we have a peft_config, we merge and unload it first
|
| 513 |
+
if isinstance(model, PeftModel):
|
| 514 |
+
model = model.merge_and_unload()
|
| 515 |
+
|
| 516 |
+
if getattr(model, "is_loaded_in_8bit", False) or getattr(model, "is_loaded_in_4bit", False):
|
| 517 |
+
_support_gc_kwargs = hasattr(
|
| 518 |
+
args, "gradient_checkpointing_kwargs"
|
| 519 |
+
) and "gradient_checkpointing_kwargs" in list(
|
| 520 |
+
inspect.signature(prepare_model_for_kbit_training).parameters
|
| 521 |
+
)
|
| 522 |
+
|
| 523 |
+
prepare_model_kwargs = {"use_gradient_checkpointing": args.gradient_checkpointing}
|
| 524 |
+
|
| 525 |
+
if _support_gc_kwargs:
|
| 526 |
+
prepare_model_kwargs["gradient_checkpointing_kwargs"] = args.gradient_checkpointing_kwargs
|
| 527 |
+
|
| 528 |
+
model = prepare_model_for_kbit_training(model, **prepare_model_kwargs)
|
| 529 |
+
elif getattr(args, "gradient_checkpointing", False):
|
| 530 |
+
# For backward compatibility with older versions of transformers
|
| 531 |
+
if hasattr(model, "enable_input_require_grads"):
|
| 532 |
+
model.enable_input_require_grads()
|
| 533 |
+
else:
|
| 534 |
+
|
| 535 |
+
def make_inputs_require_grad(module, input, output):
|
| 536 |
+
output.requires_grad_(True)
|
| 537 |
+
|
| 538 |
+
model.get_input_embeddings().register_forward_hook(make_inputs_require_grad)
|
| 539 |
+
|
| 540 |
+
# get peft model with the given config
|
| 541 |
+
model = model
|
| 542 |
+
if args.bf16 and getattr(model, "is_loaded_in_4bit", False):
|
| 543 |
+
peft_module_casting_to_bf16(model)
|
| 544 |
+
# If args.bf16 we need to explicitly call `generate` with torch amp autocast context manager
|
| 545 |
+
self._peft_has_been_casted_to_bf16 = True
|
| 546 |
+
|
| 547 |
+
# For models that use gradient_checkpointing, we need to attach a hook that enables input
|
| 548 |
+
# to explicitly have `requires_grad=True`, otherwise training will either silently
|
| 549 |
+
# fail or completely fail.
|
| 550 |
+
elif getattr(args, "gradient_checkpointing", False):
|
| 551 |
+
# For backward compatibility with older versions of transformers
|
| 552 |
+
if hasattr(model, "enable_input_require_grads"):
|
| 553 |
+
model.enable_input_require_grads()
|
| 554 |
+
else:
|
| 555 |
+
|
| 556 |
+
def make_inputs_require_grad(module, input, output):
|
| 557 |
+
output.requires_grad_(True)
|
| 558 |
+
|
| 559 |
+
model.get_input_embeddings().register_forward_hook(make_inputs_require_grad)
|
| 560 |
+
|
| 561 |
+
if args.generate_during_eval and not (is_wandb_available() or is_comet_available()):
|
| 562 |
+
raise ValueError(
|
| 563 |
+
"`generate_during_eval=True` requires Weights and Biases or Comet to be installed."
|
| 564 |
+
" Please install `wandb` or `comet-ml` to resolve."
|
| 565 |
+
)
|
| 566 |
+
|
| 567 |
+
if model is not None:
|
| 568 |
+
self.is_encoder_decoder = model.config.is_encoder_decoder
|
| 569 |
+
elif args.is_encoder_decoder is None:
|
| 570 |
+
raise ValueError("When no model is provided, you need to pass the parameter is_encoder_decoder.")
|
| 571 |
+
else:
|
| 572 |
+
self.is_encoder_decoder = args.is_encoder_decoder
|
| 573 |
+
|
| 574 |
+
self.is_peft_model = is_peft_available() and isinstance(model, PeftModel)
|
| 575 |
+
self.model_adapter_name = model_adapter_name
|
| 576 |
+
self.ref_adapter_name = ref_adapter_name
|
| 577 |
+
|
| 578 |
+
if ref_model:
|
| 579 |
+
self.ref_model = ref_model
|
| 580 |
+
elif self.is_peft_model or args.precompute_ref_log_probs:
|
| 581 |
+
# The `model` with adapters turned off will be used as the reference model
|
| 582 |
+
self.ref_model = None
|
| 583 |
+
else:
|
| 584 |
+
self.ref_model = create_reference_model(model)
|
| 585 |
+
|
| 586 |
+
if processing_class is None:
|
| 587 |
+
raise ValueError(
|
| 588 |
+
"max_length or a processing_class must be specified when using the default DPODataCollatorWithPadding"
|
| 589 |
+
)
|
| 590 |
+
if args.max_length is None:
|
| 591 |
+
warnings.warn(
|
| 592 |
+
"When using DPODataCollatorWithPadding, you should set `max_length` in the `BCOConfig`. "
|
| 593 |
+
"It will be set to `512` by default, but you should do it yourself in the future.",
|
| 594 |
+
UserWarning,
|
| 595 |
+
)
|
| 596 |
+
max_length = 512
|
| 597 |
+
if args.max_length is not None:
|
| 598 |
+
max_length = args.max_length
|
| 599 |
+
|
| 600 |
+
if args.max_prompt_length is None:
|
| 601 |
+
warnings.warn(
|
| 602 |
+
"When using DPODataCollatorWithPadding, you should set `max_prompt_length` in the `BCOConfig`. "
|
| 603 |
+
"It will be set to `128` by default, but you should do it yourself in the future.",
|
| 604 |
+
UserWarning,
|
| 605 |
+
)
|
| 606 |
+
max_prompt_length = 128
|
| 607 |
+
if args.max_prompt_length is not None:
|
| 608 |
+
max_prompt_length = args.max_prompt_length
|
| 609 |
+
|
| 610 |
+
max_completion_length = None
|
| 611 |
+
if args.max_completion_length is None and self.is_encoder_decoder:
|
| 612 |
+
warnings.warn(
|
| 613 |
+
"When using DPODataCollatorWithPadding with an encoder decoder architecture, you should set `max_completion_length` in the BCOTrainer's init"
|
| 614 |
+
" it will be set to `128` by default, but you should do it yourself in the future.",
|
| 615 |
+
UserWarning,
|
| 616 |
+
)
|
| 617 |
+
max_completion_length = 128
|
| 618 |
+
if args.max_completion_length is not None and self.is_encoder_decoder:
|
| 619 |
+
max_completion_length = args.max_completion_length
|
| 620 |
+
|
| 621 |
+
if data_collator is None:
|
| 622 |
+
data_collator = DPODataCollatorWithPadding(
|
| 623 |
+
pad_token_id=processing_class.pad_token_id,
|
| 624 |
+
label_pad_token_id=args.label_pad_token_id,
|
| 625 |
+
is_encoder_decoder=self.is_encoder_decoder,
|
| 626 |
+
)
|
| 627 |
+
|
| 628 |
+
if args.remove_unused_columns:
|
| 629 |
+
args.remove_unused_columns = False
|
| 630 |
+
# warn users
|
| 631 |
+
warnings.warn(
|
| 632 |
+
"When using DPODataCollatorWithPadding, you should set `remove_unused_columns=False` in your BCOConfig"
|
| 633 |
+
" we have set it for you, but you should do it yourself in the future.",
|
| 634 |
+
UserWarning,
|
| 635 |
+
)
|
| 636 |
+
|
| 637 |
+
self.use_dpo_data_collator = True
|
| 638 |
+
else:
|
| 639 |
+
self.use_dpo_data_collator = False
|
| 640 |
+
|
| 641 |
+
# Disable dropout in the model and reference model
|
| 642 |
+
if args.disable_dropout:
|
| 643 |
+
disable_dropout_in_model(model)
|
| 644 |
+
if self.ref_model is not None:
|
| 645 |
+
disable_dropout_in_model(self.ref_model)
|
| 646 |
+
|
| 647 |
+
self.max_length = max_length
|
| 648 |
+
self.generate_during_eval = args.generate_during_eval
|
| 649 |
+
self.label_pad_token_id = args.label_pad_token_id
|
| 650 |
+
self.padding_value = args.padding_value if args.padding_value is not None else processing_class.pad_token_id
|
| 651 |
+
self.max_prompt_length = max_prompt_length
|
| 652 |
+
self.truncation_mode = args.truncation_mode
|
| 653 |
+
self.max_completion_length = max_completion_length
|
| 654 |
+
self.precompute_ref_log_probs = args.precompute_ref_log_probs
|
| 655 |
+
|
| 656 |
+
# Since ref_logs are precomputed on the first call to get_train/eval_dataloader
|
| 657 |
+
# keep track of first called to avoid computation of future calls
|
| 658 |
+
self._precomputed_train_ref_log_probs = False
|
| 659 |
+
self._precomputed_eval_ref_log_probs = False
|
| 660 |
+
|
| 661 |
+
# metric
|
| 662 |
+
self._stored_metrics = defaultdict(lambda: defaultdict(list))
|
| 663 |
+
|
| 664 |
+
# BCO parameter
|
| 665 |
+
self.beta = args.beta
|
| 666 |
+
self.aux_loss_enabled = getattr(model.config, "output_router_logits", False)
|
| 667 |
+
self.aux_loss_coef = getattr(model.config, "router_aux_loss_coef", 0.0)
|
| 668 |
+
if self.aux_loss_enabled and self.aux_loss_coef == 0.0:
|
| 669 |
+
warnings.warn(
|
| 670 |
+
"You set `output_router_logits` to `True` in the model config, but `router_aux_loss_coef` is set to "
|
| 671 |
+
"`0.0`, meaning the auxiliary loss will not be used. Either set `router_aux_loss_coef` to a value "
|
| 672 |
+
"greater than `0.0`, or set `output_router_logits` to `False` if you don't want to use the auxiliary "
|
| 673 |
+
"loss.",
|
| 674 |
+
UserWarning,
|
| 675 |
+
)
|
| 676 |
+
|
| 677 |
+
# Underlying Distribution Matching argument
|
| 678 |
+
self.embedding_func = embedding_func
|
| 679 |
+
self.embedding_tokenizer = embedding_tokenizer
|
| 680 |
+
|
| 681 |
+
# The trainer estimates the number of FLOPs (floating-point operations) using the number of elements in the
|
| 682 |
+
# input tensor associated with the key "input_ids". However, in BCO, the sampled data does not include the
|
| 683 |
+
# "input_ids" key. Instead, the available keys are "prompt_input_ids" and "completion_input_ids". As a result,
|
| 684 |
+
# the trainer issues the warning: "Could not estimate the number of tokens of the input, floating-point
|
| 685 |
+
# operations will not be computed." To suppress this warning, we set the "estimate_tokens" key in the model's
|
| 686 |
+
# "warnings_issued" dictionary to True. This acts as a flag to indicate that the warning has already been
|
| 687 |
+
# issued.
|
| 688 |
+
model.warnings_issued["estimate_tokens"] = True
|
| 689 |
+
|
| 690 |
+
with PartialState().local_main_process_first():
|
| 691 |
+
# Apply the chat template if needed
|
| 692 |
+
train_dataset = train_dataset.map(
|
| 693 |
+
maybe_apply_chat_template, fn_kwargs={"tokenizer": processing_class}, num_proc=args.dataset_num_proc
|
| 694 |
+
)
|
| 695 |
+
if eval_dataset is not None:
|
| 696 |
+
eval_dataset = eval_dataset.map(
|
| 697 |
+
maybe_apply_chat_template,
|
| 698 |
+
fn_kwargs={"tokenizer": processing_class},
|
| 699 |
+
num_proc=args.dataset_num_proc,
|
| 700 |
+
)
|
| 701 |
+
# Shuffle the datasets
|
| 702 |
+
train_dataset = train_dataset.shuffle(seed=args.data_seed)
|
| 703 |
+
if eval_dataset is not None:
|
| 704 |
+
eval_dataset = eval_dataset.shuffle(seed=args.data_seed)
|
| 705 |
+
# Tokenize and prepare the training datasets
|
| 706 |
+
train_dataset = train_dataset.map(
|
| 707 |
+
_tokenize,
|
| 708 |
+
batched=True,
|
| 709 |
+
fn_kwargs={"tokenizer": processing_class, "embedding_tokenizer": self.embedding_tokenizer},
|
| 710 |
+
num_proc=args.dataset_num_proc,
|
| 711 |
+
desc="Tokenizing train dataset",
|
| 712 |
+
)
|
| 713 |
+
|
| 714 |
+
# Prepare the datasets
|
| 715 |
+
fn_kwargs = {
|
| 716 |
+
"prefix": "",
|
| 717 |
+
"is_encoder_decoder": self.is_encoder_decoder,
|
| 718 |
+
"tokenizer": processing_class,
|
| 719 |
+
"max_length": self.max_length,
|
| 720 |
+
"truncation_mode": self.truncation_mode,
|
| 721 |
+
"label_pad_token_id": self.label_pad_token_id,
|
| 722 |
+
"max_prompt_length": self.max_prompt_length,
|
| 723 |
+
"max_completion_length": self.max_completion_length,
|
| 724 |
+
}
|
| 725 |
+
train_dataset = train_dataset.map(
|
| 726 |
+
_process_tokens,
|
| 727 |
+
fn_kwargs=fn_kwargs,
|
| 728 |
+
num_proc=args.dataset_num_proc,
|
| 729 |
+
desc="Processing tokenized train dataset",
|
| 730 |
+
)
|
| 731 |
+
|
| 732 |
+
if eval_dataset is not None:
|
| 733 |
+
# Tokenize
|
| 734 |
+
eval_dataset = eval_dataset.map(
|
| 735 |
+
_tokenize,
|
| 736 |
+
fn_kwargs={"tokenizer": processing_class, "embedding_tokenizer": self.embedding_tokenizer},
|
| 737 |
+
batched=True,
|
| 738 |
+
num_proc=args.dataset_num_proc,
|
| 739 |
+
desc="Tokenizing eval dataset",
|
| 740 |
+
)
|
| 741 |
+
|
| 742 |
+
# Process
|
| 743 |
+
fn_kwargs = {
|
| 744 |
+
"prefix": "",
|
| 745 |
+
"is_encoder_decoder": self.is_encoder_decoder,
|
| 746 |
+
"tokenizer": processing_class,
|
| 747 |
+
"max_length": self.max_length,
|
| 748 |
+
"truncation_mode": self.truncation_mode,
|
| 749 |
+
"label_pad_token_id": self.label_pad_token_id,
|
| 750 |
+
"max_prompt_length": self.max_prompt_length,
|
| 751 |
+
"max_completion_length": self.max_completion_length,
|
| 752 |
+
}
|
| 753 |
+
eval_dataset = eval_dataset.map(
|
| 754 |
+
_process_tokens,
|
| 755 |
+
fn_kwargs=fn_kwargs,
|
| 756 |
+
num_proc=args.dataset_num_proc,
|
| 757 |
+
desc="Processing tokenized eval dataset",
|
| 758 |
+
)
|
| 759 |
+
|
| 760 |
+
desirable = train_dataset.filter(
|
| 761 |
+
lambda x: x["label"], num_proc=args.dataset_num_proc, desc="Filtering desirable examples"
|
| 762 |
+
)
|
| 763 |
+
undesirable = train_dataset.filter(
|
| 764 |
+
lambda x: not x["label"], num_proc=args.dataset_num_proc, desc="Filtering undesirable examples"
|
| 765 |
+
)
|
| 766 |
+
|
| 767 |
+
desirable = desirable.shuffle(seed=args.data_seed)
|
| 768 |
+
undesirable = undesirable.shuffle(seed=args.data_seed)
|
| 769 |
+
|
| 770 |
+
super().__init__(
|
| 771 |
+
model=model,
|
| 772 |
+
args=args,
|
| 773 |
+
data_collator=data_collator,
|
| 774 |
+
train_dataset=train_dataset,
|
| 775 |
+
eval_dataset=eval_dataset,
|
| 776 |
+
processing_class=processing_class,
|
| 777 |
+
model_init=model_init,
|
| 778 |
+
compute_metrics=compute_metrics,
|
| 779 |
+
callbacks=callbacks,
|
| 780 |
+
optimizers=optimizers,
|
| 781 |
+
preprocess_logits_for_metrics=preprocess_logits_for_metrics,
|
| 782 |
+
)
|
| 783 |
+
|
| 784 |
+
# Gradient accumulation requires scaled loss. Normally, loss scaling in the parent class depends on whether the
|
| 785 |
+
# model accepts loss-related kwargs. Since we compute our own loss, this check is irrelevant. We set
|
| 786 |
+
# self.model_accepts_loss_kwargs to False to enable scaling.
|
| 787 |
+
self.model_accepts_loss_kwargs = False
|
| 788 |
+
|
| 789 |
+
# Add tags for models that have been loaded with the correct transformers version
|
| 790 |
+
if hasattr(self.model, "add_model_tags"):
|
| 791 |
+
self.model.add_model_tags(self._tag_names)
|
| 792 |
+
|
| 793 |
+
if not hasattr(self, "accelerator"):
|
| 794 |
+
raise AttributeError(
|
| 795 |
+
"Your `Trainer` does not have an `accelerator` object. Consider upgrading `transformers`."
|
| 796 |
+
)
|
| 797 |
+
|
| 798 |
+
# Deepspeed Zero-3 does not support precompute_ref_log_probs
|
| 799 |
+
if self.is_deepspeed_enabled:
|
| 800 |
+
if self.accelerator.state.deepspeed_plugin.zero_stage == 3 and self.precompute_ref_log_probs:
|
| 801 |
+
raise ValueError(
|
| 802 |
+
"You cannot use `precompute_ref_log_probs=True` with Deepspeed ZeRO-3. Please set `precompute_ref_log_probs=False`."
|
| 803 |
+
)
|
| 804 |
+
|
| 805 |
+
if self.ref_model is None:
|
| 806 |
+
if not (self.is_peft_model or self.precompute_ref_log_probs):
|
| 807 |
+
raise ValueError(
|
| 808 |
+
"No reference model and model is not a Peft model. Try setting `precompute_ref_log_probs=True`"
|
| 809 |
+
)
|
| 810 |
+
else:
|
| 811 |
+
if self.is_deepspeed_enabled:
|
| 812 |
+
self.ref_model = self._prepare_deepspeed(self.ref_model)
|
| 813 |
+
else:
|
| 814 |
+
self.ref_model = self.accelerator.prepare_model(self.ref_model, evaluation_mode=True)
|
| 815 |
+
|
| 816 |
+
self.running = RunningMoments(accelerator=self.accelerator)
|
| 817 |
+
|
| 818 |
+
if self.embedding_func is None:
|
| 819 |
+
return
|
| 820 |
+
|
| 821 |
+
chosen_embeddings = self._get_sample_prompt_embeddings(desirable, sample_size=self.args.prompt_sample_size)
|
| 822 |
+
rejected_embeddings = self._get_sample_prompt_embeddings(undesirable, sample_size=self.args.prompt_sample_size)
|
| 823 |
+
|
| 824 |
+
embeddings = torch.cat((chosen_embeddings, rejected_embeddings), dim=0)
|
| 825 |
+
labels = torch.cat(
|
| 826 |
+
(torch.ones_like(chosen_embeddings[:, 0]), torch.zeros_like(rejected_embeddings[:, 0])), dim=0
|
| 827 |
+
)
|
| 828 |
+
|
| 829 |
+
self.clf = LogisticRegression(class_weight="balanced").fit(
|
| 830 |
+
embeddings.cpu().float().numpy(), labels.cpu().numpy()
|
| 831 |
+
)
|
| 832 |
+
|
| 833 |
+
@property
|
| 834 |
+
def match_underlying_distribution(self):
|
| 835 |
+
return self.embedding_func is not None and self.embedding_tokenizer is not None
|
| 836 |
+
|
| 837 |
+
def _get_chosen_prob(self, prompt_embeddings: torch.FloatTensor) -> torch.FloatTensor:
|
| 838 |
+
"""
|
| 839 |
+
Calculates the probability if the given prompt embedding is from desirable dataset.
|
| 840 |
+
This function calculates the probability in the process and ensemble across processes.
|
| 841 |
+
"""
|
| 842 |
+
dtype = prompt_embeddings.dtype
|
| 843 |
+
device = prompt_embeddings.device
|
| 844 |
+
rank = self.accelerator.process_index
|
| 845 |
+
|
| 846 |
+
padded_prompt_embeddings = self.accelerator.pad_across_processes(
|
| 847 |
+
prompt_embeddings, pad_index=self.embedding_tokenizer.pad_token_id
|
| 848 |
+
)
|
| 849 |
+
sample_size = padded_prompt_embeddings.shape[0]
|
| 850 |
+
nonzero = padded_prompt_embeddings.mean(dim=1) != self.embedding_tokenizer.pad_token_id
|
| 851 |
+
prompt_embeddings = self.accelerator.gather(padded_prompt_embeddings)
|
| 852 |
+
|
| 853 |
+
# cannot predict for all empty values
|
| 854 |
+
if prompt_embeddings.shape[0] == 0:
|
| 855 |
+
return torch.tensor([], device=device, dtype=dtype)
|
| 856 |
+
|
| 857 |
+
prob = self.clf.predict_proba(prompt_embeddings.cpu().float().numpy())[:, 1]
|
| 858 |
+
prob = torch.as_tensor(prob, dtype=dtype, device=device)
|
| 859 |
+
prob = self.accelerator.reduce(prob, reduction="mean")
|
| 860 |
+
|
| 861 |
+
prob = prob[sample_size * rank : sample_size * (rank + 1)]
|
| 862 |
+
prob = prob[nonzero]
|
| 863 |
+
|
| 864 |
+
return prob
|
| 865 |
+
|
| 866 |
+
def _vectorize_prompt(self, input_ids: torch.LongTensor, attention_mask: torch.LongTensor) -> torch.FloatTensor:
|
| 867 |
+
"""
|
| 868 |
+
Replaces processing_class.pad_token_id to embedding_tokenizer.pad_token_id
|
| 869 |
+
and applies self.embedding_func
|
| 870 |
+
"""
|
| 871 |
+
input_ids = torch.where(
|
| 872 |
+
input_ids == self.processing_class.pad_token_id,
|
| 873 |
+
self.embedding_tokenizer.pad_token_id,
|
| 874 |
+
input_ids,
|
| 875 |
+
)
|
| 876 |
+
|
| 877 |
+
with torch.no_grad():
|
| 878 |
+
embeddings = self.embedding_func(
|
| 879 |
+
input_ids=input_ids,
|
| 880 |
+
attention_mask=attention_mask,
|
| 881 |
+
)
|
| 882 |
+
|
| 883 |
+
return embeddings
|
| 884 |
+
|
| 885 |
+
def _get_prompt_embeddings(
|
| 886 |
+
self, batch: dict[str, Union[list, torch.LongTensor]]
|
| 887 |
+
) -> tuple[torch.FloatTensor, torch.FloatTensor]:
|
| 888 |
+
"""Extract embeddings from frozen embedding model"""
|
| 889 |
+
|
| 890 |
+
if not self.match_underlying_distribution:
|
| 891 |
+
return None, None
|
| 892 |
+
|
| 893 |
+
embeddings = self._vectorize_prompt(
|
| 894 |
+
input_ids=batch["embedding_input_ids"],
|
| 895 |
+
attention_mask=batch["embedding_attention_mask"],
|
| 896 |
+
)
|
| 897 |
+
|
| 898 |
+
chosen_idx = [i for i in range(len(batch["label"])) if batch["label"][i] is True]
|
| 899 |
+
rejected_idx = [i for i in range(len(batch["label"])) if batch["label"][i] is False]
|
| 900 |
+
|
| 901 |
+
chosen_embeddings = embeddings[chosen_idx, ...]
|
| 902 |
+
rejected_embeddings = embeddings[rejected_idx, ...]
|
| 903 |
+
|
| 904 |
+
return (chosen_embeddings, rejected_embeddings)
|
| 905 |
+
|
| 906 |
+
def _get_sample_prompt_embeddings(self, dataset: Dataset, sample_size: int = 512) -> torch.FloatTensor:
|
| 907 |
+
"""
|
| 908 |
+
Sample instances from dataset and get prompt embeddings.
|
| 909 |
+
Used for density ratio classifier training.
|
| 910 |
+
"""
|
| 911 |
+
n_samples = min(len(dataset), sample_size)
|
| 912 |
+
rand_indices = np.random.choice(len(dataset), size=(n_samples,))
|
| 913 |
+
|
| 914 |
+
embedding_dataset = dataset.select(rand_indices)
|
| 915 |
+
|
| 916 |
+
dataloader_params = {
|
| 917 |
+
"batch_size": self.args.per_device_train_batch_size,
|
| 918 |
+
"collate_fn": self.data_collator,
|
| 919 |
+
"num_workers": self.args.dataloader_num_workers,
|
| 920 |
+
"pin_memory": self.args.dataloader_pin_memory,
|
| 921 |
+
"shuffle": False,
|
| 922 |
+
}
|
| 923 |
+
|
| 924 |
+
# prepare dataloader
|
| 925 |
+
data_loader = self.accelerator.prepare(DataLoader(embedding_dataset, **dataloader_params))
|
| 926 |
+
|
| 927 |
+
with torch.no_grad():
|
| 928 |
+
all_embeddings = torch.empty(0)
|
| 929 |
+
for padded_batch in tqdm(iterable=data_loader, desc="Building sample prompt embeddings"):
|
| 930 |
+
embeddings = self._vectorize_prompt(
|
| 931 |
+
input_ids=padded_batch["embedding_input_ids"],
|
| 932 |
+
attention_mask=padded_batch["embedding_attention_mask"],
|
| 933 |
+
)
|
| 934 |
+
embeddings = self.accelerator.gather_for_metrics(embeddings)
|
| 935 |
+
all_embeddings = torch.cat((all_embeddings, embeddings.cpu()))
|
| 936 |
+
|
| 937 |
+
return all_embeddings
|
| 938 |
+
|
| 939 |
+
def _prepare_deepspeed(self, model: PreTrainedModelWrapper):
|
| 940 |
+
# Adapted from accelerate: https://github.com/huggingface/accelerate/blob/739b135f8367becb67ffaada12fe76e3aa60fefd/src/accelerate/accelerator.py#L1473
|
| 941 |
+
deepspeed_plugin = self.accelerator.state.deepspeed_plugin
|
| 942 |
+
config_kwargs = deepcopy(deepspeed_plugin.deepspeed_config)
|
| 943 |
+
|
| 944 |
+
if model is not None:
|
| 945 |
+
if hasattr(model, "config"):
|
| 946 |
+
hidden_size = (
|
| 947 |
+
max(model.config.hidden_sizes)
|
| 948 |
+
if getattr(model.config, "hidden_sizes", None)
|
| 949 |
+
else getattr(model.config, "hidden_size", None)
|
| 950 |
+
)
|
| 951 |
+
if hidden_size is not None and config_kwargs["zero_optimization"]["stage"] == 3:
|
| 952 |
+
# Note that `stage3_prefetch_bucket_size` can produce DeepSpeed messages like: `Invalidate trace cache @ step 0: expected module 1, but got module 0`
|
| 953 |
+
# This is expected and is not an error, see: https://github.com/microsoft/DeepSpeed/discussions/4081
|
| 954 |
+
config_kwargs.update(
|
| 955 |
+
{
|
| 956 |
+
"zero_optimization.reduce_bucket_size": hidden_size * hidden_size,
|
| 957 |
+
"zero_optimization.stage3_param_persistence_threshold": 10 * hidden_size,
|
| 958 |
+
"zero_optimization.stage3_prefetch_bucket_size": 0.9 * hidden_size * hidden_size,
|
| 959 |
+
}
|
| 960 |
+
)
|
| 961 |
+
|
| 962 |
+
# If ZeRO-3 is used, we shard both the active and reference model.
|
| 963 |
+
# Otherwise, we assume the reference model fits in memory and is initialized on each device with ZeRO disabled (stage 0)
|
| 964 |
+
if config_kwargs["zero_optimization"]["stage"] != 3:
|
| 965 |
+
config_kwargs["zero_optimization"]["stage"] = 0
|
| 966 |
+
model, *_ = deepspeed.initialize(model=model, config=config_kwargs)
|
| 967 |
+
model.eval()
|
| 968 |
+
return model
|
| 969 |
+
|
| 970 |
+
def _save_optimizer_and_scheduler(self, output_dir):
|
| 971 |
+
super()._save_optimizer_and_scheduler(output_dir)
|
| 972 |
+
|
| 973 |
+
# When saving optimizer and scheduler to checkpoint, save also the running delta object.
|
| 974 |
+
output_dir = output_dir if output_dir is not None else self.args.output_dir
|
| 975 |
+
|
| 976 |
+
self.running.save_to_json(os.path.join(output_dir, RUNNING_NAME))
|
| 977 |
+
|
| 978 |
+
if self.match_underlying_distribution:
|
| 979 |
+
torch.save(self.clf.get_params(), os.path.join(output_dir, CLF_NAME))
|
| 980 |
+
|
| 981 |
+
def _load_optimizer_and_scheduler(self, checkpoint):
|
| 982 |
+
super()._load_optimizer_and_scheduler(checkpoint)
|
| 983 |
+
|
| 984 |
+
if checkpoint is None:
|
| 985 |
+
return
|
| 986 |
+
# when loading optimizer and scheduler from checkpoint, also load the running delta object.
|
| 987 |
+
running_file = os.path.join(checkpoint, RUNNING_NAME)
|
| 988 |
+
if os.path.isfile(running_file):
|
| 989 |
+
self.running = RunningMoments.load_from_json(self.accelerator, running_file)
|
| 990 |
+
|
| 991 |
+
if self.match_underlying_distribution:
|
| 992 |
+
clf_file = os.path.join(checkpoint, CLF_NAME)
|
| 993 |
+
if os.path.isfile(running_file):
|
| 994 |
+
self.clf.set_params(**torch.load(clf_file, weights_only=True, map_location="cpu"))
|
| 995 |
+
|
| 996 |
+
@contextmanager
|
| 997 |
+
def null_ref_context(self):
|
| 998 |
+
"""Context manager for handling null reference model (that is, peft adapter manipulation)."""
|
| 999 |
+
with (
|
| 1000 |
+
self.accelerator.unwrap_model(self.model).disable_adapter()
|
| 1001 |
+
if self.is_peft_model and not self.ref_adapter_name
|
| 1002 |
+
else nullcontext()
|
| 1003 |
+
):
|
| 1004 |
+
if self.ref_adapter_name:
|
| 1005 |
+
self.model.set_adapter(self.ref_adapter_name)
|
| 1006 |
+
yield
|
| 1007 |
+
if self.ref_adapter_name:
|
| 1008 |
+
self.model.set_adapter(self.model_adapter_name or "default")
|
| 1009 |
+
|
| 1010 |
+
def get_train_dataloader(self) -> DataLoader:
|
| 1011 |
+
"""
|
| 1012 |
+
Returns the training [`~torch.utils.data.DataLoader`].
|
| 1013 |
+
|
| 1014 |
+
Subclass of transformers.src.transformers.trainer.get_train_dataloader to precompute `ref_log_probs`.
|
| 1015 |
+
"""
|
| 1016 |
+
|
| 1017 |
+
if self.precompute_ref_log_probs and not self._precomputed_train_ref_log_probs:
|
| 1018 |
+
dataloader_params = {
|
| 1019 |
+
"batch_size": self.args.per_device_train_batch_size,
|
| 1020 |
+
"collate_fn": self.data_collator,
|
| 1021 |
+
"num_workers": self.args.dataloader_num_workers,
|
| 1022 |
+
"pin_memory": self.args.dataloader_pin_memory,
|
| 1023 |
+
"shuffle": False,
|
| 1024 |
+
}
|
| 1025 |
+
|
| 1026 |
+
# prepare dataloader
|
| 1027 |
+
data_loader = self.accelerator.prepare(DataLoader(self.train_dataset, **dataloader_params))
|
| 1028 |
+
reference_completion_logps = []
|
| 1029 |
+
|
| 1030 |
+
for padded_batch in tqdm(iterable=data_loader, desc="Train dataset reference log probs"):
|
| 1031 |
+
reference_completion_logp = self.compute_reference_log_probs(padded_batch)
|
| 1032 |
+
|
| 1033 |
+
reference_completion_logp = self.accelerator.gather_for_metrics(reference_completion_logp)
|
| 1034 |
+
reference_completion_logps.append(reference_completion_logp.cpu())
|
| 1035 |
+
|
| 1036 |
+
self.train_dataset = self.train_dataset.add_column(
|
| 1037 |
+
name="reference_logps", column=torch.cat(reference_completion_logps).float().numpy()
|
| 1038 |
+
)
|
| 1039 |
+
|
| 1040 |
+
self._precomputed_train_ref_log_probs = True
|
| 1041 |
+
|
| 1042 |
+
return super().get_train_dataloader()
|
| 1043 |
+
|
| 1044 |
+
def get_eval_dataloader(self, eval_dataset: Optional[Dataset] = None) -> DataLoader:
|
| 1045 |
+
"""
|
| 1046 |
+
Returns the evaluation [`~torch.utils.data.DataLoader`].
|
| 1047 |
+
|
| 1048 |
+
Subclass of transformers.src.transformers.trainer.get_eval_dataloader to precompute `ref_log_probs`.
|
| 1049 |
+
|
| 1050 |
+
Args:
|
| 1051 |
+
eval_dataset (`torch.utils.data.Dataset`, *optional*):
|
| 1052 |
+
If provided, will override `self.eval_dataset`. If it is a [`~datasets.Dataset`], columns not accepted
|
| 1053 |
+
by the `model.forward()` method are automatically removed. It must implement `__len__`.
|
| 1054 |
+
"""
|
| 1055 |
+
if eval_dataset is None and self.eval_dataset is None:
|
| 1056 |
+
raise ValueError("Trainer: evaluation requires an eval_dataset.")
|
| 1057 |
+
eval_dataset = eval_dataset if eval_dataset is not None else self.eval_dataset
|
| 1058 |
+
|
| 1059 |
+
if self.precompute_ref_log_probs and not self._precomputed_eval_ref_log_probs:
|
| 1060 |
+
dataloader_params = {
|
| 1061 |
+
"batch_size": self.args.per_device_eval_batch_size,
|
| 1062 |
+
"collate_fn": self.data_collator,
|
| 1063 |
+
"num_workers": self.args.dataloader_num_workers,
|
| 1064 |
+
"pin_memory": self.args.dataloader_pin_memory,
|
| 1065 |
+
"shuffle": False,
|
| 1066 |
+
}
|
| 1067 |
+
|
| 1068 |
+
# prepare dataloader
|
| 1069 |
+
data_loader = self.accelerator.prepare(DataLoader(eval_dataset, **dataloader_params))
|
| 1070 |
+
|
| 1071 |
+
reference_completion_logps = []
|
| 1072 |
+
|
| 1073 |
+
for padded_batch in tqdm(iterable=data_loader, desc="Eval dataset reference log probs"):
|
| 1074 |
+
reference_completion_logp = self.compute_reference_log_probs(padded_batch)
|
| 1075 |
+
|
| 1076 |
+
reference_completion_logp = self.accelerator.gather_for_metrics(reference_completion_logp)
|
| 1077 |
+
reference_completion_logps.append(reference_completion_logp.cpu())
|
| 1078 |
+
|
| 1079 |
+
eval_dataset = eval_dataset.add_column(
|
| 1080 |
+
name="reference_logps", column=torch.cat(reference_completion_logps).float().numpy()
|
| 1081 |
+
)
|
| 1082 |
+
|
| 1083 |
+
# Save calculated reference_chosen_logps and reference_rejected_logps to the eval_dataset for subsequent runs
|
| 1084 |
+
if self.eval_dataset is not None:
|
| 1085 |
+
self.eval_dataset = eval_dataset
|
| 1086 |
+
self._precomputed_eval_ref_log_probs = True
|
| 1087 |
+
|
| 1088 |
+
return super().get_eval_dataloader(eval_dataset=eval_dataset)
|
| 1089 |
+
|
| 1090 |
+
def compute_reference_log_probs(self, padded_batch: dict) -> dict:
|
| 1091 |
+
"""Computes log probabilities of the reference model for a single padded batch of a BCO specific dataset."""
|
| 1092 |
+
with torch.no_grad():
|
| 1093 |
+
if self.ref_model is None:
|
| 1094 |
+
with self.null_ref_context():
|
| 1095 |
+
if self.is_encoder_decoder:
|
| 1096 |
+
completion_logits = self.model(
|
| 1097 |
+
padded_batch["prompt_input_ids"],
|
| 1098 |
+
attention_mask=padded_batch["prompt_attention_mask"],
|
| 1099 |
+
decoder_input_ids=padded_batch.get("completion_decoder_input_ids"),
|
| 1100 |
+
labels=padded_batch["completion_labels"],
|
| 1101 |
+
).logits
|
| 1102 |
+
|
| 1103 |
+
else:
|
| 1104 |
+
completion_logits = self.model(
|
| 1105 |
+
padded_batch["completion_input_ids"],
|
| 1106 |
+
attention_mask=padded_batch["completion_attention_mask"],
|
| 1107 |
+
).logits
|
| 1108 |
+
|
| 1109 |
+
else:
|
| 1110 |
+
if self.is_encoder_decoder:
|
| 1111 |
+
completion_logits = self.ref_model(
|
| 1112 |
+
padded_batch["prompt_input_ids"],
|
| 1113 |
+
attention_mask=padded_batch["prompt_attention_mask"],
|
| 1114 |
+
decoder_input_ids=padded_batch.get("completion_decoder_input_ids"),
|
| 1115 |
+
labels=padded_batch["completion_labels"],
|
| 1116 |
+
).logits
|
| 1117 |
+
|
| 1118 |
+
else:
|
| 1119 |
+
completion_logits = self.ref_model(
|
| 1120 |
+
padded_batch["completion_input_ids"], attention_mask=padded_batch["completion_attention_mask"]
|
| 1121 |
+
).logits
|
| 1122 |
+
|
| 1123 |
+
completion_logps = self.get_batch_logps(
|
| 1124 |
+
completion_logits,
|
| 1125 |
+
padded_batch["completion_labels"],
|
| 1126 |
+
average_log_prob=False,
|
| 1127 |
+
is_encoder_decoder=self.is_encoder_decoder,
|
| 1128 |
+
label_pad_token_id=self.label_pad_token_id,
|
| 1129 |
+
)
|
| 1130 |
+
|
| 1131 |
+
return completion_logps
|
| 1132 |
+
|
| 1133 |
+
@staticmethod
|
| 1134 |
+
def get_batch_logps(
|
| 1135 |
+
logits: torch.FloatTensor,
|
| 1136 |
+
labels: torch.LongTensor,
|
| 1137 |
+
average_log_prob: bool = False,
|
| 1138 |
+
label_pad_token_id: int = -100,
|
| 1139 |
+
is_encoder_decoder: bool = False,
|
| 1140 |
+
) -> torch.FloatTensor:
|
| 1141 |
+
"""Compute the log probabilities of the given labels under the given logits.
|
| 1142 |
+
|
| 1143 |
+
Args:
|
| 1144 |
+
logits: Logits of the model (unnormalized). Shape: (batch_size, sequence_length, vocab_size)
|
| 1145 |
+
labels: Labels for which to compute the log probabilities. Label tokens with a value of label_pad_token_id are ignored. Shape: (batch_size, sequence_length)
|
| 1146 |
+
average_log_prob: If True, return the average log probability per (non-masked) token. Otherwise, return the sum of the log probabilities of the (non-masked) tokens.
|
| 1147 |
+
|
| 1148 |
+
Returns:
|
| 1149 |
+
A tensor of shape (batch_size,) containing the average/sum log probabilities of the given labels under the given logits.
|
| 1150 |
+
"""
|
| 1151 |
+
if logits.shape[:-1] != labels.shape:
|
| 1152 |
+
raise ValueError("Logits (batch and sequence length dim) and labels must have the same shape.")
|
| 1153 |
+
|
| 1154 |
+
if not is_encoder_decoder:
|
| 1155 |
+
labels = labels[:, 1:].clone()
|
| 1156 |
+
logits = logits[:, :-1, :]
|
| 1157 |
+
else:
|
| 1158 |
+
# Fixes end-dec RuntimeError
|
| 1159 |
+
labels = labels.clone()
|
| 1160 |
+
|
| 1161 |
+
loss_mask = labels != label_pad_token_id
|
| 1162 |
+
|
| 1163 |
+
# dummy token; we'll ignore the losses on these tokens later
|
| 1164 |
+
labels[labels == label_pad_token_id] = 0
|
| 1165 |
+
|
| 1166 |
+
per_token_logps = selective_log_softmax(logits, labels)
|
| 1167 |
+
|
| 1168 |
+
if average_log_prob:
|
| 1169 |
+
return (per_token_logps * loss_mask).sum(-1) / loss_mask.sum(-1)
|
| 1170 |
+
else:
|
| 1171 |
+
return (per_token_logps * loss_mask).sum(-1)
|
| 1172 |
+
|
| 1173 |
+
def forward(
|
| 1174 |
+
self, model: nn.Module, batch: dict[str, Union[list, torch.LongTensor]]
|
| 1175 |
+
) -> tuple[torch.FloatTensor, torch.FloatTensor, torch.FloatTensor, torch.FloatTensor]:
|
| 1176 |
+
model_kwargs = (
|
| 1177 |
+
{
|
| 1178 |
+
"labels": batch["completion_labels"],
|
| 1179 |
+
"decoder_input_ids": batch.get("completion_decoder_input_ids"),
|
| 1180 |
+
}
|
| 1181 |
+
if self.is_encoder_decoder
|
| 1182 |
+
else {}
|
| 1183 |
+
)
|
| 1184 |
+
if self.aux_loss_enabled:
|
| 1185 |
+
model_kwargs["output_router_logits"] = True
|
| 1186 |
+
|
| 1187 |
+
outputs = model(
|
| 1188 |
+
batch["completion_input_ids"],
|
| 1189 |
+
attention_mask=batch["completion_attention_mask"],
|
| 1190 |
+
**model_kwargs,
|
| 1191 |
+
)
|
| 1192 |
+
completion_logits = outputs.logits
|
| 1193 |
+
|
| 1194 |
+
completion_logps = self.get_batch_logps(
|
| 1195 |
+
completion_logits,
|
| 1196 |
+
batch["completion_labels"],
|
| 1197 |
+
average_log_prob=False,
|
| 1198 |
+
is_encoder_decoder=self.is_encoder_decoder,
|
| 1199 |
+
label_pad_token_id=self.label_pad_token_id,
|
| 1200 |
+
)
|
| 1201 |
+
|
| 1202 |
+
if completion_logps.shape[0] != len(batch["label"]):
|
| 1203 |
+
raise ValueError(
|
| 1204 |
+
"There is a mismatch between the number of examples in this batch and the number of "
|
| 1205 |
+
"examples for which an output sequence was predicted."
|
| 1206 |
+
)
|
| 1207 |
+
|
| 1208 |
+
chosen_idx = [i for i in range(completion_logps.shape[0]) if batch["label"][i] is True]
|
| 1209 |
+
rejected_idx = [i for i in range(completion_logps.shape[0]) if batch["label"][i] is False]
|
| 1210 |
+
|
| 1211 |
+
chosen_logps = completion_logps[chosen_idx, ...]
|
| 1212 |
+
rejected_logps = completion_logps[rejected_idx, ...]
|
| 1213 |
+
|
| 1214 |
+
chosen_logits = completion_logits[chosen_idx, ...]
|
| 1215 |
+
rejected_logits = completion_logits[rejected_idx, ...]
|
| 1216 |
+
|
| 1217 |
+
if self.aux_loss_enabled:
|
| 1218 |
+
return (chosen_logps, rejected_logps, chosen_logits, rejected_logits, outputs.aux_loss)
|
| 1219 |
+
else:
|
| 1220 |
+
return (chosen_logps, rejected_logps, chosen_logits, rejected_logits)
|
| 1221 |
+
|
| 1222 |
+
def _get_udm_weight(self, rejected_embeddings: torch.FloatTensor) -> torch.FloatTensor:
|
| 1223 |
+
prob_desirable = self._get_chosen_prob(rejected_embeddings)
|
| 1224 |
+
min_ratio = self.args.min_density_ratio
|
| 1225 |
+
max_ratio = self.args.max_density_ratio
|
| 1226 |
+
|
| 1227 |
+
weight = (prob_desirable / (1 - prob_desirable + 1e-8)).clamp(min=min_ratio, max=max_ratio)
|
| 1228 |
+
|
| 1229 |
+
return weight
|
| 1230 |
+
|
| 1231 |
+
def bco_loss(
|
| 1232 |
+
self,
|
| 1233 |
+
policy_chosen_logps: torch.FloatTensor,
|
| 1234 |
+
policy_rejected_logps: torch.FloatTensor,
|
| 1235 |
+
reference_chosen_logps: torch.FloatTensor,
|
| 1236 |
+
reference_rejected_logps: torch.FloatTensor,
|
| 1237 |
+
chosen_embeddings: Optional[torch.FloatTensor],
|
| 1238 |
+
rejected_embeddings: Optional[torch.FloatTensor],
|
| 1239 |
+
) -> tuple[torch.FloatTensor, torch.FloatTensor, torch.FloatTensor, torch.FloatTensor]:
|
| 1240 |
+
"""Compute the BCO loss for a batch of policy and reference model log probabilities.
|
| 1241 |
+
|
| 1242 |
+
Args:
|
| 1243 |
+
policy_chosen_logps: Log probabilities of the policy model for the chosen responses. Shape: (num(chosen) in batch_size,)
|
| 1244 |
+
policy_rejected_logps: Log probabilities of the policy model for the rejected responses. Shape: (num(rejected) in batch_size,)
|
| 1245 |
+
reference_chosen_logps: Log probabilities of the reference model for the chosen responses. Shape: (num(chosen) in batch_size,)
|
| 1246 |
+
reference_rejected_logps: Log probabilities of the reference model for the rejected responses. Shape: (num(rejected) in batch_size,)
|
| 1247 |
+
chosen_embeddings: embeddings of desirable prompts
|
| 1248 |
+
rejected_embeddings: embeddings of undesirable prompts
|
| 1249 |
+
|
| 1250 |
+
Returns:
|
| 1251 |
+
A tuple of four tensors: (losses, chosen_rewards, rejected_rewards, delta).
|
| 1252 |
+
The losses tensor contains the BCO loss for each example in the batch.
|
| 1253 |
+
The chosen_rewards and rejected_rewards tensors contain the rewards for the chosen and rejected responses, respectively.
|
| 1254 |
+
The delta value contains the moving average of all implicit rewards.
|
| 1255 |
+
"""
|
| 1256 |
+
|
| 1257 |
+
if policy_chosen_logps.shape[0] != 0 or reference_chosen_logps.shape[0] != 0:
|
| 1258 |
+
chosen_logratios = policy_chosen_logps - reference_chosen_logps
|
| 1259 |
+
chosen_rewards = self.beta * chosen_logratios
|
| 1260 |
+
else:
|
| 1261 |
+
# lists can't be empty -- if they are, then accelerate.gather will hang
|
| 1262 |
+
chosen_losses = torch.Tensor([]).to(self.accelerator.device)
|
| 1263 |
+
chosen_rewards = torch.Tensor([]).to(self.accelerator.device)
|
| 1264 |
+
|
| 1265 |
+
if policy_rejected_logps.shape[0] != 0 or reference_rejected_logps.shape[0] != 0:
|
| 1266 |
+
rejected_logratios = policy_rejected_logps - reference_rejected_logps
|
| 1267 |
+
rejected_rewards = self.beta * rejected_logratios
|
| 1268 |
+
else:
|
| 1269 |
+
# lists can't be empty -- if they are, then accelerate.gather will hang
|
| 1270 |
+
rejected_losses = torch.Tensor([]).to(self.accelerator.device)
|
| 1271 |
+
rejected_rewards = torch.Tensor([]).to(self.accelerator.device)
|
| 1272 |
+
|
| 1273 |
+
rewards = torch.cat((chosen_rewards, rejected_rewards), 0).mean().detach()
|
| 1274 |
+
self.running.update(rewards)
|
| 1275 |
+
delta = self.running.mean
|
| 1276 |
+
|
| 1277 |
+
if policy_chosen_logps.shape[0] != 0 or reference_chosen_logps.shape[0] != 0:
|
| 1278 |
+
chosen_losses = -F.logsigmoid(chosen_rewards - delta)
|
| 1279 |
+
|
| 1280 |
+
if policy_rejected_logps.shape[0] != 0 or reference_rejected_logps.shape[0] != 0:
|
| 1281 |
+
rejected_losses = -F.logsigmoid(-(rejected_rewards - delta))
|
| 1282 |
+
|
| 1283 |
+
if self.match_underlying_distribution:
|
| 1284 |
+
chosen_weight = torch.ones_like(chosen_losses)
|
| 1285 |
+
rejected_weight = self._get_udm_weight(rejected_embeddings)
|
| 1286 |
+
|
| 1287 |
+
losses = torch.cat((chosen_weight * chosen_losses, rejected_weight * rejected_losses), dim=0)
|
| 1288 |
+
else:
|
| 1289 |
+
losses = torch.cat((chosen_losses, rejected_losses), dim=0)
|
| 1290 |
+
|
| 1291 |
+
return losses, chosen_rewards, rejected_rewards, torch.as_tensor(delta)
|
| 1292 |
+
|
| 1293 |
+
def get_batch_loss_metrics(
|
| 1294 |
+
self,
|
| 1295 |
+
model,
|
| 1296 |
+
batch: dict[str, Union[list, torch.LongTensor]],
|
| 1297 |
+
):
|
| 1298 |
+
"""Compute the BCO loss and other metrics for the given batch of inputs for train or test."""
|
| 1299 |
+
metrics = {}
|
| 1300 |
+
batch = {k: (v.to(self.accelerator.device) if isinstance(v, torch.Tensor) else v) for k, v in batch.items()}
|
| 1301 |
+
|
| 1302 |
+
forward_output = self.forward(model, batch)
|
| 1303 |
+
(
|
| 1304 |
+
policy_chosen_logps,
|
| 1305 |
+
policy_rejected_logps,
|
| 1306 |
+
policy_chosen_logits,
|
| 1307 |
+
policy_rejected_logits,
|
| 1308 |
+
) = forward_output[:4]
|
| 1309 |
+
if self.aux_loss_enabled:
|
| 1310 |
+
aux_loss = forward_output[4]
|
| 1311 |
+
|
| 1312 |
+
# if reference_logps in batch use them, otherwise use the reference model
|
| 1313 |
+
if "reference_logps" in batch:
|
| 1314 |
+
chosen_idx = [i for i in range(batch["reference_logps"].shape[0]) if batch["label"][i] is True]
|
| 1315 |
+
rejected_idx = [i for i in range(batch["reference_logps"].shape[0]) if batch["label"][i] is False]
|
| 1316 |
+
|
| 1317 |
+
reference_chosen_logps = batch["reference_logps"][chosen_idx, ...]
|
| 1318 |
+
reference_rejected_logps = batch["reference_logps"][rejected_idx, ...]
|
| 1319 |
+
else:
|
| 1320 |
+
with torch.no_grad():
|
| 1321 |
+
if self.ref_model is None:
|
| 1322 |
+
with self.null_ref_context():
|
| 1323 |
+
(
|
| 1324 |
+
reference_chosen_logps,
|
| 1325 |
+
reference_rejected_logps,
|
| 1326 |
+
_,
|
| 1327 |
+
_,
|
| 1328 |
+
) = self.forward(self.model, batch)[:4]
|
| 1329 |
+
else:
|
| 1330 |
+
(
|
| 1331 |
+
reference_chosen_logps,
|
| 1332 |
+
reference_rejected_logps,
|
| 1333 |
+
_,
|
| 1334 |
+
_,
|
| 1335 |
+
) = self.forward(self.ref_model, batch)[:4]
|
| 1336 |
+
|
| 1337 |
+
chosen_embeddings, rejected_embeddings = self._get_prompt_embeddings(batch)
|
| 1338 |
+
|
| 1339 |
+
losses, chosen_rewards, rejected_rewards, delta = self.bco_loss(
|
| 1340 |
+
policy_chosen_logps,
|
| 1341 |
+
policy_rejected_logps,
|
| 1342 |
+
reference_chosen_logps,
|
| 1343 |
+
reference_rejected_logps,
|
| 1344 |
+
chosen_embeddings,
|
| 1345 |
+
rejected_embeddings,
|
| 1346 |
+
)
|
| 1347 |
+
metrics["delta"] = self.accelerator.gather_for_metrics(delta).mean().item()
|
| 1348 |
+
|
| 1349 |
+
num_chosen = torch.Tensor([len(chosen_rewards)]).to(self.accelerator.device)
|
| 1350 |
+
num_rejected = torch.Tensor([len(rejected_rewards)]).to(self.accelerator.device)
|
| 1351 |
+
|
| 1352 |
+
all_num_chosen = self.accelerator.gather_for_metrics(num_chosen).sum().item()
|
| 1353 |
+
all_num_rejected = self.accelerator.gather_for_metrics(num_rejected).sum().item()
|
| 1354 |
+
|
| 1355 |
+
if all_num_chosen > 0:
|
| 1356 |
+
metrics["rewards/chosen_sum"] = (
|
| 1357 |
+
self.accelerator.gather_for_metrics(chosen_rewards.nansum()).nansum().item()
|
| 1358 |
+
)
|
| 1359 |
+
metrics["logps/chosen_sum"] = (
|
| 1360 |
+
self.accelerator.gather_for_metrics(policy_chosen_logps.nansum()).nansum().item()
|
| 1361 |
+
)
|
| 1362 |
+
metrics["logits/chosen_sum"] = (
|
| 1363 |
+
self.accelerator.gather_for_metrics(policy_chosen_logits.nansum()).nansum().item()
|
| 1364 |
+
)
|
| 1365 |
+
metrics["count/chosen"] = all_num_chosen
|
| 1366 |
+
|
| 1367 |
+
if all_num_rejected > 0:
|
| 1368 |
+
metrics["rewards/rejected_sum"] = (
|
| 1369 |
+
self.accelerator.gather_for_metrics(rejected_rewards.nansum()).nansum().item()
|
| 1370 |
+
)
|
| 1371 |
+
metrics["logps/rejected_sum"] = (
|
| 1372 |
+
self.accelerator.gather_for_metrics(policy_rejected_logps.nansum()).nansum().item()
|
| 1373 |
+
)
|
| 1374 |
+
metrics["logits/rejected_sum"] = (
|
| 1375 |
+
self.accelerator.gather_for_metrics(policy_rejected_logits.nansum()).nansum().item()
|
| 1376 |
+
)
|
| 1377 |
+
metrics["count/rejected"] = all_num_rejected
|
| 1378 |
+
|
| 1379 |
+
loss = losses.nanmean()
|
| 1380 |
+
if self.aux_loss_enabled:
|
| 1381 |
+
loss += self.aux_loss_coef * aux_loss
|
| 1382 |
+
|
| 1383 |
+
return loss, metrics
|
| 1384 |
+
|
| 1385 |
+
def compute_loss(
|
| 1386 |
+
self,
|
| 1387 |
+
model: Union[PreTrainedModel, nn.Module],
|
| 1388 |
+
inputs: dict[str, Union[torch.Tensor, Any]],
|
| 1389 |
+
return_outputs=False,
|
| 1390 |
+
num_items_in_batch=None,
|
| 1391 |
+
) -> Union[torch.Tensor, tuple[torch.Tensor, dict[str, torch.Tensor]]]:
|
| 1392 |
+
compute_loss_context_manager = amp.autocast("cuda") if self._peft_has_been_casted_to_bf16 else nullcontext()
|
| 1393 |
+
|
| 1394 |
+
with compute_loss_context_manager:
|
| 1395 |
+
loss, metrics = self.get_batch_loss_metrics(model, inputs)
|
| 1396 |
+
|
| 1397 |
+
# Make sure to move the loss to the device the original accumulating loss is at back in the `Trainer` class:
|
| 1398 |
+
loss = loss.to(self.args.device)
|
| 1399 |
+
# force log the metrics
|
| 1400 |
+
if self.accelerator.is_main_process:
|
| 1401 |
+
self.store_metrics(metrics, train_eval="train")
|
| 1402 |
+
|
| 1403 |
+
if return_outputs:
|
| 1404 |
+
return (loss, metrics)
|
| 1405 |
+
return loss
|
| 1406 |
+
|
| 1407 |
+
def store_metrics(self, metrics: dict[str, float], train_eval: Literal["train", "eval"] = "train") -> None:
|
| 1408 |
+
for key, value in metrics.items():
|
| 1409 |
+
self._stored_metrics[train_eval][key].append(value)
|
| 1410 |
+
|
| 1411 |
+
def _get_train_sampler(self) -> Optional[torch.utils.data.Sampler]:
|
| 1412 |
+
if self.train_dataset is None or not has_length(self.train_dataset):
|
| 1413 |
+
return None
|
| 1414 |
+
return SequentialSampler(self.train_dataset)
|
| 1415 |
+
|
| 1416 |
+
def generate_from_model_and_ref(self, model, batch: dict[str, torch.LongTensor]) -> tuple[str, str]:
|
| 1417 |
+
"""Generate samples from the model and reference model for the given batch of inputs."""
|
| 1418 |
+
|
| 1419 |
+
# If one uses `generate_during_eval` with peft + bf16, we need to explicitly call generate with
|
| 1420 |
+
# the torch cuda amp context manager as some hidden states are silently casted to full precision.
|
| 1421 |
+
generate_context_manager = amp.autocast("cuda") if self._peft_has_been_casted_to_bf16 else nullcontext()
|
| 1422 |
+
with generate_context_manager:
|
| 1423 |
+
policy_output = model.generate(
|
| 1424 |
+
input_ids=batch["prompt_input_ids"],
|
| 1425 |
+
attention_mask=batch["prompt_attention_mask"],
|
| 1426 |
+
max_length=self.max_length,
|
| 1427 |
+
do_sample=True,
|
| 1428 |
+
pad_token_id=self.processing_class.pad_token_id,
|
| 1429 |
+
)
|
| 1430 |
+
|
| 1431 |
+
# if reference_output in batch use that otherwise use the reference model
|
| 1432 |
+
if "reference_output" in batch:
|
| 1433 |
+
reference_output = batch["reference_output"]
|
| 1434 |
+
else:
|
| 1435 |
+
if self.ref_model is None:
|
| 1436 |
+
with self.null_ref_context():
|
| 1437 |
+
reference_output = self.model.generate(
|
| 1438 |
+
input_ids=batch["prompt_input_ids"],
|
| 1439 |
+
attention_mask=batch["prompt_attention_mask"],
|
| 1440 |
+
max_length=self.max_length,
|
| 1441 |
+
do_sample=True,
|
| 1442 |
+
pad_token_id=self.processing_class.pad_token_id,
|
| 1443 |
+
)
|
| 1444 |
+
else:
|
| 1445 |
+
reference_output = self.ref_model.generate(
|
| 1446 |
+
input_ids=batch["prompt_input_ids"],
|
| 1447 |
+
attention_mask=batch["prompt_attention_mask"],
|
| 1448 |
+
max_length=self.max_length,
|
| 1449 |
+
do_sample=True,
|
| 1450 |
+
pad_token_id=self.processing_class.pad_token_id,
|
| 1451 |
+
)
|
| 1452 |
+
|
| 1453 |
+
policy_output = pad_to_length(policy_output, self.max_length, self.processing_class.pad_token_id)
|
| 1454 |
+
policy_output_decoded = self.processing_class.batch_decode(policy_output, skip_special_tokens=True)
|
| 1455 |
+
|
| 1456 |
+
reference_output = pad_to_length(reference_output, self.max_length, self.processing_class.pad_token_id)
|
| 1457 |
+
reference_output_decoded = self.processing_class.batch_decode(reference_output, skip_special_tokens=True)
|
| 1458 |
+
|
| 1459 |
+
return policy_output_decoded, reference_output_decoded
|
| 1460 |
+
|
| 1461 |
+
def prediction_step(
|
| 1462 |
+
self,
|
| 1463 |
+
model: Union[PreTrainedModel, nn.Module],
|
| 1464 |
+
inputs: dict[str, Union[torch.Tensor, Any]],
|
| 1465 |
+
prediction_loss_only: bool,
|
| 1466 |
+
ignore_keys: Optional[list[str]] = None,
|
| 1467 |
+
):
|
| 1468 |
+
if ignore_keys is None:
|
| 1469 |
+
if hasattr(model, "config"):
|
| 1470 |
+
ignore_keys = getattr(model.config, "keys_to_ignore_at_inference", [])
|
| 1471 |
+
else:
|
| 1472 |
+
ignore_keys = []
|
| 1473 |
+
|
| 1474 |
+
prediction_context_manager = amp.autocast("cuda") if self._peft_has_been_casted_to_bf16 else nullcontext()
|
| 1475 |
+
with torch.no_grad(), prediction_context_manager:
|
| 1476 |
+
loss, metrics = self.get_batch_loss_metrics(model, inputs)
|
| 1477 |
+
|
| 1478 |
+
# force log the metrics
|
| 1479 |
+
if self.accelerator.is_main_process:
|
| 1480 |
+
self.store_metrics(metrics, train_eval="eval")
|
| 1481 |
+
|
| 1482 |
+
if prediction_loss_only:
|
| 1483 |
+
return (loss.detach(), None, None)
|
| 1484 |
+
|
| 1485 |
+
# logits for the chosen and rejected samples from model
|
| 1486 |
+
logits_dict = {
|
| 1487 |
+
"eval_logits/chosen": metrics["logits/chosen"],
|
| 1488 |
+
"eval_logits/rejected": metrics["logits/rejected"],
|
| 1489 |
+
}
|
| 1490 |
+
logits = tuple(v.unsqueeze(dim=0) for k, v in logits_dict.items() if k not in ignore_keys)
|
| 1491 |
+
logits = torch.stack(logits).mean(axis=1).to(self.accelerator.device)
|
| 1492 |
+
labels = torch.zeros(logits.shape[0], device=self.accelerator.device)
|
| 1493 |
+
|
| 1494 |
+
return (loss.detach(), logits, labels)
|
| 1495 |
+
|
| 1496 |
+
def evaluation_loop(
|
| 1497 |
+
self,
|
| 1498 |
+
dataloader: DataLoader,
|
| 1499 |
+
description: str,
|
| 1500 |
+
prediction_loss_only: Optional[bool] = None,
|
| 1501 |
+
ignore_keys: Optional[list[str]] = None,
|
| 1502 |
+
metric_key_prefix: str = "eval",
|
| 1503 |
+
) -> EvalLoopOutput:
|
| 1504 |
+
"""
|
| 1505 |
+
Overriding built-in evaluation loop to store metrics for each batch.
|
| 1506 |
+
Prediction/evaluation loop, shared by `Trainer.evaluate()` and `Trainer.predict()`.
|
| 1507 |
+
|
| 1508 |
+
Works both with or without labels.
|
| 1509 |
+
"""
|
| 1510 |
+
|
| 1511 |
+
# Sample and save to game log if requested (for one batch to save time)
|
| 1512 |
+
if self.generate_during_eval:
|
| 1513 |
+
# Generate random indices within the range of the total number of samples
|
| 1514 |
+
num_samples = len(dataloader.dataset)
|
| 1515 |
+
random_indices = random.sample(range(num_samples), k=self.args.eval_batch_size)
|
| 1516 |
+
|
| 1517 |
+
# Use dataloader.dataset.select to get the random batch without iterating over the DataLoader
|
| 1518 |
+
random_batch_dataset = dataloader.dataset.select(random_indices)
|
| 1519 |
+
random_batch = self.data_collator(random_batch_dataset)
|
| 1520 |
+
random_batch = self._prepare_inputs(random_batch)
|
| 1521 |
+
|
| 1522 |
+
target_indicies = [i for i in range(len(random_batch["label"])) if random_batch["label"][i] is False]
|
| 1523 |
+
target_batch = {
|
| 1524 |
+
"prompt_input_ids": random_batch["prompt_input_ids"][target_indicies],
|
| 1525 |
+
"prompt_attention_mask": random_batch["prompt_attention_mask"][target_indicies],
|
| 1526 |
+
"prompt": itemgetter(*target_indicies)(random_batch["prompt"]),
|
| 1527 |
+
}
|
| 1528 |
+
policy_output_decoded, ref_output_decoded = self.generate_from_model_and_ref(self.model, target_batch)
|
| 1529 |
+
|
| 1530 |
+
table = pd.DataFrame(
|
| 1531 |
+
columns=["Prompt", "Policy", "Ref Model"],
|
| 1532 |
+
data=[
|
| 1533 |
+
[prompt, pol[len(prompt) :], ref[len(prompt) :]]
|
| 1534 |
+
for prompt, pol, ref in zip(target_batch["prompt"], policy_output_decoded, ref_output_decoded)
|
| 1535 |
+
],
|
| 1536 |
+
)
|
| 1537 |
+
if "wandb" in self.args.report_to:
|
| 1538 |
+
wandb.log({"game_log": wandb.Table(data=table)})
|
| 1539 |
+
|
| 1540 |
+
if "comet_ml" in self.args.report_to:
|
| 1541 |
+
log_table_to_comet_experiment(
|
| 1542 |
+
name="game_log.csv",
|
| 1543 |
+
table=table,
|
| 1544 |
+
)
|
| 1545 |
+
|
| 1546 |
+
# Base evaluation
|
| 1547 |
+
initial_output = super().evaluation_loop(
|
| 1548 |
+
dataloader, description, prediction_loss_only, ignore_keys, metric_key_prefix
|
| 1549 |
+
)
|
| 1550 |
+
|
| 1551 |
+
return initial_output
|
| 1552 |
+
|
| 1553 |
+
def log(self, logs: dict[str, float], start_time: Optional[float] = None) -> None:
|
| 1554 |
+
"""
|
| 1555 |
+
Log `logs` on the various objects watching training, including stored metrics.
|
| 1556 |
+
|
| 1557 |
+
Args:
|
| 1558 |
+
logs (`dict[str, float]`):
|
| 1559 |
+
The values to log.
|
| 1560 |
+
start_time (`float` or `None`, *optional*, defaults to `None`):
|
| 1561 |
+
Start time of the training.
|
| 1562 |
+
"""
|
| 1563 |
+
# logs either has 'loss' or 'eval_loss'
|
| 1564 |
+
train_eval = "train" if "loss" in logs else "eval"
|
| 1565 |
+
# train metrics should have no prefix, eval should have 'eval_'
|
| 1566 |
+
prefix = "eval_" if train_eval == "eval" else ""
|
| 1567 |
+
# accumulate average metrics from sums and lengths
|
| 1568 |
+
for split in ["chosen", "rejected"]:
|
| 1569 |
+
if f"count/{split}" in self._stored_metrics[train_eval]:
|
| 1570 |
+
count_sum = torch.Tensor(self._stored_metrics[train_eval][f"count/{split}"]).sum().item()
|
| 1571 |
+
for metric in ["rewards", "logps", "logits"]:
|
| 1572 |
+
logs[f"{prefix}{metric}/{split}"] = (
|
| 1573 |
+
torch.Tensor(self._stored_metrics[train_eval][f"{metric}/{split}_sum"]).sum().item()
|
| 1574 |
+
/ count_sum
|
| 1575 |
+
)
|
| 1576 |
+
# delete obsolete metric
|
| 1577 |
+
del self._stored_metrics[train_eval][f"{metric}/{split}_sum"]
|
| 1578 |
+
del self._stored_metrics[train_eval][f"count/{split}"]
|
| 1579 |
+
# calculate reward margin
|
| 1580 |
+
if f"{prefix}rewards/chosen" in logs and f"{prefix}rewards/rejected" in logs:
|
| 1581 |
+
logs[f"{prefix}rewards/margins"] = logs[f"{prefix}rewards/chosen"] - logs[f"{prefix}rewards/rejected"]
|
| 1582 |
+
# Add averaged stored metrics to logs
|
| 1583 |
+
for key, metrics in self._stored_metrics[train_eval].items():
|
| 1584 |
+
logs[f"{prefix}{key}"] = torch.Tensor(metrics).mean().item()
|
| 1585 |
+
del self._stored_metrics[train_eval]
|
| 1586 |
+
|
| 1587 |
+
if version.parse(transformers.__version__) >= version.parse("4.47.0.dev0"):
|
| 1588 |
+
return super().log(logs, start_time)
|
| 1589 |
+
else: # transformers<=4.46
|
| 1590 |
+
return super().log(logs)
|
| 1591 |
+
|
| 1592 |
+
def create_model_card(
|
| 1593 |
+
self,
|
| 1594 |
+
model_name: Optional[str] = None,
|
| 1595 |
+
dataset_name: Optional[str] = None,
|
| 1596 |
+
tags: Union[str, list[str], None] = None,
|
| 1597 |
+
):
|
| 1598 |
+
"""
|
| 1599 |
+
Creates a draft of a model card using the information available to the `Trainer`.
|
| 1600 |
+
|
| 1601 |
+
Args:
|
| 1602 |
+
model_name (`str` or `None`, *optional*, defaults to `None`):
|
| 1603 |
+
Name of the model.
|
| 1604 |
+
dataset_name (`str` or `None`, *optional*, defaults to `None`):
|
| 1605 |
+
Name of the dataset used for training.
|
| 1606 |
+
tags (`str`, `list[str]` or `None`, *optional*, defaults to `None`):
|
| 1607 |
+
Tags to be associated with the model card.
|
| 1608 |
+
"""
|
| 1609 |
+
if not self.is_world_process_zero():
|
| 1610 |
+
return
|
| 1611 |
+
|
| 1612 |
+
if hasattr(self.model.config, "_name_or_path") and not os.path.isdir(self.model.config._name_or_path):
|
| 1613 |
+
base_model = self.model.config._name_or_path
|
| 1614 |
+
else:
|
| 1615 |
+
base_model = None
|
| 1616 |
+
|
| 1617 |
+
tags = tags or []
|
| 1618 |
+
if isinstance(tags, str):
|
| 1619 |
+
tags = [tags]
|
| 1620 |
+
|
| 1621 |
+
if hasattr(self.model.config, "unsloth_version"):
|
| 1622 |
+
tags.append("unsloth")
|
| 1623 |
+
|
| 1624 |
+
citation = textwrap.dedent("""\
|
| 1625 |
+
@article{jung2024binary,
|
| 1626 |
+
title = {{Binary Classifier Optimization for Large Language Model Alignment}},
|
| 1627 |
+
author = {Seungjae Jung and Gunsoo Han and Daniel Wontae Nam and Kyoung{-}Woon On},
|
| 1628 |
+
year = 2024,
|
| 1629 |
+
eprint = {arXiv:2404.04656}
|
| 1630 |
+
}""")
|
| 1631 |
+
|
| 1632 |
+
model_card = generate_model_card(
|
| 1633 |
+
base_model=base_model,
|
| 1634 |
+
model_name=model_name,
|
| 1635 |
+
hub_model_id=self.hub_model_id,
|
| 1636 |
+
dataset_name=dataset_name,
|
| 1637 |
+
tags=tags,
|
| 1638 |
+
wandb_url=wandb.run.get_url() if is_wandb_available() and wandb.run is not None else None,
|
| 1639 |
+
comet_url=get_comet_experiment_url(),
|
| 1640 |
+
trainer_name="BCO",
|
| 1641 |
+
trainer_citation=citation,
|
| 1642 |
+
paper_title="Binary Classifier Optimization for Large Language Model Alignment",
|
| 1643 |
+
paper_id="2404.04656",
|
| 1644 |
+
)
|
| 1645 |
+
|
| 1646 |
+
model_card.save(os.path.join(self.args.output_dir, "README.md"))
|
| 1647 |
+
class UnslothBCOTrainer(_UnslothBCOTrainer):
|
| 1648 |
+
"""
|
| 1649 |
+
|
| 1650 |
+
Initialize BCOTrainer from [BCO](https://huggingface.co/papers/2404.04656) paper.
|
| 1651 |
+
|
| 1652 |
+
Args:
|
| 1653 |
+
model (`transformers.PreTrainedModel`):
|
| 1654 |
+
The model to train, preferably an `AutoModelForSequenceClassification`.
|
| 1655 |
+
ref_model (`PreTrainedModelWrapper`):
|
| 1656 |
+
Hugging Face transformer model with a casual language modelling head. Used for implicit reward computation and loss. If no
|
| 1657 |
+
reference model is provided, the trainer will create a reference model with the same architecture as the model to be optimized.
|
| 1658 |
+
args (`BCOConfig`):
|
| 1659 |
+
The arguments to use for training.
|
| 1660 |
+
train_dataset (`datasets.Dataset`):
|
| 1661 |
+
The dataset to use for training.
|
| 1662 |
+
eval_dataset (`datasets.Dataset`):
|
| 1663 |
+
The dataset to use for evaluation.
|
| 1664 |
+
processing_class (`PreTrainedTokenizerBase` or `BaseImageProcessor` or `FeatureExtractionMixin` or `ProcessorMixin`, *optional*):
|
| 1665 |
+
Processing class used to process the data. If provided, will be used to automatically process the inputs
|
| 1666 |
+
for the model, and it will be saved along the model to make it easier to rerun an interrupted training or
|
| 1667 |
+
reuse the fine-tuned model.
|
| 1668 |
+
data_collator (`transformers.DataCollator`, *optional*, defaults to `None`):
|
| 1669 |
+
The data collator to use for training. If None is specified, the default data collator (`DPODataCollatorWithPadding`) will be used
|
| 1670 |
+
which will pad the sequences to the maximum length of the sequences in the batch, given a dataset of paired sequences.
|
| 1671 |
+
model_init (`Callable[[], transformers.PreTrainedModel]`):
|
| 1672 |
+
The model initializer to use for training. If None is specified, the default model initializer will be used.
|
| 1673 |
+
callbacks (`list[transformers.TrainerCallback]`):
|
| 1674 |
+
The callbacks to use for training.
|
| 1675 |
+
optimizers (`tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR]`):
|
| 1676 |
+
The optimizer and scheduler to use for training.
|
| 1677 |
+
preprocess_logits_for_metrics (`Callable[[torch.Tensor, torch.Tensor], torch.Tensor]`):
|
| 1678 |
+
The function to use to preprocess the logits before computing the metrics.
|
| 1679 |
+
peft_config (`dict`, defaults to `None`):
|
| 1680 |
+
The PEFT configuration to use for training. If you pass a PEFT configuration, the model will be wrapped in a PEFT model.
|
| 1681 |
+
compute_metrics (`Callable[[EvalPrediction], dict]`, *optional*):
|
| 1682 |
+
The function to use to compute the metrics. Must take a `EvalPrediction` and return
|
| 1683 |
+
a dictionary string to metric values.
|
| 1684 |
+
model_adapter_name (`str`, defaults to `None`):
|
| 1685 |
+
Name of the train target PEFT adapter, when using LoRA with multiple adapters.
|
| 1686 |
+
ref_adapter_name (`str`, defaults to `None`):
|
| 1687 |
+
Name of the reference PEFT adapter, when using LoRA with multiple adapters.
|
| 1688 |
+
|
| 1689 |
+
"""
|
| 1690 |
+
def __init__(
|
| 1691 |
+
self,
|
| 1692 |
+
model = None,
|
| 1693 |
+
ref_model = None,
|
| 1694 |
+
args = None,
|
| 1695 |
+
train_dataset = None,
|
| 1696 |
+
eval_dataset = None,
|
| 1697 |
+
processing_class = None,
|
| 1698 |
+
data_collator = None,
|
| 1699 |
+
model_init = None,
|
| 1700 |
+
callbacks = None,
|
| 1701 |
+
preprocess_logits_for_metrics = None,
|
| 1702 |
+
peft_config = None,
|
| 1703 |
+
compute_metrics = None,
|
| 1704 |
+
model_adapter_name = None,
|
| 1705 |
+
ref_adapter_name = None,
|
| 1706 |
+
embedding_func = None,
|
| 1707 |
+
embedding_tokenizer = None,
|
| 1708 |
+
**kwargs
|
| 1709 |
+
):
|
| 1710 |
+
if args is None: args = UnslothBCOConfig()
|
| 1711 |
+
use_bf16 = getattr(args, 'bf16', False)
|
| 1712 |
+
use_fp16 = getattr(args, 'fp16', False)
|
| 1713 |
+
force_float32 = False
|
| 1714 |
+
if os.environ.get('UNSLOTH_FORCE_FLOAT32', '0') == '1':
|
| 1715 |
+
print('Unsloth: Switching to float32 training since model cannot work with float16')
|
| 1716 |
+
force_float32 = True
|
| 1717 |
+
mixed_precision_dtype = os.environ.get('UNSLOTH_MIXED_PRECISION', 'float32')
|
| 1718 |
+
dtype = getattr(model.config, 'torch_dtype', None)
|
| 1719 |
+
if dtype is None: dtype = model.get_input_embeddings().dtype
|
| 1720 |
+
from unsloth_zoo.utils import _get_dtype
|
| 1721 |
+
dtype = _get_dtype(dtype)
|
| 1722 |
+
float16 = dtype == torch.float16
|
| 1723 |
+
if not force_float32 and (float16 and use_bf16): raise TypeError('Unsloth: Model is in float16 precision but you want to use bfloat16 precision. Set fp16 to `True` and bf16 to `False`')
|
| 1724 |
+
if not force_float32 and (not float16 and use_fp16): raise TypeError('Unsloth: Model is in bfloat16 precision but you want to use float16 precision. Set fp16 to `False` and bf16 to `True`')
|
| 1725 |
+
if force_float32:
|
| 1726 |
+
args.fp16 = False
|
| 1727 |
+
args.bf16 = False
|
| 1728 |
+
os.environ['ACCELERATE_MIXED_PRECISION'] = 'no'
|
| 1729 |
+
elif (not use_bf16 and not use_fp16) and mixed_precision_dtype == 'float32':
|
| 1730 |
+
args.fp16 = float16
|
| 1731 |
+
args.bf16 = not float16
|
| 1732 |
+
os.environ['ACCELERATE_MIXED_PRECISION'] = 'fp16' if float16 else 'bf16'
|
| 1733 |
+
if getattr(args, 'eval_dataset', None) is not None and getattr(args, 'eval_strategy', 'no') == 'no':
|
| 1734 |
+
args.eval_strategy = 'steps'
|
| 1735 |
+
if getattr(args, 'eval_steps', None) is None: args.eval_steps = 0.1
|
| 1736 |
+
ga_steps = getattr(args, 'gradient_accumulation_steps', None)
|
| 1737 |
+
if ga_steps is not None and ga_steps > 1:
|
| 1738 |
+
from transformers import __version__ as transformers_version
|
| 1739 |
+
if Version(transformers_version) <= Version('4.45.2'):
|
| 1740 |
+
print('**** Unsloth: Please use our fixed gradient_accumulation_steps by updating transformers, TRL and Unsloth!\n'
|
| 1741 |
+
'`pip install --upgrade --no-cache-dir --force-reinstall --no-deps unsloth transformers trl unsloth_zoo`')
|
| 1742 |
+
if getattr(args, 'eval_strategy', 'no') != 'no':
|
| 1743 |
+
eval_bsz = getattr(args, 'per_device_eval_batch_size', 8)
|
| 1744 |
+
if eval_bsz == 8 and args.per_device_train_batch_size < eval_bsz: args.per_device_eval_batch_size = args.per_device_train_batch_size
|
| 1745 |
+
if getattr(args, 'eval_accumulation_steps', None) is None and ga_steps is not None: args.eval_accumulation_steps = ga_steps
|
| 1746 |
+
fp16_full_eval = getattr(args, 'fp16_full_eval', False)
|
| 1747 |
+
bf16_full_eval = getattr(args, 'bf16_full_eval', False)
|
| 1748 |
+
if args.fp16 and bf16_full_eval: args.bf16_full_eval = False; args.fp16_full_eval = True
|
| 1749 |
+
if args.bf16 and fp16_full_eval: args.bf16_full_eval = True; args.fp16_full_eval = False
|
| 1750 |
+
if force_float32:
|
| 1751 |
+
args.bf16_full_eval = False
|
| 1752 |
+
args.fp16_full_eval = False
|
| 1753 |
+
elif os.environ.get('UNSLOTH_MIXED_PRECISION', 'float32') == 'bfloat16':
|
| 1754 |
+
args.bf16_full_eval = True
|
| 1755 |
+
args.fp16_full_eval = False
|
| 1756 |
+
elif not bf16_full_eval and not fp16_full_eval:
|
| 1757 |
+
args.bf16_full_eval = args.bf16
|
| 1758 |
+
args.fp16_full_eval = args.fp16
|
| 1759 |
+
_output_logits = False
|
| 1760 |
+
if locals().get('compute_metrics', None) is not None: _output_logits = True
|
| 1761 |
+
if locals().get('preprocess_logits_for_metrics', None) is not None: _output_logits = True
|
| 1762 |
+
if _output_logits:
|
| 1763 |
+
os.environ['UNSLOTH_RETURN_LOGITS'] = '1'
|
| 1764 |
+
if 'max_seq_length' not in locals() and not hasattr(args, 'max_seq_length'):
|
| 1765 |
+
pass
|
| 1766 |
+
else:
|
| 1767 |
+
model_max_seq_length = getattr(model, 'max_seq_length', None)
|
| 1768 |
+
args_max_seq_length = getattr(args, 'max_seq_length', None)
|
| 1769 |
+
if args_max_seq_length is None and model_max_seq_length is not None:
|
| 1770 |
+
max_seq_length = model.max_seq_length
|
| 1771 |
+
if hasattr(args, 'max_seq_length'): args.max_seq_length = max_seq_length
|
| 1772 |
+
if model is not None and hasattr(model, 'for_training'):
|
| 1773 |
+
model.for_training()
|
| 1774 |
+
if 'tokenizer' in locals() and hasattr(tokenizer, 'padding_side'): tokenizer.padding_side = 'right'
|
| 1775 |
+
if 'processing_class' in locals():
|
| 1776 |
+
if hasattr(processing_class, 'padding_side'): processing_class.padding_side = 'right'
|
| 1777 |
+
if hasattr(processing_class, 'tokenizer') and hasattr(processing_class.tokenizer, 'padding_side'): processing_class.tokenizer.padding_side = 'right'
|
| 1778 |
+
__tokenizer = processing_class if 'processing_class' in locals() else tokenizer
|
| 1779 |
+
from unsloth_zoo.vision_utils import UnslothVisionDataCollator
|
| 1780 |
+
if not isinstance(data_collator, UnslothVisionDataCollator):
|
| 1781 |
+
if isinstance(data_collator, DataCollatorForSeq2Seq) and 'labels' not in train_dataset.column_names:
|
| 1782 |
+
data_collator = DataCollatorForLanguageModeling(__tokenizer, mlm = False)
|
| 1783 |
+
elif isinstance(data_collator, DataCollatorForLanguageModeling) and 'labels' in train_dataset.column_names:
|
| 1784 |
+
data_collator = DataCollatorForSeq2Seq(__tokenizer)
|
| 1785 |
+
else:
|
| 1786 |
+
if hasattr(args, 'remove_unused_columns'): args.remove_unused_columns = False
|
| 1787 |
+
if hasattr(args, 'dataset_text_field'): args.dataset_text_field = ''
|
| 1788 |
+
if hasattr(args, 'dataset_kwargs'): args.dataset_kwargs = {'skip_prepare_dataset': True}
|
| 1789 |
+
if not isinstance(data_collator, UnslothVisionDataCollator):
|
| 1790 |
+
if not hasattr(__tokenizer, 'pad') and hasattr(__tokenizer, 'tokenizer'):
|
| 1791 |
+
if isinstance(data_collator, DataCollatorForSeq2Seq):
|
| 1792 |
+
data_collator = DataCollatorForSeq2Seq(__tokenizer.tokenizer)
|
| 1793 |
+
else:
|
| 1794 |
+
data_collator = DataCollatorForLanguageModeling(__tokenizer.tokenizer, mlm = False)
|
| 1795 |
+
other_metrics = []
|
| 1796 |
+
|
| 1797 |
+
from unsloth_zoo.logging_utils import PatchRLStatistics
|
| 1798 |
+
PatchRLStatistics('bco_trainer', other_metrics)
|
| 1799 |
+
|
| 1800 |
+
super().__init__(
|
| 1801 |
+
model = model,
|
| 1802 |
+
ref_model = ref_model,
|
| 1803 |
+
args = args,
|
| 1804 |
+
train_dataset = train_dataset,
|
| 1805 |
+
eval_dataset = eval_dataset,
|
| 1806 |
+
processing_class = processing_class,
|
| 1807 |
+
data_collator = data_collator,
|
| 1808 |
+
model_init = model_init,
|
| 1809 |
+
callbacks = callbacks,
|
| 1810 |
+
preprocess_logits_for_metrics = preprocess_logits_for_metrics,
|
| 1811 |
+
peft_config = peft_config,
|
| 1812 |
+
compute_metrics = compute_metrics,
|
| 1813 |
+
model_adapter_name = model_adapter_name,
|
| 1814 |
+
ref_adapter_name = ref_adapter_name,
|
| 1815 |
+
embedding_func = embedding_func,
|
| 1816 |
+
embedding_tokenizer = embedding_tokenizer,**kwargs)
|
| 1817 |
+
if hasattr(self, 'neftune_hook_handle'):
|
| 1818 |
+
self.neftune_hook_handle.remove()
|
| 1819 |
+
if hasattr(self, 'neftune_hook_handle'): del self.neftune_hook_handle
|
| 1820 |
+
if getattr(args, 'neftune_noise_alpha', None) is not None:
|
| 1821 |
+
model.get_input_embeddings().neftune_noise_alpha = self.neftune_noise_alpha
|
| 1822 |
+
pass
|
| 1823 |
+
|
| 1824 |
+
pass
|