ayushpfullstack commited on
Commit
c494a9e
·
verified ·
1 Parent(s): 893006f

Update main.py

Browse files
Files changed (1) hide show
  1. main.py +51 -35
main.py CHANGED
@@ -1,4 +1,5 @@
1
  import torch
 
2
  import numpy as np
3
  import cv2
4
  from PIL import Image
@@ -11,8 +12,8 @@ from fastapi import FastAPI, HTTPException
11
  from pydantic import BaseModel
12
  from contextlib import asynccontextmanager
13
 
14
- # Diffusers & Transformers Libraries
15
- from transformers import pipeline
16
  from diffusers import StableDiffusionControlNetInpaintPipeline, ControlNetModel, UniPCMultistepScheduler
17
 
18
  # --- API Data Models ---
@@ -32,11 +33,17 @@ async def lifespan(app: FastAPI):
32
  device = "cuda" if torch.cuda.is_available() else "cpu"
33
  torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32
34
 
35
- models['segmentation_pipeline'] = pipeline("image-segmentation", model="Intel/dpt-large-ade", device=device)
36
- models['depth_estimator'] = pipeline("depth-estimation", model="Intel/dpt-hybrid-midas", device=device)
 
 
 
 
 
 
37
 
 
38
  controlnet = ControlNetModel.from_pretrained("lllyasviel/sd-controlnet-depth", torch_dtype=torch_dtype)
