77 "math"
88 "os"
99 "path/filepath"
10+ "sync"
1011 "sync/atomic"
1112)
1213
@@ -19,6 +20,12 @@ const defaultProb = 1e-11
1920// ErrUnderflow is returned when an underflow is detected.
2021var ErrUnderflow = errors .New ("possible underflow detected" )
2122
23+ // ErrClassExists is returned when trying to add a class that already exists.
24+ var ErrClassExists = errors .New ("class already exists" )
25+
26+ // ErrAlreadyConverted is returned when trying to add a class after TF-IDF conversion.
27+ var ErrAlreadyConverted = errors .New ("cannot add class after TF-IDF conversion" )
28+
2229// Class defines a class that the classifier will filter:
2330// C = {C_1, ..., C_n}. You should define your classes as a
2431// set of constants, for example as follows:
@@ -39,7 +46,8 @@ type Classifier struct {
3946 datas map [Class ]* classData
4047 tfIdf bool
4148 DidConvertTfIdf bool // we can't classify a TF-IDF classifier if we haven't yet
42- // called ConverTermsFreqToTfIdf
49+ // called ConvertTermsFreqToTfIdf
50+ mu sync.RWMutex // protects Classes and datas for concurrent access
4351}
4452
4553// serializableClassifier represents a container for
@@ -145,7 +153,38 @@ func NewClassifierFromReader(r io.Reader) (c *Classifier, err error) {
145153 w := new (serializableClassifier )
146154 err = dec .Decode (w )
147155
148- return & Classifier {w .Classes , w .Learned , int32 (w .Seen ), w .Datas , w .TfIdf , w .DidConvertTfIdf }, err
156+ return & Classifier {
157+ Classes : w .Classes ,
158+ learned : w .Learned ,
159+ seen : int32 (w .Seen ),
160+ datas : w .Datas ,
161+ tfIdf : w .TfIdf ,
162+ DidConvertTfIdf : w .DidConvertTfIdf ,
163+ // mu is zero-valued and ready to use
164+ }, err
165+ }
166+
167+ // AddClass adds a new class to the classifier dynamically.
168+ // Returns ErrClassExists if the class already exists, or
169+ // ErrAlreadyConverted if the classifier has been converted to TF-IDF.
170+ // This method is safe for concurrent use.
171+ func (c * Classifier ) AddClass (class Class ) error {
172+ c .mu .Lock ()
173+ defer c .mu .Unlock ()
174+
175+ // Check if TF-IDF conversion has happened
176+ if c .DidConvertTfIdf {
177+ return ErrAlreadyConverted
178+ }
179+
180+ // Check if class already exists
181+ if _ , exists := c .datas [class ]; exists {
182+ return ErrClassExists
183+ }
184+
185+ c .Classes = append (c .Classes , class )
186+ c .datas [class ] = newClassData ()
187+ return nil
149188}
150189
151190// getPriors returns the prior probabilities for the
@@ -190,7 +229,10 @@ func (c *Classifier) IsTfIdf() bool {
190229
191230// WordCount returns the number of words counted for
192231// each class in the lifetime of the classifier.
232+ // This method is safe for concurrent use.
193233func (c * Classifier ) WordCount () (result []int ) {
234+ c .mu .RLock ()
235+ defer c .mu .RUnlock ()
194236 result = make ([]int , len (c .Classes ))
195237 for inx , class := range c .Classes {
196238 data := c .datas [class ]
@@ -200,16 +242,22 @@ func (c *Classifier) WordCount() (result []int) {
200242}
201243
202244// Observe should be used when word-frequencies have been already been learned
203- // externally (e.g., hadoop)
245+ // externally (e.g., hadoop).
246+ // This method is safe for concurrent use.
204247func (c * Classifier ) Observe (word string , count int , which Class ) {
248+ c .mu .Lock ()
249+ defer c .mu .Unlock ()
205250 data := c .datas [which ]
206251 data .Freqs [word ] += float64 (count )
207252 data .Total += count
208253}
209254
210255// Learn will accept new training documents for
211256// supervised learning.
257+ // This method is safe for concurrent use.
212258func (c * Classifier ) Learn (document []string , which Class ) {
259+ c .mu .Lock ()
260+ defer c .mu .Unlock ()
213261
214262 // If we are a tfidf classifier we first need to get terms as
215263 // terms frequency and store that to work out the idf part later
@@ -246,19 +294,20 @@ func (c *Classifier) Learn(document []string, which Class) {
246294// ConvertTermsFreqToTfIdf uses all the TF samples for the class and converts
247295// them to TF-IDF https://en.wikipedia.org/wiki/Tf%E2%80%93idf
248296// once we have finished learning all the classes and have the totals.
297+ // This method is safe for concurrent use.
249298func (c * Classifier ) ConvertTermsFreqToTfIdf () {
299+ c .mu .Lock ()
300+ defer c .mu .Unlock ()
250301
251302 if c .DidConvertTfIdf {
252303 panic ("Cannot call ConvertTermsFreqToTfIdf more than once. Reset and relearn to reconvert." )
253304 }
254305
255306 for className := range c .datas {
256-
257307 for wIndex := range c .datas [className ].FreqTfs {
258308 tfIdfAdder := float64 (0 )
259309
260310 for tfSampleIndex := range c .datas [className ].FreqTfs [wIndex ] {
261-
262311 // we always want a positive TF-IDF score.
263312 tf := c .datas [className ].FreqTfs [wIndex ][tfSampleIndex ]
264313 c .datas [className ].FreqTfs [wIndex ][tfSampleIndex ] = math .Log1p (tf ) * math .Log1p (float64 (c .learned )/ float64 (c .datas [className ].Total ))
@@ -267,12 +316,9 @@ func (c *Classifier) ConvertTermsFreqToTfIdf() {
267316 // convert the 'counts' to TF-IDF's
268317 c .datas [className ].Freqs [wIndex ] = tfIdfAdder
269318 }
270-
271319 }
272320
273- // sanity check
274321 c .DidConvertTfIdf = true
275-
276322}
277323
278324// LogScores produces "log-likelihood"-like scores that can
@@ -294,20 +340,22 @@ func (c *Classifier) ConvertTermsFreqToTfIdf() {
294340//
295341// Unlike c.Probabilities(), this function is not prone to
296342// floating point underflow and is relatively safe to use.
343+ // This method is safe for concurrent use.
297344func (c * Classifier ) LogScores (document []string ) (scores []float64 , inx int , strict bool ) {
345+ c .mu .RLock ()
346+ defer c .mu .RUnlock ()
347+
298348 if c .tfIdf && ! c .DidConvertTfIdf {
299349 panic ("Using a TF-IDF classifier. Please call ConvertTermsFreqToTfIdf before calling LogScores." )
300350 }
301351
302352 n := len (c .Classes )
303- scores = make ([]float64 , n , n )
353+ scores = make ([]float64 , n )
304354 priors := c .getPriors ()
305355
306356 // calculate the score for each class
307357 for index , class := range c .Classes {
308358 data := c .datas [class ]
309- // c is the sum of the logarithms
310- // as outlined in the refresher
311359 score := math .Log (priors [index ])
312360 for _ , word := range document {
313361 score += math .Log (data .getWordProb (word ))
@@ -342,7 +390,11 @@ func (c *Classifier) Classify(document []string) (class Class, scores []float64,
342390//
343391// If all scores underflow to zero, returns equal probabilities
344392// for all classes (1/n each).
393+ // This method is safe for concurrent use.
345394func (c * Classifier ) ProbScores (doc []string ) (scores []float64 , inx int , strict bool ) {
395+ c .mu .RLock ()
396+ defer c .mu .RUnlock ()
397+
346398 if c .tfIdf && ! c .DidConvertTfIdf {
347399 panic ("Using a TF-IDF classifier. Please call ConvertTermsFreqToTfIdf before calling ProbScores." )
348400 }
@@ -402,7 +454,11 @@ func (c *Classifier) ClassifyProb(document []string) (class Class, scores []floa
402454//
403455// When underflow is detected, the returned scores are computed from
404456// log-domain scores using the log-sum-exp trick for numerical stability.
457+ // This method is safe for concurrent use.
405458func (c * Classifier ) SafeProbScores (doc []string ) (scores []float64 , inx int , strict bool , err error ) {
459+ c .mu .RLock ()
460+ defer c .mu .RUnlock ()
461+
406462 if c .tfIdf && ! c .DidConvertTfIdf {
407463 panic ("Using a TF-IDF classifier. Please call ConvertTermsFreqToTfIdf before calling SafeProbScores." )
408464 }
@@ -500,11 +556,14 @@ func (c *Classifier) ClassifySafe(document []string) (class Class, scores []floa
500556// exist in the classifier for each class state for the given input
501557// words. In other words, if you obtain the frequencies
502558//
503- // freqs := c.WordFrequencies(/* [j]string */)
559+ // freqs := c.WordFrequencies(/* [j]string */)
504560//
505561// then the expression freq[i][j] represents the frequency of the j-th
506562// word within the i-th class.
563+ // This method is safe for concurrent use.
507564func (c * Classifier ) WordFrequencies (words []string ) (freqMatrix [][]float64 ) {
565+ c .mu .RLock ()
566+ defer c .mu .RUnlock ()
508567 n , l := len (c .Classes ), len (words )
509568 freqMatrix = make ([][]float64 , n )
510569 for i := range freqMatrix {
@@ -520,38 +579,52 @@ func (c *Classifier) WordFrequencies(words []string) (freqMatrix [][]float64) {
520579
521580// WordsByClass returns a map of words and their probability of
522581// appearing in the given class.
582+ // This method is safe for concurrent use.
523583func (c * Classifier ) WordsByClass (class Class ) (freqMap map [string ]float64 ) {
584+ c .mu .RLock ()
585+ defer c .mu .RUnlock ()
524586 freqMap = make (map [string ]float64 )
525587 for word , cnt := range c .datas [class ].Freqs {
526588 freqMap [word ] = float64 (cnt ) / float64 (c .datas [class ].Total )
527589 }
528-
529590 return freqMap
530591}
531592
532593
533594// WriteToFile serializes this classifier to a file.
595+ // This method is safe for concurrent use.
534596func (c * Classifier ) WriteToFile (name string ) error {
535597 file , err := os .OpenFile (name , os .O_WRONLY | os .O_CREATE | os .O_TRUNC , 0644 )
536598 if err != nil {
537599 return err
538600 }
539601 defer file .Close ()
540- return c .WriteGob (file )
602+ return c .writeGobLocked (file )
541603}
542604
543605// WriteClassesToFile writes all classes to files.
606+ // This method is safe for concurrent use.
544607func (c * Classifier ) WriteClassesToFile (rootPath string ) error {
608+ c .mu .RLock ()
609+ defer c .mu .RUnlock ()
545610 for name := range c .datas {
546- if err := c .WriteClassToFile (name , rootPath ); err != nil {
611+ if err := c .writeClassToFileLocked (name , rootPath ); err != nil {
547612 return err
548613 }
549614 }
550615 return nil
551616}
552617
553618// WriteClassToFile writes a single class to file.
619+ // This method is safe for concurrent use.
554620func (c * Classifier ) WriteClassToFile (name Class , rootPath string ) error {
621+ c .mu .RLock ()
622+ defer c .mu .RUnlock ()
623+ return c .writeClassToFileLocked (name , rootPath )
624+ }
625+
626+ // writeClassToFileLocked writes a single class to file (caller must hold lock).
627+ func (c * Classifier ) writeClassToFileLocked (name Class , rootPath string ) error {
555628 data := c .datas [name ]
556629 fileName := filepath .Join (rootPath , string (name ))
557630 file , err := os .OpenFile (fileName , os .O_WRONLY | os .O_CREATE | os .O_TRUNC , 0644 )
@@ -564,20 +637,27 @@ func (c *Classifier) WriteClassToFile(name Class, rootPath string) error {
564637
565638
566639// WriteGob serializes this classifier to GOB and writes to Writer.
640+ // This method is safe for concurrent use.
567641func (c * Classifier ) WriteGob (w io.Writer ) (err error ) {
642+ c .mu .RLock ()
643+ defer c .mu .RUnlock ()
644+ return c .writeGobLocked (w )
645+ }
646+
647+ // writeGobLocked serializes this classifier to GOB (caller must hold lock).
648+ func (c * Classifier ) writeGobLocked (w io.Writer ) (err error ) {
568649 enc := gob .NewEncoder (w )
569650 err = enc .Encode (& serializableClassifier {c .Classes , c .learned , int (c .seen ), c .datas , c .tfIdf , c .DidConvertTfIdf })
570-
571651 return
572652}
573653
574654
575655// ReadClassFromFile loads existing class data from a
576656// file.
657+ // This method is safe for concurrent use.
577658func (c * Classifier ) ReadClassFromFile (class Class , location string ) (err error ) {
578659 fileName := filepath .Join (location , string (class ))
579660 file , err := os .Open (fileName )
580-
581661 if err != nil {
582662 return err
583663 }
@@ -586,7 +666,12 @@ func (c *Classifier) ReadClassFromFile(class Class, location string) (err error)
586666 dec := gob .NewDecoder (file )
587667 w := new (classData )
588668 err = dec .Decode (w )
669+ if err != nil {
670+ return err
671+ }
589672
673+ c .mu .Lock ()
674+ defer c .mu .Unlock ()
590675 c .learned ++
591676 c .datas [class ] = w
592677 return
0 commit comments