Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
322 changes: 319 additions & 3 deletions backend/go/stablediffusion-ggml/gosd.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,9 @@
#include <time.h>
#include <string>
#include <vector>
#include <map>
#include <filesystem>
#include <algorithm>
#include "gosd.h"

#define STB_IMAGE_IMPLEMENTATION
Expand All @@ -23,6 +25,7 @@
#define STB_IMAGE_RESIZE_STATIC
#include "stb_image_resize.h"
#include <stdlib.h>
#include <regex>

// Names of the sampler method, same order as enum sample_method in stable-diffusion.h
const char* sample_method_str[] = {
Expand Down Expand Up @@ -133,6 +136,13 @@ static std::vector<sd_embedding_t> embedding_vec;
// Storage for embedding strings (needs to persist as long as embedding_vec references them)
static std::vector<std::string> embedding_strings;

// Storage for LoRAs (needs to persist for the lifetime of generation params)
static std::vector<sd_lora_t> lora_vec;
// Storage for LoRA strings (needs to persist as long as lora_vec references them)
static std::vector<std::string> lora_strings;
// Storage for lora_dir path
static std::string lora_dir_path;

// Build embeddings vector from directory, similar to upstream CLI
static void build_embedding_vec(const char* embedding_dir) {
embedding_vec.clear();
Expand Down Expand Up @@ -186,6 +196,229 @@ static void build_embedding_vec(const char* embedding_dir) {
fprintf(stderr, "Loaded %zu embeddings from %s\n", embedding_vec.size(), embedding_dir);
}

// Discover LoRA files in directory and build a map of name -> path
static std::map<std::string, std::string> discover_lora_files(const char* lora_dir) {
std::map<std::string, std::string> lora_map;

if (!lora_dir || strlen(lora_dir) == 0) {
fprintf(stderr, "LoRA directory not specified\n");
return lora_map;
}

if (!std::filesystem::exists(lora_dir) || !std::filesystem::is_directory(lora_dir)) {
fprintf(stderr, "LoRA directory does not exist or is not a directory: %s\n", lora_dir);
return lora_map;
}

static const std::vector<std::string> valid_ext = {".safetensors", ".ckpt", ".pt", ".gguf"};

fprintf(stderr, "Discovering LoRA files in: %s\n", lora_dir);

for (const auto& entry : std::filesystem::directory_iterator(lora_dir)) {
if (!entry.is_regular_file()) {
continue;
}

auto path = entry.path();
std::string ext = path.extension().string();

bool valid = false;
for (const auto& e : valid_ext) {
if (ext == e) {
valid = true;
break;
}
}
if (!valid) {
continue;
}

std::string name = path.stem().string(); // stem() already removes extension
std::string full_path = path.string();

// Store the name (without extension) -> full path mapping
// This allows users to specify just the name in <lora:name:strength>
lora_map[name] = full_path;

fprintf(stderr, "Found LoRA file: %s -> %s\n", name.c_str(), full_path.c_str());
}

fprintf(stderr, "Discovered %zu LoRA files in %s\n", lora_map.size(), lora_dir);
return lora_map;
}

// Helper function to check if a path is absolute (matches upstream)
static bool is_absolute_path(const std::string& p) {
#ifdef _WIN32
// Windows: C:/path or C:\path
return p.size() > 1 && std::isalpha(static_cast<unsigned char>(p[0])) && p[1] == ':';
#else
// Unix: /path
return !p.empty() && p[0] == '/';
#endif
}

// Parse LoRAs from prompt string (e.g., "<lora:name:1.0>" or "<lora:name>")
// Returns a vector of LoRA info and the cleaned prompt with LoRA tags removed
// Matches upstream implementation more closely
static std::pair<std::vector<sd_lora_t>, std::string> parse_loras_from_prompt(const std::string& prompt, const char* lora_dir) {
std::vector<sd_lora_t> loras;
std::string cleaned_prompt = prompt;

if (!lora_dir || strlen(lora_dir) == 0) {
fprintf(stderr, "LoRA directory not set, cannot parse LoRAs from prompt\n");
return {loras, cleaned_prompt};
}

// Discover LoRA files for name-based lookup
std::map<std::string, std::string> discovered_lora_map = discover_lora_files(lora_dir);

// Map to accumulate multipliers for the same LoRA (matches upstream)
std::map<std::string, float> lora_map;
std::map<std::string, float> high_noise_lora_map;

static const std::regex re(R"(<lora:([^:>]+):([^>]+)>)");
static const std::vector<std::string> valid_ext = {".pt", ".safetensors", ".gguf"};
std::smatch m;

std::string tmp = prompt;

fprintf(stderr, "Parsing LoRAs from prompt: %s\n", prompt.c_str());

while (std::regex_search(tmp, m, re)) {
std::string raw_path = m[1].str();
const std::string raw_mul = m[2].str();

float mul = 0.f;
try {
mul = std::stof(raw_mul);
} catch (...) {
tmp = m.suffix().str();
cleaned_prompt = std::regex_replace(cleaned_prompt, re, "", std::regex_constants::format_first_only);
fprintf(stderr, "Invalid LoRA multiplier '%s', skipping\n", raw_mul.c_str());
continue;
}

bool is_high_noise = false;
static const std::string prefix = "|high_noise|";
if (raw_path.rfind(prefix, 0) == 0) {
raw_path.erase(0, prefix.size());
is_high_noise = true;
}

std::filesystem::path final_path;
if (is_absolute_path(raw_path)) {
final_path = raw_path;
} else {
// Try name-based lookup first
auto it = discovered_lora_map.find(raw_path);
if (it != discovered_lora_map.end()) {
final_path = it->second;
} else {
// Try case-insensitive lookup
bool found = false;
for (const auto& pair : discovered_lora_map) {
std::string lower_name = raw_path;
std::string lower_key = pair.first;
std::transform(lower_name.begin(), lower_name.end(), lower_name.begin(), ::tolower);
std::transform(lower_key.begin(), lower_key.end(), lower_key.begin(), ::tolower);
if (lower_name == lower_key) {
final_path = pair.second;
found = true;
break;
}
}
if (!found) {
// Try as relative path in lora_dir
final_path = std::filesystem::path(lora_dir) / raw_path;
}
}
}

// Try adding extensions if file doesn't exist
if (!std::filesystem::exists(final_path)) {
bool found = false;
for (const auto& ext : valid_ext) {
std::filesystem::path try_path = final_path;
try_path += ext;
if (std::filesystem::exists(try_path)) {
final_path = try_path;
found = true;
break;
}
}
if (!found) {
fprintf(stderr, "WARNING: LoRA file not found: %s\n", final_path.lexically_normal().string().c_str());
tmp = m.suffix().str();
cleaned_prompt = std::regex_replace(cleaned_prompt, re, "", std::regex_constants::format_first_only);
continue;
}
}

// Normalize path (matches upstream)
const std::string key = final_path.lexically_normal().string();

// Accumulate multiplier if same LoRA appears multiple times (matches upstream)
if (is_high_noise) {
high_noise_lora_map[key] += mul;
} else {
lora_map[key] += mul;
}

fprintf(stderr, "Parsed LoRA: path='%s', multiplier=%.2f, is_high_noise=%s\n",
key.c_str(), mul, is_high_noise ? "true" : "false");

cleaned_prompt = std::regex_replace(cleaned_prompt, re, "", std::regex_constants::format_first_only);
tmp = m.suffix().str();
}

// Build final LoRA vector from accumulated maps (matches upstream)
// Store all path strings first to ensure they persist
for (const auto& kv : lora_map) {
lora_strings.push_back(kv.first);
}
for (const auto& kv : high_noise_lora_map) {
lora_strings.push_back(kv.first);
}

// Now build the LoRA vector with pointers to the stored strings
size_t string_idx = 0;
for (const auto& kv : lora_map) {
sd_lora_t item;
item.is_high_noise = false;
item.path = lora_strings[string_idx].c_str();
item.multiplier = kv.second;
loras.push_back(item);
string_idx++;
}

for (const auto& kv : high_noise_lora_map) {
sd_lora_t item;
item.is_high_noise = true;
item.path = lora_strings[string_idx].c_str();
item.multiplier = kv.second;
loras.push_back(item);
string_idx++;
}

// Clean up extra spaces
std::regex space_regex(R"(\s+)");
cleaned_prompt = std::regex_replace(cleaned_prompt, space_regex, " ");
// Trim leading/trailing spaces
size_t first = cleaned_prompt.find_first_not_of(" \t");
if (first != std::string::npos) {
cleaned_prompt.erase(0, first);
}
size_t last = cleaned_prompt.find_last_not_of(" \t");
if (last != std::string::npos) {
cleaned_prompt.erase(last + 1);
}

fprintf(stderr, "Parsed %zu LoRA(s) from prompt. Cleaned prompt: %s\n", loras.size(), cleaned_prompt.c_str());

return {loras, cleaned_prompt};
}

// Copied from the upstream CLI
static void sd_log_cb(enum sd_log_level_t level, const char* log, void* data) {
//SDParams* params = (SDParams*)data;
Expand Down Expand Up @@ -304,11 +537,17 @@ int load_model(const char *model, char *model_path, char* options[], int threads
std::filesystem::path lora_path(optval);
std::filesystem::path full_lora_path = model_path_str / lora_path;
lora_dir = strdup(full_lora_path.string().c_str());
fprintf(stderr, "Lora dir resolved to: %s\n", lora_dir);
lora_dir_path = full_lora_path.string();
fprintf(stderr, "LoRA dir resolved to: %s\n", lora_dir);
} else {
lora_dir = strdup(optval);
lora_dir_path = std::string(optval);
fprintf(stderr, "No model path provided, using lora dir as-is: %s\n", lora_dir);
}
// Discover LoRAs immediately when directory is set
if (lora_dir && strlen(lora_dir) > 0) {
discover_lora_files(lora_dir);
}
}

// New parsing
Expand Down Expand Up @@ -450,6 +689,14 @@ int load_model(const char *model, char *model_path, char* options[], int threads
ctx_params.taesd_path = taesd_path;
ctx_params.control_net_path = control_net_path;
ctx_params.lora_model_dir = lora_dir;
if (lora_dir && strlen(lora_dir) > 0) {
lora_dir_path = std::string(lora_dir);
fprintf(stderr, "LoRA model directory set to: %s\n", lora_dir);
// Discover LoRAs at load time for logging
discover_lora_files(lora_dir);
} else {
fprintf(stderr, "WARNING: LoRA model directory not set. LoRAs in prompts will not be loaded.\n");
}
// Set embeddings array and count
ctx_params.embeddings = embedding_vec.empty() ? NULL : embedding_vec.data();
ctx_params.embedding_count = static_cast<uint32_t>(embedding_vec.size());
Expand Down Expand Up @@ -546,9 +793,63 @@ sd_img_gen_params_t* sd_img_gen_params_new(void) {
return params;
}

// Storage for cleaned prompt strings (needs to persist)
static std::string cleaned_prompt_storage;
static std::string cleaned_negative_prompt_storage;

void sd_img_gen_params_set_prompts(sd_img_gen_params_t *params, const char *prompt, const char *negative_prompt) {
params->prompt = prompt;
params->negative_prompt = negative_prompt;
// Clear previous LoRA data
lora_vec.clear();
lora_strings.clear();

// Parse LoRAs from prompt
std::string prompt_str = prompt ? prompt : "";
std::string negative_prompt_str = negative_prompt ? negative_prompt : "";

// Get lora_dir from ctx_params if available, otherwise use stored path
const char* lora_dir_to_use = ctx_params.lora_model_dir;
if (!lora_dir_to_use || strlen(lora_dir_to_use) == 0) {
lora_dir_to_use = lora_dir_path.empty() ? nullptr : lora_dir_path.c_str();
}

auto [loras, cleaned_prompt] = parse_loras_from_prompt(prompt_str, lora_dir_to_use);
lora_vec = loras;
cleaned_prompt_storage = cleaned_prompt;

// Also check negative prompt for LoRAs (though this is less common)
auto [neg_loras, cleaned_negative] = parse_loras_from_prompt(negative_prompt_str, lora_dir_to_use);
// Merge negative prompt LoRAs (though typically not used)
if (!neg_loras.empty()) {
fprintf(stderr, "Note: Found %zu LoRAs in negative prompt (may not be supported)\n", neg_loras.size());
}
cleaned_negative_prompt_storage = cleaned_negative;

// Set the cleaned prompts
params->prompt = cleaned_prompt_storage.c_str();
params->negative_prompt = cleaned_negative_prompt_storage.c_str();

// Set LoRAs in params
params->loras = lora_vec.empty() ? nullptr : lora_vec.data();
params->lora_count = static_cast<uint32_t>(lora_vec.size());

fprintf(stderr, "Set prompts with %zu LoRAs. Original prompt: %s\n", lora_vec.size(), prompt ? prompt : "(null)");
fprintf(stderr, "Cleaned prompt: %s\n", cleaned_prompt_storage.c_str());

// Debug: Verify LoRAs are set correctly
if (params->loras && params->lora_count > 0) {
fprintf(stderr, "DEBUG: LoRAs set in params structure:\n");
for (uint32_t i = 0; i < params->lora_count; i++) {
fprintf(stderr, " params->loras[%u]: path='%s' (ptr=%p), multiplier=%.2f, is_high_noise=%s\n",
i,
params->loras[i].path ? params->loras[i].path : "(null)",
(void*)params->loras[i].path,
params->loras[i].multiplier,
params->loras[i].is_high_noise ? "true" : "false");
}
} else {
fprintf(stderr, "DEBUG: No LoRAs set in params structure (loras=%p, lora_count=%u)\n",
(void*)params->loras, params->lora_count);
}
}

void sd_img_gen_params_set_dimensions(sd_img_gen_params_t *params, int width, int height) {
Expand Down Expand Up @@ -740,6 +1041,20 @@ int gen_image(sd_img_gen_params_t *p, int steps, char *dst, float cfg_scale, cha
}
}

// Log LoRA information
if (p->loras && p->lora_count > 0) {
fprintf(stderr, "Using %u LoRA(s) in generation:\n", p->lora_count);
for (uint32_t i = 0; i < p->lora_count; i++) {
fprintf(stderr, " LoRA[%u]: path='%s', multiplier=%.2f, is_high_noise=%s\n",
i,
p->loras[i].path ? p->loras[i].path : "(null)",
p->loras[i].multiplier,
p->loras[i].is_high_noise ? "true" : "false");
}
} else {
fprintf(stderr, "No LoRAs specified for this generation\n");
}

fprintf(stderr, "Generating image with params: \nctx\n---\n%s\ngen\n---\n%s\n",
sd_ctx_params_to_str(&ctx_params),
sd_img_gen_params_to_str(p));
Expand Down Expand Up @@ -802,3 +1117,4 @@ int unload() {
free_sd_ctx(sd_c);
return 0;
}

Loading