-
Notifications
You must be signed in to change notification settings - Fork 125
Add prestartup script #1136
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Add prestartup script #1136
Changes from 2 commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,17 @@ | ||
import os | ||
import sys | ||
import importlib | ||
|
||
|
||
ONEDIFF_COMFY_NODES_DIR = os.path.dirname(os.path.abspath(__file__)) | ||
ONEDIFF_COMFY_PRESTARTUP_SCRIPTS_DIR = os.path.join( | ||
ONEDIFF_COMFY_NODES_DIR, "prestartup_scripts" | ||
) | ||
|
||
sys.path.append(ONEDIFF_COMFY_NODES_DIR) | ||
|
||
for filename in sorted(os.listdir(ONEDIFF_COMFY_PRESTARTUP_SCRIPTS_DIR)): | ||
if filename.endswith(".py") and filename[0] != "_": | ||
importlib.import_module(f"prestartup_scripts.{filename[:-3]}") | ||
elif filename.endswith(".so"): | ||
importlib.import_module(f"prestartup_scripts.{filename.split('.')[0]}") |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,5 @@ | ||
try: | ||
import torch_gcu | ||
import torch_gcu.transfer_to_gcu | ||
except: | ||
pass | ||
Original file line number | Diff line number | Diff line change | ||||||||||||||||||||||||||||||||||
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
@@ -0,0 +1,59 @@ | ||||||||||||||||||||||||||||||||||||
_IS_NPU_AVAILABLE = False | ||||||||||||||||||||||||||||||||||||
try: | ||||||||||||||||||||||||||||||||||||
import torch_npu | ||||||||||||||||||||||||||||||||||||
from torch_npu.contrib import transfer_to_npu | ||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||
_IS_NPU_AVAILABLE = True | ||||||||||||||||||||||||||||||||||||
except: | ||||||||||||||||||||||||||||||||||||
pass | ||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||
Comment on lines
+1
to
+9
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 🛠️ Refactor suggestion Improve error handling and imports Several improvements needed in the NPU availability check:
Apply this diff: _IS_NPU_AVAILABLE = False
try:
import torch_npu
- from torch_npu.contrib import transfer_to_npu
_IS_NPU_AVAILABLE = True
-except:
+except ImportError as e:
+ import logging
+ logging.info(f"NPU support not available: {e}")
pass 📝 Committable suggestion
Suggested change
🧰 Tools🪛 Ruff4-4: Remove unused import: (F401) 7-7: Do not use bare (E722) |
||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||
if _IS_NPU_AVAILABLE: | ||||||||||||||||||||||||||||||||||||
import comfy | ||||||||||||||||||||||||||||||||||||
from comfy.model_management import ( | ||||||||||||||||||||||||||||||||||||
is_device_cpu, | ||||||||||||||||||||||||||||||||||||
is_intel_xpu, | ||||||||||||||||||||||||||||||||||||
ENABLE_PYTORCH_ATTENTION, | ||||||||||||||||||||||||||||||||||||
) | ||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||
torch_npu.npu.set_compile_mode(jit_compile=False) | ||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||
def patch_pytorch_attention_flash_attention(): | ||||||||||||||||||||||||||||||||||||
if ENABLE_PYTORCH_ATTENTION: | ||||||||||||||||||||||||||||||||||||
return True | ||||||||||||||||||||||||||||||||||||
return False | ||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||
def patch_get_free_memory(dev=None, torch_free_too=False): | ||||||||||||||||||||||||||||||||||||
# stats = torch.npu.memory_stats(dev) | ||||||||||||||||||||||||||||||||||||
# mem_active = stats['active_bytes.all.current'] | ||||||||||||||||||||||||||||||||||||
# mem_reserved = stats['reserved_bytes.all.current'] | ||||||||||||||||||||||||||||||||||||
# mem_free_npu, _ = torch.npu.mem_get_info(dev) | ||||||||||||||||||||||||||||||||||||
# mem_free_torch = mem_reserved - mem_active | ||||||||||||||||||||||||||||||||||||
# mem_free_total = mem_free_npu + mem_free_torch | ||||||||||||||||||||||||||||||||||||
mem_free_total = 48 * 1024 * 1024 * 1024 # TODO | ||||||||||||||||||||||||||||||||||||
mem_free_torch = mem_free_total | ||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||
if torch_free_too: | ||||||||||||||||||||||||||||||||||||
return (mem_free_total, mem_free_torch) | ||||||||||||||||||||||||||||||||||||
else: | ||||||||||||||||||||||||||||||||||||
return mem_free_total | ||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||
Comment on lines
+26
to
+40
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 💡 Codebase verification Critical: Implement proper NPU memory management The review comment is correct. The code currently uses a hardcoded 48GB memory value instead of properly calculating available NPU memory. The suggested fix in the review comment is appropriate because:
The proposed diff correctly:
🔗 Analysis chainCritical: Replace hardcoded memory value The function has several issues:
Please implement proper memory calculation:
def patch_get_free_memory(dev=None, torch_free_too=False):
- # stats = torch.npu.memory_stats(dev)
- # mem_active = stats['active_bytes.all.current']
- # mem_reserved = stats['reserved_bytes.all.current']
- # mem_free_npu, _ = torch.npu.mem_get_info(dev)
- # mem_free_torch = mem_reserved - mem_active
- # mem_free_total = mem_free_npu + mem_free_torch
- mem_free_total = 48 * 1024 * 1024 * 1024 # TODO
- mem_free_torch = mem_free_total
+ try:
+ stats = torch.npu.memory_stats(dev)
+ mem_active = stats['active_bytes.all.current']
+ mem_reserved = stats['reserved_bytes.all.current']
+ mem_free_npu, _ = torch.npu.mem_get_info(dev)
+ mem_free_torch = max(0, mem_reserved - mem_active)
+ mem_free_total = max(0, mem_free_npu + mem_free_torch)
+ except Exception as e:
+ import logging
+ logging.warning(f"Failed to get NPU memory info: {e}")
+ # Fallback to a conservative estimate or raise error
+ raise RuntimeError("Cannot determine available NPU memory") Let me generate additional scripts to gather more context about the memory management implementation. The first command will help us understand the original implementation that this patch is trying to override, and the second command will help identify any memory-related configurations or constants that might be relevant to this implementation. Let me try one more search to find any NPU-specific memory management code or configurations. This will help us understand if there are any NPU-specific memory management patterns or if the hardcoded value is used elsewhere in the codebase. 🏁 Scripts executedThe following scripts were executed for the analysis: Script: #!/bin/bash
# Search for other memory-related implementations that might help
rg -l "mem_get_info|memory_stats" --type py
Length of output: 4097 Script: #!/bin/bash
# Search for the original implementation or similar memory management code
rg -B 5 -A 5 "def get_free_memory" --type py
# Look for any configuration or constants related to memory management
rg -l "MEMORY|memory" --type py
Length of output: 1017 Script: #!/bin/bash
# Look for NPU-specific memory management
rg -B 3 -A 3 "torch.npu" --type py
# Check for any memory-related constants or configurations in NPU files
rg "1024.*1024.*1024" --type py
Length of output: 2176 |
||||||||||||||||||||||||||||||||||||
def patch_should_use_fp16( | ||||||||||||||||||||||||||||||||||||
device=None, model_params=0, prioritize_performance=True, manual_cast=False | ||||||||||||||||||||||||||||||||||||
): | ||||||||||||||||||||||||||||||||||||
if device is not None: | ||||||||||||||||||||||||||||||||||||
if is_device_cpu(device): | ||||||||||||||||||||||||||||||||||||
return False | ||||||||||||||||||||||||||||||||||||
return True | ||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||
def patch_should_use_bf16( | ||||||||||||||||||||||||||||||||||||
device=None, model_params=0, prioritize_performance=True, manual_cast=False | ||||||||||||||||||||||||||||||||||||
): | ||||||||||||||||||||||||||||||||||||
return False | ||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||
comfy.model_management.pytorch_attention_flash_attention = ( | ||||||||||||||||||||||||||||||||||||
patch_pytorch_attention_flash_attention | ||||||||||||||||||||||||||||||||||||
) | ||||||||||||||||||||||||||||||||||||
comfy.model_management.get_free_memory = patch_get_free_memory | ||||||||||||||||||||||||||||||||||||
comfy.model_management.should_use_fp16 = patch_should_use_fp16 | ||||||||||||||||||||||||||||||||||||
comfy.model_management.should_use_bf16 = patch_should_use_bf16 |
Uh oh!
There was an error while loading. Please reload this page.