From: Chun-Heng Huang Date: Mon, 11 Mar 2013 04:20:04 +0000 (+0800) Subject: Matlab interface parameter checking X-Git-Tag: v194~14 X-Git-Url: https://granicus.if.org/sourcecode?a=commitdiff_plain;h=3a30c797bd2e570dfcb736475264f25d0bf79e99;p=liblinear Matlab interface parameter checking Check 'nlhs' argument. Allow various output size for predict. --- diff --git a/matlab/README b/matlab/README index 3c8b0eb..3259bdf 100644 --- a/matlab/README +++ b/matlab/README @@ -95,6 +95,7 @@ matlab> model = train(training_label_vector, training_instance_matrix [,'libline if 'col' is set, each column of training_instance_matrix is a data instance. Otherwise each row is a data instance. matlab> [predicted_label, accuracy, decision_values/prob_estimates] = predict(testing_label_vector, testing_instance_matrix, model [, 'liblinear_options', 'col']); +matlab> [predicted_label] = predict(testing_label_vector, testing_instance_matrix, model [, 'liblinear_options', 'col']); -testing_label_vector: An m by 1 vector of prediction labels. If labels of test diff --git a/matlab/predict.c b/matlab/predict.c index b1d56d5..2fd992e 100644 --- a/matlab/predict.c +++ b/matlab/predict.c @@ -49,14 +49,14 @@ void read_sparse_instance(const mxArray *prhs, int index, struct feature_node *x x[j].index = -1; } -static void fake_answer(mxArray *plhs[]) +static void fake_answer(int nlhs, mxArray *plhs[]) { - plhs[0] = mxCreateDoubleMatrix(0, 0, mxREAL); - plhs[1] = mxCreateDoubleMatrix(0, 0, mxREAL); - plhs[2] = mxCreateDoubleMatrix(0, 0, mxREAL); + int i; + for(i=0;i 5 || nrhs < 3) + if(nlhs == 2 || nlhs > 3 || nrhs > 5 || nrhs < 3) { exit_with_help(); - fake_answer(plhs); + fake_answer(nlhs, plhs); return; } if(nrhs == 5) @@ -242,7 +251,7 @@ void mexFunction( int nlhs, mxArray *plhs[], if(!mxIsDouble(prhs[0]) || !mxIsDouble(prhs[1])) { mexPrintf("Error: label vector and instance matrix must be double\n"); - fake_answer(plhs); + fake_answer(nlhs, plhs); return; } @@ -269,7 +278,7 @@ void mexFunction( int nlhs, mxArray *plhs[], if(i>=argc && argv[i-1][1] != 'q') { exit_with_help(); - fake_answer(plhs); + fake_answer(nlhs, plhs); return; } switch(argv[i-1][1]) @@ -284,7 +293,7 @@ void mexFunction( int nlhs, mxArray *plhs[], default: mexPrintf("unknown option\n"); exit_with_help(); - fake_answer(plhs); + fake_answer(nlhs, plhs); return; } } @@ -296,7 +305,7 @@ void mexFunction( int nlhs, mxArray *plhs[], { mexPrintf("Error: can't read model: %s\n", error_msg); free_and_destroy_model(&model_); - fake_answer(plhs); + fake_answer(nlhs, plhs); return; } @@ -310,12 +319,12 @@ void mexFunction( int nlhs, mxArray *plhs[], } if(mxIsSparse(prhs[1])) - do_predict(plhs, prhs, model_, prob_estimate_flag); + do_predict(nlhs, plhs, prhs, model_, prob_estimate_flag); else { mexPrintf("Testing_instance_matrix must be sparse; " "use sparse(Testing_instance_matrix) first\n"); - fake_answer(plhs); + fake_answer(nlhs, plhs); } // destroy model_ @@ -324,7 +333,7 @@ void mexFunction( int nlhs, mxArray *plhs[], else { mexPrintf("model file should be a struct array\n"); - fake_answer(plhs); + fake_answer(nlhs, plhs); } return; diff --git a/matlab/train.c b/matlab/train.c index 6351dab..1301ed5 100644 --- a/matlab/train.c +++ b/matlab/train.c @@ -243,9 +243,11 @@ int parse_command_line(int nrhs, const mxArray *prhs[], char *model_file_name) return 0; } -static void fake_answer(mxArray *plhs[]) +static void fake_answer(int nlhs, mxArray *plhs[]) { - plhs[0] = mxCreateDoubleMatrix(0, 0, mxREAL); + int i; + for(i=0;i 1) + { + exit_with_help(); + fake_answer(nlhs, plhs); + return; + } + // Transform the input Matrix to libsvm format if(nrhs > 1 && nrhs < 5) { @@ -349,7 +358,7 @@ void mexFunction( int nlhs, mxArray *plhs[], if(!mxIsDouble(prhs[0]) || !mxIsDouble(prhs[1])) { mexPrintf("Error: label vector and instance matrix must be double\n"); - fake_answer(plhs); + fake_answer(nlhs, plhs); return; } @@ -357,7 +366,7 @@ void mexFunction( int nlhs, mxArray *plhs[], { exit_with_help(); destroy_param(¶m); - fake_answer(plhs); + fake_answer(nlhs, plhs); return; } @@ -368,7 +377,7 @@ void mexFunction( int nlhs, mxArray *plhs[], mexPrintf("Training_instance_matrix must be sparse; " "use sparse(Training_instance_matrix) first\n"); destroy_param(¶m); - fake_answer(plhs); + fake_answer(nlhs, plhs); return; } @@ -383,7 +392,7 @@ void mexFunction( int nlhs, mxArray *plhs[], free(prob.y); free(prob.x); free(x_space); - fake_answer(plhs); + fake_answer(nlhs, plhs); return; } @@ -412,7 +421,7 @@ void mexFunction( int nlhs, mxArray *plhs[], else { exit_with_help(); - fake_answer(plhs); + fake_answer(nlhs, plhs); return; } } diff --git a/windows/predict.mexw64 b/windows/predict.mexw64 index 1a64777..4a7592c 100755 Binary files a/windows/predict.mexw64 and b/windows/predict.mexw64 differ diff --git a/windows/train.mexw64 b/windows/train.mexw64 index f08699c..9fc0676 100755 Binary files a/windows/train.mexw64 and b/windows/train.mexw64 differ