aparke21 commited on
Commit
9014afd
·
verified ·
1 Parent(s): 08439af

Upload 106 files

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitattributes +3 -0
  2. README.md +19 -6
  3. __pycache__/utils.cpython-312.pyc +0 -0
  4. app.py +1026 -0
  5. app.sh +17 -0
  6. app_logs/app_5936040.out +18 -0
  7. app_logs/app_5936041.out +18 -0
  8. app_logs/app_5936047.out +19 -0
  9. app_logs/app_5936050.out +1 -0
  10. app_logs/app_5936052.out +57 -0
  11. assets/umd_logo.png +3 -0
  12. configs/prompts.yaml +100 -0
  13. configs/task1_demo.yaml +27 -0
  14. configs/task1_demo_sph.yaml +28 -0
  15. data/survey_responses_screened.csv +3 -0
  16. push.sh +22 -0
  17. requirements.txt +167 -0
  18. requirements_concise.txt +18 -0
  19. unsloth/gemma-3-12b-it_task1_1_epochs_test_neutral_partialTraits/README.md +202 -0
  20. unsloth/gemma-3-12b-it_task1_1_epochs_test_neutral_partialTraits/adapter_config.json +31 -0
  21. unsloth/gemma-3-12b-it_task1_1_epochs_test_neutral_partialTraits/adapter_model.safetensors +3 -0
  22. unsloth/gemma-3-12b-it_task1_1_epochs_test_neutral_partialTraits/added_tokens.json +3 -0
  23. unsloth/gemma-3-12b-it_task1_1_epochs_test_neutral_partialTraits/chat_template.json +3 -0
  24. unsloth/gemma-3-12b-it_task1_1_epochs_test_neutral_partialTraits/preprocessor_config.json +29 -0
  25. unsloth/gemma-3-12b-it_task1_1_epochs_test_neutral_partialTraits/processor_config.json +4 -0
  26. unsloth/gemma-3-12b-it_task1_1_epochs_test_neutral_partialTraits/special_tokens_map.json +33 -0
  27. unsloth/gemma-3-12b-it_task1_1_epochs_test_neutral_partialTraits/tokenizer.json +3 -0
  28. unsloth/gemma-3-12b-it_task1_1_epochs_test_neutral_partialTraits/tokenizer.model +3 -0
  29. unsloth/gemma-3-12b-it_task1_1_epochs_test_neutral_partialTraits/tokenizer_config.json +0 -0
  30. unsloth_compiled_cache/AqlmLoraLinear_peft_forward.py +67 -0
  31. unsloth_compiled_cache/AwqLoraLinear_peft_forward.py +66 -0
  32. unsloth_compiled_cache/BatchNorm1d.py +88 -0
  33. unsloth_compiled_cache/BatchNorm2d.py +88 -0
  34. unsloth_compiled_cache/BatchNorm3d.py +88 -0
  35. unsloth_compiled_cache/Conv1d.py +43 -0
  36. unsloth_compiled_cache/Conv2d.py +43 -0
  37. unsloth_compiled_cache/Conv3d.py +43 -0
  38. unsloth_compiled_cache/ConvTranspose1d.py +70 -0
  39. unsloth_compiled_cache/ConvTranspose2d.py +71 -0
  40. unsloth_compiled_cache/ConvTranspose3d.py +71 -0
  41. unsloth_compiled_cache/GPTQLoraLinear_peft_forward.py +73 -0
  42. unsloth_compiled_cache/GroupNorm.py +43 -0
  43. unsloth_compiled_cache/LayerNorm.py +45 -0
  44. unsloth_compiled_cache/Linear4bit_peft_forward.py +97 -0
  45. unsloth_compiled_cache/Linear8bitLt_peft_forward.py +90 -0
  46. unsloth_compiled_cache/Linear_peft_forward.py +89 -0
  47. unsloth_compiled_cache/LoraParallelLinear_peft_forward.py +87 -0
  48. unsloth_compiled_cache/RMSNorm.py +46 -0
  49. unsloth_compiled_cache/UnslothAlignPropTrainer.py +637 -0
  50. unsloth_compiled_cache/UnslothBCOTrainer.py +1824 -0
.gitattributes CHANGED
@@ -36,3 +36,6 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
36
  src_hf_deploy[[:space:]]2/assets/umd_logo.png filter=lfs diff=lfs merge=lfs -text
37
  src_hf_deploy[[:space:]]2/data/survey_responses_screened.csv filter=lfs diff=lfs merge=lfs -text
38
  src_hf_deploy[[:space:]]2/unsloth/gemma-3-12b-it_task1_1_epochs_test_neutral_partialTraits/tokenizer.json filter=lfs diff=lfs merge=lfs -text
 
 
 
 
36
  src_hf_deploy[[:space:]]2/assets/umd_logo.png filter=lfs diff=lfs merge=lfs -text
37
  src_hf_deploy[[:space:]]2/data/survey_responses_screened.csv filter=lfs diff=lfs merge=lfs -text
38
  src_hf_deploy[[:space:]]2/unsloth/gemma-3-12b-it_task1_1_epochs_test_neutral_partialTraits/tokenizer.json filter=lfs diff=lfs merge=lfs -text
39
+ assets/umd_logo.png filter=lfs diff=lfs merge=lfs -text
40
+ data/survey_responses_screened.csv filter=lfs diff=lfs merge=lfs -text
41
+ unsloth/gemma-3-12b-it_task1_1_epochs_test_neutral_partialTraits/tokenizer.json filter=lfs diff=lfs merge=lfs -text
README.md CHANGED
@@ -1,12 +1,25 @@
1
  ---
2
- title: Newtest
3
- emoji: 📈
4
- colorFrom: pink
5
- colorTo: red
6
  sdk: gradio
7
- sdk_version: 6.0.2
8
  app_file: app.py
9
  pinned: false
 
 
 
 
10
  ---
11
 
12
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
 
 
 
 
 
 
 
1
  ---
2
+ title: AI-Empowered Community Simulation (Beta)
3
+ emoji: 🧠
4
+ colorFrom: red
5
+ colorTo: purple
6
  sdk: gradio
7
+ sdk_version: 5.49.1
8
  app_file: app.py
9
  pinned: false
10
+ # hardware: "gpu-a100-large" # REQUESTS A100 80GB GPU
11
+ # hardware: "gpu-l40s" # Request 1x NVIDIA L40S (48GB VRAM)
12
+ # hardware: "zerogpu"
13
+ hardware: "t4-small"
14
  ---
15
 
16
+ # AI-Empowered Community Simulation (Beta)
17
+
18
+ This Space requires **at least 28 Gb of GPU RAM** due to the size of the UnsLoTH long-context VLM model used for inference and summarization.
19
+
20
+ If the hardware fails to start or your account does not have access to this tier,
21
+ please select the appropriate hardware from:
22
+
23
+ **Settings → Hardware → ZeroGPU**
24
+
25
+ ---
__pycache__/utils.cpython-312.pyc ADDED
Binary file (5.34 kB). View file
 
