@@ -214,10 +214,16 @@ def register_fake(
214214)
215215
216216lib .define (
217- "quantized_max_pool2d (Tensor input, int[] kernel_size, int[] stride, int[] padding, int[] dilation, bool ceil_mode) -> Tensor"
217+ "quantized_max_pool2d_nchw (Tensor input, int[] kernel_size, int[] stride, int[] padding, int[] dilation, bool ceil_mode) -> Tensor"
218218)
219219lib .define (
220- "quantized_max_pool2d.out(Tensor input, int[] kernel_size, int[] stride, int[] padding, int[] dilation, bool ceil_mode, *, Tensor(a!) out) -> Tensor(a!)"
220+ "quantized_max_pool2d_nchw.out(Tensor input, int[] kernel_size, int[] stride, int[] padding, int[] dilation, bool ceil_mode, *, Tensor(a!) out) -> Tensor(a!)"
221+ )
222+ lib .define (
223+ "quantized_max_pool2d_nhwc(Tensor input, int[] kernel_size, int[] stride, int[] padding, int[] dilation, bool ceil_mode) -> Tensor"
224+ )
225+ lib .define (
226+ "quantized_max_pool2d_nhwc.out(Tensor input, int[] kernel_size, int[] stride, int[] padding, int[] dilation, bool ceil_mode, *, Tensor(a!) out) -> Tensor(a!)"
221227)
222228
223229lib .define (
@@ -2277,8 +2283,8 @@ def quantized_relu_asym8u_asym8u_per_tensor_meta(
22772283 return input .new_empty (input .size (), dtype = input .dtype )
22782284
22792285
2280- @register_fake ("cadence::quantized_max_pool2d " )
2281- def quantized_max_pool2d_meta (
2286+ @register_fake ("cadence::quantized_max_pool2d_nchw " )
2287+ def quantized_max_pool2d_nchw_meta (
22822288 input : torch .Tensor ,
22832289 kernel_size : list [int ],
22842290 stride : list [int ],
@@ -2318,6 +2324,47 @@ def quantized_max_pool2d_meta(
23182324 return input .new_empty ([batch , channels , height_out , width_out ], dtype = input .dtype )
23192325
23202326
2327+ @register_fake ("cadence::quantized_max_pool2d_nhwc" )
2328+ def quantized_max_pool2d_nhwc_meta (
2329+ input : torch .Tensor ,
2330+ kernel_size : list [int ],
2331+ stride : list [int ],
2332+ padding : list [int ],
2333+ dilation : list [int ],
2334+ ceil_mode : bool ,
2335+ ) -> torch .Tensor :
2336+ assert (
2337+ len (kernel_size ) == 2
2338+ ), f"kernel_size must have 2 elements, got { len (kernel_size )} "
2339+ assert len (stride ) == 2 , f"stride must have 2 elements, got { len (stride )} "
2340+ assert len (padding ) == 2 , f"padding must have 2 elements, got { len (padding )} "
2341+ assert len (dilation ) == 2 , f"dilation must have 2 elements, got { len (dilation )} "
2342+ assert (
2343+ len (input .size ()) == 4
2344+ ), f"input must be 4D (N, H, W, C), got { len (input .size ())} D"
2345+
2346+ batch = input .size (0 )
2347+ height_in = input .size (1 )
2348+ width_in = input .size (2 )
2349+ channels = input .size (3 )
2350+
2351+ height_out_raw = (
2352+ height_in + 2 * padding [0 ] - dilation [0 ] * (kernel_size [0 ] - 1 ) - 1
2353+ ) / stride [0 ] + 1
2354+ width_out_raw = (
2355+ width_in + 2 * padding [1 ] - dilation [1 ] * (kernel_size [1 ] - 1 ) - 1
2356+ ) / stride [1 ] + 1
2357+
2358+ if ceil_mode :
2359+ height_out = ceil (height_out_raw )
2360+ width_out = ceil (width_out_raw )
2361+ else :
2362+ height_out = int (height_out_raw )
2363+ width_out = int (width_out_raw )
2364+
2365+ return input .new_empty ([batch , height_out , width_out , channels ], dtype = input .dtype )
2366+
2367+
23212368@register_fake ("cadence::fully_connected" )
23222369def fully_connected_meta (
23232370 src : torch .Tensor ,
0 commit comments