diff --git a/convert_hf_to_gguf.py b/convert_hf_to_gguf.py
index 76ee6cef52ac0..3af4a7043b7bd 100755
--- a/convert_hf_to_gguf.py
+++ b/convert_hf_to_gguf.py
@@ -65,6 +65,8 @@ class Model:
     model_name: str | None
     metadata_override: Path | None
     dir_model_card: Path
+    enable_t_mac: bool
+    kcfg_file: Path | None
 
     # subclasses should define this!
     model_arch: gguf.MODEL_ARCH
@@ -73,7 +75,8 @@ def __init__(self, dir_model: Path, ftype: gguf.LlamaFileType, fname_out: Path,
                  use_temp_file: bool = False, eager: bool = False,
                  metadata_override: Path | None = None, model_name: str | None = None,
                  split_max_tensors: int = 0, split_max_size: int = 0, dry_run: bool = False,
-                 small_first_shard: bool = False, hparams: dict[str, Any] | None = None):
+                 small_first_shard: bool = False, hparams: dict[str, Any] | None = None,
+                 enable_t_mac: bool = False, kcfg_file: Path | None = None):
         if type(self) is Model:
             raise TypeError(f"{type(self).__name__!r} should not be directly instantiated")
 
@@ -95,7 +98,8 @@ def __init__(self, dir_model: Path, ftype: gguf.LlamaFileType, fname_out: Path,
         self.metadata_override = metadata_override
         self.model_name = model_name
         self.dir_model_card = dir_model  # overridden in convert_lora_to_gguf.py
-
+        self.enable_t_mac = enable_t_mac
+        self.kcfg_file = kcfg_file
         # Apply heuristics to figure out typical tensor encoding based on first layer tensor encoding type
         if self.ftype == gguf.LlamaFileType.GUESSED:
             # NOTE: can't use field "torch_dtype" in config.json, because some finetunes lie.
@@ -265,6 +269,185 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iter
 
         return [(self.map_tensor_name(name), data_torch)]
 
+    _gptq_quant_dict: dict[str, Tensor] | None = None
+    _t_mac_bits: int = 0
+    _t_mac_raw_shape: tuple[int, ...] | None = None
+
+    # Repack and merge qweight, scales, and qzeros into a single tensor
+    # Currently, this logic is nearly impossible to be implemented in quants.py
+    def _modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
+        if not self.enable_t_mac:
+            return self.modify_tensors(data_torch, name, bid)
+
+        # bits = 0 means not quantized
+        self._t_mac_bits = 0
+        self._t_mac_raw_shape = None
+
+        from t_mac.model_utils import get_quantization_config, preprocess_for_t_mac
+        quantization_config = get_quantization_config(self.dir_model)
+
+        if quantization_config["quant_method"] == "gptq":  # AutoGPTQ/GPTQModel
+            if name.endswith(".g_idx"):
+                return []
+
+            if name.endswith(".qweight") or name.endswith(".scales") or name.endswith(".qzeros"):
+                if self._gptq_quant_dict is None:
+                    self._gptq_quant_dict = {}
+                suffix = "." + name.split(".")[-1]
+                base_name = name.replace(suffix, "")
+                self._gptq_quant_dict.setdefault(base_name, {})[suffix] = data_torch
+                if len(self._gptq_quant_dict[base_name]) < 3:
+                    return []
+
+                qweight = LazyTorchTensor.to_eager(self._gptq_quant_dict[base_name][".qweight"]).numpy()
+                scales = LazyTorchTensor.to_eager(self._gptq_quant_dict[base_name][".scales"]).numpy()
+                qzeros = LazyTorchTensor.to_eager(self._gptq_quant_dict[base_name][".qzeros"]).numpy()
+                name = base_name + ".weight"
+                from t_mac.model_utils import unpack_gptqv2
+                w, scales, zeros, bits, group_size = unpack_gptqv2(qweight, scales, qzeros, "gptqmodel" in quantization_config["quantizer"])
+                self._t_mac_bits = bits
+                self._t_mac_raw_shape = w.shape
+                if bits != quantization_config["bits"] or group_size != quantization_config["group_size"]:
+                    logger.warning("Error while parsing weights for quantization_config: {}".format(quantization_config))
+
+                # For permutation in, e.g., LlamaModel
+                w = self.modify_tensors(torch.from_numpy(w), name, bid)[0][1].numpy()
+                scales = self.modify_tensors(torch.from_numpy(scales), name, bid)[0][1].numpy()
+                zeros = self.modify_tensors(torch.from_numpy(zeros), name, bid)[0][1].numpy()
+
+                if self.ftype == gguf.LlamaFileType.MOSTLY_INT_N:
+                    if quantization_config["sym"]:
+                        if not np.allclose(zeros, np.zeros_like(zeros)):
+                            logger.warning("Although the quantized model claimed to be symmetric, the weights are asymmetric")
+                        else:
+                            zeros = None
+                    data_torch = torch.from_numpy(preprocess_for_t_mac(self.kcfg_file, w, scales, zeros, bits=bits))
+                else:
+                    old_shape = w.shape
+                    w = w.astype("float32").reshape(-1, group_size)
+                    scales = scales.astype("float32").reshape(-1, 1)
+                    zeros = zeros.astype("float32").reshape(-1, 1)
+                    data = (w - (zeros / scales + (2 ** (bits - 1)))) * scales
+                    data_torch = torch.from_numpy(data.reshape(old_shape))
+                    if self.ftype == gguf.LlamaFileType.MOSTLY_F16:
+                        data_torch = data_torch.to(torch.float16)
+
+                return [(self.map_tensor_name(name), data_torch)]
+        elif quantization_config["quant_method"] == "bitdistiller":
+            new_name = self.map_tensor_name(name, try_suffixes=(".weight", ".bias"))
+            extra_f32 = any(self.match_model_tensor_name(new_name, key, bid) for key in (
+                gguf.MODEL_TENSOR.FFN_GATE_INP,
+                gguf.MODEL_TENSOR.POS_EMBD,
+                gguf.MODEL_TENSOR.TOKEN_TYPES,
+            ))
+
+            # if f16 desired, convert any float32 2-dim weight tensors to float16
+            data = data_torch.numpy()
+            n_dims = len(data.shape)
+            extra_f16 = any(cond for cond in (
+                (name.endswith(".weight") and n_dims >= 2),
+            ))
+
+            to_dtype = gguf.GGMLQuantizationType.F32
+
+            if self.ftype != gguf.LlamaFileType.ALL_F32 and extra_f16 and not extra_f32:
+                if self.ftype == gguf.LlamaFileType.MOSTLY_INT_N and any(self.match_model_tensor_name(new_name, key, bid) for key in [
+                    gguf.MODEL_TENSOR.ATTN_Q,
+                    gguf.MODEL_TENSOR.ATTN_K,
+                    gguf.MODEL_TENSOR.ATTN_V,
+                    gguf.MODEL_TENSOR.ATTN_QKV,
+                    gguf.MODEL_TENSOR.ATTN_OUT,
+                    gguf.MODEL_TENSOR.FFN_UP,
+                    gguf.MODEL_TENSOR.FFN_DOWN,
+                    gguf.MODEL_TENSOR.FFN_GATE,
+                ]):
+                    # I2 here is just a symbol for INT_N type.
+                    to_dtype = gguf.GGMLQuantizationType.I2
+                else:
+                    to_dtype = gguf.GGMLQuantizationType.F16
+
+            if to_dtype == gguf.GGMLQuantizationType.I2:
+                bits = quantization_config["bits"]
+                group_size = quantization_config["group_size"]
+                w, scales, zeros = self._t_mac_quantize_tensor_bitdistiller(
+                    LazyTorchTensor.to_eager(data_torch),
+                    n_bit=bits,
+                    zero_point=True,
+                    q_group_size=group_size,
+                )
+
+                self._t_mac_bits = quantization_config["bits"]
+                self._t_mac_raw_shape = w.shape
+
+                # For permutation in, e.g., LlamaModel
+                w = self.modify_tensors(torch.from_numpy(w), name, bid)[0][1].numpy()
+                scales = self.modify_tensors(torch.from_numpy(scales), name, bid)[0][1].numpy()
+                zeros = self.modify_tensors(torch.from_numpy(zeros), name, bid)[0][1].numpy()
+
+                if self.ftype == gguf.LlamaFileType.MOSTLY_INT_N:
+                    if quantization_config["sym"]:
+                        if not np.allclose(zeros, np.zeros_like(zeros)):
+                            logger.warning("Although the quantized model claimed to be symmetric, the weights are asymmetric")
+                        else:
+                            zeros = None
+                    data_torch = torch.from_numpy(preprocess_for_t_mac(self.kcfg_file, w, scales, zeros, bits=bits))
+                else:
+                    old_shape = w.shape
+                    w = w.astype("float32").reshape(-1, group_size)
+                    scales = scales.astype("float32").reshape(-1, 1)
+                    zeros = zeros.astype("float32").reshape(-1, 1)
+                    data = (w - (zeros / scales + (2 ** (bits - 1)))) * scales
+                    data_torch = torch.from_numpy(data.reshape(old_shape))
+                    if self.ftype == gguf.LlamaFileType.MOSTLY_F16:
+                        data_torch = data_torch.to(torch.float16)
+
+                return [(self.map_tensor_name(name), data_torch)]
+
+        return self.modify_tensors(data_torch, name, bid)
+
+    # Modified version of BitDistiller pseudo_quantize_tensor
+    # core quantization method (simulated quantization)
+    def _t_mac_quantize_tensor_bitdistiller(self, w, n_bit=8, zero_point=True, q_group_size=-1):
+        org_w_shape = w.shape
+        if q_group_size > 0:
+            assert org_w_shape[-1] % q_group_size == 0
+            w = w.reshape(-1, q_group_size)
+        elif q_group_size == -1:
+            w = w.reshape(-1, w.shape[-1])
+        assert w.dim() == 2
+        if zero_point:
+            max_val = w.amax(dim=1, keepdim=True)
+            min_val = w.amin(dim=1, keepdim=True)
+            max_int = 2 ** n_bit - 1
+            min_int = 0
+            scales = (max_val - min_val).clamp(min=1e-5) / max_int
+            zeros = (-torch.round(min_val / scales)).clamp_(min_int, max_int)
+        else:  # we actually never used this
+            max_val = w.abs().amax(dim=1, keepdim=True)
+            max_val = max_val.clamp(min=1e-5)
+            max_int = 2 ** (n_bit - 1) - 1
+            min_int = - 2 ** (n_bit - 1)
+            scales = max_val / max_int
+            zeros = 0
+
+        assert torch.isnan(scales).sum() == 0
+        assert torch.isnan(w).sum() == 0
+
+        w = torch.clamp(torch.round(w / scales) + zeros, min_int, max_int)
+
+        w = w.reshape(org_w_shape).numpy()
+        scales = scales.numpy().reshape(w.shape[0], -1)
+        zeros = zeros.numpy().reshape(w.shape[0], -1) if zero_point else None
+
+        if zero_point:
+            w = w.astype(np.uint8)
+            zeros = (zeros - (2 ** (n_bit - 1))) * scales
+            return w, scales, zeros
+        else:
+            w = (w - min_int).astype(np.uint8)
+            return w, scales, zeros
+    
+
     def tensor_force_quant(self, name: str, new_name: str, bid: int | None, n_dims: int) -> gguf.GGMLQuantizationType | bool:
         del name, new_name, bid, n_dims  # unused
 
@@ -285,7 +468,7 @@ def prepare_tensors(self):
             old_dtype = data_torch.dtype
 
             # convert any unsupported data types to float32
-            if data_torch.dtype not in (torch.float16, torch.float32):
+            if data_torch.dtype not in (torch.float16, torch.float32) and not self.enable_t_mac:
                 data_torch = data_torch.to(torch.float32)
 
             # use the first number-like part of the tensor name as the block id
@@ -295,7 +478,13 @@ def prepare_tensors(self):
                     bid = int(part)
                     break
 
-            for new_name, data_torch in (self.modify_tensors(data_torch, name, bid)):
+            for new_name, data_torch in (self._modify_tensors(data_torch, name, bid)):
+                # Some GPTQ models have empty bias tensors which are not in the model architecture.
+                # These tensors will cause tensor number check to fail, so we have to skip them.
+                if new_name.endswith(".bias") and np.all(LazyTorchTensor.to_eager(data_torch).numpy() == 0):
+                    logger.info(f"Skipping empty bias tensor: {new_name}")
+                    continue
+
                 data = data_torch.squeeze().numpy()
 
                 # if data ends up empty, it means data_torch was a scalar tensor -> restore
@@ -344,6 +533,19 @@ def prepare_tensors(self):
                         # TODO: use Q4_K and Q6_K
                         data_qtype = gguf.GGMLQuantizationType.F16
 
+                # If self._t_mac_bits > 0, the tensor is quantized by GPTQ
+                if self.enable_t_mac and self._t_mac_bits > 0 and self.ftype == gguf.LlamaFileType.MOSTLY_INT_N:
+                    if self._t_mac_bits == 1:
+                        data_qtype = gguf.GGMLQuantizationType.I1
+                    elif self._t_mac_bits == 2:
+                        data_qtype = gguf.GGMLQuantizationType.I2
+                    elif self._t_mac_bits == 3:
+                        data_qtype = gguf.GGMLQuantizationType.I3
+                    elif self._t_mac_bits == 4:
+                        data_qtype = gguf.GGMLQuantizationType.I4
+                    else:
+                        raise ValueError(f"Unsupported number of bits: {self._t_mac_bits}")
+
                 # No override (data_qtype is False), or wants to be quantized (data_qtype is True)
                 if isinstance(data_qtype, bool):
                     if self.ftype == gguf.LlamaFileType.ALL_F32:
@@ -358,6 +560,12 @@ def prepare_tensors(self):
                         data_qtype = gguf.GGMLQuantizationType.TQ1_0
                     elif self.ftype == gguf.LlamaFileType.MOSTLY_TQ2_0:
                         data_qtype = gguf.GGMLQuantizationType.TQ2_0
+                    elif self.ftype == gguf.LlamaFileType.MOSTLY_INT_N:
+                        # If the tensor is successfully quantized, data_qtype should be I1/2/3/4
+                        # If data_qtype is still bool, then the tensor should not be quantized
+                        # In practice, this tensor is `output.weight` for GPTQ models
+                        # TODO: Consider quantizing it?
+                        data_qtype = gguf.GGMLQuantizationType.F16
                     else:
                         raise ValueError(f"Unknown file type: {self.ftype.name}")
 
@@ -369,6 +577,7 @@ def prepare_tensors(self):
                     data = gguf.quants.quantize(data, data_qtype)
 
                 shape = gguf.quant_shape_from_byte_shape(data.shape, data_qtype) if data.dtype == np.uint8 else data.shape
+                shape = self._t_mac_raw_shape or shape
 
                 # reverse shape to make it similar to the internal ggml dimension order
                 shape_str = f"{{{', '.join(str(n) for n in reversed(shape))}}}"
@@ -376,7 +585,8 @@ def prepare_tensors(self):
                 # n_dims is implicit in the shape
                 logger.info(f"{f'%-{max_name_len}s' % f'{new_name},'} {old_dtype} --> {data_qtype.name}, shape = {shape_str}")
 
-                self.gguf_writer.add_tensor(new_name, data, raw_dtype=data_qtype)
+                raw_shape = gguf.quant_shape_to_byte_shape(self._t_mac_raw_shape, data_qtype) if self.ftype == gguf.LlamaFileType.MOSTLY_INT_N and self._t_mac_raw_shape else None
+                self.gguf_writer.add_tensor(new_name, data, raw_dtype=data_qtype, raw_shape=raw_shape)
 
     def set_type(self):
         self.gguf_writer.add_type(gguf.GGUFType.MODEL)
@@ -1700,6 +1910,15 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iter
         ]):
             # transform weight into 1/0/-1 (in fp32)
             data_torch = self.weight_quant(data_torch)
