hysts's picture
hysts HF Staff
Update
5cbc605
"""AnimeGANv3 Portrait Sketch.
This code is adapted from https://github.com/TachibanaYoshino/AnimeGANv3
https://colab.research.google.com/drive/1XYNWwM8Xq-U7KaTOqNap6A-Yq1f-V-FB?usp=sharing
The original license is the following:
```
License
This repo is made freely available to academic and
non-academic entities for non-commercial purposes such
as academic research, teaching, scientific publications.
Permission is granted to use the AnimeGANv3 given
that you agree to my license terms. Regarding the
request for commercial use, please contact us via
email to help you obtain the authorization letter.
Author
Asher Chan
```
"""
import cv2
import huggingface_hub
import numpy as np
import onnxruntime as ort
import torch
from facenet_pytorch import MTCNN
class Model:
def __init__(self) -> None:
self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
self.face_detector = MTCNN(
image_size=256, margin=80, min_face_size=128, thresholds=[0.7, 0.8, 0.9], device=self.device
)
self.model = self._load_model()
@staticmethod
def _load_model() -> ort.InferenceSession:
path = huggingface_hub.hf_hub_download(
"public-data/AnimeGANv3-portrait-sketch", "AnimeGANv3_PortraitSketch_25.onnx"
)
return ort.InferenceSession(path, providers=["CUDAExecutionProvider", "CPUExecutionProvider"])
def detect_face(self, img: np.ndarray) -> tuple[np.ndarray, np.ndarray]:
batch_boxes, batch_probs, batch_points = self.face_detector.detect(img, landmarks=True)
return batch_boxes, batch_points
@staticmethod
def margin_face(box: np.ndarray, img_hw: tuple[int, int], margin: float | None = None) -> list[int]:
x1, y1, x2, y2 = box
if margin is None:
w = x2 - x1
h = y2 - y1
x1 = max(0, x1 - w / 2)
y1 = max(0, y1 - h / 2)
x2 = min(img_hw[1], x2 + w / 2)
y2 = min(img_hw[0], y2 + h / 2)
else:
x1 = max(0, x1 - margin)
y1 = max(0, y1 - margin)
x2 = min(img_hw[1], x2 + margin)
y2 = min(img_hw[0], y2 + margin)
return list(map(int, [x1, y1, x2, y2]))
@staticmethod
def resize_image(img: np.ndarray, x32: bool = True) -> np.ndarray:
h, w = img.shape[:2]
ratio = h / w
if x32: # resize image to multiple of 32s
def to_32s(x: int) -> int:
return 256 if x < 256 else x - x % 32 # noqa: PLR2004
new_h = to_32s(h)
new_w = int(new_h / ratio) - int(new_h / ratio) % 32
img = cv2.resize(img, (new_w, new_h))
return img
def preprocess_image(self, image: np.ndarray) -> np.ndarray:
# BGR
batch_boxes, batch_points = self.detect_face(image)
if batch_boxes is None:
print("No face detected!") # noqa: T201
return None
[x1, y1, x2, y2] = self.margin_face(batch_boxes[0], image.shape[:2])
face_mat = image[y1:y2, x1:x2]
face_mat = self.resize_image(face_mat)
# BGR -> RGB
face_mat = cv2.cvtColor(face_mat, cv2.COLOR_BGR2RGB)
face_mat = face_mat.astype(np.float32) / 127.5 - 1.0
return np.expand_dims(face_mat, axis=0)
def run_model(self, face_mat: np.ndarray) -> np.ndarray:
x = self.model.get_inputs()[0].name
fake_img = self.model.run(None, {x: face_mat})
output_image = (np.squeeze(fake_img) + 1.0) / 2 * 255
return np.clip(output_image, 0, 255).astype(np.uint8)
def run(self, image: np.ndarray) -> np.ndarray:
# BGR -> RGB
data = self.preprocess_image(image[:, :, ::-1])
return self.run_model(data) # RGB