|
|
import concurrent.futures
|
|
|
import math
|
|
|
import time
|
|
|
from io import BytesIO
|
|
|
from pathlib import Path
|
|
|
|
|
|
import requests
|
|
|
from convert_atf import ATFConverter
|
|
|
from datasets import Dataset
|
|
|
from PIL import Image
|
|
|
from tqdm.auto import tqdm
|
|
|
|
|
|
atf_converter = ATFConverter()
|
|
|
|
|
|
IMG_CACHE = Path("./data/cdli_images")
|
|
|
IMG_CACHE.mkdir(exist_ok=True, parents=True)
|
|
|
MAX_IMG_RES = 2048
|
|
|
DOWNLOAD_MODE = False
|
|
|
|
|
|
|
|
|
def smart_resize(
|
|
|
height: int,
|
|
|
width: int,
|
|
|
factor: int = 28,
|
|
|
min_pixels: int = 28 * 28 * 130,
|
|
|
max_pixels: int = 28 * 28 * 1280,
|
|
|
):
|
|
|
"""Rescales the image so that the following conditions are met:
|
|
|
|
|
|
1. Both dimensions (height and width) are divisible by 'factor'.
|
|
|
|
|
|
2. The total number of pixels is within the range ['min_pixels', 'max_pixels'].
|
|
|
|
|
|
3. The aspect ratio of the image is maintained as closely as possible.
|
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if height < factor:
|
|
|
print(f"smart_resize: height={height} < factor={factor}, reset height=factor")
|
|
|
width = round((width * factor) / height)
|
|
|
height = factor
|
|
|
|
|
|
if width < factor:
|
|
|
print(f"smart_resize: width={width} < factor={factor}, reset width=factor")
|
|
|
height = round((height * factor) / width)
|
|
|
width = factor
|
|
|
|
|
|
if max(height, width) / min(height, width) > 200:
|
|
|
raise ValueError(
|
|
|
f"absolute aspect ratio must be smaller than 200, got {max(height, width) / min(height, width)}"
|
|
|
)
|
|
|
h_bar = round(height / factor) * factor
|
|
|
w_bar = round(width / factor) * factor
|
|
|
if h_bar * w_bar > max_pixels:
|
|
|
beta = math.sqrt((height * width) / max_pixels)
|
|
|
h_bar = math.floor(height / beta / factor) * factor
|
|
|
w_bar = math.floor(width / beta / factor) * factor
|
|
|
elif h_bar * w_bar < min_pixels:
|
|
|
beta = math.sqrt(min_pixels / (height * width))
|
|
|
h_bar = math.ceil(height * beta / factor) * factor
|
|
|
w_bar = math.ceil(width * beta / factor) * factor
|
|
|
return h_bar, w_bar
|
|
|
|
|
|
|
|
|
def resize_image(img_path):
|
|
|
with Image.open(img_path).convert("RGB") as image:
|
|
|
width, height = image.size
|
|
|
|
|
|
if width > MAX_IMG_RES or height > MAX_IMG_RES:
|
|
|
scale = MAX_IMG_RES / max(width, height)
|
|
|
height = int(height * scale)
|
|
|
width = int(width * scale)
|
|
|
|
|
|
new_height, new_width = smart_resize(height, width)
|
|
|
if new_height != image.height or new_width != image.width:
|
|
|
image = image.resize((new_width, new_height), Image.LANCZOS)
|
|
|
image.save(img_path)
|
|
|
|
|
|
|
|
|
def resize_cached_images():
|
|
|
img_paths = list(IMG_CACHE.glob("*.jpg"))
|
|
|
pbar = tqdm(img_paths)
|
|
|
|
|
|
with concurrent.futures.ThreadPoolExecutor(max_workers=20) as executor:
|
|
|
futures = [executor.submit(resize_image, img_path) for img_path in img_paths]
|
|
|
for future in concurrent.futures.as_completed(futures):
|
|
|
pbar.update(1)
|
|
|
|
|
|
pbar.close()
|
|
|
|
|
|
|
|
|
def get_image(id: int):
|
|
|
file_name = f"P{str(id).rjust(6, '0')}.jpg"
|
|
|
url = f"https://cdli.earth/dl/photo/{file_name}"
|
|
|
cache_file = IMG_CACHE / file_name
|
|
|
|
|
|
try:
|
|
|
if cache_file.exists():
|
|
|
tqdm.write(f"Found {file_name} in cache")
|
|
|
image = Image.open(cache_file).convert("RGB")
|
|
|
else:
|
|
|
response = requests.get(url, timeout=5)
|
|
|
response.raise_for_status()
|
|
|
image = Image.open(BytesIO(response.content)).convert("RGB")
|
|
|
|
|
|
tqdm.write(f"Downloaded {file_name}")
|
|
|
|
|
|
width, height = image.size
|
|
|
|
|
|
if width > MAX_IMG_RES or height > MAX_IMG_RES:
|
|
|
scale = MAX_IMG_RES / max(width, height)
|
|
|
height = int(height * scale)
|
|
|
width = int(width * scale)
|
|
|
|
|
|
new_height, new_width = smart_resize(height, width)
|
|
|
if new_height != image.height or new_width != image.width:
|
|
|
image = image.resize((new_width, new_height), Image.LANCZOS)
|
|
|
|
|
|
image.save(cache_file)
|
|
|
time.sleep(0.02)
|
|
|
except requests.exceptions.Timeout:
|
|
|
tqdm.write(f"Timeout downloading {file_name}")
|
|
|
return None
|
|
|
except requests.exceptions.RequestException as e:
|
|
|
tqdm.write(f"Error downloading {file_name}: {e}")
|
|
|
return None
|
|
|
except Exception as e:
|
|
|
tqdm.write(f"Error processing {file_name}: {type(e).__name__}: {e}")
|
|
|
return None
|
|
|
|
|
|
return image
|
|
|
|
|
|
|
|
|
def count_repetitions(text: str) -> int:
|
|
|
"""
|
|
|
Count the total number of repeated token occurrences in a sequence.
|
|
|
E.g., 122233 has 3 repetitions (2 appears 2 extra times, 3 appears 1 extra time).
|
|
|
"""
|
|
|
if len(text) < 2:
|
|
|
return 0
|
|
|
|
|
|
return len(text) - len(set(text))
|
|
|
|
|
|
|
|
|
def get_dataset(file="./data/cdli_dataset.parquet"):
|
|
|
if Path(file).exists():
|
|
|
return Dataset.from_parquet(file).train_test_split(test_size=1000, seed=42)
|
|
|
|
|
|
|
|
|
cdli_raw = Path("./data/cdli.atf").read_text(encoding="utf-8").split("&P")
|
|
|
cdli_filtered = [
|
|
|
section.strip()
|
|
|
for section in cdli_raw
|
|
|
if section.strip()
|
|
|
and "@tablet" in section
|
|
|
and len(section) > 50
|
|
|
and len(section) < 1000
|
|
|
and any(lang in section for lang in ["sux", "akk"])
|
|
|
]
|
|
|
|
|
|
ids = []
|
|
|
atfs = []
|
|
|
unicodes = []
|
|
|
|
|
|
for section in tqdm(cdli_filtered, desc="Parsing CDLI dump"):
|
|
|
|
|
|
lines = section.splitlines()
|
|
|
id_part = lines[0].split("=")[0].strip()
|
|
|
if not id_part.isdigit():
|
|
|
continue
|
|
|
|
|
|
atf = "\n".join(
|
|
|
[
|
|
|
line
|
|
|
for line in lines[1:]
|
|
|
if not (
|
|
|
line.startswith("# ")
|
|
|
or line.startswith(">>")
|
|
|
or line.startswith("<<")
|
|
|
or line.startswith("||")
|
|
|
)
|
|
|
]
|
|
|
)
|
|
|
parsed = atf_converter.parse(atf)
|
|
|
if parsed is None:
|
|
|
tqdm.write(f"=====\033[91m {id_part} skip (parse fail) \033[0m=====")
|
|
|
continue
|
|
|
|
|
|
unicode_parts = [
|
|
|
f"@{face}\n{parsed.get_unicode(face)}"
|
|
|
for face in parsed.ALL_FACES
|
|
|
if parsed.get_unicode(face)
|
|
|
]
|
|
|
|
|
|
|
|
|
unicode_len = sum([len(part) for part in unicode_parts])
|
|
|
if unicode_len > 300 or unicode_len < 20:
|
|
|
tqdm.write(f"=====\033[91m {id_part} skip (too short/long) \033[0m=====")
|
|
|
continue
|
|
|
|
|
|
if sum([part.count("x") for part in unicode_parts]) >= 2:
|
|
|
tqdm.write(f"=====\033[91m {id_part} skip (missing symbols) \033[0m=====")
|
|
|
continue
|
|
|
|
|
|
unicode = "\n".join(unicode_parts)
|
|
|
|
|
|
|
|
|
if count_repetitions(unicode) / len(unicode) > 0.7:
|
|
|
tqdm.write(f"=====\033[91m {id_part} skip (too repetitive) \033[0m=====")
|
|
|
continue
|
|
|
|
|
|
|
|
|
if DOWNLOAD_MODE:
|
|
|
image = get_image(int(id_part))
|
|
|
elif (IMG_CACHE / f"P{str(int(id_part)).rjust(6, '0')}.jpg").exists():
|
|
|
image = Image.open(
|
|
|
IMG_CACHE / f"P{str(int(id_part)).rjust(6, '0')}.jpg"
|
|
|
).convert("RGB")
|
|
|
else:
|
|
|
tqdm.write(f"=====\033[91m {id_part} skip (no img) \033[0m=====")
|
|
|
continue
|
|
|
|
|
|
if not image:
|
|
|
tqdm.write(f"=====\033[91m {id_part} skip (no img) \033[0m=====")
|
|
|
continue
|
|
|
|
|
|
|
|
|
try:
|
|
|
if min(image.size) < 100:
|
|
|
tqdm.write(f"=====\033[91m {id_part} skip (lowres) \033[0m=====")
|
|
|
continue
|
|
|
|
|
|
scale = 150 / image.height
|
|
|
small_image = image.resize(
|
|
|
(int(image.width * scale), int(image.height * scale)), Image.LANCZOS
|
|
|
)
|
|
|
pixels = list(small_image.getdata())
|
|
|
small_image.close()
|
|
|
image.close()
|
|
|
|
|
|
bw_pixels = sum(1 for r, g, b in pixels if r == g == b)
|
|
|
bw_percent = bw_pixels / len(pixels)
|
|
|
if bw_percent > 0.95 or bw_percent < 0.1:
|
|
|
tqdm.write(
|
|
|
f"=====\033[91m {id_part} skip (bw {bw_percent*100:.1f}%) \033[0m====="
|
|
|
)
|
|
|
continue
|
|
|
|
|
|
if sum(1 for r, g, b in pixels if r == g == b == 0) / len(pixels) < 0.15:
|
|
|
tqdm.write(
|
|
|
f"=====\033[91m {id_part} skip (not on black background) \033[0m====="
|
|
|
)
|
|
|
continue
|
|
|
except Exception as e:
|
|
|
tqdm.write(
|
|
|
f"=====\033[91m {id_part} skip (err img check: {e}) \033[0m====="
|
|
|
)
|
|
|
continue
|
|
|
|
|
|
ids.append(int(id_part))
|
|
|
atfs.append(atf)
|
|
|
unicodes.append(unicode)
|
|
|
|
|
|
tqdm.write(f"=====\033[32m {id_part} unicode (len {unicode_len}) \033[0m=====")
|
|
|
|
|
|
dataset = Dataset.from_dict(
|
|
|
{
|
|
|
"id": ids,
|
|
|
"atf": atfs,
|
|
|
"unicode": unicodes,
|
|
|
}
|
|
|
)
|
|
|
|
|
|
dataset.to_parquet(file)
|
|
|
return dataset.train_test_split(test_size=1000, seed=42)
|
|
|
|