@@ -1332,6 +1332,12 @@ def parse_image(image, input_layer_name, image_conf):
13321332 get_img_size (input_layer_name , image_conf .channels )
13331333
13341334
1335+ def parse_image3d (image , input_layer_name , image_conf ):
1336+ image_conf .channels = image .channels
1337+ image_conf .img_size , image_conf .img_size_y , image_conf .img_size_z = \
1338+ get_img3d_size (input_layer_name , image_conf .channels )
1339+
1340+
13351341def parse_norm (norm , input_layer_name , norm_conf ):
13361342 norm_conf .norm_type = norm .norm_type
13371343 config_assert (
@@ -2365,6 +2371,7 @@ def __init__(self,
23652371 name ,
23662372 inputs ,
23672373 bias = True ,
2374+ img3D = False ,
23682375 use_global_stats = True ,
23692376 moving_average_fraction = 0.9 ,
23702377 batch_norm_type = None ,
@@ -2410,15 +2417,33 @@ def __init__(self,
24102417
24112418 input_layer = self .get_input_layer (0 )
24122419 image_conf = self .config .inputs [0 ].image_conf
2413- parse_image (self .inputs [0 ].image , input_layer .name , image_conf )
2414-
2415- # Only pass the width and height of input to batch_norm layer
2416- # when either of it is non-zero.
2417- if input_layer .width != 0 or input_layer .height != 0 :
2418- self .set_cnn_layer (name , image_conf .img_size_y , image_conf .img_size ,
2419- image_conf .channels , False )
2420+ if img3D :
2421+ parse_image3d (self .inputs [0 ].image , input_layer .name , image_conf )
2422+ # Only pass the width and height of input to batch_norm layer
2423+ # when either of it is non-zero.
2424+ if input_layer .width != 0 or input_layer .height != 0 :
2425+ self .set_cnn_layer (
2426+ input_layer_name = name ,
2427+ depth = image_conf .img_size_z ,
2428+ height = image_conf .img_size_y ,
2429+ width = image_conf .img_size ,
2430+ channels = image_conf .channels ,
2431+ is_print = True )
2432+ else :
2433+ self .set_layer_size (input_layer .size )
24202434 else :
2421- self .set_layer_size (input_layer .size )
2435+ parse_image (self .inputs [0 ].image , input_layer .name , image_conf )
2436+ # Only pass the width and height of input to batch_norm layer
2437+ # when either of it is non-zero.
2438+ if input_layer .width != 0 or input_layer .height != 0 :
2439+ self .set_cnn_layer (
2440+ input_layer_name = name ,
2441+ height = image_conf .img_size_y ,
2442+ width = image_conf .img_size ,
2443+ channels = image_conf .channels ,
2444+ is_print = True )
2445+ else :
2446+ self .set_layer_size (input_layer .size )
24222447
24232448 psize = self .calc_parameter_size (image_conf )
24242449 dims = [1 , psize ]
@@ -2433,6 +2458,28 @@ def __init__(self,
24332458
24342459 self .create_bias_parameter (bias , psize )
24352460
2461+ def set_cnn_layer (self ,
2462+ input_layer_name ,
2463+ depth = None ,
2464+ height = None ,
2465+ width = None ,
2466+ channels = None ,
2467+ is_print = True ):
2468+ depthIsNone = False
2469+ if depth is None :
2470+ depth = 1
2471+ depthIsNone = True
2472+ size = depth * height * width * channels
2473+ self .set_layer_size (size )
2474+ self .set_layer_height_width (height , width )
2475+ self .set_layer_depth (depth )
2476+ if is_print and depthIsNone :
2477+ print ("output for %s: c = %d, h = %d, w = %d, size = %d" %
2478+ (input_layer_name , channels , height , width , size ))
2479+ elif is_print :
2480+ print ("output for %s: c = %d, d = %d, h = %d, w = %d, size = %d" %
2481+ (input_layer_name , channels , depth , height , width , size ))
2482+
24362483 def calc_parameter_size (self , image_conf ):
24372484 return image_conf .channels
24382485
@@ -2694,9 +2741,20 @@ def __init__(self, name, inputs, bias=True, **xargs):
26942741 super (AddToLayer , self ).__init__ (
26952742 name , 'addto' , 0 , inputs = inputs , ** xargs )
26962743 config_assert (len (inputs ) > 0 , 'inputs cannot be empty for AddToLayer' )
2697- for input_index in xrange (len (self .inputs )):
2698- input_layer = self .get_input_layer (input_index )
2699- self .set_layer_size (input_layer .size )
2744+
2745+ if len (self .inputs ) > 1 :
2746+ for input_index in xrange (len (self .inputs )):
2747+ assert self .get_input_layer (0 ).height == self .get_input_layer (
2748+ input_index ).height
2749+ assert self .get_input_layer (0 ).width == self .get_input_layer (
2750+ input_index ).width
2751+ assert self .get_input_layer (0 ).depth == self .get_input_layer (
2752+ input_index ).depth
2753+
2754+ self .set_layer_size (self .get_input_layer (0 ).size )
2755+ self .set_layer_height_width (self .get_input_layer (0 ).height , \
2756+ self .get_input_layer (0 ).width )
2757+ self .set_layer_depth (self .get_input_layer (0 ).depth )
27002758 self .create_bias_parameter (bias , self .config .size )
27012759
27022760
@@ -3376,11 +3434,20 @@ def __init__(self, name, inputs, bias=False, **xargs):
33763434 name , 'concat' , 0 , inputs = inputs , ** xargs )
33773435 size = 0
33783436 for input_index in xrange (len (self .inputs )):
3437+ assert self .get_input_layer (0 ).height == self .get_input_layer (
3438+ input_index ).height
3439+ assert self .get_input_layer (0 ).width == self .get_input_layer (
3440+ input_index ).width
3441+ assert self .get_input_layer (0 ).depth == self .get_input_layer (
3442+ input_index ).depth
33793443 input_layer = self .get_input_layer (input_index )
33803444 input = self .inputs [input_index ]
33813445 if self .config .size == 0 :
33823446 size += input_layer .size
33833447
3448+ self .set_layer_height_width (self .get_input_layer (0 ).height , \
3449+ self .get_input_layer (0 ).width )
3450+ self .set_layer_depth (self .get_input_layer (0 ).depth )
33843451 self .set_layer_size (size )
33853452
33863453
0 commit comments