WizardForest commited on
Commit
7656c57
·
verified ·
1 Parent(s): 78ff8a8

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +43 -19
app.py CHANGED
@@ -6,7 +6,10 @@ import numpy as np
6
 
7
  model = YOLO("model.pt")
8
 
9
- def predict(img, dot_color, text_color_choice, dot_size, font_size):
 
 
 
10
  path = img.split("\\")[-1].split(".")[0]
11
  print("path", path)
12
 
@@ -15,11 +18,7 @@ def predict(img, dot_color, text_color_choice, dot_size, font_size):
15
  if not isinstance(results, (list, tuple)):
16
  results = [results]
17
 
18
- count = 0
19
- annot_img = None
20
- annot_img_numbered = None
21
-
22
- # 轉換顏色文字為 RGB
23
  if dot_color == "紅色":
24
  dot_rgb = (255, 0, 0)
25
  else:
@@ -30,7 +29,7 @@ def predict(img, dot_color, text_color_choice, dot_size, font_size):
30
  else:
31
  text_rgb = (255, 255, 255)
32
 
33
- # 檢查輸入的大小是否合理
34
  try:
35
  r = int(dot_size)
36
  except:
@@ -43,34 +42,51 @@ def predict(img, dot_color, text_color_choice, dot_size, font_size):
43
  fsize = 36
44
  if fsize < 8: fsize = 8
45
 
 
 
 
 
 
 
 
 
 
 
46
  for i in results:
47
- count = len(i.boxes)
48
- img_pil = Image.fromarray(i.orig_img[..., ::-1]) # BGR → RGB
 
 
 
 
 
 
 
49
  img_pil_numbered = img_pil.copy()
50
 
51
  draw = ImageDraw.Draw(img_pil)
52
  draw_num = ImageDraw.Draw(img_pil_numbered)
53
 
54
- # 載入字體(若無 Arial 則 fallback)
55
  try:
56
- font = ImageFont.truetype("arial.ttf", fsize)
57
  except:
58
- font = ImageFont.load_default()
 
 
 
59
 
60
- for idx, box in enumerate(i.boxes.xyxy, start=1):
61
  x1, y1, x2, y2 = box[:4].tolist()
62
  cx, cy = (x1 + x2) / 2, (y1 + y2) / 2
63
 
64
- # -------- 第一張(只有圓點)--------
65
  draw.ellipse((cx - r, cy - r, cx + r, cy + r), fill=dot_rgb, outline=None)
66
 
67
- # -------- 第二張(圓點 + 編號)--------
68
  draw_num.ellipse((cx - r, cy - r, cx + r, cy + r), fill=dot_rgb, outline=None)
69
 
70
- # 最後一顆用紅字,其餘依使用者設定
71
  num_color = (255, 0, 0) if idx == count else text_rgb
72
-
73
- # 在圓點右上方標註數字
74
  draw_num.text((cx + r + 4, cy - r), str(idx), fill=num_color, font=font)
75
 
76
  annot_img = np.array(img_pil)
@@ -78,6 +94,7 @@ def predict(img, dot_color, text_color_choice, dot_size, font_size):
78
 
79
  return str(count), annot_img, annot_img_numbered
80
 
 
81
  # 自定義 CSS 樣式(背景改為白色、簡潔風格)
