|
| 1 | +# Copyright (C) 2024 Intel Corporation |
| 2 | +# SPDX-License-Identifier: Apache-2.0 |
| 3 | + |
| 4 | +import copy |
| 5 | +import os |
| 6 | + |
| 7 | +import yaml |
| 8 | + |
| 9 | + |
| 10 | +def convert_to_docker_compose(mega_yaml, output_file, device="cpu"): |
| 11 | + with open(mega_yaml, "r") as f: |
| 12 | + mega_config = yaml.safe_load(f) |
| 13 | + |
| 14 | + services = {} |
| 15 | + env_vars = mega_config.get("environment_variables", {}) |
| 16 | + |
| 17 | + # Define environment variable mapping for specific services |
| 18 | + env_var_rename = {"data_prep": {"TEI_EMBEDDING_ENDPOINT": "TEI_ENDPOINT"}} |
| 19 | + |
| 20 | + for service_name, service_config in mega_config["opea_micro_services"].items(): |
| 21 | + for container_name, container_info in service_config.items(): |
| 22 | + safe_container_name = container_name.replace("/", "-") |
| 23 | + |
| 24 | + # Initialize environment variables by combining 'common' with specific ones |
| 25 | + environment = copy.deepcopy(env_vars.get("common", {})) # Start with 'common' vars |
| 26 | + # Service-specific environment (based on anchors like redis, tei_embedding, etc.) |
| 27 | + service_envs = container_info.get("environment", {}) # The environment anchors in the YAML |
| 28 | + for key, value in service_envs.items(): |
| 29 | + environment[key] = value # Update the environment with specific variables |
| 30 | + |
| 31 | + # Apply the renaming logic using the env_var_rename mapping |
| 32 | + renamed_environment = {} |
| 33 | + for key, value in environment.items(): |
| 34 | + # If the key needs to be renamed, rename it using the mapping |
| 35 | + if key in env_var_rename.get(service_name, {}): |
| 36 | + renamed_environment[env_var_rename[service_name][key]] = value |
| 37 | + else: |
| 38 | + renamed_environment[key] = value |
| 39 | + |
| 40 | + # Replace placeholders with actual values |
| 41 | + for key in renamed_environment: |
| 42 | + if ( |
| 43 | + isinstance(renamed_environment[key], str) |
| 44 | + and renamed_environment[key].startswith("${") |
| 45 | + and renamed_environment[key].endswith("}") |
| 46 | + ): |
| 47 | + var_name = renamed_environment[key][2:-1] |
| 48 | + renamed_environment[key] = os.getenv(var_name, renamed_environment[key]) |
| 49 | + |
| 50 | + service_entry = { |
| 51 | + "image": f"{container_name}:{container_info['tag']}", |
| 52 | + "container_name": f"{safe_container_name}-server", |
| 53 | + "ports": [], |
| 54 | + "ipc": "host", |
| 55 | + "restart": "unless-stopped", |
| 56 | + "environment": renamed_environment, |
| 57 | + } |
| 58 | + |
| 59 | + # Add ports and special settings |
| 60 | + if service_name == "embedding": |
| 61 | + service_entry["ports"].append("6000:6000") |
| 62 | + elif service_name == "retrieval": |
| 63 | + service_entry["ports"].append("7000:7000") |
| 64 | + elif service_name == "reranking": |
| 65 | + service_entry["ports"].append("8000:8000") |
| 66 | + elif service_name == "llm": |
| 67 | + service_entry["ports"].append("9000:9000") |
| 68 | + |
| 69 | + # Add depends_on if necessary |
| 70 | + if container_name == "opea/dataprep-redis": |
| 71 | + service_entry["depends_on"] = ["redis-vector-db"] |
| 72 | + service_entry["ports"].append("6007:6007") |
| 73 | + elif container_name == "opea/embedding-tei": |
| 74 | + service_entry["depends_on"] = ["tei-embedding-service"] |
| 75 | + |
| 76 | + # Add volumes for specific services |
| 77 | + if "volume" in container_info: |
| 78 | + service_entry["volumes"] = container_info["volume"] |
| 79 | + |
| 80 | + services[safe_container_name] = service_entry |
| 81 | + |
| 82 | + # Additional services like redis |
| 83 | + services["redis-vector-db"] = { |
| 84 | + "image": "redis/redis-stack:7.2.0-v9", |
| 85 | + "container_name": "redis-vector-db", |
| 86 | + "ports": ["6379:6379", "8001:8001"], |
| 87 | + } |
| 88 | + |
| 89 | + # Process embedding service |
| 90 | + embedding_service = mega_config["opea_micro_services"].get("embedding", {}).get("opea/embedding-tei", {}) |
| 91 | + if embedding_service: |
| 92 | + embedding_dependencies = embedding_service.get("dependency", {}) |
| 93 | + for dep_name, dep_info in embedding_dependencies.items(): |
| 94 | + if dep_name == "ghcr.io/huggingface/text-embeddings-inference": |
| 95 | + if device == "cpu": |
| 96 | + model_id = dep_info.get("requirements", {}).get("model_id", "") |
| 97 | + services["text-embeddings-inference-service"] = { |
| 98 | + "image": f"{dep_name}:{dep_info['tag']}", |
| 99 | + "container_name": "text-embeddings-inference-server", |
| 100 | + "ports": ["8090:80"], |
| 101 | + "ipc": "host", |
| 102 | + "environment": { |
| 103 | + **env_vars.get("common", {}), |
| 104 | + "HUGGINGFACEHUB_API_TOKEN": env_vars.get("HUGGINGFACEHUB_API_TOKEN", ""), |
| 105 | + }, |
| 106 | + "command": f"--model-id {model_id} --auto-truncate", |
| 107 | + } |
| 108 | + elif dep_name == "opea/tei-gaudi": |
| 109 | + if device == "gaudi": |
| 110 | + model_id = dep_info.get("requirements", {}).get("model_id", "") |
| 111 | + services["text-embeddings-inference-service"] = { |
| 112 | + "image": f"{dep_name}:{dep_info['tag']}", |
| 113 | + "container_name": "text-embeddings-inference-server", |
| 114 | + "ports": ["8090:80"], |
| 115 | + "ipc": "host", |
| 116 | + "environment": { |
| 117 | + **env_vars.get("common", {}), |
| 118 | + "HUGGINGFACEHUB_API_TOKEN": env_vars.get("HUGGINGFACEHUB_API_TOKEN", ""), |
| 119 | + }, |
| 120 | + "command": f"--model-id {model_id} --auto-truncate", |
| 121 | + } |
| 122 | + # Add specific settings for Habana (Gaudi) devices |
| 123 | + services["text-embeddings-inference-service"]["runtime"] = "habana" |
| 124 | + services["text-embeddings-inference-service"]["cap_add"] = ["SYS_NICE"] |
| 125 | + services["text-embeddings-inference-service"]["environment"].update( |
| 126 | + { |
| 127 | + "HABANA_VISIBLE_DEVICES": "all", |
| 128 | + "OMPI_MCA_btl_vader_single_copy_mechanism": "none", |
| 129 | + "MAX_WARMUP_SEQUENCE_LENGTH": "512", |
| 130 | + "INIT_HCCL_ON_ACQUIRE": "0", |
| 131 | + "ENABLE_EXPERIMENTAL_FLAGS": "true", |
| 132 | + } |
| 133 | + ) |
| 134 | + |
| 135 | + # Reranking service handling |
| 136 | + reranking_service = mega_config["opea_micro_services"].get("reranking", {}).get("opea/reranking-tei", {}) |
| 137 | + if reranking_service: |
| 138 | + rerank_dependencies = reranking_service.get("dependency", {}) |
| 139 | + for dep_name, dep_info in rerank_dependencies.items(): |
| 140 | + if dep_name == "ghcr.io/huggingface/text-embeddings-inference": |
| 141 | + if device == "cpu": |
| 142 | + model_id = dep_info.get("requirements", {}).get("model_id", "") |
| 143 | + services["tei-reranking-service"] = { |
| 144 | + "image": f"{dep_name}:{dep_info['tag']}", |
| 145 | + "container_name": "tei-reranking-server", |
| 146 | + "ports": ["8808:80"], |
| 147 | + "volumes": ["./data:/data"], |
| 148 | + "shm_size": "1g", |
| 149 | + "environment": { |
| 150 | + **env_vars.get("common", {}), |
| 151 | + "HUGGINGFACEHUB_API_TOKEN": env_vars.get("HUGGINGFACEHUB_API_TOKEN", ""), |
| 152 | + "HF_HUB_DISABLE_PROGRESS_BARS": "1", |
| 153 | + "HF_HUB_ENABLE_HF_TRANSFER": "0", |
| 154 | + }, |
| 155 | + "command": f"--model-id {model_id} --auto-truncate", |
| 156 | + } |
| 157 | + elif dep_name == "opea/tei-gaudi": |
| 158 | + if device == "gaudi": |
| 159 | + model_id = dep_info.get("requirements", {}).get("model_id", "") |
| 160 | + services["tei-reranking-service"] = { |
| 161 | + "image": f"{dep_name}:{dep_info['tag']}", |
| 162 | + "container_name": "tei-reranking-gaudi-server", |
| 163 | + "ports": ["8808:80"], |
| 164 | + "volumes": ["./data:/data"], |
| 165 | + "shm_size": "1g", |
| 166 | + "environment": { |
| 167 | + **env_vars.get("common", {}), |
| 168 | + "HUGGINGFACEHUB_API_TOKEN": env_vars.get("HUGGINGFACEHUB_API_TOKEN", ""), |
| 169 | + "HF_HUB_DISABLE_PROGRESS_BARS": "1", |
| 170 | + "HF_HUB_ENABLE_HF_TRANSFER": "0", |
| 171 | + }, |
| 172 | + "command": f"--model-id {model_id} --auto-truncate", |
| 173 | + } |
| 174 | + # Add specific settings for Habana (Gaudi) devices |
| 175 | + services["tei-reranking-service"]["runtime"] = "habana" |
| 176 | + services["tei-reranking-service"]["cap_add"] = ["SYS_NICE"] |
| 177 | + services["tei-reranking-service"]["environment"].update( |
| 178 | + { |
| 179 | + "HABANA_VISIBLE_DEVICES": "all", |
| 180 | + "OMPI_MCA_btl_vader_single_copy_mechanism": "none", |
| 181 | + "MAX_WARMUP_SEQUENCE_LENGTH": "512", |
| 182 | + "INIT_HCCL_ON_ACQUIRE": "0", |
| 183 | + "ENABLE_EXPERIMENTAL_FLAGS": "true", |
| 184 | + } |
| 185 | + ) |
| 186 | + |
| 187 | + # LLM service |
| 188 | + llm_service = mega_config["opea_micro_services"].get("llm", {}).get("opea/llm-tgi", {}) |
| 189 | + if llm_service: |
| 190 | + llm_dependencies = llm_service.get("dependency", {}) |
| 191 | + for dep_name, dep_info in llm_dependencies.items(): |
| 192 | + if dep_name == "ghcr.io/huggingface/text-generation-inference": |
| 193 | + if device == "cpu": |
| 194 | + model_id = dep_info.get("requirements", {}).get("model_id", "") |
| 195 | + services["llm-service"] = { |
| 196 | + "image": f"{dep_name}:{dep_info['tag']}", |
| 197 | + "container_name": "llm-server", |
| 198 | + "ports": ["9001:80"], |
| 199 | + "environment": { |
| 200 | + **env_vars.get("common", {}), |
| 201 | + "HUGGINGFACEHUB_API_TOKEN": env_vars.get("HUGGINGFACEHUB_API_TOKEN", ""), |
| 202 | + }, |
| 203 | + "command": f"--model-id {model_id} --max-input-length 1024 --max-total-tokens 2048", |
| 204 | + } |
| 205 | + elif dep_name == "ghcr.io/huggingface/tgi-gaudi": |
| 206 | + if device == "gaudi": |
| 207 | + model_id = dep_info.get("requirements", {}).get("model_id", "") |
| 208 | + services["llm-service"] = { |
| 209 | + "image": f"{dep_name}:{dep_info['tag']}", |
| 210 | + "container_name": "llm-server", |
| 211 | + "ports": ["9001:80"], |
| 212 | + "environment": { |
| 213 | + **env_vars.get("common", {}), |
| 214 | + "HUGGINGFACEHUB_API_TOKEN": env_vars.get("HUGGINGFACEHUB_API_TOKEN", ""), |
| 215 | + }, |
| 216 | + "command": f"--model-id {model_id} --max-input-length 1024 --max-total-tokens 2048", |
| 217 | + } |
| 218 | + # Add specific settings for Habana (Gaudi) devices |
| 219 | + services["llm-service"]["runtime"] = "habana" |
| 220 | + services["llm-service"]["cap_add"] = ["SYS_NICE"] |
| 221 | + services["llm-service"]["environment"].update( |
| 222 | + { |
| 223 | + "HABANA_VISIBLE_DEVICES": "all", |
| 224 | + "OMPI_MCA_btl_vader_single_copy_mechanism": "none", |
| 225 | + } |
| 226 | + ) |
| 227 | + |
| 228 | + # Extract configuration for all examples from 'opea_mega_service' |
| 229 | + examples = ["chatqna", "faqgen", "audioqna", "visualqna", "codegen", "codetrans"] |
| 230 | + for example in examples: |
| 231 | + service_name = f"opea/{example}" |
| 232 | + ui_service_name = f"opea/{example}-ui" |
| 233 | + |
| 234 | + # Process both the main service and the UI service |
| 235 | + for service in [service_name, ui_service_name]: |
| 236 | + # Check if the service exists in the mega.yaml |
| 237 | + if service in mega_config.get("opea_mega_service", {}): |
| 238 | + service_config = mega_config["opea_mega_service"][service] |
| 239 | + container_name = service |
| 240 | + safe_container_name = container_name.replace("/", "-") |
| 241 | + tag = service_config.get("tag", "latest") |
| 242 | + environment = {**env_vars.get("common", {}), **service_config.get("environment", {})} |
| 243 | + |
| 244 | + service_entry = { |
| 245 | + "image": f"{container_name}:{tag}", |
| 246 | + "container_name": f"{safe_container_name}-server", |
| 247 | + "ports": ["5173:5173"] if "-ui" in service else ["8888:8888"], |
| 248 | + "ipc": "host", |
| 249 | + "restart": "unless-stopped", |
| 250 | + "environment": environment, |
| 251 | + } |
| 252 | + services[safe_container_name] = service_entry |
| 253 | + |
| 254 | + docker_compose = { |
| 255 | + "version": "3.8", |
| 256 | + "services": services, |
| 257 | + "networks": {"default": {"driver": "bridge"}}, |
| 258 | + } |
| 259 | + |
| 260 | + # Write to docker-compose.yaml |
| 261 | + with open(output_file, "w") as f: |
| 262 | + yaml.dump(docker_compose, f, default_flow_style=False) |
| 263 | + |
| 264 | + print("Docker Compose file generated:", output_file) |
0 commit comments