Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] compile graph breaks #2027

Draft
wants to merge 6 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 7 additions & 1 deletion recipes/lora_finetune_single_device.py
Original file line number Diff line number Diff line change
Expand Up @@ -664,6 +664,7 @@ def train(self) -> None:
)

# Initialize tokens count and running loss (for grad accumulation)
start = time.perf_counter()
t0 = time.perf_counter()
running_loss = 0
num_tokens = 0
Expand Down Expand Up @@ -730,6 +731,7 @@ def train(self) -> None:
# Log per-step metrics
if self.global_step % self._log_every_n_steps == 0:
time_per_step = time.perf_counter() - t0
print(time_per_step)
log_dict = {
"loss": loss_to_log,
"lr": self._optimizer.param_groups[0]["lr"],
Expand Down Expand Up @@ -773,13 +775,17 @@ def train(self) -> None:
self.epochs_run += 1
start_save_checkpoint = time.perf_counter()
log.info("Starting checkpoint save...")
self.save_checkpoint(epoch=curr_epoch)
# self.save_checkpoint(epoch=curr_epoch)
log.info(
"Checkpoint saved in {:.2f} seconds.".format(
time.perf_counter() - start_save_checkpoint
)
)

end = time.perf_counter()
time_total = end - start
print(f"{time_total=}")

def cleanup(self) -> None:
self._metric_logger.close()

Expand Down
21 changes: 1 addition & 20 deletions torchtune/models/llama3_2_vision/_component_builders.py
Original file line number Diff line number Diff line change
Expand Up @@ -338,7 +338,6 @@ def lora_llama3_2_vision_encoder(
fusion_lora: bool,
lora_attn_modules: List[LORA_ATTN_MODULES],
apply_lora_to_mlp: bool = False,
apply_lora_to_output: bool = False,
*,
# clip encoder parameters
patch_size: int,
Expand Down Expand Up @@ -377,8 +376,6 @@ def lora_llama3_2_vision_encoder(
``{"q_proj", "k_proj", "v_proj", "output_proj"}``.
apply_lora_to_mlp (bool): whether to apply LoRA to the MLP in each transformer layer.
Default: False
apply_lora_to_output (bool): whether to apply LoRA to the model's final output projection.
Default: False
patch_size (int): The size of each patch. Used to divide the tiles into patches.
E.g. for ``patch_size=40``, a tile of shape (400, 400) will have 10x10 grid of patches
with shape (40, 40) each.
Expand Down Expand Up @@ -412,7 +409,6 @@ def lora_llama3_2_vision_encoder(
lora_options = {
"lora_modules": lora_attn_modules,
"apply_lora_to_mlp": apply_lora_to_mlp,
"apply_lora_to_output": apply_lora_to_output,
"lora_rank": lora_rank,
"lora_alpha": lora_alpha,
"lora_dropout": lora_dropout,
Expand Down Expand Up @@ -679,7 +675,6 @@ def lora_llama3_2_vision_projection_head(
num_hidden_inputs: int,
# LoRA args
apply_lora_to_mlp: bool,
apply_lora_to_output: bool,
lora_rank: int,
lora_alpha: float,
lora_dropout: float = 0.0,
Expand All @@ -701,8 +696,6 @@ def lora_llama3_2_vision_projection_head(
num_hidden_inputs (int): number of hidden inputs to the projection head.
apply_lora_to_mlp (bool): whether to apply LoRA to the MLP in each transformer layer.
Default: False
apply_lora_to_output (bool): whether to apply LoRA to the model's final output projection.
Default: False
lora_rank (int): rank of each low-rank approximation
lora_alpha (float): scaling factor for the low-rank approximation
lora_dropout (float): LoRA dropout probability. Default: 0.0
Expand Down Expand Up @@ -773,19 +766,7 @@ def lora_llama3_2_vision_projection_head(
# cross encoding
# TODO: quantize_base is not applied to final output_proj currently.
proj_in = clip_embed_dim * (num_hidden_inputs + 1)
adapter_cls = DoRALinear if use_dora else LoRALinear
output_proj = (
adapter_cls(
proj_in,
decoder_embed_dim,
rank=lora_rank,
alpha=lora_alpha,
dropout=lora_dropout,
use_bias=True,
)
if apply_lora_to_output
else nn.Linear(proj_in, decoder_embed_dim)
)
output_proj = nn.Linear(proj_in, decoder_embed_dim)
return Llama3VisionProjectionHead(
layers=layers,
output=output_proj,
Expand Down
2 changes: 0 additions & 2 deletions torchtune/models/llama3_2_vision/_model_builders.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,7 +172,6 @@ def lora_llama3_2_vision_11b(
fusion_lora=fusion_type == LoRATrainable.LORA,
lora_attn_modules=lora_attn_modules,
apply_lora_to_mlp=apply_lora_to_mlp,
apply_lora_to_output=apply_lora_to_output,
patch_size=14,
num_heads=16,
clip_embed_dim=1280,
Expand Down Expand Up @@ -330,7 +329,6 @@ def lora_llama3_2_vision_90b(
fusion_lora=fusion_type == LoRATrainable.LORA,
lora_attn_modules=lora_attn_modules,
apply_lora_to_mlp=apply_lora_to_mlp,
apply_lora_to_output=apply_lora_to_output,
patch_size=14,
num_heads=16,
clip_embed_dim=1280,
Expand Down
11 changes: 6 additions & 5 deletions torchtune/modules/attention_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -191,11 +191,12 @@ def _attention_call(
# This will use flash attention under the hood with support for custom masks.
# Currently, it is used when sample packing is enabled (see torchtune.datasets.PackedDataset)
if isinstance(mask, BlockMask):
log_once(
_log,
"Using flex attention for attention computation since a BlockMask was passed in.",
level=logging.DEBUG,
)
if not torch.compiler.is_compiling(): # avoid graph break
log_once(
_log,
"Using flex attention for attention computation since a BlockMask was passed in.",
level=logging.DEBUG,
)
if dropout_p > 0.0:
raise ValueError(
"Flex attention does not support dropout. Please set dropout to 0.0."
Expand Down
6 changes: 6 additions & 0 deletions torchtune/modules/loss/ce_chunked_output_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,12 @@ def forward(self, logits: List[torch.Tensor], labels: torch.Tensor) -> torch.Ten
# compute one chunk at a time
total_loss = 0.0
for logits_chunk, labels_chunk in zip(logits, labels):

# avoid graph breaks when seq_len is not constant in the batch
torch._dynamo.mark_dynamic(logits_chunk, 0)
torch._dynamo.mark_dynamic(labels_chunk, 0)

# CE
total_loss += self.compute_cross_entropy(logits_chunk, labels_chunk)

return total_loss / total_elements
2 changes: 1 addition & 1 deletion torchtune/modules/peft/dora.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ def __init__(
self.use_bias = use_bias
self._quantize_base = quantize_base

if not self._quantize_base and quantization_kwargs:
if not self._quantize_base and any([v for v in quantization_kwargs.values()]):
raise ValueError(
f"``quantize_base`` is False, but received the following quantization arguments: {quantization_kwargs}"
)
Expand Down
2 changes: 1 addition & 1 deletion torchtune/modules/peft/lora.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ def __init__(
self.use_bias = use_bias
self._quantize_base = quantize_base

if not self._quantize_base and quantization_kwargs:
if not self._quantize_base and any([v for v in quantization_kwargs.values()]):
raise ValueError(
f"``quantize_base`` is False, but received the following quantization arguments: {quantization_kwargs}"
)
Expand Down
11 changes: 11 additions & 0 deletions torchtune/modules/transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -628,11 +628,22 @@ def forward(

# shape: [b, s, d]
h = self.tok_embeddings(tokens)
h.requires_grad = True # avoid graph breaks when using LoRA

hidden = []
for i, layer in enumerate(self.layers):
if i in self.output_hidden_states:
hidden.append(h)

# avoid graph breaks when seq_len is not constant in the batch
torch._dynamo.mark_dynamic(h, 1)
if mask is not None:
torch._dynamo.mark_dynamic(mask, 1)
if encoder_mask is not None:
torch._dynamo.mark_dynamic(encoder_mask, 1)
if input_pos is not None:
torch._dynamo.mark_dynamic(input_pos, 1)

# shape: [b, s, d]
h = layer(
h,
Expand Down
Loading