]> granicus.if.org Git - liblinear/blob - train.c
Change find_parameter_C to find_parameters in linear.def
[liblinear] / train.c
1 #include <stdio.h>
2 #include <math.h>
3 #include <stdlib.h>
4 #include <string.h>
5 #include <ctype.h>
6 #include <errno.h>
7 #include "linear.h"
8 #define Malloc(type,n) (type *)malloc((n)*sizeof(type))
9 #define INF HUGE_VAL
10
11 void print_null(const char *s) {}
12
13 void exit_with_help()
14 {
15         printf(
16         "Usage: train [options] training_set_file [model_file]\n"
17         "options:\n"
18         "-s type : set type of solver (default 1)\n"
19         "  for multi-class classification\n"
20         "        0 -- L2-regularized logistic regression (primal)\n"
21         "        1 -- L2-regularized L2-loss support vector classification (dual)\n"
22         "        2 -- L2-regularized L2-loss support vector classification (primal)\n"
23         "        3 -- L2-regularized L1-loss support vector classification (dual)\n"
24         "        4 -- support vector classification by Crammer and Singer\n"
25         "        5 -- L1-regularized L2-loss support vector classification\n"
26         "        6 -- L1-regularized logistic regression\n"
27         "        7 -- L2-regularized logistic regression (dual)\n"
28         "  for regression\n"
29         "       11 -- L2-regularized L2-loss support vector regression (primal)\n"
30         "       12 -- L2-regularized L2-loss support vector regression (dual)\n"
31         "       13 -- L2-regularized L1-loss support vector regression (dual)\n"
32         "-c cost : set the parameter C (default 1)\n"
33         "-p epsilon : set the epsilon in loss function of SVR (default 0.1)\n"
34         "-e epsilon : set tolerance of termination criterion\n"
35         "       -s 0 and 2\n"
36         "               |f'(w)|_2 <= eps*min(pos,neg)/l*|f'(w0)|_2,\n"
37         "               where f is the primal function and pos/neg are # of\n"
38         "               positive/negative data (default 0.01)\n"
39         "       -s 11\n"
40         "               |f'(w)|_2 <= eps*|f'(w0)|_2 (default 0.0001)\n"
41         "       -s 1, 3, 4, and 7\n"
42         "               Dual maximal violation <= eps; similar to libsvm (default 0.1)\n"
43         "       -s 5 and 6\n"
44         "               |f'(w)|_1 <= eps*min(pos,neg)/l*|f'(w0)|_1,\n"
45         "               where f is the primal function (default 0.01)\n"
46         "       -s 12 and 13\n"
47         "               |f'(alpha)|_1 <= eps |f'(alpha0)|,\n"
48         "               where f is the dual function (default 0.1)\n"
49         "-B bias : if bias >= 0, instance x becomes [x; bias]; if < 0, no bias term added (default -1)\n"
50         "-wi weight: weights adjust the parameter C of different classes (see README for details)\n"
51         "-v n: n-fold cross validation mode\n"
52         "-C : find parameters (C for -s 0, 2 and C, p for -s 11)\n"
53         "-q : quiet mode (no outputs)\n"
54         );
55         exit(1);
56 }
57
58 void exit_input_error(int line_num)
59 {
60         fprintf(stderr,"Wrong input format at line %d\n", line_num);
61         exit(1);
62 }
63
64 static char *line = NULL;
65 static int max_line_len;
66
67 static char* readline(FILE *input)
68 {
69         int len;
70
71         if(fgets(line,max_line_len,input) == NULL)
72                 return NULL;
73
74         while(strrchr(line,'\n') == NULL)
75         {
76                 max_line_len *= 2;
77                 line = (char *) realloc(line,max_line_len);
78                 len = (int) strlen(line);
79                 if(fgets(line+len,max_line_len-len,input) == NULL)
80                         break;
81         }
82         return line;
83 }
84
85 void parse_command_line(int argc, char **argv, char *input_file_name, char *model_file_name);
86 void read_problem(const char *filename);
87 void do_cross_validation();
88 void do_find_parameters();
89
90 struct feature_node *x_space;
91 struct parameter param;
92 struct problem prob;
93 struct model* model_;
94 int flag_cross_validation;
95 int flag_find_parameters;
96 int flag_C_specified;
97 int flag_p_specified;
98 int flag_solver_specified;
99 int nr_fold;
100 double bias;
101
102 int main(int argc, char **argv)
103 {
104         char input_file_name[1024];
105         char model_file_name[1024];
106         const char *error_msg;
107
108         parse_command_line(argc, argv, input_file_name, model_file_name);
109         read_problem(input_file_name);
110         error_msg = check_parameter(&prob,&param);
111
112         if(error_msg)
113         {
114                 fprintf(stderr,"ERROR: %s\n",error_msg);
115                 exit(1);
116         }
117
118         if (flag_find_parameters)
119         {
120                 do_find_parameters();
121         }
122         else if(flag_cross_validation)
123         {
124                 do_cross_validation();
125         }
126         else
127         {
128                 model_=train(&prob, &param);
129                 if(save_model(model_file_name, model_))
130                 {
131                         fprintf(stderr,"can't save model to file %s\n",model_file_name);
132                         exit(1);
133                 }
134                 free_and_destroy_model(&model_);
135         }
136         destroy_param(&param);
137         free(prob.y);
138         free(prob.x);
139         free(x_space);
140         free(line);
141
142         return 0;
143 }
144
145 void do_find_parameters()
146 {
147         double start_C, start_p, best_C, best_p, best_score;
148         if (flag_C_specified)
149                 start_C = param.C;
150         else
151                 start_C = -1.0;
152         if (flag_p_specified)
153                 start_p = param.p;
154         else
155                 start_p = -1.0;
156         
157         printf("Doing parameter search with %d-fold cross validation.\n", nr_fold);
158         find_parameters(&prob, &param, nr_fold, start_C, start_p, &best_C, &best_p, &best_score);
159         if(param.solver_type == L2R_LR || param.solver_type == L2R_L2LOSS_SVC)
160                 printf("Best C = %g  CV accuracy = %g%%\n", best_C, 100.0*best_score);
161         else if(param.solver_type == L2R_L2LOSS_SVR)
162                 printf("Best C = %g Best p = %g  CV MSE = %g\n", best_C, best_p, best_score);
163 }
164
165 void do_cross_validation()
166 {
167         int i;
168         int total_correct = 0;
169         double total_error = 0;
170         double sumv = 0, sumy = 0, sumvv = 0, sumyy = 0, sumvy = 0;
171         double *target = Malloc(double, prob.l);
172
173         cross_validation(&prob,&param,nr_fold,target);
174         if(param.solver_type == L2R_L2LOSS_SVR ||
175            param.solver_type == L2R_L1LOSS_SVR_DUAL ||
176            param.solver_type == L2R_L2LOSS_SVR_DUAL)
177         {
178                 for(i=0;i<prob.l;i++)
179                 {
180                         double y = prob.y[i];
181                         double v = target[i];
182                         total_error += (v-y)*(v-y);
183                         sumv += v;
184                         sumy += y;
185                         sumvv += v*v;
186                         sumyy += y*y;
187                         sumvy += v*y;
188                 }
189                 printf("Cross Validation Mean squared error = %g\n",total_error/prob.l);
190                 printf("Cross Validation Squared correlation coefficient = %g\n",
191                                 ((prob.l*sumvy-sumv*sumy)*(prob.l*sumvy-sumv*sumy))/
192                                 ((prob.l*sumvv-sumv*sumv)*(prob.l*sumyy-sumy*sumy))
193                           );
194         }
195         else
196         {
197                 for(i=0;i<prob.l;i++)
198                         if(target[i] == prob.y[i])
199                                 ++total_correct;
200                 printf("Cross Validation Accuracy = %g%%\n",100.0*total_correct/prob.l);
201         }
202
203         free(target);
204 }
205
206 void parse_command_line(int argc, char **argv, char *input_file_name, char *model_file_name)
207 {
208         int i;
209         void (*print_func)(const char*) = NULL; // default printing to stdout
210
211         // default values
212         param.solver_type = L2R_L2LOSS_SVC_DUAL;
213         param.C = 1;
214         param.eps = INF; // see setting below
215         param.p = 0.1;
216         param.nr_weight = 0;
217         param.weight_label = NULL;
218         param.weight = NULL;
219         param.init_sol = NULL;
220         flag_cross_validation = 0;
221         flag_C_specified = 0;
222         flag_p_specified = 0;
223         flag_solver_specified = 0;
224         flag_find_parameters = 0;
225         bias = -1;
226
227         // parse options
228         for(i=1;i<argc;i++)
229         {
230                 if(argv[i][0] != '-') break;
231                 if(++i>=argc)
232                         exit_with_help();
233                 switch(argv[i-1][1])
234                 {
235                         case 's':
236                                 param.solver_type = atoi(argv[i]);
237                                 flag_solver_specified = 1;
238                                 break;
239
240                         case 'c':
241                                 param.C = atof(argv[i]);
242                                 flag_C_specified = 1;
243                                 break;
244
245                         case 'p':
246                                 flag_p_specified = 1;
247                                 param.p = atof(argv[i]);
248                                 break;
249
250                         case 'e':
251                                 param.eps = atof(argv[i]);
252                                 break;
253
254                         case 'B':
255                                 bias = atof(argv[i]);
256                                 break;
257
258                         case 'w':
259                                 ++param.nr_weight;
260                                 param.weight_label = (int *) realloc(param.weight_label,sizeof(int)*param.nr_weight);
261                                 param.weight = (double *) realloc(param.weight,sizeof(double)*param.nr_weight);
262                                 param.weight_label[param.nr_weight-1] = atoi(&argv[i-1][2]);
263                                 param.weight[param.nr_weight-1] = atof(argv[i]);
264                                 break;
265
266                         case 'v':
267                                 flag_cross_validation = 1;
268                                 nr_fold = atoi(argv[i]);
269                                 if(nr_fold < 2)
270                                 {
271                                         fprintf(stderr,"n-fold cross validation: n must >= 2\n");
272                                         exit_with_help();
273                                 }
274                                 break;
275
276                         case 'q':
277                                 print_func = &print_null;
278                                 i--;
279                                 break;
280
281                         case 'C':
282                                 flag_find_parameters = 1;
283                                 i--;
284                                 break;
285
286                         default:
287                                 fprintf(stderr,"unknown option: -%c\n", argv[i-1][1]);
288                                 exit_with_help();
289                                 break;
290                 }
291         }
292
293         set_print_string_function(print_func);
294
295         // determine filenames
296         if(i>=argc)
297                 exit_with_help();
298
299         strcpy(input_file_name, argv[i]);
300
301         if(i<argc-1)
302                 strcpy(model_file_name,argv[i+1]);
303         else
304         {
305                 char *p = strrchr(argv[i],'/');
306                 if(p==NULL)
307                         p = argv[i];
308                 else
309                         ++p;
310                 sprintf(model_file_name,"%s.model",p);
311         }
312
313         // default solver for parameter selection is L2R_L2LOSS_SVC
314         if(flag_find_parameters)
315         {
316                 if(!flag_cross_validation)
317                         nr_fold = 5;
318                 if(!flag_solver_specified)
319                 {
320                         fprintf(stderr, "Solver not specified. Using -s 2\n");
321                         param.solver_type = L2R_L2LOSS_SVC;
322                 }
323                 else if(param.solver_type != L2R_LR && param.solver_type != L2R_L2LOSS_SVC && param.solver_type != L2R_L2LOSS_SVR)
324                 {
325                         fprintf(stderr, "Warm-start parameter search only available for -s 0, -s 2 and -s 11\n");
326                         exit_with_help();
327                 }
328         }
329
330         if(param.eps == INF)
331         {
332                 switch(param.solver_type)
333                 {
334                         case L2R_LR:
335                         case L2R_L2LOSS_SVC:
336                                 param.eps = 0.01;
337                                 break;
338                         case L2R_L2LOSS_SVR:
339                                 param.eps = 0.0001;
340                                 break;
341                         case L2R_L2LOSS_SVC_DUAL:
342                         case L2R_L1LOSS_SVC_DUAL:
343                         case MCSVM_CS:
344                         case L2R_LR_DUAL:
345                                 param.eps = 0.1;
346                                 break;
347                         case L1R_L2LOSS_SVC:
348                         case L1R_LR:
349                                 param.eps = 0.01;
350                                 break;
351                         case L2R_L1LOSS_SVR_DUAL:
352                         case L2R_L2LOSS_SVR_DUAL:
353                                 param.eps = 0.1;
354                                 break;
355                 }
356         }
357 }
358
359 // read in a problem (in libsvm format)
360 void read_problem(const char *filename)
361 {
362         int max_index, inst_max_index, i;
363         size_t elements, j;
364         FILE *fp = fopen(filename,"r");
365         char *endptr;
366         char *idx, *val, *label;
367
368         if(fp == NULL)
369         {
370                 fprintf(stderr,"can't open input file %s\n",filename);
371                 exit(1);
372         }
373
374         prob.l = 0;
375         elements = 0;
376         max_line_len = 1024;
377         line = Malloc(char,max_line_len);
378         while(readline(fp)!=NULL)
379         {
380                 char *p = strtok(line," \t"); // label
381
382                 // features
383                 while(1)
384                 {
385                         p = strtok(NULL," \t");
386                         if(p == NULL || *p == '\n') // check '\n' as ' ' may be after the last feature
387                                 break;
388                         elements++;
389                 }
390                 elements++; // for bias term
391                 prob.l++;
392         }
393         rewind(fp);
394
395         prob.bias=bias;
396
397         prob.y = Malloc(double,prob.l);
398         prob.x = Malloc(struct feature_node *,prob.l);
399         x_space = Malloc(struct feature_node,elements+prob.l);
400
401         max_index = 0;
402         j=0;
403         for(i=0;i<prob.l;i++)
404         {
405                 inst_max_index = 0; // strtol gives 0 if wrong format
406                 readline(fp);
407                 prob.x[i] = &x_space[j];
408                 label = strtok(line," \t\n");
409                 if(label == NULL) // empty line
410                         exit_input_error(i+1);
411
412                 prob.y[i] = strtod(label,&endptr);
413                 if(endptr == label || *endptr != '\0')
414                         exit_input_error(i+1);
415
416                 while(1)
417                 {
418                         idx = strtok(NULL,":");
419                         val = strtok(NULL," \t");
420
421                         if(val == NULL)
422                                 break;
423
424                         errno = 0;
425                         x_space[j].index = (int) strtol(idx,&endptr,10);
426                         if(endptr == idx || errno != 0 || *endptr != '\0' || x_space[j].index <= inst_max_index)
427                                 exit_input_error(i+1);
428                         else
429                                 inst_max_index = x_space[j].index;
430
431                         errno = 0;
432                         x_space[j].value = strtod(val,&endptr);
433                         if(endptr == val || errno != 0 || (*endptr != '\0' && !isspace(*endptr)))
434                                 exit_input_error(i+1);
435
436                         ++j;
437                 }
438
439                 if(inst_max_index > max_index)
440                         max_index = inst_max_index;
441
442                 if(prob.bias >= 0)
443                         x_space[j++].value = prob.bias;
444
445                 x_space[j++].index = -1;
446         }
447
448         if(prob.bias >= 0)
449         {
450                 prob.n=max_index+1;
451                 for(i=1;i<prob.l;i++)
452                         (prob.x[i]-2)->index = prob.n;
453                 x_space[j-2].index = prob.n;
454         }
455         else
456                 prob.n=max_index;
457
458         fclose(fp);
459 }