Skip to content

Commit

Permalink
refactor code
Browse files Browse the repository at this point in the history
  • Loading branch information
SkwalExe committed Mar 29, 2024
1 parent 26b1d2c commit 00e60cd
Show file tree
Hide file tree
Showing 6 changed files with 79 additions and 45 deletions.
2 changes: 2 additions & 0 deletions .flake8
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
[flake8]
max-line-length = 120
32 changes: 19 additions & 13 deletions src/main.py
Original file line number Diff line number Diff line change
@@ -1,23 +1,27 @@
import os
from time import time
from utils import *

from utils import BASE_DIR, style_names, styles, logger
from wizard import TextQuestion, SelectQuestion, Wizard, inq_ask
from textual.app import App
from textual.widgets import Header, Footer, Label, LoadingIndicator
from textual.validation import Length
from textual import log
from PIL import Image
from textual.events import Key
from click_extra import extra_command, option, ExtraContext, Parameter
# from textual import log

VERSION = "1.1.1"

from wizard import *

BASIC_INFO_QUESTIONS = [
TextQuestion("name", "Your project's name", [Length(1, failure_description="Your project's name cannot be blank")], "super-octo-project" ),
TextQuestion(
"name",
"Your project's name",
[Length(1, failure_description="Your project's name cannot be blank")],
"super-octo-project"),
SelectQuestion("style", "Logo Style", style_names, "first_letter_underlined")
]


