From: eaudex Date: Fri, 21 Aug 2009 13:40:52 +0000 (+0000) Subject: L1-regularized L2-loss SVC X-Git-Tag: v140~9 X-Git-Url: https://granicus.if.org/sourcecode?a=commitdiff_plain;h=d15ff00b8a3153852a6a63cffd198cc6ae7c9f4b;p=liblinear L1-regularized L2-loss SVC L1-regularized Logistic Regression --- diff --git a/linear.cpp b/linear.cpp index d1560e2..74978c3 100644 --- a/linear.cpp +++ b/linear.cpp @@ -6,14 +6,14 @@ #include "linear.h" #include "tron.h" typedef signed char schar; -template inline void swap(T& x, T& y) { T t=x; x=y; y=t; } +template static inline void swap(T& x, T& y) { T t=x; x=y; y=t; } #ifndef min -template inline T min(T x,T y) { return (x static inline T min(T x,T y) { return (x inline T max(T x,T y) { return (x>y)?x:y; } +template static inline T max(T x,T y) { return (x>y)?x:y; } #endif -template inline void clone(T*& dst, S* src, int n) +template static inline void clone(T*& dst, S* src, int n) { dst = new T[n]; memcpy((void *)dst,(void *)src,sizeof(T)*n); @@ -696,7 +696,7 @@ void Solver_MCSVM_CS::Solve(double *w) // // solution will be put in w -static void solve_linear_c_svc( +static void solve_l2_l1l2_svc( const problem *prob, double *w, double eps, double Cp, double Cn, int solver_type) { @@ -717,10 +717,10 @@ static void solve_linear_c_svc( double PGmin_old = -INF; double PGmax_new, PGmin_new; - // default solver_type: L2LOSS_SVM_DUAL + // default solver_type: L2_L2LOSS_SVC_DUAL double diag_p = 0.5/Cp, diag_n = 0.5/Cn; double upper_bound_p = INF, upper_bound_n = INF; - if(solver_type == L1LOSS_SVM_DUAL) + if(solver_type == L2_L1LOSS_SVC_DUAL) { diag_p = 0; diag_n = 0; upper_bound_p = Cp; upper_bound_n = Cn; @@ -834,9 +834,7 @@ static void solve_linear_c_svc( iter++; if(iter % 10 == 0) - { info("."); - } if(PGmax_new - PGmin_new <= eps) { @@ -887,9 +885,618 @@ static void solve_linear_c_svc( delete [] index; } +// A coordinate descent algorithm for +// L1-regularized L2-loss support vector classification +// +// min_w \sum |wj| - C \sum max(0, 1-yi w^T xi)^2, +// +// Given: +// x, y, Cp, Cn +// eps is the stopping tolerance +// +// solution will be put in w + +static void solve_l1_l2_svc( + problem *prob_col, double *w, double eps, + double Cp, double Cn) +{ + int l = prob_col->l; + int w_size = prob_col->n; + int i, s, iter = 0; + int max_iter = 1000; + int active_size = w_size; + int max_num_linesearch = 20; + + double sigma = 0.01; + double d, G_loss, G, H; + double Gmax_old = INF; + double Gmax_new; + double Gmax_init; + double d_old, d_diff; + double loss_old, loss_new; + double appxcond, cond; + + int *index = new int[w_size]; + schar *y = new schar[l]; + double *b = new double[l]; // b = 1-ywTx + double *xi_sq = new double[w_size]; + feature_node *x; + + // To support weights for instances, + // replace C[y[ind]] with C[i]. + double C[2] = {Cn,Cp}; + + for(i=0; iy[i] > 0) + y[i] = 1; + else + y[i] = 0; + } + + for(i=0; ix[i]; + while(x->index != -1) + { + int ind = x->index; + double val = x->value; + x->value *= prob_col->y[ind]; // x->value stores yj*xji + xi_sq[i] += C[y[ind]]*val*val; + x++; + } + } + + while(iter < max_iter) + { + Gmax_new = 0; + + for(i=0; ix[i]; + while(x->index != -1) + { + int ind = x->index; + if(b[ind] > 0) + { + double val = x->value; + double tmp = C[y[ind]]*val; + G_loss -= tmp*b[ind]; + H += tmp*val; + } + x++; + } + G_loss *= 2; + + G = G_loss; + H *= 2; + + double Gp = G+1; + double Gn = G-1; + // shrinking + if(w[i] == 0) + { + if(Gp>Gmax_old/l && Gn<-Gmax_old/l) + { + active_size--; + swap(index[s], index[active_size]); + s--; + continue; + } + } + + // obtain Newton direction d + if(Gp <= H*w[i]) + { + G = Gp; + d = -Gp/H; + } + else if(Gn >= H*w[i]) + { + G = Gn; + d = -Gn/H; + } + else + { + G = 0; + d = -w[i]; + } + + Gmax_new = max(Gmax_new, fabs(G)); + + if(fabs(d) < 1.0e-12) + continue; + + d_old = 0; + int num_linesearch; + for(num_linesearch=0; num_linesearch < max_num_linesearch; num_linesearch++) + { + d_diff = d_old - d; + cond = fabs(w[i]+d)-fabs(w[i]) + sigma*d*d; + + appxcond = xi_sq[i]*d*d + G_loss*d + cond; + if(appxcond <= 0) + { + x = prob_col->x[i]; + while(x->index != -1) + { + b[x->index] += d_diff*x->value; + x++; + } + break; + } + + if(num_linesearch == 0) + { + loss_old = 0; + loss_new = 0; + x = prob_col->x[i]; + while(x->index != -1) + { + int ind = x->index; + if(b[ind] > 0) + loss_old += C[y[ind]]*b[ind]*b[ind]; + double b_new = b[ind] + d_diff*x->value; + b[ind] = b_new; + if(b_new > 0) + loss_new += C[y[ind]]*b_new*b_new; + x++; + } + } + else + { + loss_new = 0; + x = prob_col->x[i]; + while(x->index != -1) + { + int ind = x->index; + double b_new = b[ind] + d_diff*x->value; + b[ind] = b_new; + if(b_new > 0) + loss_new += C[y[ind]]*b_new*b_new; + x++; + } + } + + cond = cond + loss_new - loss_old; + if(cond <= 0) + break; + else + { + d_old = d; + d *= 0.5; + } + } + + w[i] += d; + + // recompute b[] if line search takes too many steps + if(num_linesearch >= max_num_linesearch) + { + info("#"); + for(int j=0; jx[j]; + while(x->index != -1) + { + b[x->index] -= w[j]*x->value; + x++; + } + } + } + } + + if(iter == 0) + Gmax_init = Gmax_new; + iter++; + if(iter % 10 == 0) + info("."); + + if(Gmax_new <= eps*Gmax_init) + { + if(active_size == w_size) + break; + else + { + active_size = w_size; + info("*"); + Gmax_old = INF; + continue; + } + } + + Gmax_old = Gmax_new; + } + + info("\noptimization finished, #iter = %d\n", iter); + if(iter >= max_iter) + info("\nWARNING: reaching max number of iterations\n"); + + // calculate objective value + + double v = 0; + int nnz = 0; + for(i=0; ix[i]; + while(x->index != -1) + { + x->value *= prob_col->y[x->index]; // restore x->value + x++; + } + if(w[i] != 0) + { + v += fabs(w[i]); + nnz++; + } + } + for(i=0; i 0) + v += C[y[i]]*b[i]*b[i]; + + info("Objective value = %lf\n", v); + info("#nonzeros/#features = %d/%d\n", nnz, w_size); + + delete [] index; + delete [] y; + delete [] b; + delete [] xi_sq; +} + +// A coordinate descent algorithm for +// L1-regularized logistic regression problems +// +// min_w \sum |wj| - C \sum log(1+exp(-yi w^T xi)), +// +// Given: +// x, y, Cp, Cn +// eps is the stopping tolerance +// +// solution will be put in w + +static void solve_l1_lr( + const problem *prob_col, double *w, double eps, + double Cp, double Cn) +{ + int l = prob_col->l; + int w_size = prob_col->n; + int i, s, iter = 0; + int max_iter = 1000; + int active_size = w_size; + int max_num_linesearch = 20; + + double x_min = 0; + double sigma = 0.01; + double d, G, H; + double Gmax_old = INF; + double Gmax_new; + double Gmax_init; + double sum1, appxcond1; + double sum2, appxcond2; + double cond; + + int *index = new int[w_size]; + schar *y = new schar[l]; + double *exp_wTx = new double[l]; + double *exp_wTx_new = new double[l]; + double *xi_max = new double[w_size]; + double *C_sum = new double[w_size]; + double *xineg_sum = new double[w_size]; + double *xipos_sum = new double[w_size]; + feature_node *x; + + // To support weights for instances, + // replace C[y[ind]] with C[i]. + double C[2] = {Cn,Cp}; + + for(i=0; iy[i] > 0) + y[i] = 1; + else + y[i] = 0; + } + for(i=0; ix[i]; + while(x->index != -1) + { + int ind = x->index; + double val = x->value; + x_min = min(x_min, val); + xi_max[i] = max(xi_max[i], val); + C_sum[i] += C[y[ind]]; + if(y[ind] == 0) + xineg_sum[i] += C[y[ind]]*val; + else + xipos_sum[i] += C[y[ind]]*val; + x++; + } + } + + while(iter < max_iter) + { + Gmax_new = 0; + + for(i=0; ix[i]; + while(x->index != -1) + { + int ind = x->index; + double exp_wTxind = exp_wTx[ind]; + double tmp1 = x->value/(1+exp_wTxind); + double tmp2 = C[y[ind]]*tmp1; + double tmp3 = tmp2*exp_wTxind; + sum2 += tmp2; + sum1 += tmp3; + H += tmp1*tmp3; + x++; + } + + G = -sum2 + xineg_sum[i]; + + double Gp = G+1; + double Gn = G-1; + // shrinking + if(w[i] == 0) + { + if(Gp>Gmax_old/l && Gn<-Gmax_old/l) + { + active_size--; + swap(index[s], index[active_size]); + s--; + continue; + } + } + + // obtain Newton direction d + if(Gp <= H*w[i]) + { + G = Gp; + d = -Gp/H; + } + else if(Gn >= H*w[i]) + { + G = Gn; + d = -Gn/H; + } + else + { + G = 0; + d = -w[i]; + } + + Gmax_new = max(Gmax_new, fabs(G)); + + if(fabs(d) < 1.0e-12) + continue; + + d = min(max(d,-10.0),10.0); + + int num_linesearch; + for(num_linesearch=0; num_linesearch < max_num_linesearch; num_linesearch++) + { + cond = fabs(w[i]+d)-fabs(w[i]) + sigma*d*d; + + if(x_min >= 0) + { + double tmp = exp(d*xi_max[i]); + appxcond1 = log(1+sum1*(tmp-1)/xi_max[i]/C_sum[i])*C_sum[i] + cond - d*xipos_sum[i]; + appxcond2 = log(1+sum2*(1/tmp-1)/xi_max[i]/C_sum[i])*C_sum[i] + cond + d*xineg_sum[i]; + if(min(appxcond1,appxcond2) <= 0) + { + x = prob_col->x[i]; + while(x->index != -1) + { + exp_wTx[x->index] *= exp(d*x->value); + x++; + } + break; + } + } + + cond += d*xineg_sum[i]; + + int j = 0; + x = prob_col->x[i]; + while(x->index != -1) + { + int ind = x->index; + double exp_dx = exp(d*x->value); + exp_wTx_new[j] = exp_wTx[ind]*exp_dx; + cond += C[y[ind]]*log((1+exp_wTx_new[j])/(exp_dx+exp_wTx_new[j])); + x++; j++; + } + + if(cond <= 0) + { + int j = 0; + x = prob_col->x[i]; + while(x->index != -1) + { + int ind = x->index; + exp_wTx[ind] = exp_wTx_new[j]; + x++; j++; + } + break; + } + else + d *= 0.5; + } + + w[i] += d; + + // recompute exp_wTx[] if line search takes too many steps + if(num_linesearch >= max_num_linesearch) + { + info("#"); + for(int j=0; jx[j]; + while(x->index != -1) + { + exp_wTx[x->index] += w[j]*x->value; + x++; + } + } + + for(int j=0; j= max_iter) + info("\nWARNING: reaching max number of iterations\n"); + + // calculate objective value + + double v = 0; + int nnz = 0; + for(i=0; il; + int n = prob->n; + int nnz = 0; + int *col_ptr = new int[n+1]; + prob_col->l = l; + prob_col->n = n; + prob_col->y = new int[l]; + prob_col->x = new feature_node*[n]; + + for(i=0; iy[i] = prob->y[i]; + + for(i=0; ix[i]; + while(x->index != -1) + { + nnz++; + col_ptr[x->index]++; + x++; + } + } + for(i=1; ix[i] = &x_space[col_ptr[i]]; + + for(i=0; ix[i]; + while(x->index != -1) + { + int ind = x->index-1; + x_space[col_ptr[ind]].index = i; + x_space[col_ptr[ind]].value = x->value; + col_ptr[ind]++; + x++; + } + } + for(i=0; il; int max_nr_class = 16; @@ -946,7 +1553,7 @@ void group_classes(const problem *prob, int *nr_class_ret, int **label_ret, int free(data_label); } -void train_one(const problem *prob, const parameter *param, double *w, double Cp, double Cn) +static void train_one(const problem *prob, const parameter *param, double *w, double Cp, double Cn) { double eps=param->eps; int pos = 0; @@ -968,7 +1575,7 @@ void train_one(const problem *prob, const parameter *param, double *w, double Cp delete fun_obj; break; } - case L2LOSS_SVM: + case L2_L2LOSS_SVC: { fun_obj=new l2loss_svm_fun(prob, Cp, Cn); TRON tron_obj(fun_obj, eps*min(pos,neg)/prob->l); @@ -977,12 +1584,34 @@ void train_one(const problem *prob, const parameter *param, double *w, double Cp delete fun_obj; break; } - case L2LOSS_SVM_DUAL: - solve_linear_c_svc(prob, w, eps, Cp, Cn, L2LOSS_SVM_DUAL); + case L2_L2LOSS_SVC_DUAL: + solve_l2_l1l2_svc(prob, w, eps, Cp, Cn, L2_L2LOSS_SVC_DUAL); + break; + case L2_L1LOSS_SVC_DUAL: + solve_l2_l1l2_svc(prob, w, eps, Cp, Cn, L2_L1LOSS_SVC_DUAL); + break; + case L1_L2LOSS_SVC: + { + problem prob_col; + feature_node *x_space = NULL; + transpose(prob, x_space ,&prob_col); + solve_l1_l2_svc(&prob_col, w, eps*min(pos,neg)/prob->l, Cp, Cn); + delete [] prob_col.y; + delete [] prob_col.x; + delete [] x_space; break; - case L1LOSS_SVM_DUAL: - solve_linear_c_svc(prob, w, eps, Cp, Cn, L1LOSS_SVM_DUAL); + } + case L1_LR: + { + problem prob_col; + feature_node *x_space = NULL; + transpose(prob, x_space ,&prob_col); + solve_l1_lr(&prob_col, w, eps*min(pos,neg)/prob->l, Cp, Cn); + delete [] prob_col.y; + delete [] prob_col.x; + delete [] x_space; break; + } default: fprintf(stderr, "Error: unknown solver_type\n"); break; @@ -1123,9 +1752,9 @@ void destroy_model(struct model *model_) free(model_); } -const char *solver_type_table[]= +static const char *solver_type_table[]= { - "L2_LR", "L2LOSS_SVM_DUAL", "L2LOSS_SVM","L1LOSS_SVM_DUAL","MCSVM_CS", NULL + "L2_LR", "L2_L2LOSS_SVC_DUAL", "L2_L2LOSS_SVC","L2_L1LOSS_SVC_DUAL","MCSVM_CS", "L1_L2LOSS_SVC","L1_LR", NULL }; int save_model(const char *model_file_name, const struct model *model_) @@ -1372,10 +2001,12 @@ const char *check_parameter(const problem *prob, const parameter *param) return "C <= 0"; if(param->solver_type != L2_LR - && param->solver_type != L2LOSS_SVM_DUAL - && param->solver_type != L2LOSS_SVM - && param->solver_type != L1LOSS_SVM_DUAL - && param->solver_type != MCSVM_CS) + && param->solver_type != L2_L2LOSS_SVC_DUAL + && param->solver_type != L2_L2LOSS_SVC + && param->solver_type != L2_L1LOSS_SVC_DUAL + && param->solver_type != MCSVM_CS + && param->solver_type != L1_L2LOSS_SVC + && param->solver_type != L1_LR) return "unknown solver type"; return NULL; diff --git a/linear.h b/linear.h index f36d65a..10be96b 100644 --- a/linear.h +++ b/linear.h @@ -19,7 +19,7 @@ struct problem double bias; /* < 0 if no bias term */ }; -enum { L2_LR, L2LOSS_SVM_DUAL, L2LOSS_SVM, L1LOSS_SVM_DUAL, MCSVM_CS }; /* solver_type */ +enum { L2_LR, L2_L2LOSS_SVC_DUAL, L2_L2LOSS_SVC, L2_L1LOSS_SVC_DUAL, MCSVM_CS, L1_L2LOSS_SVC, L1_LR }; /* solver_type */ struct parameter { diff --git a/matlab/train.c b/matlab/train.c index 1ffb6ee..f8ad784 100644 --- a/matlab/train.c +++ b/matlab/train.c @@ -27,10 +27,12 @@ void exit_with_help() "liblinear_options:\n" "-s type : set type of solver (default 1)\n" " 0 -- L2-regularized logistic regression\n" - " 1 -- L2-loss support vector machines (dual)\n" - " 2 -- L2-loss support vector machines (primal)\n" - " 3 -- L1-loss support vector machines (dual)\n" - " 4 -- multi-class support vector machines by Crammer and Singer\n" + " 1 -- L2-regularized L2-loss support vector classification (dual)\n" + " 2 -- L2-regularized L2-loss support vector classification (primal)\n" + " 3 -- L2-regularized L1-loss support vector classification (dual)\n" + " 4 -- multi-class support vector classification by Crammer and Singer\n" + " 5 -- L1-regularized L2-loss support vector classification\n" + " 6 -- L1-regularized logistic regression\n" "-c cost : set the parameter C (default 1)\n" "-e epsilon : set tolerance of termination criterion\n" " -s 0 and 2\n" @@ -38,6 +40,9 @@ void exit_with_help() " where f is the primal function, (default 0.01)\n" " -s 1, 3, and 4\n" " Dual maximal violation <= eps; similar to libsvm (default 0.1)\n" + " -s 5 and 6\n" + " |f'(w)|_inf <= eps*min(pos,neg)/l*|f'(w0)|_inf,\n" + " where f is the primal function (default 0.01)\n" "-B bias : if bias >= 0, instance x becomes [x; bias]; if < 0, no bias term added (default 1)\n" "-wi weight: weights adjust the parameter C of different classes (see README for details)\n" "-v n: n-fold cross validation mode\n" @@ -84,7 +89,7 @@ int parse_command_line(int nrhs, const mxArray *prhs[], char *model_file_name) char *argv[CMD_LEN/2]; // default values - param.solver_type = L2LOSS_SVM_DUAL; + param.solver_type = L2_L2LOSS_SVC_DUAL; param.C = 1; param.eps = INF; // see setting below param.nr_weight = 0; @@ -168,10 +173,12 @@ int parse_command_line(int nrhs, const mxArray *prhs[], char *model_file_name) if(param.eps == INF) { - if(param.solver_type == L2_LR || param.solver_type == L2LOSS_SVM) + if(param.solver_type == L2_LR || param.solver_type == L2_L2LOSS_SVC) param.eps = 0.01; - else if(param.solver_type == L2LOSS_SVM_DUAL || param.solver_type == L1LOSS_SVM_DUAL || param.solver_type == MCSVM_CS) + else if(param.solver_type == L2_L2LOSS_SVC_DUAL || param.solver_type == L2_L1LOSS_SVC_DUAL || param.solver_type == MCSVM_CS) param.eps = 0.1; + else if(param.solver_type == L1_L2LOSS_SVC || param.solver_type == L1_LR) + param.eps = 0.01; } return 0; } diff --git a/train.c b/train.c index c905635..66194bf 100644 --- a/train.c +++ b/train.c @@ -17,10 +17,12 @@ void exit_with_help() "options:\n" "-s type : set type of solver (default 1)\n" " 0 -- L2-regularized logistic regression\n" - " 1 -- L2-loss support vector machines (dual)\n" - " 2 -- L2-loss support vector machines (primal)\n" - " 3 -- L1-loss support vector machines (dual)\n" - " 4 -- multi-class support vector machines by Crammer and Singer\n" + " 1 -- L2-regularized L2-loss support vector classification (dual)\n" + " 2 -- L2-regularized L2-loss support vector classification (primal)\n" + " 3 -- L2-regularized L1-loss support vector classification (dual)\n" + " 4 -- multi-class support vector classification by Crammer and Singer\n" + " 5 -- L1-regularized L2-loss support vector classification\n" + " 6 -- L1-regularized logistic regression\n" "-c cost : set the parameter C (default 1)\n" "-e epsilon : set tolerance of termination criterion\n" " -s 0 and 2\n" @@ -29,6 +31,9 @@ void exit_with_help() " positive/negative data (default 0.01)\n" " -s 1, 3, and 4\n" " Dual maximal violation <= eps; similar to libsvm (default 0.1)\n" + " -s 5 and 6\n" + " |f'(w)|_inf <= eps*min(pos,neg)/l*|f'(w0)|_inf,\n" + " where f is the primal function (default 0.01)\n" "-B bias : if bias >= 0, instance x becomes [x; bias]; if < 0, no bias term added (default 1)\n" "-wi weight: weights adjust the parameter C of different classes (see README for details)\n" "-v n: n-fold cross validation mode\n" @@ -132,7 +137,7 @@ void parse_command_line(int argc, char **argv, char *input_file_name, char *mode int i; // default values - param.solver_type = L2LOSS_SVM_DUAL; + param.solver_type = L2_L2LOSS_SVC_DUAL; param.C = 1; param.eps = INF; // see setting below param.nr_weight = 0; @@ -215,10 +220,12 @@ void parse_command_line(int argc, char **argv, char *input_file_name, char *mode if(param.eps == INF) { - if(param.solver_type == L2_LR || param.solver_type == L2LOSS_SVM) + if(param.solver_type == L2_LR || param.solver_type == L2_L2LOSS_SVC) param.eps = 0.01; - else if(param.solver_type == L2LOSS_SVM_DUAL || param.solver_type == L1LOSS_SVM_DUAL || param.solver_type == MCSVM_CS) + else if(param.solver_type == L2_L2LOSS_SVC_DUAL || param.solver_type == L2_L1LOSS_SVC_DUAL || param.solver_type == MCSVM_CS) param.eps = 0.1; + else if(param.solver_type == L1_L2LOSS_SVC || param.solver_type == L1_LR) + param.eps = 0.01; } }