AMP-Oriented Multi-task Model (AOMM)

Pure Era is the name of the projection completed by the team XJTLU-Software 2025 in the iGEM 2025 competition. This project focuses on deeply combining antimicrobial peptides(AMP) with data science. AMP-Oriented Multi-task Model (AOMM) is a state-of-the-art language model for AMPs.

  • The number of parameters is approximately 124M.
  • The model is open source on HuggingFace.
from transformers import AutoModel
model = AutoModel.from_pretrained(
    "muskwff/amp4multitask_124M",
    task_name="mic_regression",     # choose the task ["amp_classification", "hemolysis_regression", "mic_regression", "half_life_regression"]
    trust_remote_code=True      # required for loading the model
)
  • The code is available on Github.

amp4multitask Figure 1: The model architecture of AOMM

News

  • 2025.10: The parameters are updated.
  • 2025.08: The model is released on HuggingFace.

How to use

import os
import numpy as np
import json
import pandas as pd

import torch
from transformers import EsmTokenizer, AutoModel

class Inferencer:
    def __init__(self, model_root, norm_params_file):
        if model_root is None:
            self.model_root = "muskwff/amp4multitask_124M"
        else:
            self.model_root = model_root

        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        print(f"[Info] Device: {self.device}")
        self.norm_params_file = norm_params_file

        # load task config
        self.task_config = {
            "masked_lm":{
                "num_organisms": None
            },
            "amp_classification":{
                "num_organisms": None
            },
            "bioactivity_classification":{
                "num_organisms": None
            },
            "half_life_regression":{
                "num_organisms": None
            },
            "hemolysis_regression":{
                "num_organisms": None
            },
            "mic_regression":{
                "num_organisms": 12
            }
        }
        self.tasks = list(self.task_config.keys())      # six tasks in total
        self.tasks.remove("masked_lm")                  # but remove the masked_lm task
        self.tasks.remove("bioactivity_classification") # remove the bioactivity_classification task   

        # load tokenizer
        self.tokenizer = EsmTokenizer.from_pretrained(
        self.model_root,
        padding_side="left"
        )
        
        # load normalization parameters
        self.norm_params = self.load_normalization_parameters()
        self.organism_to_idx = self.build_organism_mapping()
        self.idx_to_organism = {idx: org for org, idx in self.organism_to_idx.items()}
    
    def load_normalization_parameters(self):
        if os.path.exists(self.norm_params_file):
            df = pd.read_csv(self.norm_params_file)
            return df
        else:
            raise ValueError("[Error] No normalization parameters found")
    
    def build_organism_mapping(self):
        """
        Build a mapping from organism to id
        """
        mic_params = self.norm_params[self.norm_params['parameter_type'] == 'mic_regression']
        return {org: idx for idx, org in enumerate(mic_params['organism'].tolist())}
    
    def upper_str(self, str):
        """
        Convert the amino acid letters in the string to uppercase while keeping the lowercase form of special tokens
        """
        in_tag = False
        result = []
        for char in str:
            if char == "<":
                in_tag = True
            if not in_tag and char.isalpha():
                char = char.upper()
            result.append(char)
            if char == ">":
                in_tag = False
        return "".join(result)
    
    def filter(self, str):
        s = self.upper_str(str)
        valid_aminos = ["L", "A", "G", "V", "S", "E", "R", "T", "I", "D", "P", "K", "N", "Q",
                        "F", "Y", "M", "H", "W", "C", "B", "U", "Z", "O"] # J is not in vocab list, but in data

        if len(s) == 0 or len(s) >= 100:  # MAX_SEQ_LENGTH=100
            return False

        # checl if all characters are valid
        for char in s:
            if char == '<':  # the special token start
                continue
            if char == '>':  # the special token end
                continue
            if char.upper() not in valid_aminos and char != '.' and char != '-':
                return False
        return True
        
    def inference_single_sample(self, task_name, sequence, organism_id=None):
        """
        Perform inference on a single sample
        """
        # Check if the sequence is valid
        if not self.filter(sequence):
            raise ValueError(f"[Error] Invalid sequence: {sequence}")
        
        # Tokenize the sequence
        encoding = self.tokenizer(
            sequence,
            padding='max_length',
            max_length=128,
            truncation=True,
            return_tensors='pt'
        )
        
        input_ids = encoding['input_ids'].to(self.device)
        attention_mask = encoding['attention_mask'].to(self.device)
        
        # Get the appropriate model
        model = AutoModel.from_pretrained(self.model_root, task_name=task_name, trust_remote_code=True).to(self.device)
        model.eval()

        # Prepare organism_ids for mic_regression task
        if task_name == "mic_regression" and organism_id is not None:
            # Ensure organism_id is a tensor with shape [batch_size]
            if not isinstance(organism_id, torch.Tensor):
                organism_id = torch.tensor([organism_id], dtype=torch.long, device=self.device)
            elif organism_id.dim() == 0:  # scalar tensor
                organism_id = organism_id.unsqueeze(0)
            # Ensure it's on the correct device
            organism_id = organism_id.to(self.device)
        else:
            organism_id = None
        
        with torch.inference_mode():
            outputs = model(
                input_ids=input_ids,
                attention_mask=attention_mask,
                task_name=task_name,
                organism_ids=organism_id,
                return_dict=True
            )
            
        logits = outputs.logits
        
        # Process output based on task type
        if task_name == "amp_classification":
            probs = torch.softmax(logits, dim=-1)
            positive_prob = probs[0][1].item()  # Probability of positive class (AMP)
            return positive_prob
        else:  # Regression tasks
            # Get the prediction value
            pred_value = logits.squeeze().cpu().item()
            # Denormalize if needed
            if task_name in ["mic_regression", "half_life_regression", "hemolysis_regression"]:
                if task_name == "mic_regression":
                    return self.denormalize(task_name, pred_value, organism_id)
                else:
                    return self.denormalize(task_name, pred_value)
            return pred_value
    
    def denormalize(self, task_name, normalized_value, organism_id=None):
        """
        Denormalize the value
        """
        if isinstance(organism_id, torch.Tensor):
            organism_id = organism_id.cpu().item()
        
        if task_name == "mic_regression":
            # Get organism name from organism_id
            org_name = self.idx_to_organism.get(organism_id)
            if not org_name:
                # Fallback: try to find in norm_params
                mic_params = self.norm_params[self.norm_params['parameter_type'] == 'mic_regression']
                if organism_id < len(mic_params):
                    org_name = mic_params.iloc[organism_id]['organism']
                else:
                    raise ValueError(f"[Error] Invalid organism_id: {organism_id}")
            
            # Get normalization parameters for this organism
            params = self.norm_params[
                (self.norm_params['parameter_type'] == 'mic_regression') &
                (self.norm_params['organism'] == org_name)
            ]
            if len(params) == 0:
                raise ValueError(f"[Error] No normalization parameters found for organism: {org_name}")
            
            mean = params['mean'].values[0]
            std = params['std'].values[0]
            log_val = normalized_value * std + mean
            # In merged_dataloader.py, we used math.log10
            return 10 ** log_val
        elif task_name == "half_life_regression":
            params = self.norm_params[self.norm_params['parameter_type'] == 'half_life_regression']
            if len(params) == 0:
                raise ValueError("[Error] No normalization parameters found for half_life_regression")
            
            mean = params['mean'].values[0]
            std = params['std'].values[0]
            log_val = normalized_value * std + mean
            # In merged_dataloader.py, we used math.log10
            return 10 ** log_val    
        elif task_name == "hemolysis_regression":
            params = self.norm_params[self.norm_params['parameter_type'] == 'hemolysis_regression']
            if len(params) == 0:
                raise ValueError("[Error] No normalization parameters found for hemolysis_regression")
            
            mean = params['mean'].values[0]
            std = params['std'].values[0]
            log_val = normalized_value * std + mean
            # In merged_dataloader.py, we used math.log10
            return 10 ** log_val
        
        return normalized_value  # the rest of tasks are not normalized

