Skip to content

Commit

Permalink
linting, formatting
Browse files Browse the repository at this point in the history
  • Loading branch information
daanelson committed Aug 27, 2024
1 parent fafb1c8 commit 639ef58
Show file tree
Hide file tree
Showing 7 changed files with 146 additions and 102 deletions.
23 changes: 23 additions & 0 deletions .github/workflows/ci-cd.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,29 @@ on:
workflow_dispatch:

jobs:
lint:
runs-on: ubuntu-latest

steps:
- uses: actions/checkout@v3

- name: Set up Python
uses: actions/setup-python@v4
with:
python-version: '3.12'

- name: Install dependencies
run: |
pip install ruff
- name: Run ruff linter
run: |
ruff check
- name: Run ruff formatter
run: |
ruff format --diff
build-and-push:
runs-on: ubuntu-latest
if: ${{ !contains(github.event.head_commit.message, '[skip-cd]') }}
Expand Down
8 changes: 6 additions & 2 deletions flux/modules/conditioner.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,14 @@ def __init__(self, version: str, max_length: int, is_clip=False, **hf_kwargs):
self.output_key = "pooler_output" if self.is_clip else "last_hidden_state"

if self.is_clip:
self.tokenizer: CLIPTokenizer = CLIPTokenizer.from_pretrained(version + "/tokenizer", max_length=max_length)
self.tokenizer: CLIPTokenizer = CLIPTokenizer.from_pretrained(
version + "/tokenizer", max_length=max_length
)
self.hf_module: CLIPTextModel = CLIPTextModel.from_pretrained(version + "/model", **hf_kwargs)
else:
self.tokenizer: T5Tokenizer = T5Tokenizer.from_pretrained(version + "/tokenizer", max_length=max_length)
self.tokenizer: T5Tokenizer = T5Tokenizer.from_pretrained(
version + "/tokenizer", max_length=max_length
)
self.hf_module: T5EncoderModel = T5EncoderModel.from_pretrained(version + "/model", **hf_kwargs)

