Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

User/rcadene/2024 10 07 vla #467

Draft
wants to merge 9 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
65 changes: 65 additions & 0 deletions lerobot/common/datasets/mock.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
import unittest
from pathlib import Path
from lerobot_dataset import MultiLeRobotDataset

class TestMultiLeRobotDataset(unittest.TestCase):
def setUp(self):
# Define the datasets to use
self.dataset_repo_ids = [
"lerobot/aloha_sim_insertion_human",
"lerobot/aloha_static_vinh_cup"
]
self.dataset = MultiLeRobotDataset(
repo_ids=self.dataset_repo_ids,
# Replace with your local path or None for Hugging Face Hub
split="train",
image_transforms=None, # Pass your transforms if any
delta_timestamps=None,
)

def test_initialization(self):
# Check if datasets were initialized correctly
self.assertEqual(len(self.dataset.repo_ids), 2)
self.assertEqual(self.dataset.repo_ids, self.dataset_repo_ids)

def test_num_samples(self):
# Check the total number of samples
self.assertGreater(len(self.dataset), 0)

def test_num_episodes(self):
# Check the total number of episodes
self.assertGreater(self.dataset.num_episodes, 0)

def test_fps(self):
# Check that FPS is correctly returned and is consistent
fps = self.dataset.fps
self.assertGreater(fps, 0)

def test_video_property(self):
# Check if video loading is correctly handled
self.assertIsInstance(self.dataset.video, bool)

def test_getitem(self):
# Test accessing a few samples to see if they are returned correctly
for i in range(5):
sample = self.dataset[i]
self.assertIsInstance(sample, dict)
self.assertIn("dataset_index", sample) # Check that dataset index is included
breakpoint()

def test_camera_keys(self):
# Test that camera keys are returned correctly
camera_keys = self.dataset.camera_keys
self.assertIsInstance(camera_keys, list)
self.assertGreater(len(camera_keys), 0)

def test_video_frame_keys(self):
# Test that video frame keys are returned correctly
video_frame_keys = self.dataset.video_frame_keys
if self.dataset.video:
self.assertIsInstance(video_frame_keys, list)
else:
self.assertEqual(len(video_frame_keys), 0)

if __name__ == "__main__":
unittest.main()
7 changes: 1 addition & 6 deletions lerobot/common/policies/act/modeling_act.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,12 +97,7 @@ def reset(self):

@torch.no_grad
def select_action(self, batch: dict[str, Tensor]) -> Tensor:
"""Select a single action given environment observations.

This method wraps `select_actions` in order to return one action at a time for execution in the
environment. It works by managing the actions in a queue and only calling `select_actions` when the
queue is empty.
"""
"""Select a single action given environment observations."""
self.eval()

batch = self.normalize_inputs(batch)
Expand Down
6 changes: 6 additions & 0 deletions lerobot/common/policies/factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@

def _policy_cfg_from_hydra_cfg(policy_cfg_class, hydra_cfg):
expected_kwargs = set(inspect.signature(policy_cfg_class).parameters)

