nroggendorff commited on
Commit
c446fbc
·
verified ·
1 Parent(s): 789627e

Update train.py

Browse files
Files changed (1) hide show
  1. train.py +20 -17
train.py CHANGED
@@ -50,35 +50,38 @@ def getTemplate(processor):
50
  )
51
 
52
 
53
- def preprocess_example(example, processor):
54
- image = example["image"]
55
- if isinstance(image, Image.Image):
56
- if image.mode != "RGB":
57
- image = image.convert("RGB")
58
- else:
59
- raise ValueError("Image must be a PIL Image")
 
 
 
60
 
61
- text = getTemplate(processor)
62
  return {
63
- "image": image,
64
- "text_prompt": text,
65
  }
66
 
67
 
68
- def run_preprocessing(input_dataset, output_dir, num_proc=4):
69
  print("Loading dataset for preprocessing...")
70
  ds = datasets.load_dataset(input_dataset, split="train")
71
 
72
  print("Loading processor...")
73
  processor = AutoProcessor.from_pretrained("datalab-to/chandra")
 
74
 
75
  print("Running preprocessing...")
76
  processed_ds = ds.map(
77
- lambda ex: preprocess_example(ex, processor),
78
- remove_columns=[
79
- col for col in ds.column_names if col not in ["image", "text_prompt"]
80
- ],
81
  num_proc=num_proc,
 
 
82
  )
83
 
84
  print(f"Saving preprocessed dataset to {output_dir}...")
@@ -88,7 +91,7 @@ def run_preprocessing(input_dataset, output_dir, num_proc=4):
88
 
89
  def caption_batch(batch, processor, model):
90
  images = batch["image"]
91
- texts = batch["text_prompt"]
92
 
93
  inputs = processor(text=texts, images=images, return_tensors="pt", padding=True)
94
 
@@ -141,7 +144,7 @@ def process_shard(
141
  lambda batch: caption_batch(batch, processor, model),
142
  batched=True,
143
  batch_size=batch_size,
144
- remove_columns=["text_prompt"],
145
  )
146
 
147
  print(f"[GPU {gpu_id}] Saving results to {output_file}...", flush=True)
 
50
  )
51
 
52
 
53
+ def preprocess_example_batch(examples, text):
54
+ processed_images = []
55
+
56
+ for image in examples["image"]:
57
+ if isinstance(image, Image.Image):
58
+ if image.mode != "RGB":
59
+ image = image.convert("RGB")
60
+ processed_images.append(image)
61
+ else:
62
+ raise ValueError("Image must be a PIL Image")
63
 
 
64
  return {
65
+ "image": processed_images,
66
+ "text": [text] * len(processed_images),
67
  }
68
 
69
 
70
+ def run_preprocessing(input_dataset, output_dir, num_proc=32, batch_size=100):
71
  print("Loading dataset for preprocessing...")
72
  ds = datasets.load_dataset(input_dataset, split="train")
73
 
74
  print("Loading processor...")
75
  processor = AutoProcessor.from_pretrained("datalab-to/chandra")
76
+ text = getTemplate(processor)
77
 
78
  print("Running preprocessing...")
79
  processed_ds = ds.map(
80
+ lambda ex: preprocess_example_batch(ex, text),
81
+ remove_columns=[col for col in ds.column_names if col not in ["image", "text"]],
 
 
82
  num_proc=num_proc,
83
+ batched=True,
84
+ batch_size=batch_size,
85
  )
86
 
87
  print(f"Saving preprocessed dataset to {output_dir}...")
 
91
 
92
  def caption_batch(batch, processor, model):
93
  images = batch["image"]
94
+ texts = batch["text"]
95
 
96
  inputs = processor(text=texts, images=images, return_tensors="pt", padding=True)
97
 
 
144
  lambda batch: caption_batch(batch, processor, model),
145
  batched=True,
146
  batch_size=batch_size,
147
+ remove_columns=["text"],
148
  )
149
 
150
  print(f"[GPU {gpu_id}] Saving results to {output_file}...", flush=True)