Skip to content

Commit 92882b6

Browse files
authored
disable tests for vit op counts (#7874)
1 parent 02d3d6d commit 92882b6

File tree

2 files changed

+12
-4
lines changed

2 files changed

+12
-4
lines changed

test/common_extended_utils.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -140,6 +140,12 @@ def conv_backward_flop(inputs: List[Any], outputs: List[Any]):
140140
return flop_count
141141

142142

143+
def scaled_dot_product_flash_attention_flop(inputs: List[Any], outputs: List[Any]):
144+
# FIXME: this needs to count the flops of this kernel
145+
# https://github.com/pytorch/pytorch/blob/207b06d099def9d9476176a1842e88636c1f714f/aten/src/ATen/native/cpu/FlashAttentionKernel.cpp#L52-L267
146+
return 0
147+
148+
143149
flop_mapping = {
144150
aten.mm: matmul_flop,
145151
aten.matmul: matmul_flop,
@@ -150,6 +156,7 @@ def conv_backward_flop(inputs: List[Any], outputs: List[Any]):
150156
aten.convolution_backward: conv_backward_flop,
151157
quantized.conv2d: quant_conv_flop,
152158
quantized.conv2d_relu: quant_conv_flop,
159+
aten._scaled_dot_product_flash_attention: scaled_dot_product_flash_attention_flop,
153160
}
154161

155162
unmapped_ops = set()

test/test_extended_models.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -242,7 +242,6 @@ def test_naming_conventions(model_fn):
242242
)
243243
@run_if_test_with_extended
244244
def test_schema_meta_validation(model_fn):
245-
246245
if model_fn.__name__ == "maskrcnn_resnet50_fpn_v2":
247246
pytest.skip(reason="FIXME https://github.com/pytorch/vision/issues/7349")
248247

@@ -326,9 +325,11 @@ def test_schema_meta_validation(model_fn):
326325
height, width = detection_models_input_dims[model_name]
327326
kwargs = {"height": height, "width": width}
328327

329-
calculated_ops = get_ops(model=model, weight=w, **kwargs)
330-
if calculated_ops != w.meta["_ops"]:
331-
incorrect_meta.append((w, "_ops"))
328+
if not model_fn.__name__.startswith("vit"):
329+
# FIXME: https://github.com/pytorch/vision/issues/7871
330+
calculated_ops = get_ops(model=model, weight=w, **kwargs)
331+
if calculated_ops != w.meta["_ops"]:
332+
incorrect_meta.append((w, "_ops"))
332333

333334
if not w.name.isupper():
334335
bad_names.append(w)

0 commit comments

Comments
 (0)