Skip to content

Commit 5f3ee86

Browse files
jbrukhclaude
andcommitted
Add dynamic class addition with thread safety (closes #24)
- Add AddClass method to add new classes after classifier creation - Add RWMutex to protect Classes and datas from concurrent access - Add locks to all public methods for thread safety: - RLock for read operations (LogScores, ProbScores, etc.) - Lock for write operations (Learn, Observe, AddClass, etc.) - Add ErrClassExists and ErrAlreadyConverted error types - Add comprehensive tests for AddClass functionality - Add concurrent access tests to verify thread safety All methods are now safe for concurrent use. 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude Opus 4.5 <[email protected]>
1 parent 0dd9d42 commit 5f3ee86

File tree

2 files changed

+191
-17
lines changed

2 files changed

+191
-17
lines changed

bayesian.go

Lines changed: 102 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ import (
77
"math"
88
"os"
99
"path/filepath"
10+
"sync"
1011
"sync/atomic"
1112
)
1213

@@ -19,6 +20,12 @@ const defaultProb = 1e-11
1920
// ErrUnderflow is returned when an underflow is detected.
2021
var ErrUnderflow = errors.New("possible underflow detected")
2122

23+
// ErrClassExists is returned when trying to add a class that already exists.
24+
var ErrClassExists = errors.New("class already exists")
25+
26+
// ErrAlreadyConverted is returned when trying to add a class after TF-IDF conversion.
27+
var ErrAlreadyConverted = errors.New("cannot add class after TF-IDF conversion")
28+
2229
// Class defines a class that the classifier will filter:
2330
// C = {C_1, ..., C_n}. You should define your classes as a
2431
// set of constants, for example as follows:
@@ -39,7 +46,8 @@ type Classifier struct {
3946
datas map[Class]*classData
4047
tfIdf bool
4148
DidConvertTfIdf bool // we can't classify a TF-IDF classifier if we haven't yet
42-
// called ConverTermsFreqToTfIdf
49+
// called ConvertTermsFreqToTfIdf
50+
mu sync.RWMutex // protects Classes and datas for concurrent access
4351
}
4452

4553
// serializableClassifier represents a container for
@@ -145,7 +153,38 @@ func NewClassifierFromReader(r io.Reader) (c *Classifier, err error) {
145153
w := new(serializableClassifier)
146154
err = dec.Decode(w)
147155

148-
return &Classifier{w.Classes, w.Learned, int32(w.Seen), w.Datas, w.TfIdf, w.DidConvertTfIdf}, err
156+
return &Classifier{
157+
Classes: w.Classes,
158+
learned: w.Learned,
159+
seen: int32(w.Seen),
160+
datas: w.Datas,
161+
tfIdf: w.TfIdf,
162+
DidConvertTfIdf: w.DidConvertTfIdf,
163+
// mu is zero-valued and ready to use
164+
}, err
165+
}
166+
167+
// AddClass adds a new class to the classifier dynamically.
168+
// Returns ErrClassExists if the class already exists, or
169+
// ErrAlreadyConverted if the classifier has been converted to TF-IDF.
170+
// This method is safe for concurrent use.
171+
func (c *Classifier) AddClass(class Class) error {
172+
c.mu.Lock()
173+
defer c.mu.Unlock()
174+
175+
// Check if TF-IDF conversion has happened
176+
if c.DidConvertTfIdf {
177+
return ErrAlreadyConverted
178+
}
179+
180+
// Check if class already exists
181+
if _, exists := c.datas[class]; exists {
182+
return ErrClassExists
183+
}
184+
185+
c.Classes = append(c.Classes, class)
186+
c.datas[class] = newClassData()
187+
return nil
149188
}
150189

151190
// getPriors returns the prior probabilities for the
@@ -190,7 +229,10 @@ func (c *Classifier) IsTfIdf() bool {
190229

191230
// WordCount returns the number of words counted for
192231
// each class in the lifetime of the classifier.
232+
// This method is safe for concurrent use.
193233
func (c *Classifier) WordCount() (result []int) {
234+
c.mu.RLock()
235+
defer c.mu.RUnlock()
194236
result = make([]int, len(c.Classes))
195237
for inx, class := range c.Classes {
196238
data := c.datas[class]
@@ -200,16 +242,22 @@ func (c *Classifier) WordCount() (result []int) {
200242
}
201243

202244
// Observe should be used when word-frequencies have been already been learned
203-
// externally (e.g., hadoop)
245+
// externally (e.g., hadoop).
246+
// This method is safe for concurrent use.
204247
func (c *Classifier) Observe(word string, count int, which Class) {
248+
c.mu.Lock()
249+
defer c.mu.Unlock()
205250
data := c.datas[which]
206251
data.Freqs[word] += float64(count)
207252
data.Total += count
208253
}
209254

210255
// Learn will accept new training documents for
211256
// supervised learning.
257+
// This method is safe for concurrent use.
212258
func (c *Classifier) Learn(document []string, which Class) {
259+
c.mu.Lock()
260+
defer c.mu.Unlock()
213261

214262
// If we are a tfidf classifier we first need to get terms as
215263
// terms frequency and store that to work out the idf part later
@@ -246,19 +294,20 @@ func (c *Classifier) Learn(document []string, which Class) {
246294
// ConvertTermsFreqToTfIdf uses all the TF samples for the class and converts
247295
// them to TF-IDF https://en.wikipedia.org/wiki/Tf%E2%80%93idf
248296
// once we have finished learning all the classes and have the totals.
297+
// This method is safe for concurrent use.
249298
func (c *Classifier) ConvertTermsFreqToTfIdf() {
299+
c.mu.Lock()
300+
defer c.mu.Unlock()
250301

251302
if c.DidConvertTfIdf {
252303
panic("Cannot call ConvertTermsFreqToTfIdf more than once. Reset and relearn to reconvert.")
253304
}
254305

255306
for className := range c.datas {
256-
257307
for wIndex := range c.datas[className].FreqTfs {
258308
tfIdfAdder := float64(0)
259309

260310
for tfSampleIndex := range c.datas[className].FreqTfs[wIndex] {
261-
262311
// we always want a positive TF-IDF score.
263312
tf := c.datas[className].FreqTfs[wIndex][tfSampleIndex]
264313
c.datas[className].FreqTfs[wIndex][tfSampleIndex] = math.Log1p(tf) * math.Log1p(float64(c.learned)/float64(c.datas[className].Total))
@@ -267,12 +316,9 @@ func (c *Classifier) ConvertTermsFreqToTfIdf() {
267316
// convert the 'counts' to TF-IDF's
268317
c.datas[className].Freqs[wIndex] = tfIdfAdder
269318
}
270-
271319
}
272320

273-
// sanity check
274321
c.DidConvertTfIdf = true
275-
276322
}
277323

278324
// LogScores produces "log-likelihood"-like scores that can
@@ -294,20 +340,22 @@ func (c *Classifier) ConvertTermsFreqToTfIdf() {
294340
//
295341
// Unlike c.Probabilities(), this function is not prone to
296342
// floating point underflow and is relatively safe to use.
343+
// This method is safe for concurrent use.
297344
func (c *Classifier) LogScores(document []string) (scores []float64, inx int, strict bool) {
345+
c.mu.RLock()
346+
defer c.mu.RUnlock()
347+
298348
if c.tfIdf && !c.DidConvertTfIdf {
299349
panic("Using a TF-IDF classifier. Please call ConvertTermsFreqToTfIdf before calling LogScores.")
300350
}
301351

302352
n := len(c.Classes)
303-
scores = make([]float64, n, n)
353+
scores = make([]float64, n)
304354
priors := c.getPriors()
305355

306356
// calculate the score for each class
307357
for index, class := range c.Classes {
308358
data := c.datas[class]
309-
// c is the sum of the logarithms
310-
// as outlined in the refresher
311359
score := math.Log(priors[index])
312360
for _, word := range document {
313361
score += math.Log(data.getWordProb(word))
@@ -342,7 +390,11 @@ func (c *Classifier) Classify(document []string) (class Class, scores []float64,
342390
//
343391
// If all scores underflow to zero, returns equal probabilities
344392
// for all classes (1/n each).
393+
// This method is safe for concurrent use.
345394
func (c *Classifier) ProbScores(doc []string) (scores []float64, inx int, strict bool) {
395+
c.mu.RLock()
396+
defer c.mu.RUnlock()
397+
346398
if c.tfIdf && !c.DidConvertTfIdf {
347399
panic("Using a TF-IDF classifier. Please call ConvertTermsFreqToTfIdf before calling ProbScores.")
348400
}
@@ -402,7 +454,11 @@ func (c *Classifier) ClassifyProb(document []string) (class Class, scores []floa
402454
//
403455
// When underflow is detected, the returned scores are computed from
404456
// log-domain scores using the log-sum-exp trick for numerical stability.
457+
// This method is safe for concurrent use.
405458
func (c *Classifier) SafeProbScores(doc []string) (scores []float64, inx int, strict bool, err error) {
459+
c.mu.RLock()
460+
defer c.mu.RUnlock()
461+
406462
if c.tfIdf && !c.DidConvertTfIdf {
407463
panic("Using a TF-IDF classifier. Please call ConvertTermsFreqToTfIdf before calling SafeProbScores.")
408464
}
@@ -500,11 +556,14 @@ func (c *Classifier) ClassifySafe(document []string) (class Class, scores []floa
500556
// exist in the classifier for each class state for the given input
501557
// words. In other words, if you obtain the frequencies
502558
//
503-
// freqs := c.WordFrequencies(/* [j]string */)
559+
// freqs := c.WordFrequencies(/* [j]string */)
504560
//
505561
// then the expression freq[i][j] represents the frequency of the j-th
506562
// word within the i-th class.
563+
// This method is safe for concurrent use.
507564
func (c *Classifier) WordFrequencies(words []string) (freqMatrix [][]float64) {
565+
c.mu.RLock()
566+
defer c.mu.RUnlock()
508567
n, l := len(c.Classes), len(words)
509568
freqMatrix = make([][]float64, n)
510569
for i := range freqMatrix {
@@ -520,38 +579,52 @@ func (c *Classifier) WordFrequencies(words []string) (freqMatrix [][]float64) {
520579

521580
// WordsByClass returns a map of words and their probability of
522581
// appearing in the given class.
582+
// This method is safe for concurrent use.
523583
func (c *Classifier) WordsByClass(class Class) (freqMap map[string]float64) {
584+
c.mu.RLock()
585+
defer c.mu.RUnlock()
524586
freqMap = make(map[string]float64)
525587
for word, cnt := range c.datas[class].Freqs {
526588
freqMap[word] = float64(cnt) / float64(c.datas[class].Total)
527589
}
528-
529590
return freqMap
530591
}
531592

532593

533594
// WriteToFile serializes this classifier to a file.
595+
// This method is safe for concurrent use.
534596
func (c *Classifier) WriteToFile(name string) error {
535597
file, err := os.OpenFile(name, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, 0644)
536598
if err != nil {
537599
return err
538600
}
539601
defer file.Close()
540-
return c.WriteGob(file)
602+
return c.writeGobLocked(file)
541603
}
542604

543605
// WriteClassesToFile writes all classes to files.
606+
// This method is safe for concurrent use.
544607
func (c *Classifier) WriteClassesToFile(rootPath string) error {
608+
c.mu.RLock()
609+
defer c.mu.RUnlock()
545610
for name := range c.datas {
546-
if err := c.WriteClassToFile(name, rootPath); err != nil {
611+
if err := c.writeClassToFileLocked(name, rootPath); err != nil {
547612
return err
548613
}
549614
}
550615
return nil
551616
}
552617

553618
// WriteClassToFile writes a single class to file.
619+
// This method is safe for concurrent use.
554620
func (c *Classifier) WriteClassToFile(name Class, rootPath string) error {
621+
c.mu.RLock()
622+
defer c.mu.RUnlock()
623+
return c.writeClassToFileLocked(name, rootPath)
624+
}
625+
626+
// writeClassToFileLocked writes a single class to file (caller must hold lock).
627+
func (c *Classifier) writeClassToFileLocked(name Class, rootPath string) error {
555628
data := c.datas[name]
556629
fileName := filepath.Join(rootPath, string(name))
557630
file, err := os.OpenFile(fileName, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, 0644)
@@ -564,20 +637,27 @@ func (c *Classifier) WriteClassToFile(name Class, rootPath string) error {
564637

565638

566639
// WriteGob serializes this classifier to GOB and writes to Writer.
640+
// This method is safe for concurrent use.
567641
func (c *Classifier) WriteGob(w io.Writer) (err error) {
642+
c.mu.RLock()
643+
defer c.mu.RUnlock()
644+
return c.writeGobLocked(w)
645+
}
646+
647+
// writeGobLocked serializes this classifier to GOB (caller must hold lock).
648+
func (c *Classifier) writeGobLocked(w io.Writer) (err error) {
568649
enc := gob.NewEncoder(w)
569650
err = enc.Encode(&serializableClassifier{c.Classes, c.learned, int(c.seen), c.datas, c.tfIdf, c.DidConvertTfIdf})
570-
571651
return
572652
}
573653

574654

575655
// ReadClassFromFile loads existing class data from a
576656
// file.
657+
// This method is safe for concurrent use.
577658
func (c *Classifier) ReadClassFromFile(class Class, location string) (err error) {
578659
fileName := filepath.Join(location, string(class))
579660
file, err := os.Open(fileName)
580-
581661
if err != nil {
582662
return err
583663
}
@@ -586,7 +666,12 @@ func (c *Classifier) ReadClassFromFile(class Class, location string) (err error)
586666
dec := gob.NewDecoder(file)
587667
w := new(classData)
588668
err = dec.Decode(w)
669+
if err != nil {
670+
return err
671+
}
589672

673+
c.mu.Lock()
674+
defer c.mu.Unlock()
590675
c.learned++
591676
c.datas[class] = w
592677
return

0 commit comments

Comments
 (0)