]> granicus.if.org Git - liblinear/commitdiff
initial commit of line search branch
authorleepei <leepei@local>
Thu, 31 Dec 2015 12:25:44 +0000 (20:25 +0800)
committerleepei <leepei@local>
Thu, 31 Dec 2015 12:25:44 +0000 (20:25 +0800)
linear.cpp
tron.cpp
tron.h

index b7bfecaf6713a200eb8f0b64ad80940f3131a05f..620781d06e1baace7752dd5aa5353d50cf6b0669 100644 (file)
@@ -79,47 +79,47 @@ public:
        }
 };
 
-class l2r_lr_fun: public function
+class l2r_erm_fun: public function
 {
 public:
-       l2r_lr_fun(const problem *prob, double *C);
-       ~l2r_lr_fun();
+       l2r_erm_fun(const problem *prob, double *C);
+       ~l2r_erm_fun();
 
        double fun(double *w);
-       void grad(double *w, double *g);
-       void Hv(double *s, double *Hs);
-
+       double line_search(double *d, double *w, double *g, double alpha, double *f);
        int get_nr_variable(void);
 
-private:
+protected:
+       virtual double C_times_loss(int i, double wx_i) = 0;
        void Xv(double *v, double *Xv);
        void XTv(double *v, double *XTv);
 
        double *C;
        double *z;
-       double *D;
        const problem *prob;
+       double *wx;
+       double wTw;
+       double current_f;
 };
 
-l2r_lr_fun::l2r_lr_fun(const problem *prob, double *C)
+l2r_erm_fun::l2r_erm_fun(const problem *prob, double *C)
 {
        int l=prob->l;
 
        this->prob = prob;
 
        z = new double[l];
-       D = new double[l];
+       wx = new double[l];
        this->C = C;
 }
 
-l2r_lr_fun::~l2r_lr_fun()
+l2r_erm_fun::~l2r_erm_fun()
 {
        delete[] z;
-       delete[] D;
+       delete[] wx;
 }
 
-
-double l2r_lr_fun::fun(double *w)
+double l2r_erm_fun::fun(double *w)
 {
        int i;
        double f=0;
@@ -127,23 +127,137 @@ double l2r_lr_fun::fun(double *w)
        int l=prob->l;
        int w_size=get_nr_variable();
 
+       wTw = 0;
        Xv(w, z);
 
        for(i=0;i<w_size;i++)
-               f += w[i]*w[i];
-       f /= 2.0;
+               wTw += w[i]*w[i];
        for(i=0;i<l;i++)
        {
-               double yz = y[i]*z[i];
-               if (yz >= 0)
-                       f += C[i]*log(1 + exp(-yz));
-               else
-                       f += C[i]*(-yz+log(1 + exp(yz)));
+               wx[i] = z[i];
+               f += C_times_loss(i, wx[i]);
        }
+       f = f + 0.5 * wTw;
 
+       current_f = f;
        return(f);
 }
 
