Published on

Segment Anything in a 3D scene (using THREE.js)

I've been doing plenty of R&D and simply toying of recent with various models (computer vision & language models for the most) but have been particularly fascinated in Segment Anything (SAM) capabilities, and attempted to bring its magic into 3D

SAM?

Segment Anything (SAM) is a promptable segmentation system with zero-shot generalization to unfamiliar objects and images, without the need for additional training.

Using a text prompt or cursor point, it can segment an object within an image and produce a mask of the different predictions.

/static/images/sam-three/masks2.jpg

SAM is a sharp tool that is focused on doing one thing (segmentation) extremely well; It does not keep a record of the labels of identified objects as natural language to reduce the model total payload, but instead rely on the high interoperability of the masks outputted which can be relatively easily composed with other computer vision programs and workflows.

Segmentation in itself is in my humble opinion the holy grail of Computer Vision capabilities and acts as an angular building block for most automation we can imagine in our wildest SF dreams: as this effectively focuses on the ability to understand the space; the more performant and efficient such models are, the closer we are getting from ground breaking technlogical changes such as automated medical anomaly detection, self driving cars, and dragon ball Z chi-reading devices.

The experiment

Most of the projects I've seen pairing SAM with 3D are focusing on the generation of 3D mesh from segmented 2D image; From developing extensively in the 3D & XR space @ PHORIA, it made sense to give a shot at segmentation in a 3D scene

/static/images/sam-three/pass.png

The experiment works as follow

  • A 2D pixel coordinates (currently from pointer) are sent to the processing API along with a 2D render of the scene
  • SAM creates masks of the pointed object
  • Optimisation (edge extraction) using OpenCV is applied to the mask to reduce the amount of points to be rendered
  • The 2D mask projected back into 3D space and rendered as a bounding box

Implementation tid bits

Upon the mask being produced by SAM provided a pixel coordinate, OpenCV makes it relatively easy to extract edges of the mask (which we'll use to make the raycasting efficient)

The implementation for the edge extraction looks like this

def extract_edges(self, image: np.ndarray) -> np.ndarray:
    im = Image.fromarray(image)
    im.save(temp_folder("mask.png"))
    img_data = cv2.imread(temp_folder("mask.png"))
    img_data = img_data > 128

    # Extract edges from mask
    img_data = np.asarray(img_data[:, :, 0], dtype=np.double)
    gx, gy = np.gradient(img_data)
    temp_edge = gy * gy + gx * gx
    temp_edge[temp_edge != 0.0] = 255.0
    temp_edge = np.asarray(temp_edge, dtype=np.uint8)
    cv2.imwrite(temp_folder("mask_edge.png"), temp_edge)

    # Transform mask edge to pixel coordinates
    img_data = cv2.imread(temp_folder("mask_edge.png"))
    img_data = cv2.cvtColor(img_data, cv2.COLOR_BGR2GRAY)
    img = img_data.astype(np.uint8)
    coord = cv2.findNonZero(img)
    coord = np.squeeze(coord)

    return coord

Once the pixels array is sent back to the client, we iterate through each pixels and project them back into 3D space (one tricky bit was to ensure you're including the pixel ratio of your screen in the screen space calc)

private pixelToWorldPosition(pixelCoordinate: [number, number]): THREE.Vector3 {
  const vec = new THREE.Vector3();
  const pos = new THREE.Vector3();
  const pixelRatio = window.devicePixelRatio;

  vec.set(
    ((pixelCoordinate[0] / pixelRatio) / this.sceneEl.clientWidth) * 2 - 1,
    -((pixelCoordinate[1] / pixelRatio) / this.sceneEl.clientHeight) * 2 + 1,
    0.5
  )

  this.raycaster.setFromCamera(vec, this.camera);

  var intersects = this.raycaster.intersectObjects(this.scene.children, true)

  if(intersects.length > 0) {
    return intersects[0].point
  }

  return pos
}

Results

/static/images/sam-three/demo.png

The results are concluent but relatively limited

  • Does a great job on relatively flat surfaces, and suprisingly good on curved objects if the camera perspective is right!
  • Since we're using a single perspective, parts of the segmented object are bounded to be missing or misrepresented. To improve accuracy, we would need to take multiple renders of the scene perspectives from different angles and then combine the masks to get a more accurate points representation (wether in processing using a virtual 3D grid, or client side with an offscreen canvas).
  • A common limitation of computer vision models, but quiet slow on CPU (most of the time is spent on segmentation/identification)

That's a wrap!

Was a fun little experiment with lots of potential improvements, you can check out the source code here