From a35bba5701ff3cc10c82ed1c6c4c4ebb4293cf3e Mon Sep 17 00:00:00 2001
From: neon
Date: Sat, 29 Apr 2023 12:48:35 +0200
Subject: [PATCH 01/10] init
---
deepfloyd_if/modules/stage_I.py | 4 +-
deepfloyd_if/modules/t5.py | 19 +-
deepfloyd_if/pipelines/optimized_dream.py | 94 ++++++
run_ui.py | 371 ++++++++++++++++++++++
ui_files/style.css | 238 ++++++++++++++
ui_files/utils.py | 72 +++++
6 files changed, 792 insertions(+), 6 deletions(-)
create mode 100644 deepfloyd_if/pipelines/optimized_dream.py
create mode 100644 run_ui.py
create mode 100644 ui_files/style.css
create mode 100644 ui_files/utils.py
diff --git a/deepfloyd_if/modules/stage_I.py b/deepfloyd_if/modules/stage_I.py
index a9c62cc..8b1c5f3 100644
--- a/deepfloyd_if/modules/stage_I.py
+++ b/deepfloyd_if/modules/stage_I.py
@@ -24,11 +24,13 @@ def __init__(self, *args, model_kwargs=None, pil_img_size=64, **kwargs):
self.model = self.load_checkpoint(self.model, self.dir_or_name)
self.model.eval().to(self.device)
+ def to(self, x):
+ self.model.to(x)
+
def embeddings_to_image(self, t5_embs, style_t5_embs=None, positive_t5_embs=None, negative_t5_embs=None,
batch_repeat=1, dynamic_thresholding_p=0.95, sample_loop='ddpm', positive_mixer=0.25,
sample_timestep_respacing='150', dynamic_thresholding_c=1.5, guidance_scale=7.0,
aspect_ratio='1:1', progress=True, seed=None, sample_fn=None, **kwargs):
-
return super().embeddings_to_image(
t5_embs=t5_embs,
style_t5_embs=style_t5_embs,
diff --git a/deepfloyd_if/modules/t5.py b/deepfloyd_if/modules/t5.py
index 7443426..63ec1c7 100644
--- a/deepfloyd_if/modules/t5.py
+++ b/deepfloyd_if/modules/t5.py
@@ -12,9 +12,9 @@
class T5Embedder:
-
available_models = ['t5-v1_1-xxl']
- bad_punct_regex = re.compile(r'['+'#®•©™&@·º½¾¿¡§~'+'\)'+'\('+'\]'+'\['+'\}'+'\{'+'\|'+'\\'+'\/'+'\*' + r']{1,}') # noqa
+ bad_punct_regex = re.compile(
+ r'[' + '#®•©™&@·º½¾¿¡§~' + '\)' + '\(' + '\]' + '\[' + '\}' + '\{' + '\|' + '\\' + '\/' + '\*' + r']{1,}') # noqa
def __init__(self, device, dir_or_name='t5-v1_1-xxl', *, cache_dir=None, hf_token=None, use_text_preprocessing=True,
t5_model_kwargs=None, torch_dtype=None, use_offload_folder=None):
@@ -76,6 +76,12 @@ def __init__(self, device, dir_or_name='t5-v1_1-xxl', *, cache_dir=None, hf_toke
self.tokenizer = AutoTokenizer.from_pretrained(path)
self.model = T5EncoderModel.from_pretrained(path, **t5_model_kwargs).eval()
+ def to(self, x):
+ self.model.to(x)
+
+ def cpu(self):
+ self.model.base_model().to(torch.device("cpu"))
+
def get_text_embeddings(self, texts):
texts = [self.text_preprocessing(text) for text in texts]
@@ -121,10 +127,12 @@ def clean_caption(self, caption):
caption = re.sub('', 'person', caption)
# urls:
caption = re.sub(
- r'\b((?:https?:(?:\/{1,3}|[a-zA-Z0-9%])|[a-zA-Z0-9.\-]+[.](?:com|co|ru|net|org|edu|gov|it)[\w/-]*\b\/?(?!@)))', # noqa
+ r'\b((?:https?:(?:\/{1,3}|[a-zA-Z0-9%])|[a-zA-Z0-9.\-]+[.](?:com|co|ru|net|org|edu|gov|it)[\w/-]*\b\/?(?!@)))',
+ # noqa
'', caption) # regex for urls
caption = re.sub(
- r'\b((?:www:(?:\/{1,3}|[a-zA-Z0-9%])|[a-zA-Z0-9.\-]+[.](?:com|co|ru|net|org|edu|gov|it)[\w/-]*\b\/?(?!@)))', # noqa
+ r'\b((?:www:(?:\/{1,3}|[a-zA-Z0-9%])|[a-zA-Z0-9.\-]+[.](?:com|co|ru|net|org|edu|gov|it)[\w/-]*\b\/?(?!@)))',
+ # noqa
'', caption) # regex for urls
# html:
caption = BeautifulSoup(caption, features='html.parser').text
@@ -150,7 +158,8 @@ def clean_caption(self, caption):
# все виды тире / all types of dash --> "-"
caption = re.sub(
- r'[\u002D\u058A\u05BE\u1400\u1806\u2010-\u2015\u2E17\u2E1A\u2E3A\u2E3B\u2E40\u301C\u3030\u30A0\uFE31\uFE32\uFE58\uFE63\uFF0D]+', # noqa
+ r'[\u002D\u058A\u05BE\u1400\u1806\u2010-\u2015\u2E17\u2E1A\u2E3A\u2E3B\u2E40\u301C\u3030\u30A0\uFE31\uFE32\uFE58\uFE63\uFF0D]+',
+ # noqa
'-', caption)
# кавычки к одному стандарту
diff --git a/deepfloyd_if/pipelines/optimized_dream.py b/deepfloyd_if/pipelines/optimized_dream.py
new file mode 100644
index 0000000..6620b4b
--- /dev/null
+++ b/deepfloyd_if/pipelines/optimized_dream.py
@@ -0,0 +1,94 @@
+import gc
+import numpy as np
+
+import torch.cuda
+from PIL import Image
+
+
+def run_garbage_collection():
+ gc.collect()
+ torch.cuda.empty_cache()
+
+
+def to_pil_images(images: torch.Tensor) -> list[Image]:
+ images = (images / 2 + 0.5).clamp(0, 1)
+ images = images.cpu().permute(0, 2, 3, 1).float().numpy()
+ images = np.round(images * 255).astype(np.uint8)
+ return [Image.fromarray(image) for image in images]
+
+
+def run_stage1(
+ model,
+ t5_embs,
+ negative_t5_embs,
+ seed: int = 0,
+ num_images: int = 1,
+ guidance_scale_1: float = 7.0,
+ custom_timesteps_1: str = 'smart100',
+ num_inference_steps_1: int = 100,
+):
+ run_garbage_collection()
+
+ images, _ = model.embeddings_to_image(t5_embs=t5_embs,
+ negative_t5_embs=negative_t5_embs,
+ num_images_per_prompt=num_images,
+ guidance_scale=guidance_scale_1,
+ sample_timestep_respacing=custom_timesteps_1,
+ seed=seed
+ ).images
+ pil_images_I = model.to_images(images, disable_watermark=True)
+
+ return pil_images_I
+
+
+def run_stage2(
+ model,
+ stage1_result,
+ stage2_index: int,
+ seed_2: int = 0,
+ guidance_scale_2: float = 4.0,
+ custom_timesteps_2: str = 'smart50',
+ num_inference_steps_2: int = 50,
+ disable_watermark: bool = True,
+) -> Image:
+ run_garbage_collection()
+
+ prompt_embeds = stage1_result['prompt_embeds']
+ negative_embeds = stage1_result['negative_embeds']
+ images = stage1_result['images']
+ images = images[[stage2_index]]
+
+ stageII_generations, _ = model.embeddings_to_image(low_res=images,
+ t5_embs=prompt_embeds,
+ negative_t5_embs=negative_embeds,
+ guidance_scale=guidance_scale_2,
+ sample_timestep_respacing=custom_timesteps_2,
+ seed=seed_2)
+ pil_images_II = model.to_images(stageII_generations, disable_watermark=disable_watermark)
+
+ return pil_images_II
+
+
+def run_stage3(
+ model,
+ image: Image,
+ t5_embs,
+ negative_t5_embs,
+ seed_3: int = 0,
+ guidance_scale_3: float = 9.0,
+ sample_timestep_respacing='super40',
+ disable_watermark=True
+) -> Image:
+ run_garbage_collection()
+
+ _stageIII_generations, _ = model.embeddings_to_image(image=image,
+ t5_embs=t5_embs,
+ negative_t5_embs=negative_t5_embs,
+ num_images_per_prompt=1,
+ guidance_scale=guidance_scale_3,
+ noise_level=100,
+ sample_timestep_respacing=sample_timestep_respacing,
+ seed=seed_3)
+ pil_image_III = model.to_images(_stageIII_generations, disable_watermark=disable_watermark)
+
+ return pil_image_III
diff --git a/run_ui.py b/run_ui.py
new file mode 100644
index 0000000..8e09648
--- /dev/null
+++ b/run_ui.py
@@ -0,0 +1,371 @@
+import argparse
+import gc
+import os
+
+import numpy as np
+
+from deepfloyd_if.modules import IFStageI, IFStageII, StableStageIII
+from deepfloyd_if.modules.t5 import T5Embedder
+from deepfloyd_if.pipelines.optimized_dream import run_stage1, run_stage2, run_stage3
+
+import torch
+
+import gradio as gr
+
+from ui_files.utils import randomize_seed_fn, show_gallery_view, update_upscale_button, get_stage2_index, \
+ check_if_stage2_selected, show_upscaled_view, get_device_map
+
+try:
+ import xformers
+
+ os.environ["FORCE_MEM_EFFICIENT_ATTN"] = "1"
+except:
+ pass
+
+device = torch.device(0)
+if_I = IFStageI('IF-I-XL-v1.0', device=torch.device("cpu"))
+if_I.to(torch.float16) # half
+# # if_II = IFStageII('IF-II-L-v1.0', device=torch.device("cpu"))
+# # if_III = StableStageIII('stable-diffusion-x4-upscaler', device=torch.device("cpu"))
+t5_device = torch.device(0)
+t5 = T5Embedder(device=t5_device, t5_model_kwargs={"low_cpu_mem_usage": True,
+ "torch_dtype": torch.float16,
+ "device_map": get_device_map(t5_device),
+ "offload_folder": True})
+
+
+def switch_devices(stage):
+ if stage == 1:
+ # t5.model.cpu()
+ del t5.model
+ if_I.to(torch.device(0))
+
+ gc.collect()
+ torch.cuda.empty_cache()
+ torch.cuda.synchronize()
+
+
+def process_and_run_stage1(prompt,
+ negative_prompt,
+ seed_1,
+ num_images,
+ guidance_scale_1,
+ custom_timesteps_1,
+ num_inference_steps_1):
+ print("Encoding prompts..")
+ prompt = t5.get_text_embeddings(prompt)
+ if negative_prompt == "":
+ negative_prompt = torch.zeros_like(prompt)
+ else:
+ negative_prompt = t5.get_text_embeddings(negative_prompt)
+ switch_devices(stage=1)
+ prompt = prompt.to(if_I.device)
+ negative_prompt = negative_prompt.to(if_I.device)
+ print("Encoded. Running 1st stage")
+ return run_stage1(
+ if_I,
+ t5_embs=prompt,
+ negative_t5_embs=negative_prompt,
+ seed=seed_1,
+ num_images=num_images,
+ guidance_scale_1=guidance_scale_1,
+ custom_timesteps_1=custom_timesteps_1,
+ num_inference_steps_1=num_inference_steps_1
+ )
+
+
+def create_ui(args):
+ with gr.Blocks(css='ui_files/style.css') as demo:
+ with gr.Box():
+ with gr.Row(elem_id='prompt-container').style(equal_height=True):
+ with gr.Column():
+ prompt = gr.Text(
+ label='Prompt',
+ show_label=False,
+ max_lines=1,
+ placeholder='Enter your prompt',
+ elem_id='prompt-text-input',
+ ).style(container=False)
+ negative_prompt = gr.Text(
+ label='Negative prompt',
+ show_label=False,
+ max_lines=1,
+ placeholder='Enter a negative prompt',
+ elem_id='negative-prompt-text-input',
+ ).style(container=False)
+ generate_button = gr.Button('Generate').style(full_width=False)
+
+ with gr.Column() as gallery_view:
+ gallery = gr.Gallery(label='Stage 1 results',
+ show_label=False,
+ elem_id='gallery').style(
+ columns=args.GALLERY_COLUMN_NUM,
+ object_fit='contain')
+ gr.Markdown('Pick your favorite generation to upscale.')
+ with gr.Row():
+ upscale_to_256_button = gr.Button(
+ 'Upscale to 256px',
+ visible=args.DISABLE_SD_X4_UPSCALER,
+ interactive=False)
+ upscale_button = gr.Button('Upscale',
+ interactive=False,
+ visible=not args.DISABLE_SD_X4_UPSCALER)
+ with gr.Column(visible=False) as upscale_view:
+ result = gr.Image(label='Result',
+ show_label=False,
+ type='filepath',
+ interactive=False,
+ elem_id='upscaled-image').style(height=640)
+ back_to_selection_button = gr.Button('Back to selection')
+ with gr.Accordion('Advanced options',
+ open=False,
+ visible=args.SHOW_ADVANCED_OPTIONS):
+ with gr.Tabs():
+ with gr.Tab(label='Generation'):
+ seed_1 = gr.Slider(label='Seed',
+ minimum=0,
+ maximum=args.MAX_SEED,
+ step=1,
+ value=0)
+ randomize_seed_1 = gr.Checkbox(label='Randomize seed',
+ value=True)
+ guidance_scale_1 = gr.Slider(label='Guidance scale',
+ minimum=1,
+ maximum=20,
+ step=0.1,
+ value=7.0)
+ custom_timesteps_1 = gr.Dropdown(
+ label='Custom timesteps 1',
+ choices=[
+ 'none',
+ 'fast27',
+ 'smart27',
+ 'smart50',
+ 'smart100',
+ 'smart185',
+ ],
+ value="smart100",
+ visible=True)
+ num_inference_steps_1 = gr.Slider(
+ label='Number of inference steps',
+ minimum=1,
+ maximum=200,
+ step=1,
+ value=100,
+ visible=True)
+ num_images = gr.Slider(label='Number of images',
+ minimum=1,
+ maximum=4,
+ step=1,
+ value=4,
+ visible=True)
+ with gr.Tab(label='Super-resolution 1'):
+ seed_2 = gr.Slider(label='Seed',
+ minimum=0,
+ maximum=args.MAX_SEED,
+ step=1,
+ value=0)
+ randomize_seed_2 = gr.Checkbox(label='Randomize seed',
+ value=True)
+ guidance_scale_2 = gr.Slider(label='Guidance scale',
+ minimum=1,
+ maximum=20,
+ step=0.1,
+ value=4.0)
+ custom_timesteps_2 = gr.Dropdown(
+ label='Custom timesteps 2',
+ choices=[
+ 'none',
+ 'fast27',
+ 'smart27',
+ 'smart50',
+ 'smart100',
+ 'smart185',
+ ],
+ value="smart50",
+ visible=True)
+ num_inference_steps_2 = gr.Slider(
+ label='Number of inference steps',
+ minimum=1,
+ maximum=200,
+ step=1,
+ value=50,
+ visible=True)
+ with gr.Tab(label='Super-resolution 2'):
+ seed_3 = gr.Slider(label='Seed',
+ minimum=0,
+ maximum=args.MAX_SEED,
+ step=1,
+ value=0)
+ randomize_seed_3 = gr.Checkbox(label='Randomize seed',
+ value=True)
+ guidance_scale_3 = gr.Slider(label='Guidance scale',
+ minimum=1,
+ maximum=20,
+ step=0.1,
+ value=9.0)
+ num_inference_steps_3 = gr.Slider(
+ label='Number of inference steps',
+ minimum=1,
+ maximum=200,
+ step=1,
+ value=40,
+ visible=True)
+ with gr.Box():
+ with gr.Row():
+ with gr.Accordion(label='Hidden params'):
+ selected_index_for_stage2 = gr.Number(
+ label='Selected index for Stage 2', value=-1, precision=0)
+
+ generate_button.click(
+ process_and_run_stage1,
+ [prompt,
+ negative_prompt,
+ seed_1,
+ num_images,
+ guidance_scale_1,
+ custom_timesteps_1,
+ num_inference_steps_1],
+ gallery
+ )
+
+ gallery.select(
+ fn=get_stage2_index,
+ outputs=selected_index_for_stage2,
+ queue=False,
+ )
+ #
+ # selected_index_for_stage2.change(
+ # fn=update_upscale_button,
+ # inputs=selected_index_for_stage2,
+ # outputs=[
+ # upscale_button,
+ # upscale_to_256_button,
+ # ],
+ # queue=False,
+ # )
+ #
+ # stage2_inputs = [
+ # stage1_result_path,
+ # selected_index_for_stage2,
+ # seed_2,
+ # guidance_scale_2,
+ # custom_timesteps_2,
+ # num_inference_steps_2,
+ # ]
+ #
+ # upscale_to_256_button.click(
+ # fn=check_if_stage2_selected,
+ # inputs=selected_index_for_stage2,
+ # queue=False,
+ # ).then(
+ # fn=randomize_seed_fn,
+ # inputs=[seed_2, randomize_seed_2],
+ # outputs=seed_2,
+ # queue=False,
+ # ).then(
+ # fn=show_upscaled_view,
+ # outputs=[
+ # gallery_view,
+ # upscale_view,
+ # ],
+ # queue=False,
+ # ).then(
+ # fn=run_stage2,
+ # inputs=stage2_inputs,
+ # outputs=result,
+ # api_name='upscale256',
+ # ) # .success(
+ # # fn=upload_stage2_info,
+ # # inputs=[
+ # # stage1_param_file_hash_name,
+ # # result,
+ # # selected_index_for_stage2,
+ # # seed_2,
+ # # guidance_scale_2,
+ # # custom_timesteps_2,
+ # # num_inference_steps_2,
+ # # ],
+ # # queue=False,
+ # # )
+ #
+ # stage2_3_inputs = [
+ # stage1_result_path,
+ # selected_index_for_stage2,
+ # seed_2,
+ # guidance_scale_2,
+ # custom_timesteps_2,
+ # num_inference_steps_2,
+ # prompt,
+ # negative_prompt,
+ # seed_3,
+ # guidance_scale_3,
+ # num_inference_steps_3,
+ # ]
+ #
+ # upscale_button.click(
+ # fn=check_if_stage2_selected,
+ # inputs=selected_index_for_stage2,
+ # queue=False,
+ # ).then(
+ # fn=randomize_seed_fn,
+ # inputs=[seed_2, randomize_seed_2],
+ # outputs=seed_2,
+ # queue=False,
+ # ).then(
+ # fn=randomize_seed_fn,
+ # inputs=[seed_3, randomize_seed_3],
+ # outputs=seed_3,
+ # queue=False,
+ # ).then(
+ # fn=show_upscaled_view,
+ # outputs=[
+ # gallery_view,
+ # upscale_view,
+ # ],
+ # queue=False,
+ # ).then(
+ # fn=run_stage3,
+ # inputs=stage2_3_inputs,
+ # outputs=result,
+ # api_name='upscale1024',
+ # ) # .success(
+ # # fn=upload_stage2_3_info,
+ # # inputs=[
+ # # stage1_param_file_hash_name,
+ # # result,
+ # # selected_index_for_stage2,
+ # # seed_2,
+ # # guidance_scale_2,
+ # # custom_timesteps_2,
+ # # num_inference_steps_2,
+ # # prompt,
+ # # negative_prompt,
+ # # seed_3,
+ # # guidance_scale_3,
+ # # num_inference_steps_3,
+ # # ],
+ # # queue=False,
+ # # )
+ #
+ # back_to_selection_button.click(
+ # fn=show_gallery_view,
+ # outputs=[
+ # gallery_view,
+ # upscale_view,
+ # ],
+ # queue=False,
+ # )
+ return demo
+
+
+if __name__ == '__main__':
+ parser = argparse.ArgumentParser(
+ description='IF UI settings')
+ parser.add_argument('--GALLERY_COLUMN_NUM', type=int, default=4)
+ parser.add_argument('--DISABLE_SD_X4_UPSCALER', type=bool, default=True)
+ parser.add_argument('--SHOW_ADVANCED_OPTIONS', type=bool, default=True)
+ parser.add_argument('--MAX_SEED', type=int, default=np.iinfo(np.int32).max)
+
+ demo = create_ui(parser.parse_args())
+ demo.launch()
diff --git a/ui_files/style.css b/ui_files/style.css
new file mode 100644
index 0000000..17fb109
--- /dev/null
+++ b/ui_files/style.css
@@ -0,0 +1,238 @@
+/*
+This CSS file is modified from:
+https://huggingface.co/spaces/stabilityai/stable-diffusion/blob/2794a3c3ba66115c307075098e713f572b08bf80/app.py
+*/
+
+h1 {
+ text-align: center;
+}
+
+.gradio-container {
+ font-family: 'IBM Plex Sans', sans-serif;
+}
+
+.gr-button {
+ color: white;
+ border-color: black;
+ background: black;
+}
+
+input[type='range'] {
+ accent-color: black;
+}
+
+.dark input[type='range'] {
+ accent-color: #dfdfdf;
+}
+
+.container {
+ max-width: 730px;
+ margin: auto;
+ padding-top: 1.5rem;
+}
+
+#gallery {
+ min-height: auto;
+ height: 185px;
+ margin-top: 15px;
+ margin-left: auto;
+ margin-right: auto;
+ border-bottom-right-radius: .5rem !important;
+ border-bottom-left-radius: .5rem !important;
+}
+#gallery .grid-wrap, #gallery .empty{
+ height: 185px;
+ min-height: 185px;
+}
+#gallery .preview{
+ height: 185px;
+ min-height: 185px!important;
+}
+#gallery>div>.h-full {
+ min-height: 20rem;
+}
+
+.details:hover {
+ text-decoration: underline;
+}
+
+.gr-button {
+ white-space: nowrap;
+}
+
+.gr-button:focus {
+ border-color: rgb(147 197 253 / var(--tw-border-opacity));
+ outline: none;
+ box-shadow: var(--tw-ring-offset-shadow), var(--tw-ring-shadow), var(--tw-shadow, 0 0 #0000);
+ --tw-border-opacity: 1;
+ --tw-ring-offset-shadow: var(--tw-ring-inset) 0 0 0 var(--tw-ring-offset-width) var(--tw-ring-offset-color);
+ --tw-ring-shadow: var(--tw-ring-inset) 0 0 0 calc(3px var(--tw-ring-offset-width)) var(--tw-ring-color);
+ --tw-ring-color: rgb(191 219 254 / var(--tw-ring-opacity));
+ --tw-ring-opacity: .5;
+}
+
+#advanced-btn {
+ font-size: .7rem !important;
+ line-height: 19px;
+ margin-top: 12px;
+ margin-bottom: 12px;
+ padding: 2px 8px;
+ border-radius: 14px !important;
+}
+
+#advanced-options {
+ display: none;
+ margin-bottom: 20px;
+}
+
+.footer {
+ margin-bottom: 45px;
+ margin-top: 35px;
+ text-align: center;
+ border-bottom: 1px solid #e5e5e5;
+}
+
+.footer>p {
+ font-size: .8rem;
+ display: inline-block;
+ padding: 0 10px;
+ transform: translateY(10px);
+ background: white;
+}
+
+.dark .footer {
+ border-color: #303030;
+}
+
+.dark .footer>p {
+ background: #0b0f19;
+}
+
+.acknowledgments h4 {
+ margin: 1.25em 0 .25em 0;
+ font-weight: bold;
+ font-size: 115%;
+}
+
+.animate-spin {
+ animation: spin 1s linear infinite;
+}
+
+@keyframes spin {
+ from {
+ transform: rotate(0deg);
+ }
+
+ to {
+ transform: rotate(360deg);
+ }
+}
+
+#share-btn-container {
+ display: flex;
+ padding-left: 0.5rem !important;
+ padding-right: 0.5rem !important;
+ background-color: #000000;
+ justify-content: center;
+ align-items: center;
+ border-radius: 9999px !important;
+ width: 13rem;
+ margin-top: 10px;
+ margin-left: auto;
+}
+
+#share-btn {
+ all: initial;
+ color: #ffffff;
+ font-weight: 600;
+ cursor: pointer;
+ font-family: 'IBM Plex Sans', sans-serif;
+ margin-left: 0.5rem !important;
+ padding-top: 0.25rem !important;
+ padding-bottom: 0.25rem !important;
+ right: 0;
+}
+
+#share-btn * {
+ all: unset;
+}
+
+#share-btn-container div:nth-child(-n+2) {
+ width: auto !important;
+ min-height: 0px !important;
+}
+
+#share-btn-container .wrap {
+ display: none !important;
+}
+
+.gr-form {
+ flex: 1 1 50%;
+ border-top-right-radius: 0;
+ border-bottom-right-radius: 0;
+}
+
+#prompt-container {
+ gap: 0;
+}
+
+#prompt-text-input,
+#negative-prompt-text-input {
+ padding: .45rem 0.625rem
+}
+
+#component-16 {
+ border-top-width: 1px !important;
+ margin-top: 1em
+}
+
+.image_duplication {
+ position: absolute;
+ width: 100px;
+ left: 50px
+}
+
+#component-0 {
+ max-width: 730px;
+ margin: auto;
+ padding-top: 1.5rem;
+}
+
+#upscaled-image img {
+ object-fit: scale-down;
+}
+/* share button */
+#share-btn-container {
+ display: flex;
+ padding-left: 0.5rem !important;
+ padding-right: 0.5rem !important;
+ background-color: #000000;
+ justify-content: center;
+ align-items: center;
+ border-radius: 9999px !important;
+ width: 13rem;
+ margin-top: 10px;
+ margin-left: auto;
+ flex: unset !important;
+}
+#share-btn {
+ all: initial;
+ color: #ffffff;
+ font-weight: 600;
+ cursor: pointer;
+ font-family: 'IBM Plex Sans', sans-serif;
+ margin-left: 0.5rem !important;
+ padding-top: 0.25rem !important;
+ padding-bottom: 0.25rem !important;
+ right:0;
+}
+#share-btn * {
+ all: unset !important;
+}
+#share-btn-container div:nth-child(-n+2){
+ width: auto !important;
+ min-height: 0px !important;
+}
+#share-btn-container .wrap {
+ display: none !important;
+}
\ No newline at end of file
diff --git a/ui_files/utils.py b/ui_files/utils.py
new file mode 100644
index 0000000..a80d80d
--- /dev/null
+++ b/ui_files/utils.py
@@ -0,0 +1,72 @@
+import gradio as gr
+import numpy as np
+import random
+
+
+def _update_result_view(show_gallery: bool) -> tuple[dict, dict]:
+ return gr.update(visible=show_gallery), gr.update(visible=not show_gallery)
+
+
+def show_gallery_view() -> tuple[dict, dict]:
+ return _update_result_view(True)
+
+
+def show_upscaled_view() -> tuple[dict, dict]:
+ return _update_result_view(False)
+
+
+def randomize_seed_fn(seed: int, randomize_seed: bool) -> int:
+ if randomize_seed:
+ seed = random.randint(0, np.iinfo(np.int32).max)
+ return seed
+
+
+def update_upscale_button(selected_index: int) -> tuple[dict, dict]:
+ if selected_index == -1:
+ return gr.update(interactive=False), gr.update(interactive=False)
+ else:
+ return gr.update(interactive=True), gr.update(interactive=True)
+
+
+def get_stage2_index(evt: gr.SelectData) -> int:
+ return evt.index
+
+
+def check_if_stage2_selected(index: int) -> None:
+ if index == -1:
+ raise gr.Error(
+ 'You need to select the image you would like to upscale from the Stage 1 results by clicking.'
+ )
+
+
+def get_device_map(device):
+ return {
+ 'shared': device,
+ 'encoder.embed_tokens': device,
+ 'encoder.block.0': device,
+ 'encoder.block.1': device,
+ 'encoder.block.2': device,
+ 'encoder.block.3': device,
+ 'encoder.block.4': device,
+ 'encoder.block.5': device,
+ 'encoder.block.6': device,
+ 'encoder.block.7': device,
+ 'encoder.block.8': device,
+ 'encoder.block.9': device,
+ 'encoder.block.10': device,
+ 'encoder.block.11': device,
+ 'encoder.block.12': 'cpu',
+ 'encoder.block.13': 'cpu',
+ 'encoder.block.14': 'cpu',
+ 'encoder.block.15': 'cpu',
+ 'encoder.block.16': 'cpu',
+ 'encoder.block.17': 'cpu',
+ 'encoder.block.18': 'cpu',
+ 'encoder.block.19': 'cpu',
+ 'encoder.block.20': 'cpu',
+ 'encoder.block.21': 'cpu',
+ 'encoder.block.22': 'cpu',
+ 'encoder.block.23': 'cpu',
+ 'encoder.final_layer_norm': 'cpu',
+ 'encoder.dropout': 'cpu',
+ }
From 2df8326a81a2f5a8561b23877e95cf0f4a37ce55 Mon Sep 17 00:00:00 2001
From: neon
Date: Sat, 29 Apr 2023 17:24:04 +0200
Subject: [PATCH 02/10] kinda works
---
deepfloyd_if/model/__init__.py | 3 +-
deepfloyd_if/model/gaussian_diffusion.py | 300 ++++-----
deepfloyd_if/model/unet.py | 31 +-
deepfloyd_if/model/unet_split.py | 750 ++++++++++++++++++++++
deepfloyd_if/modules/base.py | 68 +-
deepfloyd_if/modules/stage_I.py | 13 +-
deepfloyd_if/modules/stage_III_sd_x4.py | 4 +-
deepfloyd_if/pipelines/dream.py | 42 +-
deepfloyd_if/pipelines/optimized_dream.py | 2 +-
run_ui.py | 20 +-
ui_files/utils.py | 5 +-
11 files changed, 995 insertions(+), 243 deletions(-)
create mode 100644 deepfloyd_if/model/unet_split.py
diff --git a/deepfloyd_if/model/__init__.py b/deepfloyd_if/model/__init__.py
index 332da3e..55e58ca 100644
--- a/deepfloyd_if/model/__init__.py
+++ b/deepfloyd_if/model/__init__.py
@@ -1,5 +1,6 @@
# -*- coding: utf-8 -*-
from .unet import UNetModel, SuperResUNetModel
+from .unet_split import UNetSplitModel
-__all__ = ['UNetModel', 'SuperResUNetModel']
+__all__ = ['UNetModel', 'SuperResUNetModel', 'UNetSplitModel']
diff --git a/deepfloyd_if/model/gaussian_diffusion.py b/deepfloyd_if/model/gaussian_diffusion.py
index e058fbc..596c1b7 100644
--- a/deepfloyd_if/model/gaussian_diffusion.py
+++ b/deepfloyd_if/model/gaussian_diffusion.py
@@ -110,13 +110,13 @@ class GaussianDiffusion:
"""
def __init__(
- self,
- *,
- betas,
- model_mean_type,
- model_var_type,
- loss_type,
- rescale_timesteps=False,
+ self,
+ *,
+ betas,
+ model_mean_type,
+ model_var_type,
+ loss_type,
+ rescale_timesteps=False,
):
self.model_mean_type = model_mean_type
self.model_var_type = model_var_type
@@ -146,7 +146,7 @@ def __init__(
# calculations for posterior q(x_{t-1} | x_t, x_0)
self.posterior_variance = (
- betas * (1.0 - self.alphas_cumprod_prev) / (1.0 - self.alphas_cumprod)
+ betas * (1.0 - self.alphas_cumprod_prev) / (1.0 - self.alphas_cumprod)
)
# log calculation clipped because the posterior variance is 0 at the
# beginning of the diffusion chain.
@@ -154,12 +154,12 @@ def __init__(
np.append(self.posterior_variance[1], self.posterior_variance[1:])
)
self.posterior_mean_coef1 = (
- betas * np.sqrt(self.alphas_cumprod_prev) / (1.0 - self.alphas_cumprod)
+ betas * np.sqrt(self.alphas_cumprod_prev) / (1.0 - self.alphas_cumprod)
)
self.posterior_mean_coef2 = (
- (1.0 - self.alphas_cumprod_prev)
- * np.sqrt(alphas)
- / (1.0 - self.alphas_cumprod)
+ (1.0 - self.alphas_cumprod_prev)
+ * np.sqrt(alphas)
+ / (1.0 - self.alphas_cumprod)
)
def dynamic_thresholding(self, x, p=0.995, c=1.7):
@@ -189,7 +189,7 @@ def q_mean_variance(self, x_start, t):
:return: A tuple (mean, variance, log_variance), all of x_start's shape.
"""
mean = (
- _extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start
+ _extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start
)
variance = _extract_into_tensor(1.0 - self.alphas_cumprod, t, x_start.shape)
log_variance = _extract_into_tensor(
@@ -210,9 +210,9 @@ def q_sample(self, x_start, t, noise=None):
noise = torch.randn_like(x_start)
assert noise.shape == x_start.shape
return (
- _extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start
- + _extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape)
- * noise
+ _extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start
+ + _extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape)
+ * noise
)
def q_posterior_mean_variance(self, x_start, x_t, t):
@@ -222,24 +222,24 @@ def q_posterior_mean_variance(self, x_start, x_t, t):
"""
assert x_start.shape == x_t.shape
posterior_mean = (
- _extract_into_tensor(self.posterior_mean_coef1, t, x_t.shape) * x_start
- + _extract_into_tensor(self.posterior_mean_coef2, t, x_t.shape) * x_t
+ _extract_into_tensor(self.posterior_mean_coef1, t, x_t.shape) * x_start
+ + _extract_into_tensor(self.posterior_mean_coef2, t, x_t.shape) * x_t
)
posterior_variance = _extract_into_tensor(self.posterior_variance, t, x_t.shape)
posterior_log_variance_clipped = _extract_into_tensor(
self.posterior_log_variance_clipped, t, x_t.shape
)
assert (
- posterior_mean.shape[0]
- == posterior_variance.shape[0]
- == posterior_log_variance_clipped.shape[0]
- == x_start.shape[0]
+ posterior_mean.shape[0]
+ == posterior_variance.shape[0]
+ == posterior_log_variance_clipped.shape[0]
+ == x_start.shape[0]
)
return posterior_mean, posterior_variance, posterior_log_variance_clipped
def p_mean_variance(
- self, model, x, t, clip_denoised=True, dynamic_thresholding_p=0.99, dynamic_thresholding_c=1.7,
- denoised_fn=None, model_kwargs=None
+ self, model, x, t, clip_denoised=True, dynamic_thresholding_p=0.99, dynamic_thresholding_c=1.7,
+ denoised_fn=None, model_kwargs=None
):
"""
Apply the model to get p(x_{t-1} | x_t), as well as a prediction of
@@ -280,7 +280,7 @@ def p_mean_variance(
max_log = _extract_into_tensor(np.log(self.betas), t, x.shape)
# The model_var_values is [-1, 1] for [min_var, max_var].
frac = (model_var_values + 1) / 2
- model_log_variance = frac * max_log + (1 - frac) * min_log
+ model_log_variance = frac * max_log.to(frac.device) + (1 - frac) * min_log.to(frac.device)
model_variance = torch.exp(model_log_variance)
else:
model_variance, model_log_variance = {
@@ -306,6 +306,7 @@ def process_xstart(x):
return x # x.clamp(-1, 1)
return x
+ x, t = x.to(model_output.device), t.to(model_output.device)
if self.model_mean_type == ModelMeanType.PREVIOUS_X:
pred_xstart = process_xstart(
self._predict_xstart_from_xprev(x_t=x, t=t, xprev=model_output)
@@ -325,7 +326,7 @@ def process_xstart(x):
raise NotImplementedError(self.model_mean_type)
assert (
- model_mean.shape == model_log_variance.shape == pred_xstart.shape == x.shape
+ model_mean.shape == model_log_variance.shape == pred_xstart.shape == x.shape
)
return {
'mean': model_mean,
@@ -337,25 +338,25 @@ def process_xstart(x):
def _predict_xstart_from_eps(self, x_t, t, eps):
assert x_t.shape == eps.shape
return (
- _extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t
- - _extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) * eps
+ _extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t
+ - _extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) * eps
)
def _predict_xstart_from_xprev(self, x_t, t, xprev):
assert x_t.shape == xprev.shape
return ( # (xprev - coef2*x_t) / coef1
- _extract_into_tensor(1.0 / self.posterior_mean_coef1, t, x_t.shape) * xprev
- - _extract_into_tensor(
- self.posterior_mean_coef2 / self.posterior_mean_coef1, t, x_t.shape
- )
- * x_t
+ _extract_into_tensor(1.0 / self.posterior_mean_coef1, t, x_t.shape) * xprev
+ - _extract_into_tensor(
+ self.posterior_mean_coef2 / self.posterior_mean_coef1, t, x_t.shape
+ )
+ * x_t
)
def _predict_eps_from_xstart(self, x_t, t, pred_xstart):
return (
- _extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t
- - pred_xstart
- ) / _extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape)
+ _extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t
+ - pred_xstart
+ ) / _extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape)
def _scale_timesteps(self, t):
if self.rescale_timesteps:
@@ -363,8 +364,8 @@ def _scale_timesteps(self, t):
return t
def p_sample(
- self, model, x, t, clip_denoised=True, dynamic_thresholding_p=0.99, dynamic_thresholding_c=1.7,
- denoised_fn=None, model_kwargs=None, inpainting_mask=None,
+ self, model, x, t, clip_denoised=True, dynamic_thresholding_p=0.99, dynamic_thresholding_c=1.7,
+ denoised_fn=None, model_kwargs=None, inpainting_mask=None,
):
"""
Sample x_{t-1} from the model at the given timestep.
@@ -390,31 +391,36 @@ def p_sample(
denoised_fn=denoised_fn,
model_kwargs=model_kwargs,
)
- noise = torch.randn_like(x)
+ device = out['mean'].device
+ noise = torch.randn_like(x, device=device)
nonzero_mask = (
(t != 0).float().view(-1, *([1] * (len(x.shape) - 1)))
- ) # no noise when t == 0
+ ).to(device) # no noise when t == 0
if inpainting_mask is None:
- inpainting_mask = torch.ones_like(x, device=x.device)
+ inpainting_mask = torch.ones_like(x, device=device)
+
+ x, t = x.to(device), t.to(device)
- sample = out['mean'] + nonzero_mask * torch.exp(0.5 * out['log_variance']) * noise
- sample = (1 - inpainting_mask)*x + inpainting_mask*sample
+ noise = (torch.exp(0.5 * out['log_variance']) * noise).to(device)
+
+ sample = out['mean'] + nonzero_mask * noise
+ sample = (1 - inpainting_mask) * x + inpainting_mask * sample
return {'sample': sample, 'pred_xstart': out['pred_xstart']}
def p_sample_loop(
- self,
- model,
- shape,
- noise=None,
- clip_denoised=True,
- dynamic_thresholding_p=0.99,
- dynamic_thresholding_c=1.7,
- inpainting_mask=None,
- denoised_fn=None,
- model_kwargs=None,
- device=None,
- progress=False,
- sample_fn=None,
+ self,
+ model,
+ shape,
+ noise=None,
+ clip_denoised=True,
+ dynamic_thresholding_p=0.99,
+ dynamic_thresholding_c=1.7,
+ inpainting_mask=None,
+ denoised_fn=None,
+ model_kwargs=None,
+ device=None,
+ progress=False,
+ sample_fn=None,
):
"""
Generate samples from the model.
@@ -434,17 +440,17 @@ def p_sample_loop(
"""
final = None
for step_idx, sample in enumerate(self.p_sample_loop_progressive(
- model,
- shape,
- noise=noise,
- clip_denoised=clip_denoised,
- dynamic_thresholding_p=dynamic_thresholding_p,
- dynamic_thresholding_c=dynamic_thresholding_c,
- denoised_fn=denoised_fn,
- inpainting_mask=inpainting_mask,
- model_kwargs=model_kwargs,
- device=device,
- progress=progress,
+ model,
+ shape,
+ noise=noise,
+ clip_denoised=clip_denoised,
+ dynamic_thresholding_p=dynamic_thresholding_p,
+ dynamic_thresholding_c=dynamic_thresholding_c,
+ denoised_fn=denoised_fn,
+ inpainting_mask=inpainting_mask,
+ model_kwargs=model_kwargs,
+ device=device,
+ progress=progress,
)):
if sample_fn is not None:
sample = sample_fn(step_idx, sample)
@@ -452,18 +458,18 @@ def p_sample_loop(
return final['sample']
def p_sample_loop_progressive(
- self,
- model,
- shape,
- inpainting_mask=None,
- noise=None,
- clip_denoised=True,
- dynamic_thresholding_p=0.99,
- dynamic_thresholding_c=1.7,
- denoised_fn=None,
- model_kwargs=None,
- device=None,
- progress=False,
+ self,
+ model,
+ shape,
+ inpainting_mask=None,
+ noise=None,
+ clip_denoised=True,
+ dynamic_thresholding_p=0.99,
+ dynamic_thresholding_c=1.7,
+ denoised_fn=None,
+ model_kwargs=None,
+ device=None,
+ progress=False,
):
"""
Generate samples from the model and yield intermediate samples from
@@ -472,8 +478,6 @@ def p_sample_loop_progressive(
Returns a generator over dicts, where each dict is the return value of
p_sample().
"""
- if device is None:
- device = next(model.parameters()).device
assert isinstance(shape, (tuple, list))
if noise is not None:
img = noise
@@ -505,16 +509,16 @@ def p_sample_loop_progressive(
img = out['sample']
def ddim_sample(
- self,
- model,
- x,
- t,
- clip_denoised=True,
- dynamic_thresholding_p=0.99,
- dynamic_thresholding_c=1.7,
- denoised_fn=None,
- model_kwargs=None,
- eta=0.0,
+ self,
+ model,
+ x,
+ t,
+ clip_denoised=True,
+ dynamic_thresholding_p=0.99,
+ dynamic_thresholding_c=1.7,
+ denoised_fn=None,
+ model_kwargs=None,
+ eta=0.0,
):
"""
Sample x_{t-1} from the model using DDIM.
@@ -536,15 +540,15 @@ def ddim_sample(
alpha_bar = _extract_into_tensor(self.alphas_cumprod, t, x.shape)
alpha_bar_prev = _extract_into_tensor(self.alphas_cumprod_prev, t, x.shape)
sigma = (
- eta
- * torch.sqrt((1 - alpha_bar_prev) / (1 - alpha_bar))
- * torch.sqrt(1 - alpha_bar / alpha_bar_prev)
+ eta
+ * torch.sqrt((1 - alpha_bar_prev) / (1 - alpha_bar))
+ * torch.sqrt(1 - alpha_bar / alpha_bar_prev)
)
# Equation 12.
noise = torch.randn_like(x)
mean_pred = (
- out['pred_xstart'] * torch.sqrt(alpha_bar_prev)
- + torch.sqrt(1 - alpha_bar_prev - sigma ** 2) * eps
+ out['pred_xstart'] * torch.sqrt(alpha_bar_prev)
+ + torch.sqrt(1 - alpha_bar_prev - sigma ** 2) * eps
)
nonzero_mask = (
(t != 0).float().view(-1, *([1] * (len(x.shape) - 1)))
@@ -553,16 +557,16 @@ def ddim_sample(
return {'sample': sample, 'pred_xstart': out['pred_xstart']}
def ddim_reverse_sample(
- self,
- model,
- x,
- t,
- clip_denoised=True,
- dynamic_thresholding_p=0.99,
- dynamic_thresholding_c=1.7,
- denoised_fn=None,
- model_kwargs=None,
- eta=0.0,
+ self,
+ model,
+ x,
+ t,
+ clip_denoised=True,
+ dynamic_thresholding_p=0.99,
+ dynamic_thresholding_c=1.7,
+ denoised_fn=None,
+ model_kwargs=None,
+ eta=0.0,
):
"""
Sample x_{t+1} from the model using DDIM reverse ODE.
@@ -581,33 +585,33 @@ def ddim_reverse_sample(
# Usually our model outputs epsilon, but we re-derive it
# in case we used x_start or x_prev prediction.
eps = (
- _extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x.shape) * x
- - out['pred_xstart']
- ) / _extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x.shape)
+ _extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x.shape) * x
+ - out['pred_xstart']
+ ) / _extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x.shape)
alpha_bar_next = _extract_into_tensor(self.alphas_cumprod_next, t, x.shape)
# Equation 12. reversed
mean_pred = (
- out['pred_xstart'] * torch.sqrt(alpha_bar_next)
- + torch.sqrt(1 - alpha_bar_next) * eps
+ out['pred_xstart'] * torch.sqrt(alpha_bar_next)
+ + torch.sqrt(1 - alpha_bar_next) * eps
)
return {'sample': mean_pred, 'pred_xstart': out['pred_xstart']}
def ddim_sample_loop(
- self,
- model,
- shape,
- noise=None,
- clip_denoised=True,
- dynamic_thresholding_p=0.99,
- dynamic_thresholding_c=1.7,
- denoised_fn=None,
- model_kwargs=None,
- device=None,
- progress=False,
- eta=0.0,
- sample_fn=None,
+ self,
+ model,
+ shape,
+ noise=None,
+ clip_denoised=True,
+ dynamic_thresholding_p=0.99,
+ dynamic_thresholding_c=1.7,
+ denoised_fn=None,
+ model_kwargs=None,
+ device=None,
+ progress=False,
+ eta=0.0,
+ sample_fn=None,
):
"""
Generate samples from the model using DDIM.
@@ -615,17 +619,17 @@ def ddim_sample_loop(
"""
final = None
for step_idx, sample in enumerate(self.ddim_sample_loop_progressive(
- model,
- shape,
- noise=noise,
- clip_denoised=clip_denoised,
- denoised_fn=denoised_fn,
- dynamic_thresholding_p=dynamic_thresholding_p,
- dynamic_thresholding_c=dynamic_thresholding_c,
- model_kwargs=model_kwargs,
- device=device,
- progress=progress,
- eta=eta,
+ model,
+ shape,
+ noise=noise,
+ clip_denoised=clip_denoised,
+ denoised_fn=denoised_fn,
+ dynamic_thresholding_p=dynamic_thresholding_p,
+ dynamic_thresholding_c=dynamic_thresholding_c,
+ model_kwargs=model_kwargs,
+ device=device,
+ progress=progress,
+ eta=eta,
)):
if sample_fn is not None:
sample = sample_fn(step_idx, sample)
@@ -633,18 +637,18 @@ def ddim_sample_loop(
return final['sample']
def ddim_sample_loop_progressive(
- self,
- model,
- shape,
- noise=None,
- clip_denoised=True,
- dynamic_thresholding_p=0.99,
- dynamic_thresholding_c=1.7,
- denoised_fn=None,
- model_kwargs=None,
- device=None,
- progress=False,
- eta=0.0,
+ self,
+ model,
+ shape,
+ noise=None,
+ clip_denoised=True,
+ dynamic_thresholding_p=0.99,
+ dynamic_thresholding_c=1.7,
+ denoised_fn=None,
+ model_kwargs=None,
+ device=None,
+ progress=False,
+ eta=0.0,
):
"""
Use DDIM to sample from the model and yield intermediate samples from
@@ -684,7 +688,7 @@ def ddim_sample_loop_progressive(
img = out['sample']
def _vb_terms_bpd(
- self, model, x_start, x_t, t, clip_denoised=True, model_kwargs=None
+ self, model, x_start, x_t, t, clip_denoised=True, model_kwargs=None
):
"""
Get a term for the variational lower-bound.
diff --git a/deepfloyd_if/model/unet.py b/deepfloyd_if/model/unet.py
index bb83590..2fbc4aa 100644
--- a/deepfloyd_if/model/unet.py
+++ b/deepfloyd_if/model/unet.py
@@ -11,10 +11,7 @@
from .nn import avg_pool_nd, conv_nd, linear, normalization, timestep_embedding, zero_module, get_activation, \
AttentionPooling
-_FORCE_MEM_EFFICIENT_ATTN = int(os.environ.get('FORCE_MEM_EFFICIENT_ATTN', 0))
-print('FORCE_MEM_EFFICIENT_ATTN=', _FORCE_MEM_EFFICIENT_ATTN, '@UNET:QKVATTENTION')
-if _FORCE_MEM_EFFICIENT_ATTN:
- from xformers.ops import memory_efficient_attention # noqa
+from xformers.ops import memory_efficient_attention # noqa
class TimestepBlock(nn.Module):
@@ -246,7 +243,7 @@ def __init__(
self.num_heads = num_heads
else:
assert (
- channels % num_head_channels == 0
+ channels % num_head_channels == 0
), f'q,k,v channels {channels} is not divisible by num_head_channels {num_head_channels}'
self.num_heads = channels // num_head_channels
self.norm = normalization(channels, dtype=self.dtype)
@@ -310,16 +307,16 @@ def forward(self, qkv, encoder_kv=None):
k = torch.cat([ek, k], dim=-1)
v = torch.cat([ev, v], dim=-1)
scale = 1 / math.sqrt(math.sqrt(ch))
- if _FORCE_MEM_EFFICIENT_ATTN:
- q, k, v = map(lambda t: t.permute(0, 2, 1).contiguous(), (q, k, v))
- a = memory_efficient_attention(q, k, v)
- a = a.permute(0, 2, 1)
- else:
- weight = torch.einsum(
- 'bct,bcs->bts', q * scale, k * scale
- ) # More stable with f16 than dividing afterwards
- weight = torch.softmax(weight.float(), dim=-1).type(weight.dtype)
- a = torch.einsum('bts,bcs->bct', weight, v)
+ # if True: # legacy
+ q, k, v = map(lambda t: t.permute(0, 2, 1).contiguous(), (q, k, v))
+ a = memory_efficient_attention(q, k, v)
+ a = a.permute(0, 2, 1)
+ # else:
+ # weight = torch.einsum(
+ # 'bct,bcs->bts', q * scale, k * scale
+ # ) # More stable with f16 than dividing afterwards
+ # weight = torch.softmax(weight.float(), dim=-1).type(weight.dtype)
+ # a = torch.einsum('bts,bcs->bct', weight, v)
return a.reshape(bs, -1, length)
@@ -456,7 +453,7 @@ def __init__(
ds = 1
if isinstance(num_res_blocks, int):
- num_res_blocks = [num_res_blocks]*len(self.channel_mult)
+ num_res_blocks = [num_res_blocks] * len(self.channel_mult)
self.num_res_blocks = num_res_blocks
for level, mult in enumerate(self.channel_mult):
@@ -690,7 +687,7 @@ def forward(self, x, timesteps, low_res, aug_level=None, **kwargs):
)
if aug_level is None:
- aug_steps = (np.random.random(bs)*1000).astype(np.int64) # uniform [0, 1)
+ aug_steps = (np.random.random(bs) * 1000).astype(np.int64) # uniform [0, 1)
aug_steps = torch.from_numpy(aug_steps).to(x.device, dtype=torch.long)
else:
aug_steps = torch.tensor([int(aug_level * 1000)]).repeat(bs).to(x.device, dtype=torch.long)
diff --git a/deepfloyd_if/model/unet_split.py b/deepfloyd_if/model/unet_split.py
new file mode 100644
index 0000000..1d64e3b
--- /dev/null
+++ b/deepfloyd_if/model/unet_split.py
@@ -0,0 +1,750 @@
+# -*- coding: utf-8 -*-
+import gc
+import os
+import math
+from abc import abstractmethod
+
+import torch
+import numpy as np
+import torch.nn as nn
+import torch.nn.functional as F
+
+from .nn import avg_pool_nd, conv_nd, linear, normalization, timestep_embedding, zero_module, get_activation, \
+ AttentionPooling
+
+from xformers.ops import memory_efficient_attention # noqa
+
+
+class TimestepBlock(nn.Module):
+ """
+ Any module where forward() takes timestep embeddings as a second argument.
+ """
+
+ @abstractmethod
+ def forward(self, x, emb):
+ """
+ Apply the module to `x` given `emb` timestep embeddings.
+ """
+
+
+class TimestepEmbedSequential(nn.Sequential, TimestepBlock):
+ """
+ A sequential module that passes timestep embeddings to the children that
+ support it as an extra input.
+ """
+
+ def forward(self, x, emb, encoder_out=None):
+ for layer in self:
+ if isinstance(layer, TimestepBlock):
+ x = layer(x, emb)
+ elif isinstance(layer, AttentionBlock):
+ x = layer(x, encoder_out)
+ else:
+ x = layer(x)
+ return x
+
+
+class Upsample(nn.Module):
+ """
+ An upsampling layer with an optional convolution.
+ :param channels: channels in the inputs and outputs.
+ :param use_conv: a bool determining a convolution is applied.
+ :param dims: determines the signal is 1D, 2D, or 3D. If 3D, then
+ upsampling occurs in the inner-two dimensions.
+ """
+
+ def __init__(self, channels, use_conv, dims=2, out_channels=None, dtype=None):
+ super().__init__()
+ self.channels = channels
+ self.out_channels = out_channels or channels
+ self.use_conv = use_conv
+ self.dims = dims
+ self.dtype = dtype
+ if use_conv:
+ self.conv = conv_nd(dims, self.channels, self.out_channels, 3, padding=1, dtype=self.dtype)
+
+ def forward(self, x):
+ assert x.shape[1] == self.channels
+ if self.dims == 3:
+ x = F.interpolate(x, (x.shape[2], x.shape[3] * 2, x.shape[4] * 2), mode='nearest')
+ else:
+ if self.dtype == torch.bfloat16:
+ x = x.type(torch.float32 if x.device.type == 'cpu' else torch.float16)
+ x = F.interpolate(x, scale_factor=2, mode='nearest')
+ if self.dtype == torch.bfloat16:
+ x = x.type(torch.bfloat16)
+ if self.use_conv:
+ x = self.conv(x)
+ return x
+
+
+class Downsample(nn.Module):
+ """
+ A downsampling layer with an optional convolution.
+ :param channels: channels in the inputs and outputs.
+ :param use_conv: a bool determining a convolution is applied.
+ :param dims: determines the signal is 1D, 2D, or 3D. If 3D, then
+ downsampling occurs in the inner-two dimensions.
+ """
+
+ def __init__(self, channels, use_conv, dims=2, out_channels=None, dtype=None):
+ super().__init__()
+ self.channels = channels
+ self.out_channels = out_channels or channels
+ self.use_conv = use_conv
+ self.dims = dims
+ self.dtype = dtype
+ stride = 2 if dims != 3 else (1, 2, 2)
+ if use_conv:
+ self.op = conv_nd(dims, self.channels, self.out_channels, 3, stride=stride, padding=1, dtype=self.dtype)
+ else:
+ assert self.channels == self.out_channels
+ self.op = avg_pool_nd(dims, kernel_size=stride, stride=stride)
+
+ def forward(self, x):
+ assert x.shape[1] == self.channels
+ return self.op(x)
+
+
+class ResBlock(TimestepBlock):
+ """
+ A residual block that can optionally change the number of channels.
+ :param channels: the number of input channels.
+ :param emb_channels: the number of timestep embedding channels.
+ :param dropout: the rate of dropout.
+ :param out_channels: specified, the number of out channels.
+ :param use_conv: True and out_channels is specified, use a spatial
+ convolution instead of a smaller 1x1 convolution to change the
+ channels in the skip connection.
+ :param dims: determines the signal is 1D, 2D, or 3D.
+ :param up: True, use this block for upsampling.
+ :param down: True, use this block for downsampling.
+ """
+
+ def __init__(
+ self,
+ channels,
+ emb_channels,
+ dropout,
+ activation,
+ out_channels=None,
+ use_conv=False,
+ use_scale_shift_norm=False,
+ dims=2,
+ up=False,
+ down=False,
+ dtype=None,
+ efficient_activation=False,
+ scale_skip_connection=False,
+ ):
+ super().__init__()
+ self.dtype = dtype
+ self.channels = channels
+ self.emb_channels = emb_channels
+ self.dropout = dropout
+ self.out_channels = out_channels or channels
+ self.use_conv = use_conv
+ self.use_scale_shift_norm = use_scale_shift_norm
+ self.efficient_activation = efficient_activation
+ self.scale_skip_connection = scale_skip_connection
+
+ self.in_layers = nn.Sequential(
+ normalization(channels, dtype=self.dtype),
+ get_activation(activation),
+ conv_nd(dims, channels, self.out_channels, 3, padding=1, dtype=self.dtype),
+ )
+
+ self.updown = up or down
+
+ if up:
+ self.h_upd = Upsample(channels, False, dims, dtype=self.dtype)
+ self.x_upd = Upsample(channels, False, dims, dtype=self.dtype)
+ elif down:
+ self.h_upd = Downsample(channels, False, dims, dtype=self.dtype)
+ self.x_upd = Downsample(channels, False, dims, dtype=self.dtype)
+ else:
+ self.h_upd = self.x_upd = nn.Identity()
+
+ self.emb_layers = nn.Sequential(
+ nn.Identity() if self.efficient_activation else get_activation(activation),
+ linear(
+ emb_channels,
+ 2 * self.out_channels if use_scale_shift_norm else self.out_channels,
+ dtype=self.dtype
+ ),
+ )
+ self.out_layers = nn.Sequential(
+ normalization(self.out_channels, dtype=self.dtype),
+ get_activation(activation),
+ nn.Dropout(p=dropout),
+ zero_module(conv_nd(dims, self.out_channels, self.out_channels, 3, padding=1, dtype=self.dtype)),
+ )
+
+ if self.out_channels == channels:
+ self.skip_connection = nn.Identity()
+ elif use_conv:
+ self.skip_connection = conv_nd(dims, channels, self.out_channels, 3, padding=1, dtype=self.dtype)
+ else:
+ self.skip_connection = conv_nd(dims, channels, self.out_channels, 1, dtype=self.dtype)
+
+ def forward(self, x, emb):
+ """
+ Apply the block to a Tensor, conditioned on a timestep embedding.
+ :param x: an [N x C x ...] Tensor of features.
+ :param emb: an [N x emb_channels] Tensor of timestep embeddings.
+ :return: an [N x C x ...] Tensor of outputs.
+ """
+ if self.updown:
+ in_rest, in_conv = self.in_layers[:-1], self.in_layers[-1]
+ h = in_rest(x)
+ h = self.h_upd(h)
+ x = self.x_upd(x)
+ h = in_conv(h)
+ else:
+ h = self.in_layers(x)
+ emb_out = self.emb_layers(emb).type(h.dtype)
+ while len(emb_out.shape) < len(h.shape):
+ emb_out = emb_out[..., None]
+ if self.use_scale_shift_norm:
+ out_norm, out_rest = self.out_layers[0], self.out_layers[1:]
+ scale, shift = torch.chunk(emb_out, 2, dim=1)
+ h = out_norm(h) * (1 + scale) + shift
+ h = out_rest(h)
+ else:
+ h = h + emb_out
+ h = self.out_layers(h)
+
+ res = self.skip_connection(x) + h
+ if self.scale_skip_connection:
+ res *= 0.7071 # 1 / sqrt(2), https://arxiv.org/pdf/2104.07636.pdf
+ return res
+
+
+class AttentionBlock(nn.Module):
+ """
+ An attention block that allows spatial positions to attend to each other.
+ Originally ported from here, but adapted to the N-d case.
+ https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/models/unet.py#L66.
+ """
+
+ def __init__(
+ self,
+ channels,
+ num_heads=1,
+ num_head_channels=-1,
+ disable_self_attention=False,
+ encoder_channels=None,
+ dtype=None,
+ ):
+ super().__init__()
+ self.dtype = dtype
+ self.channels = channels
+ self.disable_self_attention = disable_self_attention
+ if num_head_channels == -1:
+ self.num_heads = num_heads
+ else:
+ assert (
+ channels % num_head_channels == 0
+ ), f'q,k,v channels {channels} is not divisible by num_head_channels {num_head_channels}'
+ self.num_heads = channels // num_head_channels
+ self.norm = normalization(channels, dtype=self.dtype)
+ self.qkv = conv_nd(1, channels, channels * 3, 1, dtype=self.dtype)
+ if self.disable_self_attention:
+ self.qkv = conv_nd(1, channels, channels, 1, dtype=self.dtype)
+ else:
+ self.qkv = conv_nd(1, channels, channels * 3, 1, dtype=self.dtype)
+ self.attention = QKVAttention(self.num_heads, disable_self_attention=disable_self_attention)
+
+ if encoder_channels is not None:
+ self.encoder_kv = conv_nd(1, encoder_channels, channels * 2, 1, dtype=self.dtype)
+ self.norm_encoder = normalization(encoder_channels, dtype=self.dtype)
+ self.proj_out = zero_module(conv_nd(1, channels, channels, 1, dtype=self.dtype))
+
+ def forward(self, x, encoder_out=None):
+ b, c, *spatial = x.shape
+ qkv = self.qkv(self.norm(x).view(b, c, -1))
+ if encoder_out is not None:
+ # from imagen article: https://arxiv.org/pdf/2205.11487.abs
+ encoder_out = self.norm_encoder(encoder_out)
+ # # #
+ encoder_out = self.encoder_kv(encoder_out)
+ h = self.attention(qkv, encoder_out)
+ else:
+ h = self.attention(qkv)
+ h = self.proj_out(h)
+ return x + h.reshape(b, c, *spatial)
+
+
+class QKVAttention(nn.Module):
+ """
+ A module which performs QKV attention. Matches legacy QKVAttention + input/ouput heads shaping
+ """
+
+ def __init__(self, n_heads, disable_self_attention=False):
+ super().__init__()
+ self.n_heads = n_heads
+ self.disable_self_attention = disable_self_attention
+
+ def forward(self, qkv, encoder_kv=None):
+ """
+ Apply QKV attention.
+ :param qkv: an [N x (H * 3 * C) x T] tensor of Qs, Ks, and Vs.
+ :return: an [N x (H * C) x T] tensor after attention.
+ """
+ bs, width, length = qkv.shape
+ if self.disable_self_attention:
+ ch = width // (1 * self.n_heads)
+ q, = qkv.reshape(bs * self.n_heads, ch * 1, length).split(ch, dim=1)
+ else:
+ assert width % (3 * self.n_heads) == 0
+ ch = width // (3 * self.n_heads)
+ q, k, v = qkv.reshape(bs * self.n_heads, ch * 3, length).split(ch, dim=1)
+ if encoder_kv is not None:
+ assert encoder_kv.shape[1] == self.n_heads * ch * 2
+ if self.disable_self_attention:
+ k, v = encoder_kv.reshape(bs * self.n_heads, ch * 2, -1).split(ch, dim=1)
+ else:
+ ek, ev = encoder_kv.reshape(bs * self.n_heads, ch * 2, -1).split(ch, dim=1)
+ k = torch.cat([ek, k], dim=-1)
+ v = torch.cat([ev, v], dim=-1)
+ scale = 1 / math.sqrt(math.sqrt(ch))
+ # if _FORCE_MEM_EFFICIENT_ATTN:
+ q, k, v = map(lambda t: t.permute(0, 2, 1).contiguous(), (q, k, v))
+ a = memory_efficient_attention(q, k, v)
+ a = a.permute(0, 2, 1)
+ # else:
+ # weight = torch.einsum(
+ # 'bct,bcs->bts', q * scale, k * scale
+ # ) # More stable with f16 than dividing afterwards
+ # weight = torch.softmax(weight.float(), dim=-1).type(weight.dtype)
+ # a = torch.einsum('bts,bcs->bct', weight, v)
+ return a.reshape(bs, -1, length)
+
+
+class UNetSplitModel(nn.Module):
+ """
+ The full UNet model with attention and timestep embedding.
+ :param in_channels: channels in the input Tensor.
+ :param model_channels: base channel count for the model.
+ :param out_channels: channels in the output Tensor.
+ :param num_res_blocks: number of residual blocks per downsample.
+ :param attention_resolutions: a collection of downsample rates at which
+ attention will take place. May be a set, list, or tuple.
+ For example, this contains 4, then at 4x downsampling, attention
+ will be used.
+ :param dropout: the dropout probability.
+ :param channel_mult: channel multiplier for each level of the UNet.
+ :param conv_resample: True, use learned convolutions for upsampling and
+ downsampling.
+ :param dims: determines the signal is 1D, 2D, or 3D.
+ :param num_classes: specified (as an int), then this model will be
+ class-conditional with `num_classes` classes.
+ :param num_heads: the number of attention heads in each attention layer.
+ :param num_heads_channels: specified, ignore num_heads and instead use
+ a fixed channel width per attention head.
+ :param num_heads_upsample: works with num_heads to set a different number
+ of heads for upsampling. Deprecated.
+ :param use_scale_shift_norm: use a FiLM-like conditioning mechanism.
+ :param resblock_updown: use residual blocks for up/downsampling.
+ """
+
+ def __init__(
+ self,
+ in_channels,
+ model_channels,
+ out_channels,
+ num_res_blocks,
+ attention_resolutions,
+ activation,
+ encoder_dim,
+ att_pool_heads,
+ encoder_channels,
+ image_size,
+ disable_self_attentions=None,
+ dropout=0,
+ channel_mult=(1, 2, 4, 8),
+ conv_resample=True,
+ dims=2,
+ num_classes=None,
+ precision='32',
+ num_heads=1,
+ num_head_channels=-1,
+ num_heads_upsample=-1,
+ use_scale_shift_norm=False,
+ resblock_updown=False,
+ efficient_activation=False,
+ scale_skip_connection=False,
+ ):
+ super().__init__()
+
+ if num_heads_upsample == -1:
+ num_heads_upsample = num_heads
+
+ self.encoder_channels = encoder_channels
+ self.encoder_dim = encoder_dim
+ self.efficient_activation = efficient_activation
+ self.scale_skip_connection = scale_skip_connection
+ self.in_channels = in_channels
+ self.model_channels = model_channels
+ self.out_channels = out_channels
+ self.dropout = dropout
+
+ self.secondary_device = torch.device("cpu")
+
+ # adapt attention resolutions
+ if isinstance(attention_resolutions, str):
+ self.attention_resolutions = []
+ for res in attention_resolutions.split(','):
+ self.attention_resolutions.append(image_size // int(res))
+ else:
+ self.attention_resolutions = attention_resolutions
+ self.attention_resolutions = tuple(self.attention_resolutions)
+ #
+
+ # adapt disable self attention resolutions
+ if not disable_self_attentions:
+ self.disable_self_attentions = []
+ elif disable_self_attentions is True:
+ self.disable_self_attentions = attention_resolutions
+ elif isinstance(disable_self_attentions, str):
+ self.disable_self_attentions = []
+ for res in disable_self_attentions.split(','):
+ self.disable_self_attentions.append(image_size // int(res))
+ else:
+ self.disable_self_attentions = disable_self_attentions
+ self.disable_self_attentions = tuple(self.disable_self_attentions)
+ #
+
+ # adapt channel mult
+ if isinstance(channel_mult, str):
+ self.channel_mult = tuple(int(ch_mult) for ch_mult in channel_mult.split(','))
+ else:
+ self.channel_mult = tuple(channel_mult)
+ #
+
+ self.conv_resample = conv_resample
+ self.num_classes = num_classes
+ self.dtype = torch.float32
+
+ self.precision = str(precision)
+ self.use_fp16 = precision == '16'
+ if self.precision == '16':
+ self.dtype = torch.float16
+ elif self.precision == 'bf16':
+ self.dtype = torch.bfloat16
+
+ self.num_heads = num_heads
+ self.num_head_channels = num_head_channels
+ self.num_heads_upsample = num_heads_upsample
+
+ self.time_embed_dim = model_channels * max(self.channel_mult)
+ self.time_embed = nn.Sequential(
+ linear(model_channels, self.time_embed_dim, dtype=self.dtype),
+ get_activation(activation),
+ linear(self.time_embed_dim, self.time_embed_dim, dtype=self.dtype),
+ )
+
+ if self.num_classes is not None:
+ self.label_emb = nn.Embedding(num_classes, self.time_embed_dim)
+
+ ch = input_ch = int(self.channel_mult[0] * model_channels)
+ self.input_blocks = nn.ModuleList(
+ [TimestepEmbedSequential(conv_nd(dims, in_channels, ch, 3, padding=1, dtype=self.dtype))]
+ )
+ self._feature_size = ch
+ input_block_chans = [ch]
+ ds = 1
+
+ if isinstance(num_res_blocks, int):
+ num_res_blocks = [num_res_blocks] * len(self.channel_mult)
+ self.num_res_blocks = num_res_blocks
+
+ for level, mult in enumerate(self.channel_mult):
+ for _ in range(num_res_blocks[level]):
+ layers = [
+ ResBlock(
+ ch,
+ self.time_embed_dim,
+ dropout,
+ out_channels=int(mult * model_channels),
+ dims=dims,
+ use_scale_shift_norm=use_scale_shift_norm,
+ dtype=self.dtype,
+ activation=activation,
+ efficient_activation=self.efficient_activation,
+ scale_skip_connection=self.scale_skip_connection,
+ )
+ ]
+ ch = int(mult * model_channels)
+ if ds in self.attention_resolutions:
+ layers.append(
+ AttentionBlock(
+ ch,
+ num_heads=num_heads,
+ num_head_channels=num_head_channels,
+ encoder_channels=encoder_channels,
+ dtype=self.dtype,
+ disable_self_attention=ds in self.disable_self_attentions,
+ )
+ )
+ self.input_blocks.append(TimestepEmbedSequential(*layers))
+ self._feature_size += ch
+ input_block_chans.append(ch)
+ if level != len(self.channel_mult) - 1:
+ out_ch = ch
+ self.input_blocks.append(
+ TimestepEmbedSequential(
+ ResBlock(
+ ch,
+ self.time_embed_dim,
+ dropout,
+ out_channels=out_ch,
+ dims=dims,
+ use_scale_shift_norm=use_scale_shift_norm,
+ down=True,
+ dtype=self.dtype,
+ activation=activation,
+ efficient_activation=self.efficient_activation,
+ scale_skip_connection=self.scale_skip_connection,
+ )
+ if resblock_updown
+ else Downsample(ch, conv_resample, dims=dims, out_channels=out_ch)
+ )
+ )
+ ch = out_ch
+ input_block_chans.append(ch)
+ ds *= 2
+ self._feature_size += ch
+
+ self.middle_block = TimestepEmbedSequential(
+ ResBlock(
+ ch,
+ self.time_embed_dim,
+ dropout,
+ dims=dims,
+ use_scale_shift_norm=use_scale_shift_norm,
+ dtype=self.dtype,
+ activation=activation,
+ efficient_activation=self.efficient_activation,
+ scale_skip_connection=self.scale_skip_connection,
+ ),
+ AttentionBlock(
+ ch,
+ num_heads=num_heads,
+ num_head_channels=num_head_channels,
+ encoder_channels=encoder_channels,
+ dtype=self.dtype,
+ disable_self_attention=ds in self.disable_self_attentions,
+ ),
+ ResBlock(
+ ch,
+ self.time_embed_dim,
+ dropout,
+ dims=dims,
+ use_scale_shift_norm=use_scale_shift_norm,
+ dtype=self.dtype,
+ activation=activation,
+ efficient_activation=self.efficient_activation,
+ scale_skip_connection=self.scale_skip_connection,
+ ),
+ )
+ self._feature_size += ch
+
+ self.output_blocks = nn.ModuleList([])
+ for level, mult in list(enumerate(self.channel_mult))[::-1]:
+ for i in range(num_res_blocks[level] + 1):
+ ich = input_block_chans.pop()
+ layers = [
+ ResBlock(
+ ch + ich,
+ self.time_embed_dim,
+ dropout,
+ out_channels=int(model_channels * mult),
+ dims=dims,
+ use_scale_shift_norm=use_scale_shift_norm,
+ dtype=self.dtype,
+ activation=activation,
+ efficient_activation=self.efficient_activation,
+ scale_skip_connection=self.scale_skip_connection,
+ )
+ ]
+ ch = int(model_channels * mult)
+ if ds in self.attention_resolutions:
+ layers.append(
+ AttentionBlock(
+ ch,
+ num_heads=num_heads_upsample,
+ num_head_channels=num_head_channels,
+ encoder_channels=encoder_channels,
+ dtype=self.dtype,
+ disable_self_attention=ds in self.disable_self_attentions,
+ )
+ )
+ if level and i == num_res_blocks[level]:
+ out_ch = ch
+ layers.append(
+ ResBlock(
+ ch,
+ self.time_embed_dim,
+ dropout,
+ out_channels=out_ch,
+ dims=dims,
+ use_scale_shift_norm=use_scale_shift_norm,
+ up=True,
+ dtype=self.dtype,
+ activation=activation,
+ efficient_activation=self.efficient_activation,
+ scale_skip_connection=self.scale_skip_connection,
+ )
+ if resblock_updown
+ else Upsample(ch, conv_resample, dims=dims, out_channels=out_ch)
+ )
+ ds //= 2
+ self.output_blocks.append(TimestepEmbedSequential(*layers))
+ self._feature_size += ch
+
+ self.out = nn.Sequential(
+ normalization(ch, dtype=self.dtype),
+ get_activation(activation),
+ zero_module(conv_nd(dims, input_ch, out_channels, 3, padding=1, dtype=self.dtype)),
+ )
+
+ self.activation_layer = get_activation(activation) if self.efficient_activation else nn.Identity()
+
+ self.encoder_pooling = nn.Sequential(
+ nn.LayerNorm(encoder_dim, dtype=self.dtype),
+ AttentionPooling(att_pool_heads, encoder_dim, dtype=self.dtype),
+ nn.Linear(encoder_dim, self.time_embed_dim, dtype=self.dtype),
+ nn.LayerNorm(self.time_embed_dim, dtype=self.dtype)
+ )
+
+ if encoder_dim != encoder_channels:
+ self.encoder_proj = nn.Linear(encoder_dim, encoder_channels, dtype=self.dtype)
+ else:
+ self.encoder_proj = nn.Identity()
+
+ self.cache = None
+
+ def collect(self):
+ gc.collect()
+ torch.cuda.empty_cache()
+
+ def to(self, x, stage=1): # 0, 1, 2, 3
+ if isinstance(x, torch.device):
+ secondary_device = self.secondary_device
+ if stage == 1:
+ self.middle_block.to(secondary_device)
+ self.output_blocks.to(secondary_device)
+ self.out.to(secondary_device)
+ self.collect()
+ self.time_embed.to(x)
+ self.encoder_proj.to(x)
+ self.encoder_pooling.to(x)
+ self.input_blocks.to(x)
+ elif stage == 2:
+ self.time_embed.to(secondary_device)
+ self.encoder_proj.to(secondary_device)
+ self.encoder_pooling.to(secondary_device)
+ self.input_blocks.to(secondary_device)
+ self.output_blocks.to(secondary_device)
+ self.out.to(secondary_device)
+ self.collect()
+ self.middle_block.to(x)
+ elif stage == 3:
+ self.time_embed.to(secondary_device)
+ self.encoder_proj.to(secondary_device)
+ self.encoder_pooling.to(secondary_device)
+ self.input_blocks.to(secondary_device)
+ self.middle_block.to(secondary_device)
+ self.collect()
+ self.output_blocks.to(x)
+ self.out.to(x)
+ else:
+ super().to(x)
+
+ def forward(self, x, timesteps, text_emb, timestep_text_emb=None, aug_emb=None, use_cache=False, **kwargs):
+ hs = []
+ self.to(self.primary_device, stage=1)
+ emb = self.time_embed(timestep_embedding(timesteps.to(torch.float32), self.model_channels,
+ dtype=torch.float32).to(self.primary_device).to(self.dtype))
+
+ if use_cache and self.cache is not None:
+ encoder_out, encoder_pool = self.cache
+ else:
+ text_emb = text_emb.type(self.dtype).to(self.primary_device)
+ encoder_out = self.encoder_proj(text_emb)
+ encoder_out = encoder_out.permute(0, 2, 1) # NLC -> NCL
+ if timestep_text_emb is None:
+ timestep_text_emb = text_emb
+ encoder_pool = self.encoder_pooling(timestep_text_emb)
+ if use_cache:
+ self.cache = (encoder_out, encoder_pool)
+
+ emb = emb + encoder_pool.to(emb)
+
+ if aug_emb is not None:
+ emb = emb + aug_emb.to(emb)
+
+ emb = self.activation_layer(emb)
+
+ h = x.type(self.dtype).to(self.primary_device)
+
+ for module in self.input_blocks:
+ h = module(h, emb, encoder_out)
+ hs.append(h)
+
+ self.to(self.primary_device, stage=2)
+
+ h = self.middle_block(h, emb, encoder_out)
+
+ self.to(self.primary_device, stage=3)
+
+ for module in self.output_blocks:
+ h = torch.cat([h, hs.pop()], dim=1)
+ h = module(h, emb, encoder_out)
+ h = h.type(self.dtype)
+ h = self.out(h)
+ return h
+
+
+class SuperResUNetModel(UNetSplitModel):
+ """
+ A text2im model that performs super-resolution.
+ Expects an extra kwarg `low_res` to condition on a low-resolution image.
+ """
+
+ def __init__(self, low_res_diffusion, interpolate_mode='bilinear', *args, **kwargs):
+ self.low_res_diffusion = low_res_diffusion
+ self.interpolate_mode = interpolate_mode
+ super().__init__(*args, **kwargs)
+
+ self.aug_proj = nn.Sequential(
+ linear(self.model_channels, self.time_embed_dim, dtype=self.dtype),
+ get_activation(kwargs['activation']),
+ linear(self.time_embed_dim, self.time_embed_dim, dtype=self.dtype),
+ )
+
+ def forward(self, x, timesteps, low_res, aug_level=None, **kwargs):
+ bs, _, new_height, new_width = x.shape
+
+ align_corners = True
+ if self.interpolate_mode == 'nearest':
+ align_corners = None
+
+ upsampled = F.interpolate(
+ low_res, (new_height, new_width), mode=self.interpolate_mode, align_corners=align_corners
+ )
+
+ if aug_level is None:
+ aug_steps = (np.random.random(bs) * 1000).astype(np.int64) # uniform [0, 1)
+ aug_steps = torch.from_numpy(aug_steps).to(x.device, dtype=torch.long)
+ else:
+ aug_steps = torch.tensor([int(aug_level * 1000)]).repeat(bs).to(x.device, dtype=torch.long)
+
+ upsampled = self.low_res_diffusion.q_sample(upsampled, aug_steps)
+ x = torch.cat([x, upsampled], dim=1)
+
+ aug_emb = self.aug_proj(
+ timestep_embedding(aug_steps, self.model_channels, dtype=self.dtype)
+ )
+ return super().forward(x, timesteps, aug_emb=aug_emb, **kwargs)
diff --git a/deepfloyd_if/modules/base.py b/deepfloyd_if/modules/base.py
index c808a3c..797ed32 100644
--- a/deepfloyd_if/modules/base.py
+++ b/deepfloyd_if/modules/base.py
@@ -14,14 +14,12 @@
from huggingface_hub import hf_hub_download
from accelerate.utils import set_module_tensor_to_device
-
from .. import utils
from ..model.respace import create_gaussian_diffusion
from .utils import load_model_weights, predict_proba, clip_process_generations
class IFBaseModule:
-
stage = '-'
available_models = []
@@ -68,29 +66,29 @@ def use_diffusers(self):
return False
def embeddings_to_image(
- self, t5_embs, low_res=None, *,
- style_t5_embs=None,
- positive_t5_embs=None,
- negative_t5_embs=None,
- batch_repeat=1,
- dynamic_thresholding_p=0.95,
- sample_loop='ddpm',
- sample_timestep_respacing='smart185',
- dynamic_thresholding_c=1.5,
- guidance_scale=7.0,
- aug_level=0.25,
- positive_mixer=0.15,
- blur_sigma=None,
- img_size=None,
- img_scale=4.0,
- aspect_ratio='1:1',
- progress=True,
- seed=None,
- sample_fn=None,
- support_noise=None,
- support_noise_less_qsample_steps=0,
- inpainting_mask=None,
- **kwargs,
+ self, t5_embs, low_res=None, *,
+ style_t5_embs=None,
+ positive_t5_embs=None,
+ negative_t5_embs=None,
+ batch_repeat=1,
+ dynamic_thresholding_p=0.95,
+ sample_loop='ddpm',
+ sample_timestep_respacing='smart185',
+ dynamic_thresholding_c=1.5,
+ guidance_scale=7.0,
+ aug_level=0.25,
+ positive_mixer=0.15,
+ blur_sigma=None,
+ img_size=None,
+ img_scale=4.0,
+ aspect_ratio='1:1',
+ progress=True,
+ seed=None,
+ sample_fn=None,
+ support_noise=None,
+ support_noise_less_qsample_steps=0,
+ inpainting_mask=None,
+ **kwargs,
):
self._clear_cache()
image_w, image_h = self._get_image_sizes(low_res, img_size, aspect_ratio, img_scale)
@@ -100,13 +98,13 @@ def embeddings_to_image(
def model_fn(x_t, ts, **kwargs):
half = x_t[: len(x_t) // bs_scale]
- combined = torch.cat([half]*bs_scale, dim=0)
+ combined = torch.cat([half] * bs_scale, dim=0)
model_out = self.model(combined, ts, **kwargs)
eps, rest = model_out[:, :3], model_out[:, 3:]
if bs_scale == 3:
cond_eps, pos_cond_eps, uncond_eps = torch.split(eps, len(eps) // bs_scale, dim=0)
half_eps = uncond_eps + guidance_scale * (
- cond_eps * (1 - positive_mixer) + pos_cond_eps * positive_mixer - uncond_eps)
+ cond_eps * (1 - positive_mixer) + pos_cond_eps * positive_mixer - uncond_eps)
pos_half_eps = uncond_eps + guidance_scale * (pos_cond_eps - uncond_eps)
eps = torch.cat([half_eps, pos_half_eps, half_eps], dim=0)
else:
@@ -170,7 +168,7 @@ def model_fn(x_t, ts, **kwargs):
if low_res is not None:
if blur_sigma is not None:
low_res = T.GaussianBlur(3, sigma=(blur_sigma, blur_sigma))(low_res)
- model_kwargs['low_res'] = torch.cat([low_res]*bs_scale, dim=0).to(self.device)
+ model_kwargs['low_res'] = torch.cat([low_res] * bs_scale, dim=0).to(self.device)
model_kwargs['aug_level'] = aug_level
if support_noise is None:
@@ -186,7 +184,7 @@ def model_fn(x_t, ts, **kwargs):
support_noise[inpainting_mask.cpu().bool() if inpainting_mask is not None else ...],
q_sample_steps,
)
- noise = noise.repeat(batch_size*bs_scale, 1, 1, 1).to(device=self.device, dtype=self.model.dtype)
+ noise = noise.repeat(batch_size * bs_scale, 1, 1, 1).to(device=self.device, dtype=self.model.dtype)
if inpainting_mask is not None:
inpainting_mask = inpainting_mask.to(device=self.device, dtype=torch.long)
@@ -202,7 +200,7 @@ def model_fn(x_t, ts, **kwargs):
dynamic_thresholding_p=dynamic_thresholding_p,
dynamic_thresholding_c=dynamic_thresholding_c,
inpainting_mask=inpainting_mask,
- device=self.device,
+ device=self.model.primary_device,
progress=progress,
sample_fn=sample_fn,
)[:batch_size]
@@ -216,7 +214,7 @@ def model_fn(x_t, ts, **kwargs):
model_kwargs=model_kwargs,
dynamic_thresholding_p=dynamic_thresholding_p,
dynamic_thresholding_c=dynamic_thresholding_c,
- device=self.device,
+ device=self.model.primary_device,
progress=progress,
sample_fn=sample_fn,
)[:batch_size]
@@ -311,7 +309,7 @@ def to_images(self, generations, disable_watermark=False):
def show(self, pil_images, nrow=None, size=10):
if nrow is None:
- nrow = round(len(pil_images)**0.5)
+ nrow = round(len(pil_images) ** 0.5)
imgs = torchvision.utils.make_grid(utils.pil_list_to_torch_tensors(pil_images), nrow=nrow)
if not isinstance(imgs, list):
@@ -333,16 +331,16 @@ def _clear_cache(self):
def _get_image_sizes(self, low_res, img_size, aspect_ratio, img_scale):
if low_res is not None:
bs, c, h, w = low_res.shape
- image_h, image_w = int((h*img_scale)//32)*32, int((w*img_scale//32))*32
+ image_h, image_w = int((h * img_scale) // 32) * 32, int((w * img_scale // 32)) * 32
else:
scale_w, scale_h = aspect_ratio.split(':')
scale_w, scale_h = int(scale_w), int(scale_h)
coef = scale_w / scale_h
image_h, image_w = img_size, img_size
if coef >= 1:
- image_w = int(round(img_size/8 * coef) * 8)
+ image_w = int(round(img_size / 8 * coef) * 8)
else:
- image_h = int(round(img_size/8 / coef) * 8)
+ image_h = int(round(img_size / 8 / coef) * 8)
assert image_h % 8 == 0
assert image_w % 8 == 0
diff --git a/deepfloyd_if/modules/stage_I.py b/deepfloyd_if/modules/stage_I.py
index 8b1c5f3..8b3c62d 100644
--- a/deepfloyd_if/modules/stage_I.py
+++ b/deepfloyd_if/modules/stage_I.py
@@ -1,15 +1,16 @@
# -*- coding: utf-8 -*-
import accelerate
+import torch
from .base import IFBaseModule
-from ..model import UNetModel
+from ..model import UNetModel, UNetSplitModel
class IFStageI(IFBaseModule):
stage = 'I'
available_models = ['IF-I-M-v1.0', 'IF-I-L-v1.0', 'IF-I-XL-v1.0']
- def __init__(self, *args, model_kwargs=None, pil_img_size=64, **kwargs):
+ def __init__(self, *args, model_kwargs=None, pil_img_size=64, use_split=True, **kwargs):
"""
:param conf_or_path:
:param device:
@@ -19,12 +20,16 @@ def __init__(self, *args, model_kwargs=None, pil_img_size=64, **kwargs):
super().__init__(*args, pil_img_size=pil_img_size, **kwargs)
model_params = dict(self.conf.params)
model_params.update(model_kwargs or {})
+ UNetClass = UNetSplitModel if use_split else UNetModel
with accelerate.init_empty_weights():
- self.model = UNetModel(**model_params)
+ self.model = UNetClass(**model_params)
self.model = self.load_checkpoint(self.model, self.dir_or_name)
self.model.eval().to(self.device)
- def to(self, x):
+ def to(self, x, stage=1, secondary_device=torch.device("cpu")): # 0, 1, 2, 3
+ if isinstance(x, torch.device):
+ self.model.primary_device = x
+ self.model.secondary_device = secondary_device
self.model.to(x)
def embeddings_to_image(self, t5_embs, style_t5_embs=None, positive_t5_embs=None, negative_t5_embs=None,
diff --git a/deepfloyd_if/modules/stage_III_sd_x4.py b/deepfloyd_if/modules/stage_III_sd_x4.py
index 307fad2..7599317 100644
--- a/deepfloyd_if/modules/stage_III_sd_x4.py
+++ b/deepfloyd_if/modules/stage_III_sd_x4.py
@@ -34,8 +34,8 @@ def __init__(self, *args, model_kwargs=None, pil_img_size=1024, **kwargs):
self.model = DiffusionPipeline.from_pretrained(model_id, torch_dtype=torch_dtype, token=self.hf_token)
self.model.to(self.device)
- if bool(os.environ.get('FORCE_MEM_EFFICIENT_ATTN')):
- self.model.enable_xformers_memory_efficient_attention()
+ # if bool(os.environ.get('FORCE_MEM_EFFICIENT_ATTN')):
+ self.model.enable_xformers_memory_efficient_attention()
@property
def use_diffusers(self):
diff --git a/deepfloyd_if/pipelines/dream.py b/deepfloyd_if/pipelines/dream.py
index 13d7479..0b98621 100644
--- a/deepfloyd_if/pipelines/dream.py
+++ b/deepfloyd_if/pipelines/dream.py
@@ -5,22 +5,22 @@
def dream(
- t5,
- if_I,
- if_II=None,
- if_III=None,
- *,
- prompt,
- style_prompt=None,
- negative_prompt=None,
- seed=None,
- aspect_ratio='1:1',
- if_I_kwargs=None,
- if_II_kwargs=None,
- if_III_kwargs=None,
- progress=True,
- return_tensors=False,
- disable_watermark=False,
+ t5,
+ if_I,
+ if_II=None,
+ if_III=None,
+ *,
+ prompt,
+ style_prompt=None,
+ negative_prompt=None,
+ seed=None,
+ aspect_ratio='1:1',
+ if_I_kwargs=None,
+ if_II_kwargs=None,
+ if_III_kwargs=None,
+ progress=True,
+ return_tensors=False,
+ disable_watermark=False,
):
"""
Generate pictures using text description!
@@ -108,18 +108,18 @@ def dream(
stageIII_generations = []
for idx in range(len(stageII_generations)):
if if_III.use_diffusers:
- if_III_kwargs['prompt'] = prompt[idx: idx+1]
+ if_III_kwargs['prompt'] = prompt[idx: idx + 1]
- if_III_kwargs['low_res'] = stageII_generations[idx:idx+1]
+ if_III_kwargs['low_res'] = stageII_generations[idx:idx + 1]
if_III_kwargs['seed'] = seed
- if_III_kwargs['t5_embs'] = t5_embs[idx:idx+1]
+ if_III_kwargs['t5_embs'] = t5_embs[idx:idx + 1]
if_III_kwargs['progress'] = progress
style_t5_embs = if_I_kwargs.get('style_t5_embs')
if style_t5_embs is not None:
- style_t5_embs = style_t5_embs[idx:idx+1]
+ style_t5_embs = style_t5_embs[idx:idx + 1]
positive_t5_embs = if_I_kwargs.get('positive_t5_embs')
if positive_t5_embs is not None:
- positive_t5_embs = positive_t5_embs[idx:idx+1]
+ positive_t5_embs = positive_t5_embs[idx:idx + 1]
if_III_kwargs['style_t5_embs'] = style_t5_embs
if_III_kwargs['positive_t5_embs'] = positive_t5_embs
diff --git a/deepfloyd_if/pipelines/optimized_dream.py b/deepfloyd_if/pipelines/optimized_dream.py
index 6620b4b..2eeb0df 100644
--- a/deepfloyd_if/pipelines/optimized_dream.py
+++ b/deepfloyd_if/pipelines/optimized_dream.py
@@ -35,7 +35,7 @@ def run_stage1(
guidance_scale=guidance_scale_1,
sample_timestep_respacing=custom_timesteps_1,
seed=seed
- ).images
+ )
pil_images_I = model.to_images(images, disable_watermark=True)
return pil_images_I
diff --git a/run_ui.py b/run_ui.py
index 8e09648..73f5dec 100644
--- a/run_ui.py
+++ b/run_ui.py
@@ -1,8 +1,10 @@
import argparse
import gc
import os
+import time
import numpy as np
+from accelerate import dispatch_model
from deepfloyd_if.modules import IFStageI, IFStageII, StableStageIII
from deepfloyd_if.modules.t5 import T5Embedder
@@ -15,13 +17,6 @@
from ui_files.utils import randomize_seed_fn, show_gallery_view, update_upscale_button, get_stage2_index, \
check_if_stage2_selected, show_upscaled_view, get_device_map
-try:
- import xformers
-
- os.environ["FORCE_MEM_EFFICIENT_ATTN"] = "1"
-except:
- pass
-
device = torch.device(0)
if_I = IFStageI('IF-I-XL-v1.0', device=torch.device("cpu"))
if_I.to(torch.float16) # half
@@ -37,13 +32,12 @@
def switch_devices(stage):
if stage == 1:
# t5.model.cpu()
- del t5.model
+ dispatch_model(t5.model, get_device_map(t5_device, all2cpu=True))
+ gc.collect()
+ torch.cuda.empty_cache()
+ torch.cuda.synchronize()
if_I.to(torch.device(0))
- gc.collect()
- torch.cuda.empty_cache()
- torch.cuda.synchronize()
-
def process_and_run_stage1(prompt,
negative_prompt,
@@ -157,7 +151,7 @@ def create_ui(args):
minimum=1,
maximum=4,
step=1,
- value=4,
+ value=1,
visible=True)
with gr.Tab(label='Super-resolution 1'):
seed_2 = gr.Slider(label='Seed',
diff --git a/ui_files/utils.py b/ui_files/utils.py
index a80d80d..541a577 100644
--- a/ui_files/utils.py
+++ b/ui_files/utils.py
@@ -2,6 +2,8 @@
import numpy as np
import random
+import torch
+
def _update_result_view(show_gallery: bool) -> tuple[dict, dict]:
return gr.update(visible=show_gallery), gr.update(visible=not show_gallery)
@@ -39,7 +41,8 @@ def check_if_stage2_selected(index: int) -> None:
)
-def get_device_map(device):
+def get_device_map(device, all2cpu=False):
+ device = device if not all2cpu else torch.device("cpu")
return {
'shared': device,
'encoder.embed_tokens': device,
From d34388f32384dd26f72ee636ee5631c06daad540 Mon Sep 17 00:00:00 2001
From: neon
Date: Sat, 29 Apr 2023 21:26:38 +0200
Subject: [PATCH 03/10] step 2
---
deepfloyd_if/model/gaussian_diffusion.py | 2 +-
deepfloyd_if/model/nn.py | 9 +-
deepfloyd_if/model/unet.py | 12 +-
deepfloyd_if/modules/base.py | 9 +-
deepfloyd_if/modules/stage_II.py | 6 +-
deepfloyd_if/pipelines/optimized_dream.py | 32 +++--
run_ui.py | 158 ++++++++++++----------
7 files changed, 125 insertions(+), 103 deletions(-)
diff --git a/deepfloyd_if/model/gaussian_diffusion.py b/deepfloyd_if/model/gaussian_diffusion.py
index 596c1b7..2394b31 100644
--- a/deepfloyd_if/model/gaussian_diffusion.py
+++ b/deepfloyd_if/model/gaussian_diffusion.py
@@ -875,7 +875,7 @@ def _extract_into_tensor(arr, timesteps, broadcast_shape):
dimension equal to the length of timesteps.
:return: a tensor of shape [batch_size, 1, ...] where the shape has K dims.
"""
- res = torch.from_numpy(arr).to(device=timesteps.device)[timesteps].float()
+ res = torch.from_numpy(arr).to(device=timesteps.device)[timesteps.to(torch.long)].float()
while len(res.shape) < len(broadcast_shape):
res = res[..., None]
return res.expand(broadcast_shape)
diff --git a/deepfloyd_if/model/nn.py b/deepfloyd_if/model/nn.py
index 4f1a0f0..a957825 100644
--- a/deepfloyd_if/model/nn.py
+++ b/deepfloyd_if/model/nn.py
@@ -171,17 +171,16 @@ def timestep_embedding(timesteps, dim, max_period=10000, dtype=None):
:param max_period: controls the minimum frequency of the embeddings.
:return: an [N x dim] Tensor of positional embeddings.
"""
- if dtype is None:
- dtype = torch.float32
+ dtype2 = torch.float32
half = dim // 2
freqs = torch.exp(
-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half
- ).to(device=timesteps.device, dtype=dtype)
- args = timesteps[:, None].type(dtype) * freqs[None]
+ ).to(device=timesteps.device, dtype=dtype2)
+ args = timesteps[:, None].type(dtype2) * freqs[None]
embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
if dim % 2:
embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
- return embedding
+ return embedding.to(dtype)
def attention(q, k, v, d_k):
diff --git a/deepfloyd_if/model/unet.py b/deepfloyd_if/model/unet.py
index 2fbc4aa..d82fdeb 100644
--- a/deepfloyd_if/model/unet.py
+++ b/deepfloyd_if/model/unet.py
@@ -629,7 +629,7 @@ def forward(self, x, timesteps, text_emb, timestep_text_emb=None, aug_emb=None,
if use_cache and self.cache is not None:
encoder_out, encoder_pool = self.cache
else:
- text_emb = text_emb.type(self.dtype)
+ text_emb = text_emb.type(self.dtype).to(x.device)
encoder_out = self.encoder_proj(text_emb)
encoder_out = encoder_out.permute(0, 2, 1) # NLC -> NCL
if timestep_text_emb is None:
@@ -674,6 +674,7 @@ def __init__(self, low_res_diffusion, interpolate_mode='bilinear', *args, **kwar
get_activation(kwargs['activation']),
linear(self.time_embed_dim, self.time_embed_dim, dtype=self.dtype),
)
+ self.primary_device = torch.device(0)
def forward(self, x, timesteps, low_res, aug_level=None, **kwargs):
bs, _, new_height, new_width = x.shape
@@ -684,7 +685,7 @@ def forward(self, x, timesteps, low_res, aug_level=None, **kwargs):
upsampled = F.interpolate(
low_res, (new_height, new_width), mode=self.interpolate_mode, align_corners=align_corners
- )
+ ).to(x.device)
if aug_level is None:
aug_steps = (np.random.random(bs) * 1000).astype(np.int64) # uniform [0, 1)
@@ -692,10 +693,11 @@ def forward(self, x, timesteps, low_res, aug_level=None, **kwargs):
else:
aug_steps = torch.tensor([int(aug_level * 1000)]).repeat(bs).to(x.device, dtype=torch.long)
+ aug_steps = aug_steps.to(self.dtype)
+
upsampled = self.low_res_diffusion.q_sample(upsampled, aug_steps)
x = torch.cat([x, upsampled], dim=1)
-
aug_emb = self.aug_proj(
- timestep_embedding(aug_steps, self.model_channels, dtype=self.dtype)
- )
+ timestep_embedding(aug_steps, self.model_channels, dtype=self.dtype).to(self.dtype)
+ ).to(x.device)
return super().forward(x, timesteps, aug_emb=aug_emb, **kwargs)
diff --git a/deepfloyd_if/modules/base.py b/deepfloyd_if/modules/base.py
index 797ed32..d08453b 100644
--- a/deepfloyd_if/modules/base.py
+++ b/deepfloyd_if/modules/base.py
@@ -88,8 +88,11 @@ def embeddings_to_image(
support_noise=None,
support_noise_less_qsample_steps=0,
inpainting_mask=None,
+ device=None,
**kwargs,
):
+ if device is None:
+ device = self.model.primary_device
self._clear_cache()
image_w, image_h = self._get_image_sizes(low_res, img_size, aspect_ratio, img_scale)
diffusion = self.get_diffusion(sample_timestep_respacing)
@@ -98,7 +101,7 @@ def embeddings_to_image(
def model_fn(x_t, ts, **kwargs):
half = x_t[: len(x_t) // bs_scale]
- combined = torch.cat([half] * bs_scale, dim=0)
+ combined = torch.cat([half] * bs_scale, dim=0).to(device)
model_out = self.model(combined, ts, **kwargs)
eps, rest = model_out[:, :3], model_out[:, 3:]
if bs_scale == 3:
@@ -200,7 +203,7 @@ def model_fn(x_t, ts, **kwargs):
dynamic_thresholding_p=dynamic_thresholding_p,
dynamic_thresholding_c=dynamic_thresholding_c,
inpainting_mask=inpainting_mask,
- device=self.model.primary_device,
+ device=device,
progress=progress,
sample_fn=sample_fn,
)[:batch_size]
@@ -214,7 +217,7 @@ def model_fn(x_t, ts, **kwargs):
model_kwargs=model_kwargs,
dynamic_thresholding_p=dynamic_thresholding_p,
dynamic_thresholding_c=dynamic_thresholding_c,
- device=self.model.primary_device,
+ device=device,
progress=progress,
sample_fn=sample_fn,
)[:batch_size]
diff --git a/deepfloyd_if/modules/stage_II.py b/deepfloyd_if/modules/stage_II.py
index d14b838..d4eb0af 100644
--- a/deepfloyd_if/modules/stage_II.py
+++ b/deepfloyd_if/modules/stage_II.py
@@ -22,7 +22,7 @@ def embeddings_to_image(
self, low_res, t5_embs, style_t5_embs=None, positive_t5_embs=None, negative_t5_embs=None, batch_repeat=1,
aug_level=0.25, dynamic_thresholding_p=0.95, dynamic_thresholding_c=1.0, sample_loop='ddpm',
sample_timestep_respacing='smart50', guidance_scale=4.0, img_scale=4.0, positive_mixer=0.5,
- progress=True, seed=None, sample_fn=None, **kwargs):
+ progress=True, seed=None, sample_fn=None, device=None, **kwargs):
return super().embeddings_to_image(
t5_embs=t5_embs,
low_res=low_res,
@@ -42,5 +42,9 @@ def embeddings_to_image(
progress=progress,
seed=seed,
sample_fn=sample_fn,
+ device=device,
**kwargs
)
+
+ def to(self, x):
+ self.model.to(x)
diff --git a/deepfloyd_if/pipelines/optimized_dream.py b/deepfloyd_if/pipelines/optimized_dream.py
index 2eeb0df..dfc9e15 100644
--- a/deepfloyd_if/pipelines/optimized_dream.py
+++ b/deepfloyd_if/pipelines/optimized_dream.py
@@ -29,6 +29,9 @@ def run_stage1(
):
run_garbage_collection()
+ if custom_timesteps_1 == "none":
+ custom_timesteps_1 = str(num_inference_steps_1)
+
images, _ = model.embeddings_to_image(t5_embs=t5_embs,
negative_t5_embs=negative_t5_embs,
num_images_per_prompt=num_images,
@@ -38,34 +41,33 @@ def run_stage1(
)
pil_images_I = model.to_images(images, disable_watermark=True)
- return pil_images_I
+ return images, pil_images_I
def run_stage2(
model,
- stage1_result,
- stage2_index: int,
- seed_2: int = 0,
- guidance_scale_2: float = 4.0,
+ t5_embs,
+ negative_t5_embs,
+ images,
+ seed: int = 0,
+ guidance_scale: float = 4.0,
custom_timesteps_2: str = 'smart50',
num_inference_steps_2: int = 50,
disable_watermark: bool = True,
+ device=None
) -> Image:
run_garbage_collection()
- prompt_embeds = stage1_result['prompt_embeds']
- negative_embeds = stage1_result['negative_embeds']
- images = stage1_result['images']
- images = images[[stage2_index]]
-
+ if custom_timesteps_2 == "none":
+ custom_timesteps_2 = str(num_inference_steps_2)
stageII_generations, _ = model.embeddings_to_image(low_res=images,
- t5_embs=prompt_embeds,
- negative_t5_embs=negative_embeds,
- guidance_scale=guidance_scale_2,
+ t5_embs=t5_embs,
+ negative_t5_embs=negative_t5_embs,
+ guidance_scale=guidance_scale,
sample_timestep_respacing=custom_timesteps_2,
- seed=seed_2)
+ seed=seed, device=device)
pil_images_II = model.to_images(stageII_generations, disable_watermark=disable_watermark)
-
+ print(pil_images_II)
return pil_images_II
diff --git a/run_ui.py b/run_ui.py
index 73f5dec..c102b8c 100644
--- a/run_ui.py
+++ b/run_ui.py
@@ -20,7 +20,8 @@
device = torch.device(0)
if_I = IFStageI('IF-I-XL-v1.0', device=torch.device("cpu"))
if_I.to(torch.float16) # half
-# # if_II = IFStageII('IF-II-L-v1.0', device=torch.device("cpu"))
+if_II = IFStageII('IF-II-L-v1.0', device=torch.device("cpu"))
+if_I.to(torch.float16) # half
# # if_III = StableStageIII('stable-diffusion-x4-upscaler', device=torch.device("cpu"))
t5_device = torch.device(0)
t5 = T5Embedder(device=t5_device, t5_model_kwargs={"low_cpu_mem_usage": True,
@@ -37,6 +38,12 @@ def switch_devices(stage):
torch.cuda.empty_cache()
torch.cuda.synchronize()
if_I.to(torch.device(0))
+ elif stage == 2:
+ if_I.to(torch.device("cpu"))
+ gc.collect()
+ torch.cuda.empty_cache()
+ torch.cuda.synchronize()
+ if_II.to(torch.device(0))
def process_and_run_stage1(prompt,
@@ -46,26 +53,47 @@ def process_and_run_stage1(prompt,
guidance_scale_1,
custom_timesteps_1,
num_inference_steps_1):
+ global t5_embs, negative_t5_embs, images
print("Encoding prompts..")
- prompt = t5.get_text_embeddings(prompt)
- if negative_prompt == "":
- negative_prompt = torch.zeros_like(prompt)
- else:
- negative_prompt = t5.get_text_embeddings(negative_prompt)
+ t5_embs = t5.get_text_embeddings([prompt])
+ negative_t5_embs = t5.get_text_embeddings([negative_prompt])
switch_devices(stage=1)
- prompt = prompt.to(if_I.device)
- negative_prompt = negative_prompt.to(if_I.device)
+ t5_embs = t5_embs.to(if_I.device)
+ negative_t5_embs = negative_t5_embs.to(if_I.device)
print("Encoded. Running 1st stage")
- return run_stage1(
+ images, images_ret = run_stage1(
if_I,
- t5_embs=prompt,
- negative_t5_embs=negative_prompt,
+ t5_embs=t5_embs,
+ negative_t5_embs=negative_t5_embs,
seed=seed_1,
num_images=num_images,
guidance_scale_1=guidance_scale_1,
custom_timesteps_1=custom_timesteps_1,
num_inference_steps_1=num_inference_steps_1
)
+ return images_ret
+
+
+def process_and_run_stage2(
+ index,
+ seed_2,
+ guidance_scale_2,
+ custom_timesteps_2,
+ num_inference_steps_2):
+ global t5_embs, negative_t5_embs, images
+ print("Stage 2..")
+ switch_devices(stage=2)
+ return run_stage2(
+ if_II,
+ t5_embs=t5_embs,
+ negative_t5_embs=negative_t5_embs,
+ images=images[index].unsqueeze(0).to(device),
+ seed=seed_2,
+ guidance_scale=guidance_scale_2,
+ custom_timesteps_2=custom_timesteps_2,
+ num_inference_steps_2=num_inference_steps_2,
+ device=device
+ )
def create_ui(args):
@@ -105,11 +133,11 @@ def create_ui(args):
interactive=False,
visible=not args.DISABLE_SD_X4_UPSCALER)
with gr.Column(visible=False) as upscale_view:
- result = gr.Image(label='Result',
- show_label=False,
- type='filepath',
- interactive=False,
- elem_id='upscaled-image').style(height=640)
+ result = gr.Gallery(label='Result',
+ show_label=False,
+ elem_id='upscaled-image').style(
+ columns=args.GALLERY_COLUMN_NUM,
+ object_fit='contain')
back_to_selection_button = gr.Button('Back to selection')
with gr.Accordion('Advanced options',
open=False,
@@ -138,7 +166,7 @@ def create_ui(args):
'smart100',
'smart185',
],
- value="smart100",
+ value="fast27",
visible=True)
num_inference_steps_1 = gr.Slider(
label='Number of inference steps',
@@ -176,7 +204,7 @@ def create_ui(args):
'smart100',
'smart185',
],
- value="smart50",
+ value="smart27",
visible=True)
num_inference_steps_2 = gr.Slider(
label='Number of inference steps',
@@ -228,61 +256,45 @@ def create_ui(args):
outputs=selected_index_for_stage2,
queue=False,
)
- #
- # selected_index_for_stage2.change(
- # fn=update_upscale_button,
- # inputs=selected_index_for_stage2,
- # outputs=[
- # upscale_button,
- # upscale_to_256_button,
- # ],
- # queue=False,
- # )
- #
- # stage2_inputs = [
- # stage1_result_path,
- # selected_index_for_stage2,
- # seed_2,
- # guidance_scale_2,
- # custom_timesteps_2,
- # num_inference_steps_2,
- # ]
- #
- # upscale_to_256_button.click(
- # fn=check_if_stage2_selected,
- # inputs=selected_index_for_stage2,
- # queue=False,
- # ).then(
- # fn=randomize_seed_fn,
- # inputs=[seed_2, randomize_seed_2],
- # outputs=seed_2,
- # queue=False,
- # ).then(
- # fn=show_upscaled_view,
- # outputs=[
- # gallery_view,
- # upscale_view,
- # ],
- # queue=False,
- # ).then(
- # fn=run_stage2,
- # inputs=stage2_inputs,
- # outputs=result,
- # api_name='upscale256',
- # ) # .success(
- # # fn=upload_stage2_info,
- # # inputs=[
- # # stage1_param_file_hash_name,
- # # result,
- # # selected_index_for_stage2,
- # # seed_2,
- # # guidance_scale_2,
- # # custom_timesteps_2,
- # # num_inference_steps_2,
- # # ],
- # # queue=False,
- # # )
- #
+
+ selected_index_for_stage2.change(
+ fn=update_upscale_button,
+ inputs=selected_index_for_stage2,
+ outputs=[
+ upscale_button,
+ upscale_to_256_button,
+ ],
+ queue=False,
+ )
+
+ upscale_to_256_button.click(
+ fn=check_if_stage2_selected,
+ inputs=selected_index_for_stage2,
+ queue=False,
+ ).then(
+ fn=randomize_seed_fn,
+ inputs=[seed_2, randomize_seed_2],
+ outputs=seed_2,
+ queue=False,
+ ).then(
+ fn=show_upscaled_view,
+ outputs=[
+ gallery_view,
+ upscale_view,
+ ],
+ queue=False,
+ ).then(
+ fn=process_and_run_stage2,
+ inputs=[
+ selected_index_for_stage2,
+ seed_2,
+ guidance_scale_2,
+ custom_timesteps_2,
+ num_inference_steps_2,
+ ],
+ outputs=result,
+ )
+
# stage2_3_inputs = [
# stage1_result_path,
# selected_index_for_stage2,
From e4af95d08c42f17cbd3b8201606b94821d708ec7 Mon Sep 17 00:00:00 2001
From: neon
Date: Sat, 29 Apr 2023 22:11:49 +0200
Subject: [PATCH 04/10] stage 3
---
README.md | 126 ++++++++++++-----
deepfloyd_if/modules/stage_III_sd_x4.py | 14 +-
deepfloyd_if/pipelines/optimized_dream.py | 37 +++--
run_ui.py | 157 ++++++++++++----------
4 files changed, 205 insertions(+), 129 deletions(-)
diff --git a/README.md b/README.md
index 4a8b803..e1a72c1 100644
--- a/README.md
+++ b/README.md
@@ -1,3 +1,21 @@
+# Optimized DeepFloyd IF by neonsecret
+
+Tested on rtx 3070, 8 gb vram
+
+stage 1: ~3.5 sec per iteration
+stage 2: 11 seconds for 27 steps
+stage 3: 2 seconds for 40 steps
+
+### To run the ui:
+
+```bash
+python run_ui.py
+```
+
+All the models are automatically downloaded.
+
+original readme:
+
[](LICENSE)
[](LICENSE-MODEL)
[](https://pepy.tech/project/deepfloyd_if)
@@ -11,21 +29,32 @@
-We introduce DeepFloyd IF, a novel state-of-the-art open-source text-to-image model with a high degree of photorealism and language understanding. DeepFloyd IF is a modular composed of a frozen text encoder and three cascaded pixel diffusion modules: a base model that generates 64x64 px image based on text prompt and two super-resolution models, each designed to generate images of increasing resolution: 256x256 px and 1024x1024 px. All stages of the model utilize a frozen text encoder based on the T5 transformer to extract text embeddings, which are then fed into a UNet architecture enhanced with cross-attention and attention pooling. The result is a highly efficient model that outperforms current state-of-the-art models, achieving a zero-shot FID score of 6.66 on the COCO dataset. Our work underscores the potential of larger UNet architectures in the first stage of cascaded diffusion models and depicts a promising future for text-to-image synthesis.
+We introduce DeepFloyd IF, a novel state-of-the-art open-source text-to-image model with a high degree of photorealism
+and language understanding. DeepFloyd IF is a modular composed of a frozen text encoder and three cascaded pixel
+diffusion modules: a base model that generates 64x64 px image based on text prompt and two super-resolution models, each
+designed to generate images of increasing resolution: 256x256 px and 1024x1024 px. All stages of the model utilize a
+frozen text encoder based on the T5 transformer to extract text embeddings, which are then fed into a UNet architecture
+enhanced with cross-attention and attention pooling. The result is a highly efficient model that outperforms current
+state-of-the-art models, achieving a zero-shot FID score of 6.66 on the COCO dataset. Our work underscores the potential
+of larger UNet architectures in the first stage of cascaded diffusion models and depicts a promising future for
+text-to-image synthesis.
-*Inspired by* [*Photorealistic Text-to-Image Diffusion Models with Deep Language Understanding*](https://arxiv.org/pdf/2205.11487.pdf)
+*Inspired by* [*Photorealistic Text-to-Image Diffusion Models with Deep Language
+Understanding*](https://arxiv.org/pdf/2205.11487.pdf)
## Minimum requirements to use all IF models:
+
- 16GB vRAM for IF-I-XL (4.3B text to 64x64 base module) & IF-II-L (1.2B to 256x256 upscaler module)
-- 24GB vRAM for IF-I-XL (4.3B text to 64x64 base module) & IF-II-L (1.2B to 256x256 upscaler module) & Stable x4 (to 1024x1024 upscaler)
+- 24GB vRAM for IF-I-XL (4.3B text to 64x64 base module) & IF-II-L (1.2B to 256x256 upscaler module) & Stable x4 (to
+ 1024x1024 upscaler)
- `xformers` and set env variable `FORCE_MEM_EFFICIENT_ATTN=1`
-
## Quick Start
+
[](https://colab.research.google.com/github/huggingface/notebooks/blob/main/diffusers/deepfloyd_if_free_tier_google_colab.ipynb)
[](https://huggingface.co/spaces/DeepFloyd/IF)
@@ -36,25 +65,28 @@ pip install git+https://github.com/openai/CLIP.git --no-deps
```
## Local notebooks
+
[](https://huggingface.co/DeepFloyd/IF-notebooks/blob/main/pipes-DeepFloyd-IF-v1.0.ipynb)
[](https://www.kaggle.com/code/shonenkov/deepfloyd-if-4-3b-generator-of-pictures)
-The Dream, Style Transfer, Super Resolution or Inpainting modes are avaliable in a Jupyter Notebook [here](https://huggingface.co/DeepFloyd/IF-notebooks/blob/main/pipes-DeepFloyd-IF-v1.0.ipynb).
-
-
+The Dream, Style Transfer, Super Resolution or Inpainting modes are avaliable in a Jupyter
+Notebook [here](https://huggingface.co/DeepFloyd/IF-notebooks/blob/main/pipes-DeepFloyd-IF-v1.0.ipynb).
## Integration with 🤗 Diffusers
IF is also integrated with the 🤗 Hugging Face [Diffusers library](https://github.com/huggingface/diffusers/).
-Diffusers runs each stage individually allowing the user to customize the image generation process as well as allowing to inspect intermediate results easily.
+Diffusers runs each stage individually allowing the user to customize the image generation process as well as allowing
+to inspect intermediate results easily.
### Example
Before you can use IF, you need to accept its usage conditions. To do so:
+
1. Make sure to have a [Hugging Face account](https://huggingface.co/join) and be loggin in
2. Accept the license on the model card of [DeepFloyd/IF-I-XL-v1.0](https://huggingface.co/DeepFloyd/IF-I-XL-v1.0)
3. Make sure to login locally. Install `huggingface_hub`
+
```sh
pip install huggingface_hub --upgrade
```
@@ -67,7 +99,8 @@ from huggingface_hub import login
login()
```
-and enter your [Hugging Face Hub access token](https://huggingface.co/docs/hub/security-tokens#what-are-user-access-tokens).
+and enter
+your [Hugging Face Hub access token](https://huggingface.co/docs/hub/security-tokens#what-are-user-access-tokens).
Next we install `diffusers` and dependencies:
@@ -77,7 +110,9 @@ pip install diffusers accelerate transformers safetensors
And we can now run the model locally.
-By default `diffusers` makes use of [model cpu offloading](https://huggingface.co/docs/diffusers/optimization/fp16#model-offloading-for-fast-inference-and-memory-savings) to run the whole IF pipeline with as little as 14 GB of VRAM.
+By default `diffusers` makes use
+of [model cpu offloading](https://huggingface.co/docs/diffusers/optimization/fp16#model-offloading-for-fast-inference-and-memory-savings)
+to run the whole IF pipeline with as little as 14 GB of VRAM.
If you are using `torch>=2.0.0`, make sure to **delete all** `enable_xformers_memory_efficient_attention()`
functions.
@@ -100,8 +135,10 @@ stage_2.enable_xformers_memory_efficient_attention() # remove line if torch.__v
stage_2.enable_model_cpu_offload()
# stage 3
-safety_modules = {"feature_extractor": stage_1.feature_extractor, "safety_checker": stage_1.safety_checker, "watermarker": stage_1.watermarker}
-stage_3 = DiffusionPipeline.from_pretrained("stabilityai/stable-diffusion-x4-upscaler", **safety_modules, torch_dtype=torch.float16)
+safety_modules = {"feature_extractor": stage_1.feature_extractor, "safety_checker": stage_1.safety_checker,
+ "watermarker": stage_1.watermarker}
+stage_3 = DiffusionPipeline.from_pretrained("stabilityai/stable-diffusion-x4-upscaler", **safety_modules,
+ torch_dtype=torch.float16)
stage_3.enable_xformers_memory_efficient_attention() # remove line if torch.__version__ >= 2.0.0
stage_3.enable_model_cpu_offload()
@@ -113,12 +150,14 @@ prompt_embeds, negative_embeds = stage_1.encode_prompt(prompt)
generator = torch.manual_seed(0)
# stage 1
-image = stage_1(prompt_embeds=prompt_embeds, negative_prompt_embeds=negative_embeds, generator=generator, output_type="pt").images
+image = stage_1(prompt_embeds=prompt_embeds, negative_prompt_embeds=negative_embeds, generator=generator,
+ output_type="pt").images
pt_to_pil(image)[0].save("./if_stage_I.png")
# stage 2
image = stage_2(
- image=image, prompt_embeds=prompt_embeds, negative_prompt_embeds=negative_embeds, generator=generator, output_type="pt"
+ image=image, prompt_embeds=prompt_embeds, negative_prompt_embeds=negative_embeds, generator=generator,
+ output_type="pt"
).images
pt_to_pil(image)[0].save("./if_stage_II.png")
@@ -127,12 +166,16 @@ image = stage_3(prompt=prompt, image=image, generator=generator, noise_level=100
image[0].save("./if_stage_III.png")
```
- There are multiple ways to speed up the inference time and lower the memory consumption even more with `diffusers`. To do so, please have a look at the Diffusers docs:
+There are multiple ways to speed up the inference time and lower the memory consumption even more with `diffusers`. To
+do so, please have a look at the Diffusers docs:
- 🚀 [Optimizing for inference time](https://huggingface.co/docs/diffusers/api/pipelines/if#optimizing-for-speed)
-- ⚙️ [Optimizing for low memory during inference](https://huggingface.co/docs/diffusers/api/pipelines/if#optimizing-for-memory)
+-
+⚙️ [Optimizing for low memory during inference](https://huggingface.co/docs/diffusers/api/pipelines/if#optimizing-for-memory)
-For more in-detail information about how to use IF, please have a look at [the IF blog post](https://huggingface.co/blog/if) and [the documentation](https://huggingface.co/docs/diffusers/main/en/api/pipelines/if) 📖.
+For more in-detail information about how to use IF, please have a look
+at [the IF blog post](https://huggingface.co/blog/if)
+and [the documentation](https://huggingface.co/docs/diffusers/main/en/api/pipelines/if) 📖.
## Run the code locally
@@ -150,6 +193,7 @@ t5 = T5Embedder(device="cpu")
```
### I. Dream
+
Dream is the text-to-image mode of the IF model
```python
@@ -160,7 +204,7 @@ count = 4
result = dream(
t5=t5, if_I=if_I, if_II=if_II, if_III=if_III,
- prompt=[prompt]*count,
+ prompt=[prompt] * count,
seed=42,
if_I_kwargs={
"guidance_scale": 7.0,
@@ -179,6 +223,7 @@ result = dream(
if_III.show(result['III'], size=14)
```
+

## II. Zero-shot Image-to-Image Translation
@@ -186,6 +231,7 @@ if_III.show(result['III'], size=14)

In Style Transfer mode, the output of your prompt comes out at the style of the `support_pil_img`
+
```python
from deepfloyd_if.pipelines import style_transfer
@@ -215,9 +261,10 @@ if_I.show(result['II'], 1, 20)

-
## III. Super Resolution
-For super-resolution, users can run `IF-II` and `IF-III` or 'Stable x4' on an image that was not necessarely generated by IF (two cascades):
+
+For super-resolution, users can run `IF-II` and `IF-III` or 'Stable x4' on an image that was not necessarely generated
+by IF (two cascades):
```python
from deepfloyd_if.pipelines import super_resolution
@@ -253,7 +300,6 @@ show_superres(raw_pil_image, high_res['III'][0])

-
### IV. Zero-shot Inpainting
```python
@@ -289,9 +335,11 @@ if_I.show(result['I'], 2, 3)
if_I.show(result['II'], 2, 6)
if_I.show(result['III'], 2, 14)
```
+

### 🤗 Model Zoo 🤗
+
The link to download the weights as well as the model cards will be available soon on each model of the model zoo
#### Original
@@ -305,7 +353,7 @@ The link to download the weights as well as the model cards will be available so
| [IF-II-L](https://huggingface.co/DeepFloyd/IF-II-L-v1.0)* | II | 1.2B | - | 1536 | 2.5M |
| IF-III-L* _(soon)_ | III | 700M | - | 3072 | 1.25M |
- *best modules
+*best modules
### Quantitative Evaluation
@@ -315,16 +363,19 @@ The link to download the weights as well as the model cards will be available so
## License
-The code in this repository is released under the bespoke license (see added [point two](https://github.com/deep-floyd/IF/blob/main/LICENSE#L13)).
+The code in this repository is released under the bespoke license (see
+added [point two](https://github.com/deep-floyd/IF/blob/main/LICENSE#L13)).
-The weights will be available soon via [the DeepFloyd organization at Hugging Face](https://huggingface.co/DeepFloyd) and have their own LICENSE.
+The weights will be available soon via [the DeepFloyd organization at Hugging Face](https://huggingface.co/DeepFloyd)
+and have their own LICENSE.
-**Disclaimer:** *The initial release of the IF model is under a restricted research-purposes-only license temporarily to gather feedback, and after that we intend to release a fully open-source model in line with other Stability AI models.*
+**Disclaimer:** *The initial release of the IF model is under a restricted research-purposes-only license temporarily to
+gather feedback, and after that we intend to release a fully open-source model in line with other Stability AI models.*
## Limitations and Biases
-The models available in this codebase have known limitations and biases. Please refer to [the model card](https://huggingface.co/DeepFloyd/IF-I-L-v1.0) for more information.
-
+The models available in this codebase have known limitations and biases. Please refer
+to [the model card](https://huggingface.co/DeepFloyd/IF-I-L-v1.0) for more information.
## 🎓 DeepFloyd IF creators:
@@ -335,17 +386,26 @@ The models available in this codebase have known limitations and biases. Please
- Ksenia Ivanova [GitHub](https://github.com/ivksu) | [Twitter](https://twitter.com/susiaiv)
- Nadiia Klokova [GitHub](https://github.com/vauimpuls) | [Twitter](https://twitter.com/vauimpuls)
-
## 📄 Research Paper (Soon)
## Acknowledgements
-Special thanks to [StabilityAI](http://stability.ai) and its CEO [Emad Mostaque](https://twitter.com/emostaque) for invaluable support, providing GPU compute and infrastructure to train the models (our gratitude goes to [Richard Vencu](https://github.com/rvencu)); thanks to [LAION](https://laion.ai) and [Christoph Schuhmann](https://github.com/christophschuhmann) in particular for contribution to the project and well-prepared datasets; thanks to [Huggingface](https://huggingface.co) teams for optimizing models' speed and memory consumption during inference, creating demos and giving cool advice!
+Special thanks to [StabilityAI](http://stability.ai) and its CEO [Emad Mostaque](https://twitter.com/emostaque) for
+invaluable support, providing GPU compute and infrastructure to train the models (our gratitude goes
+to [Richard Vencu](https://github.com/rvencu)); thanks to [LAION](https://laion.ai)
+and [Christoph Schuhmann](https://github.com/christophschuhmann) in particular for contribution to the project and
+well-prepared datasets; thanks to [Huggingface](https://huggingface.co) teams for optimizing models' speed and memory
+consumption during inference, creating demos and giving cool advice!
## 🚀 External Contributors 🚀
-- The Biggest Thanks [@Apolinário](https://github.com/apolinario), for ideas, consultations, help and support on all stages to make IF available in open-source; for writing a lot of documentation and instructions; for creating a friendly atmosphere in difficult moments 🦉;
+
+- The Biggest Thanks [@Apolinário](https://github.com/apolinario), for ideas, consultations, help and support on all
+ stages to make IF available in open-source; for writing a lot of documentation and instructions; for creating a
+ friendly atmosphere in difficult moments 🦉;
- Thanks, [@patrickvonplaten](https://github.com/patrickvonplaten), for improving loading time of unet models by 80%;
-for integration Stable-Diffusion-x4 as native pipeline 💪;
-- Thanks, [@williamberman](https://github.com/williamberman) and [@patrickvonplaten](https://github.com/patrickvonplaten) for diffusers integration 🙌;
-- Thanks, [@hysts](https://github.com/hysts) and [@Apolinário](https://github.com/apolinario) for creating [the best gradio demo with IF](https://huggingface.co/spaces/DeepFloyd/IF) 🚀;
+ for integration Stable-Diffusion-x4 as native pipeline 💪;
+- Thanks, [@williamberman](https://github.com/williamberman)
+ and [@patrickvonplaten](https://github.com/patrickvonplaten) for diffusers integration 🙌;
+- Thanks, [@hysts](https://github.com/hysts) and [@Apolinário](https://github.com/apolinario) for
+ creating [the best gradio demo with IF](https://huggingface.co/spaces/DeepFloyd/IF) 🚀;
- Thanks, [@Dango233](https://github.com/Dango233), for adapting IF with xformers memory efficient attention 💪;
diff --git a/deepfloyd_if/modules/stage_III_sd_x4.py b/deepfloyd_if/modules/stage_III_sd_x4.py
index 7599317..2148f94 100644
--- a/deepfloyd_if/modules/stage_III_sd_x4.py
+++ b/deepfloyd_if/modules/stage_III_sd_x4.py
@@ -9,7 +9,6 @@
class StableStageIII(IFBaseModule):
-
available_models = ['stable-diffusion-x4-upscaler']
def __init__(self, *args, model_kwargs=None, pil_img_size=1024, **kwargs):
@@ -20,7 +19,7 @@ def __init__(self, *args, model_kwargs=None, pil_img_size=1024, **kwargs):
' Please run `pip install diffusers --upgrade`'
)
- model_id = os.path.join('stabilityai', self.dir_or_name)
+ model_id = 'stabilityai' + "/" + self.dir_or_name.strip()
model_kwargs = model_kwargs or {}
precision = str(model_kwargs.get('precision', '16'))
@@ -46,12 +45,11 @@ def use_diffusers(self):
return False
def embeddings_to_image(
- self, low_res, t5_embs, style_t5_embs=None, positive_t5_embs=None, negative_t5_embs=None, batch_repeat=1,
+ self, low_res, prompt, style_t5_embs=None, positive_t5_embs=None, negative_t5_embs=None, batch_repeat=1,
aug_level=0.0, blur_sigma=None, dynamic_thresholding_p=0.95, dynamic_thresholding_c=1.0, positive_mixer=0.5,
sample_loop='ddpm', sample_timestep_respacing='75', guidance_scale=4.0, img_scale=4.0,
- progress=True, seed=None, sample_fn=None, **kwargs):
+ progress=True, seed=None, sample_fn=None, device=None, **kwargs):
- prompt = kwargs.pop('prompt')
noise_level = kwargs.pop('noise_level', 20)
if sample_loop == 'ddpm':
@@ -64,7 +62,8 @@ def embeddings_to_image(
self.model.set_progress_bar_config(disable=not progress)
generator = torch.manual_seed(seed)
- prompt = sum([batch_repeat * [p] for p in prompt], [])
+ prompt = [prompt]
+ print(prompt)
low_res = low_res.repeat(batch_repeat, 1, 1, 1)
metadata = {
@@ -82,3 +81,6 @@ def embeddings_to_image(
sample = self._IFBaseModule__validate_generations(images)
return sample, metadata
+
+ def to(self, x):
+ self.model.to(x)
diff --git a/deepfloyd_if/pipelines/optimized_dream.py b/deepfloyd_if/pipelines/optimized_dream.py
index dfc9e15..0824539 100644
--- a/deepfloyd_if/pipelines/optimized_dream.py
+++ b/deepfloyd_if/pipelines/optimized_dream.py
@@ -67,30 +67,29 @@ def run_stage2(
sample_timestep_respacing=custom_timesteps_2,
seed=seed, device=device)
pil_images_II = model.to_images(stageII_generations, disable_watermark=disable_watermark)
- print(pil_images_II)
return pil_images_II
def run_stage3(
model,
- image: Image,
- t5_embs,
+ prompt,
negative_t5_embs,
- seed_3: int = 0,
- guidance_scale_3: float = 9.0,
- sample_timestep_respacing='super40',
- disable_watermark=True
+ images,
+ seed: int = 0,
+ guidance_scale: float = 4.0,
+ custom_timesteps_2: str = 'smart50',
+ num_inference_steps_2: int = 50,
+ disable_watermark: bool = True,
+ device=None
) -> Image:
run_garbage_collection()
-
- _stageIII_generations, _ = model.embeddings_to_image(image=image,
- t5_embs=t5_embs,
- negative_t5_embs=negative_t5_embs,
- num_images_per_prompt=1,
- guidance_scale=guidance_scale_3,
- noise_level=100,
- sample_timestep_respacing=sample_timestep_respacing,
- seed=seed_3)
- pil_image_III = model.to_images(_stageIII_generations, disable_watermark=disable_watermark)
-
- return pil_image_III
+ stageII_generations, _ = model.embeddings_to_image(low_res=images,
+ prompt=prompt,
+ negative_t5_embs=negative_t5_embs,
+ guidance_scale=guidance_scale,
+ sample_timestep_respacing=num_inference_steps_2,
+ num_images_per_prompt=1,
+ noise_level=100,
+ seed=seed, device=device)
+ pil_images_III = model.to_images(stageII_generations, disable_watermark=disable_watermark)
+ return pil_images_III
diff --git a/run_ui.py b/run_ui.py
index c102b8c..6e03b4f 100644
--- a/run_ui.py
+++ b/run_ui.py
@@ -22,7 +22,7 @@
if_I.to(torch.float16) # half
if_II = IFStageII('IF-II-L-v1.0', device=torch.device("cpu"))
if_I.to(torch.float16) # half
-# # if_III = StableStageIII('stable-diffusion-x4-upscaler', device=torch.device("cpu"))
+if_III = StableStageIII('stable-diffusion-x4-upscaler', device=torch.device("cpu"))
t5_device = torch.device(0)
t5 = T5Embedder(device=t5_device, t5_model_kwargs={"low_cpu_mem_usage": True,
"torch_dtype": torch.float16,
@@ -31,6 +31,16 @@
def switch_devices(stage):
+ if stage == 0:
+ if_I.to(torch.device("cpu"))
+ if_II.to(torch.device("cpu"))
+ if_III.to(torch.device("cpu"))
+
+ gc.collect()
+ torch.cuda.empty_cache()
+ torch.cuda.synchronize()
+
+ # dispatch_model(t5.model, get_device_map(t5_device, all2cpu=False))
if stage == 1:
# t5.model.cpu()
dispatch_model(t5.model, get_device_map(t5_device, all2cpu=True))
@@ -44,6 +54,12 @@ def switch_devices(stage):
torch.cuda.empty_cache()
torch.cuda.synchronize()
if_II.to(torch.device(0))
+ elif stage == 3:
+ if_II.to(torch.device("cpu"))
+ gc.collect()
+ torch.cuda.empty_cache()
+ torch.cuda.synchronize()
+ if_III.to(torch.device(0))
def process_and_run_stage1(prompt,
@@ -55,6 +71,7 @@ def process_and_run_stage1(prompt,
num_inference_steps_1):
global t5_embs, negative_t5_embs, images
print("Encoding prompts..")
+ switch_devices(stage=0)
t5_embs = t5.get_text_embeddings([prompt])
negative_t5_embs = t5.get_text_embeddings([negative_prompt])
switch_devices(stage=1)
@@ -96,6 +113,29 @@ def process_and_run_stage2(
)
+def process_and_run_stage3(
+ index,
+ prompt,
+ seed_2,
+ guidance_scale_2,
+ custom_timesteps_2,
+ num_inference_steps_2):
+ global t5_embs, negative_t5_embs, images
+ print("Stage 3..")
+ switch_devices(stage=3)
+ return run_stage3(
+ if_III,
+ prompt=prompt,
+ negative_t5_embs=negative_t5_embs,
+ images=images[index].unsqueeze(0).to(device),
+ seed=seed_2,
+ guidance_scale=guidance_scale_2,
+ custom_timesteps_2=custom_timesteps_2,
+ num_inference_steps_2=num_inference_steps_2,
+ device=device
+ )
+
+
def create_ui(args):
with gr.Blocks(css='ui_files/style.css') as demo:
with gr.Box():
@@ -129,9 +169,6 @@ def create_ui(args):
'Upscale to 256px',
visible=args.DISABLE_SD_X4_UPSCALER,
interactive=False)
- upscale_button = gr.Button('Upscale',
- interactive=False,
- visible=not args.DISABLE_SD_X4_UPSCALER)
with gr.Column(visible=False) as upscale_view:
result = gr.Gallery(label='Result',
show_label=False,
@@ -139,6 +176,9 @@ def create_ui(args):
columns=args.GALLERY_COLUMN_NUM,
object_fit='contain')
back_to_selection_button = gr.Button('Back to selection')
+ upscale_button = gr.Button('Upscale 4x',
+ interactive=False,
+ visible=True)
with gr.Accordion('Advanced options',
open=False,
visible=args.SHOW_ADVANCED_OPTIONS):
@@ -295,73 +335,48 @@ def create_ui(args):
outputs=result,
)
- # stage2_3_inputs = [
- # stage1_result_path,
- # selected_index_for_stage2,
- # seed_2,
- # guidance_scale_2,
- # custom_timesteps_2,
- # num_inference_steps_2,
- # prompt,
- # negative_prompt,
- # seed_3,
- # guidance_scale_3,
- # num_inference_steps_3,
- # ]
- #
- # upscale_button.click(
- # fn=check_if_stage2_selected,
- # inputs=selected_index_for_stage2,
- # queue=False,
- # ).then(
- # fn=randomize_seed_fn,
- # inputs=[seed_2, randomize_seed_2],
- # outputs=seed_2,
- # queue=False,
- # ).then(
- # fn=randomize_seed_fn,
- # inputs=[seed_3, randomize_seed_3],
- # outputs=seed_3,
- # queue=False,
- # ).then(
- # fn=show_upscaled_view,
- # outputs=[
- # gallery_view,
- # upscale_view,
- # ],
- # queue=False,
- # ).then(
- # fn=run_stage3,
- # inputs=stage2_3_inputs,
- # outputs=result,
- # api_name='upscale1024',
- # ) # .success(
- # # fn=upload_stage2_3_info,
- # # inputs=[
- # # stage1_param_file_hash_name,
- # # result,
- # # selected_index_for_stage2,
- # # seed_2,
- # # guidance_scale_2,
- # # custom_timesteps_2,
- # # num_inference_steps_2,
- # # prompt,
- # # negative_prompt,
- # # seed_3,
- # # guidance_scale_3,
- # # num_inference_steps_3,
- # # ],
- # # queue=False,
- # # )
- #
- # back_to_selection_button.click(
- # fn=show_gallery_view,
- # outputs=[
- # gallery_view,
- # upscale_view,
- # ],
- # queue=False,
- # )
+ upscale_button.click(
+ fn=check_if_stage2_selected,
+ inputs=selected_index_for_stage2,
+ queue=False,
+ ).then(
+ fn=randomize_seed_fn,
+ inputs=[seed_2, randomize_seed_2],
+ outputs=seed_2,
+ queue=False,
+ ).then(
+ fn=randomize_seed_fn,
+ inputs=[seed_3, randomize_seed_3],
+ outputs=seed_3,
+ queue=False,
+ ).then(
+ fn=show_upscaled_view,
+ outputs=[
+ gallery_view,
+ upscale_view,
+ ],
+ queue=False,
+ ).then(
+ fn=process_and_run_stage3,
+ inputs=[
+ selected_index_for_stage2,
+ prompt,
+ seed_2,
+ guidance_scale_3,
+ custom_timesteps_2,
+ num_inference_steps_3,
+ ],
+ outputs=result,
+ )
+
+ back_to_selection_button.click(
+ fn=show_gallery_view,
+ outputs=[
+ gallery_view,
+ upscale_view,
+ ],
+ queue=False,
+ )
return demo
From bcda682b78e1430b14ac6fb413a6ca116534cc10 Mon Sep 17 00:00:00 2001
From: neon
Date: Sun, 30 Apr 2023 09:38:21 +0200
Subject: [PATCH 05/10] t5 reloading
---
deepfloyd_if/modules/t5.py | 10 ++++++++++
run_ui.py | 4 ++++
2 files changed, 14 insertions(+)
diff --git a/deepfloyd_if/modules/t5.py b/deepfloyd_if/modules/t5.py
index 63ec1c7..1cbd3dd 100644
--- a/deepfloyd_if/modules/t5.py
+++ b/deepfloyd_if/modules/t5.py
@@ -75,6 +75,16 @@ def __init__(self, device, dir_or_name='t5-v1_1-xxl', *, cache_dir=None, hf_toke
self.tokenizer = AutoTokenizer.from_pretrained(path)
self.model = T5EncoderModel.from_pretrained(path, **t5_model_kwargs).eval()
+ self.saved_path = path
+ self.saved_kwargs = t5_model_kwargs
+ self.loaded = True
+
+ def reload(self, dmap):
+ del self.model
+ torch.cuda.empty_cache()
+ self.saved_kwargs["device_map"] = dmap
+ self.model = T5EncoderModel.from_pretrained(self.saved_path, **self.saved_kwargs).eval()
+ self.loaded = True
def to(self, x):
self.model.to(x)
diff --git a/run_ui.py b/run_ui.py
index 6e03b4f..8312777 100644
--- a/run_ui.py
+++ b/run_ui.py
@@ -40,10 +40,14 @@ def switch_devices(stage):
torch.cuda.empty_cache()
torch.cuda.synchronize()
+ if not t5.loaded:
+ print("Reloading t5")
+ t5.reload(get_device_map(t5_device, all2cpu=False))
# dispatch_model(t5.model, get_device_map(t5_device, all2cpu=False))
if stage == 1:
# t5.model.cpu()
dispatch_model(t5.model, get_device_map(t5_device, all2cpu=True))
+ t5.model.loaded = False
gc.collect()
torch.cuda.empty_cache()
torch.cuda.synchronize()
From 9fc7a2ac0fc7599317cddab6524783f4aa28c13d Mon Sep 17 00:00:00 2001
From: neon
Date: Sun, 30 Apr 2023 09:55:44 +0200
Subject: [PATCH 06/10] fixes
---
deepfloyd_if/model/unet_split.py | 3 +++
deepfloyd_if/modules/stage_III_sd_x4.py | 1 -
deepfloyd_if/pipelines/optimized_dream.py | 4 ++--
run_ui.py | 8 +++-----
4 files changed, 8 insertions(+), 8 deletions(-)
diff --git a/deepfloyd_if/model/unet_split.py b/deepfloyd_if/model/unet_split.py
index 1d64e3b..623a378 100644
--- a/deepfloyd_if/model/unet_split.py
+++ b/deepfloyd_if/model/unet_split.py
@@ -2,6 +2,7 @@
import gc
import os
import math
+import time
from abc import abstractmethod
import torch
@@ -661,6 +662,8 @@ def to(self, x, stage=1): # 0, 1, 2, 3
self.out.to(x)
else:
super().to(x)
+ # time.sleep(3)
+ # print(stage)
def forward(self, x, timesteps, text_emb, timestep_text_emb=None, aug_emb=None, use_cache=False, **kwargs):
hs = []
diff --git a/deepfloyd_if/modules/stage_III_sd_x4.py b/deepfloyd_if/modules/stage_III_sd_x4.py
index 2148f94..e3af3a4 100644
--- a/deepfloyd_if/modules/stage_III_sd_x4.py
+++ b/deepfloyd_if/modules/stage_III_sd_x4.py
@@ -63,7 +63,6 @@ def embeddings_to_image(
generator = torch.manual_seed(seed)
prompt = [prompt]
- print(prompt)
low_res = low_res.repeat(batch_repeat, 1, 1, 1)
metadata = {
diff --git a/deepfloyd_if/pipelines/optimized_dream.py b/deepfloyd_if/pipelines/optimized_dream.py
index 0824539..7240559 100644
--- a/deepfloyd_if/pipelines/optimized_dream.py
+++ b/deepfloyd_if/pipelines/optimized_dream.py
@@ -34,7 +34,7 @@ def run_stage1(
images, _ = model.embeddings_to_image(t5_embs=t5_embs,
negative_t5_embs=negative_t5_embs,
- num_images_per_prompt=num_images,
+ batch_repeat=num_images,
guidance_scale=guidance_scale_1,
sample_timestep_respacing=custom_timesteps_1,
seed=seed
@@ -89,7 +89,7 @@ def run_stage3(
guidance_scale=guidance_scale,
sample_timestep_respacing=num_inference_steps_2,
num_images_per_prompt=1,
- noise_level=100,
+ noise_level=60,
seed=seed, device=device)
pil_images_III = model.to_images(stageII_generations, disable_watermark=disable_watermark)
return pil_images_III
diff --git a/run_ui.py b/run_ui.py
index 8312777..01eae4d 100644
--- a/run_ui.py
+++ b/run_ui.py
@@ -1,7 +1,5 @@
import argparse
import gc
-import os
-import time
import numpy as np
from accelerate import dispatch_model
@@ -47,7 +45,7 @@ def switch_devices(stage):
if stage == 1:
# t5.model.cpu()
dispatch_model(t5.model, get_device_map(t5_device, all2cpu=True))
- t5.model.loaded = False
+ t5.loaded = False
gc.collect()
torch.cuda.empty_cache()
torch.cuda.synchronize()
@@ -76,8 +74,8 @@ def process_and_run_stage1(prompt,
global t5_embs, negative_t5_embs, images
print("Encoding prompts..")
switch_devices(stage=0)
- t5_embs = t5.get_text_embeddings([prompt])
- negative_t5_embs = t5.get_text_embeddings([negative_prompt])
+ t5_embs = t5.get_text_embeddings([prompt] * num_images)
+ negative_t5_embs = t5.get_text_embeddings([negative_prompt] * num_images)
switch_devices(stage=1)
t5_embs = t5_embs.to(if_I.device)
negative_t5_embs = negative_t5_embs.to(if_I.device)
From 2a5c212b99a43ad2d129626dd8ce2421d5c76d2b Mon Sep 17 00:00:00 2001
From: neon
Date: Sun, 30 Apr 2023 10:02:47 +0200
Subject: [PATCH 07/10] small optimization
---
README.md | 2 +-
deepfloyd_if/model/unet_split.py | 25 ++++++++-----------------
2 files changed, 9 insertions(+), 18 deletions(-)
diff --git a/README.md b/README.md
index e1a72c1..8c63b75 100644
--- a/README.md
+++ b/README.md
@@ -2,7 +2,7 @@
Tested on rtx 3070, 8 gb vram
-stage 1: ~3.5 sec per iteration
+stage 1: 1.30 min for 27 steps, ~3.35 sec per iteration
stage 2: 11 seconds for 27 steps
stage 3: 2 seconds for 40 steps
diff --git a/deepfloyd_if/model/unet_split.py b/deepfloyd_if/model/unet_split.py
index 623a378..bdcdb8d 100644
--- a/deepfloyd_if/model/unet_split.py
+++ b/deepfloyd_if/model/unet_split.py
@@ -634,36 +634,29 @@ def to(self, x, stage=1): # 0, 1, 2, 3
if isinstance(x, torch.device):
secondary_device = self.secondary_device
if stage == 1:
- self.middle_block.to(secondary_device)
self.output_blocks.to(secondary_device)
self.out.to(secondary_device)
- self.collect()
+
+ # self.collect()
+
self.time_embed.to(x)
self.encoder_proj.to(x)
self.encoder_pooling.to(x)
self.input_blocks.to(x)
- elif stage == 2:
- self.time_embed.to(secondary_device)
- self.encoder_proj.to(secondary_device)
- self.encoder_pooling.to(secondary_device)
- self.input_blocks.to(secondary_device)
- self.output_blocks.to(secondary_device)
- self.out.to(secondary_device)
- self.collect()
self.middle_block.to(x)
- elif stage == 3:
+ elif stage == 2:
self.time_embed.to(secondary_device)
self.encoder_proj.to(secondary_device)
self.encoder_pooling.to(secondary_device)
self.input_blocks.to(secondary_device)
self.middle_block.to(secondary_device)
- self.collect()
+
+ # self.collect()
+
self.output_blocks.to(x)
self.out.to(x)
else:
super().to(x)
- # time.sleep(3)
- # print(stage)
def forward(self, x, timesteps, text_emb, timestep_text_emb=None, aug_emb=None, use_cache=False, **kwargs):
hs = []
@@ -696,11 +689,9 @@ def forward(self, x, timesteps, text_emb, timestep_text_emb=None, aug_emb=None,
h = module(h, emb, encoder_out)
hs.append(h)
- self.to(self.primary_device, stage=2)
-
h = self.middle_block(h, emb, encoder_out)
- self.to(self.primary_device, stage=3)
+ self.to(self.primary_device, stage=2)
for module in self.output_blocks:
h = torch.cat([h, hs.pop()], dim=1)
From 5aa32581638a466023198775620c3e239be0303d Mon Sep 17 00:00:00 2001
From: neon
Date: Sun, 30 Apr 2023 16:59:39 +0200
Subject: [PATCH 08/10] fixed stage 3, default params changed
---
README.md | 8 ++++----
deepfloyd_if/modules/stage_III_sd_x4.py | 8 +-------
deepfloyd_if/pipelines/optimized_dream.py | 4 ++--
run_ui.py | 7 ++++---
4 files changed, 11 insertions(+), 16 deletions(-)
diff --git a/README.md b/README.md
index 8c63b75..58c86d9 100644
--- a/README.md
+++ b/README.md
@@ -2,9 +2,9 @@
Tested on rtx 3070, 8 gb vram
-stage 1: 1.30 min for 27 steps, ~3.35 sec per iteration
-stage 2: 11 seconds for 27 steps
-stage 3: 2 seconds for 40 steps
+stage 1: 1.30 min for 27 steps, ~3.35 sec per iteration \
+stage 2: 11 seconds for 27 steps \
+stage 3: 30 seconds for 40 steps
### To run the ui:
@@ -14,7 +14,7 @@ python run_ui.py
All the models are automatically downloaded.
-original readme:
+Original readme:
[](LICENSE)
[](LICENSE-MODEL)
diff --git a/deepfloyd_if/modules/stage_III_sd_x4.py b/deepfloyd_if/modules/stage_III_sd_x4.py
index e3af3a4..2f6d6ef 100644
--- a/deepfloyd_if/modules/stage_III_sd_x4.py
+++ b/deepfloyd_if/modules/stage_III_sd_x4.py
@@ -50,8 +50,6 @@ def embeddings_to_image(
sample_loop='ddpm', sample_timestep_respacing='75', guidance_scale=4.0, img_scale=4.0,
progress=True, seed=None, sample_fn=None, device=None, **kwargs):
- noise_level = kwargs.pop('noise_level', 20)
-
if sample_loop == 'ddpm':
self.model.scheduler = DDPMScheduler.from_config(self.model.scheduler.config)
else:
@@ -62,13 +60,9 @@ def embeddings_to_image(
self.model.set_progress_bar_config(disable=not progress)
generator = torch.manual_seed(seed)
- prompt = [prompt]
- low_res = low_res.repeat(batch_repeat, 1, 1, 1)
-
metadata = {
- 'image': low_res,
+ 'image': low_res, # 1 3 256 256
'prompt': prompt,
- 'noise_level': noise_level,
'generator': generator,
'guidance_scale': guidance_scale,
'num_inference_steps': num_inference_steps,
diff --git a/deepfloyd_if/pipelines/optimized_dream.py b/deepfloyd_if/pipelines/optimized_dream.py
index 7240559..77c1e74 100644
--- a/deepfloyd_if/pipelines/optimized_dream.py
+++ b/deepfloyd_if/pipelines/optimized_dream.py
@@ -67,7 +67,7 @@ def run_stage2(
sample_timestep_respacing=custom_timesteps_2,
seed=seed, device=device)
pil_images_II = model.to_images(stageII_generations, disable_watermark=disable_watermark)
- return pil_images_II
+ return stageII_generations, pil_images_II
def run_stage3(
@@ -89,7 +89,7 @@ def run_stage3(
guidance_scale=guidance_scale,
sample_timestep_respacing=num_inference_steps_2,
num_images_per_prompt=1,
- noise_level=60,
+ noise_level=20,
seed=seed, device=device)
pil_images_III = model.to_images(stageII_generations, disable_watermark=disable_watermark)
return pil_images_III
diff --git a/run_ui.py b/run_ui.py
index 01eae4d..9b0a6d8 100644
--- a/run_ui.py
+++ b/run_ui.py
@@ -102,7 +102,7 @@ def process_and_run_stage2(
global t5_embs, negative_t5_embs, images
print("Stage 2..")
switch_devices(stage=2)
- return run_stage2(
+ images, images_ret = run_stage2(
if_II,
t5_embs=t5_embs,
negative_t5_embs=negative_t5_embs,
@@ -113,6 +113,7 @@ def process_and_run_stage2(
num_inference_steps_2=num_inference_steps_2,
device=device
)
+ return images_ret
def process_and_run_stage3(
@@ -208,7 +209,7 @@ def create_ui(args):
'smart100',
'smart185',
],
- value="fast27",
+ value="smart50",
visible=True)
num_inference_steps_1 = gr.Slider(
label='Number of inference steps',
@@ -273,7 +274,7 @@ def create_ui(args):
minimum=1,
maximum=200,
step=1,
- value=40,
+ value=60,
visible=True)
with gr.Box():
with gr.Row():
From 103adb5b21b1513adb2bb706a05dd6e6a7e80698 Mon Sep 17 00:00:00 2001
From: neon
Date: Sun, 30 Apr 2023 17:11:03 +0200
Subject: [PATCH 09/10] added aspect ratio possibility
---
deepfloyd_if/pipelines/optimized_dream.py | 3 ++-
run_ui.py | 12 +++++++++---
2 files changed, 11 insertions(+), 4 deletions(-)
diff --git a/deepfloyd_if/pipelines/optimized_dream.py b/deepfloyd_if/pipelines/optimized_dream.py
index 77c1e74..48ced8d 100644
--- a/deepfloyd_if/pipelines/optimized_dream.py
+++ b/deepfloyd_if/pipelines/optimized_dream.py
@@ -26,6 +26,7 @@ def run_stage1(
guidance_scale_1: float = 7.0,
custom_timesteps_1: str = 'smart100',
num_inference_steps_1: int = 100,
+ aspect_ratio='1:1'
):
run_garbage_collection()
@@ -37,7 +38,7 @@ def run_stage1(
batch_repeat=num_images,
guidance_scale=guidance_scale_1,
sample_timestep_respacing=custom_timesteps_1,
- seed=seed
+ seed=seed, aspect_ratio=aspect_ratio
)
pil_images_I = model.to_images(images, disable_watermark=True)
diff --git a/run_ui.py b/run_ui.py
index 9b0a6d8..e8663f7 100644
--- a/run_ui.py
+++ b/run_ui.py
@@ -70,7 +70,8 @@ def process_and_run_stage1(prompt,
num_images,
guidance_scale_1,
custom_timesteps_1,
- num_inference_steps_1):
+ num_inference_steps_1,
+ aspect_ratio):
global t5_embs, negative_t5_embs, images
print("Encoding prompts..")
switch_devices(stage=0)
@@ -88,7 +89,8 @@ def process_and_run_stage1(prompt,
num_images=num_images,
guidance_scale_1=guidance_scale_1,
custom_timesteps_1=custom_timesteps_1,
- num_inference_steps_1=num_inference_steps_1
+ num_inference_steps_1=num_inference_steps_1,
+ aspect_ratio=aspect_ratio
)
return images_ret
@@ -158,6 +160,9 @@ def create_ui(args):
placeholder='Enter a negative prompt',
elem_id='negative-prompt-text-input',
).style(container=False)
+ aspect_ratio_1 = gr.Radio(
+ ["16:9", "4:3", "1:1", "3:4", "9:16"], value="1:1", label="Aspect ratio"
+ ).style(container=False)
generate_button = gr.Button('Generate').style(full_width=False)
with gr.Column() as gallery_view:
@@ -290,7 +295,8 @@ def create_ui(args):
num_images,
guidance_scale_1,
custom_timesteps_1,
- num_inference_steps_1],
+ num_inference_steps_1,
+ aspect_ratio_1],
gallery
)
From 3e9808d97317c7b4bace22cc0a4fecfcdf56895d Mon Sep 17 00:00:00 2001
From: neon
Date: Mon, 1 May 2023 09:22:44 +0200
Subject: [PATCH 10/10] replaced aspect ratio with width/height params, fixed
number of images
---
deepfloyd_if/modules/base.py | 7 +++++--
deepfloyd_if/modules/stage_I.py | 6 ++++--
deepfloyd_if/pipelines/optimized_dream.py | 24 +++++++++++++----------
run_ui.py | 19 ++++++++++++------
4 files changed, 36 insertions(+), 20 deletions(-)
diff --git a/deepfloyd_if/modules/base.py b/deepfloyd_if/modules/base.py
index d08453b..6cad9f3 100644
--- a/deepfloyd_if/modules/base.py
+++ b/deepfloyd_if/modules/base.py
@@ -89,12 +89,13 @@ def embeddings_to_image(
support_noise_less_qsample_steps=0,
inpainting_mask=None,
device=None,
+ force_size=False,
**kwargs,
):
if device is None:
device = self.model.primary_device
self._clear_cache()
- image_w, image_h = self._get_image_sizes(low_res, img_size, aspect_ratio, img_scale)
+ image_w, image_h = self._get_image_sizes(low_res, img_size, aspect_ratio, img_scale, force_size=force_size)
diffusion = self.get_diffusion(sample_timestep_respacing)
bs_scale = 2 if positive_t5_embs is None else 3
@@ -331,11 +332,13 @@ def show(self, pil_images, nrow=None, size=10):
def _clear_cache(self):
self.model.cache = None
- def _get_image_sizes(self, low_res, img_size, aspect_ratio, img_scale):
+ def _get_image_sizes(self, low_res, img_size, aspect_ratio, img_scale, force_size=True):
if low_res is not None:
bs, c, h, w = low_res.shape
image_h, image_w = int((h * img_scale) // 32) * 32, int((w * img_scale // 32)) * 32
else:
+ if force_size:
+ return img_size[0], img_size[1]
scale_w, scale_h = aspect_ratio.split(':')
scale_w, scale_h = int(scale_w), int(scale_h)
coef = scale_w / scale_h
diff --git a/deepfloyd_if/modules/stage_I.py b/deepfloyd_if/modules/stage_I.py
index 8b3c62d..92883ac 100644
--- a/deepfloyd_if/modules/stage_I.py
+++ b/deepfloyd_if/modules/stage_I.py
@@ -35,7 +35,8 @@ def to(self, x, stage=1, secondary_device=torch.device("cpu")): # 0, 1, 2, 3
def embeddings_to_image(self, t5_embs, style_t5_embs=None, positive_t5_embs=None, negative_t5_embs=None,
batch_repeat=1, dynamic_thresholding_p=0.95, sample_loop='ddpm', positive_mixer=0.25,
sample_timestep_respacing='150', dynamic_thresholding_c=1.5, guidance_scale=7.0,
- aspect_ratio='1:1', progress=True, seed=None, sample_fn=None, **kwargs):
+ img_size=(64, 64), aspect_ratio='1:1', progress=True, seed=None, sample_fn=None,
+ force_size=False, **kwargs):
return super().embeddings_to_image(
t5_embs=t5_embs,
style_t5_embs=style_t5_embs,
@@ -47,11 +48,12 @@ def embeddings_to_image(self, t5_embs, style_t5_embs=None, positive_t5_embs=None
sample_loop=sample_loop,
sample_timestep_respacing=sample_timestep_respacing,
guidance_scale=guidance_scale,
- img_size=64,
+ img_size=img_size,
aspect_ratio=aspect_ratio,
progress=progress,
seed=seed,
sample_fn=sample_fn,
positive_mixer=positive_mixer,
+ force_size=force_size,
**kwargs
)
diff --git a/deepfloyd_if/pipelines/optimized_dream.py b/deepfloyd_if/pipelines/optimized_dream.py
index 48ced8d..2267db2 100644
--- a/deepfloyd_if/pipelines/optimized_dream.py
+++ b/deepfloyd_if/pipelines/optimized_dream.py
@@ -26,23 +26,27 @@ def run_stage1(
guidance_scale_1: float = 7.0,
custom_timesteps_1: str = 'smart100',
num_inference_steps_1: int = 100,
- aspect_ratio='1:1'
+ aspect_ratio='1:1',
+ img_size=(64, 64)
):
run_garbage_collection()
if custom_timesteps_1 == "none":
custom_timesteps_1 = str(num_inference_steps_1)
- images, _ = model.embeddings_to_image(t5_embs=t5_embs,
- negative_t5_embs=negative_t5_embs,
- batch_repeat=num_images,
- guidance_scale=guidance_scale_1,
- sample_timestep_respacing=custom_timesteps_1,
- seed=seed, aspect_ratio=aspect_ratio
- )
- pil_images_I = model.to_images(images, disable_watermark=True)
+ ret_images1, ret_images2 = [], []
+ for _ in range(num_images):
+ images, _ = model.embeddings_to_image(t5_embs=t5_embs,
+ negative_t5_embs=negative_t5_embs,
+ guidance_scale=guidance_scale_1, img_size=img_size,
+ sample_timestep_respacing=custom_timesteps_1,
+ seed=seed, aspect_ratio=aspect_ratio, force_size=True
+ )
+ pil_images_I = model.to_images(images, disable_watermark=True)
+ ret_images1.append(pil_images_I[0])
+ ret_images2.append(images[0])
- return images, pil_images_I
+ return ret_images2, ret_images1
def run_stage2(
diff --git a/run_ui.py b/run_ui.py
index e8663f7..ff63bca 100644
--- a/run_ui.py
+++ b/run_ui.py
@@ -71,7 +71,9 @@ def process_and_run_stage1(prompt,
guidance_scale_1,
custom_timesteps_1,
num_inference_steps_1,
- aspect_ratio):
+ # aspect_ratio,
+ width,
+ height):
global t5_embs, negative_t5_embs, images
print("Encoding prompts..")
switch_devices(stage=0)
@@ -90,7 +92,8 @@ def process_and_run_stage1(prompt,
guidance_scale_1=guidance_scale_1,
custom_timesteps_1=custom_timesteps_1,
num_inference_steps_1=num_inference_steps_1,
- aspect_ratio=aspect_ratio
+ # aspect_ratio=aspect_ratio,
+ img_size=(width, height)
)
return images_ret
@@ -160,9 +163,11 @@ def create_ui(args):
placeholder='Enter a negative prompt',
elem_id='negative-prompt-text-input',
).style(container=False)
- aspect_ratio_1 = gr.Radio(
- ["16:9", "4:3", "1:1", "3:4", "9:16"], value="1:1", label="Aspect ratio"
- ).style(container=False)
+ width = gr.Slider(32, 128, value=64, step=8, label="Width").style(container=False)
+ height = gr.Slider(32, 128, value=64, step=8, label="Height").style(container=False)
+ # aspect_ratio_1 = gr.Radio(
+ # ["16:9", "4:3", "1:1", "3:4", "9:16"], value="1:1", label="Aspect ratio"
+ # ).style(container=False)
generate_button = gr.Button('Generate').style(full_width=False)
with gr.Column() as gallery_view:
@@ -296,7 +301,9 @@ def create_ui(args):
guidance_scale_1,
custom_timesteps_1,
num_inference_steps_1,
- aspect_ratio_1],
+ # aspect_ratio_1,
+ width,
+ height],
gallery
)