]> granicus.if.org Git - liblinear/blob - train.c
Change version number to 221 for 2.21 release
[liblinear] / train.c
1 #include <stdio.h>
2 #include <math.h>
3 #include <stdlib.h>
4 #include <string.h>
5 #include <ctype.h>
6 #include <errno.h>
7 #include "linear.h"
8 #define Malloc(type,n) (type *)malloc((n)*sizeof(type))
9 #define INF HUGE_VAL
10
11 void print_null(const char *s) {}
12
13 void exit_with_help()
14 {
15         printf(
16         "Usage: train [options] training_set_file [model_file]\n"
17         "options:\n"
18         "-s type : set type of solver (default 1)\n"
19         "  for multi-class classification\n"
20         "        0 -- L2-regularized logistic regression (primal)\n"
21         "        1 -- L2-regularized L2-loss support vector classification (dual)\n"
22         "        2 -- L2-regularized L2-loss support vector classification (primal)\n"
23         "        3 -- L2-regularized L1-loss support vector classification (dual)\n"
24         "        4 -- support vector classification by Crammer and Singer\n"
25         "        5 -- L1-regularized L2-loss support vector classification\n"
26         "        6 -- L1-regularized logistic regression\n"
27         "        7 -- L2-regularized logistic regression (dual)\n"
28         "  for regression\n"
29         "       11 -- L2-regularized L2-loss support vector regression (primal)\n"
30         "       12 -- L2-regularized L2-loss support vector regression (dual)\n"
31         "       13 -- L2-regularized L1-loss support vector regression (dual)\n"
32         "-c cost : set the parameter C (default 1)\n"
33         "-p epsilon : set the epsilon in loss function of SVR (default 0.1)\n"
34         "-e epsilon : set tolerance of termination criterion\n"
35         "       -s 0 and 2\n"
36         "               |f'(w)|_2 <= eps*min(pos,neg)/l*|f'(w0)|_2,\n"
37         "               where f is the primal function and pos/neg are # of\n"
38         "               positive/negative data (default 0.01)\n"
39         "       -s 11\n"
40         "               |f'(w)|_2 <= eps*|f'(w0)|_2 (default 0.001)\n"
41         "       -s 1, 3, 4, and 7\n"
42         "               Dual maximal violation <= eps; similar to libsvm (default 0.1)\n"
43         "       -s 5 and 6\n"
44         "               |f'(w)|_1 <= eps*min(pos,neg)/l*|f'(w0)|_1,\n"
45         "               where f is the primal function (default 0.01)\n"
46         "       -s 12 and 13\n"
47         "               |f'(alpha)|_1 <= eps |f'(alpha0)|,\n"
48         "               where f is the dual function (default 0.1)\n"
49         "-B bias : if bias >= 0, instance x becomes [x; bias]; if < 0, no bias term added (default -1)\n"
50         "-wi weight: weights adjust the parameter C of different classes (see README for details)\n"
51         "-v n: n-fold cross validation mode\n"
52         "-C : find parameter C (only for -s 0 and 2)\n"
53         "-q : quiet mode (no outputs)\n"
54         );
55         exit(1);
56 }
57
58 void exit_input_error(int line_num)
59 {
60         fprintf(stderr,"Wrong input format at line %d\n", line_num);
61         exit(1);
62 }
63
64 static char *line = NULL;
65 static int max_line_len;
66
67 static char* readline(FILE *input)
68 {
69         int len;
70
71         if(fgets(line,max_line_len,input) == NULL)
72                 return NULL;
73
74         while(strrchr(line,'\n') == NULL)
75         {
76                 max_line_len *= 2;
77                 line = (char *) realloc(line,max_line_len);
78                 len = (int) strlen(line);
79                 if(fgets(line+len,max_line_len-len,input) == NULL)
80                         break;
81         }
82         return line;
83 }
84
85 void parse_command_line(int argc, char **argv, char *input_file_name, char *model_file_name);
86 void read_problem(const char *filename);
87 void do_cross_validation();
88 void do_find_parameter_C();
89
90 struct feature_node *x_space;
91 struct parameter param;
92 struct problem prob;
93 struct model* model_;
94 int flag_cross_validation;
95 int flag_find_C;
96 int flag_C_specified;
97 int flag_solver_specified;
98 int nr_fold;
99 double bias;
100
101 int main(int argc, char **argv)
102 {
103         char input_file_name[1024];
104         char model_file_name[1024];
105         const char *error_msg;
106
107         parse_command_line(argc, argv, input_file_name, model_file_name);
108         read_problem(input_file_name);
109         error_msg = check_parameter(&prob,&param);
110
111         if(error_msg)
112         {
113                 fprintf(stderr,"ERROR: %s\n",error_msg);
114                 exit(1);
115         }
116
117         if (flag_find_C)
118         {
119                 do_find_parameter_C();
120         }
121         else if(flag_cross_validation)
122         {
123                 do_cross_validation();
124         }
125         else
126         {
127                 model_=train(&prob, &param);
128                 if(save_model(model_file_name, model_))
129                 {
130                         fprintf(stderr,"can't save model to file %s\n",model_file_name);
131                         exit(1);
132                 }
133                 free_and_destroy_model(&model_);
134         }
135         destroy_param(&param);
136         free(prob.y);
137         free(prob.x);
138         free(x_space);
139         free(line);
140
141         return 0;
142 }
143
144 void do_find_parameter_C()
145 {
146         double start_C, best_C, best_rate;
147         double max_C = 1024;
148         if (flag_C_specified)
149                 start_C = param.C;
150         else
151                 start_C = -1.0;
152         printf("Doing parameter search with %d-fold cross validation.\n", nr_fold);
153         find_parameter_C(&prob, &param, nr_fold, start_C, max_C, &best_C, &best_rate);
154         printf("Best C = %g  CV accuracy = %g%%\n", best_C, 100.0*best_rate);
155 }
156
157 void do_cross_validation()
158 {
159         int i;
160         int total_correct = 0;
161         double total_error = 0;
162         double sumv = 0, sumy = 0, sumvv = 0, sumyy = 0, sumvy = 0;
163         double *target = Malloc(double, prob.l);
164
165         cross_validation(&prob,&param,nr_fold,target);
166         if(param.solver_type == L2R_L2LOSS_SVR ||
167            param.solver_type == L2R_L1LOSS_SVR_DUAL ||
168            param.solver_type == L2R_L2LOSS_SVR_DUAL)
169         {
170                 for(i=0;i<prob.l;i++)
171                 {
172                         double y = prob.y[i];
173                         double v = target[i];
174                         total_error += (v-y)*(v-y);
175                         sumv += v;
176                         sumy += y;
177                         sumvv += v*v;
178                         sumyy += y*y;
179                         sumvy += v*y;
180                 }
181                 printf("Cross Validation Mean squared error = %g\n",total_error/prob.l);
182                 printf("Cross Validation Squared correlation coefficient = %g\n",
183                                 ((prob.l*sumvy-sumv*sumy)*(prob.l*sumvy-sumv*sumy))/
184                                 ((prob.l*sumvv-sumv*sumv)*(prob.l*sumyy-sumy*sumy))
185                           );
186         }
187         else
188         {
189                 for(i=0;i<prob.l;i++)
190                         if(target[i] == prob.y[i])
191                                 ++total_correct;
192                 printf("Cross Validation Accuracy = %g%%\n",100.0*total_correct/prob.l);
193         }
194
195         free(target);
196 }
197
198 void parse_command_line(int argc, char **argv, char *input_file_name, char *model_file_name)
199 {
200         int i;
201         void (*print_func)(const char*) = NULL; // default printing to stdout
202
203         // default values
204         param.solver_type = L2R_L2LOSS_SVC_DUAL;
205         param.C = 1;
206         param.eps = INF; // see setting below
207         param.p = 0.1;
208         param.nr_weight = 0;
209         param.weight_label = NULL;
210         param.weight = NULL;
211         param.init_sol = NULL;
212         flag_cross_validation = 0;
213         flag_C_specified = 0;
214         flag_solver_specified = 0;
215         flag_find_C = 0;
216         bias = -1;
217
218         // parse options
219         for(i=1;i<argc;i++)
220         {
221                 if(argv[i][0] != '-') break;
222                 if(++i>=argc)
223                         exit_with_help();
224                 switch(argv[i-1][1])
225                 {
226                         case 's':
227                                 param.solver_type = atoi(argv[i]);
228                                 flag_solver_specified = 1;
229                                 break;
230
231                         case 'c':
232                                 param.C = atof(argv[i]);
233                                 flag_C_specified = 1;
234                                 break;
235
236                         case 'p':
237                                 param.p = atof(argv[i]);
238                                 break;
239
240                         case 'e':
241                                 param.eps = atof(argv[i]);
242                                 break;
243
244                         case 'B':
245                                 bias = atof(argv[i]);
246                                 break;
247
248                         case 'w':
249                                 ++param.nr_weight;
250                                 param.weight_label = (int *) realloc(param.weight_label,sizeof(int)*param.nr_weight);
251                                 param.weight = (double *) realloc(param.weight,sizeof(double)*param.nr_weight);
252                                 param.weight_label[param.nr_weight-1] = atoi(&argv[i-1][2]);
253                                 param.weight[param.nr_weight-1] = atof(argv[i]);
254                                 break;
255
256                         case 'v':
257                                 flag_cross_validation = 1;
258                                 nr_fold = atoi(argv[i]);
259                                 if(nr_fold < 2)
260                                 {
261                                         fprintf(stderr,"n-fold cross validation: n must >= 2\n");
262                                         exit_with_help();
263                                 }
264                                 break;
265
266                         case 'q':
267                                 print_func = &print_null;
268                                 i--;
269                                 break;
270
271                         case 'C':
272                                 flag_find_C = 1;
273                                 i--;
274                                 break;
275
276                         default:
277                                 fprintf(stderr,"unknown option: -%c\n", argv[i-1][1]);
278                                 exit_with_help();
279                                 break;
280                 }
281         }
282
283         set_print_string_function(print_func);
284
285         // determine filenames
286         if(i>=argc)
287                 exit_with_help();
288
289         strcpy(input_file_name, argv[i]);
290
291         if(i<argc-1)
292                 strcpy(model_file_name,argv[i+1]);
293         else
294         {
295                 char *p = strrchr(argv[i],'/');
296                 if(p==NULL)
297                         p = argv[i];
298                 else
299                         ++p;
300                 sprintf(model_file_name,"%s.model",p);
301         }
302
303         // default solver for parameter selection is L2R_L2LOSS_SVC
304         if(flag_find_C)
305         {
306                 if(!flag_cross_validation)
307                         nr_fold = 5;
308                 if(!flag_solver_specified)
309                 {
310                         fprintf(stderr, "Solver not specified. Using -s 2\n");
311                         param.solver_type = L2R_L2LOSS_SVC;
312                 }
313                 else if(param.solver_type != L2R_LR && param.solver_type != L2R_L2LOSS_SVC)
314                 {
315                         fprintf(stderr, "Warm-start parameter search only available for -s 0 and -s 2\n");
316                         exit_with_help();
317                 }
318         }
319
320         if(param.eps == INF)
321         {
322                 switch(param.solver_type)
323                 {
324                         case L2R_LR:
325                         case L2R_L2LOSS_SVC:
326                                 param.eps = 0.01;
327                                 break;
328                         case L2R_L2LOSS_SVR:
329                                 param.eps = 0.001;
330                                 break;
331                         case L2R_L2LOSS_SVC_DUAL:
332                         case L2R_L1LOSS_SVC_DUAL:
333                         case MCSVM_CS:
334                         case L2R_LR_DUAL:
335                                 param.eps = 0.1;
336                                 break;
337                         case L1R_L2LOSS_SVC:
338                         case L1R_LR:
339                                 param.eps = 0.01;
340                                 break;
341                         case L2R_L1LOSS_SVR_DUAL:
342                         case L2R_L2LOSS_SVR_DUAL:
343                                 param.eps = 0.1;
344                                 break;
345                 }
346         }
347 }
348
349 // read in a problem (in libsvm format)
350 void read_problem(const char *filename)
351 {
352         int max_index, inst_max_index, i;
353         size_t elements, j;
354         FILE *fp = fopen(filename,"r");
355         char *endptr;
356         char *idx, *val, *label;
357
358         if(fp == NULL)
359         {
360                 fprintf(stderr,"can't open input file %s\n",filename);
361                 exit(1);
362         }
363
364         prob.l = 0;
365         elements = 0;
366         max_line_len = 1024;
367         line = Malloc(char,max_line_len);
368         while(readline(fp)!=NULL)
369         {
370                 char *p = strtok(line," \t"); // label
371
372                 // features
373                 while(1)
374                 {
375                         p = strtok(NULL," \t");
376                         if(p == NULL || *p == '\n') // check '\n' as ' ' may be after the last feature
377                                 break;
378                         elements++;
379                 }
380                 elements++; // for bias term
381                 prob.l++;
382         }
383         rewind(fp);
384
385         prob.bias=bias;
386
387         prob.y = Malloc(double,prob.l);
388         prob.x = Malloc(struct feature_node *,prob.l);
389         x_space = Malloc(struct feature_node,elements+prob.l);
390
391         max_index = 0;
392         j=0;
393         for(i=0;i<prob.l;i++)
394         {
395                 inst_max_index = 0; // strtol gives 0 if wrong format
396                 readline(fp);
397                 prob.x[i] = &x_space[j];
398                 label = strtok(line," \t\n");
399                 if(label == NULL) // empty line
400                         exit_input_error(i+1);
401
402                 prob.y[i] = strtod(label,&endptr);
403                 if(endptr == label || *endptr != '\0')
404                         exit_input_error(i+1);
405
406                 while(1)
407                 {
408                         idx = strtok(NULL,":");
409                         val = strtok(NULL," \t");
410
411                         if(val == NULL)
412                                 break;
413
414                         errno = 0;
415                         x_space[j].index = (int) strtol(idx,&endptr,10);
416                         if(endptr == idx || errno != 0 || *endptr != '\0' || x_space[j].index <= inst_max_index)
417                                 exit_input_error(i+1);
418                         else
419                                 inst_max_index = x_space[j].index;
420
421                         errno = 0;
422                         x_space[j].value = strtod(val,&endptr);
423                         if(endptr == val || errno != 0 || (*endptr != '\0' && !isspace(*endptr)))
424                                 exit_input_error(i+1);
425
426                         ++j;
427                 }
428
429                 if(inst_max_index > max_index)
430                         max_index = inst_max_index;
431
432                 if(prob.bias >= 0)
433                         x_space[j++].value = prob.bias;
434
435                 x_space[j++].index = -1;
436         }
437
438         if(prob.bias >= 0)
439         {
440                 prob.n=max_index+1;
441                 for(i=1;i<prob.l;i++)
442                         (prob.x[i]-2)->index = prob.n;
443                 x_space[j-2].index = prob.n;
444         }
445         else
446                 prob.n=max_index;
447
448         fclose(fp);
449 }