Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[REVIEW] Fix Padding Related Bugs: Crossfit #66

Merged
Merged
Show file tree
Hide file tree
Changes from 17 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions crossfit/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,7 @@ def __call__(self, *args, **kwargs):
load_dataset = LazyLoader("crossfit.dataset.load.load_dataset")
embed = LazyLoader("crossfit.report.beir.embed.embed")
beir_report = LazyLoader("crossfit.report.beir.report.beir_report")
utils = LazyLoader("crossfit.utils")

__all__.extend(
[
Expand Down
93 changes: 93 additions & 0 deletions crossfit/backend/torch/hf/memory_curve_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
# Copyright 2024 NVIDIA CORPORATION
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import gc

import joblib
import numpy as np
import torch
from sklearn.linear_model import LinearRegression
from tqdm import tqdm
from transformers import AutoConfig, AutoTokenizer, PreTrainedModel

from crossfit.utils.model_adapter import adapt_model_input


def fit_memory_estimate_curve(
model: PreTrainedModel,
path_or_name: str,
start_batch_size: int = 1,
end_batch_size: int = 2048,
batch_size_increment: int = 256,
start_seq_len: int = 1,
end_seq_len: int = 2048,
seq_len_increment: int = 64,
mem_model_path: str = None,
) -> LinearRegression:
print(f"Fitting memory estimate curve for model: {path_or_name}")

device = next(model.parameters()).device
X: list[list[int]] = []
y: list[float] = []

max_seq = min(AutoTokenizer.from_pretrained(path_or_name).model_max_length, end_seq_len)

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Similar to the other issue, I think every time AutoTokenizer or AutoConfig is used in this file it should be using the corresponding methods of the model. Now I'm getting this error with llama guard:

Traceback (most recent call last):
  File "/usr/local/bin/aegis_classifier_inference", line 8, in <module>
    sys.exit(console_script())
  File "/usr/local/lib/python3.10/dist-packages/nemo_curator/scripts/aegis_classifier_inference.py", line 126, in console_script
    main()
  File "/usr/local/lib/python3.10/dist-packages/nemo_curator/scripts/aegis_classifier_inference.py", line 62, in main
    domain_classifier = AegisClassifier(
  File "/usr/local/lib/python3.10/dist-packages/nemo_curator/classifiers/aegis.py", line 161, in __init__
    model = AegisHFModel(config=config)
  File "/usr/local/lib/python3.10/dist-packages/nemo_curator/classifiers/aegis.py", line 90, in __init__
    super().__init__(
  File "/usr/local/lib/python3.10/dist-packages/crossfit/backend/torch/hf/model.py", line 58, in __init__
    self.mem = fit_memory_estimate_curve(
  File "/usr/local/lib/python3.10/dist-packages/crossfit/backend/torch/hf/memory_curve_utils.py", line 44, in fit_memory_estimate_curve
    max_seq = min(AutoTokenizer.from_pretrained(path_or_name).model_max_length, end_seq_len)
  File "/usr/local/lib/python3.10/dist-packages/transformers/models/auto/tokenization_auto.py", line 843, in from_pretrained
    return tokenizer_class_fast.from_pretrained(pretrained_model_name_or_path, *inputs, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/transformers/tokenization_utils_base.py", line 2032, in from_pretrained
    raise EnvironmentError(
OSError: Can't load tokenizer for 'meta-llama/LlamaGuard-7b'. If you were trying to load it from 'https://huggingface.co/models', make sure you don't have a local directory with the same name. Otherwise, make sure 'meta-llama/LlamaGuard-7b' is the correct path to a directory containing all relevant files for a LlamaTokenizerFast tokenizer.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Resolved by: 987836f

if max_seq > 1e5:
max_seq = min(AutoConfig.from_pretrained(path_or_name).max_position_embeddings, end_seq_len)

batch_size_pbar = tqdm(
range(start_batch_size, end_batch_size + 1, batch_size_increment), desc="Batch size"
)
for batch_size in batch_size_pbar:
seq_len_pbar = tqdm(
range(start_seq_len, max_seq + 1, seq_len_increment),
desc="Sequence length",
leave=False,
)
for seq_len in seq_len_pbar:
torch.cuda.reset_peak_memory_stats()

batch = {
"input_ids": torch.randint(1, 501, (batch_size, seq_len)).to(device=device),
"attention_mask": torch.ones((batch_size, seq_len)).to(device=device),
}

try:
_ = adapt_model_input(model, batch)
memory_used = torch.cuda.max_memory_allocated() / (1024**2) # Convert to MB
X.append([batch_size, seq_len, seq_len**2])
y.append(memory_used)

except RuntimeError as e:
if "out of memory" in str(e) or "out_of_memory" in str(e):
# Early stopping for this batch size
seq_len_pbar.close()
break
else:
raise e
finally:
del batch
if "outputs" in vars():
del outputs
gc.collect()
torch.cuda.empty_cache()

# Check if we've hit the memory limit for all sequence lengths
if seq_len == start_seq_len:
batch_size_pbar.close()
break

mem_model = LinearRegression().fit(np.array(X), np.array(y))
joblib.dump(mem_model, mem_model_path)

return mem_model
131 changes: 58 additions & 73 deletions crossfit/backend/torch/hf/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,24 +19,67 @@
import joblib
import numpy as np
import torch
from sklearn.linear_model import LinearRegression
from tqdm import tqdm
from transformers import AutoConfig, AutoModel, AutoTokenizer

from crossfit.backend.torch.hf.memory_curve_utils import fit_memory_estimate_curve
from crossfit.backend.torch.model import Model
from crossfit.dataset.home import CF_HOME
from crossfit.utils.model_adapter import adapt_model_input


class HFModel(Model):
def __init__(self, path_or_name: str, max_mem_gb: int = 16, training=False):
def __init__(
self,
path_or_name: str,
max_mem_gb: int = 16,
training: bool = False,
start_batch_size: int = 1,
end_batch_size: int = 2048,
batch_size_increment: int = 256,
start_seq_len: int = 1,
seq_len_increment: int = 64,
):
super().__init__(path_or_name, max_mem_gb)

if not training:
with torch.no_grad():
self.fit_memory_estimate_curve()
self.start_batch_size = start_batch_size
self.end_batch_size = end_batch_size
self.batch_size_increment = batch_size_increment
self.start_seq_len = start_seq_len
self.seq_len_increment = seq_len_increment

cache_dir = os.path.join(
CF_HOME, "memory", AutoConfig.from_pretrained(path_or_name)._name_or_path
VibhuJawa marked this conversation as resolved.
Show resolved Hide resolved
)
os.makedirs(cache_dir, exist_ok=True)
mem_model_path = os.path.join(cache_dir, "mem_model.pkl")
if os.path.exists(mem_model_path):
self.mem = joblib.load(mem_model_path)
else:
self.fit_memory_estimate_curve()
model = self.load_model("cuda") if not training else None
end_seq_len = self.max_seq_length()
if not training:
with torch.no_grad():
self.mem = fit_memory_estimate_curve(
model=model,
path_or_name=self.path_or_name,
start_batch_size=start_batch_size,
end_batch_size=end_batch_size,
start_seq_len=start_seq_len,
end_seq_len=end_seq_len,
batch_size_increment=batch_size_increment,
seq_len_increment=seq_len_increment,
mem_model_path=mem_model_path,
)
else:
self.mem = fit_memory_estimate_curve(
model=model,
path_or_name=self.path_or_name,
start_batch_size=start_batch_size,
end_batch_size=end_batch_size,
start_seq_len=start_seq_len,
end_seq_len=end_seq_len,
batch_size_increment=batch_size_increment,
seq_len_increment=seq_len_increment,
mem_model_path=mem_model_path,
)

def load_on_worker(self, worker, device="cuda"):
worker.torch_model = self.load_model(device)
Expand All @@ -60,77 +103,19 @@ def load_tokenizer(self):
def load_cfg(self):
return AutoConfig.from_pretrained(self.path_or_name)

def fit_memory_estimate_curve(self, model=None):
remove_model = False
if model is None:
remove_model = True
model = self.load_model(device="cuda")

cache_dir = os.path.join(CF_HOME, "memory", self.load_cfg()._name_or_path)
mem_model_path = os.path.join(cache_dir, "mem_model.pkl")

if os.path.exists(mem_model_path):
self.mem = joblib.load(mem_model_path)

return self

print(f"Fitting memory estimate curve for model: {self.path_or_name}")

device = next(model.parameters()).device
X = []
y = []

max_seq = self.max_seq_length()
for batch_size in tqdm(range(2048, 0, -256)):
if batch_size <= 0:
continue

for seq_len in range(max_seq, 0, -64):
if seq_len <= 0:
continue

torch.cuda.reset_peak_memory_stats()

batch = {
"input_ids": torch.randint(1, 501, (batch_size, seq_len)).to(device=device),
"attention_mask": torch.ones((batch_size, seq_len)).to(device=device),
}

try:
_ = adapt_model_input(model, batch)
memory_used = torch.cuda.max_memory_allocated() / (1024**2) # Convert to MB
X.append([batch_size, seq_len, seq_len**2])
y.append(memory_used)

except RuntimeError as e:
if "out of memory" in str(e) or "out_of_memory" in str(e):
pass
else:
raise e
finally:
del batch
if "outputs" in vars():
del outputs
gc.collect()
torch.cuda.empty_cache()

self.mem = LinearRegression().fit(np.array(X), np.array(y))
os.makedirs(cache_dir, exist_ok=True)
joblib.dump(self.mem, mem_model_path)

if remove_model:
del model
gc.collect()
torch.cuda.empty_cache()

def estimate_memory(self, max_num_tokens: int, batch_size: int) -> int:
predicted_memory = self.mem.predict(
np.array([[batch_size, max_num_tokens, max_num_tokens**2]])
)
return predicted_memory[0] / 1024 # Convert from MB to GB

def max_seq_length(self) -> int:
return self.load_cfg().max_position_embeddings
max_seq_length = self.load_tokenizer().model_max_length
# Guard against the HF bug
# which sets max_seq_length to max(int) for some models
if max_seq_length > 1e5:
max_seq_length = AutoConfig.from_pretrained(self.path_or_name).max_position_embeddings
return max_seq_length


class SentenceTransformerModel(HFModel):
Expand Down
46 changes: 37 additions & 9 deletions crossfit/backend/torch/loader.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright 2023 NVIDIA CORPORATION
# Copyright 2024 NVIDIA CORPORATION
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand All @@ -22,6 +22,7 @@
from crossfit.data.array.conversion import convert_array
from crossfit.data.array.dispatch import crossarray
from crossfit.data.dataframe.dispatch import CrossFrame
from crossfit.op.tokenize import clip_tokens
from crossfit.utils.model_adapter import adapt_model_input

DEFAULT_BATCH_SIZE = 512
Expand All @@ -36,7 +37,14 @@ def __init__(self, data: Dict[str, torch.Tensor], batch_size: int, progress_bar=
def __init__(self, data: CrossFrame, batch_size: int, progress_bar=None):
...

def __init__(self, data, batch_size: int, progress_bar=None, max_seq_len=None):
def __init__(
self,
data,
batch_size: int,
progress_bar=None,
max_seq_len=None,
padding_side: str = "right",
):
self.data = CrossFrame(data).cast(torch.Tensor)
self.tensor_dict = self.data.to_dict()
self._batch_size = batch_size
Expand All @@ -45,6 +53,7 @@ def __init__(self, data, batch_size: int, progress_bar=None, max_seq_len=None):
self._to_map = []
self.progress_bar = progress_bar
self.max_seq_len = max_seq_len
self.padding_side = padding_side

def map(self, fn):
self._to_map.append(fn)
Expand All @@ -66,7 +75,10 @@ def __next__(self):

batch = {key: val[self.current_idx : end] for key, val in self.tensor_dict.items()}
if self.max_seq_len is not None:
batch = {key: val[:, : self.max_seq_len] for key, val in batch.items()}
if self.padding_side == "right":
batch = {key: val[:, : self.max_seq_len] for key, val in batch.items()}
else:
batch = {key: val[:, -self.max_seq_len :] for key, val in batch.items()}

self.current_idx += self.batch_size

Expand Down Expand Up @@ -97,14 +109,27 @@ def __init__(
self.to_ignore = to_ignore or []
self.to_ignore.append("seq_length")
self.model = model
tokenizer = self.model.load_tokenizer()
pad_token_id = tokenizer.pad_token_id
padding_side = tokenizer.padding_side

if padding_side not in ["right", "left"]:
raise ValueError("padding_side must be either 'right' or 'left'")

self.pad_token_id = pad_token_id
frame = CrossFrame(data).cast(torch.Tensor)
seq_length = (frame[sort_key] != 0).sum(axis=1)
seq_length = (frame[sort_key] != self.pad_token_id).sum(axis=1)
self.sorted_indices = seq_length.argsort(descending=True)
frame = frame.apply(lambda x: x[self.sorted_indices])
frame = frame.assign(seq_length=seq_length[self.sorted_indices])

super().__init__(frame, initial_batch_size, progress_bar=progress_bar)
super().__init__(
frame,
initial_batch_size,
progress_bar=progress_bar,
max_seq_len=self.model.max_seq_length(),
padding_side=padding_side,
)
self.splits = self._find_optimal_splits()

def sort_column(self, col):
Expand All @@ -128,8 +153,6 @@ def __next__(self):
else:
start = self.splits[self.current_idx - 1]

_tokens = self.tensor_dict["seq_length"]

end = min(self.splits[self.current_idx], self.num_rows)
while end > start:
try:
Expand All @@ -138,8 +161,13 @@ def __next__(self):
for key, val in self.tensor_dict.items()
if key not in self.to_ignore
}
clip_len = min(max(_tokens[start], _tokens[end - 1]), self.model.max_seq_length())
batch = {key: val[:, :clip_len] for key, val in batch.items()}
batch = clip_tokens(
token_o=batch,
max_length=self.max_seq_len,
padding_side=self.padding_side,
pad_token_id=self.pad_token_id,
return_type="pt",
)

for fn in self._to_map:
batch = adapt_model_input(fn, batch)
Expand Down
8 changes: 7 additions & 1 deletion crossfit/backend/torch/op/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
from crossfit.backend.torch.loader import DEFAULT_BATCH_SIZE, InMemoryLoader, SortedSeqLoader
from crossfit.backend.torch.model import Model
from crossfit.op.base import Op
from crossfit.utils.torch_utils import concat_and_pad_tensors


class Predictor(Op):
Expand Down Expand Up @@ -66,6 +67,7 @@ def call(self, data, partition_info=None):
loader = InMemoryLoader(
data[["input_ids", "attention_mask"]],
batch_size=self.batch_size,
padding_side=self.model.load_tokenizer().padding_side,
progress_bar=self.create_progress_bar(len(data), partition_info),
max_seq_len=self.model.max_seq_length(),
)
Expand All @@ -83,7 +85,11 @@ def call(self, data, partition_info=None):
all_outputs_ls.append(output)

out = cudf.DataFrame(index=index)
outputs = cp.asarray(torch.cat(all_outputs_ls, dim=0))
outputs = cp.asarray(
concat_and_pad_tensors(
all_outputs_ls, pad_token_id=loader.pad_token_id, padding_side=loader.padding_side
)
)
_index = loader.sort_column(index.values) if self.sorted_data_loader else index
if len(outputs.shape) <= 2:
out[self.pred_output_col] = create_list_series_from_1d_or_2d_ar(outputs, _index)
Expand Down
Loading
Loading