]> granicus.if.org Git - liblinear/blob - train.c
Update Windows binaries for 2.46 release
[liblinear] / train.c
1 #include <stdio.h>
2 #include <math.h>
3 #include <stdlib.h>
4 #include <string.h>
5 #include <ctype.h>
6 #include <errno.h>
7 #include "linear.h"
8 #define Malloc(type,n) (type *)malloc((n)*sizeof(type))
9 #define INF HUGE_VAL
10
11 void print_null(const char *s) {}
12
13 void exit_with_help()
14 {
15         printf(
16         "Usage: train [options] training_set_file [model_file]\n"
17         "options:\n"
18         "-s type : set type of solver (default 1)\n"
19         "  for multi-class classification\n"
20         "        0 -- L2-regularized logistic regression (primal)\n"
21         "        1 -- L2-regularized L2-loss support vector classification (dual)\n"
22         "        2 -- L2-regularized L2-loss support vector classification (primal)\n"
23         "        3 -- L2-regularized L1-loss support vector classification (dual)\n"
24         "        4 -- support vector classification by Crammer and Singer\n"
25         "        5 -- L1-regularized L2-loss support vector classification\n"
26         "        6 -- L1-regularized logistic regression\n"
27         "        7 -- L2-regularized logistic regression (dual)\n"
28         "  for regression\n"
29         "       11 -- L2-regularized L2-loss support vector regression (primal)\n"
30         "       12 -- L2-regularized L2-loss support vector regression (dual)\n"
31         "       13 -- L2-regularized L1-loss support vector regression (dual)\n"
32         "  for outlier detection\n"
33         "       21 -- one-class support vector machine (dual)\n"
34         "-c cost : set the parameter C (default 1)\n"
35         "-p epsilon : set the epsilon in loss function of SVR (default 0.1)\n"
36         "-n nu : set the parameter nu of one-class SVM (default 0.5)\n"
37         "-e epsilon : set tolerance of termination criterion\n"
38         "       -s 0 and 2\n"
39         "               |f'(w)|_2 <= eps*min(pos,neg)/l*|f'(w0)|_2,\n"
40         "               where f is the primal function and pos/neg are # of\n"
41         "               positive/negative data (default 0.01)\n"
42         "       -s 11\n"
43         "               |f'(w)|_2 <= eps*|f'(w0)|_2 (default 0.0001)\n"
44         "       -s 1, 3, 4, 7, and 21\n"
45         "               Dual maximal violation <= eps; similar to libsvm (default 0.1 except 0.01 for -s 21)\n"
46         "      -s 5 and 6\n"
47         "               |f'(w)|_1 <= eps*min(pos,neg)/l*|f'(w0)|_1,\n"
48         "               where f is the primal function (default 0.01)\n"
49         "       -s 12 and 13\n"
50         "               |f'(alpha)|_1 <= eps |f'(alpha0)|,\n"
51         "               where f is the dual function (default 0.1)\n"
52         "-B bias : if bias >= 0, instance x becomes [x; bias]; if < 0, no bias term added (default -1)\n"
53         "-R : not regularize the bias; must with -B 1 to have the bias; DON'T use this unless you know what it is\n"
54         "       (for -s 0, 2, 5, 6, 11)\n"
55         "-wi weight: weights adjust the parameter C of different classes (see README for details)\n"
56         "-v n: n-fold cross validation mode\n"
57         "-C : find parameters (C for -s 0, 2 and C, p for -s 11)\n"
58         "-q : quiet mode (no outputs)\n"
59         );
60         exit(1);
61 }
62
63 void exit_input_error(int line_num)
64 {
65         fprintf(stderr,"Wrong input format at line %d\n", line_num);
66         exit(1);
67 }
68
69 static char *line = NULL;
70 static int max_line_len;
71
72 static char* readline(FILE *input)
73 {
74         int len;
75
76         if(fgets(line,max_line_len,input) == NULL)
77                 return NULL;
78
79         while(strrchr(line,'\n') == NULL)
80         {
81                 max_line_len *= 2;
82                 line = (char *) realloc(line,max_line_len);
83                 len = (int) strlen(line);
84                 if(fgets(line+len,max_line_len-len,input) == NULL)
85                         break;
86         }
87         return line;
88 }
89
90 void parse_command_line(int argc, char **argv, char *input_file_name, char *model_file_name);
91 void read_problem(const char *filename);
92 void do_cross_validation();
93 void do_find_parameters();
94
95 struct feature_node *x_space;
96 struct parameter param;
97 struct problem prob;
98 struct model* model_;
99 int flag_cross_validation;
100 int flag_find_parameters;
101 int flag_C_specified;
102 int flag_p_specified;
103 int flag_solver_specified;
104 int nr_fold;
105 double bias;
106
107 int main(int argc, char **argv)
108 {
109         char input_file_name[1024];
110         char model_file_name[1024];
111         const char *error_msg;
112
113         parse_command_line(argc, argv, input_file_name, model_file_name);
114         read_problem(input_file_name);
115         error_msg = check_parameter(&prob,&param);
116
117         if(error_msg)
118         {
119                 fprintf(stderr,"ERROR: %s\n",error_msg);
120                 exit(1);
121         }
122
123         if (flag_find_parameters)
124         {
125                 do_find_parameters();
126         }
127         else if(flag_cross_validation)
128         {
129                 do_cross_validation();
130         }
131         else
132         {
133                 model_=train(&prob, &param);
134                 if(save_model(model_file_name, model_))
135                 {
136                         fprintf(stderr,"can't save model to file %s\n",model_file_name);
137                         exit(1);
138                 }
139                 free_and_destroy_model(&model_);
140         }
141         destroy_param(&param);
142         free(prob.y);
143         free(prob.x);
144         free(x_space);
145         free(line);
146
147         return 0;
148 }
149
150 void do_find_parameters()
151 {
152         double start_C, start_p, best_C, best_p, best_score;
153         if (flag_C_specified)
154                 start_C = param.C;
155         else
156                 start_C = -1.0;
157         if (flag_p_specified)
158                 start_p = param.p;
159         else
160                 start_p = -1.0;
161
162         printf("Doing parameter search with %d-fold cross validation.\n", nr_fold);
163         find_parameters(&prob, &param, nr_fold, start_C, start_p, &best_C, &best_p, &best_score);
164         if(param.solver_type == L2R_LR || param.solver_type == L2R_L2LOSS_SVC)
165                 printf("Best C = %g  CV accuracy = %g%%\n", best_C, 100.0*best_score);
166         else if(param.solver_type == L2R_L2LOSS_SVR)
167                 printf("Best C = %g Best p = %g  CV MSE = %g\n", best_C, best_p, best_score);
168 }
169
170 void do_cross_validation()
171 {
172         int i;
173         int total_correct = 0;
174         double total_error = 0;
175         double sumv = 0, sumy = 0, sumvv = 0, sumyy = 0, sumvy = 0;
176         double *target = Malloc(double, prob.l);
177
178         cross_validation(&prob,&param,nr_fold,target);
179         if(param.solver_type == L2R_L2LOSS_SVR ||
180            param.solver_type == L2R_L1LOSS_SVR_DUAL ||
181            param.solver_type == L2R_L2LOSS_SVR_DUAL)
182         {
183                 for(i=0;i<prob.l;i++)
184                 {
185                         double y = prob.y[i];
186                         double v = target[i];
187                         total_error += (v-y)*(v-y);
188                         sumv += v;
189                         sumy += y;
190                         sumvv += v*v;
191                         sumyy += y*y;
192                         sumvy += v*y;
193                 }
194                 printf("Cross Validation Mean squared error = %g\n",total_error/prob.l);
195                 printf("Cross Validation Squared correlation coefficient = %g\n",
196                                 ((prob.l*sumvy-sumv*sumy)*(prob.l*sumvy-sumv*sumy))/
197                                 ((prob.l*sumvv-sumv*sumv)*(prob.l*sumyy-sumy*sumy))
198                           );
199         }
200         else
201         {
202                 for(i=0;i<prob.l;i++)
203                         if(target[i] == prob.y[i])
204                                 ++total_correct;
205                 printf("Cross Validation Accuracy = %g%%\n",100.0*total_correct/prob.l);
206         }
207
208         free(target);
209 }
210
211 void parse_command_line(int argc, char **argv, char *input_file_name, char *model_file_name)
212 {
213         int i;
214         void (*print_func)(const char*) = NULL;  // default printing to stdout
215
216         // default values
217         param.solver_type = L2R_L2LOSS_SVC_DUAL;
218         param.C = 1;
219         param.p = 0.1;
220         param.nu = 0.5;
221         param.eps = INF; // see setting below
222         param.nr_weight = 0;
223         param.regularize_bias = 1;
224         param.weight_label = NULL;
225         param.weight = NULL;
226         param.init_sol = NULL;
227         flag_cross_validation = 0;
228         flag_C_specified = 0;
229         flag_p_specified = 0;
230         flag_solver_specified = 0;
231         flag_find_parameters = 0;
232         bias = -1;
233
234         // parse options
235         for(i=1;i<argc;i++)
236         {
237                 if(argv[i][0] != '-') break;
238                 if(++i>=argc)
239                         exit_with_help();
240                 switch(argv[i-1][1])
241                 {
242                         case 's':
243                                 param.solver_type = atoi(argv[i]);
244                                 flag_solver_specified = 1;
245                                 break;
246
247                         case 'c':
248                                 param.C = atof(argv[i]);
249                                 flag_C_specified = 1;
250                                 break;
251
252                         case 'p':
253                                 flag_p_specified = 1;
254                                 param.p = atof(argv[i]);
255                                 break;
256
257                         case 'n':
258                                 param.nu = atof(argv[i]);
259                                 break;
260
261                         case 'e':
262                                 param.eps = atof(argv[i]);
263                                 break;
264
265                         case 'B':
266                                 bias = atof(argv[i]);
267                                 break;
268
269                         case 'w':
270                                 ++param.nr_weight;
271                                 param.weight_label = (int *) realloc(param.weight_label,sizeof(int)*param.nr_weight);
272                                 param.weight = (double *) realloc(param.weight,sizeof(double)*param.nr_weight);
273                                 param.weight_label[param.nr_weight-1] = atoi(&argv[i-1][2]);
274                                 param.weight[param.nr_weight-1] = atof(argv[i]);
275                                 break;
276
277                         case 'v':
278                                 flag_cross_validation = 1;
279                                 nr_fold = atoi(argv[i]);
280                                 if(nr_fold < 2)
281                                 {
282                                         fprintf(stderr,"n-fold cross validation: n must >= 2\n");
283                                         exit_with_help();
284                                 }
285                                 break;
286
287                         case 'q':
288                                 print_func = &print_null;
289                                 i--;
290                                 break;
291
292                         case 'C':
293                                 flag_find_parameters = 1;
294                                 i--;
295                                 break;
296
297                         case 'R':
298                                 param.regularize_bias = 0;
299                                 i--;
300                                 break;
301
302                         default:
303                                 fprintf(stderr,"unknown option: -%c\n", argv[i-1][1]);
304                                 exit_with_help();
305                                 break;
306                 }
307         }
308
309         set_print_string_function(print_func);
310
311         // determine filenames
312         if(i>=argc)
313                 exit_with_help();
314
315         strcpy(input_file_name, argv[i]);
316
317         if(i<argc-1)
318                 strcpy(model_file_name,argv[i+1]);
319         else
320         {
321                 char *p = strrchr(argv[i],'/');
322                 if(p==NULL)
323                         p = argv[i];
324                 else
325                         ++p;
326                 sprintf(model_file_name,"%s.model",p);
327         }
328
329         // default solver for parameter selection is L2R_L2LOSS_SVC
330         if(flag_find_parameters)
331         {
332                 if(!flag_cross_validation)
333                         nr_fold = 5;
334                 if(!flag_solver_specified)
335                 {
336                         fprintf(stderr, "Solver not specified. Using -s 2\n");
337                         param.solver_type = L2R_L2LOSS_SVC;
338                 }
339                 else if(param.solver_type != L2R_LR && param.solver_type != L2R_L2LOSS_SVC && param.solver_type != L2R_L2LOSS_SVR)
340                 {
341                         fprintf(stderr, "Warm-start parameter search only available for -s 0, -s 2 and -s 11\n");
342                         exit_with_help();
343                 }
344         }
345
346         if(param.eps == INF)
347         {
348                 switch(param.solver_type)
349                 {
350                         case L2R_LR:
351                         case L2R_L2LOSS_SVC:
352                                 param.eps = 0.01;
353                                 break;
354                         case L2R_L2LOSS_SVR:
355                                 param.eps = 0.0001;
356                                 break;
357                         case L2R_L2LOSS_SVC_DUAL:
358                         case L2R_L1LOSS_SVC_DUAL:
359                         case MCSVM_CS:
360                         case L2R_LR_DUAL:
361                                 param.eps = 0.1;
362                                 break;
363                         case L1R_L2LOSS_SVC:
364                         case L1R_LR:
365                                 param.eps = 0.01;
366                                 break;
367                         case L2R_L1LOSS_SVR_DUAL:
368                         case L2R_L2LOSS_SVR_DUAL:
369                                 param.eps = 0.1;
370                                 break;
371                         case ONECLASS_SVM:
372                                 param.eps = 0.01;
373                                 break;
374                 }
375         }
376 }
377
378 // read in a problem (in libsvm format)
379 void read_problem(const char *filename)
380 {
381         int max_index, inst_max_index, i;
382         size_t elements, j;
383         FILE *fp = fopen(filename,"r");
384         char *endptr;
385         char *idx, *val, *label;
386
387         if(fp == NULL)
388         {
389                 fprintf(stderr,"can't open input file %s\n",filename);
390                 exit(1);
391         }
392
393         prob.l = 0;
394         elements = 0;
395         max_line_len = 1024;
396         line = Malloc(char,max_line_len);
397         while(readline(fp)!=NULL)
398         {
399                 char *p = strtok(line," \t"); // label
400
401                 // features
402                 while(1)
403                 {
404                         p = strtok(NULL," \t");
405                         if(p == NULL || *p == '\n') // check '\n' as ' ' may be after the last feature
406                                 break;
407                         elements++;
408                 }
409                 elements++; // for bias term
410                 prob.l++;
411         }
412         rewind(fp);
413
414         prob.bias=bias;
415
416         prob.y = Malloc(double,prob.l);
417         prob.x = Malloc(struct feature_node *,prob.l);
418         x_space = Malloc(struct feature_node,elements+prob.l);
419
420         max_index = 0;
421         j=0;
422         for(i=0;i<prob.l;i++)
423         {
424                 inst_max_index = 0; // strtol gives 0 if wrong format
425                 readline(fp);
426                 prob.x[i] = &x_space[j];
427                 label = strtok(line," \t\n");
428                 if(label == NULL) // empty line
429                         exit_input_error(i+1);
430
431                 prob.y[i] = strtod(label,&endptr);
432                 if(endptr == label || *endptr != '\0')
433                         exit_input_error(i+1);
434
435                 while(1)
436                 {
437                         idx = strtok(NULL,":");
438                         val = strtok(NULL," \t");
439
440                         if(val == NULL)
441                                 break;
442
443                         errno = 0;
444                         x_space[j].index = (int) strtol(idx,&endptr,10);
445                         if(endptr == idx || errno != 0 || *endptr != '\0' || x_space[j].index <= inst_max_index)
446                                 exit_input_error(i+1);
447                         else
448                                 inst_max_index = x_space[j].index;
449
450                         errno = 0;
451                         x_space[j].value = strtod(val,&endptr);
452                         if(endptr == val || errno != 0 || (*endptr != '\0' && !isspace(*endptr)))
453                                 exit_input_error(i+1);
454
455                         ++j;
456                 }
457
458                 if(inst_max_index > max_index)
459                         max_index = inst_max_index;
460
461                 if(prob.bias >= 0)
462                         x_space[j++].value = prob.bias;
463
464                 x_space[j++].index = -1;
465         }
466
467         if(prob.bias >= 0)
468         {
469                 prob.n=max_index+1;
470                 for(i=1;i<prob.l;i++)
471                         (prob.x[i]-2)->index = prob.n;
472                 x_space[j-2].index = prob.n;
473         }
474         else
475                 prob.n=max_index;
476
477         fclose(fp);
478 }