]> granicus.if.org Git - liblinear/blob - train.c
L1-regularized L2-loss SVM and Logistic Regression
[liblinear] / train.c
1 #include <stdio.h>
2 #include <math.h>
3 #include <stdlib.h>
4 #include <string.h>
5 #include <ctype.h>
6 #include <errno.h>
7 #include "linear.h"
8 #define Malloc(type,n) (type *)malloc((n)*sizeof(type))
9 #define INF HUGE_VAL
10
11 void print_null(const char *s) {}
12
13 void exit_with_help()
14 {
15         printf(
16         "Usage: train [options] training_set_file [model_file]\n"
17         "options:\n"
18         "-s type : set type of solver (default 1)\n"
19         "       0 -- L2-regularized logistic regression\n"
20         "       1 -- L2-regularized L2-loss support vector classification (dual)\n"     
21         "       2 -- L2-regularized L2-loss support vector classification (primal)\n"
22         "       3 -- L2-regularized L1-loss support vector classification (dual)\n"
23         "       4 -- multi-class support vector classification by Crammer and Singer\n"
24         "       5 -- L1-regularized L2-loss support vector classification\n"
25         "       6 -- L1-regularized logistic regression\n"
26         "-c cost : set the parameter C (default 1)\n"
27         "-e epsilon : set tolerance of termination criterion\n"
28         "       -s 0 and 2\n" 
29         "               |f'(w)|_2 <= eps*min(pos,neg)/l*|f'(w0)|_2,\n" 
30         "               where f is the primal function and pos/neg are # of\n" 
31         "               positive/negative data (default 0.01)\n"
32         "       -s 1, 3, and 4\n"
33         "               Dual maximal violation <= eps; similar to libsvm (default 0.1)\n"
34         "       -s 5 and 6\n"
35         "               |f'(w)|_inf <= eps*min(pos,neg)/l*|f'(w0)|_inf,\n"
36         "               where f is the primal function (default 0.01)\n"
37         "-B bias : if bias >= 0, instance x becomes [x; bias]; if < 0, no bias term added (default 1)\n"
38         "-wi weight: weights adjust the parameter C of different classes (see README for details)\n"
39         "-v n: n-fold cross validation mode\n"
40         "-q : quiet mode (no outputs)\n"
41         );
42         exit(1);
43 }
44
45 void exit_input_error(int line_num)
46 {
47         fprintf(stderr,"Wrong input format at line %d\n", line_num);
48         exit(1);
49 }
50
51 static char *line = NULL;
52 static int max_line_len;
53
54 static char* readline(FILE *input)
55 {
56         int len;
57         
58         if(fgets(line,max_line_len,input) == NULL)
59                 return NULL;
60
61         while(strrchr(line,'\n') == NULL)
62         {
63                 max_line_len *= 2;
64                 line = (char *) realloc(line,max_line_len);
65                 len = (int) strlen(line);
66                 if(fgets(line+len,max_line_len-len,input) == NULL)
67                         break;
68         }
69         return line;
70 }
71
72 void parse_command_line(int argc, char **argv, char *input_file_name, char *model_file_name);
73 void read_problem(const char *filename);
74 void do_cross_validation();
75
76 struct feature_node *x_space;
77 struct parameter param;
78 struct problem prob;
79 struct model* model_;
80 int flag_cross_validation;
81 int nr_fold;
82 double bias;
83
84 int main(int argc, char **argv)
85 {
86         char input_file_name[1024];
87         char model_file_name[1024];
88         const char *error_msg;
89
90         parse_command_line(argc, argv, input_file_name, model_file_name);
91         read_problem(input_file_name);
92         error_msg = check_parameter(&prob,&param);
93
94         if(error_msg)
95         {
96                 fprintf(stderr,"Error: %s\n",error_msg);
97                 exit(1);
98         }
99
100         if(flag_cross_validation)
101         {
102                 do_cross_validation();
103         }
104         else
105         {
106                 model_=train(&prob, &param);
107                 save_model(model_file_name, model_);
108                 destroy_model(model_);
109         }
110         destroy_param(&param);
111         free(prob.y);
112         free(prob.x);
113         free(x_space);
114         free(line);
115
116         return 0;
117 }
118
119 void do_cross_validation()
120 {
121         int i;
122         int total_correct = 0;
123         int *target = Malloc(int, prob.l);
124
125         cross_validation(&prob,&param,nr_fold,target);
126
127         for(i=0;i<prob.l;i++)
128                 if(target[i] == prob.y[i])
129                         ++total_correct;
130         printf("Cross Validation Accuracy = %g%%\n",100.0*total_correct/prob.l);
131
132         free(target);
133 }
134
135 void parse_command_line(int argc, char **argv, char *input_file_name, char *model_file_name)
136 {
137         int i;
138
139         // default values
140         param.solver_type = L2R_L2LOSS_SVC_DUAL;
141         param.C = 1;
142         param.eps = INF; // see setting below
143         param.nr_weight = 0;
144         param.weight_label = NULL;
145         param.weight = NULL;
146         flag_cross_validation = 0;
147         bias = 1;
148
149         // parse options
150         for(i=1;i<argc;i++)
151         {
152                 if(argv[i][0] != '-') break;
153                 if(++i>=argc)
154                         exit_with_help();
155                 switch(argv[i-1][1])
156                 {
157                         case 's':
158                                 param.solver_type = atoi(argv[i]);
159                                 break;
160
161                         case 'c':
162                                 param.C = atof(argv[i]);
163                                 break;
164
165                         case 'e':
166                                 param.eps = atof(argv[i]);
167                                 break;
168
169                         case 'B':
170                                 bias = atof(argv[i]);
171                                 break;
172
173                         case 'w':
174                                 ++param.nr_weight;
175                                 param.weight_label = (int *) realloc(param.weight_label,sizeof(int)*param.nr_weight);
176                                 param.weight = (double *) realloc(param.weight,sizeof(double)*param.nr_weight);
177                                 param.weight_label[param.nr_weight-1] = atoi(&argv[i-1][2]);
178                                 param.weight[param.nr_weight-1] = atof(argv[i]);
179                                 break;
180
181                         case 'v':
182                                 flag_cross_validation = 1;
183                                 nr_fold = atoi(argv[i]);
184                                 if(nr_fold < 2)
185                                 {
186                                         fprintf(stderr,"n-fold cross validation: n must >= 2\n");
187                                         exit_with_help();
188                                 }
189                                 break;
190
191                         case 'q':
192                                 liblinear_print_string = &print_null;
193                                 i--;
194                                 break;
195
196                         default:
197                                 fprintf(stderr,"unknown option: -%c\n", argv[i-1][1]);
198                                 exit_with_help();
199                                 break;
200                 }
201         }
202
203         // determine filenames
204         if(i>=argc)
205                 exit_with_help();
206
207         strcpy(input_file_name, argv[i]);
208
209         if(i<argc-1)
210                 strcpy(model_file_name,argv[i+1]);
211         else
212         {
213                 char *p = strrchr(argv[i],'/');
214                 if(p==NULL)
215                         p = argv[i];
216                 else
217                         ++p;
218                 sprintf(model_file_name,"%s.model",p);
219         }
220
221         if(param.eps == INF)
222         {
223                 if(param.solver_type == L2R_LR || param.solver_type == L2R_L2LOSS_SVC)
224                         param.eps = 0.01;
225                 else if(param.solver_type == L2R_L2LOSS_SVC_DUAL || param.solver_type == L2R_L1LOSS_SVC_DUAL || param.solver_type == MCSVM_CS)
226                         param.eps = 0.1;
227                 else if(param.solver_type == L1R_L2LOSS_SVC || param.solver_type == L1R_LR)
228                         param.eps = 0.01;
229         }
230 }
231
232 // read in a problem (in libsvm format)
233 void read_problem(const char *filename)
234 {
235         int max_index, inst_max_index, i;
236         long int elements, j;
237         FILE *fp = fopen(filename,"r");
238         char *endptr;
239         char *idx, *val, *label;
240
241         if(fp == NULL)
242         {
243                 fprintf(stderr,"can't open input file %s\n",filename);
244                 exit(1);
245         }
246
247         prob.l = 0;
248         elements = 0;
249         max_line_len = 1024;
250         line = Malloc(char,max_line_len);
251         while(readline(fp)!=NULL)
252         {
253                 char *p = strtok(line," \t"); // label
254
255                 // features
256                 while(1)
257                 {
258                         p = strtok(NULL," \t");
259                         if(p == NULL || *p == '\n') // check '\n' as ' ' may be after the last feature
260                                 break;
261                         elements++;
262                 }
263                 elements++;
264                 prob.l++;
265         }
266         rewind(fp);
267
268         prob.bias=bias;
269
270         prob.y = Malloc(int,prob.l);
271         prob.x = Malloc(struct feature_node *,prob.l);
272         x_space = Malloc(struct feature_node,elements+prob.l);
273
274         max_index = 0;
275         j=0;
276         for(i=0;i<prob.l;i++)
277         {
278                 inst_max_index = 0; // strtol gives 0 if wrong format
279                 readline(fp);
280                 prob.x[i] = &x_space[j];
281                 label = strtok(line," \t");
282                 prob.y[i] = (int) strtol(label,&endptr,10);
283                 if(endptr == label)
284                         exit_input_error(i+1);
285
286                 while(1)
287                 {
288                         idx = strtok(NULL,":");
289                         val = strtok(NULL," \t");
290
291                         if(val == NULL)
292                                 break;
293
294                         errno = 0;
295                         x_space[j].index = (int) strtol(idx,&endptr,10);
296                         if(endptr == idx || errno != 0 || *endptr != '\0' || x_space[j].index <= inst_max_index)
297                                 exit_input_error(i+1);
298                         else
299                                 inst_max_index = x_space[j].index;
300
301                         errno = 0;
302                         x_space[j].value = strtod(val,&endptr);
303                         if(endptr == val || errno != 0 || (*endptr != '\0' && !isspace(*endptr)))
304                                 exit_input_error(i+1);
305
306                         ++j;
307                 }
308
309                 if(inst_max_index > max_index)
310                         max_index = inst_max_index;
311
312                 if(prob.bias >= 0)
313                         x_space[j++].value = prob.bias;
314
315                 x_space[j++].index = -1;
316         }
317
318         if(prob.bias >= 0)
319         {
320                 prob.n=max_index+1;
321                 for(i=1;i<prob.l;i++)
322                         (prob.x[i]-2)->index = prob.n; 
323                 x_space[j-2].index = prob.n;
324         }
325         else
326                 prob.n=max_index;
327
328         fclose(fp);
329 }