Skip to content

Commit

Permalink
cleanup
Browse files Browse the repository at this point in the history
  • Loading branch information
daanelson committed Oct 31, 2024
1 parent c8f9712 commit 017e384
Show file tree
Hide file tree
Showing 8 changed files with 18 additions and 79 deletions.
4 changes: 2 additions & 2 deletions .github/workflows/push-lora.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ jobs:
ANTHROPIC_API_KEY: ${{ secrets.ANTHROPIC_API_KEY }}
REPLICATE_API_TOKEN: ${{ secrets.REPLICATE_API_TOKEN }}
run: |
cog-safe-push -vv ${{ github.event.inputs.no_push_lora == 'true' && '--no-push' || '' }} --config=cog-safe-push-dev-lora.yaml
cog-safe-push -vv ${{ github.event.inputs.no_push_lora == 'true' && '--no-push' || '' }} --config=safe-push-configs/cog-safe-push-dev-lora.yaml
- name: Select schnell-lora
run: |
Expand All @@ -61,4 +61,4 @@ jobs:
ANTHROPIC_API_KEY: ${{ secrets.ANTHROPIC_API_KEY }}
REPLICATE_API_TOKEN: ${{ secrets.REPLICATE_API_TOKEN }}
run: |
cog-safe-push -vv ${{ github.event.inputs.no_push_lora == 'true' && '--no-push' || '' }} --config=cog-safe-push-schnell-lora.yaml
cog-safe-push -vv ${{ github.event.inputs.no_push_lora == 'true' && '--no-push' || '' }} --config=safe-push-configs/cog-safe-push-schnell-lora.yaml
51 changes: 2 additions & 49 deletions .github/workflows/push.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ jobs:
ANTHROPIC_API_KEY: ${{ secrets.ANTHROPIC_API_KEY }}
REPLICATE_API_TOKEN: ${{ secrets.REPLICATE_API_TOKEN }}
run: |
cog-safe-push -vv ${{ github.event.inputs.no_push == 'true' && '--no-push' || '' }} --config=cog-safe-push-schnell.yaml
cog-safe-push -vv ${{ github.event.inputs.no_push == 'true' && '--no-push' || '' }} --config=safe-push-configs/cog-safe-push-schnell.yaml
- name: Select dev
run: |
Expand All @@ -55,51 +55,4 @@ jobs:
ANTHROPIC_API_KEY: ${{ secrets.ANTHROPIC_API_KEY }}
REPLICATE_API_TOKEN: ${{ secrets.REPLICATE_API_TOKEN }}
run: |
cog-safe-push -vv ${{ github.event.inputs.no_push == 'true' && '--no-push' || '' }} --config=cog-safe-push-dev.yaml
cog-safe-push-lora:
runs-on: ubuntu-latest-4-cores
if: github.event.workflow == 'Push LORA Models' # Only run for the LORA workflow

steps:
- uses: actions/checkout@v3

- name: Set up Python
uses: actions/setup-python@v4
with:
python-version: "3.12"

