nroggendorff commited on
Commit
7927c3b
·
verified ·
1 Parent(s): 8412424

Update train.py

Browse files
Files changed (1) hide show
  1. train.py +39 -44
train.py CHANGED
@@ -14,7 +14,7 @@ def load_model(model_name, device_id=0):
14
  load_in_4bit=True,
15
  bnb_4bit_compute_dtype=torch.bfloat16,
16
  bnb_4bit_quant_type="nf4",
17
- bnb_4bit_use_double_quant=True,
18
  )
19
 
20
  processor = AutoProcessor.from_pretrained(model_name)
@@ -24,18 +24,16 @@ def load_model(model_name, device_id=0):
24
  quantization_config=bnb_config,
25
  dtype=torch.bfloat16,
26
  device_map={"": device_id},
 
 
27
  )
28
 
29
  return processor, model
30
 
31
 
32
- processed_count = 0
33
-
34
  def caption_batch(batch, processor, model):
35
- global processed_count
36
-
37
  images = batch["image"]
38
-
39
  pil_images = []
40
  for image in images:
41
  if not isinstance(image, Image.Image):
@@ -44,56 +42,51 @@ def caption_batch(batch, processor, model):
44
  image = image.convert("RGB")
45
  pil_images.append(image)
46
 
47
- messages_list = []
48
- for pil_image in pil_images:
49
- msg = [
50
- {
51
- "role": "user",
52
- "content": [
53
- {"type": "image"},
54
- {"type": "text", "text": "Describe the image, and skip mentioning that it's illustrated or from anime."},
55
- ],
56
- }
57
- ]
58
- messages_list.append(msg)
59
-
60
- texts = processor.apply_chat_template(messages_list, add_generation_prompt=True, tokenize=False)
61
-
62
- inputs = processor(
63
- text=texts,
64
- images=pil_images,
65
- return_tensors="pt",
66
- padding=True
67
  )
 
 
 
68
 
69
  inputs = {k: v.to(model.device) for k, v in inputs.items()}
70
 
71
- with torch.no_grad():
72
  generated = model.generate(
73
  **inputs,
74
- max_new_tokens=256,
 
 
75
  )
76
 
77
  decoded = processor.batch_decode(generated, skip_special_tokens=False)
78
 
79
  captions = []
 
80
  for d in decoded:
81
  if "<|im_start|>assistant" in d:
82
- d = d.split("<|im_start|>assistant")[-1].strip()
83
-
84
- special_tokens = set(processor.tokenizer.all_special_tokens)
85
  for token in special_tokens:
86
  d = d.replace(token, "")
87
-
88
  d = d.strip()
89
  captions.append(d)
90
 
91
- processed_count += len(images)
92
- if processed_count > 100:
93
- print(f"Processed {processed_count} examples so far...")
94
-
95
  return {
96
- "image": images,
97
  "text": captions,
98
  }
99
 
@@ -101,9 +94,6 @@ def caption_batch(batch, processor, model):
101
  def process_shard_worker(
102
  gpu_id, start, end, model_name, batch_size, input_dataset, output_file
103
  ):
104
- global processed_count
105
- processed_count = 0
106
-
107
  torch.cuda.set_device(gpu_id)
108
 
109
  print(f"[GPU {gpu_id}] Loading model...", flush=True)
@@ -117,12 +107,17 @@ def process_shard_worker(
117
  else:
118
  shard = cast(Dataset, loaded)
119
 
 
 
 
120
  print(f"[GPU {gpu_id}] Processing {len(shard)} examples...", flush=True)
121
  result = shard.map(
122
  lambda batch: caption_batch(batch, processor, model),
123
  batched=True,
124
  batch_size=batch_size,
125
- remove_columns=shard.column_names,
 
 
126
  )
127
 
128
  print(f"[GPU {gpu_id}] Saving results to {output_file}...", flush=True)
@@ -134,9 +129,9 @@ def process_shard_worker(
134
 
135
  def main():
136
  input_dataset = "none-yet/anime-captions"
137
- output_dataset = input_dataset
138
  model_name = "datalab-to/chandra"
139
- batch_size = 12
140
 
141
  print("Loading dataset info...")
142
  loaded = datasets.load_dataset(input_dataset, split="train")
@@ -182,7 +177,7 @@ def main():
182
 
183
  print(f"Final dataset size: {len(final_ds)}")
184
  print("Pushing to hub...")
185
- final_ds.push_to_hub(output_dataset, create_pr=True)
186
 
187
  print("Cleaning up temporary files...")
188
  for f in temp_files:
 
14
  load_in_4bit=True,
15
  bnb_4bit_compute_dtype=torch.bfloat16,
16
  bnb_4bit_quant_type="nf4",
17
+ bnb_4bit_use_double_quant=False,
18
  )
19
 
20
  processor = AutoProcessor.from_pretrained(model_name)
 
24
  quantization_config=bnb_config,
25
  dtype=torch.bfloat16,
26
  device_map={"": device_id},
27
+ torch_dtype=torch.bfloat16,
28
+ attn_implementation="flash_attention_2",
29
  )
