From 639ef5873eb6efecb664981751790af02c6bf87c Mon Sep 17 00:00:00 2001 From: Dan Nelson Date: Tue, 27 Aug 2024 20:04:15 +0000 Subject: [PATCH] linting, formatting --- .github/workflows/ci-cd.yaml | 23 ++++++ flux/modules/conditioner.py | 8 ++- flux/sampling.py | 24 +++---- flux/util.py | 3 +- integration-tests/test-model.py | 67 +++++++----------- predict.py | 122 +++++++++++++++++++++----------- torch_compile.py | 1 - 7 files changed, 146 insertions(+), 102 deletions(-) diff --git a/.github/workflows/ci-cd.yaml b/.github/workflows/ci-cd.yaml index cf85aba..073e85d 100644 --- a/.github/workflows/ci-cd.yaml +++ b/.github/workflows/ci-cd.yaml @@ -10,6 +10,29 @@ on: workflow_dispatch: jobs: + lint: + runs-on: ubuntu-latest + + steps: + - uses: actions/checkout@v3 + + - name: Set up Python + uses: actions/setup-python@v4 + with: + python-version: '3.12' + + - name: Install dependencies + run: | + pip install ruff + + - name: Run ruff linter + run: | + ruff check + + - name: Run ruff formatter + run: | + ruff format --diff + build-and-push: runs-on: ubuntu-latest if: ${{ !contains(github.event.head_commit.message, '[skip-cd]') }} diff --git a/flux/modules/conditioner.py b/flux/modules/conditioner.py index 4001fdb..ec61eee 100644 --- a/flux/modules/conditioner.py +++ b/flux/modules/conditioner.py @@ -10,10 +10,14 @@ def __init__(self, version: str, max_length: int, is_clip=False, **hf_kwargs): self.output_key = "pooler_output" if self.is_clip else "last_hidden_state" if self.is_clip: - self.tokenizer: CLIPTokenizer = CLIPTokenizer.from_pretrained(version + "/tokenizer", max_length=max_length) + self.tokenizer: CLIPTokenizer = CLIPTokenizer.from_pretrained( + version + "/tokenizer", max_length=max_length + ) self.hf_module: CLIPTextModel = CLIPTextModel.from_pretrained(version + "/model", **hf_kwargs) else: - self.tokenizer: T5Tokenizer = T5Tokenizer.from_pretrained(version + "/tokenizer", max_length=max_length) + self.tokenizer: T5Tokenizer = T5Tokenizer.from_pretrained( + version + "/tokenizer", max_length=max_length + ) self.hf_module: T5EncoderModel = T5EncoderModel.from_pretrained(version + "/model", **hf_kwargs) self.hf_module = self.hf_module.eval().requires_grad_(False) diff --git a/flux/sampling.py b/flux/sampling.py index 7a87a25..eb7971a 100644 --- a/flux/sampling.py +++ b/flux/sampling.py @@ -104,7 +104,7 @@ def denoise_single_item( vec: Tensor, timesteps: list[float], guidance: float = 4.0, - compile_run: bool = False + compile_run: bool = False, ): img = img.unsqueeze(0) img_ids = img_ids.unsqueeze(0) @@ -113,14 +113,14 @@ def denoise_single_item( vec = vec.unsqueeze(0) guidance_vec = torch.full((1,), guidance, device=img.device, dtype=img.dtype) - if compile_run: - torch._dynamo.mark_dynamic(img, 1, min=256, max=8100) # needs at least torch 2.4 + if compile_run: + torch._dynamo.mark_dynamic(img, 1, min=256, max=8100) # needs at least torch 2.4 torch._dynamo.mark_dynamic(img_ids, 1, min=256, max=8100) model = torch.compile(model) for t_curr, t_prev in tqdm(zip(timesteps[:-1], timesteps[1:])): t_vec = torch.full((1,), t_curr, dtype=img.dtype, device=img.device) - + pred = model( img=img, img_ids=img_ids, @@ -135,6 +135,7 @@ def denoise_single_item( return img, model + def denoise( model: Flux, # model input @@ -146,28 +147,21 @@ def denoise( # sampling parameters timesteps: list[float], guidance: float = 4.0, - compile_run: bool = False + compile_run: bool = False, ): batch_size = img.shape[0] output_imgs = [] for i in range(batch_size): denoised_img, model = denoise_single_item( - model, - img[i], - img_ids[i], - txt[i], - txt_ids[i], - vec[i], - timesteps, - guidance, - compile_run + model, img[i], img_ids[i], txt[i], txt_ids[i], vec[i], timesteps, guidance, compile_run ) compile_run = False output_imgs.append(denoised_img) - + return torch.cat(output_imgs), model + def unpack(x: Tensor, height: int, width: int) -> Tensor: return rearrange( x, diff --git a/flux/util.py b/flux/util.py index 469ae00..e624ba2 100644 --- a/flux/util.py +++ b/flux/util.py @@ -23,6 +23,7 @@ class ModelSpec: ae_path: str | None ae_url: str | None + T5_URL = "https://weights.replicate.delivery/default/official-models/flux/t5/t5-v1_1-xxl.tar" T5_CACHE = "./model-cache/t5" CLIP_URL = "https://weights.replicate.delivery/default/official-models/flux/clip/clip-vit-large-patch14.tar" @@ -191,4 +192,4 @@ def download_weights(url: str, dest: str): subprocess.check_call(["pget", "-x", url, dest], close_fds=False) else: subprocess.check_call(["pget", url, dest], close_fds=False) - print("downloading took: ", time.time() - start) \ No newline at end of file + print("downloading took: ", time.time() - start) diff --git a/integration-tests/test-model.py b/integration-tests/test-model.py index 94f4552..4501b65 100644 --- a/integration-tests/test-model.py +++ b/integration-tests/test-model.py @@ -16,9 +16,9 @@ from io import BytesIO import numpy as np -ENV = os.getenv('TEST_ENV', 'local') +ENV = os.getenv("TEST_ENV", "local") LOCAL_ENDPOINT = "http://localhost:5000/predictions" -MODEL = os.getenv('MODEL', 'no model configured') +MODEL = os.getenv("MODEL", "no model configured") IS_DEV = "dev" in MODEL @@ -42,14 +42,11 @@ def local_run(model_endpoint: str, model_input: dict): def replicate_run(version: str, model_input: dict): - pred = replicate.predictions.create( - version=version, - input=model_input - ) + pred = replicate.predictions.create(version=version, input=model_input) pred.wait() - - predict_time = pred.metrics['predict_time'] + + predict_time = pred.metrics["predict_time"] images = [] for url in pred.output: response = requests.get(url) @@ -75,9 +72,7 @@ def wait_for_server_to_be_ready(url, timeout=400): if data["status"] == "READY": return elif data["status"] == "SETUP_FAILED": - raise RuntimeError( - "Server initialization failed with status: SETUP_FAILED" - ) + raise RuntimeError("Server initialization failed with status: SETUP_FAILED") except requests.RequestException: pass @@ -90,9 +85,9 @@ def wait_for_server_to_be_ready(url, timeout=400): @pytest.fixture(scope="session") def inference_func(): - if ENV == 'local': + if ENV == "local": return partial(local_run, LOCAL_ENDPOINT) - elif ENV in {'test', 'prod'}: + elif ENV in {"test", "prod"}: model = replicate.models.get(MODEL) version = model.versions.list()[0] return partial(replicate_run, version) @@ -102,23 +97,23 @@ def inference_func(): @pytest.fixture(scope="session", autouse=True) def service(): - if ENV == 'local': + if ENV == "local": print("building model") # starts local server if we're running things locally - build_command = 'cog build -t test-model'.split() + build_command = "cog build -t test-model".split() subprocess.run(build_command, check=True) - container_name = 'cog-test' + container_name = "cog-test" try: - subprocess.check_output(['docker', 'inspect', '--format="{{.State.Running}}"', container_name]) + subprocess.check_output(["docker", "inspect", '--format="{{.State.Running}}"', container_name]) print(f"Container '{container_name}' is running. Stopping and removing...") - subprocess.check_call(['docker', 'stop', container_name]) - subprocess.check_call(['docker', 'rm', container_name]) + subprocess.check_call(["docker", "stop", container_name]) + subprocess.check_call(["docker", "rm", container_name]) print(f"Container '{container_name}' stopped and removed.") except subprocess.CalledProcessError: # Container not found print(f"Container '{container_name}' not found or not running.") - run_command = f'docker run -d -p 5000:5000 --gpus all --name {container_name} test-model '.split() + run_command = f"docker run -d -p 5000:5000 --gpus all --name {container_name} test-model ".split() process = subprocess.Popen(run_command, stdout=sys.stdout, stderr=sys.stderr) wait_for_server_to_be_ready("http://localhost:5000/health-check") @@ -137,11 +132,10 @@ def get_time_bound(): return 20 if IS_DEV else 10 - def test_base_generation(inference_func): """standard generation for dev and schnell. assert that the output image has a dog in it with blip-2 or llava""" test_example = { - 'prompt': "A cool dog", + "prompt": "A cool dog", "aspect ratio": "1:1", "num_outputs": 1, } @@ -157,7 +151,7 @@ def test_num_outputs(inference_func): base_time = None for n_outputs in range(1, 5): test_example = { - 'prompt': "A cool dog", + "prompt": "A cool dog", "aspect ratio": "1:1", "num_outputs": n_outputs, } @@ -169,15 +163,9 @@ def test_num_outputs(inference_func): base_time = time - def test_determinism(inference_func): """determinism - test with the same seed twice""" - test_example = { - 'prompt': "A cool dog", - "aspect_ratio": "9:16", - "num_outputs": 1, - "seed": 112358 - } + test_example = {"prompt": "A cool dog", "aspect_ratio": "9:16", "num_outputs": 1, "seed": 112358} time, out_one = inference_func(test_example) out_one = out_one[0] assert time < get_time_bound() @@ -185,14 +173,14 @@ def test_determinism(inference_func): out_two = out_two[0] assert time_two < get_time_bound() assert out_one.size == (768, 1344) - + one_array = np.array(out_one, dtype=np.uint16) two_array = np.array(out_two, dtype=np.uint16) assert np.allclose(one_array, two_array, atol=20) def test_resolutions(inference_func): - """changing resolutions - iterate through all resolutions and make sure that the output is """ + """changing resolutions - iterate through all resolutions and make sure that the output is""" aspect_ratios = { "1:1": (1024, 1024), "16:9": (1344, 768), @@ -206,12 +194,7 @@ def test_resolutions(inference_func): } for ratio, output in aspect_ratios.items(): - test_example = { - 'prompt': "A cool dog", - "aspect_ratio": ratio, - "num_outputs": 1, - "seed": 112358 - } + test_example = {"prompt": "A cool dog", "aspect_ratio": ratio, "num_outputs": 1, "seed": 112358} time, img_out = inference_func(test_example) img_out = img_out[0] @@ -223,11 +206,11 @@ def test_img2img(inference_func): """img2img. does it work?""" if not IS_DEV: assert True - return + return - test_example= { - 'prompt': 'a cool walrus', - 'image': 'https://replicate.delivery/pbxt/IS6z50uYJFdFeh1vCmXe9zasYbG16HqOOMETljyUJ1hmlUXU/keanu.jpeg', + test_example = { + "prompt": "a cool walrus", + "image": "https://replicate.delivery/pbxt/IS6z50uYJFdFeh1vCmXe9zasYbG16HqOOMETljyUJ1hmlUXU/keanu.jpeg", } _, img_out = inference_func(test_example) diff --git a/predict.py b/predict.py index ea2fbc8..0b9002f 100644 --- a/predict.py +++ b/predict.py @@ -24,33 +24,37 @@ SAFETY_URL = "https://weights.replicate.delivery/default/sdxl/safety-1.0.tar" MAX_IMAGE_SIZE = 1440 + @dataclass class SharedInputs: prompt: Input = Input(description="Prompt for generated image") aspect_ratio: Input = Input( - description="Aspect ratio for the generated image", - choices=["1:1", "16:9", "21:9", "2:3", "3:2", "4:5", "5:4", "9:16", "9:21"], - default="1:1") + description="Aspect ratio for the generated image", + choices=["1:1", "16:9", "21:9", "2:3", "3:2", "4:5", "5:4", "9:16", "9:21"], + default="1:1", + ) num_outputs: Input = Input(description="Number of outputs to generate", default=1, le=4, ge=1) seed: Input = Input(description="Random seed. Set for reproducible generation", default=None) output_format: Input = Input( - description="Format of the output images", - choices=["webp", "jpg", "png"], - default="webp", - ) + description="Format of the output images", + choices=["webp", "jpg", "png"], + default="webp", + ) output_quality: Input = Input( - description="Quality when saving the output images, from 0 to 100. 100 is best quality, 0 is lowest quality. Not relevant for .png outputs", - default=80, - ge=0, - le=100, - ) + description="Quality when saving the output images, from 0 to 100. 100 is best quality, 0 is lowest quality. Not relevant for .png outputs", + default=80, + ge=0, + le=100, + ) disable_safety_checker: Input = Input( - description="Disable safety checker for generated images.", - default=False, + description="Disable safety checker for generated images.", + default=False, ) + SHARED_INPUTS = SharedInputs() + class Predictor(BasePredictor): def setup(self) -> None: return @@ -95,13 +99,12 @@ def base_setup(self, flow_model_name: str, compile: bool) -> None: num_outputs=1, num_inference_steps=self.num_steps, guidance=3.5, - output_format='png', + output_format="png", output_quality=80, disable_safety_checker=True, - seed=123 + seed=123, ) - def aspect_ratio_to_width_height(self, aspect_ratio: str): aspect_ratios = { "1:1": (1024, 1024), @@ -115,7 +118,7 @@ def aspect_ratio_to_width_height(self, aspect_ratio: str): "9:21": (640, 1536), } return aspect_ratios.get(aspect_ratio) - + def get_image(self, image: str): if image is None: return None @@ -128,7 +131,7 @@ def get_image(self, image: str): ) img: torch.Tensor = transform(image) return img[None, ...] - + def predict(): raise Exception("You need to instantiate a predictor for a specific flux model") @@ -141,8 +144,8 @@ def base_predict( output_quality: int, disable_safety_checker: bool, num_inference_steps: int, - guidance: float = 3.5, # schnell ignores guidance within the model, fine to have default - image: Path = None, # img2img for flux-dev + guidance: float = 3.5, # schnell ignores guidance within the model, fine to have default + image: Path = None, # img2img for flux-dev prompt_strength: float = 0.8, seed: Optional[int] = None, ) -> List[Path]: @@ -197,7 +200,7 @@ def base_predict( if self.offload: self.t5, self.clip = self.t5.to(torch_device), self.clip.to(torch_device) - inp = prepare(t5=self.t5, clip=self.clip, img=x, prompt=[prompt]*num_outputs) + inp = prepare(t5=self.t5, clip=self.clip, img=x, prompt=[prompt] * num_outputs) if self.offload: self.t5, self.clip = self.t5.cpu(), self.clip.cpu() @@ -208,7 +211,9 @@ def base_predict( print("Compiling") st = time.time() - x, flux = denoise(self.flux, **inp, timesteps=timesteps, guidance=guidance, compile_run=self.compile_run) + x, flux = denoise( + self.flux, **inp, timesteps=timesteps, guidance=guidance, compile_run=self.compile_run + ) if self.compile_run: print(f"Compiled in {time.time() - st}") @@ -219,7 +224,7 @@ def base_predict( self.flux.cpu() torch.cuda.empty_cache() self.ae.decoder.to(x.device) - + x = unpack(x.float(), height, width) with torch.autocast(device_type=torch_device.type, dtype=torch.bfloat16): x = self.ae.decode(x) @@ -227,12 +232,17 @@ def base_predict( if self.offload: self.ae.decoder.cpu() torch.cuda.empty_cache() - - images = [Image.fromarray((127.5 * (rearrange(x[i], "c h w -> h w c").clamp(-1, 1) + 1.0)).cpu().byte().numpy()) for i in range(num_outputs)] + + images = [ + Image.fromarray( + (127.5 * (rearrange(x[i], "c h w -> h w c").clamp(-1, 1) + 1.0)).cpu().byte().numpy() + ) + for i in range(num_outputs) + ] has_nsfw_content = [False] * len(images) if not disable_safety_checker: - _, has_nsfw_content = self.run_safety_checker(images) # always on gpu - + _, has_nsfw_content = self.run_safety_checker(images) # always on gpu + output_paths = [] for i, (img, is_nsfw) in enumerate(zip(images, has_nsfw_content)): if is_nsfw: @@ -240,16 +250,18 @@ def base_predict( continue output_path = f"out-{i}.{output_format}" - save_params = {'quality': output_quality, 'optimize': True} if output_format != 'png' else {} + save_params = {"quality": output_quality, "optimize": True} if output_format != "png" else {} img.save(output_path, **save_params) output_paths.append(Path(output_path)) if not output_paths: - raise Exception("All generated images contained NSFW content. Try running it again with a different prompt.") + raise Exception( + "All generated images contained NSFW content. Try running it again with a different prompt." + ) print(f"Total safe images: {len(output_paths)} out of {len(images)}") return output_paths - + def run_safety_checker(self, images): safety_checker_input = self.feature_extractor(images, return_tensors="pt").to("cuda") np_images = [np.array(img) for img in images] @@ -259,10 +271,11 @@ def run_safety_checker(self, images): ) return image, has_nsfw_concept + class SchnellPredictor(Predictor): def setup(self) -> None: self.base_setup("flux-schnell", compile=False) - + @torch.inference_mode() def predict( self, @@ -274,30 +287,57 @@ def predict( output_quality: int = SHARED_INPUTS.output_quality, disable_safety_checker: bool = SHARED_INPUTS.disable_safety_checker, ) -> List[Path]: + return self.base_predict( + prompt, + aspect_ratio, + num_outputs, + output_format, + output_quality, + disable_safety_checker, + num_inference_steps=self.num_steps, + seed=seed, + ) - return self.base_predict(prompt, aspect_ratio, num_outputs, output_format, output_quality, disable_safety_checker, num_inference_steps=self.num_steps, seed=seed) - class DevPredictor(Predictor): def setup(self) -> None: self.base_setup("flux-dev", compile=True) - + @torch.inference_mode() def predict( self, prompt: str = SHARED_INPUTS.prompt, aspect_ratio: str = SHARED_INPUTS.aspect_ratio, - image: Path = Input(description="Input image for image to image mode. The aspect ratio of your output will match this image", default=None), - prompt_strength: float = Input(description="Prompt strength when using img2img. 1.0 corresponds to full destruction of information in image", - ge=0.0, le=1.0, default=0.80, + image: Path = Input( + description="Input image for image to image mode. The aspect ratio of your output will match this image", + default=None, + ), + prompt_strength: float = Input( + description="Prompt strength when using img2img. 1.0 corresponds to full destruction of information in image", + ge=0.0, + le=1.0, + default=0.80, ), num_outputs: int = SHARED_INPUTS.num_outputs, - num_inference_steps: int = Input(description="Number of denoising steps. Recommended range is 28-50", ge=1, le=50, default=28), + num_inference_steps: int = Input( + description="Number of denoising steps. Recommended range is 28-50", ge=1, le=50, default=28 + ), guidance: float = Input(description="Guidance for generated image", ge=0, le=10, default=3), seed: int = SHARED_INPUTS.seed, output_format: str = SHARED_INPUTS.output_format, output_quality: int = SHARED_INPUTS.output_quality, disable_safety_checker: bool = SHARED_INPUTS.disable_safety_checker, ) -> List[Path]: - - return self.base_predict(prompt, aspect_ratio, num_outputs, output_format, output_quality, disable_safety_checker, guidance=guidance, image=image, prompt_strength=prompt_strength, num_inference_steps=num_inference_steps, seed=seed) + return self.base_predict( + prompt, + aspect_ratio, + num_outputs, + output_format, + output_quality, + disable_safety_checker, + guidance=guidance, + image=image, + prompt_strength=prompt_strength, + num_inference_steps=num_inference_steps, + seed=seed, + ) diff --git a/torch_compile.py b/torch_compile.py index a4f5eb8..37171a6 100644 --- a/torch_compile.py +++ b/torch_compile.py @@ -1,4 +1,3 @@ - # Imports flux-schnell model from predict.py # from flux.util import load_flow_model # flux = load_flow_model("flux-schnell", device="cuda")