@@ -725,12 +725,11 @@ def _load_state_dict_into_meta_model(
725725 device_map_regex = "|" .join ([re .escape (k ) for k in sorted (device_map .keys (), reverse = True )])
726726
727727 is_quantized = hf_quantizer is not None
728- is_hqq_or_bnb_or_quark = is_quantized and hf_quantizer .quantization_config .quant_method in {
728+ is_hqq_or_bnb = is_quantized and hf_quantizer .quantization_config .quant_method in {
729729 QuantizationMethod .HQQ ,
730730 QuantizationMethod .BITS_AND_BYTES ,
731- QuantizationMethod .QUARK ,
732731 }
733- is_meta_state_dict = shard_file .endswith (".safetensors" ) and not is_hqq_or_bnb_or_quark
732+ is_meta_state_dict = shard_file .endswith (".safetensors" ) and not is_hqq_or_bnb
734733 file_pointer = None
735734 if is_meta_state_dict :
736735 file_pointer = safe_open (shard_file , framework = "pt" , device = tensor_device )
@@ -4701,10 +4700,9 @@ def _load_pretrained_model(
47014700 QuantizationMethod .HQQ ,
47024701 QuantizationMethod .QUARK ,
47034702 }
4704- is_hqq_or_bnb_or_quark = is_quantized and hf_quantizer .quantization_config .quant_method in {
4703+ is_hqq_or_bnb = is_quantized and hf_quantizer .quantization_config .quant_method in {
47054704 QuantizationMethod .HQQ ,
47064705 QuantizationMethod .BITS_AND_BYTES ,
4707- QuantizationMethod .QUARK ,
47084706 }
47094707
47104708 # Get all the keys of the state dicts that we have to initialize the model
@@ -4881,7 +4879,7 @@ def _load_pretrained_model(
48814879 map_location = "cpu"
48824880 if (
48834881 shard_file .endswith (".safetensors" )
4884- and not is_hqq_or_bnb_or_quark
4882+ and not is_hqq_or_bnb
48854883 and not (is_deepspeed_zero3_enabled () and not is_quantized )
48864884 ):
48874885 map_location = "meta"
0 commit comments