]> granicus.if.org Git - liblinear/commitdiff
Add option -q to disable the screen output from train
authorbiconnect <biconnect@16e7d947-dcc2-db11-b54a-0017319806e7>
Wed, 1 Jul 2009 15:58:51 +0000 (15:58 +0000)
committerbiconnect <biconnect@16e7d947-dcc2-db11-b54a-0017319806e7>
Wed, 1 Jul 2009 15:58:51 +0000 (15:58 +0000)
README
linear.cpp
linear.h
matlab/train.c
train.c
tron.cpp
tron.h

diff --git a/README b/README
index 2440161922f5f49cc7030b71aa93f5a0b05d9dff..d1dbc328a151f245e3ac03b511e1fdc29c847407 100644 (file)
--- a/README
+++ b/README
@@ -110,6 +110,7 @@ options:
 -B bias : if bias >= 0, instance x becomes [x; bias]; if < 0, no bias term added (default 1)
 -wi weight: weights adjust the parameter C of different classes (see README for details)
 -v n: n-fold cross validation mode
+-q : quiet mode (no outputs)
 
 Option -v randomly splits the data into n parts and calculates cross
 validation accuracy on them.
@@ -371,6 +372,11 @@ Library Usage
 
     This function frees the memory used by a parameter set.
 
+- Variable: extern void (*liblinear_print_string) (const char *);
+
+    Users can specify their output format by
+    liblinear_print_string = &your_print_function;
+
 Building Windows Binaries
 =========================
 
index 22373edbcb53c6324477c05d6f011183358dbe93..afd4075a89d1164ed2f5fde81f9cea04ac8a2d33 100644 (file)
@@ -21,21 +21,26 @@ template <class S, class T> inline void clone(T*& dst, S* src, int n)
 #define Malloc(type,n) (type *)malloc((n)*sizeof(type))
 #define INF HUGE_VAL
 
+static void print_string_stdout(const char *s)
+{
+       fputs(s,stdout);
+       fflush(stdout);
+}
+
+void (*liblinear_print_string) (const char *) = &print_string_stdout;
+
 #if 1
 static void info(const char *fmt,...)
 {
+       char buf[BUFSIZ];
        va_list ap;
        va_start(ap,fmt);
-       vprintf(fmt,ap);
+       vsprintf(buf,fmt,ap);
        va_end(ap);
-}
-static void info_flush()
-{
-       fflush(stdout);
+       (*liblinear_print_string)(buf);
 }
 #else
-static void info(char *fmt,...) {}
-static void info_flush() {}
+static void info(const char *fmt,...) {}
 #endif
 
 class l2_lr_fun : public function
@@ -614,8 +619,7 @@ void Solver_MCSVM_CS::Solve(double *w)
                iter++;
                if(iter % 10 == 0)
                {
-                       info("."); 
-                       info_flush();
+                       info(".");
                }
 
                if(stopping < eps_shrink)
@@ -627,7 +631,7 @@ void Solver_MCSVM_CS::Solve(double *w)
                                active_size = l;
                                for(i=0;i<l;i++)
                                        active_size_i[i] = nr_class;
-                               info("*"); info_flush();
+                               info("*");
                                eps_shrink = max(eps_shrink/2, eps);
                                start_from_all = true;
                        }
