| |
| |
|
|
| import argparse |
| import hashlib |
| import inspect |
| import itertools |
| import math |
| import os |
| import random |
| import re |
| from pathlib import Path |
| from typing import Optional, List, Literal |
|
|
| import torch |
| import torch.nn.functional as F |
| import torch.optim as optim |
| import torch.utils.checkpoint |
| from diffusers import ( |
| AutoencoderKL, |
| DDPMScheduler, |
| StableDiffusionPipeline, |
| UNet2DConditionModel, |
| ) |
| from diffusers.optimization import get_scheduler |
| from huggingface_hub import HfFolder, Repository, whoami |
| from PIL import Image |
| from torch.utils.data import Dataset |
| from torchvision import transforms |
| from tqdm.auto import tqdm |
| from transformers import CLIPTextModel, CLIPTokenizer |
| import wandb |
| import fire |
|
|
| from lora_diffusion import ( |
| PivotalTuningDatasetCapation, |
| extract_lora_ups_down, |
| inject_trainable_lora, |
| inject_trainable_lora_extended, |
| inspect_lora, |
| save_lora_weight, |
| save_all, |
| prepare_clip_model_sets, |
| evaluate_pipe, |
| UNET_EXTENDED_TARGET_REPLACE, |
| ) |
|
|
|
|
| def get_models( |
| pretrained_model_name_or_path, |
| pretrained_vae_name_or_path, |
| revision, |
| placeholder_tokens: List[str], |
| initializer_tokens: List[str], |
| device="cuda:0", |
| ): |
|
|
| tokenizer = CLIPTokenizer.from_pretrained( |
| pretrained_model_name_or_path, |
| subfolder="tokenizer", |
| revision=revision, |
| ) |
|
|
| text_encoder = CLIPTextModel.from_pretrained( |
| pretrained_model_name_or_path, |
| subfolder="text_encoder", |
| revision=revision, |
| ) |
|
|
| placeholder_token_ids = [] |
|
|
| for token, init_tok in zip(placeholder_tokens, initializer_tokens): |
| num_added_tokens = tokenizer.add_tokens(token) |
| if num_added_tokens == 0: |
| raise ValueError( |
| f"The tokenizer already contains the token {token}. Please pass a different" |
| " `placeholder_token` that is not already in the tokenizer." |
| ) |
|
|
| placeholder_token_id = tokenizer.convert_tokens_to_ids(token) |
|
|
| placeholder_token_ids.append(placeholder_token_id) |
|
|
| |
|
|
| text_encoder.resize_token_embeddings(len(tokenizer)) |
| token_embeds = text_encoder.get_input_embeddings().weight.data |
| if init_tok.startswith("<rand"): |
| |
| sigma_val = float(re.findall(r"<rand-(.*)>", init_tok)[0]) |
|
|
| token_embeds[placeholder_token_id] = ( |
| torch.randn_like(token_embeds[0]) * sigma_val |
| ) |
| print( |
| f"Initialized {token} with random noise (sigma={sigma_val}), empirically {token_embeds[placeholder_token_id].mean().item():.3f} +- {token_embeds[placeholder_token_id].std().item():.3f}" |
| ) |
| print(f"Norm : {token_embeds[placeholder_token_id].norm():.4f}") |
|
|
| elif init_tok == "<zero>": |
| token_embeds[placeholder_token_id] = torch.zeros_like(token_embeds[0]) |
| else: |
| token_ids = tokenizer.encode(init_tok, add_special_tokens=False) |
| |
| if len(token_ids) > 1: |
| raise ValueError("The initializer token must be a single token.") |
|
|
| initializer_token_id = token_ids[0] |
| token_embeds[placeholder_token_id] = token_embeds[initializer_token_id] |
|
|
| vae = AutoencoderKL.from_pretrained( |
| pretrained_vae_name_or_path or pretrained_model_name_or_path, |
| subfolder=None if pretrained_vae_name_or_path else "vae", |
| revision=None if pretrained_vae_name_or_path else revision, |
| ) |
| unet = UNet2DConditionModel.from_pretrained( |
| pretrained_model_name_or_path, |
| subfolder="unet", |
| revision=revision, |
| ) |
|
|
| return ( |
| text_encoder.to(device), |
| vae.to(device), |
| unet.to(device), |
| tokenizer, |
| placeholder_token_ids, |
| ) |
|
|
|
|
| @torch.no_grad() |
| def text2img_dataloader( |
| train_dataset, |
| train_batch_size, |
| tokenizer, |
| vae, |
| text_encoder, |
| cached_latents: bool = False, |
| ): |
|
|
| if cached_latents: |
| cached_latents_dataset = [] |
| for idx in tqdm(range(len(train_dataset))): |
| batch = train_dataset[idx] |
| |
| latents = vae.encode( |
| batch["instance_images"].unsqueeze(0).to(dtype=vae.dtype).to(vae.device) |
| ).latent_dist.sample() |
| latents = latents * 0.18215 |
| batch["instance_images"] = latents.squeeze(0) |
| cached_latents_dataset.append(batch) |
|
|
| def collate_fn(examples): |
| input_ids = [example["instance_prompt_ids"] for example in examples] |
| pixel_values = [example["instance_images"] for example in examples] |
| pixel_values = torch.stack(pixel_values) |
| pixel_values = pixel_values.to(memory_format=torch.contiguous_format).float() |
|
|
| input_ids = tokenizer.pad( |
| {"input_ids": input_ids}, |
| padding="max_length", |
| max_length=tokenizer.model_max_length, |
| return_tensors="pt", |
| ).input_ids |
|
|
| batch = { |
| "input_ids": input_ids, |
| "pixel_values": pixel_values, |
| } |
|
|
| if examples[0].get("mask", None) is not None: |
| batch["mask"] = torch.stack([example["mask"] for example in examples]) |
|
|
| return batch |
|
|
| if cached_latents: |
|
|
| train_dataloader = torch.utils.data.DataLoader( |
| cached_latents_dataset, |
| batch_size=train_batch_size, |
| shuffle=True, |
| collate_fn=collate_fn, |
| ) |
|
|
| print("PTI : Using cached latent.") |
|
|
| else: |
| train_dataloader = torch.utils.data.DataLoader( |
| train_dataset, |
| batch_size=train_batch_size, |
| shuffle=True, |
| collate_fn=collate_fn, |
| ) |
|
|
| return train_dataloader |
|
|
|
|
| def inpainting_dataloader( |
| train_dataset, train_batch_size, tokenizer, vae, text_encoder |
| ): |
| def collate_fn(examples): |
| input_ids = [example["instance_prompt_ids"] for example in examples] |
| pixel_values = [example["instance_images"] for example in examples] |
| mask_values = [example["instance_masks"] for example in examples] |
| masked_image_values = [ |
| example["instance_masked_images"] for example in examples |
| ] |
|
|
| |
| |
| if examples[0].get("class_prompt_ids", None) is not None: |
| input_ids += [example["class_prompt_ids"] for example in examples] |
| pixel_values += [example["class_images"] for example in examples] |
| mask_values += [example["class_masks"] for example in examples] |
| masked_image_values += [ |
| example["class_masked_images"] for example in examples |
| ] |
|
|
| pixel_values = ( |
| torch.stack(pixel_values).to(memory_format=torch.contiguous_format).float() |
| ) |
| mask_values = ( |
| torch.stack(mask_values).to(memory_format=torch.contiguous_format).float() |
| ) |
| masked_image_values = ( |
| torch.stack(masked_image_values) |
| .to(memory_format=torch.contiguous_format) |
| .float() |
| ) |
|
|
| input_ids = tokenizer.pad( |
| {"input_ids": input_ids}, |
| padding="max_length", |
| max_length=tokenizer.model_max_length, |
| return_tensors="pt", |
| ).input_ids |
|
|
| batch = { |
| "input_ids": input_ids, |
| "pixel_values": pixel_values, |
| "mask_values": mask_values, |
| "masked_image_values": masked_image_values, |
| } |
|
|
| if examples[0].get("mask", None) is not None: |
| batch["mask"] = torch.stack([example["mask"] for example in examples]) |
|
|
| return batch |
|
|
| train_dataloader = torch.utils.data.DataLoader( |
| train_dataset, |
| batch_size=train_batch_size, |
| shuffle=True, |
| collate_fn=collate_fn, |
| ) |
|
|
| return train_dataloader |
|
|
|
|
| def loss_step( |
| batch, |
| unet, |
| vae, |
| text_encoder, |
| scheduler, |
| train_inpainting=False, |
| t_mutliplier=1.0, |
| mixed_precision=False, |
| mask_temperature=1.0, |
| cached_latents: bool = False, |
| ): |
| weight_dtype = torch.float32 |
| if not cached_latents: |
| latents = vae.encode( |
| batch["pixel_values"].to(dtype=weight_dtype).to(unet.device) |
| ).latent_dist.sample() |
| latents = latents * 0.18215 |
|
|
| if train_inpainting: |
| masked_image_latents = vae.encode( |
| batch["masked_image_values"].to(dtype=weight_dtype).to(unet.device) |
| ).latent_dist.sample() |
| masked_image_latents = masked_image_latents * 0.18215 |
| mask = F.interpolate( |
| batch["mask_values"].to(dtype=weight_dtype).to(unet.device), |
| scale_factor=1 / 8, |
| ) |
| else: |
| latents = batch["pixel_values"] |
|
|
| if train_inpainting: |
| masked_image_latents = batch["masked_image_latents"] |
| mask = batch["mask_values"] |
|
|
| noise = torch.randn_like(latents) |
| bsz = latents.shape[0] |
|
|
| timesteps = torch.randint( |
| 0, |
| int(scheduler.config.num_train_timesteps * t_mutliplier), |
| (bsz,), |
| device=latents.device, |
| ) |
| timesteps = timesteps.long() |
|
|
| noisy_latents = scheduler.add_noise(latents, noise, timesteps) |
|
|
| if train_inpainting: |
| latent_model_input = torch.cat( |
| [noisy_latents, mask, masked_image_latents], dim=1 |
| ) |
| else: |
| latent_model_input = noisy_latents |
|
|
| if mixed_precision: |
| with torch.cuda.amp.autocast(): |
|
|
| encoder_hidden_states = text_encoder( |
| batch["input_ids"].to(text_encoder.device) |
| )[0] |
|
|
| model_pred = unet( |
| latent_model_input, timesteps, encoder_hidden_states |
| ).sample |
| else: |
|
|
| encoder_hidden_states = text_encoder( |
| batch["input_ids"].to(text_encoder.device) |
| )[0] |
|
|
| model_pred = unet(latent_model_input, timesteps, encoder_hidden_states).sample |
|
|
| if scheduler.config.prediction_type == "epsilon": |
| target = noise |
| elif scheduler.config.prediction_type == "v_prediction": |
| target = scheduler.get_velocity(latents, noise, timesteps) |
| else: |
| raise ValueError(f"Unknown prediction type {scheduler.config.prediction_type}") |
|
|
| if batch.get("mask", None) is not None: |
|
|
| mask = ( |
| batch["mask"] |
| .to(model_pred.device) |
| .reshape( |
| model_pred.shape[0], 1, model_pred.shape[2] * 8, model_pred.shape[3] * 8 |
| ) |
| ) |
| |
| mask = F.interpolate( |
| mask.float(), |
| size=model_pred.shape[-2:], |
| mode="nearest", |
| ) |
|
|
| mask = (mask + 0.01).pow(mask_temperature) |
|
|
| mask = mask / mask.max() |
|
|
| model_pred = model_pred * mask |
|
|
| target = target * mask |
|
|
| loss = ( |
| F.mse_loss(model_pred.float(), target.float(), reduction="none") |
| .mean([1, 2, 3]) |
| .mean() |
| ) |
|
|
| return loss |
|
|
|
|
| def train_inversion( |
| unet, |
| vae, |
| text_encoder, |
| dataloader, |
| num_steps: int, |
| scheduler, |
| index_no_updates, |
| optimizer, |
| save_steps: int, |
| placeholder_token_ids, |
| placeholder_tokens, |
| save_path: str, |
| tokenizer, |
| lr_scheduler, |
| test_image_path: str, |
| cached_latents: bool, |
| accum_iter: int = 1, |
| log_wandb: bool = False, |
| wandb_log_prompt_cnt: int = 10, |
| class_token: str = "person", |
| train_inpainting: bool = False, |
| mixed_precision: bool = False, |
| clip_ti_decay: bool = True, |
| ): |
|
|
| progress_bar = tqdm(range(num_steps)) |
| progress_bar.set_description("Steps") |
| global_step = 0 |
|
|
| |
| orig_embeds_params = text_encoder.get_input_embeddings().weight.data.clone() |
|
|
| if log_wandb: |
| preped_clip = prepare_clip_model_sets() |
|
|
| index_updates = ~index_no_updates |
| loss_sum = 0.0 |
|
|
| for epoch in range(math.ceil(num_steps / len(dataloader))): |
| unet.eval() |
| text_encoder.train() |
| for batch in dataloader: |
|
|
| lr_scheduler.step() |
|
|
| with torch.set_grad_enabled(True): |
| loss = ( |
| loss_step( |
| batch, |
| unet, |
| vae, |
| text_encoder, |
| scheduler, |
| train_inpainting=train_inpainting, |
| mixed_precision=mixed_precision, |
| cached_latents=cached_latents, |
| ) |
| / accum_iter |
| ) |
|
|
| loss.backward() |
| loss_sum += loss.detach().item() |
|
|
| if global_step % accum_iter == 0: |
| |
| print( |
| text_encoder.get_input_embeddings() |
| .weight.grad[index_updates, :] |
| .norm(dim=-1) |
| .mean() |
| ) |
| optimizer.step() |
| optimizer.zero_grad() |
|
|
| with torch.no_grad(): |
|
|
| |
| if clip_ti_decay: |
| pre_norm = ( |
| text_encoder.get_input_embeddings() |
| .weight[index_updates, :] |
| .norm(dim=-1, keepdim=True) |
| ) |
|
|
| lambda_ = min(1.0, 100 * lr_scheduler.get_last_lr()[0]) |
| text_encoder.get_input_embeddings().weight[ |
| index_updates |
| ] = F.normalize( |
| text_encoder.get_input_embeddings().weight[ |
| index_updates, : |
| ], |
| dim=-1, |
| ) * ( |
| pre_norm + lambda_ * (0.4 - pre_norm) |
| ) |
| print(pre_norm) |
|
|
| current_norm = ( |
| text_encoder.get_input_embeddings() |
| .weight[index_updates, :] |
| .norm(dim=-1) |
| ) |
|
|
| text_encoder.get_input_embeddings().weight[ |
| index_no_updates |
| ] = orig_embeds_params[index_no_updates] |
|
|
| print(f"Current Norm : {current_norm}") |
|
|
| global_step += 1 |
| progress_bar.update(1) |
|
|
| logs = { |
| "loss": loss.detach().item(), |
| "lr": lr_scheduler.get_last_lr()[0], |
| } |
| progress_bar.set_postfix(**logs) |
|
|
| if global_step % save_steps == 0: |
| save_all( |
| unet=unet, |
| text_encoder=text_encoder, |
| placeholder_token_ids=placeholder_token_ids, |
| placeholder_tokens=placeholder_tokens, |
| save_path=os.path.join( |
| save_path, f"step_inv_{global_step}.safetensors" |
| ), |
| save_lora=False, |
| ) |
| if log_wandb: |
| with torch.no_grad(): |
| pipe = StableDiffusionPipeline( |
| vae=vae, |
| text_encoder=text_encoder, |
| tokenizer=tokenizer, |
| unet=unet, |
| scheduler=scheduler, |
| safety_checker=None, |
| feature_extractor=None, |
| ) |
|
|
| |
| images = [] |
| for file in os.listdir(test_image_path): |
| if ( |
| file.lower().endswith(".png") |
| or file.lower().endswith(".jpg") |
| or file.lower().endswith(".jpeg") |
| ): |
| images.append( |
| Image.open(os.path.join(test_image_path, file)) |
| ) |
|
|
| wandb.log({"loss": loss_sum / save_steps}) |
| loss_sum = 0.0 |
| wandb.log( |
| evaluate_pipe( |
| pipe, |
| target_images=images, |
| class_token=class_token, |
| learnt_token="".join(placeholder_tokens), |
| n_test=wandb_log_prompt_cnt, |
| n_step=50, |
| clip_model_sets=preped_clip, |
| ) |
| ) |
|
|
| if global_step >= num_steps: |
| return |
|
|
|
|
| def perform_tuning( |
| unet, |
| vae, |
| text_encoder, |
| dataloader, |
| num_steps, |
| scheduler, |
| optimizer, |
| save_steps: int, |
| placeholder_token_ids, |
| placeholder_tokens, |
| save_path, |
| lr_scheduler_lora, |
| lora_unet_target_modules, |
| lora_clip_target_modules, |
| mask_temperature, |
| out_name: str, |
| tokenizer, |
| test_image_path: str, |
| cached_latents: bool, |
| log_wandb: bool = False, |
| wandb_log_prompt_cnt: int = 10, |
| class_token: str = "person", |
| train_inpainting: bool = False, |
| ): |
|
|
| progress_bar = tqdm(range(num_steps)) |
| progress_bar.set_description("Steps") |
| global_step = 0 |
|
|
| weight_dtype = torch.float16 |
|
|
| unet.train() |
| text_encoder.train() |
|
|
| if log_wandb: |
| preped_clip = prepare_clip_model_sets() |
|
|
| loss_sum = 0.0 |
|
|
| for epoch in range(math.ceil(num_steps / len(dataloader))): |
| for batch in dataloader: |
| lr_scheduler_lora.step() |
|
|
| optimizer.zero_grad() |
|
|
| loss = loss_step( |
| batch, |
| unet, |
| vae, |
| text_encoder, |
| scheduler, |
| train_inpainting=train_inpainting, |
| t_mutliplier=0.8, |
| mixed_precision=True, |
| mask_temperature=mask_temperature, |
| cached_latents=cached_latents, |
| ) |
| loss_sum += loss.detach().item() |
|
|
| loss.backward() |
| torch.nn.utils.clip_grad_norm_( |
| itertools.chain(unet.parameters(), text_encoder.parameters()), 1.0 |
| ) |
| optimizer.step() |
| progress_bar.update(1) |
| logs = { |
| "loss": loss.detach().item(), |
| "lr": lr_scheduler_lora.get_last_lr()[0], |
| } |
| progress_bar.set_postfix(**logs) |
|
|
| global_step += 1 |
|
|
| if global_step % save_steps == 0: |
| save_all( |
| unet, |
| text_encoder, |
| placeholder_token_ids=placeholder_token_ids, |
| placeholder_tokens=placeholder_tokens, |
| save_path=os.path.join( |
| save_path, f"step_{global_step}.safetensors" |
| ), |
| target_replace_module_text=lora_clip_target_modules, |
| target_replace_module_unet=lora_unet_target_modules, |
| ) |
| moved = ( |
| torch.tensor(list(itertools.chain(*inspect_lora(unet).values()))) |
| .mean() |
| .item() |
| ) |
|
|
| print("LORA Unet Moved", moved) |
| moved = ( |
| torch.tensor( |
| list(itertools.chain(*inspect_lora(text_encoder).values())) |
| ) |
| .mean() |
| .item() |
| ) |
|
|
| print("LORA CLIP Moved", moved) |
|
|
| if log_wandb: |
| with torch.no_grad(): |
| pipe = StableDiffusionPipeline( |
| vae=vae, |
| text_encoder=text_encoder, |
| tokenizer=tokenizer, |
| unet=unet, |
| scheduler=scheduler, |
| safety_checker=None, |
| feature_extractor=None, |
| ) |
|
|
| |
| images = [] |
| for file in os.listdir(test_image_path): |
| if file.endswith(".png") or file.endswith(".jpg"): |
| images.append( |
| Image.open(os.path.join(test_image_path, file)) |
| ) |
|
|
| wandb.log({"loss": loss_sum / save_steps}) |
| loss_sum = 0.0 |
| wandb.log( |
| evaluate_pipe( |
| pipe, |
| target_images=images, |
| class_token=class_token, |
| learnt_token="".join(placeholder_tokens), |
| n_test=wandb_log_prompt_cnt, |
| n_step=50, |
| clip_model_sets=preped_clip, |
| ) |
| ) |
|
|
| if global_step >= num_steps: |
| break |
|
|
| save_all( |
| unet, |
| text_encoder, |
| placeholder_token_ids=placeholder_token_ids, |
| placeholder_tokens=placeholder_tokens, |
| save_path=os.path.join(save_path, f"{out_name}.safetensors"), |
| target_replace_module_text=lora_clip_target_modules, |
| target_replace_module_unet=lora_unet_target_modules, |
| ) |
|
|
|
|
| def train( |
| instance_data_dir: str, |
| pretrained_model_name_or_path: str, |
| output_dir: str, |
| train_text_encoder: bool = True, |
| pretrained_vae_name_or_path: str = None, |
| revision: Optional[str] = None, |
| perform_inversion: bool = True, |
| use_template: Literal[None, "object", "style"] = None, |
| train_inpainting: bool = False, |
| placeholder_tokens: str = "", |
| placeholder_token_at_data: Optional[str] = None, |
| initializer_tokens: Optional[str] = None, |
| seed: int = 42, |
| resolution: int = 512, |
| color_jitter: bool = True, |
| train_batch_size: int = 1, |
| sample_batch_size: int = 1, |
| max_train_steps_tuning: int = 1000, |
| max_train_steps_ti: int = 1000, |
| save_steps: int = 100, |
| gradient_accumulation_steps: int = 4, |
| gradient_checkpointing: bool = False, |
| lora_rank: int = 4, |
| lora_unet_target_modules={"CrossAttention", "Attention", "GEGLU"}, |
| lora_clip_target_modules={"CLIPAttention"}, |
| lora_dropout_p: float = 0.0, |
| lora_scale: float = 1.0, |
| use_extended_lora: bool = False, |
| clip_ti_decay: bool = True, |
| learning_rate_unet: float = 1e-4, |
| learning_rate_text: float = 1e-5, |
| learning_rate_ti: float = 5e-4, |
| continue_inversion: bool = False, |
| continue_inversion_lr: Optional[float] = None, |
| use_face_segmentation_condition: bool = False, |
| cached_latents: bool = True, |
| use_mask_captioned_data: bool = False, |
| mask_temperature: float = 1.0, |
| scale_lr: bool = False, |
| lr_scheduler: str = "linear", |
| lr_warmup_steps: int = 0, |
| lr_scheduler_lora: str = "linear", |
| lr_warmup_steps_lora: int = 0, |
| weight_decay_ti: float = 0.00, |
| weight_decay_lora: float = 0.001, |
| use_8bit_adam: bool = False, |
| device="cuda:0", |
| extra_args: Optional[dict] = None, |
| log_wandb: bool = False, |
| wandb_log_prompt_cnt: int = 10, |
| wandb_project_name: str = "new_pti_project", |
| wandb_entity: str = "new_pti_entity", |
| proxy_token: str = "person", |
| enable_xformers_memory_efficient_attention: bool = False, |
| out_name: str = "final_lora", |
| ): |
| torch.manual_seed(seed) |
|
|
| if log_wandb: |
| wandb.init( |
| project=wandb_project_name, |
| entity=wandb_entity, |
| name=f"steps_{max_train_steps_ti}_lr_{learning_rate_ti}_{instance_data_dir.split('/')[-1]}", |
| reinit=True, |
| config={ |
| **(extra_args if extra_args is not None else {}), |
| }, |
| ) |
|
|
| if output_dir is not None: |
| os.makedirs(output_dir, exist_ok=True) |
| |
| if len(placeholder_tokens) == 0: |
| placeholder_tokens = [] |
| print("PTI : Placeholder Tokens not given, using null token") |
| else: |
| placeholder_tokens = placeholder_tokens.split("|") |
|
|
| assert ( |
| sorted(placeholder_tokens) == placeholder_tokens |
| ), f"Placeholder tokens should be sorted. Use something like {'|'.join(sorted(placeholder_tokens))}'" |
|
|
| if initializer_tokens is None: |
| print("PTI : Initializer Tokens not given, doing random inits") |
| initializer_tokens = ["<rand-0.017>"] * len(placeholder_tokens) |
| else: |
| initializer_tokens = initializer_tokens.split("|") |
|
|
| assert len(initializer_tokens) == len( |
| placeholder_tokens |
| ), "Unequal Initializer token for Placeholder tokens." |
|
|
| if proxy_token is not None: |
| class_token = proxy_token |
| class_token = "".join(initializer_tokens) |
|
|
| if placeholder_token_at_data is not None: |
| tok, pat = placeholder_token_at_data.split("|") |
| token_map = {tok: pat} |
|
|
| else: |
| token_map = {"DUMMY": "".join(placeholder_tokens)} |
|
|
| print("PTI : Placeholder Tokens", placeholder_tokens) |
| print("PTI : Initializer Tokens", initializer_tokens) |
|
|
| |
| text_encoder, vae, unet, tokenizer, placeholder_token_ids = get_models( |
| pretrained_model_name_or_path, |
| pretrained_vae_name_or_path, |
| revision, |
| placeholder_tokens, |
| initializer_tokens, |
| device=device, |
| ) |
|
|
| noise_scheduler = DDPMScheduler.from_config( |
| pretrained_model_name_or_path, subfolder="scheduler" |
| ) |
|
|
| if gradient_checkpointing: |
| unet.enable_gradient_checkpointing() |
|
|
| if enable_xformers_memory_efficient_attention: |
| from diffusers.utils.import_utils import is_xformers_available |
|
|
| if is_xformers_available(): |
| unet.enable_xformers_memory_efficient_attention() |
| else: |
| raise ValueError( |
| "xformers is not available. Make sure it is installed correctly" |
| ) |
|
|
| if scale_lr: |
| unet_lr = learning_rate_unet * gradient_accumulation_steps * train_batch_size |
| text_encoder_lr = ( |
| learning_rate_text * gradient_accumulation_steps * train_batch_size |
| ) |
| ti_lr = learning_rate_ti * gradient_accumulation_steps * train_batch_size |
| else: |
| unet_lr = learning_rate_unet |
| text_encoder_lr = learning_rate_text |
| ti_lr = learning_rate_ti |
|
|
| train_dataset = PivotalTuningDatasetCapation( |
| instance_data_root=instance_data_dir, |
| token_map=token_map, |
| use_template=use_template, |
| tokenizer=tokenizer, |
| size=resolution, |
| color_jitter=color_jitter, |
| use_face_segmentation_condition=use_face_segmentation_condition, |
| use_mask_captioned_data=use_mask_captioned_data, |
| train_inpainting=train_inpainting, |
| ) |
|
|
| train_dataset.blur_amount = 200 |
|
|
| if train_inpainting: |
| assert not cached_latents, "Cached latents not supported for inpainting" |
|
|
| train_dataloader = inpainting_dataloader( |
| train_dataset, train_batch_size, tokenizer, vae, text_encoder |
| ) |
| else: |
| train_dataloader = text2img_dataloader( |
| train_dataset, |
| train_batch_size, |
| tokenizer, |
| vae, |
| text_encoder, |
| cached_latents=cached_latents, |
| ) |
|
|
| index_no_updates = torch.arange(len(tokenizer)) != -1 |
|
|
| for tok_id in placeholder_token_ids: |
| index_no_updates[tok_id] = False |
|
|
| unet.requires_grad_(False) |
| vae.requires_grad_(False) |
|
|
| params_to_freeze = itertools.chain( |
| text_encoder.text_model.encoder.parameters(), |
| text_encoder.text_model.final_layer_norm.parameters(), |
| text_encoder.text_model.embeddings.position_embedding.parameters(), |
| ) |
| for param in params_to_freeze: |
| param.requires_grad = False |
|
|
| if cached_latents: |
| vae = None |
| |
| if perform_inversion: |
| ti_optimizer = optim.AdamW( |
| text_encoder.get_input_embeddings().parameters(), |
| lr=ti_lr, |
| betas=(0.9, 0.999), |
| eps=1e-08, |
| weight_decay=weight_decay_ti, |
| ) |
|
|
| lr_scheduler = get_scheduler( |
| lr_scheduler, |
| optimizer=ti_optimizer, |
| num_warmup_steps=lr_warmup_steps, |
| num_training_steps=max_train_steps_ti, |
| ) |
|
|
| train_inversion( |
| unet, |
| vae, |
| text_encoder, |
| train_dataloader, |
| max_train_steps_ti, |
| cached_latents=cached_latents, |
| accum_iter=gradient_accumulation_steps, |
| scheduler=noise_scheduler, |
| index_no_updates=index_no_updates, |
| optimizer=ti_optimizer, |
| lr_scheduler=lr_scheduler, |
| save_steps=save_steps, |
| placeholder_tokens=placeholder_tokens, |
| placeholder_token_ids=placeholder_token_ids, |
| save_path=output_dir, |
| test_image_path=instance_data_dir, |
| log_wandb=log_wandb, |
| wandb_log_prompt_cnt=wandb_log_prompt_cnt, |
| class_token=class_token, |
| train_inpainting=train_inpainting, |
| mixed_precision=False, |
| tokenizer=tokenizer, |
| clip_ti_decay=clip_ti_decay, |
| ) |
|
|
| del ti_optimizer |
|
|
| |
| if not use_extended_lora: |
| unet_lora_params, _ = inject_trainable_lora( |
| unet, |
| r=lora_rank, |
| target_replace_module=lora_unet_target_modules, |
| dropout_p=lora_dropout_p, |
| scale=lora_scale, |
| ) |
| else: |
| print("PTI : USING EXTENDED UNET!!!") |
| lora_unet_target_modules = ( |
| lora_unet_target_modules | UNET_EXTENDED_TARGET_REPLACE |
| ) |
| print("PTI : Will replace modules: ", lora_unet_target_modules) |
|
|
| unet_lora_params, _ = inject_trainable_lora_extended( |
| unet, r=lora_rank, target_replace_module=lora_unet_target_modules |
| ) |
| print(f"PTI : has {len(unet_lora_params)} lora") |
|
|
| print("PTI : Before training:") |
| inspect_lora(unet) |
|
|
| params_to_optimize = [ |
| {"params": itertools.chain(*unet_lora_params), "lr": unet_lr}, |
| ] |
|
|
| text_encoder.requires_grad_(False) |
|
|
| if continue_inversion: |
| params_to_optimize += [ |
| { |
| "params": text_encoder.get_input_embeddings().parameters(), |
| "lr": continue_inversion_lr |
| if continue_inversion_lr is not None |
| else ti_lr, |
| } |
| ] |
| text_encoder.requires_grad_(True) |
| params_to_freeze = itertools.chain( |
| text_encoder.text_model.encoder.parameters(), |
| text_encoder.text_model.final_layer_norm.parameters(), |
| text_encoder.text_model.embeddings.position_embedding.parameters(), |
| ) |
| for param in params_to_freeze: |
| param.requires_grad = False |
| else: |
| text_encoder.requires_grad_(False) |
| if train_text_encoder: |
| text_encoder_lora_params, _ = inject_trainable_lora( |
| text_encoder, |
| target_replace_module=lora_clip_target_modules, |
| r=lora_rank, |
| ) |
| params_to_optimize += [ |
| { |
| "params": itertools.chain(*text_encoder_lora_params), |
| "lr": text_encoder_lr, |
| } |
| ] |
| inspect_lora(text_encoder) |
|
|
| lora_optimizers = optim.AdamW(params_to_optimize, weight_decay=weight_decay_lora) |
|
|
| unet.train() |
| if train_text_encoder: |
| text_encoder.train() |
|
|
| train_dataset.blur_amount = 70 |
|
|
| lr_scheduler_lora = get_scheduler( |
| lr_scheduler_lora, |
| optimizer=lora_optimizers, |
| num_warmup_steps=lr_warmup_steps_lora, |
| num_training_steps=max_train_steps_tuning, |
| ) |
|
|
| perform_tuning( |
| unet, |
| vae, |
| text_encoder, |
| train_dataloader, |
| max_train_steps_tuning, |
| cached_latents=cached_latents, |
| scheduler=noise_scheduler, |
| optimizer=lora_optimizers, |
| save_steps=save_steps, |
| placeholder_tokens=placeholder_tokens, |
| placeholder_token_ids=placeholder_token_ids, |
| save_path=output_dir, |
| lr_scheduler_lora=lr_scheduler_lora, |
| lora_unet_target_modules=lora_unet_target_modules, |
| lora_clip_target_modules=lora_clip_target_modules, |
| mask_temperature=mask_temperature, |
| tokenizer=tokenizer, |
| out_name=out_name, |
| test_image_path=instance_data_dir, |
| log_wandb=log_wandb, |
| wandb_log_prompt_cnt=wandb_log_prompt_cnt, |
| class_token=class_token, |
| train_inpainting=train_inpainting, |
| ) |
|
|
|
|
| def main(): |
| fire.Fire(train) |
|
|