| | import { |
| | SamModel, |
| | AutoProcessor, |
| | RawImage, |
| | Tensor, |
| | } from "https://cdn.jsdelivr.net/npm/@huggingface/transformers@3.3.3"; |
| |
|
| | |
| | const statusLabel = document.getElementById("status"); |
| | const fileUpload = document.getElementById("upload"); |
| | const imageContainer = document.getElementById("container"); |
| | const example = document.getElementById("example"); |
| | const uploadButton = document.getElementById("upload-button"); |
| | const resetButton = document.getElementById("reset-image"); |
| | const clearButton = document.getElementById("clear-points"); |
| | const cutButton = document.getElementById("cut-mask"); |
| | const starIcon = document.getElementById("star-icon"); |
| | const crossIcon = document.getElementById("cross-icon"); |
| | const maskCanvas = document.getElementById("mask-output"); |
| | const maskContext = maskCanvas.getContext("2d"); |
| |
|
| | const EXAMPLE_URL = |
| | "https://huggingface.co/datasets/Xenova/transformers.js-docs/resolve/main/corgi.jpg"; |
| |
|
| | |
| | let isEncoding = false; |
| | let isDecoding = false; |
| | let decodePending = false; |
| | let lastPoints = null; |
| | let isMultiMaskMode = false; |
| | let imageInput = null; |
| | let imageProcessed = null; |
| | let imageEmbeddings = null; |
| |
|
| | async function decode() { |
| | |
| | if (isDecoding) { |
| | decodePending = true; |
| | return; |
| | } |
| | isDecoding = true; |
| |
|
| | |
| | const reshaped = imageProcessed.reshaped_input_sizes[0]; |
| | const points = lastPoints |
| | .map((x) => [x.position[0] * reshaped[1], x.position[1] * reshaped[0]]) |
| | .flat(Infinity); |
| | const labels = lastPoints.map((x) => BigInt(x.label)).flat(Infinity); |
| |
|
| | const num_points = lastPoints.length; |
| | const input_points = new Tensor("float32", points, [1, 1, num_points, 2]); |
| | const input_labels = new Tensor("int64", labels, [1, 1, num_points]); |
| |
|
| | |
| | const { pred_masks, iou_scores } = await model({ |
| | ...imageEmbeddings, |
| | input_points, |
| | input_labels, |
| | }); |
| |
|
| | |
| | const masks = await processor.post_process_masks( |
| | pred_masks, |
| | imageProcessed.original_sizes, |
| | imageProcessed.reshaped_input_sizes, |
| | ); |
| |
|
| | isDecoding = false; |
| |
|
| | updateMaskOverlay(RawImage.fromTensor(masks[0][0]), iou_scores.data); |
| |
|
| | |
| | if (decodePending) { |
| | decodePending = false; |
| | decode(); |
| | } |
| | } |
| |
|
| | function updateMaskOverlay(mask, scores) { |
| | |
| | if (maskCanvas.width !== mask.width || maskCanvas.height !== mask.height) { |
| | maskCanvas.width = mask.width; |
| | maskCanvas.height = mask.height; |
| | } |
| |
|
| | |
| | const imageData = maskContext.createImageData( |
| | maskCanvas.width, |
| | maskCanvas.height, |
| | ); |
| |
|
| | |
| | const numMasks = scores.length; |
| | let bestIndex = 0; |
| | for (let i = 1; i < numMasks; ++i) { |
| | if (scores[i] > scores[bestIndex]) { |
| | bestIndex = i; |
| | } |
| | } |
| | statusLabel.textContent = `Segment score: ${scores[bestIndex].toFixed(2)}`; |
| |
|
| | |
| | const pixelData = imageData.data; |
| | for (let i = 0; i < pixelData.length; ++i) { |
| | if (mask.data[numMasks * i + bestIndex] === 1) { |
| | const offset = 4 * i; |
| | pixelData[offset] = 0; |
| | pixelData[offset + 1] = 114; |
| | pixelData[offset + 2] = 189; |
| | pixelData[offset + 3] = 255; |
| | } |
| | } |
| |
|
| | |
| | maskContext.putImageData(imageData, 0, 0); |
| | } |
| |
|
| | function clearPointsAndMask() { |
| | |
| | isMultiMaskMode = false; |
| | lastPoints = null; |
| |
|
| | |
| | document.querySelectorAll(".icon").forEach((e) => e.remove()); |
| |
|
| | |
| | cutButton.disabled = true; |
| |
|
| | |
| | maskContext.clearRect(0, 0, maskCanvas.width, maskCanvas.height); |
| | } |
| | clearButton.addEventListener("click", clearPointsAndMask); |
| |
|
| | resetButton.addEventListener("click", () => { |
| | |
| | imageInput = null; |
| | imageProcessed = null; |
| | imageEmbeddings = null; |
| | isEncoding = false; |
| | isDecoding = false; |
| |
|
| | |
| | clearPointsAndMask(); |
| |
|
| | |
| | cutButton.disabled = true; |
| | imageContainer.style.backgroundImage = "none"; |
| | uploadButton.style.display = "flex"; |
| | statusLabel.textContent = "Ready"; |
| | }); |
| |
|
| | async function encode(url) { |
| | if (isEncoding) return; |
| | isEncoding = true; |
| | statusLabel.textContent = "Extracting image embedding..."; |
| |
|
| | imageInput = await RawImage.fromURL(url); |
| |
|
| | |
| | imageContainer.style.backgroundImage = `url(${url})`; |
| | uploadButton.style.display = "none"; |
| | cutButton.disabled = true; |
| |
|
| | |
| | imageProcessed = await processor(imageInput); |
| | imageEmbeddings = await model.get_image_embeddings(imageProcessed); |
| |
|
| | statusLabel.textContent = "Embedding extracted!"; |
| | isEncoding = false; |
| | } |
| |
|
| | |
| | fileUpload.addEventListener("change", function (e) { |
| | const file = e.target.files[0]; |
| | if (!file) return; |
| |
|
| | const reader = new FileReader(); |
| |
|
| | |
| | reader.onload = (e2) => encode(e2.target.result); |
| |
|
| | reader.readAsDataURL(file); |
| | }); |
| |
|
| | example.addEventListener("click", (e) => { |
| | e.preventDefault(); |
| | encode(EXAMPLE_URL); |
| | }); |
| |
|
| | |
| | imageContainer.addEventListener("mousedown", (e) => { |
| | if (e.button !== 0 && e.button !== 2) { |
| | return; |
| | } |
| | if (!imageEmbeddings) { |
| | return; |
| | } |
| | if (!isMultiMaskMode) { |
| | lastPoints = []; |
| | isMultiMaskMode = true; |
| | cutButton.disabled = false; |
| | } |
| |
|
| | const point = getPoint(e); |
| | lastPoints.push(point); |
| |
|
| | |
| | const icon = (point.label === 1 ? starIcon : crossIcon).cloneNode(); |
| | icon.style.left = `${point.position[0] * 100}%`; |
| | icon.style.top = `${point.position[1] * 100}%`; |
| | imageContainer.appendChild(icon); |
| |
|
| | |
| | decode(); |
| | }); |
| |
|
| | |
| | function clamp(x, min = 0, max = 1) { |
| | return Math.max(Math.min(x, max), min); |
| | } |
| |
|
| | function getPoint(e) { |
| | |
| | const bb = imageContainer.getBoundingClientRect(); |
| |
|
| | |
| | const mouseX = clamp((e.clientX - bb.left) / bb.width); |
| | const mouseY = clamp((e.clientY - bb.top) / bb.height); |
| |
|
| | return { |
| | position: [mouseX, mouseY], |
| | label: |
| | e.button === 2 |
| | ? 0 |
| | : 1, |
| | }; |
| | } |
| |
|
| | |
| | imageContainer.addEventListener("contextmenu", (e) => e.preventDefault()); |
| |
|
| | |
| | imageContainer.addEventListener("mousemove", (e) => { |
| | if (!imageEmbeddings || isMultiMaskMode) { |
| | |
| | |
| | return; |
| | } |
| | lastPoints = [getPoint(e)]; |
| |
|
| | decode(); |
| | }); |
| |
|
| | |
| | cutButton.addEventListener("click", async () => { |
| | const [w, h] = [maskCanvas.width, maskCanvas.height]; |
| |
|
| | |
| | const maskImageData = maskContext.getImageData(0, 0, w, h); |
| |
|
| | |
| | const cutCanvas = new OffscreenCanvas(w, h); |
| | const cutContext = cutCanvas.getContext("2d"); |
| |
|
| | |
| | const maskPixelData = maskImageData.data; |
| | const imagePixelData = imageInput.data; |
| | for (let i = 0; i < w * h; ++i) { |
| | const sourceOffset = 3 * i; |
| | const targetOffset = 4 * i; |
| |
|
| | if (maskPixelData[targetOffset + 3] > 0) { |
| | |
| | for (let j = 0; j < 3; ++j) { |
| | maskPixelData[targetOffset + j] = imagePixelData[sourceOffset + j]; |
| | } |
| | } |
| | } |
| | cutContext.putImageData(maskImageData, 0, 0); |
| |
|
| | |
| | const link = document.createElement("a"); |
| | link.download = "image.png"; |
| | link.href = URL.createObjectURL(await cutCanvas.convertToBlob()); |
| | link.click(); |
| | link.remove(); |
| | }); |
| |
|
| | const model_id = "Xenova/slimsam-77-uniform"; |
| | statusLabel.textContent = "Loading model..."; |
| | const model = await SamModel.from_pretrained(model_id, { |
| | dtype: "fp16", |
| | device: "webgpu", |
| | }); |
| | const processor = await AutoProcessor.from_pretrained(model_id); |
| | statusLabel.textContent = "Ready"; |
| |
|
| | |
| | fileUpload.disabled = false; |
| | uploadButton.style.opacity = 1; |
| | example.style.pointerEvents = "auto"; |