diff --git a/naive/naive.go b/naive/naive.go index 3d598f4..7b7e990 100644 --- a/naive/naive.go +++ b/naive/naive.go @@ -2,8 +2,11 @@ package naive import ( "bytes" + "encoding/gob" "errors" "io" + "io/ioutil" + "os" "sync" "github.com/n3integration/classifier" @@ -17,8 +20,8 @@ type Option func(c *Classifier) error // Classifier implements a naive bayes classifier type Classifier struct { - feat2cat map[string]map[string]int - catCount map[string]int + Feat2Cat map[string]map[string]int + CatCount map[string]int tokenizer classifier.Tokenizer mu sync.RWMutex } @@ -26,8 +29,8 @@ type Classifier struct { // New initializes a new naive Classifier using the standard tokenizer func New(opts ...Option) *Classifier { c := &Classifier{ - feat2cat: make(map[string]map[string]int), - catCount: make(map[string]int), + Feat2Cat: make(map[string]map[string]int), + CatCount: make(map[string]int), tokenizer: classifier.NewTokenizer(), } for _, opt := range opts { @@ -36,6 +39,25 @@ func New(opts ...Option) *Classifier { return c } +// Load reads a classifier from a binary file +func Load(path string) (*Classifier, error) { + c := &Classifier{ + tokenizer: classifier.NewTokenizer(), + } + + file, err := ioutil.ReadFile(path) + if err != nil { + return nil, err + } + + dec := gob.NewDecoder(bytes.NewBuffer(file)) + if err = dec.Decode(c); err != nil { + return nil, err + } + + return c, err +} + // Tokenizer overrides the classifier's default Tokenizer func Tokenizer(t classifier.Tokenizer) Option { return func(c *Classifier) error { @@ -92,34 +114,54 @@ func (c *Classifier) ClassifyString(doc string) (string, error) { return c.Classify(asReader(doc)) } +// Save writes a classifier to a binary file +func (c *Classifier) Save(path string) (err error) { + var network bytes.Buffer + + file, err := os.Create(path) + if err != nil { + return err + } + defer file.Close() + + enc := gob.NewEncoder(&network) + err = enc.Encode(c) + if err != nil { + return err + } + + _, err = file.Write(network.Bytes()) + return err +} + func (c *Classifier) addFeature(feature string, category string) { - if _, ok := c.feat2cat[feature]; !ok { - c.feat2cat[feature] = make(map[string]int) + if _, ok := c.Feat2Cat[feature]; !ok { + c.Feat2Cat[feature] = make(map[string]int) } - c.feat2cat[feature][category]++ + c.Feat2Cat[feature][category]++ } func (c *Classifier) featureCount(feature string, category string) float64 { - if _, ok := c.feat2cat[feature]; ok { - return float64(c.feat2cat[feature][category]) + if _, ok := c.Feat2Cat[feature]; ok { + return float64(c.Feat2Cat[feature][category]) } return 0.0 } func (c *Classifier) addCategory(category string) { - c.catCount[category]++ + c.CatCount[category]++ } func (c *Classifier) categoryCount(category string) float64 { - if _, ok := c.catCount[category]; ok { - return float64(c.catCount[category]) + if _, ok := c.CatCount[category]; ok { + return float64(c.CatCount[category]) } return 0.0 } func (c *Classifier) count() int { sum := 0 - for _, value := range c.catCount { + for _, value := range c.CatCount { sum += value } return sum @@ -127,7 +169,7 @@ func (c *Classifier) count() int { func (c *Classifier) categories() []string { var keys []string - for k := range c.catCount { + for k := range c.CatCount { keys = append(keys, k) } return keys