39
-
40
  models['inpainting_pipe'] = StableDiffusionControlNetInpaintPipeline.from_pretrained(
41
  "runwayml/stable-diffusion-v1-5",
42
  controlnet=controlnet,
@@ -55,40 +62,49 @@ app = FastAPI(lifespan=lifespan)
55
 
56
  # --- Helper Functions (Core Logic) ---
57
  def create_precise_mask(image_pil: Image.Image) -> Image.Image:
58
- # REVERTED CHANGE: Pass the PIL image directly to the pipeline
59
- segments = models['segmentation_pipeline'](image_pil)
60
- W, H = image_pil.size
61
- inclusion_mask_np = np.zeros((H, W), dtype=np.uint8)
62
- exclusion_mask_np = np.zeros((H, W), dtype=np.uint8)
63
- inclusion_labels = {"wall", "floor", "ceiling"}
64
- base_exclusion_labels = {"door", "window", "windowpane", "window blind"}
65
- insert_labels = {"painting", "picture", "shelf", "showcase", "cabinet", "mirror", "television", "radiator"}
66
- walls, inserts = [], []
67
- for segment in segments:
68
- label, mask = segment['label'], np.array(segment['mask'])
69
- if label in inclusion_labels:
70
- inclusion_mask_np = np.maximum(inclusion_mask_np, mask)
71
- if label == "wall": walls.append(mask)
72
- if label in base_exclusion_labels:
73
- exclusion_mask_np = np.maximum(exclusion_mask_np, mask)
74
- if label in insert_labels:
75
- inserts.append(mask)
76
- for insert_mask in inserts:
77
- for wall_mask in walls:
78
- if np.all((wall_mask >= insert_mask)[insert_mask > 0]):
79
- exclusion_mask_np = np.maximum(exclusion_mask_np, insert_mask)
80
- break
81
- raw_mask_np = np.copy(inclusion_mask_np); raw_mask_np[exclusion_mask_np > 0] = 0
82
  mask_filled_np = cv2.morphologyEx(raw_mask_np, cv2.MORPH_CLOSE, np.ones((10,10),np.uint8))
83
  return Image.fromarray(mask_filled_np)
84
 
85
  def generate_depth_map(image_pil: Image.Image) -> Image.Image:
86
- # REVERTED CHANGE: Pass the PIL image directly to the pipeline
87
- predicted_depth = models['depth_estimator'](image_pil)['predicted_depth']
88
- depth_map_np = predicted_depth.cpu().numpy()
89
- depth_map_np = (depth_map_np - depth_map_np.min()) / (depth_map_np.max() - depth_map_np.min()) * 255.0
90
- depth_map_np = depth_map_np.astype(np.uint8)
91
- return Image.fromarray(np.concatenate([depth_map_np[..., None]] * 3, axis=-1))
 
 
 
 
 
 
 
 
 
92
 
93
  # --- API Endpoints ---
94
  @app.get("/")
 
1
  import torch
2
+ import torch.nn.functional as F
3
  import numpy as np
4
  import cv2
5
  from PIL import Image
 
12
  from pydantic import BaseModel
13
  from contextlib import asynccontextmanager
14
 
15
+ # Diffusers & Transformers Libraries - UPDATED IMPORTS
16
+ from transformers import DPTForSemanticSegmentation, DPTImageProcessor, DPTForDepthEstimation
17
  from diffusers import StableDiffusionControlNetInpaintPipeline, ControlNetModel, UniPCMultistepScheduler
18
 
19
  # --- API Data Models ---
 
33
  device = "cuda" if torch.cuda.is_available() else "cpu"
34
  torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32
35
 
36
+ # --- UPDATED: Load processors and models separately ---
37
+ # Segmentation model
38
+ models['seg_processor'] = DPTImageProcessor.from_pretrained("Intel/dpt-large-ade")
39
+ models['seg_model'] = DPTForSemanticSegmentation.from_pretrained("Intel/dpt-large-ade").to(device)
40
+
41
+ # Depth estimation model
42
+ models['depth_processor'] = DPTImageProcessor.from_pretrained("Intel/dpt-hybrid-midas")
43
+ models['depth_model'] = DPTForDepthEstimation.from_pretrained("Intel/dpt-hybrid-midas").to(device)
44
 
45
+ # ControlNet and Inpainting Pipeline
46
  controlnet = ControlNetModel.from_pretrained("lllyasviel/sd-controlnet-depth", torch_dtype=torch_dtype)
 
47
  models['inpainting_pipe'] = StableDiffusionControlNetInpaintPipeline.from_pretrained(
48
  "runwayml/stable-diffusion-v1-5",
49
  controlnet=controlnet,
 
62
 
63
  # --- Helper Functions (Core Logic) ---
64
  def create_precise_mask(image_pil: Image.Image) -> Image.Image:
65
+ # --- UPDATED: Manual processing and inference ---
66
+ processor = models['seg_processor']
67
+ model = models['seg_model']
68
+
69
+ inputs = processor(images=image_pil, return_tensors="pt").to(model.device)
70
+ with torch.no_grad():
71
+ outputs = model(**inputs)
72
+
73
+ logits = outputs.logits
74
+ # ADE20k has 150 classes
75
+ upsampled_logits = F.interpolate(logits, size=image_pil.size[::-1], mode="bilinear", align_corners=False)
76
+ pred_seg = upsampled_logits.argmax(dim=1)[0].cpu().numpy().astype(np.uint8)
77
+
78
+ # Use a simplified mapping for room structure labels
79
+ # Wall=2, Floor=3, Ceiling=5 (based on common ADE20k indices)
80
+ inclusion_indices = {2, 3, 5}
81
+ # Door=14, Window=17
82
+ exclusion_indices = {14, 17}
83
+
84
+ inclusion_mask_np = np.isin(pred_seg, list(inclusion_indices)).astype(np.uint8) * 255
85
+ exclusion_mask_np = np.isin(pred_seg, list(exclusion_indices)).astype(np.uint8) * 255
86
+
87
+ raw_mask_np = np.copy(inclusion_mask_np)
88
+ raw_mask_np[exclusion_mask_np > 0] = 0
89
  mask_filled_np = cv2.morphologyEx(raw_mask_np, cv2.MORPH_CLOSE, np.ones((10,10),np.uint8))
90
  return Image.fromarray(mask_filled_np)
91
 
92
  def generate_depth_map(image_pil: Image.Image) -> Image.Image:
93
+ # --- UPDATED: Manual processing and inference ---
94
+ processor = models['depth_processor']
95
+ model = models['depth_model']
96
+
97
+ inputs = processor(images=image_pil, return_tensors="pt").to(model.device)
98
+ with torch.no_grad():
99
+ outputs = model(**inputs)
100
+
101
+ predicted_depth = outputs.predicted_depth
102
+ prediction = F.interpolate(predicted_depth.unsqueeze(1), size=image_pil.size[::-1], mode="bicubic", align_corners=False)
103
+
104
+ depth_map = prediction.squeeze().cpu().numpy()
105
+ depth_map = (depth_map - depth_map.min()) / (depth_map.max() - depth_map.min()) * 255.0
106
+ depth_map = depth_map.astype(np.uint8)
107
+ return Image.fromarray(np.concatenate([depth_map[..., None]] * 3, axis=-1))
108
 
109
  # --- API Endpoints ---
110
  @app.get("/")