Skip to content

Save and load a classifier #1

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
70 changes: 56 additions & 14 deletions naive/naive.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,11 @@ package naive

import (
"bytes"
"encoding/gob"
"errors"
"io"
"io/ioutil"
"os"
"sync"

"github.com/n3integration/classifier"
Expand All @@ -17,17 +20,17 @@ type Option func(c *Classifier) error

// Classifier implements a naive bayes classifier
type Classifier struct {
feat2cat map[string]map[string]int
catCount map[string]int
Feat2Cat map[string]map[string]int
CatCount map[string]int
tokenizer classifier.Tokenizer
mu sync.RWMutex
}

// New initializes a new naive Classifier using the standard tokenizer
func New(opts ...Option) *Classifier {
c := &Classifier{
feat2cat: make(map[string]map[string]int),
catCount: make(map[string]int),
Feat2Cat: make(map[string]map[string]int),
CatCount: make(map[string]int),
tokenizer: classifier.NewTokenizer(),
}
for _, opt := range opts {
Expand All @@ -36,6 +39,25 @@ func New(opts ...Option) *Classifier {
return c
}

// Load reads a classifier from a binary file
func Load(path string) (*Classifier, error) {
c := &Classifier{
tokenizer: classifier.NewTokenizer(),
}

file, err := ioutil.ReadFile(path)
if err != nil {
return nil, err
}

dec := gob.NewDecoder(bytes.NewBuffer(file))
if err = dec.Decode(c); err != nil {
return nil, err
}

return c, err
}

// Tokenizer overrides the classifier's default Tokenizer
func Tokenizer(t classifier.Tokenizer) Option {
return func(c *Classifier) error {
Expand Down Expand Up @@ -92,42 +114,62 @@ func (c *Classifier) ClassifyString(doc string) (string, error) {
return c.Classify(asReader(doc))
}

// Save writes a classifier to a binary file
func (c *Classifier) Save(path string) (err error) {
var network bytes.Buffer

file, err := os.Create(path)
if err != nil {
return err
}
defer file.Close()

enc := gob.NewEncoder(&network)
err = enc.Encode(c)
if err != nil {
return err
}

_, err = file.Write(network.Bytes())
return err
}

func (c *Classifier) addFeature(feature string, category string) {
if _, ok := c.feat2cat[feature]; !ok {
c.feat2cat[feature] = make(map[string]int)
if _, ok := c.Feat2Cat[feature]; !ok {
c.Feat2Cat[feature] = make(map[string]int)
}
c.feat2cat[feature][category]++
c.Feat2Cat[feature][category]++
}

func (c *Classifier) featureCount(feature string, category string) float64 {
if _, ok := c.feat2cat[feature]; ok {
return float64(c.feat2cat[feature][category])
if _, ok := c.Feat2Cat[feature]; ok {
return float64(c.Feat2Cat[feature][category])
}
return 0.0
}

func (c *Classifier) addCategory(category string) {
c.catCount[category]++
c.CatCount[category]++
}

func (c *Classifier) categoryCount(category string) float64 {
if _, ok := c.catCount[category]; ok {
return float64(c.catCount[category])
if _, ok := c.CatCount[category]; ok {
return float64(c.CatCount[category])
}
return 0.0
}

func (c *Classifier) count() int {
sum := 0
for _, value := range c.catCount {
for _, value := range c.CatCount {
sum += value
}
return sum
}

func (c *Classifier) categories() []string {
var keys []string
for k := range c.catCount {
for k := range c.CatCount {
keys = append(keys, k)
}
return keys
Expand Down