Skip to content

Commit

Permalink
🗝️ Update type hints (#2399)
Browse files Browse the repository at this point in the history
* New type hint structure

* Update type hints

* Delete wrong file

* Remove dict import
  • Loading branch information
qgallouedec authored Nov 26, 2024
1 parent 9368dcc commit c10cc89
Show file tree
Hide file tree
Showing 42 changed files with 462 additions and 464 deletions.
4 changes: 2 additions & 2 deletions examples/datasets/hh-rlhf-helpful-base.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@

import re
from dataclasses import dataclass
from typing import Dict, List, Optional
from typing import Optional

from datasets import load_dataset
from transformers import HfArgumentParser
Expand Down Expand Up @@ -51,7 +51,7 @@ def common_start(str1: str, str2: str) -> str:
return "".join(common_chars)


def extract_dialogue(example: str) -> List[Dict[str, str]]:
def extract_dialogue(example: str) -> list[dict[str, str]]:
# Extract the prompt, which corresponds to the common start of the chosen and rejected dialogues
prompt_text = common_start(example["chosen"], example["rejected"])

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
# limitations under the License.

from dataclasses import dataclass, field
from typing import Any, Dict, List, Optional, Union
from typing import Any, Optional, Union

import evaluate
import numpy as np
Expand Down Expand Up @@ -236,7 +236,7 @@ class RewardDataCollatorWithPadding:
pad_to_multiple_of: Optional[int] = None
return_tensors: str = "pt"

def __call__(self, features: List[Dict[str, Any]]) -> Dict[str, Any]:
def __call__(self, features: list[dict[str, Any]]) -> dict[str, Any]:
features_j = []
features_k = []
for feature in features:
Expand Down
10 changes: 5 additions & 5 deletions examples/research_projects/stack_llama_2/scripts/dpo_llama2.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
# 0. imports
import os
from dataclasses import dataclass, field
from typing import Dict, Optional
from typing import Optional

import torch
from accelerate import Accelerator
Expand Down Expand Up @@ -109,9 +109,9 @@ def get_stack_exchange_paired(
The dataset is converted to a dictionary with the following structure:
{
'prompt': List[str],
'chosen': List[str],
'rejected': List[str],
'prompt': list[str],
'chosen': list[str],
'rejected': list[str],
}
Prompts are structured as follows:
Expand All @@ -126,7 +126,7 @@ def get_stack_exchange_paired(
)
original_columns = dataset.column_names

def return_prompt_and_responses(samples) -> Dict[str, str]:
def return_prompt_and_responses(samples) -> dict[str, str]:
return {
"prompt": ["Question: " + question + "\n\nAnswer: " for question in samples["question"]],
"chosen": samples["response_j"],
Expand Down
6 changes: 3 additions & 3 deletions examples/scripts/sft_video_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@
import os
import random
from dataclasses import dataclass
from typing import Any, Dict, List
from typing import Any

import requests
import torch
Expand Down Expand Up @@ -90,7 +90,7 @@ def download_video(url: str, cache_dir: str) -> str:
raise Exception(f"Failed to download video: {e}") from e


def prepare_dataset(example: Dict[str, Any], cache_dir: str) -> Dict[str, List[Dict[str, Any]]]:
def prepare_dataset(example: dict[str, Any], cache_dir: str) -> dict[str, list[dict[str, Any]]]:
"""Prepare dataset example for training."""
video_url = example["video_url"]
timecoded_cc = example["timecoded_cc"]
Expand Down Expand Up @@ -120,7 +120,7 @@ def prepare_dataset(example: Dict[str, Any], cache_dir: str) -> Dict[str, List[D
return {"messages": messages}


def collate_fn(examples: List[Dict[str, Any]]) -> Dict[str, torch.Tensor]:
def collate_fn(examples: list[dict[str, Any]]) -> dict[str, torch.Tensor]:
"""Collate batch of examples for training."""
texts = []
video_inputs = []
Expand Down
2 changes: 1 addition & 1 deletion trl/commands/cli_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,7 +158,7 @@ def __init__(self, parsers, ignore_extra_args=False):
with the processed parsers.
Args:
parsers (`List[argparse.ArgumentParser`]):
parsers (`list[argparse.ArgumentParser`]):
List of parsers.
ignore_extra_args (`bool`):
Whether to ignore extra arguments passed by the config
Expand Down
18 changes: 9 additions & 9 deletions trl/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
import random
import warnings
from contextlib import contextmanager
from typing import Dict, List, Optional, Tuple, Union
from typing import Optional, Union

import numpy as np
import torch
Expand Down Expand Up @@ -70,10 +70,10 @@ def top_k_top_p_filtering(
return logits


def flatten_dict(nested: Dict, sep: str = "/") -> Dict:
def flatten_dict(nested: dict, sep: str = "/") -> dict:
"""Flatten dictionary and concatenate nested keys with separator."""

def recurse(nest: Dict, prefix: str, into: Dict) -> None:
def recurse(nest: dict, prefix: str, into: dict) -> None:
for k, v in nest.items():
if sep in k:
raise ValueError(f"separator '{sep}' not allowed to be in key '{k}'")
Expand All @@ -87,7 +87,7 @@ def recurse(nest: Dict, prefix: str, into: Dict) -> None:
return flat


def convert_to_scalar(stats: Dict) -> Dict:
def convert_to_scalar(stats: dict) -> dict:
"""
Converts the stats from a flattened dict to single scalar dicts
"""
Expand All @@ -103,7 +103,7 @@ def convert_to_scalar(stats: Dict) -> Dict:
return tensorboard_stats


def stack_dicts(stats_dicts: List[Dict]) -> Dict:
def stack_dicts(stats_dicts: list[dict]) -> dict:
"""Stack the values of a dict."""
results = dict()
for k in stats_dicts[0]:
Expand Down Expand Up @@ -185,7 +185,7 @@ def entropy_from_logits(logits: torch.Tensor) -> torch.Tensor:
return entropy


def stats_to_np(stats_dict: Dict) -> Dict:
def stats_to_np(stats_dict: dict) -> dict:
"""Cast all torch.tensors in dict to numpy arrays."""
new_dict = dict()
for k, v in stats_dict.items():
Expand All @@ -202,7 +202,7 @@ def stats_to_np(stats_dict: Dict) -> Dict:


def respond_to_batch(
model: nn.Module, queries: List[torch.LongTensor], txt_len: int = 20, top_k: int = 0, top_p: float = 1.0
model: nn.Module, queries: list[torch.LongTensor], txt_len: int = 20, top_k: int = 0, top_p: float = 1.0
) -> torch.LongTensor:
"""Sample text from language model."""
input_ids = queries
Expand Down Expand Up @@ -271,8 +271,8 @@ def empty_device_cache(cls):


def randn_tensor(
shape: Union[Tuple, List],
generator: Optional[Union[List[torch.Generator], torch.Generator]] = None,
shape: Union[tuple, list],
generator: Optional[Union[list[torch.Generator], torch.Generator]] = None,
device: Optional[torch.device] = None,
dtype: Optional[torch.dtype] = None,
layout: Optional[torch.layout] = None,
Expand Down
26 changes: 13 additions & 13 deletions trl/data_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, Dict, List, Optional, Sequence, TypeVar
from typing import Any, Optional, Sequence, TypeVar

from datasets import Dataset, DatasetDict
from transformers import PreTrainedTokenizer
Expand All @@ -20,12 +20,12 @@
DatasetType = TypeVar("DatasetType", Dataset, DatasetDict)


def is_conversational(example: Dict[str, Any]) -> bool:
def is_conversational(example: dict[str, Any]) -> bool:
r"""
Check if the example is in a conversational format.
Args:
example (`Dict[str, Any]`):
example (`dict[str, Any]`):
A single data entry of a dataset. The example can have different keys depending on the
dataset type.
Expand Down Expand Up @@ -60,7 +60,7 @@ def is_conversational(example: Dict[str, Any]) -> bool:
return False


def apply_chat_template(example: Dict[str, List[Dict[str, str]]], tokenizer: PreTrainedTokenizer) -> Dict[str, str]:
def apply_chat_template(example: dict[str, list[dict[str, str]]], tokenizer: PreTrainedTokenizer) -> dict[str, str]:
r"""
Apply a chat template to a conversational example.
Expand Down Expand Up @@ -139,13 +139,13 @@ def apply_chat_template(example: Dict[str, List[Dict[str, str]]], tokenizer: Pre


def maybe_apply_chat_template(
example: Dict[str, List[Dict[str, str]]], tokenizer: PreTrainedTokenizer
) -> Dict[str, str]:
example: dict[str, list[dict[str, str]]], tokenizer: PreTrainedTokenizer
) -> dict[str, str]:
r"""
If the example is in a conversational format, apply a chat template to it.
Args:
example (`Dict[str, List[Dict[str, str]]`):
example (`dict[str, list[dict[str, str]]`):
Dictionary representing a single data entry of a conversational dataset. Each data entry can have different
keys depending on the dataset type. The supported dataset types are:
Expand All @@ -163,7 +163,7 @@ def maybe_apply_chat_template(
The tokenizer to apply the chat template with.
Returns:
`Dict[str, str]`: The formatted example with the chat template applied.
`dict[str, str]`: The formatted example with the chat template applied.
Note:
This function does not alter the keys, except for Language modeling dataset, where `"messages"` is replaced by
Expand All @@ -188,7 +188,7 @@ def maybe_apply_chat_template(
return example


def _unpair_row(examples: List[Dict[str, List[Dict[str, str]]]]) -> List[Dict[str, List[Dict[str, str]]]]:
def _unpair_row(examples: list[dict[str, list[dict[str, str]]]]) -> list[dict[str, list[dict[str, str]]]]:
batch_size = len(examples["chosen"])
new_rows = {
"completion": examples["chosen"] + examples["rejected"],
Expand Down Expand Up @@ -288,7 +288,7 @@ def maybe_unpair_preference_dataset(
return dataset


def extract_prompt(example: Dict[str, Sequence]) -> Dict[str, Sequence]:
def extract_prompt(example: dict[str, Sequence]) -> dict[str, Sequence]:
r"""
Extracts the shared prompt from a preference data example, where the prompt is implicit within both
the chosen and rejected completions.
Expand All @@ -307,7 +307,7 @@ def extract_prompt(example: Dict[str, Sequence]) -> Dict[str, Sequence]:
}


def maybe_extract_prompt(example: Dict[str, List]) -> Dict[str, List]:
def maybe_extract_prompt(example: dict[str, list]) -> dict[str, list]:
r"""
Extracts the shared prompt from a preference data example, where the prompt is implicit within both
the chosen and rejected completions.
Expand All @@ -318,12 +318,12 @@ def maybe_extract_prompt(example: Dict[str, List]) -> Dict[str, List]:
"rejected" completions.
Args:
example (`Dict[str, List]`):
example (`dict[str, list]`):
A dictionary representing a single data entry in the preference dataset. It must contain the keys
`"chosen"` and `"rejected"`, where each value is either conversational or standard (`str`).
Returns:
`Dict[str, List]`: A dictionary containing:
`dict[str, list]`: A dictionary containing:
- `"prompt"`: The longest common prefix between the "chosen" and "rejected" completions.
- `"chosen"`: The remainder of the "chosen" completion, with the prompt removed.
- `"rejected"`: The remainder of the "rejected" completion, with the prompt removed.
Expand Down
16 changes: 8 additions & 8 deletions trl/extras/best_of_n_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Any, Callable, List, Optional, Union
from typing import Any, Callable, Optional, Union

import torch
from transformers import GenerationConfig, PreTrainedTokenizer, PreTrainedTokenizerFast
Expand All @@ -26,7 +26,7 @@ def __init__(
self,
model: PreTrainedModelWrapper,
tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast],
queries_to_scores: Callable[[List[str]], List[float]],
queries_to_scores: Callable[[list[str]], list[float]],
length_sampler: Any,
sample_size: int = 4,
seed: Optional[int] = None,
Expand All @@ -41,7 +41,7 @@ def __init__(
The pretrained model to use for generation
tokenizer (`PreTrainedTokenizer` or `PreTrainedTokenizerFast`):
Tokenizer associated with the pretrained model
queries_to_scores (`Callable[[List[str]], List[float]]`):
queries_to_scores (`Callable[[list[str]], list[float]]`):
Callable that takes a list of generated texts and returns the associated reward scores
length_sampler (`Any`):
Sampler used to sample the length of the generated text
Expand Down Expand Up @@ -78,16 +78,16 @@ def __init__(

def generate(
self,
tokenized_query: Union[List[int], torch.Tensor, List[torch.Tensor], List[List[int]]],
tokenized_query: Union[list[int], torch.Tensor, list[torch.Tensor], list[list[int]]],
skip_special_tokens: bool = True,
device: Optional[Union[str, torch.device]] = None,
**generation_kwargs,
) -> List[List[str]]:
) -> list[list[str]]:
r"""
Generate the best of n samples for input queries
Args:
tokenized_query (`List[int]` or `torch.Tensor` or `List[torch.Tensor]` or `List[int]`):
tokenized_query (`list[int]` or `torch.Tensor` or `list[torch.Tensor]` or `list[int]`):
represents either a single tokenized query (a single tensor or a list of integers) or a batch of tokenized queries (a list of tensors or a list of lists of integers)
skip_special_tokens (`bool`):
Whether to remove the special tokens from the output
Expand All @@ -98,13 +98,13 @@ def generate(
This is used to override generation config
Returns:
List[List[str]]: A list of lists of generated texts
list[list[str]]: A list of lists of generated texts
"""
queries = None

if isinstance(tokenized_query, torch.Tensor) and tokenized_query.ndim == 1:
queries = tokenized_query.unsqueeze(0)
elif isinstance(tokenized_query, List):
elif isinstance(tokenized_query, list):
element_type = type(tokenized_query[0])
if element_type is int:
queries = torch.tensor(tokenized_query).unsqueeze(0)
Expand Down
4 changes: 2 additions & 2 deletions trl/mergekit_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,8 +63,8 @@ class MergeConfig:
target_model_path (`Optional[str]`): Path to the target model.
policy_model_weight (`float`): Weight for the policy model (for `linear` and `ties` methods).
target_model_weight (`float`): Weight for the target model (for `linear` and `ties` methods).
policy_model_density (`List[float]`): Density parameters for the policy model (for `ties` and `dare_ties`).
target_model_density (`List[float]`): Density parameters for the target model (for `ties` and `dare_ties`).
policy_model_density (`list[float]`): Density parameters for the policy model (for `ties` and `dare_ties`).
target_model_density (`list[float]`): Density parameters for the target model (for `ties` and `dare_ties`).
normalize (`Optional[float]`): Normalization factor for the TIES method.
t_values (`Optional[float]`): Interpolation factor for the SLERP method.
dtype (`str`): Data type to use for merging, e.g., `"float16"`.
Expand Down
Loading

0 comments on commit c10cc89

Please sign in to comment.