Skip to content

Commit c1467f4

Browse files
estelleaflAflaloAflaloAflalo
authored andcommitted
[ldm3d] Update code to be functional with the new checkpoints (huggingface#3875)
* fixed typo * updated doc to be consistent in naming * make style/quality * preprocessing for 4 channels and not 6 * make style * test for 4c * make style/quality * fixed test on cpu --------- Co-authored-by: Aflalo <[email protected]> Co-authored-by: Aflalo <[email protected]> Co-authored-by: Aflalo <[email protected]>
1 parent c6c25b2 commit c1467f4

File tree

2 files changed

+41
-12
lines changed

2 files changed

+41
-12
lines changed

src/diffusers/image_processor.py

Lines changed: 15 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -312,12 +312,17 @@ def numpy_to_depth(self, images):
312312
"""
313313
if images.ndim == 3:
314314
images = images[None, ...]
315-
images = (images * 255).round().astype("uint8")
316-
if images.shape[-1] == 1:
317-
# special case for grayscale (single channel) images
318-
raise Exception("Not supported")
315+
images_depth = images[:, :, :, 3:]
316+
if images.shape[-1] == 6:
317+
images_depth = (images_depth * 255).round().astype("uint8")
318+
pil_images = [
319+
Image.fromarray(self.rgblike_to_depthmap(image_depth), mode="I;16") for image_depth in images_depth
320+
]
321+
elif images.shape[-1] == 4:
322+
images_depth = (images_depth * 65535.0).astype(np.uint16)
323+
pil_images = [Image.fromarray(image_depth, mode="I;16") for image_depth in images_depth]
319324
else:
320-
pil_images = [Image.fromarray(self.rgblike_to_depthmap(image[:, :, 3:]), mode="I;16") for image in images]
325+
raise Exception("Not supported")
321326

322327
return pil_images
323328

@@ -349,7 +354,11 @@ def postprocess(
349354
image = self.pt_to_numpy(image)
350355

351356
if output_type == "np":
352-
return image[:, :, :, :3], np.stack([self.rgblike_to_depthmap(im[:, :, 3:]) for im in image], axis=0)
357+
if image.shape[-1] == 6:
358+
image_depth = np.stack([self.rgblike_to_depthmap(im[:, :, 3:]) for im in image], axis=0)
359+
else:
360+
image_depth = image[:, :, :, 3:]
361+
return image[:, :, :, :3], image_depth
353362

354363
if output_type == "pil":
355364
return self.numpy_to_pil(image), self.numpy_to_depth(image)

tests/pipelines/stable_diffusion/test_stable_diffusion_ldm3d.py

Lines changed: 26 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -130,9 +130,9 @@ def test_stable_diffusion_ddim(self):
130130
assert depth.shape == (1, 64, 64)
131131

132132
expected_slice_rgb = np.array(
133-
[0.37301102, 0.7023895, 0.7418312, 0.5163375, 0.5825485, 0.60929704, 0.4188174, 0.48407027, 0.46555096]
133+
[0.37338176, 0.70247, 0.74203193, 0.51643604, 0.58256793, 0.60932136, 0.4181095, 0.48355877, 0.46535262]
134134
)
135-
expected_slice_depth = np.array([103.4673, 85.81202, 87.84926])
135+
expected_slice_depth = np.array([103.46727, 85.812004, 87.849236])
136136

137137
assert np.abs(image_slice_rgb.flatten() - expected_slice_rgb).max() < 1e-2
138138
assert np.abs(image_slice_depth.flatten() - expected_slice_depth).max() < 1e-2
@@ -280,10 +280,30 @@ def test_ldm3d(self):
280280
output = ldm3d_pipe(**inputs)
281281
rgb, depth = output.rgb, output.depth
282282

283-
expected_rgb_mean = 0.54461557
284-
expected_rgb_std = 0.2806707
285-
expected_depth_mean = 143.64595
286-
expected_depth_std = 83.491776
283+
expected_rgb_mean = 0.495586
284+
expected_rgb_std = 0.33795515
285+
expected_depth_mean = 112.48518
286+
expected_depth_std = 98.489746
287+
assert np.abs(expected_rgb_mean - rgb.mean()) < 1e-3
288+
assert np.abs(expected_rgb_std - rgb.std()) < 1e-3
289+
assert np.abs(expected_depth_mean - depth.mean()) < 1e-3
290+
assert np.abs(expected_depth_std - depth.std()) < 1e-3
291+
292+
def test_ldm3d_v2(self):
293+
ldm3d_pipe = StableDiffusionLDM3DPipeline.from_pretrained("Intel/ldm3d-4c").to(torch_device)
294+
ldm3d_pipe.set_progress_bar_config(disable=None)
295+
296+
inputs = self.get_inputs(torch_device)
297+
output = ldm3d_pipe(**inputs)
298+
rgb, depth = output.rgb, output.depth
299+
300+
expected_rgb_mean = 0.4194127
301+
expected_rgb_std = 0.35375586
302+
expected_depth_mean = 0.5638502
303+
expected_depth_std = 0.34686103
304+
305+
assert rgb.shape == (1, 512, 512, 3)
306+
assert depth.shape == (1, 512, 512, 1)
287307
assert np.abs(expected_rgb_mean - rgb.mean()) < 1e-3
288308
assert np.abs(expected_rgb_std - rgb.std()) < 1e-3
289309
assert np.abs(expected_depth_mean - depth.mean()) < 1e-3

0 commit comments

Comments
 (0)