Skip to content

Commit 2f2fe8e

Browse files
committed
feat: adaptive line kernel based on structure size
1 parent 70a5b2b commit 2f2fe8e

File tree

2 files changed

+35
-12
lines changed

2 files changed

+35
-12
lines changed

decimer_segmentation/complete_structure.py

Lines changed: 30 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -366,7 +366,10 @@ def get_neighbour_pixels(
366366
return neighbour_pixels
367367

368368

369-
def detect_horizontal_and_vertical_lines(image: np.ndarray) -> np.ndarray:
369+
def detect_horizontal_and_vertical_lines(
370+
image: np.ndarray,
371+
average_depiction_size: Tuple[int, int]
372+
) -> np.ndarray:
370373
"""
371374
This function takes an image and returns a binary mask that labels the pixels that
372375
are part of long horizontal or vertical lines. [Definition of long: 1/5 of the
@@ -382,19 +385,19 @@ def detect_horizontal_and_vertical_lines(image: np.ndarray) -> np.ndarray:
382385
"""
383386
binarised_im = ~image * 255
384387
binarised_im = binarised_im.astype("uint8")
388+
389+
structure_height, structure_width = average_depiction_size
385390

386-
horizontal_kernel_size = int(binarised_im.shape[1] / 7)
387391
horizontal_kernel = cv2.getStructuringElement(
388-
cv2.MORPH_RECT, (horizontal_kernel_size, 1)
392+
cv2.MORPH_RECT, (structure_width, 1)
389393
)
390394
horizontal_mask = cv2.morphologyEx(
391395
binarised_im, cv2.MORPH_OPEN, horizontal_kernel, iterations=2
392396
)
393397
horizontal_mask = horizontal_mask == 255
394398

395-
vertical_kernel_size = int(binarised_im.shape[0] / 7)
396399
vertical_kernel = cv2.getStructuringElement(
397-
cv2.MORPH_RECT, (1, vertical_kernel_size)
400+
cv2.MORPH_RECT, (1, structure_height)
398401
)
399402
vertical_mask = cv2.morphologyEx(
400403
binarised_im, cv2.MORPH_OPEN, vertical_kernel, iterations=2
@@ -472,13 +475,30 @@ def expansion_coordination(
472475

473476

474477
def complete_structure_mask(
475-
image_array: np.array, mask_array: np.array, debug=False
478+
image_array: np.array,
479+
mask_array: np.array,
480+
average_depiction_size: Tuple[int, int],
481+
debug=False
476482
) -> np.array:
477483
"""
478-
This funtion takes an image (array) and an array containing the masks (shape:
484+
This funtion takes an image (np.array) and an array containing the masks (shape:
479485
x,y,n where n is the amount of masks and x and y are the pixel coordinates).
486+
Additionally, it takes the average depiction size of the structures in the image
487+
which is used to define the kernel size for the vertical and horizontal line
488+
detection for the exclusion masks. The exclusion mask is used to exclude pixels
489+
from the mask expansion to avoid including whole tables.
480490
It detects objects on the contours of the mask and expands it until it frames the
481-
complete object in the image. It returns the expanded mask array"""
491+
complete object in the image. It returns the expanded mask array
492+
493+
Args:
494+
image_array (np.array): input image
495+
mask_array (np.array): shape: y, x, n where n is the amount of masks
496+
average_depiction_size (Tuple[int, int]): height, width
497+
debug (bool, optional): More verbose if True. Defaults to False.
498+
499+
Returns:
500+
np.array: expanded mask array
501+
"""
482502

483503
if mask_array.size != 0:
484504
# Binarization of input image
@@ -498,7 +518,8 @@ def complete_structure_mask(
498518
split_mask_arrays = np.array(
499519
[mask_array[:, :, index] for index in range(mask_array.shape[2])]
500520
)
501-
exclusion_mask = detect_horizontal_and_vertical_lines(blurred_image_array)
521+
exclusion_mask = detect_horizontal_and_vertical_lines(blurred_image_array,
522+
average_depiction_size)
502523
# Run expansion the expansion
503524
image_repeat = itertools.repeat(blurred_image_array, mask_array.shape[2])
504525
exclusion_mask_repeat = itertools.repeat(exclusion_mask, mask_array.shape[2])

decimer_segmentation/decimer_segmentation.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -98,7 +98,6 @@ def segment_chemical_structures(
9898
if not expand:
9999
masks, bboxes, _ = get_mrcnn_results(image)
100100
else:
101-
average_height, average_width = determine_average_depiction_size(bboxes)
102101
masks = get_expanded_masks(image)
103102

104103
segments, bboxes = apply_masks(image, masks)
@@ -227,9 +226,12 @@ def get_expanded_masks(image: np.array) -> np.array:
227226
np.array: expanded masks (shape: (h, w, num_masks))
228227
"""
229228
# Structure detection with MRCNN
230-
masks, _, _ = get_mrcnn_results(image)
229+
masks, bboxes, _ = get_mrcnn_results(image)
230+
size = determine_average_depiction_size(bboxes)
231231
# Mask expansion
232-
expanded_masks = complete_structure_mask(image_array=image, mask_array=masks)
232+
expanded_masks = complete_structure_mask(image_array=image,
233+
mask_array=masks,
234+
average_depiction_size=size,)
233235
return expanded_masks
234236

235237

0 commit comments

Comments
 (0)