modify model scripts
Browse files- modeling_internvl_chat.py +86 -15
modeling_internvl_chat.py
CHANGED
|
@@ -3,9 +3,10 @@
|
|
| 3 |
# Copyright (c) 2024 OpenGVLab
|
| 4 |
# Licensed under The MIT License [see LICENSE for details]
|
| 5 |
# --------------------------------------------------------
|
| 6 |
-
|
| 7 |
import warnings
|
| 8 |
from typing import List, Optional, Tuple, Union
|
|
|
|
| 9 |
|
| 10 |
import torch.utils.checkpoint
|
| 11 |
import transformers
|
|
@@ -16,6 +17,8 @@ from transformers.modeling_outputs import CausalLMOutputWithPast
|
|
| 16 |
from transformers.modeling_utils import PreTrainedModel
|
| 17 |
from transformers.utils import logging
|
| 18 |
from transformers import LlamaForCausalLM, Qwen2ForCausalLM, Qwen3ForCausalLM, Qwen3MoeForCausalLM
|
|
|
|
|
|
|
| 19 |
|
| 20 |
from .configuration_internvl_chat import InternVLChatConfig
|
| 21 |
from .conversation import get_conv_template
|
|
@@ -31,6 +34,80 @@ def version_cmp(v1, v2, op='eq'):
|
|
| 31 |
op_func = getattr(operator, op)
|
| 32 |
return op_func(version.parse(v1), version.parse(v2))
|
| 33 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 34 |
|
| 35 |
class InternVLChatModel(PreTrainedModel):
|
| 36 |
config_class = InternVLChatConfig
|
|
@@ -122,6 +199,13 @@ class InternVLChatModel(PreTrainedModel):
|
|
| 122 |
llm_model.set_output_embeddings(nn.Identity())
|
| 123 |
#! <<< NEW
|
| 124 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 125 |
def _build_or_load_user_table(self,
|
| 126 |
user_ckpt_path: Optional[str],
|
| 127 |
default_num_users: int,
|
|
@@ -322,20 +406,7 @@ class InternVLChatModel(PreTrainedModel):
|
|
| 322 |
image_tokens = IMG_START_TOKEN + IMG_CONTEXT_TOKEN * self.num_image_token * num_patches + IMG_END_TOKEN
|
| 323 |
query = query.replace('<image>', image_tokens, 1)
|
| 324 |
queries.append(query)
|
| 325 |
-
|
| 326 |
-
|
| 327 |
-
# question = questions[idx]
|
| 328 |
-
# if pixel_values is not None and '<image>' not in question:
|
| 329 |
-
# question = '<image>\n' + question
|
| 330 |
-
# template = get_conv_template(self.template)
|
| 331 |
-
# template.system_message = self.system_message
|
| 332 |
-
# template.append_message(template.roles[0], question)
|
| 333 |
-
# template.append_message(template.roles[1], None)
|
| 334 |
-
# query = template.get_prompt()
|
| 335 |
-
|
| 336 |
-
# image_tokens = IMG_START_TOKEN + IMG_CONTEXT_TOKEN * self.num_image_token * num_patches + IMG_END_TOKEN
|
| 337 |
-
# query = query.replace('<image>', image_tokens, 1)
|
| 338 |
-
# queries.append(query)
|
| 339 |
|
| 340 |
tokenizer.padding_side = 'left'
|
| 341 |
model_inputs = tokenizer(queries, return_tensors='pt', padding=True)
|
|
|
|
| 3 |
# Copyright (c) 2024 OpenGVLab
|
| 4 |
# Licensed under The MIT License [see LICENSE for details]
|
| 5 |
# --------------------------------------------------------
|
| 6 |
+
from functools import wraps
|
| 7 |
import warnings
|
| 8 |
from typing import List, Optional, Tuple, Union
|
| 9 |
+
from types import MethodType
|
| 10 |
|
| 11 |
import torch.utils.checkpoint
|
| 12 |
import transformers
|
|
|
|
| 17 |
from transformers.modeling_utils import PreTrainedModel
|
| 18 |
from transformers.utils import logging
|
| 19 |
from transformers import LlamaForCausalLM, Qwen2ForCausalLM, Qwen3ForCausalLM, Qwen3MoeForCausalLM
|
| 20 |
+
from transformers.modeling_outputs import SequenceClassifierOutputWithPast
|
| 21 |
+
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
|
| 22 |
|
| 23 |
from .configuration_internvl_chat import InternVLChatConfig
|
| 24 |
from .conversation import get_conv_template
|
|
|
|
| 34 |
op_func = getattr(operator, op)
|
| 35 |
return op_func(version.parse(v1), version.parse(v2))
|
| 36 |
|
| 37 |
+
def transformers_seq_cls_forward(self, *args, origin_forward, **kwargs):
|
| 38 |
+
labels = kwargs.pop('labels', None)
|
| 39 |
+
return_dict = kwargs.pop('return_dict', None)
|
| 40 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
| 41 |
+
input_ids = kwargs.get('input_ids')
|
| 42 |
+
inputs_embeds = kwargs.get('inputs_embeds')
|
| 43 |
+
|
| 44 |
+
output = origin_forward(*args, **kwargs)
|
| 45 |
+
if hasattr(output, 'logits'):
|
| 46 |
+
output.logits = output.logits.to(self.score.weight.dtype)
|
| 47 |
+
elif 'last_hidden_state' in output:
|
| 48 |
+
output.logits = output['last_hidden_state'].to(self.score.weight.dtype)
|
| 49 |
+
logits = self.score(output.logits)
|
| 50 |
+
if input_ids is not None:
|
| 51 |
+
batch_size = input_ids.shape[0]
|
| 52 |
+
else:
|
| 53 |
+
batch_size = inputs_embeds.shape[0]
|
| 54 |
+
|
| 55 |
+
if self.config.pad_token_id is None and batch_size != 1:
|
| 56 |
+
raise ValueError('Cannot handle batch sizes > 1 if no padding token is defined.')
|
| 57 |
+
if self.config.pad_token_id is None:
|
| 58 |
+
sequence_lengths = -1
|
| 59 |
+
else:
|
| 60 |
+
if output.get('attention_mask') is not None:
|
| 61 |
+
# When use padding_free in seq_cls tasks, `revert_padding_free` will add a attention_mask in the output
|
| 62 |
+
batch_size = output.get('attention_mask').shape[0]
|
| 63 |
+
sequence_lengths = output.get('attention_mask').sum(dim=1) - 1
|
| 64 |
+
elif input_ids is not None:
|
| 65 |
+
# if no pad token found, use modulo instead of reverse indexing for ONNX compatibility
|
| 66 |
+
sequence_lengths = torch.eq(input_ids, self.config.pad_token_id).int().argmax(-1) - 1
|
| 67 |
+
sequence_lengths = sequence_lengths % input_ids.shape[-1]
|
| 68 |
+
elif kwargs.get('attention_mask') is not None:
|
| 69 |
+
sequence_lengths = kwargs['attention_mask'].sum(dim=1) - 1
|
| 70 |
+
else:
|
| 71 |
+
sequence_lengths = -1
|
| 72 |
+
if isinstance(sequence_lengths, torch.Tensor):
|
| 73 |
+
sequence_lengths = sequence_lengths.to(logits.device)
|
| 74 |
+
|
| 75 |
+
pooled_logits = logits[torch.arange(batch_size, device=logits.device), sequence_lengths]
|
| 76 |
+
|
| 77 |
+
loss = None
|
| 78 |
+
if labels is not None:
|
| 79 |
+
labels = labels.to(logits.device)
|
| 80 |
+
if self.config.problem_type is None:
|
| 81 |
+
if self.num_labels == 1:
|
| 82 |
+
self.config.problem_type = 'regression'
|
| 83 |
+
elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
|
| 84 |
+
self.config.problem_type = 'single_label_classification'
|
| 85 |
+
else:
|
| 86 |
+
self.config.problem_type = 'multi_label_classification'
|
| 87 |
+
|
| 88 |
+
if self.config.problem_type == 'regression':
|
| 89 |
+
loss_fct = MSELoss()
|
| 90 |
+
if self.num_labels == 1:
|
| 91 |
+
loss = loss_fct(pooled_logits.squeeze(), labels.squeeze())
|
| 92 |
+
else:
|
| 93 |
+
loss = loss_fct(pooled_logits, labels)
|
| 94 |
+
elif self.config.problem_type == 'single_label_classification':
|
| 95 |
+
loss_fct = CrossEntropyLoss()
|
| 96 |
+
loss = loss_fct(pooled_logits.view(-1, self.num_labels), labels.view(-1))
|
| 97 |
+
elif self.config.problem_type == 'multi_label_classification':
|
| 98 |
+
loss_fct = BCEWithLogitsLoss()
|
| 99 |
+
loss = loss_fct(pooled_logits, labels)
|
| 100 |
+
if not return_dict:
|
| 101 |
+
output = (pooled_logits, ) + output[1:]
|
| 102 |
+
return ((loss, ) + output) if loss is not None else output
|
| 103 |
+
|
| 104 |
+
return SequenceClassifierOutputWithPast(
|
| 105 |
+
loss=loss,
|
| 106 |
+
logits=pooled_logits,
|
| 107 |
+
past_key_values=output.past_key_values,
|
| 108 |
+
hidden_states=output.hidden_states,
|
| 109 |
+
attentions=output.attentions,
|
| 110 |
+
)
|
| 111 |
|
| 112 |
class InternVLChatModel(PreTrainedModel):
|
| 113 |
config_class = InternVLChatConfig
|
|
|
|
| 199 |
llm_model.set_output_embeddings(nn.Identity())
|
| 200 |
#! <<< NEW
|
| 201 |
|
| 202 |
+
origin_forward = llm_model.forward
|
| 203 |
+
@wraps(origin_forward.__func__)
|
| 204 |
+
def new_forward(self, *args, **kwargs):
|
| 205 |
+
return transformers_seq_cls_forward(self, *args, origin_forward=origin_forward, **kwargs)
|
| 206 |
+
|
| 207 |
+
llm_model.forward = MethodType(new_forward, llm_model)
|
| 208 |
+
|
| 209 |
def _build_or_load_user_table(self,
|
| 210 |
user_ckpt_path: Optional[str],
|
| 211 |
default_num_users: int,
|
|
|
|
| 406 |
image_tokens = IMG_START_TOKEN + IMG_CONTEXT_TOKEN * self.num_image_token * num_patches + IMG_END_TOKEN
|
| 407 |
query = query.replace('<image>', image_tokens, 1)
|
| 408 |
queries.append(query)
|
| 409 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 410 |
|
| 411 |
tokenizer.padding_side = 'left'
|
| 412 |
model_inputs = tokenizer(queries, return_tensors='pt', padding=True)
|