]> granicus.if.org Git - liblinear/commitdiff
L1-regularized L2-loss SVC
authoreaudex <eaudex@16e7d947-dcc2-db11-b54a-0017319806e7>
Fri, 21 Aug 2009 13:40:52 +0000 (13:40 +0000)
committereaudex <eaudex@16e7d947-dcc2-db11-b54a-0017319806e7>
Fri, 21 Aug 2009 13:40:52 +0000 (13:40 +0000)
L1-regularized Logistic Regression

linear.cpp
linear.h
matlab/train.c
train.c

index d1560e265302530b3a98d42d561b50387634e0fa..74978c3c4238c755c94f552889ee769718990f10 100644 (file)
@@ -6,14 +6,14 @@
 #include "linear.h"
 #include "tron.h"
 typedef signed char schar;
-template <class T> inline void swap(T& x, T& y) { T t=x; x=y; y=t; }
+template <class T> static inline void swap(T& x, T& y) { T t=x; x=y; y=t; }
 #ifndef min
-template <class T> inline T min(T x,T y) { return (x<y)?x:y; }
+template <class T> static inline T min(T x,T y) { return (x<y)?x:y; }
 #endif
 #ifndef max
-template <class T> inline T max(T x,T y) { return (x>y)?x:y; }
+template <class T> static inline T max(T x,T y) { return (x>y)?x:y; }
 #endif
