lanczos commited on
Commit
bd058ef
·
verified ·
1 Parent(s): bb71449

modify model scripts

Browse files
Files changed (1) hide show
  1. modeling_internvl_chat.py +86 -15
modeling_internvl_chat.py CHANGED
@@ -3,9 +3,10 @@
3
  # Copyright (c) 2024 OpenGVLab
4
  # Licensed under The MIT License [see LICENSE for details]
5
  # --------------------------------------------------------
6
-
7
  import warnings
8
  from typing import List, Optional, Tuple, Union
 
9
 
10
  import torch.utils.checkpoint
11
  import transformers
@@ -16,6 +17,8 @@ from transformers.modeling_outputs import CausalLMOutputWithPast
16
  from transformers.modeling_utils import PreTrainedModel
17
  from transformers.utils import logging
18
  from transformers import LlamaForCausalLM, Qwen2ForCausalLM, Qwen3ForCausalLM, Qwen3MoeForCausalLM
 
 
19
 
20
  from .configuration_internvl_chat import InternVLChatConfig
21
  from .conversation import get_conv_template
@@ -31,6 +34,80 @@ def version_cmp(v1, v2, op='eq'):
31
  op_func = getattr(operator, op)
32
  return op_func(version.parse(v1), version.parse(v2))
33
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
34
 
35
  class InternVLChatModel(PreTrainedModel):
36
  config_class = InternVLChatConfig
@@ -122,6 +199,13 @@ class InternVLChatModel(PreTrainedModel):
122
  llm_model.set_output_embeddings(nn.Identity())
123
  #! <<< NEW
