From 3eac82f05e8048abcf4db834270b9421af40582e Mon Sep 17 00:00:00 2001 From: Josh Date: Thu, 5 Sep 2024 12:10:50 -0400 Subject: [PATCH 1/8] experiment ui wrapper prototype --- src/social_norms_trees/ui_wrapper.py | 96 ++++++++++++++++++++++++++++ 1 file changed, 96 insertions(+) create mode 100644 src/social_norms_trees/ui_wrapper.py diff --git a/src/social_norms_trees/ui_wrapper.py b/src/social_norms_trees/ui_wrapper.py new file mode 100644 index 00000000..a81beea3 --- /dev/null +++ b/src/social_norms_trees/ui_wrapper.py @@ -0,0 +1,96 @@ +import time +import click +from datetime import datetime +import json + +# Global database +db = {} + +def experiment_setup(): + + print("\n") + participant_id = participant_login() + + print("\n") + origin_tree = load_behavior_tree() + + experiment_id = initialize_experiment_record(participant_id, origin_tree) + + print("\nSetup Complete.\n") + + return participant_id, origin_tree, experiment_id + +def participant_login(): + global db + + participant_id = click.prompt("Please enter your participant id", type=str) + + if participant_id not in db: + db[participant_id] = {} + + return participant_id + + + +def get_behavior_trees(): + #TODO: Get behavior trees from respective data structure + + print("1. Original Tree") + + return ["Original_tree"] + +def load_behavior_tree(): + + tree_array = get_behavior_trees() + tree_index = click.prompt("Please select a behavior tree to load for the experiment (enter the number)", type=int) + return tree_array[tree_index - 1] + + + +def initialize_experiment_record(participant_id, origin_tree): + global db + + if "experiments" not in db[participant_id]: + db[participant_id]["experiments"] = {} + + experiment_id = len(db[participant_id]["experiments"]) + 1 + + experiment_record = { + "experiment_id": experiment_id, + "base_behavior_tree": origin_tree, + "start_date": datetime.now().isoformat(), + "actions": [], + } + + db[participant_id]["experiments"][experiment_id] = experiment_record + + return experiment_id + + +def run_experiment(participant_id, origin_tree, experiment_id): + global db + + #TODO: run actual experiment + print("Running experiment...\n") + db[participant_id]["experiments"][experiment_id]["final_behavior_tree"] = "updated tree" + db[participant_id]["experiments"][experiment_id]["actions"] = ["list", "of", "actions"] + time.sleep(3) + db[participant_id]["experiments"][experiment_id]["end_date"] = datetime.now().isoformat() + print("Experiment done!\n") + + + + + +def main(): + + print("AIT Prototype #1 Simulator") + + for _ in range(3): + participant_id, origin_tree, experiment_id = experiment_setup() + run_experiment(participant_id, origin_tree, experiment_id) + print(json.dumps(db, indent=4)) + + +if __name__ == "__main__": + main() \ No newline at end of file From d258d416da3b8f5ccbeeefa271b5d08220592055 Mon Sep 17 00:00:00 2001 From: Josh Date: Fri, 6 Sep 2024 11:07:38 -0400 Subject: [PATCH 2/8] remove spaces --- src/social_norms_trees/ui_wrapper.py | 7 ++----- 1 file changed, 2 insertions(+), 5 deletions(-) diff --git a/src/social_norms_trees/ui_wrapper.py b/src/social_norms_trees/ui_wrapper.py index a81beea3..9f6e1dcf 100644 --- a/src/social_norms_trees/ui_wrapper.py +++ b/src/social_norms_trees/ui_wrapper.py @@ -20,6 +20,7 @@ def experiment_setup(): return participant_id, origin_tree, experiment_id + def participant_login(): global db @@ -31,7 +32,6 @@ def participant_login(): return participant_id - def get_behavior_trees(): #TODO: Get behavior trees from respective data structure @@ -39,6 +39,7 @@ def get_behavior_trees(): return ["Original_tree"] + def load_behavior_tree(): tree_array = get_behavior_trees() @@ -46,7 +47,6 @@ def load_behavior_tree(): return tree_array[tree_index - 1] - def initialize_experiment_record(participant_id, origin_tree): global db @@ -79,9 +79,6 @@ def run_experiment(participant_id, origin_tree, experiment_id): print("Experiment done!\n") - - - def main(): print("AIT Prototype #1 Simulator") From ecdde1c29da01701a947c173cbef9449ae656291 Mon Sep 17 00:00:00 2001 From: Josh Date: Mon, 9 Sep 2024 10:29:18 -0400 Subject: [PATCH 3/8] put demo together --- src/social_norms_trees/mutate_tree.py | 450 +++++++++++++++++++++++ src/social_norms_trees/serialize_tree.py | 31 ++ src/social_norms_trees/ui_wrapper.py | 134 +++++-- 3 files changed, 584 insertions(+), 31 deletions(-) create mode 100644 src/social_norms_trees/mutate_tree.py create mode 100644 src/social_norms_trees/serialize_tree.py diff --git a/src/social_norms_trees/mutate_tree.py b/src/social_norms_trees/mutate_tree.py new file mode 100644 index 00000000..60c599c1 --- /dev/null +++ b/src/social_norms_trees/mutate_tree.py @@ -0,0 +1,450 @@ +"""Example of using worlds with just an integer for the state of the world""" + +import warnings +from functools import partial, wraps +from itertools import islice +from typing import TypeVar, Optional, List + +import click +import py_trees + + +T = TypeVar("T", bound=py_trees.behaviour.Behaviour) + + +def print_tree(tree: py_trees.behaviour.Behaviour): + tree_display = py_trees.display.unicode_tree(tree) + print(tree_display) + + +def iterate_nodes(tree: py_trees.behaviour.Behaviour): + """ + + Examples: + >>> list(iterate_nodes(py_trees.behaviours.Dummy())) # doctest: +ELLIPSIS + [] + + >>> list(iterate_nodes( + ... py_trees.composites.Sequence("", False, children=[ + ... py_trees.behaviours.Dummy(), + ... ]))) # doctest: +ELLIPSIS +NORMALIZE_WHITESPACE + [, + ] + + >>> list(iterate_nodes( + ... py_trees.composites.Sequence("", False, children=[ + ... py_trees.behaviours.Dummy(), + ... py_trees.behaviours.Dummy(), + ... py_trees.composites.Sequence("", False, children=[ + ... py_trees.behaviours.Dummy(), + ... ]), + ... ]))) # doctest: +ELLIPSIS +NORMALIZE_WHITESPACE + [, + , + , + , + ] + + """ + yield tree + for child in tree.children: + yield from iterate_nodes(child) + + +def enumerate_nodes(tree: py_trees.behaviour.Behaviour): + """ + + Examples: + >>> list(enumerate_nodes(py_trees.behaviours.Dummy())) # doctest: +ELLIPSIS + [(0, )] + + >>> list(enumerate_nodes( + ... py_trees.composites.Sequence("", False, children=[ + ... py_trees.behaviours.Dummy(), + ... ]))) # doctest: +ELLIPSIS +NORMALIZE_WHITESPACE + [(0, ), + (1, )] + + >>> list(enumerate_nodes( + ... py_trees.composites.Sequence("s1", False, children=[ + ... py_trees.behaviours.Dummy(), + ... py_trees.behaviours.Success(), + ... py_trees.composites.Sequence("s2", False, children=[ + ... py_trees.behaviours.Dummy(), + ... ]), + ... py_trees.composites.Sequence("", False, children=[ + ... py_trees.behaviours.Failure(), + ... py_trees.behaviours.Periodic("p", n=1), + ... ]), + ... ]))) # doctest: +ELLIPSIS +NORMALIZE_WHITESPACE + [(0, ), + (1, ), + (2, ), + (3, ), + (4, ), + (5, ), + (6, ), + (7, )] + + """ + return enumerate(iterate_nodes(tree)) + + +def label_tree_lines( + tree: py_trees.behaviour.Behaviour, + labels: List[str], + representation=py_trees.display.unicode_tree, +) -> str: + max_len = max([len(s) for s in labels]) + padded_labels = [s.rjust(max_len) for s in labels] + + tree_representation_lines = representation(tree).split("\n") + enumerated_tree_representation_lines = [ + f"{i}: {t}" for i, t in zip(padded_labels, tree_representation_lines) + ] + + output = "\n".join(enumerated_tree_representation_lines) + return output + + +def format_children_with_indices(composite: py_trees.composites.Composite) -> str: + """ + Examples: + >>> tree = py_trees.composites.Sequence("s1", False, children=[ + ... py_trees.behaviours.Dummy(), + ... py_trees.behaviours.Success(), + ... py_trees.composites.Sequence("s2", False, children=[ + ... py_trees.behaviours.Dummy(), + ... ]), + ... py_trees.composites.Sequence("", False, children=[ + ... py_trees.behaviours.Failure(), + ... py_trees.behaviours.Periodic("p", n=1), + ... ]), + ... ]) + >>> print(format_children_with_indices(tree)) # doctest: +NORMALIZE_WHITESPACE + _: [-] s1 + 0: --> Dummy + 1: --> Success + 2: [-] s2 + _: --> Dummy + 3: [-] + _: --> Failure + _: --> p + """ + index_strings = [] + i = 0 + for b in iterate_nodes(composite): + if b in composite.children: + index_strings.append(str(i)) + i += 1 + else: + index_strings.append("_") + + output = label_tree_lines(composite, index_strings) + return output + + +def format_tree_with_indices(tree: py_trees.behaviour.Behaviour, mode: str = "all") -> str: + """ + Examples: + >>> print(format_tree_with_indices(py_trees.behaviours.Dummy())) + 0: --> Dummy + + >>> tree = py_trees.composites.Sequence("s1", False, children=[ + ... py_trees.behaviours.Dummy(), + ... py_trees.behaviours.Success(), + ... py_trees.composites.Sequence("s2", False, children=[ + ... py_trees.behaviours.Dummy(), + ... ]), + ... py_trees.composites.Sequence("", False, children=[ + ... py_trees.behaviours.Failure(), + ... py_trees.behaviours.Periodic("p", n=1), + ... ]), + ... ]) + >>> print(format_tree_with_indices(tree)) # doctest: +NORMALIZE_WHITESPACE + 0: [-] s1 + 1: --> Dummy + 2: --> Success + 3: [-] s2 + 4: --> Dummy + 5: [-] + 6: --> Failure + 7: --> p + + """ + index_strings = [] + index = 0 + for i, node in enumerate_nodes(tree): + if mode == "children" and node.children: + # Do not number parent nodes in 'children' mode + index_strings.append('_') + elif mode == "parents" and not node.children: + # Do not number child nodes in 'parents' mode + index_strings.append('_') + else: + # Number all nodes in 'all' mode + index_strings.append(str(index)) + index += 1 + + + output = label_tree_lines(tree, index_strings) + + return output + + +def say(message): + print(message) + + +def prompt_identify_node( + tree: py_trees.behaviour.Behaviour, + message: str = "Which node?", + display_nodes: bool = True, + mode: str = "all" +) -> py_trees.behaviour.Behaviour: + node_index = prompt_identify_tree_iterator_index( + tree=tree, message=message, display_nodes=display_nodes, mode=mode + ) + node = next(islice(iterate_nodes(tree), node_index, node_index + 1)) + return node + + +def prompt_identify_tree_iterator_index( + tree: py_trees.behaviour.Behaviour, + message: str = "Which position?", + display_nodes: bool = True, + mode: str = "all" +) -> int: + if display_nodes: + text = f"{format_tree_with_indices(tree, mode)}\n{message}" + else: + text = f"{message}" + node_index = click.prompt( + text=text, + type=int, + ) + return node_index + + +def prompt_identify_child_index( + tree: py_trees.behaviour.Behaviour, + message: str = "Which position?", + display_nodes: bool = True, +) -> int: + if display_nodes: + text = f"{format_children_with_indices(tree)}\n{message}" + else: + text = f"{message}" + node_index = click.prompt( + text=text, + type=int, + ) + return node_index + + +def add_child( + tree: T, + parent: Optional[py_trees.composites.Composite] = None, + child: Optional[py_trees.behaviour.Behaviour] = None, +) -> T: + """Add a behaviour to the tree + + Examples: + >>> tree = py_trees.composites.Sequence("", False, children=[]) + >>> print(py_trees.display.ascii_tree(tree)) # doctest: +NORMALIZE_WHITESPACE + [-] + + >>> print(py_trees.display.ascii_tree(add_child(tree, py_trees.behaviours.Success()))) + ... # doctest: +NORMALIZE_WHITESPACE + [-] + --> Success + + """ + if parent is None: + parent = prompt_identify_node( + tree, f"Which parent node do you want to add the child to?" + ) + if child is None: + child_key = click.prompt( + text="What should the child do?", type=click.Choice(["say"]) + ) + match child_key: + case "say": + message = click.prompt(text="What should it say?", type=str) + + child_function = wraps(say)(partial(say, message)) + child_type = py_trees.meta.create_behaviour_from_function( + child_function + ) + child = child_type() + case _: + raise NotImplementedError() + parent.add_child(child) + return tree + + +def remove_node(tree: T, node: Optional[py_trees.behaviour.Behaviour] = None) -> T: + """Remove a behaviour from the tree + + Examples: + >>> tree = py_trees.composites.Sequence("", False, children=[ + ... py_trees.behaviours.Success(), + ... failure_node := py_trees.behaviours.Failure()]) + >>> print(py_trees.display.ascii_tree(tree)) # doctest: +NORMALIZE_WHITESPACE + [-] + --> Success + --> Failure + + >>> print(py_trees.display.ascii_tree(remove_node(tree, failure_node))) + ... # doctest: +NORMALIZE_WHITESPACE + [-] + --> Success + + """ + if node is None: + node = prompt_identify_node(tree, f"Which node do you want to remove?", True, "children") + parent_node = node.parent + if parent_node is None: + warnings.warn( + f"{node}'s parent is None, so we can't remove it. You can't remove the root node." + ) + return tree + elif isinstance(parent_node, py_trees.composites.Composite): + parent_node.remove_child(node) + else: + raise NotImplementedError() + return tree + + +def move_node( + tree: T, + node: Optional[py_trees.behaviour.Behaviour] = None, + new_parent: Optional[py_trees.behaviour.Behaviour] = None, + index: int = None, +) -> T: + """Exchange two behaviours in the tree + + Examples: + >>> tree = py_trees.composites.Sequence("", False, children=[]) + + """ + + if node is None: + node = prompt_identify_node(tree, f"Which node do you want to move?") + if new_parent is None: + new_parent = prompt_identify_node( + tree, f"What should its parent be?", display_nodes=False + ) + if index is None: + index = prompt_identify_child_index(new_parent) + + assert isinstance(new_parent, py_trees.composites.Composite) + assert isinstance(node.parent, py_trees.composites.Composite) + + node.parent.remove_child(node) + new_parent.insert_child(node, index) + + return tree + + +def exchange_nodes( + tree: T, + node0: Optional[py_trees.behaviour.Behaviour] = None, + node1: Optional[py_trees.behaviour.Behaviour] = None, +) -> T: + """Exchange two behaviours in the tree + + Examples: + >>> tree = py_trees.composites.Sequence("", False, children=[ + ... s:=py_trees.behaviours.Success(), + ... f:=py_trees.behaviours.Failure(), + ... ]) + >>> print(py_trees.display.ascii_tree(tree)) # doctest: +NORMALIZE_WHITESPACE + [-] + --> Success + --> Failure + + >>> print(py_trees.display.ascii_tree(exchange_nodes(tree, s, f))) + ... # doctest: +NORMALIZE_WHITESPACE + [-] + --> Failure + --> Success + + >>> tree = py_trees.composites.Sequence("", False, children=[ + ... a:= py_trees.composites.Sequence("A", False, children=[ + ... py_trees.behaviours.Dummy() + ... ]), + ... py_trees.composites.Sequence("B", False, children=[ + ... py_trees.behaviours.Success(), + ... c := py_trees.composites.Sequence("C", False, children=[]) + ... ]) + ... ]) + >>> print(py_trees.display.ascii_tree(tree)) # doctest: +NORMALIZE_WHITESPACE + [-] + [-] A + --> Dummy + [-] B + --> Success + [-] C + >>> print(py_trees.display.ascii_tree(exchange_nodes(tree, a, c))) + ... # doctest: +NORMALIZE_WHITESPACE + [-] + [-] C + [-] B + --> Success + [-] A + --> Dummy + """ + + if node0 is None: + node0 = prompt_identify_node(tree, f"Which node do you want to switch?") + if node1 is None: + node1 = prompt_identify_node( + tree, f"Which node do you want to switch?", display_nodes=False + ) + + node0_parent, node0_index = node0.parent, node0.parent.children.index(node0) + node1_parent, node1_index = node1.parent, node1.parent.children.index(node1) + + tree = move_node(tree, node0, node1_parent, node1_index) + tree = move_node(tree, node1, node0_parent, node0_index) + + return tree + + +# if __name__ == "__main__": + + +# tree = py_trees.composites.Sequence( +# "", +# False, +# children=[ +# py_trees.behaviours.Dummy(), +# py_trees.behaviours.Dummy(), +# py_trees.composites.Sequence( +# "", +# False, +# children=[ +# py_trees.behaviours.Success(), +# py_trees.behaviours.Dummy(), +# ], +# ), +# py_trees.composites.Sequence( +# "", +# False, +# children=[ +# py_trees.behaviours.Dummy(), +# py_trees.behaviours.Failure(), +# py_trees.behaviours.Dummy(), +# py_trees.behaviours.Running(), +# ], +# ), +# ], +# ) + +# print(py_trees.display.ascii_tree(tree)) +# move_node(tree) +# exchange_nodes(tree) +# remove_node(tree) +# print(format_tree_with_indices(tree)) +# print("Done with demo!") \ No newline at end of file diff --git a/src/social_norms_trees/serialize_tree.py b/src/social_norms_trees/serialize_tree.py new file mode 100644 index 00000000..2db28d69 --- /dev/null +++ b/src/social_norms_trees/serialize_tree.py @@ -0,0 +1,31 @@ +import py_trees + + +def serialize_tree(tree): + def serialize_node(node): + data = { + "type": node.__class__.__name__, + "name": node.name, + "children": [serialize_node(child) for child in node.children], + } + return data + + return serialize_node(tree) + +def deserialize_tree(tree): + def deserialize_node(node): + node_type = node['type'] + children = [deserialize_node(child) for child in node['children']] + + if node_type == 'Sequence': + return py_trees.composites.Sequence(node['name'], False, children=children) + elif node_type == 'Dummy': + return py_trees.behaviours.Dummy(node['name']) + elif node_type == 'Success': + return py_trees.behaviours.Success(node['name']) + elif node_type == 'Failure': + return py_trees.behaviours.Failure(node['name']) + elif node_type == 'Running': + return py_trees.behaviours.Running(node['name']) + + return deserialize_node(tree) \ No newline at end of file diff --git a/src/social_norms_trees/ui_wrapper.py b/src/social_norms_trees/ui_wrapper.py index 9f6e1dcf..cbdb4a5e 100644 --- a/src/social_norms_trees/ui_wrapper.py +++ b/src/social_norms_trees/ui_wrapper.py @@ -2,11 +2,30 @@ import click from datetime import datetime import json +import os +import uuid +import py_trees -# Global database -db = {} +from social_norms_trees.mutate_tree import move_node, exchange_nodes, remove_node +from social_norms_trees.serialize_tree import serialize_tree, deserialize_tree -def experiment_setup(): +DB_FILE = "db.json" + + +def load_db(): + if os.path.exists(DB_FILE): + with open(DB_FILE, "r") as f: + return json.load(f) + else: + return {} + +def save_db(db): + """Saves the Python dictionary back to db.json.""" + + with open(DB_FILE, "w") as f: + json.dump(db, f, indent=4) + +def experiment_setup(db): print("\n") participant_id = participant_login() @@ -14,7 +33,7 @@ def experiment_setup(): print("\n") origin_tree = load_behavior_tree() - experiment_id = initialize_experiment_record(participant_id, origin_tree) + experiment_id = initialize_experiment_record(db, participant_id, origin_tree) print("\nSetup Complete.\n") @@ -22,13 +41,9 @@ def experiment_setup(): def participant_login(): - global db participant_id = click.prompt("Please enter your participant id", type=str) - if participant_id not in db: - db[participant_id] = {} - return participant_id @@ -36,8 +51,34 @@ def get_behavior_trees(): #TODO: Get behavior trees from respective data structure print("1. Original Tree") - - return ["Original_tree"] + return [ + py_trees.composites.Sequence( + "", + False, + children=[ + py_trees.behaviours.Dummy(), + py_trees.behaviours.Dummy(), + py_trees.composites.Sequence( + "", + False, + children=[ + py_trees.behaviours.Success(), + py_trees.behaviours.Dummy(), + ], + ), + py_trees.composites.Sequence( + "", + False, + children=[ + py_trees.behaviours.Dummy(), + py_trees.behaviours.Failure(), + py_trees.behaviours.Dummy(), + py_trees.behaviours.Running(), + ], + ), + ], + ) + ] def load_behavior_tree(): @@ -47,47 +88,78 @@ def load_behavior_tree(): return tree_array[tree_index - 1] -def initialize_experiment_record(participant_id, origin_tree): - global db +def initialize_experiment_record(db, participant_id, origin_tree): - if "experiments" not in db[participant_id]: - db[participant_id]["experiments"] = {} + experiment_id = str(uuid.uuid4()) - experiment_id = len(db[participant_id]["experiments"]) + 1 + #TODO: look into python data class + #TODO: flatten structure of db to simply collction of experiment runs, that will include a field for the participant_id + #instead of grouping by participants experiment_record = { "experiment_id": experiment_id, - "base_behavior_tree": origin_tree, + "participant_id": participant_id, + "base_behavior_tree": serialize_tree(origin_tree), "start_date": datetime.now().isoformat(), "actions": [], } - db[participant_id]["experiments"][experiment_id] = experiment_record + db[experiment_id] = experiment_record return experiment_id -def run_experiment(participant_id, origin_tree, experiment_id): - global db +def run_experiment(db, participant_id, origin_tree, experiment_id): - #TODO: run actual experiment - print("Running experiment...\n") - db[participant_id]["experiments"][experiment_id]["final_behavior_tree"] = "updated tree" - db[participant_id]["experiments"][experiment_id]["actions"] = ["list", "of", "actions"] - time.sleep(3) - db[participant_id]["experiments"][experiment_id]["end_date"] = datetime.now().isoformat() - print("Experiment done!\n") + # Loop for the actual experiment part, which takes user input to decide which action to take + print("\nExperiment begins.\n") + run_simulation = True + while(run_simulation): + + user_choice = click.prompt("Would you like to perform an action on the behavior tree? (y/n)") + + if user_choice == 'y': + print("1. move node") + print("2. exchange node") + print("3. remove node") + action = click.prompt("Please select an action to perform on the behavior tree (enter the number)", type=int) + + if action == 1: + db[experiment_id]["actions"].append("move node") + move_node(origin_tree) + elif action == 2: + db[experiment_id]["actions"].append("exchange node") + exchange_nodes(origin_tree) + elif action == 3: + db[experiment_id]["actions"].append("remove node") + remove_node(origin_tree) + else: + print("Wrong choice, please enter correct number.\n") + + else: + run_simulation = False + db[experiment_id]["final_behavior_tree"] = serialize_tree(origin_tree) + db[experiment_id]["end_date"] = datetime.now().isoformat() + print("\nSimulation has ended.") + -def main(): +def main(): + #TODO: load db from disc now. and each time program runs is 1 run of the program print("AIT Prototype #1 Simulator") + + db = load_db() - for _ in range(3): - participant_id, origin_tree, experiment_id = experiment_setup() - run_experiment(participant_id, origin_tree, experiment_id) - print(json.dumps(db, indent=4)) + + participant_id, origin_tree, experiment_id = experiment_setup(db) + run_experiment(db, participant_id, origin_tree, experiment_id) + + save_db(db) + #TODO: visualize the differences between old and new behavior trees after experiment. + # Potentially use git diff + if __name__ == "__main__": main() \ No newline at end of file From 8d28eb9bca61b1d939950e81ef8e0407bf05aa93 Mon Sep 17 00:00:00 2001 From: Josh Lu Date: Mon, 9 Sep 2024 14:54:41 -0400 Subject: [PATCH 4/8] Update src/social_norms_trees/ui_wrapper.py Co-authored-by: John Gerrard Holland --- src/social_norms_trees/ui_wrapper.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/src/social_norms_trees/ui_wrapper.py b/src/social_norms_trees/ui_wrapper.py index cbdb4a5e..ff262cc3 100644 --- a/src/social_norms_trees/ui_wrapper.py +++ b/src/social_norms_trees/ui_wrapper.py @@ -117,7 +117,11 @@ def run_experiment(db, participant_id, origin_tree, experiment_id): run_simulation = True while(run_simulation): - user_choice = click.prompt("Would you like to perform an action on the behavior tree? (y/n)") + user_choice = click.prompt( + "Would you like to perform an action on the behavior tree?", + show_choices=True, + type=click.Choice(['y', 'n'], case_sensitive=False), + ) if user_choice == 'y': print("1. move node") From 19bc6b4689e5a6a79b6de3b08521eac50f563ac7 Mon Sep 17 00:00:00 2001 From: Josh Date: Mon, 9 Sep 2024 15:58:12 -0400 Subject: [PATCH 5/8] update wording, fix some bugs --- src/social_norms_trees/ui_wrapper.py | 57 ++++++++++++++++------------ 1 file changed, 33 insertions(+), 24 deletions(-) diff --git a/src/social_norms_trees/ui_wrapper.py b/src/social_norms_trees/ui_wrapper.py index ff262cc3..4a56b3d5 100644 --- a/src/social_norms_trees/ui_wrapper.py +++ b/src/social_norms_trees/ui_wrapper.py @@ -12,17 +12,19 @@ DB_FILE = "db.json" -def load_db(): - if os.path.exists(DB_FILE): - with open(DB_FILE, "r") as f: +def load_db(db_file): + if os.path.exists(db_file): + with open(db_file, "r") as f: return json.load(f) else: return {} -def save_db(db): +def save_db(db, db_file): """Saves the Python dictionary back to db.json.""" - with open(DB_FILE, "w") as f: + print(f"\nWriting results of simulation to {db_file}...") + + with open(db_file, "w") as f: json.dump(db, f, indent=4) def experiment_setup(db): @@ -112,11 +114,11 @@ def initialize_experiment_record(db, participant_id, origin_tree): def run_experiment(db, participant_id, origin_tree, experiment_id): # Loop for the actual experiment part, which takes user input to decide which action to take - print("\nExperiment begins.\n") - - run_simulation = True - while(run_simulation): + print("\nExperiment beginning...\n") + while(True): + + print(py_trees.display.ascii_tree(origin_tree)) user_choice = click.prompt( "Would you like to perform an action on the behavior tree?", show_choices=True, @@ -124,44 +126,51 @@ def run_experiment(db, participant_id, origin_tree, experiment_id): ) if user_choice == 'y': - print("1. move node") - print("2. exchange node") - print("3. remove node") - action = click.prompt("Please select an action to perform on the behavior tree (enter the number)", type=int) - - if action == 1: + action = click.prompt( + "1. move node\n" + + "2. exchange node\n" + + "3. remove node\n" + + "Please select an action to perform on the behavior tree", + type=click.Choice(['1', '2', '3'], case_sensitive=False), + show_choices=True + ) + + if action == "1": db[experiment_id]["actions"].append("move node") move_node(origin_tree) - elif action == 2: + elif action == "2": db[experiment_id]["actions"].append("exchange node") exchange_nodes(origin_tree) - elif action == 3: + elif action == "3": db[experiment_id]["actions"].append("remove node") remove_node(origin_tree) else: - print("Wrong choice, please enter correct number.\n") + print("Invalid choice, please select a valid number (1, 2, or 3).\n") else: - run_simulation = False db[experiment_id]["final_behavior_tree"] = serialize_tree(origin_tree) db[experiment_id]["end_date"] = datetime.now().isoformat() - print("\nSimulation has ended.") + break + + return db def main(): - #TODO: load db from disc now. and each time program runs is 1 run of the program print("AIT Prototype #1 Simulator") - db = load_db() + DB_FILE = "db.json" + db = load_db(DB_FILE) participant_id, origin_tree, experiment_id = experiment_setup(db) - run_experiment(db, participant_id, origin_tree, experiment_id) + db = run_experiment(db, participant_id, origin_tree, experiment_id) - save_db(db) + save_db(db, DB_FILE) + print("\nSimulation has ended.") + #TODO: visualize the differences between old and new behavior trees after experiment. # Potentially use git diff From d760b233fdf9f9bd92940e41414559acc64bdd8d Mon Sep 17 00:00:00 2001 From: Josh Date: Tue, 10 Sep 2024 10:08:12 -0400 Subject: [PATCH 6/8] remove previous unecessary changes --- src/social_norms_trees/mutate_tree.py | 64 +++------------------------ 1 file changed, 6 insertions(+), 58 deletions(-) diff --git a/src/social_norms_trees/mutate_tree.py b/src/social_norms_trees/mutate_tree.py index 60c599c1..46486a55 100644 --- a/src/social_norms_trees/mutate_tree.py +++ b/src/social_norms_trees/mutate_tree.py @@ -144,7 +144,7 @@ def format_children_with_indices(composite: py_trees.composites.Composite) -> st return output -def format_tree_with_indices(tree: py_trees.behaviour.Behaviour, mode: str = "all") -> str: +def format_tree_with_indices(tree: py_trees.behaviour.Behaviour) -> str: """ Examples: >>> print(format_tree_with_indices(py_trees.behaviours.Dummy())) @@ -172,20 +172,8 @@ def format_tree_with_indices(tree: py_trees.behaviour.Behaviour, mode: str = "al 7: --> p """ - index_strings = [] - index = 0 - for i, node in enumerate_nodes(tree): - if mode == "children" and node.children: - # Do not number parent nodes in 'children' mode - index_strings.append('_') - elif mode == "parents" and not node.children: - # Do not number child nodes in 'parents' mode - index_strings.append('_') - else: - # Number all nodes in 'all' mode - index_strings.append(str(index)) - index += 1 + index_strings = [str(i) for i, _ in enumerate_nodes(tree)] output = label_tree_lines(tree, index_strings) @@ -200,10 +188,9 @@ def prompt_identify_node( tree: py_trees.behaviour.Behaviour, message: str = "Which node?", display_nodes: bool = True, - mode: str = "all" ) -> py_trees.behaviour.Behaviour: node_index = prompt_identify_tree_iterator_index( - tree=tree, message=message, display_nodes=display_nodes, mode=mode + tree=tree, message=message, display_nodes=display_nodes ) node = next(islice(iterate_nodes(tree), node_index, node_index + 1)) return node @@ -213,10 +200,9 @@ def prompt_identify_tree_iterator_index( tree: py_trees.behaviour.Behaviour, message: str = "Which position?", display_nodes: bool = True, - mode: str = "all" ) -> int: if display_nodes: - text = f"{format_tree_with_indices(tree, mode)}\n{message}" + text = f"{format_tree_with_indices(tree)}\n{message}" else: text = f"{message}" node_index = click.prompt( @@ -302,7 +288,7 @@ def remove_node(tree: T, node: Optional[py_trees.behaviour.Behaviour] = None) -> """ if node is None: - node = prompt_identify_node(tree, f"Which node do you want to remove?", True, "children") + node = prompt_identify_node(tree, f"Which node do you want to remove?") parent_node = node.parent if parent_node is None: warnings.warn( @@ -409,42 +395,4 @@ def exchange_nodes( tree = move_node(tree, node0, node1_parent, node1_index) tree = move_node(tree, node1, node0_parent, node0_index) - return tree - - -# if __name__ == "__main__": - - -# tree = py_trees.composites.Sequence( -# "", -# False, -# children=[ -# py_trees.behaviours.Dummy(), -# py_trees.behaviours.Dummy(), -# py_trees.composites.Sequence( -# "", -# False, -# children=[ -# py_trees.behaviours.Success(), -# py_trees.behaviours.Dummy(), -# ], -# ), -# py_trees.composites.Sequence( -# "", -# False, -# children=[ -# py_trees.behaviours.Dummy(), -# py_trees.behaviours.Failure(), -# py_trees.behaviours.Dummy(), -# py_trees.behaviours.Running(), -# ], -# ), -# ], -# ) - -# print(py_trees.display.ascii_tree(tree)) -# move_node(tree) -# exchange_nodes(tree) -# remove_node(tree) -# print(format_tree_with_indices(tree)) -# print("Done with demo!") \ No newline at end of file + return tree \ No newline at end of file From c0e111efd7aa3eb53f1bc14c8688d83c2f14e283 Mon Sep 17 00:00:00 2001 From: Josh Date: Tue, 10 Sep 2024 10:46:21 -0400 Subject: [PATCH 7/8] update index to include root when needed --- src/social_norms_trees/mutate_tree.py | 23 +++++++++++++++++------ 1 file changed, 17 insertions(+), 6 deletions(-) diff --git a/src/social_norms_trees/mutate_tree.py b/src/social_norms_trees/mutate_tree.py index 46486a55..4ac02905 100644 --- a/src/social_norms_trees/mutate_tree.py +++ b/src/social_norms_trees/mutate_tree.py @@ -144,7 +144,10 @@ def format_children_with_indices(composite: py_trees.composites.Composite) -> st return output -def format_tree_with_indices(tree: py_trees.behaviour.Behaviour) -> str: +def format_tree_with_indices( + tree: py_trees.behaviour.Behaviour, + show_root: bool = False, +) -> str: """ Examples: >>> print(format_tree_with_indices(py_trees.behaviours.Dummy())) @@ -173,8 +176,14 @@ def format_tree_with_indices(tree: py_trees.behaviour.Behaviour) -> str: """ - index_strings = [str(i) for i, _ in enumerate_nodes(tree)] - + index_strings = [] + index = 0 + for i, node in enumerate_nodes(tree): + if i == 0 and not show_root: + index_strings.append('_') + else: + index_strings.append(str(index)) + index += 1 output = label_tree_lines(tree, index_strings) return output @@ -188,9 +197,10 @@ def prompt_identify_node( tree: py_trees.behaviour.Behaviour, message: str = "Which node?", display_nodes: bool = True, + show_root: bool = False, ) -> py_trees.behaviour.Behaviour: node_index = prompt_identify_tree_iterator_index( - tree=tree, message=message, display_nodes=display_nodes + tree=tree, message=message, display_nodes=display_nodes, show_root=show_root ) node = next(islice(iterate_nodes(tree), node_index, node_index + 1)) return node @@ -200,9 +210,10 @@ def prompt_identify_tree_iterator_index( tree: py_trees.behaviour.Behaviour, message: str = "Which position?", display_nodes: bool = True, + show_root: bool = False, ) -> int: if display_nodes: - text = f"{format_tree_with_indices(tree)}\n{message}" + text = f"{format_tree_with_indices(tree, show_root)}\n{message}" else: text = f"{message}" node_index = click.prompt( @@ -319,7 +330,7 @@ def move_node( node = prompt_identify_node(tree, f"Which node do you want to move?") if new_parent is None: new_parent = prompt_identify_node( - tree, f"What should its parent be?", display_nodes=False + tree, f"What should its parent be?", display_nodes=True, show_root=True ) if index is None: index = prompt_identify_child_index(new_parent) From 5934098ec18173d1757318553822cc838d5a7ad3 Mon Sep 17 00:00:00 2001 From: Josh Date: Tue, 10 Sep 2024 11:28:04 -0400 Subject: [PATCH 8/8] update helper functions for moveNode, to improve visual display --- src/social_norms_trees/mutate_tree.py | 36 +++++++++++++++++++++++++-- 1 file changed, 34 insertions(+), 2 deletions(-) diff --git a/src/social_norms_trees/mutate_tree.py b/src/social_norms_trees/mutate_tree.py index 4ac02905..dd12024d 100644 --- a/src/social_norms_trees/mutate_tree.py +++ b/src/social_norms_trees/mutate_tree.py @@ -140,6 +140,21 @@ def format_children_with_indices(composite: py_trees.composites.Composite) -> st else: index_strings.append("_") + index_strings.append(str(i)) + + output = label_tree_lines(composite, index_strings) + return output + +def format_parents_with_indices(composite: py_trees.composites.Composite) -> str: + index_strings = [] + i = 0 + for b in iterate_nodes(composite): + if b.children: + index_strings.append(str(i)) + else: + index_strings.append("_") + i += 1 + output = label_tree_lines(composite, index_strings) return output @@ -205,6 +220,23 @@ def prompt_identify_node( node = next(islice(iterate_nodes(tree), node_index, node_index + 1)) return node +def prompt_identify_parent_node( + tree: py_trees.behaviour.Behaviour, + message: str = "Which position?", + display_nodes: bool = True, +) -> int: + if display_nodes: + text = f"{format_parents_with_indices(tree)}\n{message}" + else: + text = f"{message}" + node_index = click.prompt( + text=text, + type=int, + ) + + node = next(islice(iterate_nodes(tree), node_index, node_index + 1)) + return node + def prompt_identify_tree_iterator_index( tree: py_trees.behaviour.Behaviour, @@ -329,8 +361,8 @@ def move_node( if node is None: node = prompt_identify_node(tree, f"Which node do you want to move?") if new_parent is None: - new_parent = prompt_identify_node( - tree, f"What should its parent be?", display_nodes=True, show_root=True + new_parent = prompt_identify_parent_node( + tree, f"What should its parent be?", display_nodes=True ) if index is None: index = prompt_identify_child_index(new_parent)