Text Ranking
sentence-transformers
Safetensors
English
modernbert
ecommerce
e-commerce
retail
marketplace
shopping
amazon
ebay
alibaba
google
rakuten
bestbuy
walmart
flipkart
wayfair
shein
target
etsy
shopify
taobao
asos
carrefour
costco
overstock
pretraining
encoder
language-modeling
foundation-model
text-embeddings-inference
Instructions to use thebajajra/RexReranker-mini with libraries, inference providers, notebooks, and local apps. Follow these links to get started.
- Libraries
- sentence-transformers
How to use thebajajra/RexReranker-mini with sentence-transformers:
from sentence_transformers import CrossEncoder model = CrossEncoder("thebajajra/RexReranker-mini") query = "Which planet is known as the Red Planet?" passages = [ "Venus is often called Earth's twin because of its similar size and proximity.", "Mars, known for its reddish appearance, is often referred to as the Red Planet.", "Jupiter, the largest planet in our solar system, has a prominent red spot.", "Saturn, famous for its rings, is sometimes mistaken for the Red Planet." ] scores = model.predict([(query, passage) for passage in passages]) print(scores) - Notebooks
- Google Colab
- Kaggle
| """RexReranker Model for HuggingFace. | |
| Compatible with: | |
| - Transformers: AutoModel.from_pretrained(..., trust_remote_code=True) | |
| - Sentence Transformers: CrossEncoder(..., trust_remote_code=True) | |
| """ | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| from typing import Optional, List, Union | |
| from dataclasses import dataclass | |
| from transformers import PretrainedConfig, PreTrainedModel, AutoModel | |
| from transformers.modeling_outputs import SequenceClassifierOutput | |
| class RexRerankerOutput(SequenceClassifierOutput): | |
| """Output class for RexReranker with additional distributional information.""" | |
| loss: Optional[torch.Tensor] = None | |
| logits: torch.Tensor = None # Single relevance score [B, 1] for CrossEncoder compatibility | |
| distribution_logits: torch.Tensor = None # Full distribution [B, num_bins] | |
| relevance: torch.Tensor = None # Convenience: same as logits.squeeze(-1) | |
| variance: torch.Tensor = None # Prediction variance | |
| entropy: torch.Tensor = None # Distribution entropy | |
| class RexRerankerConfig(PretrainedConfig): | |
| """Configuration for RexReranker model.""" | |
| model_type = "rex_reranker" | |
| def __init__( | |
| self, | |
| backbone_name: str = "thebajajra/RexBERT-mini", | |
| num_bins: int = 11, | |
| dropout: float = 0.0, | |
| pooling_strategy: str = "mean", | |
| hidden_size: int = None, | |
| num_labels: int = 1, # CrossEncoder compatibility | |
| transitions: List[float] = None, | |
| sigma_min: float = 0.04, | |
| sigma_max: float = 0.12, | |
| sigma_delta: float = 0.08, | |
| **kwargs, | |
| ): | |
| super().__init__(**kwargs) | |
| self.backbone_name = backbone_name | |
| self.num_bins = num_bins | |
| self.dropout = dropout | |
| self.pooling_strategy = pooling_strategy | |
| self.hidden_size = hidden_size | |
| self.num_labels = num_labels | |
| self.transitions = transitions or [0.2, 0.5, 0.8] | |
| self.sigma_min = sigma_min | |
| self.sigma_max = sigma_max | |
| self.sigma_delta = sigma_delta | |
| class RexRerankerModel(PreTrainedModel): | |
| """ | |
| RexBERT-based distributional reranker. | |
| Predicts a categorical distribution over K bins in [0, 1] representing | |
| relevance scores. The output logits contain a single relevance score | |
| for CrossEncoder compatibility, while the full distribution is available | |
| via distribution_logits or predict_with_uncertainty(). | |
| Compatible with: | |
| - sentence_transformers.CrossEncoder | |
| - transformers.AutoModelForSequenceClassification | |
| """ | |
| config_class = RexRerankerConfig | |
| base_model_prefix = "rex_reranker" | |
| supports_gradient_checkpointing = True | |
| def __init__(self, config: RexRerankerConfig): | |
| super().__init__(config) | |
| assert config.pooling_strategy in ("cls", "mean") | |
| self.pooling_strategy = config.pooling_strategy | |
| self.num_bins = config.num_bins | |
| self.backbone = AutoModel.from_pretrained( | |
| config.backbone_name, | |
| trust_remote_code=True, | |
| ) | |
| if hasattr(self.backbone, "config") and hasattr(self.backbone.config, "use_cache"): | |
| self.backbone.config.use_cache = False | |
| hidden_size = config.hidden_size or getattr(self.backbone.config, "hidden_size", None) | |
| if hidden_size is None: | |
| raise ValueError("Could not infer hidden_size.") | |
| self.dropout = nn.Dropout(config.dropout) | |
| self.score_head = nn.Linear(hidden_size, config.num_bins) | |
| self.register_buffer( | |
| "bin_centers", | |
| torch.linspace(0.0, 1.0, config.num_bins), | |
| persistent=False, | |
| ) | |
| self.post_init() | |
| def _init_weights(self, module): | |
| if isinstance(module, nn.Linear): | |
| module.weight.data.normal_(mean=0.0, std=0.02) | |
| if module.bias is not None: | |
| module.bias.data.zero_() | |
| def forward( | |
| self, | |
| input_ids: torch.Tensor, | |
| attention_mask: torch.Tensor, | |
| labels: Optional[torch.Tensor] = None, | |
| return_dict: bool = True, | |
| output_distribution: bool = False, | |
| **kwargs, # Accept extra kwargs for CrossEncoder compatibility | |
| ) -> Union[RexRerankerOutput, tuple]: | |
| """ | |
| Forward pass. | |
| Args: | |
| input_ids: Token IDs [B, T] | |
| attention_mask: Attention mask [B, T] | |
| labels: Optional relevance labels [B] | |
| return_dict: Whether to return a dataclass | |
| output_distribution: If True, include full distribution info in output | |
| Returns: | |
| RexRerankerOutput with: | |
| - logits: [B, 1] single relevance score (CrossEncoder compatible) | |
| - distribution_logits: [B, num_bins] full distribution (if output_distribution=True) | |
| - relevance, variance, entropy: convenience fields (if output_distribution=True) | |
| """ | |
| out = self.backbone( | |
| input_ids=input_ids, | |
| attention_mask=attention_mask, | |
| return_dict=True, | |
| ) | |
| last_hidden = out.last_hidden_state | |
| if self.pooling_strategy == "cls": | |
| pooled = last_hidden[:, 0, :] | |
| else: | |
| mask = attention_mask.unsqueeze(-1).float() | |
| summed = (last_hidden * mask).sum(dim=1) | |
| lengths = mask.sum(dim=1).clamp(min=1e-9) | |
| pooled = summed / lengths | |
| # Get distribution logits | |
| dist_logits = self.score_head(self.dropout(pooled)) # [B, num_bins] | |
| # Convert to single relevance score (expected value) | |
| probs = F.softmax(dist_logits, dim=-1) | |
| relevance = (probs * self.bin_centers.view(1, -1)).sum(dim=-1) # [B] | |
| # Output single score as logits for CrossEncoder compatibility [B, 1] | |
| logits = relevance.unsqueeze(-1) | |
| loss = None | |
| if labels is not None: | |
| loss = F.mse_loss(relevance, labels.float()) | |
| if not return_dict: | |
| output = (logits,) | |
| return ((loss,) + output) if loss is not None else output | |
| # Compute additional stats if requested | |
| variance = None | |
| entropy = None | |
| if output_distribution: | |
| variance = (probs * (self.bin_centers.view(1, -1) - relevance.unsqueeze(-1)) ** 2).sum(dim=-1) | |
| entropy = -(probs * torch.log(probs.clamp(min=1e-9))).sum(dim=-1) | |
| return RexRerankerOutput( | |
| loss=loss, | |
| logits=logits, | |
| distribution_logits=dist_logits if output_distribution else None, | |
| relevance=relevance, | |
| variance=variance, | |
| entropy=entropy, | |
| ) | |
| def predict_relevance( | |
| self, | |
| input_ids: torch.Tensor, | |
| attention_mask: torch.Tensor, | |
| ) -> torch.Tensor: | |
| """Get relevance scores directly. Returns [B] tensor.""" | |
| outputs = self.forward(input_ids=input_ids, attention_mask=attention_mask) | |
| return outputs.relevance | |
| def predict_with_uncertainty( | |
| self, | |
| input_ids: torch.Tensor, | |
| attention_mask: torch.Tensor, | |
| ) -> dict: | |
| """ | |
| Get relevance prediction with full uncertainty estimates. | |
| Returns: | |
| dict with: | |
| - relevance: [B] predicted relevance scores | |
| - variance: [B] prediction variance (higher = more uncertain) | |
| - entropy: [B] distribution entropy (higher = more uncertain) | |
| - probs: [B, num_bins] full probability distribution | |
| - distribution_logits: [B, num_bins] raw logits | |
| """ | |
| outputs = self.forward( | |
| input_ids=input_ids, | |
| attention_mask=attention_mask, | |
| output_distribution=True, | |
| ) | |
| probs = F.softmax(outputs.distribution_logits, dim=-1) | |
| return { | |
| "relevance": outputs.relevance, | |
| "variance": outputs.variance, | |
| "entropy": outputs.entropy, | |
| "probs": probs, | |
| "distribution_logits": outputs.distribution_logits, | |
| } | |