124
 
 
 
 
 
 
 
 
125
  def _build_or_load_user_table(self,
126
  user_ckpt_path: Optional[str],
127
  default_num_users: int,
@@ -322,20 +406,7 @@ class InternVLChatModel(PreTrainedModel):
322
  image_tokens = IMG_START_TOKEN + IMG_CONTEXT_TOKEN * self.num_image_token * num_patches + IMG_END_TOKEN
323
  query = query.replace('<image>', image_tokens, 1)
324
  queries.append(query)
325
- # for idx, num_patches in enumerate(num_patches_list):
326
-
327
- # question = questions[idx]
328
- # if pixel_values is not None and '<image>' not in question:
329
- # question = '<image>\n' + question
330
- # template = get_conv_template(self.template)
331
- # template.system_message = self.system_message
332
- # template.append_message(template.roles[0], question)
333
- # template.append_message(template.roles[1], None)
334
- # query = template.get_prompt()
335
-
336
- # image_tokens = IMG_START_TOKEN + IMG_CONTEXT_TOKEN * self.num_image_token * num_patches + IMG_END_TOKEN
337
- # query = query.replace('<image>', image_tokens, 1)
338
- # queries.append(query)
339
 
340
  tokenizer.padding_side = 'left'
341
  model_inputs = tokenizer(queries, return_tensors='pt', padding=True)
 
3
  # Copyright (c) 2024 OpenGVLab
4
  # Licensed under The MIT License [see LICENSE for details]
5
  # --------------------------------------------------------
6
+ from functools import wraps
7
  import warnings
8
  from typing import List, Optional, Tuple, Union
9
+ from types import MethodType
10
 
11
  import torch.utils.checkpoint
12
  import transformers
 
17
  from transformers.modeling_utils import PreTrainedModel
18
  from transformers.utils import logging
19
  from transformers import LlamaForCausalLM, Qwen2ForCausalLM, Qwen3ForCausalLM, Qwen3MoeForCausalLM
20
+ from transformers.modeling_outputs import SequenceClassifierOutputWithPast
21
+ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
22
 
23
  from .configuration_internvl_chat import InternVLChatConfig
24
  from .conversation import get_conv_template
 
34
  op_func = getattr(operator, op)
35
  return op_func(version.parse(v1), version.parse(v2))
36
 
37
+ def transformers_seq_cls_forward(self, *args, origin_forward, **kwargs):
38
+ labels = kwargs.pop('labels', None)
39
+ return_dict = kwargs.pop('return_dict', None)
40
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
41
+ input_ids = kwargs.get('input_ids')
42
+ inputs_embeds = kwargs.get('inputs_embeds')
43
+
44
+ output = origin_forward(*args, **kwargs)
45
+ if hasattr(output, 'logits'):
46
+ output.logits = output.logits.to(self.score.weight.dtype)
47
+ elif 'last_hidden_state' in output:
48
+ output.logits = output['last_hidden_state'].to(self.score.weight.dtype)
49
+ logits = self.score(output.logits)
50
+ if input_ids is not None:
51
+ batch_size = input_ids.shape[0]
52
+ else:
53
+ batch_size = inputs_embeds.shape[0]
54
+
55
+ if self.config.pad_token_id is None and batch_size != 1:
56
+ raise ValueError('Cannot handle batch sizes > 1 if no padding token is defined.')
57
+ if self.config.pad_token_id is None:
58
+ sequence_lengths = -1
59
+ else:
60
+ if output.get('attention_mask') is not None:
61
+ # When use padding_free in seq_cls tasks, `revert_padding_free` will add a attention_mask in the output
62
+ batch_size = output.get('attention_mask').shape[0]
63
+ sequence_lengths = output.get('attention_mask').sum(dim=1) - 1
64
+ elif input_ids is not None:
65
+ # if no pad token found, use modulo instead of reverse indexing for ONNX compatibility
66
+ sequence_lengths = torch.eq(input_ids, self.config.pad_token_id).int().argmax(-1) - 1
67
+ sequence_lengths = sequence_lengths % input_ids.shape[-1]
68
+ elif kwargs.get('attention_mask') is not None:
69
+ sequence_lengths = kwargs['attention_mask'].sum(dim=1) - 1
70
+ else:
71
+ sequence_lengths = -1
72
+ if isinstance(sequence_lengths, torch.Tensor):
73
+ sequence_lengths = sequence_lengths.to(logits.device)
74
+
75
+ pooled_logits = logits[torch.arange(batch_size, device=logits.device), sequence_lengths]
76
+
77
+ loss = None
78
+ if labels is not None:
79
+ labels = labels.to(logits.device)
80
+ if self.config.problem_type is None:
81
+ if self.num_labels == 1:
82
+ self.config.problem_type = 'regression'
83
+ elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
84
+ self.config.problem_type = 'single_label_classification'
85
+ else:
86
+ self.config.problem_type = 'multi_label_classification'
87
+
88
+ if self.config.problem_type == 'regression':
89
+ loss_fct = MSELoss()
90
+ if self.num_labels == 1:
91
+ loss = loss_fct(pooled_logits.squeeze(), labels.squeeze())
92
+ else:
93
+ loss = loss_fct(pooled_logits, labels)
94
+ elif self.config.problem_type == 'single_label_classification':
95
+ loss_fct = CrossEntropyLoss()
96
+ loss = loss_fct(pooled_logits.view(-1, self.num_labels), labels.view(-1))
97
+ elif self.config.problem_type == 'multi_label_classification':
98
+ loss_fct = BCEWithLogitsLoss()
99
+ loss = loss_fct(pooled_logits, labels)
100
+ if not return_dict:
101
+ output = (pooled_logits, ) + output[1:]
102
+ return ((loss, ) + output) if loss is not None else output
103
+
104
+ return SequenceClassifierOutputWithPast(
105
+ loss=loss,
106
+ logits=pooled_logits,
107
+ past_key_values=output.past_key_values,
108
+ hidden_states=output.hidden_states,
109
+ attentions=output.attentions,
110
+ )
111
 
112
  class InternVLChatModel(PreTrainedModel):
113
  config_class = InternVLChatConfig
 
199
  llm_model.set_output_embeddings(nn.Identity())
200
  #! <<< NEW
201
 
202
+ origin_forward = llm_model.forward
203
+ @wraps(origin_forward.__func__)
204
+ def new_forward(self, *args, **kwargs):
205
+ return transformers_seq_cls_forward(self, *args, origin_forward=origin_forward, **kwargs)
206
+
207
+ llm_model.forward = MethodType(new_forward, llm_model)
208
+
209
  def _build_or_load_user_table(self,
210
  user_ckpt_path: Optional[str],
211
  default_num_users: int,
 
406
  image_tokens = IMG_START_TOKEN + IMG_CONTEXT_TOKEN * self.num_image_token * num_patches + IMG_END_TOKEN
407
  query = query.replace('<image>', image_tokens, 1)
408
  queries.append(query)
409
+
 
 
 
 
 
 
 
 
 
 
 
 
 
410
 
411
  tokenizer.padding_side = 'left'
412
  model_inputs = tokenizer(queries, return_tensors='pt', padding=True)