self.hf_module = self.hf_module.eval().requires_grad_(False)
Expand Down
24 changes: 9 additions & 15 deletions flux/sampling.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,7 @@ def denoise_single_item(
vec: Tensor,
timesteps: list[float],
guidance: float = 4.0,
compile_run: bool = False
compile_run: bool = False,
):
img = img.unsqueeze(0)
img_ids = img_ids.unsqueeze(0)
Expand All @@ -113,14 +113,14 @@ def denoise_single_item(
vec = vec.unsqueeze(0)
guidance_vec = torch.full((1,), guidance, device=img.device, dtype=img.dtype)

if compile_run:
torch._dynamo.mark_dynamic(img, 1, min=256, max=8100) # needs at least torch 2.4
if compile_run:
torch._dynamo.mark_dynamic(img, 1, min=256, max=8100) # needs at least torch 2.4
torch._dynamo.mark_dynamic(img_ids, 1, min=256, max=8100)
model = torch.compile(model)

for t_curr, t_prev in tqdm(zip(timesteps[:-1], timesteps[1:])):
t_vec = torch.full((1,), t_curr, dtype=img.dtype, device=img.device)

pred = model(
img=img,
img_ids=img_ids,
Expand All @@ -135,6 +135,7 @@ def denoise_single_item(

return img, model


def denoise(
model: Flux,
# model input
Expand All @@ -146,28 +147,21 @@ def denoise(
# sampling parameters
timesteps: list[float],
guidance: float = 4.0,
compile_run: bool = False
compile_run: bool = False,
):
batch_size = img.shape[0]
output_imgs = []

for i in range(batch_size):
denoised_img, model = denoise_single_item(
model,
img[i],
img_ids[i],
txt[i],
txt_ids[i],
vec[i],
timesteps,
guidance,
compile_run
model, img[i], img_ids[i], txt[i], txt_ids[i], vec[i], timesteps, guidance, compile_run
)
compile_run = False
output_imgs.append(denoised_img)

return torch.cat(output_imgs), model


def unpack(x: Tensor, height: int, width: int) -> Tensor:
return rearrange(
x,
Expand Down
3 changes: 2 additions & 1 deletion flux/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ class ModelSpec:
ae_path: str | None
ae_url: str | None


T5_URL = "https://weights.replicate.delivery/default/official-models/flux/t5/t5-v1_1-xxl.tar"
T5_CACHE = "./model-cache/t5"
CLIP_URL = "https://weights.replicate.delivery/default/official-models/flux/clip/clip-vit-large-patch14.tar"
Expand Down Expand Up @@ -191,4 +192,4 @@ def download_weights(url: str, dest: str):
subprocess.check_call(["pget", "-x", url, dest], close_fds=False)
else:
subprocess.check_call(["pget", url, dest], close_fds=False)
print("downloading took: ", time.time() - start)
print("downloading took: ", time.time() - start)
67 changes: 25 additions & 42 deletions integration-tests/test-model.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,9 @@
from io import BytesIO
import numpy as np

ENV = os.getenv('TEST_ENV', 'local')
ENV = os.getenv("TEST_ENV", "local")
LOCAL_ENDPOINT = "http://localhost:5000/predictions"
MODEL = os.getenv('MODEL', 'no model configured')
MODEL = os.getenv("MODEL", "no model configured")
IS_DEV = "dev" in MODEL


Expand All @@ -42,14 +42,11 @@ def local_run(model_endpoint: str, model_input: dict):


def replicate_run(version: str, model_input: dict):
pred = replicate.predictions.create(
version=version,
input=model_input
)
pred = replicate.predictions.create(version=version, input=model_input)

pred.wait()
predict_time = pred.metrics['predict_time']

predict_time = pred.metrics["predict_time"]
images = []
for url in pred.output:
response = requests.get(url)
Expand All @@ -75,9 +72,7 @@ def wait_for_server_to_be_ready(url, timeout=400):
if data["status"] == "READY":
return
elif data["status"] == "SETUP_FAILED":
raise RuntimeError(
"Server initialization failed with status: SETUP_FAILED"
)
raise RuntimeError("Server initialization failed with status: SETUP_FAILED")

except requests.RequestException:
pass
Expand All @@ -90,9 +85,9 @@ def wait_for_server_to_be_ready(url, timeout=400):

@pytest.fixture(scope="session")
def inference_func():
if ENV == 'local':
if ENV == "local":
return partial(local_run, LOCAL_ENDPOINT)
elif ENV in {'test', 'prod'}:
elif ENV in {"test", "prod"}:
model = replicate.models.get(MODEL)
version = model.versions.list()[0]
return partial(replicate_run, version)
Expand All @@ -102,23 +97,23 @@ def inference_func():

@pytest.fixture(scope="session", autouse=True)
def service():
if ENV == 'local':
if ENV == "local":
print("building model")
# starts local server if we're running things locally
build_command = 'cog build -t test-model'.split()
build_command = "cog build -t test-model".split()
subprocess.run(build_command, check=True)
container_name = 'cog-test'
container_name = "cog-test"
try:
subprocess.check_output(['docker', 'inspect', '--format="{{.State.Running}}"', container_name])
subprocess.check_output(["docker", "inspect", '--format="{{.State.Running}}"', container_name])
print(f"Container '{container_name}' is running. Stopping and removing...")
subprocess.check_call(['docker', 'stop', container_name])
subprocess.check_call(['docker', 'rm', container_name])
subprocess.check_call(["docker", "stop", container_name])
subprocess.check_call(["docker", "rm", container_name])
print(f"Container '{container_name}' stopped and removed.")
except subprocess.CalledProcessError:
# Container not found
print(f"Container '{container_name}' not found or not running.")

run_command = f'docker run -d -p 5000:5000 --gpus all --name {container_name} test-model '.split()
run_command = f"docker run -d -p 5000:5000 --gpus all --name {container_name} test-model ".split()
process = subprocess.Popen(run_command, stdout=sys.stdout, stderr=sys.stderr)

wait_for_server_to_be_ready("http://localhost:5000/health-check")
Expand All @@ -137,11 +132,10 @@ def get_time_bound():
return 20 if IS_DEV else 10



def test_base_generation(inference_func):
"""standard generation for dev and schnell. assert that the output image has a dog in it with blip-2 or llava"""
test_example = {
'prompt': "A cool dog",
"prompt": "A cool dog",
"aspect ratio": "1:1",
"num_outputs": 1,
}
Expand All @@ -157,7 +151,7 @@ def test_num_outputs(inference_func):
base_time = None
for n_outputs in range(1, 5):
test_example = {
'prompt': "A cool dog",
"prompt": "A cool dog",
"aspect ratio": "1:1",
"num_outputs": n_outputs,
}
Expand All @@ -169,30 +163,24 @@ def test_num_outputs(inference_func):
base_time = time



def test_determinism(inference_func):
"""determinism - test with the same seed twice"""
test_example = {
'prompt': "A cool dog",
"aspect_ratio": "9:16",
"num_outputs": 1,
"seed": 112358
}
test_example = {"prompt": "A cool dog", "aspect_ratio": "9:16", "num_outputs": 1, "seed": 112358}
time, out_one = inference_func(test_example)
out_one = out_one[0]
assert time < get_time_bound()
time_two, out_two = inference_func(test_example)
out_two = out_two[0]
assert time_two < get_time_bound()
assert out_one.size == (768, 1344)

one_array = np.array(out_one, dtype=np.uint16)
two_array = np.array(out_two, dtype=np.uint16)
assert np.allclose(one_array, two_array, atol=20)


def test_resolutions(inference_func):
"""changing resolutions - iterate through all resolutions and make sure that the output is """
"""changing resolutions - iterate through all resolutions and make sure that the output is"""
aspect_ratios = {
"1:1": (1024, 1024),
"16:9": (1344, 768),
Expand All @@ -206,12 +194,7 @@ def test_resolutions(inference_func):
}

for ratio, output in aspect_ratios.items():
test_example = {
'prompt': "A cool dog",
"aspect_ratio": ratio,
"num_outputs": 1,
"seed": 112358
}
test_example = {"prompt": "A cool dog", "aspect_ratio": ratio, "num_outputs": 1, "seed": 112358}

time, img_out = inference_func(test_example)
img_out = img_out[0]
Expand All @@ -223,11 +206,11 @@ def test_img2img(inference_func):
"""img2img. does it work?"""
if not IS_DEV:
assert True
return
return

test_example= {
'prompt': 'a cool walrus',
'image': 'https://replicate.delivery/pbxt/IS6z50uYJFdFeh1vCmXe9zasYbG16HqOOMETljyUJ1hmlUXU/keanu.jpeg',
test_example = {
"prompt": "a cool walrus",
"image": "https://replicate.delivery/pbxt/IS6z50uYJFdFeh1vCmXe9zasYbG16HqOOMETljyUJ1hmlUXU/keanu.jpeg",
}

_, img_out = inference_func(test_example)
Expand Down
Loading

0 comments on commit 639ef58

Please sign in to comment.