Skip to content

Commit a054fae

Browse files
damiandamian0815
damian
authored andcommitted
attention maps and tokens being sent to web UI
1 parent 33182af commit a054fae

File tree

5 files changed

+60
-45
lines changed

5 files changed

+60
-45
lines changed

backend/invoke_ai_web_server.py

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -18,8 +18,9 @@
1818
from uuid import uuid4
1919
from threading import Event
2020

21+
from ldm.generate import Generate
2122
from ldm.invoke.args import Args, APP_ID, APP_VERSION, calculate_init_img_hash
22-
from ldm.invoke.conditioning import get_tokens_for_prompt
23+
from ldm.invoke.conditioning import get_tokens_for_prompt, get_prompt_structure
2324
from ldm.invoke.pngwriter import PngWriter, retrieve_metadata
2425
from ldm.invoke.prompt_parser import split_weighted_subprompts
2526
from ldm.invoke.generator.inpaint import infill_methods
@@ -40,7 +41,7 @@
4041

4142

4243
class InvokeAIWebServer:
43-
def __init__(self, generate, gfpgan, codeformer, esrgan) -> None:
44+
def __init__(self, generate: Generate, gfpgan, codeformer, esrgan) -> None:
4445
self.host = args.host
4546
self.port = args.port
4647

@@ -1092,8 +1093,10 @@ def image_done(image, seed, first_seed, attention_maps_image=None):
10921093
self.socketio.emit("progressUpdate", progress.to_formatted_dict())
10931094
eventlet.sleep(0)
10941095

1095-
attention_maps_image_base64_url, tokens = (None, None) if attention_maps_image is None \
1096-
else image_to_dataURL(attention_maps_image), get_tokens_for_prompt(generation_parameters["prompt"])
1096+
parsed_prompt, _ = get_prompt_structure(generation_parameters["prompt"])
1097+
tokens = get_tokens_for_prompt(self.generate.model, parsed_prompt)
1098+
attention_maps_image_base64_url = None if attention_maps_image is None \
1099+
else image_to_dataURL(attention_maps_image)
10971100

10981101
self.socketio.emit(
10991102
"generationResult",
@@ -1108,7 +1111,7 @@ def image_done(image, seed, first_seed, attention_maps_image=None):
11081111
"boundingBox": original_bounding_box,
11091112
"generationMode": generation_parameters["generation_mode"],
11101113
"attentionMaps": attention_maps_image_base64_url,
1111-
"tokens": tokens
1114+
"tokens": tokens,
11121115
},
11131116
)
11141117
eventlet.sleep(0)

ldm/generate.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -485,7 +485,7 @@ def process_image(image,seed):
485485
'extractor':self.safety_feature_extractor
486486
} if self.safety_checker else None
487487

