import os import torch import argparse from flask import Flask, request, jsonify, render_template from langchain.prompts import PromptTemplate from langchain_community.llms import HuggingFacePipeline from transformers import pipeline from modelscope import AutoModelForCausalLM, AutoTokenizer import preprocess import traceback from torch import nn import math from peft import LoraConfig, PeftModel """ python /data/zcwang/Protein/project/demo.py --model_path /data/zcwang/Protein/llama_output/output_v3 --lora_path /data/zcwang/Protein/project/2502/v12-2/lora_weight/lora_weights_1_400000.pth --adapter_path /data/zcwang/Protein/project/2502/v12-2/adapt_weight/adapter_model_and_optimizer_1_400000.pth --port 8888 --gpu 0,1,2,3 """ # Dynamic positional encoding class DynamicPositionalEncoding(nn.Module): def __init__(self, d_model, max_len=512): super(DynamicPositionalEncoding, self).__init__() self.d_model = d_model self.max_len = max_len # Create a constant positional encoding matrix pe = torch.zeros(max_len, d_model) position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1) div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model)) pe[:, 0::2] = torch.sin(position * div_term) pe[:, 1::2] = torch.cos(position * div_term) pe = pe.unsqueeze(0) # Shape: [1, max_len, d_model] self.register_buffer('pe', pe) def forward(self): return self.pe # Protein structure sequence adapter class ProteinStructureSequenceAdapter(nn.Module): def __init__(self, input_dim, output_dim, num_heads, num_queries, max_len=512): super(ProteinStructureSequenceAdapter, self).__init__() self.input_dim = input_dim self.output_dim = output_dim self.num_heads = num_heads self.num_queries = num_queries self.max_len = max_len # Linear projection layer self.linear_proj = nn.Linear(input_dim, output_dim) # Dynamic positional encoding layer self.pos_encoder = DynamicPositionalEncoding(output_dim, max_len) # Learnable queries self.learnable_queries = nn.Parameter(torch.randn(num_queries, output_dim)) # Cross-attention layer self.cross_attention = nn.MultiheadAttention(embed_dim=output_dim, num_heads=num_heads, batch_first=True) # Output projection layer self.output_proj = nn.Linear(output_dim, output_dim) self.question_proj = nn.Linear(output_dim, output_dim) # Layer normalization self.layer_norm1 = nn.LayerNorm(output_dim) self.layer_norm2 = nn.LayerNorm(output_dim) self.layer_norm3 = nn.LayerNorm(output_dim) def forward(self, x, h_state): # x shape [batch_size, seq_len, input_dim] batch_size, seq_len, _ = x.size() # Pad the input to max_len if seq_len < self.max_len: pad_size = self.max_len - seq_len padding = torch.zeros(batch_size, pad_size, self.input_dim, device=x.device) x = torch.cat([x, padding], dim=1) # Shape: [batch_size, max_len, input_dim] elif seq_len > self.max_len: x = x[:, :self.max_len, :] # Truncate if seq_len > max_len # Project the input x_proj = self.linear_proj(x) # Shape: [batch_size, max_len, output_dim] # Norm1 x_proj = self.layer_norm1(x_proj) # Positional encoding pe = self.pos_encoder() # Apply positional encoding x_pos_encoded = x_proj + pe[:, :x_proj.size(1), :] # Shape: [batch_size, max_len, output_dim] # Prepare learnable queries queries = self.learnable_queries.unsqueeze(0).expand(batch_size, -1, -1) # Shape: [batch_size, num_queries, output_dim] h_state = self.question_proj(h_state) queries = queries + h_state # Norm2 queries = self.layer_norm2(queries) # Apply positional encoding to queries queries_pos_encoded = pe[:, :self.num_queries*2:2, :] + queries # Shape: [batch_size, num_queries, output_dim] # Cross-attention attn_output, _ = self.cross_attention(queries_pos_encoded, x_pos_encoded, x_pos_encoded) # Shape: [batch_size, num_queries, output_dim] # Apply layer normalization attn_output = self.layer_norm3(attn_output) # Project the output output = self.output_proj(attn_output) # Shape: [batch_size, num_queries, output_dim] return output # Global variables to store the loaded model and adapter model = None tokenizer = None adapter = None device = None # Initialize the model and adapter def initialize_models(model_path, lora_path, adapter_path): global model, tokenizer, adapter, device # Set the device device = "cuda" if torch.cuda.is_available() else "cpu" print(f"Using device: {device}") # Load model L (LLaMA model) print(f"Loading model: {model_path}") model = AutoModelForCausalLM.from_pretrained( model_path, torch_dtype=torch.float32, device_map="auto" ) # Load LoRA weights print(f"Loading LoRA weights: {lora_path}") config = LoraConfig( r=8, lora_alpha=16, target_modules=["q_proj", "v_proj"], inference_mode=False, lora_dropout=0.1, bias="none", task_type="CAUSAL_LM", ) model = PeftModel.from_pretrained(model, model_id=lora_path, config=config) # Load tokenizer print("Loading tokenizer") tokenizer = AutoTokenizer.from_pretrained(model_path) # Load adapter model A print(f"Loading adapter model: {adapter_path}") adapter = ProteinStructureSequenceAdapter(input_dim=1152, output_dim=4096, num_heads=16, num_queries=256, max_len=512).to(device) checkpoint = torch.load(adapter_path, map_location=device) adapter.load_state_dict(checkpoint['adapter_model_weight']) adapter.eval() print("All models loaded successfully") # Main function: Process PDB file and question, generate answer def generate_answer(pdb_file_path, question): global model, tokenizer, adapter, device # Ensure the models are loaded if model is None or tokenizer is None or adapter is None: return "Error: Models are not initialized" try: # Step 1: Process the PDB file to get the protein multi-dimensional vector protein_vector = preprocess.get_mpnn_emb(pdb_file_path).unsqueeze(0).to(device) # Ensure the shape of the protein vector is correct max_length = 512 if max_length > protein_vector.size(1): padding_length = max_length - protein_vector.size(1) padding = torch.zeros((1, padding_length, 1152), device=device) protein_vector = torch.cat([protein_vector, padding], dim=1) else: protein_vector = protein_vector[:, :max_length, :] # Step 2: Get the hidden state of the question text inputs = tokenizer(f"Human: {question}\nAssistant: ", return_tensors="pt").to(device) with torch.no_grad(): hidden_states = model(inputs.input_ids, attention_mask=inputs.attention_mask, output_hidden_states=True).hidden_states[-1] question_hidden_state = hidden_states[:, -1, :] # Step 3: Input the protein vector and question hidden state into the adapter model to get the new protein embedding with torch.no_grad(): protein_embedding = adapter(protein_vector, question_hidden_state) # Step 4: Use the embedding layer of the model to encode the question text inputs_embeds = model.base_model.model.model.embed_tokens(inputs.input_ids) # Step 5: Concatenate the protein embedding and text embedding, input into the model to get the answer combined_embeds = torch.cat([protein_embedding, inputs_embeds], dim=1) combined_attention_mask = torch.cat([torch.ones((protein_embedding.size(0), protein_embedding.size(1)), device=device), inputs.attention_mask], dim=1) # Generate the answer with torch.no_grad(): generated_ids = model.generate( inputs_embeds=combined_embeds, attention_mask=combined_attention_mask, pad_token_id=tokenizer.eos_token_id, max_new_tokens=128 ) response = tokenizer.decode(generated_ids[0], skip_special_tokens=True) # Extract the Assistant part of the answer if "Assistant:" in response: response = response.split("Assistant:")[1].strip() return response except Exception as e: # Capture and print detailed error information print(f"Error in generate_answer: {str(e)}") print(traceback.format_exc()) return f"Error processing: {str(e)}" # Create a Flask application def create_app(): app = Flask(__name__) # Create a simple HTML template @app.route('/') def index(): return ''' Prot2Chat: Protein Q&A System