+            if self.enable_t_mac and self.ftype == gguf.LlamaFileType.MOSTLY_INT_N:
+                # transform weight into T-MAC INT_N format
+                from t_mac.model_utils import preprocess_for_t_mac
+                data = LazyTorchTensor.to_eager(data_torch).numpy()
+                scale = np.max(np.abs(data))
+                w = np.round(data / scale + 2).astype(np.uint8)
+                data_torch = torch.from_numpy(preprocess_for_t_mac(self.kcfg_file, w, scale.reshape(1), bits=2))
+                self._t_mac_bits = 2
+                self._t_mac_raw_shape = w.shape
 
         yield (new_name, data_torch)
 
@@ -4297,8 +4516,8 @@ def parse_args() -> argparse.Namespace:
         help="path to write to; default: based on input. {ftype} will be replaced by the outtype.",
     )
     parser.add_argument(
-        "--outtype", type=str, choices=["f32", "f16", "bf16", "q8_0", "tq1_0", "tq2_0", "auto"], default="f16",
-        help="output format - use f32 for float32, f16 for float16, bf16 for bfloat16, q8_0 for Q8_0, tq1_0 or tq2_0 for ternary, and auto for the highest-fidelity 16-bit float type depending on the first loaded tensor type",
+        "--outtype", type=str, choices=["f32", "f16", "bf16", "q8_0", "tq1_0", "tq2_0", "int_n", "auto"], default="f16",
+        help="output format - use f32 for float32, f16 for float16, bf16 for bfloat16, q8_0 for Q8_0, tq1_0 or tq2_0 for ternary, and int_n for int1/2/3/4, auto for the highest-fidelity 16-bit float type depending on the first loaded tensor type",
     )
     parser.add_argument(
         "--bigendian", action="store_true",
@@ -4344,6 +4563,14 @@ def parse_args() -> argparse.Namespace:
         "--metadata", type=Path,
         help="Specify the path for an authorship metadata override file"
     )
+    parser.add_argument(
+        "--enable-t-mac", action="store_true",
+        help="Enable T-MAC quantization format (disabled by default). Support GPTQ, GPTQv2, BitNet and BitDistiller."
+    )
+    parser.add_argument(
+        "--kcfg", type=Path,
+        help="Specify the path for the T-MAC configuration file"
+    )
 
     return parser.parse_args()
 
@@ -4387,6 +4614,7 @@ def main() -> None:
         "q8_0": gguf.LlamaFileType.MOSTLY_Q8_0,
         "tq1_0": gguf.LlamaFileType.MOSTLY_TQ1_0,
         "tq2_0": gguf.LlamaFileType.MOSTLY_TQ2_0,
+        "int_n": gguf.LlamaFileType.MOSTLY_INT_N,
         "auto": gguf.LlamaFileType.GUESSED,
     }
 
@@ -4413,14 +4641,14 @@ def main() -> None:
         except NotImplementedError:
             logger.error(f"Model {model_architecture} is not supported")
             sys.exit(1)