+int l2r_erm_fun::get_nr_variable(void)
+{
+       return prob->n;
+}
+
+double l2r_erm_fun::line_search(double *d, double *w, double* g, double alpha, double *f)
+{
+       int i;
+       int l = prob->l;
+       double dTd, wTd, gd;
+       double eta = 1e-4;
+       int w_size = get_nr_variable();
+       int max_line_search_iter = 1000;
+       double *y=prob->y;
+       Xv(d, z);
+
+       dTd = 0;
+       wTd = 0;
+       gd = 0;
+       for (i=0;i<w_size;i++)
+       {
+               dTd += d[i] * d[i];
+               wTd += d[i] * w[i];
+               gd += d[i] * g[i];
+       }
+       int line_search_times = 0;
+       while (true)
+       {
+               if (line_search_times++ >= max_line_search_iter)
+               {
+                       f[0] = current_f;
+                       return 0;
+               }
+               double loss = 0;
+               for(i=0;i<l;i++)
+               {
+                       double inner_product = z[i] * alpha + wx[i];
+                       loss += C_times_loss(i, inner_product);
+               }
+               f[0] = loss + (alpha * alpha * dTd + wTw) / 2.0 + alpha * wTd;
+               if (f[0] - current_f <= eta * alpha * gd)
+               {
+                       for (i=0;i<l;i++)
+                       {
+                               wx[i] += alpha * z[i];
+                               z[i] = wx[i];
+                       }
+                       break;
+               }
+               else
+                       alpha *= 0.5;
+       }
+       wTw += alpha*alpha * dTd + 2* alpha * wTd;
+       current_f = f[0];
+       return alpha;
+}
+
+void l2r_erm_fun::Xv(double *v, double *Xv)
+{
+       int i;
+       int l=prob->l;
+       feature_node **x=prob->x;
+
+       for(i=0;i<l;i++)
+               Xv[i]=sparse_operator::dot(v, x[i]);
+}
+
+void l2r_erm_fun::XTv(double *v, double *XTv)
+{
+       int i;
+       int l=prob->l;
+       int w_size=get_nr_variable();
+       feature_node **x=prob->x;
+
+       for(i=0;i<w_size;i++)
+               XTv[i]=0;
+       for(i=0;i<l;i++)
+               sparse_operator::axpy(v[i], x[i], XTv);
+}
+
+class l2r_lr_fun: public l2r_erm_fun
+{
+public:
+       l2r_lr_fun(const problem *prob, double *C);
+       ~l2r_lr_fun();
+
+       void grad(double *w, double *g);
+       void Hv(double *s, double *Hs);
+
+private:
+       double *D;
+       double C_times_loss(int i, double wx_i);
+};
+
+l2r_lr_fun::l2r_lr_fun(const problem *prob, double *C):
+       l2r_erm_fun(prob, C)
+{
+       int l=prob->l;
+       D = new double[l];
+}
+
+l2r_lr_fun::~l2r_lr_fun()
+{
+       delete[] D;
+}
+
+double l2r_lr_fun::C_times_loss(int i, double wx_i)
+{
+       double ywx_i = wx_i * prob->y[i];
+       if (ywx_i >= 0)
+               return C[i]*log(1 + exp(-ywx_i));
+       else
+               return C[i]*(-ywx_i + log(1 + exp(ywx_i)));
+}
+
 void l2r_lr_fun::grad(double *w, double *g)
 {
        int i;
@@ -163,11 +277,6 @@ void l2r_lr_fun::grad(double *w, double *g)
                g[i] = w[i] + g[i];
 }
 
-int l2r_lr_fun::get_nr_variable(void)
-{
-       return prob->n;
-}
-
 void l2r_lr_fun::Hv(double *s, double *Hs)
 {
        int i;
@@ -192,91 +301,43 @@ void l2r_lr_fun::Hv(double *s, double *Hs)
        delete[] wa;
 }
 
-void l2r_lr_fun::Xv(double *v, double *Xv)
-{
-       int i;
-       int l=prob->l;
-       feature_node **x=prob->x;
-
-       for(i=0;i<l;i++)
-               Xv[i]=sparse_operator::dot(v, x[i]);
-}
-
-void l2r_lr_fun::XTv(double *v, double *XTv)
-{
-       int i;
-       int l=prob->l;
-       int w_size=get_nr_variable();
-       feature_node **x=prob->x;
-
-       for(i=0;i<w_size;i++)
-               XTv[i]=0;
-       for(i=0;i<l;i++)
-               sparse_operator::axpy(v[i], x[i], XTv);
-}
-
-class l2r_l2_svc_fun: public function
+class l2r_l2_svc_fun: public l2r_erm_fun
 {
 public:
        l2r_l2_svc_fun(const problem *prob, double *C);
        ~l2r_l2_svc_fun();
 
-       double fun(double *w);
        void grad(double *w, double *g);
        void Hv(double *s, double *Hs);
 
-       int get_nr_variable(void);
-
 protected:
-       void Xv(double *v, double *Xv);
        void subXTv(double *v, double *XTv);
 
-       double *C;
-       double *z;
        int *I;
        int sizeI;
-       const problem *prob;
+
+private:
+       double C_times_loss(int i, double wx_i);
 };
 
