diff --git a/src/social_norms_trees/mutate_tree.py b/src/social_norms_trees/mutate_tree.py new file mode 100644 index 00000000..dd12024d --- /dev/null +++ b/src/social_norms_trees/mutate_tree.py @@ -0,0 +1,441 @@ +"""Example of using worlds with just an integer for the state of the world""" + +import warnings +from functools import partial, wraps +from itertools import islice +from typing import TypeVar, Optional, List + +import click +import py_trees + + +T = TypeVar("T", bound=py_trees.behaviour.Behaviour) + + +def print_tree(tree: py_trees.behaviour.Behaviour): + tree_display = py_trees.display.unicode_tree(tree) + print(tree_display) + + +def iterate_nodes(tree: py_trees.behaviour.Behaviour): + """ + + Examples: + >>> list(iterate_nodes(py_trees.behaviours.Dummy())) # doctest: +ELLIPSIS + [] + + >>> list(iterate_nodes( + ... py_trees.composites.Sequence("", False, children=[ + ... py_trees.behaviours.Dummy(), + ... ]))) # doctest: +ELLIPSIS +NORMALIZE_WHITESPACE + [, + ] + + >>> list(iterate_nodes( + ... py_trees.composites.Sequence("", False, children=[ + ... py_trees.behaviours.Dummy(), + ... py_trees.behaviours.Dummy(), + ... py_trees.composites.Sequence("", False, children=[ + ... py_trees.behaviours.Dummy(), + ... ]), + ... ]))) # doctest: +ELLIPSIS +NORMALIZE_WHITESPACE + [, + , + , + , + ] + + """ + yield tree + for child in tree.children: + yield from iterate_nodes(child) + + +def enumerate_nodes(tree: py_trees.behaviour.Behaviour): + """ + + Examples: + >>> list(enumerate_nodes(py_trees.behaviours.Dummy())) # doctest: +ELLIPSIS + [(0, )] + + >>> list(enumerate_nodes( + ... py_trees.composites.Sequence("", False, children=[ + ... py_trees.behaviours.Dummy(), + ... ]))) # doctest: +ELLIPSIS +NORMALIZE_WHITESPACE + [(0, ), + (1, )] + + >>> list(enumerate_nodes( + ... py_trees.composites.Sequence("s1", False, children=[ + ... py_trees.behaviours.Dummy(), + ... py_trees.behaviours.Success(), + ... py_trees.composites.Sequence("s2", False, children=[ + ... py_trees.behaviours.Dummy(), + ... ]), + ... py_trees.composites.Sequence("", False, children=[ + ... py_trees.behaviours.Failure(), + ... py_trees.behaviours.Periodic("p", n=1), + ... ]), + ... ]))) # doctest: +ELLIPSIS +NORMALIZE_WHITESPACE + [(0, ), + (1, ), + (2, ), + (3, ), + (4, ), + (5, ), + (6, ), + (7, )] + + """ + return enumerate(iterate_nodes(tree)) + + +def label_tree_lines( + tree: py_trees.behaviour.Behaviour, + labels: List[str], + representation=py_trees.display.unicode_tree, +) -> str: + max_len = max([len(s) for s in labels]) + padded_labels = [s.rjust(max_len) for s in labels] + + tree_representation_lines = representation(tree).split("\n") + enumerated_tree_representation_lines = [ + f"{i}: {t}" for i, t in zip(padded_labels, tree_representation_lines) + ] + + output = "\n".join(enumerated_tree_representation_lines) + return output + + +def format_children_with_indices(composite: py_trees.composites.Composite) -> str: + """ + Examples: + >>> tree = py_trees.composites.Sequence("s1", False, children=[ + ... py_trees.behaviours.Dummy(), + ... py_trees.behaviours.Success(), + ... py_trees.composites.Sequence("s2", False, children=[ + ... py_trees.behaviours.Dummy(), + ... ]), + ... py_trees.composites.Sequence("", False, children=[ + ... py_trees.behaviours.Failure(), + ... py_trees.behaviours.Periodic("p", n=1), + ... ]), + ... ]) + >>> print(format_children_with_indices(tree)) # doctest: +NORMALIZE_WHITESPACE + _: [-] s1 + 0: --> Dummy + 1: --> Success + 2: [-] s2 + _: --> Dummy + 3: [-] + _: --> Failure + _: --> p + """ + index_strings = [] + i = 0 + for b in iterate_nodes(composite): + if b in composite.children: + index_strings.append(str(i)) + i += 1 + else: + index_strings.append("_") + + index_strings.append(str(i)) + + output = label_tree_lines(composite, index_strings) + return output + +def format_parents_with_indices(composite: py_trees.composites.Composite) -> str: + index_strings = [] + i = 0 + for b in iterate_nodes(composite): + if b.children: + index_strings.append(str(i)) + else: + index_strings.append("_") + i += 1 + + output = label_tree_lines(composite, index_strings) + return output + + +def format_tree_with_indices( + tree: py_trees.behaviour.Behaviour, + show_root: bool = False, +) -> str: + """ + Examples: + >>> print(format_tree_with_indices(py_trees.behaviours.Dummy())) + 0: --> Dummy + + >>> tree = py_trees.composites.Sequence("s1", False, children=[ + ... py_trees.behaviours.Dummy(), + ... py_trees.behaviours.Success(), + ... py_trees.composites.Sequence("s2", False, children=[ + ... py_trees.behaviours.Dummy(), + ... ]), + ... py_trees.composites.Sequence("", False, children=[ + ... py_trees.behaviours.Failure(), + ... py_trees.behaviours.Periodic("p", n=1), + ... ]), + ... ]) + >>> print(format_tree_with_indices(tree)) # doctest: +NORMALIZE_WHITESPACE + 0: [-] s1 + 1: --> Dummy + 2: --> Success + 3: [-] s2 + 4: --> Dummy + 5: [-] + 6: --> Failure + 7: --> p + + """ + + index_strings = [] + index = 0 + for i, node in enumerate_nodes(tree): + if i == 0 and not show_root: + index_strings.append('_') + else: + index_strings.append(str(index)) + index += 1 + output = label_tree_lines(tree, index_strings) + + return output + + +def say(message): + print(message) + + +def prompt_identify_node( + tree: py_trees.behaviour.Behaviour, + message: str = "Which node?", + display_nodes: bool = True, + show_root: bool = False, +) -> py_trees.behaviour.Behaviour: + node_index = prompt_identify_tree_iterator_index( + tree=tree, message=message, display_nodes=display_nodes, show_root=show_root + ) + node = next(islice(iterate_nodes(tree), node_index, node_index + 1)) + return node + +def prompt_identify_parent_node( + tree: py_trees.behaviour.Behaviour, + message: str = "Which position?", + display_nodes: bool = True, +) -> int: + if display_nodes: + text = f"{format_parents_with_indices(tree)}\n{message}" + else: + text = f"{message}" + node_index = click.prompt( + text=text, + type=int, + ) + + node = next(islice(iterate_nodes(tree), node_index, node_index + 1)) + return node + + +def prompt_identify_tree_iterator_index( + tree: py_trees.behaviour.Behaviour, + message: str = "Which position?", + display_nodes: bool = True, + show_root: bool = False, +) -> int: + if display_nodes: + text = f"{format_tree_with_indices(tree, show_root)}\n{message}" + else: + text = f"{message}" + node_index = click.prompt( + text=text, + type=int, + ) + return node_index + + +def prompt_identify_child_index( + tree: py_trees.behaviour.Behaviour, + message: str = "Which position?", + display_nodes: bool = True, +) -> int: + if display_nodes: + text = f"{format_children_with_indices(tree)}\n{message}" + else: + text = f"{message}" + node_index = click.prompt( + text=text, + type=int, + ) + return node_index + + +def add_child( + tree: T, + parent: Optional[py_trees.composites.Composite] = None, + child: Optional[py_trees.behaviour.Behaviour] = None, +) -> T: + """Add a behaviour to the tree + + Examples: + >>> tree = py_trees.composites.Sequence("", False, children=[]) + >>> print(py_trees.display.ascii_tree(tree)) # doctest: +NORMALIZE_WHITESPACE + [-] + + >>> print(py_trees.display.ascii_tree(add_child(tree, py_trees.behaviours.Success()))) + ... # doctest: +NORMALIZE_WHITESPACE + [-] + --> Success + + """ + if parent is None: + parent = prompt_identify_node( + tree, f"Which parent node do you want to add the child to?" + ) + if child is None: + child_key = click.prompt( + text="What should the child do?", type=click.Choice(["say"]) + ) + match child_key: + case "say": + message = click.prompt(text="What should it say?", type=str) + + child_function = wraps(say)(partial(say, message)) + child_type = py_trees.meta.create_behaviour_from_function( + child_function + ) + child = child_type() + case _: + raise NotImplementedError() + parent.add_child(child) + return tree + + +def remove_node(tree: T, node: Optional[py_trees.behaviour.Behaviour] = None) -> T: + """Remove a behaviour from the tree + + Examples: + >>> tree = py_trees.composites.Sequence("", False, children=[ + ... py_trees.behaviours.Success(), + ... failure_node := py_trees.behaviours.Failure()]) + >>> print(py_trees.display.ascii_tree(tree)) # doctest: +NORMALIZE_WHITESPACE + [-] + --> Success + --> Failure + + >>> print(py_trees.display.ascii_tree(remove_node(tree, failure_node))) + ... # doctest: +NORMALIZE_WHITESPACE + [-] + --> Success + + """ + if node is None: + node = prompt_identify_node(tree, f"Which node do you want to remove?") + parent_node = node.parent + if parent_node is None: + warnings.warn( + f"{node}'s parent is None, so we can't remove it. You can't remove the root node." + ) + return tree + elif isinstance(parent_node, py_trees.composites.Composite): + parent_node.remove_child(node) + else: + raise NotImplementedError() + return tree + + +def move_node( + tree: T, + node: Optional[py_trees.behaviour.Behaviour] = None, + new_parent: Optional[py_trees.behaviour.Behaviour] = None, + index: int = None, +) -> T: + """Exchange two behaviours in the tree + + Examples: + >>> tree = py_trees.composites.Sequence("", False, children=[]) + + """ + + if node is None: + node = prompt_identify_node(tree, f"Which node do you want to move?") + if new_parent is None: + new_parent = prompt_identify_parent_node( + tree, f"What should its parent be?", display_nodes=True + ) + if index is None: + index = prompt_identify_child_index(new_parent) + + assert isinstance(new_parent, py_trees.composites.Composite) + assert isinstance(node.parent, py_trees.composites.Composite) + + node.parent.remove_child(node) + new_parent.insert_child(node, index) + + return tree + + +def exchange_nodes( + tree: T, + node0: Optional[py_trees.behaviour.Behaviour] = None, + node1: Optional[py_trees.behaviour.Behaviour] = None, +) -> T: + """Exchange two behaviours in the tree + + Examples: + >>> tree = py_trees.composites.Sequence("", False, children=[ + ... s:=py_trees.behaviours.Success(), + ... f:=py_trees.behaviours.Failure(), + ... ]) + >>> print(py_trees.display.ascii_tree(tree)) # doctest: +NORMALIZE_WHITESPACE + [-] + --> Success + --> Failure + + >>> print(py_trees.display.ascii_tree(exchange_nodes(tree, s, f))) + ... # doctest: +NORMALIZE_WHITESPACE + [-] + --> Failure + --> Success + + >>> tree = py_trees.composites.Sequence("", False, children=[ + ... a:= py_trees.composites.Sequence("A", False, children=[ + ... py_trees.behaviours.Dummy() + ... ]), + ... py_trees.composites.Sequence("B", False, children=[ + ... py_trees.behaviours.Success(), + ... c := py_trees.composites.Sequence("C", False, children=[]) + ... ]) + ... ]) + >>> print(py_trees.display.ascii_tree(tree)) # doctest: +NORMALIZE_WHITESPACE + [-] + [-] A + --> Dummy + [-] B + --> Success + [-] C + >>> print(py_trees.display.ascii_tree(exchange_nodes(tree, a, c))) + ... # doctest: +NORMALIZE_WHITESPACE + [-] + [-] C + [-] B + --> Success + [-] A + --> Dummy + """ + + if node0 is None: + node0 = prompt_identify_node(tree, f"Which node do you want to switch?") + if node1 is None: + node1 = prompt_identify_node( + tree, f"Which node do you want to switch?", display_nodes=False + ) + + node0_parent, node0_index = node0.parent, node0.parent.children.index(node0) + node1_parent, node1_index = node1.parent, node1.parent.children.index(node1) + + tree = move_node(tree, node0, node1_parent, node1_index) + tree = move_node(tree, node1, node0_parent, node0_index) + + return tree \ No newline at end of file diff --git a/src/social_norms_trees/serialize_tree.py b/src/social_norms_trees/serialize_tree.py new file mode 100644 index 00000000..2db28d69 --- /dev/null +++ b/src/social_norms_trees/serialize_tree.py @@ -0,0 +1,31 @@ +import py_trees + + +def serialize_tree(tree): + def serialize_node(node): + data = { + "type": node.__class__.__name__, + "name": node.name, + "children": [serialize_node(child) for child in node.children], + } + return data + + return serialize_node(tree) + +def deserialize_tree(tree): + def deserialize_node(node): + node_type = node['type'] + children = [deserialize_node(child) for child in node['children']] + + if node_type == 'Sequence': + return py_trees.composites.Sequence(node['name'], False, children=children) + elif node_type == 'Dummy': + return py_trees.behaviours.Dummy(node['name']) + elif node_type == 'Success': + return py_trees.behaviours.Success(node['name']) + elif node_type == 'Failure': + return py_trees.behaviours.Failure(node['name']) + elif node_type == 'Running': + return py_trees.behaviours.Running(node['name']) + + return deserialize_node(tree) \ No newline at end of file diff --git a/src/social_norms_trees/ui_wrapper.py b/src/social_norms_trees/ui_wrapper.py new file mode 100644 index 00000000..4a56b3d5 --- /dev/null +++ b/src/social_norms_trees/ui_wrapper.py @@ -0,0 +1,178 @@ +import time +import click +from datetime import datetime +import json +import os +import uuid +import py_trees + +from social_norms_trees.mutate_tree import move_node, exchange_nodes, remove_node +from social_norms_trees.serialize_tree import serialize_tree, deserialize_tree + +DB_FILE = "db.json" + + +def load_db(db_file): + if os.path.exists(db_file): + with open(db_file, "r") as f: + return json.load(f) + else: + return {} + +def save_db(db, db_file): + """Saves the Python dictionary back to db.json.""" + + print(f"\nWriting results of simulation to {db_file}...") + + with open(db_file, "w") as f: + json.dump(db, f, indent=4) + +def experiment_setup(db): + + print("\n") + participant_id = participant_login() + + print("\n") + origin_tree = load_behavior_tree() + + experiment_id = initialize_experiment_record(db, participant_id, origin_tree) + + print("\nSetup Complete.\n") + + return participant_id, origin_tree, experiment_id + + +def participant_login(): + + participant_id = click.prompt("Please enter your participant id", type=str) + + return participant_id + + +def get_behavior_trees(): + #TODO: Get behavior trees from respective data structure + + print("1. Original Tree") + return [ + py_trees.composites.Sequence( + "", + False, + children=[ + py_trees.behaviours.Dummy(), + py_trees.behaviours.Dummy(), + py_trees.composites.Sequence( + "", + False, + children=[ + py_trees.behaviours.Success(), + py_trees.behaviours.Dummy(), + ], + ), + py_trees.composites.Sequence( + "", + False, + children=[ + py_trees.behaviours.Dummy(), + py_trees.behaviours.Failure(), + py_trees.behaviours.Dummy(), + py_trees.behaviours.Running(), + ], + ), + ], + ) + ] + + +def load_behavior_tree(): + + tree_array = get_behavior_trees() + tree_index = click.prompt("Please select a behavior tree to load for the experiment (enter the number)", type=int) + return tree_array[tree_index - 1] + + +def initialize_experiment_record(db, participant_id, origin_tree): + + experiment_id = str(uuid.uuid4()) + + #TODO: look into python data class + + #TODO: flatten structure of db to simply collction of experiment runs, that will include a field for the participant_id + #instead of grouping by participants + experiment_record = { + "experiment_id": experiment_id, + "participant_id": participant_id, + "base_behavior_tree": serialize_tree(origin_tree), + "start_date": datetime.now().isoformat(), + "actions": [], + } + + db[experiment_id] = experiment_record + + return experiment_id + + +def run_experiment(db, participant_id, origin_tree, experiment_id): + + # Loop for the actual experiment part, which takes user input to decide which action to take + print("\nExperiment beginning...\n") + + while(True): + + print(py_trees.display.ascii_tree(origin_tree)) + user_choice = click.prompt( + "Would you like to perform an action on the behavior tree?", + show_choices=True, + type=click.Choice(['y', 'n'], case_sensitive=False), + ) + + if user_choice == 'y': + action = click.prompt( + "1. move node\n" + + "2. exchange node\n" + + "3. remove node\n" + + "Please select an action to perform on the behavior tree", + type=click.Choice(['1', '2', '3'], case_sensitive=False), + show_choices=True + ) + + if action == "1": + db[experiment_id]["actions"].append("move node") + move_node(origin_tree) + elif action == "2": + db[experiment_id]["actions"].append("exchange node") + exchange_nodes(origin_tree) + elif action == "3": + db[experiment_id]["actions"].append("remove node") + remove_node(origin_tree) + else: + print("Invalid choice, please select a valid number (1, 2, or 3).\n") + + else: + db[experiment_id]["final_behavior_tree"] = serialize_tree(origin_tree) + db[experiment_id]["end_date"] = datetime.now().isoformat() + break + + return db + + + +def main(): + print("AIT Prototype #1 Simulator") + + DB_FILE = "db.json" + db = load_db(DB_FILE) + + + participant_id, origin_tree, experiment_id = experiment_setup(db) + db = run_experiment(db, participant_id, origin_tree, experiment_id) + + save_db(db, DB_FILE) + + + print("\nSimulation has ended.") + + #TODO: visualize the differences between old and new behavior trees after experiment. + # Potentially use git diff + +if __name__ == "__main__": + main() \ No newline at end of file