From a35bba5701ff3cc10c82ed1c6c4c4ebb4293cf3e Mon Sep 17 00:00:00 2001 From: neon Date: Sat, 29 Apr 2023 12:48:35 +0200 Subject: [PATCH 01/10] init --- deepfloyd_if/modules/stage_I.py | 4 +- deepfloyd_if/modules/t5.py | 19 +- deepfloyd_if/pipelines/optimized_dream.py | 94 ++++++ run_ui.py | 371 ++++++++++++++++++++++ ui_files/style.css | 238 ++++++++++++++ ui_files/utils.py | 72 +++++ 6 files changed, 792 insertions(+), 6 deletions(-) create mode 100644 deepfloyd_if/pipelines/optimized_dream.py create mode 100644 run_ui.py create mode 100644 ui_files/style.css create mode 100644 ui_files/utils.py diff --git a/deepfloyd_if/modules/stage_I.py b/deepfloyd_if/modules/stage_I.py index a9c62cc..8b1c5f3 100644 --- a/deepfloyd_if/modules/stage_I.py +++ b/deepfloyd_if/modules/stage_I.py @@ -24,11 +24,13 @@ def __init__(self, *args, model_kwargs=None, pil_img_size=64, **kwargs): self.model = self.load_checkpoint(self.model, self.dir_or_name) self.model.eval().to(self.device) + def to(self, x): + self.model.to(x) + def embeddings_to_image(self, t5_embs, style_t5_embs=None, positive_t5_embs=None, negative_t5_embs=None, batch_repeat=1, dynamic_thresholding_p=0.95, sample_loop='ddpm', positive_mixer=0.25, sample_timestep_respacing='150', dynamic_thresholding_c=1.5, guidance_scale=7.0, aspect_ratio='1:1', progress=True, seed=None, sample_fn=None, **kwargs): - return super().embeddings_to_image( t5_embs=t5_embs, style_t5_embs=style_t5_embs, diff --git a/deepfloyd_if/modules/t5.py b/deepfloyd_if/modules/t5.py index 7443426..63ec1c7 100644 --- a/deepfloyd_if/modules/t5.py +++ b/deepfloyd_if/modules/t5.py @@ -12,9 +12,9 @@ class T5Embedder: - available_models = ['t5-v1_1-xxl'] - bad_punct_regex = re.compile(r'['+'#®•©™&@·º½¾¿¡§~'+'\)'+'\('+'\]'+'\['+'\}'+'\{'+'\|'+'\\'+'\/'+'\*' + r']{1,}') # noqa + bad_punct_regex = re.compile( + r'[' + '#®•©™&@·º½¾¿¡§~' + '\)' + '\(' + '\]' + '\[' + '\}' + '\{' + '\|' + '\\' + '\/' + '\*' + r']{1,}') # noqa def __init__(self, device, dir_or_name='t5-v1_1-xxl', *, cache_dir=None, hf_token=None, use_text_preprocessing=True, t5_model_kwargs=None, torch_dtype=None, use_offload_folder=None): @@ -76,6 +76,12 @@ def __init__(self, device, dir_or_name='t5-v1_1-xxl', *, cache_dir=None, hf_toke self.tokenizer = AutoTokenizer.from_pretrained(path) self.model = T5EncoderModel.from_pretrained(path, **t5_model_kwargs).eval() + def to(self, x): + self.model.to(x) + + def cpu(self): + self.model.base_model().to(torch.device("cpu")) + def get_text_embeddings(self, texts): texts = [self.text_preprocessing(text) for text in texts] @@ -121,10 +127,12 @@ def clean_caption(self, caption): caption = re.sub('', 'person', caption) # urls: caption = re.sub( - r'\b((?:https?:(?:\/{1,3}|[a-zA-Z0-9%])|[a-zA-Z0-9.\-]+[.](?:com|co|ru|net|org|edu|gov|it)[\w/-]*\b\/?(?!@)))', # noqa + r'\b((?:https?:(?:\/{1,3}|[a-zA-Z0-9%])|[a-zA-Z0-9.\-]+[.](?:com|co|ru|net|org|edu|gov|it)[\w/-]*\b\/?(?!@)))', + # noqa '', caption) # regex for urls caption = re.sub( - r'\b((?:www:(?:\/{1,3}|[a-zA-Z0-9%])|[a-zA-Z0-9.\-]+[.](?:com|co|ru|net|org|edu|gov|it)[\w/-]*\b\/?(?!@)))', # noqa + r'\b((?:www:(?:\/{1,3}|[a-zA-Z0-9%])|[a-zA-Z0-9.\-]+[.](?:com|co|ru|net|org|edu|gov|it)[\w/-]*\b\/?(?!@)))', + # noqa '', caption) # regex for urls # html: caption = BeautifulSoup(caption, features='html.parser').text @@ -150,7 +158,8 @@ def clean_caption(self, caption): # все виды тире / all types of dash --> "-" caption = re.sub( - r'[\u002D\u058A\u05BE\u1400\u1806\u2010-\u2015\u2E17\u2E1A\u2E3A\u2E3B\u2E40\u301C\u3030\u30A0\uFE31\uFE32\uFE58\uFE63\uFF0D]+', # noqa + r'[\u002D\u058A\u05BE\u1400\u1806\u2010-\u2015\u2E17\u2E1A\u2E3A\u2E3B\u2E40\u301C\u3030\u30A0\uFE31\uFE32\uFE58\uFE63\uFF0D]+', + # noqa '-', caption) # кавычки к одному стандарту diff --git a/deepfloyd_if/pipelines/optimized_dream.py b/deepfloyd_if/pipelines/optimized_dream.py new file mode 100644 index 0000000..6620b4b --- /dev/null +++ b/deepfloyd_if/pipelines/optimized_dream.py @@ -0,0 +1,94 @@ +import gc +import numpy as np + +import torch.cuda +from PIL import Image + + +def run_garbage_collection(): + gc.collect() + torch.cuda.empty_cache() + + +def to_pil_images(images: torch.Tensor) -> list[Image]: + images = (images / 2 + 0.5).clamp(0, 1) + images = images.cpu().permute(0, 2, 3, 1).float().numpy() + images = np.round(images * 255).astype(np.uint8) + return [Image.fromarray(image) for image in images] + + +def run_stage1( + model, + t5_embs, + negative_t5_embs, + seed: int = 0, + num_images: int = 1, + guidance_scale_1: float = 7.0, + custom_timesteps_1: str = 'smart100', + num_inference_steps_1: int = 100, +): + run_garbage_collection() + + images, _ = model.embeddings_to_image(t5_embs=t5_embs, + negative_t5_embs=negative_t5_embs, + num_images_per_prompt=num_images, + guidance_scale=guidance_scale_1, + sample_timestep_respacing=custom_timesteps_1, + seed=seed + ).images + pil_images_I = model.to_images(images, disable_watermark=True) + + return pil_images_I + + +def run_stage2( + model, + stage1_result, + stage2_index: int, + seed_2: int = 0, + guidance_scale_2: float = 4.0, + custom_timesteps_2: str = 'smart50', + num_inference_steps_2: int = 50, + disable_watermark: bool = True, +) -> Image: + run_garbage_collection() + + prompt_embeds = stage1_result['prompt_embeds'] + negative_embeds = stage1_result['negative_embeds'] + images = stage1_result['images'] + images = images[[stage2_index]] + + stageII_generations, _ = model.embeddings_to_image(low_res=images, + t5_embs=prompt_embeds, + negative_t5_embs=negative_embeds, + guidance_scale=guidance_scale_2, + sample_timestep_respacing=custom_timesteps_2, + seed=seed_2) + pil_images_II = model.to_images(stageII_generations, disable_watermark=disable_watermark) + + return pil_images_II + + +def run_stage3( + model, + image: Image, + t5_embs, + negative_t5_embs, + seed_3: int = 0, + guidance_scale_3: float = 9.0, + sample_timestep_respacing='super40', + disable_watermark=True +) -> Image: + run_garbage_collection() + + _stageIII_generations, _ = model.embeddings_to_image(image=image, + t5_embs=t5_embs, + negative_t5_embs=negative_t5_embs, + num_images_per_prompt=1, + guidance_scale=guidance_scale_3, + noise_level=100, + sample_timestep_respacing=sample_timestep_respacing, + seed=seed_3) + pil_image_III = model.to_images(_stageIII_generations, disable_watermark=disable_watermark) + + return pil_image_III diff --git a/run_ui.py b/run_ui.py new file mode 100644 index 0000000..8e09648 --- /dev/null +++ b/run_ui.py @@ -0,0 +1,371 @@ +import argparse +import gc +import os + +import numpy as np + +from deepfloyd_if.modules import IFStageI, IFStageII, StableStageIII +from deepfloyd_if.modules.t5 import T5Embedder +from deepfloyd_if.pipelines.optimized_dream import run_stage1, run_stage2, run_stage3 + +import torch + +import gradio as gr + +from ui_files.utils import randomize_seed_fn, show_gallery_view, update_upscale_button, get_stage2_index, \ + check_if_stage2_selected, show_upscaled_view, get_device_map + +try: + import xformers + + os.environ["FORCE_MEM_EFFICIENT_ATTN"] = "1" +except: + pass + +device = torch.device(0) +if_I = IFStageI('IF-I-XL-v1.0', device=torch.device("cpu")) +if_I.to(torch.float16) # half +# # if_II = IFStageII('IF-II-L-v1.0', device=torch.device("cpu")) +# # if_III = StableStageIII('stable-diffusion-x4-upscaler', device=torch.device("cpu")) +t5_device = torch.device(0) +t5 = T5Embedder(device=t5_device, t5_model_kwargs={"low_cpu_mem_usage": True, + "torch_dtype": torch.float16, + "device_map": get_device_map(t5_device), + "offload_folder": True}) + + +def switch_devices(stage): + if stage == 1: + # t5.model.cpu() + del t5.model + if_I.to(torch.device(0)) + + gc.collect() + torch.cuda.empty_cache() + torch.cuda.synchronize() + + +def process_and_run_stage1(prompt, + negative_prompt, + seed_1, + num_images, + guidance_scale_1, + custom_timesteps_1, + num_inference_steps_1): + print("Encoding prompts..") + prompt = t5.get_text_embeddings(prompt) + if negative_prompt == "": + negative_prompt = torch.zeros_like(prompt) + else: + negative_prompt = t5.get_text_embeddings(negative_prompt) + switch_devices(stage=1) + prompt = prompt.to(if_I.device) + negative_prompt = negative_prompt.to(if_I.device) + print("Encoded. Running 1st stage") + return run_stage1( + if_I, + t5_embs=prompt, + negative_t5_embs=negative_prompt, + seed=seed_1, + num_images=num_images, + guidance_scale_1=guidance_scale_1, + custom_timesteps_1=custom_timesteps_1, + num_inference_steps_1=num_inference_steps_1 + ) + + +def create_ui(args): + with gr.Blocks(css='ui_files/style.css') as demo: + with gr.Box(): + with gr.Row(elem_id='prompt-container').style(equal_height=True): + with gr.Column(): + prompt = gr.Text( + label='Prompt', + show_label=False, + max_lines=1, + placeholder='Enter your prompt', + elem_id='prompt-text-input', + ).style(container=False) + negative_prompt = gr.Text( + label='Negative prompt', + show_label=False, + max_lines=1, + placeholder='Enter a negative prompt', + elem_id='negative-prompt-text-input', + ).style(container=False) + generate_button = gr.Button('Generate').style(full_width=False) + + with gr.Column() as gallery_view: + gallery = gr.Gallery(label='Stage 1 results', + show_label=False, + elem_id='gallery').style( + columns=args.GALLERY_COLUMN_NUM, + object_fit='contain') + gr.Markdown('Pick your favorite generation to upscale.') + with gr.Row(): + upscale_to_256_button = gr.Button( + 'Upscale to 256px', + visible=args.DISABLE_SD_X4_UPSCALER, + interactive=False) + upscale_button = gr.Button('Upscale', + interactive=False, + visible=not args.DISABLE_SD_X4_UPSCALER) + with gr.Column(visible=False) as upscale_view: + result = gr.Image(label='Result', + show_label=False, + type='filepath', + interactive=False, + elem_id='upscaled-image').style(height=640) + back_to_selection_button = gr.Button('Back to selection') + with gr.Accordion('Advanced options', + open=False, + visible=args.SHOW_ADVANCED_OPTIONS): + with gr.Tabs(): + with gr.Tab(label='Generation'): + seed_1 = gr.Slider(label='Seed', + minimum=0, + maximum=args.MAX_SEED, + step=1, + value=0) + randomize_seed_1 = gr.Checkbox(label='Randomize seed', + value=True) + guidance_scale_1 = gr.Slider(label='Guidance scale', + minimum=1, + maximum=20, + step=0.1, + value=7.0) + custom_timesteps_1 = gr.Dropdown( + label='Custom timesteps 1', + choices=[ + 'none', + 'fast27', + 'smart27', + 'smart50', + 'smart100', + 'smart185', + ], + value="smart100", + visible=True) + num_inference_steps_1 = gr.Slider( + label='Number of inference steps', + minimum=1, + maximum=200, + step=1, + value=100, + visible=True) + num_images = gr.Slider(label='Number of images', + minimum=1, + maximum=4, + step=1, + value=4, + visible=True) + with gr.Tab(label='Super-resolution 1'): + seed_2 = gr.Slider(label='Seed', + minimum=0, + maximum=args.MAX_SEED, + step=1, + value=0) + randomize_seed_2 = gr.Checkbox(label='Randomize seed', + value=True) + guidance_scale_2 = gr.Slider(label='Guidance scale', + minimum=1, + maximum=20, + step=0.1, + value=4.0) + custom_timesteps_2 = gr.Dropdown( + label='Custom timesteps 2', + choices=[ + 'none', + 'fast27', + 'smart27', + 'smart50', + 'smart100', + 'smart185', + ], + value="smart50", + visible=True) + num_inference_steps_2 = gr.Slider( + label='Number of inference steps', + minimum=1, + maximum=200, + step=1, + value=50, + visible=True) + with gr.Tab(label='Super-resolution 2'): + seed_3 = gr.Slider(label='Seed', + minimum=0, + maximum=args.MAX_SEED, + step=1, + value=0) + randomize_seed_3 = gr.Checkbox(label='Randomize seed', + value=True) + guidance_scale_3 = gr.Slider(label='Guidance scale', + minimum=1, + maximum=20, + step=0.1, + value=9.0) + num_inference_steps_3 = gr.Slider( + label='Number of inference steps', + minimum=1, + maximum=200, + step=1, + value=40, + visible=True) + with gr.Box(): + with gr.Row(): + with gr.Accordion(label='Hidden params'): + selected_index_for_stage2 = gr.Number( + label='Selected index for Stage 2', value=-1, precision=0) + + generate_button.click( + process_and_run_stage1, + [prompt, + negative_prompt, + seed_1, + num_images, + guidance_scale_1, + custom_timesteps_1, + num_inference_steps_1], + gallery + ) + + gallery.select( + fn=get_stage2_index, + outputs=selected_index_for_stage2, + queue=False, + ) + # + # selected_index_for_stage2.change( + # fn=update_upscale_button, + # inputs=selected_index_for_stage2, + # outputs=[ + # upscale_button, + # upscale_to_256_button, + # ], + # queue=False, + # ) + # + # stage2_inputs = [ + # stage1_result_path, + # selected_index_for_stage2, + # seed_2, + # guidance_scale_2, + # custom_timesteps_2, + # num_inference_steps_2, + # ] + # + # upscale_to_256_button.click( + # fn=check_if_stage2_selected, + # inputs=selected_index_for_stage2, + # queue=False, + # ).then( + # fn=randomize_seed_fn, + # inputs=[seed_2, randomize_seed_2], + # outputs=seed_2, + # queue=False, + # ).then( + # fn=show_upscaled_view, + # outputs=[ + # gallery_view, + # upscale_view, + # ], + # queue=False, + # ).then( + # fn=run_stage2, + # inputs=stage2_inputs, + # outputs=result, + # api_name='upscale256', + # ) # .success( + # # fn=upload_stage2_info, + # # inputs=[ + # # stage1_param_file_hash_name, + # # result, + # # selected_index_for_stage2, + # # seed_2, + # # guidance_scale_2, + # # custom_timesteps_2, + # # num_inference_steps_2, + # # ], + # # queue=False, + # # ) + # + # stage2_3_inputs = [ + # stage1_result_path, + # selected_index_for_stage2, + # seed_2, + # guidance_scale_2, + # custom_timesteps_2, + # num_inference_steps_2, + # prompt, + # negative_prompt, + # seed_3, + # guidance_scale_3, + # num_inference_steps_3, + # ] + # + # upscale_button.click( + # fn=check_if_stage2_selected, + # inputs=selected_index_for_stage2, + # queue=False, + # ).then( + # fn=randomize_seed_fn, + # inputs=[seed_2, randomize_seed_2], + # outputs=seed_2, + # queue=False, + # ).then( + # fn=randomize_seed_fn, + # inputs=[seed_3, randomize_seed_3], + # outputs=seed_3, + # queue=False, + # ).then( + # fn=show_upscaled_view, + # outputs=[ + # gallery_view, + # upscale_view, + # ], + # queue=False, + # ).then( + # fn=run_stage3, + # inputs=stage2_3_inputs, + # outputs=result, + # api_name='upscale1024', + # ) # .success( + # # fn=upload_stage2_3_info, + # # inputs=[ + # # stage1_param_file_hash_name, + # # result, + # # selected_index_for_stage2, + # # seed_2, + # # guidance_scale_2, + # # custom_timesteps_2, + # # num_inference_steps_2, + # # prompt, + # # negative_prompt, + # # seed_3, + # # guidance_scale_3, + # # num_inference_steps_3, + # # ], + # # queue=False, + # # ) + # + # back_to_selection_button.click( + # fn=show_gallery_view, + # outputs=[ + # gallery_view, + # upscale_view, + # ], + # queue=False, + # ) + return demo + + +if __name__ == '__main__': + parser = argparse.ArgumentParser( + description='IF UI settings') + parser.add_argument('--GALLERY_COLUMN_NUM', type=int, default=4) + parser.add_argument('--DISABLE_SD_X4_UPSCALER', type=bool, default=True) + parser.add_argument('--SHOW_ADVANCED_OPTIONS', type=bool, default=True) + parser.add_argument('--MAX_SEED', type=int, default=np.iinfo(np.int32).max) + + demo = create_ui(parser.parse_args()) + demo.launch() diff --git a/ui_files/style.css b/ui_files/style.css new file mode 100644 index 0000000..17fb109 --- /dev/null +++ b/ui_files/style.css @@ -0,0 +1,238 @@ +/* +This CSS file is modified from: +https://huggingface.co/spaces/stabilityai/stable-diffusion/blob/2794a3c3ba66115c307075098e713f572b08bf80/app.py +*/ + +h1 { + text-align: center; +} + +.gradio-container { + font-family: 'IBM Plex Sans', sans-serif; +} + +.gr-button { + color: white; + border-color: black; + background: black; +} + +input[type='range'] { + accent-color: black; +} + +.dark input[type='range'] { + accent-color: #dfdfdf; +} + +.container { + max-width: 730px; + margin: auto; + padding-top: 1.5rem; +} + +#gallery { + min-height: auto; + height: 185px; + margin-top: 15px; + margin-left: auto; + margin-right: auto; + border-bottom-right-radius: .5rem !important; + border-bottom-left-radius: .5rem !important; +} +#gallery .grid-wrap, #gallery .empty{ + height: 185px; + min-height: 185px; +} +#gallery .preview{ + height: 185px; + min-height: 185px!important; +} +#gallery>div>.h-full { + min-height: 20rem; +} + +.details:hover { + text-decoration: underline; +} + +.gr-button { + white-space: nowrap; +} + +.gr-button:focus { + border-color: rgb(147 197 253 / var(--tw-border-opacity)); + outline: none; + box-shadow: var(--tw-ring-offset-shadow), var(--tw-ring-shadow), var(--tw-shadow, 0 0 #0000); + --tw-border-opacity: 1; + --tw-ring-offset-shadow: var(--tw-ring-inset) 0 0 0 var(--tw-ring-offset-width) var(--tw-ring-offset-color); + --tw-ring-shadow: var(--tw-ring-inset) 0 0 0 calc(3px var(--tw-ring-offset-width)) var(--tw-ring-color); + --tw-ring-color: rgb(191 219 254 / var(--tw-ring-opacity)); + --tw-ring-opacity: .5; +} + +#advanced-btn { + font-size: .7rem !important; + line-height: 19px; + margin-top: 12px; + margin-bottom: 12px; + padding: 2px 8px; + border-radius: 14px !important; +} + +#advanced-options { + display: none; + margin-bottom: 20px; +} + +.footer { + margin-bottom: 45px; + margin-top: 35px; + text-align: center; + border-bottom: 1px solid #e5e5e5; +} + +.footer>p { + font-size: .8rem; + display: inline-block; + padding: 0 10px; + transform: translateY(10px); + background: white; +} + +.dark .footer { + border-color: #303030; +} + +.dark .footer>p { + background: #0b0f19; +} + +.acknowledgments h4 { + margin: 1.25em 0 .25em 0; + font-weight: bold; + font-size: 115%; +} + +.animate-spin { + animation: spin 1s linear infinite; +} + +@keyframes spin { + from { + transform: rotate(0deg); + } + + to { + transform: rotate(360deg); + } +} + +#share-btn-container { + display: flex; + padding-left: 0.5rem !important; + padding-right: 0.5rem !important; + background-color: #000000; + justify-content: center; + align-items: center; + border-radius: 9999px !important; + width: 13rem; + margin-top: 10px; + margin-left: auto; +} + +#share-btn { + all: initial; + color: #ffffff; + font-weight: 600; + cursor: pointer; + font-family: 'IBM Plex Sans', sans-serif; + margin-left: 0.5rem !important; + padding-top: 0.25rem !important; + padding-bottom: 0.25rem !important; + right: 0; +} + +#share-btn * { + all: unset; +} + +#share-btn-container div:nth-child(-n+2) { + width: auto !important; + min-height: 0px !important; +} + +#share-btn-container .wrap { + display: none !important; +} + +.gr-form { + flex: 1 1 50%; + border-top-right-radius: 0; + border-bottom-right-radius: 0; +} + +#prompt-container { + gap: 0; +} + +#prompt-text-input, +#negative-prompt-text-input { + padding: .45rem 0.625rem +} + +#component-16 { + border-top-width: 1px !important; + margin-top: 1em +} + +.image_duplication { + position: absolute; + width: 100px; + left: 50px +} + +#component-0 { + max-width: 730px; + margin: auto; + padding-top: 1.5rem; +} + +#upscaled-image img { + object-fit: scale-down; +} +/* share button */ +#share-btn-container { + display: flex; + padding-left: 0.5rem !important; + padding-right: 0.5rem !important; + background-color: #000000; + justify-content: center; + align-items: center; + border-radius: 9999px !important; + width: 13rem; + margin-top: 10px; + margin-left: auto; + flex: unset !important; +} +#share-btn { + all: initial; + color: #ffffff; + font-weight: 600; + cursor: pointer; + font-family: 'IBM Plex Sans', sans-serif; + margin-left: 0.5rem !important; + padding-top: 0.25rem !important; + padding-bottom: 0.25rem !important; + right:0; +} +#share-btn * { + all: unset !important; +} +#share-btn-container div:nth-child(-n+2){ + width: auto !important; + min-height: 0px !important; +} +#share-btn-container .wrap { + display: none !important; +} \ No newline at end of file diff --git a/ui_files/utils.py b/ui_files/utils.py new file mode 100644 index 0000000..a80d80d --- /dev/null +++ b/ui_files/utils.py @@ -0,0 +1,72 @@ +import gradio as gr +import numpy as np +import random + + +def _update_result_view(show_gallery: bool) -> tuple[dict, dict]: + return gr.update(visible=show_gallery), gr.update(visible=not show_gallery) + + +def show_gallery_view() -> tuple[dict, dict]: + return _update_result_view(True) + + +def show_upscaled_view() -> tuple[dict, dict]: + return _update_result_view(False) + + +def randomize_seed_fn(seed: int, randomize_seed: bool) -> int: + if randomize_seed: + seed = random.randint(0, np.iinfo(np.int32).max) + return seed + + +def update_upscale_button(selected_index: int) -> tuple[dict, dict]: + if selected_index == -1: + return gr.update(interactive=False), gr.update(interactive=False) + else: + return gr.update(interactive=True), gr.update(interactive=True) + + +def get_stage2_index(evt: gr.SelectData) -> int: + return evt.index + + +def check_if_stage2_selected(index: int) -> None: + if index == -1: + raise gr.Error( + 'You need to select the image you would like to upscale from the Stage 1 results by clicking.' + ) + + +def get_device_map(device): + return { + 'shared': device, + 'encoder.embed_tokens': device, + 'encoder.block.0': device, + 'encoder.block.1': device, + 'encoder.block.2': device, + 'encoder.block.3': device, + 'encoder.block.4': device, + 'encoder.block.5': device, + 'encoder.block.6': device, + 'encoder.block.7': device, + 'encoder.block.8': device, + 'encoder.block.9': device, + 'encoder.block.10': device, + 'encoder.block.11': device, + 'encoder.block.12': 'cpu', + 'encoder.block.13': 'cpu', + 'encoder.block.14': 'cpu', + 'encoder.block.15': 'cpu', + 'encoder.block.16': 'cpu', + 'encoder.block.17': 'cpu', + 'encoder.block.18': 'cpu', + 'encoder.block.19': 'cpu', + 'encoder.block.20': 'cpu', + 'encoder.block.21': 'cpu', + 'encoder.block.22': 'cpu', + 'encoder.block.23': 'cpu', + 'encoder.final_layer_norm': 'cpu', + 'encoder.dropout': 'cpu', + } From 2df8326a81a2f5a8561b23877e95cf0f4a37ce55 Mon Sep 17 00:00:00 2001 From: neon Date: Sat, 29 Apr 2023 17:24:04 +0200 Subject: [PATCH 02/10] kinda works --- deepfloyd_if/model/__init__.py | 3 +- deepfloyd_if/model/gaussian_diffusion.py | 300 ++++----- deepfloyd_if/model/unet.py | 31 +- deepfloyd_if/model/unet_split.py | 750 ++++++++++++++++++++++ deepfloyd_if/modules/base.py | 68 +- deepfloyd_if/modules/stage_I.py | 13 +- deepfloyd_if/modules/stage_III_sd_x4.py | 4 +- deepfloyd_if/pipelines/dream.py | 42 +- deepfloyd_if/pipelines/optimized_dream.py | 2 +- run_ui.py | 20 +- ui_files/utils.py | 5 +- 11 files changed, 995 insertions(+), 243 deletions(-) create mode 100644 deepfloyd_if/model/unet_split.py diff --git a/deepfloyd_if/model/__init__.py b/deepfloyd_if/model/__init__.py index 332da3e..55e58ca 100644 --- a/deepfloyd_if/model/__init__.py +++ b/deepfloyd_if/model/__init__.py @@ -1,5 +1,6 @@ # -*- coding: utf-8 -*- from .unet import UNetModel, SuperResUNetModel +from .unet_split import UNetSplitModel -__all__ = ['UNetModel', 'SuperResUNetModel'] +__all__ = ['UNetModel', 'SuperResUNetModel', 'UNetSplitModel'] diff --git a/deepfloyd_if/model/gaussian_diffusion.py b/deepfloyd_if/model/gaussian_diffusion.py index e058fbc..596c1b7 100644 --- a/deepfloyd_if/model/gaussian_diffusion.py +++ b/deepfloyd_if/model/gaussian_diffusion.py @@ -110,13 +110,13 @@ class GaussianDiffusion: """ def __init__( - self, - *, - betas, - model_mean_type, - model_var_type, - loss_type, - rescale_timesteps=False, + self, + *, + betas, + model_mean_type, + model_var_type, + loss_type, + rescale_timesteps=False, ): self.model_mean_type = model_mean_type self.model_var_type = model_var_type @@ -146,7 +146,7 @@ def __init__( # calculations for posterior q(x_{t-1} | x_t, x_0) self.posterior_variance = ( - betas * (1.0 - self.alphas_cumprod_prev) / (1.0 - self.alphas_cumprod) + betas * (1.0 - self.alphas_cumprod_prev) / (1.0 - self.alphas_cumprod) ) # log calculation clipped because the posterior variance is 0 at the # beginning of the diffusion chain. @@ -154,12 +154,12 @@ def __init__( np.append(self.posterior_variance[1], self.posterior_variance[1:]) ) self.posterior_mean_coef1 = ( - betas * np.sqrt(self.alphas_cumprod_prev) / (1.0 - self.alphas_cumprod) + betas * np.sqrt(self.alphas_cumprod_prev) / (1.0 - self.alphas_cumprod) ) self.posterior_mean_coef2 = ( - (1.0 - self.alphas_cumprod_prev) - * np.sqrt(alphas) - / (1.0 - self.alphas_cumprod) + (1.0 - self.alphas_cumprod_prev) + * np.sqrt(alphas) + / (1.0 - self.alphas_cumprod) ) def dynamic_thresholding(self, x, p=0.995, c=1.7): @@ -189,7 +189,7 @@ def q_mean_variance(self, x_start, t): :return: A tuple (mean, variance, log_variance), all of x_start's shape. """ mean = ( - _extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start + _extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start ) variance = _extract_into_tensor(1.0 - self.alphas_cumprod, t, x_start.shape) log_variance = _extract_into_tensor( @@ -210,9 +210,9 @@ def q_sample(self, x_start, t, noise=None): noise = torch.randn_like(x_start) assert noise.shape == x_start.shape return ( - _extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start - + _extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) - * noise + _extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start + + _extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) + * noise ) def q_posterior_mean_variance(self, x_start, x_t, t): @@ -222,24 +222,24 @@ def q_posterior_mean_variance(self, x_start, x_t, t): """ assert x_start.shape == x_t.shape posterior_mean = ( - _extract_into_tensor(self.posterior_mean_coef1, t, x_t.shape) * x_start - + _extract_into_tensor(self.posterior_mean_coef2, t, x_t.shape) * x_t + _extract_into_tensor(self.posterior_mean_coef1, t, x_t.shape) * x_start + + _extract_into_tensor(self.posterior_mean_coef2, t, x_t.shape) * x_t ) posterior_variance = _extract_into_tensor(self.posterior_variance, t, x_t.shape) posterior_log_variance_clipped = _extract_into_tensor( self.posterior_log_variance_clipped, t, x_t.shape ) assert ( - posterior_mean.shape[0] - == posterior_variance.shape[0] - == posterior_log_variance_clipped.shape[0] - == x_start.shape[0] + posterior_mean.shape[0] + == posterior_variance.shape[0] + == posterior_log_variance_clipped.shape[0] + == x_start.shape[0] ) return posterior_mean, posterior_variance, posterior_log_variance_clipped def p_mean_variance( - self, model, x, t, clip_denoised=True, dynamic_thresholding_p=0.99, dynamic_thresholding_c=1.7, - denoised_fn=None, model_kwargs=None + self, model, x, t, clip_denoised=True, dynamic_thresholding_p=0.99, dynamic_thresholding_c=1.7, + denoised_fn=None, model_kwargs=None ): """ Apply the model to get p(x_{t-1} | x_t), as well as a prediction of @@ -280,7 +280,7 @@ def p_mean_variance( max_log = _extract_into_tensor(np.log(self.betas), t, x.shape) # The model_var_values is [-1, 1] for [min_var, max_var]. frac = (model_var_values + 1) / 2 - model_log_variance = frac * max_log + (1 - frac) * min_log + model_log_variance = frac * max_log.to(frac.device) + (1 - frac) * min_log.to(frac.device) model_variance = torch.exp(model_log_variance) else: model_variance, model_log_variance = { @@ -306,6 +306,7 @@ def process_xstart(x): return x # x.clamp(-1, 1) return x + x, t = x.to(model_output.device), t.to(model_output.device) if self.model_mean_type == ModelMeanType.PREVIOUS_X: pred_xstart = process_xstart( self._predict_xstart_from_xprev(x_t=x, t=t, xprev=model_output) @@ -325,7 +326,7 @@ def process_xstart(x): raise NotImplementedError(self.model_mean_type) assert ( - model_mean.shape == model_log_variance.shape == pred_xstart.shape == x.shape + model_mean.shape == model_log_variance.shape == pred_xstart.shape == x.shape ) return { 'mean': model_mean, @@ -337,25 +338,25 @@ def process_xstart(x): def _predict_xstart_from_eps(self, x_t, t, eps): assert x_t.shape == eps.shape return ( - _extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t - - _extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) * eps + _extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t + - _extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) * eps ) def _predict_xstart_from_xprev(self, x_t, t, xprev): assert x_t.shape == xprev.shape return ( # (xprev - coef2*x_t) / coef1 - _extract_into_tensor(1.0 / self.posterior_mean_coef1, t, x_t.shape) * xprev - - _extract_into_tensor( - self.posterior_mean_coef2 / self.posterior_mean_coef1, t, x_t.shape - ) - * x_t + _extract_into_tensor(1.0 / self.posterior_mean_coef1, t, x_t.shape) * xprev + - _extract_into_tensor( + self.posterior_mean_coef2 / self.posterior_mean_coef1, t, x_t.shape + ) + * x_t ) def _predict_eps_from_xstart(self, x_t, t, pred_xstart): return ( - _extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t - - pred_xstart - ) / _extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) + _extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t + - pred_xstart + ) / _extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) def _scale_timesteps(self, t): if self.rescale_timesteps: @@ -363,8 +364,8 @@ def _scale_timesteps(self, t): return t def p_sample( - self, model, x, t, clip_denoised=True, dynamic_thresholding_p=0.99, dynamic_thresholding_c=1.7, - denoised_fn=None, model_kwargs=None, inpainting_mask=None, + self, model, x, t, clip_denoised=True, dynamic_thresholding_p=0.99, dynamic_thresholding_c=1.7, + denoised_fn=None, model_kwargs=None, inpainting_mask=None, ): """ Sample x_{t-1} from the model at the given timestep. @@ -390,31 +391,36 @@ def p_sample( denoised_fn=denoised_fn, model_kwargs=model_kwargs, ) - noise = torch.randn_like(x) + device = out['mean'].device + noise = torch.randn_like(x, device=device) nonzero_mask = ( (t != 0).float().view(-1, *([1] * (len(x.shape) - 1))) - ) # no noise when t == 0 + ).to(device) # no noise when t == 0 if inpainting_mask is None: - inpainting_mask = torch.ones_like(x, device=x.device) + inpainting_mask = torch.ones_like(x, device=device) + + x, t = x.to(device), t.to(device) - sample = out['mean'] + nonzero_mask * torch.exp(0.5 * out['log_variance']) * noise - sample = (1 - inpainting_mask)*x + inpainting_mask*sample + noise = (torch.exp(0.5 * out['log_variance']) * noise).to(device) + + sample = out['mean'] + nonzero_mask * noise + sample = (1 - inpainting_mask) * x + inpainting_mask * sample return {'sample': sample, 'pred_xstart': out['pred_xstart']} def p_sample_loop( - self, - model, - shape, - noise=None, - clip_denoised=True, - dynamic_thresholding_p=0.99, - dynamic_thresholding_c=1.7, - inpainting_mask=None, - denoised_fn=None, - model_kwargs=None, - device=None, - progress=False, - sample_fn=None, + self, + model, + shape, + noise=None, + clip_denoised=True, + dynamic_thresholding_p=0.99, + dynamic_thresholding_c=1.7, + inpainting_mask=None, + denoised_fn=None, + model_kwargs=None, + device=None, + progress=False, + sample_fn=None, ): """ Generate samples from the model. @@ -434,17 +440,17 @@ def p_sample_loop( """ final = None for step_idx, sample in enumerate(self.p_sample_loop_progressive( - model, - shape, - noise=noise, - clip_denoised=clip_denoised, - dynamic_thresholding_p=dynamic_thresholding_p, - dynamic_thresholding_c=dynamic_thresholding_c, - denoised_fn=denoised_fn, - inpainting_mask=inpainting_mask, - model_kwargs=model_kwargs, - device=device, - progress=progress, + model, + shape, + noise=noise, + clip_denoised=clip_denoised, + dynamic_thresholding_p=dynamic_thresholding_p, + dynamic_thresholding_c=dynamic_thresholding_c, + denoised_fn=denoised_fn, + inpainting_mask=inpainting_mask, + model_kwargs=model_kwargs, + device=device, + progress=progress, )): if sample_fn is not None: sample = sample_fn(step_idx, sample) @@ -452,18 +458,18 @@ def p_sample_loop( return final['sample'] def p_sample_loop_progressive( - self, - model, - shape, - inpainting_mask=None, - noise=None, - clip_denoised=True, - dynamic_thresholding_p=0.99, - dynamic_thresholding_c=1.7, - denoised_fn=None, - model_kwargs=None, - device=None, - progress=False, + self, + model, + shape, + inpainting_mask=None, + noise=None, + clip_denoised=True, + dynamic_thresholding_p=0.99, + dynamic_thresholding_c=1.7, + denoised_fn=None, + model_kwargs=None, + device=None, + progress=False, ): """ Generate samples from the model and yield intermediate samples from @@ -472,8 +478,6 @@ def p_sample_loop_progressive( Returns a generator over dicts, where each dict is the return value of p_sample(). """ - if device is None: - device = next(model.parameters()).device assert isinstance(shape, (tuple, list)) if noise is not None: img = noise @@ -505,16 +509,16 @@ def p_sample_loop_progressive( img = out['sample'] def ddim_sample( - self, - model, - x, - t, - clip_denoised=True, - dynamic_thresholding_p=0.99, - dynamic_thresholding_c=1.7, - denoised_fn=None, - model_kwargs=None, - eta=0.0, + self, + model, + x, + t, + clip_denoised=True, + dynamic_thresholding_p=0.99, + dynamic_thresholding_c=1.7, + denoised_fn=None, + model_kwargs=None, + eta=0.0, ): """ Sample x_{t-1} from the model using DDIM. @@ -536,15 +540,15 @@ def ddim_sample( alpha_bar = _extract_into_tensor(self.alphas_cumprod, t, x.shape) alpha_bar_prev = _extract_into_tensor(self.alphas_cumprod_prev, t, x.shape) sigma = ( - eta - * torch.sqrt((1 - alpha_bar_prev) / (1 - alpha_bar)) - * torch.sqrt(1 - alpha_bar / alpha_bar_prev) + eta + * torch.sqrt((1 - alpha_bar_prev) / (1 - alpha_bar)) + * torch.sqrt(1 - alpha_bar / alpha_bar_prev) ) # Equation 12. noise = torch.randn_like(x) mean_pred = ( - out['pred_xstart'] * torch.sqrt(alpha_bar_prev) - + torch.sqrt(1 - alpha_bar_prev - sigma ** 2) * eps + out['pred_xstart'] * torch.sqrt(alpha_bar_prev) + + torch.sqrt(1 - alpha_bar_prev - sigma ** 2) * eps ) nonzero_mask = ( (t != 0).float().view(-1, *([1] * (len(x.shape) - 1))) @@ -553,16 +557,16 @@ def ddim_sample( return {'sample': sample, 'pred_xstart': out['pred_xstart']} def ddim_reverse_sample( - self, - model, - x, - t, - clip_denoised=True, - dynamic_thresholding_p=0.99, - dynamic_thresholding_c=1.7, - denoised_fn=None, - model_kwargs=None, - eta=0.0, + self, + model, + x, + t, + clip_denoised=True, + dynamic_thresholding_p=0.99, + dynamic_thresholding_c=1.7, + denoised_fn=None, + model_kwargs=None, + eta=0.0, ): """ Sample x_{t+1} from the model using DDIM reverse ODE. @@ -581,33 +585,33 @@ def ddim_reverse_sample( # Usually our model outputs epsilon, but we re-derive it # in case we used x_start or x_prev prediction. eps = ( - _extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x.shape) * x - - out['pred_xstart'] - ) / _extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x.shape) + _extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x.shape) * x + - out['pred_xstart'] + ) / _extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x.shape) alpha_bar_next = _extract_into_tensor(self.alphas_cumprod_next, t, x.shape) # Equation 12. reversed mean_pred = ( - out['pred_xstart'] * torch.sqrt(alpha_bar_next) - + torch.sqrt(1 - alpha_bar_next) * eps + out['pred_xstart'] * torch.sqrt(alpha_bar_next) + + torch.sqrt(1 - alpha_bar_next) * eps ) return {'sample': mean_pred, 'pred_xstart': out['pred_xstart']} def ddim_sample_loop( - self, - model, - shape, - noise=None, - clip_denoised=True, - dynamic_thresholding_p=0.99, - dynamic_thresholding_c=1.7, - denoised_fn=None, - model_kwargs=None, - device=None, - progress=False, - eta=0.0, - sample_fn=None, + self, + model, + shape, + noise=None, + clip_denoised=True, + dynamic_thresholding_p=0.99, + dynamic_thresholding_c=1.7, + denoised_fn=None, + model_kwargs=None, + device=None, + progress=False, + eta=0.0, + sample_fn=None, ): """ Generate samples from the model using DDIM. @@ -615,17 +619,17 @@ def ddim_sample_loop( """ final = None for step_idx, sample in enumerate(self.ddim_sample_loop_progressive( - model, - shape, - noise=noise, - clip_denoised=clip_denoised, - denoised_fn=denoised_fn, - dynamic_thresholding_p=dynamic_thresholding_p, - dynamic_thresholding_c=dynamic_thresholding_c, - model_kwargs=model_kwargs, - device=device, - progress=progress, - eta=eta, + model, + shape, + noise=noise, + clip_denoised=clip_denoised, + denoised_fn=denoised_fn, + dynamic_thresholding_p=dynamic_thresholding_p, + dynamic_thresholding_c=dynamic_thresholding_c, + model_kwargs=model_kwargs, + device=device, + progress=progress, + eta=eta, )): if sample_fn is not None: sample = sample_fn(step_idx, sample) @@ -633,18 +637,18 @@ def ddim_sample_loop( return final['sample'] def ddim_sample_loop_progressive( - self, - model, - shape, - noise=None, - clip_denoised=True, - dynamic_thresholding_p=0.99, - dynamic_thresholding_c=1.7, - denoised_fn=None, - model_kwargs=None, - device=None, - progress=False, - eta=0.0, + self, + model, + shape, + noise=None, + clip_denoised=True, + dynamic_thresholding_p=0.99, + dynamic_thresholding_c=1.7, + denoised_fn=None, + model_kwargs=None, + device=None, + progress=False, + eta=0.0, ): """ Use DDIM to sample from the model and yield intermediate samples from @@ -684,7 +688,7 @@ def ddim_sample_loop_progressive( img = out['sample'] def _vb_terms_bpd( - self, model, x_start, x_t, t, clip_denoised=True, model_kwargs=None + self, model, x_start, x_t, t, clip_denoised=True, model_kwargs=None ): """ Get a term for the variational lower-bound. diff --git a/deepfloyd_if/model/unet.py b/deepfloyd_if/model/unet.py index bb83590..2fbc4aa 100644 --- a/deepfloyd_if/model/unet.py +++ b/deepfloyd_if/model/unet.py @@ -11,10 +11,7 @@ from .nn import avg_pool_nd, conv_nd, linear, normalization, timestep_embedding, zero_module, get_activation, \ AttentionPooling -_FORCE_MEM_EFFICIENT_ATTN = int(os.environ.get('FORCE_MEM_EFFICIENT_ATTN', 0)) -print('FORCE_MEM_EFFICIENT_ATTN=', _FORCE_MEM_EFFICIENT_ATTN, '@UNET:QKVATTENTION') -if _FORCE_MEM_EFFICIENT_ATTN: - from xformers.ops import memory_efficient_attention # noqa +from xformers.ops import memory_efficient_attention # noqa class TimestepBlock(nn.Module): @@ -246,7 +243,7 @@ def __init__( self.num_heads = num_heads else: assert ( - channels % num_head_channels == 0 + channels % num_head_channels == 0 ), f'q,k,v channels {channels} is not divisible by num_head_channels {num_head_channels}' self.num_heads = channels // num_head_channels self.norm = normalization(channels, dtype=self.dtype) @@ -310,16 +307,16 @@ def forward(self, qkv, encoder_kv=None): k = torch.cat([ek, k], dim=-1) v = torch.cat([ev, v], dim=-1) scale = 1 / math.sqrt(math.sqrt(ch)) - if _FORCE_MEM_EFFICIENT_ATTN: - q, k, v = map(lambda t: t.permute(0, 2, 1).contiguous(), (q, k, v)) - a = memory_efficient_attention(q, k, v) - a = a.permute(0, 2, 1) - else: - weight = torch.einsum( - 'bct,bcs->bts', q * scale, k * scale - ) # More stable with f16 than dividing afterwards - weight = torch.softmax(weight.float(), dim=-1).type(weight.dtype) - a = torch.einsum('bts,bcs->bct', weight, v) + # if True: # legacy + q, k, v = map(lambda t: t.permute(0, 2, 1).contiguous(), (q, k, v)) + a = memory_efficient_attention(q, k, v) + a = a.permute(0, 2, 1) + # else: + # weight = torch.einsum( + # 'bct,bcs->bts', q * scale, k * scale + # ) # More stable with f16 than dividing afterwards + # weight = torch.softmax(weight.float(), dim=-1).type(weight.dtype) + # a = torch.einsum('bts,bcs->bct', weight, v) return a.reshape(bs, -1, length) @@ -456,7 +453,7 @@ def __init__( ds = 1 if isinstance(num_res_blocks, int): - num_res_blocks = [num_res_blocks]*len(self.channel_mult) + num_res_blocks = [num_res_blocks] * len(self.channel_mult) self.num_res_blocks = num_res_blocks for level, mult in enumerate(self.channel_mult): @@ -690,7 +687,7 @@ def forward(self, x, timesteps, low_res, aug_level=None, **kwargs): ) if aug_level is None: - aug_steps = (np.random.random(bs)*1000).astype(np.int64) # uniform [0, 1) + aug_steps = (np.random.random(bs) * 1000).astype(np.int64) # uniform [0, 1) aug_steps = torch.from_numpy(aug_steps).to(x.device, dtype=torch.long) else: aug_steps = torch.tensor([int(aug_level * 1000)]).repeat(bs).to(x.device, dtype=torch.long) diff --git a/deepfloyd_if/model/unet_split.py b/deepfloyd_if/model/unet_split.py new file mode 100644 index 0000000..1d64e3b --- /dev/null +++ b/deepfloyd_if/model/unet_split.py @@ -0,0 +1,750 @@ +# -*- coding: utf-8 -*- +import gc +import os +import math +from abc import abstractmethod + +import torch +import numpy as np +import torch.nn as nn +import torch.nn.functional as F + +from .nn import avg_pool_nd, conv_nd, linear, normalization, timestep_embedding, zero_module, get_activation, \ + AttentionPooling + +from xformers.ops import memory_efficient_attention # noqa + + +class TimestepBlock(nn.Module): + """ + Any module where forward() takes timestep embeddings as a second argument. + """ + + @abstractmethod + def forward(self, x, emb): + """ + Apply the module to `x` given `emb` timestep embeddings. + """ + + +class TimestepEmbedSequential(nn.Sequential, TimestepBlock): + """ + A sequential module that passes timestep embeddings to the children that + support it as an extra input. + """ + + def forward(self, x, emb, encoder_out=None): + for layer in self: + if isinstance(layer, TimestepBlock): + x = layer(x, emb) + elif isinstance(layer, AttentionBlock): + x = layer(x, encoder_out) + else: + x = layer(x) + return x + + +class Upsample(nn.Module): + """ + An upsampling layer with an optional convolution. + :param channels: channels in the inputs and outputs. + :param use_conv: a bool determining a convolution is applied. + :param dims: determines the signal is 1D, 2D, or 3D. If 3D, then + upsampling occurs in the inner-two dimensions. + """ + + def __init__(self, channels, use_conv, dims=2, out_channels=None, dtype=None): + super().__init__() + self.channels = channels + self.out_channels = out_channels or channels + self.use_conv = use_conv + self.dims = dims + self.dtype = dtype + if use_conv: + self.conv = conv_nd(dims, self.channels, self.out_channels, 3, padding=1, dtype=self.dtype) + + def forward(self, x): + assert x.shape[1] == self.channels + if self.dims == 3: + x = F.interpolate(x, (x.shape[2], x.shape[3] * 2, x.shape[4] * 2), mode='nearest') + else: + if self.dtype == torch.bfloat16: + x = x.type(torch.float32 if x.device.type == 'cpu' else torch.float16) + x = F.interpolate(x, scale_factor=2, mode='nearest') + if self.dtype == torch.bfloat16: + x = x.type(torch.bfloat16) + if self.use_conv: + x = self.conv(x) + return x + + +class Downsample(nn.Module): + """ + A downsampling layer with an optional convolution. + :param channels: channels in the inputs and outputs. + :param use_conv: a bool determining a convolution is applied. + :param dims: determines the signal is 1D, 2D, or 3D. If 3D, then + downsampling occurs in the inner-two dimensions. + """ + + def __init__(self, channels, use_conv, dims=2, out_channels=None, dtype=None): + super().__init__() + self.channels = channels + self.out_channels = out_channels or channels + self.use_conv = use_conv + self.dims = dims + self.dtype = dtype + stride = 2 if dims != 3 else (1, 2, 2) + if use_conv: + self.op = conv_nd(dims, self.channels, self.out_channels, 3, stride=stride, padding=1, dtype=self.dtype) + else: + assert self.channels == self.out_channels + self.op = avg_pool_nd(dims, kernel_size=stride, stride=stride) + + def forward(self, x): + assert x.shape[1] == self.channels + return self.op(x) + + +class ResBlock(TimestepBlock): + """ + A residual block that can optionally change the number of channels. + :param channels: the number of input channels. + :param emb_channels: the number of timestep embedding channels. + :param dropout: the rate of dropout. + :param out_channels: specified, the number of out channels. + :param use_conv: True and out_channels is specified, use a spatial + convolution instead of a smaller 1x1 convolution to change the + channels in the skip connection. + :param dims: determines the signal is 1D, 2D, or 3D. + :param up: True, use this block for upsampling. + :param down: True, use this block for downsampling. + """ + + def __init__( + self, + channels, + emb_channels, + dropout, + activation, + out_channels=None, + use_conv=False, + use_scale_shift_norm=False, + dims=2, + up=False, + down=False, + dtype=None, + efficient_activation=False, + scale_skip_connection=False, + ): + super().__init__() + self.dtype = dtype + self.channels = channels + self.emb_channels = emb_channels + self.dropout = dropout + self.out_channels = out_channels or channels + self.use_conv = use_conv + self.use_scale_shift_norm = use_scale_shift_norm + self.efficient_activation = efficient_activation + self.scale_skip_connection = scale_skip_connection + + self.in_layers = nn.Sequential( + normalization(channels, dtype=self.dtype), + get_activation(activation), + conv_nd(dims, channels, self.out_channels, 3, padding=1, dtype=self.dtype), + ) + + self.updown = up or down + + if up: + self.h_upd = Upsample(channels, False, dims, dtype=self.dtype) + self.x_upd = Upsample(channels, False, dims, dtype=self.dtype) + elif down: + self.h_upd = Downsample(channels, False, dims, dtype=self.dtype) + self.x_upd = Downsample(channels, False, dims, dtype=self.dtype) + else: + self.h_upd = self.x_upd = nn.Identity() + + self.emb_layers = nn.Sequential( + nn.Identity() if self.efficient_activation else get_activation(activation), + linear( + emb_channels, + 2 * self.out_channels if use_scale_shift_norm else self.out_channels, + dtype=self.dtype + ), + ) + self.out_layers = nn.Sequential( + normalization(self.out_channels, dtype=self.dtype), + get_activation(activation), + nn.Dropout(p=dropout), + zero_module(conv_nd(dims, self.out_channels, self.out_channels, 3, padding=1, dtype=self.dtype)), + ) + + if self.out_channels == channels: + self.skip_connection = nn.Identity() + elif use_conv: + self.skip_connection = conv_nd(dims, channels, self.out_channels, 3, padding=1, dtype=self.dtype) + else: + self.skip_connection = conv_nd(dims, channels, self.out_channels, 1, dtype=self.dtype) + + def forward(self, x, emb): + """ + Apply the block to a Tensor, conditioned on a timestep embedding. + :param x: an [N x C x ...] Tensor of features. + :param emb: an [N x emb_channels] Tensor of timestep embeddings. + :return: an [N x C x ...] Tensor of outputs. + """ + if self.updown: + in_rest, in_conv = self.in_layers[:-1], self.in_layers[-1] + h = in_rest(x) + h = self.h_upd(h) + x = self.x_upd(x) + h = in_conv(h) + else: + h = self.in_layers(x) + emb_out = self.emb_layers(emb).type(h.dtype) + while len(emb_out.shape) < len(h.shape): + emb_out = emb_out[..., None] + if self.use_scale_shift_norm: + out_norm, out_rest = self.out_layers[0], self.out_layers[1:] + scale, shift = torch.chunk(emb_out, 2, dim=1) + h = out_norm(h) * (1 + scale) + shift + h = out_rest(h) + else: + h = h + emb_out + h = self.out_layers(h) + + res = self.skip_connection(x) + h + if self.scale_skip_connection: + res *= 0.7071 # 1 / sqrt(2), https://arxiv.org/pdf/2104.07636.pdf + return res + + +class AttentionBlock(nn.Module): + """ + An attention block that allows spatial positions to attend to each other. + Originally ported from here, but adapted to the N-d case. + https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/models/unet.py#L66. + """ + + def __init__( + self, + channels, + num_heads=1, + num_head_channels=-1, + disable_self_attention=False, + encoder_channels=None, + dtype=None, + ): + super().__init__() + self.dtype = dtype + self.channels = channels + self.disable_self_attention = disable_self_attention + if num_head_channels == -1: + self.num_heads = num_heads + else: + assert ( + channels % num_head_channels == 0 + ), f'q,k,v channels {channels} is not divisible by num_head_channels {num_head_channels}' + self.num_heads = channels // num_head_channels + self.norm = normalization(channels, dtype=self.dtype) + self.qkv = conv_nd(1, channels, channels * 3, 1, dtype=self.dtype) + if self.disable_self_attention: + self.qkv = conv_nd(1, channels, channels, 1, dtype=self.dtype) + else: + self.qkv = conv_nd(1, channels, channels * 3, 1, dtype=self.dtype) + self.attention = QKVAttention(self.num_heads, disable_self_attention=disable_self_attention) + + if encoder_channels is not None: + self.encoder_kv = conv_nd(1, encoder_channels, channels * 2, 1, dtype=self.dtype) + self.norm_encoder = normalization(encoder_channels, dtype=self.dtype) + self.proj_out = zero_module(conv_nd(1, channels, channels, 1, dtype=self.dtype)) + + def forward(self, x, encoder_out=None): + b, c, *spatial = x.shape + qkv = self.qkv(self.norm(x).view(b, c, -1)) + if encoder_out is not None: + # from imagen article: https://arxiv.org/pdf/2205.11487.abs + encoder_out = self.norm_encoder(encoder_out) + # # # + encoder_out = self.encoder_kv(encoder_out) + h = self.attention(qkv, encoder_out) + else: + h = self.attention(qkv) + h = self.proj_out(h) + return x + h.reshape(b, c, *spatial) + + +class QKVAttention(nn.Module): + """ + A module which performs QKV attention. Matches legacy QKVAttention + input/ouput heads shaping + """ + + def __init__(self, n_heads, disable_self_attention=False): + super().__init__() + self.n_heads = n_heads + self.disable_self_attention = disable_self_attention + + def forward(self, qkv, encoder_kv=None): + """ + Apply QKV attention. + :param qkv: an [N x (H * 3 * C) x T] tensor of Qs, Ks, and Vs. + :return: an [N x (H * C) x T] tensor after attention. + """ + bs, width, length = qkv.shape + if self.disable_self_attention: + ch = width // (1 * self.n_heads) + q, = qkv.reshape(bs * self.n_heads, ch * 1, length).split(ch, dim=1) + else: + assert width % (3 * self.n_heads) == 0 + ch = width // (3 * self.n_heads) + q, k, v = qkv.reshape(bs * self.n_heads, ch * 3, length).split(ch, dim=1) + if encoder_kv is not None: + assert encoder_kv.shape[1] == self.n_heads * ch * 2 + if self.disable_self_attention: + k, v = encoder_kv.reshape(bs * self.n_heads, ch * 2, -1).split(ch, dim=1) + else: + ek, ev = encoder_kv.reshape(bs * self.n_heads, ch * 2, -1).split(ch, dim=1) + k = torch.cat([ek, k], dim=-1) + v = torch.cat([ev, v], dim=-1) + scale = 1 / math.sqrt(math.sqrt(ch)) + # if _FORCE_MEM_EFFICIENT_ATTN: + q, k, v = map(lambda t: t.permute(0, 2, 1).contiguous(), (q, k, v)) + a = memory_efficient_attention(q, k, v) + a = a.permute(0, 2, 1) + # else: + # weight = torch.einsum( + # 'bct,bcs->bts', q * scale, k * scale + # ) # More stable with f16 than dividing afterwards + # weight = torch.softmax(weight.float(), dim=-1).type(weight.dtype) + # a = torch.einsum('bts,bcs->bct', weight, v) + return a.reshape(bs, -1, length) + + +class UNetSplitModel(nn.Module): + """ + The full UNet model with attention and timestep embedding. + :param in_channels: channels in the input Tensor. + :param model_channels: base channel count for the model. + :param out_channels: channels in the output Tensor. + :param num_res_blocks: number of residual blocks per downsample. + :param attention_resolutions: a collection of downsample rates at which + attention will take place. May be a set, list, or tuple. + For example, this contains 4, then at 4x downsampling, attention + will be used. + :param dropout: the dropout probability. + :param channel_mult: channel multiplier for each level of the UNet. + :param conv_resample: True, use learned convolutions for upsampling and + downsampling. + :param dims: determines the signal is 1D, 2D, or 3D. + :param num_classes: specified (as an int), then this model will be + class-conditional with `num_classes` classes. + :param num_heads: the number of attention heads in each attention layer. + :param num_heads_channels: specified, ignore num_heads and instead use + a fixed channel width per attention head. + :param num_heads_upsample: works with num_heads to set a different number + of heads for upsampling. Deprecated. + :param use_scale_shift_norm: use a FiLM-like conditioning mechanism. + :param resblock_updown: use residual blocks for up/downsampling. + """ + + def __init__( + self, + in_channels, + model_channels, + out_channels, + num_res_blocks, + attention_resolutions, + activation, + encoder_dim, + att_pool_heads, + encoder_channels, + image_size, + disable_self_attentions=None, + dropout=0, + channel_mult=(1, 2, 4, 8), + conv_resample=True, + dims=2, + num_classes=None, + precision='32', + num_heads=1, + num_head_channels=-1, + num_heads_upsample=-1, + use_scale_shift_norm=False, + resblock_updown=False, + efficient_activation=False, + scale_skip_connection=False, + ): + super().__init__() + + if num_heads_upsample == -1: + num_heads_upsample = num_heads + + self.encoder_channels = encoder_channels + self.encoder_dim = encoder_dim + self.efficient_activation = efficient_activation + self.scale_skip_connection = scale_skip_connection + self.in_channels = in_channels + self.model_channels = model_channels + self.out_channels = out_channels + self.dropout = dropout + + self.secondary_device = torch.device("cpu") + + # adapt attention resolutions + if isinstance(attention_resolutions, str): + self.attention_resolutions = [] + for res in attention_resolutions.split(','): + self.attention_resolutions.append(image_size // int(res)) + else: + self.attention_resolutions = attention_resolutions + self.attention_resolutions = tuple(self.attention_resolutions) + # + + # adapt disable self attention resolutions + if not disable_self_attentions: + self.disable_self_attentions = [] + elif disable_self_attentions is True: + self.disable_self_attentions = attention_resolutions + elif isinstance(disable_self_attentions, str): + self.disable_self_attentions = [] + for res in disable_self_attentions.split(','): + self.disable_self_attentions.append(image_size // int(res)) + else: + self.disable_self_attentions = disable_self_attentions + self.disable_self_attentions = tuple(self.disable_self_attentions) + # + + # adapt channel mult + if isinstance(channel_mult, str): + self.channel_mult = tuple(int(ch_mult) for ch_mult in channel_mult.split(',')) + else: + self.channel_mult = tuple(channel_mult) + # + + self.conv_resample = conv_resample + self.num_classes = num_classes + self.dtype = torch.float32 + + self.precision = str(precision) + self.use_fp16 = precision == '16' + if self.precision == '16': + self.dtype = torch.float16 + elif self.precision == 'bf16': + self.dtype = torch.bfloat16 + + self.num_heads = num_heads + self.num_head_channels = num_head_channels + self.num_heads_upsample = num_heads_upsample + + self.time_embed_dim = model_channels * max(self.channel_mult) + self.time_embed = nn.Sequential( + linear(model_channels, self.time_embed_dim, dtype=self.dtype), + get_activation(activation), + linear(self.time_embed_dim, self.time_embed_dim, dtype=self.dtype), + ) + + if self.num_classes is not None: + self.label_emb = nn.Embedding(num_classes, self.time_embed_dim) + + ch = input_ch = int(self.channel_mult[0] * model_channels) + self.input_blocks = nn.ModuleList( + [TimestepEmbedSequential(conv_nd(dims, in_channels, ch, 3, padding=1, dtype=self.dtype))] + ) + self._feature_size = ch + input_block_chans = [ch] + ds = 1 + + if isinstance(num_res_blocks, int): + num_res_blocks = [num_res_blocks] * len(self.channel_mult) + self.num_res_blocks = num_res_blocks + + for level, mult in enumerate(self.channel_mult): + for _ in range(num_res_blocks[level]): + layers = [ + ResBlock( + ch, + self.time_embed_dim, + dropout, + out_channels=int(mult * model_channels), + dims=dims, + use_scale_shift_norm=use_scale_shift_norm, + dtype=self.dtype, + activation=activation, + efficient_activation=self.efficient_activation, + scale_skip_connection=self.scale_skip_connection, + ) + ] + ch = int(mult * model_channels) + if ds in self.attention_resolutions: + layers.append( + AttentionBlock( + ch, + num_heads=num_heads, + num_head_channels=num_head_channels, + encoder_channels=encoder_channels, + dtype=self.dtype, + disable_self_attention=ds in self.disable_self_attentions, + ) + ) + self.input_blocks.append(TimestepEmbedSequential(*layers)) + self._feature_size += ch + input_block_chans.append(ch) + if level != len(self.channel_mult) - 1: + out_ch = ch + self.input_blocks.append( + TimestepEmbedSequential( + ResBlock( + ch, + self.time_embed_dim, + dropout, + out_channels=out_ch, + dims=dims, + use_scale_shift_norm=use_scale_shift_norm, + down=True, + dtype=self.dtype, + activation=activation, + efficient_activation=self.efficient_activation, + scale_skip_connection=self.scale_skip_connection, + ) + if resblock_updown + else Downsample(ch, conv_resample, dims=dims, out_channels=out_ch) + ) + ) + ch = out_ch + input_block_chans.append(ch) + ds *= 2 + self._feature_size += ch + + self.middle_block = TimestepEmbedSequential( + ResBlock( + ch, + self.time_embed_dim, + dropout, + dims=dims, + use_scale_shift_norm=use_scale_shift_norm, + dtype=self.dtype, + activation=activation, + efficient_activation=self.efficient_activation, + scale_skip_connection=self.scale_skip_connection, + ), + AttentionBlock( + ch, + num_heads=num_heads, + num_head_channels=num_head_channels, + encoder_channels=encoder_channels, + dtype=self.dtype, + disable_self_attention=ds in self.disable_self_attentions, + ), + ResBlock( + ch, + self.time_embed_dim, + dropout, + dims=dims, + use_scale_shift_norm=use_scale_shift_norm, + dtype=self.dtype, + activation=activation, + efficient_activation=self.efficient_activation, + scale_skip_connection=self.scale_skip_connection, + ), + ) + self._feature_size += ch + + self.output_blocks = nn.ModuleList([]) + for level, mult in list(enumerate(self.channel_mult))[::-1]: + for i in range(num_res_blocks[level] + 1): + ich = input_block_chans.pop() + layers = [ + ResBlock( + ch + ich, + self.time_embed_dim, + dropout, + out_channels=int(model_channels * mult), + dims=dims, + use_scale_shift_norm=use_scale_shift_norm, + dtype=self.dtype, + activation=activation, + efficient_activation=self.efficient_activation, + scale_skip_connection=self.scale_skip_connection, + ) + ] + ch = int(model_channels * mult) + if ds in self.attention_resolutions: + layers.append( + AttentionBlock( + ch, + num_heads=num_heads_upsample, + num_head_channels=num_head_channels, + encoder_channels=encoder_channels, + dtype=self.dtype, + disable_self_attention=ds in self.disable_self_attentions, + ) + ) + if level and i == num_res_blocks[level]: + out_ch = ch + layers.append( + ResBlock( + ch, + self.time_embed_dim, + dropout, + out_channels=out_ch, + dims=dims, + use_scale_shift_norm=use_scale_shift_norm, + up=True, + dtype=self.dtype, + activation=activation, + efficient_activation=self.efficient_activation, + scale_skip_connection=self.scale_skip_connection, + ) + if resblock_updown + else Upsample(ch, conv_resample, dims=dims, out_channels=out_ch) + ) + ds //= 2 + self.output_blocks.append(TimestepEmbedSequential(*layers)) + self._feature_size += ch + + self.out = nn.Sequential( + normalization(ch, dtype=self.dtype), + get_activation(activation), + zero_module(conv_nd(dims, input_ch, out_channels, 3, padding=1, dtype=self.dtype)), + ) + + self.activation_layer = get_activation(activation) if self.efficient_activation else nn.Identity() + + self.encoder_pooling = nn.Sequential( + nn.LayerNorm(encoder_dim, dtype=self.dtype), + AttentionPooling(att_pool_heads, encoder_dim, dtype=self.dtype), + nn.Linear(encoder_dim, self.time_embed_dim, dtype=self.dtype), + nn.LayerNorm(self.time_embed_dim, dtype=self.dtype) + ) + + if encoder_dim != encoder_channels: + self.encoder_proj = nn.Linear(encoder_dim, encoder_channels, dtype=self.dtype) + else: + self.encoder_proj = nn.Identity() + + self.cache = None + + def collect(self): + gc.collect() + torch.cuda.empty_cache() + + def to(self, x, stage=1): # 0, 1, 2, 3 + if isinstance(x, torch.device): + secondary_device = self.secondary_device + if stage == 1: + self.middle_block.to(secondary_device) + self.output_blocks.to(secondary_device) + self.out.to(secondary_device) + self.collect() + self.time_embed.to(x) + self.encoder_proj.to(x) + self.encoder_pooling.to(x) + self.input_blocks.to(x) + elif stage == 2: + self.time_embed.to(secondary_device) + self.encoder_proj.to(secondary_device) + self.encoder_pooling.to(secondary_device) + self.input_blocks.to(secondary_device) + self.output_blocks.to(secondary_device) + self.out.to(secondary_device) + self.collect() + self.middle_block.to(x) + elif stage == 3: + self.time_embed.to(secondary_device) + self.encoder_proj.to(secondary_device) + self.encoder_pooling.to(secondary_device) + self.input_blocks.to(secondary_device) + self.middle_block.to(secondary_device) + self.collect() + self.output_blocks.to(x) + self.out.to(x) + else: + super().to(x) + + def forward(self, x, timesteps, text_emb, timestep_text_emb=None, aug_emb=None, use_cache=False, **kwargs): + hs = [] + self.to(self.primary_device, stage=1) + emb = self.time_embed(timestep_embedding(timesteps.to(torch.float32), self.model_channels, + dtype=torch.float32).to(self.primary_device).to(self.dtype)) + + if use_cache and self.cache is not None: + encoder_out, encoder_pool = self.cache + else: + text_emb = text_emb.type(self.dtype).to(self.primary_device) + encoder_out = self.encoder_proj(text_emb) + encoder_out = encoder_out.permute(0, 2, 1) # NLC -> NCL + if timestep_text_emb is None: + timestep_text_emb = text_emb + encoder_pool = self.encoder_pooling(timestep_text_emb) + if use_cache: + self.cache = (encoder_out, encoder_pool) + + emb = emb + encoder_pool.to(emb) + + if aug_emb is not None: + emb = emb + aug_emb.to(emb) + + emb = self.activation_layer(emb) + + h = x.type(self.dtype).to(self.primary_device) + + for module in self.input_blocks: + h = module(h, emb, encoder_out) + hs.append(h) + + self.to(self.primary_device, stage=2) + + h = self.middle_block(h, emb, encoder_out) + + self.to(self.primary_device, stage=3) + + for module in self.output_blocks: + h = torch.cat([h, hs.pop()], dim=1) + h = module(h, emb, encoder_out) + h = h.type(self.dtype) + h = self.out(h) + return h + + +class SuperResUNetModel(UNetSplitModel): + """ + A text2im model that performs super-resolution. + Expects an extra kwarg `low_res` to condition on a low-resolution image. + """ + + def __init__(self, low_res_diffusion, interpolate_mode='bilinear', *args, **kwargs): + self.low_res_diffusion = low_res_diffusion + self.interpolate_mode = interpolate_mode + super().__init__(*args, **kwargs) + + self.aug_proj = nn.Sequential( + linear(self.model_channels, self.time_embed_dim, dtype=self.dtype), + get_activation(kwargs['activation']), + linear(self.time_embed_dim, self.time_embed_dim, dtype=self.dtype), + ) + + def forward(self, x, timesteps, low_res, aug_level=None, **kwargs): + bs, _, new_height, new_width = x.shape + + align_corners = True + if self.interpolate_mode == 'nearest': + align_corners = None + + upsampled = F.interpolate( + low_res, (new_height, new_width), mode=self.interpolate_mode, align_corners=align_corners + ) + + if aug_level is None: + aug_steps = (np.random.random(bs) * 1000).astype(np.int64) # uniform [0, 1) + aug_steps = torch.from_numpy(aug_steps).to(x.device, dtype=torch.long) + else: + aug_steps = torch.tensor([int(aug_level * 1000)]).repeat(bs).to(x.device, dtype=torch.long) + + upsampled = self.low_res_diffusion.q_sample(upsampled, aug_steps) + x = torch.cat([x, upsampled], dim=1) + + aug_emb = self.aug_proj( + timestep_embedding(aug_steps, self.model_channels, dtype=self.dtype) + ) + return super().forward(x, timesteps, aug_emb=aug_emb, **kwargs) diff --git a/deepfloyd_if/modules/base.py b/deepfloyd_if/modules/base.py index c808a3c..797ed32 100644 --- a/deepfloyd_if/modules/base.py +++ b/deepfloyd_if/modules/base.py @@ -14,14 +14,12 @@ from huggingface_hub import hf_hub_download from accelerate.utils import set_module_tensor_to_device - from .. import utils from ..model.respace import create_gaussian_diffusion from .utils import load_model_weights, predict_proba, clip_process_generations class IFBaseModule: - stage = '-' available_models = [] @@ -68,29 +66,29 @@ def use_diffusers(self): return False def embeddings_to_image( - self, t5_embs, low_res=None, *, - style_t5_embs=None, - positive_t5_embs=None, - negative_t5_embs=None, - batch_repeat=1, - dynamic_thresholding_p=0.95, - sample_loop='ddpm', - sample_timestep_respacing='smart185', - dynamic_thresholding_c=1.5, - guidance_scale=7.0, - aug_level=0.25, - positive_mixer=0.15, - blur_sigma=None, - img_size=None, - img_scale=4.0, - aspect_ratio='1:1', - progress=True, - seed=None, - sample_fn=None, - support_noise=None, - support_noise_less_qsample_steps=0, - inpainting_mask=None, - **kwargs, + self, t5_embs, low_res=None, *, + style_t5_embs=None, + positive_t5_embs=None, + negative_t5_embs=None, + batch_repeat=1, + dynamic_thresholding_p=0.95, + sample_loop='ddpm', + sample_timestep_respacing='smart185', + dynamic_thresholding_c=1.5, + guidance_scale=7.0, + aug_level=0.25, + positive_mixer=0.15, + blur_sigma=None, + img_size=None, + img_scale=4.0, + aspect_ratio='1:1', + progress=True, + seed=None, + sample_fn=None, + support_noise=None, + support_noise_less_qsample_steps=0, + inpainting_mask=None, + **kwargs, ): self._clear_cache() image_w, image_h = self._get_image_sizes(low_res, img_size, aspect_ratio, img_scale) @@ -100,13 +98,13 @@ def embeddings_to_image( def model_fn(x_t, ts, **kwargs): half = x_t[: len(x_t) // bs_scale] - combined = torch.cat([half]*bs_scale, dim=0) + combined = torch.cat([half] * bs_scale, dim=0) model_out = self.model(combined, ts, **kwargs) eps, rest = model_out[:, :3], model_out[:, 3:] if bs_scale == 3: cond_eps, pos_cond_eps, uncond_eps = torch.split(eps, len(eps) // bs_scale, dim=0) half_eps = uncond_eps + guidance_scale * ( - cond_eps * (1 - positive_mixer) + pos_cond_eps * positive_mixer - uncond_eps) + cond_eps * (1 - positive_mixer) + pos_cond_eps * positive_mixer - uncond_eps) pos_half_eps = uncond_eps + guidance_scale * (pos_cond_eps - uncond_eps) eps = torch.cat([half_eps, pos_half_eps, half_eps], dim=0) else: @@ -170,7 +168,7 @@ def model_fn(x_t, ts, **kwargs): if low_res is not None: if blur_sigma is not None: low_res = T.GaussianBlur(3, sigma=(blur_sigma, blur_sigma))(low_res) - model_kwargs['low_res'] = torch.cat([low_res]*bs_scale, dim=0).to(self.device) + model_kwargs['low_res'] = torch.cat([low_res] * bs_scale, dim=0).to(self.device) model_kwargs['aug_level'] = aug_level if support_noise is None: @@ -186,7 +184,7 @@ def model_fn(x_t, ts, **kwargs): support_noise[inpainting_mask.cpu().bool() if inpainting_mask is not None else ...], q_sample_steps, ) - noise = noise.repeat(batch_size*bs_scale, 1, 1, 1).to(device=self.device, dtype=self.model.dtype) + noise = noise.repeat(batch_size * bs_scale, 1, 1, 1).to(device=self.device, dtype=self.model.dtype) if inpainting_mask is not None: inpainting_mask = inpainting_mask.to(device=self.device, dtype=torch.long) @@ -202,7 +200,7 @@ def model_fn(x_t, ts, **kwargs): dynamic_thresholding_p=dynamic_thresholding_p, dynamic_thresholding_c=dynamic_thresholding_c, inpainting_mask=inpainting_mask, - device=self.device, + device=self.model.primary_device, progress=progress, sample_fn=sample_fn, )[:batch_size] @@ -216,7 +214,7 @@ def model_fn(x_t, ts, **kwargs): model_kwargs=model_kwargs, dynamic_thresholding_p=dynamic_thresholding_p, dynamic_thresholding_c=dynamic_thresholding_c, - device=self.device, + device=self.model.primary_device, progress=progress, sample_fn=sample_fn, )[:batch_size] @@ -311,7 +309,7 @@ def to_images(self, generations, disable_watermark=False): def show(self, pil_images, nrow=None, size=10): if nrow is None: - nrow = round(len(pil_images)**0.5) + nrow = round(len(pil_images) ** 0.5) imgs = torchvision.utils.make_grid(utils.pil_list_to_torch_tensors(pil_images), nrow=nrow) if not isinstance(imgs, list): @@ -333,16 +331,16 @@ def _clear_cache(self): def _get_image_sizes(self, low_res, img_size, aspect_ratio, img_scale): if low_res is not None: bs, c, h, w = low_res.shape - image_h, image_w = int((h*img_scale)//32)*32, int((w*img_scale//32))*32 + image_h, image_w = int((h * img_scale) // 32) * 32, int((w * img_scale // 32)) * 32 else: scale_w, scale_h = aspect_ratio.split(':') scale_w, scale_h = int(scale_w), int(scale_h) coef = scale_w / scale_h image_h, image_w = img_size, img_size if coef >= 1: - image_w = int(round(img_size/8 * coef) * 8) + image_w = int(round(img_size / 8 * coef) * 8) else: - image_h = int(round(img_size/8 / coef) * 8) + image_h = int(round(img_size / 8 / coef) * 8) assert image_h % 8 == 0 assert image_w % 8 == 0 diff --git a/deepfloyd_if/modules/stage_I.py b/deepfloyd_if/modules/stage_I.py index 8b1c5f3..8b3c62d 100644 --- a/deepfloyd_if/modules/stage_I.py +++ b/deepfloyd_if/modules/stage_I.py @@ -1,15 +1,16 @@ # -*- coding: utf-8 -*- import accelerate +import torch from .base import IFBaseModule -from ..model import UNetModel +from ..model import UNetModel, UNetSplitModel class IFStageI(IFBaseModule): stage = 'I' available_models = ['IF-I-M-v1.0', 'IF-I-L-v1.0', 'IF-I-XL-v1.0'] - def __init__(self, *args, model_kwargs=None, pil_img_size=64, **kwargs): + def __init__(self, *args, model_kwargs=None, pil_img_size=64, use_split=True, **kwargs): """ :param conf_or_path: :param device: @@ -19,12 +20,16 @@ def __init__(self, *args, model_kwargs=None, pil_img_size=64, **kwargs): super().__init__(*args, pil_img_size=pil_img_size, **kwargs) model_params = dict(self.conf.params) model_params.update(model_kwargs or {}) + UNetClass = UNetSplitModel if use_split else UNetModel with accelerate.init_empty_weights(): - self.model = UNetModel(**model_params) + self.model = UNetClass(**model_params) self.model = self.load_checkpoint(self.model, self.dir_or_name) self.model.eval().to(self.device) - def to(self, x): + def to(self, x, stage=1, secondary_device=torch.device("cpu")): # 0, 1, 2, 3 + if isinstance(x, torch.device): + self.model.primary_device = x + self.model.secondary_device = secondary_device self.model.to(x) def embeddings_to_image(self, t5_embs, style_t5_embs=None, positive_t5_embs=None, negative_t5_embs=None, diff --git a/deepfloyd_if/modules/stage_III_sd_x4.py b/deepfloyd_if/modules/stage_III_sd_x4.py index 307fad2..7599317 100644 --- a/deepfloyd_if/modules/stage_III_sd_x4.py +++ b/deepfloyd_if/modules/stage_III_sd_x4.py @@ -34,8 +34,8 @@ def __init__(self, *args, model_kwargs=None, pil_img_size=1024, **kwargs): self.model = DiffusionPipeline.from_pretrained(model_id, torch_dtype=torch_dtype, token=self.hf_token) self.model.to(self.device) - if bool(os.environ.get('FORCE_MEM_EFFICIENT_ATTN')): - self.model.enable_xformers_memory_efficient_attention() + # if bool(os.environ.get('FORCE_MEM_EFFICIENT_ATTN')): + self.model.enable_xformers_memory_efficient_attention() @property def use_diffusers(self): diff --git a/deepfloyd_if/pipelines/dream.py b/deepfloyd_if/pipelines/dream.py index 13d7479..0b98621 100644 --- a/deepfloyd_if/pipelines/dream.py +++ b/deepfloyd_if/pipelines/dream.py @@ -5,22 +5,22 @@ def dream( - t5, - if_I, - if_II=None, - if_III=None, - *, - prompt, - style_prompt=None, - negative_prompt=None, - seed=None, - aspect_ratio='1:1', - if_I_kwargs=None, - if_II_kwargs=None, - if_III_kwargs=None, - progress=True, - return_tensors=False, - disable_watermark=False, + t5, + if_I, + if_II=None, + if_III=None, + *, + prompt, + style_prompt=None, + negative_prompt=None, + seed=None, + aspect_ratio='1:1', + if_I_kwargs=None, + if_II_kwargs=None, + if_III_kwargs=None, + progress=True, + return_tensors=False, + disable_watermark=False, ): """ Generate pictures using text description! @@ -108,18 +108,18 @@ def dream( stageIII_generations = [] for idx in range(len(stageII_generations)): if if_III.use_diffusers: - if_III_kwargs['prompt'] = prompt[idx: idx+1] + if_III_kwargs['prompt'] = prompt[idx: idx + 1] - if_III_kwargs['low_res'] = stageII_generations[idx:idx+1] + if_III_kwargs['low_res'] = stageII_generations[idx:idx + 1] if_III_kwargs['seed'] = seed - if_III_kwargs['t5_embs'] = t5_embs[idx:idx+1] + if_III_kwargs['t5_embs'] = t5_embs[idx:idx + 1] if_III_kwargs['progress'] = progress style_t5_embs = if_I_kwargs.get('style_t5_embs') if style_t5_embs is not None: - style_t5_embs = style_t5_embs[idx:idx+1] + style_t5_embs = style_t5_embs[idx:idx + 1] positive_t5_embs = if_I_kwargs.get('positive_t5_embs') if positive_t5_embs is not None: - positive_t5_embs = positive_t5_embs[idx:idx+1] + positive_t5_embs = positive_t5_embs[idx:idx + 1] if_III_kwargs['style_t5_embs'] = style_t5_embs if_III_kwargs['positive_t5_embs'] = positive_t5_embs diff --git a/deepfloyd_if/pipelines/optimized_dream.py b/deepfloyd_if/pipelines/optimized_dream.py index 6620b4b..2eeb0df 100644 --- a/deepfloyd_if/pipelines/optimized_dream.py +++ b/deepfloyd_if/pipelines/optimized_dream.py @@ -35,7 +35,7 @@ def run_stage1( guidance_scale=guidance_scale_1, sample_timestep_respacing=custom_timesteps_1, seed=seed - ).images + ) pil_images_I = model.to_images(images, disable_watermark=True) return pil_images_I diff --git a/run_ui.py b/run_ui.py index 8e09648..73f5dec 100644 --- a/run_ui.py +++ b/run_ui.py @@ -1,8 +1,10 @@ import argparse import gc import os +import time import numpy as np +from accelerate import dispatch_model from deepfloyd_if.modules import IFStageI, IFStageII, StableStageIII from deepfloyd_if.modules.t5 import T5Embedder @@ -15,13 +17,6 @@ from ui_files.utils import randomize_seed_fn, show_gallery_view, update_upscale_button, get_stage2_index, \ check_if_stage2_selected, show_upscaled_view, get_device_map -try: - import xformers - - os.environ["FORCE_MEM_EFFICIENT_ATTN"] = "1" -except: - pass - device = torch.device(0) if_I = IFStageI('IF-I-XL-v1.0', device=torch.device("cpu")) if_I.to(torch.float16) # half @@ -37,13 +32,12 @@ def switch_devices(stage): if stage == 1: # t5.model.cpu() - del t5.model + dispatch_model(t5.model, get_device_map(t5_device, all2cpu=True)) + gc.collect() + torch.cuda.empty_cache() + torch.cuda.synchronize() if_I.to(torch.device(0)) - gc.collect() - torch.cuda.empty_cache() - torch.cuda.synchronize() - def process_and_run_stage1(prompt, negative_prompt, @@ -157,7 +151,7 @@ def create_ui(args): minimum=1, maximum=4, step=1, - value=4, + value=1, visible=True) with gr.Tab(label='Super-resolution 1'): seed_2 = gr.Slider(label='Seed', diff --git a/ui_files/utils.py b/ui_files/utils.py index a80d80d..541a577 100644 --- a/ui_files/utils.py +++ b/ui_files/utils.py @@ -2,6 +2,8 @@ import numpy as np import random +import torch + def _update_result_view(show_gallery: bool) -> tuple[dict, dict]: return gr.update(visible=show_gallery), gr.update(visible=not show_gallery) @@ -39,7 +41,8 @@ def check_if_stage2_selected(index: int) -> None: ) -def get_device_map(device): +def get_device_map(device, all2cpu=False): + device = device if not all2cpu else torch.device("cpu") return { 'shared': device, 'encoder.embed_tokens': device, From d34388f32384dd26f72ee636ee5631c06daad540 Mon Sep 17 00:00:00 2001 From: neon Date: Sat, 29 Apr 2023 21:26:38 +0200 Subject: [PATCH 03/10] step 2 --- deepfloyd_if/model/gaussian_diffusion.py | 2 +- deepfloyd_if/model/nn.py | 9 +- deepfloyd_if/model/unet.py | 12 +- deepfloyd_if/modules/base.py | 9 +- deepfloyd_if/modules/stage_II.py | 6 +- deepfloyd_if/pipelines/optimized_dream.py | 32 +++-- run_ui.py | 158 ++++++++++++---------- 7 files changed, 125 insertions(+), 103 deletions(-) diff --git a/deepfloyd_if/model/gaussian_diffusion.py b/deepfloyd_if/model/gaussian_diffusion.py index 596c1b7..2394b31 100644 --- a/deepfloyd_if/model/gaussian_diffusion.py +++ b/deepfloyd_if/model/gaussian_diffusion.py @@ -875,7 +875,7 @@ def _extract_into_tensor(arr, timesteps, broadcast_shape): dimension equal to the length of timesteps. :return: a tensor of shape [batch_size, 1, ...] where the shape has K dims. """ - res = torch.from_numpy(arr).to(device=timesteps.device)[timesteps].float() + res = torch.from_numpy(arr).to(device=timesteps.device)[timesteps.to(torch.long)].float() while len(res.shape) < len(broadcast_shape): res = res[..., None] return res.expand(broadcast_shape) diff --git a/deepfloyd_if/model/nn.py b/deepfloyd_if/model/nn.py index 4f1a0f0..a957825 100644 --- a/deepfloyd_if/model/nn.py +++ b/deepfloyd_if/model/nn.py @@ -171,17 +171,16 @@ def timestep_embedding(timesteps, dim, max_period=10000, dtype=None): :param max_period: controls the minimum frequency of the embeddings. :return: an [N x dim] Tensor of positional embeddings. """ - if dtype is None: - dtype = torch.float32 + dtype2 = torch.float32 half = dim // 2 freqs = torch.exp( -math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half - ).to(device=timesteps.device, dtype=dtype) - args = timesteps[:, None].type(dtype) * freqs[None] + ).to(device=timesteps.device, dtype=dtype2) + args = timesteps[:, None].type(dtype2) * freqs[None] embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) if dim % 2: embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1) - return embedding + return embedding.to(dtype) def attention(q, k, v, d_k): diff --git a/deepfloyd_if/model/unet.py b/deepfloyd_if/model/unet.py index 2fbc4aa..d82fdeb 100644 --- a/deepfloyd_if/model/unet.py +++ b/deepfloyd_if/model/unet.py @@ -629,7 +629,7 @@ def forward(self, x, timesteps, text_emb, timestep_text_emb=None, aug_emb=None, if use_cache and self.cache is not None: encoder_out, encoder_pool = self.cache else: - text_emb = text_emb.type(self.dtype) + text_emb = text_emb.type(self.dtype).to(x.device) encoder_out = self.encoder_proj(text_emb) encoder_out = encoder_out.permute(0, 2, 1) # NLC -> NCL if timestep_text_emb is None: @@ -674,6 +674,7 @@ def __init__(self, low_res_diffusion, interpolate_mode='bilinear', *args, **kwar get_activation(kwargs['activation']), linear(self.time_embed_dim, self.time_embed_dim, dtype=self.dtype), ) + self.primary_device = torch.device(0) def forward(self, x, timesteps, low_res, aug_level=None, **kwargs): bs, _, new_height, new_width = x.shape @@ -684,7 +685,7 @@ def forward(self, x, timesteps, low_res, aug_level=None, **kwargs): upsampled = F.interpolate( low_res, (new_height, new_width), mode=self.interpolate_mode, align_corners=align_corners - ) + ).to(x.device) if aug_level is None: aug_steps = (np.random.random(bs) * 1000).astype(np.int64) # uniform [0, 1) @@ -692,10 +693,11 @@ def forward(self, x, timesteps, low_res, aug_level=None, **kwargs): else: aug_steps = torch.tensor([int(aug_level * 1000)]).repeat(bs).to(x.device, dtype=torch.long) + aug_steps = aug_steps.to(self.dtype) + upsampled = self.low_res_diffusion.q_sample(upsampled, aug_steps) x = torch.cat([x, upsampled], dim=1) - aug_emb = self.aug_proj( - timestep_embedding(aug_steps, self.model_channels, dtype=self.dtype) - ) + timestep_embedding(aug_steps, self.model_channels, dtype=self.dtype).to(self.dtype) + ).to(x.device) return super().forward(x, timesteps, aug_emb=aug_emb, **kwargs) diff --git a/deepfloyd_if/modules/base.py b/deepfloyd_if/modules/base.py index 797ed32..d08453b 100644 --- a/deepfloyd_if/modules/base.py +++ b/deepfloyd_if/modules/base.py @@ -88,8 +88,11 @@ def embeddings_to_image( support_noise=None, support_noise_less_qsample_steps=0, inpainting_mask=None, + device=None, **kwargs, ): + if device is None: + device = self.model.primary_device self._clear_cache() image_w, image_h = self._get_image_sizes(low_res, img_size, aspect_ratio, img_scale) diffusion = self.get_diffusion(sample_timestep_respacing) @@ -98,7 +101,7 @@ def embeddings_to_image( def model_fn(x_t, ts, **kwargs): half = x_t[: len(x_t) // bs_scale] - combined = torch.cat([half] * bs_scale, dim=0) + combined = torch.cat([half] * bs_scale, dim=0).to(device) model_out = self.model(combined, ts, **kwargs) eps, rest = model_out[:, :3], model_out[:, 3:] if bs_scale == 3: @@ -200,7 +203,7 @@ def model_fn(x_t, ts, **kwargs): dynamic_thresholding_p=dynamic_thresholding_p, dynamic_thresholding_c=dynamic_thresholding_c, inpainting_mask=inpainting_mask, - device=self.model.primary_device, + device=device, progress=progress, sample_fn=sample_fn, )[:batch_size] @@ -214,7 +217,7 @@ def model_fn(x_t, ts, **kwargs): model_kwargs=model_kwargs, dynamic_thresholding_p=dynamic_thresholding_p, dynamic_thresholding_c=dynamic_thresholding_c, - device=self.model.primary_device, + device=device, progress=progress, sample_fn=sample_fn, )[:batch_size] diff --git a/deepfloyd_if/modules/stage_II.py b/deepfloyd_if/modules/stage_II.py index d14b838..d4eb0af 100644 --- a/deepfloyd_if/modules/stage_II.py +++ b/deepfloyd_if/modules/stage_II.py @@ -22,7 +22,7 @@ def embeddings_to_image( self, low_res, t5_embs, style_t5_embs=None, positive_t5_embs=None, negative_t5_embs=None, batch_repeat=1, aug_level=0.25, dynamic_thresholding_p=0.95, dynamic_thresholding_c=1.0, sample_loop='ddpm', sample_timestep_respacing='smart50', guidance_scale=4.0, img_scale=4.0, positive_mixer=0.5, - progress=True, seed=None, sample_fn=None, **kwargs): + progress=True, seed=None, sample_fn=None, device=None, **kwargs): return super().embeddings_to_image( t5_embs=t5_embs, low_res=low_res, @@ -42,5 +42,9 @@ def embeddings_to_image( progress=progress, seed=seed, sample_fn=sample_fn, + device=device, **kwargs ) + + def to(self, x): + self.model.to(x) diff --git a/deepfloyd_if/pipelines/optimized_dream.py b/deepfloyd_if/pipelines/optimized_dream.py index 2eeb0df..dfc9e15 100644 --- a/deepfloyd_if/pipelines/optimized_dream.py +++ b/deepfloyd_if/pipelines/optimized_dream.py @@ -29,6 +29,9 @@ def run_stage1( ): run_garbage_collection() + if custom_timesteps_1 == "none": + custom_timesteps_1 = str(num_inference_steps_1) + images, _ = model.embeddings_to_image(t5_embs=t5_embs, negative_t5_embs=negative_t5_embs, num_images_per_prompt=num_images, @@ -38,34 +41,33 @@ def run_stage1( ) pil_images_I = model.to_images(images, disable_watermark=True) - return pil_images_I + return images, pil_images_I def run_stage2( model, - stage1_result, - stage2_index: int, - seed_2: int = 0, - guidance_scale_2: float = 4.0, + t5_embs, + negative_t5_embs, + images, + seed: int = 0, + guidance_scale: float = 4.0, custom_timesteps_2: str = 'smart50', num_inference_steps_2: int = 50, disable_watermark: bool = True, + device=None ) -> Image: run_garbage_collection() - prompt_embeds = stage1_result['prompt_embeds'] - negative_embeds = stage1_result['negative_embeds'] - images = stage1_result['images'] - images = images[[stage2_index]] - + if custom_timesteps_2 == "none": + custom_timesteps_2 = str(num_inference_steps_2) stageII_generations, _ = model.embeddings_to_image(low_res=images, - t5_embs=prompt_embeds, - negative_t5_embs=negative_embeds, - guidance_scale=guidance_scale_2, + t5_embs=t5_embs, + negative_t5_embs=negative_t5_embs, + guidance_scale=guidance_scale, sample_timestep_respacing=custom_timesteps_2, - seed=seed_2) + seed=seed, device=device) pil_images_II = model.to_images(stageII_generations, disable_watermark=disable_watermark) - + print(pil_images_II) return pil_images_II diff --git a/run_ui.py b/run_ui.py index 73f5dec..c102b8c 100644 --- a/run_ui.py +++ b/run_ui.py @@ -20,7 +20,8 @@ device = torch.device(0) if_I = IFStageI('IF-I-XL-v1.0', device=torch.device("cpu")) if_I.to(torch.float16) # half -# # if_II = IFStageII('IF-II-L-v1.0', device=torch.device("cpu")) +if_II = IFStageII('IF-II-L-v1.0', device=torch.device("cpu")) +if_I.to(torch.float16) # half # # if_III = StableStageIII('stable-diffusion-x4-upscaler', device=torch.device("cpu")) t5_device = torch.device(0) t5 = T5Embedder(device=t5_device, t5_model_kwargs={"low_cpu_mem_usage": True, @@ -37,6 +38,12 @@ def switch_devices(stage): torch.cuda.empty_cache() torch.cuda.synchronize() if_I.to(torch.device(0)) + elif stage == 2: + if_I.to(torch.device("cpu")) + gc.collect() + torch.cuda.empty_cache() + torch.cuda.synchronize() + if_II.to(torch.device(0)) def process_and_run_stage1(prompt, @@ -46,26 +53,47 @@ def process_and_run_stage1(prompt, guidance_scale_1, custom_timesteps_1, num_inference_steps_1): + global t5_embs, negative_t5_embs, images print("Encoding prompts..") - prompt = t5.get_text_embeddings(prompt) - if negative_prompt == "": - negative_prompt = torch.zeros_like(prompt) - else: - negative_prompt = t5.get_text_embeddings(negative_prompt) + t5_embs = t5.get_text_embeddings([prompt]) + negative_t5_embs = t5.get_text_embeddings([negative_prompt]) switch_devices(stage=1) - prompt = prompt.to(if_I.device) - negative_prompt = negative_prompt.to(if_I.device) + t5_embs = t5_embs.to(if_I.device) + negative_t5_embs = negative_t5_embs.to(if_I.device) print("Encoded. Running 1st stage") - return run_stage1( + images, images_ret = run_stage1( if_I, - t5_embs=prompt, - negative_t5_embs=negative_prompt, + t5_embs=t5_embs, + negative_t5_embs=negative_t5_embs, seed=seed_1, num_images=num_images, guidance_scale_1=guidance_scale_1, custom_timesteps_1=custom_timesteps_1, num_inference_steps_1=num_inference_steps_1 ) + return images_ret + + +def process_and_run_stage2( + index, + seed_2, + guidance_scale_2, + custom_timesteps_2, + num_inference_steps_2): + global t5_embs, negative_t5_embs, images + print("Stage 2..") + switch_devices(stage=2) + return run_stage2( + if_II, + t5_embs=t5_embs, + negative_t5_embs=negative_t5_embs, + images=images[index].unsqueeze(0).to(device), + seed=seed_2, + guidance_scale=guidance_scale_2, + custom_timesteps_2=custom_timesteps_2, + num_inference_steps_2=num_inference_steps_2, + device=device + ) def create_ui(args): @@ -105,11 +133,11 @@ def create_ui(args): interactive=False, visible=not args.DISABLE_SD_X4_UPSCALER) with gr.Column(visible=False) as upscale_view: - result = gr.Image(label='Result', - show_label=False, - type='filepath', - interactive=False, - elem_id='upscaled-image').style(height=640) + result = gr.Gallery(label='Result', + show_label=False, + elem_id='upscaled-image').style( + columns=args.GALLERY_COLUMN_NUM, + object_fit='contain') back_to_selection_button = gr.Button('Back to selection') with gr.Accordion('Advanced options', open=False, @@ -138,7 +166,7 @@ def create_ui(args): 'smart100', 'smart185', ], - value="smart100", + value="fast27", visible=True) num_inference_steps_1 = gr.Slider( label='Number of inference steps', @@ -176,7 +204,7 @@ def create_ui(args): 'smart100', 'smart185', ], - value="smart50", + value="smart27", visible=True) num_inference_steps_2 = gr.Slider( label='Number of inference steps', @@ -228,61 +256,45 @@ def create_ui(args): outputs=selected_index_for_stage2, queue=False, ) - # - # selected_index_for_stage2.change( - # fn=update_upscale_button, - # inputs=selected_index_for_stage2, - # outputs=[ - # upscale_button, - # upscale_to_256_button, - # ], - # queue=False, - # ) - # - # stage2_inputs = [ - # stage1_result_path, - # selected_index_for_stage2, - # seed_2, - # guidance_scale_2, - # custom_timesteps_2, - # num_inference_steps_2, - # ] - # - # upscale_to_256_button.click( - # fn=check_if_stage2_selected, - # inputs=selected_index_for_stage2, - # queue=False, - # ).then( - # fn=randomize_seed_fn, - # inputs=[seed_2, randomize_seed_2], - # outputs=seed_2, - # queue=False, - # ).then( - # fn=show_upscaled_view, - # outputs=[ - # gallery_view, - # upscale_view, - # ], - # queue=False, - # ).then( - # fn=run_stage2, - # inputs=stage2_inputs, - # outputs=result, - # api_name='upscale256', - # ) # .success( - # # fn=upload_stage2_info, - # # inputs=[ - # # stage1_param_file_hash_name, - # # result, - # # selected_index_for_stage2, - # # seed_2, - # # guidance_scale_2, - # # custom_timesteps_2, - # # num_inference_steps_2, - # # ], - # # queue=False, - # # ) - # + + selected_index_for_stage2.change( + fn=update_upscale_button, + inputs=selected_index_for_stage2, + outputs=[ + upscale_button, + upscale_to_256_button, + ], + queue=False, + ) + + upscale_to_256_button.click( + fn=check_if_stage2_selected, + inputs=selected_index_for_stage2, + queue=False, + ).then( + fn=randomize_seed_fn, + inputs=[seed_2, randomize_seed_2], + outputs=seed_2, + queue=False, + ).then( + fn=show_upscaled_view, + outputs=[ + gallery_view, + upscale_view, + ], + queue=False, + ).then( + fn=process_and_run_stage2, + inputs=[ + selected_index_for_stage2, + seed_2, + guidance_scale_2, + custom_timesteps_2, + num_inference_steps_2, + ], + outputs=result, + ) + # stage2_3_inputs = [ # stage1_result_path, # selected_index_for_stage2, From e4af95d08c42f17cbd3b8201606b94821d708ec7 Mon Sep 17 00:00:00 2001 From: neon Date: Sat, 29 Apr 2023 22:11:49 +0200 Subject: [PATCH 04/10] stage 3 --- README.md | 126 ++++++++++++----- deepfloyd_if/modules/stage_III_sd_x4.py | 14 +- deepfloyd_if/pipelines/optimized_dream.py | 37 +++-- run_ui.py | 157 ++++++++++++---------- 4 files changed, 205 insertions(+), 129 deletions(-) diff --git a/README.md b/README.md index 4a8b803..e1a72c1 100644 --- a/README.md +++ b/README.md @@ -1,3 +1,21 @@ +# Optimized DeepFloyd IF by neonsecret + +Tested on rtx 3070, 8 gb vram + +stage 1: ~3.5 sec per iteration +stage 2: 11 seconds for 27 steps +stage 3: 2 seconds for 40 steps + +### To run the ui: + +```bash +python run_ui.py +``` + +All the models are automatically downloaded. + +original readme: + [![License](https://img.shields.io/badge/Code_License-Modified_MIT-blue.svg)](LICENSE) [![License](https://img.shields.io/badge/Weights_License-DeepFloyd_IF-orange.svg)](LICENSE-MODEL) [![Downloads](https://pepy.tech/badge/deepfloyd_if)](https://pepy.tech/project/deepfloyd_if) @@ -11,21 +29,32 @@

-We introduce DeepFloyd IF, a novel state-of-the-art open-source text-to-image model with a high degree of photorealism and language understanding. DeepFloyd IF is a modular composed of a frozen text encoder and three cascaded pixel diffusion modules: a base model that generates 64x64 px image based on text prompt and two super-resolution models, each designed to generate images of increasing resolution: 256x256 px and 1024x1024 px. All stages of the model utilize a frozen text encoder based on the T5 transformer to extract text embeddings, which are then fed into a UNet architecture enhanced with cross-attention and attention pooling. The result is a highly efficient model that outperforms current state-of-the-art models, achieving a zero-shot FID score of 6.66 on the COCO dataset. Our work underscores the potential of larger UNet architectures in the first stage of cascaded diffusion models and depicts a promising future for text-to-image synthesis. +We introduce DeepFloyd IF, a novel state-of-the-art open-source text-to-image model with a high degree of photorealism +and language understanding. DeepFloyd IF is a modular composed of a frozen text encoder and three cascaded pixel +diffusion modules: a base model that generates 64x64 px image based on text prompt and two super-resolution models, each +designed to generate images of increasing resolution: 256x256 px and 1024x1024 px. All stages of the model utilize a +frozen text encoder based on the T5 transformer to extract text embeddings, which are then fed into a UNet architecture +enhanced with cross-attention and attention pooling. The result is a highly efficient model that outperforms current +state-of-the-art models, achieving a zero-shot FID score of 6.66 on the COCO dataset. Our work underscores the potential +of larger UNet architectures in the first stage of cascaded diffusion models and depicts a promising future for +text-to-image synthesis.

-*Inspired by* [*Photorealistic Text-to-Image Diffusion Models with Deep Language Understanding*](https://arxiv.org/pdf/2205.11487.pdf) +*Inspired by* [*Photorealistic Text-to-Image Diffusion Models with Deep Language +Understanding*](https://arxiv.org/pdf/2205.11487.pdf) ## Minimum requirements to use all IF models: + - 16GB vRAM for IF-I-XL (4.3B text to 64x64 base module) & IF-II-L (1.2B to 256x256 upscaler module) -- 24GB vRAM for IF-I-XL (4.3B text to 64x64 base module) & IF-II-L (1.2B to 256x256 upscaler module) & Stable x4 (to 1024x1024 upscaler) +- 24GB vRAM for IF-I-XL (4.3B text to 64x64 base module) & IF-II-L (1.2B to 256x256 upscaler module) & Stable x4 (to + 1024x1024 upscaler) - `xformers` and set env variable `FORCE_MEM_EFFICIENT_ATTN=1` - ## Quick Start + [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/huggingface/notebooks/blob/main/diffusers/deepfloyd_if_free_tier_google_colab.ipynb) [![Hugging Face Spaces](https://img.shields.io/badge/%F0%9F%A4%97%20Hugging%20Face-Spaces-blue)](https://huggingface.co/spaces/DeepFloyd/IF) @@ -36,25 +65,28 @@ pip install git+https://github.com/openai/CLIP.git --no-deps ``` ## Local notebooks + [![Jupyter Notebook](https://img.shields.io/badge/jupyter_notebook-%23FF7A01.svg?logo=jupyter&logoColor=white)](https://huggingface.co/DeepFloyd/IF-notebooks/blob/main/pipes-DeepFloyd-IF-v1.0.ipynb) [![Kaggle](https://kaggle.com/static/images/open-in-kaggle.svg)](https://www.kaggle.com/code/shonenkov/deepfloyd-if-4-3b-generator-of-pictures) -The Dream, Style Transfer, Super Resolution or Inpainting modes are avaliable in a Jupyter Notebook [here](https://huggingface.co/DeepFloyd/IF-notebooks/blob/main/pipes-DeepFloyd-IF-v1.0.ipynb). - - +The Dream, Style Transfer, Super Resolution or Inpainting modes are avaliable in a Jupyter +Notebook [here](https://huggingface.co/DeepFloyd/IF-notebooks/blob/main/pipes-DeepFloyd-IF-v1.0.ipynb). ## Integration with 🤗 Diffusers IF is also integrated with the 🤗 Hugging Face [Diffusers library](https://github.com/huggingface/diffusers/). -Diffusers runs each stage individually allowing the user to customize the image generation process as well as allowing to inspect intermediate results easily. +Diffusers runs each stage individually allowing the user to customize the image generation process as well as allowing +to inspect intermediate results easily. ### Example Before you can use IF, you need to accept its usage conditions. To do so: + 1. Make sure to have a [Hugging Face account](https://huggingface.co/join) and be loggin in 2. Accept the license on the model card of [DeepFloyd/IF-I-XL-v1.0](https://huggingface.co/DeepFloyd/IF-I-XL-v1.0) 3. Make sure to login locally. Install `huggingface_hub` + ```sh pip install huggingface_hub --upgrade ``` @@ -67,7 +99,8 @@ from huggingface_hub import login login() ``` -and enter your [Hugging Face Hub access token](https://huggingface.co/docs/hub/security-tokens#what-are-user-access-tokens). +and enter +your [Hugging Face Hub access token](https://huggingface.co/docs/hub/security-tokens#what-are-user-access-tokens). Next we install `diffusers` and dependencies: @@ -77,7 +110,9 @@ pip install diffusers accelerate transformers safetensors And we can now run the model locally. -By default `diffusers` makes use of [model cpu offloading](https://huggingface.co/docs/diffusers/optimization/fp16#model-offloading-for-fast-inference-and-memory-savings) to run the whole IF pipeline with as little as 14 GB of VRAM. +By default `diffusers` makes use +of [model cpu offloading](https://huggingface.co/docs/diffusers/optimization/fp16#model-offloading-for-fast-inference-and-memory-savings) +to run the whole IF pipeline with as little as 14 GB of VRAM. If you are using `torch>=2.0.0`, make sure to **delete all** `enable_xformers_memory_efficient_attention()` functions. @@ -100,8 +135,10 @@ stage_2.enable_xformers_memory_efficient_attention() # remove line if torch.__v stage_2.enable_model_cpu_offload() # stage 3 -safety_modules = {"feature_extractor": stage_1.feature_extractor, "safety_checker": stage_1.safety_checker, "watermarker": stage_1.watermarker} -stage_3 = DiffusionPipeline.from_pretrained("stabilityai/stable-diffusion-x4-upscaler", **safety_modules, torch_dtype=torch.float16) +safety_modules = {"feature_extractor": stage_1.feature_extractor, "safety_checker": stage_1.safety_checker, + "watermarker": stage_1.watermarker} +stage_3 = DiffusionPipeline.from_pretrained("stabilityai/stable-diffusion-x4-upscaler", **safety_modules, + torch_dtype=torch.float16) stage_3.enable_xformers_memory_efficient_attention() # remove line if torch.__version__ >= 2.0.0 stage_3.enable_model_cpu_offload() @@ -113,12 +150,14 @@ prompt_embeds, negative_embeds = stage_1.encode_prompt(prompt) generator = torch.manual_seed(0) # stage 1 -image = stage_1(prompt_embeds=prompt_embeds, negative_prompt_embeds=negative_embeds, generator=generator, output_type="pt").images +image = stage_1(prompt_embeds=prompt_embeds, negative_prompt_embeds=negative_embeds, generator=generator, + output_type="pt").images pt_to_pil(image)[0].save("./if_stage_I.png") # stage 2 image = stage_2( - image=image, prompt_embeds=prompt_embeds, negative_prompt_embeds=negative_embeds, generator=generator, output_type="pt" + image=image, prompt_embeds=prompt_embeds, negative_prompt_embeds=negative_embeds, generator=generator, + output_type="pt" ).images pt_to_pil(image)[0].save("./if_stage_II.png") @@ -127,12 +166,16 @@ image = stage_3(prompt=prompt, image=image, generator=generator, noise_level=100 image[0].save("./if_stage_III.png") ``` - There are multiple ways to speed up the inference time and lower the memory consumption even more with `diffusers`. To do so, please have a look at the Diffusers docs: +There are multiple ways to speed up the inference time and lower the memory consumption even more with `diffusers`. To +do so, please have a look at the Diffusers docs: - 🚀 [Optimizing for inference time](https://huggingface.co/docs/diffusers/api/pipelines/if#optimizing-for-speed) -- ⚙️ [Optimizing for low memory during inference](https://huggingface.co/docs/diffusers/api/pipelines/if#optimizing-for-memory) +- +⚙️ [Optimizing for low memory during inference](https://huggingface.co/docs/diffusers/api/pipelines/if#optimizing-for-memory) -For more in-detail information about how to use IF, please have a look at [the IF blog post](https://huggingface.co/blog/if) and [the documentation](https://huggingface.co/docs/diffusers/main/en/api/pipelines/if) 📖. +For more in-detail information about how to use IF, please have a look +at [the IF blog post](https://huggingface.co/blog/if) +and [the documentation](https://huggingface.co/docs/diffusers/main/en/api/pipelines/if) 📖. ## Run the code locally @@ -150,6 +193,7 @@ t5 = T5Embedder(device="cpu") ``` ### I. Dream + Dream is the text-to-image mode of the IF model ```python @@ -160,7 +204,7 @@ count = 4 result = dream( t5=t5, if_I=if_I, if_II=if_II, if_III=if_III, - prompt=[prompt]*count, + prompt=[prompt] * count, seed=42, if_I_kwargs={ "guidance_scale": 7.0, @@ -179,6 +223,7 @@ result = dream( if_III.show(result['III'], size=14) ``` + ![](./pics/dream-III.jpg) ## II. Zero-shot Image-to-Image Translation @@ -186,6 +231,7 @@ if_III.show(result['III'], size=14) ![](./pics/img_to_img_scheme.jpeg) In Style Transfer mode, the output of your prompt comes out at the style of the `support_pil_img` + ```python from deepfloyd_if.pipelines import style_transfer @@ -215,9 +261,10 @@ if_I.show(result['II'], 1, 20) ![Alternative Text](./pics/deep_floyd_if_image_2_image.gif) - ## III. Super Resolution -For super-resolution, users can run `IF-II` and `IF-III` or 'Stable x4' on an image that was not necessarely generated by IF (two cascades): + +For super-resolution, users can run `IF-II` and `IF-III` or 'Stable x4' on an image that was not necessarely generated +by IF (two cascades): ```python from deepfloyd_if.pipelines import super_resolution @@ -253,7 +300,6 @@ show_superres(raw_pil_image, high_res['III'][0]) ![](./pics/if_as_upscaler.jpg) - ### IV. Zero-shot Inpainting ```python @@ -289,9 +335,11 @@ if_I.show(result['I'], 2, 3) if_I.show(result['II'], 2, 6) if_I.show(result['III'], 2, 14) ``` + ![](./pics/deep_floyd_if_inpainting.gif) ### 🤗 Model Zoo 🤗 + The link to download the weights as well as the model cards will be available soon on each model of the model zoo #### Original @@ -305,7 +353,7 @@ The link to download the weights as well as the model cards will be available so | [IF-II-L](https://huggingface.co/DeepFloyd/IF-II-L-v1.0)* | II | 1.2B | - | 1536 | 2.5M | | IF-III-L* _(soon)_ | III | 700M | - | 3072 | 1.25M | - *best modules +*best modules ### Quantitative Evaluation @@ -315,16 +363,19 @@ The link to download the weights as well as the model cards will be available so ## License -The code in this repository is released under the bespoke license (see added [point two](https://github.com/deep-floyd/IF/blob/main/LICENSE#L13)). +The code in this repository is released under the bespoke license (see +added [point two](https://github.com/deep-floyd/IF/blob/main/LICENSE#L13)). -The weights will be available soon via [the DeepFloyd organization at Hugging Face](https://huggingface.co/DeepFloyd) and have their own LICENSE. +The weights will be available soon via [the DeepFloyd organization at Hugging Face](https://huggingface.co/DeepFloyd) +and have their own LICENSE. -**Disclaimer:** *The initial release of the IF model is under a restricted research-purposes-only license temporarily to gather feedback, and after that we intend to release a fully open-source model in line with other Stability AI models.* +**Disclaimer:** *The initial release of the IF model is under a restricted research-purposes-only license temporarily to +gather feedback, and after that we intend to release a fully open-source model in line with other Stability AI models.* ## Limitations and Biases -The models available in this codebase have known limitations and biases. Please refer to [the model card](https://huggingface.co/DeepFloyd/IF-I-L-v1.0) for more information. - +The models available in this codebase have known limitations and biases. Please refer +to [the model card](https://huggingface.co/DeepFloyd/IF-I-L-v1.0) for more information. ## 🎓 DeepFloyd IF creators: @@ -335,17 +386,26 @@ The models available in this codebase have known limitations and biases. Please - Ksenia Ivanova [GitHub](https://github.com/ivksu) | [Twitter](https://twitter.com/susiaiv) - Nadiia Klokova [GitHub](https://github.com/vauimpuls) | [Twitter](https://twitter.com/vauimpuls) - ## 📄 Research Paper (Soon) ## Acknowledgements -Special thanks to [StabilityAI](http://stability.ai) and its CEO [Emad Mostaque](https://twitter.com/emostaque) for invaluable support, providing GPU compute and infrastructure to train the models (our gratitude goes to [Richard Vencu](https://github.com/rvencu)); thanks to [LAION](https://laion.ai) and [Christoph Schuhmann](https://github.com/christophschuhmann) in particular for contribution to the project and well-prepared datasets; thanks to [Huggingface](https://huggingface.co) teams for optimizing models' speed and memory consumption during inference, creating demos and giving cool advice! +Special thanks to [StabilityAI](http://stability.ai) and its CEO [Emad Mostaque](https://twitter.com/emostaque) for +invaluable support, providing GPU compute and infrastructure to train the models (our gratitude goes +to [Richard Vencu](https://github.com/rvencu)); thanks to [LAION](https://laion.ai) +and [Christoph Schuhmann](https://github.com/christophschuhmann) in particular for contribution to the project and +well-prepared datasets; thanks to [Huggingface](https://huggingface.co) teams for optimizing models' speed and memory +consumption during inference, creating demos and giving cool advice! ## 🚀 External Contributors 🚀 -- The Biggest Thanks [@Apolinário](https://github.com/apolinario), for ideas, consultations, help and support on all stages to make IF available in open-source; for writing a lot of documentation and instructions; for creating a friendly atmosphere in difficult moments 🦉; + +- The Biggest Thanks [@Apolinário](https://github.com/apolinario), for ideas, consultations, help and support on all + stages to make IF available in open-source; for writing a lot of documentation and instructions; for creating a + friendly atmosphere in difficult moments 🦉; - Thanks, [@patrickvonplaten](https://github.com/patrickvonplaten), for improving loading time of unet models by 80%; -for integration Stable-Diffusion-x4 as native pipeline 💪; -- Thanks, [@williamberman](https://github.com/williamberman) and [@patrickvonplaten](https://github.com/patrickvonplaten) for diffusers integration 🙌; -- Thanks, [@hysts](https://github.com/hysts) and [@Apolinário](https://github.com/apolinario) for creating [the best gradio demo with IF](https://huggingface.co/spaces/DeepFloyd/IF) 🚀; + for integration Stable-Diffusion-x4 as native pipeline 💪; +- Thanks, [@williamberman](https://github.com/williamberman) + and [@patrickvonplaten](https://github.com/patrickvonplaten) for diffusers integration 🙌; +- Thanks, [@hysts](https://github.com/hysts) and [@Apolinário](https://github.com/apolinario) for + creating [the best gradio demo with IF](https://huggingface.co/spaces/DeepFloyd/IF) 🚀; - Thanks, [@Dango233](https://github.com/Dango233), for adapting IF with xformers memory efficient attention 💪; diff --git a/deepfloyd_if/modules/stage_III_sd_x4.py b/deepfloyd_if/modules/stage_III_sd_x4.py index 7599317..2148f94 100644 --- a/deepfloyd_if/modules/stage_III_sd_x4.py +++ b/deepfloyd_if/modules/stage_III_sd_x4.py @@ -9,7 +9,6 @@ class StableStageIII(IFBaseModule): - available_models = ['stable-diffusion-x4-upscaler'] def __init__(self, *args, model_kwargs=None, pil_img_size=1024, **kwargs): @@ -20,7 +19,7 @@ def __init__(self, *args, model_kwargs=None, pil_img_size=1024, **kwargs): ' Please run `pip install diffusers --upgrade`' ) - model_id = os.path.join('stabilityai', self.dir_or_name) + model_id = 'stabilityai' + "/" + self.dir_or_name.strip() model_kwargs = model_kwargs or {} precision = str(model_kwargs.get('precision', '16')) @@ -46,12 +45,11 @@ def use_diffusers(self): return False def embeddings_to_image( - self, low_res, t5_embs, style_t5_embs=None, positive_t5_embs=None, negative_t5_embs=None, batch_repeat=1, + self, low_res, prompt, style_t5_embs=None, positive_t5_embs=None, negative_t5_embs=None, batch_repeat=1, aug_level=0.0, blur_sigma=None, dynamic_thresholding_p=0.95, dynamic_thresholding_c=1.0, positive_mixer=0.5, sample_loop='ddpm', sample_timestep_respacing='75', guidance_scale=4.0, img_scale=4.0, - progress=True, seed=None, sample_fn=None, **kwargs): + progress=True, seed=None, sample_fn=None, device=None, **kwargs): - prompt = kwargs.pop('prompt') noise_level = kwargs.pop('noise_level', 20) if sample_loop == 'ddpm': @@ -64,7 +62,8 @@ def embeddings_to_image( self.model.set_progress_bar_config(disable=not progress) generator = torch.manual_seed(seed) - prompt = sum([batch_repeat * [p] for p in prompt], []) + prompt = [prompt] + print(prompt) low_res = low_res.repeat(batch_repeat, 1, 1, 1) metadata = { @@ -82,3 +81,6 @@ def embeddings_to_image( sample = self._IFBaseModule__validate_generations(images) return sample, metadata + + def to(self, x): + self.model.to(x) diff --git a/deepfloyd_if/pipelines/optimized_dream.py b/deepfloyd_if/pipelines/optimized_dream.py index dfc9e15..0824539 100644 --- a/deepfloyd_if/pipelines/optimized_dream.py +++ b/deepfloyd_if/pipelines/optimized_dream.py @@ -67,30 +67,29 @@ def run_stage2( sample_timestep_respacing=custom_timesteps_2, seed=seed, device=device) pil_images_II = model.to_images(stageII_generations, disable_watermark=disable_watermark) - print(pil_images_II) return pil_images_II def run_stage3( model, - image: Image, - t5_embs, + prompt, negative_t5_embs, - seed_3: int = 0, - guidance_scale_3: float = 9.0, - sample_timestep_respacing='super40', - disable_watermark=True + images, + seed: int = 0, + guidance_scale: float = 4.0, + custom_timesteps_2: str = 'smart50', + num_inference_steps_2: int = 50, + disable_watermark: bool = True, + device=None ) -> Image: run_garbage_collection() - - _stageIII_generations, _ = model.embeddings_to_image(image=image, - t5_embs=t5_embs, - negative_t5_embs=negative_t5_embs, - num_images_per_prompt=1, - guidance_scale=guidance_scale_3, - noise_level=100, - sample_timestep_respacing=sample_timestep_respacing, - seed=seed_3) - pil_image_III = model.to_images(_stageIII_generations, disable_watermark=disable_watermark) - - return pil_image_III + stageII_generations, _ = model.embeddings_to_image(low_res=images, + prompt=prompt, + negative_t5_embs=negative_t5_embs, + guidance_scale=guidance_scale, + sample_timestep_respacing=num_inference_steps_2, + num_images_per_prompt=1, + noise_level=100, + seed=seed, device=device) + pil_images_III = model.to_images(stageII_generations, disable_watermark=disable_watermark) + return pil_images_III diff --git a/run_ui.py b/run_ui.py index c102b8c..6e03b4f 100644 --- a/run_ui.py +++ b/run_ui.py @@ -22,7 +22,7 @@ if_I.to(torch.float16) # half if_II = IFStageII('IF-II-L-v1.0', device=torch.device("cpu")) if_I.to(torch.float16) # half -# # if_III = StableStageIII('stable-diffusion-x4-upscaler', device=torch.device("cpu")) +if_III = StableStageIII('stable-diffusion-x4-upscaler', device=torch.device("cpu")) t5_device = torch.device(0) t5 = T5Embedder(device=t5_device, t5_model_kwargs={"low_cpu_mem_usage": True, "torch_dtype": torch.float16, @@ -31,6 +31,16 @@ def switch_devices(stage): + if stage == 0: + if_I.to(torch.device("cpu")) + if_II.to(torch.device("cpu")) + if_III.to(torch.device("cpu")) + + gc.collect() + torch.cuda.empty_cache() + torch.cuda.synchronize() + + # dispatch_model(t5.model, get_device_map(t5_device, all2cpu=False)) if stage == 1: # t5.model.cpu() dispatch_model(t5.model, get_device_map(t5_device, all2cpu=True)) @@ -44,6 +54,12 @@ def switch_devices(stage): torch.cuda.empty_cache() torch.cuda.synchronize() if_II.to(torch.device(0)) + elif stage == 3: + if_II.to(torch.device("cpu")) + gc.collect() + torch.cuda.empty_cache() + torch.cuda.synchronize() + if_III.to(torch.device(0)) def process_and_run_stage1(prompt, @@ -55,6 +71,7 @@ def process_and_run_stage1(prompt, num_inference_steps_1): global t5_embs, negative_t5_embs, images print("Encoding prompts..") + switch_devices(stage=0) t5_embs = t5.get_text_embeddings([prompt]) negative_t5_embs = t5.get_text_embeddings([negative_prompt]) switch_devices(stage=1) @@ -96,6 +113,29 @@ def process_and_run_stage2( ) +def process_and_run_stage3( + index, + prompt, + seed_2, + guidance_scale_2, + custom_timesteps_2, + num_inference_steps_2): + global t5_embs, negative_t5_embs, images + print("Stage 3..") + switch_devices(stage=3) + return run_stage3( + if_III, + prompt=prompt, + negative_t5_embs=negative_t5_embs, + images=images[index].unsqueeze(0).to(device), + seed=seed_2, + guidance_scale=guidance_scale_2, + custom_timesteps_2=custom_timesteps_2, + num_inference_steps_2=num_inference_steps_2, + device=device + ) + + def create_ui(args): with gr.Blocks(css='ui_files/style.css') as demo: with gr.Box(): @@ -129,9 +169,6 @@ def create_ui(args): 'Upscale to 256px', visible=args.DISABLE_SD_X4_UPSCALER, interactive=False) - upscale_button = gr.Button('Upscale', - interactive=False, - visible=not args.DISABLE_SD_X4_UPSCALER) with gr.Column(visible=False) as upscale_view: result = gr.Gallery(label='Result', show_label=False, @@ -139,6 +176,9 @@ def create_ui(args): columns=args.GALLERY_COLUMN_NUM, object_fit='contain') back_to_selection_button = gr.Button('Back to selection') + upscale_button = gr.Button('Upscale 4x', + interactive=False, + visible=True) with gr.Accordion('Advanced options', open=False, visible=args.SHOW_ADVANCED_OPTIONS): @@ -295,73 +335,48 @@ def create_ui(args): outputs=result, ) - # stage2_3_inputs = [ - # stage1_result_path, - # selected_index_for_stage2, - # seed_2, - # guidance_scale_2, - # custom_timesteps_2, - # num_inference_steps_2, - # prompt, - # negative_prompt, - # seed_3, - # guidance_scale_3, - # num_inference_steps_3, - # ] - # - # upscale_button.click( - # fn=check_if_stage2_selected, - # inputs=selected_index_for_stage2, - # queue=False, - # ).then( - # fn=randomize_seed_fn, - # inputs=[seed_2, randomize_seed_2], - # outputs=seed_2, - # queue=False, - # ).then( - # fn=randomize_seed_fn, - # inputs=[seed_3, randomize_seed_3], - # outputs=seed_3, - # queue=False, - # ).then( - # fn=show_upscaled_view, - # outputs=[ - # gallery_view, - # upscale_view, - # ], - # queue=False, - # ).then( - # fn=run_stage3, - # inputs=stage2_3_inputs, - # outputs=result, - # api_name='upscale1024', - # ) # .success( - # # fn=upload_stage2_3_info, - # # inputs=[ - # # stage1_param_file_hash_name, - # # result, - # # selected_index_for_stage2, - # # seed_2, - # # guidance_scale_2, - # # custom_timesteps_2, - # # num_inference_steps_2, - # # prompt, - # # negative_prompt, - # # seed_3, - # # guidance_scale_3, - # # num_inference_steps_3, - # # ], - # # queue=False, - # # ) - # - # back_to_selection_button.click( - # fn=show_gallery_view, - # outputs=[ - # gallery_view, - # upscale_view, - # ], - # queue=False, - # ) + upscale_button.click( + fn=check_if_stage2_selected, + inputs=selected_index_for_stage2, + queue=False, + ).then( + fn=randomize_seed_fn, + inputs=[seed_2, randomize_seed_2], + outputs=seed_2, + queue=False, + ).then( + fn=randomize_seed_fn, + inputs=[seed_3, randomize_seed_3], + outputs=seed_3, + queue=False, + ).then( + fn=show_upscaled_view, + outputs=[ + gallery_view, + upscale_view, + ], + queue=False, + ).then( + fn=process_and_run_stage3, + inputs=[ + selected_index_for_stage2, + prompt, + seed_2, + guidance_scale_3, + custom_timesteps_2, + num_inference_steps_3, + ], + outputs=result, + ) + + back_to_selection_button.click( + fn=show_gallery_view, + outputs=[ + gallery_view, + upscale_view, + ], + queue=False, + ) return demo From bcda682b78e1430b14ac6fb413a6ca116534cc10 Mon Sep 17 00:00:00 2001 From: neon Date: Sun, 30 Apr 2023 09:38:21 +0200 Subject: [PATCH 05/10] t5 reloading --- deepfloyd_if/modules/t5.py | 10 ++++++++++ run_ui.py | 4 ++++ 2 files changed, 14 insertions(+) diff --git a/deepfloyd_if/modules/t5.py b/deepfloyd_if/modules/t5.py index 63ec1c7..1cbd3dd 100644 --- a/deepfloyd_if/modules/t5.py +++ b/deepfloyd_if/modules/t5.py @@ -75,6 +75,16 @@ def __init__(self, device, dir_or_name='t5-v1_1-xxl', *, cache_dir=None, hf_toke self.tokenizer = AutoTokenizer.from_pretrained(path) self.model = T5EncoderModel.from_pretrained(path, **t5_model_kwargs).eval() + self.saved_path = path + self.saved_kwargs = t5_model_kwargs + self.loaded = True + + def reload(self, dmap): + del self.model + torch.cuda.empty_cache() + self.saved_kwargs["device_map"] = dmap + self.model = T5EncoderModel.from_pretrained(self.saved_path, **self.saved_kwargs).eval() + self.loaded = True def to(self, x): self.model.to(x) diff --git a/run_ui.py b/run_ui.py index 6e03b4f..8312777 100644 --- a/run_ui.py +++ b/run_ui.py @@ -40,10 +40,14 @@ def switch_devices(stage): torch.cuda.empty_cache() torch.cuda.synchronize() + if not t5.loaded: + print("Reloading t5") + t5.reload(get_device_map(t5_device, all2cpu=False)) # dispatch_model(t5.model, get_device_map(t5_device, all2cpu=False)) if stage == 1: # t5.model.cpu() dispatch_model(t5.model, get_device_map(t5_device, all2cpu=True)) + t5.model.loaded = False gc.collect() torch.cuda.empty_cache() torch.cuda.synchronize() From 9fc7a2ac0fc7599317cddab6524783f4aa28c13d Mon Sep 17 00:00:00 2001 From: neon Date: Sun, 30 Apr 2023 09:55:44 +0200 Subject: [PATCH 06/10] fixes --- deepfloyd_if/model/unet_split.py | 3 +++ deepfloyd_if/modules/stage_III_sd_x4.py | 1 - deepfloyd_if/pipelines/optimized_dream.py | 4 ++-- run_ui.py | 8 +++----- 4 files changed, 8 insertions(+), 8 deletions(-) diff --git a/deepfloyd_if/model/unet_split.py b/deepfloyd_if/model/unet_split.py index 1d64e3b..623a378 100644 --- a/deepfloyd_if/model/unet_split.py +++ b/deepfloyd_if/model/unet_split.py @@ -2,6 +2,7 @@ import gc import os import math +import time from abc import abstractmethod import torch @@ -661,6 +662,8 @@ def to(self, x, stage=1): # 0, 1, 2, 3 self.out.to(x) else: super().to(x) + # time.sleep(3) + # print(stage) def forward(self, x, timesteps, text_emb, timestep_text_emb=None, aug_emb=None, use_cache=False, **kwargs): hs = [] diff --git a/deepfloyd_if/modules/stage_III_sd_x4.py b/deepfloyd_if/modules/stage_III_sd_x4.py index 2148f94..e3af3a4 100644 --- a/deepfloyd_if/modules/stage_III_sd_x4.py +++ b/deepfloyd_if/modules/stage_III_sd_x4.py @@ -63,7 +63,6 @@ def embeddings_to_image( generator = torch.manual_seed(seed) prompt = [prompt] - print(prompt) low_res = low_res.repeat(batch_repeat, 1, 1, 1) metadata = { diff --git a/deepfloyd_if/pipelines/optimized_dream.py b/deepfloyd_if/pipelines/optimized_dream.py index 0824539..7240559 100644 --- a/deepfloyd_if/pipelines/optimized_dream.py +++ b/deepfloyd_if/pipelines/optimized_dream.py @@ -34,7 +34,7 @@ def run_stage1( images, _ = model.embeddings_to_image(t5_embs=t5_embs, negative_t5_embs=negative_t5_embs, - num_images_per_prompt=num_images, + batch_repeat=num_images, guidance_scale=guidance_scale_1, sample_timestep_respacing=custom_timesteps_1, seed=seed @@ -89,7 +89,7 @@ def run_stage3( guidance_scale=guidance_scale, sample_timestep_respacing=num_inference_steps_2, num_images_per_prompt=1, - noise_level=100, + noise_level=60, seed=seed, device=device) pil_images_III = model.to_images(stageII_generations, disable_watermark=disable_watermark) return pil_images_III diff --git a/run_ui.py b/run_ui.py index 8312777..01eae4d 100644 --- a/run_ui.py +++ b/run_ui.py @@ -1,7 +1,5 @@ import argparse import gc -import os -import time import numpy as np from accelerate import dispatch_model @@ -47,7 +45,7 @@ def switch_devices(stage): if stage == 1: # t5.model.cpu() dispatch_model(t5.model, get_device_map(t5_device, all2cpu=True)) - t5.model.loaded = False + t5.loaded = False gc.collect() torch.cuda.empty_cache() torch.cuda.synchronize() @@ -76,8 +74,8 @@ def process_and_run_stage1(prompt, global t5_embs, negative_t5_embs, images print("Encoding prompts..") switch_devices(stage=0) - t5_embs = t5.get_text_embeddings([prompt]) - negative_t5_embs = t5.get_text_embeddings([negative_prompt]) + t5_embs = t5.get_text_embeddings([prompt] * num_images) + negative_t5_embs = t5.get_text_embeddings([negative_prompt] * num_images) switch_devices(stage=1) t5_embs = t5_embs.to(if_I.device) negative_t5_embs = negative_t5_embs.to(if_I.device) From 2a5c212b99a43ad2d129626dd8ce2421d5c76d2b Mon Sep 17 00:00:00 2001 From: neon Date: Sun, 30 Apr 2023 10:02:47 +0200 Subject: [PATCH 07/10] small optimization --- README.md | 2 +- deepfloyd_if/model/unet_split.py | 25 ++++++++----------------- 2 files changed, 9 insertions(+), 18 deletions(-) diff --git a/README.md b/README.md index e1a72c1..8c63b75 100644 --- a/README.md +++ b/README.md @@ -2,7 +2,7 @@ Tested on rtx 3070, 8 gb vram -stage 1: ~3.5 sec per iteration +stage 1: 1.30 min for 27 steps, ~3.35 sec per iteration stage 2: 11 seconds for 27 steps stage 3: 2 seconds for 40 steps diff --git a/deepfloyd_if/model/unet_split.py b/deepfloyd_if/model/unet_split.py index 623a378..bdcdb8d 100644 --- a/deepfloyd_if/model/unet_split.py +++ b/deepfloyd_if/model/unet_split.py @@ -634,36 +634,29 @@ def to(self, x, stage=1): # 0, 1, 2, 3 if isinstance(x, torch.device): secondary_device = self.secondary_device if stage == 1: - self.middle_block.to(secondary_device) self.output_blocks.to(secondary_device) self.out.to(secondary_device) - self.collect() + + # self.collect() + self.time_embed.to(x) self.encoder_proj.to(x) self.encoder_pooling.to(x) self.input_blocks.to(x) - elif stage == 2: - self.time_embed.to(secondary_device) - self.encoder_proj.to(secondary_device) - self.encoder_pooling.to(secondary_device) - self.input_blocks.to(secondary_device) - self.output_blocks.to(secondary_device) - self.out.to(secondary_device) - self.collect() self.middle_block.to(x) - elif stage == 3: + elif stage == 2: self.time_embed.to(secondary_device) self.encoder_proj.to(secondary_device) self.encoder_pooling.to(secondary_device) self.input_blocks.to(secondary_device) self.middle_block.to(secondary_device) - self.collect() + + # self.collect() + self.output_blocks.to(x) self.out.to(x) else: super().to(x) - # time.sleep(3) - # print(stage) def forward(self, x, timesteps, text_emb, timestep_text_emb=None, aug_emb=None, use_cache=False, **kwargs): hs = [] @@ -696,11 +689,9 @@ def forward(self, x, timesteps, text_emb, timestep_text_emb=None, aug_emb=None, h = module(h, emb, encoder_out) hs.append(h) - self.to(self.primary_device, stage=2) - h = self.middle_block(h, emb, encoder_out) - self.to(self.primary_device, stage=3) + self.to(self.primary_device, stage=2) for module in self.output_blocks: h = torch.cat([h, hs.pop()], dim=1) From 5aa32581638a466023198775620c3e239be0303d Mon Sep 17 00:00:00 2001 From: neon Date: Sun, 30 Apr 2023 16:59:39 +0200 Subject: [PATCH 08/10] fixed stage 3, default params changed --- README.md | 8 ++++---- deepfloyd_if/modules/stage_III_sd_x4.py | 8 +------- deepfloyd_if/pipelines/optimized_dream.py | 4 ++-- run_ui.py | 7 ++++--- 4 files changed, 11 insertions(+), 16 deletions(-) diff --git a/README.md b/README.md index 8c63b75..58c86d9 100644 --- a/README.md +++ b/README.md @@ -2,9 +2,9 @@ Tested on rtx 3070, 8 gb vram -stage 1: 1.30 min for 27 steps, ~3.35 sec per iteration -stage 2: 11 seconds for 27 steps -stage 3: 2 seconds for 40 steps +stage 1: 1.30 min for 27 steps, ~3.35 sec per iteration \ +stage 2: 11 seconds for 27 steps \ +stage 3: 30 seconds for 40 steps ### To run the ui: @@ -14,7 +14,7 @@ python run_ui.py All the models are automatically downloaded. -original readme: +Original readme: [![License](https://img.shields.io/badge/Code_License-Modified_MIT-blue.svg)](LICENSE) [![License](https://img.shields.io/badge/Weights_License-DeepFloyd_IF-orange.svg)](LICENSE-MODEL) diff --git a/deepfloyd_if/modules/stage_III_sd_x4.py b/deepfloyd_if/modules/stage_III_sd_x4.py index e3af3a4..2f6d6ef 100644 --- a/deepfloyd_if/modules/stage_III_sd_x4.py +++ b/deepfloyd_if/modules/stage_III_sd_x4.py @@ -50,8 +50,6 @@ def embeddings_to_image( sample_loop='ddpm', sample_timestep_respacing='75', guidance_scale=4.0, img_scale=4.0, progress=True, seed=None, sample_fn=None, device=None, **kwargs): - noise_level = kwargs.pop('noise_level', 20) - if sample_loop == 'ddpm': self.model.scheduler = DDPMScheduler.from_config(self.model.scheduler.config) else: @@ -62,13 +60,9 @@ def embeddings_to_image( self.model.set_progress_bar_config(disable=not progress) generator = torch.manual_seed(seed) - prompt = [prompt] - low_res = low_res.repeat(batch_repeat, 1, 1, 1) - metadata = { - 'image': low_res, + 'image': low_res, # 1 3 256 256 'prompt': prompt, - 'noise_level': noise_level, 'generator': generator, 'guidance_scale': guidance_scale, 'num_inference_steps': num_inference_steps, diff --git a/deepfloyd_if/pipelines/optimized_dream.py b/deepfloyd_if/pipelines/optimized_dream.py index 7240559..77c1e74 100644 --- a/deepfloyd_if/pipelines/optimized_dream.py +++ b/deepfloyd_if/pipelines/optimized_dream.py @@ -67,7 +67,7 @@ def run_stage2( sample_timestep_respacing=custom_timesteps_2, seed=seed, device=device) pil_images_II = model.to_images(stageII_generations, disable_watermark=disable_watermark) - return pil_images_II + return stageII_generations, pil_images_II def run_stage3( @@ -89,7 +89,7 @@ def run_stage3( guidance_scale=guidance_scale, sample_timestep_respacing=num_inference_steps_2, num_images_per_prompt=1, - noise_level=60, + noise_level=20, seed=seed, device=device) pil_images_III = model.to_images(stageII_generations, disable_watermark=disable_watermark) return pil_images_III diff --git a/run_ui.py b/run_ui.py index 01eae4d..9b0a6d8 100644 --- a/run_ui.py +++ b/run_ui.py @@ -102,7 +102,7 @@ def process_and_run_stage2( global t5_embs, negative_t5_embs, images print("Stage 2..") switch_devices(stage=2) - return run_stage2( + images, images_ret = run_stage2( if_II, t5_embs=t5_embs, negative_t5_embs=negative_t5_embs, @@ -113,6 +113,7 @@ def process_and_run_stage2( num_inference_steps_2=num_inference_steps_2, device=device ) + return images_ret def process_and_run_stage3( @@ -208,7 +209,7 @@ def create_ui(args): 'smart100', 'smart185', ], - value="fast27", + value="smart50", visible=True) num_inference_steps_1 = gr.Slider( label='Number of inference steps', @@ -273,7 +274,7 @@ def create_ui(args): minimum=1, maximum=200, step=1, - value=40, + value=60, visible=True) with gr.Box(): with gr.Row(): From 103adb5b21b1513adb2bb706a05dd6e6a7e80698 Mon Sep 17 00:00:00 2001 From: neon Date: Sun, 30 Apr 2023 17:11:03 +0200 Subject: [PATCH 09/10] added aspect ratio possibility --- deepfloyd_if/pipelines/optimized_dream.py | 3 ++- run_ui.py | 12 +++++++++--- 2 files changed, 11 insertions(+), 4 deletions(-) diff --git a/deepfloyd_if/pipelines/optimized_dream.py b/deepfloyd_if/pipelines/optimized_dream.py index 77c1e74..48ced8d 100644 --- a/deepfloyd_if/pipelines/optimized_dream.py +++ b/deepfloyd_if/pipelines/optimized_dream.py @@ -26,6 +26,7 @@ def run_stage1( guidance_scale_1: float = 7.0, custom_timesteps_1: str = 'smart100', num_inference_steps_1: int = 100, + aspect_ratio='1:1' ): run_garbage_collection() @@ -37,7 +38,7 @@ def run_stage1( batch_repeat=num_images, guidance_scale=guidance_scale_1, sample_timestep_respacing=custom_timesteps_1, - seed=seed + seed=seed, aspect_ratio=aspect_ratio ) pil_images_I = model.to_images(images, disable_watermark=True) diff --git a/run_ui.py b/run_ui.py index 9b0a6d8..e8663f7 100644 --- a/run_ui.py +++ b/run_ui.py @@ -70,7 +70,8 @@ def process_and_run_stage1(prompt, num_images, guidance_scale_1, custom_timesteps_1, - num_inference_steps_1): + num_inference_steps_1, + aspect_ratio): global t5_embs, negative_t5_embs, images print("Encoding prompts..") switch_devices(stage=0) @@ -88,7 +89,8 @@ def process_and_run_stage1(prompt, num_images=num_images, guidance_scale_1=guidance_scale_1, custom_timesteps_1=custom_timesteps_1, - num_inference_steps_1=num_inference_steps_1 + num_inference_steps_1=num_inference_steps_1, + aspect_ratio=aspect_ratio ) return images_ret @@ -158,6 +160,9 @@ def create_ui(args): placeholder='Enter a negative prompt', elem_id='negative-prompt-text-input', ).style(container=False) + aspect_ratio_1 = gr.Radio( + ["16:9", "4:3", "1:1", "3:4", "9:16"], value="1:1", label="Aspect ratio" + ).style(container=False) generate_button = gr.Button('Generate').style(full_width=False) with gr.Column() as gallery_view: @@ -290,7 +295,8 @@ def create_ui(args): num_images, guidance_scale_1, custom_timesteps_1, - num_inference_steps_1], + num_inference_steps_1, + aspect_ratio_1], gallery ) From 3e9808d97317c7b4bace22cc0a4fecfcdf56895d Mon Sep 17 00:00:00 2001 From: neon Date: Mon, 1 May 2023 09:22:44 +0200 Subject: [PATCH 10/10] replaced aspect ratio with width/height params, fixed number of images --- deepfloyd_if/modules/base.py | 7 +++++-- deepfloyd_if/modules/stage_I.py | 6 ++++-- deepfloyd_if/pipelines/optimized_dream.py | 24 +++++++++++++---------- run_ui.py | 19 ++++++++++++------ 4 files changed, 36 insertions(+), 20 deletions(-) diff --git a/deepfloyd_if/modules/base.py b/deepfloyd_if/modules/base.py index d08453b..6cad9f3 100644 --- a/deepfloyd_if/modules/base.py +++ b/deepfloyd_if/modules/base.py @@ -89,12 +89,13 @@ def embeddings_to_image( support_noise_less_qsample_steps=0, inpainting_mask=None, device=None, + force_size=False, **kwargs, ): if device is None: device = self.model.primary_device self._clear_cache() - image_w, image_h = self._get_image_sizes(low_res, img_size, aspect_ratio, img_scale) + image_w, image_h = self._get_image_sizes(low_res, img_size, aspect_ratio, img_scale, force_size=force_size) diffusion = self.get_diffusion(sample_timestep_respacing) bs_scale = 2 if positive_t5_embs is None else 3 @@ -331,11 +332,13 @@ def show(self, pil_images, nrow=None, size=10): def _clear_cache(self): self.model.cache = None - def _get_image_sizes(self, low_res, img_size, aspect_ratio, img_scale): + def _get_image_sizes(self, low_res, img_size, aspect_ratio, img_scale, force_size=True): if low_res is not None: bs, c, h, w = low_res.shape image_h, image_w = int((h * img_scale) // 32) * 32, int((w * img_scale // 32)) * 32 else: + if force_size: + return img_size[0], img_size[1] scale_w, scale_h = aspect_ratio.split(':') scale_w, scale_h = int(scale_w), int(scale_h) coef = scale_w / scale_h diff --git a/deepfloyd_if/modules/stage_I.py b/deepfloyd_if/modules/stage_I.py index 8b3c62d..92883ac 100644 --- a/deepfloyd_if/modules/stage_I.py +++ b/deepfloyd_if/modules/stage_I.py @@ -35,7 +35,8 @@ def to(self, x, stage=1, secondary_device=torch.device("cpu")): # 0, 1, 2, 3 def embeddings_to_image(self, t5_embs, style_t5_embs=None, positive_t5_embs=None, negative_t5_embs=None, batch_repeat=1, dynamic_thresholding_p=0.95, sample_loop='ddpm', positive_mixer=0.25, sample_timestep_respacing='150', dynamic_thresholding_c=1.5, guidance_scale=7.0, - aspect_ratio='1:1', progress=True, seed=None, sample_fn=None, **kwargs): + img_size=(64, 64), aspect_ratio='1:1', progress=True, seed=None, sample_fn=None, + force_size=False, **kwargs): return super().embeddings_to_image( t5_embs=t5_embs, style_t5_embs=style_t5_embs, @@ -47,11 +48,12 @@ def embeddings_to_image(self, t5_embs, style_t5_embs=None, positive_t5_embs=None sample_loop=sample_loop, sample_timestep_respacing=sample_timestep_respacing, guidance_scale=guidance_scale, - img_size=64, + img_size=img_size, aspect_ratio=aspect_ratio, progress=progress, seed=seed, sample_fn=sample_fn, positive_mixer=positive_mixer, + force_size=force_size, **kwargs ) diff --git a/deepfloyd_if/pipelines/optimized_dream.py b/deepfloyd_if/pipelines/optimized_dream.py index 48ced8d..2267db2 100644 --- a/deepfloyd_if/pipelines/optimized_dream.py +++ b/deepfloyd_if/pipelines/optimized_dream.py @@ -26,23 +26,27 @@ def run_stage1( guidance_scale_1: float = 7.0, custom_timesteps_1: str = 'smart100', num_inference_steps_1: int = 100, - aspect_ratio='1:1' + aspect_ratio='1:1', + img_size=(64, 64) ): run_garbage_collection() if custom_timesteps_1 == "none": custom_timesteps_1 = str(num_inference_steps_1) - images, _ = model.embeddings_to_image(t5_embs=t5_embs, - negative_t5_embs=negative_t5_embs, - batch_repeat=num_images, - guidance_scale=guidance_scale_1, - sample_timestep_respacing=custom_timesteps_1, - seed=seed, aspect_ratio=aspect_ratio - ) - pil_images_I = model.to_images(images, disable_watermark=True) + ret_images1, ret_images2 = [], [] + for _ in range(num_images): + images, _ = model.embeddings_to_image(t5_embs=t5_embs, + negative_t5_embs=negative_t5_embs, + guidance_scale=guidance_scale_1, img_size=img_size, + sample_timestep_respacing=custom_timesteps_1, + seed=seed, aspect_ratio=aspect_ratio, force_size=True + ) + pil_images_I = model.to_images(images, disable_watermark=True) + ret_images1.append(pil_images_I[0]) + ret_images2.append(images[0]) - return images, pil_images_I + return ret_images2, ret_images1 def run_stage2( diff --git a/run_ui.py b/run_ui.py index e8663f7..ff63bca 100644 --- a/run_ui.py +++ b/run_ui.py @@ -71,7 +71,9 @@ def process_and_run_stage1(prompt, guidance_scale_1, custom_timesteps_1, num_inference_steps_1, - aspect_ratio): + # aspect_ratio, + width, + height): global t5_embs, negative_t5_embs, images print("Encoding prompts..") switch_devices(stage=0) @@ -90,7 +92,8 @@ def process_and_run_stage1(prompt, guidance_scale_1=guidance_scale_1, custom_timesteps_1=custom_timesteps_1, num_inference_steps_1=num_inference_steps_1, - aspect_ratio=aspect_ratio + # aspect_ratio=aspect_ratio, + img_size=(width, height) ) return images_ret @@ -160,9 +163,11 @@ def create_ui(args): placeholder='Enter a negative prompt', elem_id='negative-prompt-text-input', ).style(container=False) - aspect_ratio_1 = gr.Radio( - ["16:9", "4:3", "1:1", "3:4", "9:16"], value="1:1", label="Aspect ratio" - ).style(container=False) + width = gr.Slider(32, 128, value=64, step=8, label="Width").style(container=False) + height = gr.Slider(32, 128, value=64, step=8, label="Height").style(container=False) + # aspect_ratio_1 = gr.Radio( + # ["16:9", "4:3", "1:1", "3:4", "9:16"], value="1:1", label="Aspect ratio" + # ).style(container=False) generate_button = gr.Button('Generate').style(full_width=False) with gr.Column() as gallery_view: @@ -296,7 +301,9 @@ def create_ui(args): guidance_scale_1, custom_timesteps_1, num_inference_steps_1, - aspect_ratio_1], + # aspect_ratio_1, + width, + height], gallery )