Skip to content
This repository has been archived by the owner on Jul 24, 2023. It is now read-only.

Commit

Permalink
renamed type "expected" to "item", and error handling for model options
Browse files Browse the repository at this point in the history
  • Loading branch information
shixzie committed Jul 2, 2017
1 parent 6482975 commit 20f99e6
Show file tree
Hide file tree
Showing 2 changed files with 31 additions and 18 deletions.
47 changes: 30 additions & 17 deletions nlp.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,13 @@ package nlp

import (
"bytes"
"errors"
"fmt"
"reflect"
"strconv"
"strings"
"time"
"unicode"

"github.com/Shixzie/nlp/parser"
"github.com/cdipaolo/goml/base"
Expand Down Expand Up @@ -69,13 +71,13 @@ func (nl *NL) Learn() error {
type model struct {
tpy reflect.Type
fields []field
expected [][]expected
expected [][]item
samples []string
timeFormat string
timeLocation *time.Location
}

type expected struct {
type item struct {
limit bool
value string
field field
Expand All @@ -88,23 +90,31 @@ type field struct {
}

// ModelOption is an option for a specific model
type ModelOption func(*model)
type ModelOption func(*model) error

// WithTimeFormat sets the format used in time.Parse(format, val),
// note that format can't contain any spaces, the default is 01-02-2006_3:04pm
func WithTimeFormat(format string) ModelOption {
return func(m *model) {
m.timeFormat = strings.Replace(format, " ", "", -1)
return func(m *model) error {
for _, v := range format {
if unicode.IsSpace(v) {
return errors.New("time format can't contain any spaces")
}
}
m.timeFormat = format
return nil
}
}

// WithTimeLocation sets the location used in time.ParseInLocation(format, value, loc),
// the default is time.Local
func WithTimeLocation(loc *time.Location) ModelOption {
return func(m *model) {
if loc != nil {
m.timeLocation = loc
return func(m *model) error {
if loc == nil {
return errors.New("time location can't be nil")
}
m.timeLocation = loc
return nil
}
}

Expand All @@ -126,12 +136,15 @@ func (nl *NL) RegisterModel(i interface{}, samples []string, ops ...ModelOption)
mod := &model{
tpy: tpy,
samples: samples,
expected: make([][]expected, len(samples)),
expected: make([][]item, len(samples)),
timeFormat: "01-02-2006_3:04pm",
timeLocation: time.Local,
}
for _, op := range ops {
op(mod)
err := op(mod)
if err != nil {
return err
}
}
NextField:
for i := 0; i < tpy.NumField(); i++ {
Expand Down Expand Up @@ -162,7 +175,7 @@ func (m *model) learn() error {
if err != nil {
return err
}
var exps []expected
var exps []item
var hasAtLeastOneKey bool
for _, tk := range tokens {
if tk.Kw {
Expand All @@ -171,14 +184,14 @@ func (m *model) learn() error {
for _, f := range m.fields {
if tk.Val == f.name {
mistypedField = false
exps = append(exps, expected{field: f, value: tk.Val})
exps = append(exps, item{field: f, value: tk.Val})
}
}
if mistypedField {
return fmt.Errorf("sample#%d: mistyped field %q", sid, tk.Val)
}
} else {
exps = append(exps, expected{limit: true, value: tk.Val})
exps = append(exps, item{limit: true, value: tk.Val})
}
}
if !hasAtLeastOneKey {
Expand All @@ -189,7 +202,7 @@ func (m *model) learn() error {
return nil
}

func (m *model) selectBestSample(expr string) []expected {
func (m *model) selectBestSample(expr string) []item {
// slice [sample_id]score
scores := make([]int, len(m.samples))

Expand All @@ -200,7 +213,7 @@ func (m *model) selectBestSample(expr string) []expected {

// fmt.Printf("tokens: %v\n", tokens)

mapping := make([][]expected, len(m.samples))
mapping := make([][]item, len(m.samples))
limitsOrder := make([][]string, len(m.samples)+1)

for sid, exps := range m.expected {
Expand Down Expand Up @@ -235,7 +248,7 @@ func (m *model) selectBestSample(expr string) []expected {
scores[sid] = scores[sid] + 1
if len(currentVal) > 0 {
// fmt.Printf("appending: %v {%v}\n", strings.Join(currentVal, " "), e.field.n)
mapping[sid] = append(mapping[sid], expected{field: e.field, value: strings.Join(currentVal, " ")})
mapping[sid] = append(mapping[sid], item{field: e.field, value: strings.Join(currentVal, " ")})
currentVal = currentVal[:0]
lastToken = i
continue expecteds
Expand All @@ -251,7 +264,7 @@ func (m *model) selectBestSample(expr string) []expected {
}
if len(currentVal) > 0 {
// fmt.Printf("appending: %v {%v}\n", strings.Join(currentVal, " "), e.field.n)
mapping[sid] = append(mapping[sid], expected{field: e.field, value: strings.Join(currentVal, " ")})
mapping[sid] = append(mapping[sid], item{field: e.field, value: strings.Join(currentVal, " ")})
}
}
// fmt.Printf("\n\n")
Expand Down
2 changes: 1 addition & 1 deletion nlp_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -277,7 +277,7 @@ func Test_model_learn(t *testing.T) {
type fields struct {
tpy reflect.Type
fields []field
expected [][]expected
expected [][]item
samples []string
}
tests := []struct {
Expand Down

0 comments on commit 20f99e6

Please sign in to comment.