import os os.environ["TOKENIZERS_PARALLELISM"] = "false" import torch import torch.nn as nn from torch.utils.data import DataLoader, Dataset from transformers import AutoTokenizer from tqdm import tqdm import os import logging from src.model import ParallelT5Small from src.data_preprocess import DataProcessor # Setup logging logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) class TranslationDataset(Dataset): def __init__(self, data): self.data = data def __len__(self): return len(self.data) def __getitem__(self, idx): item = self.data[idx] return { 'input_ids': item['input_ids'], 'attention_mask': item['attention_mask'], 'labels': item['labels'] } def setup_training(data_dir="./data/processed"): """Setup model, tokenizer, and data using preprocessed data""" # Load preprocessed data processor = DataProcessor() processed_data, tokenizer = processor.load_processed_data(data_dir) if not processed_data: logger.error("No processed data found. Please run data_preprocess.py first.") return None, None, None, None # Create dataloaders dataloaders = {} for split_name, data in processed_data.items(): dataset = TranslationDataset(data) batch_size = 16 if split_name == 'train' else 32 dataloaders[split_name] = DataLoader( dataset, batch_size=batch_size, shuffle=(split_name == 'train'), num_workers=2, pin_memory=True ) # Initialize model model = ParallelT5Small(vocab_size=tokenizer.vocab_size) logger.info(f"Train samples: {len(processed_data['train'])}") logger.info(f"Validation samples: {len(processed_data['validation'])}") logger.info(f"Test samples: {len(processed_data['test'])}") return model, tokenizer, dataloaders['train'], dataloaders['validation'] def calculate_accuracy(logits, labels, ignore_index=-100): """Calculate accuracy for the model""" predictions = logits.argmax(dim=-1) mask = labels != ignore_index correct = (predictions == labels) & mask accuracy = correct.sum().float() / mask.sum().float() return accuracy def train_epoch(model, dataloader, optimizer, criterion, device, tokenizer): """Train for one epoch""" model.train() total_loss = 0 total_accuracy = 0 progress_bar = tqdm(dataloader, desc="Training") for batch in progress_bar: optimizer.zero_grad() input_ids = batch['input_ids'].to(device) attention_mask = batch['attention_mask'].to(device) labels = batch['labels'].to(device) # Forward pass logits = model(input_ids, attention_mask) # Calculate loss loss = criterion(logits.view(-1, logits.size(-1)), labels.view(-1)) # Calculate accuracy accuracy = calculate_accuracy(logits, labels, tokenizer.pad_token_id) # Backward pass loss.backward() torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0) optimizer.step() total_loss += loss.item() total_accuracy += accuracy.item() progress_bar.set_postfix({ 'loss': loss.item(), 'acc': accuracy.item() }) avg_loss = total_loss / len(dataloader) avg_accuracy = total_accuracy / len(dataloader) return avg_loss, avg_accuracy def validate_epoch(model, dataloader, criterion, device, tokenizer): """Validate for one epoch""" model.eval() total_loss = 0 total_accuracy = 0 with torch.no_grad(): for batch in tqdm(dataloader, desc="Validation"): input_ids = batch['input_ids'].to(device) attention_mask = batch['attention_mask'].to(device) labels = batch['labels'].to(device) logits = model(input_ids, attention_mask) loss = criterion(logits.view(-1, logits.size(-1)), labels.view(-1)) accuracy = calculate_accuracy(logits, labels, tokenizer.pad_token_id) total_loss += loss.item() total_accuracy += accuracy.item() avg_loss = total_loss / len(dataloader) avg_accuracy = total_accuracy / len(dataloader) return avg_loss, avg_accuracy def train_model(): """Main training function""" # Setup device device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') logger.info(f"Using device: {device}") # Setup training components model, tokenizer, train_loader, val_loader = setup_training() if model is None: return model.to(device) # Training parameters optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4, weight_decay=0.01) scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=10) criterion = nn.CrossEntropyLoss(ignore_index=tokenizer.pad_token_id) # Training loop num_epochs = 50 best_val_loss = float('inf') # Training history history = { 'train_loss': [], 'train_acc': [], 'val_loss': [], 'val_acc': [] } for epoch in range(num_epochs): logger.info(f"\nEpoch {epoch+1}/{num_epochs}") # Train train_loss, train_acc = train_epoch( model, train_loader, optimizer, criterion, device, tokenizer ) # Validate val_loss, val_acc = validate_epoch( model, val_loader, criterion, device, tokenizer ) # Update history history['train_loss'].append(train_loss) history['train_acc'].append(train_acc) history['val_loss'].append(val_loss) history['val_acc'].append(val_acc) logger.info(f"Train Loss: {train_loss:.4f}, Train Acc: {train_acc:.4f}") logger.info(f"Val Loss: {val_loss:.4f}, Val Acc: {val_acc:.4f}") # Save best model if val_loss < best_val_loss: best_val_loss = val_loss torch.save({ 'epoch': epoch, 'model_state_dict': model.state_dict(), 'optimizer_state_dict': optimizer.state_dict(), 'val_loss': val_loss, 'history': history }, 'checkpoints/best_model.pth') logger.info("Saved best model!") # Save checkpoint every 2 epochs if (epoch + 1) % 2 == 0: torch.save({ 'epoch': epoch, 'model_state_dict': model.state_dict(), 'optimizer_state_dict': optimizer.state_dict(), 'val_loss': val_loss, 'history': history }, f'checkpoints/checkpoint_epoch_{epoch+1}.pth') scheduler.step() # Save final model torch.save(model.state_dict(), 'checkpoints/final_model.pth') # Save training history torch.save(history, 'checkpoints/training_history.pth') logger.info("Training completed!") logger.info(f"Best validation loss: {best_val_loss:.4f}") if __name__ == "__main__": # Create directories os.makedirs('checkpoints', exist_ok=True) os.makedirs('data/processed', exist_ok=True) # Start training train_model()