@@ -11,8 +11,10 @@ import (
1111)
1212
1313// defaultProb is the tiny non-zero probability that a word
14- // we have not seen before appears in the class.
15- const defaultProb = 0.00000000001
14+ // we have not seen before appears in the class. This is used
15+ // as a fallback when Laplace smoothing cannot be applied
16+ // (e.g., when the classifier has no training data).
17+ const defaultProb = 1e-11
1618
1719// ErrUnderflow is returned when an underflow is detected.
1820var ErrUnderflow = errors .New ("possible underflow detected" )
@@ -72,75 +74,56 @@ func newClassData() *classData {
7274
7375// getWordProb returns P(W|C_j) -- the probability of seeing
7476// a particular word W in a document of this class.
77+ // Uses Laplace smoothing (add-one smoothing) to handle unseen words:
78+ // P(W|C) = (count(W,C) + 1) / (total_words_in_C + vocabulary_size)
7579func (d * classData ) getWordProb (word string ) float64 {
76- value , ok := d .Freqs [ word ]
77- if ! ok {
80+ vocab := len ( d .Freqs )
81+ if d . Total == 0 || vocab == 0 {
7882 return defaultProb
7983 }
80- return float64 (value ) / float64 (d .Total )
84+ value := d .Freqs [word ] // 0 if not found
85+ return (value + 1 ) / (float64 (d .Total ) + float64 (vocab ))
8186}
8287
83- // NewClassifierTfIdf returns a new classifier. The classes provided
84- // should be at least 2 in number and unique, or this method will
85- // panic.
86- func NewClassifierTfIdf ( classes ... Class ) ( c * Classifier ) {
88+ // newClassifier is the internal constructor that creates a classifier.
89+ // The classes provided should be at least 2 in number and unique,
90+ // or this function will panic.
91+ func newClassifier ( tfIdf bool , classes [] Class ) * Classifier {
8792 n := len (classes )
88-
89- // check size
9093 if n < 2 {
9194 panic ("provide at least two classes" )
9295 }
9396
9497 // check uniqueness
95- check := make (map [Class ]bool , n )
98+ check := make (map [Class ]struct {} , n )
9699 for _ , class := range classes {
97- check [class ] = true
100+ check [class ] = struct {}{}
98101 }
99102 if len (check ) != n {
100103 panic ("classes must be unique" )
101104 }
102- // create the classifier
103- c = & Classifier {
105+
106+ c : = & Classifier {
104107 Classes : classes ,
105108 datas : make (map [Class ]* classData , n ),
106- tfIdf : true ,
109+ tfIdf : tfIdf ,
107110 }
108111 for _ , class := range classes {
109112 c .datas [class ] = newClassData ()
110113 }
111- return
114+ return c
112115}
113116
114- // NewClassifier returns a new classifier. The classes provided
115- // should be at least 2 in number and unique, or this method will
116- // panic.
117- func NewClassifier (classes ... Class ) (c * Classifier ) {
118- n := len (classes )
119-
120- // check size
121- if n < 2 {
122- panic ("provide at least two classes" )
123- }
117+ // NewClassifierTfIdf returns a new TF-IDF classifier. The classes provided
118+ // should be at least 2 in number and unique, or this method will panic.
119+ func NewClassifierTfIdf (classes ... Class ) * Classifier {
120+ return newClassifier (true , classes )
121+ }
124122
125- // check uniqueness
126- check := make (map [Class ]bool , n )
127- for _ , class := range classes {
128- check [class ] = true
129- }
130- if len (check ) != n {
131- panic ("classes must be unique" )
132- }
133- // create the classifier
134- c = & Classifier {
135- Classes : classes ,
136- datas : make (map [Class ]* classData , n ),
137- tfIdf : false ,
138- DidConvertTfIdf : false ,
139- }
140- for _ , class := range classes {
141- c .datas [class ] = newClassData ()
142- }
143- return
123+ // NewClassifier returns a new classifier. The classes provided
124+ // should be at least 2 in number and unique, or this method will panic.
125+ func NewClassifier (classes ... Class ) * Classifier {
126+ return newClassifier (false , classes )
144127}
145128
146129// NewClassifierFromFile loads an existing classifier from
@@ -493,37 +476,35 @@ func (c *Classifier) WordsByClass(class Class) (freqMap map[string]float64) {
493476
494477
495478// WriteToFile serializes this classifier to a file.
496- func (c * Classifier ) WriteToFile (name string ) ( err error ) {
497- file , err := os .OpenFile (name , os .O_WRONLY | os .O_CREATE , 0644 )
479+ func (c * Classifier ) WriteToFile (name string ) error {
480+ file , err := os .OpenFile (name , os .O_WRONLY | os .O_CREATE | os . O_TRUNC , 0644 )
498481 if err != nil {
499482 return err
500483 }
501484 defer file .Close ()
502-
503485 return c .WriteGob (file )
504486}
505487
506488// WriteClassesToFile writes all classes to files.
507- func (c * Classifier ) WriteClassesToFile (rootPath string ) ( err error ) {
489+ func (c * Classifier ) WriteClassesToFile (rootPath string ) error {
508490 for name := range c .datas {
509- c .WriteClassToFile (name , rootPath )
491+ if err := c .WriteClassToFile (name , rootPath ); err != nil {
492+ return err
493+ }
510494 }
511- return
495+ return nil
512496}
513497
514498// WriteClassToFile writes a single class to file.
515- func (c * Classifier ) WriteClassToFile (name Class , rootPath string ) ( err error ) {
499+ func (c * Classifier ) WriteClassToFile (name Class , rootPath string ) error {
516500 data := c .datas [name ]
517501 fileName := filepath .Join (rootPath , string (name ))
518- file , err := os .OpenFile (fileName , os .O_WRONLY | os .O_CREATE , 0644 )
502+ file , err := os .OpenFile (fileName , os .O_WRONLY | os .O_CREATE | os . O_TRUNC , 0644 )
519503 if err != nil {
520504 return err
521505 }
522506 defer file .Close ()
523-
524- enc := gob .NewEncoder (file )
525- err = enc .Encode (data )
526- return
507+ return gob .NewEncoder (file ).Encode (data )
527508}
528509
529510
0 commit comments