| | import os |
| | import torch |
| | import argparse |
| | from flask import Flask, request, jsonify, render_template |
| | from langchain.prompts import PromptTemplate |
| | from langchain_community.llms import HuggingFacePipeline |
| | from transformers import pipeline |
| | from modelscope import AutoModelForCausalLM, AutoTokenizer |
| | import preprocess |
| | import traceback |
| | from torch import nn |
| | import math |
| | from peft import LoraConfig, PeftModel |
| | """ |
| | python /data/zcwang/Protein/project/demo.py --model_path /data/zcwang/Protein/llama_output/output_v3 --lora_path /data/zcwang/Protein/project/2502/v12-2/lora_weight/lora_weights_1_400000.pth --adapter_path /data/zcwang/Protein/project/2502/v12-2/adapt_weight/adapter_model_and_optimizer_1_400000.pth --port 8888 --gpu 0,1,2,3 |
| | """ |
| | |
| | class DynamicPositionalEncoding(nn.Module): |
| | def __init__(self, d_model, max_len=512): |
| | super(DynamicPositionalEncoding, self).__init__() |
| | self.d_model = d_model |
| | self.max_len = max_len |
| | |
| | |
| | pe = torch.zeros(max_len, d_model) |
| | position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1) |
| | div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model)) |
| | pe[:, 0::2] = torch.sin(position * div_term) |
| | pe[:, 1::2] = torch.cos(position * div_term) |
| | pe = pe.unsqueeze(0) |
| | self.register_buffer('pe', pe) |
| |
|
| | def forward(self): |
| | return self.pe |
| |
|
| | |
| | class ProteinStructureSequenceAdapter(nn.Module): |
| | def __init__(self, input_dim, output_dim, num_heads, num_queries, max_len=512): |
| | super(ProteinStructureSequenceAdapter, self).__init__() |
| | self.input_dim = input_dim |
| | self.output_dim = output_dim |
| | self.num_heads = num_heads |
| | self.num_queries = num_queries |
| | self.max_len = max_len |
| | |
| | |
| | self.linear_proj = nn.Linear(input_dim, output_dim) |
| | |
| | |
| | self.pos_encoder = DynamicPositionalEncoding(output_dim, max_len) |
| | |
| | |
| | self.learnable_queries = nn.Parameter(torch.randn(num_queries, output_dim)) |
| | |
| | |
| | self.cross_attention = nn.MultiheadAttention(embed_dim=output_dim, num_heads=num_heads, batch_first=True) |
| | |
| | |
| | self.output_proj = nn.Linear(output_dim, output_dim) |
| | |
| | self.question_proj = nn.Linear(output_dim, output_dim) |
| | |
| | self.layer_norm1 = nn.LayerNorm(output_dim) |
| | self.layer_norm2 = nn.LayerNorm(output_dim) |
| | self.layer_norm3 = nn.LayerNorm(output_dim) |
| |
|
| | def forward(self, x, h_state): |
| | |
| | batch_size, seq_len, _ = x.size() |
| | |
| | |
| | if seq_len < self.max_len: |
| | pad_size = self.max_len - seq_len |
| | padding = torch.zeros(batch_size, pad_size, self.input_dim, device=x.device) |
| | x = torch.cat([x, padding], dim=1) |
| | elif seq_len > self.max_len: |
| | x = x[:, :self.max_len, :] |
| | |
| | |
| | x_proj = self.linear_proj(x) |
| |
|
| | |
| | x_proj = self.layer_norm1(x_proj) |
| | |
| | |
| | pe = self.pos_encoder() |
| |
|
| | |
| | x_pos_encoded = x_proj + pe[:, :x_proj.size(1), :] |
| | |
| | |
| | queries = self.learnable_queries.unsqueeze(0).expand(batch_size, -1, -1) |
| |
|
| | h_state = self.question_proj(h_state) |
| |
|
| | queries = queries + h_state |
| | |
| | queries = self.layer_norm2(queries) |
| |
|
| | |
| | queries_pos_encoded = pe[:, :self.num_queries*2:2, :] + queries |
| | |
| | |
| | attn_output, _ = self.cross_attention(queries_pos_encoded, x_pos_encoded, x_pos_encoded) |
| | |
| | |
| | attn_output = self.layer_norm3(attn_output) |
| | |
| | |
| | output = self.output_proj(attn_output) |
| | |
| | return output |
| |
|
| | |
| | model = None |
| | tokenizer = None |
| | adapter = None |
| | device = None |
| |
|
| | |
| | def initialize_models(model_path, lora_path, adapter_path): |
| | global model, tokenizer, adapter, device |
| | |
| | |
| | device = "cuda" if torch.cuda.is_available() else "cpu" |
| | print(f"Using device: {device}") |
| | |
| | |
| | print(f"Loading model: {model_path}") |
| | model = AutoModelForCausalLM.from_pretrained( |
| | model_path, |
| | torch_dtype=torch.float32, |
| | device_map="auto" |
| | ) |
| |
|
| | |
| | print(f"Loading LoRA weights: {lora_path}") |
| | config = LoraConfig( |
| | r=8, |
| | lora_alpha=16, |
| | target_modules=["q_proj", "v_proj"], |
| | inference_mode=False, |
| | lora_dropout=0.1, |
| | bias="none", |
| | task_type="CAUSAL_LM", |
| | ) |
| | model = PeftModel.from_pretrained(model, model_id=lora_path, config=config) |
| |
|
| | |
| | print("Loading tokenizer") |
| | tokenizer = AutoTokenizer.from_pretrained(model_path) |
| |
|
| | |
| | print(f"Loading adapter model: {adapter_path}") |
| | adapter = ProteinStructureSequenceAdapter(input_dim=1152, output_dim=4096, num_heads=16, num_queries=256, max_len=512).to(device) |
| | checkpoint = torch.load(adapter_path, map_location=device) |
| | adapter.load_state_dict(checkpoint['adapter_model_weight']) |
| | adapter.eval() |
| | |
| | print("All models loaded successfully") |
| |
|
| | |
| | def generate_answer(pdb_file_path, question): |
| | global model, tokenizer, adapter, device |
| | |
| | |
| | if model is None or tokenizer is None or adapter is None: |
| | return "Error: Models are not initialized" |
| | |
| | try: |
| | |
| | protein_vector = preprocess.get_mpnn_emb(pdb_file_path).unsqueeze(0).to(device) |
| | |
| | |
| | max_length = 512 |
| | if max_length > protein_vector.size(1): |
| | padding_length = max_length - protein_vector.size(1) |
| | padding = torch.zeros((1, padding_length, 1152), device=device) |
| | protein_vector = torch.cat([protein_vector, padding], dim=1) |
| | else: |
| | protein_vector = protein_vector[:, :max_length, :] |
| | |
| | |
| | inputs = tokenizer(f"Human: {question}\nAssistant: ", return_tensors="pt").to(device) |
| | with torch.no_grad(): |
| | hidden_states = model(inputs.input_ids, attention_mask=inputs.attention_mask, output_hidden_states=True).hidden_states[-1] |
| | question_hidden_state = hidden_states[:, -1, :] |
| | |
| | |
| | with torch.no_grad(): |
| | protein_embedding = adapter(protein_vector, question_hidden_state) |
| | |
| | |
| | inputs_embeds = model.base_model.model.model.embed_tokens(inputs.input_ids) |
| | |
| | |
| | combined_embeds = torch.cat([protein_embedding, inputs_embeds], dim=1) |
| | combined_attention_mask = torch.cat([torch.ones((protein_embedding.size(0), protein_embedding.size(1)), device=device), inputs.attention_mask], dim=1) |
| | |
| | |
| | with torch.no_grad(): |
| | generated_ids = model.generate( |
| | inputs_embeds=combined_embeds, |
| | attention_mask=combined_attention_mask, |
| | pad_token_id=tokenizer.eos_token_id, |
| | max_new_tokens=128 |
| | ) |
| | |
| | response = tokenizer.decode(generated_ids[0], skip_special_tokens=True) |
| | |
| | |
| | if "Assistant:" in response: |
| | response = response.split("Assistant:")[1].strip() |
| | |
| | return response |
| | except Exception as e: |
| | |
| | print(f"Error in generate_answer: {str(e)}") |
| | print(traceback.format_exc()) |
| | return f"Error processing: {str(e)}" |
| |
|
| | |
| | def create_app(): |
| | app = Flask(__name__) |
| | |
| | |
| | @app.route('/') |
| | def index(): |
| | return ''' |
| | <!DOCTYPE html> |
| | <html> |
| | <head> |
| | <title>Prot2Chat: Protein Q&A System</title> |
| | <style> |
| | body { font-family: Arial, sans-serif; max-width: 800px; margin: 0 auto; padding: 20px; } |
| | .form-group { margin-bottom: 15px; } |
| | label { display: block; margin-bottom: 5px; } |
| | input[type="file"], input[type="text"] { width: 100%; padding: 8px; } |
| | button { padding: 10px 15px; background-color: #4CAF50; color: white; border: none; cursor: pointer; } |
| | #result { margin-top: 20px; padding: 15px; border: 1px solid #ddd; border-radius: 4px; } |
| | #loading { display: none; margin-top: 20px; text-align: center; } |
| | </style> |
| | </head> |
| | <body> |
| | <h1>Prot2Chat: Protein Q&A System</h1> |
| | <form id="queryForm" enctype="multipart/form-data"> |
| | <div class="form-group"> |
| | <label for="pdbFile">Upload PDB file:</label> |
| | <input type="file" id="pdbFile" name="pdbFile" accept=".pdb" required> |
| | </div> |
| | <div class="form-group"> |
| | <label for="question">Enter your question:</label> |
| | <textarea id="question" name="question" placeholder="Please enter your question about protein..." required style="width: 100%; height: 100px;"></textarea> |
| | </div> |
| | <button type="submit">Submit</button> |
| | </form> |
| | <div id="loading">Processing, please wait...</div> |
| | <div id="result" style="display: none;"> |
| | <h2>Answer:</h2> |
| | <p id="answer"></p> |
| | </div> |
| | |
| | <script> |
| | document.getElementById('queryForm').addEventListener('submit', async function(e) { |
| | e.preventDefault(); |
| | |
| | const formData = new FormData(); |
| | formData.append('pdbFile', document.getElementById('pdbFile').files[0]); |
| | formData.append('question', document.getElementById('question').value); |
| | |
| | document.getElementById('result').style.display = 'none'; |
| | document.getElementById('loading').style.display = 'block'; |
| | |
| | try { |
| | const response = await fetch('/api/query', { |
| | method: 'POST', |
| | body: formData |
| | }); |
| | |
| | const data = await response.json(); |
| | document.getElementById('loading').style.display = 'none'; |
| | |
| | if (data.error) { |
| | document.getElementById('answer').textContent = "Error: " + data.error; |
| | } else { |
| | document.getElementById('answer').textContent = data.answer; |
| | } |
| | document.getElementById('result').style.display = 'block'; |
| | } catch (error) { |
| | document.getElementById('loading').style.display = 'none'; |
| | console.error('Error:', error); |
| | document.getElementById('answer').textContent = "An error occurred, please try again: " + error; |
| | document.getElementById('result').style.display = 'block'; |
| | } |
| | }); |
| | </script> |
| | </body> |
| | </html> |
| | ''' |
| |
|
| | |
| | @app.route('/api/query', methods=['POST']) |
| | def query(): |
| | try: |
| | |
| | if 'pdbFile' not in request.files: |
| | return jsonify({'error': 'No PDB file uploaded'}), 400 |
| | |
| | pdb_file = request.files['pdbFile'] |
| | if pdb_file.filename == '': |
| | return jsonify({'error': 'No file selected'}), 400 |
| | |
| | question = request.form.get('question', '') |
| | if not question: |
| | return jsonify({'error': 'Question cannot be empty'}), 400 |
| | |
| | |
| | temp_dir = os.path.join(os.getcwd(), 'temp') |
| | os.makedirs(temp_dir, exist_ok=True) |
| | temp_pdb_path = os.path.join(temp_dir, pdb_file.filename) |
| | pdb_file.save(temp_pdb_path) |
| | |
| | print(f"Processing file: {temp_pdb_path}") |
| | print(f"Question: {question}") |
| | |
| | |
| | answer = generate_answer(temp_pdb_path, question) |
| | |
| | |
| | if os.path.exists(temp_pdb_path): |
| | os.remove(temp_pdb_path) |
| | |
| | return jsonify({'answer': answer}) |
| | except Exception as e: |
| | print(f"API error: {str(e)}") |
| | print(traceback.format_exc()) |
| | return jsonify({'error': str(e)}), 500 |
| | |
| | return app |
| |
|
| |
|
| | if __name__ == '__main__': |
| | |
| | parser = argparse.ArgumentParser(description='Protein Structure Q&A System') |
| | parser.add_argument('--model_path', type=str, default="/data/zcwang/Protein/llama_output/output_v3", |
| | help='Path to the LLaMA model') |
| | parser.add_argument('--lora_path', type=str, default="/data/zcwang/Protein/project/2502/v12-2/lora_weight/lora_weights_1_400000.pth", |
| | help='Path to the LoRA weights') |
| | parser.add_argument('--adapter_path', type=str, default="/data/zcwang/Protein/project/2502/v12-2/adapt_weight/adapter_model_and_optimizer_1_400000.pth", |
| | help='Path to the adapter model weights') |
| | parser.add_argument('--port', type=int, default=7777, |
| | help='Server port') |
| | parser.add_argument('--gpu', type=str, default="0,1", |
| | help='IDs of the GPUs to use, e.g., "0,1"') |
| | |
| | args = parser.parse_args() |
| | |
| | |
| | os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu |
| | |
| | |
| | initialize_models(args.model_path, args.lora_path, args.adapter_path) |
| | |
| | |
| | app = create_app() |
| | app.run(host='localhost', port=args.port, debug=False) |