-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathEmTwentyNewsgroups.java
182 lines (145 loc) · 8.06 KB
/
EmTwentyNewsgroups.java
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
import com.aliasi.corpus.Corpus;
import com.aliasi.corpus.ObjectHandler;
import com.aliasi.classify.Classification;
import com.aliasi.classify.Classified;
import com.aliasi.classify.JointClassifier;
import com.aliasi.classify.JointClassification;
import com.aliasi.classify.JointClassifierEvaluator;
import com.aliasi.classify.TradNaiveBayesClassifier;
import com.aliasi.classify.ConfusionMatrix;
import com.aliasi.io.LogLevel;
import com.aliasi.io.Reporter;
import com.aliasi.io.Reporters;
import com.aliasi.stats.Statistics;
import com.aliasi.tokenizer.EnglishStopTokenizerFactory;
import com.aliasi.tokenizer.IndoEuropeanTokenizerFactory;
import com.aliasi.tokenizer.LowerCaseTokenizerFactory;
import com.aliasi.tokenizer.RegExFilteredTokenizerFactory;
import com.aliasi.tokenizer.TokenizerFactory;
import com.aliasi.tokenizer.WhitespaceNormTokenizerFactory;
import com.aliasi.util.AbstractExternalizable;
import com.aliasi.util.Factory;
import com.aliasi.util.Strings;
import java.io.File;
import java.io.IOException;
import java.io.PrintWriter;
import java.io.*;
import java.util.Arrays;
import java.util.Random;
import java.util.regex.Pattern;
public class EmTwentyNewsgroups {
static final long RANDOM_SEED = 45L;
static final int NUM_REPLICATIONS = 1;
static final int MAX_EPOCHS = 50;
static final double MIN_IMPROVEMENT = 0.0001;
static final double CATEGORY_PRIOR = 0.005; // balanced, doesn't matter
static final double TOKEN_IN_CATEGORY_PRIOR = 0.001; // very sensitive to this
static final double INITIAL_TOKEN_IN_CATEGORY_PRIOR = 0.1; // only used first run; want more uniform
static final double DOC_LENGTH_NORM = 9.0;
static final double COUNT_MULTIPLIER = 1.0;
static final double MIN_COUNT = 0.0001;
static final TokenizerFactory TOKENIZER_FACTORY = tokenizerFactory();
public static void main(String[] args) throws Exception {
long startTime = System.currentTimeMillis();
File corpusPath = new File(args[0]);
System.out.println("CORPUS PATH=" + corpusPath);
System.out.println("DOC LENGTH NORM=" + DOC_LENGTH_NORM);
System.out.println("CATEGORY PRIOR=" + CATEGORY_PRIOR);
System.out.println("TOKEN IN CATEGORY PRIOR=" + TOKEN_IN_CATEGORY_PRIOR);
System.out.println("INITIAL TOKEN IN CATEGORY PRIOR=" + INITIAL_TOKEN_IN_CATEGORY_PRIOR);
System.out.println("NUM REPS=" + NUM_REPLICATIONS);
System.out.println("MAX EPOCHS=" + MAX_EPOCHS);
System.out.println("RANDOM SEED=" + RANDOM_SEED);
System.out.println();
final TwentyNewsgroupsCorpus corpus = new TwentyNewsgroupsCorpus(corpusPath);
Corpus<ObjectHandler<CharSequence>> unlabeledCorpus = corpus.unlabeledCorpus();
System.out.println(corpus);
System.out.println();
Reporter reporter = Reporters.stream(System.out,"ISO-8859-1").setLevel(LogLevel.DEBUG);
Random random = new Random(RANDOM_SEED);
for (int numSupervisedItems : new Integer[] { 0 }) {
System.out.println("SUPERVISED DOCS/CAT=" + numSupervisedItems);
corpus.setMaxSupervisedInstancesPerCategory(numSupervisedItems);
double[] accs = new double[NUM_REPLICATIONS];
double[] accsEm = new double[NUM_REPLICATIONS];
for (int trial = 0; trial < NUM_REPLICATIONS; ++trial) {
System.out.println("TRIAL=" + trial);
corpus.permuteInstances(random);
TradNaiveBayesClassifier initialClassifier
= new TradNaiveBayesClassifier(corpus.categorySet(),
TOKENIZER_FACTORY,
CATEGORY_PRIOR,
INITIAL_TOKEN_IN_CATEGORY_PRIOR,
DOC_LENGTH_NORM);
Factory<TradNaiveBayesClassifier> classifierFactory
= new Factory<TradNaiveBayesClassifier>() {
public TradNaiveBayesClassifier create() {
return new TradNaiveBayesClassifier(corpus.categorySet(),
TOKENIZER_FACTORY,
CATEGORY_PRIOR,
TOKEN_IN_CATEGORY_PRIOR,
DOC_LENGTH_NORM);
}};
TradNaiveBayesClassifier emClassifier
= TradNaiveBayesClassifier.emTrain(initialClassifier,
classifierFactory,
corpus,
unlabeledCorpus,
MIN_COUNT,
MAX_EPOCHS,
MIN_IMPROVEMENT,
reporter);
System.out.println("=====INITIAL CLASSIFIER=====");
accs[trial] = eval(initialClassifier,corpus);
System.out.println("=====EM CLASSIFIER=====");
accsEm[trial] = eval(emClassifier,corpus);
System.out.printf("ACC=%5.3f EM ACC=%5.3f\n\n",
accs[trial], accsEm[trial]);
}
System.out.println(" ---------------------");
System.out.printf("#Sup=%4d Supervised mean(acc)=%5.3f sd(acc)=%5.3f EM mean(acc)=%5.3f sd(acc)=%5.3f %10s\n\n",
numSupervisedItems,
Statistics.mean(accs),
Statistics.standardDeviation(accs),
Statistics.mean(accsEm),
Statistics.standardDeviation(accsEm),
Strings.msToString(System.currentTimeMillis() - startTime));
}
reporter.close();
}
static double eval(TradNaiveBayesClassifier classifier,
Corpus<ObjectHandler<Classified<CharSequence>>> corpus)
throws IOException, ClassNotFoundException {
String[] categories = classifier.categorySet().toArray(new String[0]);
Arrays.sort(categories);
@SuppressWarnings("unchecked")
JointClassifier<CharSequence> compiledClassifier
= (JointClassifier<CharSequence>)
AbstractExternalizable.compile(classifier);
boolean storeInputs = false;
JointClassifierEvaluator<CharSequence> evaluator
= new JointClassifierEvaluator<CharSequence>(compiledClassifier,
categories,
storeInputs);
corpus.visitTest(evaluator);
//Get content of ConfusionMatrix
String[] evalCats = new String[] {"004","139","013","014","016","020","149","025","029","034","040","046","047","050","052","055","062","011","070","073","079","083","089","092","104","111"};
ConfusionMatrix confuse = evaluator.confusionMatrix();
String[] cats = confuse.categories();
for (String cat: evalCats){
int index = Arrays.asList(cats).indexOf(cat);
int realCount = corpus.getTestCount(cat);
int emCount = confuse.count(index, index);
float score = (float)emCount / realCount;
System.out.printf("Evaluation score for topic [%s]: %d %d %5.3f\n", cat, emCount, realCount, score);
}
return evaluator.confusionMatrix().totalAccuracy();
}
static TokenizerFactory tokenizerFactory() {
TokenizerFactory factory = IndoEuropeanTokenizerFactory.INSTANCE;
factory = new RegExFilteredTokenizerFactory(factory,Pattern.compile("\\p{Alpha}+"));
factory = new LowerCaseTokenizerFactory(factory);
factory = new EnglishStopTokenizerFactory(factory);
return factory;
}
}