Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 1 addition & 2 deletions yad2k.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,8 +38,7 @@
parser.add_argument(
'-flcl',
'--fully_convolutional',
help='Model is fully convolutional so set input shape to (None, None, 3). '
'WARNING: This experimental option does not work properly for YOLO_v2.',
help='Model is fully convolutional so set input shape to (None, None, 3). ',
action='store_true')


Expand Down
6 changes: 3 additions & 3 deletions yad2k/models/keras_yolo.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ def yolo_head(feats, anchors, num_classes):
# conv_dims = K.variable([conv_width, conv_height])

# Dynamic implementation of conv dims for fully convolutional model.
conv_dims = K.shape(feats)[1:3] # assuming channels last
conv_dims = K.gather(K.shape(feats), [2, 1]) # assuming channels last; [width, height]
# In YOLO the height index is the inner most iteration.
conv_height_index = K.arange(0, stop=conv_dims[0])
conv_width_index = K.arange(0, stop=conv_dims[1])
Expand All @@ -108,11 +108,11 @@ def yolo_head(feats, anchors, num_classes):
K.expand_dims(conv_width_index, 0), [conv_dims[0], 1])
conv_width_index = K.flatten(K.transpose(conv_width_index))
conv_index = K.transpose(K.stack([conv_height_index, conv_width_index]))
conv_index = K.reshape(conv_index, [1, conv_dims[0], conv_dims[1], 1, 2])
conv_index = K.reshape(conv_index, [1, conv_dims[1], conv_dims[0], 1, 2])
conv_index = K.cast(conv_index, K.dtype(feats))

feats = K.reshape(
feats, [-1, conv_dims[0], conv_dims[1], num_anchors, num_classes + 5])
feats, [-1, conv_dims[1], conv_dims[0], num_anchors, num_classes + 5])
conv_dims = K.cast(K.reshape(conv_dims, [1, 1, 1, 1, 2]), K.dtype(feats))

# Static generation of conv_index:
Expand Down