-
Notifications
You must be signed in to change notification settings - Fork 0
/
sdxl_process_data_dir_06.py
104 lines (83 loc) · 3.63 KB
/
sdxl_process_data_dir_06.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
#process_data_dir.py
#scan dir for image-caption.txt pairs
#caches found image-caption.pairs
#mirrors data_dir directory structure in cache_dir
#per image-caption.txt pair creates needed cache files & json file:
#per data_dir, creates data_dir.txt: contains json_files list
#upscale to resolution info:
#no upscale to resolution:
#image >= max_resolution: downscale image to max resolution
#image < min_resolution: delete image
#training resolution range: upscale resolution to max resolution
#upscale to resolution used:
#image >= min_resolution and image < upscale to resolution
#image is upscaled to upscale_to_resolution
#training resolution range: upscale resolution to max resolution
import argparse
import logging
import os
from accelerate import Accelerator
from sdxl_data_functions_18 import data_dir_search, cache_image_caption_pair
#welcome message
print("\nprocess_data_dir: initializing")
#initiate accelerator
accelerator = Accelerator(
mixed_precision="fp16",
)
device = accelerator.device
##arguments
parser = argparse.ArgumentParser()
parser.add_argument("--basename", type=str, default="data", help="The name of the dataset folder: ie '/mnt/storage/comics/' basename would be 'comics'")
parser.add_argument("--data_dir", type=str, default="data", help="'path/to/data_dir' --image-caption.txt directory location")
parser.add_argument("--cache_dir", type=str, default="cache", help="'path/to/cache_dir'data_dir --location to store cached images/captions")
parser.add_argument("--pretrained_model_name_or_path", type=str, default="stabilityai/stable-diffusion-xl-base-1.0", help="'huggingface model, or path to local model")
parser.add_argument("--max_resolution", type=int, default=1536, help="maximum image resolution")
parser.add_argument("--min_resolution", type=int, default=512, help="maximum image resolution")
parser.add_argument("--upscale_to_resolution", type=int, help="upscale image to resolution for caching, use original_size parameter")
parser.add_argument("--upscale_use_GFPGAN", action='store_true', help="after upscale image, use GFPGAN to fix face (use for photos only)")
parser.add_argument("--save_upscale_samples", action='store_true', help="after upscale image, save_upscale_samples")
args = parser.parse_args()
##variables
#dirs
basename = args.basename
data_dir = args.data_dir
cache_dir = args.cache_dir
os.makedirs(cache_dir, exist_ok=True)
#models
pretrained_model_name_or_path = args.pretrained_model_name_or_path
pretrained_vae_model_name_or_path = "madebyollin/sdxl-vae-fp16-fix"
#resolution & upscale
max_resolution = args.max_resolution
min_resolution = args.min_resolution
upscale_to_resolution = args.upscale_to_resolution
save_upscale_samples = args.save_upscale_samples
upscale_use_GFPGAN = args.upscale_use_GFPGAN
##error logging
logging.basicConfig(
filename="error_log.txt", # Specify the log file name
level=logging.ERROR, # Set the logging level to ERROR
format="%(asctime)s - %(levelname)s - %(message)s" # Format for log messages
)
#search for image-caption.txt pairs the directory and subdirectories
#input: data_dir
#return: image_caption_pair_file tuple list
image_caption_files_tuple_list = data_dir_search(data_dir)
##preprocess images/caption, cache latents/hidden_encoder_states
#input:images & captions
#hashes & caches files & saves json list to disk
#returns list of json filepaths
json_file_paths_list = cache_image_caption_pair(
image_caption_files_tuple_list,
pretrained_model_name_or_path,
pretrained_vae_model_name_or_path,
cache_dir,
data_dir,
basename,
accelerator,
device,
max_resolution,
min_resolution,
upscale_to_resolution,
upscale_use_GFPGAN,
save_upscale_samples
)