app.py ADDED
@@ -0,0 +1,1026 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Instruction Tuning of LLM for Trait-conditioned Style Impact Caliberation
3
+ """
4
+ import unsloth
5
+ import yaml # type: ignore
6
+ import pandas as pd # type: ignore
7
+ import os
8
+ from PIL import Image # type: ignore
9
+ import gradio as gr
10
+
11
+ import torch # type: ignore
12
+ from langchain_community.chat_models import ChatOllama # type: ignore
13
+ from langchain_core.messages import SystemMessage, HumanMessage # type: ignore
14
+ from langchain_ollama import OllamaEmbeddings # type: ignore
15
+ from langchain_core.output_parsers import StrOutputParser # type: ignore
16
+ from pydantic import BaseModel # format LLM output as JSON # type: ignore
17
+ from unsloth import FastVisionModel, FastModel, FastLanguageModel # type: ignore
18
+ from transformers import TextStreamer # type: ignore
19
+ from unsloth.chat_templates import get_chat_template # type: ignore
20
+ from unsloth.chat_templates import standardize_sharegpt # type: ignore
21
+ from transformers import TextIteratorStreamer
22
+
23
+ from utils import convert_to_base64, load_config, process_trait_info # type: ignore
24
+ from tqdm import tqdm # type: ignore
25
+ from termcolor import colored # type: ignore
26
+ import threading
27
+ import random
28
+ import numpy as np
29
+ import random
30
+
31
+ import threading
32
+ # generation_lock = threading.Lock()
33
+
34
+ # from transformers import StoppingCriteria, StoppingCriteriaList
35
+ # class StopGenerationCriteria(StoppingCriteria):
36
+ # def __init__(self, stop_event):
37
+ # self.stop_event = stop_event
38
+
39
+ # def __call__(self, input_ids, scores, **kwargs):
40
+ # return self.stop_event.is_set()
41
+
42
+
43
+ device = 'cuda' if torch.cuda.is_available() else 'cpu'
44
+
45
+ TRAIT_VALUES = {
46
+ "Gender": [
47
+ "Male", "Female", "Non-binary/third gender", "Leave Blank",
48
+ ],
49
+ "Age": [
50
+ "18–24", "25–34", "35–44", "45–54", "55–64", "65 or older", "Leave Blank",
51
+ ],
52
+ "Current Profession": [
53
+ "Healthcare/Medical", "Government/Public Service",
54
+ "Business/Finance",
55
+ "Technology/Engineering", "Education", "Arts/Entertainment",
56
+ "Retail/Hospitality/Food Service",
57
+ "Skilled Trades/Labor (e.g., construction, electrician, landscaper, house cleaner)",
58
+ "Student",
59
+ "Unemployed/Looking for work", "Retired",
60
+ "Other",
61
+ "Leave Blank",
62
+ ],
63
+ "Race/Ethnicity" : [
64
+ "Asian", "Black/African American", "Hispanic/Latino",
65
+ "Native American/Alaska Native", "Native Hawaiian/Other Pacific Islander",
66
+ "White/Caucasian", "Other", "Leave Blank",
67
+ ],
68
+ "Religious/Cultural Group": [
69
+ "Christianity", "Islam", "Hinduism", "Judaism", "Buddhism", "None of the above", "Leave Blank",
70
+ ],
71
+ "Political Affiliation": [
72
+ "Conservative", "Apolitical/Not involved in politics", "Independent",
73
+ "Libertarian", "Moderate", "Liberal", "Leave Blank",
74
+ ],
75
+ "Highest Education": [
76
+ "Less than high school", "High school diploma or equivalent", "Some college, no degree",
77
+ "Associate’s degree", "Bachelor’s degree",
78
+ "Master’s degree", "Doctoral or professional degree",
79
+ "Leave Blank",
80
+ ],
81
+ "Annual Household Income": [
82
+ "Less than $25,000", "$25,000–$49,999", "$50,000–$74,999",
83
+ "$75,000–$99,999", "$100,000–$149,999", "$150,000 or more",
84
+ "Leave Blank",
85
+ ],
86
+ "Family Status": [
87
+ "Single, living alone", "Single, living with family", "Single Parent with children",
88
+ "Married/Partnered, no children", "Married/Partnered, with children",
89
+ "Multi-generation family (e.g., with parents, grandparents, or extended family)",
90
+ "Leave Blank",
91
+ ],
92
+ }
93
+
94
+ HEALTH_TOPICS = {
95
+ "Chronic Obstructive Pulmonary Disease (COPD)": "COPD1.1",
96
+ "Heart Disease": "HD1",
97
+ "HIV": "HIV1.1",
98
+ "Mental Health": "MH1.1",
99
+ "Nutrition": "N2.1",
100
+ "Substance Abuse": "SA4.1",
101
+ "Sexual Practice": "SP7.1",
102
+ "Vaccination": "V7.1",
103
+ "Cystic Fibrosis": "CF1.1",
104
+ }
105
+
106
+ health_topics = ""
107
+ for topic in HEALTH_TOPICS:
108
+ health_topics += topic + '\n'
109
+
110
+
111
+
112
+ ##########################################################
113
+ ### To increase style variability to avoid repetitiveness
114
+ ##########################################################
115
+ # * Style variants
116
+ style_variants = [
117
+ "Write with a slightly informal and reflective tone.",
118
+ "Write in a straightforward conversational tone.",
119
+ "Write with mild emotional coloring, but still natural.",
120
+ "Write in a calm, matter-of-fact tone.",
121
+ "Write in a slightly narrative, flowing tone.",
122
+ "Write in a concise but personable tone.",
123
+ "Write in a informal, pragmatic tone, focusing on clarity and utility.",
124
+ ]
125
+ # --- Add small lexical noise / synonym variation ---
126
+ lexical_flavors = [
127
+ "Feel free to vary sentence structures slightly.",
128
+ "Use a mix of simple and slightly complex sentences.",
129
+ "Use a light mix of paraphrasing expressions.",
130
+ "Feel free to choose different synonyms for common emotional words.",
131
+ "Introduce subtle variation in connectors like 'however', 'still', or 'overall'.",
132
+ ]
133
+ openers = [
134
+ "This message",
135
+ "From this message",
136
+ "Through the message",
137
+ "After seeing this message",
138
+ "Looking at this poster",
139
+ "Based on what this poster conveys",
140
+ "Hmmm I think that this message",
141
+ "Reflecting on the message here",
142
+ "Considering what this poster is trying to say",
143
+ "Seeing this message makes me think",
144
+ "Thinking about what this poster is communicating",
145
+ "After reading what's on here",
146
+ "Based on what’s written here",
147
+ "After I look at this whole thing",
148
+ ]
149
+ openers_generic = [
150
+ "Hmmm when thinking about",
151
+ "When I think about",
152
+ "My impression about",
153
+ "On top of my head",
154
+ "My general thoughts about",
155
+ "The way I see it,",
156
+ "From my point of view on",
157
+ "My initial take on",
158
+ "In my own words,",
159
+ "As I see things,",
160
+ "Just speaking for myself,",
161
+ "At a glance,",
162
+ ]
163
+ openers_poster_summary = [
164
+ "This poster",
165
+ "This poster seems to",
166
+ "My interpretation of the poster is",
167
+ "From what this poster shows, it seems to",
168
+ "Looking at the poster as a whole, it appears to",
169
+ "Based on the imagery and tone, the poster seems to",
170
+ "Visually, the poster comes across as trying to",
171
+ "To me, this poster is trying to",
172
+ "When I look at this poster, it feels like it aims to",
173
+ "The poster gives me the impression that it intends to",
174
+ ]
175
+ openers_explain = [
176
+ "The reason why I think that is because",
177
+ "To explain why I",
178
+ "Well, to explain my thoughts",
179
+ "To put it simply, I feel this way because",
180
+ "My reasoning behind that is",
181
+ "What leads me to that view is",
182
+ "A big part of why I think that is",
183
+ "To give some context for my view,",
184
+ "Here’s why I lean that way:",
185
+ "I see it that way mainly because",
186
+ "Let me explain why I think so",
187
+ "Thinking through it, I realize it's because",
188
+ "To unpack my thinking a bit,",
189
+ "I guess it’s because",
190
+ "The thing that really shapes my view is",
191
+ "It’s pretty much because",
192
+ "A lot of it comes down to",
193
+ "I feel that way mostly because",
194
+ "My thinking comes from the idea that",
195
+ ]
196
+
197
+
198
+
199
+ """
200
+ Generate LLM response given a single user prompt and input image
201
+ """
202
+ def vlm_response(user_input, history, health_topic,
203
+ gender, age, profession, race, religion,
204
+ political, education, income, family_status,
205
+ # extraversion, agreeableness, conscientiousness, neuroticism, openness,
206
+ ):
207
+ # # 1. Initialize Stop Event for this session
208
+ # stop_event = threading.Event()
209
+ # # Create the stopping criteria to pass to the model
210
+ # stopping_criteria = StoppingCriteriaList([StopGenerationCriteria(stop_event)])
211
+
212
+ # 1. Clear any lingering state
213
+ torch.cuda.empty_cache() # Clear GPU memory
214
+ # 2. Initialize Streamers LOCALLY (Fresh for every request)
215
+ # Note: We need to re-initialize these for every single generation call
216
+ # or just once per function call if we share them.
217
+ streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
218
+ # streamer_aux = TextIteratorStreamer(tokenizer_aux, skip_prompt=True, skip_special_tokens=True)
219
+
220
+ """ [NOTE] we have not use `history` for this generation """
221
+ # get uploaded image
222
+ image = Image.open(user_input['files'][0]) if user_input['files'] else None
223
+ image_uploaded = True
224
+ if image is None:
225
+ image = Image.new('RGB', (24,24))
226
+ image_uploaded = False
227
+ # image_b64 = convert_to_base64(image)
228
+ print(health_topic)
229
+ # print("Image uploaded:", image_uploaded)
230
+
231
+
232
+
233
+ #################################################
234
+ # 1. Construct traits from user inputs
235
+ #################################################
236
+ demo_dict = {
237
+ "Gender": gender,
238
+ "Age": age,
239
+ "Current Profession": profession,
240
+ "Race/Ethnicity": race,
241
+ "Religious/Cultural Group": religion,
242
+ "Political Affiliation": political,
243
+ "Highest Education": education,
244
+ "Annual Household Income": income,
245
+ "Family Status": family_status,
246
+ }
247
+ # big5_dict = {
248
+ # "Extraversion": extraversion,
249
+ # "Agreeableness": agreeableness,
250
+ # "Conscientiousness": conscientiousness,
251
+ # "Neuroticism": neuroticism,
252
+ # "Open-Mindedness": openness,
253
+ # }
254
+
255
+ demo_info = ""
256
+ for trait, value in demo_dict.items():
257
+ if value != "Leave Blank": # only add non-blank values
258
+ demo_info += f"{trait}: {value}\n"
259
+ else:
260
+ demo_info += f"{trait}: [Not specified]\n"
261
+ persona_score = ""
262
+ persona_score += "Big-Five Trait Scores:\n"
263
+ # for trait, value in big5_dict.items():
264
+ # persona_score += f"{trait}: {value}\n"
265
+ # no locus of control trait score
266
+ locus = None
267
+
268
+ ######################################################################################
269
+ # 1*. modify trait info based on trait selection setings
270
+ # demo_full: wheter include full demographic traits or only selected ones
271
+ # include_big5, include_facet, include_locus: include big5 / facet / locus of control traits or not
272
+ # format: <trait>: <value> if available; else <trait>: [Not specified]
273
+ ######################################################################################
274
+ demo_info, persona_score, locus = process_trait_info(
275
+ demo_info, persona_score, locus,
276
+ demo_full=False, include_big5=True,
277
+ include_facet=False, include_locus=False,
278
+ train_mode=False,
279
+ )
280
+ # print(demo_info)
281
+ # print(persona_score)
282
+
283
+ ###############################################
284
+ ### Add style variability ###
285
+ ###############################################
286
+ style_hint = random.choice(style_variants) # increase style variant
287
+ lexical_hint = random.choice(lexical_flavors) # increase lexical variant
288
+ opening_phrase = random.choice(openers) # increase opening variant
289
+ opening_generic = random.choice(openers_generic) # increase opening variant
290
+ opening_poster = random.choice(openers_poster_summary) # poster summary variation
291
+ opening_explain = random.choice(openers_explain) # thought explanation
292
+ print('Style:', style_hint)
293
+ print('Lexical:', lexical_hint)
294
+ print('Opening:', opening_phrase)
295
+ print('Generic opening:', opening_generic)
296
+
297
+
298
+ # Wrap the GENERATION logic in try/finally to handle cleanup
299
+ try:
300
+ if image_uploaded:
301
+ """###############################################################
302
+ Case 1: a health poster is uploaded
303
+ => VLM-enabled response prediction to that specific poster
304
+ ###############################################################"""
305
+ ################################################
306
+ # * IMAGE UNDERSTANDING
307
+ ################################################
308
+ yield "Analyzing image content..." # UI Feedback
309
+
310
+ PROMPT = (
311
+ f"Describe the content and main message in given heatlh campaign poster and how it's related to {health_topic}. ",
312
+ "Note that the message could be non-direct or subtle (e.g. irony, fear-driven evoke without explicit texts, etc). Only provide the answer (in 2-4 sentences). ",
313
+ f"Start the response with {opening_poster}"
314
+ )
315
+ messages = [
316
+ {"role": "user", "content": [
317
+ {"type": "image"},
318
+ {"type": "text", "text": PROMPT}
319
+ ]}
320
+ ]
321
+ input_text = tokenizer.apply_chat_template(messages, add_generation_prompt = True)
322
+ inputs = tokenizer(
323
+ image.convert("RGB"),
324
+ input_text,
325
+ add_special_tokens = False,
326
+ return_tensors = "pt",
327
+ ).to(device)
328
+ # Model inference
329
+ gen_tokens = model.generate(
330
+ **inputs,
331
+ max_new_tokens = 512,
332
+ use_cache = True,
333
+ # do_sample=cfgs["stochastic"],
334
+ # temperature=cfgs["temperature"],
335
+ # min_p=0.9,
336
+ # min_p=0.3,
337
+ top_k=15,
338
+ temperature=0.8,
339
+ do_sample=True, # cfgs["stochastic"]
340
+ )
341
+ outs = tokenizer.batch_decode(gen_tokens[:, inputs.input_ids.shape[1]:])[0]
342
+ image_desc = outs.replace(tokenizer.eos_token, "")
343
+ image_desc = image_desc.replace("<end_of_turn>", "")
344
+
345
+ ################################################
346
+ # 2. Construct SYSTEM and USER PROMPT
347
+ ################################################
348
+ SYSTEM_PROMPT = cfg_prompts["SYSTEM_SIM"]
349
+ SIM_PROMPT = ""
350
+ # prompt for role-playing information
351
+ SIM_PROMPT += f"You are: Demographics:\n{demo_info}\n"
352
+ # SIM_PROMPT += f"Your personality test shows you have (min score = 0; max score = 5):\nBig-Five Trait Scores:\n{persona_score}\n\n"
353
+ # SIM_PROMPT += f"You also have {locus}\n"
354
+ # situation description (role-playing)
355
+ SIM_PROMPT += cfg_prompts["SIMULATION_SIM"]
356
+
357
+ ################################################
358
+ # 3. Stage 1: VLM-enabled response prediction
359
+ # Predict Trait-aware Likert Scale Responses
360
+ ################################################
361
+ assert cfgs["infer_engine"] == "unsloth", "Only unsloth inference is supported"
362
+ assert cfgs["vision"] == True, "Must have vision input"
363
+ # load a sample row to extract Likert scale questions
364
+ df = pd.read_csv(os.path.expandvars(cfgs["data_path"]))
365
+ # extract sample with given health_topic for correct question set
366
+ sample = df[df['Poster_id'] == HEALTH_TOPICS[health_topic]].iloc[0]
367
+ del df # free memory
368
+ """ Iterate through each question"""
369
+ # answers_json = {}
370
+ answers_numeric = ""
371
+ # for question in [
372
+ # "This message makes me more concerned about the health risks in the poster - Scale: 1 (not at all) - 9 (extremely)",
373
+ # "The message motivates me to engage in healthier lifestyle and habit - Scale: 1 (not at all) - 9 (extremely)",
374
+ # "In your opinion, how harmful is ignoring the health risks in the poster? - Scale: 1 (not at all) - 9 (extremely",
375
+ # "How open are you to engaging in the activity in the poster? - Scale: 1 (not at all) - 9 (extremely)",
376
+ # ]:
377
+ for i in range(1,16,1):
378
+ # a. parse specific Likert score question
379
+ col = f"Q{i}"
380
+ if pd.isna(sample[col]):
381
+ continue
382
+ question = sample[col].replace("\n", " ")
383
+ # instruction prompt to answer in proper format
384
+ if "type in" in question.lower():
385
+ continue # skip free-text questions for demo
386
+ elif "make you feel" in question.lower():
387
+ continue # skip emotional questions: imprecise
388
+ elif "how open" in question.lower():
389
+ continue # skip intentional question: low-accuracy
390
+ # b. intialize USER PROMPT with SIMULATION PROMPT
391
+ # with full demographic+personality data
392
+ USER_PROMPT = SIM_PROMPT
393
+ USER_PROMPT += f"Question: {question}\n\n"
394
+ # instruction prompt to answer in proper format
395
+ USER_PROMPT += cfg_prompts['INSTRUCTION_MCQ']
396
+ # c. Contruct LLM message: response prediction
397
+ messages = [
398
+ {"role": "user", "content": [
399
+ {"type": "image"},
400
+ {"type": "text", "text": SYSTEM_PROMPT + USER_PROMPT}
401
+ ]}
402
+ ]
403
+ input_text = tokenizer.apply_chat_template(messages, add_generation_prompt = True)
404
+ inputs = tokenizer(
405
+ image.convert("RGB"),
406
+ input_text,
407
+ add_special_tokens = False,
408
+ return_tensors = "pt",
409
+ ).to(device)
410
+ # d. Model inference
411
+ gen_tokens = model.generate(
412
+ **inputs,
413
+ max_new_tokens = 16,
414
+ use_cache = True,
415
+ do_sample=cfgs["stochastic"],
416
+ temperature=cfgs["temperature"],
417
+ min_p=0.9,
418
+ )
419
+ outs = tokenizer.batch_decode(gen_tokens[:, inputs.input_ids.shape[1]:])[0]
420
+ answer = outs.replace(tokenizer.eos_token, "")
421
+ answer = answer.replace("<end_of_turn>", "")
422
+ # answers_json[col] = answer
423
+ answers_numeric += f"{question}. Your answer: {answer}\n"
424
+ # print(answers_json)
425
+ print(answers_numeric)
426
+
427
+ ################################################
428
+ # 4. Stage 2: LLM Summarization of all answers
429
+ # => final response generation based on
430
+ # all Likert answers to the poster
431
+ # => one-shot prompting
432
+ ################################################
433
+ SYSTEM_PROMPT = "You are a helpful assistant."
434
+ # USER_PROMPT = f"Please convert these questions and answers into a concise and coherent \
435
+ # summary of your overall reactions, feelings, and perspectives about the poster: {answers_numeric} \
436
+ # Please provide the final response only."
437
+ # USER_PROMPT = f"Summarize the main points from questions and answers below into a concise and coherent overall reaction to the poster:\
438
+ # {answers_numeric}. Provide the final response only.\n"
439
+ USER_PROMPT = (
440
+ "Summarize the following survey responses into a short, natural paragraph that captures your overall sentiment, motivation, and thinking. "
441
+ f"Write as if paraphrasing what a person might say in conversation. Adjust your style based on your demographic/personality traits."
442
+ "Do NOT repeat numeric scores. "
443
+ "Preserve polarity: low scores → low concern/motivation/openness; high scores → high concern/motivation/openness. "
444
+ "If answers are mixed (e.g., believes something is harmful but isn't personally moved), reflect that nuance explicitly. "
445
+ "Keep to 1-5 sentences.\n\n"
446
+
447
+ "**STRICTLY FOLLOW THESE RULES:**\n"
448
+ "- Infer direction from each item's Scale description (e.g., 1-9: higher = more; 0-6: higher = more). "
449
+ "- Use calibrated wording: 1-2 = very low, 3-4 = low, 5 = moderate, 6-7 = high, 8-9 = very high; for 0-6: 0-1 = not/slight, 2-3 = somewhat, 4-5 = high, 6 = very. "
450
+ "- VERY IMPORTANT: provide ONLY the final summarized response, without anything else!"
451
+ f"- The response MUST have a consistent health topic: {health_topic}. Ground each sentence to the impact of campaign message.\n"
452
+ "- Never invert sentiment. Prefer hedged phrases (e.g., “not particularly,” “only somewhat,” “very open,” “not open at all”).\n\n"
453
+ f"- Mimic the talking style of emulated demographic as realistic as possible."
454
+
455
+ "**Example input 1:**\n"
456
+ "The message makes me more concerned about the health risks of poor eating habits - Scale: 1-9. Your answer: 9\n"
457
+ "The message motivates me to make healthy eating choices - Scale: 1-9. Your answer: 9\n"
458
+ "In your opinion, how harmful is neglecting proper nutrition and weight management to your overall health? - Scale: 0–6. Your answer: 5\n"
459
+ "How open are you to adopting healthier eating habits and lifestyle changes? - Scale: 1-9. Your answer: 9\n"
460
+ "**Example output 1:**\n"
461
+ "This message really heightened my awareness of how unhealthy eating can be. The content in the message strongly motivates me to make better choices, and I feel very ready to follow through.\n\n"
462
+
463
+ "**Example input 2:**\n"
464
+ "The message makes me more concerned about the health risks of COPD and smoking - Scale: 1-9. Your answer: 1\n"
465
+ "The message motivates me to not smoke. - Scale: 1-9. Your answer: 1\n"
466
+ "In your opinion, how harmful is smoking to your general health? - Scale: 0-6. Your answer: 6\n"
467
+ "How open are you to smoking in the future? - Scale: 1-9. Your answer: 1\n"
468
+ "**Example output 2:**\n"
469
+ "From this message, I recognize smoking is very harmful, but the content in the message didn't increase my concern or motivate me much. It does somewhat make me understand that smoking is harmful, however. Anyway, I'm not open to smoking in the future.\n\n"
470
+
471
+ "**Example input 3:**\n"
472
+ "The message makes me more concerned about the effects of lack of exercise - Scale: 1-9. Your answer: 4\n"
473
+ "The message motivates me to be more active - Scale: 1-9. Your answer: 3\n"
474
+ "How open are you to exercising regularly? - Scale: 1-9. Your answer: 4\n"
475
+ "**Example output 3:**\n"
476
+ "Through the message, I get that exercise matters and the message raised my awareness a bit, but the poster content itself didn't really motivate me. The content in the message has some small impact in motivating me to change my routine.\n\n"
477
+
478
+ # "**Example input 4:**\n"
479
+ # "The message makes me more concerned about the health risks of substance abuse - Scale: 1 (not at all) - 9 (extremely). Your answer: 6\n"
480
+ # "The message motivates me to not use substances. - Scale: 1 (not at all) - 9 (extremely). Your answer: 6\n"
481
+ # "In your opinion, how harmful is substance use to your general health? - Scale: 0 (not at all)-6 (extremely harmful). Your answer: 5\n"
482
+ # "How open are you to trying a substance in the future? - Scale: 1 (not at all)-9 (extremely). Your answer: 1\n"
483
+ # "**Example output 4:**\n"
484
+ # "This message somewhat makes me more concerned about the health risks of substance abuse motivates me not to use them. However, the message itself doesn't completely convince me that substance abuse is harmful. However, I'm not open to trying substance at all!!\n"
485
+ f"Start the response with '{opening_phrase}' (Style hint: {style_hint}; Lexical hint: {lexical_hint})\n"
486
+ f"Input: {answers_numeric}. "
487
+ )
488
+
489
+ # Contruct LLM message
490
+ messages = [
491
+ {"role": "user", "content": [
492
+ # {"type": "image"},
493
+ {"type": "text", "text": SYSTEM_PROMPT + USER_PROMPT}
494
+ ]}
495
+ ]
496
+ # input_text = tokenizer_aux.apply_chat_template(messages, add_generation_prompt = True)
497
+ # inputs = tokenizer_aux(
498
+ # # image.convert("RGB"),
499
+ # input_text,
500
+ # add_special_tokens = False,
501
+ # return_tensors = "pt",
502
+ # ).to(device)
503
+ input_text = tokenizer.apply_chat_template(messages, add_generation_prompt = True)
504
+ inputs = tokenizer(
505
+ # image.convert("RGB"),
506
+ input_text,
507
+ add_special_tokens = False,
508
+ return_tensors = "pt",
509
+ ).to(device)
510
+
511
+ ############################
512
+ ### Text LLM Streaming ###
513
+ ############################
514
+ # generation with streamer
515
+ generate_kwargs = dict(
516
+ **inputs,
517
+ streamer=streamer, # streamer_aux,
518
+ max_new_tokens=512,
519
+ use_cache=True,
520
+ # min_p=0.3,
521
+ top_k=15,
522
+ temperature=0.8,
523
+ do_sample=True, # cfgs["stochastic"]
524
+ )
525
+ # separate thread to run generation
526
+ thread = threading.Thread(
527
+ target=model.generate, # model_aux.generate,
528
+ kwargs=generate_kwargs
529
+ )
530
+ thread.start()
531
+ # stream out generation
532
+ outputs = [
533
+ f"Emulated traits:\n {demo_info}\n" + '='*20 + "\n\n",
534
+ image_desc + "\n\n"
535
+ ]
536
+ for new_token in streamer: # streamer_aux:
537
+ outputs.append(new_token)
538
+ final_output = ''.join(outputs)
539
+ yield final_output
540
+
541
+ # Ensure thread finishes
542
+ thread.join()
543
+
544
+ # text representation of final response
545
+ response = "".join(outputs[2:]) # ignore trait summary & image description
546
+ print(colored('Traits', 'green'), demo_info)
547
+ print(colored('Emulated response:', 'green'), response)
548
+ print('='*100)
549
+
550
+
551
+ ################################################
552
+ # 5. Stage 3: provide explanation (demo purpose)
553
+ # => condition on {trait} AND {reponse}
554
+ ################################################
555
+ SYSTEM_PROMPT = cfg_prompts["SYSTEM_SIM"]
556
+ SIM_PROMPT = ""
557
+ # prompt for role-playing information
558
+ SIM_PROMPT += f"You are: Demographics:\n{demo_info}\n"
559
+ # SIM_PROMPT += f"Your personality test shows you have (min score = 0; max score = 5):\nBig-Five Trait Scores:\n{persona_score}\n\n"
560
+ # SIM_PROMPT += f"You also have {locus}\n"
561
+ # situation description (role-playing)
562
+ SIM_PROMPT += cfg_prompts["SIMULATION_SIM"]
563
+ SIM_PROMPT += (
564
+ f"After seeing the uploaded impage, your response were {response}. "
565
+ "Briefly explain WHY you responded that way, based on your demographic background. "
566
+ f"Keep the explanation concise and direct. Start the response with '{opening_explain}' "
567
+ f"(Style hint: {style_hint}, concise; Lexical hint: {lexical_hint}). "
568
+ "Afterward, give a few *generic and succinct* suggestions to improve the poster's persuasiveness."
569
+ )
570
+ USER_PROMPT = SIM_PROMPT
571
+
572
+ # Contruct LLM message
573
+ messages = [
574
+ {"role": "user", "content": [
575
+ {"type": "image"},
576
+ {"type": "text", "text": SYSTEM_PROMPT + USER_PROMPT}
577
+ ]}
578
+ ]
579
+ # input_text = tokenizer_aux.apply_chat_template(messages, add_generation_prompt = True)
580
+ # inputs = tokenizer_aux(
581
+ # image.convert("RGB"),
582
+ # input_text,
583
+ # add_special_tokens = False,
584
+ # return_tensors = "pt",
585
+ # ).to(device)
586
+ input_text = tokenizer.apply_chat_template(messages, add_generation_prompt = True)
587
+ inputs = tokenizer(
588
+ image.convert("RGB"),
589
+ input_text,
590
+ add_special_tokens = False,
591
+ return_tensors = "pt",
592
+ ).to(device)
593
+
594
+ ############################
595
+ ### Text LLM Streaming ###
596
+ ############################
597
+ # generation with streamer
598
+ generate_kwargs = dict(
599
+ **inputs,
600
+ streamer=streamer, # streamer_aux,
601
+ max_new_tokens=512,
602
+ use_cache=True,
603
+ min_p=0.85,
604
+ temperature=0.1,
605
+ do_sample=True, # cfgs["stochastic"]
606
+ )
607
+ # separate thread to run generation
608
+ thread = threading.Thread(
609
+ target=model.generate, # model_aux.generate,
610
+ kwargs=generate_kwargs
611
+ )
612
+ thread.start()
613
+ # stream out generation
614
+ # outputs = [image_desc + "\n\n"]
615
+ outputs += ["\n"]
616
+ for new_token in streamer: # streamer_aux:
617
+ outputs.append(new_token)
618
+ final_output = ''.join(outputs)
619
+ yield final_output
620
+
621
+ thread.join()
622
+
623
+
624
+ return answer
625
+ else:
626
+ """###############################################################
627
+ Case 2: no health poster is uploaded
628
+ => General Response to the health topic
629
+ => not conditioned on any particular health poster
630
+ ###############################################################"""
631
+ ################################################
632
+ # 2. Construct SYSTEM and USER PROMPT
633
+ ################################################
634
+ SYSTEM_PROMPT = (
635
+ "You are a person with unique demographic and personality traits. "
636
+ "Based on your background, you naturally have thoughts, feelings, and reactions to what you see."
637
+ )
638
+ SIM_PROMPT = ""
639
+ # prompt for role-playing information
640
+ SIM_PROMPT += f"You are: {demo_info}\n"
641
+ # SIM_PROMPT += f"Your personality test shows you have (min score = 0; max score = 5): {persona_score}\n"
642
+ # SIM_PROMPT += f"You also have {locus}\n"
643
+ # situation description (role-playing)
644
+ SIM_PROMPT += f"You are being asked a general question to share your *general* opinions and beliefs about a given health topic.\n"
645
+ ################################################
646
+ # 3. LLM-enabled response prediction
647
+ # Predict Trait-aware Likert Scale Responses
648
+ ################################################
649
+ assert cfgs["infer_engine"] == "unsloth", "Only unsloth inference is supported"
650
+ USER_PROMPT = SIM_PROMPT
651
+ USER_PROMPT += (
652
+ f"What are your *general* thoughts and opinions about the {health_topic} health topic? "
653
+ f" What's your attitude and feeling when talking about {health_topic} in general and why?"
654
+ f" How familiar are you with {health_topic}? How much do you care or know about it?"
655
+ f" Do you think {health_topic} is an important topic to talk about?"
656
+ f" What is its impacts and importance {health_topic} in society and your life? Why?"
657
+ f" Do you have any strong opinions about it?"
658
+ f" Are you interested in learning more about it?"
659
+ )
660
+ # instruction prompt to answer in proper format
661
+ USER_PROMPT += (
662
+ "Your personality, locus of control, and demographic traits influence your response. Adjust your style based on your demographic personality traits.\n"
663
+ "**STRICTLY FOLLOW THESE RULES:**\n"
664
+ "- Human-like, casual, everyday conversational response. Only answer the questions\n"
665
+ f"- The response MUST have a consistent health topic: {health_topic}.\n"
666
+ # "- Answer briefly in **5-7 sentences**.\n"
667
+ "- Only provide the answer. DO NOT REPEAT THE PROMPT!\n"
668
+ "- Condition your response on your *demographic/personality traits provided earlier, IGNORING the [Not specified] ones*.\n"
669
+ "- MUST provide *reasonable* and *informative* answers aligned with your background."
670
+ f"- Start the response with '{opening_generic}' ; {style_hint} {lexical_hint}\n"
671
+ # f"- Start the answer some variations of \'About my personal thoughts on *{health_topic}*, I \' \n"
672
+ # f"- Start the answer with something like: When thinking about {health_topic}, I ..."
673
+ )
674
+ # c. Contruct LLM message
675
+ # print("USER PROMPT:", USER_PROMPT)
676
+ messages = [
677
+ {"role": "user", "content": SYSTEM_PROMPT + USER_PROMPT}
678
+ ]
679
+ assert "gemma" in cfgs["model"], "Currently only gemma model is supported for no-image input"
680
+ input_text = tokenizer.apply_chat_template(messages, add_generation_prompt = True)
681
+ inputs = tokenizer(
682
+ input_text,
683
+ add_special_tokens = False,
684
+ return_tensors = "pt",
685
+ ).to(device)
686
+ ############################
687
+ ### Text LLM Streaming ###
688
+ ############################
689
+ # generation with streamer
690
+ generate_kwargs = dict(
691
+ **inputs,
692
+ streamer=streamer,
693
+ max_new_tokens=512,
694
+ use_cache=True,
695
+ # min_p=0.3,
696
+ top_k=15,
697
+ temperature=0.8,
698
+ do_sample=True, # cfgs["stochastic"]
699
+ )
700
+ # separate thread to run generation
701
+ thread = threading.Thread(
702
+ target=model.generate,
703
+ kwargs=generate_kwargs
704
+ )
705
+ thread.start()
706
+ # stream out generation
707
+ outputs = [f"Emulated traits:\n {demo_info}\n" + '='*20 + "\n\n"]
708
+ for new_token in streamer:
709
+ outputs.append(new_token)
710
+ final_output = ''.join(outputs)
711
+ yield final_output
712
+ thread.join()
713
+
714
+ except GeneratorExit:
715
+ print("User disconnected. Waiting for generation to complete...")
716
+ finally:
717
+ # Ensure cleanup happens even on normal finish or errors
718
+ if thread is not None and thread.is_alive():
719
+ thread.join()
720
+ torch.cuda.empty_cache()
721
+
722
+ """###########################################################################
723
+ Evaluate a given model (specified in model_cfgs)
724
+ on posters with given test_style
725
+
726
+ Args:
727
+ + cfgs : specify model type (e.g. gemma or llama),
728
+ data source, and export paths
729
+ + prompts : set of prompts
730
+
731
+ Outputs:
732
+ => save model in cfgs["export_path"] (CSV file)
733
+ + if cfgs["export_path"] not exists, initialize it with cfgs["data_path"]
734
+ => original survey data with ground-truth responses
735
+ + add column "<model>:<version>": store AI-simulated responses
736
+ + support concurrent evaluation on different jobs
737
+ ##########################################################################"""
738
+ if __name__ == '__main__':
739
+ """==========================================
740
+ 1. load model settings & prompts format
741
+ =========================================="""
742
+ ######################################
743
+ # Load model configs & prompts
744
+ ######################################
745
+ model_cfg = "./configs/task1_demo_sph.yaml"
746
+ prompt_cfg = "./configs/prompts.yaml"
747
+ cfgs = load_config(model_cfg)
748
+ cfg_prompts = load_config(prompt_cfg)
749
+
750
+ """==========================================
751
+ 2. Evaluate model defined in configs
752
+ =========================================="""
753
+ print(colored('MODEL USE:', 'green'), cfgs["model"])
754
+ # print(prompts['SYSTEM'])
755
+ # print(prompts['INSTRUCTION'])
756
+
757
+ """===============================
758
+ 3. Initialize model
759
+ => `model`, `tokenizer`
760
+ are initialized here
761
+ ==============================="""
762
+ assert cfgs["infer_engine"] == "unsloth", "Only unsloth inference is supported"
763
+ assert cfgs["vision"] == True, "Must have vision input"
764
+ if cfgs["vision"]:
765
+ #################################################
766
+ ### (1) MAIN MODEL
767
+ ### => response emulation, fine-tuned model
768
+ #################################################
769
+ # WITH VISUAL STIMULI
770
+ model, tokenizer = FastVisionModel.from_pretrained(
771
+ model_name=cfgs["model"],
772
+ load_in_4bit=True,
773
+ )
774
+ FastVisionModel.for_inference(model)
775
+ if "gemma" in cfgs["model"]:
776
+ # gemma-specific tokenizer chat template
777
+ tokenizer = get_chat_template(
778
+ tokenizer,
779
+ chat_template = "gemma-3",
780
+ )
781
+ #################################################
782
+ ### (2) AUXILLIARY MODEL
783
+ ### => summarization model
784
+ ### => larger (12b) for better summarization
785
+ #################################################
786
+ # model_aux, tokenizer_aux = FastVisionModel.from_pretrained(
787
+ # model_name=cfgs["model_summarize"],
788
+ # load_in_4bit=True,
789
+ # )
790
+ # FastVisionModel.for_inference(model)
791
+ # if "gemma" in cfgs["model"]:
792
+ # # gemma-specific tokenizer chat template
793
+ # tokenizer_aux = get_chat_template(
794
+ # tokenizer_aux,
795
+ # chat_template = "gemma-3",
796
+ # )
797
+
798
+ # # initialize streamer tokens
799
+ # streamer = TextIteratorStreamer(
800
+ # tokenizer, skip_prompt=True, skip_special_tokens=True
801
+ # )
802
+ # streamer_aux = TextIteratorStreamer(
803
+ # tokenizer_aux, skip_prompt=True, skip_special_tokens=True
804
+ # )
805
+
806
+ """=============================================
807
+ 4. User-input Dropdown Traits
808
+ ============================================="""
809
+ #################################
810
+ ### Gradio Interface ###
811
+ #################################
812
+ with gr.Blocks(theme="gradio/dark") as interface:
813
+ # --- Title Page with Logo ---
814
+ LOGO_PATH = os.path.abspath(os.path.join(os.path.dirname(__file__), "assets/umd_logo.png"))
815
+ gr.Image(value=LOGO_PATH, show_label=False, interactive=False, height=100)
816
+ gr.Markdown(
817
+ """
818
+ <div style="text-align: center;">
819
+ <h1 style="margin-bottom: 0.5em;">
820
+ UMD AI-Empowered Response Prediction in Public Health Messaging
821
+ </h1>
822
+ </div>
823
+
824
+ <hr style="margin-top: 0.8em; margin-bottom: 0.8em;"> <!-- thinner spacing around line -->
825
+
826
+ <div style="text-align: center;">
827
+ <h2 style="margin-top: 0.3em; margin-bottom: 0.6em;">
828
+ User Guide
829
+ </h2>
830
+ </div>
831
+
832
+ <ul style="text-align: left; max-width: 800px; margin: auto;">
833
+ <li>This program emulates <b>demographic- and personality-conditioned responses</b> to public health posters using our trait-aligned Vision-Language Model (VLM).</li>
834
+ <li>To begin, (1) specify the target demographic traits, then (2) upload a public health poster to predict responses.</li>
835
+ <li>If a health poster is uploaded, the model first summarizes its understanding of the image.</li>
836
+ <li><b>Please note:</b>
837
+ <ul>
838
+ <li>Each interaction only uses the uploaded image and selected traits (no conversation history).</li>
839
+ <li>You don’t need to type any text prompt; just upload the Health Poster and click <b>Submit</b>.</li>
840
+ <li>If no poster or image is uploaded, the program automatically generates the emulated person’s <b>general opinion</b> on the selected Health Topic.</li>
841
+ <li>Please do not interrupt the generation process as it can lead to unexpected results. In case it happens, simply refresh the web app.</li>
842
+ <li><b>Limitation:</b> The model may generate less realistic emulations to some under-represented demographics in the survey dataset (e.g., Asian seniors). We are conducting more comprehensive survey to effectively address this limitation.</li>
843
+ </ul>
844
+ </li>
845
+ </ul>
846
+
847
+ <hr style="margin-top: 0.8em; margin-bottom: 1.2em;">
848
+ """,
849
+ elem_id="intro-section"
850
+ )
851
+
852
+ # Scroll to intro section on load
853
+ gr.HTML("""
854
+ <script>
855
+ window.onload = function() {
856
+ window.scrollTo({ top: 0, behavior: 'smooth' });
857
+ }
858
+ </script>
859
+ """)
860
+
861
+ ##########################
862
+ ### Demographic Traits ###
863
+ ##########################
864
+ gr.Markdown("## 1. Please specify the target demographic traits to be emulated here:")
865
+ # Dropdowns (single-select, no custom values)
866
+ with gr.Row():
867
+ gender = gr.Dropdown(
868
+ label="Gender",
869
+ choices=TRAIT_VALUES["Gender"],
870
+ allow_custom_value=False,
871
+ value="Female",
872
+ )
873
+ age = gr.Dropdown(
874
+ label="Age",
875
+ choices=TRAIT_VALUES["Age"],
876
+ allow_custom_value=False,
877
+ value="25–34",
878
+ )
879
+ profession = gr.Dropdown(
880
+ label="Current Profession",
881
+ choices=TRAIT_VALUES["Current Profession"], # keep given order
882
+ allow_custom_value=False,
883
+ value="Student",
884
+ )
885
+ with gr.Row():
886
+ race = gr.Dropdown(
887
+ label="Race/Ethnicity",
888
+ choices=TRAIT_VALUES["Race/Ethnicity"],
889
+ allow_custom_value=False,
890
+ value="White/Caucasian",
891
+ )
892
+ religion = gr.Dropdown(
893
+ label="Religious/Cultural Group",
894
+ choices=TRAIT_VALUES["Religious/Cultural Group"],
895
+ allow_custom_value=False,
896
+ value="Leave Blank",
897
+ )
898
+ political = gr.Dropdown(
899
+ label="Political Affiliation",
900
+ choices=TRAIT_VALUES["Political Affiliation"],
901
+ allow_custom_value=False,
902
+ value="Leave Blank",
903
+ )
904
+ with gr.Row():
905
+ education = gr.Dropdown(
906
+ label="Highest Education",
907
+ choices=TRAIT_VALUES["Highest Education"],
908
+ allow_custom_value=False,
909
+ value="Leave Blank",
910
+ )
911
+ income = gr.Dropdown(
912
+ label="Annual Household Income",
913
+ choices=TRAIT_VALUES["Annual Household Income"],
914
+ allow_custom_value=False,
915
+ value="$75,000–$99,999",
916
+ )
917
+ family_status = gr.Dropdown(
918
+ label="Family Status",
919
+ choices=TRAIT_VALUES["Family Status"],
920
+ allow_custom_value=False,
921
+ value="Leave Blank"
922
+ )
923
+ # ##########################
924
+ # ### Big Five Traits ###
925
+ # ##########################
926
+ # gr.Markdown("## 1.b) Please adjust the Big Five Personality Traits to be emulated:")
927
+ # with gr.Accordion("Big Five Personality Traits (1 = very low, 5 = very high)", open=True):
928
+ # gr.Markdown(
929
+ # "Adjust the sliders to represent the target personality profile. "
930
+ # "Leave them as-is if not applicable."
931
+ # )
932
+ # with gr.Row():
933
+ # with gr.Column(scale=1):
934
+ # openness = gr.Slider(
935
+ # label="Open-Mindedness",
936
+ # minimum=1, maximum=5, step=0.2, value=2.5,
937
+ # interactive=True
938
+ # )
939
+ # with gr.Column(scale=1):
940
+ # conscientiousness = gr.Slider(
941
+ # label="Conscientiousness",
942
+ # minimum=1, maximum=5, step=0.2, value=2.5,
943
+ # interactive=True
944
+ # )
945
+ # with gr.Column(scale=1):
946
+ # extraversion = gr.Slider(
947
+ # label="Extraversion",
948
+ # minimum=1, maximum=5, step=0.2, value=2.5,
949
+ # interactive=True
950
+ # )
951
+ # with gr.Row():
952
+ # with gr.Column(scale=1):
953
+ # neuroticism = gr.Slider(
954
+ # label="Neuroticism",
955
+ # minimum=1, maximum=5, step=0.2, value=2.5,
956
+ # interactive=True
957
+ # )
958
+ # with gr.Column(scale=1):
959
+ # agreeableness = gr.Slider(
960
+ # label="Agreeableness",
961
+ # minimum=1, maximum=5, step=0.2, value=2.5,
962
+ # interactive=True
963
+ # )
964
+ # gr.Column(scale=1) # right spacer
965
+
966
+ ##########################
967
+ ### Health Topic ###
968
+ ##########################
969
+ gr.Markdown("## 2. Please specify the main Health Topic of the poster here:")
970
+ # ---- dropdown at ~50% page width and centered ----
971
+ with gr.Row():
972
+ with gr.Column(scale=1):
973
+ health_topic = gr.Dropdown(
974
+ label="Health Topic",
975
+ choices=HEALTH_TOPICS,
976
+ allow_custom_value=False,
977
+ )
978
+ gr.Column(scale=1) # right spacer
979
+ ##########################
980
+ ### Chat interface ###
981
+ ##########################
982
+ gr.Markdown("## 3. Upload Public Health Poster here (if no poster is uploaded, the model emulates General Response to the topic):")
983
+ gr.Markdown("""
984
+ #### ▶️ Use Case 1: Poster-Based Response
985
+ + Upload **only one** poster image — the first file is the one processed.
986
+ + The model has **no memory**, so re-upload the image for each new request.
987
+ + Must choose a **Health Topic** that matches the poster content for best results.
988
+ + No text prompt is needed: upload the poster and click **Submit**.
989
+ #### ▶️ Use Case 2: General Response (No Poster)
990
+ + Simply select a Health Topic and click **Send**.
991
+ """
992
+ )
993
+ gr.Markdown("""
994
+ ### 📘 Important Notes
995
+ - ⚠️ **Do not interrupt the generation process.** Stopping midway can cause backend issues. Please allow the response to complete.
996
+ - 🏷️ Before uploading a poster, select its **corresponding health topic**.
997
+ - 🎯 For the best experience, ensure the **topic accurately matches the poster content**.
998
+ - 🧩 If you choose not to upload a poster, the model will produce a **general, trait-conditioned response** for the selected topic.
999
+ """)
1000
+ chat = gr.ChatInterface(
1001
+ fn=vlm_response,
1002
+ multimodal=True, # text + image
1003
+ title=f"Vision-Language Model: Trait-Conditioned Response Emulation",
1004
+ type="messages",
1005
+ additional_inputs=[
1006
+ health_topic, gender, age, profession, race, religion,
1007
+ political, education, income, family_status,
1008
+ # extraversion, agreeableness, conscientiousness, neuroticism, openness,
1009
+ ],
1010
+ chatbot=gr.Chatbot(height=500), # height=330
1011
+ autofocus=False,
1012
+ )
1013
+
1014
+ """=============================================
1015
+ 5. Chat Interface Launch
1016
+ ============================================="""
1017
+ interface.queue(
1018
+ max_size=20,
1019
+ default_concurrency_limit=1,
1020
+ ).launch(
1021
+ share=True,
1022
+ max_threads=1,
1023
+ # show_error=True,
1024
+ # prevent_thread_lock=False,
1025
+ # debug=True,
1026
+ )
app.sh ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+ #SBATCH -c 16 # 16 CPUs
3
+ #SBATCH --mem=32g # 32 GB RAM
4
+ #SBATCH --gres=gpu:rtxa5000:1 # 1 GPU (A6000)
5
+ #SBATCH --time=3-00:00:00 # 8 days
6
+ #SBATCH --account=gamma
7
+ #SBATCH --partition=gamma
8
+ #SBATCH --qos=gamma-huge-long
9
+ #SBATCH --output=/fs/nexus-projects/health_sim_ai/src_hf_deploy/app_logs/app_%j.out
10
+
11
+ export HOME=/fs/nexus-projects/health_sim_ai
12
+ cd /fs/nexus-projects/health_sim_ai
13
+ source venvs/llm/bin/activate
14
+ cd src_hf_deploy
15
+ python -u app.py
16
+ # python inference_pred_llm.py
17
+ # python inference_rec_llm.py
app_logs/app_5936040.out ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ 🦥 Unsloth: Will patch your computer to enable 2x faster free finetuning.
2
+ 🦥 Unsloth Zoo will now patch everything to make training faster!
3
+ MODEL USE: unsloth/gemma-3-12b-it_task1_1_epochs_test_neutral_partialTraits
4
+ ==((====))== Unsloth 2025.3.19: Fast Gemma3 patching. Transformers: 4.50.0.
5
+ \\ /| NVIDIA RTX A5000. Num GPUs = 1. Max memory: 23.547 GB. Platform: Linux.
6
+ O^O/ \_/ \ Torch: 2.6.0+cu124. CUDA: 8.6. CUDA Toolkit: 12.4. Triton: 3.2.0
7
+ \ / Bfloat16 = TRUE. FA [Xformers = 0.0.29.post3. FA2 = False]
8
+ "-____-" Free license: http://github.com/unslothai/unsloth
9
+ Unsloth: Fast downloading is enabled - ignore downloading bars which are red colored!
10
+
11
+ Using a slow image processor as `use_fast` is unset and a slow processor was saved with this model. `use_fast=True` will be the default behavior in v4.50, even if the model was saved with a slow processor. This will result in minor differences in outputs. You'll still be able to use a slow processor with `use_fast=False`.
12
+ /fs/nexus-projects/health_sim_ai/venvs/llm/lib/python3.12/site-packages/gradio/blocks.py:1069: UserWarning: Cannot load gradio/dark. Caught Exception: The space gradio/dark does not exist
13
+ warnings.warn(f"Cannot load {theme}. Caught Exception: {str(e)}")
14
+ /fs/nexus-projects/health_sim_ai/src_hf_deploy/app.py:1010: UserWarning: You have not specified a value for the `type` parameter. Defaulting to the 'tuples' format for chatbot messages, but this is deprecated and will be removed in a future version of Gradio. Please set type='messages' instead, which uses openai-style dictionaries with 'role' and 'content' keys.
15
+ chatbot=gr.Chatbot(height=500), # height=330
16
+ /fs/nexus-projects/health_sim_ai/venvs/llm/lib/python3.12/site-packages/gradio/chat_interface.py:323: UserWarning: The type of the gr.Chatbot does not match the type of the gr.ChatInterface.The type of the gr.ChatInterface, 'messages', will be used.
17
+ warnings.warn(
18
+ slurmstepd: error: *** JOB 5936040 ON gammagpu09 CANCELLED AT 2025-12-08T03:01:34 ***
app_logs/app_5936041.out ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ 🦥 Unsloth: Will patch your computer to enable 2x faster free finetuning.
2
+ 🦥 Unsloth Zoo will now patch everything to make training faster!
3
+ MODEL USE: unsloth/gemma-3-12b-it_task1_1_epochs_test_neutral_partialTraits
4
+ ==((====))== Unsloth 2025.3.19: Fast Gemma3 patching. Transformers: 4.50.0.
5
+ \\ /| NVIDIA RTX A5000. Num GPUs = 1. Max memory: 23.547 GB. Platform: Linux.
6
+ O^O/ \_/ \ Torch: 2.6.0+cu124. CUDA: 8.6. CUDA Toolkit: 12.4. Triton: 3.2.0
7
+ \ / Bfloat16 = TRUE. FA [Xformers = 0.0.29.post3. FA2 = False]
8
+ "-____-" Free license: http://github.com/unslothai/unsloth
9
+ Unsloth: Fast downloading is enabled - ignore downloading bars which are red colored!
10
+
11
+ Using a slow image processor as `use_fast` is unset and a slow processor was saved with this model. `use_fast=True` will be the default behavior in v4.50, even if the model was saved with a slow processor. This will result in minor differences in outputs. You'll still be able to use a slow processor with `use_fast=False`.
12
+ /fs/nexus-projects/health_sim_ai/venvs/llm/lib/python3.12/site-packages/gradio/blocks.py:1069: UserWarning: Cannot load gradio/dark. Caught Exception: The space gradio/dark does not exist
13
+ warnings.warn(f"Cannot load {theme}. Caught Exception: {str(e)}")
14
+ /fs/nexus-projects/health_sim_ai/src_hf_deploy/app.py:1010: UserWarning: You have not specified a value for the `type` parameter. Defaulting to the 'tuples' format for chatbot messages, but this is deprecated and will be removed in a future version of Gradio. Please set type='messages' instead, which uses openai-style dictionaries with 'role' and 'content' keys.
15
+ chatbot=gr.Chatbot(height=500), # height=330
16
+ /fs/nexus-projects/health_sim_ai/venvs/llm/lib/python3.12/site-packages/gradio/chat_interface.py:323: UserWarning: The type of the gr.Chatbot does not match the type of the gr.ChatInterface.The type of the gr.ChatInterface, 'messages', will be used.
17
+ warnings.warn(
18
+ slurmstepd: error: *** JOB 5936041 ON gammagpu09 CANCELLED AT 2025-12-08T03:07:56 ***
app_logs/app_5936047.out ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ 🦥 Unsloth: Will patch your computer to enable 2x faster free finetuning.
2
+ 🦥 Unsloth Zoo will now patch everything to make training faster!
3
+ MODEL USE: unsloth/gemma-3-12b-it_task1_1_epochs_test_neutral_partialTraits
4
+ ==((====))== Unsloth 2025.3.19: Fast Gemma3 patching. Transformers: 4.50.0.
5
+ \\ /| NVIDIA RTX A5000. Num GPUs = 1. Max memory: 23.547 GB. Platform: Linux.
6
+ O^O/ \_/ \ Torch: 2.6.0+cu124. CUDA: 8.6. CUDA Toolkit: 12.4. Triton: 3.2.0
7
+ \ / Bfloat16 = TRUE. FA [Xformers = 0.0.29.post3. FA2 = False]
8
+ "-____-" Free license: http://github.com/unslothai/unsloth
9
+ Unsloth: Fast downloading is enabled - ignore downloading bars which are red colored!
10
+
11
+ Using a slow image processor as `use_fast` is unset and a slow processor was saved with this model. `use_fast=True` will be the default behavior in v4.50, even if the model was saved with a slow processor. This will result in minor differences in outputs. You'll still be able to use a slow processor with `use_fast=False`.
12
+ /fs/nexus-projects/health_sim_ai/venvs/llm/lib/python3.12/site-packages/gradio/blocks.py:1069: UserWarning: Cannot load gradio/dark. Caught Exception: The space gradio/dark does not exist
13
+ warnings.warn(f"Cannot load {theme}. Caught Exception: {str(e)}")
14
+ /fs/nexus-projects/health_sim_ai/src_hf_deploy/app.py:1010: UserWarning: You have not specified a value for the `type` parameter. Defaulting to the 'tuples' format for chatbot messages, but this is deprecated and will be removed in a future version of Gradio. Please set type='messages' instead, which uses openai-style dictionaries with 'role' and 'content' keys.
15
+ chatbot=gr.Chatbot(height=500), # height=330
16
+ /fs/nexus-projects/health_sim_ai/venvs/llm/lib/python3.12/site-packages/gradio/chat_interface.py:323: UserWarning: The type of the gr.Chatbot does not match the type of the gr.ChatInterface.The type of the gr.ChatInterface, 'messages', will be used.
17
+ warnings.warn(
18
+ grep: gradio_output.log: No such file or directory
19
+ Gradio Public URL:
app_logs/app_5936050.out ADDED
@@ -0,0 +1 @@
 
 
1
+ slurmstepd: error: *** JOB 5936050 ON gammagpu09 CANCELLED AT 2025-12-08T03:23:42 ***
app_logs/app_5936052.out ADDED
@@ -0,0 +1,57 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ 🦥 Unsloth: Will patch your computer to enable 2x faster free finetuning.
2
+ 🦥 Unsloth Zoo will now patch everything to make training faster!
3
+ MODEL USE: unsloth/gemma-3-12b-it_task1_1_epochs_test_neutral_partialTraits
4
+ ==((====))== Unsloth 2025.3.19: Fast Gemma3 patching. Transformers: 4.50.0.
5
+ \\ /| NVIDIA RTX A5000. Num GPUs = 1. Max memory: 23.547 GB. Platform: Linux.
6
+ O^O/ \_/ \ Torch: 2.6.0+cu124. CUDA: 8.6. CUDA Toolkit: 12.4. Triton: 3.2.0
7
+ \ / Bfloat16 = TRUE. FA [Xformers = 0.0.29.post3. FA2 = False]
8
+ "-____-" Free license: http://github.com/unslothai/unsloth
9
+ Unsloth: Fast downloading is enabled - ignore downloading bars which are red colored!
10
+
11
+ Using a slow image processor as `use_fast` is unset and a slow processor was saved with this model. `use_fast=True` will be the default behavior in v4.50, even if the model was saved with a slow processor. This will result in minor differences in outputs. You'll still be able to use a slow processor with `use_fast=False`.
12
+ /fs/nexus-projects/health_sim_ai/venvs/llm/lib/python3.12/site-packages/gradio/blocks.py:1069: UserWarning: Cannot load gradio/dark. Caught Exception: The space gradio/dark does not exist
13
+ warnings.warn(f"Cannot load {theme}. Caught Exception: {str(e)}")
14
+ /fs/nexus-projects/health_sim_ai/src_hf_deploy/app.py:1010: UserWarning: You have not specified a value for the `type` parameter. Defaulting to the 'tuples' format for chatbot messages, but this is deprecated and will be removed in a future version of Gradio. Please set type='messages' instead, which uses openai-style dictionaries with 'role' and 'content' keys.
15
+ chatbot=gr.Chatbot(height=500), # height=330
16
+ /fs/nexus-projects/health_sim_ai/venvs/llm/lib/python3.12/site-packages/gradio/chat_interface.py:323: UserWarning: The type of the gr.Chatbot does not match the type of the gr.ChatInterface.The type of the gr.ChatInterface, 'messages', will be used.
17
+ warnings.warn(
18
+ * Running on local URL: http://127.0.0.1:7860
19
+ * Running on public URL: https://8a035fb4eb42d29651.gradio.live
20
+
21
+ This share link expires in 1 week. For free permanent hosting and GPU upgrades, run `gradio deploy` from the terminal in the working directory to deploy to Hugging Face Spaces (https://huggingface.co/spaces)
22
+ Chronic Obstructive Pulmonary Disease (COPD)
23
+ Style: Write in a informal, pragmatic tone, focusing on clarity and utility.
24
+ Lexical: Feel free to vary sentence structures slightly.
25
+ Opening: Through the message
26
+ Generic opening: My initial take on
27
+ Chronic Obstructive Pulmonary Disease (COPD)
28
+ Style: Write in a slightly narrative, flowing tone.
29
+ Lexical: Use a light mix of paraphrasing expressions.
30
+ Opening: Through the message
31
+ Generic opening: On top of my head
32
+ The message makes me more concerned about the health risks of COPD and smoking - Scale: 1 (not at all) - 9 (extremely). Your answer: 8
33
+ The message motivates me to not smoke. - Scale: 1 (not at all) - 9 (extremely). Your answer: 8
34
+ In your opinion, how harmful is smoking to your general health? - Scale: 0 (not at all)-6 (extremely harmful). Your answer: 6
35
+
36
+ Nutrition
37
+ Style: Write in a slightly narrative, flowing tone.
38
+ Lexical: Use a mix of simple and slightly complex sentences.
39
+ Opening: Reflecting on the message here
40
+ Generic opening: Just speaking for myself,
41
+ The message makes me more concerned about the health risks of poor eating habits - Scale: 1 (not at all) - 9 (extremely). Your answer: 8
42
+ The message motivates me to make healthy eating choices - Scale: 1 (not at all) - 9 (extremely). Your answer: 8
43
+ In your opinion, how harmful is neglecting proper nutrition and weight management to your overall health? - Scale: 0 (not at all)-6 (extremely harmful). Your answer: 6
44
+
45
+ Traits Demographics:
46
+ Gender: Female
47
+ Age: 25–34
48
+ Current Profession: Student
49
+ Race/Ethnicity: White/Caucasian
50
+ Religious/Cultural Group: [Not specified]
51
+ Political Affiliation: [Not specified]
52
+ Highest Education: [Not specified]
53
+ Annual Household Income: $75,000–$99,999
54
+ Family Status: [Not specified]
55
+
56
+ Emulated response: Reflecting on the message here, I'm now very concerned about the health consequences of poor eating. The message really motivates me to make healthy choices - I feel more determined than ever to prioritize my nutrition and maintain a healthy weight. It's made me realize the importance of mindful eating and making informed food choices.
57
+ ====================================================================================================
assets/umd_logo.png ADDED

Git LFS Details

  • SHA256: 6163ef79b6fa3772492de058d477a5852cb7b5d920a32d25c764270a917802e6
  • Pointer size: 131 Bytes
  • Size of remote file: 300 kB
configs/prompts.yaml ADDED
@@ -0,0 +1,100 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #########################################################
2
+ ### TASK 1: COMMUNITY SIMULATION ###
3
+ #########################################################
4
+ # SYSTEM PROMPT FOR COMMUNITY RESPONSE PREDICTION
5
+ SYSTEM_SIM: >
6
+ You are a person with unique demographic and personality traits.
7
+ During an online study, you are shown a public health campaign poster.
8
+ Based on your background, you naturally have thoughts, feelings, and reactions to what you see.
9
+ # SIMULATION PROMPT FOR COMMUNITY RESPONSE PREDICTION
10
+ SIMULATION_SIM: >
11
+ You are now being shown a public health campaign poster, followed by a survey question
12
+ designed to capture your thoughts, feelings, and emotions in response to the image.
13
+ # TASK 1: RESPONSE PREDICTION -> MCQ (SENTIMENT, BEHAVIORAL, EMOTIONAL)
14
+ INSTRUCTION_MCQ: |
15
+ Please respond the survey question authentically, as if you are completing a real online survey. Your personality, locus of control, and demographic traits influence your reactions.
16
+ **CRITICAL INSTRUCTIONS - FOLLOW THESE EXACTLY:**
17
+ - **BE REALISTIC and HUMAN-LIKE. Only answer the questions. Imagine you're quickly filling out a survey.**
18
+ - Answer with **ONLY the valid number** to realistically express your emotions/feeling.
19
+ - Only provide the answer. DO NOT REPEAT THE QUESTION NOR PROVIDE REASONING.
20
+ # TASK 1: RESPONSE PREDICTION -> OPEN-TEXT
21
+ INSTRUCTION_FREE: |
22
+ Please respond the survey question authentically, as if you are completing a real online survey. Your personality, locus of control, and demographic traits influence your reactions.
23
+ **CRITICAL INSTRUCTIONS - FOLLOW THESE EXACTLY:**
24
+ - **BE REALISTIC and HUMAN-LIKE. Only answer the questions. Imagine you're quickly filling out a survey.**
25
+ - Answer very briefly in **1-3 sentences**.
26
+ - Only provide the answer. DO NOT REPEAT THE QUESTION NOR PROVIDE REASONING.
27
+
28
+
29
+ #########################################################
30
+ ### TASK 2: COMMUNICATION STRATEGY SUGGESTION ###
31
+ #########################################################
32
+ # SYSTEM PROMPT FOR COMMUNICATION STRATEGY SUGGESTION
33
+ SYSTEM_REC: >
34
+ You are not an AI or language model. You are an expert with PhD-level training in demography, behavioral science, psychology, marketing, and public health.
35
+ You have deep knowledge of how individuals and communities with varying personality traits and demographic profiles respond to different campaign messaging strategies and visual stimuli.
36
+ # SIMULATION PROMPT FOR COMMUNICATION STRATEGY SUGGESTION
37
+ SIMULATION_REC: |
38
+ You need to evaluate campaign messaging strategies intended to influence a specific individual or community, characterized by known demographic and personality traits.
39
+ There are 3 main types of communication strategies:
40
+ (1) Informational/Neutral
41
+ (2) Self-Efficacy
42
+ (3) Threatening/Fear-driven
43
+ # TASK 2: STRATEGY EVALUATION
44
+ INSTRUCTION_REC: |
45
+ Evaluate the likely impact of the above communication strategy on the specified individual/community.
46
+ (1) Negative impact (expected response score of 1–3 out of 9)
47
+ (2) No impact (expected response score of 4–6 out of 9)
48
+ (3) Positive impact (expected response score of 7–9 out of 9)
49
+
50
+ Please answer with 1 of 3 following labels only: "positive", "negative", or "no impact".
51
+
52
+ # # TASK 2: STRATEGY SUGGESTION
53
+ # INSTRUCTION_REC_NO_IMPACT: |
54
+ # There are 3 main types of communication strategies:
55
+ # (1) Informational/Neutral
56
+ # (2) Self-Efficacy
57
+ # (3) Threatening/Fear-driven
58
+
59
+ # Based on your expertise, which strategy is most likely to have LITTLE IMPACT (i.e., an expected response score of 4–6 out of 9) on the target individual or community?
60
+ # Suggestion only ONE and only provide the strategy name.
61
+ # INSTRUCTION_REC_POSITIVE: |
62
+ # There are 3 main types of communication strategies:
63
+ # (1) Informational/Neutral
64
+ # (2) Self-Efficacy
65
+ # (3) Threatening/Fear-driven
66
+
67
+ # Based on your expertise, which strategy is most likely to have a POSITIVE IMPACT (i.e., an expected response score of 7–9 out of 9) on the target individual or community?
68
+ # Suggestion only ONE and only provide the strategy name.
69
+ # INSTRUCTION_REC_NEGATIVE: |
70
+ # There are 3 main types of communication strategies:
71
+ # (1) Informational/Neutral
72
+ # (2) Self-Efficacy
73
+ # (3) Threatening/Fear-driven
74
+
75
+ # Based on your expertise, which strategy is most likely to have a NEGATIVE IMPACT (i.e., an expected response score of 1–3 out of 9) on the target individual or community?
76
+ # Suggestion only ONE and only provide the strategy name.
77
+
78
+
79
+
80
+ #########################################################
81
+ ### TASK 3: COMMUNICATION STRATEGY CLASSIFICATION ###
82
+ #########################################################
83
+ # SYSTEM PROMPT FOR COMMUNICATION STRATEGY CLASSIFICATION
84
+ SYSTEM_CLS: >
85
+ You are an expert with PhD qualifications in 5 areas: demography, behavioral science, psychology, marketing, and public health.
86
+ # SIMULATION PROMPT FOR COMMUNICATION STRATEGY CLASSIFICATION
87
+ SIMULATION_CLS: >
88
+ You are now being shown a public health campaign poster.
89
+ # TASK 3: STRATEGY CLASSIFICAITON
90
+ INSTRUCTION_STRAT: |
91
+ There are <?> main types of communication strategies:
92
+ (1)
93
+ (2)
94
+ (3)
95
+
96
+ Based on your experience and expertise, what is the communication strategy of the poster? Choose only one and only include the strategy name.
97
+
98
+ JSON_CONVERSION: >
99
+ Extract the content in this answer to JSON with format: <Q1>: \"<Answer to Q1>\"
100
+ Ensure all questions are properly included (13 questions in total).
configs/task1_demo.yaml ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ temperature: 0.
2
+ top_p: 1.0
3
+ stochastic: False # deterministic
4
+ seed: 99
5
+ infer_engine: "unsloth"
6
+ data_path: "data/survey_responses_screened.csv" # make sure to export HOME to project path
7
+ # export_path: "$HOME/src/evals/task1_ai_responses.csv"
8
+
9
+ #########################
10
+ ### Emulation Model ###
11
+ #########################
12
+ # model: "unsloth/Llama-3.2-11B-Vision-Instruct"
13
+ # model: "unsloth/Llama-3.2-11B-Vision-Instruct_task1_1_epochs_test_train_on_all"
14
+ # model: "unsloth/gemma-3-4b-it_task1_1_epochs_test_train_on_all"
15
+ # model: "unsloth/gemma-3-4b-it_task1_1_epochs_test_neutral"
16
+ # model: "unsloth/gemma-3-4b-it_task1_1_epochs_test_efficacy"
17
+
18
+ # model: "unsloth/gemma-3-12b-it"
19
+ # model: "unsloth/gemma-3-12b-it_task1_1_epochs_test_neutral"
20
+ # model: "unsloth/gemma-3-12b-it_task1_1_epochs_test_threatening_partialTraits"
21
+ model: "unsloth/gemma-3-12b-it_task1_1_epochs_test_neutral_partialTraits"
22
+ vision: true # default
23
+ trait: true # default
24
+ version: ""
25
+
26
+
27
+ model_summarize: "unsloth/gemma-3-12b-it"
configs/task1_demo_sph.yaml ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ temperature: 0.
2
+ top_p: 1.0
3
+ stochastic: False # deterministic
4
+ seed: 99
5
+ infer_engine: "unsloth"
6
+ data_path: "data/survey_responses_screened.csv" # make sure to export HOME to project path
7
+ # export_path: "$HOME/src/evals/task1_ai_responses.csv"
8
+
9
+ #########################
10
+ ### Emulation Model ###
11
+ #########################
12
+ # model: "unsloth/Llama-3.2-11B-Vision-Instruct"
13
+ # model: "unsloth/Llama-3.2-11B-Vision-Instruct_task1_1_epochs_test_train_on_all"
14
+ # model: "unsloth/gemma-3-4b-it_task1_1_epochs_test_train_on_all"
15
+ # model: "unsloth/gemma-3-4b-it_task1_1_epochs_test_neutral"
16
+ # model: "unsloth/gemma-3-4b-it_task1_1_epochs_test_efficacy"
17
+
18
+ # model: "unsloth/gemma-3-12b-it"
19
+ # model: "unsloth/gemma-3-12b-it_task1_1_epochs_test_neutral"
20
+ # model: "unsloth/gemma-3-12b-it_task1_1_epochs_test_threatening_partialTraits"
21
+ model: "unsloth/gemma-3-12b-it_task1_1_epochs_test_neutral_partialTraits"
22
+ # model: "unsloth/gemma-3-12b-it_task1_1_epochs_test_neutral_partialTraits_sphTraits"
23
+ vision: true # default
24
+ trait: true # default
25
+ version: ""
26
+
27
+
28
+ model_summarize: "unsloth/gemma-3-12b-it"
data/survey_responses_screened.csv ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:9eb0e96347a8739c4d6b138a9395feeec591b8dd64fd0f6a74b857b49bb47b2c
3
+ size 18465749
push.sh ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ git init
2
+ git lfs install
3
+
4
+ git add app.py configs data requirements.txt unsloth utils.py
5
+ git commit -m "Initial commit"
6
+
7
+ git branch -M main
8
+
9
+ git lfs migrate import --include-ref=refs/heads/main --above=10MB -y
10
+
11
+ git remote add huggingface https://huggingface.co/spaces/anh-nn01/ai_empowered_community_simulation_beta
12
+
13
+ git push -u huggingface main --
14
+
15
+ # Notes:
16
+ # 1. use module load for lfs
17
+ # 2. use only launch(), not launch(share=True, max_threads=1,)
18
+ # 3. export full requirements.txt using pip freeze > requirements.txt
19
+ # => comment out `ipython` and `ollama` dependencies
20
+ # 4. Manual upload LoRA weights to HF repo due to potential file corruption
21
+ # 5. Manual upload of /app/assets/umd_logo.png
22
+ # 6. Perhaps manual upload everything is more stable for now :))
requirements.txt ADDED
@@ -0,0 +1,167 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ accelerate==1.6.0
2
+ aiofiles==24.1.0
3
+ aiohappyeyeballs==2.6.1
4
+ aiohttp==3.11.16
5
+ aiosignal==1.3.2
6
+ annotated-types==0.7.0
7
+ anyio==4.9.0
8
+ asttokens==3.0.1
9
+ attrs==25.3.0
10
+ bitsandbytes==0.45.5
11
+ Brotli==1.1.0
12
+ certifi==2025.1.31
13
+ charset-normalizer==3.4.1
14
+ click==8.1.8
15
+ colored==2.3.0
16
+ contourpy==1.3.2
17
+ cut-cross-entropy==25.1.1
18
+ cycler==0.12.1
19
+ Cython==3.1.2
20
+ dataclasses-json==0.6.7
21
+ datasets==3.5.0
22
+ decorator==5.2.1
23
+ diffusers @ git+https://github.com/huggingface/diffusers.git@ee40088fe5437f8ed65ec96a22250149e4f334cc
24
+ dill==0.3.8
25
+ docker-pycreds==0.4.0
26
+ docstring_parser==0.16
27
+ executing==2.2.1
28
+ fastapi==0.119.0
29
+ ffmpeg==1.4
30
+ ffmpy==0.6.3
31
+ filelock==3.18.0
32
+ fonttools==4.58.0
33
+ frozenlist==1.5.0
34
+ fsspec==2024.12.0
35
+ gitdb==4.0.12
36
+ GitPython==3.1.44
37
+ gradio==5.49.1
38
+ gradio_client==1.13.3
39
+ greenlet==3.2.0
40
+ groovy==0.1.2
41
+ h11==0.14.0
42
+ hf-xet==1.1.10
43
+ hf_transfer==0.1.9
44
+ httpcore==1.0.8
45
+ httpx==0.27.2
46
+ httpx-sse==0.4.0
47
+ huggingface-hub==0.35.3
48
+ idna==3.10
49
+ importlib_metadata==8.6.1
50
+ # ipython==9.8.0
51
+ # ipython_pygments_lexers==1.1.1
52
+ jedi==0.19.2
53
+ Jinja2==3.1.6
54
+ jsonpatch==1.33
55
+ jsonpointer==3.0.0
56
+ kiwisolver==1.4.8
57
+ langchain==0.3.23
58
+ langchain-community==0.3.21
59
+ langchain-core==0.3.52
60
+ langchain-ollama==0.2.1
61
+ langchain-text-splitters==0.3.8
62
+ langsmith==0.3.31
63
+ markdown-it-py==3.0.0
64
+ MarkupSafe==3.0.2
65
+ marshmallow==3.26.1
66
+ matplotlib==3.10.3
67
+ matplotlib-inline==0.2.1
68
+ mdurl==0.1.2
69
+ mpmath==1.3.0
70
+ multidict==6.4.3
71
+ multiprocess==0.70.16
72
+ mypy-extensions==1.0.0
73
+ networkx==3.4.2
74
+ numpy==2.2.4
75
+ nvidia-cublas-cu12==12.4.5.8
76
+ nvidia-cuda-cupti-cu12==12.4.127
77
+ nvidia-cuda-nvrtc-cu12==12.4.127
78
+ nvidia-cuda-runtime-cu12==12.4.127
79
+ nvidia-cudnn-cu12==9.1.0.70
80
+ nvidia-cufft-cu12==11.2.1.3
81
+ nvidia-curand-cu12==10.3.5.147
82
+ nvidia-cusolver-cu12==11.6.1.9
83
+ nvidia-cusparse-cu12==12.3.1.170
84
+ nvidia-cusparselt-cu12==0.6.2
85
+ nvidia-nccl-cu12==2.21.5
86
+ nvidia-nvjitlink-cu12==12.4.127
87
+ nvidia-nvtx-cu12==12.4.127
88
+ # ollama==0.4.2
89
+ orjson==3.10.16
90
+ packaging==24.2
91
+ pandas==2.2.3
92
+ parso==0.8.5
93
+ peft==0.15.2
94
+ pexpect==4.9.0
95
+ pillow==11.2.1
96
+ platformdirs==4.3.7
97
+ prompt_toolkit==3.0.52
98
+ propcache==0.3.1
99
+ protobuf==3.20.3
100
+ psutil==7.0.0
101
+ ptyprocess==0.7.0
102
+ pure_eval==0.2.3
103
+ pyarrow==19.0.1
104
+ pydantic==2.11.3
105
+ pydantic-settings==2.8.1
106
+ pydantic_core==2.33.1
107
+ pydub==0.25.1
108
+ Pygments==2.19.1
109
+ pyparsing==3.2.3
110
+ python-dateutil==2.9.0.post0
111
+ python-dotenv==1.1.0
112
+ python-multipart==0.0.20
113
+ pytz==2025.2
114
+ PyYAML==6.0.2
115
+ regex==2024.11.6
116
+ requests==2.32.3
117
+ requests-toolbelt==1.0.0
118
+ rich==14.0.0
119
+ ruff==0.14.0
120
+ safehttpx==0.1.6
121
+ safetensors==0.5.3
122
+ seaborn==0.13.2
123
+ semantic-version==2.10.0
124
+ sentencepiece==0.2.0
125
+ sentry-sdk==2.27.0
126
+ setproctitle==1.3.5
127
+ setuptools==79.0.0
128
+ shellingham==1.5.4
129
+ shtab==1.7.2
130
+ six==1.17.0
131
+ smmap==5.0.2
132
+ sniffio==1.3.1
133
+ SQLAlchemy==2.0.40
134
+ stack-data==0.6.3
135
+ starlette==0.48.0
136
+ sympy==1.13.1
137
+ tenacity==9.1.2
138
+ termcolor==3.0.1
139
+ tokenizers==0.21.4
140
+ tomlkit==0.13.3
141
+ torch==2.6.0
142
+ torchvision==0.21.0
143
+ tqdm==4.67.1
144
+ traitlets==5.14.3
145
+ transformers==4.50.0
146
+ triton==3.2.0
147
+ trl==0.15.2
148
+ typeguard==4.4.2
149
+ typer==0.19.2
150
+ typing-inspect==0.9.0
151
+ typing-inspection==0.4.0
152
+ typing_extensions==4.13.2
153
+ tyro==0.9.19
154
+ tzdata==2025.2
155
+ unsloth==2025.3.19
156
+ unsloth_zoo==2025.3.17
157
+ urllib3==2.4.0
158
+ uvicorn==0.37.0
159
+ wandb==0.19.10
160
+ wcwidth==0.2.14
161
+ websockets==15.0.1
162
+ wheel==0.45.1
163
+ xformers==0.0.29.post3
164
+ xxhash==3.5.0
165
+ yarl==1.19.0
166
+ zipp==3.21.0
167
+ zstandard==0.23.0
requirements_concise.txt ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ numpy==2.2.4
2
+ pandas==2.2.3
3
+ pillow==11.2.1
4
+ langchain==0.3.23
5
+ langchain-core==0.3.52
6
+ langchain-community==0.3.21
7
+ langchain-ollama==0.2.1
8
+ # ollama==0.4.2
9
+ tqdm
10
+ torch
11
+ unsloth==2025.3.19
12
+ termcolor
13
+ python-dotenv
14
+ transformers==4.50.0
15
+ wandb
16
+
17
+ # Image Generation
18
+ # git+https://github.com/huggingface/diffusers.git
unsloth/gemma-3-12b-it_task1_1_epochs_test_neutral_partialTraits/README.md ADDED
@@ -0,0 +1,202 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ base_model: unsloth/gemma-3-12b-it-unsloth-bnb-4bit
3
+ library_name: peft
4
+ ---
5
+
6
+ # Model Card for Model ID
7
+
8
+ <!-- Provide a quick summary of what the model is/does. -->
9
+
10
+
11
+
12
+ ## Model Details
13
+
14
+ ### Model Description
15
+
16
+ <!-- Provide a longer summary of what this model is. -->
17
+
18
+
19
+
20
+ - **Developed by:** [More Information Needed]
21
+ - **Funded by [optional]:** [More Information Needed]
22
+ - **Shared by [optional]:** [More Information Needed]
23
+ - **Model type:** [More Information Needed]
24
+ - **Language(s) (NLP):** [More Information Needed]
25
+ - **License:** [More Information Needed]
26
+ - **Finetuned from model [optional]:** [More Information Needed]
27
+
28
+ ### Model Sources [optional]
29
+
30
+ <!-- Provide the basic links for the model. -->
31
+
32
+ - **Repository:** [More Information Needed]
33
+ - **Paper [optional]:** [More Information Needed]
34
+ - **Demo [optional]:** [More Information Needed]
35
+
36
+ ## Uses
37
+
38
+ <!-- Address questions around how the model is intended to be used, including the foreseeable users of the model and those affected by the model. -->
39
+
40
+ ### Direct Use
41
+
42
+ <!-- This section is for the model use without fine-tuning or plugging into a larger ecosystem/app. -->
43
+
44
+ [More Information Needed]
45
+
46
+ ### Downstream Use [optional]
47
+
48
+ <!-- This section is for the model use when fine-tuned for a task, or when plugged into a larger ecosystem/app -->
49
+
50
+ [More Information Needed]
51
+
52
+ ### Out-of-Scope Use
53
+
54
+ <!-- This section addresses misuse, malicious use, and uses that the model will not work well for. -->
55
+
56
+ [More Information Needed]
57
+
58
+ ## Bias, Risks, and Limitations
59
+
60
+ <!-- This section is meant to convey both technical and sociotechnical limitations. -->
61
+
62
+ [More Information Needed]
63
+
64
+ ### Recommendations
65
+
66
+ <!-- This section is meant to convey recommendations with respect to the bias, risk, and technical limitations. -->
67
+
68
+ Users (both direct and downstream) should be made aware of the risks, biases and limitations of the model. More information needed for further recommendations.
69
+
70
+ ## How to Get Started with the Model
71
+
72
+ Use the code below to get started with the model.
73
+
74
+ [More Information Needed]
75
+
76
+ ## Training Details
77
+
78
+ ### Training Data
79
+
80
+ <!-- This should link to a Dataset Card, perhaps with a short stub of information on what the training data is all about as well as documentation related to data pre-processing or additional filtering. -->
81
+
82
+ [More Information Needed]
83
+
84
+ ### Training Procedure
85
+
86
+ <!-- This relates heavily to the Technical Specifications. Content here should link to that section when it is relevant to the training procedure. -->
87
+
88
+ #### Preprocessing [optional]
89
+
90
+ [More Information Needed]
91
+
92
+
93
+ #### Training Hyperparameters
94
+
95
+ - **Training regime:** [More Information Needed] <!--fp32, fp16 mixed precision, bf16 mixed precision, bf16 non-mixed precision, fp16 non-mixed precision, fp8 mixed precision -->
96
+
97
+ #### Speeds, Sizes, Times [optional]
98
+
99
+ <!-- This section provides information about throughput, start/end time, checkpoint size if relevant, etc. -->
100
+
101
+ [More Information Needed]
102
+
103
+ ## Evaluation
104
+
105
+ <!-- This section describes the evaluation protocols and provides the results. -->
106
+
107
+ ### Testing Data, Factors & Metrics
108
+
109
+ #### Testing Data
110
+
111
+ <!-- This should link to a Dataset Card if possible. -->
112
+
113
+ [More Information Needed]
114
+
115
+ #### Factors
116
+
117
+ <!-- These are the things the evaluation is disaggregating by, e.g., subpopulations or domains. -->
118
+
119
+ [More Information Needed]
120
+
121
+ #### Metrics
122
+
123
+ <!-- These are the evaluation metrics being used, ideally with a description of why. -->
124
+
125
+ [More Information Needed]
126
+
127
+ ### Results
128
+
129
+ [More Information Needed]
130
+
131
+ #### Summary
132
+
133
+
134
+
135
+ ## Model Examination [optional]
136
+
137
+ <!-- Relevant interpretability work for the model goes here -->
138
+
139
+ [More Information Needed]
140
+
141
+ ## Environmental Impact
142
+
143
+ <!-- Total emissions (in grams of CO2eq) and additional considerations, such as electricity usage, go here. Edit the suggested text below accordingly -->
144
+
145
+ Carbon emissions can be estimated using the [Machine Learning Impact calculator](https://mlco2.github.io/impact#compute) presented in [Lacoste et al. (2019)](https://arxiv.org/abs/1910.09700).
146
+
147
+ - **Hardware Type:** [More Information Needed]
148
+ - **Hours used:** [More Information Needed]
149
+ - **Cloud Provider:** [More Information Needed]
150
+ - **Compute Region:** [More Information Needed]
151
+ - **Carbon Emitted:** [More Information Needed]
152
+
153
+ ## Technical Specifications [optional]
154
+
155
+ ### Model Architecture and Objective
156
+
157
+ [More Information Needed]
158
+
159
+ ### Compute Infrastructure
160
+
161
+ [More Information Needed]
162
+
163
+ #### Hardware
164
+
165
+ [More Information Needed]
166
+
167
+ #### Software
168
+
169
+ [More Information Needed]
170
+
171
+ ## Citation [optional]
172
+
173
+ <!-- If there is a paper or blog post introducing the model, the APA and Bibtex information for that should go in this section. -->
174
+
175
+ **BibTeX:**
176
+
177
+ [More Information Needed]
178
+
179
+ **APA:**
180
+
181
+ [More Information Needed]
182
+
183
+ ## Glossary [optional]
184
+
185
+ <!-- If relevant, include terms and calculations in this section that can help readers understand the model or model card. -->
186
+
187
+ [More Information Needed]
188
+
189
+ ## More Information [optional]
190
+
191
+ [More Information Needed]
192
+
193
+ ## Model Card Authors [optional]
194
+
195
+ [More Information Needed]
196
+
197
+ ## Model Card Contact
198
+
199
+ [More Information Needed]
200
+ ### Framework versions
201
+
202
+ - PEFT 0.15.2
unsloth/gemma-3-12b-it_task1_1_epochs_test_neutral_partialTraits/adapter_config.json ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "alpha_pattern": {},
3
+ "auto_mapping": null,
4
+ "base_model_name_or_path": "unsloth/gemma-3-12b-it-unsloth-bnb-4bit",
5
+ "bias": "none",
6
+ "corda_config": null,
7
+ "eva_config": null,
8
+ "exclude_modules": null,
9
+ "fan_in_fan_out": false,
10
+ "inference_mode": true,
11
+ "init_lora_weights": true,
12
+ "layer_replication": null,
13
+ "layers_pattern": null,
14
+ "layers_to_transform": null,
15
+ "loftq_config": {},
16
+ "lora_alpha": 8,
17
+ "lora_bias": false,
18
+ "lora_dropout": 0,
19
+ "megatron_config": null,
20
+ "megatron_core": "megatron.core",
21
+ "modules_to_save": null,
22
+ "peft_type": "LORA",
23
+ "r": 8,
24
+ "rank_pattern": {},
25
+ "revision": null,
26
+ "target_modules": "(?:.*?(?:language|text).*?(?:self_attn|attention|attn|mlp|feed_forward|ffn|dense).*?(?:k_proj|v_proj|q_proj|out_proj|fc1|fc2|o_proj|gate_proj|up_proj|down_proj).*?)|(?:\\bmodel\\.layers\\.[\\d]{1,}\\.(?:self_attn|attention|attn|mlp|feed_forward|ffn|dense)\\.(?:(?:k_proj|v_proj|q_proj|out_proj|fc1|fc2|o_proj|gate_proj|up_proj|down_proj)))",
27
+ "task_type": "CAUSAL_LM",
28
+ "trainable_token_indices": null,
29
+ "use_dora": false,
30
+ "use_rslora": false
31
+ }
unsloth/gemma-3-12b-it_task1_1_epochs_test_neutral_partialTraits/adapter_model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:7108dca92843322a12e503ab99cbd70a5f676fa25c54e6a11d88473f65143ee3
3
+ size 131040264
unsloth/gemma-3-12b-it_task1_1_epochs_test_neutral_partialTraits/added_tokens.json ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ {
2
+ "<image_soft_token>": 262144
3
+ }
unsloth/gemma-3-12b-it_task1_1_epochs_test_neutral_partialTraits/chat_template.json ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ {
2
+ "chat_template": "{{ bos_token }}\n{%- if messages[0]['role'] == 'system' -%}\n {%- if messages[0]['content'] is string -%}\n {%- set first_user_prefix = messages[0]['content'] + '\n\n' -%}\n {%- else -%}\n {%- set first_user_prefix = messages[0]['content'][0]['text'] + '\n\n' -%}\n {%- endif -%}\n {%- set loop_messages = messages[1:] -%}\n{%- else -%}\n {%- set first_user_prefix = \"\" -%}\n {%- set loop_messages = messages -%}\n{%- endif -%}\n{%- for message in loop_messages -%}\n {%- if (message['role'] == 'user') != (loop.index0 % 2 == 0) -%}\n {{ raise_exception(\"Conversation roles must alternate user/assistant/user/assistant/...\") }}\n {%- endif -%}\n {%- if (message['role'] == 'assistant') -%}\n {%- set role = \"model\" -%}\n {%- else -%}\n {%- set role = message['role'] -%}\n {%- endif -%}\n {{ '<start_of_turn>' + role + '\n' + (first_user_prefix if loop.first else \"\") }}\n {%- if message['content'] is string -%}\n {{ message['content'] | trim }}\n {%- elif message['content'] is iterable -%}\n {%- for item in message['content'] -%}\n {%- if item['type'] == 'image' -%}\n {{ '<start_of_image>' }}\n {%- elif item['type'] == 'text' -%}\n {{ item['text'] | trim }}\n {%- endif -%}\n {%- endfor -%}\n {%- else -%}\n {{ raise_exception(\"Invalid content type\") }}\n {%- endif -%}\n {{ '<end_of_turn>\n' }}\n{%- endfor -%}\n{%- if add_generation_prompt -%}\n {{ '<start_of_turn>model\n' }}\n{%- endif -%}\n"
3
+ }
unsloth/gemma-3-12b-it_task1_1_epochs_test_neutral_partialTraits/preprocessor_config.json ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "do_convert_rgb": null,
3
+ "do_normalize": true,
4
+ "do_pan_and_scan": null,
5
+ "do_rescale": true,
6
+ "do_resize": true,
7
+ "image_mean": [
8
+ 0.5,
9
+ 0.5,
10
+ 0.5
11
+ ],
12
+ "image_processor_type": "Gemma3ImageProcessor",
13
+ "image_seq_length": 256,
14
+ "image_std": [
15
+ 0.5,
16
+ 0.5,
17
+ 0.5
18
+ ],
19
+ "pan_and_scan_max_num_crops": null,
20
+ "pan_and_scan_min_crop_size": null,
21
+ "pan_and_scan_min_ratio_to_activate": null,
22
+ "processor_class": "Gemma3Processor",
23
+ "resample": 2,
24
+ "rescale_factor": 0.00392156862745098,
25
+ "size": {
26
+ "height": 896,
27
+ "width": 896
28
+ }
29
+ }
unsloth/gemma-3-12b-it_task1_1_epochs_test_neutral_partialTraits/processor_config.json ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ {
2
+ "image_seq_length": 256,
3
+ "processor_class": "Gemma3Processor"
4
+ }
unsloth/gemma-3-12b-it_task1_1_epochs_test_neutral_partialTraits/special_tokens_map.json ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "boi_token": "<start_of_image>",
3
+ "bos_token": {
4
+ "content": "<bos>",
5
+ "lstrip": false,
6
+ "normalized": false,
7
+ "rstrip": false,
8
+ "single_word": false
9
+ },
10
+ "eoi_token": "<end_of_image>",
11
+ "eos_token": {
12
+ "content": "<end_of_turn>",
13
+ "lstrip": false,
14
+ "normalized": false,
15
+ "rstrip": false,
16
+ "single_word": false
17
+ },
18
+ "image_token": "<image_soft_token>",
19
+ "pad_token": {
20
+ "content": "<pad>",
21
+ "lstrip": false,
22
+ "normalized": false,
23
+ "rstrip": false,
24
+ "single_word": false
25
+ },
26
+ "unk_token": {
27
+ "content": "<unk>",
28
+ "lstrip": false,
29
+ "normalized": false,
30
+ "rstrip": false,
31
+ "single_word": false
32
+ }
33
+ }
unsloth/gemma-3-12b-it_task1_1_epochs_test_neutral_partialTraits/tokenizer.json ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:7666402c0617d170e6b0a985b3130c3fb0795393aa0970600994a5d9aae12351
3
+ size 33384822
unsloth/gemma-3-12b-it_task1_1_epochs_test_neutral_partialTraits/tokenizer.model ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:1299c11d7cf632ef3b4e11937501358ada021bbdf7c47638d13c0ee982f2e79c
3
+ size 4689074
unsloth/gemma-3-12b-it_task1_1_epochs_test_neutral_partialTraits/tokenizer_config.json ADDED
The diff for this file is too large to render. See raw diff
 
unsloth_compiled_cache/AqlmLoraLinear_peft_forward.py ADDED
@@ -0,0 +1,67 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ 2025.3.17
3
+ 2025.3.19
4
+ 4.50.0
5
+ 0.15.2
6
+ __UNSLOTH_VERSIONING__
7
+ """
8
+
9
+ torch_compile_options = {'epilogue_fusion': True, 'max_autotune': False, 'shape_padding': True, 'trace.enabled': False, 'triton.cudagraphs': False}
10
+ from torch import Tensor
11
+ import torch
12
+ import torch.nn as nn
13
+ from torch.nn import functional as F
14
+ from peft.tuners.lora.aqlm import (torch)
15
+
16
+
17
+ torch_addmm = torch.addmm
18
+ torch_add = torch.add
19
+ # @torch.compile(fullgraph = False, dynamic = True, options = torch_compile_options)
20
+ def lora_forward(result, lora_A, lora_B, dropout, x, scaling):
21
+ xA = dropout(x) @ lora_A.weight.t()
22
+ # output = result + scaling * xA @ lora_B.weight.t()
23
+ shape = result.shape
24
+ output = torch_addmm(
25
+ result.view(-1, shape[-1]),
26
+ xA.view(-1, xA.shape[-1]),
27
+ lora_B.weight.t(),
28
+ alpha = scaling,
29
+ beta = 1,
30
+ ).view(shape)
31
+
32
+ bias = lora_B.bias
33
+ if bias is not None:
34
+ output = torch_add(
35
+ output,
36
+ bias,
37
+ alpha = scaling,
38
+ )
39
+ return output
40
+ pass
41
+
42
+ def unsloth_forward(self, x: torch.Tensor):
43
+ # note: logic differs from default Linear because merging is not supported
44
+ result = self.base_layer(x)
45
+
46
+ if self.disable_adapters:
47
+ return result
48
+
49
+ for active_adapter in self.active_adapters:
50
+ if active_adapter not in self.lora_A.keys():
51
+ continue
52
+ lora_A = self.lora_A[active_adapter]
53
+ lora_B = self.lora_B[active_adapter]
54
+ dropout = self.lora_dropout[active_adapter]
55
+ scaling = self.scaling[active_adapter]
56
+
57
+ requires_conversion = not torch.is_autocast_enabled()
58
+ if requires_conversion:
59
+ expected_dtype = result.dtype
60
+ x = self._cast_input_dtype(x, lora_A.weight.dtype)
61
+
62
+ output = lora_B(lora_A(dropout(x)))
63
+ if requires_conversion:
64
+ output = output.to(expected_dtype)
65
+ output = output * scaling
66
+ result += output
67
+ return result
unsloth_compiled_cache/AwqLoraLinear_peft_forward.py ADDED
@@ -0,0 +1,66 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ 2025.3.17
3
+ 2025.3.19
4
+ 4.50.0
5
+ 0.15.2
6
+ __UNSLOTH_VERSIONING__
7
+ """
8
+
9
+ torch_compile_options = {'epilogue_fusion': True, 'max_autotune': False, 'shape_padding': True, 'trace.enabled': False, 'triton.cudagraphs': False}
10
+ from torch import Tensor
11
+ import torch
12
+ import torch.nn as nn
13
+ from torch.nn import functional as F
14
+ from peft.tuners.lora.awq import (torch)
15
+
16
+
17
+ torch_addmm = torch.addmm
18
+ torch_add = torch.add
19
+ # @torch.compile(fullgraph = False, dynamic = True, options = torch_compile_options)
20
+ def lora_forward(result, lora_A, lora_B, dropout, x, scaling):
21
+ xA = dropout(x) @ lora_A.weight.t()
22
+ # output = result + scaling * xA @ lora_B.weight.t()
23
+ shape = result.shape
24
+ output = torch_addmm(
25
+ result.view(-1, shape[-1]),
26
+ xA.view(-1, xA.shape[-1]),
27
+ lora_B.weight.t(),
28
+ alpha = scaling,
29
+ beta = 1,
30
+ ).view(shape)
31
+
32
+ bias = lora_B.bias
33
+ if bias is not None:
34
+ output = torch_add(
35
+ output,
36
+ bias,
37
+ alpha = scaling,
38
+ )
39
+ return output
40
+ pass
41
+
42
+ def unsloth_forward(self, x: torch.Tensor):
43
+ result = self.quant_linear_module(x)
44
+
45
+ if self.disable_adapters:
46
+ return result
47
+
48
+ for active_adapter in self.active_adapters:
49
+ if active_adapter not in self.lora_A.keys():
50
+ continue
51
+ lora_A = self.lora_A[active_adapter]
52
+ lora_B = self.lora_B[active_adapter]
53
+ dropout = self.lora_dropout[active_adapter]
54
+ scaling = self.scaling[active_adapter]
55
+
56
+ requires_conversion = not torch.is_autocast_enabled()
57
+ if requires_conversion:
58
+ expected_dtype = result.dtype
59
+ x = self._cast_input_dtype(x, lora_A.weight.dtype)
60
+
61
+ output = lora_B(lora_A(dropout(x)))
62
+ if requires_conversion:
63
+ output = output.to(expected_dtype)
64
+ output = output * scaling
65
+ result = result + output
66
+ return result
unsloth_compiled_cache/BatchNorm1d.py ADDED
@@ -0,0 +1,88 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ 2025.3.17
3
+ 2025.3.19
4
+ 4.50.0
5
+ 0.15.2
6
+ __UNSLOTH_VERSIONING__
7
+ """
8
+
9
+ # Unsloth Zoo - Utilities for Unsloth
10
+ # Copyright 2023-present Daniel Han-Chen, Michael Han-Chen & the Unsloth team. All rights reserved.
11
+ #
12
+ # This program is free software: you can redistribute it and/or modify
13
+ # it under the terms of the GNU Lesser General Public License as published by
14
+ # the Free Software Foundation, either version 3 of the License, or
15
+ # (at your option) any later version.
16
+ #
17
+ # This program is distributed in the hope that it will be useful,
18
+ # but WITHOUT ANY WARRANTY; without even the implied warranty of
19
+ # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
20
+ # GNU General Public License for more details.
21
+ #
22
+ # You should have received a copy of the GNU Lesser General Public License
23
+ # along with this program. If not, see <https://www.gnu.org/licenses/>.
24
+
25
+ import os
26
+ import importlib.util
27
+ if importlib.util.find_spec("unsloth_studio") is None:
28
+ UNSLOTH_STUDIO_ENABLED = False
29
+ else:
30
+ UNSLOTH_STUDIO_ENABLED = os.environ.get("UNSLOTH_STUDIO_DISABLED", "0") == "0"
31
+ pass
32
+ from typing import List, Dict, Tuple, Optional, Any, Callable
33
+ import math
34
+
35
+ torch_compile_options = {'epilogue_fusion': True, 'max_autotune': False, 'shape_padding': True, 'trace.enabled': False, 'triton.cudagraphs': False}
36
+ from torch import Tensor
37
+ import torch
38
+ import torch.nn as nn
39
+ from torch.nn import functional as F
40
+ from transformers.models.gemma3.modeling_gemma3 import (nn)
41
+
42
+ def forward(self, input: Tensor) -> Tensor:
43
+ self._check_input_dim(input)
44
+
45
+ # exponential_average_factor is set to self.momentum
46
+ # (when it is available) only so that it gets updated
47
+ # in ONNX graph when this node is exported to ONNX.
48
+ if self.momentum is None:
49
+ exponential_average_factor = 0.0
50
+ else:
51
+ exponential_average_factor = self.momentum
52
+
53
+ if self.training and self.track_running_stats:
54
+ # TODO: if statement only here to tell the jit to skip emitting this when it is None
55
+ if self.num_batches_tracked is not None: # type: ignore[has-type]
56
+ self.num_batches_tracked.add_(1) # type: ignore[has-type]
57
+ if self.momentum is None: # use cumulative moving average
58
+ exponential_average_factor = 1.0 / float(self.num_batches_tracked)
59
+ else: # use exponential moving average
60
+ exponential_average_factor = self.momentum
61
+
62
+ r"""
63
+ Decide whether the mini-batch stats should be used for normalization rather than the buffers.
64
+ Mini-batch stats are used in training mode, and in eval mode when buffers are None.
65
+ """
66
+ if self.training:
67
+ bn_training = True
68
+ else:
69
+ bn_training = (self.running_mean is None) and (self.running_var is None)
70
+
71
+ r"""
72
+ Buffers are only updated if they are to be tracked and we are in training mode. Thus they only need to be
73
+ passed when the update should occur (i.e. in training mode when they are tracked), or when buffer stats are
74
+ used for normalization (i.e. in eval mode when buffers are not None).
75
+ """
76
+ return F.batch_norm(
77
+ input,
78
+ # If buffers are not to be tracked, ensure that they won't be updated
79
+ self.running_mean
80
+ if not self.training or self.track_running_stats
81
+ else None,
82
+ self.running_var if not self.training or self.track_running_stats else None,
83
+ self.weight,
84
+ self.bias,
85
+ bn_training,
86
+ exponential_average_factor,
87
+ self.eps,
88
+ ).to(input.dtype).to(input.dtype)
unsloth_compiled_cache/BatchNorm2d.py ADDED
@@ -0,0 +1,88 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ 2025.3.17
3
+ 2025.3.19
4
+ 4.50.0
5
+ 0.15.2
6
+ __UNSLOTH_VERSIONING__
7
+ """
8
+
9
+ # Unsloth Zoo - Utilities for Unsloth
10
+ # Copyright 2023-present Daniel Han-Chen, Michael Han-Chen & the Unsloth team. All rights reserved.
11
+ #
12
+ # This program is free software: you can redistribute it and/or modify
13
+ # it under the terms of the GNU Lesser General Public License as published by
14
+ # the Free Software Foundation, either version 3 of the License, or
15
+ # (at your option) any later version.
16
+ #
17
+ # This program is distributed in the hope that it will be useful,
18
+ # but WITHOUT ANY WARRANTY; without even the implied warranty of
19
+ # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
20
+ # GNU General Public License for more details.
21
+ #
22
+ # You should have received a copy of the GNU Lesser General Public License
23
+ # along with this program. If not, see <https://www.gnu.org/licenses/>.
24
+
25
+ import os
26
+ import importlib.util
27
+ if importlib.util.find_spec("unsloth_studio") is None:
28
+ UNSLOTH_STUDIO_ENABLED = False
29
+ else:
30
+ UNSLOTH_STUDIO_ENABLED = os.environ.get("UNSLOTH_STUDIO_DISABLED", "0") == "0"
31
+ pass
32
+ from typing import List, Dict, Tuple, Optional, Any, Callable
33
+ import math
34
+
35
+ torch_compile_options = {'epilogue_fusion': True, 'max_autotune': False, 'shape_padding': True, 'trace.enabled': False, 'triton.cudagraphs': False}
36
+ from torch import Tensor
37
+ import torch
38
+ import torch.nn as nn
39
+ from torch.nn import functional as F
40
+ from transformers.models.gemma3.modeling_gemma3 import (nn)
41
+
42
+ def forward(self, input: Tensor) -> Tensor:
43
+ self._check_input_dim(input)
44
+
45
+ # exponential_average_factor is set to self.momentum
46
+ # (when it is available) only so that it gets updated
47
+ # in ONNX graph when this node is exported to ONNX.
48
+ if self.momentum is None:
49
+ exponential_average_factor = 0.0
50
+ else:
51
+ exponential_average_factor = self.momentum
52
+
53
+ if self.training and self.track_running_stats:
54
+ # TODO: if statement only here to tell the jit to skip emitting this when it is None
55
+ if self.num_batches_tracked is not None: # type: ignore[has-type]
56
+ self.num_batches_tracked.add_(1) # type: ignore[has-type]
57
+ if self.momentum is None: # use cumulative moving average
58
+ exponential_average_factor = 1.0 / float(self.num_batches_tracked)
59
+ else: # use exponential moving average
60
+ exponential_average_factor = self.momentum
61
+
62
+ r"""
63
+ Decide whether the mini-batch stats should be used for normalization rather than the buffers.
64
+ Mini-batch stats are used in training mode, and in eval mode when buffers are None.
65
+ """
66
+ if self.training:
67
+ bn_training = True
68
+ else:
69
+ bn_training = (self.running_mean is None) and (self.running_var is None)
70
+
71
+ r"""
72
+ Buffers are only updated if they are to be tracked and we are in training mode. Thus they only need to be
73
+ passed when the update should occur (i.e. in training mode when they are tracked), or when buffer stats are
74
+ used for normalization (i.e. in eval mode when buffers are not None).
75
+ """
76
+ return F.batch_norm(
77
+ input,
78
+ # If buffers are not to be tracked, ensure that they won't be updated
79
+ self.running_mean
80
+ if not self.training or self.track_running_stats
81
+ else None,
82
+ self.running_var if not self.training or self.track_running_stats else None,
83
+ self.weight,
84
+ self.bias,
85
+ bn_training,
86
+ exponential_average_factor,
87
+ self.eps,
88
+ ).to(input.dtype).to(input.dtype)
unsloth_compiled_cache/BatchNorm3d.py ADDED
@@ -0,0 +1,88 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ 2025.3.17
3
+ 2025.3.19
4
+ 4.50.0
5
+ 0.15.2
6
+ __UNSLOTH_VERSIONING__
7
+ """
8
+
9
+ # Unsloth Zoo - Utilities for Unsloth
10
+ # Copyright 2023-present Daniel Han-Chen, Michael Han-Chen & the Unsloth team. All rights reserved.
11
+ #
12
+ # This program is free software: you can redistribute it and/or modify
13
+ # it under the terms of the GNU Lesser General Public License as published by
14
+ # the Free Software Foundation, either version 3 of the License, or
15
+ # (at your option) any later version.
16
+ #
17
+ # This program is distributed in the hope that it will be useful,
18
+ # but WITHOUT ANY WARRANTY; without even the implied warranty of
19
+ # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
20
+ # GNU General Public License for more details.
21
+ #
22
+ # You should have received a copy of the GNU Lesser General Public License
23
+ # along with this program. If not, see <https://www.gnu.org/licenses/>.
24
+
25
+ import os
26
+ import importlib.util
27
+ if importlib.util.find_spec("unsloth_studio") is None:
28
+ UNSLOTH_STUDIO_ENABLED = False
29
+ else:
30
+ UNSLOTH_STUDIO_ENABLED = os.environ.get("UNSLOTH_STUDIO_DISABLED", "0") == "0"
31
+ pass
32
+ from typing import List, Dict, Tuple, Optional, Any, Callable
33
+ import math
34
+
35
+ torch_compile_options = {'epilogue_fusion': True, 'max_autotune': False, 'shape_padding': True, 'trace.enabled': False, 'triton.cudagraphs': False}
36
+ from torch import Tensor
37
+ import torch
38
+ import torch.nn as nn
39
+ from torch.nn import functional as F
40
+ from transformers.models.gemma3.modeling_gemma3 import (nn)
41
+
42
+ def forward(self, input: Tensor) -> Tensor:
43
+ self._check_input_dim(input)
44
+
45
+ # exponential_average_factor is set to self.momentum
46
+ # (when it is available) only so that it gets updated
47
+ # in ONNX graph when this node is exported to ONNX.
48
+ if self.momentum is None:
49
+ exponential_average_factor = 0.0
50
+ else:
51
+ exponential_average_factor = self.momentum
52
+
53
+ if self.training and self.track_running_stats:
54
+ # TODO: if statement only here to tell the jit to skip emitting this when it is None
55
+ if self.num_batches_tracked is not None: # type: ignore[has-type]
56
+ self.num_batches_tracked.add_(1) # type: ignore[has-type]
57
+ if self.momentum is None: # use cumulative moving average
58
+ exponential_average_factor = 1.0 / float(self.num_batches_tracked)
59
+ else: # use exponential moving average
60
+ exponential_average_factor = self.momentum
61
+
62
+ r"""
63
+ Decide whether the mini-batch stats should be used for normalization rather than the buffers.
64
+ Mini-batch stats are used in training mode, and in eval mode when buffers are None.
65
+ """
66
+ if self.training:
67
+ bn_training = True
68
+ else:
69
+ bn_training = (self.running_mean is None) and (self.running_var is None)
70
+
71
+ r"""
72
+ Buffers are only updated if they are to be tracked and we are in training mode. Thus they only need to be
73
+ passed when the update should occur (i.e. in training mode when they are tracked), or when buffer stats are
74
+ used for normalization (i.e. in eval mode when buffers are not None).
75
+ """
76
+ return F.batch_norm(
77
+ input,
78
+ # If buffers are not to be tracked, ensure that they won't be updated
79
+ self.running_mean
80
+ if not self.training or self.track_running_stats
81
+ else None,
82
+ self.running_var if not self.training or self.track_running_stats else None,
83
+ self.weight,
84
+ self.bias,
85
+ bn_training,
86
+ exponential_average_factor,
87
+ self.eps,
88
+ ).to(input.dtype).to(input.dtype)
unsloth_compiled_cache/Conv1d.py ADDED
@@ -0,0 +1,43 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ 2025.3.17
3
+ 2025.3.19
4
+ 4.50.0
5
+ 0.15.2
6
+ __UNSLOTH_VERSIONING__
7
+ """
8
+
9
+ # Unsloth Zoo - Utilities for Unsloth
10
+ # Copyright 2023-present Daniel Han-Chen, Michael Han-Chen & the Unsloth team. All rights reserved.
11
+ #
12
+ # This program is free software: you can redistribute it and/or modify
13
+ # it under the terms of the GNU Lesser General Public License as published by
14
+ # the Free Software Foundation, either version 3 of the License, or
15
+ # (at your option) any later version.
16
+ #
17
+ # This program is distributed in the hope that it will be useful,
18
+ # but WITHOUT ANY WARRANTY; without even the implied warranty of
19
+ # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
20
+ # GNU General Public License for more details.
21
+ #
22
+ # You should have received a copy of the GNU Lesser General Public License
23
+ # along with this program. If not, see <https://www.gnu.org/licenses/>.
24
+
25
+ import os
26
+ import importlib.util
27
+ if importlib.util.find_spec("unsloth_studio") is None:
28
+ UNSLOTH_STUDIO_ENABLED = False
29
+ else:
30
+ UNSLOTH_STUDIO_ENABLED = os.environ.get("UNSLOTH_STUDIO_DISABLED", "0") == "0"
31
+ pass
32
+ from typing import List, Dict, Tuple, Optional, Any, Callable
33
+ import math
34
+
35
+ torch_compile_options = {'epilogue_fusion': True, 'max_autotune': False, 'shape_padding': True, 'trace.enabled': False, 'triton.cudagraphs': False}
36
+ from torch import Tensor
37
+ import torch
38
+ import torch.nn as nn
39
+ from torch.nn import functional as F
40
+
41
+
42
+ def forward(self, input: Tensor) -> Tensor:
43
+ return self._conv_forward(input, self.weight, self.bias).to(input.dtype).to(input.dtype)
unsloth_compiled_cache/Conv2d.py ADDED
@@ -0,0 +1,43 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ 2025.3.17
3
+ 2025.3.19
4
+ 4.50.0
5
+ 0.15.2
6
+ __UNSLOTH_VERSIONING__
7
+ """
8
+
9
+ # Unsloth Zoo - Utilities for Unsloth
10
+ # Copyright 2023-present Daniel Han-Chen, Michael Han-Chen & the Unsloth team. All rights reserved.
11
+ #
12
+ # This program is free software: you can redistribute it and/or modify
13
+ # it under the terms of the GNU Lesser General Public License as published by
14
+ # the Free Software Foundation, either version 3 of the License, or
15
+ # (at your option) any later version.
16
+ #
17
+ # This program is distributed in the hope that it will be useful,
18
+ # but WITHOUT ANY WARRANTY; without even the implied warranty of
19
+ # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
20
+ # GNU General Public License for more details.
21
+ #
22
+ # You should have received a copy of the GNU Lesser General Public License
23
+ # along with this program. If not, see <https://www.gnu.org/licenses/>.
24
+
25
+ import os
26
+ import importlib.util
27
+ if importlib.util.find_spec("unsloth_studio") is None:
28
+ UNSLOTH_STUDIO_ENABLED = False
29
+ else:
30
+ UNSLOTH_STUDIO_ENABLED = os.environ.get("UNSLOTH_STUDIO_DISABLED", "0") == "0"
31
+ pass
32
+ from typing import List, Dict, Tuple, Optional, Any, Callable
33
+ import math
34
+
35
+ torch_compile_options = {'epilogue_fusion': True, 'max_autotune': False, 'shape_padding': True, 'trace.enabled': False, 'triton.cudagraphs': False}
36
+ from torch import Tensor
37
+ import torch
38
+ import torch.nn as nn
39
+ from torch.nn import functional as F
40
+
41
+
42
+ def forward(self, input: Tensor) -> Tensor:
43
+ return self._conv_forward(input, self.weight, self.bias).to(input.dtype).to(input.dtype)
unsloth_compiled_cache/Conv3d.py ADDED
@@ -0,0 +1,43 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ 2025.3.17
3
+ 2025.3.19
4
+ 4.50.0
5
+ 0.15.2
6
+ __UNSLOTH_VERSIONING__
7
+ """
8
+
9
+ # Unsloth Zoo - Utilities for Unsloth
10
+ # Copyright 2023-present Daniel Han-Chen, Michael Han-Chen & the Unsloth team. All rights reserved.
11
+ #
12
+ # This program is free software: you can redistribute it and/or modify
13
+ # it under the terms of the GNU Lesser General Public License as published by
14
+ # the Free Software Foundation, either version 3 of the License, or
15
+ # (at your option) any later version.
16
+ #
17
+ # This program is distributed in the hope that it will be useful,
18
+ # but WITHOUT ANY WARRANTY; without even the implied warranty of
19
+ # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
20
+ # GNU General Public License for more details.
21
+ #
22
+ # You should have received a copy of the GNU Lesser General Public License
23
+ # along with this program. If not, see <https://www.gnu.org/licenses/>.
24
+
25
+ import os
26
+ import importlib.util
27
+ if importlib.util.find_spec("unsloth_studio") is None:
28
+ UNSLOTH_STUDIO_ENABLED = False
29
+ else:
30
+ UNSLOTH_STUDIO_ENABLED = os.environ.get("UNSLOTH_STUDIO_DISABLED", "0") == "0"
31
+ pass
32
+ from typing import List, Dict, Tuple, Optional, Any, Callable
33
+ import math
34
+
35
+ torch_compile_options = {'epilogue_fusion': True, 'max_autotune': False, 'shape_padding': True, 'trace.enabled': False, 'triton.cudagraphs': False}
36
+ from torch import Tensor
37
+ import torch
38
+ import torch.nn as nn
39
+ from torch.nn import functional as F
40
+
41
+
42
+ def forward(self, input: Tensor) -> Tensor:
43
+ return self._conv_forward(input, self.weight, self.bias).to(input.dtype).to(input.dtype)
unsloth_compiled_cache/ConvTranspose1d.py ADDED
@@ -0,0 +1,70 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ 2025.3.17
3
+ 2025.3.19
4
+ 4.50.0
5
+ 0.15.2
6
+ __UNSLOTH_VERSIONING__
7
+ """
8
+
9
+ # Unsloth Zoo - Utilities for Unsloth
10
+ # Copyright 2023-present Daniel Han-Chen, Michael Han-Chen & the Unsloth team. All rights reserved.
11
+ #
12
+ # This program is free software: you can redistribute it and/or modify
13
+ # it under the terms of the GNU Lesser General Public License as published by
14
+ # the Free Software Foundation, either version 3 of the License, or
15
+ # (at your option) any later version.
16
+ #
17
+ # This program is distributed in the hope that it will be useful,
18
+ # but WITHOUT ANY WARRANTY; without even the implied warranty of
19
+ # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
20
+ # GNU General Public License for more details.
21
+ #
22
+ # You should have received a copy of the GNU Lesser General Public License
23
+ # along with this program. If not, see <https://www.gnu.org/licenses/>.
24
+
25
+ import os
26
+ import importlib.util
27
+ if importlib.util.find_spec("unsloth_studio") is None:
28
+ UNSLOTH_STUDIO_ENABLED = False
29
+ else:
30
+ UNSLOTH_STUDIO_ENABLED = os.environ.get("UNSLOTH_STUDIO_DISABLED", "0") == "0"
31
+ pass
32
+ from typing import List, Dict, Tuple, Optional, Any, Callable
33
+ import math
34
+
35
+ torch_compile_options = {'epilogue_fusion': True, 'max_autotune': False, 'shape_padding': True, 'trace.enabled': False, 'triton.cudagraphs': False}
36
+ from torch import Tensor
37
+ import torch
38
+ import torch.nn as nn
39
+ from torch.nn import functional as F
40
+ from transformers.models.gemma3.modeling_gemma3 import (List, Optional, Tuple, nn)
41
+
42
+ def forward(self, input: Tensor, output_size: Optional[List[int]] = None) -> Tensor:
43
+ if self.padding_mode != "zeros":
44
+ raise ValueError(
45
+ "Only `zeros` padding mode is supported for ConvTranspose1d"
46
+ )
47
+
48
+ assert isinstance(self.padding, tuple)
49
+ # One cannot replace List by Tuple or Sequence in "_output_padding" because
50
+ # TorchScript does not support `Sequence[T]` or `Tuple[T, ...]`.
51
+ num_spatial_dims = 1
52
+ output_padding = self._output_padding(
53
+ input,
54
+ output_size,
55
+ self.stride, # type: ignore[arg-type]
56
+ self.padding, # type: ignore[arg-type]
57
+ self.kernel_size, # type: ignore[arg-type]
58
+ num_spatial_dims,
59
+ self.dilation, # type: ignore[arg-type]
60
+ )
61
+ return F.conv_transpose1d(
62
+ input,
63
+ self.weight,
64
+ self.bias,
65
+ self.stride,
66
+ self.padding,
67
+ output_padding,
68
+ self.groups,
69
+ self.dilation,
70
+ ).to(input.dtype).to(input.dtype)
unsloth_compiled_cache/ConvTranspose2d.py ADDED
@@ -0,0 +1,71 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ 2025.3.17
3
+ 2025.3.19
4
+ 4.50.0
5
+ 0.15.2
6
+ __UNSLOTH_VERSIONING__
7
+ """
8
+
9
+ # Unsloth Zoo - Utilities for Unsloth
10
+ # Copyright 2023-present Daniel Han-Chen, Michael Han-Chen & the Unsloth team. All rights reserved.
11
+ #
12
+ # This program is free software: you can redistribute it and/or modify
13
+ # it under the terms of the GNU Lesser General Public License as published by
14
+ # the Free Software Foundation, either version 3 of the License, or
15
+ # (at your option) any later version.
16
+ #
17
+ # This program is distributed in the hope that it will be useful,
18
+ # but WITHOUT ANY WARRANTY; without even the implied warranty of
19
+ # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
20
+ # GNU General Public License for more details.
21
+ #
22
+ # You should have received a copy of the GNU Lesser General Public License
23
+ # along with this program. If not, see <https://www.gnu.org/licenses/>.
24
+
25
+ import os
26
+ import importlib.util
27
+ if importlib.util.find_spec("unsloth_studio") is None:
28
+ UNSLOTH_STUDIO_ENABLED = False
29
+ else:
30
+ UNSLOTH_STUDIO_ENABLED = os.environ.get("UNSLOTH_STUDIO_DISABLED", "0") == "0"
31
+ pass
32
+ from typing import List, Dict, Tuple, Optional, Any, Callable
33
+ import math
34
+
35
+ torch_compile_options = {'epilogue_fusion': True, 'max_autotune': False, 'shape_padding': True, 'trace.enabled': False, 'triton.cudagraphs': False}
36
+ from torch import Tensor
37
+ import torch
38
+ import torch.nn as nn
39
+ from torch.nn import functional as F
40
+ from transformers.models.gemma3.modeling_gemma3 import (List, Optional, Tuple, nn)
41
+
42
+ def forward(self, input: Tensor, output_size: Optional[List[int]] = None) -> Tensor:
43
+ if self.padding_mode != "zeros":
44
+ raise ValueError(
45
+ "Only `zeros` padding mode is supported for ConvTranspose2d"
46
+ )
47
+
48
+ assert isinstance(self.padding, tuple)
49
+ # One cannot replace List by Tuple or Sequence in "_output_padding" because
50
+ # TorchScript does not support `Sequence[T]` or `Tuple[T, ...]`.
51
+ num_spatial_dims = 2
52
+ output_padding = self._output_padding(
53
+ input,
54
+ output_size,
55
+ self.stride, # type: ignore[arg-type]
56
+ self.padding, # type: ignore[arg-type]
57
+ self.kernel_size, # type: ignore[arg-type]
58
+ num_spatial_dims,
59
+ self.dilation, # type: ignore[arg-type]
60
+ )
61
+
62
+ return F.conv_transpose2d(
63
+ input,
64
+ self.weight,
65
+ self.bias,
66
+ self.stride,
67
+ self.padding,
68
+ output_padding,
69
+ self.groups,
70
+ self.dilation,
71
+ ).to(input.dtype).to(input.dtype)
unsloth_compiled_cache/ConvTranspose3d.py ADDED
@@ -0,0 +1,71 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ 2025.3.17
3
+ 2025.3.19
4
+ 4.50.0
5
+ 0.15.2
6
+ __UNSLOTH_VERSIONING__
7
+ """
8
+
9
+ # Unsloth Zoo - Utilities for Unsloth
10
+ # Copyright 2023-present Daniel Han-Chen, Michael Han-Chen & the Unsloth team. All rights reserved.
11
+ #
12
+ # This program is free software: you can redistribute it and/or modify
13
+ # it under the terms of the GNU Lesser General Public License as published by
14
+ # the Free Software Foundation, either version 3 of the License, or
15
+ # (at your option) any later version.
16
+ #
17
+ # This program is distributed in the hope that it will be useful,
18
+ # but WITHOUT ANY WARRANTY; without even the implied warranty of
19
+ # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
20
+ # GNU General Public License for more details.
21
+ #
22
+ # You should have received a copy of the GNU Lesser General Public License
23
+ # along with this program. If not, see <https://www.gnu.org/licenses/>.
24
+
25
+ import os
26
+ import importlib.util
27
+ if importlib.util.find_spec("unsloth_studio") is None:
28
+ UNSLOTH_STUDIO_ENABLED = False
29
+ else:
30
+ UNSLOTH_STUDIO_ENABLED = os.environ.get("UNSLOTH_STUDIO_DISABLED", "0") == "0"
31
+ pass
32
+ from typing import List, Dict, Tuple, Optional, Any, Callable
33
+ import math
34
+
35
+ torch_compile_options = {'epilogue_fusion': True, 'max_autotune': False, 'shape_padding': True, 'trace.enabled': False, 'triton.cudagraphs': False}
36
+ from torch import Tensor
37
+ import torch
38
+ import torch.nn as nn
39
+ from torch.nn import functional as F
40
+ from transformers.models.gemma3.modeling_gemma3 import (List, Optional, Tuple, nn)
41
+
42
+ def forward(self, input: Tensor, output_size: Optional[List[int]] = None) -> Tensor:
43
+ if self.padding_mode != "zeros":
44
+ raise ValueError(
45
+ "Only `zeros` padding mode is supported for ConvTranspose3d"
46
+ )
47
+
48
+ assert isinstance(self.padding, tuple)
49
+ # One cannot replace List by Tuple or Sequence in "_output_padding" because
50
+ # TorchScript does not support `Sequence[T]` or `Tuple[T, ...]`.
51
+ num_spatial_dims = 3
52
+ output_padding = self._output_padding(
53
+ input,
54
+ output_size,
55
+ self.stride, # type: ignore[arg-type]
56
+ self.padding, # type: ignore[arg-type]
57
+ self.kernel_size, # type: ignore[arg-type]
58
+ num_spatial_dims,
59
+ self.dilation, # type: ignore[arg-type]
60
+ )
61
+
62
+ return F.conv_transpose3d(
63
+ input,
64
+ self.weight,
65
+ self.bias,
66
+ self.stride,
67
+ self.padding,
68
+ output_padding,
69
+ self.groups,
70
+ self.dilation,
71
+ ).to(input.dtype).to(input.dtype)
unsloth_compiled_cache/GPTQLoraLinear_peft_forward.py ADDED
@@ -0,0 +1,73 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ 2025.3.17
3
+ 2025.3.19
4
+ 4.50.0
5
+ 0.15.2
6
+ __UNSLOTH_VERSIONING__
7
+ """
8
+
9
+ torch_compile_options = {'epilogue_fusion': True, 'max_autotune': False, 'shape_padding': True, 'trace.enabled': False, 'triton.cudagraphs': False}
10
+ from torch import Tensor
11
+ import torch
12
+ import torch.nn as nn
13
+ from torch.nn import functional as F
14
+ from peft.tuners.lora.gptq import (torch)
15
+
16
+
17
+ torch_addmm = torch.addmm
18
+ torch_add = torch.add
19
+ # @torch.compile(fullgraph = False, dynamic = True, options = torch_compile_options)
20
+ def lora_forward(result, lora_A, lora_B, dropout, x, scaling):
21
+ xA = dropout(x) @ lora_A.weight.t()
22
+ # output = result + scaling * xA @ lora_B.weight.t()
23
+ shape = result.shape
24
+ output = torch_addmm(
25
+ result.view(-1, shape[-1]),
26
+ xA.view(-1, xA.shape[-1]),
27
+ lora_B.weight.t(),
28
+ alpha = scaling,
29
+ beta = 1,
30
+ ).view(shape)
31
+
32
+ bias = lora_B.bias
33
+ if bias is not None:
34
+ output = torch_add(
35
+ output,
36
+ bias,
37
+ alpha = scaling,
38
+ )
39
+ return output
40
+ pass
41
+
42
+ def unsloth_forward(self, x: torch.Tensor):
43
+ # note: logic differs from default Linear because merging is not supported
44
+ result = self.quant_linear_module(x)
45
+
46
+ if self.disable_adapters:
47
+ return result
48
+
49
+ lora_A_keys = self.lora_A.keys()
50
+ for active_adapter in self.active_adapters:
51
+ if active_adapter not in lora_A_keys:
52
+ continue
53
+
54
+ lora_A = self.lora_A[active_adapter]
55
+ lora_B = self.lora_B[active_adapter]
56
+ dropout = self.lora_dropout[active_adapter]
57
+ scaling = self.scaling[active_adapter]
58
+
59
+ requires_conversion = not torch.is_autocast_enabled()
60
+ if requires_conversion:
61
+ expected_dtype = result.dtype
62
+ x = self._cast_input_dtype(x, lora_A.weight.dtype)
63
+
64
+ output = lora_B(lora_A(dropout(x)))
65
+
66
+ if requires_conversion:
67
+ output = output.to(expected_dtype)
68
+
69
+ if scaling != 1: # skip scaling == 1 no-op
70
+ output = output * scaling
71
+
72
+ result += output
73
+ return result
unsloth_compiled_cache/GroupNorm.py ADDED
@@ -0,0 +1,43 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ 2025.3.17
3
+ 2025.3.19
4
+ 4.50.0
5
+ 0.15.2
6
+ __UNSLOTH_VERSIONING__
7
+ """
8
+
9
+ # Unsloth Zoo - Utilities for Unsloth
10
+ # Copyright 2023-present Daniel Han-Chen, Michael Han-Chen & the Unsloth team. All rights reserved.
11
+ #
12
+ # This program is free software: you can redistribute it and/or modify
13
+ # it under the terms of the GNU Lesser General Public License as published by
14
+ # the Free Software Foundation, either version 3 of the License, or
15
+ # (at your option) any later version.
16
+ #
17
+ # This program is distributed in the hope that it will be useful,
18
+ # but WITHOUT ANY WARRANTY; without even the implied warranty of
19
+ # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
20
+ # GNU General Public License for more details.
21
+ #
22
+ # You should have received a copy of the GNU Lesser General Public License
23
+ # along with this program. If not, see <https://www.gnu.org/licenses/>.
24
+
25
+ import os
26
+ import importlib.util
27
+ if importlib.util.find_spec("unsloth_studio") is None:
28
+ UNSLOTH_STUDIO_ENABLED = False
29
+ else:
30
+ UNSLOTH_STUDIO_ENABLED = os.environ.get("UNSLOTH_STUDIO_DISABLED", "0") == "0"
31
+ pass
32
+ from typing import List, Dict, Tuple, Optional, Any, Callable
33
+ import math
34
+
35
+ torch_compile_options = {'epilogue_fusion': True, 'max_autotune': False, 'shape_padding': True, 'trace.enabled': False, 'triton.cudagraphs': False}
36
+ from torch import Tensor
37
+ import torch
38
+ import torch.nn as nn
39
+ from torch.nn import functional as F
40
+
41
+
42
+ def forward(self, input: Tensor) -> Tensor:
43
+ return F.group_norm(input, self.num_groups, self.weight, self.bias, self.eps).to(input.dtype).to(input.dtype)
unsloth_compiled_cache/LayerNorm.py ADDED
@@ -0,0 +1,45 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ 2025.3.17
3
+ 2025.3.19
4
+ 4.50.0
5
+ 0.15.2
6
+ __UNSLOTH_VERSIONING__
7
+ """
8
+
9
+ # Unsloth Zoo - Utilities for Unsloth
10
+ # Copyright 2023-present Daniel Han-Chen, Michael Han-Chen & the Unsloth team. All rights reserved.
11
+ #
12
+ # This program is free software: you can redistribute it and/or modify
13
+ # it under the terms of the GNU Lesser General Public License as published by
14
+ # the Free Software Foundation, either version 3 of the License, or
15
+ # (at your option) any later version.
16
+ #
17
+ # This program is distributed in the hope that it will be useful,
18
+ # but WITHOUT ANY WARRANTY; without even the implied warranty of
19
+ # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
20
+ # GNU General Public License for more details.
21
+ #
22
+ # You should have received a copy of the GNU Lesser General Public License
23
+ # along with this program. If not, see <https://www.gnu.org/licenses/>.
24
+
25
+ import os
26
+ import importlib.util
27
+ if importlib.util.find_spec("unsloth_studio") is None:
28
+ UNSLOTH_STUDIO_ENABLED = False
29
+ else:
30
+ UNSLOTH_STUDIO_ENABLED = os.environ.get("UNSLOTH_STUDIO_DISABLED", "0") == "0"
31
+ pass
32
+ from typing import List, Dict, Tuple, Optional, Any, Callable
33
+ import math
34
+
35
+ torch_compile_options = {'epilogue_fusion': True, 'max_autotune': False, 'shape_padding': True, 'trace.enabled': False, 'triton.cudagraphs': False}
36
+ from torch import Tensor
37
+ import torch
38
+ import torch.nn as nn
39
+ from torch.nn import functional as F
40
+
41
+
42
+ def forward(self, input: Tensor) -> Tensor:
43
+ return F.layer_norm(
44
+ input, self.normalized_shape, self.weight, self.bias, self.eps
45
+ ).to(input.dtype).to(input.dtype)
unsloth_compiled_cache/Linear4bit_peft_forward.py ADDED
@@ -0,0 +1,97 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ 2025.3.17
3
+ 2025.3.19
4
+ 4.50.0
5
+ 0.15.2
6
+ __UNSLOTH_VERSIONING__
7
+ """
8
+
9
+ torch_compile_options = {'epilogue_fusion': True, 'max_autotune': False, 'shape_padding': True, 'trace.enabled': False, 'triton.cudagraphs': False}
10
+ from torch import Tensor
11
+ import torch
12
+ import torch.nn as nn
13
+ from torch.nn import functional as F
14
+ from peft.tuners.lora.bnb import (torch)
15
+
16
+
17
+ torch_addmm = torch.addmm
18
+ torch_add = torch.add
19
+ # @torch.compile(fullgraph = False, dynamic = True, options = torch_compile_options)
20
+ def lora_forward(result, lora_A, lora_B, dropout, x, scaling):
21
+ xA = dropout(x) @ lora_A.weight.t()
22
+ # output = result + scaling * xA @ lora_B.weight.t()
23
+ shape = result.shape
24
+ output = torch_addmm(
25
+ result.view(-1, shape[-1]),
26
+ xA.view(-1, xA.shape[-1]),
27
+ lora_B.weight.t(),
28
+ alpha = scaling,
29
+ beta = 1,
30
+ ).view(shape)
31
+
32
+ bias = lora_B.bias
33
+ if bias is not None:
34
+ output = torch_add(
35
+ output,
36
+ bias,
37
+ alpha = scaling,
38
+ )
39
+ return output
40
+ pass
41
+
42
+ def unsloth_forward(self, x: torch.Tensor, *args, **kwargs) -> torch.Tensor:
43
+
44
+ adapter_names = kwargs.pop("adapter_names", None)
45
+
46
+ if self.disable_adapters:
47
+ if self.merged:
48
+ self.unmerge()
49
+ result = self.base_layer(x, *args, **kwargs)
50
+ elif adapter_names is not None:
51
+ result = self._mixed_batch_forward(x, *args, adapter_names=adapter_names, **kwargs)
52
+ elif self.merged:
53
+ result = self.base_layer(x, *args, **kwargs)
54
+ else:
55
+ result = self.base_layer(x, *args, **kwargs)
56
+ # As per Tim Dettmers, for 4bit, we need to defensively clone here.
57
+ # The reason is that in some cases, an error can occur that backprop
58
+ # does not work on a manipulated view. This issue may be solved with
59
+ # newer PyTorch versions but this would need extensive testing to be
60
+ # sure.
61
+
62
+
63
+ for active_adapter in self.active_adapters:
64
+ if active_adapter not in self.lora_A.keys():
65
+ continue
66
+ lora_A = self.lora_A[active_adapter]
67
+ lora_B = self.lora_B[active_adapter]
68
+ dropout = self.lora_dropout[active_adapter]
69
+ scaling = self.scaling[active_adapter]
70
+
71
+ requires_conversion = not torch.is_autocast_enabled()
72
+ if requires_conversion:
73
+ expected_dtype = result.dtype
74
+ x = self._cast_input_dtype(x, lora_A.weight.dtype)
75
+
76
+ if not self.use_dora[active_adapter]:
77
+ return lora_forward(result, lora_A, lora_B, dropout, x, scaling)
78
+ else:
79
+ if isinstance(dropout, torch.nn.Identity) or not self.training:
80
+ base_result = result
81
+ else:
82
+ x = dropout(x)
83
+ base_result = None
84
+
85
+ output = self.lora_magnitude_vector[active_adapter](
86
+ x,
87
+ lora_A=lora_A,
88
+ lora_B=lora_B,
89
+ scaling=scaling,
90
+ base_layer=self.get_base_layer(),
91
+ base_result=base_result,
92
+ )
93
+ if requires_conversion:
94
+ output = output.to(expected_dtype)
95
+ result = result + output
96
+
97
+ return result
unsloth_compiled_cache/Linear8bitLt_peft_forward.py ADDED
@@ -0,0 +1,90 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ 2025.3.17
3
+ 2025.3.19
4
+ 4.50.0
5
+ 0.15.2
6
+ __UNSLOTH_VERSIONING__
7
+ """
8
+
9
+ torch_compile_options = {'epilogue_fusion': True, 'max_autotune': False, 'shape_padding': True, 'trace.enabled': False, 'triton.cudagraphs': False}
10
+ from torch import Tensor
11
+ import torch
12
+ import torch.nn as nn
13
+ from torch.nn import functional as F
14
+ from peft.tuners.lora.bnb import (torch)
15
+
16
+
17
+ torch_addmm = torch.addmm
18
+ torch_add = torch.add
19
+ # @torch.compile(fullgraph = False, dynamic = True, options = torch_compile_options)
20
+ def lora_forward(result, lora_A, lora_B, dropout, x, scaling):
21
+ xA = dropout(x) @ lora_A.weight.t()
22
+ # output = result + scaling * xA @ lora_B.weight.t()
23
+ shape = result.shape
24
+ output = torch_addmm(
25
+ result.view(-1, shape[-1]),
26
+ xA.view(-1, xA.shape[-1]),
27
+ lora_B.weight.t(),
28
+ alpha = scaling,
29
+ beta = 1,
30
+ ).view(shape)
31
+
32
+ bias = lora_B.bias
33
+ if bias is not None:
34
+ output = torch_add(
35
+ output,
36
+ bias,
37
+ alpha = scaling,
38
+ )
39
+ return output
40
+ pass
41
+
42
+ def unsloth_forward(self, x: torch.Tensor, *args, **kwargs) -> torch.Tensor:
43
+
44
+ adapter_names = kwargs.pop("adapter_names", None)
45
+
46
+ if self.disable_adapters:
47
+ if self.merged:
48
+ self.unmerge()
49
+ result = self.base_layer(x, *args, **kwargs)
50
+ elif adapter_names is not None:
51
+ result = self._mixed_batch_forward(x, *args, adapter_names=adapter_names, **kwargs)
52
+ elif self.merged:
53
+ result = self.base_layer(x, *args, **kwargs)
54
+ else:
55
+ result = self.base_layer(x, *args, **kwargs)
56
+ for active_adapter in self.active_adapters:
57
+ if active_adapter not in self.lora_A.keys():
58
+ continue
59
+ lora_A = self.lora_A[active_adapter]
60
+ lora_B = self.lora_B[active_adapter]
61
+ dropout = self.lora_dropout[active_adapter]
62
+ scaling = self.scaling[active_adapter]
63
+
64
+ requires_conversion = not torch.is_autocast_enabled()
65
+ if requires_conversion:
66
+ expected_dtype = result.dtype
67
+ x = self._cast_input_dtype(x, lora_A.weight.dtype)
68
+
69
+ if not self.use_dora[active_adapter]:
70
+ return lora_forward(result, lora_A, lora_B, dropout, x, scaling)
71
+ else:
72
+ if isinstance(dropout, torch.nn.Identity) or not self.training:
73
+ base_result = result
74
+ else:
75
+ x = dropout(x)
76
+ base_result = None
77
+
78
+ output = self.lora_magnitude_vector[active_adapter](
79
+ x,
80
+ lora_A=lora_A,
81
+ lora_B=lora_B,
82
+ scaling=scaling,
83
+ base_layer=self.get_base_layer(),
84
+ base_result=base_result,
85
+ )
86
+ if requires_conversion:
87
+ output = output.to(expected_dtype)
88
+ result = result + output
89
+
90
+ return result
unsloth_compiled_cache/Linear_peft_forward.py ADDED
@@ -0,0 +1,89 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ 2025.3.17
3
+ 2025.3.19
4
+ 4.50.0
5
+ 0.15.2
6
+ __UNSLOTH_VERSIONING__
7
+ """
8
+
9
+ torch_compile_options = {'epilogue_fusion': True, 'max_autotune': False, 'shape_padding': True, 'trace.enabled': False, 'triton.cudagraphs': False}
10
+ from torch import Tensor
11
+ import torch
12
+ import torch.nn as nn
13
+ from torch.nn import functional as F
14
+ from peft.tuners.lora.layer import (Any, F, nn, torch)
15
+
16
+
17
+ torch_addmm = torch.addmm
18
+ torch_add = torch.add
19
+ # @torch.compile(fullgraph = False, dynamic = True, options = torch_compile_options)
20
+ def lora_forward(result, lora_A, lora_B, dropout, x, scaling):
21
+ xA = dropout(x) @ lora_A.weight.t()
22
+ # output = result + scaling * xA @ lora_B.weight.t()
23
+ shape = result.shape
24
+ output = torch_addmm(
25
+ result.view(-1, shape[-1]),
26
+ xA.view(-1, xA.shape[-1]),
27
+ lora_B.weight.t(),
28
+ alpha = scaling,
29
+ beta = 1,
30
+ ).view(shape)
31
+
32
+ bias = lora_B.bias
33
+ if bias is not None:
34
+ output = torch_add(
35
+ output,
36
+ bias,
37
+ alpha = scaling,
38
+ )
39
+ return output
40
+ pass
41
+
42
+ def unsloth_forward(self, x: torch.Tensor, *args: Any, **kwargs: Any) -> torch.Tensor:
43
+
44
+ adapter_names = kwargs.pop("adapter_names", None)
45
+
46
+ if self.disable_adapters:
47
+ if self.merged:
48
+ self.unmerge()
49
+ result = self.base_layer(x, *args, **kwargs)
50
+ elif adapter_names is not None:
51
+ result = self._mixed_batch_forward(x, *args, adapter_names=adapter_names, **kwargs)
52
+ elif self.merged:
53
+ result = self.base_layer(x, *args, **kwargs)
54
+ else:
55
+ result = self.base_layer(x, *args, **kwargs)
56
+ torch_result_dtype = result.dtype
57
+
58
+ lora_A_keys = self.lora_A.keys()
59
+ for active_adapter in self.active_adapters:
60
+ if active_adapter not in lora_A_keys:
61
+ continue
62
+
63
+ lora_A = self.lora_A[active_adapter]
64
+ lora_B = self.lora_B[active_adapter]
65
+ dropout = self.lora_dropout[active_adapter]
66
+ scaling = self.scaling[active_adapter]
67
+ if not torch.is_autocast_enabled(): result, x = result.to(lora_A.weight.dtype), x.to(lora_A.weight.dtype)
68
+
69
+ if not self.use_dora[active_adapter]:
70
+ return lora_forward(result, lora_A, lora_B, dropout, x, scaling)
71
+ else:
72
+ if isinstance(dropout, nn.Identity) or not self.training:
73
+ base_result = result
74
+ else:
75
+ x = dropout(x)
76
+ base_result = None
77
+
78
+ result = result + self.lora_magnitude_vector[active_adapter](
79
+ x,
80
+ lora_A=lora_A,
81
+ lora_B=lora_B,
82
+ scaling=scaling,
83
+ base_layer=self.get_base_layer(),
84
+ base_result=base_result,
85
+ )
86
+
87
+ result = result.to(torch_result_dtype)
88
+
89
+ return result
unsloth_compiled_cache/LoraParallelLinear_peft_forward.py ADDED
@@ -0,0 +1,87 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ 2025.3.17
3
+ 2025.3.19
4
+ 4.50.0
5
+ 0.15.2
6
+ __UNSLOTH_VERSIONING__
7
+ """
8
+
9
+ torch_compile_options = {'epilogue_fusion': True, 'max_autotune': False, 'shape_padding': True, 'trace.enabled': False, 'triton.cudagraphs': False}
10
+ from torch import Tensor
11
+ import torch
12
+ import torch.nn as nn
13
+ from torch.nn import functional as F
14
+ from peft.tuners.lora.tp_layer import (Any, __name__, nn, torch)
15
+
16
+
17
+ torch_addmm = torch.addmm
18
+ torch_add = torch.add
19
+ # @torch.compile(fullgraph = False, dynamic = True, options = torch_compile_options)
20
+ def lora_forward(result, lora_A, lora_B, dropout, x, scaling):
21
+ xA = dropout(x) @ lora_A.weight.t()
22
+ # output = result + scaling * xA @ lora_B.weight.t()
23
+ shape = result.shape
24
+ output = torch_addmm(
25
+ result.view(-1, shape[-1]),
26
+ xA.view(-1, xA.shape[-1]),
27
+ lora_B.weight.t(),
28
+ alpha = scaling,
29
+ beta = 1,
30
+ ).view(shape)
31
+
32
+ bias = lora_B.bias
33
+ if bias is not None:
34
+ output = torch_add(
35
+ output,
36
+ bias,
37
+ alpha = scaling,
38
+ )
39
+ return output
40
+ pass
41
+
42
+ def unsloth_forward(self, x: torch.Tensor, *args: Any, **kwargs: Any):
43
+
44
+ adapter_names = kwargs.pop("adapter_names", None)
45
+ # If weight is used for matrix multiplication here, the final aggregation operation of the original
46
+ # parallel_linear layer will be missing, so we need to directly call its forward function to obtain the
47
+ # output of the original parallel_linear layer.
48
+ if self.disable_adapters:
49
+ if self.merged:
50
+ self.unmerge()
51
+ result, bias = self.base_layer(x, *args, **kwargs)
52
+ elif adapter_names is not None:
53
+ raise ValueError(f"{self.__class__.__name__} does not support mixed_batch_forward yet.")
54
+ elif self.merged:
55
+ result, bias = self.base_layer(x, *args, **kwargs)
56
+ else:
57
+ result, bias = self.base_layer(x, *args, **kwargs)
58
+ torch_result_dtype = result.dtype
59
+ for active_adapter in self.active_adapters:
60
+ if active_adapter not in self.lora_A.keys():
61
+ continue
62
+ lora_A = self.lora_A[active_adapter]
63
+ lora_B = self.lora_B[active_adapter]
64
+ dropout = self.lora_dropout[active_adapter]
65
+ scaling = self.scaling[active_adapter]
66
+ if not torch.is_autocast_enabled(): result, x = result.to(lora_A.weight.dtype), x.to(lora_A.weight.dtype)
67
+
68
+ if not self.use_dora[active_adapter]:
69
+ return lora_forward(result, lora_A, lora_B, dropout, x, scaling)
70
+ else:
71
+ if isinstance(dropout, torch.nn.Identity) or not self.training:
72
+ base_result = result
73
+ else:
74
+ x = dropout(x)
75
+ base_result = None
76
+
77
+ result = result + self.lora_magnitude_vector[active_adapter](
78
+ x,
79
+ lora_A=lora_A,
80
+ lora_B=lora_B,
81
+ scaling=scaling,
82
+ base_layer=self.get_base_layer(),
83
+ base_result=base_result,
84
+ )
85
+
86
+ result = result.to(torch_result_dtype)
87
+ return result, bias
unsloth_compiled_cache/RMSNorm.py ADDED
@@ -0,0 +1,46 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ 2025.3.17
3
+ 2025.3.19
4
+ 4.50.0
5
+ 0.15.2
6
+ __UNSLOTH_VERSIONING__
7
+ """
8
+
9
+ # Unsloth Zoo - Utilities for Unsloth
10
+ # Copyright 2023-present Daniel Han-Chen, Michael Han-Chen & the Unsloth team. All rights reserved.
11
+ #
12
+ # This program is free software: you can redistribute it and/or modify
13
+ # it under the terms of the GNU Lesser General Public License as published by
14
+ # the Free Software Foundation, either version 3 of the License, or
15
+ # (at your option) any later version.
16
+ #
17
+ # This program is distributed in the hope that it will be useful,
18
+ # but WITHOUT ANY WARRANTY; without even the implied warranty of
19
+ # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
20
+ # GNU General Public License for more details.
21
+ #
22
+ # You should have received a copy of the GNU Lesser General Public License
23
+ # along with this program. If not, see <https://www.gnu.org/licenses/>.
24
+
25
+ import os
26
+ import importlib.util
27
+ if importlib.util.find_spec("unsloth_studio") is None:
28
+ UNSLOTH_STUDIO_ENABLED = False
29
+ else:
30
+ UNSLOTH_STUDIO_ENABLED = os.environ.get("UNSLOTH_STUDIO_DISABLED", "0") == "0"
31
+ pass
32
+ from typing import List, Dict, Tuple, Optional, Any, Callable
33
+ import math
34
+
35
+ torch_compile_options = {'epilogue_fusion': True, 'max_autotune': False, 'shape_padding': True, 'trace.enabled': False, 'triton.cudagraphs': False}
36
+ from torch import Tensor
37
+ import torch
38
+ import torch.nn as nn
39
+ from torch.nn import functional as F
40
+ from transformers.models.gemma3.modeling_gemma3 import (torch)
41
+
42
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
43
+ """
44
+ Runs forward pass.
45
+ """
46
+ return F.rms_norm(x, self.normalized_shape, self.weight, self.eps).to(input.dtype).to(input.dtype)
unsloth_compiled_cache/UnslothAlignPropTrainer.py ADDED
@@ -0,0 +1,637 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ 2025.3.17
3
+ 2025.3.19
4
+ 4.50.0
5
+ 0.15.2
6
+ __UNSLOTH_VERSIONING__
7
+ """
8
+ from torch import Tensor
9
+ import torch
10
+ import torch.nn as nn
11
+ from torch.nn import functional as F
12
+ from trl.trainer.alignprop_trainer import (Accelerator, AlignPropConfig, AlignPropTrainer, Any, Callable, DDPOStableDiffusionPipeline, Optional, ProjectConfiguration, PyTorchModelHubMixin, Union, defaultdict, generate_model_card, get_comet_experiment_url, is_wandb_available, logger, os, set_seed, textwrap, torch, wandb, warn)
13
+
14
+
15
+ import os
16
+ from typing import *
17
+ from dataclasses import dataclass, field
18
+ from packaging.version import Version
19
+ import torch
20
+ import numpy as np
21
+ from contextlib import nullcontext
22
+ from torch.nn import functional as F
23
+ from transformers import DataCollatorForSeq2Seq, DataCollatorForLanguageModeling
24
+
25
+ torch_compile_options = {
26
+ "epilogue_fusion" : True,
27
+ "max_autotune" : False,
28
+ "shape_padding" : True,
29
+ "trace.enabled" : False,
30
+ "triton.cudagraphs" : False,
31
+ }
32
+
33
+ @torch.compile(dynamic = True, fullgraph = True, options = torch_compile_options,)
34
+ def selective_log_softmax(logits, index):
35
+ logits = logits.to(torch.float32)
36
+ selected_logits = torch.gather(logits, dim = -1, index = index.unsqueeze(-1)).squeeze(-1)
37
+ # loop to reduce peak mem consumption
38
+ # logsumexp_values = torch.stack([torch.logsumexp(lg, dim=-1) for lg in logits])
39
+ logsumexp_values = torch.logsumexp(logits, dim = -1)
40
+ per_token_logps = selected_logits - logsumexp_values # log_softmax(x_i) = x_i - logsumexp(x)
41
+ return per_token_logps
42
+ @dataclass
43
+ class UnslothAlignPropConfig(AlignPropConfig):
44
+ """
45
+
46
+ Configuration class for the [`AlignPropTrainer`].
47
+
48
+ Using [`~transformers.HfArgumentParser`] we can turn this class into
49
+ [argparse](https://docs.python.org/3/library/argparse#module-argparse) arguments that can be specified on the
50
+ command line.
51
+
52
+ Parameters:
53
+ exp_name (`str`, *optional*, defaults to `os.path.basename(sys.argv[0])[: -len(".py")]`):
54
+ Name of this experiment (defaults to the file name without the extension).
55
+ run_name (`str`, *optional*, defaults to `""`):
56
+ Name of this run.
57
+ seed (`int`, *optional*, defaults to `0`):
58
+ Random seed for reproducibility.
59
+ log_with (`str` or `None`, *optional*, defaults to `None`):
60
+ Log with either `"wandb"` or `"tensorboard"`. Check
61
+ [tracking](https://huggingface.co/docs/accelerate/usage_guides/tracking) for more details.
62
+ log_image_freq (`int`, *optional*, defaults to `1`):
63
+ Frequency for logging images.
64
+ tracker_kwargs (`dict[str, Any]`, *optional*, defaults to `{}`):
65
+ Keyword arguments for the tracker (e.g., `wandb_project`).
66
+ accelerator_kwargs (`dict[str, Any]`, *optional*, defaults to `{}`):
67
+ Keyword arguments for the accelerator.
68
+ project_kwargs (`dict[str, Any]`, *optional*, defaults to `{}`):
69
+ Keyword arguments for the accelerator project config (e.g., `logging_dir`).
70
+ tracker_project_name (`str`, *optional*, defaults to `"trl"`):
71
+ Name of project to use for tracking.
72
+ logdir (`str`, *optional*, defaults to `"logs"`):
73
+ Top-level logging directory for checkpoint saving.
74
+ num_epochs (`int`, *optional*, defaults to `100`):
75
+ Number of epochs to train.
76
+ save_freq (`int`, *optional*, defaults to `1`):
77
+ Number of epochs between saving model checkpoints.
78
+ num_checkpoint_limit (`int`, *optional*, defaults to `5`):
79
+ Number of checkpoints to keep before overwriting old ones.
80
+ mixed_precision (`str`, *optional*, defaults to `"fp16"`):
81
+ Mixed precision training.
82
+ allow_tf32 (`bool`, *optional*, defaults to `True`):
83
+ Allow `tf32` on Ampere GPUs.
84
+ resume_from (`str`, *optional*, defaults to `""`):
85
+ Path to resume training from a checkpoint.
86
+ sample_num_steps (`int`, *optional*, defaults to `50`):
87
+ Number of sampler inference steps.
88
+ sample_eta (`float`, *optional*, defaults to `1.0`):
89
+ Eta parameter for the DDIM sampler.
90
+ sample_guidance_scale (`float`, *optional*, defaults to `5.0`):
91
+ Classifier-free guidance weight.
92
+ train_batch_size (`int`, *optional*, defaults to `1`):
93
+ Batch size for training.
94
+ train_use_8bit_adam (`bool`, *optional*, defaults to `False`):
95
+ Whether to use the 8bit Adam optimizer from `bitsandbytes`.
96
+ train_learning_rate (`float`, *optional*, defaults to `1e-3`):
97
+ Learning rate.
98
+ train_adam_beta1 (`float`, *optional*, defaults to `0.9`):
99
+ Beta1 for Adam optimizer.
100
+ train_adam_beta2 (`float`, *optional*, defaults to `0.999`):
101
+ Beta2 for Adam optimizer.
102
+ train_adam_weight_decay (`float`, *optional*, defaults to `1e-4`):
103
+ Weight decay for Adam optimizer.
104
+ train_adam_epsilon (`float`, *optional*, defaults to `1e-8`):
105
+ Epsilon value for Adam optimizer.
106
+ train_gradient_accumulation_steps (`int`, *optional*, defaults to `1`):
107
+ Number of gradient accumulation steps.
108
+ train_max_grad_norm (`float`, *optional*, defaults to `1.0`):
109
+ Maximum gradient norm for gradient clipping.
110
+ negative_prompts (`str` or `None`, *optional*, defaults to `None`):
111
+ Comma-separated list of prompts to use as negative examples.
112
+ truncated_backprop_rand (`bool`, *optional*, defaults to `True`):
113
+ If `True`, randomized truncation to different diffusion timesteps is used.
114
+ truncated_backprop_timestep (`int`, *optional*, defaults to `49`):
115
+ Absolute timestep to which the gradients are backpropagated. Used only if `truncated_backprop_rand=False`.
116
+ truncated_rand_backprop_minmax (`tuple[int, int]`, *optional*, defaults to `(0, 50)`):
117
+ Range of diffusion timesteps for randomized truncated backpropagation.
118
+ push_to_hub (`bool`, *optional*, defaults to `False`):
119
+ Whether to push the final model to the Hub.
120
+
121
+ """
122
+ vllm_sampling_params: Optional[Any] = field(
123
+ default = None,
124
+ metadata = {'help': 'vLLM SamplingParams'},
125
+ )
126
+ unsloth_num_chunks : Optional[int] = field(
127
+ default = -1,
128
+ metadata = {'help': 'Chunk size to reduce memory usage. -1 is most efficient.'},
129
+ )
130
+ def __init__(
131
+ self,
132
+ exp_name = 'app',
133
+ run_name = '',
134
+ seed = 3407,
135
+ log_with = None,
136
+ log_image_freq = 1,
137
+ tracker_project_name = 'trl',
138
+ logdir = 'logs',
139
+ num_epochs = 100,
140
+ save_freq = 1,
141
+ num_checkpoint_limit = 5,
142
+ mixed_precision = 'fp16',
143
+ allow_tf32 = True,
144
+ resume_from = '',
145
+ sample_num_steps = 50,
146
+ sample_eta = 1.0,
147
+ sample_guidance_scale = 5.0,
148
+ train_batch_size = 1,
149
+ train_use_8bit_adam = False,
150
+ train_learning_rate = 5e-05,
151
+ train_adam_beta1 = 0.9,
152
+ train_adam_beta2 = 0.999,
153
+ train_adam_weight_decay = 0.01,
154
+ train_adam_epsilon = 1e-08,
155
+ train_gradient_accumulation_steps = 2,
156
+ train_max_grad_norm = 1.0,
157
+ negative_prompts = None,
158
+ truncated_backprop_rand = True,
159
+ truncated_backprop_timestep = 49,
160
+ push_to_hub = False,
161
+ vllm_sampling_params = None,
162
+ unsloth_num_chunks = -1,
163
+ **kwargs,
164
+ ):
165
+
166
+ super().__init__(
167
+ exp_name = exp_name,
168
+ run_name = run_name,
169
+ seed = seed,
170
+ log_with = log_with,
171
+ log_image_freq = log_image_freq,
172
+ tracker_project_name = tracker_project_name,
173
+ logdir = logdir,
174
+ num_epochs = num_epochs,
175
+ save_freq = save_freq,
176
+ num_checkpoint_limit = num_checkpoint_limit,
177
+ mixed_precision = mixed_precision,
178
+ allow_tf32 = allow_tf32,
179
+ resume_from = resume_from,
180
+ sample_num_steps = sample_num_steps,
181
+ sample_eta = sample_eta,
182
+ sample_guidance_scale = sample_guidance_scale,
183
+ train_batch_size = train_batch_size,
184
+ train_use_8bit_adam = train_use_8bit_adam,
185
+ train_learning_rate = train_learning_rate,
186
+ train_adam_beta1 = train_adam_beta1,
187
+ train_adam_beta2 = train_adam_beta2,
188
+ train_adam_weight_decay = train_adam_weight_decay,
189
+ train_adam_epsilon = train_adam_epsilon,
190
+ train_gradient_accumulation_steps = train_gradient_accumulation_steps,
191
+ train_max_grad_norm = train_max_grad_norm,
192
+ negative_prompts = negative_prompts,
193
+ truncated_backprop_rand = truncated_backprop_rand,
194
+ truncated_backprop_timestep = truncated_backprop_timestep,
195
+ push_to_hub = push_to_hub,**kwargs)
196
+ self.vllm_sampling_params = vllm_sampling_params
197
+ self.unsloth_num_chunks = unsloth_num_chunks
198
+ pass
199
+
200
+ class _UnslothAlignPropTrainer(PyTorchModelHubMixin):
201
+ """"""
202
+
203
+ _tag_names = ["trl", "alignprop"]
204
+
205
+ def __init__(
206
+ self,
207
+ config: AlignPropConfig,
208
+ reward_function: Callable[[torch.Tensor, tuple[str], tuple[Any]], torch.Tensor],
209
+ prompt_function: Callable[[], tuple[str, Any]],
210
+ sd_pipeline: DDPOStableDiffusionPipeline,
211
+ image_samples_hook: Optional[Callable[[Any, Any, Any], Any]] = None,
212
+ ):
213
+ if image_samples_hook is None:
214
+ warn("No image_samples_hook provided; no images will be logged")
215
+
216
+ self.prompt_fn = prompt_function
217
+ self.reward_fn = reward_function
218
+ self.config = config
219
+ self.image_samples_callback = image_samples_hook
220
+
221
+ accelerator_project_config = ProjectConfiguration(**self.config.project_kwargs)
222
+
223
+ if self.config.resume_from:
224
+ self.config.resume_from = os.path.normpath(os.path.expanduser(self.config.resume_from))
225
+ if "checkpoint_" not in os.path.basename(self.config.resume_from):
226
+ # get the most recent checkpoint in this directory
227
+ checkpoints = list(
228
+ filter(
229
+ lambda x: "checkpoint_" in x,
230
+ os.listdir(self.config.resume_from),
231
+ )
232
+ )
233
+ if len(checkpoints) == 0:
234
+ raise ValueError(f"No checkpoints found in {self.config.resume_from}")
235
+ checkpoint_numbers = sorted([int(x.split("_")[-1]) for x in checkpoints])
236
+ self.config.resume_from = os.path.join(
237
+ self.config.resume_from,
238
+ f"checkpoint_{checkpoint_numbers[-1]}",
239
+ )
240
+
241
+ accelerator_project_config.iteration = checkpoint_numbers[-1] + 1
242
+
243
+ self.accelerator = Accelerator(
244
+ log_with=self.config.log_with,
245
+ mixed_precision=self.config.mixed_precision,
246
+ project_config=accelerator_project_config,
247
+ # we always accumulate gradients across timesteps; we want config.train.gradient_accumulation_steps to be the
248
+ # number of *samples* we accumulate across, so we need to multiply by the number of training timesteps to get
249
+ # the total number of optimizer steps to accumulate across.
250
+ gradient_accumulation_steps=self.config.train_gradient_accumulation_steps,
251
+ **self.config.accelerator_kwargs,
252
+ )
253
+
254
+ is_using_tensorboard = config.log_with is not None and config.log_with == "tensorboard"
255
+
256
+ if self.accelerator.is_main_process:
257
+ self.accelerator.init_trackers(
258
+ self.config.tracker_project_name,
259
+ config=dict(alignprop_trainer_config=config.to_dict())
260
+ if not is_using_tensorboard
261
+ else config.to_dict(),
262
+ init_kwargs=self.config.tracker_kwargs,
263
+ )
264
+
265
+ logger.info(f"\n{config}")
266
+
267
+ set_seed(self.config.seed, device_specific=True)
268
+
269
+ self.sd_pipeline = sd_pipeline
270
+
271
+ self.sd_pipeline.set_progress_bar_config(
272
+ position=1,
273
+ disable=not self.accelerator.is_local_main_process,
274
+ leave=False,
275
+ desc="Timestep",
276
+ dynamic_ncols=True,
277
+ )
278
+
279
+ # For mixed precision training we cast all non-trainable weights (vae, non-lora text_encoder and non-lora unet) to half-precision
280
+ # as these weights are only used for inference, keeping weights in full precision is not required.
281
+ if self.accelerator.mixed_precision == "fp16":
282
+ inference_dtype = torch.float16
283
+ elif self.accelerator.mixed_precision == "bf16":
284
+ inference_dtype = torch.bfloat16
285
+ else:
286
+ inference_dtype = torch.float32
287
+
288
+ self.sd_pipeline.vae.to(self.accelerator.device, dtype=inference_dtype)
289
+ self.sd_pipeline.text_encoder.to(self.accelerator.device, dtype=inference_dtype)
290
+ self.sd_pipeline.unet.to(self.accelerator.device, dtype=inference_dtype)
291
+
292
+ trainable_layers = self.sd_pipeline.get_trainable_layers()
293
+
294
+ self.accelerator.register_save_state_pre_hook(self._save_model_hook)
295
+ self.accelerator.register_load_state_pre_hook(self._load_model_hook)
296
+
297
+ # Enable TF32 for faster training on Ampere GPUs,
298
+ # cf https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices
299
+ if self.config.allow_tf32:
300
+ torch.backends.cuda.matmul.allow_tf32 = True
301
+
302
+ self.optimizer = self._setup_optimizer(
303
+ trainable_layers.parameters() if not isinstance(trainable_layers, list) else trainable_layers
304
+ )
305
+
306
+ self.neg_prompt_embed = self.sd_pipeline.text_encoder(
307
+ self.sd_pipeline.tokenizer(
308
+ [""] if self.config.negative_prompts is None else self.config.negative_prompts,
309
+ return_tensors="pt",
310
+ padding="max_length",
311
+ truncation=True,
312
+ max_length=self.sd_pipeline.tokenizer.model_max_length,
313
+ ).input_ids.to(self.accelerator.device)
314
+ )[0]
315
+
316
+ # NOTE: for some reason, autocast is necessary for non-lora training but for lora training it isn't necessary and it uses
317
+ # more memory
318
+ self.autocast = self.sd_pipeline.autocast or self.accelerator.autocast
319
+
320
+ if hasattr(self.sd_pipeline, "use_lora") and self.sd_pipeline.use_lora:
321
+ unet, self.optimizer = self.accelerator.prepare(trainable_layers, self.optimizer)
322
+ self.trainable_layers = list(filter(lambda p: p.requires_grad, unet.parameters()))
323
+ else:
324
+ self.trainable_layers, self.optimizer = self.accelerator.prepare(trainable_layers, self.optimizer)
325
+
326
+ if config.resume_from:
327
+ logger.info(f"Resuming from {config.resume_from}")
328
+ self.accelerator.load_state(config.resume_from)
329
+ self.first_epoch = int(config.resume_from.split("_")[-1]) + 1
330
+ else:
331
+ self.first_epoch = 0
332
+
333
+ def compute_rewards(self, prompt_image_pairs):
334
+ reward, reward_metadata = self.reward_fn(
335
+ prompt_image_pairs["images"], prompt_image_pairs["prompts"], prompt_image_pairs["prompt_metadata"]
336
+ )
337
+ return reward
338
+
339
+ def step(self, epoch: int, global_step: int):
340
+ """
341
+ Perform a single step of training.
342
+
343
+ Args:
344
+ epoch (int): The current epoch.
345
+ global_step (int): The current global step.
346
+
347
+ Side Effects:
348
+ - Model weights are updated
349
+ - Logs the statistics to the accelerator trackers.
350
+ - If `self.image_samples_callback` is not None, it will be called with the prompt_image_pairs, global_step, and the accelerator tracker.
351
+
352
+ Returns:
353
+ global_step (int): The updated global step.
354
+ """
355
+ info = defaultdict(list)
356
+
357
+ self.sd_pipeline.unet.train()
358
+
359
+ for _ in range(self.config.train_gradient_accumulation_steps):
360
+ with self.accelerator.accumulate(self.sd_pipeline.unet), self.autocast(), torch.enable_grad():
361
+ prompt_image_pairs = self._generate_samples(
362
+ batch_size=self.config.train_batch_size,
363
+ )
364
+
365
+ rewards = self.compute_rewards(prompt_image_pairs)
366
+
367
+ prompt_image_pairs["rewards"] = rewards
368
+
369
+ rewards_vis = self.accelerator.gather(rewards).detach().cpu().numpy()
370
+
371
+ loss = self.calculate_loss(rewards)
372
+
373
+ self.accelerator.backward(loss)
374
+
375
+ if self.accelerator.sync_gradients:
376
+ self.accelerator.clip_grad_norm_(
377
+ self.trainable_layers.parameters()
378
+ if not isinstance(self.trainable_layers, list)
379
+ else self.trainable_layers,
380
+ self.config.train_max_grad_norm,
381
+ )
382
+
383
+ self.optimizer.step()
384
+ self.optimizer.zero_grad()
385
+
386
+ info["reward_mean"].append(rewards_vis.mean())
387
+ info["reward_std"].append(rewards_vis.std())
388
+ info["loss"].append(loss.item())
389
+
390
+ # Checks if the accelerator has performed an optimization step behind the scenes
391
+ if self.accelerator.sync_gradients:
392
+ # log training-related stuff
393
+ info = {k: torch.mean(torch.tensor(v)) for k, v in info.items()}
394
+ info = self.accelerator.reduce(info, reduction="mean")
395
+ info.update({"epoch": epoch})
396
+ self.accelerator.log(info, step=global_step)
397
+ global_step += 1
398
+ info = defaultdict(list)
399
+ else:
400
+ raise ValueError(
401
+ "Optimization step should have been performed by this point. Please check calculated gradient accumulation settings."
402
+ )
403
+ # Logs generated images
404
+ if self.image_samples_callback is not None and global_step % self.config.log_image_freq == 0:
405
+ self.image_samples_callback(prompt_image_pairs, global_step, self.accelerator.trackers[0])
406
+
407
+ if epoch != 0 and epoch % self.config.save_freq == 0 and self.accelerator.is_main_process:
408
+ self.accelerator.save_state()
409
+
410
+ return global_step
411
+
412
+ def calculate_loss(self, rewards):
413
+ """
414
+ Calculate the loss for a batch of an unpacked sample
415
+
416
+ Args:
417
+ rewards (torch.Tensor):
418
+ Differentiable reward scalars for each generated image, shape: [batch_size]
419
+
420
+ Returns:
421
+ loss (torch.Tensor)
422
+ (all of these are of shape (1,))
423
+ """
424
+ # Loss is specific to Aesthetic Reward function used in AlignProp (https://huggingface.co/papers/2310.03739)
425
+ loss = 10.0 - (rewards).mean()
426
+ return loss
427
+
428
+ def loss(
429
+ self,
430
+ advantages: torch.Tensor,
431
+ clip_range: float,
432
+ ratio: torch.Tensor,
433
+ ):
434
+ unclipped_loss = -advantages * ratio
435
+ clipped_loss = -advantages * torch.clamp(
436
+ ratio,
437
+ 1.0 - clip_range,
438
+ 1.0 + clip_range,
439
+ )
440
+ return torch.mean(torch.maximum(unclipped_loss, clipped_loss))
441
+
442
+ def _setup_optimizer(self, trainable_layers_parameters):
443
+ if self.config.train_use_8bit_adam:
444
+ import bitsandbytes
445
+
446
+ optimizer_cls = bitsandbytes.optim.AdamW8bit
447
+ else:
448
+ optimizer_cls = torch.optim.AdamW
449
+
450
+ return optimizer_cls(
451
+ trainable_layers_parameters,
452
+ lr=self.config.train_learning_rate,
453
+ betas=(self.config.train_adam_beta1, self.config.train_adam_beta2),
454
+ weight_decay=self.config.train_adam_weight_decay,
455
+ eps=self.config.train_adam_epsilon,
456
+ )
457
+
458
+ def _save_model_hook(self, models, weights, output_dir):
459
+ self.sd_pipeline.save_checkpoint(models, weights, output_dir)
460
+ weights.pop() # ensures that accelerate doesn't try to handle saving of the model
461
+
462
+ def _load_model_hook(self, models, input_dir):
463
+ self.sd_pipeline.load_checkpoint(models, input_dir)
464
+ models.pop() # ensures that accelerate doesn't try to handle loading of the model
465
+
466
+ def _generate_samples(self, batch_size, with_grad=True, prompts=None):
467
+ """
468
+ Generate samples from the model
469
+
470
+ Args:
471
+ batch_size (int): Batch size to use for sampling
472
+ with_grad (bool): Whether the generated RGBs should have gradients attached to it.
473
+
474
+ Returns:
475
+ prompt_image_pairs (dict[Any])
476
+ """
477
+ prompt_image_pairs = {}
478
+
479
+ sample_neg_prompt_embeds = self.neg_prompt_embed.repeat(batch_size, 1, 1)
480
+
481
+ if prompts is None:
482
+ prompts, prompt_metadata = zip(*[self.prompt_fn() for _ in range(batch_size)])
483
+ else:
484
+ prompt_metadata = [{} for _ in range(batch_size)]
485
+
486
+ prompt_ids = self.sd_pipeline.tokenizer(
487
+ prompts,
488
+ return_tensors="pt",
489
+ padding="max_length",
490
+ truncation=True,
491
+ max_length=self.sd_pipeline.tokenizer.model_max_length,
492
+ ).input_ids.to(self.accelerator.device)
493
+
494
+ prompt_embeds = self.sd_pipeline.text_encoder(prompt_ids)[0]
495
+
496
+ if with_grad:
497
+ sd_output = self.sd_pipeline.rgb_with_grad(
498
+ prompt_embeds=prompt_embeds,
499
+ negative_prompt_embeds=sample_neg_prompt_embeds,
500
+ num_inference_steps=self.config.sample_num_steps,
501
+ guidance_scale=self.config.sample_guidance_scale,
502
+ eta=self.config.sample_eta,
503
+ truncated_backprop_rand=self.config.truncated_backprop_rand,
504
+ truncated_backprop_timestep=self.config.truncated_backprop_timestep,
505
+ truncated_rand_backprop_minmax=self.config.truncated_rand_backprop_minmax,
506
+ output_type="pt",
507
+ )
508
+ else:
509
+ sd_output = self.sd_pipeline(
510
+ prompt_embeds=prompt_embeds,
511
+ negative_prompt_embeds=sample_neg_prompt_embeds,
512
+ num_inference_steps=self.config.sample_num_steps,
513
+ guidance_scale=self.config.sample_guidance_scale,
514
+ eta=self.config.sample_eta,
515
+ output_type="pt",
516
+ )
517
+
518
+ images = sd_output.images
519
+
520
+ prompt_image_pairs["images"] = images
521
+ prompt_image_pairs["prompts"] = prompts
522
+ prompt_image_pairs["prompt_metadata"] = prompt_metadata
523
+
524
+ return prompt_image_pairs
525
+
526
+ def train(self, epochs: Optional[int] = None):
527
+ """
528
+ Train the model for a given number of epochs
529
+ """
530
+ global_step = 0
531
+ if epochs is None:
532
+ epochs = self.config.num_epochs
533
+ for epoch in range(self.first_epoch, epochs):
534
+ global_step = self.step(epoch, global_step)
535
+
536
+ def _save_pretrained(self, save_directory):
537
+ self.sd_pipeline.save_pretrained(save_directory)
538
+ self.create_model_card()
539
+
540
+ def create_model_card(
541
+ self,
542
+ model_name: Optional[str] = None,
543
+ dataset_name: Optional[str] = None,
544
+ tags: Union[str, list[str], None] = None,
545
+ ):
546
+ """
547
+ Creates a draft of a model card using the information available to the `Trainer`.
548
+
549
+ Args:
550
+ model_name (`str` or `None`, *optional*, defaults to `None`):
551
+ Name of the model.
552
+ dataset_name (`str` or `None`, *optional*, defaults to `None`):
553
+ Name of the dataset used for training.
554
+ tags (`str`, `list[str]` or `None`, *optional*, defaults to `None`):
555
+ Tags to be associated with the model card.
556
+ """
557
+ if not self.is_world_process_zero():
558
+ return
559
+
560
+ if hasattr(self.model.config, "_name_or_path") and not os.path.isdir(self.model.config._name_or_path):
561
+ base_model = self.model.config._name_or_path
562
+ else:
563
+ base_model = None
564
+
565
+ tags = tags or []
566
+ if isinstance(tags, str):
567
+ tags = [tags]
568
+
569
+ if hasattr(self.model.config, "unsloth_version"):
570
+ tags.append("unsloth")
571
+
572
+ citation = textwrap.dedent("""\
573
+ @article{prabhudesai2024aligning,
574
+ title = {{Aligning Text-to-Image Diffusion Models with Reward Backpropagation}},
575
+ author = {Mihir Prabhudesai and Anirudh Goyal and Deepak Pathak and Katerina Fragkiadaki},
576
+ year = 2024,
577
+ eprint = {arXiv:2310.03739}
578
+ }""")
579
+
580
+ model_card = generate_model_card(
581
+ base_model=base_model,
582
+ model_name=model_name,
583
+ hub_model_id=self.hub_model_id,
584
+ dataset_name=dataset_name,
585
+ tags=tags,
586
+ wandb_url=wandb.run.get_url() if is_wandb_available() and wandb.run is not None else None,
587
+ comet_url=get_comet_experiment_url(),
588
+ trainer_name="AlignProp",
589
+ trainer_citation=citation,
590
+ paper_title="Aligning Text-to-Image Diffusion Models with Reward Backpropagation",
591
+ paper_id="2310.03739",
592
+ )
593
+
594
+ model_card.save(os.path.join(self.args.output_dir, "README.md"))
595
+ class UnslothAlignPropTrainer(_UnslothAlignPropTrainer):
596
+ """
597
+
598
+ The AlignPropTrainer uses Deep Diffusion Policy Optimization to optimise diffusion models.
599
+ Note, this trainer is heavily inspired by the work here: https://github.com/mihirp1998/AlignProp/
600
+ As of now only Stable Diffusion based pipelines are supported
601
+
602
+ Attributes:
603
+ config (`AlignPropConfig`):
604
+ Configuration object for AlignPropTrainer. Check the documentation of `PPOConfig` for more details.
605
+ reward_function (`Callable[[torch.Tensor, tuple[str], tuple[Any]], torch.Tensor]`):
606
+ Reward function to be used
607
+ prompt_function (`Callable[[], tuple[str, Any]]`):
608
+ Function to generate prompts to guide model
609
+ sd_pipeline (`DDPOStableDiffusionPipeline`):
610
+ Stable Diffusion pipeline to be used for training.
611
+ image_samples_hook (`Optional[Callable[[Any, Any, Any], Any]]`):
612
+ Hook to be called to log images
613
+
614
+ """
615
+ def __init__(
616
+ self,
617
+ config,
618
+ reward_function,
619
+ prompt_function,
620
+ sd_pipeline,
621
+ image_samples_hook = None,
622
+ **kwargs
623
+ ):
624
+ if args is None: args = UnslothAlignPropConfig()
625
+ other_metrics = []
626
+
627
+ from unsloth_zoo.logging_utils import PatchRLStatistics
628
+ PatchRLStatistics('alignprop_trainer', other_metrics)
629
+
630
+ super().__init__(
631
+ config = config,
632
+ reward_function = reward_function,
633
+ prompt_function = prompt_function,
634
+ sd_pipeline = sd_pipeline,
635
+ image_samples_hook = image_samples_hook,**kwargs)
636
+
637
+ pass
unsloth_compiled_cache/UnslothBCOTrainer.py ADDED
@@ -0,0 +1,1824 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ 2025.3.17
3
+ 2025.3.19
4
+ 4.50.0
5
+ 0.15.2
6
+ __UNSLOTH_VERSIONING__
7
+ """
8
+ from torch import Tensor
9
+ import torch
10
+ import torch.nn as nn
11
+ from torch.nn import functional as F
12
+ from trl.trainer.bco_trainer import (Any, AutoModelForCausalLM, BCOConfig, BCOTrainer, BaseImageProcessor, CLF_NAME, Callable, DPODataCollatorWithPadding, DataCollator, DataLoader, Dataset, EvalLoopOutput, F, FeatureExtractionMixin, Literal, Optional, PartialState, PeftModel, PreTrainedModel, PreTrainedModelWrapper, PreTrainedTokenizerBase, ProcessorMixin, RUNNING_NAME, RunningMoments, SequentialSampler, Trainer, TrainerCallback, TrainingArguments, Union, _process_tokens, _tokenize, amp, contextmanager, create_reference_model, deepcopy, defaultdict, disable_dropout_in_model, generate_model_card, get_comet_experiment_url, has_length, inspect, is_comet_available, is_peft_available, is_sklearn_available, is_wandb_available, itemgetter, log_table_to_comet_experiment, maybe_apply_chat_template, nn, np, nullcontext, os, pad_to_length, pd, peft_module_casting_to_bf16, prepare_model_for_kbit_training, random, textwrap, torch, tqdm, transformers, version, wandb, warnings, F, Optional, PeftModel, PreTrainedModel, Trainer, is_peft_available, os, torch)
13
+
14
+
15
+ import os
16
+ from typing import *
17
+ from dataclasses import dataclass, field
18
+ from packaging.version import Version
19
+ import torch
20
+ import numpy as np
21
+ from contextlib import nullcontext
22
+ from torch.nn import functional as F
23
+ from transformers import DataCollatorForSeq2Seq, DataCollatorForLanguageModeling
24
+
25
+ torch_compile_options = {
26
+ "epilogue_fusion" : True,
27
+ "max_autotune" : False,
28
+ "shape_padding" : True,
29
+ "trace.enabled" : False,
30
+ "triton.cudagraphs" : False,
31
+ }
32
+
33
+ @torch.compile(dynamic = True, fullgraph = True, options = torch_compile_options,)
34
+ def selective_log_softmax(logits, index):
35
+ logits = logits.to(torch.float32)
36
+ selected_logits = torch.gather(logits, dim = -1, index = index.unsqueeze(-1)).squeeze(-1)
37
+ # loop to reduce peak mem consumption
38
+ # logsumexp_values = torch.stack([torch.logsumexp(lg, dim=-1) for lg in logits])
39
+ logsumexp_values = torch.logsumexp(logits, dim = -1)
40
+ per_token_logps = selected_logits - logsumexp_values # log_softmax(x_i) = x_i - logsumexp(x)
41
+ return per_token_logps
42
+ @dataclass
43
+ class UnslothBCOConfig(BCOConfig):
44
+ """
45
+
46
+ Configuration class for the [`BCOTrainer`].
47
+
48
+ Using [`~transformers.HfArgumentParser`] we can turn this class into
49
+ [argparse](https://docs.python.org/3/library/argparse#module-argparse) arguments that can be specified on the
50
+ command line.
51
+
52
+ Parameters:
53
+ max_length (`int` or `None`, *optional*, defaults to `1024`):
54
+ Maximum length of the sequences (prompt + completion) in the batch. This argument is required if you want
55
+ to use the default data collator.
56
+ max_prompt_length (`int` or `None`, *optional*, defaults to `512`):
57
+ Maximum length of the prompt. This argument is required if you want to use the default data collator.
58
+ max_completion_length (`int` or `None`, *optional*, defaults to `None`):
59
+ Maximum length of the completion. This argument is required if you want to use the default data collator
60
+ and your model is an encoder-decoder.
61
+ beta (`float`, *optional*, defaults to `0.1`):
62
+ Parameter controlling the deviation from the reference model. Higher β means less deviation from the
63
+ reference model.
64
+ label_pad_token_id (`int`, *optional*, defaults to `-100`):
65
+ Label pad token id. This argument is required if you want to use the default data collator.
66
+ padding_value (`int` or `None`, *optional*, defaults to `None`):
67
+ Padding value to use. If `None`, the padding value of the tokenizer is used.
68
+ truncation_mode (`str`, *optional*, defaults to `"keep_end"`):
69
+ Truncation mode to use when the prompt is too long. Possible values are `"keep_end"` or `"keep_start"`.
70
+ This argument is required if you want to use the default data collator.
71
+ disable_dropout (`bool`, *optional*, defaults to `True`):
72
+ Whether to disable dropout in the model and reference model.
73
+ generate_during_eval (`bool`, *optional*, defaults to `False`):
74
+ If `True`, generates and logs completions from both the model and the reference model to W&B or Comet during
75
+ evaluation.
76
+ is_encoder_decoder (`bool` or `None`, *optional*, defaults to `None`):
77
+ When using the `model_init` argument (callable) to instantiate the model instead of the `model` argument,
78
+ you need to specify if the model returned by the callable is an encoder-decoder model.
79
+ precompute_ref_log_probs (`bool`, *optional*, defaults to `False`):
80
+ Whether to precompute reference model log probabilities for training and evaluation datasets. This is
81
+ useful when training without the reference model to reduce the total GPU memory needed.
82
+ model_init_kwargs (`dict[str, Any]` or `None`, *optional*, defaults to `None`):
83
+ Keyword arguments to pass to `AutoModelForCausalLM.from_pretrained` when instantiating the model from a
84
+ string.
85
+ ref_model_init_kwargs (`dict[str, Any]` or `None`, *optional*, defaults to `None`):
86
+ Keyword arguments to pass to `AutoModelForCausalLM.from_pretrained` when instantiating the reference model
87
+ from a string.
88
+ dataset_num_proc (`int` or `None`, *optional*, defaults to `None`):
89
+ Number of processes to use for processing the dataset.
90
+ prompt_sample_size (`int`, *optional*, defaults to `1024`):
91
+ Number of prompts that are fed to density ratio classifier.
92
+ min_density_ratio (`float`, *optional*, defaults to `0.5`):
93
+ Minimum value of the density ratio. The estimated density ratio is clamped to this value.
94
+ max_density_ratio (`float`, *optional*, defaults to `10.0`):
95
+ Maximum value of the density ratio. The estimated density ratio is clamped to this value.
96
+
97
+ """
98
+ vllm_sampling_params: Optional[Any] = field(
99
+ default = None,
100
+ metadata = {'help': 'vLLM SamplingParams'},
101
+ )
102
+ unsloth_num_chunks : Optional[int] = field(
103
+ default = -1,
104
+ metadata = {'help': 'Chunk size to reduce memory usage. -1 is most efficient.'},
105
+ )
106
+ def __init__(
107
+ self,
108
+ output_dir = None,
109
+ overwrite_output_dir = None,
110
+ do_train = False,
111
+ do_eval = False,
112
+ do_predict = False,
113
+ eval_strategy = 'no',
114
+ prediction_loss_only = False,
115
+ per_device_train_batch_size = 4,
116
+ per_device_eval_batch_size = 4,
117
+ per_gpu_train_batch_size = None,
118
+ per_gpu_eval_batch_size = None,
119
+ gradient_accumulation_steps = 2,
120
+ eval_accumulation_steps = 2,
121
+ eval_delay = 0,
122
+ torch_empty_cache_steps = 250,
123
+ learning_rate = 5e-05,
124
+ weight_decay = 0.01,
125
+ adam_beta1 = 0.9,
126
+ adam_beta2 = 0.999,
127
+ adam_epsilon = 1e-08,
128
+ max_grad_norm = 1.0,
129
+ num_train_epochs = 3.0,
130
+ max_steps = -1,
131
+ lr_scheduler_type = 'linear',
132
+ warmup_ratio = 0.1,
133
+ warmup_steps = 0,
134
+ log_level = 'passive',
135
+ log_level_replica = 'warning',
136
+ log_on_each_node = True,
137
+ logging_dir = None,
138
+ logging_strategy = 'steps',
139
+ logging_first_step = False,
140
+ logging_steps = 1,
141
+ logging_nan_inf_filter = False,
142
+ save_strategy = 'steps',
143
+ save_steps = 500,
144
+ save_total_limit = None,
145
+ save_safetensors = True,
146
+ save_on_each_node = False,
147
+ save_only_model = False,
148
+ restore_callback_states_from_checkpoint = False,
149
+ no_cuda = False,
150
+ use_cpu = False,
151
+ use_mps_device = False,
152
+ seed = 3407,
153
+ data_seed = 3407,
154
+ jit_mode_eval = False,
155
+ use_ipex = False,
156
+ bf16 = False,
157
+ fp16 = False,
158
+ fp16_opt_level = 'O1',
159
+ half_precision_backend = 'auto',
160
+ bf16_full_eval = False,
161
+ fp16_full_eval = False,
162
+ tf32 = None,
163
+ local_rank = -1,
164
+ ddp_backend = None,
165
+ tpu_num_cores = None,
166
+ tpu_metrics_debug = False,
167
+ debug = '',
168
+ dataloader_drop_last = False,
169
+ eval_steps = None,
170
+ dataloader_num_workers = 0,
171
+ dataloader_prefetch_factor = None,
172
+ past_index = -1,
173
+ run_name = None,
174
+ disable_tqdm = None,
175
+ remove_unused_columns = True,
176
+ label_names = None,
177
+ load_best_model_at_end = False,
178
+ metric_for_best_model = None,
179
+ greater_is_better = None,
180
+ ignore_data_skip = False,
181
+ fsdp = '',
182
+ fsdp_min_num_params = 0,
183
+ fsdp_config = None,
184
+ tp_size = 0,
185
+ fsdp_transformer_layer_cls_to_wrap = None,
186
+ accelerator_config = None,
187
+ deepspeed = None,
188
+ label_smoothing_factor = 0.0,
189
+ optim = 'adamw_8bit',
190
+ optim_args = None,
191
+ adafactor = False,
192
+ group_by_length = False,
193
+ length_column_name = 'length',
194
+ report_to = None,
195
+ ddp_find_unused_parameters = None,
196
+ ddp_bucket_cap_mb = None,
197
+ ddp_broadcast_buffers = None,
198
+ dataloader_pin_memory = True,
199
+ dataloader_persistent_workers = False,
200
+ skip_memory_metrics = True,
201
+ use_legacy_prediction_loop = False,
202
+ push_to_hub = False,
203
+ resume_from_checkpoint = None,
204
+ hub_model_id = None,
205
+ hub_strategy = 'every_save',
206
+ hub_token = None,
207
+ hub_private_repo = None,
208
+ hub_always_push = False,
209
+ gradient_checkpointing = False,
210
+ gradient_checkpointing_kwargs = None,
211
+ include_inputs_for_metrics = False,
212
+ eval_do_concat_batches = True,
213
+ fp16_backend = 'auto',
214
+ evaluation_strategy = None,
215
+ push_to_hub_model_id = None,
216
+ push_to_hub_organization = None,
217
+ push_to_hub_token = None,
218
+ mp_parameters = '',
219
+ auto_find_batch_size = False,
220
+ full_determinism = False,
221
+ torchdynamo = None,
222
+ ray_scope = 'last',
223
+ ddp_timeout = 1800,
224
+ torch_compile = False,
225
+ torch_compile_backend = None,
226
+ torch_compile_mode = None,
227
+ dispatch_batches = None,
228
+ split_batches = None,
229
+ include_tokens_per_second = False,
230
+ include_num_input_tokens_seen = False,
231
+ neftune_noise_alpha = None,
232
+ optim_target_modules = None,
233
+ batch_eval_metrics = False,
234
+ eval_on_start = False,
235
+ use_liger_kernel = False,
236
+ eval_use_gather_object = False,
237
+ average_tokens_across_devices = False,
238
+ max_length = 1024,
239
+ max_prompt_length = 512,
240
+ max_completion_length = None,
241
+ beta = 0.1,
242
+ label_pad_token_id = -100,
243
+ padding_value = None,
244
+ truncation_mode = 'keep_end',
245
+ disable_dropout = True,
246
+ generate_during_eval = False,
247
+ is_encoder_decoder = None,
248
+ precompute_ref_log_probs = False,
249
+ model_init_kwargs = None,
250
+ ref_model_init_kwargs = None,
251
+ dataset_num_proc = None,
252
+ prompt_sample_size = 1024,
253
+ min_density_ratio = 0.5,
254
+ max_density_ratio = 10.0,
255
+ vllm_sampling_params = None,
256
+ unsloth_num_chunks = -1,
257
+ **kwargs,
258
+ ):
259
+ if learning_rate < 1e-7: raise FloatingPointError(f'Unsloth: Your learning rate of `{learning_rate}` is too small and less than 1e-7! Consider increasing it, otherwise gradient updates will be close to 0!')
260
+ if learning_rate > 1: raise OverflowError(f'Unsloth: Your learning rate of `{learning_rate}` is way too larger > 1! Consider decreasing it to 1e-1, otherwise gradient updates will explode!')
261
+ if output_dir is None and save_strategy == 'steps' and save_steps == 500:
262
+ output_dir = 'unsloth_training_checkpoints'
263
+ save_strategy = 'no'
264
+ if dataset_num_proc is None:
265
+ from multiprocessing import cpu_count
266
+ dataset_num_proc = cpu_count()
267
+
268
+ super().__init__(
269
+ output_dir = output_dir,
270
+ overwrite_output_dir = overwrite_output_dir,
271
+ do_train = do_train,
272
+ do_eval = do_eval,
273
+ do_predict = do_predict,
274
+ eval_strategy = eval_strategy,
275
+ prediction_loss_only = prediction_loss_only,
276
+ per_device_train_batch_size = per_device_train_batch_size,
277
+ per_device_eval_batch_size = per_device_eval_batch_size,
278
+ per_gpu_train_batch_size = per_gpu_train_batch_size,
279
+ per_gpu_eval_batch_size = per_gpu_eval_batch_size,
280
+ gradient_accumulation_steps = gradient_accumulation_steps,
281
+ eval_accumulation_steps = eval_accumulation_steps,
282
+ eval_delay = eval_delay,
283
+ torch_empty_cache_steps = torch_empty_cache_steps,
284
+ learning_rate = learning_rate,
285
+ weight_decay = weight_decay,
286
+ adam_beta1 = adam_beta1,
287
+ adam_beta2 = adam_beta2,
288
+ adam_epsilon = adam_epsilon,
289
+ max_grad_norm = max_grad_norm,
290
+ num_train_epochs = num_train_epochs,
291
+ max_steps = max_steps,
292
+ lr_scheduler_type = lr_scheduler_type,
293
+ warmup_ratio = warmup_ratio,
294
+ warmup_steps = warmup_steps,
295
+ log_level = log_level,
296
+ log_level_replica = log_level_replica,
297
+ log_on_each_node = log_on_each_node,
298
+ logging_dir = logging_dir,
299
+ logging_strategy = logging_strategy,
300
+ logging_first_step = logging_first_step,
301
+ logging_steps = logging_steps,
302
+ logging_nan_inf_filter = logging_nan_inf_filter,
303
+ save_strategy = save_strategy,
304
+ save_steps = save_steps,
305
+ save_total_limit = save_total_limit,
306
+ save_safetensors = save_safetensors,
307
+ save_on_each_node = save_on_each_node,
308
+ save_only_model = save_only_model,
309
+ restore_callback_states_from_checkpoint = restore_callback_states_from_checkpoint,
310
+ no_cuda = no_cuda,
311
+ use_cpu = use_cpu,
312
+ use_mps_device = use_mps_device,
313
+ seed = seed,
314
+ data_seed = data_seed,
315
+ jit_mode_eval = jit_mode_eval,
316
+ use_ipex = use_ipex,
317
+ bf16 = bf16,
318
+ fp16 = fp16,
319
+ fp16_opt_level = fp16_opt_level,
320
+ half_precision_backend = half_precision_backend,
321
+ bf16_full_eval = bf16_full_eval,
322
+ fp16_full_eval = fp16_full_eval,
323
+ tf32 = tf32,
324
+ local_rank = local_rank,
325
+ ddp_backend = ddp_backend,
326
+ tpu_num_cores = tpu_num_cores,
327
+ tpu_metrics_debug = tpu_metrics_debug,
328
+ debug = debug,
329
+ dataloader_drop_last = dataloader_drop_last,
330
+ eval_steps = eval_steps,
331
+ dataloader_num_workers = dataloader_num_workers,
332
+ dataloader_prefetch_factor = dataloader_prefetch_factor,
333
+ past_index = past_index,
334
+ run_name = run_name,
335
+ disable_tqdm = disable_tqdm,
336
+ remove_unused_columns = remove_unused_columns,
337
+ label_names = label_names,
338
+ load_best_model_at_end = load_best_model_at_end,
339
+ metric_for_best_model = metric_for_best_model,
340
+ greater_is_better = greater_is_better,
341
+ ignore_data_skip = ignore_data_skip,
342
+ fsdp = fsdp,
343
+ fsdp_min_num_params = fsdp_min_num_params,
344
+ fsdp_config = fsdp_config,
345
+ tp_size = tp_size,
346
+ fsdp_transformer_layer_cls_to_wrap = fsdp_transformer_layer_cls_to_wrap,
347
+ accelerator_config = accelerator_config,
348
+ deepspeed = deepspeed,
349
+ label_smoothing_factor = label_smoothing_factor,
350
+ optim = optim,
351
+ optim_args = optim_args,
352
+ adafactor = adafactor,
353
+ group_by_length = group_by_length,
354
+ length_column_name = length_column_name,
355
+ report_to = report_to,
356
+ ddp_find_unused_parameters = ddp_find_unused_parameters,
357
+ ddp_bucket_cap_mb = ddp_bucket_cap_mb,
358
+ ddp_broadcast_buffers = ddp_broadcast_buffers,
359
+ dataloader_pin_memory = dataloader_pin_memory,
360
+ dataloader_persistent_workers = dataloader_persistent_workers,
361
+ skip_memory_metrics = skip_memory_metrics,
362
+ use_legacy_prediction_loop = use_legacy_prediction_loop,
363
+ push_to_hub = push_to_hub,
364
+ resume_from_checkpoint = resume_from_checkpoint,
365
+ hub_model_id = hub_model_id,
366
+ hub_strategy = hub_strategy,
367
+ hub_token = hub_token,
368
+ hub_private_repo = hub_private_repo,
369
+ hub_always_push = hub_always_push,
370
+ gradient_checkpointing = gradient_checkpointing,
371
+ gradient_checkpointing_kwargs = gradient_checkpointing_kwargs,
372
+ include_inputs_for_metrics = include_inputs_for_metrics,
373
+ eval_do_concat_batches = eval_do_concat_batches,
374
+ fp16_backend = fp16_backend,
375
+ evaluation_strategy = evaluation_strategy,
376
+ push_to_hub_model_id = push_to_hub_model_id,
377
+ push_to_hub_organization = push_to_hub_organization,
378
+ push_to_hub_token = push_to_hub_token,
379
+ mp_parameters = mp_parameters,
380
+ auto_find_batch_size = auto_find_batch_size,
381
+ full_determinism = full_determinism,
382
+ torchdynamo = torchdynamo,
383
+ ray_scope = ray_scope,
384
+ ddp_timeout = ddp_timeout,
385
+ torch_compile = torch_compile,
386
+ torch_compile_backend = torch_compile_backend,
387
+ torch_compile_mode = torch_compile_mode,
388
+ dispatch_batches = dispatch_batches,
389
+ split_batches = split_batches,
390
+ include_tokens_per_second = include_tokens_per_second,
391
+ include_num_input_tokens_seen = include_num_input_tokens_seen,
392
+ neftune_noise_alpha = neftune_noise_alpha,
393
+ optim_target_modules = optim_target_modules,
394
+ batch_eval_metrics = batch_eval_metrics,
395
+ eval_on_start = eval_on_start,
396
+ use_liger_kernel = use_liger_kernel,
397
+ eval_use_gather_object = eval_use_gather_object,
398
+ average_tokens_across_devices = average_tokens_across_devices,
399
+ max_length = max_length,
400
+ max_prompt_length = max_prompt_length,
401
+ max_completion_length = max_completion_length,
402
+ beta = beta,
403
+ label_pad_token_id = label_pad_token_id,
404
+ padding_value = padding_value,
405
+ truncation_mode = truncation_mode,
406
+ disable_dropout = disable_dropout,
407
+ generate_during_eval = generate_during_eval,
408
+ is_encoder_decoder = is_encoder_decoder,
409
+ precompute_ref_log_probs = precompute_ref_log_probs,
410
+ model_init_kwargs = model_init_kwargs,
411
+ ref_model_init_kwargs = ref_model_init_kwargs,
412
+ dataset_num_proc = dataset_num_proc,
413
+ prompt_sample_size = prompt_sample_size,
414
+ min_density_ratio = min_density_ratio,
415
+ max_density_ratio = max_density_ratio,**kwargs)
416
+ self.vllm_sampling_params = vllm_sampling_params
417
+ self.unsloth_num_chunks = unsloth_num_chunks
418
+ pass
419
+
420
+ class _UnslothBCOTrainer(Trainer):
421
+ r""""""
422
+
423
+ _tag_names = ["trl", "bco"]
424
+
425
+ def __init__(
426
+ self,
427
+ model: Union[PreTrainedModel, nn.Module, str] = None,
428
+ ref_model: Optional[Union[PreTrainedModel, nn.Module, str]] = None,
429
+ args: BCOConfig = None,
430
+ train_dataset: Optional[Dataset] = None,
431
+ eval_dataset: Optional[Union[Dataset, dict[str, Dataset]]] = None,
432
+ processing_class: Optional[
433
+ Union[PreTrainedTokenizerBase, BaseImageProcessor, FeatureExtractionMixin, ProcessorMixin]
434
+ ] = None,
435
+ data_collator: Optional[DataCollator] = None,
436
+ model_init: Optional[Callable[[], PreTrainedModel]] = None,
437
+ callbacks: Optional[list[TrainerCallback]] = None,
438
+ optimizers: tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR] = (None, None),
439
+ preprocess_logits_for_metrics: Optional[Callable[[torch.Tensor, torch.Tensor], torch.Tensor]] = None,
440
+ peft_config: Optional[dict] = None,
441
+ compute_metrics: Optional[Callable[[EvalLoopOutput], dict]] = None,
442
+ model_adapter_name: Optional[str] = None,
443
+ ref_adapter_name: Optional[str] = None,
444
+ embedding_func: Optional[Callable] = None,
445
+ embedding_tokenizer: Optional[PreTrainedTokenizerBase] = None,
446
+ ):
447
+ if not is_sklearn_available():
448
+ raise ImportError(
449
+ "BCOTrainer requires the scikit-learn library. Please install it with `pip install scikit-learn`."
450
+ )
451
+
452
+ if type(args) is TrainingArguments:
453
+ raise ValueError("Please use `BCOConfig` instead `TrainingArguments`.")
454
+
455
+ if not isinstance(model, str) and ref_model is model:
456
+ raise ValueError(
457
+ "`model` and `ref_model` cannot be the same object. If you want `ref_model` to be the "
458
+ "same as `model`, you must mass a copy of it, or `None` if you use peft."
459
+ )
460
+
461
+ if args.model_init_kwargs is None:
462
+ model_init_kwargs = {}
463
+ elif not isinstance(model, str):
464
+ raise ValueError("You passed model_kwargs to the BCOTrainer. But your model is already instantiated.")
465
+ else:
466
+ model_init_kwargs = args.model_init_kwargs
467
+ torch_dtype = model_init_kwargs.get("torch_dtype")
468
+ if torch_dtype is not None:
469
+ # Convert to `torch.dtype` if an str is passed
470
+ if isinstance(torch_dtype, str) and torch_dtype != "auto":
471
+ torch_dtype = getattr(torch, torch_dtype)
472
+ if torch_dtype != "auto" and not isinstance(torch_dtype, torch.dtype):
473
+ raise ValueError(
474
+ f"Invalid `torch_dtype` passed to the BCOConfig. Expected a string with either `torch.dtype` or 'auto', but got {torch_dtype}."
475
+ )
476
+ model_init_kwargs["torch_dtype"] = torch_dtype
477
+
478
+ if args.ref_model_init_kwargs is None:
479
+ ref_model_init_kwargs = {}
480
+ elif not isinstance(ref_model, str):
481
+ raise ValueError(
482
+ "You passed ref_model_kwargs to the BCOTrainer. But your ref_model is already instantiated."
483
+ )
484
+ else:
485
+ ref_model_init_kwargs = args.ref_model_init_kwargs
486
+ torch_dtype = ref_model_init_kwargs.get("torch_dtype")
487
+ if torch_dtype is not None:
488
+ # Convert to `torch.dtype` if an str is passed
489
+ if isinstance(torch_dtype, str) and torch_dtype != "auto":
490
+ torch_dtype = getattr(torch, torch_dtype)
491
+ if torch_dtype != "auto" and not isinstance(torch_dtype, torch.dtype):
492
+ raise ValueError(
493
+ f"Invalid `torch_dtype` passed to the BCOConfig. Expected a string with either `torch.dtype` or 'auto', but got {torch_dtype}."
494
+ )
495
+ ref_model_init_kwargs["torch_dtype"] = torch_dtype
496
+
497
+ if isinstance(model, str):
498
+ model = AutoModelForCausalLM.from_pretrained(model, **model_init_kwargs)
499
+
500
+ if isinstance(ref_model, str):
501
+ ref_model = AutoModelForCausalLM.from_pretrained(ref_model, **ref_model_init_kwargs)
502
+
503
+ # Initialize this variable to False. This helps tracking the case when `peft_module_casting_to_bf16`
504
+ # has been called in order to properly call autocast if needed.
505
+ self._peft_has_been_casted_to_bf16 = False
506
+
507
+ if not is_peft_available() and peft_config is not None:
508
+ raise ValueError(
509
+ "PEFT is not installed and you passed a `peft_config` in the trainer's kwargs, please install it with `pip install peft` to use the PEFT models"
510
+ )
511
+ elif is_peft_available() and peft_config is not None:
512
+ # if model is a peft model and we have a peft_config, we merge and unload it first
513
+ if isinstance(model, PeftModel):
514
+ model = model.merge_and_unload()
515
+
516
+ if getattr(model, "is_loaded_in_8bit", False) or getattr(model, "is_loaded_in_4bit", False):
517
+ _support_gc_kwargs = hasattr(
518
+ args, "gradient_checkpointing_kwargs"
519
+ ) and "gradient_checkpointing_kwargs" in list(
520
+ inspect.signature(prepare_model_for_kbit_training).parameters
521
+ )
522
+
523
+ prepare_model_kwargs = {"use_gradient_checkpointing": args.gradient_checkpointing}
524
+
525
+ if _support_gc_kwargs:
526
+ prepare_model_kwargs["gradient_checkpointing_kwargs"] = args.gradient_checkpointing_kwargs
527
+
528
+ model = prepare_model_for_kbit_training(model, **prepare_model_kwargs)
529
+ elif getattr(args, "gradient_checkpointing", False):
530
+ # For backward compatibility with older versions of transformers
531
+ if hasattr(model, "enable_input_require_grads"):
532
+ model.enable_input_require_grads()
533
+ else:
534
+
535
+ def make_inputs_require_grad(module, input, output):
536
+ output.requires_grad_(True)
537
+
538
+ model.get_input_embeddings().register_forward_hook(make_inputs_require_grad)
539
+
540
+ # get peft model with the given config
541
+ model = model
542
+ if args.bf16 and getattr(model, "is_loaded_in_4bit", False):
543
+ peft_module_casting_to_bf16(model)
544
+ # If args.bf16 we need to explicitly call `generate` with torch amp autocast context manager
545
+ self._peft_has_been_casted_to_bf16 = True
546
+
547
+ # For models that use gradient_checkpointing, we need to attach a hook that enables input
548
+ # to explicitly have `requires_grad=True`, otherwise training will either silently
549
+ # fail or completely fail.
550
+ elif getattr(args, "gradient_checkpointing", False):
551
+ # For backward compatibility with older versions of transformers
552
+ if hasattr(model, "enable_input_require_grads"):
553
+ model.enable_input_require_grads()
554
+ else:
555
+
556
+ def make_inputs_require_grad(module, input, output):
557
+ output.requires_grad_(True)
558
+
559
+ model.get_input_embeddings().register_forward_hook(make_inputs_require_grad)
560
+
561
+ if args.generate_during_eval and not (is_wandb_available() or is_comet_available()):
562
+ raise ValueError(
563
+ "`generate_during_eval=True` requires Weights and Biases or Comet to be installed."
564
+ " Please install `wandb` or `comet-ml` to resolve."
565
+ )
566
+
567
+ if model is not None:
568
+ self.is_encoder_decoder = model.config.is_encoder_decoder
569
+ elif args.is_encoder_decoder is None:
570
+ raise ValueError("When no model is provided, you need to pass the parameter is_encoder_decoder.")
571
+ else:
572
+ self.is_encoder_decoder = args.is_encoder_decoder
573
+
574
+ self.is_peft_model = is_peft_available() and isinstance(model, PeftModel)
575
+ self.model_adapter_name = model_adapter_name
576
+ self.ref_adapter_name = ref_adapter_name
577
+
578
+ if ref_model:
579
+ self.ref_model = ref_model
580
+ elif self.is_peft_model or args.precompute_ref_log_probs:
581
+ # The `model` with adapters turned off will be used as the reference model
582
+ self.ref_model = None
583
+ else:
584
+ self.ref_model = create_reference_model(model)
585
+
586
+ if processing_class is None:
587
+ raise ValueError(
588
+ "max_length or a processing_class must be specified when using the default DPODataCollatorWithPadding"
589
+ )
590
+ if args.max_length is None:
591
+ warnings.warn(
592
+ "When using DPODataCollatorWithPadding, you should set `max_length` in the `BCOConfig`. "
593
+ "It will be set to `512` by default, but you should do it yourself in the future.",
594
+ UserWarning,
595
+ )
596
+ max_length = 512
597
+ if args.max_length is not None:
598
+ max_length = args.max_length
599
+
600
+ if args.max_prompt_length is None:
601
+ warnings.warn(
602
+ "When using DPODataCollatorWithPadding, you should set `max_prompt_length` in the `BCOConfig`. "
603
+ "It will be set to `128` by default, but you should do it yourself in the future.",
604
+ UserWarning,
605
+ )
606
+ max_prompt_length = 128
607
+ if args.max_prompt_length is not None:
608
+ max_prompt_length = args.max_prompt_length
609
+
610
+ max_completion_length = None
611
+ if args.max_completion_length is None and self.is_encoder_decoder:
612
+ warnings.warn(
613
+ "When using DPODataCollatorWithPadding with an encoder decoder architecture, you should set `max_completion_length` in the BCOTrainer's init"
614
+ " it will be set to `128` by default, but you should do it yourself in the future.",
615
+ UserWarning,
616
+ )
617
+ max_completion_length = 128
618
+ if args.max_completion_length is not None and self.is_encoder_decoder:
619
+ max_completion_length = args.max_completion_length
620
+
621
+ if data_collator is None:
622
+ data_collator = DPODataCollatorWithPadding(
623
+ pad_token_id=processing_class.pad_token_id,
624
+ label_pad_token_id=args.label_pad_token_id,
625
+ is_encoder_decoder=self.is_encoder_decoder,
626
+ )
627
+
628
+ if args.remove_unused_columns:
629
+ args.remove_unused_columns = False
630
+ # warn users
631
+ warnings.warn(
632
+ "When using DPODataCollatorWithPadding, you should set `remove_unused_columns=False` in your BCOConfig"
633
+ " we have set it for you, but you should do it yourself in the future.",
634
+ UserWarning,
635
+ )
636
+
637
+ self.use_dpo_data_collator = True
638
+ else:
639
+ self.use_dpo_data_collator = False
640
+
641
+ # Disable dropout in the model and reference model
642
+ if args.disable_dropout:
643
+ disable_dropout_in_model(model)
644
+ if self.ref_model is not None:
645
+ disable_dropout_in_model(self.ref_model)
646
+
647
+ self.max_length = max_length
648
+ self.generate_during_eval = args.generate_during_eval
649
+ self.label_pad_token_id = args.label_pad_token_id
650
+ self.padding_value = args.padding_value if args.padding_value is not None else processing_class.pad_token_id
651
+ self.max_prompt_length = max_prompt_length
652
+ self.truncation_mode = args.truncation_mode
653
+ self.max_completion_length = max_completion_length
654
+ self.precompute_ref_log_probs = args.precompute_ref_log_probs
655
+
656
+ # Since ref_logs are precomputed on the first call to get_train/eval_dataloader
657
+ # keep track of first called to avoid computation of future calls
658
+ self._precomputed_train_ref_log_probs = False
659
+ self._precomputed_eval_ref_log_probs = False
660
+
661
+ # metric
662
+ self._stored_metrics = defaultdict(lambda: defaultdict(list))
663
+
664
+ # BCO parameter
665
+ self.beta = args.beta
666
+ self.aux_loss_enabled = getattr(model.config, "output_router_logits", False)
667
+ self.aux_loss_coef = getattr(model.config, "router_aux_loss_coef", 0.0)
668
+ if self.aux_loss_enabled and self.aux_loss_coef == 0.0:
669
+ warnings.warn(
670
+ "You set `output_router_logits` to `True` in the model config, but `router_aux_loss_coef` is set to "
671
+ "`0.0`, meaning the auxiliary loss will not be used. Either set `router_aux_loss_coef` to a value "
672
+ "greater than `0.0`, or set `output_router_logits` to `False` if you don't want to use the auxiliary "
673
+ "loss.",
674
+ UserWarning,
675
+ )
676
+
677
+ # Underlying Distribution Matching argument
678
+ self.embedding_func = embedding_func
679
+ self.embedding_tokenizer = embedding_tokenizer
680
+
681
+ # The trainer estimates the number of FLOPs (floating-point operations) using the number of elements in the
682
+ # input tensor associated with the key "input_ids". However, in BCO, the sampled data does not include the
683
+ # "input_ids" key. Instead, the available keys are "prompt_input_ids" and "completion_input_ids". As a result,
684
+ # the trainer issues the warning: "Could not estimate the number of tokens of the input, floating-point
685
+ # operations will not be computed." To suppress this warning, we set the "estimate_tokens" key in the model's
686
+ # "warnings_issued" dictionary to True. This acts as a flag to indicate that the warning has already been
687
+ # issued.
688
+ model.warnings_issued["estimate_tokens"] = True
689
+
690
+ with PartialState().local_main_process_first():
691
+ # Apply the chat template if needed
692
+ train_dataset = train_dataset.map(
693
+ maybe_apply_chat_template, fn_kwargs={"tokenizer": processing_class}, num_proc=args.dataset_num_proc
694
+ )
695
+ if eval_dataset is not None:
696
+ eval_dataset = eval_dataset.map(
697
+ maybe_apply_chat_template,
698
+ fn_kwargs={"tokenizer": processing_class},
699
+ num_proc=args.dataset_num_proc,
700
+ )
701
+ # Shuffle the datasets
702
+ train_dataset = train_dataset.shuffle(seed=args.data_seed)
703
+ if eval_dataset is not None:
704
+ eval_dataset = eval_dataset.shuffle(seed=args.data_seed)
705
+ # Tokenize and prepare the training datasets
706
+ train_dataset = train_dataset.map(
707
+ _tokenize,
708
+ batched=True,
709
+ fn_kwargs={"tokenizer": processing_class, "embedding_tokenizer": self.embedding_tokenizer},
710
+ num_proc=args.dataset_num_proc,
711
+ desc="Tokenizing train dataset",
712
+ )
713
+
714
+ # Prepare the datasets
715
+ fn_kwargs = {
716
+ "prefix": "",
717
+ "is_encoder_decoder": self.is_encoder_decoder,
718
+ "tokenizer": processing_class,
719
+ "max_length": self.max_length,
720
+ "truncation_mode": self.truncation_mode,
721
+ "label_pad_token_id": self.label_pad_token_id,
722
+ "max_prompt_length": self.max_prompt_length,
723
+ "max_completion_length": self.max_completion_length,
724
+ }
725
+ train_dataset = train_dataset.map(
726
+ _process_tokens,
727
+ fn_kwargs=fn_kwargs,
728
+ num_proc=args.dataset_num_proc,
729
+ desc="Processing tokenized train dataset",
730
+ )
731
+
732
+ if eval_dataset is not None:
733
+ # Tokenize
734
+ eval_dataset = eval_dataset.map(
735
+ _tokenize,
736
+ fn_kwargs={"tokenizer": processing_class, "embedding_tokenizer": self.embedding_tokenizer},
737
+ batched=True,
738
+ num_proc=args.dataset_num_proc,
739
+ desc="Tokenizing eval dataset",
740
+ )
741
+
742
+ # Process
743
+ fn_kwargs = {
744
+ "prefix": "",
745
+ "is_encoder_decoder": self.is_encoder_decoder,
746
+ "tokenizer": processing_class,
747
+ "max_length": self.max_length,
748
+ "truncation_mode": self.truncation_mode,
749
+ "label_pad_token_id": self.label_pad_token_id,
750
+ "max_prompt_length": self.max_prompt_length,
751
+ "max_completion_length": self.max_completion_length,
752
+ }
753
+ eval_dataset = eval_dataset.map(
754
+ _process_tokens,
755
+ fn_kwargs=fn_kwargs,
756
+ num_proc=args.dataset_num_proc,
757
+ desc="Processing tokenized eval dataset",
758
+ )
759
+
760
+ desirable = train_dataset.filter(
761
+ lambda x: x["label"], num_proc=args.dataset_num_proc, desc="Filtering desirable examples"
762
+ )
763
+ undesirable = train_dataset.filter(
764
+ lambda x: not x["label"], num_proc=args.dataset_num_proc, desc="Filtering undesirable examples"
765
+ )
766
+
767
+ desirable = desirable.shuffle(seed=args.data_seed)
768
+ undesirable = undesirable.shuffle(seed=args.data_seed)
769
+
770
+ super().__init__(
771
+ model=model,
772
+ args=args,
773
+ data_collator=data_collator,
774
+ train_dataset=train_dataset,
775
+ eval_dataset=eval_dataset,
776
+ processing_class=processing_class,
777
+ model_init=model_init,
778
+ compute_metrics=compute_metrics,
779
+ callbacks=callbacks,
780
+ optimizers=optimizers,
781
+ preprocess_logits_for_metrics=preprocess_logits_for_metrics,
782
+ )
783
+
784
+ # Gradient accumulation requires scaled loss. Normally, loss scaling in the parent class depends on whether the
785
+ # model accepts loss-related kwargs. Since we compute our own loss, this check is irrelevant. We set
786
+ # self.model_accepts_loss_kwargs to False to enable scaling.
787
+ self.model_accepts_loss_kwargs = False
788
+
789
+ # Add tags for models that have been loaded with the correct transformers version
790
+ if hasattr(self.model, "add_model_tags"):
791
+ self.model.add_model_tags(self._tag_names)
792
+
793
+ if not hasattr(self, "accelerator"):
794
+ raise AttributeError(
795
+ "Your `Trainer` does not have an `accelerator` object. Consider upgrading `transformers`."
796
+ )
797
+
798
+ # Deepspeed Zero-3 does not support precompute_ref_log_probs
799
+ if self.is_deepspeed_enabled:
800
+ if self.accelerator.state.deepspeed_plugin.zero_stage == 3 and self.precompute_ref_log_probs:
801
+ raise ValueError(
802
+ "You cannot use `precompute_ref_log_probs=True` with Deepspeed ZeRO-3. Please set `precompute_ref_log_probs=False`."
803
+ )
804
+
805
+ if self.ref_model is None:
806
+ if not (self.is_peft_model or self.precompute_ref_log_probs):
807
+ raise ValueError(
808
+ "No reference model and model is not a Peft model. Try setting `precompute_ref_log_probs=True`"
809
+ )
810
+ else:
811
+ if self.is_deepspeed_enabled:
812
+ self.ref_model = self._prepare_deepspeed(self.ref_model)
813
+ else:
814
+ self.ref_model = self.accelerator.prepare_model(self.ref_model, evaluation_mode=True)
815
+
816
+ self.running = RunningMoments(accelerator=self.accelerator)
817
+
818
+ if self.embedding_func is None:
819
+ return
820
+
821
+ chosen_embeddings = self._get_sample_prompt_embeddings(desirable, sample_size=self.args.prompt_sample_size)
822
+ rejected_embeddings = self._get_sample_prompt_embeddings(undesirable, sample_size=self.args.prompt_sample_size)
823
+
824
+ embeddings = torch.cat((chosen_embeddings, rejected_embeddings), dim=0)
825
+ labels = torch.cat(
826
+ (torch.ones_like(chosen_embeddings[:, 0]), torch.zeros_like(rejected_embeddings[:, 0])), dim=0
827
+ )
828
+
829
+ self.clf = LogisticRegression(class_weight="balanced").fit(
830
+ embeddings.cpu().float().numpy(), labels.cpu().numpy()
831
+ )
832
+
833
+ @property
834
+ def match_underlying_distribution(self):
835
+ return self.embedding_func is not None and self.embedding_tokenizer is not None
836
+
837
+ def _get_chosen_prob(self, prompt_embeddings: torch.FloatTensor) -> torch.FloatTensor:
838
+ """
839
+ Calculates the probability if the given prompt embedding is from desirable dataset.
840
+ This function calculates the probability in the process and ensemble across processes.
841
+ """
842
+ dtype = prompt_embeddings.dtype
843
+ device = prompt_embeddings.device
844
+ rank = self.accelerator.process_index
845
+
846
+ padded_prompt_embeddings = self.accelerator.pad_across_processes(
847
+ prompt_embeddings, pad_index=self.embedding_tokenizer.pad_token_id
848
+ )
849
+ sample_size = padded_prompt_embeddings.shape[0]
850
+ nonzero = padded_prompt_embeddings.mean(dim=1) != self.embedding_tokenizer.pad_token_id
851
+ prompt_embeddings = self.accelerator.gather(padded_prompt_embeddings)
852
+
853
+ # cannot predict for all empty values
854
+ if prompt_embeddings.shape[0] == 0:
855
+ return torch.tensor([], device=device, dtype=dtype)
856
+
857
+ prob = self.clf.predict_proba(prompt_embeddings.cpu().float().numpy())[:, 1]
858
+ prob = torch.as_tensor(prob, dtype=dtype, device=device)
859
+ prob = self.accelerator.reduce(prob, reduction="mean")
860
+
861
+ prob = prob[sample_size * rank : sample_size * (rank + 1)]
862
+ prob = prob[nonzero]
863
+
864
+ return prob
865
+
866
+ def _vectorize_prompt(self, input_ids: torch.LongTensor, attention_mask: torch.LongTensor) -> torch.FloatTensor:
867
+ """
868
+ Replaces processing_class.pad_token_id to embedding_tokenizer.pad_token_id
869
+ and applies self.embedding_func
870
+ """
871
+ input_ids = torch.where(
872
+ input_ids == self.processing_class.pad_token_id,
873
+ self.embedding_tokenizer.pad_token_id,
874
+ input_ids,
875
+ )
876
+
877
+ with torch.no_grad():
878
+ embeddings = self.embedding_func(
879
+ input_ids=input_ids,
880
+ attention_mask=attention_mask,
881
+ )
882
+
883
+ return embeddings
884
+
885
+ def _get_prompt_embeddings(
886
+ self, batch: dict[str, Union[list, torch.LongTensor]]
887
+ ) -> tuple[torch.FloatTensor, torch.FloatTensor]:
888
+ """Extract embeddings from frozen embedding model"""
889
+
890
+ if not self.match_underlying_distribution:
891
+ return None, None
892
+
893
+ embeddings = self._vectorize_prompt(
894
+ input_ids=batch["embedding_input_ids"],
895
+ attention_mask=batch["embedding_attention_mask"],
896
+ )
897
+
898
+ chosen_idx = [i for i in range(len(batch["label"])) if batch["label"][i] is True]
899
+ rejected_idx = [i for i in range(len(batch["label"])) if batch["label"][i] is False]
900
+
901
+ chosen_embeddings = embeddings[chosen_idx, ...]
902
+ rejected_embeddings = embeddings[rejected_idx, ...]
903
+
904
+ return (chosen_embeddings, rejected_embeddings)
905
+
906
+ def _get_sample_prompt_embeddings(self, dataset: Dataset, sample_size: int = 512) -> torch.FloatTensor:
907
+ """
908
+ Sample instances from dataset and get prompt embeddings.
909
+ Used for density ratio classifier training.
910
+ """
911
+ n_samples = min(len(dataset), sample_size)
912
+ rand_indices = np.random.choice(len(dataset), size=(n_samples,))
913
+
914
+ embedding_dataset = dataset.select(rand_indices)
915
+
916
+ dataloader_params = {
917
+ "batch_size": self.args.per_device_train_batch_size,
918
+ "collate_fn": self.data_collator,
919
+ "num_workers": self.args.dataloader_num_workers,
920
+ "pin_memory": self.args.dataloader_pin_memory,
921
+ "shuffle": False,
922
+ }
923
+
924
+ # prepare dataloader
925
+ data_loader = self.accelerator.prepare(DataLoader(embedding_dataset, **dataloader_params))
926
+
927
+ with torch.no_grad():
928
+ all_embeddings = torch.empty(0)
929
+ for padded_batch in tqdm(iterable=data_loader, desc="Building sample prompt embeddings"):
930
+ embeddings = self._vectorize_prompt(
931
+ input_ids=padded_batch["embedding_input_ids"],
932
+ attention_mask=padded_batch["embedding_attention_mask"],
933
+ )
934
+ embeddings = self.accelerator.gather_for_metrics(embeddings)
935
+ all_embeddings = torch.cat((all_embeddings, embeddings.cpu()))
936
+
937
+ return all_embeddings
938
+
939
+ def _prepare_deepspeed(self, model: PreTrainedModelWrapper):
940
+ # Adapted from accelerate: https://github.com/huggingface/accelerate/blob/739b135f8367becb67ffaada12fe76e3aa60fefd/src/accelerate/accelerator.py#L1473
941
+ deepspeed_plugin = self.accelerator.state.deepspeed_plugin
942
+ config_kwargs = deepcopy(deepspeed_plugin.deepspeed_config)
943
+
944
+ if model is not None:
945
+ if hasattr(model, "config"):
946
+ hidden_size = (
947
+ max(model.config.hidden_sizes)
948
+ if getattr(model.config, "hidden_sizes", None)
949
+ else getattr(model.config, "hidden_size", None)
950
+ )
951
+ if hidden_size is not None and config_kwargs["zero_optimization"]["stage"] == 3:
952
+ # Note that `stage3_prefetch_bucket_size` can produce DeepSpeed messages like: `Invalidate trace cache @ step 0: expected module 1, but got module 0`
953
+ # This is expected and is not an error, see: https://github.com/microsoft/DeepSpeed/discussions/4081
954
+ config_kwargs.update(
955
+ {
956
+ "zero_optimization.reduce_bucket_size": hidden_size * hidden_size,
957
+ "zero_optimization.stage3_param_persistence_threshold": 10 * hidden_size,
958
+ "zero_optimization.stage3_prefetch_bucket_size": 0.9 * hidden_size * hidden_size,
959
+ }
960
+ )
961
+
962
+ # If ZeRO-3 is used, we shard both the active and reference model.
963
+ # Otherwise, we assume the reference model fits in memory and is initialized on each device with ZeRO disabled (stage 0)
964
+ if config_kwargs["zero_optimization"]["stage"] != 3:
965
+ config_kwargs["zero_optimization"]["stage"] = 0
966
+ model, *_ = deepspeed.initialize(model=model, config=config_kwargs)
967
+ model.eval()
968
+ return model
969
+
970
+ def _save_optimizer_and_scheduler(self, output_dir):
971
+ super()._save_optimizer_and_scheduler(output_dir)
972
+
973
+ # When saving optimizer and scheduler to checkpoint, save also the running delta object.
974
+ output_dir = output_dir if output_dir is not None else self.args.output_dir
975
+
976
+ self.running.save_to_json(os.path.join(output_dir, RUNNING_NAME))
977
+
978
+ if self.match_underlying_distribution:
979
+ torch.save(self.clf.get_params(), os.path.join(output_dir, CLF_NAME))
980
+
981
+ def _load_optimizer_and_scheduler(self, checkpoint):
982
+ super()._load_optimizer_and_scheduler(checkpoint)
983
+
984
+ if checkpoint is None:
985
+ return
986
+ # when loading optimizer and scheduler from checkpoint, also load the running delta object.
987
+ running_file = os.path.join(checkpoint, RUNNING_NAME)
988
+ if os.path.isfile(running_file):
989
+ self.running = RunningMoments.load_from_json(self.accelerator, running_file)
990
+
991
+ if self.match_underlying_distribution:
992
+ clf_file = os.path.join(checkpoint, CLF_NAME)
993
+ if os.path.isfile(running_file):
994
+ self.clf.set_params(**torch.load(clf_file, weights_only=True, map_location="cpu"))
995
+
996
+ @contextmanager
997
+ def null_ref_context(self):
998
+ """Context manager for handling null reference model (that is, peft adapter manipulation)."""
999
+ with (
1000
+ self.accelerator.unwrap_model(self.model).disable_adapter()
1001
+ if self.is_peft_model and not self.ref_adapter_name
1002
+ else nullcontext()
1003
+ ):
1004
+ if self.ref_adapter_name:
1005
+ self.model.set_adapter(self.ref_adapter_name)
1006
+ yield
1007
+ if self.ref_adapter_name:
1008
+ self.model.set_adapter(self.model_adapter_name or "default")
1009
+
1010
+ def get_train_dataloader(self) -> DataLoader:
1011
+ """
1012
+ Returns the training [`~torch.utils.data.DataLoader`].
1013
+
1014
+ Subclass of transformers.src.transformers.trainer.get_train_dataloader to precompute `ref_log_probs`.
1015
+ """
1016
+
1017
+ if self.precompute_ref_log_probs and not self._precomputed_train_ref_log_probs:
1018
+ dataloader_params = {
1019
+ "batch_size": self.args.per_device_train_batch_size,
1020
+ "collate_fn": self.data_collator,
1021
+ "num_workers": self.args.dataloader_num_workers,
1022
+ "pin_memory": self.args.dataloader_pin_memory,
1023
+ "shuffle": False,
1024
+ }
1025
+
1026
+ # prepare dataloader
1027
+ data_loader = self.accelerator.prepare(DataLoader(self.train_dataset, **dataloader_params))
1028
+ reference_completion_logps = []
1029
+
1030
+ for padded_batch in tqdm(iterable=data_loader, desc="Train dataset reference log probs"):
1031
+ reference_completion_logp = self.compute_reference_log_probs(padded_batch)
1032
+
1033
+ reference_completion_logp = self.accelerator.gather_for_metrics(reference_completion_logp)
1034
+ reference_completion_logps.append(reference_completion_logp.cpu())
1035
+
1036
+ self.train_dataset = self.train_dataset.add_column(
1037
+ name="reference_logps", column=torch.cat(reference_completion_logps).float().numpy()
1038
+ )
1039
+
1040
+ self._precomputed_train_ref_log_probs = True
1041
+
1042
+ return super().get_train_dataloader()
1043
+
1044
+ def get_eval_dataloader(self, eval_dataset: Optional[Dataset] = None) -> DataLoader:
1045
+ """
1046
+ Returns the evaluation [`~torch.utils.data.DataLoader`].
1047
+
1048
+ Subclass of transformers.src.transformers.trainer.get_eval_dataloader to precompute `ref_log_probs`.
1049
+
1050
+ Args:
1051
+ eval_dataset (`torch.utils.data.Dataset`, *optional*):
1052
+ If provided, will override `self.eval_dataset`. If it is a [`~datasets.Dataset`], columns not accepted
1053
+ by the `model.forward()` method are automatically removed. It must implement `__len__`.
1054
+ """
1055
+ if eval_dataset is None and self.eval_dataset is None:
1056
+ raise ValueError("Trainer: evaluation requires an eval_dataset.")
1057
+ eval_dataset = eval_dataset if eval_dataset is not None else self.eval_dataset
1058
+
1059
+ if self.precompute_ref_log_probs and not self._precomputed_eval_ref_log_probs:
1060
+ dataloader_params = {
1061
+ "batch_size": self.args.per_device_eval_batch_size,
1062
+ "collate_fn": self.data_collator,
1063
+ "num_workers": self.args.dataloader_num_workers,
1064
+ "pin_memory": self.args.dataloader_pin_memory,
1065
+ "shuffle": False,
1066
+ }
1067
+
1068
+ # prepare dataloader
1069
+ data_loader = self.accelerator.prepare(DataLoader(eval_dataset, **dataloader_params))
1070
+
1071
+ reference_completion_logps = []
1072
+
1073
+ for padded_batch in tqdm(iterable=data_loader, desc="Eval dataset reference log probs"):
1074
+ reference_completion_logp = self.compute_reference_log_probs(padded_batch)
1075
+
1076
+ reference_completion_logp = self.accelerator.gather_for_metrics(reference_completion_logp)
1077
+ reference_completion_logps.append(reference_completion_logp.cpu())
1078
+
1079
+ eval_dataset = eval_dataset.add_column(
1080
+ name="reference_logps", column=torch.cat(reference_completion_logps).float().numpy()
1081
+ )
1082
+
1083
+ # Save calculated reference_chosen_logps and reference_rejected_logps to the eval_dataset for subsequent runs
1084
+ if self.eval_dataset is not None:
1085
+ self.eval_dataset = eval_dataset
1086
+ self._precomputed_eval_ref_log_probs = True
1087
+
1088
+ return super().get_eval_dataloader(eval_dataset=eval_dataset)
1089
+
1090
+ def compute_reference_log_probs(self, padded_batch: dict) -> dict:
1091
+ """Computes log probabilities of the reference model for a single padded batch of a BCO specific dataset."""
1092
+ with torch.no_grad():
1093
+ if self.ref_model is None:
1094
+ with self.null_ref_context():
1095
+ if self.is_encoder_decoder:
1096
+ completion_logits = self.model(
1097
+ padded_batch["prompt_input_ids"],
1098
+ attention_mask=padded_batch["prompt_attention_mask"],
1099
+ decoder_input_ids=padded_batch.get("completion_decoder_input_ids"),
1100
+ labels=padded_batch["completion_labels"],
1101
+ ).logits
1102
+
1103
+ else:
1104
+ completion_logits = self.model(
1105
+ padded_batch["completion_input_ids"],
1106
+ attention_mask=padded_batch["completion_attention_mask"],
1107
+ ).logits
1108
+
1109
+ else:
1110
+ if self.is_encoder_decoder:
1111
+ completion_logits = self.ref_model(
1112
+ padded_batch["prompt_input_ids"],
1113
+ attention_mask=padded_batch["prompt_attention_mask"],
1114
+ decoder_input_ids=padded_batch.get("completion_decoder_input_ids"),
1115
+ labels=padded_batch["completion_labels"],
1116
+ ).logits
1117
+
1118
+ else:
1119
+ completion_logits = self.ref_model(
1120
+ padded_batch["completion_input_ids"], attention_mask=padded_batch["completion_attention_mask"]
1121
+ ).logits
1122
+
1123
+ completion_logps = self.get_batch_logps(
1124
+ completion_logits,
1125
+ padded_batch["completion_labels"],
1126
+ average_log_prob=False,
1127
+ is_encoder_decoder=self.is_encoder_decoder,
1128
+ label_pad_token_id=self.label_pad_token_id,
1129
+ )
1130
+
1131
+ return completion_logps
1132
+
1133
+ @staticmethod
1134
+ def get_batch_logps(
1135
+ logits: torch.FloatTensor,
1136
+ labels: torch.LongTensor,
1137
+ average_log_prob: bool = False,
1138
+ label_pad_token_id: int = -100,
1139
+ is_encoder_decoder: bool = False,
1140
+ ) -> torch.FloatTensor:
1141
+ """Compute the log probabilities of the given labels under the given logits.
1142
+
1143
+ Args:
1144
+ logits: Logits of the model (unnormalized). Shape: (batch_size, sequence_length, vocab_size)
1145
+ labels: Labels for which to compute the log probabilities. Label tokens with a value of label_pad_token_id are ignored. Shape: (batch_size, sequence_length)
1146
+ average_log_prob: If True, return the average log probability per (non-masked) token. Otherwise, return the sum of the log probabilities of the (non-masked) tokens.
1147
+
1148
+ Returns:
1149
+ A tensor of shape (batch_size,) containing the average/sum log probabilities of the given labels under the given logits.
1150
+ """
1151
+ if logits.shape[:-1] != labels.shape:
1152
+ raise ValueError("Logits (batch and sequence length dim) and labels must have the same shape.")
1153
+
1154
+ if not is_encoder_decoder:
1155
+ labels = labels[:, 1:].clone()
1156
+ logits = logits[:, :-1, :]
1157
+ else:
1158
+ # Fixes end-dec RuntimeError
1159
+ labels = labels.clone()
1160
+
1161
+ loss_mask = labels != label_pad_token_id
1162
+
1163
+ # dummy token; we'll ignore the losses on these tokens later
1164
+ labels[labels == label_pad_token_id] = 0
1165
+
1166
+ per_token_logps = selective_log_softmax(logits, labels)
1167
+
1168
+ if average_log_prob:
1169
+ return (per_token_logps * loss_mask).sum(-1) / loss_mask.sum(-1)
1170
+ else:
1171
+ return (per_token_logps * loss_mask).sum(-1)
1172
+
1173
+ def forward(
1174
+ self, model: nn.Module, batch: dict[str, Union[list, torch.LongTensor]]
1175
+ ) -> tuple[torch.FloatTensor, torch.FloatTensor, torch.FloatTensor, torch.FloatTensor]:
1176
+ model_kwargs = (
1177
+ {
1178
+ "labels": batch["completion_labels"],
1179
+ "decoder_input_ids": batch.get("completion_decoder_input_ids"),
1180
+ }
1181
+ if self.is_encoder_decoder
1182
+ else {}
1183
+ )
1184
+ if self.aux_loss_enabled:
1185
+ model_kwargs["output_router_logits"] = True
1186
+
1187
+ outputs = model(
1188
+ batch["completion_input_ids"],
1189
+ attention_mask=batch["completion_attention_mask"],
1190
+ **model_kwargs,
1191
+ )
1192
+ completion_logits = outputs.logits
1193
+
1194
+ completion_logps = self.get_batch_logps(
1195
+ completion_logits,
1196
+ batch["completion_labels"],
1197
+ average_log_prob=False,
1198
+ is_encoder_decoder=self.is_encoder_decoder,
1199
+ label_pad_token_id=self.label_pad_token_id,
1200
+ )
1201
+
1202
+ if completion_logps.shape[0] != len(batch["label"]):
1203
+ raise ValueError(
1204
+ "There is a mismatch between the number of examples in this batch and the number of "
1205
+ "examples for which an output sequence was predicted."
1206
+ )
1207
+
1208
+ chosen_idx = [i for i in range(completion_logps.shape[0]) if batch["label"][i] is True]
1209
+ rejected_idx = [i for i in range(completion_logps.shape[0]) if batch["label"][i] is False]
1210
+
1211
+ chosen_logps = completion_logps[chosen_idx, ...]
1212
+ rejected_logps = completion_logps[rejected_idx, ...]
1213
+
1214
+ chosen_logits = completion_logits[chosen_idx, ...]
1215
+ rejected_logits = completion_logits[rejected_idx, ...]
1216
+
1217
+ if self.aux_loss_enabled:
1218
+ return (chosen_logps, rejected_logps, chosen_logits, rejected_logits, outputs.aux_loss)
1219
+ else:
1220
+ return (chosen_logps, rejected_logps, chosen_logits, rejected_logits)
1221
+
1222
+ def _get_udm_weight(self, rejected_embeddings: torch.FloatTensor) -> torch.FloatTensor:
1223
+ prob_desirable = self._get_chosen_prob(rejected_embeddings)
1224
+ min_ratio = self.args.min_density_ratio
1225
+ max_ratio = self.args.max_density_ratio
1226
+
1227
+ weight = (prob_desirable / (1 - prob_desirable + 1e-8)).clamp(min=min_ratio, max=max_ratio)
1228
+
1229
+ return weight
1230
+
1231
+ def bco_loss(
1232
+ self,
1233
+ policy_chosen_logps: torch.FloatTensor,
1234
+ policy_rejected_logps: torch.FloatTensor,
1235
+ reference_chosen_logps: torch.FloatTensor,
1236
+ reference_rejected_logps: torch.FloatTensor,
1237
+ chosen_embeddings: Optional[torch.FloatTensor],
1238
+ rejected_embeddings: Optional[torch.FloatTensor],
1239
+ ) -> tuple[torch.FloatTensor, torch.FloatTensor, torch.FloatTensor, torch.FloatTensor]:
1240
+ """Compute the BCO loss for a batch of policy and reference model log probabilities.
1241
+
1242
+ Args:
1243
+ policy_chosen_logps: Log probabilities of the policy model for the chosen responses. Shape: (num(chosen) in batch_size,)
1244
+ policy_rejected_logps: Log probabilities of the policy model for the rejected responses. Shape: (num(rejected) in batch_size,)
1245
+ reference_chosen_logps: Log probabilities of the reference model for the chosen responses. Shape: (num(chosen) in batch_size,)
1246
+ reference_rejected_logps: Log probabilities of the reference model for the rejected responses. Shape: (num(rejected) in batch_size,)
1247
+ chosen_embeddings: embeddings of desirable prompts
1248
+ rejected_embeddings: embeddings of undesirable prompts
1249
+
1250
+ Returns:
1251
+ A tuple of four tensors: (losses, chosen_rewards, rejected_rewards, delta).
1252
+ The losses tensor contains the BCO loss for each example in the batch.
1253
+ The chosen_rewards and rejected_rewards tensors contain the rewards for the chosen and rejected responses, respectively.
1254
+ The delta value contains the moving average of all implicit rewards.
1255
+ """
1256
+
1257
+ if policy_chosen_logps.shape[0] != 0 or reference_chosen_logps.shape[0] != 0:
1258
+ chosen_logratios = policy_chosen_logps - reference_chosen_logps
1259
+ chosen_rewards = self.beta * chosen_logratios
1260
+ else:
1261
+ # lists can't be empty -- if they are, then accelerate.gather will hang
1262
+ chosen_losses = torch.Tensor([]).to(self.accelerator.device)
1263
+ chosen_rewards = torch.Tensor([]).to(self.accelerator.device)
1264
+
1265
+ if policy_rejected_logps.shape[0] != 0 or reference_rejected_logps.shape[0] != 0:
1266
+ rejected_logratios = policy_rejected_logps - reference_rejected_logps
1267
+ rejected_rewards = self.beta * rejected_logratios
1268
+ else:
1269
+ # lists can't be empty -- if they are, then accelerate.gather will hang
1270
+ rejected_losses = torch.Tensor([]).to(self.accelerator.device)
1271
+ rejected_rewards = torch.Tensor([]).to(self.accelerator.device)
1272
+
1273
+ rewards = torch.cat((chosen_rewards, rejected_rewards), 0).mean().detach()
1274
+ self.running.update(rewards)
1275
+ delta = self.running.mean
1276
+
1277
+ if policy_chosen_logps.shape[0] != 0 or reference_chosen_logps.shape[0] != 0:
1278
+ chosen_losses = -F.logsigmoid(chosen_rewards - delta)
1279
+
1280
+ if policy_rejected_logps.shape[0] != 0 or reference_rejected_logps.shape[0] != 0:
1281
+ rejected_losses = -F.logsigmoid(-(rejected_rewards - delta))
1282
+
1283
+ if self.match_underlying_distribution:
1284
+ chosen_weight = torch.ones_like(chosen_losses)
1285
+ rejected_weight = self._get_udm_weight(rejected_embeddings)
1286
+
1287
+ losses = torch.cat((chosen_weight * chosen_losses, rejected_weight * rejected_losses), dim=0)
1288
+ else:
1289
+ losses = torch.cat((chosen_losses, rejected_losses), dim=0)
1290
+
1291
+ return losses, chosen_rewards, rejected_rewards, torch.as_tensor(delta)
1292
+
1293
+ def get_batch_loss_metrics(
1294
+ self,
1295
+ model,
1296
+ batch: dict[str, Union[list, torch.LongTensor]],
1297
+ ):
1298
+ """Compute the BCO loss and other metrics for the given batch of inputs for train or test."""
1299
+ metrics = {}
1300
+ batch = {k: (v.to(self.accelerator.device) if isinstance(v, torch.Tensor) else v) for k, v in batch.items()}
1301
+
1302
+ forward_output = self.forward(model, batch)
1303
+ (
1304
+ policy_chosen_logps,
1305
+ policy_rejected_logps,
1306
+ policy_chosen_logits,
1307
+ policy_rejected_logits,
1308
+ ) = forward_output[:4]
1309
+ if self.aux_loss_enabled:
1310
+ aux_loss = forward_output[4]
1311
+
1312
+ # if reference_logps in batch use them, otherwise use the reference model
1313
+ if "reference_logps" in batch:
1314
+ chosen_idx = [i for i in range(batch["reference_logps"].shape[0]) if batch["label"][i] is True]
1315
+ rejected_idx = [i for i in range(batch["reference_logps"].shape[0]) if batch["label"][i] is False]
1316
+
1317
+ reference_chosen_logps = batch["reference_logps"][chosen_idx, ...]
1318
+ reference_rejected_logps = batch["reference_logps"][rejected_idx, ...]
1319
+ else:
1320
+ with torch.no_grad():
1321
+ if self.ref_model is None:
1322
+ with self.null_ref_context():
1323
+ (
1324
+ reference_chosen_logps,
1325
+ reference_rejected_logps,
1326
+ _,
1327
+ _,
1328
+ ) = self.forward(self.model, batch)[:4]
1329
+ else:
1330
+ (
1331
+ reference_chosen_logps,
1332
+ reference_rejected_logps,
1333
+ _,
1334
+ _,
1335
+ ) = self.forward(self.ref_model, batch)[:4]
1336
+
1337
+ chosen_embeddings, rejected_embeddings = self._get_prompt_embeddings(batch)
1338
+
1339
+ losses, chosen_rewards, rejected_rewards, delta = self.bco_loss(
1340
+ policy_chosen_logps,
1341
+ policy_rejected_logps,
1342
+ reference_chosen_logps,
1343
+ reference_rejected_logps,
1344
+ chosen_embeddings,
1345
+ rejected_embeddings,
1346
+ )
1347
+ metrics["delta"] = self.accelerator.gather_for_metrics(delta).mean().item()
1348
+
1349
+ num_chosen = torch.Tensor([len(chosen_rewards)]).to(self.accelerator.device)
1350
+ num_rejected = torch.Tensor([len(rejected_rewards)]).to(self.accelerator.device)
1351
+
1352
+ all_num_chosen = self.accelerator.gather_for_metrics(num_chosen).sum().item()
1353
+ all_num_rejected = self.accelerator.gather_for_metrics(num_rejected).sum().item()
1354
+
1355
+ if all_num_chosen > 0:
1356
+ metrics["rewards/chosen_sum"] = (
1357
+ self.accelerator.gather_for_metrics(chosen_rewards.nansum()).nansum().item()
1358
+ )
1359
+ metrics["logps/chosen_sum"] = (
1360
+ self.accelerator.gather_for_metrics(policy_chosen_logps.nansum()).nansum().item()
1361
+ )
1362
+ metrics["logits/chosen_sum"] = (
1363
+ self.accelerator.gather_for_metrics(policy_chosen_logits.nansum()).nansum().item()
1364
+ )
1365
+ metrics["count/chosen"] = all_num_chosen
1366
+
1367
+ if all_num_rejected > 0:
1368
+ metrics["rewards/rejected_sum"] = (
1369
+ self.accelerator.gather_for_metrics(rejected_rewards.nansum()).nansum().item()
1370
+ )
1371
+ metrics["logps/rejected_sum"] = (
1372
+ self.accelerator.gather_for_metrics(policy_rejected_logps.nansum()).nansum().item()
1373
+ )
1374
+ metrics["logits/rejected_sum"] = (
1375
+ self.accelerator.gather_for_metrics(policy_rejected_logits.nansum()).nansum().item()
1376
+ )
1377
+ metrics["count/rejected"] = all_num_rejected
1378
+
1379
+ loss = losses.nanmean()
1380
+ if self.aux_loss_enabled:
1381
+ loss += self.aux_loss_coef * aux_loss
1382
+
1383
+ return loss, metrics
1384
+
1385
+ def compute_loss(
1386
+ self,
1387
+ model: Union[PreTrainedModel, nn.Module],
1388
+ inputs: dict[str, Union[torch.Tensor, Any]],
1389
+ return_outputs=False,
1390
+ num_items_in_batch=None,
1391
+ ) -> Union[torch.Tensor, tuple[torch.Tensor, dict[str, torch.Tensor]]]:
1392
+ compute_loss_context_manager = amp.autocast("cuda") if self._peft_has_been_casted_to_bf16 else nullcontext()
1393
+
1394
+ with compute_loss_context_manager:
1395
+ loss, metrics = self.get_batch_loss_metrics(model, inputs)
1396
+
1397
+ # Make sure to move the loss to the device the original accumulating loss is at back in the `Trainer` class:
1398
+ loss = loss.to(self.args.device)
1399
+ # force log the metrics
1400
+ if self.accelerator.is_main_process:
1401
+ self.store_metrics(metrics, train_eval="train")
1402
+
1403
+ if return_outputs:
1404
+ return (loss, metrics)
1405
+ return loss
1406
+
1407
+ def store_metrics(self, metrics: dict[str, float], train_eval: Literal["train", "eval"] = "train") -> None:
1408
+ for key, value in metrics.items():
1409
+ self._stored_metrics[train_eval][key].append(value)
1410
+
1411
+ def _get_train_sampler(self) -> Optional[torch.utils.data.Sampler]:
1412
+ if self.train_dataset is None or not has_length(self.train_dataset):
1413
+ return None
1414
+ return SequentialSampler(self.train_dataset)
1415
+
1416
+ def generate_from_model_and_ref(self, model, batch: dict[str, torch.LongTensor]) -> tuple[str, str]:
1417
+ """Generate samples from the model and reference model for the given batch of inputs."""
1418
+
1419
+ # If one uses `generate_during_eval` with peft + bf16, we need to explicitly call generate with
1420
+ # the torch cuda amp context manager as some hidden states are silently casted to full precision.
1421
+ generate_context_manager = amp.autocast("cuda") if self._peft_has_been_casted_to_bf16 else nullcontext()
1422
+ with generate_context_manager:
1423
+ policy_output = model.generate(
1424
+ input_ids=batch["prompt_input_ids"],
1425
+ attention_mask=batch["prompt_attention_mask"],
1426
+ max_length=self.max_length,
1427
+ do_sample=True,
1428
+ pad_token_id=self.processing_class.pad_token_id,
1429
+ )
1430
+
1431
+ # if reference_output in batch use that otherwise use the reference model
1432
+ if "reference_output" in batch:
1433
+ reference_output = batch["reference_output"]
1434
+ else:
1435
+ if self.ref_model is None:
1436
+ with self.null_ref_context():
1437
+ reference_output = self.model.generate(
1438
+ input_ids=batch["prompt_input_ids"],
1439
+ attention_mask=batch["prompt_attention_mask"],
1440
+ max_length=self.max_length,
1441
+ do_sample=True,
1442
+ pad_token_id=self.processing_class.pad_token_id,
1443
+ )
1444
+ else:
1445
+ reference_output = self.ref_model.generate(
1446
+ input_ids=batch["prompt_input_ids"],
1447
+ attention_mask=batch["prompt_attention_mask"],
1448
+ max_length=self.max_length,
1449
+ do_sample=True,
1450
+ pad_token_id=self.processing_class.pad_token_id,
1451
+ )
1452
+
1453
+ policy_output = pad_to_length(policy_output, self.max_length, self.processing_class.pad_token_id)
1454
+ policy_output_decoded = self.processing_class.batch_decode(policy_output, skip_special_tokens=True)
1455
+
1456
+ reference_output = pad_to_length(reference_output, self.max_length, self.processing_class.pad_token_id)
1457
+ reference_output_decoded = self.processing_class.batch_decode(reference_output, skip_special_tokens=True)
1458
+
1459
+ return policy_output_decoded, reference_output_decoded
1460
+
1461
+ def prediction_step(
1462
+ self,
1463
+ model: Union[PreTrainedModel, nn.Module],
1464
+ inputs: dict[str, Union[torch.Tensor, Any]],
1465
+ prediction_loss_only: bool,
1466
+ ignore_keys: Optional[list[str]] = None,
1467
+ ):
1468
+ if ignore_keys is None:
1469
+ if hasattr(model, "config"):
1470
+ ignore_keys = getattr(model.config, "keys_to_ignore_at_inference", [])
1471
+ else:
1472
+ ignore_keys = []
1473
+
1474
+ prediction_context_manager = amp.autocast("cuda") if self._peft_has_been_casted_to_bf16 else nullcontext()
1475
+ with torch.no_grad(), prediction_context_manager:
1476
+ loss, metrics = self.get_batch_loss_metrics(model, inputs)
1477
+
1478
+ # force log the metrics
1479
+ if self.accelerator.is_main_process:
1480
+ self.store_metrics(metrics, train_eval="eval")
1481
+
1482
+ if prediction_loss_only:
1483
+ return (loss.detach(), None, None)
1484
+
1485
+ # logits for the chosen and rejected samples from model
1486
+ logits_dict = {
1487
+ "eval_logits/chosen": metrics["logits/chosen"],
1488
+ "eval_logits/rejected": metrics["logits/rejected"],
1489
+ }
1490
+ logits = tuple(v.unsqueeze(dim=0) for k, v in logits_dict.items() if k not in ignore_keys)
1491
+ logits = torch.stack(logits).mean(axis=1).to(self.accelerator.device)
1492
+ labels = torch.zeros(logits.shape[0], device=self.accelerator.device)
1493
+
1494
+ return (loss.detach(), logits, labels)
1495
+
1496
+ def evaluation_loop(
1497
+ self,
1498
+ dataloader: DataLoader,
1499
+ description: str,
1500
+ prediction_loss_only: Optional[bool] = None,
1501
+ ignore_keys: Optional[list[str]] = None,
1502
+ metric_key_prefix: str = "eval",
1503
+ ) -> EvalLoopOutput:
1504
+ """
1505
+ Overriding built-in evaluation loop to store metrics for each batch.
1506
+ Prediction/evaluation loop, shared by `Trainer.evaluate()` and `Trainer.predict()`.
1507
+
1508
+ Works both with or without labels.
1509
+ """
1510
+
1511
+ # Sample and save to game log if requested (for one batch to save time)
1512
+ if self.generate_during_eval:
1513
+ # Generate random indices within the range of the total number of samples
1514
+ num_samples = len(dataloader.dataset)
1515
+ random_indices = random.sample(range(num_samples), k=self.args.eval_batch_size)
1516
+
1517
+ # Use dataloader.dataset.select to get the random batch without iterating over the DataLoader
1518
+ random_batch_dataset = dataloader.dataset.select(random_indices)
1519
+ random_batch = self.data_collator(random_batch_dataset)
1520
+ random_batch = self._prepare_inputs(random_batch)
1521
+
1522
+ target_indicies = [i for i in range(len(random_batch["label"])) if random_batch["label"][i] is False]
1523
+ target_batch = {
1524
+ "prompt_input_ids": random_batch["prompt_input_ids"][target_indicies],
1525
+ "prompt_attention_mask": random_batch["prompt_attention_mask"][target_indicies],
1526
+ "prompt": itemgetter(*target_indicies)(random_batch["prompt"]),
1527
+ }
1528
+ policy_output_decoded, ref_output_decoded = self.generate_from_model_and_ref(self.model, target_batch)
1529
+
1530
+ table = pd.DataFrame(
1531
+ columns=["Prompt", "Policy", "Ref Model"],
1532
+ data=[
1533
+ [prompt, pol[len(prompt) :], ref[len(prompt) :]]
1534
+ for prompt, pol, ref in zip(target_batch["prompt"], policy_output_decoded, ref_output_decoded)
1535
+ ],
1536
+ )
1537
+ if "wandb" in self.args.report_to:
1538
+ wandb.log({"game_log": wandb.Table(data=table)})
1539
+
1540
+ if "comet_ml" in self.args.report_to:
1541
+ log_table_to_comet_experiment(
1542
+ name="game_log.csv",
1543
+ table=table,
1544
+ )
1545
+
1546
+ # Base evaluation
1547
+ initial_output = super().evaluation_loop(
1548
+ dataloader, description, prediction_loss_only, ignore_keys, metric_key_prefix
1549
+ )
1550
+
1551
+ return initial_output
1552
+
1553
+ def log(self, logs: dict[str, float], start_time: Optional[float] = None) -> None:
1554
+ """
1555
+ Log `logs` on the various objects watching training, including stored metrics.
1556
+
1557
+ Args:
1558
+ logs (`dict[str, float]`):
1559
+ The values to log.
1560
+ start_time (`float` or `None`, *optional*, defaults to `None`):
1561
+ Start time of the training.
1562
+ """
1563
+ # logs either has 'loss' or 'eval_loss'
1564
+ train_eval = "train" if "loss" in logs else "eval"
1565
+ # train metrics should have no prefix, eval should have 'eval_'
1566
+ prefix = "eval_" if train_eval == "eval" else ""
1567
+ # accumulate average metrics from sums and lengths
1568
+ for split in ["chosen", "rejected"]:
1569
+ if f"count/{split}" in self._stored_metrics[train_eval]:
1570
+ count_sum = torch.Tensor(self._stored_metrics[train_eval][f"count/{split}"]).sum().item()
1571
+ for metric in ["rewards", "logps", "logits"]:
1572
+ logs[f"{prefix}{metric}/{split}"] = (
1573
+ torch.Tensor(self._stored_metrics[train_eval][f"{metric}/{split}_sum"]).sum().item()
1574
+ / count_sum
1575
+ )
1576
+ # delete obsolete metric
1577
+ del self._stored_metrics[train_eval][f"{metric}/{split}_sum"]
1578
+ del self._stored_metrics[train_eval][f"count/{split}"]
1579
+ # calculate reward margin
1580
+ if f"{prefix}rewards/chosen" in logs and f"{prefix}rewards/rejected" in logs:
1581
+ logs[f"{prefix}rewards/margins"] = logs[f"{prefix}rewards/chosen"] - logs[f"{prefix}rewards/rejected"]
1582
+ # Add averaged stored metrics to logs
1583
+ for key, metrics in self._stored_metrics[train_eval].items():
1584
+ logs[f"{prefix}{key}"] = torch.Tensor(metrics).mean().item()
1585
+ del self._stored_metrics[train_eval]
1586
+
1587
+ if version.parse(transformers.__version__) >= version.parse("4.47.0.dev0"):
1588
+ return super().log(logs, start_time)
1589
+ else: # transformers<=4.46
1590
+ return super().log(logs)
1591
+
1592
+ def create_model_card(
1593
+ self,
1594
+ model_name: Optional[str] = None,
1595
+ dataset_name: Optional[str] = None,
1596
+ tags: Union[str, list[str], None] = None,
1597
+ ):
1598
+ """
1599
+ Creates a draft of a model card using the information available to the `Trainer`.
1600
+
1601
+ Args:
1602
+ model_name (`str` or `None`, *optional*, defaults to `None`):
1603
+ Name of the model.
1604
+ dataset_name (`str` or `None`, *optional*, defaults to `None`):
1605
+ Name of the dataset used for training.
1606
+ tags (`str`, `list[str]` or `None`, *optional*, defaults to `None`):
1607
+ Tags to be associated with the model card.
1608
+ """
1609
+ if not self.is_world_process_zero():
1610
+ return
1611
+
1612
+ if hasattr(self.model.config, "_name_or_path") and not os.path.isdir(self.model.config._name_or_path):
1613
+ base_model = self.model.config._name_or_path
1614
+ else:
1615
+ base_model = None
1616
+
1617
+ tags = tags or []
1618
+ if isinstance(tags, str):
1619
+ tags = [tags]
1620
+
1621
+ if hasattr(self.model.config, "unsloth_version"):
1622
+ tags.append("unsloth")
1623
+
1624
+ citation = textwrap.dedent("""\
1625
+ @article{jung2024binary,
1626
+ title = {{Binary Classifier Optimization for Large Language Model Alignment}},
1627
+ author = {Seungjae Jung and Gunsoo Han and Daniel Wontae Nam and Kyoung{-}Woon On},
1628
+ year = 2024,
1629
+ eprint = {arXiv:2404.04656}
1630
+ }""")
1631
+
1632
+ model_card = generate_model_card(
1633
+ base_model=base_model,
1634
+ model_name=model_name,
1635
+ hub_model_id=self.hub_model_id,
1636
+ dataset_name=dataset_name,
1637
+ tags=tags,
1638
+ wandb_url=wandb.run.get_url() if is_wandb_available() and wandb.run is not None else None,
1639
+ comet_url=get_comet_experiment_url(),
1640
+ trainer_name="BCO",
1641
+ trainer_citation=citation,
1642
+ paper_title="Binary Classifier Optimization for Large Language Model Alignment",
1643
+ paper_id="2404.04656",
1644
+ )
1645
+
1646
+ model_card.save(os.path.join(self.args.output_dir, "README.md"))
1647
+ class UnslothBCOTrainer(_UnslothBCOTrainer):
1648
+ """
1649
+
1650
+ Initialize BCOTrainer from [BCO](https://huggingface.co/papers/2404.04656) paper.
1651
+
1652
+ Args:
1653
+ model (`transformers.PreTrainedModel`):
1654
+ The model to train, preferably an `AutoModelForSequenceClassification`.
1655
+ ref_model (`PreTrainedModelWrapper`):
1656
+ Hugging Face transformer model with a casual language modelling head. Used for implicit reward computation and loss. If no
1657
+ reference model is provided, the trainer will create a reference model with the same architecture as the model to be optimized.
1658
+ args (`BCOConfig`):
1659
+ The arguments to use for training.
1660
+ train_dataset (`datasets.Dataset`):
1661
+ The dataset to use for training.
1662
+ eval_dataset (`datasets.Dataset`):
1663
+ The dataset to use for evaluation.
1664
+ processing_class (`PreTrainedTokenizerBase` or `BaseImageProcessor` or `FeatureExtractionMixin` or `ProcessorMixin`, *optional*):
1665
+ Processing class used to process the data. If provided, will be used to automatically process the inputs
1666
+ for the model, and it will be saved along the model to make it easier to rerun an interrupted training or
1667
+ reuse the fine-tuned model.
1668
+ data_collator (`transformers.DataCollator`, *optional*, defaults to `None`):
1669
+ The data collator to use for training. If None is specified, the default data collator (`DPODataCollatorWithPadding`) will be used
1670
+ which will pad the sequences to the maximum length of the sequences in the batch, given a dataset of paired sequences.
1671
+ model_init (`Callable[[], transformers.PreTrainedModel]`):
1672
+ The model initializer to use for training. If None is specified, the default model initializer will be used.
1673
+ callbacks (`list[transformers.TrainerCallback]`):
1674
+ The callbacks to use for training.
1675
+ optimizers (`tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR]`):
1676
+ The optimizer and scheduler to use for training.
1677
+ preprocess_logits_for_metrics (`Callable[[torch.Tensor, torch.Tensor], torch.Tensor]`):
1678
+ The function to use to preprocess the logits before computing the metrics.
1679
+ peft_config (`dict`, defaults to `None`):
1680
+ The PEFT configuration to use for training. If you pass a PEFT configuration, the model will be wrapped in a PEFT model.
1681
+ compute_metrics (`Callable[[EvalPrediction], dict]`, *optional*):
1682
+ The function to use to compute the metrics. Must take a `EvalPrediction` and return
1683
+ a dictionary string to metric values.
1684
+ model_adapter_name (`str`, defaults to `None`):
1685
+ Name of the train target PEFT adapter, when using LoRA with multiple adapters.
1686
+ ref_adapter_name (`str`, defaults to `None`):
1687
+ Name of the reference PEFT adapter, when using LoRA with multiple adapters.
1688
+
1689
+ """
1690
+ def __init__(
1691
+ self,
1692
+ model = None,
1693
+ ref_model = None,
1694
+ args = None,
1695
+ train_dataset = None,
1696
+ eval_dataset = None,
1697
+ processing_class = None,
1698
+ data_collator = None,
1699
+ model_init = None,
1700
+ callbacks = None,
1701
+ preprocess_logits_for_metrics = None,
1702
+ peft_config = None,
1703
+ compute_metrics = None,
1704
+ model_adapter_name = None,
1705
+ ref_adapter_name = None,
1706
+ embedding_func = None,
1707
+ embedding_tokenizer = None,
1708
+ **kwargs
1709
+ ):
1710
+ if args is None: args = UnslothBCOConfig()
1711
+ use_bf16 = getattr(args, 'bf16', False)
1712
+ use_fp16 = getattr(args, 'fp16', False)
1713
+ force_float32 = False
1714
+ if os.environ.get('UNSLOTH_FORCE_FLOAT32', '0') == '1':
1715
+ print('Unsloth: Switching to float32 training since model cannot work with float16')
1716
+ force_float32 = True
1717
+ mixed_precision_dtype = os.environ.get('UNSLOTH_MIXED_PRECISION', 'float32')
1718
+ dtype = getattr(model.config, 'torch_dtype', None)
1719
+ if dtype is None: dtype = model.get_input_embeddings().dtype
1720
+ from unsloth_zoo.utils import _get_dtype
1721
+ dtype = _get_dtype(dtype)
1722
+ float16 = dtype == torch.float16
1723
+ if not force_float32 and (float16 and use_bf16): raise TypeError('Unsloth: Model is in float16 precision but you want to use bfloat16 precision. Set fp16 to `True` and bf16 to `False`')
1724
+ if not force_float32 and (not float16 and use_fp16): raise TypeError('Unsloth: Model is in bfloat16 precision but you want to use float16 precision. Set fp16 to `False` and bf16 to `True`')
1725
+ if force_float32:
1726
+ args.fp16 = False
1727
+ args.bf16 = False
1728
+ os.environ['ACCELERATE_MIXED_PRECISION'] = 'no'
1729
+ elif (not use_bf16 and not use_fp16) and mixed_precision_dtype == 'float32':
1730
+ args.fp16 = float16
1731
+ args.bf16 = not float16
1732
+ os.environ['ACCELERATE_MIXED_PRECISION'] = 'fp16' if float16 else 'bf16'
1733
+ if getattr(args, 'eval_dataset', None) is not None and getattr(args, 'eval_strategy', 'no') == 'no':
1734
+ args.eval_strategy = 'steps'
1735
+ if getattr(args, 'eval_steps', None) is None: args.eval_steps = 0.1
1736
+ ga_steps = getattr(args, 'gradient_accumulation_steps', None)
1737
+ if ga_steps is not None and ga_steps > 1:
1738
+ from transformers import __version__ as transformers_version
1739
+ if Version(transformers_version) <= Version('4.45.2'):
1740
+ print('**** Unsloth: Please use our fixed gradient_accumulation_steps by updating transformers, TRL and Unsloth!\n'
1741
+ '`pip install --upgrade --no-cache-dir --force-reinstall --no-deps unsloth transformers trl unsloth_zoo`')
1742
+ if getattr(args, 'eval_strategy', 'no') != 'no':
1743
+ eval_bsz = getattr(args, 'per_device_eval_batch_size', 8)
1744
+ if eval_bsz == 8 and args.per_device_train_batch_size < eval_bsz: args.per_device_eval_batch_size = args.per_device_train_batch_size
1745
+ if getattr(args, 'eval_accumulation_steps', None) is None and ga_steps is not None: args.eval_accumulation_steps = ga_steps
1746
+ fp16_full_eval = getattr(args, 'fp16_full_eval', False)
1747
+ bf16_full_eval = getattr(args, 'bf16_full_eval', False)
1748
+ if args.fp16 and bf16_full_eval: args.bf16_full_eval = False; args.fp16_full_eval = True
1749
+ if args.bf16 and fp16_full_eval: args.bf16_full_eval = True; args.fp16_full_eval = False
1750
+ if force_float32:
1751
+ args.bf16_full_eval = False
1752
+ args.fp16_full_eval = False
1753
+ elif os.environ.get('UNSLOTH_MIXED_PRECISION', 'float32') == 'bfloat16':
1754
+ args.bf16_full_eval = True
1755
+ args.fp16_full_eval = False
1756
+ elif not bf16_full_eval and not fp16_full_eval:
1757
+ args.bf16_full_eval = args.bf16
1758
+ args.fp16_full_eval = args.fp16
1759
+ _output_logits = False
1760
+ if locals().get('compute_metrics', None) is not None: _output_logits = True
1761
+ if locals().get('preprocess_logits_for_metrics', None) is not None: _output_logits = True
1762
+ if _output_logits:
1763
+ os.environ['UNSLOTH_RETURN_LOGITS'] = '1'
1764
+ if 'max_seq_length' not in locals() and not hasattr(args, 'max_seq_length'):
1765
+ pass
1766
+ else:
1767
+ model_max_seq_length = getattr(model, 'max_seq_length', None)
1768
+ args_max_seq_length = getattr(args, 'max_seq_length', None)
1769
+ if args_max_seq_length is None and model_max_seq_length is not None:
1770
+ max_seq_length = model.max_seq_length
1771
+ if hasattr(args, 'max_seq_length'): args.max_seq_length = max_seq_length
1772
+ if model is not None and hasattr(model, 'for_training'):
1773
+ model.for_training()
1774
+ if 'tokenizer' in locals() and hasattr(tokenizer, 'padding_side'): tokenizer.padding_side = 'right'
1775
+ if 'processing_class' in locals():
1776
+ if hasattr(processing_class, 'padding_side'): processing_class.padding_side = 'right'
1777
+ if hasattr(processing_class, 'tokenizer') and hasattr(processing_class.tokenizer, 'padding_side'): processing_class.tokenizer.padding_side = 'right'
1778
+ __tokenizer = processing_class if 'processing_class' in locals() else tokenizer
1779
+ from unsloth_zoo.vision_utils import UnslothVisionDataCollator
1780
+ if not isinstance(data_collator, UnslothVisionDataCollator):
1781
+ if isinstance(data_collator, DataCollatorForSeq2Seq) and 'labels' not in train_dataset.column_names:
1782
+ data_collator = DataCollatorForLanguageModeling(__tokenizer, mlm = False)
1783
+ elif isinstance(data_collator, DataCollatorForLanguageModeling) and 'labels' in train_dataset.column_names:
1784
+ data_collator = DataCollatorForSeq2Seq(__tokenizer)
1785
+ else:
1786
+ if hasattr(args, 'remove_unused_columns'): args.remove_unused_columns = False
1787
+ if hasattr(args, 'dataset_text_field'): args.dataset_text_field = ''
1788
+ if hasattr(args, 'dataset_kwargs'): args.dataset_kwargs = {'skip_prepare_dataset': True}
1789
+ if not isinstance(data_collator, UnslothVisionDataCollator):
1790
+ if not hasattr(__tokenizer, 'pad') and hasattr(__tokenizer, 'tokenizer'):
1791
+ if isinstance(data_collator, DataCollatorForSeq2Seq):
1792
+ data_collator = DataCollatorForSeq2Seq(__tokenizer.tokenizer)
1793
+ else:
1794
+ data_collator = DataCollatorForLanguageModeling(__tokenizer.tokenizer, mlm = False)
1795
+ other_metrics = []
1796
+
1797
+ from unsloth_zoo.logging_utils import PatchRLStatistics
1798
+ PatchRLStatistics('bco_trainer', other_metrics)
1799
+
1800
+ super().__init__(
1801
+ model = model,
1802
+ ref_model = ref_model,
1803
+ args = args,
1804
+ train_dataset = train_dataset,
1805
+ eval_dataset = eval_dataset,
1806
+ processing_class = processing_class,
1807
+ data_collator = data_collator,
1808
+ model_init = model_init,
1809
+ callbacks = callbacks,
1810
+ preprocess_logits_for_metrics = preprocess_logits_for_metrics,
1811
+ peft_config = peft_config,
1812
+ compute_metrics = compute_metrics,
1813
+ model_adapter_name = model_adapter_name,
1814
+ ref_adapter_name = ref_adapter_name,
1815
+ embedding_func = embedding_func,
1816
+ embedding_tokenizer = embedding_tokenizer,**kwargs)
1817
+ if hasattr(self, 'neftune_hook_handle'):
1818
+ self.neftune_hook_handle.remove()
1819
+ if hasattr(self, 'neftune_hook_handle'): del self.neftune_hook_handle
1820
+ if getattr(args, 'neftune_noise_alpha', None) is not None:
1821
+ model.get_input_embeddings().neftune_noise_alpha = self.neftune_noise_alpha
1822
+ pass
1823
+
1824
+ pass