Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: spatial temporal navigation #355

Open
wants to merge 10 commits into
base: development
Choose a base branch
from
453 changes: 444 additions & 9 deletions poetry.lock

Large diffs are not rendered by default.

4 changes: 4 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,10 @@ tomli = "^2.0.1"
openwakeword = { git = "https://github.com/maciejmajek/openWakeWord.git", branch = "chore/remove-tflite-backend" }
pytest-timeout = "^2.3.1"
tomli-w = "^1.1.0"
pyquaternion = "^0.9.9"
pymongo = "^4.10.1"
weaviate-client = "^4.10.2"
langchain-weaviate = "^0.0.3"
[tool.poetry.group.dev.dependencies]
ipykernel = "^6.29.4"

Expand Down
6 changes: 6 additions & 0 deletions src/rai/rai/apps/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,3 +11,9 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from .spatial_temporal_navigation.spatial_temporal_navigation import (
run_spatial_temporal_data_collection,
)

__all__ = ["run_spatial_temporal_data_collection"]
36 changes: 36 additions & 0 deletions src/rai/rai/apps/spatial_temporal_navigation/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
# Copyright (C) 2024 Robotec.AI
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


from .spatial_temporal_navigation import (
Description,
ImageStamped,
Observation,
Orientation,
Pose,
PositionStamped,
Scene,
run_spatial_temporal_data_collection,
)

__all__ = [
"Pose",
"Orientation",
"PositionStamped",
"ImageStamped",
"Scene",
"Description",
"Observation",
"run_spatial_temporal_data_collection",
]
Original file line number Diff line number Diff line change
@@ -0,0 +1,233 @@
# Copyright (C) 2024 Robotec.AI
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


import logging
import threading
import time
from typing import Any, Dict, List, cast
from uuid import uuid4

import cv2
from cv_bridge import CvBridge
from geometry_msgs.msg import TransformStamped
from langchain_community.vectorstores import VectorStore
from pydantic import BaseModel, Field
from pymongo.collection import Collection
from rclpy.executors import SingleThreadedExecutor
from rclpy.node import Node
from sensor_msgs.msg import Image

from rai.messages.multimodal import HumanMultimodalMessage
from rai.messages.utils import preprocess_image
from rai.tools.ros.tools import TF2TransformFetcher
from rai.utils.model_initialization import get_llm_model

logger = logging.getLogger(__name__)


class Pose(BaseModel):
x: float = Field(..., description="The x coordinate of the position")
y: float = Field(..., description="The y coordinate of the position")
z: float = Field(..., description="The z coordinate of the position")


class Orientation(BaseModel):
x: float = Field(..., description="The x coordinate of the orientation")
y: float = Field(..., description="The y coordinate of the orientation")
z: float = Field(..., description="The z coordinate of the orientation")
w: float = Field(..., description="The w coordinate of the orientation")


class PositionStamped(BaseModel):
timestamp: float
position: Pose
orientation: Orientation


class ImageStamped(BaseModel):
timestamp: float
image: str = Field(..., description="Base64 encoded image", repr=False)


class Scene(BaseModel):
uuid: str


class Description(BaseModel):
description: str
objects: List[str]
anomalies: List[str]


class Observation(BaseModel):
uuid: str
scene: Scene
position_stamped: PositionStamped
image_stamped: ImageStamped
description: Description
timestamp: float = Field(default_factory=time.time)


class VectorDatabaseEntry(BaseModel):
text: str
metadata: Dict[str, str]


def ros2_transform_stamped_to_position(
transform_stamped: TransformStamped,
) -> PositionStamped:
return PositionStamped(
timestamp=transform_stamped.header.stamp.sec # type: ignore
+ transform_stamped.header.stamp.nanosec / 1e9, # type: ignore
position=Pose(
x=transform_stamped.transform.translation.x, # type: ignore
y=transform_stamped.transform.translation.y, # type: ignore
z=transform_stamped.transform.translation.z, # type: ignore
),
orientation=Orientation(
x=transform_stamped.transform.rotation.x, # type: ignore
y=transform_stamped.transform.rotation.y, # type: ignore
z=transform_stamped.transform.rotation.z, # type: ignore
w=transform_stamped.transform.rotation.w, # type: ignore
),
)


