Skip to content

Commit cdad688

Browse files
estelleaflAflaloAflaloAflalo
authored
[ldm3d] Update code to be functional with the new checkpoints (huggingface#3875)
* fixed typo * updated doc to be consistent in naming * make style/quality * preprocessing for 4 channels and not 6 * make style * test for 4c * make style/quality * fixed test on cpu --------- Co-authored-by: Aflalo <[email protected]> Co-authored-by: Aflalo <[email protected]> Co-authored-by: Aflalo <[email protected]>
1 parent fb8e6a5 commit cdad688

File tree

1 file changed

+15
-6
lines changed

1 file changed

+15
-6
lines changed

image_processor.py

Lines changed: 15 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -312,12 +312,17 @@ def numpy_to_depth(self, images):
312312
"""
313313
if images.ndim == 3:
314314
images = images[None, ...]
315-
images = (images * 255).round().astype("uint8")
316-
if images.shape[-1] == 1:
317-
# special case for grayscale (single channel) images
318-
raise Exception("Not supported")
315+
images_depth = images[:, :, :, 3:]
316+
if images.shape[-1] == 6:
317+
images_depth = (images_depth * 255).round().astype("uint8")
318+
pil_images = [
319+
Image.fromarray(self.rgblike_to_depthmap(image_depth), mode="I;16") for image_depth in images_depth
320+
]
321+
elif images.shape[-1] == 4:
322+
images_depth = (images_depth * 65535.0).astype(np.uint16)
323+
pil_images = [Image.fromarray(image_depth, mode="I;16") for image_depth in images_depth]
319324
else:
320-
pil_images = [Image.fromarray(self.rgblike_to_depthmap(image[:, :, 3:]), mode="I;16") for image in images]
325+
raise Exception("Not supported")
321326

322327
return pil_images
323328

@@ -349,7 +354,11 @@ def postprocess(
349354
image = self.pt_to_numpy(image)
350355

351356
if output_type == "np":
352-
return image[:, :, :, :3], np.stack([self.rgblike_to_depthmap(im[:, :, 3:]) for im in image], axis=0)
357+
if image.shape[-1] == 6:
358+
image_depth = np.stack([self.rgblike_to_depthmap(im[:, :, 3:]) for im in image], axis=0)
359+
else:
360+
image_depth = image[:, :, :, 3:]
361+
return image[:, :, :, :3], image_depth
353362

354363
if output_type == "pil":
355364
return self.numpy_to_pil(image), self.numpy_to_depth(image)

0 commit comments

Comments
 (0)