From 92f096797a1cb58b3f6af9d2a4e05e0ac5f9da66 Mon Sep 17 00:00:00 2001 From: Chou Hung-Yi Date: Sat, 23 May 2020 23:25:17 +0800 Subject: [PATCH] solver for one-class SVM supported -add new solver ONECLASS_SVM (-s 21) -add function solve_oneclass_svm -add new attribute rho to model -add new attribute nu to parameter -modify function check_parameter, get_decfun_bias to reject one-class SVM from accessing bias -modify/add function get_decfun_coef, get_decfun_rho, get_w_value for one-class SVM to get its decision function -add function check_oneclass_model -modify python/MATLAB interface and train.c update -REDME update --- README | 55 ++++- linear.cpp | 411 ++++++++++++++++++++++++++++++++++- linear.h | 6 +- matlab/README | 3 +- matlab/linear_model_matlab.c | 9 +- matlab/train.c | 11 +- python/README | 8 +- python/liblinear.py | 36 ++- python/liblinearutil.py | 6 +- train.c | 17 +- 10 files changed, 530 insertions(+), 32 deletions(-) diff --git a/README b/README index d78d6b4..5143680 100644 --- a/README +++ b/README @@ -1,8 +1,9 @@ LIBLINEAR is a simple package for solving large-scale regularized linear -classification and regression. It currently supports +classification, regression and outlier detection. It currently supports - L2-regularized logistic regression/L2-loss support vector classification/L1-loss support vector classification - L1-regularized L2-loss support vector classification/L1-regularized logistic regression -- L2-regularized L2-loss support vector regression/L1-loss support vector regression. +- L2-regularized L2-loss support vector regression/L1-loss support vector regression +- one-class support vector machine. This document explains the usage of LIBLINEAR. To get started, please read the ``Quick Start'' section first. @@ -114,8 +115,11 @@ options: 11 -- L2-regularized L2-loss support vector regression (primal) 12 -- L2-regularized L2-loss support vector regression (dual) 13 -- L2-regularized L1-loss support vector regression (dual) + for outlier detection + 21 -- one-class support vector machine (dual) -c cost : set the parameter C (default 1) -p epsilon : set the epsilon in loss function of epsilon-SVR (default 0.1) +-n nu : set the parameter nu of one-class SVM (default 0.5) -e epsilon : set tolerance of termination criterion -s 0 and 2 |f'(w)|_2 <= eps*min(pos,neg)/l*|f'(w0)|_2, @@ -123,12 +127,12 @@ options: positive/negative data (default 0.01) -s 11 |f'(w)|_2 <= eps*|f'(w0)|_2 (default 0.0001) - -s 1, 3, 4 and 7 - Dual maximal violation <= eps; similar to libsvm (default 0.1) + -s 1, 3, 4, 7, and 21 + Dual maximal violation <= eps; similar to libsvm (default 0.1 except 0.01 for -s 21) -s 5 and 6 |f'(w)|_1 <= eps*min(pos,neg)/l*|f'(w0)|_1, where f is the primal function (default 0.01) - -s 12 and 13\n" + -s 12 and 13 |f'(alpha)|_1 <= eps |f'(alpha0)|, where f is the dual function (default 0.1) -B bias : if bias >= 0, instance x becomes [x; bias]; if < 0, no bias term added (default -1) @@ -199,6 +203,15 @@ where Q is a matrix with Q_ij = x_i^T x_j. +For one-class SVM dual (-s 21), we solve + +min_alpha 0.5(alpha^T Q alpha) + s.t. 0 <= alpha_i <= 1 and \sum alpha_i = nu*l, + +where + +Q is a matrix with Q_ij = x_i^T x_j. + If bias >= 0, w becomes [w; w_{n+1}] and x becomes [x; bias]. The primal-dual relationship implies that -s 1 and -s 2 give the same @@ -253,6 +266,10 @@ Train linear SVM with L2-loss function. Train a logistic regression model. +> train -s 21 -n 0.1 data_file + +Train a linear one-class SVM which selects roughly 10% data as outliers. + > train -v 5 -e 0.001 data_file Do five-fold cross-validation using L2-loss SVM. @@ -365,6 +382,7 @@ in linear.h, so you can check the version number. /* these are for training only */ double eps; /* stopping criteria */ double C; + double nu; /* one-class SVM only */ int nr_weight; int *weight_label; double* weight; @@ -372,7 +390,7 @@ in linear.h, so you can check the version number. double *init_sol; }; - solver_type can be one of L2R_LR, L2R_L2LOSS_SVC_DUAL, L2R_L2LOSS_SVC, L2R_L1LOSS_SVC_DUAL, MCSVM_CS, L1R_L2LOSS_SVC, L1R_LR, L2R_LR_DUAL, L2R_L2LOSS_SVR, L2R_L2LOSS_SVR_DUAL, L2R_L1LOSS_SVR_DUAL. + solver_type can be one of L2R_LR, L2R_L2LOSS_SVC_DUAL, L2R_L2LOSS_SVC, L2R_L1LOSS_SVC_DUAL, MCSVM_CS, L1R_L2LOSS_SVC, L1R_LR, L2R_LR_DUAL, L2R_L2LOSS_SVR, L2R_L2LOSS_SVR_DUAL, L2R_L1LOSS_SVR_DUAL, ONECLASS_SVM. for classification L2R_LR L2-regularized logistic regression (primal) L2R_L2LOSS_SVC_DUAL L2-regularized L2-loss support vector classification (dual) @@ -386,9 +404,12 @@ in linear.h, so you can check the version number. L2R_L2LOSS_SVR L2-regularized L2-loss support vector regression (primal) L2R_L2LOSS_SVR_DUAL L2-regularized L2-loss support vector regression (dual) L2R_L1LOSS_SVR_DUAL L2-regularized L1-loss support vector regression (dual) + for outlier detection + ONECLASS_SVM one-class support vector machine (dual) C is the cost of constraints violation. p is the sensitiveness of loss of support vector regression. + nu in ONECLASS_SVM approximates the fraction of data as outliers. eps is the stopping criterion. nr_weight, weight_label, and weight are used to change the penalty @@ -420,6 +441,7 @@ in linear.h, so you can check the version number. double *w; int *label; /* label of each class */ double bias; + double rho; /* one-class SVM only */ }; param describes the parameters used to obtain the model. @@ -438,11 +460,13 @@ in linear.h, so you can check the version number. | for 1st feature | for 2nd feature | +------------------+------------------+------------+ + The array label stores class labels. + If bias >= 0, x becomes [x; bias]. The number of features is increased by one, so w is a (nr_feature+1)*nr_class array. The value of bias is stored in the variable bias. - The array label stores class labels. + rho is the bias term used in one-class SVM only. - Function: void cross_validation(const problem *prob, const parameter *param, int nr_fold, double *target); @@ -531,15 +555,21 @@ in linear.h, so you can check the version number. starts from 1, while label_idx starts from 0. If feat_idx is not in the valid range (1 to nr_feature), then a zero value will be returned. For classification models, if label_idx is not in the valid range (0 to - nr_class-1), then a zero value will be returned; for regression models, - label_idx is ignored. + nr_class-1), then a zero value will be returned; for regression models + and one-class SVM models, label_idx is ignored. - Function: double get_decfun_bias(const struct model *model_, int label_idx); This function gives the bias term corresponding to the class with the label_idx. For classification models, if label_idx is not in a valid range (0 to nr_class-1), then a zero value will be returned; for regression - models, label_idx is ignored. + models, label_idx is ignored. This function cannot be called for a one-class + SVM model. + +- Function: double get_decfun_rho(const struct model *model_); + + This function gives rho, the bias term used in one-class SVM only. This + function can only be called for a one-class SVM model. - Function: const char *check_parameter(const struct problem *prob, const struct parameter *param); @@ -559,6 +589,11 @@ in linear.h, so you can check the version number. This function returns 1 if the model is a regression model; otherwise it returns 0. +- Function: int check_oneclass_model(const struct model *model); + + This function returns 1 if the model is a one-class SVM model; otherwise + it returns 0. + - Function: int save_model(const char *model_file_name, const struct model *model_); diff --git a/linear.cpp b/linear.cpp index 3bdb114..c873156 100644 --- a/linear.cpp +++ b/linear.cpp @@ -70,6 +70,28 @@ public: return (ret); } + static double sparse_dot(const feature_node *x1, const feature_node *x2) + { + double ret = 0; + while(x1->index != -1 && x2->index != -1) + { + if(x1->index == x2->index) + { + ret += x1->value * x2->value; + ++x1; + ++x2; + } + else + { + if(x1->index > x2->index) + ++x2; + else + ++x1; + } + } + return (ret); + } + static void axpy(const double a, const feature_node *x, double *y) { while(x->index != -1) @@ -2017,6 +2039,342 @@ static void solve_l1r_lr( delete [] D; } +struct heap { + enum HEAP_TYPE { MIN, MAX }; + int _size; + HEAP_TYPE _type; + feature_node* a; + + heap(int max_size, HEAP_TYPE type) + { + _size = 0; + a = new feature_node[max_size]; + _type = type; + } + ~heap() + { + delete [] a; + } + bool cmp(const feature_node& left, const feature_node& right) + { + if(_type == MIN) + return left.value > right.value; + else + return left.value < right.value; + } + int size() + { + return _size; + } + void push(feature_node node) + { + a[_size] = node; + _size++; + int i = _size-1; + while(i) + { + int p = (i-1)/2; + if(cmp(a[p], a[i])) + { + swap(a[i], a[p]); + i = p; + } + else + break; + } + } + void pop() + { + _size--; + a[0] = a[_size]; + int i = 0; + while(i*2+1 < _size) + { + int l = i*2+1; + int r = i*2+2; + if(r < _size && cmp(a[l], a[r])) + l = r; + if(cmp(a[i], a[l])) + { + swap(a[i], a[l]); + i = l; + } + else + break; + } + } + feature_node top() + { + return a[0]; + } +}; + +// A two-level coordinate descent algorithm for +// a scaled one-class SVM dual problem +// +// min_\alpha 0.5(\alpha^T Q \alpha), +// s.t. 0 <= \alpha_i <= 1 and +// e^T \alpha = \nu l +// +// where Qij = xi^T xj +// +// Given: +// x, nu +// eps is the stopping tolerance +// +// solution will be put in w and rho +// +// See Algorithm 7 in supplementary materials of Chou et al., SDM 2020. + +static void solve_oneclass_svm(const problem *prob, double *w, double *rho, double eps, double nu) +{ + int l = prob->l; + int w_size = prob->n; + int i, j, s, iter = 0; + double Gi, Gj; + double Qij, quad_coef, delta, sum; + double old_alpha_i; + double *QD = new double[l]; + double *G = new double[l]; + int *index = new int[l]; + double *alpha = new double[l]; + int max_inner_iter; + int max_iter = 1000; + int active_size = l; + + double negGmax; // max { -grad(f)_i | alpha_i < 1 } + double negGmin; // min { -grad(f)_i | alpha_i > 0 } + + int *most_violating_i = new int[l]; + int *most_violating_j = new int[l]; + + int n = (int)(nu*l); // # of alpha's at upper bound + for(i=0; ix[i]; + QD[i] = sparse_operator::nrm2_sq(xi); + sparse_operator::axpy(alpha[i], xi, w); + + index[i] = i; + } + + while (iter < max_iter) + { + negGmax = -INF; + negGmin = INF; + + for (s=0; sx[i]; + G[i] = sparse_operator::dot(w, xi); + if (alpha[i] < 1) + negGmax = max(negGmax, -G[i]); + if (alpha[i] > 0) + negGmin = min(negGmin, -G[i]); + } + + if (negGmax - negGmin < eps) + { + if (active_size == l) + break; + else + { + active_size = l; + info("*"); + continue; + } + } + + for(s=0; s negGmax) || + (alpha[i] == 0 && -G[i] < negGmin)) + { + active_size--; + swap(index[s], index[active_size]); + s--; + } + } + + max_inner_iter = max(active_size/10, 1); + struct heap min_heap = heap(max_inner_iter, heap::MIN); + struct heap max_heap = heap(max_inner_iter, heap::MAX); + struct feature_node node; + for(s=0; s 0) + { + if (max_heap.size() < max_inner_iter) + max_heap.push(node); + else if (max_heap.top().value > node.value) + { + max_heap.pop(); + max_heap.push(node); + } + } + } + max_inner_iter = min(min_heap.size(), max_heap.size()); + while (max_heap.size() > max_inner_iter) + max_heap.pop(); + while (min_heap.size() > max_inner_iter) + min_heap.pop(); + + for (s=max_inner_iter-1; s>=0; s--) + { + most_violating_i[s] = min_heap.top().index; + most_violating_j[s] = max_heap.top().index; + min_heap.pop(); + max_heap.pop(); + } + + for (s=0; sx[i]; + feature_node const * xj = prob->x[j]; + + Gi = sparse_operator::dot(w, xi); + Gj = sparse_operator::dot(w, xj); + + int violating_pair = 0; + if (alpha[i] < 1 && alpha[j] > 0 && -Gj + 1e-12 < -Gi) + violating_pair = 1; + else + if (alpha[i] > 0 && alpha[j] < 1 && -Gi + 1e-12 < -Gj) + violating_pair = 1; + if (violating_pair == 0) + continue; + + Qij = sparse_operator::sparse_dot(xi, xj); + quad_coef = QD[i] + QD[j] - 2*Qij; + if(quad_coef <= 0) + quad_coef = 1e-12; + delta = (Gi - Gj) / quad_coef; + old_alpha_i = alpha[i]; + sum = alpha[i] + alpha[j]; + alpha[i] = alpha[i] - delta; + alpha[j] = alpha[j] + delta; + if (sum > 1) + { + if (alpha[i] > 1) + { + alpha[i] = 1; + alpha[j] = sum - 1; + } + } + else + { + if (alpha[j] < 0) + { + alpha[j] = 0; + alpha[i] = sum; + } + } + if (sum > 1) + { + if (alpha[j] > 1) + { + alpha[j] = 1; + alpha[i] = sum - 1; + } + } + else + { + if (alpha[i] < 0) + { + alpha[i] = 0; + alpha[j] = sum; + } + } + delta = alpha[i] - old_alpha_i; + sparse_operator::axpy(delta, xi, w); + sparse_operator::axpy(-delta, xj, w); + } + iter++; + if (iter % 10 == 0) + info("."); + } + info("\noptimization finished, #iter = %d\n",iter); + if (iter >= max_iter) + info("\nWARNING: reaching max number of iterations\n\n"); + + // calculate object value + double v = 0; + for(i=0; i 0) + ++nSV; + } + info("Objective value = %lf\n", v/2); + info("nSV = %d\n", nSV); + + // calculate rho + double nr_free = 0; + double ub = INF, lb = -INF, sum_free = 0; + for(i=0; ix[i]); + if (alpha[i] == 0) + lb = max(lb, G); + else if (alpha[i] == 1) + ub = min(ub, G); + else + { + ++nr_free; + sum_free += G; + } + } + + if (nr_free > 0) + *rho = sum_free/nr_free; + else + *rho = (ub + lb)/2; + + info("rho = %lf\n", *rho); + + delete [] QD; + delete [] G; + delete [] index; + delete [] alpha; + delete [] most_violating_i; + delete [] most_violating_j; +} + // transpose matrix X from row format to column format static void transpose(const problem *prob, feature_node **x_space_ret, problem *prob_col) { @@ -2473,6 +2831,13 @@ model* train(const problem *prob, const parameter *param) model_->label = NULL; train_one(prob, param, model_->w, 0, 0); } + else if(check_oneclass_model(model_)) + { + model_->w = Malloc(double, w_size); + model_->nr_class = 2; + model_->label = NULL; + solve_oneclass_svm(prob, model_->w, &(model_->rho), param->eps, param->nu); + } else { int nr_class; @@ -2793,11 +3158,15 @@ double predict_values(const struct model *model_, const struct feature_node *x, for(i=0;ivalue; } + if(check_oneclass_model(model_)) + dec_values[0] -= model_->rho; if(nr_class==2) { if(check_regression_model(model_)) return dec_values[0]; + else if(check_oneclass_model(model_)) + return (dec_values[0]>0)?1:-1; else return (dec_values[0]>0)?model_->label[0]:model_->label[1]; } @@ -2860,7 +3229,9 @@ static const char *solver_type_table[]= "L2R_LR", "L2R_L2LOSS_SVC_DUAL", "L2R_L2LOSS_SVC", "L2R_L1LOSS_SVC_DUAL", "MCSVM_CS", "L1R_L2LOSS_SVC", "L1R_LR", "L2R_LR_DUAL", "", "", "", - "L2R_L2LOSS_SVR", "L2R_L2LOSS_SVR_DUAL", "L2R_L1LOSS_SVR_DUAL", NULL + "L2R_L2LOSS_SVR", "L2R_L2LOSS_SVR_DUAL", "L2R_L1LOSS_SVR_DUAL", + "", "", "", "", "", "", "", + "ONECLASS_SVM", NULL }; int save_model(const char *model_file_name, const struct model *model_) @@ -2906,6 +3277,9 @@ int save_model(const char *model_file_name, const struct model *model_) fprintf(fp, "bias %.17g\n", model_->bias); + if(check_oneclass_model(model_)) + fprintf(fp, "rho %.17g\n", model_->rho); + fprintf(fp, "w\n"); for(i=0; iparam; // parameters for training only won't be assigned, but arrays are assigned as NULL for safety @@ -3010,6 +3385,11 @@ struct model *load_model(const char *model_file_name) FSCANF(fp,"%lf",&bias); model_->bias=bias; } + else if(strcmp(cmd,"rho")==0) + { + FSCANF(fp,"%lf",&rho); + model_->rho=rho; + } else if(strcmp(cmd,"w")==0) { break; @@ -3082,7 +3462,7 @@ static inline double get_w_value(const struct model *model_, int idx, int label_ if(idx < 0 || idx > model_->nr_feature) return 0; - if(check_regression_model(model_)) + if(check_regression_model(model_) || check_oneclass_model(model_)) return w[idx]; else { @@ -3112,6 +3492,11 @@ double get_decfun_coef(const struct model *model_, int feat_idx, int label_idx) double get_decfun_bias(const struct model *model_, int label_idx) { + if(check_oneclass_model(model_)) + { + fprintf(stderr, "ERROR: get_decfun_bias can not be called for a one-class SVM model\n"); + return 0; + } int bias_idx = model_->nr_feature; double bias = model_->bias; if(bias <= 0) @@ -3120,6 +3505,17 @@ double get_decfun_bias(const struct model *model_, int label_idx) return bias*get_w_value(model_, bias_idx, label_idx); } +double get_decfun_rho(const struct model *model_) +{ + if(check_oneclass_model(model_)) + return model_->rho; + else + { + fprintf(stderr, "ERROR: get_decfun_rho can be called only for a one-class SVM model\n"); + return 0; + } +} + void free_model_content(struct model *model_ptr) { if(model_ptr->w != NULL) @@ -3159,6 +3555,9 @@ const char *check_parameter(const problem *prob, const parameter *param) if(param->p < 0) return "p < 0"; + if(prob->bias >= 0 && param->solver_type == ONECLASS_SVM) + return "prob->bias >=0, but this is ignored in ONECLASS_SVM"; + if(param->solver_type != L2R_LR && param->solver_type != L2R_L2LOSS_SVC_DUAL && param->solver_type != L2R_L2LOSS_SVC @@ -3169,7 +3568,8 @@ const char *check_parameter(const problem *prob, const parameter *param) && param->solver_type != L2R_LR_DUAL && param->solver_type != L2R_L2LOSS_SVR && param->solver_type != L2R_L2LOSS_SVR_DUAL - && param->solver_type != L2R_L1LOSS_SVR_DUAL) + && param->solver_type != L2R_L1LOSS_SVR_DUAL + && param->solver_type != ONECLASS_SVM) return "unknown solver type"; if(param->init_sol != NULL @@ -3195,6 +3595,11 @@ int check_regression_model(const struct model *model_) model_->param.solver_type==L2R_L2LOSS_SVR_DUAL); } +int check_oneclass_model(const struct model *model_) +{ + return model_->param.solver_type == ONECLASS_SVM; +} + void set_print_string_function(void (*print_func)(const char*)) { if (print_func == NULL) diff --git a/linear.h b/linear.h index 7ad9606..2419188 100644 --- a/linear.h +++ b/linear.h @@ -23,7 +23,7 @@ struct problem double bias; /* < 0 if no bias term */ }; -enum { L2R_LR, L2R_L2LOSS_SVC_DUAL, L2R_L2LOSS_SVC, L2R_L1LOSS_SVC_DUAL, MCSVM_CS, L1R_L2LOSS_SVC, L1R_LR, L2R_LR_DUAL, L2R_L2LOSS_SVR = 11, L2R_L2LOSS_SVR_DUAL, L2R_L1LOSS_SVR_DUAL }; /* solver_type */ +enum { L2R_LR, L2R_L2LOSS_SVC_DUAL, L2R_L2LOSS_SVC, L2R_L1LOSS_SVC_DUAL, MCSVM_CS, L1R_L2LOSS_SVC, L1R_LR, L2R_LR_DUAL, L2R_L2LOSS_SVR = 11, L2R_L2LOSS_SVR_DUAL, L2R_L1LOSS_SVR_DUAL, ONECLASS_SVM = 21 }; /* solver_type */ struct parameter { @@ -36,6 +36,7 @@ struct parameter int *weight_label; double* weight; double p; + double nu; double *init_sol; }; @@ -47,6 +48,7 @@ struct model double *w; int *label; /* label of each class */ double bias; + double rho; /* one-class SVM only */ }; struct model* train(const struct problem *prob, const struct parameter *param); @@ -65,6 +67,7 @@ int get_nr_class(const struct model *model_); void get_labels(const struct model *model_, int* label); double get_decfun_coef(const struct model *model_, int feat_idx, int label_idx); double get_decfun_bias(const struct model *model_, int label_idx); +double get_decfun_rho(const struct model *model_); void free_model_content(struct model *model_ptr); void free_and_destroy_model(struct model **model_ptr_ptr); @@ -73,6 +76,7 @@ void destroy_param(struct parameter *param); const char *check_parameter(const struct problem *prob, const struct parameter *param); int check_probability_model(const struct model *model); int check_regression_model(const struct model *model); +int check_oneclass_model(const struct model *model); void set_print_string_function(void (*print_func) (const char*)); #ifdef __cplusplus diff --git a/matlab/README b/matlab/README index 8a32826..af656a0 100644 --- a/matlab/README +++ b/matlab/README @@ -109,7 +109,7 @@ Returned Model Structure The 'train' function returns a model which can be used for future prediction. It is a structure and is organized as [Parameters, nr_class, -nr_feature, bias, Label, w]: +nr_feature, bias, Label, w, rho]: -Parameters: Parameters (now only solver type is provided) -nr_class: number of classes; = 2 for regression @@ -122,6 +122,7 @@ nr_feature, bias, Label, w]: nr_w is 1 if nr_class=2 and -s is not 4 (i.e., not multi-class svm by Crammer and Singer). It is nr_class otherwise. + -rho: the bias term of one-class SVM. If the '-v' option is specified, cross validation is conducted and the returned model is just a scalar: cross-validation accuracy for diff --git a/matlab/linear_model_matlab.c b/matlab/linear_model_matlab.c index f11b451..a048b6d 100644 --- a/matlab/linear_model_matlab.c +++ b/matlab/linear_model_matlab.c @@ -12,7 +12,7 @@ typedef int mwIndex; #define Malloc(type,n) (type *)malloc((n)*sizeof(type)) -#define NUM_OF_RETURN_FIELD 6 +#define NUM_OF_RETURN_FIELD 7 static const char *field_names[] = { "Parameters", @@ -21,6 +21,7 @@ static const char *field_names[] = { "bias", "Label", "w", + "rho", }; const char *model_to_matlab_structure(mxArray *plhs[], struct model *model_) @@ -89,6 +90,12 @@ const char *model_to_matlab_structure(mxArray *plhs[], struct model *model_) ptr[i]=model_->w[i]; out_id++; + // rho + rhs[out_id] = mxCreateDoubleMatrix(1, 1, mxREAL); + ptr = mxGetPr(rhs[out_id]); + ptr[0] = model_->rho; + out_id++; + /* Create a struct matrix contains NUM_OF_RETURN_FIELD fields */ return_model = mxCreateStructMatrix(1, 1, NUM_OF_RETURN_FIELD, field_names); diff --git a/matlab/train.c b/matlab/train.c index 1771aee..63cd4b6 100644 --- a/matlab/train.c +++ b/matlab/train.c @@ -41,6 +41,7 @@ void exit_with_help() " 13 -- L2-regularized L1-loss support vector regression (dual)\n" "-c cost : set the parameter C (default 1)\n" "-p epsilon : set the epsilon in loss function of SVR (default 0.1)\n" + "-n nu : set the parameter nu of one-class SVM (default 0.5)\n" "-e epsilon : set tolerance of termination criterion\n" " -s 0 and 2\n" " |f'(w)|_2 <= eps*min(pos,neg)/l*|f'(w0)|_2,\n" @@ -48,8 +49,8 @@ void exit_with_help() " positive/negative data (default 0.01)\n" " -s 11\n" " |f'(w)|_2 <= eps*|f'(w0)|_2 (default 0.0001)\n" - " -s 1, 3, 4 and 7\n" - " Dual maximal violation <= eps; similar to libsvm (default 0.1)\n" + " -s 1, 3, 4, 7, and 21\n" + " Dual maximal violation <= eps; similar to libsvm (default 0.1 except 0.01 for -s 21)\n" " -s 5 and 6\n" " |f'(w)|_1 <= eps*min(pos,neg)/l*|f'(w0)|_1,\n" " where f is the primal function (default 0.01)\n" @@ -214,6 +215,9 @@ int parse_command_line(int nrhs, const mxArray *prhs[], char *model_file_name) param.p = atof(argv[i]); flag_p_specified = 1; break; + case 'n': + param.nu = atof(argv[i]); + break; case 'e': param.eps = atof(argv[i]); break; @@ -294,6 +298,9 @@ int parse_command_line(int nrhs, const mxArray *prhs[], char *model_file_name) case L2R_L2LOSS_SVR_DUAL: param.eps = 0.1; break; + case ONECLASS_SVM: + param.eps = 0.01; + break; } } return 0; diff --git a/python/README b/python/README index f4e52f6..1403ab9 100644 --- a/python/README +++ b/python/README @@ -280,12 +280,18 @@ LIBLINEAR shared library: Note that w_k is a Python list of length nr_feature, which means that w_k[0] = W_k1. For regression models, W is just a vector of length nr_feature. Either - set label_idx=0 or omit the label_idx parameter to access the coefficients. + set label_idx=0 or omit the label_idx parameter to access the coefficients. >>> W_j = model_.get_decfun_coef(feat_idx=j) >>> b = model_.get_decfun_bias() >>> [W, b] = model_.get_decfun() + For one-class SVM models, label_idx is ignored and b=-rho is returned. That + is, the decision function is w*x+b=w*x-rho. + + >>> rho = model_.get_decfun_rho() + >>> [W, rho] = model_.get_decfun() + Note that in get_decfun_coef, get_decfun_bias, and get_decfun, feat_idx starts from 1, while label_idx starts from 0. If label_idx is not in the valid range (0 to nr_class-1), then a NaN will be returned; and if feat_idx diff --git a/python/liblinear.py b/python/liblinear.py index eb66608..ae2bd77 100644 --- a/python/liblinear.py +++ b/python/liblinear.py @@ -20,14 +20,15 @@ __all__ = ['liblinear', 'feature_node', 'gen_feature_nodearray', 'problem', 'parameter', 'model', 'toPyModel', 'L2R_LR', 'L2R_L2LOSS_SVC_DUAL', 'L2R_L2LOSS_SVC', 'L2R_L1LOSS_SVC_DUAL', 'MCSVM_CS', 'L1R_L2LOSS_SVC', 'L1R_LR', 'L2R_LR_DUAL', 'L2R_L2LOSS_SVR', - 'L2R_L2LOSS_SVR_DUAL', 'L2R_L1LOSS_SVR_DUAL', 'print_null'] + 'L2R_L2LOSS_SVR_DUAL', 'L2R_L1LOSS_SVR_DUAL', 'ONECLASS_SVM', + 'print_null'] try: dirname = path.dirname(path.abspath(__file__)) if sys.platform == 'win32': liblinear = CDLL(path.join(dirname, r'..\windows\liblinear.dll')) else: - liblinear = CDLL(path.join(dirname, '../liblinear.so.3')) + liblinear = CDLL(path.join(dirname, '../liblinear.so.4')) except: # For unix the prefix 'lib' is not considered. if find_library('linear'): @@ -48,6 +49,7 @@ L2R_LR_DUAL = 7 L2R_L2LOSS_SVR = 11 L2R_L2LOSS_SVR_DUAL = 12 L2R_L1LOSS_SVR_DUAL = 13 +ONECLASS_SVM = 21 PRINT_STRING_FUN = CFUNCTYPE(None, c_char_p) def print_null(s): @@ -226,8 +228,8 @@ class problem(Structure): class parameter(Structure): - _names = ["solver_type", "eps", "C", "nr_weight", "weight_label", "weight", "p", "init_sol"] - _types = [c_int, c_double, c_double, c_int, POINTER(c_int), POINTER(c_double), c_double, POINTER(c_double)] + _names = ["solver_type", "eps", "C", "nr_weight", "weight_label", "weight", "p", "nu", "init_sol"] + _types = [c_int, c_double, c_double, c_int, POINTER(c_int), POINTER(c_double), c_double, c_double, POINTER(c_double)] _fields_ = genFields(_names, _types) def __init__(self, options = None): @@ -250,6 +252,7 @@ class parameter(Structure): self.eps = float('inf') self.C = 1 self.p = 0.1 + self.nu = 0.5 self.nr_weight = 0 self.weight_label = None self.weight = None @@ -289,6 +292,9 @@ class parameter(Structure): i = i + 1 self.p = float(argv[i]) self.flag_p_specified = True + elif argv[i] == "-n": + i = i + 1 + self.nu = float(argv[i]) elif argv[i] == "-e": i = i + 1 self.eps = float(argv[i]) @@ -343,10 +349,12 @@ class parameter(Structure): self.eps = 0.01 elif self.solver_type in [L2R_L2LOSS_SVR_DUAL, L2R_L1LOSS_SVR_DUAL]: self.eps = 0.1 + elif self.solver_type in [ONECLASS_SVM]: + self.eps = 0.01 class model(Structure): - _names = ["param", "nr_class", "nr_feature", "w", "label", "bias"] - _types = [parameter, c_int, c_int, POINTER(c_double), POINTER(c_int), c_double] + _names = ["param", "nr_class", "nr_feature", "w", "label", "bias", "rho"] + _types = [parameter, c_int, c_int, POINTER(c_double), POINTER(c_int), c_double, c_double] _fields_ = genFields(_names, _types) def __init__(self): @@ -375,10 +383,17 @@ class model(Structure): def get_decfun_bias(self, label_idx=0): return liblinear.get_decfun_bias(self, label_idx) + def get_decfun_rho(self): + return liblinear.get_decfun_rho(self) + def get_decfun(self, label_idx=0): w = [liblinear.get_decfun_coef(self, feat_idx, label_idx) for feat_idx in range(1, self.nr_feature+1)] - b = liblinear.get_decfun_bias(self, label_idx) - return (w, b) + if self.is_oneclass_model(): + rho = self.get_decfun_rho() + return (w, -rho) + else: + b = liblinear.get_decfun_bias(self, label_idx) + return (w, b) def is_probability_model(self): return (liblinear.check_probability_model(self) == 1) @@ -386,6 +401,9 @@ class model(Structure): def is_regression_model(self): return (liblinear.check_regression_model(self) == 1) + def is_oneclass_model(self): + return (liblinear.check_oneclass_model(self) == 1) + def toPyModel(model_ptr): """ toPyModel(model_ptr) -> model @@ -414,6 +432,7 @@ fillprototype(liblinear.get_nr_class, c_int, [POINTER(model)]) fillprototype(liblinear.get_labels, None, [POINTER(model), POINTER(c_int)]) fillprototype(liblinear.get_decfun_coef, c_double, [POINTER(model), c_int, c_int]) fillprototype(liblinear.get_decfun_bias, c_double, [POINTER(model), c_int]) +fillprototype(liblinear.get_decfun_rho, c_double, [POINTER(model)]) fillprototype(liblinear.free_model_content, None, [POINTER(model)]) fillprototype(liblinear.free_and_destroy_model, None, [POINTER(POINTER(model))]) @@ -421,4 +440,5 @@ fillprototype(liblinear.destroy_param, None, [POINTER(parameter)]) fillprototype(liblinear.check_parameter, c_char_p, [POINTER(problem), POINTER(parameter)]) fillprototype(liblinear.check_probability_model, c_int, [POINTER(model)]) fillprototype(liblinear.check_regression_model, c_int, [POINTER(model)]) +fillprototype(liblinear.check_oneclass_model, c_int, [POINTER(model)]) fillprototype(liblinear.set_print_string_function, None, [CFUNCTYPE(None, c_char_p)]) diff --git a/python/liblinearutil.py b/python/liblinearutil.py index a768ecf..9427974 100644 --- a/python/liblinearutil.py +++ b/python/liblinearutil.py @@ -75,6 +75,8 @@ def train(arg1, arg2=None, arg3=None): 11 -- L2-regularized L2-loss support vector regression (primal) 12 -- L2-regularized L2-loss support vector regression (dual) 13 -- L2-regularized L1-loss support vector regression (dual) + for outlier detection + 21 -- one-class support vector machine (dual) -c cost : set the parameter C (default 1) -p epsilon : set the epsilon in loss function of SVR (default 0.1) -e epsilon : set tolerance of termination criterion @@ -83,8 +85,8 @@ def train(arg1, arg2=None, arg3=None): where f is the primal function, (default 0.01) -s 11 |f'(w)|_2 <= eps*|f'(w0)|_2 (default 0.0001) - -s 1, 3, 4, and 7 - Dual maximal violation <= eps; similar to liblinear (default 0.) + -s 1, 3, 4, 7, and 21 + Dual maximal violation <= eps; similar to libsvm (default 0.1 except 0.01 for -s 21) -s 5 and 6 |f'(w)|_inf <= eps*min(pos,neg)/l*|f'(w0)|_inf, where f is the primal function (default 0.01) diff --git a/train.c b/train.c index f544dca..bd0af94 100644 --- a/train.c +++ b/train.c @@ -29,8 +29,11 @@ void exit_with_help() " 11 -- L2-regularized L2-loss support vector regression (primal)\n" " 12 -- L2-regularized L2-loss support vector regression (dual)\n" " 13 -- L2-regularized L1-loss support vector regression (dual)\n" + " for outlier detection\n" + " 21 -- one-class support vector machine (dual)\n" "-c cost : set the parameter C (default 1)\n" "-p epsilon : set the epsilon in loss function of SVR (default 0.1)\n" + "-n nu : set the parameter nu of one-class SVM (default 0.5)\n" "-e epsilon : set tolerance of termination criterion\n" " -s 0 and 2\n" " |f'(w)|_2 <= eps*min(pos,neg)/l*|f'(w0)|_2,\n" @@ -38,8 +41,8 @@ void exit_with_help() " positive/negative data (default 0.01)\n" " -s 11\n" " |f'(w)|_2 <= eps*|f'(w0)|_2 (default 0.0001)\n" - " -s 1, 3, 4, and 7\n" - " Dual maximal violation <= eps; similar to libsvm (default 0.1)\n" + " -s 1, 3, 4, 7, and 21\n" + " Dual maximal violation <= eps; similar to libsvm (default 0.1 except 0.01 for -s 21)\n" " -s 5 and 6\n" " |f'(w)|_1 <= eps*min(pos,neg)/l*|f'(w0)|_1,\n" " where f is the primal function (default 0.01)\n" @@ -211,8 +214,9 @@ void parse_command_line(int argc, char **argv, char *input_file_name, char *mode // default values param.solver_type = L2R_L2LOSS_SVC_DUAL; param.C = 1; - param.eps = INF; // see setting below param.p = 0.1; + param.nu = 0.5; + param.eps = INF; // see setting below param.nr_weight = 0; param.weight_label = NULL; param.weight = NULL; @@ -247,6 +251,10 @@ void parse_command_line(int argc, char **argv, char *input_file_name, char *mode param.p = atof(argv[i]); break; + case 'n': + param.nu = atof(argv[i]); + break; + case 'e': param.eps = atof(argv[i]); break; @@ -352,6 +360,9 @@ void parse_command_line(int argc, char **argv, char *input_file_name, char *mode case L2R_L2LOSS_SVR_DUAL: param.eps = 0.1; break; + case ONECLASS_SVM: + param.eps = 0.01; + break; } } } -- 2.50.1