]> granicus.if.org Git - liblinear/commitdiff
Matlab interface parameter checking
authorChun-Heng Huang <applerman@hotmail.com>
Mon, 11 Mar 2013 04:20:04 +0000 (12:20 +0800)
committerChun-Heng Huang <applerman@hotmail.com>
Mon, 11 Mar 2013 04:20:04 +0000 (12:20 +0800)
Check 'nlhs' argument.
Allow various output size for predict.

matlab/README
matlab/predict.c
matlab/train.c
windows/predict.mexw64
windows/train.mexw64

index 3c8b0eb064cbddd8d76588e6d4e02bbcfe473b9f..3259bdf178a702dbe2855e3540273cce15a7b9e1 100644 (file)
@@ -95,6 +95,7 @@ matlab> model = train(training_label_vector, training_instance_matrix [,'libline
             if 'col' is set, each column of training_instance_matrix is a data instance. Otherwise each row is a data instance.
 
 matlab> [predicted_label, accuracy, decision_values/prob_estimates] = predict(testing_label_vector, testing_instance_matrix, model [, 'liblinear_options', 'col']);
+matlab> [predicted_label] = predict(testing_label_vector, testing_instance_matrix, model [, 'liblinear_options', 'col']);
 
         -testing_label_vector:
             An m by 1 vector of prediction labels. If labels of test
index b1d56d5ee78de9d390c4861938f708712391da6b..2fd992e52d8afdeaf043b6153a8c156a994c1d35 100644 (file)
@@ -49,14 +49,14 @@ void read_sparse_instance(const mxArray *prhs, int index, struct feature_node *x
        x[j].index = -1;
 }
 
-static void fake_answer(mxArray *plhs[])
+static void fake_answer(int nlhs, mxArray *plhs[])
 {
-       plhs[0] = mxCreateDoubleMatrix(0, 0, mxREAL);
-       plhs[1] = mxCreateDoubleMatrix(0, 0, mxREAL);
-       plhs[2] = mxCreateDoubleMatrix(0, 0, mxREAL);
+       int i;
+       for(i=0;i<nlhs;i++)
+               plhs[i] = mxCreateDoubleMatrix(0, 0, mxREAL);
 }
 
-void do_predict(mxArray *plhs[], const mxArray *prhs[], struct model *model_, const int predict_probability_flag)
+void do_predict(int nlhs, mxArray *plhs[], const mxArray *prhs[], struct model *model_, const int predict_probability_flag)
 {
        int label_vector_row_num, label_vector_col_num;
        int feature_number, testing_instance_number;
@@ -65,6 +65,7 @@ void do_predict(mxArray *plhs[], const mxArray *prhs[], struct model *model_, co
        double *ptr_prob_estimates, *ptr_dec_values, *ptr;
        struct feature_node *x;
        mxArray *pplhs[1]; // instance sparse matrix in row format
+       mxArray *tplhs[3]; // temporary storage for plhs[]
 
        int correct = 0;
        int total = 0;
@@ -95,13 +96,13 @@ void do_predict(mxArray *plhs[], const mxArray *prhs[], struct model *model_, co
        if(label_vector_row_num!=testing_instance_number)
        {
                mexPrintf("Length of label vector does not match # of instances.\n");
-               fake_answer(plhs);
+               fake_answer(nlhs, plhs);
                return;
        }
        if(label_vector_col_num!=1)
        {
                mexPrintf("label (1st argument) should be a vector (# of column is 1).\n");
-               fake_answer(plhs);
+               fake_answer(nlhs, plhs);
                return;
        }
 
@@ -117,7 +118,7 @@ void do_predict(mxArray *plhs[], const mxArray *prhs[], struct model *model_, co
                if(mexCallMATLAB(1, pplhs, 1, pprhs, "transpose"))
                {
                        mexPrintf("Error: cannot transpose testing instance matrix\n");
-                       fake_answer(plhs);
+                       fake_answer(nlhs, plhs);
                        return;
                }
        }
@@ -125,15 +126,15 @@ void do_predict(mxArray *plhs[], const mxArray *prhs[], struct model *model_, co
 
        prob_estimates = Malloc(double, nr_class);
 
-       plhs[0] = mxCreateDoubleMatrix(testing_instance_number, 1, mxREAL);
+       tplhs[0] = mxCreateDoubleMatrix(testing_instance_number, 1, mxREAL);
        if(predict_probability_flag)
-               plhs[2] = mxCreateDoubleMatrix(testing_instance_number, nr_class, mxREAL);
+               tplhs[2] = mxCreateDoubleMatrix(testing_instance_number, nr_class, mxREAL);
        else
-               plhs[2] = mxCreateDoubleMatrix(testing_instance_number, nr_w, mxREAL);
+               tplhs[2] = mxCreateDoubleMatrix(testing_instance_number, nr_w, mxREAL);
 
-       ptr_predict_label = mxGetPr(plhs[0]);
-       ptr_prob_estimates = mxGetPr(plhs[2]);
-       ptr_dec_values = mxGetPr(plhs[2]);
+       ptr_predict_label = mxGetPr(tplhs[0]);
+       ptr_prob_estimates = mxGetPr(tplhs[2]);
+       ptr_dec_values = mxGetPr(tplhs[2]);
        x = Malloc(struct feature_node, feature_number+2);
        for(instance_index=0;instance_index<testing_instance_number;instance_index++)
        {
@@ -189,8 +190,8 @@ void do_predict(mxArray *plhs[], const mxArray *prhs[], struct model *model_, co
                info("Accuracy = %g%% (%d/%d)\n", (double) correct/total*100,correct,total);
 
        // return accuracy, mean squared error, squared correlation coefficient
-       plhs[1] = mxCreateDoubleMatrix(3, 1, mxREAL);
-       ptr = mxGetPr(plhs[1]);
+       tplhs[1] = mxCreateDoubleMatrix(3, 1, mxREAL);
+       ptr = mxGetPr(tplhs[1]);
        ptr[0] = (double)correct/total*100;
        ptr[1] = error/total;
        ptr[2] = ((total*sumpt-sump*sumt)*(total*sumpt-sump*sumt))/
@@ -199,12 +200,20 @@ void do_predict(mxArray *plhs[], const mxArray *prhs[], struct model *model_, co
        free(x);
        if(prob_estimates != NULL)
                free(prob_estimates);
+
+       switch(nlhs){
+               case 3: plhs[2] = tplhs[2];
+                               plhs[1] = tplhs[1];
+               case 1:
+               case 0: plhs[0] = tplhs[0];
+       }
 }
 
 void exit_with_help()
 {
        mexPrintf(
                        "Usage: [predicted_label, accuracy, decision_values/prob_estimates] = predict(testing_label_vector, testing_instance_matrix, model, 'liblinear_options','col')\n"
+                       "       [predicted_label] = predict(testing_label_vector, testing_instance_matrix, model, 'liblinear_options','col')\n"
                        "liblinear_options:\n"
                        "-b probability_estimates: whether to output probability estimates, 0 or 1 (default 0); currently for logistic regression only\n"
                        "-q quiet mode (no outputs)\n"
@@ -225,10 +234,10 @@ void mexFunction( int nlhs, mxArray *plhs[],
        info = &mexPrintf;
        col_format_flag = 0;
 
-       if(nrhs > 5 || nrhs < 3)
+       if(nlhs == 2 || nlhs > 3 || nrhs > 5 || nrhs < 3)
        {
                exit_with_help();
-               fake_answer(plhs);
+               fake_answer(nlhs, plhs);
                return;
        }
        if(nrhs == 5)
@@ -242,7 +251,7 @@ void mexFunction( int nlhs, mxArray *plhs[],
 
        if(!mxIsDouble(prhs[0]) || !mxIsDouble(prhs[1])) {
                mexPrintf("Error: label vector and instance matrix must be double\n");
-               fake_answer(plhs);
+               fake_answer(nlhs, plhs);
                return;
        }
 
@@ -269,7 +278,7 @@ void mexFunction( int nlhs, mxArray *plhs[],
                                if(i>=argc && argv[i-1][1] != 'q')
                                {
                                        exit_with_help();
-                                       fake_answer(plhs);
+                                       fake_answer(nlhs, plhs);
                                        return;
                                }
                                switch(argv[i-1][1])
@@ -284,7 +293,7 @@ void mexFunction( int nlhs, mxArray *plhs[],
                                        default:
                                                mexPrintf("unknown option\n");
                                                exit_with_help();
-                                               fake_answer(plhs);
+                                               fake_answer(nlhs, plhs);
                                                return;
                                }
                        }
@@ -296,7 +305,7 @@ void mexFunction( int nlhs, mxArray *plhs[],
                {
                        mexPrintf("Error: can't read model: %s\n", error_msg);
                        free_and_destroy_model(&model_);
-                       fake_answer(plhs);
+                       fake_answer(nlhs, plhs);
                        return;
                }
 
@@ -310,12 +319,12 @@ void mexFunction( int nlhs, mxArray *plhs[],
                }
 
                if(mxIsSparse(prhs[1]))
-                       do_predict(plhs, prhs, model_, prob_estimate_flag);
+                       do_predict(nlhs, plhs, prhs, model_, prob_estimate_flag);
                else
                {
                        mexPrintf("Testing_instance_matrix must be sparse; "
                                "use sparse(Testing_instance_matrix) first\n");
-                       fake_answer(plhs);
+                       fake_answer(nlhs, plhs);
                }
 
                // destroy model_
@@ -324,7 +333,7 @@ void mexFunction( int nlhs, mxArray *plhs[],
        else
        {
                mexPrintf("model file should be a struct array\n");
-               fake_answer(plhs);
+               fake_answer(nlhs, plhs);
        }
 
        return;
index 6351dab3e6c6530806f10f6c4557717719d2c706..1301ed5bfa86f1086e27e9073e93a3b03f6b3d7e 100644 (file)
@@ -243,9 +243,11 @@ int parse_command_line(int nrhs, const mxArray *prhs[], char *model_file_name)
        return 0;
 }
 
-static void fake_answer(mxArray *plhs[])
+static void fake_answer(int nlhs, mxArray *plhs[])
 {
-       plhs[0] = mxCreateDoubleMatrix(0, 0, mxREAL);
+       int i;
+       for(i=0;i<nlhs;i++)
+               plhs[i] = mxCreateDoubleMatrix(0, 0, mxREAL);
 }
 
 int read_problem_sparse(const mxArray *label_vec, const mxArray *instance_mat)
@@ -342,6 +344,13 @@ void mexFunction( int nlhs, mxArray *plhs[],
        // (for cross validation)
        srand(1);
 
+       if(nlhs > 1)
+       {
+               exit_with_help();
+               fake_answer(nlhs, plhs);
+               return;
+       }
+
        // Transform the input Matrix to libsvm format
        if(nrhs > 1 && nrhs < 5)
        {
@@ -349,7 +358,7 @@ void mexFunction( int nlhs, mxArray *plhs[],
 
                if(!mxIsDouble(prhs[0]) || !mxIsDouble(prhs[1])) {
                        mexPrintf("Error: label vector and instance matrix must be double\n");
-                       fake_answer(plhs);
+                       fake_answer(nlhs, plhs);
                        return;
                }
 
@@ -357,7 +366,7 @@ void mexFunction( int nlhs, mxArray *plhs[],
                {
                        exit_with_help();
                        destroy_param(&param);
-                       fake_answer(plhs);
+                       fake_answer(nlhs, plhs);
                        return;
                }
 
@@ -368,7 +377,7 @@ void mexFunction( int nlhs, mxArray *plhs[],
                        mexPrintf("Training_instance_matrix must be sparse; "
                                "use sparse(Training_instance_matrix) first\n");
                        destroy_param(&param);
-                       fake_answer(plhs);
+                       fake_answer(nlhs, plhs);
                        return;
                }
 
@@ -383,7 +392,7 @@ void mexFunction( int nlhs, mxArray *plhs[],
                        free(prob.y);
                        free(prob.x);
                        free(x_space);
-                       fake_answer(plhs);
+                       fake_answer(nlhs, plhs);
                        return;
                }
 
@@ -412,7 +421,7 @@ void mexFunction( int nlhs, mxArray *plhs[],
        else
        {
                exit_with_help();
-               fake_answer(plhs);
+               fake_answer(nlhs, plhs);
                return;
        }
 }
index 1a64777a9a8dffe0740d0b00dfe8e72662981cc7..4a7592cf9175a36b7f124799170515d54e757ad9 100755 (executable)
Binary files a/windows/predict.mexw64 and b/windows/predict.mexw64 differ
index f08699ce5717b92e028b5f9518355d1e495a8a04..9fc067666e1d903e4dd4bd56e972a8922275147f 100755 (executable)
Binary files a/windows/train.mexw64 and b/windows/train.mexw64 differ