nroggendorff commited on
Commit
789627e
·
verified ·
1 Parent(s): c79fe94

Update train.py

Browse files
Files changed (1) hide show
  1. train.py +95 -124
train.py CHANGED
@@ -31,7 +31,7 @@ def load_model(model_name, device_id=0):
31
  return processor, model
32
 
33
 
34
- def build_template(processor):
35
  msg = [
36
  {
37
  "role": "user",
@@ -44,92 +44,54 @@ def build_template(processor):
44
  ],
45
  }
46
  ]
 
47
  return processor.apply_chat_template(
48
  msg, add_generation_prompt=True, tokenize=False
49
  )
50
 
51
 
52
- def iterable_to_map(ds, chunk_size=10000):
53
- buffer = []
54
- for ex in ds:
55
- buffer.append(ex)
56
- if len(buffer) >= chunk_size:
57
- yield buffer
58
- buffer = []
59
-
60
 
61
- def cpu_preprocess(input_dataset, output_folder, model_name):
62
- print("CPU preprocessing…")
 
 
 
63
 
64
- processor = AutoProcessor.from_pretrained(model_name)
65
- template = build_template(processor)
66
-
67
- def _pp(batch):
68
- out_images = []
69
- for img in batch["image"]:
70
- if isinstance(img, Image.Image):
71
- if img.mode != "RGB":
72
- img = img.convert("RGB")
73
- out_images.append(img)
74
-
75
- prompts = [template] * len(out_images)
76
- return {
77
- "image": out_images,
78
- "prompt": prompts,
79
- }
80
 
 
 
81
  ds = datasets.load_dataset(input_dataset, split="train")
82
 
83
- if ds is None:
84
- raise ValueError(
85
- f"Failed to load dataset '{input_dataset}' with split 'train'. Check the dataset name or available splits."
86
- )
87
-
88
- if isinstance(ds, datasets.DatasetDict):
89
- if "train" in ds:
90
- ds = ds["train"]
91
- else:
92
- raise ValueError(
93
- f"'{input_dataset}' does not contain a 'train' split. Available splits: {list(ds.keys())}"
94
- )
95
-
96
- if not isinstance(ds, datasets.Dataset):
97
- raise TypeError(f"Expected a Dataset instance, got {type(ds)}")
98
-
99
- print(f"Dataset loaded: {len(ds)} examples")
100
 
101
- ds2 = ds.map(
102
- _pp,
103
- batched=True,
104
- remove_columns=[c for c in ds.column_names if c not in ("image",)],
 
 
 
105
  )
106
 
107
- print("Saving CPU-preprocessed dataset")
108
- parts = []
109
- for chunk in iterable_to_map(ds2):
110
- part = Dataset.from_list(chunk)
111
- parts.append(part)
112
 
113
- ds2 = datasets.concatenate_datasets(parts)
114
- ds2.save_to_disk(output_folder)
115
 
116
- print("CPU preprocessing done.")
 
 
117
 
 
118
 
119
- def caption_batch(batch, processor, model):
120
- imgs = batch["image"]
121
- prompts = batch["prompt"]
122
-
123
- pil_images = []
124
- for image in imgs:
125
- if isinstance(image, Image.Image):
126
- if image.mode != "RGB":
127
- image = image.convert("RGB")
128
- pil_images.append(image)
129
-
130
- inputs = processor(
131
- text=prompts, images=pil_images, return_tensors="pt", padding=True
132
- )
133
  inputs = {
134
  k: v.pin_memory().to(model.device, non_blocking=True) for k, v in inputs.items()
135
  }
@@ -144,47 +106,49 @@ def caption_batch(batch, processor, model):
144
  decoded = processor.batch_decode(generated, skip_special_tokens=False)
145
 
146
  captions = []
147
- special = set(processor.tokenizer.all_special_tokens)
148
-
149
  for d in decoded:
150
  if "<|im_start|>assistant" in d:
151
  d = d.split("<|im_start|>assistant")[-1]
152
- for token in special:
 
153
  d = d.replace(token, "")
154
- captions.append(d.strip())
155
 
156
- return {"text": captions}
 
 
 
 
 
157
 
158
 
159
  def process_shard(
160
- gpu_id, start, end, model_name, batch_size, prepped_folder, output_file
161
  ):
162
  try:
163
  torch.cuda.set_device(gpu_id)
164
 
165
- print(f"[GPU {gpu_id}] Loading model", flush=True)
166
  processor, model = load_model(model_name, gpu_id)
167
 
168
- print(f"[GPU {gpu_id}] Loading preprocessed shard [{start}:{end}]", flush=True)
169
- shard = datasets.load_from_disk(prepped_folder)
170
- if isinstance(shard, datasets.DatasetDict):
171
- shard = shard["train"]
172
- shard = shard.select(range(start, end))
173
 