class OctoLogoApp(App):
BINDINGS = [
("ctrl+q", "quit", "Quit"),
Expand All @@ -38,7 +42,6 @@ async def on_key(self, event: Key):
elif event.key == "v" and self.finished:
self.result.show()


def on_wizard_finished(self, message: Wizard.Finished):
# Get the wizard answers and the wizard's id
self.answers.update(message.answers)
Expand All @@ -53,7 +56,7 @@ def on_wizard_finished(self, message: Wizard.Finished):
style_wizard.questions = styles[self.answers['style']].module.questions
style_wizard.title = "Style Settings"
self.mount(style_wizard)
# When the style-specific wizard is finished, create the image and save it
# When the style-specific wizard is finished, create the image and save it
elif finished_wizard_id == "style_wizard":
style = styles[self.answers['style']].module
self.result = style.get_image(self.answers)
Expand All @@ -64,10 +67,12 @@ def on_wizard_finished(self, message: Wizard.Finished):
# Final message
def final_message(self):
self.loading_wid.add_class("hidden")
self.mount(Label(f"Logo saved to [bold]{self.save_to}[/bold].\n[blue blink]-> Press v to view the result[/blue blink]\n[red]Press enter to quit[/red]"))
self.mount(Label(
f"Logo saved to [bold]{self.save_to}[/bold].\n"
f"[blue blink]-> Press v to view the result[/blue blink]\n"
f"[red]Press enter to quit[/red]"))
self.result.save(self.save_to)
self.finished = True


def compose(self):
self.app.title = f"Octo Logo v{VERSION}"
Expand All @@ -84,10 +89,11 @@ def compose(self):

def disable_ansi(ctx: ExtraContext, param: Parameter, val: bool):
ctx.color = not val

# We must return the value for the main function no_ansi parameter not to be None
return val


@extra_command(params=[])
@option("-t", "--no-tui", is_flag=True, help="Dont use the Textual Terminal User Interface")
def main(no_tui: bool):
Expand All @@ -108,10 +114,10 @@ def main(no_tui: bool):
style = styles[answers['style']].module
result = style.get_image(answers)
save_to = f'output/{answers["name"]}_{int(time())}.png'

result.save(save_to)
logger.success(f"Image saved to : {save_to}")


if __name__ == "__main__":
main()
3 changes: 2 additions & 1 deletion src/styles/all_underlined.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,5 +4,6 @@
active = True
questions = underline_core.questions


def get_image(answers):
return underline_core.get_image(answers, "all")
return underline_core.get_image(answers, "all")
3 changes: 2 additions & 1 deletion src/styles/first_letter_underlined.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,5 +4,6 @@
active = True
questions = underline_core.questions


def get_image(answers):
return underline_core.get_image(answers, "first_letter")
return underline_core.get_image(answers, "first_letter")
16 changes: 12 additions & 4 deletions src/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,9 @@
from sys import stdout

logger.remove()
logger.add(stdout, format="[ <green>{time:HH:mm:ss}</green> ]"
logger.add(
stdout,
format="[ <green>{time:HH:mm:ss}</green> ]"
" - <level>{level}</level> -> "
"<level>{message}</level>")

Expand All @@ -32,6 +34,7 @@ def remove_ext(filename):
"""
return filename.split(".")[0]


class Style():
display_name: str
module: Any
Expand All @@ -40,18 +43,20 @@ def __init__(self, display_name: str, module: Any) -> None:
self.display_name = display_name
self.module = module



BASE_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
FONTS_DIR = os.path.join(BASE_DIR, "fonts")
COLORS_DIR = os.path.join(BASE_DIR, "colors")


# Get all the fonts in the fonts directory
def get_font_list() -> list[str]:
return os.listdir(FONTS_DIR)


font_list = get_font_list()


# Get all the color schemes in the colors directory
# keep only files with the .toml extension
# and remove the extension
Expand All @@ -68,8 +73,11 @@ def get_color_schemes() -> dict[str, dict[str, str]]:

return colors


color_schemes = get_color_schemes()
color_scheme_names: dict[str, str] = [(color_schemes[color_scheme]['name'], color_scheme) for color_scheme in color_schemes]
color_scheme_names: dict[str, str] = [
(color_schemes[color_scheme]['name'], color_scheme) for color_scheme in color_schemes]


def get_styles() -> dict[str, Style]:
"""
Expand All @@ -96,4 +104,4 @@ def get_styles() -> dict[str, Style]:

styles = get_styles()

style_names: dict[str, str] = [(styles[style].display_name, style) for style in styles]
style_names: dict[str, str] = [(styles[style].display_name, style) for style in styles]
68 changes: 42 additions & 26 deletions src/wizard.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
from utils import logger
from textual import log
from typing import Any
from textual.validation import Validator
from textual.widgets import Input, Select, Static, Button, Label
Expand All @@ -8,6 +7,8 @@
from textual.message import Message
from textual import on
import inquirer as inq
# from textual import log


class QuestionBase():
name: str
Expand All @@ -18,6 +19,7 @@ def __init__(self, name: str, label: str) -> None:
self.type = type
self.label = label


class TextQuestion(QuestionBase):
validators: list[Validator] | None
placeholder: str
Expand All @@ -27,8 +29,8 @@ class TextQuestion(QuestionBase):
def __init__(
self,
name: str,
label: str,
validators: list[Validator]| None = None,
label: str,
validators: list[Validator] | None = None,
placeholder: str = "",
default_value: str = ""
) -> None:
Expand All @@ -43,14 +45,15 @@ def as_widget(self):
_input.value = self.default_value
return _input


class SelectQuestion(QuestionBase):
options: list
type: str = "select"
default_value: Any | None = None

def __init__(
self,
name: str,
self,
name: str,
label: str,
options: list,
default_value: str | None = None
Expand All @@ -60,16 +63,24 @@ def __init__(
self.default_value = default_value

def as_widget(self):
_select = Select(classes="full-width", id=self.name, options=self.options, allow_blank=False, value=self.default_value)
_select = Select(
classes="full-width",
id=self.name,
options=self.options,
allow_blank=False,
value=self.default_value)

_select.border_title = self.label

return _select


class BackNextButtons(Static):
def compose(self):
yield Button("Back", variant="warning", id="back")
yield Button("Next", variant="success", id="next")


class Wizard(Static):
question_index = reactive(-1)
answers: dict = dict()
Expand All @@ -82,6 +93,7 @@ class Wizard(Static):
class Finished(Message):
answers = dict
wizard_id: str

def __init__(self, answers, wizard_id):
self.answers = answers
self.wizard_id = wizard_id
Expand All @@ -95,7 +107,7 @@ def on_back(self):

@on(Button.Pressed, "#next")
async def on_next(self):
# If the selected question is an input then fire the submit event so that
# If the selected question is an input then fire the submit event so that
# validation is made and the next question is shown.
# Else, just go to the next question since a select cannot be invalid
if isinstance(self.selected_question, Input):
Expand All @@ -112,43 +124,43 @@ def handle_validation_result(self, validation_result: ValidationResult):
self.input_message.add_class("hidden")

else:
# If the validation comports an error then disable the next button,
# If the validation comports an error then disable the next button,
# Show and set the content of the error message and set the input's color to red
self.query_one("#next").disabled = True
self.input_message.remove_class("hidden")
self.input_message.renderable = validation_result.failure_descriptions[0]
self.selected_question.add_class("invalid")


def on_input_changed(self, message: Input.Changed):
# When an input is changed, save its new value into the self.answers dict
self.answers[message.input.id] = message.value

# Show error messages if any
self.handle_validation_result(message.validation_result)

def on_input_submitted(self, message: Input.Submitted):
# Handle the validation result to show
# Handle the validation result to show
# a message if there are any errors
self.handle_validation_result(message.validation_result)

# When the input is submitted, if it is valid then go to the next question
if (message.validation_result.is_valid):
self.question_index += 1



def on_select_changed(self, message: Select.Changed):
# When a select is changed update the value in the self.answers dict
self.answers[message.select.id] = message.value

def compose(self):
# Render directly every input
# They are all hidden by default
# They are all hidden by default
for i, question in enumerate(self.questions):
wid = question.as_widget()

# For every select, the value in the answers dict will not be updated if the user just keeps the default value
# and click next without changing the value, the on_select_changed function will not be called and the
# For every select, the value in the answers dict
# will not be updated if the user just keeps the default value
# and click next without changing the value,
# the on_select_changed function will not be called and the
# answers dict will not contain the key corresponding to the select which can result in a KeyError
if isinstance(wid, Select):
self.answers[wid.id] = question.default_value
Expand All @@ -165,31 +177,29 @@ def compose(self):
self.input_message.styles.color = "tomato"
self.input_message.styles.max_width = "100%"
yield self.input_message

# ----------------------------
yield BackNextButtons()


def on_mount(self):
# Trigger the watch_question_index function to make the first input appear
self.question_index = 0

def watch_question_index(self):

# Remove the selected class from the previous shown input if any
if not self.selected_question is None:
if self.selected_question is not None:
self.selected_question.add_class("hidden")

# If the question index has been incremented but it is now out of bound then
# the user clicked next on the last question
if self.question_index == len(self.questions):
self.post_message(self.Finished(self.answers, self.id))
return

# Put the question index in the border title
self.border_title = f"{self.title} [{self.question_index + 1}/{len(self.questions)}]"


# Show the input corresponding to the new value of self.question_index
self.selected_question = self.query_one(f"#{self.questions[self.question_index].name}")
self.selected_question.remove_class("hidden")
Expand All @@ -199,21 +209,27 @@ def watch_question_index(self):
# the "Back" button since there arent any questions before
self.query_one("#back").disabled = self.question_index == 0


def validate_text_question(question: SelectQuestion | TextQuestion, value: str) -> bool:
for validator in question.validators:
validation = validator.validate(value)
if not validation.is_valid:
logger.error(validation.failure_descriptions[0])
return False

return True


def inq_ask(questions: list[SelectQuestion | TextQuestion]) -> dict[str, Any]:
answers = dict()

for question in questions:
if isinstance(question, SelectQuestion):
answers[question.name] = inq.list_input(question.label, choices=question.options, default=question.default_value)
answers[question.name] = inq.list_input(
question.label,
choices=question.options,
default=question.default_value)

elif isinstance(question, TextQuestion):
while True:
temp = inq.text(question.label, default=question.default_value)
Expand All @@ -223,4 +239,4 @@ def inq_ask(questions: list[SelectQuestion | TextQuestion]) -> dict[str, Any]:
answers[question.name] = temp
break

return answers
return answers

0 comments on commit 00e60cd

Please sign in to comment.