diff --git a/llmfoundry/models/mpt/modeling_mpt.py b/llmfoundry/models/mpt/modeling_mpt.py index 89a54e4ecf..0b1a80dbf5 100644 --- a/llmfoundry/models/mpt/modeling_mpt.py +++ b/llmfoundry/models/mpt/modeling_mpt.py @@ -630,8 +630,8 @@ def _resolve_reuse_state_layer_idx( ) if 'attn_config' not in parent_config: parent_config['attn_config'] = {} - parent_config['attn_config']['reuse_state_layer_idx'] = override_config[ - 'attn_config']['reuse_state_layer_idx'] + parent_config['attn_config'][reuse_type] = override_config['attn_config' + ][reuse_type] if override_config != parent_config and not ( 'allow_mismatch' in override_config and @@ -994,7 +994,7 @@ def forward( raise KeyError( f'kv cache for layer {attn_block.reuse_kv_x_layer_idx} not found in {layer_kv_x_cache_dict=}.', ) - x_prev = layer_kv_x_cache_dict[attn_block.reuse_kv_layer_idx] + x_prev = layer_kv_x_cache_dict[attn_block.reuse_kv_x_layer_idx] else: x_prev = None if output_hidden_states: