Skip to content

Commit

Permalink
refactor: standardized the hfconfig resolver ptrs
Browse files Browse the repository at this point in the history
  • Loading branch information
Rexwang8 committed May 17, 2024
1 parent 2a7ad6b commit 11c7b17
Showing 1 changed file with 8 additions and 11 deletions.
19 changes: 8 additions & 11 deletions resources/resolver.go
Original file line number Diff line number Diff line change
Expand Up @@ -757,7 +757,7 @@ func ResolveConfig(vocabId string, token string) (config *HFConfig,
resources = rslvdResources
}

var hfConfig HFConfig
var hfConfig *HFConfig
if configErr := json.Unmarshal(*((*resources)["config.json"]).Data,
&hfConfig); configErr != nil {
resources.Cleanup()
Expand Down Expand Up @@ -787,19 +787,19 @@ func ResolveConfig(vocabId string, token string) (config *HFConfig,
}
hfConfig.BosTokenStr = &bosToken

hfConfigPtr, err := ResolveHFFromResources(resources, hfConfig)
hfConfig, err = ResolveHFFromResources(resources, hfConfig)
if err != nil {
return nil, nil, err
}
hfConfig = *hfConfigPtr

return &hfConfig, resources, nil
return hfConfig, resources, nil

}

// ResolveHFFromResources
// Given a set of resources, resolve the HuggingFace configuration.
// Used to be able to resolve both embedded and local resources.
func ResolveHFFromResources(resources *Resources, hfConfig HFConfig) (*HFConfig, error) {
func ResolveHFFromResources(resources *Resources, hfConfig *HFConfig) (*HFConfig, error) {
//use interfaces to unmarsal the config file and tokenizer config file
var config interface{}
var tokenizerConfig interface{}
Expand Down Expand Up @@ -829,7 +829,6 @@ func ResolveHFFromResources(resources *Resources, hfConfig HFConfig) (*HFConfig,
if config != nil || tokenizerConfig != nil {
hasReadConfig := false
if config != nil {
fmt.Printf("Config resolution: %v\n", config)
//using interfaces, first check if bos_token is in string format
if bosToken, ok := config.(map[string]interface{})["bos_token"].(string); ok {
hfConfig.BosTokenStr = &bosToken
Expand All @@ -845,7 +844,6 @@ func ResolveHFFromResources(resources *Resources, hfConfig HFConfig) (*HFConfig,
if tokenizerConfig != nil && !hasReadConfig {
//using interfaces, first check if bos_token is in string format
if bosToken, ok := tokenizerConfig.(map[string]interface{})["bos_token"].(string); ok {
fmt.Printf("string tokenizer resolution\n %v\n", tokenizerConfig)
hfConfig.BosTokenStr = &bosToken
if eosToken, ok := tokenizerConfig.(map[string]interface{})["eos_token"].(string); ok {
hfConfig.EosTokenStr = &eosToken
Expand All @@ -858,7 +856,6 @@ func ResolveHFFromResources(resources *Resources, hfConfig HFConfig) (*HFConfig,
}
//if not, assume llama2 format and try to unmarshal
if !hasReadConfig {
fmt.Printf("tokenizer llama2 resolution\n %v\n", tokenizerConfig)
cfg := tokenizerConfig.(map[string]interface{})
if bosToken, ok := cfg["bos_token"].(map[string]interface{}); ok {
if content, ok := bosToken["content"].(string); ok {
Expand All @@ -876,7 +873,6 @@ func ResolveHFFromResources(resources *Resources, hfConfig HFConfig) (*HFConfig,
}
//if that doesn't work, assume mistral format
if !hasReadConfig {
fmt.Printf("tokenizer mistral resolution\n %v\n", tokenizerConfig)
if bosToken, ok := tokenizerConfig.(map[string]interface{})["bos_token"].(string); ok {
hfConfig.BosTokenStr = &bosToken
}
Expand All @@ -891,14 +887,15 @@ func ResolveHFFromResources(resources *Resources, hfConfig HFConfig) (*HFConfig,

}
fmt.Printf("Resolved config: %v\n", &hfConfig)
return &hfConfig, nil
return hfConfig, nil
}

// ResolveVocabId
// Resolves a vocabulary id to a set of resources, from embedded,
// local filesystem, or remote.
func ResolveVocabId(vocabId string, token string) (*HFConfig, *Resources, error) {
var resolvedVocabId string
fmt.Printf("Resolving vocab id: %s\n", vocabId)
if _, vocabErr := EmbeddedDirExists(vocabId); vocabErr == nil {
endOfText := "<|endoftext|>"
bosText := "<|startoftext|>"
Expand Down Expand Up @@ -941,7 +938,7 @@ func ResolveVocabId(vocabId string, token string) (*HFConfig, *Resources, error)
resources["tokenizer_config.json"] = *tokenizer_specials_config
}

hf, err := ResolveHFFromResources(&resources, *hf)
hf, err := ResolveHFFromResources(&resources, hf)
if err != nil {
return nil, nil, err
}
Expand Down

0 comments on commit 11c7b17

Please sign in to comment.