- name: Install Cog
run: |
sudo curl -o /usr/local/bin/cog -L "https://github.com/replicate/cog/releases/latest/download/cog_$(uname -s)_$(uname -m)"
sudo chmod +x /usr/local/bin/cog
- name: cog login
run: |
echo ${{ secrets.COG_TOKEN }} | cog login --token-stdin
- name: Install cog-safe-push
run: |
pip install git+https://github.com/replicate/cog-safe-push.git
- name: Select dev-lora
run: |
./script/select.sh dev-lora
- name: Run cog-safe-push on flux-dev-lora
env:
ANTHROPIC_API_KEY: ${{ secrets.ANTHROPIC_API_KEY }}
REPLICATE_API_TOKEN: ${{ secrets.REPLICATE_API_TOKEN }}
run: |
cog-safe-push -vv ${{ github.event.inputs.no_push_lora == 'true' && '--no-push' || '' }} --config=cog-safe-push-dev-lora.yaml
- name: Select schnell-lora
run: |
./script/select.sh schnell-lora
- name: Run cog-safe-push on flux-schnell-lora
env:
ANTHROPIC_API_KEY: ${{ secrets.ANTHROPIC_API_KEY }}
REPLICATE_API_TOKEN: ${{ secrets.REPLICATE_API_TOKEN }}
run: |
cog-safe-push -vv ${{ github.event.inputs.no_push_lora == 'true' && '--no-push' || '' }} --config=cog-safe-push-schnell-lora.yaml
cog-safe-push -vv ${{ github.event.inputs.no_push == 'true' && '--no-push' || '' }} --config=safe-push-configs/cog-safe-push-dev.yaml
40 changes: 13 additions & 27 deletions fp8/lora_loading.py
Original file line number Diff line number Diff line change
Expand Up @@ -384,11 +384,9 @@ def apply_linear1_lora_weight_to_module(
alpha = rank
else:
alpha = alpha
w_dtype = module_weight.dtype
device = module_weight.device
w_orig = module_weight
w_up = lora_A # .to(dtype=w_dtype, device=device)
w_down = lora_B # .to(dtype=w_dtype, device=device)
w_up = lora_A
w_down = lora_B

if alpha != rank:
w_up = w_up * alpha / rank
Expand All @@ -406,18 +404,13 @@ def apply_linear1_lora_weight_to_module(
v_down = w_down[3072 * 2 : 3072 * 3]
mlp_down = w_down[3072 * 3 :]

# q += lora_scale * torch.mm(q_down, q_up).to(torch.bfloat16)
# k += lora_scale * torch.mm(k_down, k_up).to(torch.bfloat16)
# v += lora_scale * torch.mm(v_down, v_up).to(torch.bfloat16)
# mlp += lora_scale * torch.mm(mlp_down, mlp_up).to(torch.bfloat16)

q = (q.float() + lora_scale * torch.mm(q_down, q_up)).to(torch.bfloat16)
k = (k.float() + lora_scale * torch.mm(k_down, k_up)).to(torch.bfloat16)
v = (v.float() + lora_scale * torch.mm(v_down, v_up)).to(torch.bfloat16)
mlp = (mlp.float() + lora_scale * torch.mm(mlp_down, mlp_up)).to(torch.bfloat16)

fused_weight = torch.cat([q, k, v, mlp], dim=0)
return fused_weight # .to(dtype=w_dtype, device=device)
return fused_weight


@torch.inference_mode()
Expand All @@ -437,11 +430,9 @@ def apply_attn_qkv_lora_weight_to_module(
alpha = rank
else:
alpha = alpha
w_dtype = module_weight.dtype
device = module_weight.device
w_orig = module_weight
w_up = lora_A # .to(dtype=w_dtype, device=device)
w_down = lora_B # .to(dtype=w_dtype, device=device)
w_up = lora_A
w_down = lora_B

if alpha != rank:
w_up = w_up * alpha / rank
Expand All @@ -456,15 +447,12 @@ def apply_attn_qkv_lora_weight_to_module(
k_down = w_down[3072 : 3072 * 2]
v_down = w_down[3072 * 2 : 3072 * 3]

# q += lora_scale * torch.mm(q_down, q_up).to(torch.bfloat16)
# k += lora_scale * torch.mm(k_down, k_up).to(torch.bfloat16)
# v += lora_scale * torch.mm(v_down, v_up).to(torch.bfloat16)
q = (q.float() + lora_scale * torch.mm(q_down, q_up)).to(torch.bfloat16)
k = (k.float() + lora_scale * torch.mm(k_down, k_up)).to(torch.bfloat16)
v = (v.float() + lora_scale * torch.mm(v_down, v_up)).to(torch.bfloat16)

fused_weight = torch.cat([q, k, v], dim=0)
return fused_weight # .to(dtype=w_dtype, device=device)
return fused_weight


@torch.inference_mode()
Expand All @@ -489,18 +477,17 @@ def apply_lora_weight_to_module(
alpha = alpha

w_orig = module_weight
w_up = lora_A # .to(dtype=w_dtype, device=device)
w_down = lora_B # .to(dtype=w_dtype, device=device)
w_up = lora_A
w_down = lora_B

if alpha != rank:
w_up = w_up * alpha / rank
if uneven_rank:
w_down = w_down.repeat_interleave(int(rank_diff), dim=1)
#fused_lora = lora_scale * torch.mm(w_down, w_up).to(torch.bfloat16)
#fused_weight = w_orig + fused_lora

fused_lora = lora_scale * torch.mm(w_down, w_up)
fused_weight = (w_orig.float() + fused_lora).to(torch.bfloat16)
return fused_weight # .to(dtype=w_dtype, device=device)
return fused_weight


@torch.inference_mode()
Expand All @@ -511,7 +498,7 @@ def load_lora(model: Flux, lora_path: str | Path, lora_scale: float = 1.0):
lora_weights = load_file(lora_path, device="cuda")
is_kohya = any(".lora_down.weight" in k for k in lora_weights)

# this is a bit circuitous at the moment but it works
# converting to diffusers to convert from diffusers is a bit circuitous at the moment but it works
if is_kohya:
lora_weights = _convert_kohya_flux_lora_to_diffusers(lora_weights)

Expand Down Expand Up @@ -589,16 +576,15 @@ def apply_lora_to_model(model: Flux, lora_weights: dict, lora_scale: float = 1.0
weight_is_f8 = True
weight_f16 = (
module.float8_data.clone()
#.to(torch.bfloat16)
.float()
.mul(module.scale_reciprocal)
.to(module.weight.device)
.to(torch.bfloat16)
)
elif isinstance(module, torch.nn.Linear):
weight_f16 = module.weight.clone() # .detach()
weight_f16 = module.weight.clone()
elif isinstance(module, CublasLinear):
weight_f16 = module.weight.clone() # .detach()
weight_f16 = module.weight.clone()
lora_sd = get_lora_for_key(key, lora_weights)

assert weight_f16.dtype == torch.bfloat16, f"{key} is {weight_f16.dtype}, not torch.bfloat16"
Expand Down
2 changes: 1 addition & 1 deletion predict.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,7 @@ class SharedInputs:
)
lora_scale: Input = (
Input(
description="Determines how strongly the main LoRA should be applied. Sane results between 0 and 1 for base inference; for go_fast we apply a 1.5x multiplier to this value because we've seen best performance with that. You may still need to experiment to find the best value for your particular lora. ",
description="Determines how strongly the main LoRA should be applied. Sane results between 0 and 1 for base inference. For go_fast we apply a 1.5x multiplier to this value; we've generally seen good performance when scaling the base value by that amount. You may still need to experiment to find the best value for your particular lora.",
default=1.0,
le=5.0,
ge=-5.0,
Expand Down
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.

0 comments on commit 017e384

Please sign in to comment.