174
- print(f"[GPU {gpu_id}] Captioning {len(shard)} examples", flush=True)
175
  result = shard.map(
176
  lambda batch: caption_batch(batch, processor, model),
177
  batched=True,
178
  batch_size=batch_size,
179
- remove_columns=["image", "prompt"],
180
  )
181
 
182
- print(f"[GPU {gpu_id}] Saving {output_file}", flush=True)
183
  result.save_to_disk(output_file)
184
 
185
- print(f"[GPU {gpu_id}] Done.", flush=True)
186
  return output_file
187
-
188
  except Exception as e:
189
  print(f"[GPU {gpu_id}] Error: {e}", flush=True)
190
  raise
@@ -194,37 +158,44 @@ def main():
194
  mp.set_start_method("spawn", force=True)
195
 
196
  input_dataset = "none-yet/anime-captions"
197
- prepped_folder = "cpu_preprocessed"
198
  output_dataset = "nroggendorff/anime-captions"
199
  model_name = "datalab-to/chandra"
200
  batch_size = 20
201
 
202
- if not os.path.exists(prepped_folder):
203
- cpu_preprocess(input_dataset, prepped_folder, model_name)
204
-
205
- ds = datasets.load_from_disk(prepped_folder)
206
- total = len(ds)
207
 
 
 
208
  num_gpus = torch.cuda.device_count()
209
- shard = total // num_gpus
 
210
 
211
- print(f"Dataset size: {total}")
212
  print(f"Using {num_gpus} GPUs")
213
- print(f"Shard size: {shard}")
214
 
215
  processes = []
216
  temp_files = []
217
 
218
  for i in range(num_gpus):
219
- s = i * shard
220
- e = s + shard if i < num_gpus - 1 else total
221
-
222
- of = f"temp_shard_{i}"
223
- temp_files.append(of)
224
 
225
  p = mp.Process(
226
  target=process_shard,
227
- args=(i, s, e, model_name, batch_size, prepped_folder, of),
 
 
 
 
 
 
 
 
228
  )
229
  p.start()
230
  processes.append(p)
@@ -232,32 +203,32 @@ def main():
232
  for p in processes:
233
  p.join()
234
  if p.exitcode != 0:
235
- print("A process failed, aborting…")
236
- for q in processes:
237
- if q.is_alive():
238
- q.terminate()
239
- for q in processes:
240
- q.join()
241
- raise RuntimeError("GPU worker failed.")
242
-
243
- print("Merging shards…")
244
- parts = []
245
- for f in temp_files:
246
- ds = datasets.load_from_disk(f)
247
- if isinstance(ds, datasets.DatasetDict):
248
- ds = ds["train"]
249
- parts.append(ds)
250
-
251
- final_ds = datasets.concatenate_datasets(parts)
252
-
253
- print(f"Pushing final dataset to {output_dataset}…")
254
  final_ds.push_to_hub(output_dataset, create_pr=False)
255
 
256
- print("Cleaning up")
257
  for f in temp_files:
258
- shutil.rmtree(f, ignore_errors=True)
 
 
 
259
 
260
- print("Done.")
261
 
262
 
263
  if __name__ == "__main__":
 
31
  return processor, model
32
 
33
 
34
+ def getTemplate(processor):
35
  msg = [
36
  {
37
  "role": "user",
 
44
  ],
45
  }
46
  ]
47
+
48
  return processor.apply_chat_template(
49
  msg, add_generation_prompt=True, tokenize=False
50
  )
51
 
52
 
53
+ def preprocess_example(example, processor):
54
+ image = example["image"]
55
+ if isinstance(image, Image.Image):
56
+ if image.mode != "RGB":
57
+ image = image.convert("RGB")
58
+ else:
59
+ raise ValueError("Image must be a PIL Image")
 
60
 
61
+ text = getTemplate(processor)
62
+ return {
63
+ "image": image,
64
+ "text_prompt": text,
65
+ }
66
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
67
 
68
+ def run_preprocessing(input_dataset, output_dir, num_proc=4):
69
+ print("Loading dataset for preprocessing...")
70
  ds = datasets.load_dataset(input_dataset, split="train")
71
 
72
+ print("Loading processor...")
73
+ processor = AutoProcessor.from_pretrained("datalab-to/chandra")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
74
 
75
+ print("Running preprocessing...")
76
+ processed_ds = ds.map(
77
+ lambda ex: preprocess_example(ex, processor),
78
+ remove_columns=[
79
+ col for col in ds.column_names if col not in ["image", "text_prompt"]
80
+ ],
81
+ num_proc=num_proc,
82
  )
83
 
84
+ print(f"Saving preprocessed dataset to {output_dir}...")
85
+ processed_ds.save_to_disk(output_dir)
86
+ print("Preprocessing done.")
 
 
87
 
 
 
88
 
89
+ def caption_batch(batch, processor, model):
90
+ images = batch["image"]
91
+ texts = batch["text_prompt"]
92
 