@@ -831,8 +835,7 @@ static void solve_linear_c_svc(
                iter++;
                if(iter % 10 == 0)
                {
-                       info("."); 
-                       info_flush();
+                       info(".");
                }
 
                if(PGmax_new - PGmin_new <= eps)
@@ -842,7 +845,7 @@ static void solve_linear_c_svc(
                        else
                        {
                                active_size = l;
-                               info("*"); info_flush();
+                               info("*");
                                PGmax_old = INF;
                                PGmin_old = -INF;
                                continue;
@@ -960,6 +963,7 @@ void train_one(const problem *prob, const parameter *param, double *w, double Cp
                {
                        fun_obj=new l2_lr_fun(prob, Cp, Cn);
                        TRON tron_obj(fun_obj, eps*min(pos,neg)/prob->l);
+                       tron_obj.set_print_string(liblinear_print_string);
                        tron_obj.tron(w);
                        delete fun_obj;
                        break;
@@ -968,6 +972,7 @@ void train_one(const problem *prob, const parameter *param, double *w, double Cp
                {
                        fun_obj=new l2loss_svm_fun(prob, Cp, Cn);
                        TRON tron_obj(fun_obj, eps*min(pos,neg)/prob->l);
+                       tron_obj.set_print_string(liblinear_print_string);
                        tron_obj.tron(w);
                        delete fun_obj;
                        break;
index 5031d5acd79d573bb39610fe29c8351ad73710c0..ae85013d34144882aff046ac3a47d5464fea40f2 100644 (file)
--- a/linear.h
+++ b/linear.h
@@ -60,6 +60,7 @@ void get_labels(const struct model *model_, int* label);
 void destroy_model(struct model *model_);
 void destroy_param(struct parameter *param);
 const char *check_parameter(const struct problem *prob, const struct parameter *param);
+extern void (*liblinear_print_string) (const char *);
 
 #ifdef __cplusplus
 }
index 0777e10bf8f4f96ba090ab577044a42462b5d747..c97ad7b0ae5946712b2e36b1901850572f1d8d65 100644 (file)
@@ -16,6 +16,10 @@ typedef int mwIndex;
 #define Malloc(type,n) (type *)malloc((n)*sizeof(type))
 #define INF HUGE_VAL
 
+void print_null(const char *s){}
+
+void (*liblinear_default_print_string) (const char *);
+
 void exit_with_help()
 {
        mexPrintf(
@@ -89,6 +93,12 @@ int parse_command_line(int nrhs, const mxArray *prhs[], char *model_file_name)
        col_format_flag = 0;
        bias = 1;
 
+       // train loaded only once under matlab
+       if(liblinear_default_print_string == NULL)
+               liblinear_default_print_string = liblinear_print_string;
+       else
+               liblinear_print_string = liblinear_default_print_string;
+
        if(nrhs <= 1)
                return 1;
 
@@ -112,7 +122,8 @@ int parse_command_line(int nrhs, const mxArray *prhs[], char *model_file_name)
        for(i=1;i<argc;i++)
        {
                if(argv[i][0] != '-') break;
-               if(++i>=argc)
+               ++i;
+               if(i>=argc && argv[i-1][1] != 'q') // since option -q has no parameter
                        return 1;
                switch(argv[i-1][1])
                {
@@ -144,6 +155,10 @@ int parse_command_line(int nrhs, const mxArray *prhs[], char *model_file_name)
                                param.weight_label[param.nr_weight-1] = atoi(&argv[i-1][2]);
                                param.weight[param.nr_weight-1] = atof(argv[i]);
                                break;
+                       case 'q':
+                               liblinear_print_string = &print_null;
+                               i--;
+                               break;
                        default:
                                mexPrintf("unknown option\n");
                                return 1;
diff --git a/train.c b/train.c
index 62a857895fe24bfcae7d77ccc2fa11d3afa21723..903ac358913f689d406b5b87ee5d85325d6136d8 100644 (file)
--- a/train.c
+++ b/train.c
@@ -8,6 +8,8 @@
 #define Malloc(type,n) (type *)malloc((n)*sizeof(type))
 #define INF HUGE_VAL
 
+void print_null(const char *s) {}
+
 void exit_with_help()
 {
        printf(
@@ -29,6 +31,7 @@ void exit_with_help()
        "-B bias : if bias >= 0, instance x becomes [x; bias]; if < 0, no bias term added (default 1)\n"
        "-wi weight: weights adjust the parameter C of different classes (see README for details)\n"
        "-v n: n-fold cross validation mode\n"
+       "-q : quiet mode (no outputs)\n"
        );
        exit(1);
 }
@@ -179,6 +182,11 @@ void parse_command_line(int argc, char **argv, char *input_file_name, char *mode
                                }
                                break;
 
+                       case 'q':
+                               liblinear_print_string = &print_null;
+                               i--;
+                               break;
+
                        default:
                                fprintf(stderr,"unknown option: -%c\n", argv[i-1][1]);
                                exit_with_help();
index 293370a8ab021f0276d7624ced5b2de0ffce820a..5cc829d4e2ce30e70106ef9dd609bc423106f107 100644 (file)
--- a/tron.cpp
+++ b/tron.cpp
@@ -25,12 +25,28 @@ extern int dscal_(int *, double *, double *, int *);
 }
 #endif
 
+void default_print(const char *buf)
+{
+       fputs(buf,stdout);
+       fflush(stdout);
+}
+
+void TRON::info(const char *fmt,...)
+{
+       char buf[BUFSIZ];
+       va_list ap;
+       va_start(ap,fmt);
+       vsprintf(buf,fmt,ap);
+       va_end(ap);
+       (*tron_print_string)(buf);
+}
 
 TRON::TRON(const function *fun_obj, double eps, int max_iter)
 {
        this->fun_obj=const_cast<function *>(fun_obj);
        this->eps=eps;
        this->max_iter=max_iter;
+       tron_print_string = default_print;
 }
 
 TRON::~TRON()
@@ -104,7 +120,7 @@ void TRON::tron(double *w)
                else
                        delta = max(delta, min(alpha*snorm, sigma3*delta));
 
-               printf("iter %2d act %5.3e pre %5.3e delta %5.3e f %5.3e |g| %5.3e CG %3d\n", iter, actred, prered, delta, f, gnorm, cg_iter);
+               info("iter %2d act %5.3e pre %5.3e delta %5.3e f %5.3e |g| %5.3e CG %3d\n", iter, actred, prered, delta, f, gnorm, cg_iter);
 
                if (actred > eta0*prered)
                {
@@ -119,18 +135,18 @@ void TRON::tron(double *w)
                }
                if (f < -1.0e+32)
                {
-                       printf("warning: f < -1.0e+32\n");
+                       info("warning: f < -1.0e+32\n");
                        break;
                }
                if (fabs(actred) <= 0 && prered <= 0)
                {
-                       printf("warning: actred and prered <= 0\n");
+                       info("warning: actred and prered <= 0\n");
                        break;
                }
                if (fabs(actred) <= 1.0e-12*fabs(f) &&
                    fabs(prered) <= 1.0e-12*fabs(f))
                {
-                       printf("warning: actred and prered too small\n");
+                       info("warning: actred and prered too small\n");
                        break;
                }
        }
@@ -171,7 +187,7 @@ int TRON::trcg(double delta, double *g, double *s, double *r)
                daxpy_(&n, &alpha, d, &inc, s, &inc);
                if (dnrm2_(&n, s, &inc) > delta)
                {
-                       printf("cg reaches trust region boundary\n");
+                       info("cg reaches trust region boundary\n");
                        alpha = -alpha;
                        daxpy_(&n, &alpha, d, &inc, s, &inc);
 
@@ -212,3 +228,8 @@ double TRON::norm_inf(int n, double *x)
                        dmax = fabs(x[i]);
        return(dmax);
 }
+
+void TRON::set_print_string(void (*print_string) (const char *buf))
+{
+       tron_print_string = print_string;
+}
diff --git a/tron.h b/tron.h
index fe6a96bcc08fe0f6f59a70d6e699c14515eed8b0..3045c2e83a1338eb8ec148ed9bc689ea7d7a71ae 100644 (file)
--- a/tron.h
+++ b/tron.h
@@ -19,6 +19,7 @@ public:
        ~TRON();
 
        void tron(double *w);
+       void set_print_string(void (*i_print) (const char *buf));
 
 private:
        int trcg(double delta, double *g, double *s, double *r);
@@ -27,6 +28,7 @@ private:
        double eps;
        int max_iter;
        function *fun_obj;
+       void info(const char *fmt,...);
+       void (*tron_print_string)(const char *buf);
 };
-
 #endif