-l2r_l2_svc_fun::l2r_l2_svc_fun(const problem *prob, double *C)
+l2r_l2_svc_fun::l2r_l2_svc_fun(const problem *prob, double *C):
+       l2r_erm_fun(prob, C)
 {
-       int l=prob->l;
-
-       this->prob = prob;
-
-       z = new double[l];
-       I = new int[l];
-       this->C = C;
+       I = new int[prob->l];
 }
 
 l2r_l2_svc_fun::~l2r_l2_svc_fun()
 {
-       delete[] z;
        delete[] I;
 }
 
-double l2r_l2_svc_fun::fun(double *w)
+double l2r_l2_svc_fun::C_times_loss(int i, double wx_i)
 {
-       int i;
-       double f=0;
-       double *y=prob->y;
-       int l=prob->l;
-       int w_size=get_nr_variable();
-
-       Xv(w, z);
-
-       for(i=0;i<w_size;i++)
-               f += w[i]*w[i];
-       f /= 2.0;
-       for(i=0;i<l;i++)
-       {
-               z[i] = y[i]*z[i];
-               double d = 1-z[i];
+               double d = 1 - prob->y[i] * wx_i;
                if (d > 0)
-                       f += C[i]*d*d;
-       }
-
-       return(f);
+                       return C[i]*d*d;
+               else
+                       return 0;
 }
 
 void l2r_l2_svc_fun::grad(double *w, double *g)
@@ -288,23 +349,21 @@ void l2r_l2_svc_fun::grad(double *w, double *g)
 
        sizeI = 0;
        for (i=0;i<l;i++)
+       {
+               z[i] *= y[i];
                if (z[i] < 1)
                {
                        z[sizeI] = C[i]*y[i]*(z[i]-1);
                        I[sizeI] = i;
                        sizeI++;
                }
+       }
        subXTv(z, g);
 
        for(i=0;i<w_size;i++)
                g[i] = w[i] + 2*g[i];
 }
 
-int l2r_l2_svc_fun::get_nr_variable(void)
-{
-       return prob->n;
-}
-
 void l2r_l2_svc_fun::Hv(double *s, double *Hs)
 {
        int i;
@@ -328,16 +387,6 @@ void l2r_l2_svc_fun::Hv(double *s, double *Hs)
        delete[] wa;
 }
 
-void l2r_l2_svc_fun::Xv(double *v, double *Xv)
-{
-       int i;
-       int l=prob->l;
-       feature_node **x=prob->x;
-
-       for(i=0;i<l;i++)
-               Xv[i]=sparse_operator::dot(v, x[i]);
-}
-
 void l2r_l2_svc_fun::subXTv(double *v, double *XTv)
 {
        int i;
@@ -355,10 +404,10 @@ class l2r_l2_svr_fun: public l2r_l2_svc_fun
 public:
        l2r_l2_svr_fun(const problem *prob, double *C, double p);
 
-       double fun(double *w);
        void grad(double *w, double *g);
 
 private:
+       double C_times_loss(int i, double wx_i);
        double p;
 };
 
@@ -368,30 +417,14 @@ l2r_l2_svr_fun::l2r_l2_svr_fun(const problem *prob, double *C, double p):
        this->p = p;
 }
 
-double l2r_l2_svr_fun::fun(double *w)
+double l2r_l2_svr_fun::C_times_loss(int i, double wx_i)
 {
-       int i;
-       double f=0;
-       double *y=prob->y;
-       int l=prob->l;
-       int w_size=get_nr_variable();
-       double d;
-
-       Xv(w, z);
-
-       for(i=0;i<w_size;i++)
-               f += w[i]*w[i];
-       f /= 2;
-       for(i=0;i<l;i++)
-       {
-               d = z[i] - y[i];
+               double d = wx_i - prob->y[i];
                if(d < -p)
-                       f += C[i]*(d+p)*(d+p);
+                       return C[i]*(d+p)*(d+p);
                else if(d > p)
-                       f += C[i]*(d-p)*(d-p);
-       }
-
-       return(f);
+                       return C[i]*(d-p)*(d-p);
+               return 0;
 }
 
 void l2r_l2_svr_fun::grad(double *w, double *g)