if not set(hydra_cfg.policy).issuperset(expected_kwargs):
logging.warning(
f"Hydra config is missing arguments: {set(expected_kwargs).difference(hydra_cfg.policy)}"
Expand Down Expand Up @@ -66,6 +67,11 @@ def get_policy_and_config_classes(name: str) -> tuple[Policy, object]:
from lerobot.common.policies.vqbet.modeling_vqbet import VQBeTPolicy

return VQBeTPolicy, VQBeTConfig
elif name == "vla":
from lerobot.common.policies.vla.configuration_vla import VLAConfig
from lerobot.common.policies.vla.modeling_vla import VLAPolicy

return VLAPolicy, VLAConfig
else:
raise NotImplementedError(f"Policy with name {name} is not implemented.")

Expand Down
197 changes: 197 additions & 0 deletions lerobot/common/policies/vla/configuration_qwen2_vl.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,197 @@
# Copyright 2024 The Qwen team, Alibaba Group and the HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Qwen2VL model configuration"""

import os
from typing import Union

from transformers.configuration_utils import PretrainedConfig
from transformers.utils import (
logging, # Using standard Python logging module instead of `transformers.utils.logging`
)

logger = logging.get_logger(__name__)


def _validate_default_rope_parameters(config: PretrainedConfig, ignore_keys: set | None = None):
rope_scaling = config.rope_scaling
rope_type = rope_scaling.get(
"rope_type", rope_scaling.get("type", None)
) # BC: "rope_type" was originally "type"
required_keys = {"rope_type"}
received_keys = set(rope_scaling.keys())
# _check_received_keys(rope_type, received_keys, required_keys, ignore_keys=ignore_keys)


# Like `ROPE_INIT_FUNCTIONS`, this validation function mapping can be dynamically updated for custom RoPE types.
ROPE_VALIDATION_FUNCTIONS = {
"default": _validate_default_rope_parameters,
# "linear": _validate_linear_scaling_rope_parameters,
# "dynamic": _validate_dynamic_scaling_rope_parameters,
# "yarn": _validate_yarn_parameters,
# "longrope": _validate_longrope_parameters,
# "llama3": _validate_llama3_parameters,
}


def rope_config_validation(config: PretrainedConfig, ignore_keys: set | None = None):
"""
Validate the RoPE config arguments, given a `PretrainedConfig` object
"""
rope_scaling = getattr(config, "rope_scaling", None) # not a default parameter in `PretrainedConfig`
if rope_scaling is None:
return

# BC: "rope_type" was originally "type"
rope_type = rope_scaling.get("rope_type", rope_scaling.get("type", "default"))
validation_fn = ROPE_VALIDATION_FUNCTIONS.get(rope_type)
if validation_fn is not None:
validation_fn(config, ignore_keys=ignore_keys)
else:
logger.warning(
f"Missing validation function mapping in `ROPE_VALIDATION_FUNCTIONS` for 'rope_type'='{rope_type}'"
)


class Qwen2VLVisionConfig(PretrainedConfig):
model_type = "qwen2_vl"

def __init__(
self,
depth=32,
embed_dim=1280,
hidden_size=3584,
hidden_act="quick_gelu",
mlp_ratio=4,
num_heads=16,
in_channels=3,
patch_size=14,
spatial_merge_size=2,
temporal_patch_size=2,
**kwargs,
):
super().__init__(**kwargs)

self.depth = depth
self.embed_dim = embed_dim
self.hidden_size = hidden_size
self.hidden_act = hidden_act
self.mlp_ratio = mlp_ratio
self.num_heads = num_heads
self.in_channels = in_channels
self.patch_size = patch_size
self.spatial_merge_size = spatial_merge_size
self.temporal_patch_size = temporal_patch_size

@classmethod
def from_pretrained(
cls, pretrained_model_name_or_path: Union[str, os.PathLike], **kwargs
) -> "PretrainedConfig":
cls._set_token_in_kwargs(kwargs)

config_dict, kwargs = cls.get_config_dict(pretrained_model_name_or_path, **kwargs)

if config_dict.get("model_type") == "qwen2_vl":
config_dict = config_dict["vision_config"]

if (
"model_type" in config_dict
and hasattr(cls, "model_type")
and config_dict["model_type"] != cls.model_type
):
logger.warning(
f"You are using a model of type {config_dict['model_type']} to instantiate a model of type "
f"{cls.model_type}. This is not supported for all configurations of models and can yield errors."
)

return cls.from_dict(config_dict, **kwargs)


class Qwen2VLConfig(PretrainedConfig):
r"""
A simplified version of the Qwen2VL model configuration class without the `transformers` dependencies.
"""

model_type = "qwen2_vl"
keys_to_ignore_at_inference = ["past_key_values"]

def __init__(
self,
vocab_size=152064,
hidden_size=8192,
intermediate_size=29568,
num_hidden_layers=80,
num_decoder_layers=1,
num_attention_heads=64,
num_key_value_heads=8,
# dim_feedforward = 3200,
hidden_act="silu",
pad_token_id=0,
max_position_embeddings=32768,
initializer_range=0.02,
rms_norm_eps=1e-05,
use_cache=True,
tie_word_embeddings=False,
rope_theta=1000000.0,
use_sliding_window=False,
sliding_window=4096,
max_window_layers=80,
attention_dropout=0.0,
vision_config=None,
rope_scaling={"type": "mrope", "mrope_section": [2, 2, 2]},
pruned_heads=None,
**kwargs,
):
# Initialize vision config
if isinstance(vision_config, dict):
self.vision_config = Qwen2VLVisionConfig(**vision_config)
elif vision_config is None:
self.vision_config = Qwen2VLVisionConfig()

# Model hyperparameters
self.vocab_size = vocab_size
self.max_position_embeddings = max_position_embeddings
self.hidden_size = hidden_size
self.intermediate_size = intermediate_size
self.num_hidden_layers = num_hidden_layers
self.num_attention_heads = num_attention_heads
self.num_key_value_heads = num_key_value_heads
self.hidden_act = hidden_act
self.initializer_range = initializer_range
self.rms_norm_eps = rms_norm_eps
self.use_cache = use_cache
self.rope_theta = rope_theta
self.attention_dropout = attention_dropout
self.use_sliding_window = use_sliding_window
self.sliding_window = sliding_window
self.max_window_layers = max_window_layers
self.pad_token_id = pad_token_id
self.pruned_heads = pruned_heads or {}
self.rope_scaling = rope_scaling
self.num_decoder_layers = num_decoder_layers

if self.rope_scaling is not None and "type" in self.rope_scaling:
if self.rope_scaling["type"] == "mrope":
self.rope_scaling["type"] = "default"
self.rope_scaling["rope_type"] = self.rope_scaling["type"]
rope_config_validation(self, ignore_keys={"mrope_section"})

super().__init__(tie_word_embeddings=tie_word_embeddings, **kwargs)

# @classmethod
# def from_pretrained(cls, pretrained_model_name_or_path, **kwargs):
# # Custom loading logic from a pre-trained model or path
# logger.info(f"Loading pretrained config from {pretrained_model_name_or_path}...")
# # Add custom logic here to load a pretrained configuration
# return cls(**kwargs)
Loading