Pill-Counter / app.py
WizardForest's picture
Update app.py
7656c57 verified
import gradio as gr
from ultralytics import YOLO
from PIL import Image, ImageDraw, ImageFont
import numpy as np
model = YOLO("model.pt")
from PIL import Image, ImageDraw, ImageFont
import numpy as np
def predict(img, target_class, dot_color, text_color_choice, dot_size, font_size):
path = img.split("\\")[-1].split(".")[0]
print("path", path)
results = model.predict(source=img, save=False, show_labels=False, show_conf=False)
if not isinstance(results, (list, tuple)):
results = [results]
# 顏色設定
if dot_color == "紅色":
dot_rgb = (255, 0, 0)
else:
dot_rgb = (57, 255, 20) # 螢光綠
if text_color_choice == "黑色":
text_rgb = (0, 0, 0)
else:
text_rgb = (255, 255, 255)
# 大小設定
try:
r = int(dot_size)
except:
r = 10
if r < 2: r = 2
try:
fsize = int(font_size)
except:
fsize = 36
if fsize < 8: fsize = 8
count = 0
annot_img = None
annot_img_numbered = None
# 轉換成 YOLO 類別索引
if target_class == "膠囊":
class_idx = 0
else:
class_idx = 1
for i in results:
# 過濾出該類別的框
selected_boxes = []
if hasattr(i.boxes, "cls"):
for b, c in zip(i.boxes.xyxy, i.boxes.cls):
if int(c) == class_idx:
selected_boxes.append(b)
count = len(selected_boxes)
img_pil = Image.fromarray(i.orig_img[..., ::-1])
img_pil_numbered = img_pil.copy()
draw = ImageDraw.Draw(img_pil)
draw_num = ImageDraw.Draw(img_pil_numbered)
try:
font = ImageFont.truetype("DejaVuSans.ttf", fsize)
except:
try:
font = ImageFont.truetype("arial.ttf", fsize)
except:
font = ImageFont.load_default()
for idx, box in enumerate(selected_boxes, start=1):
x1, y1, x2, y2 = box[:4].tolist()
cx, cy = (x1 + x2) / 2, (y1 + y2) / 2
# 第一張:只有圓點
draw.ellipse((cx - r, cy - r, cx + r, cy + r), fill=dot_rgb, outline=None)
# 第二張:圓點 + 編號
draw_num.ellipse((cx - r, cy - r, cx + r, cy + r), fill=dot_rgb, outline=None)
# 最後一顆改紅字
num_color = (255, 0, 0) if idx == count else text_rgb
draw_num.text((cx + r + 4, cy - r), str(idx), fill=num_color, font=font)
annot_img = np.array(img_pil)
annot_img_numbered = np.array(img_pil_numbered)
return str(count), annot_img, annot_img_numbered
# 自定義 CSS 樣式(背景改為白色、簡潔風格)
custom_css = """
.gradio-container {
background: #ffffff;
font-family: 'Segoe UI', Tahoma, Geneva, Verdana, sans-serif;
}
/* 主標題 */
.main-header {
text-align: center;
color: #333333;
font-size: 2.2em;
font-weight: bold;
margin-bottom: 5px;
}
/* 副標題 */
.sub-header {
text-align: center;
color: #555555;
font-size: 1.1em;
margin-bottom: 20px;
font-weight: 300;
}
/* 卡片容器 */
.card-container {
background: #f9f9f9;
border-radius: 12px;
padding: 20px;
box-shadow: 0 4px 12px rgba(0,0,0,0.05);
margin: 15px;
}
/* 輸入區 */
.input-section {
background: #ffffff;
border-radius: 10px;
padding: 20px;
margin-bottom: 15px;
border: 1px solid #e0e0e0;
}
/* 輸出區 */
.output-section {
background: #ffffff;
border-radius: 10px;
padding: 20px;
border: 1px solid #e0e0e0;
}
/* 按鈕 */
.predict-btn {
background: #4285f4 !important;
border: none !important;
border-radius: 20px !important;
padding: 12px 32px !important;
color: white !important;
font-size: 16px !important;
font-weight: 600 !important;
transition: background 0.2s ease !important;
cursor: pointer !important;
}
.predict-btn:hover {
background: #3367d6 !important;
}
/* 圖片上傳 */
.image-upload {
border: 2px dashed #cccccc !important;
border-radius: 10px !important;
background: #fcfcfc !important;
padding: 15px !important;
}
.image-upload:hover {
border-color: #bbbbbb !important;
background: #f5f5f5 !important;
}
/* 數量文字 */
.count-output {
font-size: 22px !important;
font-weight: 600 !important;
color: #222222 !important;
text-align: center !important;
background: #ffffff !important;
border: 1px solid #e0e0e0 !important;
border-radius: 8px !important;
padding: 12px !important;
}
/* 結果圖 */
.result-image {
border-radius: 8px !important;
box-shadow: 0 6px 18px rgba(0,0,0,0.04) !important;
border: 1px solid #e0e0e0 !important;
}
/* 響應式 */
@media (max-width: 768px) {
.main-header { font-size: 1.8em; }
.predict-btn { width: 100% !important; margin-top: 15px !important; }
}
"""
with gr.Blocks(
title="Pill Counter",
theme=gr.themes.Soft(
primary_hue="blue",
secondary_hue="green",
neutral_hue="gray",
font=gr.themes.GoogleFont("Noto Sans TC")
),
css=custom_css
) as demo:
gr.HTML("""
<div class="main-header">Pill Counter</div>
""")
with gr.Row():
with gr.Column(scale=1):
gr.HTML('<h3 style="color: #4285f4; text-align: center; margin-bottom: 15px;">📤 上傳藥物圖片</h3>')
img = gr.Image(
type="filepath",
format=["jpg", "png"],
height=450,
width=800,
label="請選擇或拖放圖片",
elem_classes=["image-upload"]
)
target_class = gr.Radio(
choices=["膠囊", "錠劑"],
value="膠囊",
label="要計算的類別",
info="選擇只計算錠劑或膠囊的總數"
)
gr.HTML('<div style="text-align: center; margin-top: 15px;">')
button = gr.Button(
"開始計算",
variant="primary",
elem_classes=["predict-btn"]
)
gr.HTML('</div><div style="text-align:center; margin-top:8px; color:#888;">支援 JPG/PNG,建議 640×640 以上</div>')
gr.HTML('</div>')
dot_color = gr.Radio(
choices=["紅色", "螢光綠"],
value="紅色",
label="圓點顏色",
info="選擇藥錠中心圓點顏色"
)
text_color_choice = gr.Radio(
choices=["黑色", "白色"],
value="黑色",
label="文字顏色",
info="選擇編號文字顏色"
)
dot_size = gr.Textbox(
label="圓點大小",
value="10",
info="設定圓點半徑大小(建議範圍 5~20)"
)
font_size = gr.Textbox(
label="字體大小",
value="36",
info="設定編號文字大小(建議範圍 20~60)"
)
with gr.Column(scale=1):
gr.HTML('<h3 style="color: #34a853; text-align: center; margin-bottom: 15px;">📊 檢測結果</h3>')
data_output = gr.Textbox(
interactive=False,
elem_classes=["count-output"]
)
gr.HTML('<div style="text-align:center; margin:15px 0;"><label style="font-size:14px;color:#555;">標註結果</label></div>')
img_output = gr.Image(
type="numpy",
elem_classes=["result-image"]
)
gr.HTML('<div style="text-align:center; margin:20px 0;"><label style="font-size:14px;color:#555;">標註 + 編號結果</label></div>')
img_output_numbered = gr.Image(
type="numpy",
elem_classes=["result-image"]
)
gr.HTML('</div>')
gr.HTML("""
<div style="background:#ffffff; border-left:4px solid #4285f4; padding:15px; margin:15px; color:#666;">
<strong>使用步驟:</strong>
<ol style="margin-top:8px;">
<li>上傳清晰的圖片,保持純色背景(白色或黑色),勿有反光</li>
<li>點擊「開始計算」按鈕</li>
<li>查看標註後圖片及藥錠總數</li>
</ol>
</div>
""")
button.click(
fn=predict,
inputs=[img, target_class, dot_color, text_color_choice, dot_size, font_size],
outputs=[data_output, img_output, img_output_numbered]
)
if __name__ == "__main__":
demo.launch()