nroggendorff commited on
Commit
87153fd
·
verified ·
1 Parent(s): 97901ed

Update train.py

Browse files
Files changed (1) hide show
  1. train.py +48 -22
train.py CHANGED
@@ -26,35 +26,61 @@ def load_model(model_name="datalab-to/chandra", device_id=0):
26
  def caption_batch(batch, processor, model):
27
  images = batch["image"]
28
 
29
- messages = [
30
- {
31
- "role": "user",
32
- "content": [
33
- {"type": "image", "image": image},
34
- {
35
- "type": "text",
36
- "text": "Describe the image, and skip mentioning that it's illustrated or from anime.",
37
- },
38
- ],
39
- }
40
- for image in images
41
- ]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
42
 
43
- inputs = processor.apply_chat_template(
44
- messages,
45
- tokenize=True,
46
- add_generation_prompt=True,
47
- return_dict=True,
48
- return_tensors="pt",
49
  ).to(model.device)
50
 
51
  with torch.no_grad():
52
- generated = model.generate(**inputs)
 
 
 
53
 
54
  decoded = processor.batch_decode(generated)
55
- captions = [d.split("<|im_start|>assistant\n")[-1] for d in decoded]
56
 
57
- return {"image": images, "text": captions}
 
 
 
 
 
 
 
 
 
58
 
59
  # %%
60
  import datasets
 
26
  def caption_batch(batch, processor, model):
27
  images = batch["image"]
28
 
29
+ encoded_list = []
30
+ for image in images:
31
+ msg = [
32
+ {
33
+ "role": "user",
34
+ "content": [
35
+ {"type": "image", "image": image},
36
+ {
37
+ "type": "text",
38
+ "text": "Describe the image, and skip mentioning that it's illustrated or from anime.",
39
+ },
40
+ ],
41
+ }
42
+ ]
43
+
44
+ enc = processor.apply_chat_template(
45
+ msg,
46
+ tokenize=True,
47
+ add_generation_prompt=True,
48
+ return_dict=True,
49
+ return_tensors="pt",
50
+ )
51
+
52
+ encoded_list.append(enc)
53
+
54
+ input_ids = torch.nn.utils.rnn.pad_sequence(
55
+ [e.input_ids[0] for e in encoded_list],
56
+ batch_first=True,
57
+ padding_value=processor.tokenizer.pad_token_id,
58
+ ).to(model.device)
59
 
60
+ attention_mask = torch.nn.utils.rnn.pad_sequence(
61
+ [e.attention_mask[0] for e in encoded_list],
62
+ batch_first=True,
63
+ padding_value=0,
 
 
64
  ).to(model.device)
65
 
66
  with torch.no_grad():
67
+ generated = model.generate(
68
+ input_ids=input_ids,
69
+ attention_mask=attention_mask,
70
+ )
71
 
72
  decoded = processor.batch_decode(generated)
 
73
 
74
+ captions = []
75
+ for d in decoded:
76
+ if "<|im_start|>assistant" in d:
77
+ d = d.split("<|im_start|>assistant")[-1].strip()
78
+ captions.append(d)
79
+
80
+ return {
81
+ "image": images,
82
+ "text": captions,
83
+ }
84
 
85
  # %%
86
  import datasets