-template <class S, class T> inline void clone(T*& dst, S* src, int n)
+template <class S, class T> static inline void clone(T*& dst, S* src, int n)
 {   
        dst = new T[n];
        memcpy((void *)dst,(void *)src,sizeof(T)*n);
@@ -696,7 +696,7 @@ void Solver_MCSVM_CS::Solve(double *w)
 //
 // solution will be put in w
 
-static void solve_linear_c_svc(
+static void solve_l2_l1l2_svc(
        const problem *prob, double *w, double eps, 
        double Cp, double Cn, int solver_type)
 {
@@ -717,10 +717,10 @@ static void solve_linear_c_svc(
        double PGmin_old = -INF;
        double PGmax_new, PGmin_new;
 
-       // default solver_type: L2LOSS_SVM_DUAL
+       // default solver_type: L2_L2LOSS_SVC_DUAL
        double diag_p = 0.5/Cp, diag_n = 0.5/Cn;
        double upper_bound_p = INF, upper_bound_n = INF;
-       if(solver_type == L1LOSS_SVM_DUAL)
+       if(solver_type == L2_L1LOSS_SVC_DUAL)
        {
                diag_p = 0; diag_n = 0;
                upper_bound_p = Cp; upper_bound_n = Cn;
@@ -834,9 +834,7 @@ static void solve_linear_c_svc(
 
                iter++;
                if(iter % 10 == 0)
-               {
                        info(".");
-               }
 
                if(PGmax_new - PGmin_new <= eps)
                {
@@ -887,9 +885,618 @@ static void solve_linear_c_svc(
        delete [] index;
 }
 
+// A coordinate descent algorithm for 
+// L1-regularized L2-loss support vector classification
+//
+//  min_w \sum |wj| - C \sum max(0, 1-yi w^T xi)^2,
+//
+// Given: 
+// x, y, Cp, Cn
+// eps is the stopping tolerance
+//
+// solution will be put in w
+
+static void solve_l1_l2_svc(
+       problem *prob_col, double *w, double eps, 
+       double Cp, double Cn)
+{
+       int l = prob_col->l;
+       int w_size = prob_col->n;
+       int i, s, iter = 0;
+       int max_iter = 1000;
+       int active_size = w_size;
+       int max_num_linesearch = 20;
+
+       double sigma = 0.01;
+       double d, G_loss, G, H;
+       double Gmax_old = INF;
+       double Gmax_new;
+       double Gmax_init;
+       double d_old, d_diff;
+       double loss_old, loss_new;
+       double appxcond, cond;
+
+       int *index = new int[w_size];
+       schar *y = new schar[l];
+       double *b = new double[l]; // b = 1-ywTx
+       double *xi_sq = new double[w_size];
+       feature_node *x;
+
+       // To support weights for instances,
+       // replace C[y[ind]] with C[i].
+       double C[2] = {Cn,Cp};
+
+       for(i=0; i<l; i++)
+       {
+               b[i] = 1;
+               if(prob_col->y[i] > 0)
+                       y[i] = 1;
+               else
+                       y[i] = 0;
+       }
+
+       for(i=0; i<w_size; i++)
+       {
+               w[i] = 0;
+               index[i] = i;
+               xi_sq[i] = 0;
+               x = prob_col->x[i];
+               while(x->index != -1)
+               {
+                       int ind = x->index;
+                       double val = x->value;
+                       x->value *= prob_col->y[ind]; // x->value stores yj*xji
+                       xi_sq[i] += C[y[ind]]*val*val;
+                       x++;
+               }
+       }
+
+       while(iter < max_iter)
+       {
+               Gmax_new  = 0;
+
+               for(i=0; i<active_size; i++)
+               {
+                       int j = i+rand()%(active_size-i);
+                       swap(index[i], index[j]);
+               }
+
+               for(s=0; s<active_size; s++)
+               {
+                       i = index[s];
+                       G_loss = 0;
+                       H = 0;
+
+                       x = prob_col->x[i];
+                       while(x->index != -1)
+                       {
+                               int ind = x->index;
+                               if(b[ind] > 0)
+                               {
+                                       double val = x->value;
+                                       double tmp = C[y[ind]]*val;
+                                       G_loss -= tmp*b[ind];
+                                       H += tmp*val;
+                               }
+                               x++;
+                       }
+                       G_loss *= 2;
+
+                       G = G_loss;
+                       H *= 2;
+
+                       double Gp = G+1;
+                       double Gn = G-1;
+                       // shrinking
+                       if(w[i] == 0)
+                       {
+                               if(Gp>Gmax_old/l && Gn<-Gmax_old/l)
+                               {
+                                       active_size--;
+                                       swap(index[s], index[active_size]);
+                                       s--;
+                                       continue;
+                               }
+                       }
+
+                       // obtain Newton direction d
+                       if(Gp <= H*w[i])
+                       {
+                               G = Gp;
+                               d = -Gp/H;
+                       }
+                       else if(Gn >= H*w[i])
+                       {
+                               G = Gn;
+                               d = -Gn/H;
+                       }
+                       else
+                       {
+                               G = 0;
+                               d = -w[i];
+                       }
+
+                       Gmax_new = max(Gmax_new, fabs(G));
+
+                       if(fabs(d) < 1.0e-12)
+                               continue;
+
+                       d_old = 0;
+                       int num_linesearch;
+                       for(num_linesearch=0; num_linesearch < max_num_linesearch; num_linesearch++)
+                       {
+                               d_diff = d_old - d;
+                               cond = fabs(w[i]+d)-fabs(w[i]) + sigma*d*d;
+
+                               appxcond = xi_sq[i]*d*d + G_loss*d + cond;
+                               if(appxcond <= 0)
+                               {
+                                       x = prob_col->x[i];
+                                       while(x->index != -1)
+                                       {
+                                               b[x->index] += d_diff*x->value;
+                                               x++;
+                                       }
+                                       break;
+                               }
+
+                               if(num_linesearch == 0)
+                               {
+                                       loss_old = 0;
+                                       loss_new = 0;
+                                       x = prob_col->x[i];
+                                       while(x->index != -1)
+                                       {
+                                               int ind = x->index;
+                                               if(b[ind] > 0)
+                                                       loss_old += C[y[ind]]*b[ind]*b[ind];
+                                               double b_new = b[ind] + d_diff*x->value;
+                                               b[ind] = b_new;
+                                               if(b_new > 0)
+                                                       loss_new += C[y[ind]]*b_new*b_new;
+                                               x++;
+                                       }
+                               }
+                               else 
+                               {
+                                       loss_new = 0;
+                                       x = prob_col->x[i];
+                                       while(x->index != -1)
+                                       {
+                                               int ind = x->index;
+                                               double b_new = b[ind] + d_diff*x->value;
+                                               b[ind] = b_new;
+                                               if(b_new > 0)
+                                                       loss_new += C[y[ind]]*b_new*b_new;
+                                               x++;
+                                       }
+                               }
+
+                               cond = cond + loss_new - loss_old;
+                               if(cond <= 0)
+                                       break;
+                               else
+                               {
+                                       d_old = d;
+                                       d *= 0.5;
+                               }
+                       }
+
+                       w[i] += d;
+
+                       // recompute b[] if line search takes too many steps
+                       if(num_linesearch >= max_num_linesearch)
+                       {
+                               info("#");
+                               for(int j=0; j<l; j++)
+                                       b[j] = 1;
+
+                               for(int j=0; j<w_size; j++)
+                               {
+                                       if(w[j]==0) continue;
+                                       x = prob_col->x[j];
+                                       while(x->index != -1)
+                                       {
+                                               b[x->index] -= w[j]*x->value;
+                                               x++;
+                                       }
+                               }
+                       }
+               }
+
+               if(iter == 0)
+                       Gmax_init = Gmax_new;
+               iter++;
+               if(iter % 10 == 0)
+                       info(".");
+
+               if(Gmax_new <= eps*Gmax_init)
+               {
+                       if(active_size == w_size)
+                               break;
+                       else
+                       {
+                               active_size = w_size;
+                               info("*");
+                               Gmax_old = INF;
+                               continue;
+                       }
+               }
+
+               Gmax_old = Gmax_new;
+       }
+
+       info("\noptimization finished, #iter = %d\n", iter);
+       if(iter >= max_iter)
+               info("\nWARNING: reaching max number of iterations\n");
+
+       // calculate objective value
+
+       double v = 0;
+       int nnz = 0;
+       for(i=0; i<w_size; i++)
+       {
+               x = prob_col->x[i];
+               while(x->index != -1)
+               {
+                       x->value *= prob_col->y[x->index]; // restore x->value
+                       x++;
+               }
+               if(w[i] != 0)
+               {
+                       v += fabs(w[i]);
+                       nnz++;
+               }
+       }
+       for(i=0; i<l; i++)
+               if(b[i] > 0)
+                       v += C[y[i]]*b[i]*b[i];
+
+       info("Objective value = %lf\n", v);
+       info("#nonzeros/#features = %d/%d\n", nnz, w_size);
+
+       delete [] index;
+       delete [] y;
+       delete [] b;
+       delete [] xi_sq;
+}
+
+// A coordinate descent algorithm for 
+// L1-regularized logistic regression problems
+//
+//  min_w \sum |wj| - C \sum log(1+exp(-yi w^T xi)),
+//
+// Given: 
+// x, y, Cp, Cn
+// eps is the stopping tolerance
+//
+// solution will be put in w
+
+static void solve_l1_lr(
+       const problem *prob_col, double *w, double eps, 
+       double Cp, double Cn)
+{
+       int l = prob_col->l;
+       int w_size = prob_col->n;
+       int i, s, iter = 0;
+       int max_iter = 1000;
+       int active_size = w_size;
+       int max_num_linesearch = 20;
+
+       double x_min = 0;
+       double sigma = 0.01;
+       double d, G, H;
+       double Gmax_old = INF;
+       double Gmax_new;
+       double Gmax_init;
+       double sum1, appxcond1;
+       double sum2, appxcond2;
+       double cond;
+
+       int *index = new int[w_size];
+       schar *y = new schar[l];
+       double *exp_wTx = new double[l];
+       double *exp_wTx_new = new double[l];
+       double *xi_max = new double[w_size];
+       double *C_sum = new double[w_size];
+       double *xineg_sum = new double[w_size];
+       double *xipos_sum = new double[w_size];
+       feature_node *x;
+
+       // To support weights for instances,
+       // replace C[y[ind]] with C[i].
+       double C[2] = {Cn,Cp};
+
+       for(i=0; i<l; i++)
+       {
+               exp_wTx[i] = 1;
+               if(prob_col->y[i] > 0)
+                       y[i] = 1;
+               else
+                       y[i] = 0;
+       }
+       for(i=0; i<w_size; i++)
+       {
+               w[i] = 0;
+               index[i] = i;
+               xi_max[i] = 0;
+               C_sum[i] = 0;
+               xineg_sum[i] = 0;
+               xipos_sum[i] = 0;
+               x = prob_col->x[i];
+               while(x->index != -1)
+               {
+                       int ind = x->index;
+                       double val = x->value;
+                       x_min = min(x_min, val);
+                       xi_max[i] = max(xi_max[i], val);
+                       C_sum[i] += C[y[ind]];
+                       if(y[ind] == 0)
+                               xineg_sum[i] += C[y[ind]]*val;
+                       else
+                               xipos_sum[i] += C[y[ind]]*val;
+                       x++;
+               }
+       }
+
+       while(iter < max_iter)
+       {
+               Gmax_new = 0;
+
+               for(i=0; i<active_size; i++)
+               {
+                       int j = i+rand()%(active_size-i);
+                       swap(index[i], index[j]);
+               }
+
+               for(s=0; s<active_size; s++)
+               {
+                       i = index[s];
+                       sum1 = 0;
+                       sum2 = 0;
+                       H = 0;
+
+                       x = prob_col->x[i];
+                       while(x->index != -1)
+                       {
+                               int ind = x->index;
+                               double exp_wTxind = exp_wTx[ind];
+                               double tmp1 = x->value/(1+exp_wTxind);
+                               double tmp2 = C[y[ind]]*tmp1;
+                               double tmp3 = tmp2*exp_wTxind;
+                               sum2 += tmp2;
+                               sum1 += tmp3;
+                               H += tmp1*tmp3;
+                               x++;
+                       }
+
+                       G = -sum2 + xineg_sum[i];
+
+                       double Gp = G+1;
+                       double Gn = G-1;
+                       // shrinking
+                       if(w[i] == 0)
+                       {
+                               if(Gp>Gmax_old/l && Gn<-Gmax_old/l)
+                               {
+                                       active_size--;
+                                       swap(index[s], index[active_size]);
+                                       s--;
+                                       continue;
+                               }
+                       }
+
+                       // obtain Newton direction d
+                       if(Gp <= H*w[i])
+                       {
+                               G = Gp;
+                               d = -Gp/H;
+                       }
+                       else if(Gn >= H*w[i])
+                       {
+                               G = Gn;
+                               d = -Gn/H;
+                       }
+                       else
+                       {
+                               G = 0;
+                               d = -w[i];
+                       }
+
+                       Gmax_new = max(Gmax_new, fabs(G));
+
+                       if(fabs(d) < 1.0e-12)
+                               continue;
+
+                       d = min(max(d,-10.0),10.0);
+
+                       int num_linesearch;
+                       for(num_linesearch=0; num_linesearch < max_num_linesearch; num_linesearch++)
+                       {
+                               cond = fabs(w[i]+d)-fabs(w[i]) + sigma*d*d;
+
+                               if(x_min >= 0)
+                               {
+                                       double tmp = exp(d*xi_max[i]);
+                                       appxcond1 = log(1+sum1*(tmp-1)/xi_max[i]/C_sum[i])*C_sum[i] + cond - d*xipos_sum[i];
+                                       appxcond2 = log(1+sum2*(1/tmp-1)/xi_max[i]/C_sum[i])*C_sum[i] + cond + d*xineg_sum[i];
+                                       if(min(appxcond1,appxcond2) <= 0)
+                                       {
+                                               x = prob_col->x[i];
+                                               while(x->index != -1)
+                                               {
+                                                       exp_wTx[x->index] *= exp(d*x->value);
+                                                       x++;
+                                               }
+                                               break;
+                                       }
+                               }
+
+                               cond += d*xineg_sum[i];
+
+                               int j = 0;
+                               x = prob_col->x[i];
+                               while(x->index != -1)
+                               {
+                                       int ind = x->index;
+                                       double exp_dx = exp(d*x->value);
+                                       exp_wTx_new[j] = exp_wTx[ind]*exp_dx;
+                                       cond += C[y[ind]]*log((1+exp_wTx_new[j])/(exp_dx+exp_wTx_new[j]));
+                                       x++; j++;
+                               }
+
+                               if(cond <= 0)
+                               {
+                                       int j = 0;
+                                       x = prob_col->x[i];
+                                       while(x->index != -1)
+                                       {
+                                               int ind = x->index;
+                                               exp_wTx[ind] = exp_wTx_new[j];
+                                               x++; j++;
+                                       }
+                                       break;
+                               }
+                               else
+                                       d *= 0.5;
+                       }
+
+                       w[i] += d;
+
+                       // recompute exp_wTx[] if line search takes too many steps
+                       if(num_linesearch >= max_num_linesearch)
+                       {
+                               info("#");
+                               for(int j=0; j<l; j++)
+                                       exp_wTx[j] = 0;
+
+                               for(int j=0; j<w_size; j++)
+                               {
+                                       if(w[j]==0) continue;
+                                       x = prob_col->x[j];
+                                       while(x->index != -1)
+                                       {
+                                               exp_wTx[x->index] += w[j]*x->value;
+                                               x++;
+                                       }
+                               }
+
+                               for(int j=0; j<l; j++)
+                                       exp_wTx[j] = exp(exp_wTx[j]);
+                       }
+               }
+
+               if(iter == 0)
+                       Gmax_init = Gmax_new;
+               iter++;
+               if(iter % 10 == 0)
+                       info(".");
+
+               if(Gmax_new <= eps*Gmax_init)
+               {
+                       if(active_size == w_size)
+                               break;
+                       else
+                       {
+                               active_size = w_size;
+                               info("*");
+                               Gmax_old = INF;
+                               continue;
+                       }
+               }
+
+               Gmax_old = Gmax_new;
+       }
+
+       info("\noptimization finished, #iter = %d\n", iter);
+       if(iter >= max_iter)
+               info("\nWARNING: reaching max number of iterations\n");
+
+       // calculate objective value
+       
+       double v = 0;
+       int nnz = 0;
+       for(i=0; i<w_size; i++)
+               if(w[i] != 0)
+               {
+                       v += fabs(w[i]);
+                       nnz++;
+               }
+       for(i=0; i<l; i++)
+               if(y[i] == 1)
+                       v += C[y[i]]*log(1+1/exp_wTx[i]);
+               else
+                       v += C[y[i]]*log(1+exp_wTx[i]);
+
+       info("Objective value = %lf\n", v);
+       info("#nonzeros/#features = %d/%d\n", nnz, w_size);
+
+       delete [] index;
+       delete [] y;
+       delete [] exp_wTx;
+       delete [] exp_wTx_new;
+       delete [] xi_max;
+       delete [] C_sum;
+       delete [] xineg_sum;
+       delete [] xipos_sum;
+}
+
+// transpose matrix X from row format to column format
+static void transpose(const problem *prob, feature_node *x_space, problem *prob_col)
+{
+       int i;
+       int l = prob->l;
+       int n = prob->n;
+       int nnz = 0;
+       int *col_ptr = new int[n+1];
+       prob_col->l = l;
+       prob_col->n = n;
+       prob_col->y = new int[l];
+       prob_col->x = new feature_node*[n];
+
+       for(i=0; i<l; i++)
+               prob_col->y[i] = prob->y[i];
+
+       for(i=0; i<n+1; i++)
+               col_ptr[i] = 0;
+       for(i=0; i<l; i++)
+       {
+               feature_node *x = prob->x[i];
+               while(x->index != -1)
+               {
+                       nnz++;
+                       col_ptr[x->index]++;
+                       x++;
+               }
+       }
+       for(i=1; i<n+1; i++)
+               col_ptr[i] += col_ptr[i-1] + 1;
+
+       x_space = new feature_node[nnz+n];
+       for(i=0; i<n; i++)
+               prob_col->x[i] = &x_space[col_ptr[i]];
+
+       for(i=0; i<l; i++)
+       {
+               feature_node *x = prob->x[i];
+               while(x->index != -1)
+               {
+                       int ind = x->index-1;
+                       x_space[col_ptr[ind]].index = i;
+                       x_space[col_ptr[ind]].value = x->value;
+                       col_ptr[ind]++;
+                       x++;
+               }
+       }
+       for(i=0; i<n; i++)
+               x_space[col_ptr[i]].index = -1;
+
+       delete [] col_ptr;
+}
+
 // label: label name, start: begin of each class, count: #data of classes, perm: indices to the original data
 // perm, length l, must be allocated before calling this subroutine
-void group_classes(const problem *prob, int *nr_class_ret, int **label_ret, int **start_ret, int **count_ret, int *perm)
+static void group_classes(const problem *prob, int *nr_class_ret, int **label_ret, int **start_ret, int **count_ret, int *perm)
 {
        int l = prob->l;
        int max_nr_class = 16;
@@ -946,7 +1553,7 @@ void group_classes(const problem *prob, int *nr_class_ret, int **label_ret, int
        free(data_label);
 }
 
-void train_one(const problem *prob, const parameter *param, double *w, double Cp, double Cn)
+static void train_one(const problem *prob, const parameter *param, double *w, double Cp, double Cn)
 {
        double eps=param->eps;
        int pos = 0;
@@ -968,7 +1575,7 @@ void train_one(const problem *prob, const parameter *param, double *w, double Cp
                        delete fun_obj;
                        break;
                }
-               case L2LOSS_SVM:
+               case L2_L2LOSS_SVC:
                {
                        fun_obj=new l2loss_svm_fun(prob, Cp, Cn);
                        TRON tron_obj(fun_obj, eps*min(pos,neg)/prob->l);
@@ -977,12 +1584,34 @@ void train_one(const problem *prob, const parameter *param, double *w, double Cp
                        delete fun_obj;
                        break;
                }
-               case L2LOSS_SVM_DUAL:
-                       solve_linear_c_svc(prob, w, eps, Cp, Cn, L2LOSS_SVM_DUAL);
+               case L2_L2LOSS_SVC_DUAL:
+                       solve_l2_l1l2_svc(prob, w, eps, Cp, Cn, L2_L2LOSS_SVC_DUAL);
+                       break;
+               case L2_L1LOSS_SVC_DUAL:
+                       solve_l2_l1l2_svc(prob, w, eps, Cp, Cn, L2_L1LOSS_SVC_DUAL);
+                       break;
+               case L1_L2LOSS_SVC:
+               {
+                       problem prob_col;
+                       feature_node *x_space = NULL;
+                       transpose(prob, x_space ,&prob_col);
+                       solve_l1_l2_svc(&prob_col, w, eps*min(pos,neg)/prob->l, Cp, Cn);
+                       delete [] prob_col.y;
+                       delete [] prob_col.x;
+                       delete [] x_space;
                        break;
-               case L1LOSS_SVM_DUAL:
-                       solve_linear_c_svc(prob, w, eps, Cp, Cn, L1LOSS_SVM_DUAL);
+               }
+               case L1_LR:
+               {
+                       problem prob_col;
+                       feature_node *x_space = NULL;
+                       transpose(prob, x_space ,&prob_col);
+                       solve_l1_lr(&prob_col, w, eps*min(pos,neg)/prob->l, Cp, Cn);
+                       delete [] prob_col.y;
+                       delete [] prob_col.x;
+                       delete [] x_space;
                        break;
+               }
                default:
                        fprintf(stderr, "Error: unknown solver_type\n");
                        break;
@@ -1123,9 +1752,9 @@ void destroy_model(struct model *model_)
        free(model_);
 }
 
-const char *solver_type_table[]=
+static const char *solver_type_table[]=
 {
-       "L2_LR", "L2LOSS_SVM_DUAL", "L2LOSS_SVM","L1LOSS_SVM_DUAL","MCSVM_CS", NULL
+       "L2_LR", "L2_L2LOSS_SVC_DUAL", "L2_L2LOSS_SVC","L2_L1LOSS_SVC_DUAL","MCSVM_CS", "L1_L2LOSS_SVC","L1_LR", NULL
 };
 
 int save_model(const char *model_file_name, const struct model *model_)
@@ -1372,10 +2001,12 @@ const char *check_parameter(const problem *prob, const parameter *param)
                return "C <= 0";
 
        if(param->solver_type != L2_LR
-               && param->solver_type != L2LOSS_SVM_DUAL
-               && param->solver_type != L2LOSS_SVM
-               && param->solver_type != L1LOSS_SVM_DUAL
-               && param->solver_type != MCSVM_CS)
+               && param->solver_type != L2_L2LOSS_SVC_DUAL
+               && param->solver_type != L2_L2LOSS_SVC
+               && param->solver_type != L2_L1LOSS_SVC_DUAL
+               && param->solver_type != MCSVM_CS
+               && param->solver_type != L1_L2LOSS_SVC
+               && param->solver_type != L1_LR)
                return "unknown solver type";
 
        return NULL;
index f36d65af717375f42b58b5cb4473973783d19e37..10be96bbce5a9e726671976cb5958950f4408069 100644 (file)
--- a/linear.h
+++ b/linear.h
@@ -19,7 +19,7 @@ struct problem
        double bias;            /* < 0 if no bias term */  
 };
 
-enum { L2_LR, L2LOSS_SVM_DUAL, L2LOSS_SVM, L1LOSS_SVM_DUAL, MCSVM_CS }; /* solver_type */
+enum { L2_LR, L2_L2LOSS_SVC_DUAL, L2_L2LOSS_SVC, L2_L1LOSS_SVC_DUAL, MCSVM_CS, L1_L2LOSS_SVC, L1_LR }; /* solver_type */
 
 struct parameter
 {
index 1ffb6ee7e3ed3a9a4f40662b6c98b15ff2ee23b7..f8ad784787f7f289d5a1ffe920a6cd37882c6633 100644 (file)
@@ -27,10 +27,12 @@ void exit_with_help()
        "liblinear_options:\n"
        "-s type : set type of solver (default 1)\n"
        "       0 -- L2-regularized logistic regression\n"
-       "       1 -- L2-loss support vector machines (dual)\n"  
-       "       2 -- L2-loss support vector machines (primal)\n"
-       "       3 -- L1-loss support vector machines (dual)\n"
-       "       4 -- multi-class support vector machines by Crammer and Singer\n"
+       "       1 -- L2-regularized L2-loss support vector classification (dual)\n"     
+       "       2 -- L2-regularized L2-loss support vector classification (primal)\n"
+       "       3 -- L2-regularized L1-loss support vector classification (dual)\n"
+       "       4 -- multi-class support vector classification by Crammer and Singer\n"
+       "       5 -- L1-regularized L2-loss support vector classification\n"
+       "       6 -- L1-regularized logistic regression\n"
        "-c cost : set the parameter C (default 1)\n"
        "-e epsilon : set tolerance of termination criterion\n"
        "       -s 0 and 2\n" 
@@ -38,6 +40,9 @@ void exit_with_help()
        "               where f is the primal function, (default 0.01)\n"
        "       -s 1, 3, and 4\n"
        "               Dual maximal violation <= eps; similar to libsvm (default 0.1)\n"
+       "       -s 5 and 6\n"
+       "               |f'(w)|_inf <= eps*min(pos,neg)/l*|f'(w0)|_inf,\n"
+       "               where f is the primal function (default 0.01)\n"
        "-B bias : if bias >= 0, instance x becomes [x; bias]; if < 0, no bias term added (default 1)\n"
        "-wi weight: weights adjust the parameter C of different classes (see README for details)\n"
        "-v n: n-fold cross validation mode\n"
@@ -84,7 +89,7 @@ int parse_command_line(int nrhs, const mxArray *prhs[], char *model_file_name)
        char *argv[CMD_LEN/2];
 
        // default values
-       param.solver_type = L2LOSS_SVM_DUAL;
+       param.solver_type = L2_L2LOSS_SVC_DUAL;
        param.C = 1;
        param.eps = INF; // see setting below
        param.nr_weight = 0;
@@ -168,10 +173,12 @@ int parse_command_line(int nrhs, const mxArray *prhs[], char *model_file_name)
 
        if(param.eps == INF)
        {
-               if(param.solver_type == L2_LR || param.solver_type == L2LOSS_SVM)
+               if(param.solver_type == L2_LR || param.solver_type == L2_L2LOSS_SVC)
                        param.eps = 0.01;
-               else if(param.solver_type == L2LOSS_SVM_DUAL || param.solver_type == L1LOSS_SVM_DUAL || param.solver_type == MCSVM_CS)
+               else if(param.solver_type == L2_L2LOSS_SVC_DUAL || param.solver_type == L2_L1LOSS_SVC_DUAL || param.solver_type == MCSVM_CS)
                        param.eps = 0.1;
+               else if(param.solver_type == L1_L2LOSS_SVC || param.solver_type == L1_LR)
+                       param.eps = 0.01;
        }
        return 0;
 }
diff --git a/train.c b/train.c
index c905635c0de0bc219349cd871dd140c84b793bdb..66194bfc20dab530525edaf3ef0a9098050715bc 100644 (file)
--- a/train.c
+++ b/train.c
@@ -17,10 +17,12 @@ void exit_with_help()
        "options:\n"
        "-s type : set type of solver (default 1)\n"
        "       0 -- L2-regularized logistic regression\n"
-       "       1 -- L2-loss support vector machines (dual)\n"  
-       "       2 -- L2-loss support vector machines (primal)\n"
-       "       3 -- L1-loss support vector machines (dual)\n"
-       "       4 -- multi-class support vector machines by Crammer and Singer\n"
+       "       1 -- L2-regularized L2-loss support vector classification (dual)\n"     
+       "       2 -- L2-regularized L2-loss support vector classification (primal)\n"
+       "       3 -- L2-regularized L1-loss support vector classification (dual)\n"
+       "       4 -- multi-class support vector classification by Crammer and Singer\n"
+       "       5 -- L1-regularized L2-loss support vector classification\n"
+       "       6 -- L1-regularized logistic regression\n"
        "-c cost : set the parameter C (default 1)\n"
        "-e epsilon : set tolerance of termination criterion\n"
        "       -s 0 and 2\n" 
@@ -29,6 +31,9 @@ void exit_with_help()
        "               positive/negative data (default 0.01)\n"
        "       -s 1, 3, and 4\n"
        "               Dual maximal violation <= eps; similar to libsvm (default 0.1)\n"
+       "       -s 5 and 6\n"
+       "               |f'(w)|_inf <= eps*min(pos,neg)/l*|f'(w0)|_inf,\n"
+       "               where f is the primal function (default 0.01)\n"
        "-B bias : if bias >= 0, instance x becomes [x; bias]; if < 0, no bias term added (default 1)\n"
        "-wi weight: weights adjust the parameter C of different classes (see README for details)\n"
        "-v n: n-fold cross validation mode\n"
@@ -132,7 +137,7 @@ void parse_command_line(int argc, char **argv, char *input_file_name, char *mode
        int i;
 
        // default values
-       param.solver_type = L2LOSS_SVM_DUAL;
+       param.solver_type = L2_L2LOSS_SVC_DUAL;
        param.C = 1;
        param.eps = INF; // see setting below
        param.nr_weight = 0;
@@ -215,10 +220,12 @@ void parse_command_line(int argc, char **argv, char *input_file_name, char *mode
 
        if(param.eps == INF)
        {
-               if(param.solver_type == L2_LR || param.solver_type == L2LOSS_SVM)
+               if(param.solver_type == L2_LR || param.solver_type == L2_L2LOSS_SVC)
                        param.eps = 0.01;
-               else if(param.solver_type == L2LOSS_SVM_DUAL || param.solver_type == L1LOSS_SVM_DUAL || param.solver_type == MCSVM_CS)
+               else if(param.solver_type == L2_L2LOSS_SVC_DUAL || param.solver_type == L2_L1LOSS_SVC_DUAL || param.solver_type == MCSVM_CS)
                        param.eps = 0.1;
+               else if(param.solver_type == L1_L2LOSS_SVC || param.solver_type == L1_LR)
+                       param.eps = 0.01;
        }
 }