diff --git a/models/demos/llama3/PERF.md b/models/demos/llama3/PERF.md index f0dbf00ec4b0..69161dfff083 100644 --- a/models/demos/llama3/PERF.md +++ b/models/demos/llama3/PERF.md @@ -1,8 +1,8 @@ # Llama 3 model performance and accuracy -Performance collected from [demo/demo.py](demo/demo.py) and accuracy collected from [tests/test_llama_accuracy.py](tests/test_llama_accuracy.py). You can generate this table by running these tests with the `lt` tool (tell it to run `accuracy,demo`) and pressing `m` whilst in the results section to export to markdown. +Performance collected from [demo/demo.py](demo/demo.py) and accuracy collected from [tests/test_llama_accuracy.py](tests/test_llama_accuracy.py). You can generate this table by running these tests with the `lt` tool (tell it to run `table`) and pressing `m` whilst in the results section to export to markdown. -Note that `test_llama_accuracy.py` parses the below to determine expected values. +Note that `test_llama_accuracy.py` parses the below to determine expected values +- 0.5. ## LlamaOptimizations.performance @@ -10,18 +10,18 @@ This configuration uses bfp4 MLP FF1+FF3 for all models. | Model | Device | Top-1 (%) | Top-5 (%) | Speed (t/s/u) | |-------|--------|-----------|-----------|---------------| -| 1b | N150 | 79 | 98 | 90.5 | -| 1b | N300 | 81 | 98 | 101.7 | -| 1b | T3K | 81 | 98 | 96.8 | -| 3b | N150 | 85 | 96 | 49.0 | -| 3b | N300 | 88 | 97 | 56.9 | -| 3b | T3K | 88 | 97 | 54.5 | -| 8b | N150 | 86 | 98 | 28.4 | -| 8b | N300 | 84 | 98 | 38.6 | -| 8b | T3K | 84 | 97 | 52.6 | -| 11b | N300 | 86 | 97 | 38.6 | -| 11b | T3K | 84 | 98 | 52.6 | -| 70b | T3K | 94 | 100 | 14.3 | +| 1b | N150 | 88 | 98 | 85.6 | +| 1b | N300 | 88 | 98 | 93.6 | +| 1b | T3K | 88 | 98 | 90.5 | +| 3b | N150 | 89 | 98 | 46.3 | +| 3b | N300 | 91 | 98 | 52.8 | +| 3b | T3K | 89 | 98 | 52.0 | +| 8b | N150 | 87 | 98 | 27.5 | +| 8b | N300 | 86 | 98 | 36.5 | +| 8b | T3K | 84 | 97 | 46.7 | +| 11b | N300 | 88 | 98 | 36.4 | +| 11b | T3K | 87 | 98 | 46.8 | +| 70b | T3K | 94 | 100 | 13.9 | ## LlamaOptimizations.accuracy @@ -29,15 +29,15 @@ This configuration uses bfp4 MLP FF1+FF3 only for the 3.1-70B model. | Model | Device | Top-1 (%) | Top-5 (%) | Speed (t/s/u) | |-------|--------|-----------|-----------|---------------| -| 1b | N150 | 77 | 96 | 85.8 | -| 1b | N300 | 80 | 98 | 98.6 | -| 1b | T3K | 78 | 98 | 97.2 | -| 3b | N150 | 88 | 98 | 44.1 | -| 3b | N300 | 88 | 98 | 53.9 | -| 3b | T3K | 88 | 98 | 54.8 | -| 8b | N150 | 89 | 98 | 23.5 | -| 8b | N300 | 90 | 98 | 34.1 | -| 8b | T3K | 88 | 97 | 49.9 | -| 11b | N300 | 90 | 97 | 33.8 | -| 11b | T3K | 88 | 97 | 52.6 | -| 70b | T3K | 94 | 100 | 14.5 | +| 1b | N150 | 88 | 98 | 81.7 | +| 1b | N300 | 88 | 98 | 91.5 | +| 1b | T3K | 88 | 98 | 87.8 | +| 3b | N150 | 89 | 98 | 41.9 | +| 3b | N300 | 91 | 98 | 50.4 | +| 3b | T3K | 89 | 98 | 51.4 | +| 8b | N150 | 87 | 98 | 22.9 | +| 8b | N300 | 86 | 98 | 32.8 | +| 8b | T3K | 84 | 97 | 46.0 | +| 11b | N300 | 88 | 98 | 32.4 | +| 11b | T3K | 87 | 98 | 44.1 | +| 70b | T3K | 94 | 100 | 13.9 | diff --git a/models/demos/llama3/demo/demo.py b/models/demos/llama3/demo/demo.py index f3b5b998fcb1..090287ce610b 100644 --- a/models/demos/llama3/demo/demo.py +++ b/models/demos/llama3/demo/demo.py @@ -355,7 +355,11 @@ def run_llama3_demo( for batch_id in range(batch_size): prefill_seq_len = prefill_lens[batch_id] rot_mats_prefill = get_prefill_rot_mat( - model_args.head_dim, model_args.max_seq_len, mesh_device, seq_len=prefill_seq_len + model_args.head_dim, + model_args.max_seq_len, + mesh_device, + seq_len=prefill_seq_len, + scale_factor=model_args.rope_scaling_factor, ) if decoding_pos[batch_id] < prefill_seq_len: pt_prefill_input[batch_id][ diff --git a/models/demos/llama3/lt b/models/demos/llama3/lt index 8f68983a2b6b..afdc36ebcee2 100755 --- a/models/demos/llama3/lt +++ b/models/demos/llama3/lt @@ -304,8 +304,7 @@ def main(stdscr): if new_max_y != last_drawn_state["max_y"] or new_max_x != last_drawn_state["max_x"]: stdscr.clear() last_drawn_state["max_y"], last_drawn_state["max_x"] = new_max_y, new_max_x - last_drawn_state["input_fields"] = [] # Reset to force redraw - last_drawn_state["output_entries"] = [] # Reset to force redraw + last_drawn_state["current_line"] = -1 # Reset to force redraw screen_needs_update.set() if screen_needs_update.is_set(): @@ -322,6 +321,8 @@ def main(stdscr): if exiting and all( entry["status"] in ["Exiting", "Cancelled", "Error", "Finished"] for entry in output_entries ): + # Save state before exiting + output_entries.save_state() return if c == -1: @@ -383,6 +384,9 @@ def main(stdscr): if command_input == "tests": command_input = "embedding,rmsnorm,attention,attention-prefill,mlp,lm-head,decoder,decoder-prefill,model,model-prefill" + if command_input == "table": + command_input = "accuracy,demo,accuracy-acc,demo-acc" + # Parse models, devices, and commands models = parse_list(model_input) devices = parse_list(device_input) @@ -488,7 +492,14 @@ def main(stdscr): elif c == ord("m") and current_line >= len(input_fields): # Export results to markdown export_results_to_markdown(output_entries, stdscr) + last_drawn_state["current_line"] = -1 # Reset to force redraw screen_needs_update.set() + elif c == ord("p") and current_line >= len(input_fields): + # Reparse the selected entry's log file + entry_index = current_line - len(input_fields) + if entry_index < len(output_entries): + entry = output_entries[entry_index] + reparse_log_file(entry, screen_needs_update) else: if current_line < len(input_fields) and not exiting: current_field = current_line @@ -506,6 +517,7 @@ def define_color_pairs(): COLOR_LIGHT_RED = 174 # Light pastel red COLOR_LIGHT_PURPLE = 183 # Light pastel purple COLOR_GRAY = 250 # Light gray + COLOR_GRAY_DARK = 244 # Dark gray COLOR_WHITE = 15 # Bright white COLOR_BLACK = 16 # Black @@ -524,9 +536,10 @@ def define_color_pairs(): curses.init_pair(9, COLOR_LIGHT_GREEN, -1) # PCC > 0.99 curses.init_pair(10, COLOR_LIGHT_YELLOW, -1) # PCC 0.98-0.99 curses.init_pair(11, COLOR_LIGHT_RED, -1) # PCC < 0.98 + curses.init_pair(12, COLOR_GRAY_DARK, -1) # Accuracy percentages # Add a new color pair for the help bar - curses.init_pair(12, COLOR_LIGHT_CYAN, -1) + curses.init_pair(13, COLOR_LIGHT_CYAN, -1) # Store the color pair numbers for use in the rest of the program global COLOR_PAIR_SELECTED @@ -541,6 +554,7 @@ def define_color_pairs(): global COLOR_PAIR_PCC_GREEN global COLOR_PAIR_PCC_YELLOW global COLOR_PAIR_PCC_RED + global COLOR_PAIR_PCC_ACCURACY global COLOR_PAIR_HELP_BAR COLOR_PAIR_SELECTED = curses.color_pair(1) @@ -555,7 +569,8 @@ def define_color_pairs(): COLOR_PAIR_PCC_GREEN = curses.color_pair(9) COLOR_PAIR_PCC_YELLOW = curses.color_pair(10) COLOR_PAIR_PCC_RED = curses.color_pair(11) - COLOR_PAIR_HELP_BAR = curses.color_pair(12) + COLOR_PAIR_PCC_ACCURACY = curses.color_pair(12) + COLOR_PAIR_HELP_BAR = curses.color_pair(13) def draw_changes(stdscr, input_fields, output_entries, current_line, last_drawn_state): @@ -657,13 +672,16 @@ def draw_output_entry(stdscr, entry, y, is_selected, max_x): color = COLOR_PAIR_SPEED elif i == 5: # PCC column if col: - pcc_value = float(col) - if pcc_value > 0.99: - color = COLOR_PAIR_PCC_GREEN - elif 0.98 < pcc_value <= 0.99: - color = COLOR_PAIR_PCC_YELLOW - else: - color = COLOR_PAIR_PCC_RED + try: + pcc_value = float(col) + if pcc_value > 0.99: + color = COLOR_PAIR_PCC_GREEN + elif 0.98 < pcc_value <= 0.99: + color = COLOR_PAIR_PCC_YELLOW + else: + color = COLOR_PAIR_PCC_RED + except ValueError: + color = COLOR_PAIR_PCC_ACCURACY else: color = curses.color_pair(0) stdscr.addstr(y, x, col_text, color) @@ -731,7 +749,10 @@ def run_entry_command(entry, screen_lock, output_entries, screen_needs_update): # Define command shortcuts command_shortcuts = { + "accuracy": "pytest models/demos/llama3/tests/test_llama_accuracy.py -k 'performance and file'", + "accuracy-acc": "pytest models/demos/llama3/tests/test_llama_accuracy.py -k 'accuracy and file'", "demo": "pytest models/demos/llama3/demo/demo.py -k performance-batch-1", + "demo-acc": "pytest models/demos/llama3/demo/demo.py -k accuracy-batch-1", "demo-32": "pytest models/demos/llama3/demo/demo.py -k performance-batch-32", "demo-long": "pytest models/demos/llama3/demo/demo.py -k performance-long", "attention": "pytest models/demos/llama3/tests/test_llama_attention.py", @@ -759,7 +780,6 @@ def run_entry_command(entry, screen_lock, output_entries, screen_needs_update): "vision-encoder": "pytest models/demos/llama3/tests/multimodal/test_llama_vision_encoder.py", "vision-text-xfmr": "pytest models/demos/llama3/tests/multimodal/test_llama_cross_attention_transformer_text.py", "vision-vision-xfmr": "pytest models/demos/llama3/tests/multimodal/test_llama_cross_attention_transformer_vision.py", - "accuracy": "pytest models/demos/llama3/tests/test_llama_accuracy.py -k performance", } # Check if the command is a shortcut and replace it if necessary @@ -810,6 +830,7 @@ def process_output(entry, screen_lock, output_entries, screen_needs_update): # Update status and output based on output status, output, speed, pcc = parse_output_line(line, previous_line, entry.status) previous_line = line.strip() + with entry.lock: if status != entry.status or output or speed is not None or pcc is not None: entry.status = status # This will mark entry as changed via __setattr__ @@ -818,8 +839,14 @@ def process_output(entry, screen_lock, output_entries, screen_needs_update): if speed is not None: entry.speed = f"{speed:.1f}" if pcc is not None: - if entry.pcc is None or float(pcc) < float(entry.pcc): + try: + pcc_value = float(pcc) + if entry.pcc is None or pcc_value < float(entry.pcc): + entry.pcc = pcc + except ValueError: entry.pcc = pcc + # Save state whenever process status changes + output_entries.save_state() screen_needs_update.set() with screen_lock: @@ -844,10 +871,12 @@ def process_output(entry, screen_lock, output_entries, screen_needs_update): reset_device_async(entry, screen_lock, screen_needs_update) else: entry.status = "Finished" + # Save state when process completes + output_entries.save_state() entry.process = None log_file.close() - screen_needs_update.set() # Ensure screen is updated after process termination + screen_needs_update.set() def parse_output_line(line, previous_line, current_status): @@ -869,6 +898,12 @@ def parse_output_line(line, previous_line, current_status): pcc_match = re.search(r"PCC: (\d+\.\d+)", line) if pcc_match: pcc = f"{float(pcc_match.group(1)):.5f}" + else: + # Check for Top-1/Top-5 accuracy format + acc_match = re.search(r"Top-1: (\d+)% \| Top-5: (\d+)%", line) + if acc_match: + top1, top5 = acc_match.groups() + pcc = f"{top1.strip():<3s}|{top5.strip():>3s}" if "Initializing device" in line: return "Initializing device", None, speed, pcc @@ -1026,7 +1061,7 @@ def get_help_text(current_line, num_input_fields, num_output_entries): elif current_line <= num_input_fields - 1: return "Enter: Next field | ↑↓: Navigate fields | Esc: Exit" else: - return "Enter: View log | Backspace/x: Cancel entry | X: Cancel all | r: Restart entry | ↑↓: Navigate entries | Esc: Exit" + return "Enter: View log | Backspace/x: Cancel entry | X: Cancel all | r: Restart entry | p: Reparse log | ↑↓: Navigate entries | Esc: Exit" def cancel_entry(entry): @@ -1047,33 +1082,84 @@ def cancel_entry(entry): def export_results_to_markdown(output_entries, stdscr): - demo_results = {} - accuracy_results = {} + # Initialize ordered lists to maintain entry order + perf_entries = [] + acc_entries = [] - # Collect results from entries + # Collect results from entries in their original order for entry in output_entries: - if entry.command_name == "demo" and entry.status == "Finished": - demo_results[(entry.model, entry.device)] = entry.speed - elif entry.command_name == "accuracy" and entry.status == "Finished": - # Parse Top-1 and Top-5 from output - top1 = "N/A" - top5 = "N/A" - if entry.output: - match = re.search(r"Top-1: (\d+)% \| Top-5: (\d+)%", entry.output) - if match: - top1, top5 = match.groups() - accuracy_results[(entry.model, entry.device)] = (top1, top5) - - # Create markdown table + if entry.status == "Finished": + key = (entry.model, entry.device) + + if entry.command_name == "demo" or entry.command_name == "accuracy": + # Get speed from demo entry + speed = entry.speed if entry.command_name == "demo" else None + # Get accuracy from accuracy entry + top1, top5 = "N/A", "N/A" + if entry.command_name == "accuracy" and entry.pcc: + match = re.match(r"(\d+)\s*\|\s*(\d+)", entry.pcc) + if match: + top1, top5 = match.group(1), match.group(2) + + # Find existing entry or create new one + existing_entry = next((e for e in perf_entries if e[0] == key), None) + if existing_entry: + if speed: + existing_entry[3] = speed + if top1 != "N/A": + existing_entry[1:3] = [top1, top5] + else: + perf_entries.append([key, top1, top5, speed or "N/A"]) + + elif entry.command_name == "demo-acc" or entry.command_name == "accuracy-acc": + # Same logic for accuracy configuration + speed = entry.speed if entry.command_name == "demo-acc" else None + top1, top5 = "N/A", "N/A" + if entry.command_name == "accuracy-acc" and entry.pcc: + match = re.match(r"(\d+)\s*\|\s*(\d+)", entry.pcc) + if match: + top1, top5 = match.group(1), match.group(2) + + existing_entry = next((e for e in acc_entries if e[0] == key), None) + if existing_entry: + if speed: + existing_entry[3] = speed + if top1 != "N/A": + existing_entry[1:3] = [top1, top5] + else: + acc_entries.append([key, top1, top5, speed or "N/A"]) + + # Create markdown content markdown_lines = [ + "## LlamaOptimizations.performance", + "", + "This configuration uses bfp4 MLP FF1+FF3 for all models.", + "", "| Model | Device | Top-1 (%) | Top-5 (%) | Speed (t/s/u) |", "|-------|--------|-----------|-----------|---------------|", ] - for key in demo_results.keys(): - model, device = key - speed = demo_results.get(key, "N/A") - top1, top5 = accuracy_results.get(key, ("N/A", "N/A")) + # Add rows for performance table in original order + for entry in perf_entries: + (model, device), top1, top5, speed = entry + markdown_lines.append(f"| {model} | {device} | {top1} | {top5} | {speed} |") + + # Add accuracy table + markdown_lines.extend( + [ + "", + "## LlamaOptimizations.accuracy", + "", + "This configuration uses bfp4 MLP FF1+FF3 only for the 3.1-70B model.", + "", + "| Model | Device | Top-1 (%) | Top-5 (%) | Speed (t/s/u) |", + "|-------|--------|-----------|-----------|---------------|", + ] + ) + + # Add rows for accuracy table in original order + for entry in acc_entries: + (model, device), top1, top5, speed = entry markdown_lines.append(f"| {model} | {device} | {top1} | {top5} | {speed} |") # Write to PERF.md @@ -1083,10 +1169,52 @@ def export_results_to_markdown(output_entries, stdscr): # Clear screen and show message stdscr.clear() stdscr.addstr(0, 0, "\n".join(markdown_lines)) - stdscr.addstr(len(markdown_lines) + 1, 0, f"Table written to {os.path.abspath('PERF.md')}") - stdscr.addstr(len(markdown_lines) + 2, 0, "Press any key to return...") + stdscr.addstr(len(markdown_lines) + 2, 0, f"Table written to {os.path.abspath('PERF.md')}") + stdscr.addstr(len(markdown_lines) + 3, 0, "Press any key to return...") stdscr.refresh() - stdscr.getch() # Wait for a key press + + # Temporarily make getch() blocking + stdscr.nodelay(False) + + # Wait for a key press and flush input buffer + stdscr.getch() + curses.flushinp() + + # Restore non-blocking mode + stdscr.nodelay(True) + + +def reparse_log_file(entry, screen_needs_update): + """Reparse an entry's log file to update speed and pcc values.""" + try: + with open(entry.get_log_filename(), "r") as f: + previous_line = "" + status = entry.status # Preserve the current status + + # Reset speed and pcc before reparsing + entry.speed = None + entry.pcc = None + + for line in f: + new_status, output, speed, pcc = parse_output_line(line, previous_line, status) + previous_line = line.strip() + + if speed is not None: + entry.speed = f"{speed:.1f}" + if pcc is not None: + try: + pcc_value = float(pcc) + if entry.pcc is None or pcc_value < float(entry.pcc): + entry.pcc = pcc + except ValueError: + entry.pcc = pcc + if output: + entry.output = output + + screen_needs_update.set() + + except FileNotFoundError: + pass # Log file doesn't exist if __name__ == "__main__": diff --git a/models/demos/llama3/model_params/Llama3.1-70B-Instruct/params.json b/models/demos/llama3/model_params/Llama3.1-70B-Instruct/params.json index c358d8ce7b62..35700d4579d3 100755 --- a/models/demos/llama3/model_params/Llama3.1-70B-Instruct/params.json +++ b/models/demos/llama3/model_params/Llama3.1-70B-Instruct/params.json @@ -1 +1 @@ -{"dim": 8192, "n_layers": 80, "n_heads": 64, "n_kv_heads": 8, "vocab_size": 128256, "ffn_dim_multiplier": 1.3, "multiple_of": 4096, "norm_eps": 1e-05, "rope_theta": 500000.0, "use_scaled_rope": true} +{"dim": 8192, "n_layers": 80, "n_heads": 64, "n_kv_heads": 8, "vocab_size": 128256, "ffn_dim_multiplier": 1.3, "multiple_of": 4096, "norm_eps": 1e-05, "rope_theta": 500000.0, "use_scaled_rope": true, "rope_scaling_factor": 8} diff --git a/models/demos/llama3/model_params/Llama3.1-8B-Instruct/params.json b/models/demos/llama3/model_params/Llama3.1-8B-Instruct/params.json index e4be627b0658..353aff622ddd 100755 --- a/models/demos/llama3/model_params/Llama3.1-8B-Instruct/params.json +++ b/models/demos/llama3/model_params/Llama3.1-8B-Instruct/params.json @@ -1 +1 @@ -{"dim": 4096, "n_layers": 32, "n_heads": 32, "n_kv_heads": 8, "vocab_size": 128256, "ffn_dim_multiplier": 1.3, "multiple_of": 1024, "norm_eps": 1e-05, "rope_theta": 500000.0, "use_scaled_rope": true} +{"dim": 4096, "n_layers": 32, "n_heads": 32, "n_kv_heads": 8, "vocab_size": 128256, "ffn_dim_multiplier": 1.3, "multiple_of": 1024, "norm_eps": 1e-05, "rope_theta": 500000.0, "use_scaled_rope": true, "rope_scaling_factor": 8} diff --git a/models/demos/llama3/model_params/Llama3.2-11B-Vision-Instruct/params.json b/models/demos/llama3/model_params/Llama3.2-11B-Vision-Instruct/params.json index 5030f736b786..6de5f6a05cf5 100755 --- a/models/demos/llama3/model_params/Llama3.2-11B-Vision-Instruct/params.json +++ b/models/demos/llama3/model_params/Llama3.2-11B-Vision-Instruct/params.json @@ -8,6 +8,7 @@ "norm_eps": 1e-05, "rope_theta": 500000.0, "use_scaled_rope": true, + "rope_scaling_factor": 8, "vision_chunk_size": 560, "vision_max_num_chunks": 4, "vocab_size": 128256, diff --git a/models/demos/llama3/model_params/Llama3.2-1B-Instruct/params.json b/models/demos/llama3/model_params/Llama3.2-1B-Instruct/params.json index 37494ec5c5ba..2327e622e91a 100755 --- a/models/demos/llama3/model_params/Llama3.2-1B-Instruct/params.json +++ b/models/demos/llama3/model_params/Llama3.2-1B-Instruct/params.json @@ -8,5 +8,6 @@ "multiple_of": 256, "norm_eps": 1e-05, "rope_theta": 500000.0, - "use_scaled_rope": true + "use_scaled_rope": true, + "rope_scaling_factor": 32 } diff --git a/models/demos/llama3/model_params/Llama3.2-3B-Instruct/params.json b/models/demos/llama3/model_params/Llama3.2-3B-Instruct/params.json index 81d5e1c78fa6..35467179b809 100755 --- a/models/demos/llama3/model_params/Llama3.2-3B-Instruct/params.json +++ b/models/demos/llama3/model_params/Llama3.2-3B-Instruct/params.json @@ -8,5 +8,6 @@ "multiple_of": 256, "norm_eps": 1e-05, "rope_theta": 500000.0, - "use_scaled_rope": true + "use_scaled_rope": true, + "rope_scaling_factor": 32 } diff --git a/models/demos/llama3/tests/generate_reference_outputs.py b/models/demos/llama3/tests/generate_reference_outputs.py index e770e803c636..1f0514bfe7b7 100644 --- a/models/demos/llama3/tests/generate_reference_outputs.py +++ b/models/demos/llama3/tests/generate_reference_outputs.py @@ -11,38 +11,43 @@ from models.demos.llama3.tt.model_config import TtModelArgs from models.demos.t3000.llama2_70b.reference.llama.llama31_8b.tokenizer import Tokenizer from loguru import logger - - -def generate_reference_outputs(total_length, output_file): - # Load the model arguments - model_args = TtModelArgs(mesh_device=None) - tokenizer = Tokenizer(model_args.tokenizer_path) - - # Load the model state dict - state_dict = model_args.load_state_dict() - - # Initialize the reference model - state_dict_prefix = model_args.get_state_dict_prefix("", None) - reference_state_dict = { - k[len(state_dict_prefix) :]: v - for k, v in state_dict.items() - if ( - any([f"{state_dict_prefix}layers.{i}." in k for i in range(model_args.n_layers)]) - or any( - [ - f"{state_dict_prefix}{name}" in k - for name in ["tok_embeddings.weight", "norm.weight", "output.weight"] - ] +from transformers import AutoModelForCausalLM, AutoTokenizer + + +def generate_reference_outputs(total_length, output_file, hf_model_name=None): + if hf_model_name: + # HuggingFace path + tokenizer = AutoTokenizer.from_pretrained(hf_model_name) + model = AutoModelForCausalLM.from_pretrained(hf_model_name, torch_dtype=torch.float32) + model.eval() + else: + # Original path - load reference model + model_args = TtModelArgs(mesh_device=None) + model_args.max_seq_len = total_length + tokenizer = Tokenizer(model_args.tokenizer_path) + + state_dict = model_args.load_state_dict() + state_dict_prefix = model_args.get_state_dict_prefix("", None) + reference_state_dict = { + k[len(state_dict_prefix) :]: v + for k, v in state_dict.items() + if ( + any([f"{state_dict_prefix}layers.{i}." in k for i in range(model_args.n_layers)]) + or any( + [ + f"{state_dict_prefix}{name}" in k + for name in ["tok_embeddings.weight", "norm.weight", "output.weight"] + ] + ) ) - ) - } - reference_model = Transformer(model_args) - reference_model.load_state_dict(reference_state_dict) - reference_model.eval() # Set to evaluation mode + } + model = Transformer(model_args) + model.load_state_dict(reference_state_dict) + model.eval() - # Initialize HostEmbedding - embd = HostEmbedding(model_args) - embd.load_state_dict({"emb.weight": state_dict[f"{state_dict_prefix}tok_embeddings.weight"]}) + # Initialize HostEmbedding + embd = HostEmbedding(model_args) + embd.load_state_dict({"emb.weight": state_dict[f"{state_dict_prefix}tok_embeddings.weight"]}) # Load the book text and encode tokens current_file_path = os.path.abspath(__file__) @@ -52,8 +57,12 @@ def generate_reference_outputs(total_length, output_file): with bz2.open(prompt_file, "rt", encoding="utf-8") as f: text = f.read() - # Encode text to tokens - encoded_tokens = tokenizer.encode(text, bos=True, eos=False)[:total_length] + # Modify token encoding based on model type + if hf_model_name: + encoded_tokens = tokenizer.encode(text, add_special_tokens=True)[:total_length] + else: + encoded_tokens = tokenizer.encode(text, bos=True, eos=False)[:total_length] + encoded_tokens_tensor = torch.tensor(encoded_tokens).unsqueeze(0) # Shape [1, seq_len] print(f"{'Progress':<15}{'Correct':<8}{'Actual':<15}{'Top 5 Predictions':<75}") @@ -77,9 +86,13 @@ def generate_reference_outputs(total_length, output_file): # Trim input chunk if needed chunk_tokens = chunk_tokens[:, :actual_chunk_size] - # Process chunk - pt_decode_input = embd(chunk_tokens).view(1, actual_chunk_size, -1) - ref_output = reference_model(pt_decode_input, start_pos=chunk_start) + # Process chunk based on model type + if hf_model_name: + outputs = model(chunk_tokens) + ref_output = outputs.logits + else: + pt_decode_input = embd(chunk_tokens).view(1, actual_chunk_size, -1) + ref_output = model(pt_decode_input, start_pos=chunk_start) # Compute top-5 predictions probs = torch.softmax(ref_output, dim=-1) @@ -121,6 +134,9 @@ def generate_reference_outputs(total_length, output_file): if len(segment_accuracies) <= global_pos // 100: segment_accuracies.append((segment_top1_acc, segment_top5_acc)) + # Concatenate all top5 tokens into a single tensor + all_top5_tokens = torch.cat(all_top5_tokens, dim=0) # Shape: [total_tokens, 5] + # Save the data data = { "top5_tokens": all_top5_tokens, @@ -153,9 +169,10 @@ def main(): parser.add_argument( "--output_file", type=str, default="reference_outputs.pt", help="Output file path for reference data" ) + parser.add_argument("--model", type=str, help="Optional: HuggingFace model name (e.g., 'meta-llama/Llama-2-7b-hf')") args = parser.parse_args() - generate_reference_outputs(total_length=args.total_length, output_file=args.output_file) + generate_reference_outputs(total_length=args.total_length, output_file=args.output_file, hf_model_name=args.model) if __name__ == "__main__": diff --git a/models/demos/llama3/tests/multimodal/test_llama_cross_attention_transformer_text.py b/models/demos/llama3/tests/multimodal/test_llama_cross_attention_transformer_text.py index 1d9da2fbcca3..444a113e56c6 100644 --- a/models/demos/llama3/tests/multimodal/test_llama_cross_attention_transformer_text.py +++ b/models/demos/llama3/tests/multimodal/test_llama_cross_attention_transformer_text.py @@ -214,7 +214,11 @@ def test_llama_cross_attention_transformer_text_inference( ) rot_mats = get_prefill_rot_mat( - model_args.head_dim, model_args.max_seq_len, mesh_device, seq_len=seq_len + model_args.head_dim, + model_args.max_seq_len, + mesh_device, + seq_len=seq_len, + scale_factor=model_args.rope_scaling_factor, ) tt_out = tt_model( tt_h, diff --git a/models/demos/llama3/tests/reference_outputs/11b.refpt b/models/demos/llama3/tests/reference_outputs/11b.refpt index 712c5402bd0e..34d4cc72b78e 100644 Binary files a/models/demos/llama3/tests/reference_outputs/11b.refpt and b/models/demos/llama3/tests/reference_outputs/11b.refpt differ diff --git a/models/demos/llama3/tests/reference_outputs/1b.refpt b/models/demos/llama3/tests/reference_outputs/1b.refpt index 3c23bf00460c..a5efba465613 100644 Binary files a/models/demos/llama3/tests/reference_outputs/1b.refpt and b/models/demos/llama3/tests/reference_outputs/1b.refpt differ diff --git a/models/demos/llama3/tests/reference_outputs/3b.refpt b/models/demos/llama3/tests/reference_outputs/3b.refpt index 230ea1815cf7..5a7c48370c04 100644 Binary files a/models/demos/llama3/tests/reference_outputs/3b.refpt and b/models/demos/llama3/tests/reference_outputs/3b.refpt differ diff --git a/models/demos/llama3/tests/reference_outputs/70b.refpt b/models/demos/llama3/tests/reference_outputs/70b.refpt index 849ae4d7791e..6911cbaa75fc 100644 Binary files a/models/demos/llama3/tests/reference_outputs/70b.refpt and b/models/demos/llama3/tests/reference_outputs/70b.refpt differ diff --git a/models/demos/llama3/tests/reference_outputs/8b.refpt b/models/demos/llama3/tests/reference_outputs/8b.refpt index 17641169e16a..b5ee1619faa9 100644 Binary files a/models/demos/llama3/tests/reference_outputs/8b.refpt and b/models/demos/llama3/tests/reference_outputs/8b.refpt differ diff --git a/models/demos/llama3/tests/test_interleaved_to_sharded.py b/models/demos/llama3/tests/test_interleaved_to_sharded.py index 9edc9a89dd03..62a0a20dd2e3 100644 --- a/models/demos/llama3/tests/test_interleaved_to_sharded.py +++ b/models/demos/llama3/tests/test_interleaved_to_sharded.py @@ -63,7 +63,7 @@ def test_llama_decoder_inference(mesh_device, use_program_cache, reset_seeds): seqlen = 1 batch = model_args.max_batch_size - cos, sin = precompute_freqs(model_args.head_dim, model_args.max_seq_len * 2) + cos, sin = precompute_freqs(model_args.head_dim, model_args.max_seq_len * 2, model_args.rope_scaling_factor) freqs_cis = torch.complex(cos, sin) for i in range(generation_length): diff --git a/models/demos/llama3/tests/test_llama_accuracy.py b/models/demos/llama3/tests/test_llama_accuracy.py index b19cb086066d..8f38055a6f33 100644 --- a/models/demos/llama3/tests/test_llama_accuracy.py +++ b/models/demos/llama3/tests/test_llama_accuracy.py @@ -96,7 +96,14 @@ def get_accuracy_thresholds(model_name: str, device_name: str, optimizations: Ll "batch_size", (1,), ) -def test_tt_model_accuracy( +@pytest.mark.parametrize( + "use_reference_file", + [ + pytest.param(True, id="reference_file"), + pytest.param(False, id="reference_text"), + ], +) +def test_tt_model_acc( prefill_len, decode_len, max_seq_len, @@ -105,10 +112,15 @@ def test_tt_model_accuracy( page_params, optimizations, mesh_device, + use_reference_file, use_program_cache, reset_seeds, ensure_gc, + is_ci_env, ): + if is_ci_env and not use_reference_file: + pytest.skip("CI test only runs vs reference file") + dtype = ttnn.bfloat8_b mesh_device.enable_async(True) @@ -127,14 +139,27 @@ def test_tt_model_accuracy( # Load the reference data model_size = model_args.model_name.split("-")[1].lower() # e.g., "1b", "3b", "8b", "70b" - reference_data_file = f"models/demos/llama3/tests/reference_outputs/{model_size}.refpt" - logger.info(f"Loading reference data from {reference_data_file}") - assert os.path.exists( - reference_data_file - ), f"Reference data file {reference_data_file} does not exist, generate it with generate_reference_outputs.sh" - reference_data = torch.load(reference_data_file) - reference_tokens = reference_data["reference_tokens"] - top5_tokens = reference_data["top5_tokens"] + + if use_reference_file: + # Existing reference file loading logic + reference_data_file = f"models/demos/llama3/tests/reference_outputs/{model_size}.refpt" + logger.info(f"Loading reference data from {reference_data_file}") + assert os.path.exists(reference_data_file) + reference_data = torch.load(reference_data_file) + reference_tokens = reference_data["reference_tokens"] + top5_tokens = reference_data["top5_tokens"] + else: + # Load and encode the reference text + current_file_path = os.path.dirname(os.path.abspath(__file__)) + prompt_file = os.path.join(current_file_path, "tale-of-two-cities.txt.bz2") + with bz2.open(prompt_file, "rt", encoding="utf-8") as f: + text = f.read() + + # Encode text to tokens + encoded_tokens = tokenizer.encode(text, bos=True, eos=False) + total_length = prefill_len + decode_len + 1 + reference_tokens = torch.tensor(encoded_tokens[:total_length]).unsqueeze(0) + top5_tokens = None # Will be computed during inference N = prefill_len + decode_len input_ids = reference_tokens[:, : N + 1] # Shape [1, N+1] @@ -198,7 +223,11 @@ def test_tt_model_accuracy( # Pre-compute the rotational embedding matrix and send to device rot_mats_prefill = get_prefill_rot_mat( - model_args.head_dim, model_args.max_seq_len, mesh_device, seq_len=prefill_lens[0] + model_args.head_dim, + model_args.max_seq_len, + mesh_device, + seq_len=prefill_lens[0], + scale_factor=model_args.rope_scaling_factor, ) prefill_input = model_args.prepare_residual_tensor_prefill( @@ -234,8 +263,11 @@ def test_tt_model_accuracy( rot_mats = tt_model.rope_setup.get_rot_mats(current_pos) # Print table header - logger.info(f"{'Progress':<15}{'Correct':<8}{'True':<15}{'Actual':<15}{'Top 5 Predictions':<75}") - logger.info("-" * 128) + if use_reference_file: + logger.info(f"{'Progress':<15}{'Correct':<8}{'True':<15}{'Actual':<15}{'Top 5 Predictions':<75}") + else: + logger.info(f"{'Progress':<15}{'Correct':<8}{'True':<15}{'Top 5 Predictions':<75}") + logger.info("-" * 113) top1_correct = [] top5_correct = [] @@ -276,30 +308,47 @@ def test_tt_model_accuracy( dim=3, use_multicore=True if model_args.max_batch_size == 1 else False, ) + if not use_reference_file: + tt_logits = ttnn.to_torch(tt_out_rm, mesh_composer=ttnn.ConcatMeshToTensor(mesh_device, dim=1))[0, 0, 0, :] + ttnn.deallocate(tt_out_rm) + tt_argmax_token = ttnn.to_torch(tt_out_tok, mesh_composer=ttnn.ConcatMeshToTensor(mesh_device, dim=1))[ 0, 0, 0, 0 ] - ttnn.deallocate(tt_out_rm) + ttnn.plus_one(current_pos_tensor) # Update rot_mats for next iteration current_pos += 1 rot_mats = tt_model.rope_setup.get_rot_mats(current_pos) - # Get reference top5 tokens and probabilities for this position - ref_top5_tokens = top5_tokens[prefill_len + i] + # Modify the accuracy checking section when using reference text + if not use_reference_file: + # Get probabilities from model output + probs = torch.softmax(tt_logits, dim=-1) + _, tt_top5_tokens = torch.topk(probs, k=5, dim=-1) + + # Check against actual next token + true_token = input_ids[0, prefill_len + i + 1].item() + top1_match = tt_argmax_token.item() == true_token + top5_match = true_token in tt_top5_tokens + ref_top5_text = [tokenizer.decode([t]) for t in tt_top5_tokens] + else: + # Existing logic for reference file comparison + ref_top5_tokens = top5_tokens[prefill_len + i] + top1_match = tt_argmax_token.item() == ref_top5_tokens[0].item() + top5_match = tt_argmax_token in ref_top5_tokens + ref_top5_text = [tokenizer.decode([t]) for t in ref_top5_tokens] # Check top-1 and top-5 accuracy - top1_match = tt_argmax_token.item() == ref_top5_tokens[0].item() top1_correct.append(top1_match) - top5_match = tt_argmax_token in ref_top5_tokens top5_correct.append(top5_match) true_match = ( tt_argmax_token.item() == input_ids[0, prefill_len + i + 1].item() if i < generation_length - 1 else False ) - # Store error information if top5 is incorrect - if not top5_match: + # Store error information vs reference model if top5 is incorrect + if use_reference_file and not top5_match: context_start = max(0, prefill_len + i - 9) context_tokens = input_ids[0, context_start : prefill_len + i + 1] context_text = tokenizer.decode(context_tokens.tolist()) @@ -316,13 +365,13 @@ def test_tt_model_accuracy( } ) + sanitize = lambda x: repr(x)[1:-1] # Use repr() and remove the outer quotes + # Decode tokens to text tt_argmax_text = tokenizer.decode([tt_argmax_token]) true_text = tokenizer.decode([true_token]) if true_token is not None else "N/A" - ref_top5_text = [tokenizer.decode([t]) for t in ref_top5_tokens] # Prepare table row - sanitize = lambda x: repr(x)[1:-1] # Use repr() and remove the outer quotes progress_str = f"{i+1}/{generation_length}" correct = "x" if top1_match else ("-" if top5_match else ("!" if true_match else " ")) tt_argmax_text = sanitize(tt_argmax_text) @@ -330,7 +379,10 @@ def test_tt_model_accuracy( ref_top5_str = " ".join(f"{sanitize(t):<14}" for t in ref_top5_text) # Print table row - logger.info(f"{progress_str:<15}{correct:<8}{true_text:<15}{tt_argmax_text:<15}{ref_top5_str}") + if use_reference_file: + logger.info(f"{progress_str:<15}{correct:<8}{true_text:<15}{tt_argmax_text:<15}{ref_top5_str}") + else: + logger.info(f"{progress_str:<15}{correct:<8}{true_text:<15}{ref_top5_str}") # Compute accuracies over every 100 tokens num_tokens = len(top1_correct) @@ -352,17 +404,19 @@ def test_tt_model_accuracy( f"Total tokens {num_tokens}: Top-1 accuracy: {total_top1_acc:3.0f} %, Top-5 accuracy: {total_top5_acc:3.0f} %" ) - logger.info("\nError Summary (only showing errors where reference top-1 matches true token):") - logger.info("-" * 120) - for error in errors: - true_token = input_ids[0, error["position"] + 1].item() - if error["expected_ids"][0] == true_token: - sanitize = lambda x: repr(x)[1:-1] # Use repr() and remove the outer quotes - context = sanitize(error["context"]) - incorrect = sanitize(error["incorrect"]) - expected = " | ".join(sanitize(t) for t in error["expected"]) - true_word = sanitize(tokenizer.decode([true_token])) - logger.info(f"{error['position']}: {context}[{incorrect}] != [{expected}], true: [{true_word}]") + # Only show error summary when using reference files + if use_reference_file: + logger.info("\nError Summary (only showing errors where reference top-1 matches true token):") + logger.info("-" * 120) + for error in errors: + true_token = input_ids[0, error["position"] + 1].item() + if error["expected_ids"][0] == true_token: + sanitize = lambda x: repr(x)[1:-1] # Use repr() and remove the outer quotes + context = sanitize(error["context"]) + incorrect = sanitize(error["incorrect"]) + expected = " | ".join(sanitize(t) for t in error["expected"]) + true_word = sanitize(tokenizer.decode([true_token])) + logger.info(f"{error['position']}: {context}[{incorrect}] != [{expected}], true: [{true_word}]") # Get accuracy thresholds from PERF.md min_top1_acc, min_top5_acc = get_accuracy_thresholds( @@ -372,5 +426,9 @@ def test_tt_model_accuracy( ) logger.info(f"Top-1: {total_top1_acc:.0f}% | Top-5: {total_top5_acc:.0f}%") - assert total_top1_acc > min_top1_acc, f"Top-1 accuracy {total_top1_acc:.1f}% is too low (expected >{min_top1_acc}%)" - assert total_top5_acc > min_top5_acc, f"Top-5 accuracy {total_top5_acc:.1f}% is too low (expected >{min_top5_acc}%)" + assert ( + total_top1_acc >= min_top1_acc + ), f"Top-1 accuracy {total_top1_acc:.1f}% is too low (expected >{min_top1_acc}%)" + assert ( + total_top5_acc >= min_top5_acc + ), f"Top-5 accuracy {total_top5_acc:.1f}% is too low (expected >{min_top5_acc}%)" diff --git a/models/demos/llama3/tests/test_llama_attention.py b/models/demos/llama3/tests/test_llama_attention.py index edb9ac99a43b..6a40ead9f9c8 100644 --- a/models/demos/llama3/tests/test_llama_attention.py +++ b/models/demos/llama3/tests/test_llama_attention.py @@ -98,6 +98,7 @@ def test_llama_attention_inference( model_args.max_seq_len, model_args.rope_theta, model_args.use_scaled_rope, + model_args.rope_scaling_factor, ) transformation_mats = rope_setup.get_both_trans_mats() @@ -138,7 +139,11 @@ def test_llama_attention_inference( ) cos, sin = precompute_freqs( - model_args.head_dim, model_args.max_seq_len * 2, model_args.rope_theta, model_args.use_scaled_rope + model_args.head_dim, + model_args.max_seq_len * 2, + model_args.rope_theta, + model_args.use_scaled_rope, + model_args.rope_scaling_factor, ) freqs_cis = torch.complex(cos, sin) diff --git a/models/demos/llama3/tests/test_llama_attention_prefill.py b/models/demos/llama3/tests/test_llama_attention_prefill.py index 4335bdb4ee1a..59a4147b3737 100644 --- a/models/demos/llama3/tests/test_llama_attention_prefill.py +++ b/models/demos/llama3/tests/test_llama_attention_prefill.py @@ -88,7 +88,13 @@ def test_llama_attention_inference( reference_model.load_state_dict(partial_state_dict) # pre-compute the rotational embedding matrix and send to device - rot_mats = get_prefill_rot_mat(model_args.head_dim, model_args.max_seq_len, mesh_device, seq_len=max_seq_len) + rot_mats = get_prefill_rot_mat( + model_args.head_dim, + model_args.max_seq_len, + mesh_device, + seq_len=max_seq_len, + scale_factor=model_args.rope_scaling_factor, + ) transformation_mat_torch = get_rot_transformation_mat(model_args.head_dim) transformation_mats_prefill = ttnn.as_tensor( transformation_mat_torch, @@ -162,7 +168,11 @@ def test_llama_attention_inference( positions = torch.LongTensor(range(max_seq_len)) freqs_cis_i = precompute_freqs_cis( - model_args.head_dim, model_args.max_seq_len * 2, model_args.rope_theta, model_args.use_scaled_rope + model_args.head_dim, + model_args.max_seq_len * 2, + model_args.rope_theta, + model_args.use_scaled_rope, + model_args.rope_scaling_factor, )[positions] attn_mask = torch.full((max_seq_len, max_seq_len), torch.finfo(torch.float32).min) attn_mask_torch = torch.triu(attn_mask, diagonal=1) diff --git a/models/demos/llama3/tests/test_llama_decoder.py b/models/demos/llama3/tests/test_llama_decoder.py index 316c811aaf3e..f2ba83777b9e 100644 --- a/models/demos/llama3/tests/test_llama_decoder.py +++ b/models/demos/llama3/tests/test_llama_decoder.py @@ -93,6 +93,7 @@ def test_llama_decoder_inference( model_args.max_seq_len, model_args.rope_theta, model_args.use_scaled_rope, + model_args.rope_scaling_factor, ) transformation_mats = rope_setup.get_both_trans_mats() @@ -135,7 +136,11 @@ def test_llama_decoder_inference( seqlen = 1 cos, sin = precompute_freqs( - model_args.head_dim, model_args.max_seq_len * 2, model_args.rope_theta, model_args.use_scaled_rope + model_args.head_dim, + model_args.max_seq_len * 2, + model_args.rope_theta, + model_args.use_scaled_rope, + model_args.rope_scaling_factor, ) freqs_cis = torch.complex(cos, sin) diff --git a/models/demos/llama3/tests/test_llama_decoder_prefill.py b/models/demos/llama3/tests/test_llama_decoder_prefill.py index 622e67f91b41..e2f63e95f461 100644 --- a/models/demos/llama3/tests/test_llama_decoder_prefill.py +++ b/models/demos/llama3/tests/test_llama_decoder_prefill.py @@ -90,7 +90,13 @@ def test_llama_decoder_inference( all_tests_pass = True # pre-compute the rotational embedding matrix and send to device - rot_mats = get_prefill_rot_mat(model_args.head_dim, model_args.max_seq_len, mesh_device, seq_len=max_seq_len) + rot_mats = get_prefill_rot_mat( + model_args.head_dim, + model_args.max_seq_len, + mesh_device, + seq_len=max_seq_len, + scale_factor=model_args.rope_scaling_factor, + ) transformation_mat_torch = get_rot_transformation_mat(model_args.head_dim) transformation_mats_prefill = ttnn.as_tensor( transformation_mat_torch, @@ -147,7 +153,11 @@ def test_llama_decoder_inference( ) positions = torch.LongTensor(range(max_seq_len)) freqs_cis_i = precompute_freqs_cis( - model_args.head_dim, model_args.max_seq_len * 2, model_args.rope_theta, model_args.use_scaled_rope + model_args.head_dim, + model_args.max_seq_len * 2, + model_args.rope_theta, + model_args.use_scaled_rope, + model_args.rope_scaling_factor, )[positions] # Reference model diff --git a/models/demos/llama3/tests/test_llama_model.py b/models/demos/llama3/tests/test_llama_model.py index 37e0e4384192..e0ebb3a5f579 100644 --- a/models/demos/llama3/tests/test_llama_model.py +++ b/models/demos/llama3/tests/test_llama_model.py @@ -7,7 +7,6 @@ import os import ttnn from models.demos.llama3.tt.llama_common import ( - precompute_freqs, sample_host, encode_prompt_llama_instruct, HostEmbedding, diff --git a/models/demos/llama3/tests/test_llama_model_prefill.py b/models/demos/llama3/tests/test_llama_model_prefill.py index e30c25cc8f47..2bc01289d038 100644 --- a/models/demos/llama3/tests/test_llama_model_prefill.py +++ b/models/demos/llama3/tests/test_llama_model_prefill.py @@ -16,7 +16,7 @@ ) from models.demos.llama3.tt.llama_model import TtTransformer from models.demos.llama3.tt.model_config import TtModelArgs, LlamaOptimizations -from models.demos.t3000.llama2_70b.reference.llama.llama31_8b.model import Transformer, precompute_freqs_cis +from models.demos.t3000.llama2_70b.reference.llama.llama31_8b.model import Transformer from models.demos.t3000.llama2_70b.reference.llama.llama31_8b.tokenizer import Tokenizer from models.utility_functions import ( comp_pcc, @@ -142,7 +142,13 @@ def test_llama_model_inference( embd.load_state_dict({"emb.weight": state_dict[f"{state_dict_prefix}tok_embeddings.weight"]}) # pre-compute the rotational embedding matrix and send to device - rot_mats = get_prefill_rot_mat(model_args.head_dim, model_args.max_seq_len, mesh_device, seq_len=seq_len) + rot_mats = get_prefill_rot_mat( + model_args.head_dim, + model_args.max_seq_len, + mesh_device, + seq_len=seq_len, + scale_factor=model_args.rope_scaling_factor, + ) # Setup page table page_table_tt = None paged_attention_config = None diff --git a/models/demos/llama3/tt/llama_common.py b/models/demos/llama3/tt/llama_common.py index fd7f368557f4..ee43a74281e6 100644 --- a/models/demos/llama3/tt/llama_common.py +++ b/models/demos/llama3/tt/llama_common.py @@ -44,10 +44,9 @@ def encode_prompt_llama_instruct(tokenizer, prompt_text, system_prompt_text=None return begin_of_text + system_prompt + user_prompt + assistant_reply -def apply_scaling(freqs: torch.Tensor): - # Llama-3.1 specific scaling +def apply_scaling(freqs: torch.Tensor, scale_factor: float = 8): + # Llama-3.x specific scaling # Values obtained from grid search - scale_factor = 8 low_freq_factor = 1 high_freq_factor = 4 old_context_len = 8192 # original llama3 length @@ -68,7 +67,7 @@ def apply_scaling(freqs: torch.Tensor): return torch.tensor(new_freqs, dtype=freqs.dtype, device=freqs.device) -def precompute_freqs(dim: int, end: int, theta: float = 500000.0, use_scaled: bool = True): +def precompute_freqs(dim: int, end: int, theta: float = 500000.0, use_scaled: bool = True, scale_factor: float = 8): """ Precompute the frequency tensor for sine and cosine values with given dimensions. @@ -83,7 +82,7 @@ def precompute_freqs(dim: int, end: int, theta: float = 500000.0, use_scaled: bo freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim)) t = torch.arange(end) if use_scaled: - freqs = apply_scaling(freqs) + freqs = apply_scaling(freqs, scale_factor) freqs = torch.outer(t, freqs).float() return torch.cos(freqs), torch.sin(freqs) @@ -113,8 +112,8 @@ def gather_cos_sin(position_ids, cos, sin): return cos, sin -def get_prefill_rot_mat(head_dim, max_seq_len, mesh_device, seq_len): - cos, sin = precompute_freqs(head_dim, max_seq_len * 2) +def get_prefill_rot_mat(head_dim, max_seq_len, mesh_device, seq_len, scale_factor): + cos, sin = precompute_freqs(head_dim, max_seq_len * 2, scale_factor=scale_factor) cos_gathered, sin_gathered = gather_cos_sin(torch.arange(0, seq_len), cos, sin) assert cos_gathered.size() == (1, 1, seq_len, head_dim) assert sin_gathered.size() == (1, 1, seq_len, head_dim) diff --git a/models/demos/llama3/tt/llama_model.py b/models/demos/llama3/tt/llama_model.py index 9c55182115f8..79be949eb262 100644 --- a/models/demos/llama3/tt/llama_model.py +++ b/models/demos/llama3/tt/llama_model.py @@ -55,6 +55,7 @@ def __init__( args.max_seq_len, args.rope_theta, args.use_scaled_rope, + args.rope_scaling_factor, ) self.trans_mats_dict = self.rope_setup.get_both_trans_mats() @@ -117,7 +118,11 @@ def prepare_inputs_prefill(self, tokens, page_table=None): tokens_embd = self.embd(tokens) tt_rot_mats_prefill = get_prefill_rot_mat( - self.args.head_dim, self.args.max_seq_len, self.mesh_device, seq_len=S + self.args.head_dim, + self.args.max_seq_len, + self.mesh_device, + seq_len=S, + scale_factor=self.args.rope_scaling_factor, ) if page_table is not None: diff --git a/models/demos/llama3/tt/llama_rope.py b/models/demos/llama3/tt/llama_rope.py index c1b982308bc3..aad1e774ba4e 100644 --- a/models/demos/llama3/tt/llama_rope.py +++ b/models/demos/llama3/tt/llama_rope.py @@ -11,8 +11,8 @@ from loguru import logger -def compute_gather_cos_sin(dhead, end, theta, position_ids, use_scaled_rope): - cos, sin = precompute_freqs(dhead, end, theta, use_scaled_rope) +def compute_gather_cos_sin(dhead, end, theta, position_ids, use_scaled_rope, scale_factor): + cos, sin = precompute_freqs(dhead, end, theta, use_scaled_rope, scale_factor) return gather_cos_sin(position_ids, cos, sin) @@ -25,6 +25,7 @@ def __init__( max_seq_len: int, rope_theta: float = 10000, use_scaled_rope: bool = False, + scale_factor: float = 8, datatype=ttnn.bfloat16, ): super().__init__() @@ -45,6 +46,7 @@ def __init__( theta=rope_theta, position_ids=torch.arange(max_seq_len), use_scaled_rope=use_scaled_rope, + scale_factor=scale_factor, ) self.cos_matrix = ttnn.from_torch( diff --git a/models/demos/llama3/tt/model_config.py b/models/demos/llama3/tt/model_config.py index 4ddb684fe9c5..81be425f4470 100644 --- a/models/demos/llama3/tt/model_config.py +++ b/models/demos/llama3/tt/model_config.py @@ -136,19 +136,25 @@ def __init__( if "3.2-1B" in LLAMA_DIR: local_params = "LLAMA3_2_1B_PARAMS" self.model_name = "3.2-1B" + self.rope_scaling_factor = 32 elif "3.2-3B" in LLAMA_DIR: local_params = "LLAMA3_2_3B_PARAMS" self.model_name = "3.2-3B" + self.rope_scaling_factor = 32 elif "3.1-8B" in LLAMA_DIR: local_params = "LLAMA3_1_8B_PARAMS" self.model_name = "3.1-8B" + self.rope_scaling_factor = 8 elif "3.2-11B" in LLAMA_DIR: local_params = "LLAMA3_2_11B_PARAMS" self.model_name = "3.2-11B" + self.rope_scaling_factor = 8 # shared with 3.1-8B elif "3.1-70B" in LLAMA_DIR: local_params = "LLAMA3_1_70B_PARAMS" self.model_name = "3.1-70B" + self.rope_scaling_factor = 8 else: + # NOTE: 3.2-90B and 3.3-70B also use scaling factor of 8 raise ValueError(f"Unsupported LLAMA model: {LLAMA_DIR}") if callable(optimizations): @@ -193,7 +199,7 @@ def __init__( self.model_config.update({f"{key}_TILE": ttnn.TILE_LAYOUT for key in self.OP_KEYS if "LAYOUT" in key}) self.cos, self.sin = precompute_freqs( - self.head_dim, self.max_seq_len * 2, self.rope_theta, self.use_scaled_rope + self.head_dim, self.max_seq_len * 2, self.rope_theta, self.use_scaled_rope, self.rope_scaling_factor ) # for prefill self.rot_emb = freqs_to_rotation_matrix(self.cos, self.sin) # for decode diff --git a/models/demos/llama3/tt/multimodal/llama_cross_attention_transformer_text.py b/models/demos/llama3/tt/multimodal/llama_cross_attention_transformer_text.py index a7ce9def430e..1c05ec06e1cf 100644 --- a/models/demos/llama3/tt/multimodal/llama_cross_attention_transformer_text.py +++ b/models/demos/llama3/tt/multimodal/llama_cross_attention_transformer_text.py @@ -127,6 +127,7 @@ def __init__( configuration.max_seq_len, configuration.rope_theta, configuration.use_scaled_rope, + configuration.rope_scaling_factor, ) self.trans_mats_dict = self.rope_setup.get_both_trans_mats() diff --git a/models/demos/llama3/tt/multimodal/llama_vision_model.py b/models/demos/llama3/tt/multimodal/llama_vision_model.py index 0b1f36fd6f4e..ef400f99275b 100644 --- a/models/demos/llama3/tt/multimodal/llama_vision_model.py +++ b/models/demos/llama3/tt/multimodal/llama_vision_model.py @@ -333,7 +333,11 @@ def prepare_inputs_prefill( h, ) rot_mats = get_prefill_rot_mat( - self.configuration.head_dim, self.configuration.max_seq_len, self.mesh_device, seq_len=S + self.configuration.head_dim, + self.configuration.max_seq_len, + self.mesh_device, + seq_len=S, + scale_factor=self.configuration.rope_scaling_factor, ) full_text_mask_expand_11SD = full_text_mask.expand(-1, -1, -1, self.configuration.dim) diff --git a/models/demos/t3000/llama2_70b/reference/llama b/models/demos/t3000/llama2_70b/reference/llama index dc6852f56c7a..29125b7ad8b5 160000 --- a/models/demos/t3000/llama2_70b/reference/llama +++ b/models/demos/t3000/llama2_70b/reference/llama @@ -1 +1 @@ -Subproject commit dc6852f56c7a49acd96624c153e565e58e5a6ca0 +Subproject commit 29125b7ad8b5513eeaa4417ed92892bf39c8bd74