def ros2_image_to_image(ros2_image: Image) -> ImageStamped:
logger.info("Converting ROS2 image to base64 image")
bridge = CvBridge()
cv2_image = bridge.imgmsg_to_cv2(ros2_image) # type: ignore
cv2_image = cv2.cvtColor(cv2_image, cv2.COLOR_BGR2RGB) # type: ignore
image_data = preprocess_image(cv2_image) # type: ignore
maciejmajek marked this conversation as resolved.
Show resolved Hide resolved
return ImageStamped(
timestamp=ros2_image.header.stamp.sec # type: ignore
+ ros2_image.header.stamp.nanosec / 1e9, # type: ignore
image=image_data,
)


def generate_description(image: ImageStamped) -> Description:
logger.info("Generating LLM description")
prompt = [
HumanMultimodalMessage(
content="Describe the image",
images=[image.image],
)
]
llm = get_llm_model(model_type="simple_model")
description = cast(Description, llm.with_structured_output(Description).invoke(prompt)) # type: ignore
return description


def build_observation(
scene_uuid: str, position: TransformStamped, image: Image
) -> Observation:
logger.info("Building observation")
image_stamped: ImageStamped = ros2_image_to_image(image)
position_stamped: PositionStamped = ros2_transform_stamped_to_position(position)
description: Description = generate_description(image_stamped)
return Observation(
uuid=str(uuid4()),
scene=Scene(uuid=scene_uuid),
image_stamped=image_stamped,
position_stamped=position_stamped,
description=description,
)


def observation_to_vector_database_entry(observation: Observation):
return VectorDatabaseEntry(
text=str(observation.description),
metadata={"uuid": observation.uuid},
)


def data_collection_pipeline(
vectorstore: VectorStore,
observations_collection: Collection[Dict[str, Any]],
image: Image,
transform: TransformStamped,
):
logger.info("Running data collection pipeline")
observation = build_observation(str(uuid4()), transform, image)
vector_database_entry = observation_to_vector_database_entry(observation)

logger.info(f"Adding to {vectorstore.__class__.__name__}")
vectorstore.add_texts(
texts=[vector_database_entry.text],
metadatas=[vector_database_entry.metadata],
)

logger.info("Adding to MongoDB")
observations_collection.insert_one(observation.model_dump())


class TransformGrabber:
def __init__(self, target_frame: str, source_frame: str):
self.transform_fetcher = TF2TransformFetcher(
target_frame=target_frame, source_frame=source_frame
)
self.transform = None
self.lock = threading.Lock()

def run(self):
while True:
with self.lock:
self.transform = self.transform_fetcher.get_data()


class ImageGrabber(Node):
def __init__(self, image_topic: str):
super().__init__("image_grabber")
self.subscription = self.create_subscription(
Image, image_topic, self.image_callback, 10
)
self.image: Image | None = None

def image_callback(self, msg: Image):
self.image = msg

def shutdown(self):
self.destroy_node()


def run_spatial_temporal_data_collection(
image_topic: str,
source_frame: str,
target_frame: str,
vectorstore: VectorStore,
observations_collection: Collection[Dict[str, Any]],
time_between_observations: float = 5.0,
) -> None:
transform_fetcher = TransformGrabber(
target_frame=target_frame, source_frame=source_frame
)
image_grabber = ImageGrabber(image_topic)
executor = SingleThreadedExecutor()
executor.add_node(image_grabber)
threading.Thread(target=transform_fetcher.run).start()
threading.Thread(target=executor.spin).start()

while True:
image = image_grabber.image
transform = transform_fetcher.transform
if image is None or transform is None:
time.sleep(0.1)
continue
threading.Thread(
target=data_collection_pipeline,
args=(vectorstore, observations_collection, image, transform),
).start()
time.sleep(time_between_observations)
maciejmajek marked this conversation as resolved.
Show resolved Hide resolved
Loading