Image-to-Text
Transformers
Safetensors
Khmer
khmer-ocr
feature-extraction
transformer
text-recognition
crnn
khmer-text-recognition
custom_code
Instructions to use Darayut/khmer-text-recognition with libraries, inference providers, notebooks, and local apps. Follow these links to get started.
- Libraries
- Transformers
How to use Darayut/khmer-text-recognition with Transformers:
# Use a pipeline as a high-level helper # Warning: Pipeline type "image-to-text" is no longer supported in transformers v5. # You must load the model directly (see below) or downgrade to v4.x with: # 'pip install "transformers<5.0.0' from transformers import pipeline pipe = pipeline("image-to-text", model="Darayut/khmer-text-recognition", trust_remote_code=True)# Load model directly from transformers import AutoModel model = AutoModel.from_pretrained("Darayut/khmer-text-recognition", trust_remote_code=True, dtype="auto") - Notebooks
- Google Colab
- Kaggle
| # modeling_khmerocr.py | |
| import torch | |
| import torch.nn as nn | |
| from torch.nn.utils.rnn import pad_sequence | |
| from transformers import PreTrainedModel | |
| from .configuration_khmerocr import KhmerOCRConfig | |
| # ========================================== | |
| # 1. HELPER CLASSES (SequenceSE, CNN, etc.) | |
| # ========================================== | |
| class SequenceSE(nn.Module): | |
| def __init__(self, channels, reduction=16): | |
| super(SequenceSE, self).__init__() | |
| self.fc = nn.Sequential( | |
| nn.Conv1d(channels, channels // reduction, kernel_size=1), | |
| nn.ReLU(inplace=True), | |
| nn.Conv1d(channels // reduction, channels, kernel_size=1), | |
| nn.Sigmoid() | |
| ) | |
| def forward(self, x): | |
| b, c, h, w = x.size() | |
| y = torch.mean(x, dim=2).view(b, c, w) | |
| y = self.fc(y) | |
| y = y.view(b, c, 1, w) | |
| return x * y | |
| class ImprovedFeatureExtractor(nn.Module): | |
| def __init__(self): | |
| super().__init__() | |
| self.conv1 = nn.Sequential(nn.Conv2d(1, 64, 3, 1, 1), nn.BatchNorm2d(64), nn.ReLU(True)) | |
| self.pool1 = nn.MaxPool2d(2, 2) | |
| self.conv2 = nn.Sequential(nn.Conv2d(64, 128, 3, 1, 1), nn.BatchNorm2d(128), nn.ReLU(True)) | |
| self.pool2 = nn.MaxPool2d(2, 2) | |
| self.conv3 = nn.Sequential(nn.Conv2d(128, 256, 3, 1, 1), nn.BatchNorm2d(256), nn.ReLU(True)) | |
| self.conv4 = nn.Sequential(nn.Conv2d(256, 256, 3, 1, 1), nn.BatchNorm2d(256), nn.ReLU(True)) | |
| self.se3 = SequenceSE(256) | |
| self.pool3 = nn.MaxPool2d(kernel_size=(2, 1), stride=(2, 1)) | |
| self.conv5 = nn.Sequential(nn.Conv2d(256, 512, 3, 1, 1), nn.BatchNorm2d(512), nn.ReLU(True)) | |
| self.conv6 = nn.Sequential(nn.Conv2d(512, 512, 3, 1, 1), nn.BatchNorm2d(512), nn.ReLU(True)) | |
| self.se4 = SequenceSE(512) | |
| self.pool4 = nn.MaxPool2d(kernel_size=(2, 1), stride=(2, 1)) | |
| self.conv7 = nn.Conv2d(512, 512, 3, 1, 1) | |
| self.bn7 = nn.BatchNorm2d(512) | |
| self.relu7 = nn.ReLU(True) | |
| self.se5 = SequenceSE(512) | |
| self.final_pool = nn.AdaptiveAvgPool2d((2, 32)) | |
| def forward(self, x): | |
| x = self.pool1(self.conv1(x)) | |
| x = self.pool2(self.conv2(x)) | |
| x = self.conv4(self.conv3(x)) | |
| x = self.se3(x) | |
| x = self.pool3(x) | |
| x = self.conv6(self.conv5(x)) | |
| x = self.se4(x) | |
| x = self.pool4(x) | |
| x = self.relu7(self.bn7(self.conv7(x))) | |
| x = self.se5(x) | |
| x = self.final_pool(x) | |
| return x | |
| class PatchEncoder(nn.Module): | |
| def __init__(self, in_channels, emb_dim, k1=2, k2=1, max_patches=256): | |
| super().__init__() | |
| self.proj = nn.Conv2d(in_channels, emb_dim, kernel_size=(k1, k2), stride=(k1, k2)) | |
| self.pos_emb = nn.Parameter(torch.zeros(max_patches, emb_dim)) | |
| nn.init.trunc_normal_(self.pos_emb, std=0.02) | |
| def forward(self, F): | |
| x = self.proj(F) | |
| B, D, Hp, Wp = x.shape | |
| N = Hp * Wp | |
| x = x.flatten(2).transpose(1, 2) | |
| x = x + self.pos_emb[:N].unsqueeze(0) | |
| return x, N | |
| def make_encoder(emb_dim=384, nhead=8, num_layers=3, dim_feedforward=1024, dropout=0.1): | |
| enc_layer = nn.TransformerEncoderLayer(d_model=emb_dim, nhead=nhead, | |
| dim_feedforward=dim_feedforward, | |
| dropout=dropout, activation='relu') | |
| return nn.TransformerEncoder(enc_layer, num_layers=num_layers) | |
| class TransformerDecoderWrapper(nn.Module): | |
| def __init__(self, vocab_size, emb_dim, nhead=8, num_layers=3, pad_idx=0, max_len=256): | |
| super().__init__() | |
| self.tok_emb = nn.Embedding(vocab_size, emb_dim, padding_idx=pad_idx) | |
| dec_layer = nn.TransformerDecoderLayer(d_model=emb_dim, nhead=nhead, dim_feedforward=emb_dim*4, dropout=0.1) | |
| self.decoder = nn.TransformerDecoder(dec_layer, num_layers=num_layers) | |
| self.pos_emb = nn.Parameter(torch.zeros(max_len, emb_dim)) | |
| nn.init.trunc_normal_(self.pos_emb, std=0.1) | |
| self.out_proj = nn.Linear(emb_dim, vocab_size) | |
| self.pad_idx = pad_idx | |
| def generate_square_subsequent_mask(self, sz): | |
| mask = (torch.triu(torch.ones(sz, sz)) == 1).transpose(0, 1) | |
| mask = mask.float().masked_fill(mask == 0, float('-inf')).masked_fill(mask == 1, float(0.0)) | |
| return mask | |
| def forward(self, tgt_tokens, memory, memory_key_padding_mask): | |
| B, T = tgt_tokens.size() | |
| device = tgt_tokens.device | |
| tok = self.tok_emb(tgt_tokens) | |
| pos = self.pos_emb[:T,:].unsqueeze(0).expand(B,-1,-1) | |
| tgt = (tok + pos).transpose(0,1) | |
| tgt_key_padding_mask = (tgt_tokens == self.pad_idx) | |
| if memory_key_padding_mask is not None: | |
| memory_key_padding_mask = memory_key_padding_mask.bool() | |
| tgt_mask = self.generate_square_subsequent_mask(T).to(device) | |
| mem = memory.transpose(0,1) | |
| dec_out = self.decoder(tgt, mem, tgt_mask=tgt_mask, tgt_key_padding_mask=tgt_key_padding_mask, memory_key_padding_mask=memory_key_padding_mask) | |
| return self.out_proj(dec_out.transpose(0,1)) | |
| # ========================================== | |
| # 2. MAIN MODEL WRAPPER | |
| # ========================================== | |
| class KhmerOCR(PreTrainedModel): | |
| config_class = KhmerOCRConfig | |
| def __init__(self, config): | |
| super().__init__(config) | |
| self.vocab_size = config.vocab_size | |
| self.pad_idx = config.pad_idx | |
| self.emb_dim = config.emb_dim | |
| self.cnn = ImprovedFeatureExtractor() | |
| self.patch = PatchEncoder(512, emb_dim=self.emb_dim, k1=2, k2=1) | |
| self.enc = make_encoder(emb_dim=self.emb_dim, nhead=config.nhead, num_layers=config.num_encoder_layers) | |
| self.global_pos = nn.Parameter(torch.zeros(config.max_global_len, self.emb_dim)) | |
| nn.init.trunc_normal_(self.global_pos, std=0.02) | |
| self.context_bilstm = nn.LSTM( | |
| input_size=self.emb_dim, | |
| hidden_size=self.emb_dim // 2, | |
| num_layers=1, | |
| batch_first=True, | |
| bidirectional=True | |
| ) | |
| self.dec = TransformerDecoderWrapper(self.vocab_size, emb_dim=self.emb_dim, nhead=config.nhead, | |
| num_layers=config.num_decoder_layers, pad_idx=self.pad_idx) | |
| def forward(self, chunk_lists, tgt_tokens=None): | |
| # 1. Flatten | |
| chunk_sizes = [len(c) for c in chunk_lists] | |
| flat_input_list = [chunk for img_chunks in chunk_lists for chunk in img_chunks] | |
| flat_input = torch.stack(flat_input_list) | |
| # 2. Pipeline | |
| f = self.cnn(flat_input) | |
| p, _ = self.patch(f) | |
| p = p.transpose(0, 1).contiguous() | |
| enc_out = self.enc(p) | |
| enc_out = enc_out.transpose(0, 1) | |
| # 3. Merge | |
| batch_encoded_list = [] | |
| cursor = 0 | |
| feature_dim = enc_out.size(-1) | |
| for size in chunk_sizes: | |
| img_chunks = enc_out[cursor : cursor + size] | |
| merged_seq = img_chunks.reshape(-1, feature_dim) | |
| batch_encoded_list.append(merged_seq) | |
| cursor += size | |
| # 4. Pad & Global Pos | |
| memory = pad_sequence(batch_encoded_list, batch_first=True, padding_value=0.0) | |
| B, T, _ = memory.shape | |
| limit = min(T, self.global_pos.size(0)) | |
| pos_emb = self.global_pos[:limit, :].unsqueeze(0) | |
| if T > self.global_pos.size(0): | |
| memory = memory[:, :limit, :] + pos_emb | |
| T = limit | |
| else: | |
| memory = memory + pos_emb | |
| # 5. BiLSTM | |
| self.context_bilstm.flatten_parameters() | |
| memory, _ = self.context_bilstm(memory) | |
| # If inference (no targets), return memory for search | |
| if tgt_tokens is None: | |
| return memory | |
| # 6. Decoder | |
| memory_key_padding_mask = torch.ones((B, T), dtype=torch.bool, device=memory.device) | |
| for i, seq in enumerate(batch_encoded_list): | |
| valid_len = min(seq.shape[0], T) | |
| memory_key_padding_mask[i, :valid_len] = False | |
| logits = self.dec(tgt_tokens, memory, memory_key_padding_mask) | |
| return logits |