@@ -1652,46 +1652,81 @@ static void ggml_vk_load_shaders(vk_device& device) {
1652
1652
1653
1653
// Create 2 variants, {f16,f32} accumulator
1654
1654
#define CREATE_MM2 (PIPELINE_NAME, NAMELC, WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT, ID ) \
1655
- CREATE_MM (PIPELINE_NAME . f16acc, NAMELC, _f16acc, WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT, ID) \
1656
- CREATE_MM (PIPELINE_NAME . f32acc, NAMELC, , WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT, ID) \
1655
+ if (device->coopmat_acc_f16_support ) { \
1656
+ CREATE_MM (PIPELINE_NAME . f16acc, NAMELC, _f16acc, WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT, ID) \
1657
+ } \
1658
+ if (device->coopmat_acc_f32_support ) { \
1659
+ CREATE_MM (PIPELINE_NAME . f32acc, NAMELC, , WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT, ID) \
1660
+ } \
1657
1661
1658
1662
CREATE_MM (pipeline_matmul_f32, matmul_f32_f32, , wg_denoms, warptile, vk_mat_mat_push_constants, 3 , );
1659
1663
CREATE_MM (pipeline_matmul_f32_f16, matmul_f32_f16, , wg_denoms, warptile, vk_mat_mat_push_constants, 3 , );
1660
1664
CREATE_MM2 (pipeline_matmul_f16, matmul_f16, wg_denoms, warptile, vk_mat_mat_push_constants, 3 , );
1661
1665
CREATE_MM2 (pipeline_matmul_f16_f32, matmul_f16_f32, wg_denoms, warptile, vk_mat_mat_push_constants, 3 , );
1662
1666
1663
- CREATE_MM (pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_0].f16acc , matmul_q4_0_f32, _f16acc, wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3 , );
1664
- CREATE_MM (pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_1].f16acc , matmul_q4_1_f32, _f16acc, wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3 , );
1665
- CREATE_MM (pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_0].f16acc , matmul_q5_0_f32, _f16acc, wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3 , );
1666
- CREATE_MM (pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_1].f16acc , matmul_q5_1_f32, _f16acc, wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3 , );
1667
- CREATE_MM (pipeline_dequant_mul_mat_mat[GGML_TYPE_Q8_0].f16acc , matmul_q8_0_f32, _f16acc, wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3 , );
1667
+ if (device->coopmat_acc_f16_support ) {
1668
+ CREATE_MM (pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_0].f16acc , matmul_q4_0_f32, _f16acc, wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3 , );
1669
+ CREATE_MM (pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_1].f16acc , matmul_q4_1_f32, _f16acc, wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3 , );
1670
+ CREATE_MM (pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_0].f16acc , matmul_q5_0_f32, _f16acc, wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3 , );
1671
+ CREATE_MM (pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_1].f16acc , matmul_q5_1_f32, _f16acc, wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3 , );
1672
+ CREATE_MM (pipeline_dequant_mul_mat_mat[GGML_TYPE_Q8_0].f16acc , matmul_q8_0_f32, _f16acc, wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3 , );
1673
+
1674
+ CREATE_MM (pipeline_dequant_mul_mat_mat[GGML_TYPE_Q2_K].f16acc , matmul_q2_k_f32, _f16acc, wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3 , );
1675
+ CREATE_MM (pipeline_dequant_mul_mat_mat[GGML_TYPE_Q3_K].f16acc , matmul_q3_k_f32, _f16acc, wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3 , );
1676
+ CREATE_MM (pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_K].f16acc , matmul_q4_k_f32, _f16acc, wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3 , );
1677
+ CREATE_MM (pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_K].f16acc , matmul_q5_k_f32, _f16acc, wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3 , );
1678
+ CREATE_MM (pipeline_dequant_mul_mat_mat[GGML_TYPE_Q6_K].f16acc , matmul_q6_k_f32, _f16acc, wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3 , );
1679
+ CREATE_MM (pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ4_NL].f16acc , matmul_iq4_nl_f32, _f16acc, wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3 , );
1680
+ } else {
1681
+ CREATE_MM (pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_0].f16acc , matmul_q4_0_f32, , wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3 , );
1682
+ CREATE_MM (pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_1].f16acc , matmul_q4_1_f32, , wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3 , );
1683
+ CREATE_MM (pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_0].f16acc , matmul_q5_0_f32, , wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3 , );
1684
+ CREATE_MM (pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_1].f16acc , matmul_q5_1_f32, , wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3 , );
1685
+ CREATE_MM (pipeline_dequant_mul_mat_mat[GGML_TYPE_Q8_0].f16acc , matmul_q8_0_f32, , wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3 , );
1668
1686
1669
- CREATE_MM (pipeline_dequant_mul_mat_mat[GGML_TYPE_Q2_K].f16acc , matmul_q2_k_f32, _f16acc, wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3 , );
1670
- CREATE_MM (pipeline_dequant_mul_mat_mat[GGML_TYPE_Q3_K].f16acc , matmul_q3_k_f32, _f16acc, wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3 , );
1671
- CREATE_MM (pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_K].f16acc , matmul_q4_k_f32, _f16acc, wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3 , );
1672
- CREATE_MM (pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_K].f16acc , matmul_q5_k_f32, _f16acc, wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3 , );
1673
- CREATE_MM (pipeline_dequant_mul_mat_mat[GGML_TYPE_Q6_K].f16acc , matmul_q6_k_f32, _f16acc, wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3 , );
1674
- CREATE_MM (pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ4_NL].f16acc , matmul_iq4_nl_f32, _f16acc, wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3 , );
1687
+ CREATE_MM (pipeline_dequant_mul_mat_mat[GGML_TYPE_Q2_K].f16acc , matmul_q2_k_f32, , wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3 , );
1688
+ CREATE_MM (pipeline_dequant_mul_mat_mat[GGML_TYPE_Q3_K].f16acc , matmul_q3_k_f32, , wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3 , );
1689
+ CREATE_MM (pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_K].f16acc , matmul_q4_k_f32, , wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3 , );
1690
+ CREATE_MM (pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_K].f16acc , matmul_q5_k_f32, , wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3 , );
1691
+ CREATE_MM (pipeline_dequant_mul_mat_mat[GGML_TYPE_Q6_K].f16acc , matmul_q6_k_f32, , wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3 , );
1692
+ CREATE_MM (pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ4_NL].f16acc , matmul_iq4_nl_f32, , wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3 , );
1693
+ }
1675
1694
1676
1695
// If there's not enough shared memory for row_ids and the result tile, don't create these pipelines.
1677
1696
if (device->mul_mat_id_s || device->mul_mat_id_m || device->mul_mat_id_l ) {
1678
1697
CREATE_MM (pipeline_matmul_id_f32, matmul_id_f32_f32, , wg_denoms, warptile, vk_mat_mat_push_constants, 4 , _id);
1679
1698
CREATE_MM2 (pipeline_matmul_id_f16, matmul_id_f16, wg_denoms, warptile, vk_mat_mat_push_constants, 4 , _id);
1680
1699
CREATE_MM2 (pipeline_matmul_id_f16_f32, matmul_id_f16_f32, wg_denoms, warptile, vk_mat_mat_push_constants, 4 , _id);
1681
1700
1682
- CREATE_MM (pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_0].f16acc , matmul_id_q4_0_f32, _f16acc, wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4 , _id);
1683
- CREATE_MM (pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_1].f16acc , matmul_id_q4_1_f32, _f16acc, wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4 , _id);
1684
- CREATE_MM (pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_0].f16acc , matmul_id_q5_0_f32, _f16acc, wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4 , _id);
1685
- CREATE_MM (pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_1].f16acc , matmul_id_q5_1_f32, _f16acc, wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4 , _id);
1686
- CREATE_MM (pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q8_0].f16acc , matmul_id_q8_0_f32, _f16acc, wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4 , _id);
1687
-
1688
- CREATE_MM (pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q2_K].f16acc , matmul_id_q2_k_f32, _f16acc, wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4 , _id);
1689
- CREATE_MM (pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q3_K].f16acc , matmul_id_q3_k_f32, _f16acc, wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4 , _id);
1690
- CREATE_MM (pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_K].f16acc , matmul_id_q4_k_f32, _f16acc, wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4 , _id);
1691
- CREATE_MM (pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_K].f16acc , matmul_id_q5_k_f32, _f16acc, wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4 , _id);
1692
- CREATE_MM (pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q6_K].f16acc , matmul_id_q6_k_f32, _f16acc, wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4 , _id);
1693
- CREATE_MM (pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_NL].f16acc , matmul_id_iq4_nl_f32, _f16acc, wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4 , _id);
1701
+ if (device->coopmat_acc_f16_support ) {
1702
+ CREATE_MM (pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_0].f16acc , matmul_id_q4_0_f32, _f16acc, wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4 , _id);
1703
+ CREATE_MM (pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_1].f16acc , matmul_id_q4_1_f32, _f16acc, wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4 , _id);
1704
+ CREATE_MM (pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_0].f16acc , matmul_id_q5_0_f32, _f16acc, wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4 , _id);
1705
+ CREATE_MM (pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_1].f16acc , matmul_id_q5_1_f32, _f16acc, wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4 , _id);
1706
+ CREATE_MM (pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q8_0].f16acc , matmul_id_q8_0_f32, _f16acc, wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4 , _id);
1707
+
1708
+ CREATE_MM (pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q2_K].f16acc , matmul_id_q2_k_f32, _f16acc, wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4 , _id);
1709
+ CREATE_MM (pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q3_K].f16acc , matmul_id_q3_k_f32, _f16acc, wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4 , _id);
1710
+ CREATE_MM (pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_K].f16acc , matmul_id_q4_k_f32, _f16acc, wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4 , _id);
1711
+ CREATE_MM (pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_K].f16acc , matmul_id_q5_k_f32, _f16acc, wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4 , _id);
1712
+ CREATE_MM (pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q6_K].f16acc , matmul_id_q6_k_f32, _f16acc, wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4 , _id);
1713
+ CREATE_MM (pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_NL].f16acc , matmul_id_iq4_nl_f32, _f16acc, wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4 , _id);
1714
+ } else {
1715
+ CREATE_MM (pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_0].f16acc , matmul_id_q4_0_f32, , wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4 , _id);
1716
+ CREATE_MM (pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_1].f16acc , matmul_id_q4_1_f32, , wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4 , _id);
1717
+ CREATE_MM (pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_0].f16acc , matmul_id_q5_0_f32, , wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4 , _id);
1718
+ CREATE_MM (pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_1].f16acc , matmul_id_q5_1_f32, , wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4 , _id);
1719
+ CREATE_MM (pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q8_0].f16acc , matmul_id_q8_0_f32, , wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4 , _id);
1720
+
1721
+ CREATE_MM (pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q2_K].f16acc , matmul_id_q2_k_f32, , wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4 , _id);
1722
+ CREATE_MM (pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q3_K].f16acc , matmul_id_q3_k_f32, , wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4 , _id);
1723
+ CREATE_MM (pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_K].f16acc , matmul_id_q4_k_f32, , wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4 , _id);
1724
+ CREATE_MM (pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_K].f16acc , matmul_id_q5_k_f32, , wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4 , _id);
1725
+ CREATE_MM (pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q6_K].f16acc , matmul_id_q6_k_f32, , wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4 , _id);
1726
+ CREATE_MM (pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_NL].f16acc , matmul_id_iq4_nl_f32, , wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4 , _id);
1727
+ }
1694
1728
}
1729
+ #undef CREATE_MM2
1695
1730
#undef CREATE_MM
1696
1731
} else if (device->fp16 ) {
1697
1732
// Create 6 variants, {s,m,l}x{unaligned,aligned}
@@ -1709,6 +1744,11 @@ static void ggml_vk_load_shaders(vk_device& device) {
1709
1744
if (device->mul_mat ## ID ## _s) \
1710
1745
ggml_vk_create_pipeline (device, device-> PIPELINE_NAME ->a_s , #NAMELC #F16ACC " _aligned_s" , NAMELC ## _aligned ## F16ACC ## _len, NAMELC ## _aligned ## F16ACC ## _data, " main" , PARAMCOUNT, sizeof (PUSHCONST), s_ ## WG_DENOMS, s_ ## WARPTILE, s_align); \
1711
1746
1747
+ // Create 2 variants, {f16,f32} accumulator
1748
+ #define CREATE_MM2 (PIPELINE_NAME, NAMELC, WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT, ID ) \
1749
+ CREATE_MM (PIPELINE_NAME . f16acc, NAMELC, _f16acc, WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT, ID) \
1750
+ CREATE_MM (PIPELINE_NAME . f32acc, NAMELC, , WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT, ID) \
1751
+
1712
1752
CREATE_MM (pipeline_matmul_f32, matmul_f32_f32, , wg_denoms, warptile, vk_mat_mat_push_constants, 3 , );
1713
1753
CREATE_MM (pipeline_matmul_f32_f16, matmul_f32_f16, , wg_denoms, warptile, vk_mat_mat_push_constants, 3 , );
1714
1754
CREATE_MM2 (pipeline_matmul_f16, matmul_f16, wg_denoms, warptile, vk_mat_mat_push_constants, 3 , );
@@ -1746,6 +1786,7 @@ static void ggml_vk_load_shaders(vk_device& device) {
1746
1786
CREATE_MM (pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q6_K].f16acc , matmul_id_q6_k_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4 , _id);
1747
1787
CREATE_MM (pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_NL].f16acc , matmul_id_iq4_nl_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4 , _id);
1748
1788
}
1789
+ #undef CREATE_MM2
1749
1790
#undef CREATE_MM
1750
1791
} else {
1751
1792
// Create 6 variants, {s,m,l}x{unaligned,aligned}
@@ -1800,7 +1841,6 @@ static void ggml_vk_load_shaders(vk_device& device) {
1800
1841
CREATE_MM (pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q6_K].f32acc , matmul_id_q6_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4 , _id);
1801
1842
CREATE_MM (pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_NL].f32acc , matmul_id_iq4_nl_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4 , _id);
1802
1843
}
1803
- #undef CREATE_MM2
1804
1844
#undef CREATE_MM
1805
1845
}
1806
1846
@@ -2109,11 +2149,11 @@ static vk_device ggml_vk_get_device(size_t idx) {
2109
2149
2110
2150
device->fp16 = !force_disable_f16 && fp16_storage && fp16_compute;
2111
2151
2112
- // if (device->vendor_id == VK_VENDOR_ID_INTEL || (device->vendor_id == VK_VENDOR_ID_AMD && driver_props.driverID == vk::DriverId::eAmdProprietary)) {
2113
- // // Intel drivers don't support coopmat properly yet
2114
- // // Only RADV supports coopmat properly on AMD
2115
- // device->coopmat_support = false;
2116
- // }
2152
+ if (device->vendor_id == VK_VENDOR_ID_INTEL || (device->vendor_id == VK_VENDOR_ID_AMD && driver_props.driverID == vk::DriverId::eAmdProprietary)) {
2153
+ // Intel drivers don't support coopmat properly yet
2154
+ // Only RADV supports coopmat properly on AMD
2155
+ device->coopmat_support = false ;
2156
+ }
2117
2157
2118
2158
std::vector<vk::QueueFamilyProperties> queue_family_props = device->physical_device .getQueueFamilyProperties ();
2119
2159
@@ -2204,8 +2244,8 @@ static vk_device ggml_vk_get_device(size_t idx) {
2204
2244
device->pipeline_robustness = pl_robustness_features.pipelineRobustness ;
2205
2245
2206
2246
device->subgroup_size_control = device->subgroup_size_control &&
2207
- (!( subgroup_size_control_props.requiredSubgroupSizeStages & vk::ShaderStageFlagBits::eCompute) ||
2208
- ! subgroup_size_control_features.subgroupSizeControl ) ;
2247
+ (subgroup_size_control_props.requiredSubgroupSizeStages & vk::ShaderStageFlagBits::eCompute) &&
2248
+ subgroup_size_control_features.subgroupSizeControl ;
2209
2249
2210
2250
if (device->subgroup_size_control ) {
2211
2251
device->subgroup_min_size = subgroup_size_control_props.minSubgroupSize ;
@@ -2363,7 +2403,7 @@ static vk_device ggml_vk_get_device(size_t idx) {
2363
2403
}
2364
2404
}
2365
2405
2366
- if (device->coopmat_m == 0 ) {
2406
+ if (device->coopmat_m == 0 || !device-> coopmat_acc_f32_support ) {
2367
2407
// No suitable matmul mode found
2368
2408
GGML_LOG_DEBUG (" ggml_vulkan: WARNING: No suitable matrix core mode found. Disabling matrix cores.\n " );
2369
2409
device->coopmat_support = false ;
@@ -2496,11 +2536,11 @@ static void ggml_vk_print_gpu_info(size_t idx) {
2496
2536
}
2497
2537
}
2498
2538
2499
- // if (props2.properties.vendorID == VK_VENDOR_ID_INTEL || (props2.properties.vendorID == VK_VENDOR_ID_AMD && driver_props.driverID == vk::DriverId::eAmdProprietary)) {
2500
- // // Intel drivers don't support coopmat properly yet
2501
- // // Only RADV supports coopmat properly on AMD
2502
- // coopmat_support = false;
2503
- // }
2539
+ if (props2.properties .vendorID == VK_VENDOR_ID_INTEL || (props2.properties .vendorID == VK_VENDOR_ID_AMD && driver_props.driverID == vk::DriverId::eAmdProprietary)) {
2540
+ // Intel drivers don't support coopmat properly yet
2541
+ // Only RADV supports coopmat properly on AMD
2542
+ coopmat_support = false ;
2543
+ }
2504
2544
2505
2545
const char * GGML_VK_DISABLE_F16 = getenv (" GGML_VK_DISABLE_F16" );
2506
2546
bool force_disable_f16 = GGML_VK_DISABLE_F16 != nullptr ;
@@ -2783,7 +2823,7 @@ static vk_matmul_pipeline ggml_vk_get_mul_mat_mat_pipeline(ggml_backend_vk_conte
2783
2823
if (src0_type == GGML_TYPE_F32 && src1_type == GGML_TYPE_F16) {
2784
2824
return ctx->device ->pipeline_matmul_f32_f16 ;
2785
2825
}
2786
- if (prec == GGML_PREC_DEFAULT && ctx->device ->fp16 ) {
2826
+ if (prec == GGML_PREC_DEFAULT && ctx->device ->fp16 && !(ctx-> device -> coopmat_support && !ctx-> device -> coopmat_acc_f16_support ) ) {
2787
2827
if (src0_type == GGML_TYPE_F16 && src1_type == GGML_TYPE_F32) {
2788
2828
return ctx->device ->pipeline_matmul_f16_f32 .f16acc ;
2789
2829
}
@@ -2858,7 +2898,7 @@ static vk_matmul_pipeline ggml_vk_get_mul_mat_mat_id_pipeline(ggml_backend_vk_co
2858
2898
if (src0_type == GGML_TYPE_F32 && src1_type == GGML_TYPE_F32) {
2859
2899
return ctx->device ->pipeline_matmul_id_f32 ;
2860
2900
}
2861
- if (prec == GGML_PREC_DEFAULT && ctx->device ->fp16 ) {
2901
+ if (prec == GGML_PREC_DEFAULT && ctx->device ->fp16 && !(ctx-> device -> coopmat_support && !ctx-> device -> coopmat_acc_f16_support ) ) {
2862
2902
if (src0_type == GGML_TYPE_F16 && src1_type == GGML_TYPE_F32) {
2863
2903
return ctx->device ->pipeline_matmul_id_f16_f32 .f16acc ;
2864
2904
}
0 commit comments