93
+ inputs = processor(text=texts, images=images, return_tensors="pt", padding=True)
94
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
95
  inputs = {
96
  k: v.pin_memory().to(model.device, non_blocking=True) for k, v in inputs.items()
97
  }
 
106
  decoded = processor.batch_decode(generated, skip_special_tokens=False)
107
 
108
  captions = []
109
+ special_tokens = set(processor.tokenizer.all_special_tokens)
 
110
  for d in decoded:
111
  if "<|im_start|>assistant" in d:
112
  d = d.split("<|im_start|>assistant")[-1]
113
+
114
+ for token in special_tokens:
115
  d = d.replace(token, "")
 
116
 
117
+ d = d.strip()
118
+ captions.append(d)
119
+
120
+ return {
121
+ "text": captions,
122
+ }
123
 
124
 
125
  def process_shard(
126
+ gpu_id, start, end, model_name, batch_size, input_dataset, output_file
127
  ):
128
  try:
129
  torch.cuda.set_device(gpu_id)
130
 
131
+ print(f"[GPU {gpu_id}] Loading model...", flush=True)
132
  processor, model = load_model(model_name, gpu_id)
133
 
134
+ print(f"[GPU {gpu_id}] Loading data shard [{start}:{end}]...", flush=True)
135
+ loaded = datasets.load_from_disk(input_dataset).select(range(start, end))
136
+
137
+ shard = cast(Dataset, loaded)
 
138
 
139
+ print(f"[GPU {gpu_id}] Processing {len(shard)} examples...", flush=True)
140
  result = shard.map(
141
  lambda batch: caption_batch(batch, processor, model),
142
  batched=True,
143
  batch_size=batch_size,
144
+ remove_columns=["text_prompt"],
145
  )
146
 
147
+ print(f"[GPU {gpu_id}] Saving results to {output_file}...", flush=True)
148
  result.save_to_disk(output_file)
149
 
150
+ print(f"[GPU {gpu_id}] Done!", flush=True)
151
  return output_file
 
152
  except Exception as e:
153
  print(f"[GPU {gpu_id}] Error: {e}", flush=True)
154
  raise
 
158
  mp.set_start_method("spawn", force=True)
159
 
160
  input_dataset = "none-yet/anime-captions"
161
+ preprocessed_dataset = "temp_preprocessed"
162
  output_dataset = "nroggendorff/anime-captions"
163
  model_name = "datalab-to/chandra"
164
  batch_size = 20
165
 
166
+ if not os.path.exists(preprocessed_dataset):
167
+ run_preprocessing(input_dataset, preprocessed_dataset)
 
 
 
168
 
169
+ print("Loading preprocessed dataset...")
170
+ ds = datasets.load_from_disk(preprocessed_dataset)
171
  num_gpus = torch.cuda.device_count()
172
+ total_size = len(ds)
173
+ shard_size = total_size // num_gpus
174
 
175
+ print(f"Dataset size: {total_size}")
176
  print(f"Using {num_gpus} GPUs")
177
+ print(f"Shard size: {shard_size}")
178
 
179
  processes = []
180
  temp_files = []
181
 
182
  for i in range(num_gpus):
183
+ start = i * shard_size
184
+ end = start + shard_size if i < num_gpus - 1 else total_size
185
+ output_file = f"temp_shard_{i}"
186
+ temp_files.append(output_file)
 
187
 
188
  p = mp.Process(
189
  target=process_shard,
190
+ args=(
191
+ i,
192
+ start,
193
+ end,
194
+ model_name,
195
+ batch_size,
196
+ preprocessed_dataset,
197
+ output_file,
198
+ ),
199
  )
200
  p.start()
201
  processes.append(p)
 
203
  for p in processes:
204
  p.join()
205
  if p.exitcode != 0:
206
+ print(f"\nProcess failed with exit code {p.exitcode}", flush=True)
207
+ print("Terminating all processes...", flush=True)
208
+ for proc in processes:
209
+ if proc.is_alive():
210
+ proc.terminate()
211
+ for proc in processes:
212
+ proc.join()
213
+ raise RuntimeError(f"At least one process failed")
214
+
215
+ print("\nAll processes completed. Loading and concatenating results...")
216
+
217
+ shards = [cast(Dataset, datasets.load_from_disk(f)) for f in temp_files]
218
+ final_ds = datasets.concatenate_datasets(shards)
219
+
220
+ print(f"Final dataset size: {len(final_ds)}")
221
+ print("Pushing to hub...")
 
 
 
222
  final_ds.push_to_hub(output_dataset, create_pr=False)
223
 
224
+ print("Cleaning up temporary files...")
225
  for f in temp_files:
226
+ if os.path.exists(f):
227
+ shutil.rmtree(f)
228
+ if os.path.exists(preprocessed_dataset):
229
+ shutil.rmtree(preprocessed_dataset)
230
 
231
+ print("Done!")
232
 
233
 
234
  if __name__ == "__main__":