import os # --- Environment Variables for Debugging and Optimization --- # These settings are preserved from the original script to provide detailed logs # and enable optimizations within PyTorch, ZenTorch, Quark, and underlying libraries. # Enhanced logging configuration for ZenTorch 5.1 #os.environ["TORCHDYNAMO_VERBOSE"] = "1" #os.environ["TRANSFORMERS_VERBOSITY"] = "debug" #os.environ["TRANSFORMERS_DEBUG"] = "1" #os.environ["TORCH_LOGS"] = "+dynamo,graph_code,graph_breaks" #os.environ["TORCH_COMPILE_DEBUG"] = "1" # ZenTorch 5.1 specific environment variables os.environ["ZENTORCH_VERSION"] = "5.1" #os.environ["ZENTORCH_PY_LOG_LEVEL"] = "DEBUG" os.environ["ZENTORCH_VERBOSE"] = "1" #os.environ["ZENTORCH_LOG_LEVEL"] = "debug" os.environ["ZENTORCH_OPTIMIZATION_LEVEL"] = "3" #os.environ["ZENTORCH_ENABLE_GRAPH_OPTIMIZATION"] = "1" # ZenDNN configuration os.environ["ZENDNN_PRIMITIVE_LOG_ENABLE"] = "1" os.environ["ZENDNN_LOG_OPTS"] = "ALL:3" os.environ["ZENDNN_VERBOSE"] = "1" os.environ["ZENDNN_DEBUG"] = "1" #os.environ["ZENDNN_GEMM_ALGO"] = "3" os.environ["ZENDNN_MATMUL_ALGO"] = "INT8:4,FP32:3" #"BF16:0" # BLIS configuration os.environ["BLIS_VERBOSE"] = "1" os.environ["BLIS_NUM_THREADS"] = "64" os.environ["OMP_NUM_THREADS"] = "64" # Quark debugging and optimization os.environ["QUARK_DEBUG"] = "1" os.environ["QUARK_DEBUG_NAN"] = "1" os.environ["QUARK_OPTIMIZATION_LEVEL"] = "2" os.environ["QUARK_ENABLE_PROFILING"] = "1" os.environ["QUARK_LOG_LEVEL"] = "DEBUG" # --- Main Script Imports --- try: import quark from quark.torch.quantization import QuantizationConfig print("✓ Quark imports successful") except ImportError as e: print(f"⚠ Quark import failed: {e}. Please ensure Quark is installed correctly.") # Exit if Quark is not available, as it's critical for this script. exit() from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer import torch import zentorch import time import numpy as np # Verify library versions print(f"PyTorch version: {torch.__version__}") print(f"ZenTorch version: {zentorch.__version__}") print(f"Transformers version: {getattr(__import__('transformers'), '__version__', 'unknown')}") # --- Configuration --- #MODEL_DIR = "/models" # Directory containing the base model files MODEL_DIR = "/media/nvme1/models/hf/int8/epoch_0/" torch_device = torch.device("cpu") def load_quantized_model(): """ Load a pre-quantized model from disk. The quantization configuration is read directly from the model's config.json file. """ try: print("Loading pre-quantized model...") # The transformers library will read the quantization_config from config.json # and load the model accordingly. We specify torch_dtype=torch.bfloat16 for # the parts of the model that are not quantized (e.g., embeddings) and for # the activation compute data type. model = AutoModelForCausalLM.from_pretrained( MODEL_DIR, torch_dtype=torch.float32, device_map="cpu", #"auto" if torch.cuda.is_available() else "cpu", trust_remote_code=False ) model.eval() print("✓ Pre-quantized model loaded successfully.") return model except Exception as e: print(f"Error loading pre-quantized model: {e}") print("Could not load the model. Please check the model files and config in /models.") exit(1) # Exit with an error code def optimize_with_zentorch(model): """Apply ZenTorch 5.1 optimizations to the quantized model.""" print("Applying ZenTorch 5.1 optimizations...") try: # The unsupported keyword arguments have been removed. # ZenTorch will proceed with its default optimizations for the given dtype. optimized_model = zentorch.llm.optimize( model, dtype=torch.float32) print("✓ ZenTorch optimization completed successfully") return optimized_model except Exception as e: print(f"⚠ ZenTorch optimization failed: {e}") print("Exiting as ZenTorch optimization is required and fallback is disabled.") exit(1) def main(): """Main execution function.""" # Load the pre-quantized model q_model = load_quantized_model() # Load tokenizer print("Loading tokenizer...") tokenizer = AutoTokenizer.from_pretrained(MODEL_DIR, trust_remote_code=False) # Add padding token if it's missing to avoid errors during tokenization if tokenizer.pad_token is None: tokenizer.pad_token = tokenizer.eos_token # Optimize the quantized model with ZenTorch 5.1 q_model = optimize_with_zentorch(q_model) # Prepare input for inference question = "What is 2+2?" print(f"Processing question: {question}") inputs = tokenizer( question, return_tensors="pt", padding=True, truncation=True ) input_ids = inputs["input_ids"].to(torch_device) attention_mask = inputs.get("attention_mask", torch.ones_like(input_ids)).to(torch_device) # --- Performance timing and Inference --- start_time = time.time() # Generate response using the optimized W8A8 model # with torch.no_grad(): with torch.inference_mode(), torch.no_grad(): q_model.forward = torch.compile(q_model.forward, backend="zentorch") output = q_model.generate( input_ids=input_ids, attention_mask=attention_mask, max_new_tokens=150, do_sample=False, pad_token_id=tokenizer.eos_token_id, use_cache=True, ) end_time = time.time() # --- Decode and display results --- response = tokenizer.decode(output[0], skip_special_tokens=True) # --- UPDATED PART: Calculate and Display TPS --- elapsed_time = end_time - start_time num_input_tokens = input_ids.shape[1] num_output_tokens = output.shape[1] num_new_tokens = num_output_tokens - num_input_tokens tokens_per_second = num_new_tokens / elapsed_time # --- END OF UPDATE --- print(f"\n--- Results ---") print(f"Prompt: {question}") print(f"Response: {response}") print(f"Generation time: {end_time - start_time:.2f} seconds") # --- UPDATED PART: Print Performance Metrics --- print("\n--- Performance ---") print(f"Generated {num_new_tokens} new tokens in {elapsed_time:.2f} seconds.") print(f"Tokens Per Second (TPS): {tokens_per_second:.2f}") if torch.cuda.is_available(): print(f"GPU Memory allocated: {torch.cuda.memory_allocated() / 1024**3:.2f} GB") print("\n✓ W8A8 inference completed successfully.") if __name__ == "__main__": main()