]> granicus.if.org Git - liblinear/commitdiff
Change the C access mechanism for all solver
authorrofu <rofu@16e7d947-dcc2-db11-b54a-0017319806e7>
Sun, 20 Dec 2009 05:20:37 +0000 (05:20 +0000)
committerrofu <rofu@16e7d947-dcc2-db11-b54a-0017319806e7>
Sun, 20 Dec 2009 05:20:37 +0000 (05:20 +0000)
by adding a Macro GETI(i).
It is easier for the instance weight support.

linear.cpp

index c06ad3ef2a56f1da9c6b136fc1945ccdc4a0e52a..33fcae477601d7cd19ef98730e0f1d50cfe9ad75 100644 (file)
@@ -391,6 +391,10 @@ void l2r_l2_svc_fun::subXTv(double *v, double *XTv)
 // eps is the stopping tolerance
 //
 // solution will be put in w
+
+#define GETI(i) (prob->y[i])
+// To support weights for instances, use GETI(i) (i)
+
 class Solver_MCSVM_CS
 {
        public:
@@ -408,7 +412,7 @@ class Solver_MCSVM_CS
                const problem *prob;
 };
 
-Solver_MCSVM_CS::Solver_MCSVM_CS(const problem *prob, int nr_class, double *C, double eps, int max_iter)
+Solver_MCSVM_CS::Solver_MCSVM_CS(const problem *prob, int nr_class, double *weighted_C, double eps, int max_iter)
 {
        this->w_size = prob->n;
        this->l = prob->l;
@@ -416,7 +420,7 @@ Solver_MCSVM_CS::Solver_MCSVM_CS(const problem *prob, int nr_class, double *C, d
        this->eps = eps;
        this->max_iter = max_iter;
        this->prob = prob;
-       this->C = C;
+       this->C = weighted_C;
        this->B = new double[nr_class];
        this->G = new double[nr_class];
 }
@@ -465,7 +469,7 @@ bool Solver_MCSVM_CS::be_shrunk(int i, int m, int yi, double alpha_i, double min
 {
        double bound = 0;
        if(m == yi)
-               bound = C[prob->y[i]];
+               bound = C[GETI(i)];
        if(alpha_i == bound && G[m] < minG)
                return true;
        return false;
@@ -549,7 +553,7 @@ void Solver_MCSVM_CS::Solve(double *w)
                                                maxG = G[m];
                                }
                                if(y_index[i] < active_size_i[i])
-                                       if(alpha_i[prob->y[i]] < C[prob->y[i]] && G[y_index[i]] < minG)
+                                       if(alpha_i[prob->y[i]] < C[GETI(i)] && G[y_index[i]] < minG)
                                                minG = G[y_index[i]];
 
                                for(m=0;m<active_size_i[i];m++)
@@ -591,7 +595,7 @@ void Solver_MCSVM_CS::Solve(double *w)
                                for(m=0;m<active_size_i[i];m++)
                                        B[m] = G[m] - Ai*alpha_i[alpha_index_i[m]] ;
 
-                               solve_sub_problem(Ai, y_index[i], C[prob->y[i]], active_size_i[i], alpha_new);
+                               solve_sub_problem(Ai, y_index[i], C[GETI(i)], active_size_i[i], alpha_new);
                                int nz_d = 0;
                                for(m=0;m<active_size_i[i];m++)
                                {
@@ -696,6 +700,10 @@ void Solver_MCSVM_CS::Solve(double *w)
 //
 // solution will be put in w
 
+#undef GETI
+#define GETI(i) (y[i]+1)
+// To support weights for instances, use GETI(i) (i)
+
 static void solve_l2r_l1l2_svc(
        const problem *prob, double *w, double eps, 
        double Cp, double Cn, int solver_type)
@@ -718,12 +726,14 @@ static void solve_l2r_l1l2_svc(
        double PGmax_new, PGmin_new;
 
        // default solver_type: L2R_L2LOSS_SVC_DUAL
-       double diag_p = 0.5/Cp, diag_n = 0.5/Cn;
-       double upper_bound_p = INF, upper_bound_n = INF;
+       double diag[3] = {0.5/Cn, 0, 0.5/Cp};
+       double upper_bound[3] = {INF, 0, INF};
        if(solver_type == L2R_L1LOSS_SVC_DUAL)
        {
-               diag_p = 0; diag_n = 0;
-               upper_bound_p = Cp; upper_bound_n = Cn;
+               diag[0] = 0;
+               diag[2] = 0;
+               upper_bound[0] = Cn;
+               upper_bound[2] = Cp;
        }
 
        for(i=0; i<w_size; i++)
@@ -734,13 +744,12 @@ static void solve_l2r_l1l2_svc(
                if(prob->y[i] > 0)
                {
                        y[i] = +1; 
-                       QD[i] = diag_p;
                }
                else
                {
                        y[i] = -1;
-                       QD[i] = diag_n;
                }
+               QD[i] = diag[GETI(i)];
 
                feature_node *xi = prob->x[i];
                while (xi->index != -1)
@@ -776,16 +785,8 @@ static void solve_l2r_l1l2_svc(
                        }
                        G = G*yi-1;
 
-                       if(yi == 1)
-                       {
-                               C = upper_bound_p; 
-                               G += alpha[i]*diag_p; 
-                       }
-                       else 
-                       {
-                               C = upper_bound_n;
-                               G += alpha[i]*diag_n; 
-                       }
+                       C = upper_bound[GETI(i)];
+                       G += alpha[i]*diag[GETI(i)];
 
                        PG = 0;
                        if (alpha[i] == 0)
@@ -869,10 +870,7 @@ static void solve_l2r_l1l2_svc(
                v += w[i]*w[i];
        for(i=0; i<l; i++)
        {
-               if (y[i] == 1)
-                       v += alpha[i]*(alpha[i]*diag_p - 2); 
-               else
-                       v += alpha[i]*(alpha[i]*diag_n - 2);
+               v += alpha[i]*(alpha[i]*diag[GETI(i)] - 2);
                if(alpha[i] > 0)
                        ++nSV;
        }
@@ -896,6 +894,10 @@ static void solve_l2r_l1l2_svc(
 //
 // solution will be put in w
 
+#undef GETI
+#define GETI(i) (y[i]+1)
+// To support weights for instances, use GETI(i) (i)
+
 static void solve_l1r_l2_svc(
        problem *prob_col, double *w, double eps, 
        double Cp, double Cn)
@@ -922,9 +924,7 @@ static void solve_l1r_l2_svc(
        double *xj_sq = new double[w_size];
        feature_node *x;
 
-       // To support weights for instances,
-       // replace C[y[i]] with C[i].
-       double C[2] = {Cn,Cp};
+       double C[3] = {Cn,0,Cp};
 
        for(j=0; j<l; j++)
        {
@@ -932,7 +932,7 @@ static void solve_l1r_l2_svc(
                if(prob_col->y[j] > 0)
                        y[j] = 1;
                else
-                       y[j] = 0;
+                       y[j] = -1;
        }
        for(j=0; j<w_size; j++)
        {
@@ -944,8 +944,8 @@ static void solve_l1r_l2_svc(
                {
                        int ind = x->index-1;
                        double val = x->value;
-                       x->value *= prob_col->y[ind]; // x->value stores yi*xij
-                       xj_sq[j] += C[y[ind]]*val*val;
+                       x->value *= y[ind]; // x->value stores yi*xij
+                       xj_sq[j] += C[GETI(ind)]*val*val;
                        x++;
                }
        }
@@ -973,7 +973,7 @@ static void solve_l1r_l2_svc(
                                if(b[ind] > 0)
                                {
                                        double val = x->value;
-                                       double tmp = C[y[ind]]*val;
+                                       double tmp = C[GETI(ind)]*val;
                                        G_loss -= tmp*b[ind];
                                        H += tmp*val;
                                }
@@ -1049,11 +1049,11 @@ static void solve_l1r_l2_svc(
                                        {
                                                int ind = x->index-1;
                                                if(b[ind] > 0)
-                                                       loss_old += C[y[ind]]*b[ind]*b[ind];
+                                                       loss_old += C[GETI(ind)]*b[ind]*b[ind];
                                                double b_new = b[ind] + d_diff*x->value;
                                                b[ind] = b_new;
                                                if(b_new > 0)
-                                                       loss_new += C[y[ind]]*b_new*b_new;
+                                                       loss_new += C[GETI(ind)]*b_new*b_new;
                                                x++;
                                        }
                                }
@@ -1067,7 +1067,7 @@ static void solve_l1r_l2_svc(
                                                double b_new = b[ind] + d_diff*x->value;
                                                b[ind] = b_new;
                                                if(b_new > 0)
-                                                       loss_new += C[y[ind]]*b_new*b_new;
+                                                       loss_new += C[GETI(ind)]*b_new*b_new;
                                                x++;
                                        }
                                }
@@ -1151,7 +1151,7 @@ static void solve_l1r_l2_svc(
        }
        for(j=0; j<l; j++)
                if(b[j] > 0)
-                       v += C[y[j]]*b[j]*b[j];
+                       v += C[GETI(j)]*b[j]*b[j];
 
        info("Objective value = %lf\n", v);
        info("#nonzeros/#features = %d/%d\n", nnz, w_size);
@@ -1173,6 +1173,10 @@ static void solve_l1r_l2_svc(
 //
 // solution will be put in w
 
+#undef GETI
+#define GETI(i) (y[i]+1)
+// To support weights for instances, use GETI(i) (i)
+
 static void solve_l1r_lr(
        const problem *prob_col, double *w, double eps, 
        double Cp, double Cn)
@@ -1204,9 +1208,7 @@ static void solve_l1r_lr(
        double *xjpos_sum = new double[w_size];
        feature_node *x;
 
-       // To support weights for instances,
-       // replace C[y[i]] with C[i].
-       double C[2] = {Cn,Cp};
+       double C[3] = {Cn,0,Cp};
 
        for(j=0; j<l; j++)
        {
@@ -1214,7 +1216,7 @@ static void solve_l1r_lr(
                if(prob_col->y[j] > 0)
                        y[j] = 1;
                else
-                       y[j] = 0;
+                       y[j] = -1;
        }
        for(j=0; j<w_size; j++)
        {
@@ -1231,11 +1233,11 @@ static void solve_l1r_lr(
                        double val = x->value;
                        x_min = min(x_min, val);
                        xj_max[j] = max(xj_max[j], val);
-                       C_sum[j] += C[y[ind]];
-                       if(y[ind] == 0)
-                               xjneg_sum[j] += C[y[ind]]*val;
+                       C_sum[j] += C[GETI(ind)];
+                       if(y[ind] == -1)
+                               xjneg_sum[j] += C[GETI(ind)]*val;
                        else
-                               xjpos_sum[j] += C[y[ind]]*val;
+                               xjpos_sum[j] += C[GETI(ind)]*val;
                        x++;
                }
        }
@@ -1263,7 +1265,7 @@ static void solve_l1r_lr(
                                int ind = x->index-1;
                                double exp_wTxind = exp_wTx[ind];
                                double tmp1 = x->value/(1+exp_wTxind);
-                               double tmp2 = C[y[ind]]*tmp1;
+                               double tmp2 = C[GETI(ind)]*tmp1;
                                double tmp3 = tmp2*exp_wTxind;
                                sum2 += tmp2;
                                sum1 += tmp3;
@@ -1342,7 +1344,7 @@ static void solve_l1r_lr(
                                        int ind = x->index-1;
                                        double exp_dx = exp(d*x->value);
                                        exp_wTx_new[i] = exp_wTx[ind]*exp_dx;
-                                       cond += C[y[ind]]*log((1+exp_wTx_new[i])/(exp_dx+exp_wTx_new[i]));
+                                       cond += C[GETI(ind)]*log((1+exp_wTx_new[i])/(exp_dx+exp_wTx_new[i]));
                                        x++; i++;
                                }
 
@@ -1428,9 +1430,9 @@ static void solve_l1r_lr(
                }
        for(j=0; j<l; j++)
                if(y[j] == 1)
-                       v += C[y[j]]*log(1+1/exp_wTx[j]);
+                       v += C[GETI(j)]*log(1+1/exp_wTx[j]);
                else
-                       v += C[y[j]]*log(1+exp_wTx[j]);
+                       v += C[GETI(j)]*log(1+exp_wTx[j]);
 
        info("Objective value = %lf\n", v);
        info("#nonzeros/#features = %d/%d\n", nnz, w_size);