30
 
31
  return processor, model
32
 
33
 
 
 
34
  def caption_batch(batch, processor, model):
 
 
35
  images = batch["image"]
36
+
37
  pil_images = []
38
  for image in images:
39
  if not isinstance(image, Image.Image):
 
42
  image = image.convert("RGB")
43
  pil_images.append(image)
44
 
45
+ msg = [
46
+ {
47
+ "role": "user",
48
+ "content": [
49
+ {"type": "image"},
50
+ {
51
+ "type": "text",
52
+ "text": "Describe the image concisely, and skip mentioning that it's illustrated or from anime.",
53
+ },
54
+ ],
55
+ }
56
+ ]
57
+
58
+ text = processor.apply_chat_template(
59
+ msg, add_generation_prompt=True, tokenize=False
 
 
 
 
 
60
  )
61
+ texts = [text] * len(pil_images)
62
+
63
+ inputs = processor(text=texts, images=pil_images, return_tensors="pt", padding=True)
64
 
65
  inputs = {k: v.to(model.device) for k, v in inputs.items()}
66
 
67
+ with torch.no_grad(), torch.cuda.amp.autocast(dtype=torch.bfloat16):
68
  generated = model.generate(
69
  **inputs,
70
+ max_new_tokens=128,
71
+ do_sample=False,
72
+ use_cache=True,
73
  )
74
 
75
  decoded = processor.batch_decode(generated, skip_special_tokens=False)
76
 
77
  captions = []
78
+ special_tokens = set(processor.tokenizer.all_special_tokens)
79
  for d in decoded:
80
  if "<|im_start|>assistant" in d:
81
+ d = d.split("<|im_start|>assistant")[-1]
82
+
 
83
  for token in special_tokens:
84
  d = d.replace(token, "")
85
+
86
  d = d.strip()
87
  captions.append(d)
88
 
 
 
 
 
89
  return {
 
90
  "text": captions,
91
  }
92
 
 
94
  def process_shard_worker(
95
  gpu_id, start, end, model_name, batch_size, input_dataset, output_file
96
  ):
 
 
 
97
  torch.cuda.set_device(gpu_id)
98
 
99
  print(f"[GPU {gpu_id}] Loading model...", flush=True)
 
107
  else:
108
  shard = cast(Dataset, loaded)
109
 
110
+ shard = shard.with_format("torch")
111
+ shard.set_format(type="torch", columns=["image"])
112
+
113
  print(f"[GPU {gpu_id}] Processing {len(shard)} examples...", flush=True)
114
  result = shard.map(
115
  lambda batch: caption_batch(batch, processor, model),
116
  batched=True,
117
  batch_size=batch_size,
118
+ remove_columns=[col for col in shard.column_names if col != "image"],
119
+ writer_batch_size=1000,
120
+ keep_in_memory=True,
121
  )
122
 
123
  print(f"[GPU {gpu_id}] Saving results to {output_file}...", flush=True)
 
129
 
130
  def main():
131
  input_dataset = "none-yet/anime-captions"
132
+ output_dataset = "nroggendorff/anime-captions"
133
  model_name = "datalab-to/chandra"
134
+ batch_size = 32
135
 
136
  print("Loading dataset info...")
137
  loaded = datasets.load_dataset(input_dataset, split="train")
 
177
 
178
  print(f"Final dataset size: {len(final_ds)}")
179
  print("Pushing to hub...")
180
+ final_ds.push_to_hub(output_dataset, create_pr=False)
181
 
182
  print("Cleaning up temporary files...")
183
  for f in temp_files: