-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathcosmo_classify_naive_bayes.m
More file actions
64 lines (49 loc) · 1.87 KB
/
cosmo_classify_naive_bayes.m
File metadata and controls
64 lines (49 loc) · 1.87 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
function predicted=cosmo_classify_naive_bayes(samples_train, targets_train, samples_test, opt)
% naive bayes classifier
%
% predicted=cosmo_classify_naive_bayes(samples_train, targets_train, samples_test[, opt])
%
% Inputs
% - samples_train PxR training data for P samples and R features
% - targets_train Px1 training data classes
% - samples_test QxR test data
%- opt (currently ignored)
%
% Output
% - predicted Qx1 predicted data classes for samples_test
%
% NNO Aug 2013
if nargin<4, opt=struct(); end
[ntrain, nfeatures]=size(samples_train);
[ntest, nfeatures_]=size(samples_test);
ntrain_=numel(targets_train);
if nfeatures~=nfeatures_ || ntrain_~=ntrain, error('illegal input size'); end
classes=unique(targets_train);
nclasses=numel(classes);
% allocate space for statistics of each class
mus=zeros(nclasses,nfeatures);
stds=zeros(nclasses,nfeatures);
class_probs=zeros(nclasses,1);
% compute means and standard deviations of each class
for k=1:nclasses
msk=targets_train==classes(k);
n=sum(msk); % number of samples
d=samples_train(msk,:); % samples in this class
mu=mean(d); %mean
mus(k,:)=mu;
stds(k,:)=sqrt(1/(n-1) * sum(bsxfun(@minus,mu,d).^2,1)); % standard deviation - faster implementation than 'std'
class_probs(k)=log(n/ntrain); % log of class probability
end
predicted=zeros(ntest,1);
for k=1:ntest
sample=samples_test(k,:);
% compute feature-wise probality relative to stats of each class
ps=normcdf(repmat(sample,nclasses,1), mus, stds);
% being 'naive' we assume independence - so take the product of the
% p values. (for better precision we take the log of the probablities
% and sum them)
test_prob=sum(log(ps),2)+class_probs;
% find the one with the highest probability
[foo, mx_idx]=max(test_prob);
predicted(k)=classes(mx_idx);
end