Instructions to use naver-hyperclovax/HyperCLOVAX-SEED-Vision-Instruct-3B with libraries, inference providers, notebooks, and local apps. Follow these links to get started.
- Libraries
- Transformers
How to use naver-hyperclovax/HyperCLOVAX-SEED-Vision-Instruct-3B with Transformers:
# Use a pipeline as a high-level helper from transformers import pipeline pipe = pipeline("text-generation", model="naver-hyperclovax/HyperCLOVAX-SEED-Vision-Instruct-3B", trust_remote_code=True) messages = [ {"role": "user", "content": "Who are you?"}, ] pipe(messages)# Load model directly from transformers import AutoModelForCausalLM model = AutoModelForCausalLM.from_pretrained("naver-hyperclovax/HyperCLOVAX-SEED-Vision-Instruct-3B", trust_remote_code=True, dtype="auto") - Notebooks
- Google Colab
- Kaggle
- Local Apps
- vLLM
How to use naver-hyperclovax/HyperCLOVAX-SEED-Vision-Instruct-3B with vLLM:
Install from pip and serve model
# Install vLLM from pip: pip install vllm # Start the vLLM server: vllm serve "naver-hyperclovax/HyperCLOVAX-SEED-Vision-Instruct-3B" # Call the server using curl (OpenAI-compatible API): curl -X POST "http://localhost:8000/v1/chat/completions" \ -H "Content-Type: application/json" \ --data '{ "model": "naver-hyperclovax/HyperCLOVAX-SEED-Vision-Instruct-3B", "messages": [ { "role": "user", "content": "What is the capital of France?" } ] }'Use Docker
docker model run hf.co/naver-hyperclovax/HyperCLOVAX-SEED-Vision-Instruct-3B
- SGLang
How to use naver-hyperclovax/HyperCLOVAX-SEED-Vision-Instruct-3B with SGLang:
Install from pip and serve model
# Install SGLang from pip: pip install sglang # Start the SGLang server: python3 -m sglang.launch_server \ --model-path "naver-hyperclovax/HyperCLOVAX-SEED-Vision-Instruct-3B" \ --host 0.0.0.0 \ --port 30000 # Call the server using curl (OpenAI-compatible API): curl -X POST "http://localhost:30000/v1/chat/completions" \ -H "Content-Type: application/json" \ --data '{ "model": "naver-hyperclovax/HyperCLOVAX-SEED-Vision-Instruct-3B", "messages": [ { "role": "user", "content": "What is the capital of France?" } ] }'Use Docker images
docker run --gpus all \ --shm-size 32g \ -p 30000:30000 \ -v ~/.cache/huggingface:/root/.cache/huggingface \ --env "HF_TOKEN=<secret>" \ --ipc=host \ lmsysorg/sglang:latest \ python3 -m sglang.launch_server \ --model-path "naver-hyperclovax/HyperCLOVAX-SEED-Vision-Instruct-3B" \ --host 0.0.0.0 \ --port 30000 # Call the server using curl (OpenAI-compatible API): curl -X POST "http://localhost:30000/v1/chat/completions" \ -H "Content-Type: application/json" \ --data '{ "model": "naver-hyperclovax/HyperCLOVAX-SEED-Vision-Instruct-3B", "messages": [ { "role": "user", "content": "What is the capital of France?" } ] }' - Docker Model Runner
How to use naver-hyperclovax/HyperCLOVAX-SEED-Vision-Instruct-3B with Docker Model Runner:
docker model run hf.co/naver-hyperclovax/HyperCLOVAX-SEED-Vision-Instruct-3B
| import ast | |
| import contextlib | |
| import gc | |
| import json | |
| import math | |
| import os | |
| from dataclasses import dataclass | |
| from functools import partial | |
| from itertools import chain | |
| from typing import Any, Dict, List, Optional, Tuple, Union | |
| import torch | |
| import torch.distributed as dist | |
| import torch.nn as nn | |
| from einops import rearrange | |
| from timm.layers import LayerNorm, LayerNorm2d | |
| from timm.models.regnet import RegStage | |
| from torch.nn import CrossEntropyLoss | |
| from transformers import ( | |
| AutoConfig, | |
| AutoModel, | |
| AutoModelForCausalLM, | |
| AutoTokenizer, | |
| PreTrainedModel, | |
| ) | |
| from transformers.generation.utils import GenerationMixin | |
| from transformers.integrations.deepspeed import is_deepspeed_zero3_enabled | |
| from transformers.modeling_utils import ( | |
| is_fsdp_enabled, | |
| is_local_dist_rank_0, | |
| no_init_weights, | |
| ) | |
| from transformers.models.auto import CONFIG_MAPPING | |
| from transformers.utils import ModelOutput | |
| from .configuration_hyperclovax import HCXVisionConfig | |
| from .preprocessor import select_best_resolution | |
| EOT = "<|endofturn|>" | |
| IMG_LOC = "<|dummy3|>" | |
| def get_rank(): | |
| if dist.is_initialized(): | |
| return dist.get_rank() | |
| return 0 | |
| def get_world_size(): | |
| if torch.distributed.is_initialized(): | |
| world_size = torch.distributed.get_world_size() | |
| else: | |
| world_size = 1 | |
| return world_size | |
| def unpad_image(tensor: torch.Tensor, original_size: Tuple[int, int]) -> torch.Tensor: | |
| """Unpads a PyTorch tensor of a padded and resized image. | |
| This function removes padding from a tensor image that was previously padded and resized. | |
| The padding is removed based on the aspect ratio difference between the original and current image dimensions. | |
| Args: | |
| tensor: The image tensor, assumed to be in CxHxW format. | |
| original_size: The original size of the image as (width, height). | |
| Returns: | |
| The unpadded image tensor. | |
| Examples: | |
| >>> import torch | |
| >>> # Example 1: Unpadding with height padding | |
| >>> padded_tensor = torch.randn(1, 64, 48) # Padded tensor (C=1, H=64, W=48) | |
| >>> original_size = (32, 32) # Original size (width=32, height=32) | |
| >>> unpadded_tensor = unpad_image(padded_tensor, original_size) | |
| >>> unpadded_tensor.shape | |
| torch.Size([1, 48, 48]) | |
| >>> # Example 2: Unpadding with width padding | |
| >>> padded_tensor = torch.randn(1, 48, 64) # Padded tensor (C=1, H=48, W=64) | |
| >>> original_size = (32, 32) # Original size (width=32, height=32) | |
| >>> unpadded_tensor = unpad_image(padded_tensor, original_size) | |
| >>> unpadded_tensor.shape | |
| torch.Size([1, 48, 48]) | |
| """ | |
| original_width, original_height = original_size | |
| current_height, current_width = tensor.shape[1:] | |
| original_aspect_ratio = original_width / original_height | |
| current_aspect_ratio = current_width / current_height | |
| if original_aspect_ratio > current_aspect_ratio: | |
| scale_factor = current_width / original_width | |
| new_height = int(original_height * scale_factor) | |
| padding = (current_height - new_height) // 2 | |
| unpadded_tensor = tensor[:, padding : current_height - padding, :] | |
| else: | |
| scale_factor = current_height / original_height | |
| new_width = int(original_width * scale_factor) | |
| padding = (current_width - new_width) // 2 | |
| unpadded_tensor = tensor[:, :, padding : current_width - padding] | |
| return unpadded_tensor | |
| def get_anyres_image_grid_shape( | |
| image_size: Tuple[int, int], | |
| grid_pinpoints: Union[str, List[Tuple[int, int]]], | |
| patch_size: int, | |
| ) -> Tuple[int, int]: | |
| """Calculates the image patch grid shape after any-resolution preprocessing. | |
| Selects the optimal resolution from predefined grid pinpoints based on input image | |
| dimensions using `select_best_resolution`, then computes the grid layout by | |
| dividing the selected resolution by the patch size using integer division. | |
| Args: | |
| image_size (Tuple[int, int]): Original image dimensions in (width, height) format. | |
| grid_pinpoints (Union[str, List[Tuple[int, int]]]): Accepts either: | |
| - List of (height, width) resolution tuples | |
| - String representation of list (e.g., "[(224, 224), (336, 336)]") | |
| patch_size (int): Spatial dimension of square patches for grid division. | |
| Returns: | |
| Tuple[int, int]: Grid dimensions as (num_patches_width, num_patches_height). | |
| Examples: | |
| >>> # Basic case with list input | |
| >>> get_anyres_image_grid_shape((1000, 800), [(224, 224), (448, 448)], 112) | |
| (4, 4) | |
| >>> # Basic case with string input | |
| >>> get_anyres_image_grid_shape((600, 400), "[(336, 336), (672, 672)]", 112) | |
| (6, 6) | |
| >>> # Case where resolution is not perfectly divisible by patch_size | |
| >>> # select_best_resolution picks (224, 224). 224 // 100 = 2 | |
| >>> get_anyres_image_grid_shape((500, 500), [(224, 224)], 100) | |
| (2, 2) | |
| >>> # Different patch size | |
| >>> # select_best_resolution picks (448, 448). 448 // 224 = 2 | |
| >>> get_anyres_image_grid_shape((1200, 900), [(448, 448), (224, 224)], 224) | |
| (2, 2) | |
| Note: | |
| String-formatted grid_pinpoints are converted via ast.literal_eval. Invalid formats | |
| may raise syntax exceptions. The actual resolution selection depends on the | |
| implementation of `select_best_resolution`. The doctests assume | |
| `select_best_resolution` picks the *first* resolution provided in `grid_pinpoints`. | |
| """ | |
| possible_resolutions = grid_pinpoints if isinstance(grid_pinpoints, list) else ast.literal_eval(grid_pinpoints) | |
| original_width, original_height = image_size | |
| height, width = select_best_resolution((original_height, original_width), possible_resolutions) | |
| return width // patch_size, height // patch_size | |
| def reshape_and_unpad_image_features( | |
| image_feature: torch.Tensor, | |
| height: int, | |
| width: int, | |
| image_size: Tuple[int, int], | |
| possible_resolutions: List[Tuple[int, int]], | |
| grid_size: int, | |
| unpad: bool, | |
| image_newline: torch.Tensor, | |
| ) -> torch.Tensor: | |
| """Reshapes and processes image features with optional unpadding operation. | |
| Processes input image features by: | |
| 1. Separating base features from spatial features | |
| 2. Reshaping spatial features into a 5D tensor (num_patch_height, num_patch_width, height, width, channels) | |
| 3. Performing either unpadding operation or simple reshaping based on 'unpad' flag | |
| 4. Concatenating processed features with base features | |
| Args: | |
| image_feature: Input tensor containing image features with shape | |
| [1 + num_patches, feature_dim] where the first element is the base feature | |
| height: Original image height in pixels | |
| width: Original image width in pixels | |
| image_size: Target image size as (width, height) tuple | |
| possible_resolutions: List of possible [height, width] resolutions for multi-scale processing | |
| grid_size: Grid dimension for patch arrangement | |
| unpad: Flag to enable unpadding operation | |
| image_newline: Special token tensor used as separator when unpadding | |
| Returns: | |
| torch.Tensor: Processed image features tensor with shape [1 + num_processed_patches, feature_dim] | |
| Raises: | |
| AssertionError: If base feature dimension doesn't match height*width | |
| """ | |
| base_image_feature = image_feature[0] | |
| image_feature = image_feature[1:] | |
| assert ( | |
| height * width == base_image_feature.shape[0] | |
| ), f"height: {height}, width: {width}, base_image_feature.shape[0]: {base_image_feature.shape[0]}" | |
| num_patch_width, num_patch_height = get_anyres_image_grid_shape(image_size, possible_resolutions, grid_size) | |
| image_feature = image_feature.view(num_patch_height, num_patch_width, height, width, -1) | |
| if unpad: | |
| image_feature = image_feature.permute(4, 0, 2, 1, 3).contiguous() | |
| image_feature = image_feature.flatten(1, 2).flatten(2, 3) | |
| image_feature = unpad_image(image_feature, image_size) | |
| image_feature = torch.cat( | |
| ( | |
| image_feature, | |
| image_newline[:, None, None].expand(*image_feature.shape[:-1], 1).to(image_feature.device), | |
| ), | |
| dim=-1, | |
| ) | |
| image_feature = image_feature.flatten(1, 2).transpose(0, 1) | |
| else: | |
| image_feature = image_feature.permute(0, 2, 1, 3, 4).contiguous() | |
| image_feature = image_feature.flatten(0, 3) | |
| image_feature = torch.cat((base_image_feature, image_feature), dim=0) | |
| return image_feature | |
| def anyres_postprocessing( | |
| image_forward_outs: torch.FloatTensor, | |
| split_sizes: List[int], | |
| image_sizes: List[List[int]], | |
| possible_resolutions: List[Tuple[int, int]], | |
| is_videos: List[bool], | |
| patch_size: int, | |
| grid_size: int, | |
| image_newline: torch.FloatTensor, | |
| num_queries_vis_abstractor: int = -1, | |
| unpad: bool = False, | |
| ) -> List[torch.FloatTensor]: | |
| """Processes 2D visual features into 1D sequences with post-processing steps. | |
| Performs AnyRes postprocessing by flattening 2D visual features from grid partitions into 1D sequences, adding | |
| newline embeddings at row boundaries for images, and optionally removing padding regions based on original image | |
| sizes. For video data, processes each frame's features separately into a single sequence per video and disables | |
| unpadding and newline insertion. | |
| Args: | |
| image_forward_outs (List[torch.FloatTensor]): List of input tensors with shape | |
| (number_of_images_in_grid, total_patches, feature_dim) containing visual features. | |
| split_sizes (List[int]): A list containing the number of patches for each sample in the batch. The sum of | |
| `split_sizes` should equal `image_forward_outs.shape[0]`. | |
| image_sizes (List[List[int]]): A list where each element is a list `[width, height]` representing the original | |
| dimensions of the corresponding image sample. Used for unpadding. | |
| possible_resolutions (List[Tuple[int, int]]): A list of supported resolution tuples `(height, width)` used by | |
| `reshape_and_unpad_image_features` for spatial reconstruction, especially during unpadding. | |
| is_videos (List[bool]): A list of boolean flags indicating whether each corresponding sample in the batch is a | |
| video [`True`] or an image [`False`]. | |
| patch_size (int): The spatial dimension (height and width) of the square patches the image was divided into. | |
| grid_size (int): The spatial dimension (height and width) of the square grid onto which patches are mapped. | |
| `grid_size` should be divisible by `patch_size`. | |
| image_newline (torch.FloatTensor): A learnable tensor representing the newline embedding, typically with shape | |
| (1, feature_dim). Added after each row of image patches when not unpadding. | |
| num_queries_vis_abstractor (int, optional): If a visual abstractor with a fixed number of output queries is used | |
| instead of grid patching, this specifies the number of queries. Must be a perfect square if > 0. | |
| Defaults to -1 (indicating standard grid patching is used). | |
| unpad (bool, optional): If `True`, removes padding tokens from image features based on `image_sizes` and | |
| `possible_resolutions`. Does not apply to video features. Defaults to False. | |
| Returns: | |
| List[torch.FloatTensor]: A list of tensors, where each tensor represents the processed 1D sequence of visual | |
| features for a single sample from the input batch. The length of the sequence varies depending on processing | |
| (unpadding, newlines, video flattening). | |
| Raises: | |
| AssertionError: If `num_queries_vis_abstractor` is greater than 0 but not a perfect square. | |
| """ | |
| height = width = grid_size // patch_size | |
| if num_queries_vis_abstractor > 0: | |
| assert (num_queries_vis_abstractor**0.5).is_integer(), "n_queries must be square number" | |
| height = width = int(num_queries_vis_abstractor**0.5) | |
| image_features = torch.split(image_forward_outs, split_sizes, dim=0) | |
| # post-processing (unpad, add newline) | |
| new_image_features = [] | |
| for image_idx, (image_feature, is_video) in enumerate(zip(image_features, is_videos)): | |
| if image_feature.shape[0] > 1: | |
| if not is_video: | |
| image_feature = reshape_and_unpad_image_features( | |
| image_feature=image_feature, | |
| height=height, | |
| width=width, | |
| image_size=image_sizes[image_idx], | |
| possible_resolutions=possible_resolutions, | |
| grid_size=grid_size, # Pass grid info if needed by helper | |
| unpad=unpad, | |
| image_newline=image_newline, | |
| ) | |
| else: | |
| image_feature = image_feature.flatten(0, 1) | |
| else: | |
| image_feature = image_feature[0] | |
| if unpad and not is_video: | |
| image_feature = torch.cat((image_feature, image_newline[None].to(image_feature.device)), dim=0) | |
| new_image_features.append(image_feature) | |
| image_features = new_image_features | |
| return image_features | |
| def adaptive_anyres_postprocessing( | |
| image_forward_outs: torch.FloatTensor, | |
| image_sizes: List[List[int]], | |
| possible_resolutions: List[Tuple[int, int]], | |
| is_videos: List[bool], | |
| group_ids: List[List[int]], | |
| num_queries_vis_abstractors: List[List[int]], | |
| grid_size: int, | |
| image_newline: torch.FloatTensor, | |
| unpad: bool = False, | |
| ) -> List[torch.FloatTensor]: | |
| """Adaptive AnyRes postprocessing for multi-group feature aggregation. | |
| Processes 2D visual features into 1D sequences with group-wise adaptive processing. Each image can belong to | |
| multiple processing groups with different query configurations. Features are processed per group and aggregated | |
| according to group_ids. | |
| Args: | |
| image_forward_outs (List[torch.FloatTensor]): List of input tensors with shape | |
| (number_of_images_in_grid, total_patches, feature_dim) containing visual features. | |
| image_sizes (List[List[int]]): Original image dimensions for each sample. [[width, height], ... ] | |
| possible_resolutions (List[Tuple[int, int]]): Supported resolutions. [[height, width], ... ] | |
| is_videos (List[bool]): Flags indicating video inputs | |
| group_ids (List[List[int]]): Group indices for feature aggregation. Each group means a single grid. | |
| num_queries_vis_abstractors (List[List[int]]): Query numbers per group | |
| grid_size (int): Total grid size for spatial processing | |
| image_newline (torch.FloatTensor): Sample-wise config. Newline embedding tensor | |
| unpad (bool, optional): Sample-wise config. Enable padding removal. Defaults to False. | |
| Returns: | |
| List[torch.FloatTensor]: Aggregated features per group | |
| Raises: | |
| AssertionError: If num_queries is not square number in any group | |
| """ | |
| # post-processing (unpad, add newline) | |
| new_image_features = [] | |
| for image_idx, (image_feature, is_video) in enumerate(zip(image_forward_outs, is_videos)): | |
| num_queries_vis_abstractor = num_queries_vis_abstractors[image_idx] | |
| assert (num_queries_vis_abstractor**0.5).is_integer(), "n_queries must be square number" | |
| height = width = int(num_queries_vis_abstractor**0.5) | |
| if image_feature.shape[0] > 1: | |
| if not is_video: | |
| image_feature = reshape_and_unpad_image_features( | |
| image_feature=image_feature, | |
| height=height, | |
| width=width, | |
| image_size=image_sizes[image_idx], | |
| possible_resolutions=possible_resolutions, | |
| grid_size=grid_size, | |
| unpad=unpad, | |
| image_newline=image_newline, | |
| ) | |
| else: | |
| image_feature = image_feature.flatten(0, 1) | |
| else: | |
| image_feature = image_feature[0] | |
| if unpad and not is_video: | |
| image_feature = torch.cat((image_feature, image_newline[None].to(image_feature.device)), dim=0) | |
| new_image_features.append(image_feature) | |
| image_features = [ | |
| torch.cat([new_image_features[group_id] for group_id in group_ids_list], dim=0) for group_ids_list in group_ids | |
| ] | |
| return image_features | |
| class HCXVisionOutput(ModelOutput): | |
| """Output class for vision models, containing various computation results. | |
| Args: | |
| loss (Optional[torch.FloatTensor], optional): Total cross-entropy loss calculated from logits and labels. | |
| loss_per_sample (Optional[torch.FloatTensor], optional): Per-sample loss values for advanced loss processing. | |
| logits (torch.FloatTensor): Classification scores (before SoftMax) of shape (batch_size, num_classes). | |
| past_key_values (Optional[Tuple[Tuple[torch.FloatTensor]]], optional): Contains precomputed hidden-states | |
| that can be used (see `past_key_values` input) to speed up sequential decoding. | |
| hidden_states (Optional[Tuple[torch.FloatTensor]], optional): | |
| Tuple of torch.FloatTensor (one for the output of the embeddings + one for the output of each layer) of | |
| shape (batch_size, sequence_length, hidden_size). | |
| Hidden-states of the model at the output of each layer plus the initial embedding outputs. | |
| attentions (Optional[Tuple[torch.FloatTensor]], optional): Tuple of torch.FloatTensor (one for each layer) | |
| of shape (batch_size, num_heads, sequence_length, sequence_length). Attentions weights after the attention | |
| softmax, used to compute the weighted average in the self-attention heads. | |
| """ | |
| loss: Optional[torch.FloatTensor] = None | |
| loss_per_sample: Optional[torch.FloatTensor] = None | |
| logits: torch.FloatTensor = None | |
| past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None | |
| hidden_states: Optional[Tuple[torch.FloatTensor]] = None | |
| attentions: Optional[Tuple[torch.FloatTensor]] = None | |
| class HCXVisionForCausalLM(PreTrainedModel, GenerationMixin): | |
| """HCX Vision model for causal language modeling with vision-language capabilities. | |
| This class combines a vision model with a language model to create a multimodal model | |
| capable of processing images or videos and generating text based on the visual inputs. | |
| Attributes: | |
| config_class: Configuration class for the model. | |
| vision_model_name: Name of the vision model component. | |
| _no_split_modules: List of modules that should not be split during parallel processing. | |
| supports_gradient_checkpointing: Whether the model supports gradient checkpointing. | |
| _skip_keys_device_placement: Keys to skip during device placement. | |
| """ | |
| config_class = HCXVisionConfig | |
| vision_model_name = "vision_model" | |
| _no_split_modules = ["CLIPAttention", "SiglipVisionModel"] | |
| supports_gradient_checkpointing = True | |
| _skip_keys_device_placement = "past_key_values" | |
| def __init__( | |
| self, | |
| config: HCXVisionConfig, | |
| **kwargs: Optional[Any], | |
| ) -> None: | |
| """Initialize the HCXVisionForCausalLM model. | |
| Args: | |
| config: Configuration object for the model containing parameters for both | |
| vision and language components. | |
| **kwargs: Additional keyword arguments: | |
| - use_liger: Whether to use liger kernel for hyperclovax models. | |
| - use_fused_ce: Whether to use fused cross-entropy loss. | |
| - use_sum_loss: Whether to use sum reduction for loss instead of mean. | |
| - is_safetensor_save: Whether to save model using safetensors format. | |
| Raises: | |
| ValueError: If vision_config is not defined or if language_config is not defined. | |
| """ | |
| super().__init__(config) | |
| self.flag_changed_max_position_embeddings = False | |
| vision_model_type = config.vision_config["model_type"] | |
| if vision_model_type in CONFIG_MAPPING: | |
| vision_config = CONFIG_MAPPING[vision_model_type](**config.vision_config) | |
| vision_config.auto_map = {} | |
| else: | |
| if config.vision_model_name_or_path is not None: | |
| vision_config = AutoConfig.from_pretrained(config.vision_model_name_or_path, trust_remote_code=True) | |
| elif config.vision_config["_name_or_path"] is not None: | |
| vision_config = AutoConfig.from_pretrained( | |
| config.vision_config["_name_or_path"], trust_remote_code=True | |
| ) | |
| else: | |
| raise ValueError("vision_config is not defined") | |
| self.use_liger = kwargs.pop("use_liger", False) | |
| self.use_fused_ce = kwargs.pop("use_fused_ce", False) | |
| self.reduction = "sum" if kwargs.pop("use_sum_loss", False) else "mean" | |
| self.vision_config = vision_config | |
| vision_config.anyres = config.anyres | |
| vision_config.max_num_grids = config.max_num_grids | |
| possible_resolutions = [] | |
| if config.anyres: | |
| assert config.max_num_grids > 0 | |
| for i in range(1, config.max_num_grids + 1): | |
| for j in range(1, config.max_num_grids + 1): | |
| if i == 1 and j == 1 and not config.use_1x1_grid: | |
| continue | |
| if i * j <= config.max_num_grids: | |
| possible_resolutions.append([i, j]) | |
| possible_resolutions = [ | |
| [ys * vision_config.image_size, xs * vision_config.image_size] for ys, xs in possible_resolutions | |
| ] | |
| self.possible_resolutions = possible_resolutions | |
| with no_init_weights(): | |
| self.vision_model = AutoModel.from_config( | |
| vision_config, trust_remote_code=True | |
| ) # weight will be loaded in from_pretrained | |
| assert config.language_config["model_type"] == "llama" | |
| language_config = CONFIG_MAPPING["llama"](**config.language_config) | |
| language_config._attn_implementation = kwargs.get("attn_implementation", "sdpa") # activate flash attention | |
| language_config.logits_scaling = 1.0 | |
| self.language_config = language_config | |
| self.language_model = AutoModelForCausalLM.from_config(language_config) | |
| self.language_model.gradient_checkpointing_enable() | |
| self.num_queries_vis_abstractor = config.num_queries_vis_abstractor | |
| # mm_projctor(==connector); vision_model_hidden_size -> LLM embedding size | |
| input_hidden_size = vision_config.hidden_size | |
| self.mm_projector = HCXVisionCAbstractor( | |
| num_queries=self.num_queries_vis_abstractor, | |
| num_input_tokens=(self.vision_config.image_size // self.vision_config.patch_size) ** 2, | |
| encoder_hidden_size=input_hidden_size, | |
| hidden_size=input_hidden_size, | |
| output_hidden_size=language_config.hidden_size, | |
| pos_emb=config.proj_pos_emb, | |
| prenorm=config.proj_prenorm, | |
| ) | |
| self.use_nth_layer = config.use_nth_layer | |
| self.config.update({"vision_config": self.vision_model.config.to_dict()}) | |
| self.config.update({"language_config": self.language_model.config.to_dict()}) | |
| self.lm_head_vocab_size = ( | |
| language_config.padded_vocab_size | |
| if hasattr(language_config, "padded_vocab_size") | |
| else language_config.vocab_size | |
| ) | |
| self.language_model.lm_head = nn.Linear(language_config.hidden_size, self.lm_head_vocab_size, bias=False) | |
| self.model_parallel = False | |
| self.device_map = None | |
| self.use_no_grad = None | |
| self.decoder_max_length = config.decoder_max_length | |
| self.anyres = config.anyres | |
| self.unpad = config.unpad | |
| if self.anyres: | |
| self.image_newline = nn.Parameter(torch.empty(language_config.hidden_size, dtype=self.dtype)) | |
| self.is_safetensor_save = kwargs.get("is_safetensor_save", True) | |
| self._backward_compatibility_gradient_checkpointing() | |
| def _init_weights(self, module): | |
| # copies from https://github.com/kakaobrain/honeybee/blob/main/honeybee/common_layers.py#L55 | |
| if ( | |
| isinstance(module, nn.Conv2d) # noqa: SIM101 | |
| or isinstance(module, nn.Embedding) | |
| or isinstance(module, nn.Linear) | |
| ): | |
| module.weight.data.normal_(mean=0.0, std=0.02) | |
| if hasattr(module, "bias") and module.bias is not None: | |
| module.bias.data.zero_() | |
| elif isinstance(module, nn.LayerNorm): | |
| module.bias.data.zero_() | |
| module.weight.data.fill_(1.0) | |
| elif isinstance(module, nn.Parameter): | |
| embed_std = 1 / torch.sqrt(torch.tensor(module.size(0), dtype=torch.float)).to(module.dtype) | |
| module.data.normal_(mean=0.0, std=embed_std) | |
| def forward( | |
| self, | |
| input_ids: Optional[torch.LongTensor] = None, | |
| pixel_values: Optional[List[List[torch.FloatTensor]]] = None, | |
| past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None, | |
| attention_mask: Optional[torch.FloatTensor] = None, | |
| inputs_embeds: Optional[torch.FloatTensor] = None, | |
| labels: Optional[torch.LongTensor] = None, | |
| use_cache: Optional[bool] = None, | |
| output_attentions: Optional[bool] = None, | |
| output_hidden_states: Optional[bool] = None, | |
| return_dict: Optional[bool] = None, | |
| image_sizes: Optional[List[List[List[int]]]] = None, | |
| vision_query_lengths: Optional[List[List[int]]] = None, | |
| non_vision_query_lengths: Optional[List[int]] = None, | |
| img_start_ids_list: Optional[List[List[int]]] = None, | |
| num_queries_vis_abstractors: Optional[List[List[int]]] = None, | |
| num_queries_vis_abstractors_slow: Optional[List[List[int]]] = None, | |
| first_last_frames_slows: Optional[List[bool]] = None, | |
| is_video_list: Optional[List[bool]] = None, | |
| **kwargs, | |
| ) -> Union[Tuple, HCXVisionOutput]: | |
| """Forward pass of the model. | |
| This method processes the input tokens and images, combines them into a unified | |
| representation, and generates text output based on the inputs. | |
| Args: | |
| input_ids: Input token IDs. In positions where images are inputted, the value is replaced by "<|dummy3|>" | |
| pixel_values: List of lists of 4D tensors for images. Each outer list corresponds to a batch and contains | |
| inner lists of image tensors. | |
| past_key_values: Pre-computed key and value states of the attention layers for faster inference. | |
| attention_mask: Mask to avoid performing attention on padding token indices. | |
| inputs_embeds: Input embeddings. If provided, input_ids will not be used. | |
| labels: Labels for computing the language modeling loss. | |
| use_cache: Whether to use past key/values for faster inference. | |
| output_attentions: Whether to return attention weights of each layer. | |
| output_hidden_states: Whether to return hidden states of each layer. | |
| return_dict: Whether to return a ModelOutput instead of a tuple. | |
| image_sizes: List of lists representing image dimensions (width, height). | |
| vision_query_lengths: List of lists containing lengths when each image is converted into visual tokens. | |
| non_vision_query_lengths: List of lengths of text tokens (excluding visual tokens) for each sample. | |
| img_start_ids_list: List of lists containing indices of img_start_id tokens for each sample. | |
| num_queries_vis_abstractors: List of lists containing number of visual tokens for each image grid.\ | |
| For video frames, this is the number of visual tokens for the fast part. | |
| num_queries_vis_abstractors_slow: List of lists containing number of visual tokens for | |
| the slow part when applying the slowfast algorithm to video frames. | |
| first_last_frames_slows: List of booleans indicating whether the slowfast algorithm is | |
| applied to the first or last frames of the video. | |
| is_video_list: List of booleans indicating which inputs are videos. | |
| **kwargs: Additional keyword arguments. | |
| Returns: | |
| If return_dict=True, returns an HCXVisionOutput object containing: | |
| - loss: Language modeling loss if labels are provided, otherwise None. | |
| - loss_per_sample: Per-sample loss if labels are provided, otherwise None. | |
| - logits: Prediction scores of the language modeling head. | |
| - past_key_values: Past key/values for faster inference if use_cache=True. | |
| - hidden_states: Hidden states of all layers if output_hidden_states=True. | |
| - attentions: Attention weights of all layers if output_attentions=True. | |
| If return_dict=False, returns a tuple containing the above items except loss_per_sample. | |
| """ | |
| output_attentions = ( | |
| output_attentions if output_attentions is not None else self.config.vision_config["output_attentions"] | |
| ) | |
| output_hidden_states = ( | |
| output_hidden_states | |
| if output_hidden_states is not None | |
| else self.config.vision_config["output_hidden_states"] | |
| ) | |
| return_dict = return_dict if return_dict is not None else self.config.use_return_dict | |
| if inputs_embeds is None and past_key_values is None: | |
| inputs_embeds = self.extract_inputs_embeds( | |
| input_ids=input_ids, | |
| pixel_values=pixel_values, | |
| past_key_values=past_key_values, | |
| image_sizes=image_sizes, | |
| vision_query_lengths=vision_query_lengths, | |
| non_vision_query_lengths=non_vision_query_lengths, | |
| img_start_ids_list=img_start_ids_list, | |
| num_queries_vis_abstractors=num_queries_vis_abstractors, | |
| num_queries_vis_abstractors_slow=num_queries_vis_abstractors_slow, | |
| first_last_frames_slows=first_last_frames_slows, | |
| is_videos=is_video_list, | |
| ) | |
| if inputs_embeds is not None: | |
| input_ids = None | |
| # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) | |
| outputs = self.language_model.base_model( | |
| input_ids=input_ids, | |
| inputs_embeds=inputs_embeds, | |
| attention_mask=attention_mask, | |
| past_key_values=past_key_values, | |
| use_cache=use_cache, | |
| output_attentions=output_attentions, | |
| output_hidden_states=output_hidden_states, | |
| return_dict=return_dict, | |
| ) | |
| hidden_states = outputs[0] | |
| hidden_states = hidden_states * self.language_config.logits_scaling | |
| loss = None | |
| loss_per_sample = None | |
| logits = self.language_model.lm_head(hidden_states) | |
| if labels is not None: | |
| # Shift so that tokens < n predict n | |
| shift_logits = logits[..., :-1, :].contiguous() | |
| shift_labels = labels[..., 1:].contiguous() | |
| # Flatten the tokens | |
| loss_fct = CrossEntropyLoss(reduction="none") # ignore IGNORE_INDEX(-100) | |
| shift_logits = shift_logits.view(-1, self.lm_head_vocab_size) | |
| shift_labels = shift_labels.view(-1) | |
| # Enable model/pipeline parallelism | |
| shift_labels = shift_labels.to(shift_logits.device) | |
| loss = loss_fct(shift_logits, shift_labels) | |
| if get_rank() == 0: | |
| loss_per_sample = loss.view(logits.shape[0], -1).sum(axis=1) / ( | |
| shift_labels.view(logits.shape[0], -1) != self.config.ignore_index | |
| ).sum(axis=1) | |
| loss = loss[shift_labels != self.config.ignore_index].mean() | |
| if not return_dict: | |
| output = (logits,) + outputs[1:] | |
| return (loss,) + output if loss is not None else output | |
| return HCXVisionOutput( | |
| loss=loss, | |
| loss_per_sample=loss_per_sample, | |
| logits=logits, | |
| past_key_values=outputs.past_key_values, | |
| hidden_states=outputs.hidden_states, | |
| attentions=outputs.attentions, | |
| ) | |
| def determine_non_vision_query_lengths( | |
| self, input_ids: torch.LongTensor, pad_id: int, img_start_id: int | |
| ) -> List[int]: | |
| """Calculate the lengths of non-vision query parts in the input. | |
| This method calculates the length of text tokens (excluding visual tokens) for each sample. | |
| When input_ids are collated, they are padded with pad_id on the right, so this method finds | |
| these values by identifying pad tokens and img_start_id tokens. | |
| Args: | |
| input_ids: Input token IDs with img_start_id markers for image positions. | |
| pad_id: Token ID used for padding. | |
| img_start_id: Token ID marking the start of image data. | |
| Returns: | |
| List of lengths of non-vision query parts for each sample in the batch. | |
| """ | |
| non_vision_query_lengths = [] | |
| batch_size, len_seq = input_ids.size(0), input_ids.size(1) | |
| for i in range(batch_size): | |
| temp_idx = (input_ids[i] == pad_id).nonzero() | |
| eos_idx = temp_idx[0, 0].item() if len(temp_idx) > 0 else len_seq | |
| num_imgs = (input_ids[i] == img_start_id).sum().item() | |
| non_vision_query_lengths.append(eos_idx - num_imgs) | |
| if all([pad_id in input_id for input_id in input_ids.tolist()]): | |
| non_vision_query_lengths = [ | |
| non_vision_query_length + 1 for non_vision_query_length in non_vision_query_lengths | |
| ] | |
| return non_vision_query_lengths | |
| def determine_vision_query_lengths( | |
| self, image_features: List[List[torch.Tensor]], image_cnts: List[int] | |
| ) -> List[List[int]]: | |
| """Calculate the lengths of vision query parts in the input. | |
| This method calculates the lengths of visual tokens for each image in each sample based on | |
| the shapes of image feature tensors. For samples without any images, a dummy image is included | |
| but then converted to an empty list. | |
| Args: | |
| image_features: List of lists of image features tensors. | |
| image_cnts: List of counts of images for each sample in the batch. | |
| Returns: | |
| List of lists of lengths of visual tokens for each image in each sample. | |
| """ | |
| vision_query_lengths = [ | |
| [image_feature.size(0) for image_feature in image_feature_list] for image_feature_list in image_features | |
| ] | |
| for i, image_cnt in enumerate(image_cnts): | |
| if image_cnt == 0: | |
| assert len(vision_query_lengths[i]) == 1 # 현재 검정 이미지 1개 들어가있음 | |
| vision_query_lengths[i] = [] # 빈 list 로 변환 | |
| return vision_query_lengths | |
| # Copied from transformers.models.llava.modeling_llava.LlavaForConditionalGeneration.get_input_embeddings | |
| def get_input_embeddings(self): | |
| return self.language_model.get_input_embeddings() | |
| # Copied from transformers.models.llava.modeling_llava.LlavaForConditionalGeneration.set_input_embeddings | |
| def set_input_embeddings(self, value): | |
| self.language_model.set_input_embeddings(value) | |
| # Copied from transformers.models.llava.modeling_llava.LlavaForConditionalGeneration.get_output_embeddings | |
| def get_output_embeddings(self): | |
| return self.language_model.get_output_embeddings() | |
| # Copied from transformers.models.llava.modeling_llava.LlavaForConditionalGeneration.set_output_embeddings | |
| def set_output_embeddings(self, new_embeddings): | |
| self.language_model.set_output_embeddings(new_embeddings) | |
| # Copied from transformers.models.llava.modeling_llava.LlavaForConditionalGeneration.set_decoder | |
| def set_decoder(self, decoder): | |
| self.language_model.set_decoder(decoder) | |
| # Copied from transformers.models.llava.modeling_llava.LlavaForConditionalGeneration.get_decoder | |
| def get_decoder(self): | |
| return self.language_model.get_decoder() | |
| # Copied from transformers.models.llava.modeling_llava.LlavaForConditionalGeneration.tie_weights | |
| def tie_weights(self): | |
| return self.language_model.tie_weights() | |
| # Copied from transformers.models.llava.modeling_llava.LlavaForConditionalGeneration.resize_token_embeddings | |
| def resize_token_embeddings(self, new_num_tokens: Optional[int] = None, pad_to_multiple_of=None) -> nn.Embedding: | |
| model_embeds = self.language_model.resize_token_embeddings(new_num_tokens, pad_to_multiple_of) | |
| self.config.text_config.vocab_size = model_embeds.num_embeddings | |
| self.vocab_size = model_embeds.num_embeddings | |
| return model_embeds | |
| def extract_inputs_embeds( | |
| self, | |
| input_ids: Optional[torch.LongTensor] = None, | |
| pixel_values: Optional[List[List[torch.FloatTensor]]] = None, # list of list of 4D tensors | |
| past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None, | |
| image_sizes: Optional[List[List[List[int]]]] = None, | |
| vision_query_lengths: Optional[List[List[int]]] = None, | |
| non_vision_query_lengths: Optional[List[int]] = None, | |
| img_start_ids_list: Optional[List[List[int]]] = None, | |
| num_queries_vis_abstractors: Optional[List[List[int]]] = None, | |
| num_queries_vis_abstractors_slow: Optional[List[List[int]]] = None, | |
| first_last_frames_slows: Optional[List[bool]] = None, | |
| is_videos: Optional[List[str]] = None, | |
| ): | |
| """Extract input embeddings by processing text tokens and visual features. | |
| This method processes the input tokens and image features, extracts the visual features | |
| using the vision model, and combines them with the text token embeddings to create | |
| a unified input representation for the language model. | |
| Args: | |
| input_ids: Input token IDs with img_start_id markers for image positions. | |
| pixel_values: List of lists of image tensors. | |
| past_key_values: Pre-computed key and value states for faster inference. | |
| image_sizes: List of lists of image dimensions (width, height). | |
| vision_query_lengths: List of lists of lengths when each image is converted to visual tokens. | |
| non_vision_query_lengths: List of lengths of text tokens (excluding visual tokens) for each sample. | |
| img_start_ids_list: List of lists containing indices of img_start_id tokens for each sample. | |
| num_queries_vis_abstractors: List of lists containing number of visual tokens for each image grid. | |
| num_queries_vis_abstractors_slow: List of lists containing number of visual tokens for | |
| the slow part when applying the slowfast algorithm to video frames. | |
| first_last_frames_slows: List of booleans indicating whether the slowfast algorithm is | |
| applied to the first or last frames of the video. | |
| is_videos: List of booleans indicating which inputs are videos. | |
| Returns: | |
| Combined embeddings of text tokens and visual features. | |
| """ | |
| inputs_embeds = None | |
| if past_key_values: | |
| pass | |
| else: | |
| # Flatten CLIP and connector for feature encoding, then convert back to List of List format | |
| len_pixel_values = [len(pixel_value) for pixel_value in pixel_values] | |
| concat_pixel_values = torch.cat(list(chain(*pixel_values)), dim=0) # list of list of 4D Tensor | |
| visual_token_idx = 0 if "siglip" in self.vision_config.model_type else 1 | |
| # Check if all parameters of the model require_grad=False | |
| if self.use_no_grad is None: | |
| self.use_no_grad = all(not p.requires_grad for p in self.vision_model.vision_model.encoder.parameters()) | |
| context = torch.no_grad() if self.use_no_grad else contextlib.nullcontext() | |
| with context: | |
| if self.use_no_grad: | |
| # Fixed number of for-loop iterations to 10. | |
| # Currently no memory effect observed, so proceeding without chunking. | |
| n_chunks = 1 | |
| else: | |
| n_chunks = 1 | |
| total_len = concat_pixel_values.size(0) | |
| # Calculate the size of each chunk based on total data length (divided into 10 chunks) | |
| chunk_size = math.ceil(total_len / n_chunks) if total_len > 0 else 1 | |
| image_forward_outs_chunks = [] | |
| for i in range(n_chunks): | |
| start = i * chunk_size | |
| end = (i + 1) * chunk_size | |
| # Current chunk slice (could be an empty tensor if there's no data) | |
| chunk = concat_pixel_values[start:end].to(self.vision_model.dtype) | |
| # If the current chunk size is smaller than chunk_size, pad with dummy data | |
| if chunk.size(0) < chunk_size: | |
| # print(f"chunk.size(0): {chunk.size(0)}, chunk_size: {chunk_size}") | |
| pad_size = chunk_size - chunk.size(0) | |
| # Create dummy tensor based on concat_pixel_values shape | |
| dummy_shape = (pad_size,) + tuple(concat_pixel_values.shape[1:]) | |
| dummy = torch.zeros( | |
| dummy_shape, | |
| dtype=concat_pixel_values.dtype, | |
| device=concat_pixel_values.device, | |
| ) | |
| chunk = torch.cat([chunk, dummy], dim=0) | |
| # Pass the chunk through the vision model (processed according to use_nth_layer) | |
| if self.use_nth_layer == -1: | |
| # Replace post_layernorm of the last layer with Identity | |
| self.vision_model.vision_model.post_layernorm = nn.Identity() | |
| outs = self.vision_model(chunk) | |
| outs = outs.last_hidden_state[:, visual_token_idx:] | |
| else: | |
| outs = self.vision_model(chunk, output_hidden_states=True) | |
| outs = outs.hidden_states[self.use_nth_layer][:, visual_token_idx:] | |
| image_forward_outs_chunks.append(outs) | |
| # Concatenate results from all chunks | |
| image_forward_outs = torch.cat(image_forward_outs_chunks, dim=0).to(image_forward_outs_chunks[0].dtype) | |
| if num_queries_vis_abstractors is None: | |
| assert num_queries_vis_abstractors_slow is None | |
| image_sizes = list(chain(*image_sizes)) | |
| if is_videos is not None: | |
| is_videos = list(chain(*is_videos)) | |
| group_ids = None | |
| image_forward_outs = image_forward_outs.to(dtype=self.mm_projector.dtype) | |
| image_forward_outs = self.mm_projector(image_forward_outs) | |
| else: | |
| # adaptive anyres is only implemented in HCXVisionCAbstractor | |
| assert isinstance(self.mm_projector, HCXVisionCAbstractor) | |
| ( | |
| num_queries_vis_abstractors, | |
| num_grids, | |
| image_sizes, | |
| is_videos, | |
| group_ids, | |
| ) = self.compute_adaptive_params( | |
| pixel_values, | |
| num_queries_vis_abstractors, | |
| num_queries_vis_abstractors_slow, | |
| image_sizes, | |
| is_videos, | |
| first_last_frames_slows, | |
| ) | |
| image_forward_outs = image_forward_outs.to(dtype=self.mm_projector.dtype) | |
| image_forward_outs = self.mm_projector( | |
| image_forward_outs, | |
| num_queries_vis_abstractors=num_queries_vis_abstractors, | |
| num_grids=num_grids, | |
| ) | |
| if self.anyres: | |
| split_sizes = [pixel_value.shape[0] for pixel_value in chain(*pixel_values)] | |
| if num_queries_vis_abstractors is None: | |
| image_features = anyres_postprocessing( | |
| image_forward_outs=image_forward_outs, | |
| split_sizes=split_sizes, | |
| image_sizes=image_sizes, | |
| num_queries_vis_abstractor=self.num_queries_vis_abstractor, | |
| unpad=self.unpad, | |
| is_videos=is_videos, | |
| patch_size=self.vision_model.config.patch_size, | |
| grid_size=self.vision_model.config.image_size, | |
| image_newline=self.image_newline, | |
| possible_resolutions=self.possible_resolutions, | |
| ) | |
| else: | |
| image_features = adaptive_anyres_postprocessing( | |
| image_forward_outs=image_forward_outs, | |
| image_sizes=image_sizes, | |
| num_queries_vis_abstractors=num_queries_vis_abstractors, | |
| unpad=self.unpad, | |
| is_videos=is_videos, | |
| grid_size=self.vision_model.config.image_size, | |
| image_newline=self.image_newline, | |
| possible_resolutions=self.possible_resolutions, | |
| group_ids=group_ids, | |
| ) | |
| else: | |
| if num_queries_vis_abstractors is None: | |
| image_features = [image_forward_out for image_forward_out in image_forward_outs] | |
| else: | |
| image_features = [image_forward_out.unsqueeze(0) for image_forward_out in image_forward_outs] | |
| # print(f"BEFORE GROUPING: len(image_features): {len(image_features)}") | |
| image_features = [ | |
| image_features[sum(len_pixel_values[:i]) : sum(len_pixel_values[: i + 1])] | |
| for i in range(len(len_pixel_values)) | |
| ] | |
| batch_size = input_ids.size(0) | |
| image_feature_dim = image_features[0][0].size(1) | |
| image_feature_dtype = image_features[0][0].dtype | |
| if img_start_ids_list is None: | |
| image_cnts = (input_ids == self.config.img_start_id).sum(dim=1).tolist() | |
| else: | |
| image_cnts = [len(img_start_ids) for img_start_ids in img_start_ids_list] | |
| if non_vision_query_lengths is None: | |
| non_vision_query_lengths = self.determine_non_vision_query_lengths( | |
| input_ids, self.tokenizer.pad_token_id, self.config.img_start_id | |
| ) | |
| if vision_query_lengths is None: | |
| vision_query_lengths = self.determine_vision_query_lengths(image_features, image_cnts) | |
| # Slicing is faster than concatenation | |
| len_inputs_embeds = max( | |
| [ | |
| sum(vision_query_length) + non_vision_query_length | |
| for non_vision_query_length, vision_query_length in zip( | |
| non_vision_query_lengths, vision_query_lengths | |
| ) | |
| ] | |
| ) | |
| len_inputs_embeds = min(self.decoder_max_length, len_inputs_embeds) | |
| inputs_embeds = torch.zeros( | |
| [batch_size, len_inputs_embeds, image_feature_dim], | |
| dtype=image_feature_dtype, | |
| device=self.device, | |
| requires_grad=True, | |
| ).clone() | |
| # temp_embeds : torch.bfloat16 : [batchsize, 174, 3072] | |
| temp_embeds = self.get_input_embeddings()(input_ids) | |
| # The complete format is <PROMPT><USER_PREFIX><VISION_QUERIES>Sentence | |
| for batch_idx, sample in enumerate(input_ids): | |
| # Concatenate with visual tokens and then slice | |
| non_vision_query_length = non_vision_query_lengths[batch_idx] | |
| # Safely concatenate with visual tokens and then slice | |
| sample = sample[: non_vision_query_length + image_cnts[batch_idx]] | |
| if image_cnts[batch_idx] == 0: # Text instruction data doesn't insert image features | |
| temp_idx = 0 | |
| # Reference: https://github.com/haotian-liu/LLaVA/commit/44e0562f9497fb79f042427307472a87d266d90a#diff-4477387d506ccb1897a13972cba26c9da3fad4d3e1c32ec4b8bd8ff7acd3f292 | |
| # https://github.com/intel/intel-extension-for-transformers/issues/1201#issuecomment-1915875119 | |
| inputs_embeds[batch_idx, :non_vision_query_length] = temp_embeds[batch_idx][ | |
| :non_vision_query_length | |
| ] | |
| inputs_embeds[batch_idx, temp_idx:temp_idx] = image_features[batch_idx][0][ | |
| 0:0 | |
| ] # First image of batch_idx sample (dummy image) | |
| else: | |
| if img_start_ids_list is None: | |
| img_start_ids = (sample == self.config.img_start_id).nonzero() | |
| else: | |
| img_start_ids = img_start_ids_list[batch_idx] | |
| assert len(img_start_ids) == image_cnts[batch_idx] == len(image_features[batch_idx]) | |
| # Initialize starting points for input embeddings and temporary embeddings | |
| input_start, temp_start = 0, 0 | |
| # Iterate through each image starting point in the batch | |
| for multi_img_idx, img_start_idx in enumerate(img_start_ids): | |
| # Calculate token length up to the current image starting point | |
| token_len = img_start_idx - temp_start | |
| # Copy tokens to inputs_embeds | |
| inputs_embeds[batch_idx, input_start : input_start + token_len] = temp_embeds[ | |
| batch_idx, temp_start : temp_start + token_len | |
| ] | |
| inputs_embeds[ | |
| batch_idx, | |
| input_start | |
| + token_len : input_start | |
| + token_len | |
| + vision_query_lengths[batch_idx][multi_img_idx], | |
| ] = image_features[batch_idx][multi_img_idx] | |
| # Update starting points for next token processing | |
| input_start += token_len + vision_query_lengths[batch_idx][multi_img_idx] | |
| temp_start += token_len + 1 # Increase by 1 to skip the image start token | |
| # Process tokens after the last image end token | |
| token_len = min(sample[temp_start:].size(0), inputs_embeds.size(1) - input_start) | |
| inputs_embeds[batch_idx, input_start : input_start + token_len] = temp_embeds[ | |
| batch_idx, temp_start : temp_start + token_len | |
| ] | |
| return inputs_embeds | |
| def generate( | |
| self, | |
| input_ids: Optional[torch.LongTensor] = None, | |
| pixel_values: Optional[List[List[torch.FloatTensor]]] = None, | |
| image_sizes: Optional[List[List[List[int]]]] = None, | |
| vision_query_lengths: Optional[List[List[int]]] = None, | |
| non_vision_query_lengths: Optional[List[int]] = None, | |
| num_queries_vis_abstractors: Optional[List[List[int]]] = None, | |
| num_queries_vis_abstractors_slow: Optional[List[List[int]]] = None, | |
| first_last_frames_slows: Optional[List[bool]] = None, | |
| is_videos: Optional[List[bool]] = None, | |
| img_start_ids_list: Optional[List[List[int]]] = None, | |
| pad_token_id: Optional[int] = None, | |
| eos_token_id: Optional[int] = None, | |
| bad_words_ids: Optional[List[List[int]]] = None, | |
| max_length: int = 196, | |
| min_length: int = 2, | |
| do_sample: bool = True, | |
| num_beams: int = 1, | |
| top_p: float = 0.6, | |
| top_k: int = 0, | |
| temperature: float = 0.5, | |
| repetition_penalty: float = 1.0, | |
| length_penalty: int = 1, | |
| use_cache: bool = True, | |
| **kwargs, | |
| ) -> torch.LongTensor: | |
| """Generate text based on input tokens and images. | |
| This method generates text based on the provided input tokens and images using | |
| beam search and/or sampling strategies. | |
| Args: | |
| input_ids: Input token IDs with img_start_id markers for image positions. | |
| pixel_values: List of lists of image tensors. | |
| image_sizes: List of lists of image dimensions (width, height). | |
| vision_query_lengths: List of lists of lengths when each image is converted to visual tokens. | |
| non_vision_query_lengths: List of lengths of text tokens (excluding visual tokens) for each sample. | |
| num_queries_vis_abstractors: List of lists containing number of visual tokens for each image grid. | |
| num_queries_vis_abstractors_slow: List of lists containing number of visual tokens for the slow part when | |
| applying the slowfast algorithm to video frames. | |
| first_last_frames_slows: List of booleans indicating whether the slowfast algorithm is applied to the first | |
| or last frames of the video. | |
| is_videos: List of booleans indicating which inputs are videos. | |
| img_start_ids_list: List of lists containing indices of img_start_id tokens for each sample. | |
| pad_token_id: Token ID used for padding. | |
| eos_token_id: Token ID used to signal the end of a sequence. | |
| bad_words_ids: List of token ID sequences that should not be generated. | |
| max_length: Maximum length of the sequence to be generated (input length + max_new_tokens). | |
| min_length: Minimum length of the sequence to be generated (input length + min_new_tokens). | |
| do_sample: Whether to use sampling for generation (otherwise uses greedy decoding). | |
| num_beams: Number of beams for beam search. 1 means no beam search. | |
| top_p: Nucleus sampling parameter. Tokens with cumulative probability > top_p are kept. | |
| top_k: Number of highest probability tokens to keep for top-k-filtering. | |
| temperature: Value used to modulate the next token probabilities. | |
| repetition_penalty: Penalty applied to tokens that have already appeared in the sequence. | |
| length_penalty: Exponential penalty applied to sequence length. | |
| use_cache: Whether to use past key/values for faster inference. | |
| **kwargs: Additional keyword arguments. | |
| Returns: | |
| Generated token IDs. | |
| """ | |
| # inputs_embeds: torch.bfloat16 : [batchsize, variable(visual token, text token, system prompt 모두 포함)] | |
| if pad_token_id is None: | |
| pad_token_id = self.tokenizer.pad_token_id | |
| if eos_token_id is None: | |
| eos_token_id = self.tokenizer.encode("<|endofturn|>")[0] | |
| if bad_words_ids is None: | |
| bad_words_ids = [ | |
| [ | |
| self.config.language_config["bos_token_id"], | |
| ], | |
| [ | |
| self.config.language_config["eos_token_id"], | |
| ], | |
| ] | |
| if pixel_values is None: | |
| return self.language_model.generate( | |
| input_ids, pad_token_id=pad_token_id, eos_token_id=eos_token_id, bad_words_ids=bad_words_ids, **kwargs | |
| ) | |
| inputs_embeds = self.extract_inputs_embeds( | |
| input_ids=input_ids, | |
| pixel_values=self.to_vision_model_device(pixel_values), | |
| image_sizes=image_sizes, | |
| vision_query_lengths=vision_query_lengths, | |
| non_vision_query_lengths=non_vision_query_lengths, | |
| img_start_ids_list=img_start_ids_list, | |
| num_queries_vis_abstractors=num_queries_vis_abstractors, | |
| num_queries_vis_abstractors_slow=num_queries_vis_abstractors_slow, | |
| first_last_frames_slows=first_last_frames_slows, | |
| is_videos=is_videos, | |
| ) | |
| inputs_embeds = ( | |
| inputs_embeds.to(self.base_model.device) if isinstance(inputs_embeds, torch.Tensor) else inputs_embeds | |
| ) | |
| # pred : torch.int64 : [batchsize, generated token_length] | |
| pred = self.language_model.generate( | |
| inputs_embeds=inputs_embeds, | |
| pad_token_id=pad_token_id, | |
| eos_token_id=eos_token_id, | |
| bad_words_ids=bad_words_ids, | |
| max_new_tokens=max_length, | |
| min_length=min_length, | |
| num_beams=num_beams, | |
| do_sample=(False if temperature == 0.0 else do_sample), # set do_sample=False if invalid temperature | |
| top_k=top_k, | |
| top_p=top_p, | |
| temperature=temperature, | |
| repetition_penalty=repetition_penalty, | |
| length_penalty=length_penalty, | |
| early_stopping=(False if num_beams <= 1 else True), # set early_stopping=False when not beam_search | |
| use_cache=use_cache, | |
| ) | |
| return pred | |
| def to_vision_model_device(self, input_tensor: Union[torch.Tensor, List]) -> Union[torch.Tensor, List]: | |
| """Move input tensors to the vision model's device. | |
| This method recursively moves input tensors or lists of tensors to the vision model's device. | |
| Args: | |
| input_tensor: Input tensor or list of tensors to be moved to the vision model's device. | |
| Returns: | |
| The input tensor or list of tensors moved to the vision model's device. | |
| Raises: | |
| TypeError: If the input is neither a tensor nor a list. | |
| """ | |
| if isinstance(input_tensor, list): | |
| return [self.to_vision_model_device(item) for item in input_tensor] | |
| elif isinstance(input_tensor, torch.Tensor): | |
| return input_tensor.to(self.vision_model.device) | |
| else: | |
| raise TypeError("Unsupported data type. Only tensors and lists are allowed.") | |
| def prepare_inputs_for_generation( | |
| self, | |
| input_ids: torch.LongTensor, | |
| past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None, | |
| attention_mask: Optional[torch.FloatTensor] = None, | |
| inputs_embeds: Optional[torch.FloatTensor] = None, | |
| **kwargs, | |
| ) -> Dict[str, Any]: | |
| """Prepare inputs for the generation algorithm. | |
| This method prepares the input for each generation step based on the model's needs. | |
| Args: | |
| input_ids: Input token IDs. | |
| past_key_values: Pre-computed key and value states for faster inference. | |
| attention_mask: Mask to avoid performing attention on padding token indices. | |
| inputs_embeds: Input embeddings. If provided, input_ids will not be used. | |
| **kwargs: Additional keyword arguments. | |
| Returns: | |
| Dictionary containing the prepared inputs for the model. | |
| """ | |
| input_ids = kwargs.get("decoder_input_ids", input_ids) | |
| if past_key_values: | |
| input_ids = input_ids[:, -1:] | |
| # if `inputs_embeds` are passed, we only want to use them in the 1st generation step | |
| if inputs_embeds is not None and past_key_values is None: | |
| model_inputs = {"inputs_embeds": inputs_embeds} | |
| else: | |
| model_inputs = {"input_ids": input_ids} | |
| model_inputs.update( | |
| { | |
| "past_key_values": past_key_values, | |
| "use_cache": kwargs.get("use_cache"), | |
| "attention_mask": attention_mask, | |
| "pixel_values": kwargs.get("pixel_values", None), | |
| } | |
| ) | |
| return model_inputs | |
| def from_config(cls, config, vision_model_name_or_path): | |
| return cls(config, vision_model_name_or_path) | |
| def from_pretrained( | |
| cls, | |
| pretrained_model_name_or_path: Optional[Union[str, os.PathLike]] = None, | |
| *model_args, | |
| **kwargs, | |
| ) -> "HCXVisionForCausalLM": | |
| assert pretrained_model_name_or_path is not None | |
| save_only_vision = kwargs.pop("save_only_vision") if "save_only_vision" in kwargs else False | |
| save_only_qformer = kwargs.pop("save_only_qformer") if "save_only_qformer" in kwargs else False | |
| save_shard_size = kwargs.pop("save_shard_size") if "save_shard_size" in kwargs else "5GB" | |
| if pretrained_model_name_or_path is not None: # when evaluate or load instruction tunned model | |
| model: HCXVisionForCausalLM = super().from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs) | |
| model.tokenizer = AutoTokenizer.from_pretrained(pretrained_model_name_or_path) | |
| img_start_id = model.tokenizer.encode(IMG_LOC, add_special_tokens=False) | |
| assert ( | |
| len(img_start_id) == 1 | |
| ), f'"<|dummy3|>" was not encoded into a single special token. Encoding result: {img_start_id}' | |
| model.config.img_start_id = img_start_id[0] | |
| model.save_only_vision = save_only_vision | |
| model.save_only_qformer = save_only_qformer | |
| model.save_shard_size = save_shard_size | |
| return model | |
| def get_language_model(self): | |
| return self.language_model.base_model | |
| def get_vision_model(self): | |
| return self.vision_model | |
| def save_pretrained( | |
| self, | |
| save_directory: Union[str, os.PathLike], | |
| *args, | |
| **kwargs, | |
| ): | |
| state_dict = kwargs["state_dict"] if "state_dict" in kwargs else self.state_dict() | |
| partial_state_dict = self.get_pretrained_state_dict( | |
| state_dict, | |
| save_directory, | |
| ) | |
| kwargs["state_dict"] = partial_state_dict | |
| kwargs["safe_serialization"] = self.is_safetensor_save | |
| kwargs.setdefault("max_shard_size", self.save_shard_size) | |
| super().save_pretrained(save_directory, *args, **kwargs) | |
| def get_pretrained_state_dict(self, state_dict, save_dir): | |
| vision_key = "vision_model." | |
| llm_keys = ["language_model."] | |
| head_key = "lm_head." | |
| for key in list(state_dict.keys()): | |
| if self.save_only_vision: | |
| for llm_key in llm_keys: | |
| if llm_key in key: | |
| state_dict.pop(key) | |
| if key.startswith(head_key): | |
| state_dict.pop(key) | |
| elif self.save_only_qformer: | |
| if f"{vision_key}" in key: | |
| state_dict.pop(key) | |
| return state_dict | |
| def compute_adaptive_params( | |
| self, | |
| pixel_values: Optional[List[List[torch.FloatTensor]]] = None, | |
| num_queries_vis_abstractors: Optional[List[List[int]]] = None, | |
| num_queries_vis_abstractors_slow: Optional[List[List[int]]] = None, | |
| image_sizes: Optional[List[List[List[int]]]] = None, | |
| is_videos: Optional[List[bool]] = None, | |
| first_last_frames_slows: Optional[List[bool]] = None, | |
| ) -> Tuple[List[int], List[int], List[List[int]], List[bool], List[List[int]]]: | |
| """Compute adaptive parameters for processing different image and video inputs. | |
| This method calculates parameters needed for adaptive processing, especially when handling | |
| variable resolutions or applying the slowfast algorithm to video frames. It flattens | |
| batch-level inputs (lists of lists) into single lists representing all images/frames | |
| in the batch. Based on slowfast configuration, it may split video frames into 'slow' | |
| and 'fast' components, adjusting query counts and grid indices accordingly. | |
| Args: | |
| pixel_values: List of lists of image tensors (per sample). Used to determine the initial number of grids per | |
| image/frame. | |
| num_queries_vis_abstractors: List of lists (per sample) containing the base number of visual tokens | |
| generated by the visual abstractor for each image grid | |
| (e.g., 81 for a full grid, 9 for a subsampled/fast grid). | |
| num_queries_vis_abstractors_slow: List of lists (per sample) containing the number of visual tokens for the | |
| 'slow' path when applying slowfast. Non-zero values here trigger the slowfast processing logic. | |
| image_sizes: List of lists (per sample) of original image dimensions ([width, height]). | |
| is_videos: List of lists (per sample) of booleans indicating if each input item is part of a video sequence. | |
| first_last_frames_slows: List (per sample) of booleans. If True, slowfast logic | |
| (if active based on `num_queries_vis_abstractors_slow`) is applied only to the first or last frame(s) | |
| within each video sequence. | |
| Returns: | |
| Tuple containing: | |
| - num_queries_vis_abstractors: Flattened list of final query counts per processed grid. | |
| Values might be adjusted based on slow/fast splitting | |
| (e.g., using values from `num_queries_vis_abstractors_slow` for slow frames). | |
| Example: [81, 81, 81, 9, 81, 9, ...] (Image, Image, Vid_Slow, Vid_Fast, Vid_Slow, Vid_Fast...) | |
| - num_grids: Flattened list representing cumulative grid counts, acting as end indices for slicing the | |
| flattened `image_forward_outs`. Adjusted for slow/fast splits. | |
| Example: [0, 1, 9, 10, 18, 19, 27, ...] (Indices after Grid0_Slow(1), | |
| Grid1_Fast(8), Grid2_Slow(1), Grid3_Fast(8)...). | |
| - image_sizes: Flattened list of image dimensions ([width, height]), potentially duplicated if slow/fast | |
| splitting occurred. | |
| - is_videos: Flattened list of booleans indicating video status, potentially duplicated for | |
| slow/fast splits. Example: [False, False, True, True, True, True, ...] | |
| (Image1, Image2, Vid_grid1_slow, Vid_grid1_fast, Vid_grid2_slow, Vid_grid2_fast...) | |
| - group_ids: List of lists, grouping indices that correspond to the same original image or frame. | |
| If a frame is split into slow/fast, its group will contain multiple indices. | |
| Example: [[0], [1], [2, 3], [4, 5], ...] | |
| (Group for Image1, Group for Image2, Group for Vid1_Slow+Fast, Group for Vid2_Slow+Fast...). | |
| Raises: | |
| AssertionError: If input validation fails (e.g., negative query counts). | |
| Exception: If an unexpected case is encountered during slowfast processing. | |
| """ | |
| # Check if all elements are integers greater than or equal to 0 | |
| assert all( | |
| all(isinstance(value, int) and value >= 0 for value in sublist) for sublist in num_queries_vis_abstractors | |
| ), "All values in num_queries_vis_abstractors must be integers >= 0." | |
| assert all( | |
| all(isinstance(value, int) and value >= 0 for value in sublist) | |
| for sublist in num_queries_vis_abstractors_slow | |
| ), "All values in num_queries_vis_abstractors_slow must be integers >= 0." | |
| assert is_videos is not None | |
| # Is it the first or last image? (for applying slowfast to video processing) | |
| is_first_images = [] | |
| is_last_images = [] | |
| for is_video in is_videos: | |
| for idx, is_video_item in enumerate(is_video): | |
| if idx == 0: | |
| is_first_images.append(True) | |
| else: | |
| is_first_images.append(False) | |
| if idx == len(is_video) - 1: | |
| is_last_images.append(True) | |
| else: | |
| is_last_images.append(False) | |
| num_queries_vis_abstractors = list(chain(*num_queries_vis_abstractors)) | |
| num_queries_vis_abstractors_slow = list(chain(*num_queries_vis_abstractors_slow)) | |
| image_sizes = list(chain(*image_sizes)) | |
| is_videos = list(chain(*is_videos)) | |
| first_last_frames_slows = list(chain(*first_last_frames_slows)) | |
| # Use slowfast mode if there's at least one visual token count greater than 0 in num_queries_vis_abstractors_slow | |
| use_slowfast = any([num_query > 0 for num_query in num_queries_vis_abstractors_slow]) | |
| num_grids = [pixel_value.shape[0] for pixel_value in chain(*pixel_values)] | |
| num_grids = [0] + num_grids | |
| group_ids = [] | |
| if use_slowfast: | |
| new_num_grids = [num_grids[0]] | |
| new_num_queries = [] | |
| new_image_sizes = [] | |
| new_is_videos = [] | |
| # When using slowfast, split more finely | |
| # 0th local grid is slow frame, remaining local grids are fast frames | |
| for ( | |
| num_query, | |
| num_query_slow, | |
| num_grid, | |
| image_size, | |
| is_video, | |
| first_last_frames_slow, | |
| is_first_image, | |
| is_last_image, | |
| ) in zip( | |
| num_queries_vis_abstractors, | |
| num_queries_vis_abstractors_slow, | |
| num_grids[1:], | |
| image_sizes, | |
| is_videos, | |
| first_last_frames_slows, | |
| is_first_images, | |
| is_last_images, | |
| ): | |
| if not first_last_frames_slow and num_query_slow > 0: # Process all image in slowfast mode | |
| assert is_video # slowfast mode is only applied to videos | |
| this_group_ids = [group_ids[-1][-1] + 1 if group_ids else 0] | |
| # slow frame (first grid) | |
| new_num_grids.append(new_num_grids[-1] + 1) | |
| new_num_queries.append(num_query_slow) | |
| new_image_sizes.append(image_size) | |
| new_is_videos.append(is_video) | |
| if num_grid >= 2: | |
| # fast frames | |
| new_num_grids.append(new_num_grids[-1] + num_grid - 1) | |
| new_num_queries.append(num_query) | |
| new_image_sizes.append(image_size) | |
| new_is_videos.append(is_video) | |
| this_group_ids.append(this_group_ids[-1] + 1) | |
| group_ids.append(this_group_ids) | |
| elif ( | |
| first_last_frames_slow and num_query_slow > 0 and (is_first_image or is_last_image) | |
| ): # Process only first/last image in slowfast mode | |
| # Case for special treatment of first/last frames in slow mode | |
| assert is_video # slowfast mode is only applied to videos | |
| this_group_ids = [group_ids[-1][-1] + 1 if group_ids else 0] | |
| if num_grid == 1: | |
| # Simply process with slow since there's only one grid | |
| new_num_grids.append(new_num_grids[-1] + 1) | |
| new_num_queries.append(num_query_slow) | |
| new_image_sizes.append(image_size) | |
| new_is_videos.append(is_video) | |
| if num_grid >= 2: | |
| # Special treatment for first or last grid depending on is_first_image or is_last_image | |
| if is_first_image: # includes both first and last | |
| # slow frame (first grid) | |
| new_num_grids.append(new_num_grids[-1] + 1) | |
| new_num_queries.append(num_query_slow) | |
| new_image_sizes.append(image_size) | |
| new_is_videos.append(is_video) | |
| # fast frames | |
| new_num_grids.append(new_num_grids[-1] + num_grid - 1) | |
| new_num_queries.append(num_query) | |
| new_image_sizes.append(image_size) | |
| new_is_videos.append(is_video) | |
| this_group_ids.append(this_group_ids[-1] + 1) | |
| elif is_last_image: | |
| # fast frames | |
| new_num_grids.append(new_num_grids[-1] + num_grid - 1) | |
| new_num_queries.append(num_query) | |
| new_image_sizes.append(image_size) | |
| new_is_videos.append(is_video) | |
| # slow frame (last grid) | |
| new_num_grids.append(new_num_grids[-1] + 1) | |
| new_num_queries.append(num_query_slow) | |
| new_image_sizes.append(image_size) | |
| new_is_videos.append(is_video) | |
| this_group_ids.append(this_group_ids[-1] + 1) | |
| else: | |
| raise Exception("This case should not be reached.") | |
| group_ids.append(this_group_ids) | |
| else: | |
| # Not in slowfast mode, so reduce all by num_query (fast) | |
| new_num_grids.append(new_num_grids[-1] + num_grid) | |
| new_num_queries.append(num_query) | |
| new_image_sizes.append(image_size) | |
| new_is_videos.append(is_video) | |
| start_group_id = group_ids[-1][-1] + 1 if group_ids else 0 | |
| group_ids.append([start_group_id]) | |
| num_grids = new_num_grids | |
| num_queries_vis_abstractors = new_num_queries | |
| image_sizes = new_image_sizes | |
| is_videos = new_is_videos | |
| else: | |
| num_grids = [sum(num_grids[:i]) for i in range(1, len(num_grids) + 1)] | |
| group_ids = [[group_id] for group_id in range(len(is_videos))] | |
| return num_queries_vis_abstractors, num_grids, image_sizes, is_videos, group_ids | |
| def load_state_dict_into_model(model_to_load, state_dict, strict=True, start_prefix=""): | |
| # from https://github.com/huggingface/transformers/blob/0a55d9f7376f72ad3ff296d4249840021b03bcc4/src/transformers/modeling_utils.py#L517 | |
| # Convert old format to new format if needed from a PyTorch state_dict | |
| old_keys = [] | |
| new_keys = [] | |
| for key in state_dict.keys(): | |
| new_key = None | |
| if "gamma" in key: | |
| new_key = key.replace("gamma", "weight") | |
| if "beta" in key: | |
| new_key = key.replace("beta", "bias") | |
| if new_key: | |
| old_keys.append(key) | |
| new_keys.append(new_key) | |
| for old_key, new_key in zip(old_keys, new_keys): | |
| state_dict[new_key] = state_dict.pop(old_key) | |
| # copy state_dict so _load_from_state_dict can modify it | |
| metadata = getattr(state_dict, "_metadata", None) | |
| state_dict = state_dict.copy() | |
| if metadata is not None: | |
| state_dict._metadata = metadata | |
| error_msgs = [] | |
| # PyTorch's `_load_from_state_dict` does not copy parameters in a module's descendants | |
| # so we need to apply the function recursively. | |
| def load(module: nn.Module, state_dict, prefix=""): | |
| local_metadata = {} if metadata is None else metadata.get(prefix[:-1], {}) | |
| args = (state_dict, prefix, local_metadata, strict, [], [], error_msgs) | |
| # Parameters of module and children will start with prefix. We can exit early if there are none in this | |
| # state_dict | |
| if len([key for key in state_dict if key.startswith(prefix)]) > 0: | |
| if is_deepspeed_zero3_enabled(): | |
| import deepspeed | |
| # In sharded models, each shard has only part of the full state_dict, so only gather | |
| # parameters that are in the current state_dict. | |
| named_parameters = dict(module.named_parameters(prefix=prefix[:-1], recurse=False)) | |
| params_to_gather = [named_parameters[k] for k in state_dict.keys() if k in named_parameters] | |
| if len(params_to_gather) > 0: | |
| # because zero3 puts placeholders in model params, this context | |
| # manager gathers (unpartitions) the params of the current layer, then loads from | |
| # the state dict and then re-partitions them again | |
| with deepspeed.zero.GatheredParameters(params_to_gather, modifier_rank=0): | |
| if torch.distributed.get_rank() == 0: | |
| module._load_from_state_dict(*args) | |
| else: | |
| module._load_from_state_dict(*args) | |
| for name, child in module._modules.items(): | |
| if child is not None: | |
| load(child, state_dict, prefix + name + ".") | |
| load(model_to_load, state_dict, prefix=start_prefix) | |
| # Delete `state_dict` so it could be collected by GC earlier. Note that `state_dict` is a copy of the argument, so | |
| # it's safe to delete it. | |
| del state_dict | |
| return error_msgs | |
| class HCXVisionCAbstractor(nn.Module): | |
| """ | |
| This module is based on C-Abstractor, whose license is under apache-2.0. | |
| You can check the original code at https://github.com/khanrc/honeybee/blob/main/honeybee/projectors/projectors.py | |
| and we made necessary modifications. | |
| """ | |
| def __init__( | |
| self, | |
| num_queries: int, | |
| num_input_tokens: int, | |
| encoder_hidden_size: int, | |
| hidden_size: int, | |
| output_hidden_size: int, | |
| pos_emb: bool = True, | |
| prenorm: bool = False, | |
| ): | |
| super().__init__() | |
| self.num_input_tokens = num_input_tokens | |
| self.output_hidden_size = output_hidden_size | |
| # Positional embedding | |
| if pos_emb: | |
| self.pos_emb = torch.nn.Parameter(torch.zeros(1, num_input_tokens, encoder_hidden_size)) | |
| self.pos_emb.data.normal_(mean=0.0, std=0.02) | |
| else: | |
| self.pos_emb = None | |
| # (Optional) Pre-normalization layer | |
| if prenorm: | |
| self.prenorm = LayerNorm(encoder_hidden_size) | |
| else: | |
| self.prenorm = None | |
| self.build_net(num_queries, encoder_hidden_size, hidden_size, output_hidden_size) | |
| self.dtype = next(self.parameters()).dtype | |
| def forward( | |
| self, | |
| x: torch.Tensor, | |
| num_queries_vis_abstractors: Optional[List[List[int]]] = None, | |
| num_grids: Optional[List[int]] = None, | |
| ) -> torch.Tensor: | |
| """ | |
| Args: | |
| x: (B, L, encoder_hidden_size) tensor from the visual backbone (e.g. CLIP visual encoder), including cls token. | |
| """ | |
| if self.prenorm is not None: | |
| x = self.prenorm(x) | |
| if self.pos_emb is not None: | |
| x = x + self.pos_emb | |
| x = self._forward( | |
| x, | |
| num_queries_vis_abstractors=num_queries_vis_abstractors, | |
| num_grids=num_grids, | |
| ) # (B, L, output_hidden_size) | |
| return x | |
| def _forward( | |
| self, | |
| x: torch.Tensor, | |
| num_queries_vis_abstractors: Optional[List[List[int]]] = None, | |
| num_grids: Optional[List[int]] = None, | |
| ) -> torch.Tensor: | |
| # x: [B, L, dim] | |
| B, L, dim = x.shape | |
| hw = int(L ** 0.5) | |
| x = rearrange(x, "b (h w) d -> b d h w", h=hw, w=hw) | |
| if num_queries_vis_abstractors is not None: | |
| assert num_grids is not None | |
| return self._forward_adaptive_num_query(x, num_queries_vis_abstractors, num_grids) | |
| x = self.net(x) | |
| x = rearrange(x, "b d h w -> b (h w) d") | |
| x = self.readout(x) | |
| return x | |
| def _forward_adaptive_num_query( | |
| self, | |
| x: torch.Tensor, | |
| num_queries_vis_abstractors: Optional[List[List[int]]] = None, | |
| num_grids: Optional[List[int]] = None, | |
| ) -> List[torch.Tensor]: | |
| # self.net is consisted by 3 layers (s1, sampler, s2) | |
| assert len(self.net) == 3 | |
| x = self.net[0](x) # s1 | |
| new_x = [] | |
| for i, num_queries in enumerate(num_queries_vis_abstractors): | |
| hw = int(num_queries**0.5) | |
| sampler = nn.AdaptiveAvgPool2d((hw, hw)) | |
| out = sampler(x[num_grids[i]:num_grids[i + 1], :]) | |
| out = self.net[2](out) # s2 | |
| out = rearrange(out, "b d h w -> b (h w) d") | |
| out = self.readout(out) | |
| new_x.append(out) | |
| return new_x | |
| def build_net( | |
| self, | |
| n_queries: int, | |
| encoder_hidden_size: int, | |
| hidden_size: int, | |
| output_hidden_size: int, | |
| depth: int = 3, | |
| mlp_depth: int = 2, | |
| ): | |
| assert (n_queries ** 0.5).is_integer(), f"n_queries must be square number. n_queries: {n_queries}" | |
| hw = int(n_queries ** 0.5) | |
| # RegBlock = ResBlock + SE | |
| RegBlock = partial( | |
| RegStage, | |
| stride=1, | |
| dilation=1, | |
| act_layer=nn.SiLU, | |
| norm_layer=LayerNorm2d, | |
| ) | |
| s1 = RegBlock( | |
| depth, | |
| encoder_hidden_size, | |
| hidden_size, | |
| ) | |
| sampler = nn.AdaptiveAvgPool2d((hw, hw)) | |
| s2 = RegBlock( | |
| depth, | |
| hidden_size, | |
| hidden_size, | |
| ) | |
| self.net = nn.Sequential(s1, sampler, s2) | |
| self.readout = self.build_mlp(mlp_depth, hidden_size, output_hidden_size) | |
| def build_mlp( | |
| self, | |
| depth: int, | |
| hidden_size: int, | |
| output_hidden_size: int, | |
| ): | |
| layers = [nn.Linear(hidden_size, output_hidden_size)] | |
| for _ in range(1, depth): | |
| layers.append(nn.SiLU()) | |
| layers.append(nn.Linear(output_hidden_size, output_hidden_size)) | |
| return nn.Sequential(*layers) | |
| def load_sharded_checkpoint( | |
| model, folder, pick_prefix="", replace_prefix_list=[], replace_prefix_dict={}, print_info=True | |
| ): | |
| if folder is None: | |
| return {} | |
| files = os.listdir(folder) | |
| # find relevant files | |
| pytorch_bin_files = [file for file in files if file.startswith("pytorch_model") and file.endswith(".bin")] | |
| safetensor_files = [file for file in files if file.endswith(".safetensors")] | |
| shard_index_file = [file for file in files if file.endswith(".index.json")] | |
| # check if sharded | |
| index_present = len(shard_index_file) > 0 | |
| index_file = os.path.join(folder, shard_index_file[0]) if index_present else [] | |
| # check if safetensor | |
| is_safetensor = len(safetensor_files) > 0 | |
| model_keys = model.state_dict().keys() | |
| if is_safetensor: | |
| from safetensors.torch import load_file | |
| load_function = load_file | |
| shard_files = safetensor_files | |
| else: | |
| load_function = partial(torch.load, map_location="cpu") | |
| shard_files = pytorch_bin_files | |
| # sharded case | |
| if index_present: | |
| with open(index_file, "r", encoding="utf-8") as f: | |
| index = json.load(f) | |
| loaded_keys = index["weight_map"].keys() | |
| if pick_prefix: | |
| loaded_keys = [k[len(pick_prefix) :] for k in loaded_keys if k.startswith(pick_prefix)] | |
| if replace_prefix_list: | |
| for rep_prefix in replace_prefix_list: | |
| loaded_keys = [k[len(rep_prefix) :] if k.startswith(rep_prefix) else k for k in loaded_keys] | |
| if replace_prefix_dict: | |
| for rep_prefix in replace_prefix_dict: | |
| loaded_keys = [ | |
| k.replace(rep_prefix, replace_prefix_dict[rep_prefix]) if k.startswith(rep_prefix) else k | |
| for k in loaded_keys | |
| ] | |
| for i, shard_file in enumerate(shard_files): | |
| state_dict = load_function(os.path.join(folder, shard_file)) | |
| # if pick_prefix, use only pick | |
| if pick_prefix: | |
| state_dict = {k[len(pick_prefix) :]: v for k, v in state_dict.items() if k.startswith(pick_prefix)} | |
| for rep_prefix in replace_prefix_list: | |
| state_dict = {k[len(rep_prefix) :] if k.startswith(rep_prefix) else k: v for k, v in state_dict.items()} | |
| for rep_prefix in replace_prefix_dict: | |
| state_dict = { | |
| k.replace(rep_prefix, replace_prefix_dict[rep_prefix]) if k.startswith(rep_prefix) else k: v | |
| for k, v in state_dict.items() | |
| } | |
| if is_deepspeed_zero3_enabled(): | |
| # torch.distributed.barrier() | |
| rank = torch.distributed.get_rank() | |
| print(f"# [info] ZeRo3 - load sharded no {i}, rank {rank}") | |
| load_state_dict_into_model(model, state_dict, strict=False) | |
| elif is_fsdp_enabled(): | |
| if is_local_dist_rank_0(): | |
| model.load_state_dict(state_dict, strict=False) | |
| else: | |
| model.load_state_dict(state_dict, strict=False) | |
| # Make sure memory is freed before we load the next state dict. | |
| if not index_present: | |
| loaded_keys = state_dict.keys() | |
| del state_dict | |
| gc.collect() | |
| # missing keys | |
| missing_keys = [key for key in model_keys if key not in loaded_keys] | |
| unexpected_keys = [key for key in loaded_keys if key not in model_keys] | |
| if get_rank() == 0 and print_info: | |
| print(f"[info] missing_keys: {missing_keys}") | |
| print(f"[info] unexpected_keys: {unexpected_keys}") | |
| return {"missing_keys": missing_keys, "unexpected_keys": unexpected_keys} | |