Skip to content

Commit 8143111

Browse files
committed
feat: improved line detection
1 parent 3ec8239 commit 8143111

File tree

2 files changed

+109
-22
lines changed

2 files changed

+109
-22
lines changed

decimer_segmentation/complete_structure.py

Lines changed: 106 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,11 @@
11
import cv2
2+
import math
23
import numpy as np
34
import matplotlib.pyplot as plt
45
import itertools
56
from skimage.color import rgb2gray
67
from skimage.filters import threshold_otsu
7-
from skimage.morphology import binary_erosion
8+
from skimage.morphology import binary_erosion, binary_dilation
89
from typing import List, Tuple
910
from scipy.ndimage import label
1011

@@ -98,8 +99,7 @@ def detect_horizontal_and_vertical_lines(
9899
) -> np.ndarray:
99100
"""
100101
This function takes an image and returns a binary mask that labels the pixels that
101-
are part of long horizontal or vertical lines. [Definition of long: 1/5 of the
102-
width/height of the image].
102+
are part of long horizontal or vertical lines.
103103
104104
Args:
105105
image (np.ndarray): binarised image (np.array; type bool) as it is returned by
@@ -130,11 +130,86 @@ def detect_horizontal_and_vertical_lines(
130130
return horizontal_mask + vertical_mask
131131

132132

133+
def find_equidistant_points(
134+
x1: int,
135+
y1: int,
136+
x2: int,
137+
y2: int,
138+
num_points: int = 5
139+
) -> List[Tuple[int, int]]:
140+
"""
141+
Finds equidistant points between two points.
142+
143+
Args:
144+
x1 (int): x coordinate of first point
145+
y1 (int): y coordinate of first point
146+
x2 (int): x coordinate of second point
147+
y2 (int): y coordinate of second point
148+
num_points (int, optional): Number of points to return. Defaults to 5.
149+
150+
Returns:
151+
List[Tuple[int, int]]: Equidistant points on the given line
152+
"""
153+
points = []
154+
for i in range(num_points + 1):
155+
t = i / num_points
156+
x = x1 * (1 - t) + x2 * t
157+
y = y1 * (1 - t) + y2 * t
158+
points.append((x, y))
159+
return points
160+
161+
162+
def detect_lines(
163+
image: np.ndarray,
164+
max_depiction_size: Tuple[int, int],
165+
segmentation_mask: np.ndarray
166+
) -> np.ndarray:
167+
"""
168+
This function takes an image and returns a binary mask that labels the pixels that
169+
are part of lines that are not part of chemical structures (like arrays, tables).
170+
171+
Args:
172+
image (np.ndarray): binarised image (np.array; type bool) as it is returned by
173+
binary_erosion() in complete_structure_mask()
174+
max_depiction_size (Tuple[int, int]): height, width; used for thresholds
175+
segmentation_mask (np.ndarray): Indicates whether or not a pixel is part of a
176+
chemical structure depiction (shape: (height, width))
177+
Returns:
178+
np.ndarray: Exclusion mask that contains indices of pixels that are part of
179+
horizontal or vertical lines
180+
"""
181+
image = ~image * 255
182+
image = image.astype("uint8")
183+
# Detect lines using the Hough Transform
184+
lines = cv2.HoughLinesP(image,
185+
1,
186+
np.pi / 180,
187+
threshold=5,
188+
minLineLength=int(min(max_depiction_size)/4),
189+
maxLineGap=10)
190+
# Generate exclusion mask based on detected lines
191+
exclusion_mask = np.zeros_like(image)
192+
if lines is None:
193+
return exclusion_mask
194+
for line in lines:
195+
x1, y1, x2, y2 = line[0]
196+
# Check if any of the lines is in a chemical structure depiction
197+
points = find_equidistant_points(x1, y1, x2, y2, num_points=7)
198+
points_in_structure = False
199+
for x, y in points[1:-1]:
200+
if segmentation_mask[int(y), int(x)]:
201+
points_in_structure = True
202+
break
203+
if points_in_structure:
204+
continue
205+
cv2.line(exclusion_mask, (x1, y1), (x2, y2), 255, 2)
206+
return exclusion_mask
207+
208+
133209
def expand_masks(
134210
image_array: np.array,
135211
seed_pixels: List[Tuple[int, int]],
136212
mask_array: np.array,
137-
exclusion_mask: np.array,
138213
) -> np.array:
139214
"""
140215
This function generates a mask array where the given masks have been
@@ -144,20 +219,15 @@ def expand_masks(
144219
image_array (np.array): array that represents an image (float values)
145220
seed_pixels (List[Tuple[int, int]]): [(x, y), ...]
146221
mask_array (np.array): MRCNN output; shape: (y, x, mask_index)
147-
exclusion_mask (np.array]: indicates whether or not a pixel is excluded from
148-
expansion
149-
contour_expansion (bool, optional): Indicates whether or not to expand
150-
from contours. Defaults to False.
151222
152223
Returns:
153224
np.array: Expanded masks
154225
"""
155-
image_with_exclusion = np.invert(image_array) * np.invert(exclusion_mask)
156-
labeled_array, _ = label(image_with_exclusion)
226+
labeled_array, _ = label(image_array)
157227
mask_array = np.zeros_like(image_array)
158228
for seed_pixel in seed_pixels:
159229
x, y = seed_pixel
160-
if mask_array[y, x] or exclusion_mask[y, x]:
230+
if mask_array[y, x]:
161231
continue
162232
label_value = labeled_array[y, x]
163233
if label_value > 0:
@@ -176,7 +246,7 @@ def expansion_coordination(
176246
seed_pixels = get_seeds(image_array,
177247
mask_array,
178248
exclusion_mask)
179-
mask_array = expand_masks(image_array, seed_pixels, mask_array, exclusion_mask)
249+
mask_array = expand_masks(image_array, seed_pixels, mask_array)
180250
return mask_array
181251

182252

@@ -213,29 +283,45 @@ def complete_structure_mask(
213283
blur_factor = (
214284
int(image_array.shape[1] / 185) if image_array.shape[1] / 185 >= 2 else 2
215285
)
216-
if debug:
217-
plot_it(binarized_image_array)
218-
# Define kernel and apply
219286
kernel = np.ones((blur_factor, blur_factor))
220287
blurred_image_array = binary_erosion(binarized_image_array, footprint=kernel)
221288
if debug:
222289
plot_it(blurred_image_array)
290+
debug = True
291+
if debug:
292+
plot_it(binarized_image_array)
223293
# Slice mask array along third dimension into single masks
224294
split_mask_arrays = np.array(
225295
[mask_array[:, :, index] for index in range(mask_array.shape[2])]
226296
)
227-
exclusion_mask = detect_horizontal_and_vertical_lines(
297+
# Detect horizontal and vertical lines
298+
horizontal_vertical_lines = detect_horizontal_and_vertical_lines(
228299
blurred_image_array, max_depiction_size
229300
)
230-
# Run expansion the expansion
231-
image_repeat = itertools.repeat(blurred_image_array, mask_array.shape[2])
232-
exclusion_mask_repeat = itertools.repeat(exclusion_mask, mask_array.shape[2])
301+
302+
hough_lines = detect_lines(
303+
binarized_image_array,
304+
max_depiction_size,
305+
segmentation_mask=np.any(mask_array, axis=2).astype(np.bool)
306+
)
307+
hough_lines = binary_dilation(hough_lines, footprint=kernel)
308+
exclusion_mask = horizontal_vertical_lines + hough_lines
309+
if debug:
310+
plot_it(horizontal_vertical_lines)
311+
plot_it(hough_lines)
312+
plot_it(exclusion_mask)
313+
plot_it(np.invert(binarized_image_array) * np.invert(exclusion_mask))
314+
image_with_exclusion = np.invert(blurred_image_array) * np.invert(exclusion_mask)
315+
316+
# Run expansion
317+
image_repeat = itertools.repeat(image_with_exclusion, mask_array.shape[2])
318+
exclusion_repeat = itertools.repeat(exclusion_mask, mask_array.shape[2])
233319
# Faster with map function
234320
expanded_split_mask_arrays = map(
235321
expansion_coordination,
236322
split_mask_arrays,
237323
image_repeat,
238-
exclusion_mask_repeat,
324+
exclusion_repeat,
239325
)
240326
# Stack mask arrays to give the desired output format
241327
mask_array = np.stack(expanded_split_mask_arrays, -1)

decimer_segmentation/decimer_segmentation.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@ class InferenceConfig(moldetect.MolDetectConfig):
3636
# Run detection on one image at a time
3737
GPU_COUNT = 1
3838
IMAGES_PER_GPU = 1
39+
DETECTION_MIN_CONFIDENCE = 0.7
3940

4041

4142
def segment_chemical_structures_from_file(
@@ -142,8 +143,8 @@ def determine_depiction_size_with_buffer(
142143
width = bbox[3] - bbox[1]
143144
heights.append(height)
144145
widths.append(width)
145-
height = int(1.1 * np.max(heights))
146-
width = int(1.1 * np.max(widths))
146+
height = int(1.05 * np.max(heights))
147+
width = int(1.05 * np.max(widths))
147148
return height, width
148149

149150

0 commit comments

Comments
 (0)