if __name__ == "__main__":
    inferencer = Inferencer(
        model_root="AOMM",
        norm_params_file="normalization_parameters.csv"
    )
    sample_seq = "GFGCPGDQYECNRHCRSIGCRAGYCDAVTLWLRCTCTGCSGKK"
    organism_id = 10

    # Get predictions
    amp_prob = inferencer.inference_single_sample(sequence=sample_seq, task_name="amp_classification")
    mic_value = inferencer.inference_single_sample(sequence=sample_seq, task_name="mic_regression", organism_id=organism_id)
    half_life_value = inferencer.inference_single_sample(sequence=sample_seq, task_name="half_life_regression")
    hemolysis_value = inferencer.inference_single_sample(sequence=sample_seq, task_name="hemolysis_regression")
    print(f"AMP Probability: {amp_prob:.4f}")
    print(f"MIC Value: {mic_value:.4f}")
    print(f"Half-Life Value: {half_life_value:.4f}")
    print(f"Hemolysis Value: {hemolysis_value:.4f}")

Model Introduction

The model AOMM is a deep learning model focusing on five important tasks regarding to AMP:

  1. sequence mask prediction: Masked sequence training is a common task in NLP(Auxiliary task)
  2. AMP classification: Determine whether a peptide is an antimicrobial peptide
  3. microorganism-specific minimum inhibitory concentration regression
  4. half-life regression
  5. hemolytic activity regression
  6. the classifiaction Bioactivity(Auxiliary task)
  • Model Architecture

The model architecture consists of two parts: 1. Encoder 2. Decoder, which means the model is a encoder-decoder model. The parameters of the model is shown as follows:

model_max_length num_hidden_layers hidden_size num_attention_heads
128 18 768 24

We use the wohle sequence feature as the input of decoder for task mask training, while the input of the decoder for the other tasks is the CLS token of the sequence feature.

  • Tokenizer Configuration

The tokenizer we use in the AOMM is the same as Esm2 model, which you can find on HuggingFace. However, because the sequence length of antimicrobial peptides is generally between 1 and 100 AA. Thus we set the max_length to 128.

from transformers import AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained("muskwff/amp4multitask_124M")

The configuration file of the tokenizer is as follows:

{
  "model_max_length": 128,
  "tokenizer_class": "EsmTokenizer",
  "special_tokens_map": {
  "cls_token": "<cls>",
  "pad_token": "<pad>",
  "eos_token": "<eos>",
  "unk_token": "<unk>",
  "mask_token": "<mask>"
}
}
Downloads last month
5
Inference Providers NEW
This model isn't deployed by any Inference Provider. ๐Ÿ™‹ Ask for provider support

Dataset used to train NeurEv0/AOMM