@@ -2202,7 +2235,7 @@ static void train_one(const problem *prob, const parameter *param, double *w, do
                                C[i] = param->C;
 
                        fun_obj=new l2r_l2_svr_fun(prob, C, param->p);
-                       TRON tron_obj(fun_obj, param->eps);
+                       TRON tron_obj(fun_obj, param->eps, eps_cg);
                        tron_obj.set_print_string(liblinear_print_string);
                        tron_obj.tron(w);
                        delete fun_obj;
index 3ea60f6dfa45be60f159a9ba3126d21a6d57071f..0ef20901019cd09acd1bdbcae39b63b3b5920d26 100644 (file)
--- a/tron.cpp
+++ b/tron.cpp
@@ -56,16 +56,11 @@ TRON::~TRON()
 
 void TRON::tron(double *w)
 {
-       // Parameters for updating the iterates.
-       double eta0 = 1e-4, eta1 = 0.25, eta2 = 0.75;
-
-       // Parameters for updating the trust region size delta.
-       double sigma1 = 0.25, sigma2 = 0.5, sigma3 = 4;
-
        int n = fun_obj->get_nr_variable();
        int i, cg_iter;
-       double delta, snorm, one=1.0;
-       double alpha, f, fnew, prered, actred, gs;
+       double snorm, step_size, one=1.0;
+       double alpha, f, fnew;
+       double init_step_size = 1;
        int search = 1, iter = 1, inc = 1;
        double *s = new double[n];
        double *r = new double[n];
@@ -82,8 +77,7 @@ void TRON::tron(double *w)
 
        f = fun_obj->fun(w);
        fun_obj->grad(w, g);
-       delta = dnrm2_(&n, g, &inc);
-       double gnorm = delta;
+       double gnorm = dnrm2_(&n, g, &inc);
 
        if (gnorm <= eps*gnorm0)
                search = 0;
@@ -93,66 +87,32 @@ void TRON::tron(double *w)
        double *w_new = new double[n];
        while (iter <= max_iter && search)
        {
-               cg_iter = trcg(delta, g, s, r);
+               cg_iter = trcg(g, s, r);
 
                memcpy(w_new, w, sizeof(double)*n);
                daxpy_(&n, &one, s, &inc, w_new, &inc);
 
-               gs = ddot_(&n, g, &inc, s, &inc);
-               prered = -0.5*(gs-ddot_(&n, s, &inc, r, &inc));
-               fnew = fun_obj->fun(w_new);
-
-               // Compute the actual reduction.
-               actred = f - fnew;
-
-               // On the first iteration, adjust the initial step bound.
-               snorm = dnrm2_(&n, s, &inc);
-               if (iter == 1)
-                       delta = min(delta, snorm);
-
-               // Compute prediction alpha*snorm of the step.
-               if (fnew - f - gs <= 0)
-                       alpha = sigma3;
-               else
-                       alpha = max(sigma1, -0.5*(gs/(fnew - f - gs)));
-
-               // Update the trust region bound according to the ratio of actual to predicted reduction.
-               if (actred < eta0*prered)
-                       delta = min(max(alpha, sigma1)*snorm, sigma2*delta);
-               else if (actred < eta1*prered)
-                       delta = max(sigma1*delta, min(alpha*snorm, sigma2*delta));
-               else if (actred < eta2*prered)
-                       delta = max(sigma1*delta, min(alpha*snorm, sigma3*delta));
-               else
-                       delta = max(delta, min(alpha*snorm, sigma3*delta));
-
-               info("iter %2d act %5.3e pre %5.3e delta %5.3e f %5.3e |g| %5.3e CG %3d\n", iter, actred, prered, delta, f, gnorm, cg_iter);
-
-               if (actred > eta0*prered)
-               {
-                       iter++;
-                       memcpy(w, w_new, sizeof(double)*n);
-                       f = fnew;
-                       fun_obj->grad(w, g);
-
-                       gnorm = dnrm2_(&n, g, &inc);
-                       if (gnorm <= eps*gnorm0)
-                               break;
-               }
-               if (f < -1.0e+32)
+               step_size = fun_obj->line_search(s, w, g, init_step_size, &fnew);
+               if (step_size == 0)
                {
-                       info("WARNING: f < -1.0e+32\n");
+                       info("WARNING: line search fails\n");
                        break;
                }
-               if (fabs(actred) <= 0 && prered <= 0)
-               {
-                       info("WARNING: actred and prered <= 0\n");
+               daxpy_(&n, &step_size, s, &inc, w, &inc);
+
+               info("iter %2d f %5.3e |g| %5.3e CG %3d step_size %5.3e \n", iter, f, gnorm, cg_iter, step_size);
+
+               f = fnew;
+               iter++;
+
+               fun_obj->grad(w, g);
+
+               gnorm = dnrm2_(&n, g, &inc);
+               if (gnorm <= eps*gnorm0)
                        break;
-               }
-               if (fabs(actred) <= 1.0e-12*fabs(f) &&
-                   fabs(prered) <= 1.0e-12*fabs(f))
+               if (f < -1.0e+32)
                {
-                       info("WARNING: actred and prered too small\n");
+                       info("WARNING: f < -1.0e+32\n");
                        break;
                }
        }
@@ -163,7 +123,7 @@ void TRON::tron(double *w)
        delete[] s;
 }
 
-int TRON::trcg(double delta, double *g, double *s, double *r)
+int TRON::trcg(double *g, double *s, double *r)
 {
        int i, inc = 1;
        int n = fun_obj->get_nr_variable();
@@ -191,26 +151,6 @@ int TRON::trcg(double delta, double *g, double *s, double *r)
 
                alpha = rTr/ddot_(&n, d, &inc, Hd, &inc);
                daxpy_(&n, &alpha, d, &inc, s, &inc);
-               if (dnrm2_(&n, s, &inc) > delta)
-               {
-                       info("cg reaches trust region boundary\n");
-                       alpha = -alpha;
-                       daxpy_(&n, &alpha, d, &inc, s, &inc);
-
-                       double std = ddot_(&n, s, &inc, d, &inc);
-                       double sts = ddot_(&n, s, &inc, s, &inc);
-                       double dtd = ddot_(&n, d, &inc, d, &inc);
-                       double dsq = delta*delta;
-                       double rad = sqrt(std*std + dtd*(dsq-sts));
-                       if (std >= 0)
-                               alpha = (dsq - sts)/(std + rad);
-                       else
-                               alpha = (rad - std)/dtd;
-                       daxpy_(&n, &alpha, d, &inc, s, &inc);
-                       alpha = -alpha;
-                       daxpy_(&n, &alpha, Hd, &inc, r, &inc);
-                       break;
-               }
                alpha = -alpha;
                daxpy_(&n, &alpha, Hd, &inc, r, &inc);
                rnewTrnew = ddot_(&n, r, &inc, r, &inc);
diff --git a/tron.h b/tron.h
index 56002dcdbd0224d469196375d1aa9e053ae4addc..a7da979f6c7ac64c3dbdf4a0b66d9bb7ad00f3c4 100644 (file)
--- a/tron.h
+++ b/tron.h
@@ -7,6 +7,7 @@ public:
        virtual double fun(double *w) = 0 ;
        virtual void grad(double *w, double *g) = 0 ;
        virtual void Hv(double *s, double *Hs) = 0 ;
+       virtual double line_search(double *s, double *w, double *g, double init_step_size, double *fnew) = 0 ;
 
        virtual int get_nr_variable(void) = 0 ;
        virtual ~function(void){}
@@ -22,7 +23,7 @@ public:
        void set_print_string(void (*i_print) (const char *buf));
 
 private:
-       int trcg(double delta, double *g, double *s, double *r);
+       int trcg(double *g, double *s, double *r);
        double norm_inf(int n, double *x);
 
        double eps;