488-
results, attention_maps_images = generator.generate(
488+
results = generator.generate(
489489
prompt,
490490
iterations=iterations,
491491
seed=self.seed,

ldm/invoke/CLI.py

Lines changed: 25 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
import traceback
99
import yaml
1010

11+
from ldm.generate import Generate
1112
from ldm.invoke.globals import Globals
1213
from ldm.invoke.prompt_parser import PromptParser
1314
from ldm.invoke.readline import get_completer, Completer
@@ -27,7 +28,7 @@ def main():
2728
"""Initialize command-line parsers and the diffusion model"""
2829
global infile
2930
print('* Initializing, be patient...')
30-
31+
3132
opt = Args()
3233
args = opt.parse_args()
3334
if not args:
@@ -47,7 +48,7 @@ def main():
4748
# alert - setting globals here
4849
Globals.root = os.path.expanduser(args.root_dir or os.environ.get('INVOKEAI_ROOT') or os.path.abspath('.'))
4950
Globals.try_patchmatch = args.patchmatch
50-
51+
5152
print(f'>> InvokeAI runtime directory is "{Globals.root}"')
5253

5354
# loading here to avoid long delays on startup
@@ -281,7 +282,7 @@ def main_loop(gen, opt):
281282
prefix = file_writer.unique_prefix()
282283
step_callback = make_step_callback(gen, opt, prefix) if opt.save_intermediates > 0 else None
283284

284-
def image_writer(image, seed, upscaled=False, first_seed=None, use_prefix=None, prompt_in=None):
285+
def image_writer(image, seed, upscaled=False, first_seed=None, use_prefix=None, prompt_in=None, attention_maps_image=None):
285286
# note the seed is the seed of the current image
286287
# the first_seed is the original seed that noise is added to
287288
# when the -v switch is used to generate variations
@@ -341,8 +342,8 @@ def image_writer(image, seed, upscaled=False, first_seed=None, use_prefix=None,
341342
filename,
342343
tool,
343344
formatted_dream_prompt,
344-
)
345-
345+
)
346+
346347
if (not postprocessed) or opt.save_original:
347348
# only append to results if we didn't overwrite an earlier output
348349
results.append([path, formatted_dream_prompt])
@@ -432,7 +433,7 @@ def do_command(command:str, gen, opt:Args, completer) -> tuple:
432433
add_embedding_terms(gen, completer)
433434
completer.add_history(command)
434435
operation = None
435-
436+
436437
elif command.startswith('!models'):
437438
gen.model_cache.print_models()
438439
completer.add_history(command)
@@ -533,7 +534,7 @@ def add_weights_to_config(model_path:str, gen, opt, completer):
533534

534535
completer.complete_extensions(('.yaml','.yml'))
535536
completer.linebuffer = 'configs/stable-diffusion/v1-inference.yaml'
536-
537+
537538
done = False
538539
while not done:
539540
new_config['config'] = input('Configuration file for this model: ')
@@ -564,7 +565,7 @@ def add_weights_to_config(model_path:str, gen, opt, completer):
564565
print('** Please enter a valid integer between 64 and 2048')
565566

566567
make_default = input('Make this the default model? [n] ') in ('y','Y')
567-
568+
568569
if write_config_file(opt.conf, gen, model_name, new_config, make_default=make_default):
569570
completer.add_model(model_name)
570571

@@ -577,14 +578,14 @@ def del_config(model_name:str, gen, opt, completer):
577578
gen.model_cache.commit(opt.conf)
578579
print(f'** {model_name} deleted')
579580
completer.del_model(model_name)
580-
581+
581582
def edit_config(model_name:str, gen, opt, completer):
582583
config = gen.model_cache.config
583-
584+
584585
if model_name not in config:
585586
print(f'** Unknown model {model_name}')
586587
return
587-
588+
588589
print(f'\n>> Editing model {model_name} from configuration file {opt.conf}')
589590

590591
conf = config[model_name]
@@ -597,10 +598,10 @@ def edit_config(model_name:str, gen, opt, completer):
597598
make_default = input('Make this the default model? [n] ') in ('y','Y')
598599
completer.complete_extensions(None)
599600
write_config_file(opt.conf, gen, model_name, new_config, clobber=True, make_default=make_default)
600-
601+
601602
def write_config_file(conf_path, gen, model_name, new_config, clobber=False, make_default=False):
602603
current_model = gen.model_name
603-
604+
604605
op = 'modify' if clobber else 'import'
605606
print('\n>> New configuration:')
606607
if make_default:
@@ -623,7 +624,7 @@ def write_config_file(conf_path, gen, model_name, new_config, clobber=False, mak
623624
gen.model_cache.set_default_model(model_name)
624625

625626
gen.model_cache.commit(conf_path)
626-
627+
627628
do_switch = input(f'Keep model loaded? [y]')
628629
if len(do_switch)==0 or do_switch[0] in ('y','Y'):
629630
pass
@@ -653,7 +654,7 @@ def do_postprocess (gen, opt, callback):
653654
opt.prompt = opt.new_prompt
654655
else:
655656
opt.prompt = None
656-
657+
657658
if os.path.dirname(file_path) == '': #basename given
658659
file_path = os.path.join(opt.outdir,file_path)
659660

@@ -718,7 +719,7 @@ def add_postprocessing_to_metadata(opt,original_file,new_file,tool,command):
718719
)
719720
meta['image']['postprocessing'] = pp
720721
write_metadata(new_file,meta)
721-
722+
722723
def prepare_image_metadata(
723724
opt,
724725
prefix,
@@ -789,28 +790,28 @@ def get_next_command(infile=None) -> str: # command string
789790
print(f'#{command}')
790791
return command
791792

792-
def invoke_ai_web_server_loop(gen, gfpgan, codeformer, esrgan):
793+
def invoke_ai_web_server_loop(gen: Generate, gfpgan, codeformer, esrgan):
793794
print('\n* --web was specified, starting web server...')
794795
from backend.invoke_ai_web_server import InvokeAIWebServer
795796
# Change working directory to the stable-diffusion directory
796797
os.chdir(
797798
os.path.abspath(os.path.join(os.path.dirname(__file__), '..'))
798799
)
799-
800+
800801
invoke_ai_web_server = InvokeAIWebServer(generate=gen, gfpgan=gfpgan, codeformer=codeformer, esrgan=esrgan)
801802

802803
try:
803804
invoke_ai_web_server.run()
804805
except KeyboardInterrupt:
805806
pass
806-
807+
807808
def add_embedding_terms(gen,completer):
808809
'''
809810
Called after setting the model, updates the autocompleter with
810811
any terms loaded by the embedding manager.
811812
'''
812813
completer.add_embedding_terms(gen.model.embedding_manager.list_terms())
813-
814+
814815
def split_variations(variations_string) -> list:
815816
# shotgun parsing, woo
816817
parts = []
@@ -867,15 +868,15 @@ def callback(img, step):
867868
image = gen.sample_to_image(img)
868869
image.save(filename,'PNG')
869870
return callback
870-
871+
871872
def retrieve_dream_command(opt,command,completer):
872873
'''
873874
Given a full or partial path to a previously-generated image file,
874875
will retrieve and format the dream command used to generate the image,
875876
and pop it into the readline buffer (linux, Mac), or print out a comment
876877
for cut-and-paste (windows)
877878
878-
Given a wildcard path to a folder with image png files,
879+
Given a wildcard path to a folder with image png files,
879880
will retrieve and format the dream command used to generate the images,
880881
and save them to a file commands.txt for further processing
881882
'''
@@ -911,7 +912,7 @@ def write_commands(opt, file_path:str, outfilepath:str):
911912
except ValueError:
912913
print(f'## "{basename}": unacceptable pattern')
913914
return
914-
915+
915916
commands = []
916917
cmd = None
917918
for path in paths:
@@ -940,7 +941,7 @@ def emergency_model_reconfigure():
940941
print(' After reconfiguration is done, please relaunch invoke.py. ')
941942
print('!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!')
942943
print('configure_invokeai is launching....\n')
943-
944+
944945
sys.argv = ['configure_invokeai','--interactive']
945946
import configure_invokeai
946947
configure_invokeai.main()

ldm/invoke/conditioning.py

Lines changed: 24 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -19,25 +19,33 @@
1919

2020

2121
def get_uc_and_c_and_ec(prompt_string, model, log_tokens=False, skip_normalize_legacy_blend=False):
22-
23-
prompt, negative_prompt = get_prompt_structure(prompt_string, skip_normalize_legacy_blend=skip_normalize_legacy_blend)
22+
prompt, negative_prompt = get_prompt_structure(prompt_string,
23+
skip_normalize_legacy_blend=skip_normalize_legacy_blend)
2424
conditioning = _get_conditioning_for_prompt(prompt, negative_prompt, model, log_tokens)
2525

2626
return conditioning
2727

28-
def get_prompt_structure(prompt_string, skip_normalize_legacy_blend: bool=False) -> (Union[FlattenedPrompt, Blend], FlattenedPrompt):
28+
29+
def get_prompt_structure(prompt_string, skip_normalize_legacy_blend: bool = False) -> (
30+
Union[FlattenedPrompt, Blend], FlattenedPrompt):
2931
"""
3032
parse the passed-in prompt string and return tuple (positive_prompt, negative_prompt)
3133
"""
32-
prompt, negative_prompt = _parse_prompt_string(prompt_string, skip_normalize_legacy_blend=skip_normalize_legacy_blend)
34+
prompt, negative_prompt = _parse_prompt_string(prompt_string,
35+
skip_normalize_legacy_blend=skip_normalize_legacy_blend)
3336
return prompt, negative_prompt
3437

38+
3539
def get_tokens_for_prompt(model, parsed_prompt: FlattenedPrompt) -> [str]:
36-
text_fragments = [(x.text if x is Fragment else x.original.text if x is CrossAttentionControlSubstitute else str(x))
40+
text_fragments = [x.text if type(x) is Fragment else
41+
(" ".join([f.text for f in x.original]) if type(x) is CrossAttentionControlSubstitute else
42+
str(x))
3743
for x in parsed_prompt.children]
38-
tokens = model.cond_stage_model.tokenizer.tokenize(text_fragments)
44+
text = " ".join(text_fragments)
45+
tokens = model.cond_stage_model.tokenizer.tokenize(text)
3946
return tokens
4047

48+
4149
def _parse_prompt_string(prompt_string_uncleaned, skip_normalize_legacy_blend=False) -> Union[FlattenedPrompt, Blend]:
4250
# Extract Unconditioned Words From Prompt
4351
unconditioned_words = ''
@@ -67,6 +75,7 @@ def _parse_prompt_string(prompt_string_uncleaned, skip_normalize_legacy_blend=Fa
6775
parsed_negative_prompt: FlattenedPrompt = pp.parse_conjunction(unconditioned_words).prompts[0]
6876
return parsed_prompt, parsed_negative_prompt
6977

78+
7079
def _get_conditioning_for_prompt(parsed_prompt: Union[Blend, FlattenedPrompt], parsed_negative_prompt: FlattenedPrompt,
7180
model, log_tokens=False) \
7281
-> tuple[torch.Tensor, torch.Tensor, InvokeAIDiffuserComponent.ExtraConditioningInfo]:
@@ -102,7 +111,8 @@ def _get_conditioning_for_prompt(parsed_prompt: Union[Blend, FlattenedPrompt], p
102111
# hybrid conditioning is in play
103112
unconditioning, conditioning = _flatten_hybrid_conditioning(unconditioning, conditioning)
104113
if cac_args is not None:
105-
print(">> Hybrid conditioning cannot currently be combined with cross attention control. Cross attention control will be ignored.")
114+
print(
115+
">> Hybrid conditioning cannot currently be combined with cross attention control. Cross attention control will be ignored.")
106116
cac_args = None
107117

108118
return (
@@ -112,8 +122,7 @@ def _get_conditioning_for_prompt(parsed_prompt: Union[Blend, FlattenedPrompt], p
112122
)
113123

114124

115-
116-
def _get_conditioning_for_cross_attention_control(model, prompt: FlattenedPrompt, log_tokens: bool=True):
125+
def _get_conditioning_for_cross_attention_control(model, prompt: FlattenedPrompt, log_tokens: bool = True):
117126
original_prompt = FlattenedPrompt()
118127
edited_prompt = FlattenedPrompt()
119128
# for name, a0, a1, b0, b1 in edit_opcodes: only name == 'equal' is currently parsed
@@ -185,7 +194,6 @@ def _get_conditioning_for_cross_attention_control(model, prompt: FlattenedPrompt
185194
return conditioning, cac_args
186195

187196

188-
189197
def _get_conditioning_for_blend(model, blend: Blend, log_tokens: bool = False):
190198
embeddings_to_blend = None
191199
for i, flattened_prompt in enumerate(blend.prompts):
@@ -201,7 +209,8 @@ def _get_conditioning_for_blend(model, blend: Blend, log_tokens: bool = False):
201209
return conditioning
202210

203211

204-
def _get_embeddings_and_tokens_for_prompt(model, flattened_prompt: FlattenedPrompt, log_tokens: bool=False, log_display_label: str=None):
212+
def _get_embeddings_and_tokens_for_prompt(model, flattened_prompt: FlattenedPrompt, log_tokens: bool = False,
213+
log_display_label: str = None):
205214
if type(flattened_prompt) is not FlattenedPrompt:
206215
raise Exception(f"embeddings can only be made from FlattenedPrompts, got {type(flattened_prompt)} instead")
207216
fragments = [x.text for x in flattened_prompt.children]
@@ -213,11 +222,13 @@ def _get_embeddings_and_tokens_for_prompt(model, flattened_prompt: FlattenedProm
213222

214223
return embeddings, tokens
215224

225+
216226
def _get_tokens_length(model, fragments: list[Fragment]):
217227
fragment_texts = [x.text for x in fragments]
218228
tokens = model.cond_stage_model.get_tokens(fragment_texts, include_start_and_end_markers=False)
219229
return sum([len(x) for x in tokens])
220230

231+
221232
def _flatten_hybrid_conditioning(uncond, cond):
222233
'''
223234
This handles the choice between a conditional conditioning
@@ -244,7 +255,7 @@ def log_tokenization(text, model, display_label=None):
244255
# but for readability it has been replaced with ' '
245256
"""
246257

247-
tokens = model.cond_stage_model.tokenizer.tokenize(text)
258+
tokens = model.cond_stage_model.tokenizer.tokenize(text)
248259
tokenized = ""
249260
discarded = ""
250261
usedTokens = 0
@@ -261,5 +272,5 @@ def log_tokenization(text, model, display_label=None):
261272
print(f"\n>> Tokens {display_label or ''} ({usedTokens}):\n{tokenized}\x1b[0m")
262273
if discarded != "":
263274
print(
264-
f">> Tokens Discarded ({totalTokens-usedTokens}):\n{discarded}\x1b[0m"
275+
f">> Tokens Discarded ({totalTokens - usedTokens}):\n{discarded}\x1b[0m"
265276
)

ldm/invoke/generator/base.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -103,11 +103,11 @@ def generate(self,prompt,init_image,width,height,sampler, iterations=1,seed=None
103103
results.append([image, seed])
104104

105105
if image_callback is not None:
106-
image_callback(image, seed, first_seed=first_seed)
106+
image_callback(image, seed, first_seed=first_seed, attention_maps_image=attention_maps_images[-1])
107107

108108
seed = self.new_seed()
109109

110-
return results, attention_maps_images
110+
return results
111111

112112
def sample_to_image(self,samples)->Image.Image:
113113
"""

0 commit comments

Comments
 (0)