NeoLLM / configuration_neollm.py
KitsuVp's picture
Update configuration_neollm.py
e60f9fc verified
# ==================== configuration_neollm.py ====================
from transformers.configuration_utils import PretrainedConfig
from transformers.modeling_rope_utils import rope_config_validation
from transformers.utils import logging
logger = logging.get_logger(__name__)
class NeoLLMConfig(PretrainedConfig):
r"""
Configuration class for the NeoLLM model architecture.
Instantiates a NeoLLM model according to the specified arguments, defining the
full architecture including attention mechanisms, normalization, periodicity
modeling, an optional Leviathan continuous token embedding generator, an
optional Leviathan-JTok-M token-indexed modulation module, and optional
Spelling Bee character-level embedding augmentation.
Configuration objects inherit from :class:`~transformers.PretrainedConfig` and
can be used to control the model outputs. Read the documentation from
:class:`~transformers.PretrainedConfig` for more information.
Args:
vocab_size (:obj:`int`, *optional*, defaults to 200005):
Vocabulary size of the NeoLLM model. Defines the number of different
tokens that can be represented by the ``input_ids``.
hidden_size (:obj:`int`, *optional*, defaults to 512):
Dimensionality of the hidden representations.
intermediate_size (:obj:`int`, *optional*, defaults to 1536):
Dimensionality of the MLP feed-forward intermediate representations.
num_hidden_layers (:obj:`int`, *optional*, defaults to 12):
Number of decoder Transformer layers.
num_attention_heads (:obj:`int`, *optional*, defaults to 8):
Number of query attention heads per layer.
num_key_value_heads (:obj:`int`, *optional*, defaults to 2):
Number of key/value attention heads (GQA). Must divide
``num_attention_heads`` evenly.
hidden_act (:obj:`str`, *optional*, defaults to ``"xielu"``):
Non-linear activation function used in the MLP layers.
max_position_embeddings (:obj:`int`, *optional*, defaults to 32768):
Maximum sequence length supported by the positional encoding.
initializer_range (:obj:`float`, *optional*, defaults to 0.02):
Standard deviation of the truncated-normal weight initializer.
rms_norm_eps (:obj:`float`, *optional*, defaults to 1e-6):
Epsilon for RMS normalization layers.
tie_word_embeddings (:obj:`bool`, *optional*, defaults to ``False``):
Whether to share weights between the input embedding matrix and the
output language-model head. Automatically forced to ``False`` when
``use_token_generator=True``, because the generator produces input
representations via a learned smooth surface and the output head must
remain an independent dense projection.
rope_theta (:obj:`float`, *optional*, defaults to 10000.0):
Base period for Rotary Position Embeddings (RoPE).
rope_scaling (:obj:`dict`, *optional*):
Dictionary containing the RoPE scaling configuration. Must contain at
least the key ``"rope_type"`` (or ``"type"``). Validated by
:func:`~transformers.modeling_rope_utils.rope_config_validation`.
partial_rotary_factor (:obj:`float`, *optional*, defaults to 0.25):
Fraction of each attention head's dimension to rotate with RoPE.
attention_bias (:obj:`bool`, *optional*, defaults to ``False``):
Whether to add a bias term to Q, K, V, and output projections.
attention_dropout (:obj:`float`, *optional*, defaults to 0.1):
Dropout probability applied to attention weights during training.
head_dim (:obj:`int`, *optional*, defaults to 64):
Dimensionality of each attention head.
use_momentum_attention (:obj:`bool`, *optional*, defaults to ``True``):
Enable post-RoPE Momentum Attention: applies a causal first-difference
shear to Q and K before the dot-product score computation.
momentum_gamma (:obj:`float`, *optional*, defaults to 0.10):
Mixing coefficient for the Momentum Attention shear. Ignored when
``use_momentum_attention=False``.
use_mea_attention (:obj:`bool`, *optional*, defaults to ``True``):
Enable Multi-head Explicit Attention (MEA), which applies a learned
head-level linear composition over K and V initialized as identity,
allowing inter-head interaction to emerge freely from step 0.
mea_component_key_value_heads (:obj:`int`, *optional*):
Number of component K/V heads used by MEA. Defaults to
``num_key_value_heads`` when ``None``.
mea_groupnorm_eps (:obj:`float`, *optional*, defaults to 1e-6):
Epsilon for the GQA-grouped MEA output normalisation.
use_lucid_attention (:obj:`bool`, *optional*, defaults to ``False``):
Enable LUCID attention: applies a lower-triangular solve to
precondition the value states using the causal key-key similarity
matrix in RKHS, decorrelating keys to reduce attentional noise in
long-context settings (Duvvuri et al., 2026).
lucid_attention_eps (:obj:`float`, *optional*, defaults to 1e-6):
Epsilon for RMS key normalization inside the LUCID preconditioner.
use_affine_scaled_attention (:obj:`bool`, *optional*, defaults to ``False``):
Enable Affine-Scaled Attention (Bae et al., 2026). Applies an
input-dependent per-head scaling factor Ξ± and a moving-average bias Ξ²
directly to the softmax-normalized attention weights before the
weighted sum with V:
[Ξ±(X) Β· softmax(QKα΅€/√dk) + Ξ²(X)] V
This relaxes the unit-sum constraint of softmax, reduces first-token
bias, increases attention entropy, and promotes head diversity.
Orthogonal to Gated Attention: the gate modulates the post-SDPA
output, while Affine-Scaled modulates the softmax weights directly.
Only active in eager attention mode (flash kernels do not expose
intermediate softmax weights).
affine_momentum (:obj:`float`, *optional*, defaults to 0.9):
EMA momentum coefficient ρ for the running average α_ma used to
compute the bias term Ξ²(X) = (Ξ±_ma βˆ’ Ξ±(X)) / N. Controls the
trade-off between the running estimate and the current batch
statistic. Ignored when ``use_affine_scaled_attention=False``.
use_xsa (:obj:`bool`, *optional*, defaults to ``False``):
Enable Exclusive Self Attention (Zhai, 2026). After the SDPA output
is computed, removes from each head's output the component that falls
along the direction of the token's own value vector, forcing the
attention layer to carry only contextual information orthogonal to
self-position.
Two paths depending on active components:
- **MEA active or LUCID active**: ``v_ref`` is the value vector
after MEA mixing and after LUCID preconditioning β€” the vector that
actually participated in the SDPA aggregation.
- **MEA and LUCID inactive**: ``v_ref`` is the raw value projection
β€” standard XSA as described in the paper.
Applied after MEAHeadRMSNorm and before the Gated Attention gate.
Gains increase with sequence length.
xsa_eps (:obj:`float`, *optional*, defaults to 1e-6):
Epsilon for the denominator of the XSA projection to prevent
division by zero: ``β€–v_refβ€–Β² + xsa_eps``. Ignored when
``use_xsa=False``.
use_stack_memory (:obj:`bool`, *optional*, defaults to ``False``):
Enable the StackMemory / STACKTRANS hidden-state stack between
decoder layers (Zhang et al., NeurIPS 2025). The module keeps the
standard attention path intact and applies differentiable soft
``push``, ``pop`` and ``no-op`` stack updates to the layer input
before the attention block.
stack_d_model (:obj:`int`, *optional*, defaults to ``32``):
Low-rank stack width ``d_s`` used by StackMemory. Hidden states are
projected ``hidden_size β†’ stack_d_model`` before stack operations
and projected back afterwards.
num_mem_heads (:obj:`int`, *optional*, defaults to ``8``):
Number of independent stack heads. Must divide ``stack_d_model``.
stack_slots (:obj:`int`, *optional*, defaults to ``16``):
Number of slots kept by each differentiable stack head.
stack_memory_cache_size (:obj:`int`, *optional*, defaults to ``2048``):
Cache length kept for parity with the released StackTrans source.
The default training path keeps this cache disabled.
fan_ratio (:obj:`float`, *optional*, defaults to 0.125):
Ratio controlling the periodic-dimension size in FANformer attention.
The transformed representation has dimension
``hidden_size * (1 + fan_ratio)``.
fan_ratio_ffn (:obj:`float`, *optional*, defaults to 0.0625):
Ratio controlling the periodic-dimension size in FANformer MLP
layers. Set to half of ``fan_ratio`` to model complementary
periodicities in the feature space.
dropout_rate (:obj:`float`, *optional*, defaults to 0.1):
General dropout probability for attention outputs and MLP states.
use_learnable_multipliers (:obj:`bool`, *optional*, defaults to ``True``):
Enable Learnable Multipliers on the matrix layers where the
Velikanov et al. placement recommends them. When ``True``, the
model keeps row multipliers on ``q_proj`` and MLP ``gate_proj``,
row+column multipliers on dense ``o_proj`` and MLP ``down_proj``,
and no extra multipliers on ``k_proj``, ``v_proj`` or ``up_proj``.
When ``False``, those multiplier parameters are not instantiated
and the wrapped linear layers behave as ordinary ``nn.Linear``
modules.
use_embedding_multipliers (:obj:`bool`, *optional*, defaults to ``True``):
Add vocabulary-row and hidden-channel multipliers to the standard
``nn.Embedding`` matrix, matching the Learnable Multipliers paper's
recommendation for embeddings. This flag is effective only when
``use_learnable_multipliers=True``, ``use_token_generator=False``
and ``tie_word_embeddings=False``. It is intentionally disabled for
tied input/output embeddings and for the Leviathan generator path so
it does not break parameter sharing or wrap a non-matrix generator.
use_lns (:obj:`bool`, *optional*, defaults to ``False``):
Instantiate and apply LayerNorm Scaling in decoder layers. When
``False``, no LNS modules are constructed and the corresponding
path is identity.
use_gpas (:obj:`bool`, *optional*, defaults to ``False``):
Instantiate and apply GPAS at decoder residual junctions. When
``False``, no GPAS modules or parameters are constructed.
use_siamesenorm (:obj:`bool`, *optional*, defaults to ``False``):
Enable SiameseNorm as a two-stream residual topology using RMSNorm
only. When ``True``, decoder layers instantiate the SiameseNorm RMS
stream norms instead of the standard decoder pre-norm modules.
siamese_normalized_input (:obj:`bool`, *optional*, defaults to ``True``):
Apply an additional RMSNorm to the fused SiameseNorm input before
the shared Attention and MLP blocks. This corresponds to the
normalized-input mechanism used by the strongest SiameseNorm
variant.
siamese_depth_scaling (:obj:`bool`, *optional*, defaults to ``True``):
Scale the residual update added to the bounded SiameseNorm stream
by ``1 / sqrt(2 * (layer_idx + 1))``. The unbounded stream receives
the unscaled shared update.
siamese_attn_x_scale_init (:obj:`float`, *optional*, defaults to ``1.0``):
Initial value for the learnable elementwise scale applied to the
bounded stream in the SiameseNorm attention input.
use_embedding_input_norm (:obj:`bool`, *optional*, defaults to ``True``):
Controls only the first pre-attention normalisation applied to the
raw embedding stream. When ``True``, layer 0 instantiates and applies
``nn.RMSNorm``. When ``False``, layer 0
does not instantiate that input-normalisation module, so raw token
embeddings enter the first attention block directly. All deeper
decoder layers keep their normal pre-norm flow unchanged.
use_token_generator (:obj:`bool`, *optional*, defaults to ``False``):
Replace the discrete vocabulary embedding lookup table with a
**Leviathan** continuous token generator (Batley & Saha, 2026).
When enabled:
- ``tie_word_embeddings`` is forced to ``False``.
- ``model.embed_tokens`` is replaced by ``model.token_generator``
(:class:`LeviathanGenerator`).
- The input-embedding parameter budget scales as
``O(k Β· ⌈V^{1/k}βŒ‰ Β· d_seed)`` instead of ``O(V Β· D)``.
- When ``use_jtokm=True``, the generator additionally returns
``z_tilde`` and ``B_vals`` for reuse by every decoder layer,
avoiding redundant B-spline evaluation.
See :class:`LeviathanGenerator` in ``modeling_neollm.py``.
generator_d_seed (:obj:`int`, *optional*, defaults to 128):
Dimensionality of the latent seed space ``zΜƒ ∈ [0,1]^{d_seed}``.
Also used as the per-dimension input to each generator head's
B-spline expansion and as the residual dimension for JTok-M
surfaces when ``use_jtokm=True``.
generator_num_modes (:obj:`int`, *optional*, defaults to 8):
Number of independent per-head generator modes. Each mode has its
own preprocessing (Dense without bias + LayerNorm + sigmoid(x/2)),
its own learnable per-dimension scale, its own delta-parameterized
spline weights ``1 + wd_i``, and its own output projection. Head
outputs are summed to form the embedding.
generator_num_knots (:obj:`int`, *optional*, defaults to 32):
Number of B-spline knot points on ``[0, 1]``. The quadratic basis
is explicitly normalized across the knot dimension after evaluation.
Shared between the input generator heads and all JTok-M surfaces.
generator_spline_degree (:obj:`int`, *optional*, defaults to 2):
Polynomial degree of the B-spline basis. Kept for documentation;
the closed-form KHRONOS quadratic kernel is used in practice.
generator_k (:obj:`int`, *optional*, defaults to 3):
Number of coordinate dimensions for latent compositional indexing.
generator_krank (:obj:`int`, *optional*, defaults to 64):
Output rank of each per-head KHRONOS tensor-product kernel inside
the Leviathan generator. Each of the ``generator_num_modes`` heads
produces a vector of this dimensionality via the tensor-product
aggregation, which is then projected independently to
``hidden_size`` via ``head_out[i]``. ``64`` matches the author
clarification for ``wd_i`` with shape ``(128, 32, 64)``. Lower
values, such as ``32``, are valid memory-saving variants but are
not the full paper-faithful default. Ignored when
``use_token_generator=False``.
use_jtokm (:obj:`bool`, *optional*, defaults to ``False``):
Enable the **Leviathan-JTok-M** token-indexed modulation module
(Yang et al., 2026; fused with Leviathan geometry).
Unlike the original JTok paper which maintains discrete embedding
tables of size ``V Γ— d`` per layer β€” reintroducing the vocabulary
tax in every decoder layer β€” this implementation operates over the
Leviathan latent coordinate ``z̃_x``. Parameter cost scales with
``n_e Γ— M_mod Γ— d_seed Γ— n_knots`` per layer rather than ``V Γ— d``,
breaking the linear dependency on vocabulary size. Additionally,
tokens with nearby latent coordinates receive structurally related
modulations, introducing continuity that the discrete formulation
cannot express.
Architecture per decoder layer when active:
1. **Surface pool**: ``n_e`` independent CP-separable surfaces, each
sharing the same ``z̃_x`` and ``B(z̃_x)`` produced by the
generator. Surface ``i`` computes:
.. math::
m^{\\ell}_{x,i} = W^{\\ell,i}_{\\text{out}}
[M^{\\ell,i}_1, \\ldots, M^{\\ell,i}_{M_{\\text{mod}}}]^\\top
+ W^{\\ell,i}_{\\text{res}}\\, \\tilde{z}_x
2. **Context router**: a linear projection of
``RMSNorm(h̃^ℓ_x)`` — the hidden state *after* attention —
produces ``n_e`` routing logits. TopK selects K surfaces;
Sigmoid-normalized weights (not Softmax) avoid inter-surface
competition:
.. math::
w^{\\ell}_i = \\frac{\\sigma(g^{\\ell}_i)}{\\sum_{j \\in
\\mathcal{G}^{\\ell}_x} \\sigma(g^{\\ell}_j)}
3. **Additive injection** with LNS-coordinated scaling:
.. math::
\\Delta r^{\\ell}_x = \\frac{1}{\\sqrt{2\\ell}} \\cdot
s^{\\ell} \\odot \\text{Norm}_{\\varepsilon}(e^{\\ell}_x)
.. math::
h^{\\ell+1}_x = \\tilde{h}^{\\ell}_x + \\Delta m^{\\ell}_x
+ \\Delta r^{\\ell}_x
The ``1/√(2β„“)`` factor β€” where ``β„“`` is the 1-indexed layer
index β€” is coordinated with the existing LNS factor ``1/βˆšβ„“``
to maintain a **constant JTok-M / backbone contribution ratio**
of ``1/√2 β‰ˆ 0.707`` at every depth.
4. **Load-balancing loss** (averaged over all layers):
.. math::
\\mathcal{L}_{\\text{aux}} = \\lambda \\cdot n_e
\\sum_{i=1}^{n_e} p_i f_i
Requires ``use_token_generator=True``.
jtokm_num_experts (:obj:`int`, *optional*, defaults to 5):
Number of independent CP-separable modulation surfaces ``n_e`` per
decoder layer.
jtokm_top_k (:obj:`int`, *optional*, defaults to 2):
Number of surfaces selected by the router per token per layer (K).
Must satisfy ``1 ≀ jtokm_top_k < jtokm_num_experts``.
jtokm_num_modes (:obj:`int`, *optional*, defaults to 4):
Number of rank-1 separable modes ``M_mod`` per JTok-M surface.
jtokm_aux_loss_weight (:obj:`float`, *optional*, defaults to 1e-4):
Coefficient ``Ξ»`` for the load-balancing auxiliary loss.
jtokm_norm_eps (:obj:`float`, *optional*, defaults to 1e-6):
Epsilon for L2 normalisation of modulation vectors.
use_spelling_bee_embeddings (:obj:`bool`, *optional*, defaults to ``False``):
Augment token embeddings with character-level byte information
(Rabe, Clymo & Dong, 2026).
Each token's UTF-8 encoding (up to 16 bytes) is embedded through a
shared ``nn.Embedding(256, d)`` table. Byte embeddings are
position-encoded with RoPE using intra-token byte positions (not
sequence positions), summed and normalised by ``√byte_len``, then
averaged with the standard token embedding:
.. math::
e_{\\text{bee}}(t) = \\frac{1}{2}\\left(e_{\\text{tok}}(t) +
\\frac{1}{\\sqrt{|t|}}\\sum_{i=1}^{16}
\\text{RoPE}(e_{\\text{byte}}[b_i], i)\\right)
Adds ``256 Γ— hidden_size`` parameters (β‰ˆ0.13M for d=512).
Zero inference overhead when ``bake_inference_table()`` is called
after training.
Compatible with all four combinations of ``use_token_generator``
and ``use_spelling_bee_embeddings``.
**Setup required**: call
``model.model.spelling_bee.set_byte_table(tokenizer)`` once after
model instantiation (handled automatically by ``setup_model`` in
``train.py``).
Reference: Rabe, Clymo & Dong (2026). *Spelling Bee Embeddings for
Language Modeling.* arXiv:2601.18030.
use_hadamard_o_proj (:obj:`bool`, *optional*, defaults to ``False``):
Replace the dense ``W_O ∈ R^{dΓ—d}`` output projection in every
multi-head attention block with a fixed Walsh–Hadamard Transform
followed by a learnable per-channel affine rescaling
``Ξ± βŠ™ FWHT(x)/√d + Ξ²``.
The WHT is a parameter-free orthogonal matrix whose singular values
are all identically 1, so the effective condition number is
``ΞΊ = 1`` by construction and cannot grow during training. This
directly addresses the high-ΞΊ pathology (ΞΊ up to 10^5) observed in
the dense ``o_proj`` matrices, which causes FP8 per-tensor
quantisation to lose low-magnitude directions entirely.
Parameter reduction: replaces ``dΒ²`` weights with ``2d``
(``Ξ±`` and ``Ξ²``), saving β‰ˆ25% of attention parameters per block.
Requires ``hidden_size`` to be a power of 2 (512 βœ“, 1024 βœ“,
768 βœ—).
Reference: Aggarwal & Kumar (2026). *Rethinking Attention Output
Projection: Structured Hadamard Transforms for Efficient
Transformers.* arXiv:2603.08343.
use_repo (:obj:`bool`, *optional*, defaults to ``False``):
Enable Context Re-Positioning (REPO) in attention layers at or
above ``repo_start_layer`` (Li et al., 2026).
REPO replaces the fixed linear position indices ``0…L-1`` fed to
RoPE with continuous, data-dependent positions ``z_i = f_Ο•(h_i)``
learned end-to-end. The attention score between tokens ``i`` and
``j`` becomes:
.. math::
A^{\\text{REPO}}_{i,j} = q_i^\\top\\, g_\\theta(z_j - z_i)\\, k_j
where ``g_ΞΈ`` is the standard RoPE rotation and ``z_i`` is
predicted from the hidden state ``h_i`` by a lightweight SwiGLU
sub-layer ``f_Ο•``:
.. math::
r_i = \\text{Swish}(h_i W_g) \\odot (h_i W_c), \\quad
z_i^{(h)} = r_i w_z^{(h)}
``W_g, W_c \\in \\mathbb{R}^{d \\times d_p}`` are shared across
all query heads within a layer; ``w_z^{(h)} \\in \\mathbb{R}^{d_p}``
is learned independently per head. The assigned positions are
real-valued and unconstrained β€” the model may learn constant
(NoPE-like), monotonic (RoPE-like), or hybrid patterns as needed.
Lower layers (``layer_idx < repo_start_layer``) retain the
standard integer RoPE positions because they primarily capture
surface-level, locally-dependent features that benefit less from
re-positioning (Li et al., 2026, Β§3).
Overhead: +0.9% parameters; inference latency negligible.
repo_start_layer (:obj:`int`, *optional*, defaults to
``num_hidden_layers // 3``):
Index of the first decoder layer to which REPO is applied.
Layers ``[0, repo_start_layer)`` continue to use standard
integer RoPE positions. Must satisfy
``0 <= repo_start_layer < num_hidden_layers``.
Ignored when ``use_repo=False``.
repo_d_p (:obj:`int`, *optional*, defaults to
``hidden_size // 8``):
Dimensionality of the intermediate position representation
``r_i \\in \\mathbb{R}^{d_p}`` inside ``f_Ο•``. The paper sets
``d_p = d/8`` on the assumption that positional information
is less rich than the full hidden representation. Ignored
when ``use_repo=False``.
use_repo_grape (:obj:`bool`, *optional*, defaults to ``False``):
Enable the proposed REPO-GRAPE positional operator. This reuses
the REPO coordinate module ``f_Ο•`` and replaces the RoPE action
with a GRAPE-M action using a learned per-head/per-plane angular
spectrum:
.. math::
u_i^{(h)} = z_i^{(h)}, \\quad
\\theta_{h,r}=\\theta^0_r\\exp(s_{h,r})
When enabled, REPO-GRAPE activates the REPO coordinate path at
``repo_start_layer`` even if ``use_repo=False``. Lower layers keep
standard integer RoPE.
use_laurel (:obj:`bool`, *optional*, defaults to ``False``):
Enable the Learned Augmented Residual Layer (LAUREL) framework
(Menghani, Kumar & Kumar, ICML 2025). LAUREL generalises the
canonical residual connection:
.. math::
x_{i+1} = \\alpha \\cdot f(x_i) + g(x_i)
where :math:`g` is a learned linear function operating on the
residual stream. Applied independently to both the attention
and MLP sublayers of every decoder layer.
At least one of ``use_laurel_rw`` or ``use_laurel_lr`` must be
``True`` when this flag is active; both may be active
simultaneously, producing the combined **LAUREL-RW+LR** variant
(paper eq. 5).
Incompatible with ``use_attn_res=True`` β€” both methods modify
the residual stream and their interaction is undefined.
Reference: Menghani, G., Kumar, R. & Kumar, S. (2025).
*LAUREL: Learned Augmented Residual Layer.* ICML 2025.
use_laurel_rw (:obj:`bool`, *optional*, defaults to ``True``):
Enable the **LAUREL-RW** (Residual Weights) variant. Assigns
independent learned scalars :math:`\\alpha, \\beta` to the
sublayer output and residual respectively:
.. math::
x_{i+1} = \\alpha_s \\cdot f(x_i) + \\beta_s \\cdot x_i
:math:`\\alpha_s, \\beta_s = \\text{softmax}([\\tilde{\\alpha},
\\tilde{\\beta}])` so that they are non-negative and sum to 1,
preventing unbounded growth (paper Β§2.1). Adds **2 parameters
per sublayer** (4 per decoder layer).
When combined with ``use_laurel_lr=True`` (LAUREL-RW+LR,
paper eq. 5):
.. math::
x_{i+1} = \\alpha_s \\cdot f(x_i)
+ \\beta_s \\cdot (B A x_i + x_i)
Ignored when ``use_laurel=False``.
use_laurel_lr (:obj:`bool`, *optional*, defaults to ``False``):
Enable the **LAUREL-LR** (Low-Rank) variant. Augments the
residual with a rank-``laurel_lr_rank`` correction:
.. math::
x_{i+1} = f(x_i) + B A x_i + x_i
where :math:`A \\in \\mathbb{R}^{D \\times r}` and
:math:`B \\in \\mathbb{R}^{r \\times D}` are learnable matrices
(paper eq. 3). :math:`A` is initialised with column-orthogonal
values :math:`A_{i,j} = 1/\\sqrt{rD}` if :math:`i \\bmod r = j`
else 0; :math:`B` is initialised to zero β€” matching the LoRA
convention and ensuring the residual starts as identity
(paper Β§3.3). Adds **2Β·rΒ·D parameters per sublayer**
(4Β·rΒ·D per decoder layer).
Ignored when ``use_laurel=False``.
laurel_lr_rank (:obj:`int`, *optional*, defaults to ``32``):
Rank ``r`` of the low-rank matrices in LAUREL-LR. The paper
recommends :math:`r \\in \\{32, 48, 64\\}` for LLMs
(paper Β§3.3). Ignored when ``use_laurel=False`` or
``use_laurel_lr=False``.
Constraints:
- ``use_jtokm=True`` requires ``use_token_generator=True``.
- ``1 ≀ jtokm_top_k < jtokm_num_experts`` when ``use_jtokm=True``.
- ``use_spelling_bee_embeddings=True`` requires calling
``model.model.spelling_bee.set_byte_table(tokenizer)`` before
training (handled automatically by ``setup_model``).
- ``repo_start_layer`` must satisfy
``0 <= repo_start_layer < num_hidden_layers`` when
``use_repo=True``, ``use_repo_grape=True`` or
``use_repo_goat_prior=True``.
- At most one of ``use_laurel=True``, ``use_attn_res=True``,
``use_stack_memory=True`` or ``use_siamesenorm=True`` may be active.
- When ``use_laurel=True``, at least one of ``use_laurel_rw`` or
``use_laurel_lr`` must be ``True``.
- ``stack_d_model`` must be divisible by ``num_mem_heads`` when
``use_stack_memory=True``.
Examples::
>>> from configuration_neollm import NeoLLMConfig
>>> from modeling_neollm import NeoLLMForCausalLM
>>> # Standard dense-embedding model
>>> config = NeoLLMConfig(use_token_generator=False,
... tie_word_embeddings=True)
>>> model = NeoLLMForCausalLM(config)
>>> # Full attention stack
>>> config_full = NeoLLMConfig(
... use_affine_scaled_attention=True,
... use_xsa=True,
... use_lucid_attention=True,
... )
>>> model_full = NeoLLMForCausalLM(config_full)
>>> # Leviathan generator + JTok-M
>>> config_jtokm = NeoLLMConfig(
... use_token_generator=True,
... use_jtokm=True,
... )
>>> # REPO: context re-positioning from layer 4 onward (default for 12 layers)
>>> config_repo = NeoLLMConfig(
... use_repo=True,
... # repo_start_layer defaults to num_hidden_layers // 3 = 4
... # repo_d_p defaults to hidden_size // 8 = 64
... )
>>> # Proposed REPO-GRAPE-M: REPO coordinates + learned GRAPE spectrum
>>> config_repo_grape = NeoLLMConfig(
... use_repo_grape=True,
... )
>>> # REPO-GRAPE-M + GOAT-style factorised log-prior
>>> config_repo_grape_goat = NeoLLMConfig(
... use_repo_grape=True,
... use_repo_goat_prior=True,
... repo_goat_num_frequencies=3,
... )
References:
Bae, J. et al. (2026). *Affine-Scaled Attention: Towards Flexible and
Stable Transformer Attention.* arXiv:2602.23057.
Duvvuri, S. et al. (2026). *LUCID: Attention with Preconditioned
Representations.* arXiv:2602.10410.
Zhai, S. (2026). *Exclusive Self Attention.* arXiv:2603.09078.
Batley, R. T. & Saha, S. (2026). *A Separable Architecture for Continuous
Token Representation in Language Models.* arXiv:2601.22040.
Yang, Y. et al. (2026). *JTok: On Token Embedding as Another Axis of
Scaling Law via Joint Token Self-Modulation.* arXiv:2602.00800.
Robinson, M. et al. (2025). *Token Embeddings Violate the Manifold
Hypothesis.* arXiv:2504.01002.
Rabe, M. N., Clymo, J. & Dong, Z. (2026). *Spelling Bee Embeddings for
Language Modeling.* arXiv:2601.18030.
Li, H., Zhao, T., Cai, D. & Sproat, R. (2026). *REPO: Language Models
with Context Re-Positioning.* arXiv:2512.14391.
Zhang, Y. et al. (2026). *Group Representational Position Encoding.*
ICLR 2026 / arXiv:2512.07805.
Litman, E. & Guo, G. (2026). *You Need Better Attention Priors.*
arXiv:2601.15380.
Menghani, G., Kumar, R. & Kumar, S. (2025). *LAUREL: Learned Augmented
Residual Layer.* ICML 2025. arXiv:2411.07501.
"""
model_type = "neollm"
keys_to_ignore_at_inference = []
def __init__(
self,
vocab_size=64402,
hidden_size=512,
intermediate_size=1536,
num_hidden_layers=12,
num_attention_heads=8,
num_key_value_heads=4,
hidden_act="xielu",
max_position_embeddings=32768,
initializer_range=0.02,
rms_norm_eps=1e-6,
tie_word_embeddings=False,
rope_theta=10000.0,
rope_scaling=None,
partial_rotary_factor=0.25,
attention_bias=False,
attention_dropout=0.1,
head_dim=64,
use_momentum_attention=True,
momentum_gamma=0.10,
use_mea_attention=False,
mea_component_key_value_heads=None,
mea_groupnorm_eps=1e-6,
use_lucid_attention=False,
lucid_attention_eps=1e-6,
use_affine_scaled_attention=False,
affine_momentum=0.9,
use_xsa=True,
xsa_eps=1e-6,
# ── Directional Routing (Taylor, 2026) ────────────────────────────
use_directional_routing=False,
directional_routing_k=4,
directional_routing_temp=3.0,
# ── Attention Residuals (Kimi Team, 2026) ─────────────────────────
use_attn_res=False,
attn_res_num_blocks=4,
# ── StackMemory / STACKTRANS (Zhang et al., NeurIPS 2025) ───────
use_stack_memory=False,
stack_d_model=32, # H * d_s = 4 * 16
num_mem_heads=4, # H = 4
stack_slots=16, # S = 24
stack_memory_cache_size=2048,
# ── ResFormer cross-layer FAN residual (He et al., 2023) ─────────
use_fan_residual=False,
fan_ratio=0.125,
fan_ratio_ffn=0.0625,
dropout_rate=0.1,
# ── Learnable Multipliers (Velikanov et al., 2026) ─────────────────
use_learnable_multipliers=True,
use_embedding_multipliers=False,
# ── Normalization controls ────────────────────────────────────────
use_lns=False,
use_gpas=False,
use_siamesenorm=True,
siamese_normalized_input=True,
siamese_depth_scaling=True,
siamese_attn_x_scale_init=1.0,
# ── Embedding input normalization ─────────────────────────────────
use_embedding_input_norm=True,
# ── Leviathan continuous token generator ──────────────────────────
use_token_generator=False,
generator_d_seed=128,
generator_num_modes=8,
generator_num_knots=32,
generator_spline_degree=2,
generator_k=3,
generator_krank=64,
# ── Leviathan-JTok-M token-indexed modulation ─────────────────────
use_jtokm=False,
jtokm_num_experts=4,
jtokm_top_k=2,
jtokm_num_modes=4,
jtokm_aux_loss_weight=1e-4,
jtokm_norm_eps=1e-6,
# ── Hadamard output projection (Aggarwal & Kumar, 2026) ───────────
use_hadamard_o_proj=True,
# ── PolyNorm exclusivity ──────────────────────────────────────────
polynorm_exclusive=False,
# ── Spelling Bee Embeddings (Rabe et al., 2026) ───────────────────
use_spelling_bee_embeddings=True,
# ── Context Re-Positioning (Li et al., 2026) ──────────────────────
use_repo=True,
repo_start_layer=None,
repo_d_p=None,
use_repo_grape=True,
use_repo_goat_prior=False,
repo_goat_num_frequencies=3,
repo_goat_sink_decay=4.0,
# ── LAuReL: Learned Augmented Residual Layer (Menghani et al., 2025) ─
use_laurel=False,
use_laurel_rw=False,
use_laurel_lr=False,
laurel_lr_rank=32,
# ── Interleaved Head Attention (Duvvuri et al., 2026) ─────────────
use_iha=True,
iha_num_pseudo_heads=2, # P=2 β†’ 2Γ—2=4 patrones por head
iha_local_global_pattern="LLLLG", # 4 locales + 1 global (paper Β§5.1)
iha_sliding_window=None, # auto = N // (2*P^2) usando la longitud real del batch
iha_global_layers_use_iha=False, # False replica el paper: la capa G es global estΓ‘ndar, sin IHA
**kwargs,
):
# ── Generator / tying consistency ─────────────────────────────────
if use_token_generator and tie_word_embeddings:
logger.warning(
"`use_token_generator=True` is incompatible with "
"`tie_word_embeddings=True`. "
"Automatically setting `tie_word_embeddings=False`. "
"The continuous generator replaces the discrete lookup table "
"with a learned smooth surface, so input and output parameters "
"are always structurally decoupled."
)
tie_word_embeddings = False
# ── Embedding multiplier applicability ─────────────────────────────
if use_embedding_multipliers:
if not use_learnable_multipliers:
logger.warning(
"`use_embedding_multipliers=True` is inactive because "
"`use_learnable_multipliers=False`. No embedding multiplier "
"parameters will be instantiated."
)
elif use_token_generator:
logger.warning(
"`use_embedding_multipliers=True` is inactive because "
"`use_token_generator=True` replaces the standard embedding "
"matrix with LeviathanGenerator."
)
elif tie_word_embeddings:
logger.warning(
"`use_embedding_multipliers=True` is inactive because "
"`tie_word_embeddings=True`; embedding multipliers would "
"break clean input/output parameter sharing."
)
# ── JTok-M / generator dependency ─────────────────────────────────
if use_jtokm and not use_token_generator:
raise ValueError(
"`use_jtokm=True` requires `use_token_generator=True`. "
"The JTok-M surfaces are defined over the Leviathan latent "
"coordinate z̃_x, which is only produced when the generator "
"is active. Set `use_token_generator=True` or disable JTok-M."
)
# ── JTok-M top-k sanity ────────────────────────────────────────────
if use_jtokm and not (1 <= jtokm_top_k < jtokm_num_experts):
raise ValueError(
f"`jtokm_top_k` must satisfy 1 <= jtokm_top_k < jtokm_num_experts, "
f"got jtokm_top_k={jtokm_top_k}, jtokm_num_experts={jtokm_num_experts}."
)
# ── REPO: resolve defaults and validate ───────────────────────────
# repo_start_layer defaults to num_hidden_layers // 3, matching the
# paper's 1/3-of-depth heuristic (Li et al., 2026, Β§3).
# repo_d_p defaults to hidden_size // 8, matching the paper's
# assumption that positional information is less rich than the full
# hidden representation (Li et al., 2026, Β§3.2).
if repo_start_layer is None:
repo_start_layer = num_hidden_layers // 3
if repo_d_p is None:
repo_d_p = hidden_size // 8
if (use_repo or use_repo_grape or use_repo_goat_prior) and not (0 <= repo_start_layer < num_hidden_layers):
raise ValueError(
f"`repo_start_layer` must satisfy "
f"0 <= repo_start_layer < num_hidden_layers, "
f"got repo_start_layer={repo_start_layer}, "
f"num_hidden_layers={num_hidden_layers}."
)
# ── REPO-GOAT prior: validate factorised prior dimensions ─────────
if use_repo_goat_prior:
if repo_goat_num_frequencies < 0:
raise ValueError(
f"`repo_goat_num_frequencies` must be >= 0, "
f"got {repo_goat_num_frequencies}."
)
if repo_goat_sink_decay <= 0.0:
raise ValueError(
f"`repo_goat_sink_decay` must be > 0, "
f"got {repo_goat_sink_decay}."
)
_repo_goat_prior_dim = 2 * int(repo_goat_num_frequencies) + 2
if (head_dim + _repo_goat_prior_dim) % 8 != 0:
logger.warning(
"`use_repo_goat_prior=True` appends %d prior channels to "
"head_dim=%d, giving internal attention dim=%d. Some "
"FlashAttention builds prefer dimensions divisible by 8; "
"the default repo_goat_num_frequencies=3 gives +8.",
_repo_goat_prior_dim, head_dim, head_dim + _repo_goat_prior_dim,
)
# ── SiameseNorm: RMS-only topology by construction ────────────────
# ── IHA / MEA compatibility ───────────────────────────────────────
# The implementation keeps both modules in-place:
# IHA acts first on Q/K/V component heads.
# MEA then applies its [H_comp, H_kv] mixing independently inside
# each IHA pseudo-slot on K/V.
# This preserves IHA's pseudo-head structure and the GQA ratio
# (H_q*P) / (H_kv*P) = H_q / H_kv without moving other attention ops.
if use_iha and iha_num_pseudo_heads < 1:
raise ValueError(
f"`iha_num_pseudo_heads` must be >= 1, got {iha_num_pseudo_heads}."
)
if use_iha:
_iha_pattern = str(iha_local_global_pattern).upper().strip()
if not _iha_pattern:
raise ValueError("`iha_local_global_pattern` must not be empty when `use_iha=True`.")
_bad_iha_pattern_chars = sorted(set(_iha_pattern) - {"L", "G"})
if _bad_iha_pattern_chars:
raise ValueError(
"`iha_local_global_pattern` only accepts 'L' and 'G'. "
f"Got invalid characters: {_bad_iha_pattern_chars}."
)
# ── Residual-flow mechanism mutex ─────────────────────────────────
# AttnRes, LAuReL, StackMemory and SiameseNorm all alter the residual
# topology. Keep exactly one experimental residual-flow mechanism
# active at a time so ablations remain interpretable.
_residual_flow_flags = {
"use_laurel": bool(use_laurel),
"use_attn_res": bool(use_attn_res),
"use_stack_memory": bool(use_stack_memory),
"use_siamesenorm": bool(use_siamesenorm),
}
_active_residual_flow = [name for name, enabled in _residual_flow_flags.items() if enabled]
if len(_active_residual_flow) > 1:
raise ValueError(
"Enable at most one residual-flow mechanism among "
"`use_laurel`, `use_attn_res`, `use_stack_memory`, "
"and `use_siamesenorm`. "
f"Active flags: {_active_residual_flow}."
)
# ── LAuReL: sub-flag consistency ──────────────────────────────────
if use_laurel and not use_laurel_rw and not use_laurel_lr:
raise ValueError(
"`use_laurel=True` requires at least one sub-variant to be active. "
"Set `use_laurel_rw=True` and/or `use_laurel_lr=True`."
)
# ── StackMemory: source-code shape constraints ────────────────────
if use_stack_memory:
if stack_d_model <= 0:
raise ValueError(f"`stack_d_model` must be > 0, got {stack_d_model}.")
if num_mem_heads <= 0:
raise ValueError(f"`num_mem_heads` must be > 0, got {num_mem_heads}.")
if stack_slots <= 0:
raise ValueError(f"`stack_slots` must be > 0, got {stack_slots}.")
if stack_d_model % num_mem_heads != 0:
raise ValueError(
f"`stack_d_model` must be divisible by `num_mem_heads`, "
f"got stack_d_model={stack_d_model}, num_mem_heads={num_mem_heads}."
)
super().__init__(tie_word_embeddings=tie_word_embeddings, **kwargs)
# ── Core Transformer ──────────────────────────────────────────────
self.vocab_size = vocab_size
self.max_position_embeddings = max_position_embeddings
self.hidden_size = hidden_size
self.intermediate_size = intermediate_size
self.num_hidden_layers = num_hidden_layers
self.num_attention_heads = num_attention_heads
self.num_key_value_heads = num_key_value_heads
self.hidden_act = hidden_act
self.initializer_range = initializer_range
self.rms_norm_eps = rms_norm_eps
# ── Positional encoding ───────────────────────────────────────────
self.rope_theta = rope_theta
self.rope_scaling = rope_scaling
self.partial_rotary_factor = partial_rotary_factor
# ── Attention ─────────────────────────────────────────────────────
self.attention_bias = attention_bias
self.attention_dropout = attention_dropout
self.head_dim = head_dim
self.use_momentum_attention = use_momentum_attention
self.momentum_gamma = momentum_gamma
self.use_mea_attention = use_mea_attention
self.mea_component_key_value_heads = (
num_key_value_heads
if mea_component_key_value_heads is None
else int(mea_component_key_value_heads)
)
self.mea_groupnorm_eps = mea_groupnorm_eps
self.use_lucid_attention = use_lucid_attention
self.lucid_attention_eps = lucid_attention_eps
self.use_affine_scaled_attention = use_affine_scaled_attention
self.affine_momentum = affine_momentum
self.use_xsa = use_xsa
self.xsa_eps = xsa_eps
# ── Directional Routing ───────────────────────────────────────────
self.use_directional_routing = use_directional_routing
self.directional_routing_k = directional_routing_k
self.directional_routing_temp = directional_routing_temp
# ── Attention Residuals ───────────────────────────────────────────
# use_attn_res=True: replace fixed residual accumulation with learned
# depth-wise softmax attention over preceding layer outputs.
# attn_res_num_blocks=0: Full AttnRes β€” all previous layer outputs
# are kept as sources (N grows to num_hidden_layers+1).
# attn_res_num_blocks=4: Block AttnRes β€” 4 block summaries maximum,
# block_size = num_hidden_layers // 4 = 3 layers per block.
# Memory cost: O(num_blocks Γ— batch Γ— seq Γ— hidden) instead of
# O(num_layers Γ— batch Γ— seq Γ— hidden).
self.use_attn_res = use_attn_res
self.attn_res_num_blocks = attn_res_num_blocks
# ── StackMemory / STACKTRANS ──────────────────────────────────────
self.use_stack_memory = use_stack_memory
self.stack_d_model = stack_d_model
self.num_mem_heads = num_mem_heads
self.stack_slots = stack_slots
self.stack_memory_cache_size = stack_memory_cache_size
rope_config_validation(self)
# ── FANformer periodicity ─────────────────────────────────────────
self.use_fan_residual = use_fan_residual
self.fan_ratio = fan_ratio
self.fan_ratio_ffn = fan_ratio_ffn
# ── Regularization ────────────────────────────────────────────────
self.dropout_rate = dropout_rate
# ── Learnable Multipliers ─────────────────────────────────────────
self.use_learnable_multipliers = use_learnable_multipliers
self.use_embedding_multipliers = use_embedding_multipliers
# ── Normalization controls ────────────────────────────────────────
self.use_lns = bool(use_lns)
self.use_gpas = bool(use_gpas)
self.use_siamesenorm = bool(use_siamesenorm)
self.siamese_normalized_input = bool(siamese_normalized_input)
self.siamese_depth_scaling = bool(siamese_depth_scaling)
self.siamese_attn_x_scale_init = float(siamese_attn_x_scale_init)
# ── Embedding input normalization ─────────────────────────────────
# True preserves the previous layer-0 pre-norm path, using the active
# norm implementation:
# embeddings -> RMSNorm -> LNS -> attention.
# False removes that first input-norm module entirely, so raw embeddings
# enter the first attention block and the layer-0 input-norm parameters
# are not counted; layers >= 1 keep their pre-norm unchanged.
self.use_embedding_input_norm = use_embedding_input_norm
# ── Leviathan generator ───────────────────────────────────────────
self.use_token_generator = use_token_generator
self.generator_d_seed = generator_d_seed
self.generator_num_modes = generator_num_modes
self.generator_num_knots = generator_num_knots
self.generator_spline_degree = generator_spline_degree
self.generator_k = generator_k
self.generator_krank = generator_krank
# ── Leviathan-JTok-M ─────────────────────────────────────────────
self.use_jtokm = use_jtokm
self.jtokm_num_experts = jtokm_num_experts
self.jtokm_top_k = jtokm_top_k
self.jtokm_num_modes = jtokm_num_modes
self.jtokm_aux_loss_weight = jtokm_aux_loss_weight
self.jtokm_norm_eps = jtokm_norm_eps
# ── Hadamard output projection (Aggarwal & Kumar, 2026) ───────────
self.use_hadamard_o_proj = use_hadamard_o_proj
# ── PolyNorm exclusivity ──────────────────────────────────────────
self.polynorm_exclusive = polynorm_exclusive
# ── Spelling Bee Embeddings (Rabe et al., 2026) ───────────────────
self.use_spelling_bee_embeddings = use_spelling_bee_embeddings
# ── Context Re-Positioning (Li et al., 2026) ──────────────────────
self.use_repo = use_repo
self.repo_start_layer = repo_start_layer
self.repo_d_p = repo_d_p
self.use_repo_grape = use_repo_grape
self.use_repo_goat_prior = use_repo_goat_prior
self.repo_goat_num_frequencies = int(repo_goat_num_frequencies)
self.repo_goat_sink_decay = float(repo_goat_sink_decay)
# ── LAuReL: Learned Augmented Residual Layer (Menghani et al., 2025) ─
self.use_laurel = use_laurel
self.use_laurel_rw = use_laurel_rw
self.use_laurel_lr = use_laurel_lr
self.laurel_lr_rank = laurel_lr_rank
# ── Interleaved Head Attention (Duvvuri et al., 2026) ─────────────
# use_iha=True: enables learned cross-head mixing of Q, K, V.
# iha_num_pseudo_heads (P): number of pseudo-heads per original head.
# P=1: lightweight cross-head linear mixing, fully shape-preserving,
# compatible with all other attention flags.
# P>1: full IHA with pseudo-head expansion and collapse.
# If MEA is active, MEA composes K/V independently inside each
# pseudo-slot after IHA, so both remain compatible.
# iha_local_global_pattern: paper Sec. 5.1 / Appendix C schedule.
# "LLLLG" β†’ 4 local IHA layers + 1 global transport layer per cycle.
# 'L' always means local IHA with sliding-window compensation.
# 'G' means global full-sequence attention. By default, 'G' is
# standard non-IHA attention to match the paper's FLOP argument:
# 4Β·O(HNΒ²d/2)+1Β·O(HNΒ²d). Set iha_global_layers_use_iha=True
# only for the more expensive ablation where global layers also
# use full IHA, whose attention cost scales as O(HPΒ²NΒ²d).
# Applied only when P>1 (P=1 never needs FLOP compensation).
# iha_sliding_window: window size W for local-IHA layers.
# None β†’ auto = N/(2PΒ²) with N = actual sequence length at forward time
# (paper Sec. 5.1 / Appendix C exact recipe).
# int β†’ use the provided explicit window size as-is.
# iha_global_layers_use_iha:
# False β†’ paper-faithful default: G layers do not instantiate or run
# IHA parameters, so they are ordinary global attention layers.
# True β†’ G layers instantiate and run global IHA over the expanded
# sequence; useful only as an explicit high-cost ablation.
# Init: identity (IHA ≑ MHA at step 0, Theorem 2 inclusion proof).
self.use_iha = use_iha
self.iha_num_pseudo_heads = iha_num_pseudo_heads
self.iha_local_global_pattern = iha_local_global_pattern
self.iha_sliding_window = iha_sliding_window
self.iha_global_layers_use_iha = bool(iha_global_layers_use_iha)
self.auto_map = {
"AutoConfig": "configuration_neollm.NeoLLMConfig",
"AutoModel": "modeling_neollm.NeoLLMModel",
"AutoModelForCausalLM": "modeling_neollm.NeoLLMForCausalLM",
}
__all__ = ["NeoLLMConfig"]