NabuOCR / training /get_cdli_dataset.py
boatbomber's picture
Upload training code
819ced0 verified
import concurrent.futures
import math
import time
from io import BytesIO
from pathlib import Path
import requests
from convert_atf import ATFConverter
from datasets import Dataset
from PIL import Image
from tqdm.auto import tqdm
atf_converter = ATFConverter()
IMG_CACHE = Path("./data/cdli_images")
IMG_CACHE.mkdir(exist_ok=True, parents=True)
MAX_IMG_RES = 2048
DOWNLOAD_MODE = False
def smart_resize(
height: int,
width: int,
factor: int = 28,
min_pixels: int = 28 * 28 * 130,
max_pixels: int = 28 * 28 * 1280,
):
"""Rescales the image so that the following conditions are met:
1. Both dimensions (height and width) are divisible by 'factor'.
2. The total number of pixels is within the range ['min_pixels', 'max_pixels'].
3. The aspect ratio of the image is maintained as closely as possible.
"""
# if height < factor or width < factor:
# raise ValueError(f"height:{height} or width:{width} must be larger than factor:{factor}")
# if int(height < factor//4) + int(width < factor//4):
# raise ValueError(f"height:{height} or width:{width} must be larger than factor:{factor//4}")
if height < factor:
print(f"smart_resize: height={height} < factor={factor}, reset height=factor")
width = round((width * factor) / height)
height = factor
if width < factor:
print(f"smart_resize: width={width} < factor={factor}, reset width=factor")
height = round((height * factor) / width)
width = factor
if max(height, width) / min(height, width) > 200:
raise ValueError(
f"absolute aspect ratio must be smaller than 200, got {max(height, width) / min(height, width)}"
)
h_bar = round(height / factor) * factor
w_bar = round(width / factor) * factor
if h_bar * w_bar > max_pixels:
beta = math.sqrt((height * width) / max_pixels)
h_bar = math.floor(height / beta / factor) * factor
w_bar = math.floor(width / beta / factor) * factor
elif h_bar * w_bar < min_pixels:
beta = math.sqrt(min_pixels / (height * width))
h_bar = math.ceil(height * beta / factor) * factor
w_bar = math.ceil(width * beta / factor) * factor
return h_bar, w_bar
def resize_image(img_path):
with Image.open(img_path).convert("RGB") as image:
width, height = image.size
# Scale down if larger than MAX_IMG_RES
if width > MAX_IMG_RES or height > MAX_IMG_RES:
scale = MAX_IMG_RES / max(width, height)
height = int(height * scale)
width = int(width * scale)
# Always ensure dimensions are multiples of 28 for vision model compatibility
new_height, new_width = smart_resize(height, width)
if new_height != image.height or new_width != image.width:
image = image.resize((new_width, new_height), Image.LANCZOS)
image.save(img_path)
def resize_cached_images():
img_paths = list(IMG_CACHE.glob("*.jpg"))
pbar = tqdm(img_paths)
with concurrent.futures.ThreadPoolExecutor(max_workers=20) as executor:
futures = [executor.submit(resize_image, img_path) for img_path in img_paths]
for future in concurrent.futures.as_completed(futures):
pbar.update(1)
pbar.close()
def get_image(id: int):
file_name = f"P{str(id).rjust(6, '0')}.jpg"
url = f"https://cdli.earth/dl/photo/{file_name}"
cache_file = IMG_CACHE / file_name
try:
if cache_file.exists():
tqdm.write(f"Found {file_name} in cache")
image = Image.open(cache_file).convert("RGB")
else:
response = requests.get(url, timeout=5)
response.raise_for_status()
image = Image.open(BytesIO(response.content)).convert("RGB")
tqdm.write(f"Downloaded {file_name}")
width, height = image.size
# Scale down if larger than MAX_IMG_RES
if width > MAX_IMG_RES or height > MAX_IMG_RES:
scale = MAX_IMG_RES / max(width, height)
height = int(height * scale)
width = int(width * scale)
# Always ensure dimensions are multiples of 28 for vision model compatibility
new_height, new_width = smart_resize(height, width)
if new_height != image.height or new_width != image.width:
image = image.resize((new_width, new_height), Image.LANCZOS)
image.save(cache_file)
time.sleep(0.02) # Rate limiting
except requests.exceptions.Timeout:
tqdm.write(f"Timeout downloading {file_name}")
return None
except requests.exceptions.RequestException as e:
tqdm.write(f"Error downloading {file_name}: {e}")
return None
except Exception as e:
tqdm.write(f"Error processing {file_name}: {type(e).__name__}: {e}")
return None
return image
def count_repetitions(text: str) -> int:
"""
Count the total number of repeated token occurrences in a sequence.
E.g., 122233 has 3 repetitions (2 appears 2 extra times, 3 appears 1 extra time).
"""
if len(text) < 2:
return 0
return len(text) - len(set(text))
def get_dataset(file="./data/cdli_dataset.parquet"):
if Path(file).exists():
return Dataset.from_parquet(file).train_test_split(test_size=1000, seed=42)
# 1. Get all the ids from cdli.atf (source: https://github.com/cdli-gh/data/raw/refs/heads/master/cdliatf_unblocked.atf)
cdli_raw = Path("./data/cdli.atf").read_text(encoding="utf-8").split("&P")
cdli_filtered = [
section.strip()
for section in cdli_raw
if section.strip() # Ignore empty sections
and "@tablet" in section # Only include tablets
and len(section) > 50 # Ignore short sections
and len(section) < 1000 # Ignore long sections
and any(lang in section for lang in ["sux", "akk"]) # Limit supported languages
]
ids = []
atfs = []
unicodes = []
for section in tqdm(cdli_filtered, desc="Parsing CDLI dump"):
# Split section at first space to get the ID, ignore if not parseable
lines = section.splitlines()
id_part = lines[0].split("=")[0].strip()
if not id_part.isdigit():
continue
atf = "\n".join(
[
line
for line in lines[1:]
if not (
line.startswith("# ")
or line.startswith(">>")
or line.startswith("<<")
or line.startswith("||")
)
]
)
parsed = atf_converter.parse(atf)
if parsed is None:
tqdm.write(f"=====\033[91m {id_part} skip (parse fail) \033[0m=====")
continue
unicode_parts = [
f"@{face}\n{parsed.get_unicode(face)}"
for face in parsed.ALL_FACES
if parsed.get_unicode(face)
]
# Skip massive tablets
unicode_len = sum([len(part) for part in unicode_parts])
if unicode_len > 300 or unicode_len < 20:
tqdm.write(f"=====\033[91m {id_part} skip (too short/long) \033[0m=====")
continue
# Skip tablets that are poorly translated to unicode
if sum([part.count("x") for part in unicode_parts]) >= 2:
tqdm.write(f"=====\033[91m {id_part} skip (missing symbols) \033[0m=====")
continue
unicode = "\n".join(unicode_parts)
# Drop the super repetitive admin tablets (model ends up getting stuck repeating the common phrases)
if count_repetitions(unicode) / len(unicode) > 0.7:
tqdm.write(f"=====\033[91m {id_part} skip (too repetitive) \033[0m=====")
continue
# Ignore if we don't have an image for this atf
if DOWNLOAD_MODE:
image = get_image(int(id_part))
elif (IMG_CACHE / f"P{str(int(id_part)).rjust(6, '0')}.jpg").exists():
image = Image.open(
IMG_CACHE / f"P{str(int(id_part)).rjust(6, '0')}.jpg"
).convert("RGB")
else:
tqdm.write(f"=====\033[91m {id_part} skip (no img) \033[0m=====")
continue
if not image:
tqdm.write(f"=====\033[91m {id_part} skip (no img) \033[0m=====")
continue
# Drop low res, B&W, or non-isolated background
try:
if min(image.size) < 100:
tqdm.write(f"=====\033[91m {id_part} skip (lowres) \033[0m=====")
continue
scale = 150 / image.height
small_image = image.resize(
(int(image.width * scale), int(image.height * scale)), Image.LANCZOS
)
pixels = list(small_image.getdata())
small_image.close()
image.close()
bw_pixels = sum(1 for r, g, b in pixels if r == g == b)
bw_percent = bw_pixels / len(pixels)
if bw_percent > 0.95 or bw_percent < 0.1:
tqdm.write(
f"=====\033[91m {id_part} skip (bw {bw_percent*100:.1f}%) \033[0m====="
)
continue
if sum(1 for r, g, b in pixels if r == g == b == 0) / len(pixels) < 0.15:
tqdm.write(
f"=====\033[91m {id_part} skip (not on black background) \033[0m====="
)
continue
except Exception as e:
tqdm.write(
f"=====\033[91m {id_part} skip (err img check: {e}) \033[0m====="
)
continue
ids.append(int(id_part))
atfs.append(atf)
unicodes.append(unicode)
tqdm.write(f"=====\033[32m {id_part} unicode (len {unicode_len}) \033[0m=====")
dataset = Dataset.from_dict(
{
"id": ids,
"atf": atfs,
"unicode": unicodes,
}
)
dataset.to_parquet(file)
return dataset.train_test_split(test_size=1000, seed=42)