-
         model_instance = model_class(dir_model=dir_model, ftype=output_type, fname_out=fname_out,
                                      is_big_endian=args.bigendian, use_temp_file=args.use_temp_file,
                                      eager=args.no_lazy,
                                      metadata_override=args.metadata, model_name=args.model_name,
                                      split_max_tensors=args.split_max_tensors,
                                      split_max_size=split_str_to_n_bytes(args.split_max_size), dry_run=args.dry_run,
-                                     small_first_shard=args.no_tensor_first_split)
+                                     small_first_shard=args.no_tensor_first_split,
+                                     enable_t_mac=args.enable_t_mac, kcfg_file=args.kcfg)
 
         if args.vocab_only:
             logger.info("Exporting model vocab...")
diff --git a/ggml/CMakeLists.txt b/ggml/CMakeLists.txt
index cfa6e3f70e4a3..8c586640bbf9c 100644
--- a/ggml/CMakeLists.txt
+++ b/ggml/CMakeLists.txt
@@ -166,6 +166,9 @@ option(GGML_SYCL                            "ggml: use SYCL"
 option(GGML_SYCL_F16                        "ggml: use 16 bit floats for sycl calculations"   OFF)
 set   (GGML_SYCL_TARGET "INTEL" CACHE STRING
                                             "ggml: sycl target device")
+option(GGML_TMAC                            "ggml: use TMAC"                                  OFF)
+option(GGML_TMAC_SYSLIB                     "ggml: use TMAC system library"                   OFF)
+option(GGML_TMAC_TVM_THREADPOOL             "ggml: use TVM threadpool for TMAC"               OFF)
 
 # extra artifacts
 option(GGML_BUILD_TESTS    "ggml: build tests"    ${GGML_STANDALONE})
diff --git a/ggml/include/ggml-tmac.h b/ggml/include/ggml-tmac.h
new file mode 100644
index 0000000000000..f79b674455dc6
--- /dev/null
+++ b/ggml/include/ggml-tmac.h
@@ -0,0 +1,42 @@
+#pragma once
+
+#include "ggml.h"
+#include "ggml-backend.h"
+
+#ifdef __ARM_NEON
+#include <arm_neon.h>
+typedef float16_t tmac_float_type;
+#else
+typedef float tmac_float_type;
+#endif
+
+#ifdef  __cplusplus
+extern "C" {
+#endif
+
+struct tmac_tensor_extra {
+    int lut_scales_size;
+    int scales_size;
+    int n_tile_num;
+    uint8_t * qweights;
+    tmac_float_type * scales;
+};
+
+GGML_API void ggml_tmac_init(void);
+GGML_API void ggml_tmac_free(void);
+// src0->type == Q4_0/IQ2_XXS/IQ3_XXS
+// T-MAC currently only supports BitNet quantization or GPTQ-like quantization (only scales, without zeros)
+// If use i-quantization gguf models, the results will be wrong
+// TODO: add customized block types Q2_0/Q3_0
+GGML_API bool ggml_tmac_can_mul_mat(const struct ggml_tensor * src0, const struct ggml_tensor * src1, const struct ggml_tensor * dst);
+GGML_API size_t ggml_tmac_mul_mat_get_wsize(const struct ggml_tensor * src0, const struct ggml_tensor * src1, const struct ggml_tensor * dst);
+GGML_API void ggml_tmac_mul_mat_task_init(void * src1, void * qlut, void * lut_scales, void * lut_biases, int n, int k, int m, int bits);
+GGML_API void ggml_tmac_mul_mat_task_compute(void * src0, void * scales, void * qlut, void * lut_scales, void * lut_biases, void * dst, int n, int k, int m, int bits);
+GGML_API void ggml_tmac_transform_tensor(struct ggml_tensor * tensor);
+GGML_API int ggml_tmac_get_type_bits(enum ggml_type type);
+GGML_API void ggml_tmac_set_n_threads(int n_threads);
+GGML_API size_t ggml_tmac_get_nbytes(const struct ggml_tensor * tensor);
+
+#ifdef  __cplusplus
+}
+#endif
diff --git a/ggml/include/ggml.h b/ggml/include/ggml.h
index 8a0bcbff8c61a..385330c7ae679 100644
--- a/ggml/include/ggml.h
+++ b/ggml/include/ggml.h
@@ -389,6 +389,10 @@ extern "C" {
         GGML_TYPE_Q4_0_8_8 = 33,
         GGML_TYPE_TQ1_0   = 34,
         GGML_TYPE_TQ2_0   = 35,
+        GGML_TYPE_I1      = 36,
+        GGML_TYPE_I2      = 37,
+        GGML_TYPE_I3      = 38,
+        GGML_TYPE_I4      = 39,
         GGML_TYPE_COUNT,
     };
 
diff --git a/ggml/src/CMakeLists.txt b/ggml/src/CMakeLists.txt
index 34b81bd7fdda1..00b39d455266b 100644
--- a/ggml/src/CMakeLists.txt
+++ b/ggml/src/CMakeLists.txt
@@ -870,6 +870,42 @@ if (GGML_KOMPUTE)
     endif()
 endif()
 
+if (GGML_TMAC)
+    find_package(TMAC)
+
+    if (TMAC_FOUND)
+        message(STATUS "TMAC found")
+
+        list(APPEND GGML_CDEF_PUBLIC GGML_USE_TMAC)
+
+        set(GGML_HEADERS_TMAC ../include/ggml-tmac.h)
+        set(GGML_SOURCES_TMAC ggml-tmac.cpp)
+
+        link_directories(${TMAC_LIB_DIR})
+        file(COPY ${TMAC_LIB_DIR}/kcfg.ini DESTINATION ${CMAKE_CURRENT_BINARY_DIR})
+        # TODO: link t_mac_object when GGML_TMAC_SYSLIB
+
+        if (GGML_TMAC_TVM_THREADPOOL)
+            add_compile_definitions(TMAC_USE_TVM_THREADPOOL)
+            set(GGML_EXTRA_LIBS_PRIVATE ${GGML_EXTRA_LIBS_PRIVATE} t_mac)
+        else()
+            if ((NOT ${CMAKE_C_COMPILER_ID} MATCHES "Clang") OR
+                (NOT ${CMAKE_CXX_COMPILER_ID} MATCHES "Clang"))
+                message(FATAL_ERROR "Clang is required for T-MAC compilation")
+            endif()
+
+            set(GGML_EXTRA_LIBS_PRIVATE ${GGML_EXTRA_LIBS_PRIVATE} t_mac_no_tvm)
+            set(GGML_SOURCES_TMAC ${GGML_SOURCES_TMAC} ${TMAC_KERNELS_SOURCE})
+        endif()
+
+        if (GGML_TMAC_RECHUNK)
+            add_compile_definitions(TMAC_RECHUNK)
+        endif()
+    else()
+        message(WARNING "TMAC not found")
+    endif()
+endif()
+
 if (GGML_CPU_HBM)
     find_library(memkind memkind REQUIRED)
 
@@ -1170,6 +1206,26 @@ if (CMAKE_OSX_ARCHITECTURES      STREQUAL "arm64" OR
             # Raspberry Pi 3, 4, Zero 2 (32-bit)
             list(APPEND ARCH_FLAGS -mno-unaligned-access)
         endif()
+        if (${CMAKE_SYSTEM_PROCESSOR} MATCHES "aarch64" AND GGML_TMAC AND TMAC_FOUND)
+            # We need fullfp16 for T-MAC
+            # TODO: we need to simplify this logic through check_cxx_source_compiles or Presets?
+            check_cxx_source_compiles("#include <arm_neon.h>\nint main() { int8x16_t _a, _b; int32x4_t _s = vmlaq_f32(_s, _a, _b); return 0; }" GGML_COMPILER_SUPPORT_MATMUL_INT8)
+            if (GGML_COMPILER_SUPPORT_MATMUL_INT8)
+                # Device with armv8.7a+ cpu, e.g., WSL on Surface Laptop 7
+                # based on arm64-windows-llvm.cmake
+                list(APPEND ARCH_FLAGS -march=armv8.7-a+fp16 -fvectorize -ffp-model=fast -fno-finite-math-only)
+                add_compile_definitions(__ARM_FEATURE_MATMUL_INT8)
+            else ()
+                # Jetson AGX Orin, Raspberry Pi 5
+                list(APPEND ARCH_FLAGS -march=armv8.2a+fp16)
+            endif ()
+        endif()
+        if (${CMAKE_SYSTEM_PROCESSOR} MATCHES "ARM64" AND GGML_TMAC AND TMAC_FOUND)
+            # ARM Windows with LLVM clang GNU interface
+            # We need fullfp16 for T-MAC
+            # TODO: check_cxx_source_compiles
+            list(APPEND ARCH_FLAGS -march=armv8.2a+fp16)
+        endif()
         if (GGML_SVE)
             list(APPEND ARCH_FLAGS -march=armv8.6-a+sve)
         endif()
@@ -1184,7 +1240,9 @@ elseif (CMAKE_OSX_ARCHITECTURES STREQUAL "x86_64" OR CMAKE_GENERATOR_PLATFORM_LW
             # TODO: improve, should not reference files from the parent folder
             include(../cmake/FindSIMD.cmake)
         endif ()
-        if (GGML_AVX512)
+        # Can't use GGML_AVX512 with Clang for MSVC
+        # with error: conflicting types for '_m_prefetchw
+        if (GGML_AVX512 AND (NOT ${CMAKE_C_COMPILER_ID} MATCHES "Clang") AND (NOT ${CMAKE_CXX_COMPILER_ID} MATCHES "Clang"))
             list(APPEND ARCH_FLAGS /arch:AVX512)
             # MSVC has no compile-time flags enabling specific
             # AVX512 extensions, neither it defines the
@@ -1388,6 +1446,7 @@ add_library(ggml
             ${GGML_SOURCES_LLAMAFILE} ${GGML_HEADERS_LLAMAFILE}
             ${GGML_SOURCES_AMX}       ${GGML_HEADERS_AMX}
             ${GGML_SOURCES_CANN}      ${GGML_HEADERS_CANN}
+            ${GGML_SOURCES_TMAC}      ${GGML_HEADERS_TMAC}
             ggml-aarch64.c            ggml-aarch64.h
             )
 
diff --git a/ggml/src/ggml-cpu.c b/ggml/src/ggml-cpu.c
index 4b8ffb629afbb..44e552aee2145 100644
--- a/ggml/src/ggml-cpu.c
+++ b/ggml/src/ggml-cpu.c
@@ -84,6 +84,10 @@
 #include <Accelerate/Accelerate.h>
 #endif
 
+#if defined(GGML_USE_TMAC)
+#include "ggml-tmac.h"
+#endif
+
 // floating point type used to accumulate sums
 typedef double ggml_float;
 
@@ -423,6 +427,26 @@ static const struct ggml_type_traits_cpu type_traits_cpu[GGML_TYPE_COUNT] = {
         .vec_dot_type             = GGML_TYPE_Q8_K,
         .nrows                    = 1,
     },
+    [GGML_TYPE_I1] = {
+        .vec_dot                  = (ggml_vec_dot_t) ggml_vec_dot_f32,
+        .vec_dot_type             = GGML_TYPE_F32,
+        .nrows                    = 1,
+    },
+    [GGML_TYPE_I2] = {
+        .vec_dot                  = (ggml_vec_dot_t) ggml_vec_dot_f32,
+        .vec_dot_type             = GGML_TYPE_F32,
+        .nrows                    = 1,
+    },
+    [GGML_TYPE_I3] = {
+        .vec_dot                  = (ggml_vec_dot_t) ggml_vec_dot_f32,
+        .vec_dot_type             = GGML_TYPE_F32,
+        .nrows                    = 1,
+    },
+    [GGML_TYPE_I4] = {
+        .vec_dot                  = (ggml_vec_dot_t) ggml_vec_dot_f32,
+        .vec_dot_type             = GGML_TYPE_F32,
+        .nrows                    = 1,
+    },
 };
 
 const struct ggml_type_traits_cpu * ggml_get_type_traits_cpu(enum ggml_type type) {
@@ -7478,6 +7502,156 @@ static void ggml_compute_forward_mul_mat(
 UseGgmlGemm1:;
 #endif
 
+// TODO: Refactor t-mac as ggml-backend,
+//       as ggml-blas.cpp has been moved to backend
+#if defined(GGML_USE_TMAC)
+    if (ggml_tmac_can_mul_mat(src0, src1, dst)) {
+        const int bits = ggml_tmac_get_type_bits(type);
+        // src0: weight,     ne00 = k, ne01 = n
+        // src1: activation, ne10 = k, ne11 = m
+        char * wdata = params->wdata;
+
+        struct tmac_tensor_extra * wt = src0->extra;
+        char * cur_wdata = wdata;
+        tmac_float_type * tmac_f_ptr = wdata;
+        if (sizeof(tmac_float_type) == 2) {
+            cur_wdata = wdata + MAX(ne10, ne01) * ne11 * sizeof(tmac_float_type);
+        };
+        int8_t * qlut = cur_wdata;
+        tmac_float_type * lut_scales = (tmac_float_type *) (qlut + ne10 * ne11 * 4);
+        tmac_float_type * lut_biases = (tmac_float_type *) (lut_scales + wt->lut_scales_size * ne11);
+
+        GGML_ASSERT(src1->type == GGML_TYPE_F32);
+        tmac_float_type * act_input;
+        if (sizeof(tmac_float_type) == 2) {
+            act_input = tmac_f_ptr;
+        } else {
+            act_input = src1->data;
+        }
+        for (int ine11 = ith; ine11 < ne11; ine11 += nth) {
+            if (sizeof(tmac_float_type) == 2) {
+                ggml_fp32_to_fp16_row((const float *) src1->data + ne10 * ine11, act_input + ne10 * ine11, ne10);
+            }
+            ggml_tmac_mul_mat_task_init(act_input + ne10 * ine11,
+                                        qlut + ne10 * ine11 * 4,
+                                        lut_scales + wt->lut_scales_size * ine11,
+                                        lut_biases + wt->lut_scales_size * ine11,
+                                        ne01, ne00, 1, bits);
+        }
+
+        if (ith == 0) {
+            // Every thread starts at ith, so the first unprocessed chunk is nth.  This save a bit of coordination right at the start.
+            atomic_store_explicit(&params->threadpool->current_chunk, nth, memory_order_relaxed);
+        }
+
+        ggml_barrier(params->threadpool);
+
+        tmac_float_type * act_output;
+        if (sizeof(tmac_float_type) == 2) {
+            act_output = tmac_f_ptr;
+        } else {
+            act_output = dst->data;
+        }
+// TODO: remove TVM threadpool if ensuring unused
+#if defined(TMAC_USE_TVM_THREADPOOL)
+        if (ith != 0) {
+            return;
+        }
+        // TODO: schedule ne11(m) in T-MAC
+        for (int ine11 = 0; ine11 < ne11; ine11++) {
+            const int qlut_offset       = ne10 * ine11 * 4;
+            const int lut_scales_offset = wt->lut_scales_size * ine11;
+            const int dst_offset        = ne0 * ine11;
+
+            ggml_tmac_mul_mat_task_compute(wt->qweights,
+                                           wt->scales,
+                                           qlut + qlut_offset,
+                                           lut_scales + lut_scales_offset,
+                                           lut_biases + lut_scales_offset,
+                                           act_output + dst_offset,
+                                           ne01, ne00, 1, bits);
+        }
+        if (sizeof(tmac_float_type) == 2) {
+            ggml_fp16_to_fp32_row(tmac_f_ptr, dst->data, ne00 * ne01);
+        }
+#else  // #if defined(TMAC_USE_TVM_THREADPOOL)
+        const int n_tile_num = wt->n_tile_num;
+        // Currently, T-MAC requires ne0 devisible by n_tile_num
+        GGML_ASSERT(ne0 % n_tile_num == 0);
+
+        const int64_t w_size       = ne00 * ne01 * bits / 8;
+        const int64_t w_chunk_size = w_size / n_tile_num;
+
+        const int64_t nr0 = ne0;
+        const int64_t nr1 = ne1 * ne2 * ne3;
+
+        // Adopt the same style with current llama.cpp impl
+        // But different chunk size for 0/1 dim.
+        // No scrap.
+        const int chunk_size0 = ne0 / n_tile_num;
+        const int chunk_size1 = 8;  // TODO: tune in T-MAC
+
+        // nchunk0 == n_tile_num
+        int64_t nchunk0 = (nr0 + chunk_size0 - 1) / chunk_size0;
+        int64_t nchunk1 = (nr1 + chunk_size1 - 1) / chunk_size1;
+
+        int64_t dr0 = chunk_size0;
+        int64_t dr1 = chunk_size1;
+#if defined(TMAC_RECHUNK)
+        // Rechunk
+        if ((nchunk1 == 1) && (nchunk0 > nth * 4)) {
+            // dr0 should be divisible by chunk_size0
+            dr0 = (ne0 / (nth * 4) / chunk_size0) * chunk_size0;
+            nchunk0 = (nr0 + dr0 - 1) / dr0;
+        }
+#endif
+
+        int current_chunk = ith;
+
+        while (current_chunk < nchunk0 * nchunk1) {
+            const int64_t ith0 = current_chunk % nchunk0;
+            const int64_t ith1 = current_chunk / nchunk0;
+
+            const int64_t ir0_start = dr0 * ith0;
+            const int64_t ir0_end   = MIN(ir0_start + dr0, nr0);
+
+            const int64_t ir1_start = dr1 * ith1;
+            const int64_t ir1_end   = MIN(ir1_start + dr1, nr1);
+
+            // inline ggml_compute_forward_mul_mat_one_chunk here for simplicity
+            for (int64_t ichunk0 = ir0_start / chunk_size0; ichunk0 < ir0_end / chunk_size0; ichunk0++) {
+                const int64_t w_offset      = ichunk0 * w_chunk_size;
+                const int64_t scales_offset = ichunk0 * wt->scales_size / n_tile_num;
+
+                for (int64_t ine11 = ir1_start; ine11 < ir1_end; ine11++) {
+                    const int64_t qlut_offset       = ne10 * ine11 * 4;
+                    const int64_t lut_scales_offset = wt->lut_scales_size * ine11;
+                    const int64_t dst_offset        = ne0 * ine11 + ichunk0 * chunk_size0;
+
+                    ggml_tmac_mul_mat_task_compute(wt->qweights + w_offset,
+                                                   wt->scales + scales_offset,
+                                                   qlut + qlut_offset,
+                                                   lut_scales + lut_scales_offset,
+                                                   lut_biases + lut_scales_offset,
+                                                   act_output + dst_offset,
+                                                   chunk_size0, ne00, 1, bits);
+                    if (sizeof(tmac_float_type) == 2) {
+                        ggml_fp16_to_fp32_row(act_output + dst_offset, (float *) dst->data + dst_offset, chunk_size0);
+                    }
+                }
+            }
+
+            if (nth >= nchunk0 * nchunk1) {
+                break;
+            }
+
+            current_chunk = atomic_fetch_add_explicit(&params->threadpool->current_chunk, 1, memory_order_relaxed);
+        }
+#endif  // #if defined(TMAC_USE_TVM_THREADPOOL)
+        return;
+    }  // if (ggml_tmac_can_mul_mat(src0, src1, dst))
+#endif  // #if defined(GGML_USE_TMAC)
+
     if (src1->type != vec_dot_type) {
         char * wdata = params->wdata;
 
@@ -9123,6 +9297,10 @@ static void ggml_compute_forward_clamp(
         case GGML_TYPE_I32:
         case GGML_TYPE_I64:
         case GGML_TYPE_F64:
+        case GGML_TYPE_I1:
+        case GGML_TYPE_I2:
+        case GGML_TYPE_I3:
+        case GGML_TYPE_I4:
         case GGML_TYPE_COUNT:
             {
                 GGML_ABORT("fatal error");
@@ -13172,6 +13350,11 @@ struct ggml_cplan ggml_graph_plan(
                 {
                     const enum ggml_type vec_dot_type = type_traits_cpu[node->src[0]->type].vec_dot_type;
 
+#if defined(GGML_USE_TMAC)
+                    if (ggml_tmac_can_mul_mat(node->src[0], node->src[1], node)) {
+                        cur = ggml_tmac_mul_mat_get_wsize(node->src[0], node->src[1], node);
+                    } else
+#endif
                     if (node->src[1]->type != vec_dot_type) {
                         cur = ggml_row_size(vec_dot_type, ggml_nelements(node->src[1]));
                     }
diff --git a/ggml/src/ggml-quants.c b/ggml/src/ggml-quants.c
index 7aa6dce8907f5..8d828b8c0180b 100644
--- a/ggml/src/ggml-quants.c
+++ b/ggml/src/ggml-quants.c
@@ -15741,6 +15741,12 @@ bool ggml_validate_row_data(enum ggml_type type, const void * data, size_t nbyte
         case GGML_TYPE_I64:
             // nothing to validate
             break;
+        case GGML_TYPE_I1:
+        case GGML_TYPE_I2:
+        case GGML_TYPE_I3:
+        case GGML_TYPE_I4:
+            // nothing to validate
+            break;
         default:
             {
                 fprintf(stderr, "%s: invalid type %d\n", __func__, type);
diff --git a/ggml/src/ggml-tmac.cpp b/ggml/src/ggml-tmac.cpp
new file mode 100644
index 0000000000000..18d5bd3be35fd
--- /dev/null
+++ b/ggml/src/ggml-tmac.cpp
@@ -0,0 +1,526 @@
+#include <vector>
+#include <type_traits>
+
+#include "ggml-tmac.h"
+#include "ggml-quants.h"
+
+#include "t-mac/tmac_gemm_wrapper.h"
+
+#define GGML_TMAC_MAX_NODES 8192
+
+static bool initialized = false;
+
+static TMAC::TMACGeMMWrapper<tmac_float_type> * wrapper = nullptr;
+
+static tmac_tensor_extra * tmac_tensor_extras = nullptr;
+
+static size_t tmac_tensor_extras_index = 0;
+
+static void * aligned_malloc(size_t size) {
+#if defined(_WIN32)
+    return _aligned_malloc(size, TMAC::kAllocAlignment);
+#else
+    void * ptr = nullptr;
+    posix_memalign(&ptr, TMAC::kAllocAlignment, size);
+    return ptr;
+#endif
+}
+
+static void aligned_free(void * ptr) {
+#if defined(_WIN32)
+    _aligned_free(ptr);
+#else
+    free(ptr);
+#endif
+}
+
+void ggml_tmac_init(void) {
+    LOG(INFO) << "ggml_tmac_init";
+
+    if (initialized) {
+        return;
+    }
+    initialized = true;
+
+    if (wrapper == nullptr) {
+        wrapper = new TMAC::TMACGeMMWrapper<tmac_float_type>();
+    }
+    if (tmac_tensor_extras == nullptr) {
+        tmac_tensor_extras = new tmac_tensor_extra[GGML_TMAC_MAX_NODES];
+    }
+    tmac_tensor_extras_index = 0;
+}
+
+void ggml_tmac_free(void) {
+    LOG(INFO) << "ggml_tmac_free";
+
+    if (!initialized) {
+        return;
+    }
+    initialized = false;
+
+    delete wrapper;
+    wrapper = nullptr;
+    for (size_t i = 0; i < tmac_tensor_extras_index; i++) {
+        // aligned_free(tmac_tensor_extras[i].qweights);
+        // aligned_free(tmac_tensor_extras[i].scales);
+    }
+    delete[] tmac_tensor_extras;
+    tmac_tensor_extras = nullptr;
+}
+
+static bool is_type_supported(enum ggml_type type) {
+    if (type == GGML_TYPE_Q4_0 ||
+        type == GGML_TYPE_I1 ||
+        type == GGML_TYPE_I2 ||
+        type == GGML_TYPE_I3 ||
+        type == GGML_TYPE_I4 ||
+        type == GGML_TYPE_TQ1_0 ||
+        type == GGML_TYPE_TQ2_0) {
+        return true;
+    } else {
+        return false;
+    }
+}
+
+static bool do_permutate(enum ggml_type type) {
+    if (type == GGML_TYPE_I1 ||
+        type == GGML_TYPE_I2 ||
+        type == GGML_TYPE_I3 ||
+        type == GGML_TYPE_I4) {
+        // Add additional args to decide if permuted I2 or naive I2
+        return false;
+    } else {
+        return true;
+    }
+}
+
+struct BlockQ40TypeAccessor {
+    using block_t = block_q4_0;
+
+    static constexpr int BITS = 4;
+    static constexpr int SIMD_LEN = 16;
+    static constexpr int group_size = (sizeof(block_t) - sizeof(ggml_fp16_t)) * 8 / BITS;
+    static constexpr int simd_n_elem = SIMD_LEN * 8 / BITS;
+
+    static uint8_t get_q(const void * data, int idx) {
+        const uint8_t * qs = (const uint8_t *) ((((const block_t *) data)[idx / group_size]).qs);
+        int internal_idx = idx % group_size;
+        const uint8_t * simd_qs = qs + internal_idx / simd_n_elem * SIMD_LEN;
+        int simd_idx = internal_idx % simd_n_elem;
+        return simd_qs[simd_idx % SIMD_LEN] >> (simd_idx / SIMD_LEN * BITS);
+    }
+
+    static tmac_float_type get_scale(const void * data, int idx) {
+        ggml_fp16_t d = ((const block_t *) data)[idx / group_size].d;
+        if (sizeof(tmac_float_type) == 2) {
+            tmac_float_type * fp16dp = reinterpret_cast<tmac_float_type *>(&d);
+            return *fp16dp;
+        } else {
+            return ggml_fp16_to_fp32(d);
+        }
+    }
+};
+
+struct BlockI2TypeAccessor {
+    static constexpr int BITS = 2;
+    static constexpr int n_elem = 8 / BITS;
+
+    static uint8_t get_q(const void * data, int idx) {
+        const uint8_t * qs = (const uint8_t *) data;
+        int elem_idx = idx % n_elem;
+        return qs[idx / n_elem] >> (elem_idx * BITS);
+    }
+
+    static tmac_float_type get_scale(const void * data, int idx, int group_size) {
+        const float * ss = (const float *) data;
+        float s = ss[idx / group_size];
+        return (tmac_float_type) s;
+    }
+};
+
+struct BlockTQ10TypeAccessor {
+    using block_t = block_tq1_0;
+
+    static constexpr int elements_qs = 5;    // 5 elements per byte
+    static constexpr int elements_qh = 4;    // 4 elements per byte
+    static constexpr int BITS = 2;
+    static constexpr int group_size_qs = sizeof(((block_t *)0)->qs) * elements_qs;
+    static constexpr int group_size_qh = sizeof(((block_t *)0)->qh) * elements_qh;
+    static constexpr int group_size = group_size_qs + group_size_qh;
+    static constexpr int SIMD_LEN_qs_1 = 32;
+    static constexpr int SIMD_LEN_qs_2 = 16;
+    static constexpr int SIMD_LEN_qh = 4;
+    static constexpr int simd_n_elem_qs_1 = SIMD_LEN_qs_1 * elements_qs;        // 160
+    static constexpr int simd_n_elem_qs_2 = SIMD_LEN_qs_2 * elements_qs;        // 80
+    static constexpr int simd_n_elem_qh = SIMD_LEN_qh * elements_qh;            // 16
+
+    static constexpr uint8_t pow3[5] = {1, 3, 9, 27, 81};
+
+    static uint8_t get_q(const void * data, int idx) {
+        const uint8_t * qs = (const uint8_t *) ((((const block_t *) data)[idx / group_size]).qs);
+        uint8_t cur_qs;
+        uint8_t trit;
+        int internal_idx = idx % group_size;
+
+        if (internal_idx < simd_n_elem_qs_1) {
+            const int internal_offset = 0;
+            const uint8_t * simd_qs = qs + internal_offset;
+            int simd_idx = internal_idx;
+            int simd_byte = simd_idx % SIMD_LEN_qs_1;
+            int simd_trit = simd_idx / SIMD_LEN_qs_1;
+
+            cur_qs = simd_qs[simd_byte] * pow3[simd_trit];
+            trit = ((uint16_t) cur_qs * 3) >> 8;
+        }
+        else if (internal_idx < simd_n_elem_qs_1 + simd_n_elem_qs_2) {
+            const int internal_offset = SIMD_LEN_qs_1;
+            const uint8_t * simd_qs = qs + internal_offset;
+            int simd_idx = internal_idx - simd_n_elem_qs_1;
+            int simd_byte = simd_idx % SIMD_LEN_qs_2;
+            int simd_trit = simd_idx / SIMD_LEN_qs_2;
+
+            cur_qs = simd_qs[simd_byte] * pow3[simd_trit];
+            trit = ((uint16_t) cur_qs * 3) >> 8;
+        }
+        else {
+            const int internal_offset = SIMD_LEN_qs_1 + SIMD_LEN_qs_2;
+            const uint8_t * simd_qs = qs + internal_offset;
+            int simd_idx = internal_idx - simd_n_elem_qs_1 - simd_n_elem_qs_2;
+            int simd_byte = simd_idx % SIMD_LEN_qh;
+            int simd_trit = simd_idx / SIMD_LEN_qh;
+
+            cur_qs = simd_qs[simd_byte] * pow3[simd_trit];
+            trit = ((uint16_t) cur_qs * 3) >> 8;
+        }
+
+        return trit + 1;
+    }
+
+    static tmac_float_type get_scale(const void * data, int idx, int group_size) {
+        ggml_fp16_t d = ((const block_t *) data)[idx / group_size].d;
+        if (sizeof(tmac_float_type) == 2) {
+            tmac_float_type * fp16dp = reinterpret_cast<tmac_float_type *>(&d);
+            return *fp16dp;
+        } else {
+            return ggml_fp16_to_fp32(d);
+        }
+    }
+};
+
+struct BlockTQ20TypeAccessor {
+    using block_t = block_tq2_0;
+
+    static constexpr int BITS = 2;
+    static constexpr int SIMD_LEN = 32;
+    static constexpr int group_size = (sizeof(block_t) - sizeof(ggml_fp16_t)) * 8 / BITS;   // 256
+    static constexpr int simd_n_elem = SIMD_LEN * 8 / BITS;                                 // 128
+
+    static uint8_t get_q(const void * data, int idx) {
+        const uint8_t * qs = (const uint8_t *) ((((const block_t *) data)[idx / group_size]).qs);
+        int internal_idx = idx % group_size;
+        const uint8_t * simd_qs = qs + internal_idx / simd_n_elem * SIMD_LEN;
+        int simd_idx = internal_idx % simd_n_elem;
+        return (simd_qs[simd_idx % SIMD_LEN] >> (simd_idx / SIMD_LEN * BITS)) + 1;
+    }
+
+    static tmac_float_type get_scale(const void * data, int idx, int group_size) {
+        ggml_fp16_t d = ((const block_t *) data)[idx / group_size].d;
+        if (sizeof(tmac_float_type) == 2) {
+            tmac_float_type * fp16dp = reinterpret_cast<tmac_float_type *>(&d);
+            return *fp16dp;
+        } else {
+            return ggml_fp16_to_fp32(d);
+        }
+    }
+};
+
+bool ggml_tmac_can_mul_mat(const struct ggml_tensor * src0, const struct ggml_tensor * src1, const struct ggml_tensor * dst) {
+    if ((is_type_supported(src0->type)) &&
+        src1->type == GGML_TYPE_F32 &&
+        dst->type == GGML_TYPE_F32 &&
+        src0->backend == GGML_BACKEND_TYPE_CPU &&
+        strcmp(src0->name, "token_embd.weight") &&  // means not equal
+        strcmp(src0->name, "output.weight")) {
+        return true;
+    }
+    return false;
+}
+
+size_t ggml_tmac_mul_mat_get_wsize(const struct ggml_tensor * src0, const struct ggml_tensor * src1, const struct ggml_tensor * dst) {
+    const size_t ne01 = src0->ne[1];
+    const size_t ne10 = src1->ne[0];
+    const size_t ne11 = src1->ne[1];
+    const int bits = ggml_tmac_get_type_bits(src0->type);
+
+    TMAC::TMACGeMMConfig kcfg = wrapper->get_kcfg(ne01, ne10, 1, bits);
+
+    size_t wsize = ne10 * ne11 * 4 * sizeof(int8_t) + kcfg.lut_scales_size * ne11 * 2 * sizeof(tmac_float_type);
+    if (sizeof(tmac_float_type) == 2) {
+        // Need fp32 to fp16 conversion
+        wsize += std::max(ne10, ne01) * ne11 * sizeof(tmac_float_type);
+    }
+    wsize = ((wsize - 1) / TMAC::kAllocAlignment + 1) * TMAC::kAllocAlignment;
+    return wsize;
+}
+
+// m = batch_size
+// n = output_dim
+void ggml_tmac_mul_mat_task_init(void * src1, void * qlut, void * lut_scales, void * lut_biases, int n, int k, int m, int bits) {
+    // t-mac llama.cpp n and m swapped
+    wrapper->llama_cpp_init(src1, qlut, lut_scales, lut_biases, n, k, m, bits);
+}
+
+void ggml_tmac_mul_mat_task_compute(void * src0, void * scales, void * qlut, void * lut_scales, void * lut_biases, void * dst, int n, int k, int m, int bits) {
+    wrapper->llama_cpp_compute(src0, scales, qlut, lut_scales, lut_biases, dst, n, k, m, bits);
+}
+
+size_t ggml_tmac_get_nbytes(const struct ggml_tensor * tensor) {
+    const int bits = ggml_tmac_get_type_bits(tensor->type);
+
+    int k = tensor->ne[0];
+    int m = tensor->ne[1];  // `n` in llama.cpp
+
+    TMAC::TMACGeMMConfig kcfg = wrapper->get_kcfg(m, k, 1, bits);
+    // Currently, I2 always uses float to store scales or zero points
+    size_t nbytes = k * m / 8 * bits + kcfg.scales_size * sizeof(float);
+    return nbytes;
+}
+
+void ggml_tmac_transform_tensor(struct ggml_tensor * tensor) {
+    if (!(is_type_supported(tensor->type) && tensor->backend == GGML_BACKEND_TYPE_CPU && tensor->extra == nullptr)) {
+        return;
+    }
+
+    const int bits = ggml_tmac_get_type_bits(tensor->type);
+    const int g = 4;
+    const int ngroups_per_elem = 2;
+
+    int k = tensor->ne[0];
+    int m = tensor->ne[1];  // `n` in llama.cpp
+
+    TMAC::TMACGeMMConfig kcfg = wrapper->get_kcfg(m, k, 1, bits);
+    const int bm              = kcfg.bm;
+    const int simd_n_in       = kcfg.simd_n_in;
+    const int simd_n_out      = kcfg.simd_n_out;
+    const int kfactor         = kcfg.kfactor;
+    const int group_size      = kcfg.group_size;  // could be different from block size in llama.cpp
+    const int lut_scales_size = kcfg.lut_scales_size;
+    const int scales_size     = kcfg.scales_size;
+    const int n_tile_num      = kcfg.n_tile_num;
+    DLOG(INFO) << "Transforming tensor: " << tensor->name << " (m: " << m << ", k: " << k << ", bits: " << bits << ")";
+    DLOG(INFO) << "kcfg (bm=" << bm << ", simd_n_in=" << simd_n_in << ", simd_n_out=" << simd_n_out << ", kfactor=" << kfactor
+               << ", group_size=" << group_size << ", lut_scales_size=" << lut_scales_size << ", scales_size=" << scales_size << ", n_tile_num=" << n_tile_num << ")";
+    if (bm == 0) {
+        // TODO: warning token.embd if not support
+        if (!strcmp(tensor->name, "token_embd.weight") || !strcmp(tensor->name, "output.weight")) {
+            LOG(WARNING) << "Do not find kcfg for " << tensor->name << ". Consider compiling T-MAC kernel for it if vocab size is a multiply of 128 or 320, detected " << tensor->ne[1] << ".";
+            return;
+        }
+        else {
+            // Instead of fatal error, try to avoid using t-mac?
+            LOG(FATAL) << "Failed to find kcfg. Abort transforming";
+            return;
+        }
+    }
+
+    const int mgroup = ngroups_per_elem * simd_n_in;
+    m = m * bits;
+
+    uint8_t * qweights;
+    tmac_float_type * scales;
+
+    if (do_permutate(tensor->type)) {
+        scales = (tmac_float_type *) aligned_malloc(scales_size * sizeof(tmac_float_type));
+        qweights = (uint8_t *) aligned_malloc(k * m / 8);
+    } else {
+        /* scales could be either float32 or float16, so inplace cast is feasible. */
+        GGML_ASSERT(sizeof(tmac_float_type) <= sizeof(float));
+        qweights = (uint8_t *) tensor->data;
+        scales = (tmac_float_type *) (qweights + k * m / 8);
+        float * i2_scales = (float * )(qweights + k * m / 8);
+        for (int i = 0; i < scales_size; i++) {
+            scales[i] = (tmac_float_type) i2_scales[i];
+        }
+    }
+
+    tensor->extra = tmac_tensor_extras + tmac_tensor_extras_index;
+    tmac_tensor_extras[tmac_tensor_extras_index++] = {
+        /* .lut_scales_size = */ lut_scales_size,
+        /* .scales_size     = */ scales_size,
+        /* .n_tile_num      = */ n_tile_num,
+        /* .qweights        = */ qweights,
+        /* .scales          = */ scales
+    };
+
+    if (do_permutate(tensor->type)) {
+// for fast testing
+// #define TMAC_EMPTY_WEIGHTS
+#ifndef TMAC_EMPTY_WEIGHTS
+        // TODO: optimize to accelerate weights loading
+        uint8_t * buf1 = new uint8_t[m * k];
+        uint8_t * buf2 = new uint8_t[m * k / g];
+
+        // # (M // bits, K, bits)
+        // w = np.stack([(w >> ib) & 1 for ib in range(bits)], axis=-1)
+        for (int im = 0; im < m / bits; im++) {
+            for (int ik = 0; ik < k; ik++) {
+                uint8_t v;
+                if (tensor->type == GGML_TYPE_Q4_0) {
+                    v = BlockQ40TypeAccessor::get_q(tensor->data, im * k + ik);
+                } else if (tensor->type == GGML_TYPE_I2) {
+                    v = BlockI2TypeAccessor::get_q(tensor->data, im * k + ik);
+                } else if (tensor->type == GGML_TYPE_TQ1_0) {
+                    v = BlockTQ10TypeAccessor::get_q(tensor->data, im * k + ik);
+                } else if (tensor->type == GGML_TYPE_TQ2_0) {
+                    v = BlockTQ20TypeAccessor::get_q(tensor->data, im * k + ik);
+                } else {
+                    LOG(FATAL) << "Unsupported type";
+                }
+
+                for (int ib = 0; ib < bits; ib++) {
+                    buf1[im * k * bits + ik * bits + ib] = (v >> ib) & 1;
+                }
+            }
+        }
+
+        // # (M // bits, K, bits) -> (M // bits, bits, K) -> (M // bits, bits, K // g, g) -> (M // bits, bits, K // g)
+        // w = w.transpose(0, 2, 1).reshape(M // bits, bits, K // g, g)
+        // w = sum([(w[:, :, :, ig] << ig) for ig in range(g)])
+        memset(buf2, 0, m * k / g);
+        for (int im = 0; im < m / bits; im++) {
+            for (int ik = 0; ik < k; ik++) {
+                for (int ib = 0; ib < bits; ib++) {
+                    int new_im = im;
+                    int new_ib = ib;
+                    int new_ik = ik / g;
+                    int new_ig = ik % g;
+                    buf2[new_im * bits * k / g + new_ib * k / g + new_ik] += buf1[im * k * bits + ik * bits + ib] << new_ig;
+                }
+            }
+        }
+
+        // # 0, 16, 1, 17, 2, 18, 3, 19, 4, 20, 5, 21, 6, 22, 7, 23, 8, 24, 9, 25, 10, 26, 11, 27, 12, 28, 13, 29, 14, 30, 15, 31
+        // # for bits=3
+        // # bit0: [0, 8), bit1: [8, 16), bit2: [16, 24), bit0: [24, 32)
+        // # (M // bits // simd_n_float16, bits, simd_n_float16, K // g)
+        // w = w.reshape(M // bits // simd_n_out, simd_n_out, bits, K // g).transpose(0, 2, 1, 3)
+        // mgroup = ngroups_per_elem * simd_n_in
+        // w = w.reshape(M // mgroup, ngroups_per_elem, simd_n_in, K // g).transpose(0, 2, 1, 3)
+        // #             0        1             2             3                 4                  5
+        // w = w.reshape(M // bm, bm // mgroup, simd_n_in, ngroups_per_elem, K // g // kfactor, kfactor).transpose(0, 4, 1, 5, 2, 3)
+        // w = sum([(w[:, :, :, :, :, ng] << (ng * g)) for ng in range(ngroups_per_elem)])
+        memset(qweights, 0, m * k / g / ngroups_per_elem);
+        for (int im = 0; im < m / bits; im++) {
+            for (int ib = 0; ib < bits; ib++) {
+                for (int ik = 0; ik < k / g; ik++) {
+                    int new_im = im / simd_n_out;
+                    int new_isno = im % simd_n_out;
+                    int new_ib = ib;
+                    int new_ik = ik;
+                    // w = w.reshape(M // bits // simd_n_out, simd_n_out, bits, K // g).transpose(0, 2, 1, 3)
+                    int new_idx = new_im * bits * simd_n_out * k / g + new_ib * simd_n_out * k / g + new_isno * k / g + new_ik;
+                    // w = w.reshape(M // mgroup, ngroups_per_elem, simd_n_in, K // g).transpose(0, 2, 1, 3)
+                    int nb2 = k / g;
+                    int nb1 = simd_n_in * nb2;
+                    int nb0 = ngroups_per_elem * nb1;
+                    new_im = new_idx / nb0;
+                    int new_ing = (new_idx % nb0) / nb1;
+                    int new_isni = (new_idx % nb1) / nb2;
+                    new_ik = (new_idx % nb2);
+                    new_idx = new_im * ngroups_per_elem * simd_n_in * k / g + new_isni * ngroups_per_elem * k / g + new_ing * k / g + new_ik;
+                    // #             0        1             2             3                 4                  5
+                    // w = w.reshape(M // bm, bm // mgroup, simd_n_in, ngroups_per_elem, K // g // kfactor, kfactor).transpose(0, 4, 1, 5, 2, 3)
+                    int nb4 = kfactor;
+                    int nb3 = k / g / kfactor * nb4;
+                    nb2 = ngroups_per_elem * nb3;
+                    nb1 = simd_n_in * nb2;
+                    nb0 = bm / mgroup * nb1;
+                    new_im = new_idx / nb0;
+                    int new_ibm = (new_idx % nb0) / nb1;
+                    new_isni = (new_idx % nb1) / nb2;
+                    new_ing = (new_idx % nb2) / nb3;
+                    new_ik = (new_idx % nb3) / nb4;
+                    int new_ikf = (new_idx % nb4);
+                    new_idx = new_im * k / g / kfactor * bm / mgroup * kfactor * simd_n_in * ngroups_per_elem +
+                            new_ik * bm / mgroup * kfactor * simd_n_in * ngroups_per_elem +
+                            new_ibm * kfactor * simd_n_in * ngroups_per_elem +
+                            new_ikf * simd_n_in * ngroups_per_elem +
+                            new_isni * ngroups_per_elem +
+                            new_ing;
+                    new_idx = new_idx / ngroups_per_elem;
+                    // w = sum([(w[:, :, :, :, :, ng] << (ng * g)) for ng in range(ngroups_per_elem)])
+                    qweights[new_idx] += buf2[im * bits * k / g + ib * k / g + ik] << (new_ing * g);
+                }
+            }
+        }
+
+        const float * i2_scales = (const float * ) ((const uint8_t *) tensor->data + k * m / 8);
+        if (scales_size < m / bits) {  // BitNet-like scale (m_groups,)
+            for (int i = 0; i < scales_size; i++) {
+                scales[i] = (tmac_float_type) i2_scales[i];
+            }
+        } else {  // GPTQ-like scale (m / bits, k / group_size)
+            GGML_ASSERT(scales_size == m / bits * k / group_size);
+            // scales = scales.reshape(M // bm, bm // bits, K // group_size).transpose(0, 2, 1)
+            for (int im = 0; im < m / bits; im += 1) {
+                for (int ik = 0; ik < k; ik += group_size) {
+                    tmac_float_type scale;
+                    int idx = im * k + ik;
+                    if (tensor->type == GGML_TYPE_Q4_0) {
+                        scale = BlockQ40TypeAccessor::get_scale(tensor->data, idx);
+                    } else if (tensor->type == GGML_TYPE_I2) {
+                        scale = BlockI2TypeAccessor::get_scale(i2_scales, idx, group_size);
+                    } else if (tensor->type == GGML_TYPE_TQ1_0) {
+                        scale = BlockTQ10TypeAccessor::get_scale(tensor->data, idx, group_size);
+                    } else if (tensor->type == GGML_TYPE_TQ2_0) {
+                        scale = BlockTQ20TypeAccessor::get_scale(tensor->data, idx, group_size);
+                    } else {
+                        LOG(FATAL) << "Unsupported type";
+                    }
+                    int new_idx;
+                    idx = idx / group_size;
+                    int new_im = idx / (bm / bits * k / group_size);
+                    int new_ibm = (idx % (bm / bits * k / group_size)) / (k / group_size);
+                    int new_ik = (idx % (k / group_size));
+                    new_idx = new_im * k / group_size * bm / bits + new_ik * bm / bits + new_ibm;
+                    scales[new_idx] = scale;
+                }
+            }
+        }
+
+        delete[] buf1;
+        delete[] buf2;
+#else
+        memset(qweights, 0x88, k * m / 8);
+        for (int i = 0; i < scales_size; i++) {
+            scales[i] = 1.0f;
+        }
+#endif
+    }  // if (do_permutate(tensor->type))
+}
+
+int ggml_tmac_get_type_bits(enum ggml_type type) {
+    switch (type) {
+        case GGML_TYPE_I1:
+            return 1;
+        case GGML_TYPE_I2:
+            return 2;
+        case GGML_TYPE_I3:
+            return 3;
+        case GGML_TYPE_I4:
+            return 4;
+        case GGML_TYPE_Q4_0:
+            return 4;
+        case GGML_TYPE_TQ1_0:
+            return 2;
+        case GGML_TYPE_TQ2_0:
+            return 2;
+        default:
+            return 0;
+    }
+}
+
+void ggml_tmac_set_n_threads(int n_threads) {
+    wrapper->set_num_threads(n_threads);
+}
diff --git a/ggml/src/ggml.c b/ggml/src/ggml.c
index 7dc3340a1e749..1bc193bf442e7 100644
--- a/ggml/src/ggml.c
+++ b/ggml/src/ggml.c
@@ -8,6 +8,10 @@
 #include "ggml.h"
 #include "ggml-aarch64.h"
 
+#if defined(GGML_USE_TMAC)
+#include "ggml-tmac.h"
+#endif
+
 #if defined(_MSC_VER) || defined(__MINGW32__)
 #include <malloc.h> // using malloc.h with MSC/MINGW
 #elif !defined(__FreeBSD__) && !defined(__NetBSD__) && !defined(__OpenBSD__)
@@ -541,6 +545,30 @@ static void ggml_vec_dot_f16(int n, float * restrict s, size_t bs, ggml_fp16_t *
 static void ggml_vec_dot_bf16(int n, float * restrict s, size_t bs, ggml_bf16_t * restrict x, size_t bx, ggml_bf16_t * restrict y, size_t by, int nrc);
 
 static const struct ggml_type_traits type_traits[GGML_TYPE_COUNT] = {
+    [GGML_TYPE_I1] = {
+        .type_name                = "i1",
+        .blck_size                = 8,
+        .type_size                = sizeof(int8_t),
+        .is_quantized             = false,
+    },
+    [GGML_TYPE_I2] = {
+        .type_name                = "i2",
+        .blck_size                = 4,
+        .type_size                = sizeof(int8_t),
+        .is_quantized             = false,
+    },
+    [GGML_TYPE_I3] = {
+        .type_name                = "i3",
+        .blck_size                = 2,
+        .type_size                = sizeof(int8_t),
+        .is_quantized             = false,
+    },
+    [GGML_TYPE_I4] = {
+        .type_name                = "i4",
+        .blck_size                = 2,
+        .type_size                = sizeof(int8_t),
+        .is_quantized             = false,
+    },
     [GGML_TYPE_I8] = {
         .type_name                = "i8",
         .blck_size                = 1,
@@ -1161,6 +1189,14 @@ size_t ggml_nbytes(const struct ggml_tensor * tensor) {
             nbytes += (tensor->ne[i] - 1)*tensor->nb[i];
         }
     }
+#if defined(GGML_USE_TMAC)
+    if(tensor->type == GGML_TYPE_I1 ||
+       tensor->type == GGML_TYPE_I2 ||
+       tensor->type == GGML_TYPE_I3 ||
+       tensor->type == GGML_TYPE_I4){
+        nbytes = ggml_tmac_get_nbytes(tensor);
+    }
+#endif
 
     return nbytes;
 }
@@ -1417,6 +1453,11 @@ struct ggml_context * ggml_init(struct ggml_init_params params) {
             } u = {i};
             ggml_table_f32_f16[i] = GGML_COMPUTE_FP16_TO_FP32(u.fp16);
         }
+        
+#if defined(GGML_USE_TMAC)
+        ggml_tmac_init();
+#endif
+
         is_first_call = true;
     }
 
diff --git a/gguf-py/gguf/constants.py b/gguf-py/gguf/constants.py
index 7ab08b036e527..ed52e0c2100fb 100644
--- a/gguf-py/gguf/constants.py
+++ b/gguf-py/gguf/constants.py
@@ -1404,6 +1404,10 @@ class GGMLQuantizationType(IntEnum):
     Q4_0_8_8 = 33
     TQ1_0   = 34
     TQ2_0   = 35
+    I1      = 36
+    I2      = 37
+    I3      = 38
+    I4      = 39
 
 
 # TODO: add GGMLFileType from ggml_ftype in ggml.h
@@ -1450,6 +1454,7 @@ class LlamaFileType(IntEnum):
     MOSTLY_Q4_0_8_8      = 35  # except 1d tensors
     MOSTLY_TQ1_0         = 36  # except 1d tensors
     MOSTLY_TQ2_0         = 37  # except 1d tensors
+    MOSTLY_INT_N         = 38  # except 1d tensors
 
     GUESSED              = 1024  # not specified in the model file
 
@@ -1528,6 +1533,15 @@ def get_type(val: Any) -> GGUFValueType:
     GGMLQuantizationType.Q4_0_8_8:(32, 2 + 16),
     GGMLQuantizationType.TQ1_0:   (256, 2 + 4 * 13),
     GGMLQuantizationType.TQ2_0:   (256, 2 + 64),
+    # Currently, we use tricks here
+    # - The block size doesn't include scales or zero_points as group_size is changeable
+    # - So the size is slightly smaller than the real size
+    # - The n_bytes in gguf_reader.py is thus inaccurate
+    # - During inference, the accurate nbytes info will be known through ggml_tmac_get_nbytes
+    GGMLQuantizationType.I1:      (8, 1),
+    GGMLQuantizationType.I2:      (4, 1),
+    GGMLQuantizationType.I3:      (8, 3),
+    GGMLQuantizationType.I4:      (2, 1),
 }
 
 
diff --git a/gguf-py/gguf/quants.py b/gguf-py/gguf/quants.py
index 3c8ba82e19d3d..6f2c5e10d881f 100644
--- a/gguf-py/gguf/quants.py
+++ b/gguf-py/gguf/quants.py
@@ -60,6 +60,15 @@ def quantize(data: np.ndarray, qtype: GGMLQuantizationType) -> np.ndarray:
         return data.astype(np.float16, copy=False)
     elif (q := _type_traits.get(qtype)) is not None:
         return q.quantize(data)
+    # Do nothing for I1/2/3/4, as they are already quantized
+    elif qtype == GGMLQuantizationType.I1:
+        return data
+    elif qtype == GGMLQuantizationType.I2:
+        return data
+    elif qtype == GGMLQuantizationType.I3:
+        return data
+    elif qtype == GGMLQuantizationType.I4:
+        return data
     else:
         raise NotImplementedError(f"Quantization for {qtype.name} is not yet implemented")
 
diff --git a/include/llama.h b/include/llama.h
index ccb48f73cef5c..825d282d35495 100644
--- a/include/llama.h
+++ b/include/llama.h
@@ -176,6 +176,7 @@ extern "C" {
         LLAMA_FTYPE_MOSTLY_Q4_0_8_8      = 35, // except 1d tensors
         LLAMA_FTYPE_MOSTLY_TQ1_0         = 36, // except 1d tensors
         LLAMA_FTYPE_MOSTLY_TQ2_0         = 37, // except 1d tensors
+        LLAMA_FTYPE_MOSTLY_INT_N         = 38,
 
         LLAMA_FTYPE_GUESSED = 1024, // not specified in the model file
     };
diff --git a/src/llama.cpp b/src/llama.cpp
index 3e563d811b77c..7529f3e04c33a 100644
--- a/src/llama.cpp
+++ b/src/llama.cpp
@@ -9,6 +9,10 @@
 #include "ggml-backend.h"
 #include "ggml-cpp.h"
 
+#ifdef GGML_USE_TMAC
+#  include "ggml-tmac.h"
+#endif
+
 // TODO: replace with ggml API call
 #define QK_K 256
 
@@ -4434,6 +4438,10 @@ struct llama_model_loader {
                 case GGML_TYPE_Q4_0_4_4: ftype = LLAMA_FTYPE_MOSTLY_Q4_0_4_4; break;
                 case GGML_TYPE_Q4_0_4_8: ftype = LLAMA_FTYPE_MOSTLY_Q4_0_4_8; break;
                 case GGML_TYPE_Q4_0_8_8: ftype = LLAMA_FTYPE_MOSTLY_Q4_0_8_8; break;
+                case GGML_TYPE_I1:      ftype = LLAMA_FTYPE_MOSTLY_INT_N;   break;
+                case GGML_TYPE_I2:      ftype = LLAMA_FTYPE_MOSTLY_INT_N;   break;
+                case GGML_TYPE_I3:      ftype = LLAMA_FTYPE_MOSTLY_INT_N;   break;
+                case GGML_TYPE_I4:      ftype = LLAMA_FTYPE_MOSTLY_INT_N;   break;
                 default:
                     {
                         LLAMA_LOG_WARN("%s: unknown type %s\n", __func__, ggml_type_name(type_max));
@@ -5032,6 +5040,11 @@ struct llama_model_loader {
             }
 
             size_done += n_size;
+
+#if defined(GGML_USE_TMAC)
+            // Do pre-transformation to reduce first-run latency
+            ggml_tmac_transform_tensor(cur);
+#endif
         }
 
         // free temporary resources used for async uploads
@@ -5171,6 +5184,7 @@ static std::string llama_model_ftype_name(llama_ftype ftype) {
         case LLAMA_FTYPE_MOSTLY_Q5_0:     return "Q5_0";
         case LLAMA_FTYPE_MOSTLY_Q5_1:     return "Q5_1";
         case LLAMA_FTYPE_MOSTLY_Q8_0:     return "Q8_0";
+        case LLAMA_FTYPE_MOSTLY_INT_N:    return "INT_N";
         case LLAMA_FTYPE_MOSTLY_Q2_K:     return "Q2_K - Medium";
         case LLAMA_FTYPE_MOSTLY_Q2_K_S:   return "Q2_K - Small";
         case LLAMA_FTYPE_MOSTLY_Q3_K_S:   return "Q3_K - Small";
@@ -17183,6 +17197,13 @@ static void llama_graph_compute(
             ggml_cgraph * gf,
                     int   n_threads,
         ggml_threadpool * threadpool) {
+#ifdef GGML_USE_TMAC
+    #ifdef TMAC_USE_TVM_THREADPOOL
+        ggml_tmac_set_n_threads(n_threads);
+        n_threads = 1;
+    #endif
+#endif
+
     if (lctx.backend_cpu != nullptr) {
         ggml_backend_cpu_set_threadpool(lctx.backend_cpu, threadpool);
         ggml_backend_cpu_set_abort_callback(lctx.backend_cpu, lctx.abort_callback, lctx.abort_callback_data);
@@ -18747,6 +18768,13 @@ static void llama_model_quantize_internal(const std::string & fname_inp, const s
             if (params->output_tensor_type < GGML_TYPE_COUNT && strcmp(tensor->name, "output.weight") == 0) {
                 new_type = params->output_tensor_type;
             }
+            if (tensor->type == GGML_TYPE_I1 ||
+                tensor->type == GGML_TYPE_I2 ||
+                tensor->type == GGML_TYPE_I3 ||
+                tensor->type == GGML_TYPE_I4) {
+                // no need quantize for iN
+                new_type = tensor->type;
+            }
 
             // If we've decided to quantize to the same type the tensor is already
             // in then there's nothing to do.