| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| | from typing import Callable, Optional, Tuple, Union |
| |
|
| | import numpy as np |
| | import math |
| | import torch |
| | import torch.nn as nn |
| | import torch.nn.functional as F |
| |
|
| | from PIL import Image |
| |
|
| | from transformers.activations import ACT2FN |
| | from transformers.cache_utils import Cache, DynamicCache |
| | from transformers.masking_utils import create_causal_mask |
| | from transformers.modeling_outputs import ( |
| | BaseModelOutputWithPast, |
| | CausalLMOutputWithPast, |
| | ) |
| | from transformers.processing_utils import Unpack |
| | from transformers.modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update |
| | from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel, load_state_dict |
| | from transformers.generation import GenerationMixin |
| | from transformers.utils import logging, TransformersKwargs |
| |
|
| | from .moondream3_moe_fused.moe_fused_linear import MoeFusedLinear |
| | from .moondream3_moe_fused.kernels.indexing import get_expert_counts_and_idx |
| | from .configuration_moondream3 import Moondream3Config, Moondream3TextConfig, Moondream3VisionConfig, Moondream3RegionConfig |
| |
|
| | from . import modeling_moondream3 |
| |
|
| | logger = logging.get_logger(__name__) |
| |
|
| | _CONFIG_FOR_DOC = "Moondream3Config" |
| |
|
| | class Moondream3FusedSparseMoeBlock(nn.Module): |
| | def __init__(self, config: Moondream3TextConfig) -> None: |
| | super().__init__() |
| | self.num_experts = config.num_experts |
| | self.num_selected = config.num_experts_per_tok |
| | self.hidden_size = config.hidden_size |
| | self.moe_intermediate_size = config.moe_intermediate_size |
| |
|
| | self.gate = nn.Linear(config.hidden_size, config.num_experts, bias=False) |
| | self.gate_proj = MoeFusedLinear(self.hidden_size, self.moe_intermediate_size, config.num_experts) |
| | self.up_proj = MoeFusedLinear(self.hidden_size, self.moe_intermediate_size, config.num_experts) |
| | self.down_proj = MoeFusedLinear(self.moe_intermediate_size, self.hidden_size, config.num_experts) |
| |
|
| | def forward(self, hidden_states: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: |
| | batch_size, sequence_length, hidden_dim = hidden_states.shape |
| | M = batch_size * sequence_length |
| |
|
| | hidden_states = hidden_states.view(M, hidden_dim) |
| | |
| | router_logits = self.gate(hidden_states) |
| |
|
| | routing_weights = F.softmax(router_logits, dim=1, dtype=torch.float32) |
| | |
| | routing_weights, selected_experts = torch.topk(routing_weights, self.num_selected, dim=-1) |
| | routing_weights /= routing_weights.sum(dim=-1, keepdim=True) |
| | |
| | routing_weights = routing_weights.to(hidden_states.dtype) |
| |
|
| | hidden_states = hidden_states.unsqueeze(1).expand(M, self.num_selected, hidden_dim) |
| | |
| | hidden_states = hidden_states.reshape(M * self.num_selected, hidden_dim) |
| | selected_experts = selected_experts.view(M * self.num_selected) |
| |
|
| | |
| | |
| | m_sizes, sort_idx, inv_sort_idx = get_expert_counts_and_idx(selected_experts, self.num_experts) |
| | hidden_states = hidden_states[sort_idx] |
| |
|
| | |
| | gate_h = self.gate_proj(hidden_states, m_sizes) |
| | up_h = self.up_proj(hidden_states, m_sizes) |
| | hidden_states = F.gelu(up_h) * (gate_h + 1) |
| | del gate_h, up_h |
| | hidden_states = self.down_proj(hidden_states, m_sizes) |
| |
|
| | hidden_states = hidden_states[inv_sort_idx] |
| |
|
| | hidden_states = hidden_states.view(M, self.num_selected, hidden_dim) |
| | hidden_states = torch.einsum("beo,be->bo", hidden_states, routing_weights) |
| |
|
| | hidden_states = hidden_states.view(batch_size, sequence_length, hidden_dim) |
| | return hidden_states, router_logits |
| |
|
| | modeling_moondream3.Moondream3SparseMoeBlock = Moondream3FusedSparseMoeBlock |
| | from .modeling_moondream3 import Moondream3Config, Moondream3TextConfig, Moondream3VisionConfig, Moondream3RegionConfig, Moondream3PreTrainedModel, Moondream3Model, Moondream3TextModel, Moondream3VisionModel, Moondream3ForConditionalGeneration |
| |
|
| |
|
| | class Moondream3ForConditionalGeneration(Moondream3PreTrainedModel, GenerationMixin): |
| | _tied_weights_keys = ["lm_head.weight"] |
| |
|
| | def __init__(self, config: Moondream3Config): |
| | super().__init__(config) |
| | self.model = Moondream3Model(config) |
| | self.vocab_size = config.text_config.vocab_size |
| | self.lm_head = nn.Linear(config.text_config.hidden_size, config.text_config.vocab_size, bias=True) |
| | self.post_init() |
| |
|
| | def get_input_embeddings(self): |
| | return self.model.text_model.embed_tokens |
| |
|
| | def set_input_embeddings(self, value): |
| | self.model.text_model.embed_tokens = value |
| |
|
| | def get_output_embeddings(self): |
| | return self.lm_head |
| |
|
| | def set_output_embeddings(self, new_embeddings): |
| | self.lm_head = new_embeddings |
| |
|
| | def set_decoder(self, decoder): |
| | self.model.text_model = decoder |
| |
|
| | def get_decoder(self): |
| | return self.model.text_model |
| |
|
| | def forward( |
| | self, |
| | input_ids: torch.LongTensor = None, |
| | pixel_values: torch.FloatTensor = None, |
| | tiling: torch.LongTensor = None, |
| | attention_mask: Optional[torch.Tensor] = None, |
| | position_ids: Optional[torch.LongTensor] = None, |
| | past_key_values: Optional[Cache] = None, |
| | inputs_embeds: Optional[torch.FloatTensor] = None, |
| | labels: Optional[torch.LongTensor] = None, |
| | use_cache: Optional[bool] = None, |
| | output_attentions: Optional[bool] = None, |
| | output_hidden_states: Optional[bool] = None, |
| | return_dict: Optional[bool] = None, |
| | cache_position: Optional[torch.LongTensor] = None, |
| | logits_to_keep: int = 0, |
| | **kwargs: Unpack[TransformersKwargs], |
| | ) -> Union[Tuple, CausalLMOutputWithPast]: |
| | |
| | model_outputs = self.model( |
| | input_ids=input_ids, |
| | pixel_values=pixel_values, |
| | tiling=tiling, |
| | attention_mask=attention_mask, |
| | position_ids=position_ids, |
| | past_key_values=past_key_values, |
| | inputs_embeds=inputs_embeds, |
| | labels=None, |
| | use_cache=use_cache, |
| | output_attentions=output_attentions, |
| | output_hidden_states=output_hidden_states, |
| | return_dict=True, |
| | cache_position=cache_position, |
| | logits_to_keep=logits_to_keep, |
| | ) |
| |
|
| | hidden_states = model_outputs.last_hidden_state |
| |
|
| | |
| | if isinstance(logits_to_keep, int) and logits_to_keep > 0: |
| | hs = hidden_states[:, -logits_to_keep:, :] |
| | elif isinstance(logits_to_keep, slice): |
| | hs = hidden_states[:, logits_to_keep, :] |
| | else: |
| | hs = hidden_states |
| |
|
| | logits = self.lm_head(hs) |
| |
|
| | loss = None |
| | if labels is not None: |
| | |
| | loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.vocab_size) |
| |
|
| | return CausalLMOutputWithPast( |
| | loss=loss, |
| | logits=logits, |
| | past_key_values=getattr(model_outputs, "past_key_values", None), |
| | hidden_states=getattr(model_outputs, "hidden_states", None), |
| | attentions=getattr(model_outputs, "attentions", None), |
| | ) |
| |
|
| | @classmethod |
| | def _load_pretrained_model( |
| | cls, |
| | model: "PreTrainedModel", |
| | state_dict: Optional[dict], |
| | checkpoint_files: Optional[list[str]], |
| | pretrained_model_name_or_path, |
| | weights_only: bool = True, |
| | **kwargs, |
| | ): |
| | if checkpoint_files is not None: |
| | state_dict = {} |
| | for file in checkpoint_files: |
| | sd = load_state_dict(file, map_location="cpu", weights_only=weights_only) |
| | for key, value in sd.items(): |
| | state_dict[key] = value |
| |
|
| | from collections import defaultdict |
| |
|
| | moe_layer_experts = defaultdict(set) |
| |
|
| | for key in state_dict.keys(): |
| | if key.startswith("model.text_model.layers."): |
| | parts = key.split(".") |
| | |
| | if len(parts) > 6 and parts[5] == "experts" and parts[3].isdigit() and parts[6].isdigit(): |
| | layer_idx = int(parts[3]) |
| | expert_idx = int(parts[6]) |
| | moe_layer_experts[layer_idx].add(expert_idx) |
| |
|
| | moe_layers = {layer: len(experts) for layer, experts in moe_layer_experts.items()} |
| | for layer_idx, num_experts in moe_layers.items(): |
| | state_dict[f"model.text_model.layers.{layer_idx}.mlp.down_proj.weight"] = torch.stack( |
| | [ |
| | state_dict[f"model.text_model.layers.{layer_idx}.mlp.experts.{i}.down_proj.weight"] for i in range(num_experts) |
| | ] |
| | ) |
| | for i in range(num_experts): |
| | del state_dict[f"model.text_model.layers.{layer_idx}.mlp.experts.{i}.down_proj.weight"] |
| |
|
| | state_dict[f"model.text_model.layers.{layer_idx}.mlp.up_proj.weight"] = torch.stack( |
| | [ |
| | state_dict[f"model.text_model.layers.{layer_idx}.mlp.experts.{i}.up_proj.weight"] for i in range(num_experts) |
| | ] |
| | ) |
| | for i in range(num_experts): |
| | del state_dict[f"model.text_model.layers.{layer_idx}.mlp.experts.{i}.up_proj.weight"] |
| |
|
| | state_dict[f"model.text_model.layers.{layer_idx}.mlp.gate_proj.weight"] = torch.stack( |
| | [ |
| | state_dict[f"model.text_model.layers.{layer_idx}.mlp.experts.{i}.gate_proj.weight"] for i in range(num_experts) |
| | ] |
| | ) |
| | for i in range(num_experts): |
| | del state_dict[f"model.text_model.layers.{layer_idx}.mlp.experts.{i}.gate_proj.weight"] |
| | checkpoint_files = None |
| |
|
| | model, missing_keys, unexpected_keys, mismatched_keys, disk_offload_index, error_msgs = super()._load_pretrained_model( |
| | model, |
| | state_dict, |
| | checkpoint_files, |
| | pretrained_model_name_or_path, |
| | **kwargs, |
| | ) |
| | return model, missing_keys, unexpected_keys, mismatched_keys, disk_offload_index, error_msgs |
| |
|
| | def _fix_state_dict_keys_on_save(self, state_dict: dict): |
| | for layer_idx in range(self.config.text_config.moe_start_layer, self.config.text_config.num_hidden_layers): |
| | layer_key = f"model.text_model.layers.{layer_idx}" |
| | tensor = state_dict.pop(f"{layer_key}.mlp.down_proj.weight").cpu() |
| | for i, t in enumerate(torch.unbind(tensor)): |
| | base_key = f"{layer_key}.mlp.experts.{i}" |
| | state_dict[f"{base_key}.down_proj.weight"] = t.contiguous() |
| |
|
| | tensor = state_dict.pop(f"{layer_key}.mlp.up_proj.weight").cpu() |
| | for i, t in enumerate(torch.unbind(tensor)): |
| | base_key = f"{layer_key}.mlp.experts.{i}" |
| | state_dict[f"{base_key}.up_proj.weight"] = t.contiguous() |
| |
|
| | tensor = state_dict.pop(f"{layer_key}.mlp.gate_proj.weight").cpu() |
| | for i, t in enumerate(torch.unbind(tensor)): |
| | base_key = f"{layer_key}.mlp.experts.{i}" |
| | state_dict[f"{base_key}.gate_proj.weight"] = t.contiguous() |
| | return state_dict |
| |
|
| |
|
| | @staticmethod |
| | def _reorder_cache(past_key_values, beam_idx): |
| | reordered_past = () |
| | for layer_past in past_key_values: |
| | reordered_past += ( |
| | tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past), |
| | ) |
| | return reordered_past |
| |
|
| |
|
| | __all__ = [ |
| | "Moondream3Config", |
| | "Moondream3TextConfig", |
| | "Moondream3VisionConfig", |
| | "Moondream3RegionConfig", |
| | "Moondream3PreTrainedModel", |
| | "Moondream3Model", |
| | "Moondream3TextModel", |
| | "Moondream3VisionModel", |
| | "Moondream3ForConditionalGeneration", |
| | ] |