| from typing import List |
| import torch |
| from safetensors import safe_open |
| from diffusers import StableDiffusionPipeline |
| from .lora import ( |
| monkeypatch_or_replace_safeloras, |
| apply_learned_embed_in_clip, |
| set_lora_diag, |
| parse_safeloras_embeds, |
| ) |
|
|
|
|
| def lora_join(lora_safetenors: list): |
| metadatas = [dict(safelora.metadata()) for safelora in lora_safetenors] |
| _total_metadata = {} |
| total_metadata = {} |
| total_tensor = {} |
| total_rank = 0 |
| ranklist = [] |
| for _metadata in metadatas: |
| rankset = [] |
| for k, v in _metadata.items(): |
| if k.endswith("rank"): |
| rankset.append(int(v)) |
|
|
| assert len(set(rankset)) <= 1, "Rank should be the same per model" |
| if len(rankset) == 0: |
| rankset = [0] |
|
|
| total_rank += rankset[0] |
| _total_metadata.update(_metadata) |
| ranklist.append(rankset[0]) |
|
|
| |
| for k, v in _total_metadata.items(): |
| if v != "<embed>": |
| total_metadata[k] = v |
|
|
| tensorkeys = set() |
| for safelora in lora_safetenors: |
| tensorkeys.update(safelora.keys()) |
|
|
| for keys in tensorkeys: |
| if keys.startswith("text_encoder") or keys.startswith("unet"): |
| tensorset = [safelora.get_tensor(keys) for safelora in lora_safetenors] |
|
|
| is_down = keys.endswith("down") |
|
|
| if is_down: |
| _tensor = torch.cat(tensorset, dim=0) |
| assert _tensor.shape[0] == total_rank |
| else: |
| _tensor = torch.cat(tensorset, dim=1) |
| assert _tensor.shape[1] == total_rank |
|
|
| total_tensor[keys] = _tensor |
| keys_rank = ":".join(keys.split(":")[:-1]) + ":rank" |
| total_metadata[keys_rank] = str(total_rank) |
| token_size_list = [] |
| for idx, safelora in enumerate(lora_safetenors): |
| tokens = [k for k, v in safelora.metadata().items() if v == "<embed>"] |
| for jdx, token in enumerate(sorted(tokens)): |
|
|
| total_tensor[f"<s{idx}-{jdx}>"] = safelora.get_tensor(token) |
| total_metadata[f"<s{idx}-{jdx}>"] = "<embed>" |
|
|
| print(f"Embedding {token} replaced to <s{idx}-{jdx}>") |
|
|
| token_size_list.append(len(tokens)) |
|
|
| return total_tensor, total_metadata, ranklist, token_size_list |
|
|
|
|
| class DummySafeTensorObject: |
| def __init__(self, tensor: dict, metadata): |
| self.tensor = tensor |
| self._metadata = metadata |
|
|
| def keys(self): |
| return self.tensor.keys() |
|
|
| def metadata(self): |
| return self._metadata |
|
|
| def get_tensor(self, key): |
| return self.tensor[key] |
|
|
|
|
| class LoRAManager: |
| def __init__(self, lora_paths_list: List[str], pipe: StableDiffusionPipeline): |
|
|
| self.lora_paths_list = lora_paths_list |
| self.pipe = pipe |
| self._setup() |
|
|
| def _setup(self): |
|
|
| self._lora_safetenors = [ |
| safe_open(path, framework="pt", device="cpu") |
| for path in self.lora_paths_list |
| ] |
|
|
| ( |
| total_tensor, |
| total_metadata, |
| self.ranklist, |
| self.token_size_list, |
| ) = lora_join(self._lora_safetenors) |
|
|
| self.total_safelora = DummySafeTensorObject(total_tensor, total_metadata) |
|
|
| monkeypatch_or_replace_safeloras(self.pipe, self.total_safelora) |
| tok_dict = parse_safeloras_embeds(self.total_safelora) |
|
|
| apply_learned_embed_in_clip( |
| tok_dict, |
| self.pipe.text_encoder, |
| self.pipe.tokenizer, |
| token=None, |
| idempotent=True, |
| ) |
|
|
| def tune(self, scales): |
|
|
| assert len(scales) == len( |
| self.ranklist |
| ), "Scale list should be the same length as ranklist" |
|
|
| diags = [] |
| for scale, rank in zip(scales, self.ranklist): |
| diags = diags + [scale] * rank |
|
|
| set_lora_diag(self.pipe.unet, torch.tensor(diags)) |
|
|
| def prompt(self, prompt): |
| if prompt is not None: |
| for idx, tok_size in enumerate(self.token_size_list): |
| prompt = prompt.replace( |
| f"<{idx + 1}>", |
| "".join([f"<s{idx}-{jdx}>" for jdx in range(tok_size)]), |
| ) |
| |
|
|
| return prompt |
|
|