diff --git a/.gitignore b/.gitignore index 1bdf93a..495b931 100644 --- a/.gitignore +++ b/.gitignore @@ -1,4 +1,12 @@ .DS_Store dataset_tools.egg-info dataset_tools.dataset_tools.egg-info -build \ No newline at end of file +build +dataset_tools/__pycache__/__init__.cpython-310.pyc +dataset_tools/__pycache__/access_disk.cpython-310.pyc +dataset_tools/__pycache__/correct_types.cpython-310.pyc +dataset_tools/__pycache__/logger.cpython-310.pyc +dataset_tools/__pycache__/main.cpython-310.pyc +dataset_tools/__pycache__/metadata_parser.cpython-310.pyc +dataset_tools/__pycache__/ui.cpython-310.pyc +dataset_tools/__pycache__/widgets.cpython-310.pyc diff --git a/dataset_tools/correct_types.py b/dataset_tools/correct_types.py index 2d685bf..31f3b01 100644 --- a/dataset_tools/correct_types.py +++ b/dataset_tools/correct_types.py @@ -4,7 +4,13 @@ """確認 Data Type""" from ast import Constant -from typing_extensions import TypedDict, Annotated, List + +from platform import python_version_tuple + +if float(python_version_tuple()[0]) == 3.0 and float(python_version_tuple()[1]) <= 12.0: + from typing_extensions import TypedDict, Annotated, List, Union +else: + from typing import TypedDict, Annotated, List, Union from pydantic import TypeAdapter, BaseModel, Field, AfterValidator, field_validator, ValidationError @@ -71,10 +77,14 @@ class NodeNames: "CLIPTextEncodeSD3", "CLIPTextEncodeSDXL", "CLIPTextEncodeHunyuanDiT", + "CLIPTextEncodePixArtAlpha", + "CLIPTextEncodeSDXLRefiner", "WildcardEncode //Inspire", "ImpactWildcardProcessor", - "CLIPTextEncode", + "ImpactWildcardEncodeCLIPTextEncode", ] + PROMPT_LABELS = ["Positive prompt", "Negative prompt", "Prompt"] + IGNORE_KEYS = [ "type", "link", @@ -84,6 +94,29 @@ class NodeNames: "size", ] + DATA_KEYS = { + "class_type": "inputs", + "nodes": "widget_values", + } + PROMPT_NODE_FIELDS = { + "text", + "t5xxl", + "clip-l", + "clip-g", + "mt5", + "mt5xl", + "bert", + "clip-h", + "wildcard", + "string", + "positive", + "negative", + "text_g", + "text_l", + "wildcard_text", + "populated_text", + } + EXC_INFO: bool = LOG_LEVEL != "i" @@ -109,7 +142,18 @@ def bracket_check(maybe_brackets: str | dict): class NodeDataMap(TypedDict): class_type: str - inputs: dict + inputs: Union[dict, float] + + +class NodeWorkflow(TypedDict): + last_node_id: int + last_link_id: Union[int, dict] + nodes: list + links: list + groups: list + config: dict + extra: dict + version: float class BracketedDict(BaseModel): @@ -130,6 +174,7 @@ class IsThisNode: """ data = TypeAdapter(NodeDataMap) + workflow = TypeAdapter(NodeWorkflow) class ListOfDelineatedStr(BaseModel): diff --git a/dataset_tools/logger.py b/dataset_tools/logger.py index 06e452b..a7ac7f1 100644 --- a/dataset_tools/logger.py +++ b/dataset_tools/logger.py @@ -12,7 +12,7 @@ from rich.logging import RichHandler from rich.style import Style -from dataset_tools.correct_types import LOG_LEVEL, EXC_INFO +from dataset_tools.correct_types import EXC_INFO, LOG_LEVEL msg_init = None # pylint: disable=invalid-name @@ -82,9 +82,9 @@ def wrapper(*args, **kwargs) -> None: return wrapper -def debug_message(message, *args): - logger.debug("%s", f"{message} {args}") +def debug_message(*args): + logger.debug(args) -def info_monitor(message, *args): - logger.info("%s", f"{message} {args}") +def info_monitor(*args): + logger.info(args, exc_info=EXC_INFO) diff --git a/dataset_tools/metadata_parser.py b/dataset_tools/metadata_parser.py index 945c496..7cb0af3 100644 --- a/dataset_tools/metadata_parser.py +++ b/dataset_tools/metadata_parser.py @@ -17,6 +17,7 @@ from dataset_tools.access_disk import MetadataFileReader from dataset_tools.correct_types import ( IsThisNode, + NodeWorkflow, BracketedDict, ListOfDelineatedStr, UpField, @@ -28,8 +29,8 @@ # /______________________________________________________________________________________________________________________ ComfyUI format -@debug_monitor -def clean_with_json(prestructured_data: dict) -> dict: +# @debug_monitor +def clean_with_json(prestructured_data: dict, first_key: str) -> dict: """ Use json loads to arrange half-formatted dict into valid dict\n :param prestructured_data: A dict with a single working key @@ -37,7 +38,7 @@ def clean_with_json(prestructured_data: dict) -> dict: :return: A formatted dictionary object """ try: - cleaned_data = json.loads(prestructured_data) + cleaned_data = json.loads(prestructured_data[first_key]) except JSONDecodeError as error_log: nfo("Attempted to parse invalid formatting on", prestructured_data, error_log) return None @@ -45,44 +46,93 @@ def clean_with_json(prestructured_data: dict) -> dict: @debug_monitor -def rename_next_keys_of(nested_map: dict, search_key: str, new_labels: list) -> dict: +def validate_typical(nested_map: dict, key_name: str): """ - Divide the next layer of a nested dictionary by search criteria\n - Then rename the matching values\n - :param nested_map: Where to start in the nested dictionary - :param search_key: Key to retrieve the data beneath - :param new_labels: A list of labels to apply to the new dictionary - :return: The combined dictionary with specified keys + Check metadata structure and ensure it meets expectations\n + :param nested_map: metadata structure + :type nested_map: dict + :return: The original node map if valid, or None """ - extracted_data = {} - for key_name in nested_map: - is_this_node = IsThisNode() - next_layer = nested_map[key_name] - is_this_node.data.validate_python(next_layer) # be sure we have the right data - if isinstance(new_labels, list): - class_type = next_layer.get("class_type", False) - if search_key in class_type: - class_type_field = next(iter(x for x in new_labels if x not in extracted_data), "") # Name of prompt - inputs_field = next_layer.get("inputs", UpField.PLACEHOLDER).items() - prompt_data = next(iter(v for k, v in inputs_field if isinstance(v, str))) - if prompt_data: - extracted_data[class_type_field] = prompt_data - debug_message(prompt_data) - - else: - node_inputs = next_layer.get("inputs", {"": UpField.PLACEHOLDER}).items() - gen_data = "\n".join(f"{k}: {v}" for k, v in node_inputs if not isinstance(v, list) and v is not None) - if gen_data: - extracted_data[class_type] = gen_data - + is_this_node = IsThisNode() + if next(iter(nested_map[key_name])) in NodeWorkflow.__annotations__.keys(): + try: + is_this_node.workflow.validate_python(nested_map) + except ValidationError as error_log: # + nfo("%s", f"Node workflow not found, returning NoneType {key_name}", error_log) + else: + return nested_map[key_name] + else: + try: + is_this_node.data.validate_python(nested_map[key_name]) # Be sure we have the right data + except ValidationError as error_log: + nfo("%s", "Node data not found", error_log) else: - raise ValidationError(f"Not list format in {new_labels}") + return nested_map[key_name] - return extracted_data + nfo("%s", f"Node workflow not found {key_name}") + raise KeyError(f"Unknown format for dictionary {key_name} in {nested_map}") + + +def search_for_prompt_in(this_layer, previously_extracted_data, node_id): + if node_id in this_layer and this_layer[node_id] in NodeNames.ENCODERS: + for field_name, contents in this_layer[NodeNames.DATA_KEYS[node_id]].items(): + if isinstance(contents, str): + if field_name == "text": + add_label = next((x for x in NodeNames.PROMPT_LABELS if x not in previously_extracted_data), "") + prompt_data = {add_label: contents} + else: + prompt_data = {field_name: contents} + return prompt_data + + elif "inputs" in this_layer: + return {k: v for k, v in this_layer["inputs"].items() if v and not isinstance(v, list)} + return {} + + +def rename_prompt_keys_of(normalized_clean_data): + previously_extracted_data = {} + for layer in normalized_clean_data: + this_layer = validate_typical(normalized_clean_data, layer) + + if this_layer: + search_terms = ((n, nc) for n in NodeNames.DATA_KEYS for nc in NodeNames.ENCODERS) + for node_id, _ in search_terms: + previously_extracted_data.update(search_for_prompt_in(this_layer, previously_extracted_data, node_id)) + + return previously_extracted_data + + +def redivide_nodeui_data_in(header: str, first_key: str) -> Tuple[dict]: + """ + Orchestrate tasks to recreate dictionary structure and extract relevant keys within\n + :param header: Embedded dictionary structure + :type variable: str + :param section_titles: Key names for relevant data segments + :type variable: list + :return: Metadata dict, or empty dicts if not found + """ + + def pack_prompt(mixed_data: dict): + """Convert dictionary to expected formatting""" + packed_data_pass_1 = {k: mixed_data.get(k) for k in NodeNames.PROMPT_LABELS if k in mixed_data} + packed_data_pass_2 = {k: mixed_data.get(k) for k in NodeNames.PROMPT_NODE_FIELDS if k in mixed_data} + prompt_data = packed_data_pass_1 | packed_data_pass_2 + return prompt_data + + try: + jsonified_header = clean_with_json(header, first_key) + if first_key == "workflow": + normalized_clean_data = {"1": jsonified_header} # To match normalized_prompt_structure format + else: + normalized_clean_data = jsonified_header + sorted_header_data = rename_prompt_keys_of(normalized_clean_data) + except KeyError as error_log: + nfo("%s", "No key found.", error_log) + return {"": UpField.PLACEHOLDER, " ": UpField.PLACEHOLDER} + return pack_prompt(sorted_header_data), {k: v for k, v in sorted_header_data.items() if k not in NodeNames.PROMPT_LABELS and k not in NodeNames.PROMPT_NODE_FIELDS} -@debug_monitor def arrange_nodeui_metadata(header_data: dict) -> dict: """ Using the header from a file, run formatting and parsing processes \n @@ -91,25 +141,12 @@ def arrange_nodeui_metadata(header_data: dict) -> dict: :return: Metadata in a standardized format """ - if header_data: - dirty_prompts = header_data.get("prompt") - clean_metadata = clean_with_json(dirty_prompts) - prompt_map = {} - generation_map = {} - prompt_keys = ["Positive prompt", "Negative prompt", "Prompt"] - for encoder_type in NodeNames.ENCODERS: - renamed_metadata = rename_next_keys_of(clean_metadata, encoder_type, prompt_keys) - - generation_map = renamed_metadata.copy() - for keys in prompt_keys: - if renamed_metadata.get(keys) is not None: - prompt_map[keys] = renamed_metadata[keys] - generation_map.pop(keys) - if not prompt_map: - prompt_map = {"": UpField.PLACEHOLDER} - - return {UpField.PROMPT: prompt_map, DownField.GENERATION_DATA: generation_map} - return {"": UpField.PLACEHOLDER} + extracted_prompt_pairs, generation_data_pairs = redivide_nodeui_data_in(header_data, "prompt") + if extracted_prompt_pairs == {}: + gen_pairs_copy = generation_data_pairs.copy() + extracted_prompt_pairs, second_gen_map = redivide_nodeui_data_in(header_data, "workflow") + generation_data_pairs = second_gen_map | gen_pairs_copy + return {UpField.PROMPT: extracted_prompt_pairs or {"": UpField.PLACEHOLDER}, DownField.GENERATION_DATA: generation_data_pairs} # /______________________________________________________________________________________________________________________ A4 format @@ -122,6 +159,7 @@ def delineate_by_esc_codes(text_chunks: dict, extra_delineation: str = "'Negativ :param text_chunk: Data from a file header :return: text data in a dict structure """ + segments = [] replace_strings = ["\xe2\x80\x8b", "\x00", "\u200b", "\n", extra_delineation] dirty_string = text_chunks.get("parameters", text_chunks.get("exif", False)) # Try parameters, then "exif" @@ -175,7 +213,7 @@ def extract_prompts(clean_segments: list) -> Tuple[dict, str]: @debug_monitor -def validate_dictionary_structure(possibly_invalid: str) -> str: +def validate_mapping_bracket_pair_structure_of(possibly_invalid: str) -> str: """ Take a string and prepare it for a conversion to a dict map\n :param possibly_invalid: The string to prepare @@ -212,10 +250,11 @@ def repair_flat_dict(traces_of_pairs: List[str]) -> dict: :type traces_of_kv_pairs: `list :return: A corrected dictionary structure from the kv pairs """ + delineated_str = ListOfDelineatedStr(convert=traces_of_pairs) key, _ = next(iter(traces_of_pairs)) reformed_sub_dict = getattr(delineated_str, next(iter(delineated_str.model_fields))) - validated_string = validate_dictionary_structure(reformed_sub_dict) + validated_string = validate_mapping_bracket_pair_structure_of(reformed_sub_dict) repaired_sub_dict[key] = validated_string return repaired_sub_dict @@ -247,6 +286,7 @@ def make_paired_str_dict(text_to_convert: str) -> dict: :param dehashed_data: Metadata tags with quote and bracket delineated features removed :return: A valid dictionary structure with identical information """ + segmented = text_to_convert.split(", ") delineated = [item.split(": ") for item in segmented if isinstance(item, str) and ": " in item] try: @@ -261,10 +301,11 @@ def make_paired_str_dict(text_to_convert: str) -> dict: def arrange_webui_metadata(header_data: str) -> dict: """ Using the header from a file, send to multiple formatting, cleaning, and parsing, processes \n - Return format : {"Prompts": , "Settings": , "System": } \n + Return format : {"Prompts": , "Settings": , "System": }\n :param header_data: Header data from a file :return: Metadata in a standardized format """ + cleaned_text = delineate_by_esc_codes(header_data) prompt_map, deprompted_text = extract_prompts(cleaned_text) system_map, generation_text = extract_dict_by_delineation(deprompted_text) @@ -276,6 +317,20 @@ def arrange_webui_metadata(header_data: str) -> dict: } +# /______________________________________________________________________________________________________________________ EXIF Tags + + +@debug_monitor +def arrange_exif_metadata(header_data: dict) -> dict: + """Arrange EXIF data into correct format""" + metadata = { + UpField.TAGS: {UpField.EXIF: {key: value} for key, value in header_data.items() if (key != "icc_profile" or key != "exif")}, + DownField.ICC: {DownField.DATA: header_data.get("icc_profile")}, + DownField.EXIF: {DownField.DATA: header_data.get("exif")}, + } + return metadata + + # /______________________________________________________________________________________________________________________ Module Interface @@ -283,37 +338,33 @@ def arrange_webui_metadata(header_data: str) -> dict: def coordinate_metadata_ops(header_data: dict | str, metadata: dict = None) -> dict: """ Process data based on identifying contents\n - :param header_data: metadata extracted from file + :param header_data: Metadata extracted from file :type header_data: dict - :param metadata: the filtered output extracted from header data + :param datatype: The kind of variable storing the metadata + :type datatype: str + :param metadata: The filtered output extracted from header data :type metadata: dict :return: A dict of the metadata inside header data """ - is_dict = isinstance(header_data, dict) - has_prompt = is_dict and header_data.get("prompt") - has_params = is_dict and header_data.get("parameters") - has_tags = is_dict and ("icc_profile" in header_data or "exif" in header_data) - is_str = isinstance(header_data, str) + + has_prompt = isinstance(header_data, dict) and header_data.get("prompt") + has_params = isinstance(header_data, dict) and header_data.get("parameters") + has_tags = isinstance(header_data, dict) and ("icc_profile" in header_data or "exif" in header_data) if has_prompt: metadata = arrange_nodeui_metadata(header_data) elif has_params: metadata = arrange_webui_metadata(header_data) elif has_tags: - metadata = { - UpField.TAGS: {UpField.EXIF: {key: value} for key, value in header_data.items() if (key != "icc_profile" or key != "exif")}, - DownField.ICC: {DownField.DATA: header_data.get("icc_profile")}, - DownField.EXIF: {DownField.DATA: header_data.get("exif")}, - } - elif is_dict: + metadata = arrange_exif_metadata(header_data) + elif isinstance(header_data, dict): try: metadata = {UpField.JSON_DATA: json.loads(f"{header_data}")} except JSONDecodeError as error_log: nfo("JSON Decode failed %s", error_log) - metadata = {UpField.DATA: header_data} - elif is_str: + if not metadata and isinstance(header_data, str): metadata = {UpField.DATA: header_data} - else: + elif not metadata: metadata = {UpField.PLACEHOLDER: {"": UpField.PLACEHOLDER}} return metadata @@ -332,7 +383,6 @@ def parse_metadata(file_path_named: str) -> dict: header_data = reader.read_header(file_path_named) if header_data is not None: metadata = coordinate_metadata_ops(header_data) - debug_message(metadata) if metadata == {UpField.PLACEHOLDER: {"": UpField.PLACEHOLDER}} or not isinstance(metadata, dict): nfo("Unexpected format", file_path_named) return metadata diff --git a/tests/test_md_ps.py b/tests/test_md_ps.py index b3d3474..ac1a83a 100644 --- a/tests/test_md_ps.py +++ b/tests/test_md_ps.py @@ -6,13 +6,15 @@ # pylint: disable=line-too-long, missing-class-docstring import unittest -from unittest.mock import Mock, patch +from unittest.mock import Mock, patch, call from pydantic import ValidationError +from types import NoneType + from dataset_tools.logger import logger from dataset_tools.access_disk import MetadataFileReader -from dataset_tools.correct_types import UpField, DownField +from dataset_tools.correct_types import IsThisNode, UpField, DownField, NodeNames from dataset_tools.metadata_parser import ( arrange_webui_metadata, delineate_by_esc_codes, @@ -21,9 +23,11 @@ extract_prompts, coordinate_metadata_ops, arrange_nodeui_metadata, - validate_dictionary_structure, - rename_next_keys_of, + validate_mapping_bracket_pair_structure_of, + rename_prompt_keys_of, parse_metadata, + redivide_nodeui_data_in, + validate_typical, # clean_with_json ) @@ -36,11 +40,88 @@ def setUp(self): "PonyXLV6_Scores: 4b8555f2fb80, GrungeOutfiPDXL_: b6af61969ec4, GlamorShots_PDXL: 4b8ee3d1bd12, PDXL_FLWRBOY: af38cbdc40f6, PonyXLV6_Scores: 4b8555f2fb80, GrungeOutfiPDXL_: b6af61969ec4, GlamorShots_PDXL: 4b8ee3d1bd12, PDXL_FLWRBOY: af38cbdc40f6", ] self.valid_metadata_sub_map = "{PonyXLV6_Scores: 4b8555f2fb80, GrungeOutfiPDXL_: b6af61969ec4, GlamorShots_PDXL: 4b8ee3d1bd12, PDXL_FLWRBOY: af38cbdc40f6, PonyXLV6_Scores: 4b8555f2fb80, GrungeOutfiPDXL_: b6af61969ec4, GlamorShots_PDXL: 4b8ee3d1bd12, PDXL_FLWRBOY: af38cbdc40f6}" + self.test_is_dict_data = {"prompt": {"2": {"class_type": "deblah", "inputs": {"even_more_blah": "meh"}}}} + self.test_delineate_mock_header = {"parameters": "1 2 3 4\u200b\u200b\u200b5\n6\n7\n8\xe2\x80\x8b\xe2\x80\x8b\xe2\x80\x8b9\n10\n11\n12\x00\x00\u200bbingpot\x00\n"} + self.test_extract_prompts_mock_extract_data = [ + "Lookie Lookie, all, the, terms,in, the prompt, wao", + "Negative prompt: no bad, only 5 fingers", + "theres some other data here whatever", + ] + self.extract_dict_mock_partial_map = 'A: long, test: string, With: {"Some": "useful", "although": "also"}, Some: useless, data: thrown, in: as, well: !!, Only: "The": "best", "Algorithm": "Will", "Successfully": "Match", All, Correctly! !' + self.mock_dict = { + "3": { + "inputs": { + "seed": 948476150837611, + "steps": 60, + "cfg": 12.0, + "sampler_name": "dpmpp_2m", + "scheduler": "karras", + "denoise": 1.0, + "model": ["14", 0], + "positive": ["6", 0], + "negative": ["7", 0], + "latent_image": ["5", 0], + }, + "class_type": "KSampler", + } + } + self.mock_bracket_dict_next_keys = { + "Positive prompt": {"clip_l": "Red gauze tape spun around an invisible hand", "t5xxl": "Red gauze tape spun around an invisible hand"}, + "Negative prompt": {" "}, + } + + self.mock_bracket_dict_gen_data = { + "inputs": { + "seed": 1944739425534, + "steps": 4, + "cfg": 1.0, + "sampler_name": "euler", + "scheduler": "simple", + "denoise": 1.0, + "model": ["136", 0], + "positive": ["110", 0], + "negative": ["110", 1], + "latent_image": ["88", 0], + }, + } + self.redivide_dict_test_data = { + "2": { + "inputs": { + "seed": 1944739425534, + "steps": 4, + "cfg": 1.0, + "sampler_name": "euler", + "scheduler": "simple", + "denoise": 1.0, + "model": ["136", 0], + "positive": ["110", 0], + "negative": ["110", 1], + "latent_image": ["88", 0], + }, + "class_type": "KSampler", + "_meta": {"title": "3 KSampler"}, + }, + "109": { + "inputs": { + "clip_l": "Red gauze tape spun around an invisible hand", + "t5xxl": "Red gauze tape spun around an invisible hand", + "guidance": 1.0, + "clip": ["136", 1], + }, + "class_type": "CLIPTextEncodeFlux", + "_meta": {"title": "CLIPTextEncodeFlux"}, + }, + "113": { + "inputs": {"text": "", "clip": ["136", 1]}, + "class_type": "CLIPTextEncode", + "_meta": {"title": "2b Negative [CLIP Text Encode (Prompt)]"}, + }, + } + self.new_prompt_dict_labels = ["Positive prompt", "Negative prompt", "Prompt"] def test_delineate_by_esc_codes(self): """test""" - mock_header_data = {"parameters": "1 2 3 4\u200b\u200b\u200b5\n6\n7\n8\xe2\x80\x8b\xe2\x80\x8b\xe2\x80\x8b9\n10\n11\n12\x00\x00\u200bbingpot\x00\n"} - formatted_chunk = delineate_by_esc_codes(mock_header_data) + formatted_chunk = delineate_by_esc_codes(self.test_delineate_mock_header) logger.debug("%s", f"{list(x for x in formatted_chunk)}") assert formatted_chunk == [ "1 2 3 4", @@ -57,12 +138,7 @@ def test_delineate_by_esc_codes(self): def test_extract_prompts(self): """test""" - mock_extract_data = [ - "Lookie Lookie, all, the, terms,in, the prompt, wao", - "Negative prompt: no bad, only 5 fingers", - "theres some other data here whatever", - ] - prompt, deprompted_segments = extract_prompts(mock_extract_data) + prompt, deprompted_segments = extract_prompts(self.test_extract_prompts_mock_extract_data) assert deprompted_segments == "theres some other data here whatever" assert prompt == { "Negative prompt": "no bad, only 5 fingers", @@ -71,8 +147,7 @@ def test_extract_prompts(self): def test_extract_dict_by_delineation(self): """test""" - mock_partial_map = 'A: long, test: string, With: {"Some": "useful", "although": "also"}, Some: useless, data: thrown, in: as, well: !!, Only: "The": "best", "Algorithm": "Will", "Successfully": "Match", All, Correctly! !' - hashes, dehashed_text = extract_dict_by_delineation(mock_partial_map) + hashes, dehashed_text = extract_dict_by_delineation(self.extract_dict_mock_partial_map) logger.debug(hashes) logger.debug(dehashed_text) assert hashes == { @@ -126,99 +201,112 @@ def test_arrange_webui_metadata(self): assert mock_delineate_by_esc_codes.call_count == 1 assert UpField.PROMPT in result and DownField.GENERATION_DATA in result and DownField.SYSTEM in result - def test_rename_next_keys_of_not_prompt(self): + def test_rename_prompt_keys_of_not_prompt(self): """test""" - mock_dict = { - "3": { - "inputs": { - "seed": 948476150837611, - "steps": 60, - "cfg": 12.0, - "sampler_name": "dpmpp_2m", - "scheduler": "karras", - "denoise": 1.0, - "model": ["14", 0], - "positive": ["6", 0], - "negative": ["7", 0], - "latent_image": ["5", 0], - }, - "class_type": "KSampler", - } + extracted = rename_prompt_keys_of(self.mock_dict) + assert extracted == { + "seed": 948476150837611, + "steps": 60, + "cfg": 12.0, + "sampler_name": "dpmpp_2m", + "scheduler": "karras", + "denoise": 1.0, } - extracted = rename_next_keys_of( - mock_dict, - "CLIPTextEncode", - ["Positive prompt", "Negative prompt", "Prompt"], - ) - expected_extracted = {"KSampler": "seed: 948476150837611\nsteps: 60\ncfg: 12.0\nsampler_name: dpmpp_2m\nscheduler: karras\ndenoise: 1.0"} - assert extracted == expected_extracted - def test_rename_next_keys_of_prompt(self): + def test_rename_prompt_keys_of_prompt(self): """test""" - mock_dict = { - "3": { - "inputs": {"text": "a long and thought out text prompt"}, - "class_type": "CLIPTextEncode", - } - } - extracted = rename_next_keys_of( - mock_dict, - "CLIPTextEncode", - ["Positive prompt", "Negative prompt", "Prompt"], + extracted = rename_prompt_keys_of( + self.redivide_dict_test_data, ) - expected_extracted = {"Positive prompt": "a long and thought out text prompt"} + expected_extracted = { + "clip_l": "Red gauze tape spun around an invisible hand", + "t5xxl": "Red gauze tape spun around an invisible hand", + "guidance": 1.0, + } | { + "seed": 1944739425534, + "steps": 4, + "cfg": 1.0, + "sampler_name": "euler", + "scheduler": "simple", + "denoise": 1.0, + } assert extracted == expected_extracted - def test_validate_dictionary_structure(self): + def test_validate_mapping_bracket_pair_structure_of(self): """test""" - possibly_valid = validate_dictionary_structure(self.actual_metadata_sub_map) + possibly_valid = validate_mapping_bracket_pair_structure_of(self.actual_metadata_sub_map) assert isinstance(possibly_valid, str) assert possibly_valid == self.valid_metadata_sub_map - def test_arrange_nodeui_metadata(self): - """test""" - self.assertEqual(arrange_nodeui_metadata({}), {"": "No Data"}) + @patch("dataset_tools.metadata_parser.redivide_nodeui_data_in") + def test_arrange_nodeui_metadata_prompt(self, mock_redivide): + mock_input = {"prompt": "prompt"} + mock_redivide.return_value = "prompt", "gen" + result = arrange_nodeui_metadata(mock_input) + mock_redivide.assert_called_with(mock_input, "prompt") + expected_result = {"Generation Data": "gen", "Prompt Data": "prompt"} + assert result == expected_result + + def test_validate_typical(self): + result = validate_typical(self.redivide_dict_test_data, "2") + expected_output = self.redivide_dict_test_data["2"] + assert result == expected_output - next_keys = {"Positive prompt": {"clip_l": "Red gauze tape spun around an invisible hand", "t5xxl": "Red gauze tape spun around an invisible hand"}, "Negative prompt": {" "}} + def test_validate_typical_fail(self): + subdict = {"mock": "data"} + mock_data = {"prompt": subdict} + with self.assertRaises(KeyError): + out = validate_typical(mock_data, "prompt") + assert out is None + + @patch("dataset_tools.metadata_parser.redivide_nodeui_data_in") + def test_arrange_nodeui_metadata_workflow(self, mock_redivide): + full_metadata_key_test = {"prompt": "data", "workflow": "two"} + mock_redivide.side_effect = [({}, {"data": "tokeep"}), ({"Positive": "three"}, {"gen_data": "one"})] + result = arrange_nodeui_metadata(full_metadata_key_test) + mock_redivide.assert_has_calls([call({"prompt": "data", "workflow": "two"}, "prompt"), call({"prompt": "data", "workflow": "two"}, "workflow")], any_order=False) + # mock_redivide.assert_called_with({"1": "dict_data"}, self.new_prompt_dict_labels) + assert result == {"Generation Data": {"data": "tokeep", "gen_data": "one"}, "Prompt Data": {"Positive": "three"}} @patch("dataset_tools.metadata_parser.clean_with_json") - @patch("dataset_tools.metadata_parser.rename_next_keys_of") - def test_arrange_nodeui_metadata_calls(self, mock_rename, mock_clean): + @patch("dataset_tools.metadata_parser.rename_prompt_keys_of") + def test_redivide_nodeui_data_in(self, mock_rename, mock_clean): """test""" - test_data = { - "2": { - "inputs": {"seed": 1944739425534, "steps": 4, "cfg": 1.0, "sampler_name": "euler", "scheduler": "simple", "denoise": 1.0, "model": ["136", 0], "positive": ["110", 0], "negative": ["110", 1], "latent_image": ["88", 0]}, - "class_type": "KSampler", - "_meta": {"title": "3 KSampler"}, - }, - "109": { - "inputs": {"clip_l": "Red gauze tape spun around an invisible hand", "t5xxl": "Red gauze tape spun around an invisible hand", "guidance": 1.0, "clip": ["136", 1]}, - "class_type": "CLIPTextEncodeFlux", - "_meta": {"title": "CLIPTextEncodeFlux"}, - }, - "113": {"inputs": {"text": "", "clip": ["136", 1]}, "class_type": "CLIPTextEncode", "_meta": {"title": "2b Negative [CLIP Text Encode (Prompt)]"}}, - } - gen_data = { - "inputs": {"seed": 1944739425534, "steps": 4, "cfg": 1.0, "sampler_name": "euler", "scheduler": "simple", "denoise": 1.0, "model": ["136", 0], "positive": ["110", 0], "negative": ["110", 1], "latent_image": ["88", 0]}, + + mock_clean.return_value = self.redivide_dict_test_data + mock_rename.return_value = { + "Positive prompt": {"clip_l": "Red gauze tape spun around an invisible hand", "t5xxl": "Red gauze tape spun around an invisible hand"}, + "Negative prompt": {" "}, + "inputs": self.mock_bracket_dict_gen_data["inputs"], } - mock_clean.return_value = test_data - mock_rename.return_value = {"Positive prompt": {"clip_l": "Red gauze tape spun around an invisible hand", "t5xxl": "Red gauze tape spun around an invisible hand"}, "Negative prompt": {" "}, "inputs": gen_data["inputs"]} - result = arrange_nodeui_metadata({"prompt": f"{test_data}"}) - expected_result = {"Generation Data": gen_data, "Prompt Data": self.next_keys} - mock_clean.assert_called_with(str(test_data)) + mock_data_with_prompt = {"prompt": f"{self.redivide_dict_test_data}"} + result = redivide_nodeui_data_in(mock_data_with_prompt, "prompt") + mock_clean.assert_called_with(mock_data_with_prompt, "prompt") + expected_result = ( + self.mock_bracket_dict_next_keys, + self.mock_bracket_dict_gen_data, + ) self.assertEqual(result, expected_result) @patch("dataset_tools.metadata_parser.clean_with_json") - def test_is_dict(self, mock_clean): + @patch("dataset_tools.metadata_parser.rename_prompt_keys_of") + def test_redivide_nodeui_data_empty_prompt(self, mock_rename, mock_clean): """test""" - data = {"prompt": {"2": {"class_type": "deblah", "inputs": {"even_more_blah": "meh"}}}} - mock_clean.return_value = {"2": {"class_type": "deblah", "inputs": {"even_more_blah": "meh"}}} - assert arrange_nodeui_metadata(data) + + mock_clean.return_value = self.mock_dict + mock_rename.return_value = self.mock_dict + result = redivide_nodeui_data_in(f"{self.mock_dict}", "prompt") + print(result) + mock_clean.assert_called_with(str(self.mock_dict), "prompt") + expected_result = ( + {}, + self.mock_dict, + ) + self.assertEqual(result, expected_result) @patch("dataset_tools.metadata_parser.clean_with_json") def test_fail_dict(self, mock_clean): """test""" - # data = {"prompt": {"2": {"class_type": "deblah", "inputs": {"even_more_blah":"meh"} } } } mock_clean.return_value = {"2": {"glass_type": "deblah", "inputs": {"even_more_blah": "meh"}}} assert ValidationError @@ -227,7 +315,7 @@ def test_coordinate_webui(self, mock_webui): """test""" data = {"parameters": {"random": "data"}} mock_webui.return_value = data - result = coordinate_metadata_ops(data) + result = coordinate_metadata_ops(data, dict) mock_webui.assert_called_with(data) assert result == data @@ -236,7 +324,7 @@ def test_coordinate_nodeui(self, mock_node): """test""" data = {"prompt": {"random": "data"}} mock_node.return_value = data - result = coordinate_metadata_ops(data) + result = coordinate_metadata_ops(data, dict) mock_node.assert_called_with(data) assert result == data