Prot2Chat: Protein Q&A System

Processing, please wait...
''' # API endpoint @app.route('/api/query', methods=['POST']) def query(): try: # Get the uploaded PDB file if 'pdbFile' not in request.files: return jsonify({'error': 'No PDB file uploaded'}), 400 pdb_file = request.files['pdbFile'] if pdb_file.filename == '': return jsonify({'error': 'No file selected'}), 400 question = request.form.get('question', '') if not question: return jsonify({'error': 'Question cannot be empty'}), 400 # Save the PDB file to a temporary location temp_dir = os.path.join(os.getcwd(), 'temp') os.makedirs(temp_dir, exist_ok=True) temp_pdb_path = os.path.join(temp_dir, pdb_file.filename) pdb_file.save(temp_pdb_path) print(f"Processing file: {temp_pdb_path}") print(f"Question: {question}") # Generate the answer answer = generate_answer(temp_pdb_path, question) # Delete the temporary file if os.path.exists(temp_pdb_path): os.remove(temp_pdb_path) return jsonify({'answer': answer}) except Exception as e: print(f"API error: {str(e)}") print(traceback.format_exc()) return jsonify({'error': str(e)}), 500 return app if __name__ == '__main__': parser = argparse.ArgumentParser(description='Protein Structure Q&A System') parser.add_argument('--model_path', type=str, default="/data/zcwang/Protein/llama_output/output_v3", help='Path to the LLaMA model') parser.add_argument('--lora_path', type=str, default="/data/zcwang/Protein/project/2502/v12-2/lora_weight/lora_weights_1_400000.pth", help='Path to the LoRA weights') parser.add_argument('--adapter_path', type=str, default="/data/zcwang/Protein/project/2502/v12-2/adapt_weight/adapter_model_and_optimizer_1_400000.pth", help='Path to the adapter model weights') parser.add_argument('--port', type=int, default=7777, help='Server port') parser.add_argument('--gpu', type=str, default="0,1", help='IDs of the GPUs to use, e.g., "0,1"') args = parser.parse_args() os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu initialize_models(args.model_path, args.lora_path, args.adapter_path) app = create_app() app.run(host='localhost', port=args.port, debug=False)