Skip to content

Commit

Permalink
X/dev (#16)
Browse files Browse the repository at this point in the history
Add CLI command, logging, adjust parser, add tests, make package, folder
changes, try/except catches on invalid metadata reads, pyproject.toml
  • Loading branch information
duskfallcrew authored Jan 4, 2025
2 parents 94da0f3 + df2704c commit 6364d8d
Show file tree
Hide file tree
Showing 15 changed files with 350 additions and 243 deletions.
3 changes: 3 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -1 +1,4 @@
.DS_Store
dataset_tools.egg-info
dataset_tools.dataset_tools.egg-info
build
7 changes: 7 additions & 0 deletions CONTRIBUTING.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@

Dev maintenance Notes :

- Update version - python -m _version
- Run with debug-level logging : python -m dataset-tools.main --log debug
- Run with warning-level logging : python -m dataset-tools.main --log warn
- Run with info-level logging : python -m dataset-tools.main --log info
4 changes: 2 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,12 @@ Dataset-Tools is a desktop application designed to help users browse and manage
3. In the Dataset-Tools folder, install the required dependencies:

```sh
pip install -r requirements.txt
pip install .
```
4. Run the application using Python:

```sh
python main.py
dataset-tools
```

### User Interface Overview
Expand Down
Empty file removed __init__.py
Empty file.
16 changes: 16 additions & 0 deletions _version.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
# file generated by setuptools_scm
# don't change, don't track in version control
TYPE_CHECKING = False
if TYPE_CHECKING:
from typing import Tuple, Union
VERSION_TUPLE = Tuple[Union[int, str], ...]
else:
VERSION_TUPLE = object

version: str
__version__: str
__version_tuple__: VERSION_TUPLE
version_tuple: VERSION_TUPLE

__version__ = version = '0.6.dev2+ga1b3e8a.d20250103'
__version_tuple__ = version_tuple = (0, 6, 'dev2', 'ga1b3e8a.d20250103')
21 changes: 8 additions & 13 deletions main.py → dataset_tools/__init__.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
import sys
from PyQt6.QtWidgets import QApplication
from ui import MainWindow # Import our main window class

from importlib.metadata import version, PackageNotFoundError # setuptools-scm versioning
try:
__version__ = version("dataset-tools")
except PackageNotFoundError:
# package is not installed
pass

import logging
import logging
from logging import Logger
import rich
import sys

log_level = "INFO"
Expand All @@ -16,7 +18,7 @@
handler = RichHandler(console=Console(stderr=True))

if handler is None:
handler = logging.StreamHandler(sys.stdout) # same as print
handler = logging.StreamHandler(sys.stdout)
handler.propagate = False

formatter = logging.Formatter(
Expand All @@ -34,10 +36,3 @@
log_level = getattr(logging, log_level)
logger = logging.getLogger(__name__)


if __name__ == "__main__":
app = QApplication(sys.argv)
window = MainWindow() # Initialize our main window.
window.show()
sys.exit(app.exec())

24 changes: 24 additions & 0 deletions dataset_tools/main.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
import sys
import logging
from dataset_tools import logger
from dataset_tools.ui import MainWindow # Import our main window class
import argparse

def main():
parser = argparse.ArgumentParser(description="Set the logging level via command line")

parser.add_argument('--log', default='WARNING', help='Set the logging level (DEBUG, INFO, WARNING, ERROR, CRITICAL)')

args = parser.parse_args()

log_level = getattr(logging, args.log.upper())
logger = logging.getLogger(__name__)

from PyQt6.QtWidgets import QApplication
app = QApplication(sys.argv)
window = MainWindow() # Initialize our main window.
window.show()
sys.exit(app.exec())

if __name__ == "__main__":
main()
143 changes: 143 additions & 0 deletions dataset_tools/metadata_parser.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,143 @@

from png import Reader as pngReader
import zlib
import re
import json
from dataset_tools import logger
from collections import defaultdict
import ast

def open_jpg_header(file_path_named: str):
"""
Open jpg format files\n
:param file_path_named: `str` The path and file name of the jpg file
:return: `Generator[bytes]` Generator element containing header tags
"""
from PIL import Image, ExifTags
from PIL.ExifTags import TAGS

pil_img = Image.open(file_path_named)
exif_info = pil_img._getexif()
exif = {TAGS.get(k, k): v for k, v in exif_info.items()}
return exif

def open_png_header(file_path_named: str) -> bytes:
"""
Open png format files\n
:param file_path_named: `str` The path and file name of the png file
:return: `Generator[bytes]` Generator element containing header bytes
"""
try:
with open(file_path_named, "rb") as f:
png_data = f.read()
reader = pngReader(bytes=png_data)
header_chunks = reader.chunks()
except Exception as error_log:
logger.info(f"Error reading png file: {file_path_named} {error_log}")
logger.debug(f"{file_path_named} {error_log}")
else:
return header_chunks


def extract_metadata_chunks(header_chunks: bytes,
text_prefix: tuple = (b"tEXt", b"iTXt"),
search_key: tuple = (b"parameters", b"prompt")
) -> bytes:
"""
Scan metadata chunks, then extract relevant data\n
:param header_chunks: `Generator[bytes]` Data header from relevant file
:param text_prefix: `tuple` Values that precede text bytes
:param search_key: `tuple` Values that precede text we are looking for
:return: `bytes` Byte string from the header of relevant data
"""
for chunk_name, chunk_data in header_chunks:
if chunk_name in text_prefix:
parts = chunk_data.split(b"\x00", 3)
key, *_, text_chunk = parts
if chunk_name == b"iTXt" and parts[1] == b'\x00':
try:
text_chunk = zlib.decompress(text_chunk)
except Exception as error_log:
logger.info(f"Error decompressing: ", error_log)
logger.debug(f"",error_log, exc_info=True)
continue

if key in search_key:
return text_chunk

def clean_string_with_json(formatted_text:str) -> dict:
""""
Convert data into a clean dictionary\n
:param pre_cleaned_text: `str` Unstructured utf-8 formatted string
:return: `dict` Dictionary formatted data
"""
if next(iter(formatted_text)) != "{":
formatted_text = restructure_metadata(formatted_text)
formatted_text = str(formatted_text).replace("\'","\"").replace('\n', '').strip()
try:
print(formatted_text)
json_structured_text = json.loads(formatted_text)
except Exception as e:
print("Error parsing json directly", e)
else:
return json_structured_text

def format_chunk(text_chunk: bytes) -> dict:
"""
Turn raw bytes into utf8 formatted text\n
:param text_chunk: `bytes` Data from a file header
:return: `dict` text data in a dict structure
"""
try:
formatted_text = text_chunk.decode('utf-8', errors='ignore')
except Exception as error_log:
logger.info("Error decoding: ", error_log)
logger.debug(f"",error_log, exc_info=True)
else:
json_structured_text = clean_string_with_json(formatted_text)
logger.debug(f"Decoded Text: {json_structured_text}")
return json_structured_text

def restructure_metadata(formatted_text: str) -> dict:
"""
Reconstruct metadata header format into a dict\n
:param formatted_text: `str` Unstructured utf-8 formatted text
:return: `dict` The text formatted into a valid dictionary structure
"""
pre_cleaned_text = defaultdict(dict)

start_idx = formatted_text.find("POS\"") + 1
end_idx = formatted_text.find("\"", start_idx)
positive_string = formatted_text[start_idx:end_idx].strip()

start_idx = formatted_text.find("Neg") + 1
end_idx = formatted_text.find("\"", start_idx)
negative_string = formatted_text[start_idx:end_idx].strip()

start_idx = formatted_text.find("Hashes") + len("Hashes:")
end_idx = formatted_text.find("\"", start_idx)
hash_string = formatted_text[start_idx:end_idx].strip()

positive = positive_string.replace("\'","\"").replace('\n', '').strip()
negative = negative_string.replace("\'","\"").replace('\n', '').strip()
text_split = formatted_text.strip().split('\n')

for strip in text_split:
mapped_metadata = {}
for key, value in re.findall(r'(\w+):\s*(\d+(?:\.\d+)?)', strip):
mapped_metadata.setdefault(key.replace("\'","\"").replace('\n', '').strip(), value.replace("\'","\"").replace('\n', '').strip())
pre_cleaned_text = mapped_metadata | {"Hashes": hash_string, "Positive prompt": positive_string, "Negative prompt": negative_string }
return pre_cleaned_text


def parse_metadata(file_path_named: str) -> dict:
"""
Extract the metadata from the header of an image file\n
:param file_path_named: `str` The file to open
:return: `dict` The metadata from the header of the file
"""
header_chunks = open_png_header(file_path_named)
if header_chunks is not None:
text_chunk = extract_metadata_chunks(header_chunks)
json_structured_text = format_chunk(text_chunk)
return json_structured_text
35 changes: 35 additions & 0 deletions dataset_tools/test_md_ps.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@

import unittest
from unittest.mock import patch, mock_open, Mock
import zlib

from dataset_tools.metadata_parser import open_png_header, extract_metadata_chunks

class TestParseMetadata(unittest.TestCase):

@patch('dataset_tools.metadata_parser.open_png_header.pngReader', 'chunks', return_value=[(b'tEXt', b'data')])
def test_parse_metadata_success(self, chunks):
mock_file = mock_open(read_data=b'\x89PNG\r\n\x1a\nIHDR...')
with patch('builtins.open', mock_file, create=True):
chunks = open_png_header("mock_path")
self.assertIsNotNone(chunks)
self.assertTrue(list(chunks)) # Confirm it's not empty

if __name__ == '__main__':
@patch('builtins.open', side_effect=IOError)
def test_parse_metadata_failure(self, mock_file):
self.assertIsNone(open_png_header("nonexistent.png"))

def mock_header_chunks(self, data):
# Mock generator
return ((b"tEXt", b"parameters\x00example"), (b"iTXt", b"prompt\x00\x00\x00CompVal"))

def test_extract_metadata(self):
compressed_val = zlib.compress(b'Test Value')
mock_data = [(b"tEXt", b"parameters\x00metadata"),
(b"iTXt", b"prompt\x00\x00\x00\0"+compressed_val)]
assert extract_metadata_chunks(iter(mock_data)) == compressed_val


if __name__ == '__main__':
unittest.main()
Loading

0 comments on commit 6364d8d

Please sign in to comment.