-
Notifications
You must be signed in to change notification settings - Fork 1.2k
/
Copy pathmultisvc.go
176 lines (151 loc) · 4.79 KB
/
multisvc.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
package ensemble
import (
"github.com/sjwhitworth/golearn/base"
"github.com/sjwhitworth/golearn/linear_models"
"github.com/sjwhitworth/golearn/meta"
"fmt"
)
// MultiLinearSVC implements a multi-class Support Vector Classifier using a one-vs-all
// voting scheme. Only one CategoricalAttribute class is supported.
type MultiLinearSVC struct {
m *meta.OneVsAllModel
parameters *linear_models.LinearSVCParams
weights map[string]float64
}
// NewMultiLinearSVC creates a new MultiLinearSVC using the OneVsAllModel.
// The loss and penalty arguments can be "l1" or "l2". Typical values are
// "l1" for the loss and "l2" for the penalty. The dual parameter controls
// whether the system solves the dual or primal SVM form, true should be used
// in most cases. C is the penalty term, normally 1.0. eps is the convergence
// term, typically 1e-4.
func NewMultiLinearSVC(loss, penalty string, dual bool, C float64, eps float64, weights map[string]float64) *MultiLinearSVC {
// Set up the training parameters
params := &linear_models.LinearSVCParams{0, nil, C, eps, false, dual}
err := params.SetKindFromStrings(loss, penalty)
if err != nil {
panic(err)
}
// Return me...
ret := &MultiLinearSVC{
parameters: params,
weights: weights,
}
ret.initializeOneVsAllModel()
return ret
}
func (m *MultiLinearSVC) initializeOneVsAllModel() {
// Classifier creation function
classifierFunc := func(cls string) base.Classifier {
var weightVec []float64
newParams := m.parameters.Copy()
if m.weights != nil {
weightVec = make([]float64, 2)
for i := range m.weights {
if i != cls {
weightVec[0] += m.weights[i]
} else {
weightVec[1] = m.weights[i]
}
}
}
newParams.ClassWeights = weightVec
ret, err := linear_models.NewLinearSVCFromParams(newParams)
if err != nil {
panic(err)
}
return ret
}
m.m = meta.NewOneVsAllModel(classifierFunc)
}
// Fit builds the MultiLinearSVC by building n (where n is the number of values
// the singular CategoricalAttribute can take) seperate one-vs-rest models.
func (m *MultiLinearSVC) Fit(instances base.FixedDataGrid) error {
m.m.Fit(instances)
return nil
}
// Predict issues predictions from the MultiLinearSVC. Each underlying LinearSVC is
// used to predict whether an instance takes on a class or some other class, and the
// model which definitively reports a given class is the one chosen. The result is
// undefined if all underlying models predict that the instance originates from some
// other class.
func (m *MultiLinearSVC) Predict(from base.FixedDataGrid) (base.FixedDataGrid, error) {
return m.m.Predict(from)
}
func (m *MultiLinearSVC) GetClassifierMetadata() base.ClassifierMetadataV1 {
return base.ClassifierMetadataV1{
FormatVersion: 1,
ClassifierName: "MultiLinearSVC",
ClassifierVersion: "1",
ClassifierMetadata: nil,
}
}
func (m *MultiLinearSVC) Save(filePath string) error {
metadata := m.GetClassifierMetadata()
serializer, err := base.CreateSerializedClassifierStub(filePath, metadata)
if err != nil {
return err
}
err = m.SaveWithPrefix(serializer, "")
if err != nil {
return fmt.Errorf("Unable to Save(): %v", err)
}
serializer.Close()
return err
}
func (m *MultiLinearSVC) SaveWithPrefix(serializer *base.ClassifierSerializer, prefix string) error {
p := func(fName string) string {
return fmt.Sprintf("%s/%s", prefix, fName)
}
// Write out the linear parameters
err := serializer.WriteJSONForKey(p("params"), m.parameters)
if err != nil {
return fmt.Errorf("Unable to marshal parameters: %v", err)
}
// Write out the weights
err = serializer.WriteJSONForKey(p("weights"), m.weights)
if err != nil {
return fmt.Errorf("Unable to write weights: %v", err)
}
// Serialize the model
err = m.m.SaveWithPrefix(serializer, p("one-vs-all"))
return err
}
func (m *MultiLinearSVC) GetMetadata() base.ClassifierMetadataV1 {
return base.ClassifierMetadataV1{
FormatVersion: 1,
ClassifierName: "MultiLinearSVC",
ClassifierVersion: "1.0",
ClassifierMetadata: nil,
}
}
func (m *MultiLinearSVC) Load(filePath string) error {
reader, err := base.ReadSerializedClassifierStub(filePath)
if err != nil {
return err
}
err = m.LoadWithPrefix(reader, "")
if err != nil {
return err
}
return nil
}
func (m *MultiLinearSVC) LoadWithPrefix(reader *base.ClassifierDeserializer, prefix string) error {
p := func(fName string) string {
return fmt.Sprintf("%s/%s", prefix, fName)
}
err := reader.GetJSONForKey(p("params"), &m.parameters)
if err != nil {
return fmt.Errorf("Can't load parameters: %v", err)
}
err = reader.GetJSONForKey(p("weights"), &m.weights)
if err != nil {
return fmt.Errorf("Can't load parameters: %v", err)
}
m.initializeOneVsAllModel()
// Load the model
err = m.m.LoadWithPrefix(reader, p("one-vs-all"))
if err != nil {
return err
}
return nil
}