Skip to content

Commit

Permalink
Merge pull request #232 from hmcalister/stateManager
Browse files Browse the repository at this point in the history
State manager
  • Loading branch information
hmcalister authored Jun 4, 2023
2 parents 29e02ba + 2d9ad35 commit 4126f8f
Show file tree
Hide file tree
Showing 13 changed files with 155 additions and 180 deletions.
26 changes: 9 additions & 17 deletions hopfieldnetwork/HopfieldNetwork.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,10 @@ package hopfieldnetwork

import (
"fmt"
"hmcalister/hopfield/hopfieldnetwork/activationfunction"
"hmcalister/hopfield/hopfieldnetwork/datacollector"
"hmcalister/hopfield/hopfieldnetwork/domain"
"hmcalister/hopfield/hopfieldnetwork/energyfunction"
"hmcalister/hopfield/hopfieldnetwork/learningmappingfunction"
"hmcalister/hopfield/hopfieldnetwork/noiseapplication"
"hmcalister/hopfield/hopfieldnetwork/states/statemanager"
"hmcalister/hopfield/hopfieldutils"
"log"

Expand All @@ -27,6 +25,7 @@ type HopfieldNetwork struct {
matrix *mat.Dense
dimension int
domain domain.DomainEnum
domainStateManager statemanager.StateManager
forceSymmetric bool
forceZeroDiagonal bool
learningMethod LearningMethod
Expand Down Expand Up @@ -168,9 +167,7 @@ func (network *HopfieldNetwork) String() string {
// A float64 representing the energy of the given state with respect to the network.
// Note a lower energy is more stable - but a negative state energy may still be unstable!
func (network *HopfieldNetwork) StateEnergy(state *mat.VecDense) float64 {
stateCopy := mat.VecDenseCopyOf(state)
learningmappingfunction.GetLearningMappingFunction(network.domain)(stateCopy)
return energyfunction.StateEnergy(network.matrix, stateCopy)
return network.domainStateManager.StateEnergy(network.matrix, state)
}

// Get the energy of a given unit (indexed by i) in the state with respect to the network matrix.
Expand All @@ -184,9 +181,7 @@ func (network *HopfieldNetwork) StateEnergy(state *mat.VecDense) float64 {
//
// A float64 representing the energy of the given unit within the state.
func (network *HopfieldNetwork) UnitEnergy(state *mat.VecDense, unitIndex int) float64 {
stateCopy := mat.VecDenseCopyOf(state)
learningmappingfunction.GetLearningMappingFunction(network.domain)(stateCopy)
return energyfunction.UnitEnergy(network.matrix, stateCopy, unitIndex)
return network.domainStateManager.UnitEnergy(network.matrix, state, unitIndex)
}

// Get the energy of a each unit within a state with respect to the network matrix.
Expand All @@ -199,10 +194,7 @@ func (network *HopfieldNetwork) UnitEnergy(state *mat.VecDense, unitIndex int) f
//
// A slice of float64 representing the energy of the given state's units with respect to the network.
func (network *HopfieldNetwork) AllUnitEnergies(state *mat.VecDense) []float64 {
stateCopy := mat.VecDenseCopyOf(state)
learningmappingfunction.GetLearningMappingFunction(network.domain)(stateCopy)
unitEnergies := energyfunction.AllUnitEnergies(network.matrix, stateCopy)
return unitEnergies.RawVector().Data
return network.domainStateManager.AllUnitEnergies(network.matrix, state)
}

// Determine if a given state is unstable.
Expand Down Expand Up @@ -262,7 +254,7 @@ func (network *HopfieldNetwork) AllStatesAreStable(states []*mat.VecDense) bool
// states []*mat.VecDense: A collection of states to learn
func (network *HopfieldNetwork) LearnStates(states []*mat.VecDense) []*datacollector.LearnStateData {
for _, state := range states {
activationfunction.GetActivationFunction(network.domain)(state)
network.domainStateManager.ActivationFunction(state)
}
network.targetStates = append(network.targetStates, states...)

Expand Down Expand Up @@ -298,7 +290,7 @@ func (network *HopfieldNetwork) UpdateState(state *mat.VecDense) {
for _, unitIndex := range chunk {
state.SetVec(unitIndex, newState.AtVec(unitIndex))
}
activationfunction.GetActivationFunction(network.domain)(state)
network.domainStateManager.ActivationFunction(state)
}
}

Expand Down Expand Up @@ -334,7 +326,7 @@ func (network *HopfieldNetwork) RelaxState(state *mat.VecDense) *RelaxationResul
for _, unitIndex := range chunk {
state.SetVec(unitIndex, newState.AtVec(unitIndex))
}
activationfunction.GetActivationFunction(network.domain)(state)
network.domainStateManager.ActivationFunction(state)
}

// Collect the current history item if requested
Expand Down Expand Up @@ -412,7 +404,7 @@ StateRecvLoop:
for _, unitIndex := range chunk {
state.SetVec(unitIndex, newState.AtVec(unitIndex))
}
activationfunction.GetActivationFunction(network.domain)(state)
network.domainStateManager.ActivationFunction(state)
}

// Collect the current history item if requested
Expand Down
2 changes: 2 additions & 0 deletions hopfieldnetwork/HopfieldNetworkBuilder.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
"hmcalister/hopfield/hopfieldnetwork/datacollector"
"hmcalister/hopfield/hopfieldnetwork/domain"
"hmcalister/hopfield/hopfieldnetwork/noiseapplication"
"hmcalister/hopfield/hopfieldnetwork/states/statemanager"
"log"
"time"

Expand Down Expand Up @@ -239,6 +240,7 @@ func (networkBuilder *HopfieldNetworkBuilder) Build() *HopfieldNetwork {
matrix: matrix,
dimension: networkBuilder.dimension,
domain: networkBuilder.domain,
domainStateManager: statemanager.GetDomainStateManager(networkBuilder.domain),
forceSymmetric: networkBuilder.forceSymmetric,
forceZeroDiagonal: networkBuilder.forceZeroDiagonal,
learningMethod: networkBuilder.learningMethod,
Expand Down
28 changes: 10 additions & 18 deletions hopfieldnetwork/LearningRule.go
Original file line number Diff line number Diff line change
@@ -1,10 +1,6 @@
package hopfieldnetwork

import (
"fmt"
"hmcalister/hopfield/hopfieldnetwork/activationfunction"
"hmcalister/hopfield/hopfieldnetwork/learningmappingfunction"

"gonum.org/v1/gonum/mat"
)

Expand Down Expand Up @@ -69,17 +65,18 @@ func getLearningRule(learningRule LearningRuleEnum) LearningRule {
//
// A pointer to a new matrix that stabilizes the given states as much as possible.
func hebbian(network *HopfieldNetwork, states []*mat.VecDense) *mat.Dense {
var instanceContribution float64
updatedMatrix := mat.DenseCopyOf(network.GetMatrix())
updatedMatrix.Zero()
for _, state := range states {
stateCopy := mat.VecDenseCopyOf(state)
learningmappingfunction.GetLearningMappingFunction(network.domain)(stateCopy)
fmt.Printf("%v\n", stateCopy)
for i := 0; i < network.GetDimension(); i++ {
for j := 0; j < network.GetDimension(); j++ {
val := stateCopy.AtVec(i) * stateCopy.AtVec(j)
val += updatedMatrix.At(i, j)
updatedMatrix.Set(i, j, val)
if state.AtVec(i) == state.AtVec(j) {
instanceContribution = 1.0
} else {
instanceContribution = -1.0
}
updatedMatrix.Set(i, j, updatedMatrix.At(i, j)+instanceContribution)
}
}
}
Expand Down Expand Up @@ -112,7 +109,7 @@ func delta(network *HopfieldNetwork, states []*mat.VecDense) *mat.Dense {
relaxedStates[stateIndex] = mat.VecDenseCopyOf(states[stateIndex])
// We also apply some noise to the state to aide in learning
network.learningNoiseMethod(network.randomGenerator, relaxedStates[stateIndex], network.learningNoiseScale)
activationfunction.GetActivationFunction(network.domain)(relaxedStates[stateIndex])
network.domainStateManager.ActivationFunction(relaxedStates[stateIndex])
}

// This is the most important call - relax all the states!
Expand All @@ -124,16 +121,11 @@ func delta(network *HopfieldNetwork, states []*mat.VecDense) *mat.Dense {
stateHistory := relaxationResults[stateIndex].StateHistory
relaxedState := stateHistory[len(stateHistory)-1]

stateCopy := mat.VecDenseCopyOf(state)
learningmappingfunction.GetLearningMappingFunction(network.domain)(stateCopy)
relaxedCopy := mat.VecDenseCopyOf(relaxedState)
learningmappingfunction.GetLearningMappingFunction(network.domain)(relaxedCopy)

relaxationDifference.Zero()
relaxationDifference.SubVec(stateCopy, relaxedCopy)
relaxationDifference.SubVec(state, relaxedState)

stateContribution.Zero()
stateContribution.Outer(0.5, relaxationDifference, stateCopy)
stateContribution.Outer(0.5, relaxationDifference, state)

updatedMatrix.Add(updatedMatrix, stateContribution)
}
Expand Down
65 changes: 0 additions & 65 deletions hopfieldnetwork/activationfunction/ActivationFunction.go

This file was deleted.

1 change: 0 additions & 1 deletion hopfieldnetwork/domain/DomainEnum.go
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
package domain

// An enum to note the domain of the network.
// This determines the learning mapping function to use.
type DomainEnum int

const (
Expand Down
34 changes: 0 additions & 34 deletions hopfieldnetwork/energyfunction/EnergyFunction.go

This file was deleted.

36 changes: 0 additions & 36 deletions hopfieldnetwork/learningmappingfunction/LearningMappingFunction.go

This file was deleted.

Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,7 @@
package states

import (
"hmcalister/hopfield/hopfieldnetwork/activationfunction"
"hmcalister/hopfield/hopfieldnetwork/domain"
"hmcalister/hopfield/hopfieldnetwork/states/statemanager"

"gonum.org/v1/gonum/mat"
"gonum.org/v1/gonum/stat/distuv"
Expand All @@ -13,9 +12,9 @@ import (
//
// Note this struct should be initialized using the StateGeneratorBuilder from [hmcalister/hopfield/hopfieldnetwork/states].
type StateGenerator struct {
domain domain.DomainEnum
rng distuv.Uniform
dimension int
domainStateManager statemanager.StateManager
rng distuv.Uniform
dimension int
}

// Creates and returns a fresh array that can store a state.
Expand Down Expand Up @@ -48,7 +47,7 @@ func (gen *StateGenerator) NextState(dataArray []float64) *mat.VecDense {
}

state := mat.NewVecDense(gen.dimension, dataArray)
activationfunction.GetActivationFunction(gen.domain)(state)
gen.domainStateManager.ActivationFunction(state)
return state
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package states

import (
"hmcalister/hopfield/hopfieldnetwork/domain"
"hmcalister/hopfield/hopfieldnetwork/states/statemanager"
"time"

"golang.org/x/exp/rand"
Expand Down Expand Up @@ -124,7 +125,8 @@ func (builder *StateGeneratorBuilder) Build() *StateGenerator {
}

return &StateGenerator{
rng: rand_dist,
dimension: builder.dimension,
domainStateManager: statemanager.GetDomainStateManager(builder.domain),
rng: rand_dist,
dimension: builder.dimension,
}
}
Loading

0 comments on commit 4126f8f

Please sign in to comment.