Noa commited on
Commit
80f71ae
·
1 Parent(s): 5e51f81

Update train.py

Browse files
Files changed (1) hide show
  1. train.py +20 -17
train.py CHANGED
@@ -14,7 +14,7 @@ def load_model(model_name, device_id=0):
14
  load_in_4bit=True,
15
  bnb_4bit_compute_dtype=torch.bfloat16,
16
  bnb_4bit_quant_type="nf4",
17
- bnb_4bit_use_double_quant=False,
18
  )
19
 
20
  processor = AutoProcessor.from_pretrained(model_name)
@@ -31,16 +31,7 @@ def load_model(model_name, device_id=0):
31
  return processor, model
32
 
33
 
34
- def caption_batch(batch, processor, model):
35
- images = batch["image"]
36
-
37
- pil_images = []
38
- for image in images:
39
- if isinstance(image, Image.Image):
40
- if image.mode != "RGB":
41
- image = image.convert("RGB")
42
- pil_images.append(image)
43
-
44
  msg = [
45
  {
46
  "role": "user",
@@ -54,14 +45,26 @@ def caption_batch(batch, processor, model):
54
  }
55
  ]
56
 
57
- text = processor.apply_chat_template(
58
  msg, add_generation_prompt=True, tokenize=False
59
  )
 
 
 
 
 
 
 
 
 
 
 
 
60
  texts = [text] * len(pil_images)
61
 
62
  inputs = processor(text=texts, images=pil_images, return_tensors="pt", padding=True)
63
 
64
- inputs = {k: v.to(model.device) for k, v in inputs.items()}
65
 
66
  with torch.no_grad(), torch.amp.autocast('cuda', dtype=torch.bfloat16):
67
  generated = model.generate(
@@ -106,7 +109,7 @@ def process_shard(gpu_id, start, end, model_name, batch_size, input_dataset, out
106
 
107
  print(f"[GPU {gpu_id}] Processing {len(shard)} examples...", flush=True)
108
  result = shard.map(
109
- lambda batch: caption_batch(batch, processor, model),
110
  batched=True,
111
  batch_size=batch_size,
112
  remove_columns=[col for col in shard.column_names if col != "image"],
@@ -124,11 +127,11 @@ def process_shard(gpu_id, start, end, model_name, batch_size, input_dataset, out
124
 
125
  def main():
126
  mp.set_start_method('spawn', force=True)
127
-
128
  input_dataset = "none-yet/anime-captions"
129
  output_dataset = "nroggendorff/anime-captions"
130
  model_name = "datalab-to/chandra"
131
- batch_size = 16
132
 
133
  print("Loading dataset info...")
134
  loaded = datasets.load_dataset(input_dataset, split="train")
@@ -192,4 +195,4 @@ def main():
192
 
193
 
194
  if __name__ == "__main__":
195
- main()
 
14
  load_in_4bit=True,
15
  bnb_4bit_compute_dtype=torch.bfloat16,
16
  bnb_4bit_quant_type="nf4",
17
+ bnb_4bit_use_double_quant=True,
18
  )
19
 
20
  processor = AutoProcessor.from_pretrained(model_name)
 
31
  return processor, model
32
 
33
 
34
+ def getTemplate(processor):
 
 
 
 
 
 
 
 
 
35
  msg = [
36
  {
37
  "role": "user",
 
45
  }
46
  ]
47
 
48
+ return processor.apply_chat_template(
49
  msg, add_generation_prompt=True, tokenize=False
50
  )
51
+
52
+
53
+ def caption_batch(batch, processor, model, text):
54
+ images = batch["image"]
55
+
56
+ pil_images = []
57
+ for image in images:
58
+ if isinstance(image, Image.Image):
59
+ if image.mode != "RGB":
60
+ image = image.convert("RGB")
61
+ pil_images.append(image)
62
+
63
  texts = [text] * len(pil_images)
64
 
65
  inputs = processor(text=texts, images=pil_images, return_tensors="pt", padding=True)
66
 
67
+ inputs = {k: v.pin_memory().to(model.device, non_blocking=True) for k, v in inputs.items()}
68
 
69
  with torch.no_grad(), torch.amp.autocast('cuda', dtype=torch.bfloat16):
70
  generated = model.generate(
 
109
 
110
  print(f"[GPU {gpu_id}] Processing {len(shard)} examples...", flush=True)
111
  result = shard.map(
112
+ lambda batch: caption_batch(batch, processor, model, getTemplate(processor)),
113
  batched=True,
114
  batch_size=batch_size,
115
  remove_columns=[col for col in shard.column_names if col != "image"],
 
127
 
128
  def main():
129
  mp.set_start_method('spawn', force=True)
130
+
131
  input_dataset = "none-yet/anime-captions"
132
  output_dataset = "nroggendorff/anime-captions"
133
  model_name = "datalab-to/chandra"
134
+ batch_size = 20
135
 
136
  print("Loading dataset info...")
137
  loaded = datasets.load_dataset(input_dataset, split="train")
 
195
 
196
 
197
  if __name__ == "__main__":
198
+ main()