82
  custom_css = """
83
  .gradio-container {
@@ -212,6 +229,13 @@ with gr.Blocks(
212
  label="請選擇或拖放圖片",
213
  elem_classes=["image-upload"]
214
  )
 
 
 
 
 
 
 
215
  gr.HTML('<div style="text-align: center; margin-top: 15px;">')
216
  button = gr.Button(
217
  "開始計算",
@@ -281,7 +305,7 @@ with gr.Blocks(
281
 
282
  button.click(
283
  fn=predict,
284
- inputs=[img, dot_color, text_color_choice, dot_size, font_size],
285
  outputs=[data_output, img_output, img_output_numbered]
286
  )
287
 
 
6
 
7
  model = YOLO("model.pt")
8
 
9
+ from PIL import Image, ImageDraw, ImageFont
10
+ import numpy as np
11
+
12
+ def predict(img, target_class, dot_color, text_color_choice, dot_size, font_size):
13
  path = img.split("\\")[-1].split(".")[0]
14
  print("path", path)
15
 
 
18
  if not isinstance(results, (list, tuple)):
19
  results = [results]
20
 
21
+ # 顏色設定
 
 
 
 
22
  if dot_color == "紅色":
23
  dot_rgb = (255, 0, 0)
24
  else:
 
29
  else:
30
  text_rgb = (255, 255, 255)
31
 
32
+ # 大小設定
33
  try:
34
  r = int(dot_size)
35
  except:
 
42
  fsize = 36
43
  if fsize < 8: fsize = 8
44
 
45
+ count = 0
46
+ annot_img = None
47
+ annot_img_numbered = None
48
+
49
+ # 轉換成 YOLO 類別索引
50
+ if target_class == "膠囊":
51
+ class_idx = 0
52
+ else:
53
+ class_idx = 1
54
+
55
  for i in results:
56
+ # 過濾出該類別的框
57
+ selected_boxes = []
58
+ if hasattr(i.boxes, "cls"):
59
+ for b, c in zip(i.boxes.xyxy, i.boxes.cls):
60
+ if int(c) == class_idx:
61
+ selected_boxes.append(b)
62
+
63
+ count = len(selected_boxes)
64
+ img_pil = Image.fromarray(i.orig_img[..., ::-1])
65
  img_pil_numbered = img_pil.copy()
66
 
67
  draw = ImageDraw.Draw(img_pil)
68
  draw_num = ImageDraw.Draw(img_pil_numbered)
69
 
 
70
  try:
71
+ font = ImageFont.truetype("DejaVuSans.ttf", fsize)
72
  except:
73
+ try:
74
+ font = ImageFont.truetype("arial.ttf", fsize)
75
+ except:
76
+ font = ImageFont.load_default()
77
 
78
+ for idx, box in enumerate(selected_boxes, start=1):
79
  x1, y1, x2, y2 = box[:4].tolist()
80
  cx, cy = (x1 + x2) / 2, (y1 + y2) / 2
81
 
82
+ # 第一張:只有圓點
83
  draw.ellipse((cx - r, cy - r, cx + r, cy + r), fill=dot_rgb, outline=None)
84
 
85
+ # 第二張:圓點 + 編號
86
  draw_num.ellipse((cx - r, cy - r, cx + r, cy + r), fill=dot_rgb, outline=None)
87
 
88
+ # 最後一顆改紅字
89
  num_color = (255, 0, 0) if idx == count else text_rgb
 
 
90
  draw_num.text((cx + r + 4, cy - r), str(idx), fill=num_color, font=font)
91
 
92
  annot_img = np.array(img_pil)
 
94
 
95
  return str(count), annot_img, annot_img_numbered
96
 
97
+
98
  # 自定義 CSS 樣式(背景改為白色、簡潔風格)
99
  custom_css = """
100
  .gradio-container {
 
229
  label="請選擇或拖放圖片",
230
  elem_classes=["image-upload"]
231
  )
232
+ target_class = gr.Radio(
233
+ choices=["膠囊", "錠劑"],
234
+ value="膠囊",
235
+ label="要計算的類別",
236
+ info="選擇只計算錠劑或膠囊的總數"
237
+ )
238
+
239
  gr.HTML('<div style="text-align: center; margin-top: 15px;">')
240
  button = gr.Button(
241
  "開始計算",
 
305
 
306
  button.click(
307
  fn=predict,
308
+ inputs=[img, target_class, dot_color, text_color_choice, dot_size, font_size],
309
  outputs=[data_output, img_output, img_output_numbered]
310
  )
311