]> granicus.if.org Git - liblinear/commitdiff
solver for one-class SVM supported
authorChou Hung-Yi <s1243221@gmail.com>
Sat, 23 May 2020 15:25:17 +0000 (23:25 +0800)
committerChou Hung-Yi <s1243221@gmail.com>
Sat, 23 May 2020 15:42:50 +0000 (23:42 +0800)
    -add new solver ONECLASS_SVM (-s 21)
    -add function solve_oneclass_svm
    -add new attribute rho to model
    -add new attribute nu to parameter
    -modify function check_parameter, get_decfun_bias
        to reject one-class SVM from accessing bias
    -modify/add function get_decfun_coef, get_decfun_rho, get_w_value
        for one-class SVM to get its decision function
    -add function check_oneclass_model
    -modify python/MATLAB interface and train.c update
    -REDME update

README
linear.cpp
linear.h
matlab/README
matlab/linear_model_matlab.c
matlab/train.c
python/README
python/liblinear.py
python/liblinearutil.py
train.c

diff --git a/README b/README
index d78d6b48529beaa9b921066e2bc57e66a03c6535..51436804db732ebaffa8e2918fd6229cf6a6bec4 100644 (file)
--- a/README
+++ b/README
@@ -1,8 +1,9 @@
 LIBLINEAR is a simple package for solving large-scale regularized linear
-classification and regression. It currently supports
+classification, regression and outlier detection. It currently supports
 - L2-regularized logistic regression/L2-loss support vector classification/L1-loss support vector classification
 - L1-regularized L2-loss support vector classification/L1-regularized logistic regression
-- L2-regularized L2-loss support vector regression/L1-loss support vector regression.
+- L2-regularized L2-loss support vector regression/L1-loss support vector regression
+- one-class support vector machine.
 This document explains the usage of LIBLINEAR.
 
 To get started, please read the ``Quick Start'' section first.
@@ -114,8 +115,11 @@ options:
        11 -- L2-regularized L2-loss support vector regression (primal)
        12 -- L2-regularized L2-loss support vector regression (dual)
        13 -- L2-regularized L1-loss support vector regression (dual)
+  for outlier detection
+       21 -- one-class support vector machine (dual)
 -c cost : set the parameter C (default 1)
 -p epsilon : set the epsilon in loss function of epsilon-SVR (default 0.1)
+-n nu : set the parameter nu of one-class SVM (default 0.5)
 -e epsilon : set tolerance of termination criterion
        -s 0 and 2
                |f'(w)|_2 <= eps*min(pos,neg)/l*|f'(w0)|_2,
@@ -123,12 +127,12 @@ options:
                positive/negative data (default 0.01)
        -s 11
                |f'(w)|_2 <= eps*|f'(w0)|_2 (default 0.0001)
-       -s 1, 3, 4 and 7
-               Dual maximal violation <= eps; similar to libsvm (default 0.1)
+       -s 1, 3, 4, 7, and 21
+               Dual maximal violation <= eps; similar to libsvm (default 0.1 except 0.01 for -s 21)
        -s 5 and 6
                |f'(w)|_1 <= eps*min(pos,neg)/l*|f'(w0)|_1,
                where f is the primal function (default 0.01)
-       -s 12 and 13\n"
+       -s 12 and 13
                |f'(alpha)|_1 <= eps |f'(alpha0)|,
                where f is the dual function (default 0.1)
 -B bias : if bias >= 0, instance x becomes [x; bias]; if < 0, no bias term added (default -1)
@@ -199,6 +203,15 @@ where
 
 Q is a matrix with Q_ij = x_i^T x_j.
 
+For one-class SVM dual (-s 21), we solve
+
+min_alpha 0.5(alpha^T Q alpha)
+    s.t.   0 <= alpha_i <= 1 and \sum alpha_i = nu*l,
+
+where
+
+Q is a matrix with Q_ij = x_i^T x_j.
+
 If bias >= 0, w becomes [w; w_{n+1}] and x becomes [x; bias].
 
 The primal-dual relationship implies that -s 1 and -s 2 give the same
@@ -253,6 +266,10 @@ Train linear SVM with L2-loss function.
 
 Train a logistic regression model.
 
+> train -s 21 -n 0.1 data_file
+
+Train a linear one-class SVM which selects roughly 10% data as outliers.
+
 > train -v 5 -e 0.001 data_file
 
 Do five-fold cross-validation using L2-loss SVM.
@@ -365,6 +382,7 @@ in linear.h, so you can check the version number.
                 /* these are for training only */
                 double eps;             /* stopping criteria */
                 double C;
+                double nu;              /* one-class SVM only */
                 int nr_weight;
                 int *weight_label;
                 double* weight;
@@ -372,7 +390,7 @@ in linear.h, so you can check the version number.
                 double *init_sol;
         };
 
-    solver_type can be one of L2R_LR, L2R_L2LOSS_SVC_DUAL, L2R_L2LOSS_SVC, L2R_L1LOSS_SVC_DUAL, MCSVM_CS, L1R_L2LOSS_SVC, L1R_LR, L2R_LR_DUAL, L2R_L2LOSS_SVR, L2R_L2LOSS_SVR_DUAL, L2R_L1LOSS_SVR_DUAL.
+    solver_type can be one of L2R_LR, L2R_L2LOSS_SVC_DUAL, L2R_L2LOSS_SVC, L2R_L1LOSS_SVC_DUAL, MCSVM_CS, L1R_L2LOSS_SVC, L1R_LR, L2R_LR_DUAL, L2R_L2LOSS_SVR, L2R_L2LOSS_SVR_DUAL, L2R_L1LOSS_SVR_DUAL, ONECLASS_SVM.
   for classification
     L2R_LR                L2-regularized logistic regression (primal)
     L2R_L2LOSS_SVC_DUAL   L2-regularized L2-loss support vector classification (dual)
@@ -386,9 +404,12 @@ in linear.h, so you can check the version number.
     L2R_L2LOSS_SVR        L2-regularized L2-loss support vector regression (primal)
     L2R_L2LOSS_SVR_DUAL   L2-regularized L2-loss support vector regression (dual)
     L2R_L1LOSS_SVR_DUAL   L2-regularized L1-loss support vector regression (dual)
+  for outlier detection
+    ONECLASS_SVM          one-class support vector machine (dual)
 
     C is the cost of constraints violation.
     p is the sensitiveness of loss of support vector regression.
+    nu in ONECLASS_SVM approximates the fraction of data as outliers.
     eps is the stopping criterion.
 
     nr_weight, weight_label, and weight are used to change the penalty
@@ -420,6 +441,7 @@ in linear.h, so you can check the version number.
                 double *w;
                 int *label;             /* label of each class */
                 double bias;
+                double rho;             /* one-class SVM only */
         };
 
      param describes the parameters used to obtain the model.
@@ -438,11 +460,13 @@ in linear.h, so you can check the version number.
      | for 1st feature  | for 2nd feature  |
      +------------------+------------------+------------+
 
+     The array label stores class labels.
+
      If bias >= 0, x becomes [x; bias]. The number of features is
      increased by one, so w is a (nr_feature+1)*nr_class array. The
      value of bias is stored in the variable bias.
 
-     The array label stores class labels.
+     rho is the bias term used in one-class SVM only.
 
 - Function: void cross_validation(const problem *prob, const parameter *param, int nr_fold, double *target);
 
@@ -531,15 +555,21 @@ in linear.h, so you can check the version number.
     starts from 1, while label_idx starts from 0. If feat_idx is not in the
     valid range (1 to nr_feature), then a zero value will be returned. For
     classification models, if label_idx is not in the valid range (0 to
-    nr_class-1), then a zero value will be returned; for regression models,
-    label_idx is ignored.
+    nr_class-1), then a zero value will be returned; for regression models
+    and one-class SVM models, label_idx is ignored.
 
 - Function: double get_decfun_bias(const struct model *model_, int label_idx);
 
     This function gives the bias term corresponding to the class with the
     label_idx. For classification models, if label_idx is not in a valid range
     (0 to nr_class-1), then a zero value will be returned; for regression
-    models, label_idx is ignored.
+    models, label_idx is ignored. This function cannot be called for a one-class
+    SVM model.
+
+- Function: double get_decfun_rho(const struct model *model_);
+
+    This function gives rho, the bias term used in one-class SVM only. This
+    function can only be called for a one-class SVM model.
 
 - Function: const char *check_parameter(const struct problem *prob,
             const struct parameter *param);
@@ -559,6 +589,11 @@ in linear.h, so you can check the version number.
     This function returns 1 if the model is a regression model; otherwise
     it returns 0.
 
+- Function: int check_oneclass_model(const struct model *model);
+
+    This function returns 1 if the model is a one-class SVM model; otherwise
+    it returns 0.
+
 - Function: int save_model(const char *model_file_name,
             const struct model *model_);
 
index 3bdb114d7bc69a6b496cfd71f776fe35339e8677..c873156930df6dfa4a2235ac60e7811834875da5 100644 (file)
@@ -70,6 +70,28 @@ public:
                return (ret);
        }
 
+       static double sparse_dot(const feature_node *x1, const feature_node *x2)
+       {
+               double ret = 0;
+               while(x1->index != -1 && x2->index != -1)
+               {
+                       if(x1->index == x2->index)
+                       {
+                               ret += x1->value * x2->value;
+                               ++x1;
+                               ++x2;
+                       }
+                       else
+                       {
+                               if(x1->index > x2->index)
+                                       ++x2;
+                               else
+                                       ++x1;
+                       }
+               }
+               return (ret);
+       }
+
        static void axpy(const double a, const feature_node *x, double *y)
        {
                while(x->index != -1)
@@ -2017,6 +2039,342 @@ static void solve_l1r_lr(
        delete [] D;
 }
 
+struct heap {
+       enum HEAP_TYPE { MIN, MAX };
+       int _size;
+       HEAP_TYPE _type;
+       feature_node* a;
+
+       heap(int max_size, HEAP_TYPE type)
+       {
+               _size = 0;
+               a = new feature_node[max_size];
+               _type = type;
+       }
+       ~heap()
+       {
+               delete [] a;
+       }
+       bool cmp(const feature_node& left, const feature_node& right)
+       {
+               if(_type == MIN)
+                       return left.value > right.value;
+               else
+                       return left.value < right.value;
+       }
+       int size()
+       {
+               return _size;
+       }
+       void push(feature_node node)
+       {
+               a[_size] = node;
+               _size++;
+               int i = _size-1;
+               while(i)
+               {
+                       int p = (i-1)/2;
+                       if(cmp(a[p], a[i]))
+                       {
+                               swap(a[i], a[p]);
+                               i = p;
+                       }
+                       else
+                               break;
+               }
+       }
+       void pop()
+       {
+               _size--;
+               a[0] = a[_size];
+               int i = 0;
+               while(i*2+1 < _size)
+               {
+                       int l = i*2+1;
+                       int r = i*2+2;
+                       if(r < _size && cmp(a[l], a[r]))
+                               l = r;
+                       if(cmp(a[i], a[l]))
+                       {
+                               swap(a[i], a[l]);
+                               i = l;
+                       }
+                       else
+                               break;
+               }
+       }
+       feature_node top()
+       {
+               return a[0];
+       }
+};
+
+// A two-level coordinate descent algorithm for
+// a scaled one-class SVM dual problem
+//
+//  min_\alpha  0.5(\alpha^T Q \alpha),
+//    s.t.      0 <= \alpha_i <= 1 and
+//              e^T \alpha = \nu l
+//
+//  where Qij = xi^T xj
+//
+// Given:
+// x, nu
+// eps is the stopping tolerance
+//
+// solution will be put in w and rho
+//
+// See Algorithm 7 in supplementary materials of Chou et al., SDM 2020.
+
+static void solve_oneclass_svm(const problem *prob, double *w, double *rho, double eps, double nu)
+{
+       int l = prob->l;
+       int w_size = prob->n;
+       int i, j, s, iter = 0;
+       double Gi, Gj;
+       double Qij, quad_coef, delta, sum;
+       double old_alpha_i;
+       double *QD = new double[l];
+       double *G = new double[l];
+       int *index = new int[l];
+       double *alpha = new double[l];
+       int max_inner_iter;
+       int max_iter = 1000;
+       int active_size = l;
+
+       double negGmax;                 // max { -grad(f)_i | alpha_i < 1 }
+       double negGmin;                 // min { -grad(f)_i | alpha_i > 0 }
+
+       int *most_violating_i = new int[l];
+       int *most_violating_j = new int[l];
+
+       int n = (int)(nu*l);            // # of alpha's at upper bound
+       for(i=0; i<n; i++)
+               alpha[i] = 1;
+       if (n<l)
+               alpha[i] = nu*l-n;
+       for(i=n+1; i<l; i++)
+               alpha[i] = 0;
+
+       for(i=0; i<w_size; i++)
+               w[i] = 0;
+       for(i=0; i<l; i++)
+       {
+               feature_node * const xi = prob->x[i];
+               QD[i] = sparse_operator::nrm2_sq(xi);
+               sparse_operator::axpy(alpha[i], xi, w);
+
+               index[i] = i;
+       }
+
+       while (iter < max_iter)
+       {
+               negGmax = -INF;
+               negGmin = INF;
+
+               for (s=0; s<active_size; s++)
+               {
+                       i = index[s];
+                       feature_node * const xi = prob->x[i];
+                       G[i] = sparse_operator::dot(w, xi);
+                       if (alpha[i] < 1)
+                               negGmax = max(negGmax, -G[i]);
+                       if (alpha[i] > 0)
+                               negGmin = min(negGmin, -G[i]);
+               }
+
+               if (negGmax - negGmin < eps)
+               {
+                       if (active_size == l)
+                               break;
+                       else
+                       {
+                               active_size = l;
+                               info("*");
+                               continue;
+                       }
+               }
+
+               for(s=0; s<active_size; s++)
+               {
+                       i = index[s];
+                       if ((alpha[i] == 1 && -G[i] > negGmax) ||
+                           (alpha[i] == 0 && -G[i] < negGmin))
+                       {
+                               active_size--;
+                               swap(index[s], index[active_size]);
+                               s--;
+                       }
+               }
+
+               max_inner_iter = max(active_size/10, 1);
+               struct heap min_heap = heap(max_inner_iter, heap::MIN);
+               struct heap max_heap = heap(max_inner_iter, heap::MAX);
+               struct feature_node node;
+               for(s=0; s<active_size; s++)
+               {
+                       i = index[s];
+                       node.index = i;
+                       node.value = -G[i];
+
+                       if (alpha[i] < 1)
+                       {
+                               if (min_heap.size() < max_inner_iter)
+                                       min_heap.push(node);
+                               else if (min_heap.top().value < node.value)
+                               {
+                                       min_heap.pop();
+                                       min_heap.push(node);
+                               }
+                       }
+
+                       if (alpha[i] > 0)
+                       {
+                               if (max_heap.size() < max_inner_iter)
+                                       max_heap.push(node);
+                               else if (max_heap.top().value > node.value)
+                               {
+                                       max_heap.pop();
+                                       max_heap.push(node);
+                               }
+                       }
+               }
+               max_inner_iter = min(min_heap.size(), max_heap.size());
+               while (max_heap.size() > max_inner_iter)
+                       max_heap.pop();
+               while (min_heap.size() > max_inner_iter)
+                       min_heap.pop();
+
+               for (s=max_inner_iter-1; s>=0; s--)
+               {
+                       most_violating_i[s] = min_heap.top().index;
+                       most_violating_j[s] = max_heap.top().index;
+                       min_heap.pop();
+                       max_heap.pop();
+               }
+
+               for (s=0; s<max_inner_iter; s++)
+               {
+                       i = most_violating_i[s];
+                       j = most_violating_j[s];
+
+                       if ((alpha[i] == 0 && alpha[j] == 0) ||
+                           (alpha[i] == 1 && alpha[j] == 1))
+                               continue;
+
+                       feature_node const * xi = prob->x[i];
+                       feature_node const * xj = prob->x[j];
+
+                       Gi = sparse_operator::dot(w, xi);
+                       Gj = sparse_operator::dot(w, xj);
+
+                       int violating_pair = 0;
+                       if (alpha[i] < 1 && alpha[j] > 0 && -Gj + 1e-12 < -Gi)
+                               violating_pair = 1;
+                       else
+                               if (alpha[i] > 0 && alpha[j] < 1 && -Gi + 1e-12 < -Gj)
+                                       violating_pair = 1;
+                       if (violating_pair == 0)
+                               continue;
+
+                       Qij = sparse_operator::sparse_dot(xi, xj);
+                       quad_coef = QD[i] + QD[j] - 2*Qij;
+                       if(quad_coef <= 0)
+                               quad_coef = 1e-12;
+                       delta = (Gi - Gj) / quad_coef;
+                       old_alpha_i = alpha[i];
+                       sum = alpha[i] + alpha[j];
+                       alpha[i] = alpha[i] - delta;
+                       alpha[j] = alpha[j] + delta;
+                       if (sum > 1)
+                       {
+                               if (alpha[i] > 1)
+                               {
+                                       alpha[i] = 1;
+                                       alpha[j] = sum - 1;
+                               }
+                       }
+                       else
+                       {
+                               if (alpha[j] < 0)
+                               {
+                                       alpha[j] = 0;
+                                       alpha[i] = sum;
+                               }
+                       }
+                       if (sum > 1)
+                       {
+                               if (alpha[j] > 1)
+                               {
+                                       alpha[j] = 1;
+                                       alpha[i] = sum - 1;
+                               }
+                       }
+                       else
+                       {
+                               if (alpha[i] < 0)
+                               {
+                                       alpha[i] = 0;
+                                       alpha[j] = sum;
+                               }
+                       }
+                       delta = alpha[i] - old_alpha_i;
+                       sparse_operator::axpy(delta, xi, w);
+                       sparse_operator::axpy(-delta, xj, w);
+               }
+               iter++;
+               if (iter % 10 == 0)
+                       info(".");
+       }
+       info("\noptimization finished, #iter = %d\n",iter);
+       if (iter >= max_iter)
+               info("\nWARNING: reaching max number of iterations\n\n");
+
+       // calculate object value
+       double v = 0;
+       for(i=0; i<w_size; i++)
+               v += w[i]*w[i];
+       int nSV = 0;
+       for(i=0; i<l; i++)
+       {
+               if (alpha[i] > 0)
+                       ++nSV;
+       }
+       info("Objective value = %lf\n", v/2);
+       info("nSV = %d\n", nSV);
+
+       // calculate rho
+       double nr_free = 0;
+       double ub = INF, lb = -INF, sum_free = 0;
+       for(i=0; i<l; i++)
+       {
+               double G = sparse_operator::dot(w, prob->x[i]);
+               if (alpha[i] == 0)
+                       lb = max(lb, G);
+               else if (alpha[i] == 1)
+                       ub = min(ub, G);
+               else
+               {
+                       ++nr_free;
+                       sum_free += G;
+               }
+       }
+
+       if (nr_free > 0)
+               *rho = sum_free/nr_free;
+       else
+               *rho = (ub + lb)/2;
+
+       info("rho = %lf\n", *rho);
+
+       delete [] QD;
+       delete [] G;
+       delete [] index;
+       delete [] alpha;
+       delete [] most_violating_i;
+       delete [] most_violating_j;
+}
+
 // transpose matrix X from row format to column format
 static void transpose(const problem *prob, feature_node **x_space_ret, problem *prob_col)
 {
@@ -2473,6 +2831,13 @@ model* train(const problem *prob, const parameter *param)
                model_->label = NULL;
                train_one(prob, param, model_->w, 0, 0);
        }
+       else if(check_oneclass_model(model_))
+       {
+               model_->w = Malloc(double, w_size);
+               model_->nr_class = 2;
+               model_->label = NULL;
+               solve_oneclass_svm(prob, model_->w, &(model_->rho), param->eps, param->nu);
+       }
        else
        {
                int nr_class;
@@ -2793,11 +3158,15 @@ double predict_values(const struct model *model_, const struct feature_node *x,
                        for(i=0;i<nr_w;i++)
                                dec_values[i] += w[(idx-1)*nr_w+i]*lx->value;
        }
+       if(check_oneclass_model(model_))
+               dec_values[0] -= model_->rho;
 
        if(nr_class==2)
        {
                if(check_regression_model(model_))
                        return dec_values[0];
+               else if(check_oneclass_model(model_))
+                       return (dec_values[0]>0)?1:-1;
                else
                        return (dec_values[0]>0)?model_->label[0]:model_->label[1];
        }
@@ -2860,7 +3229,9 @@ static const char *solver_type_table[]=
        "L2R_LR", "L2R_L2LOSS_SVC_DUAL", "L2R_L2LOSS_SVC", "L2R_L1LOSS_SVC_DUAL", "MCSVM_CS",
        "L1R_L2LOSS_SVC", "L1R_LR", "L2R_LR_DUAL",
        "", "", "",
-       "L2R_L2LOSS_SVR", "L2R_L2LOSS_SVR_DUAL", "L2R_L1LOSS_SVR_DUAL", NULL
+       "L2R_L2LOSS_SVR", "L2R_L2LOSS_SVR_DUAL", "L2R_L1LOSS_SVR_DUAL",
+       "", "", "", "", "", "", "",
+       "ONECLASS_SVM", NULL
 };
 
 int save_model(const char *model_file_name, const struct model *model_)
@@ -2906,6 +3277,9 @@ int save_model(const char *model_file_name, const struct model *model_)
 
        fprintf(fp, "bias %.17g\n", model_->bias);
 
+       if(check_oneclass_model(model_))
+               fprintf(fp, "rho %.17g\n", model_->rho);
+
        fprintf(fp, "w\n");
        for(i=0; i<w_size; i++)
        {
@@ -2956,6 +3330,7 @@ struct model *load_model(const char *model_file_name)
        int n;
        int nr_class;
        double bias;
+       double rho;
        model *model_ = Malloc(model,1);
        parameter& param = model_->param;
        // parameters for training only won't be assigned, but arrays are assigned as NULL for safety
@@ -3010,6 +3385,11 @@ struct model *load_model(const char *model_file_name)
                        FSCANF(fp,"%lf",&bias);
                        model_->bias=bias;
                }
+               else if(strcmp(cmd,"rho")==0)
+               {
+                       FSCANF(fp,"%lf",&rho);
+                       model_->rho=rho;
+               }
                else if(strcmp(cmd,"w")==0)
                {
                        break;
@@ -3082,7 +3462,7 @@ static inline double get_w_value(const struct model *model_, int idx, int label_
 
        if(idx < 0 || idx > model_->nr_feature)
                return 0;
-       if(check_regression_model(model_))
+       if(check_regression_model(model_) || check_oneclass_model(model_))
                return w[idx];
        else
        {
@@ -3112,6 +3492,11 @@ double get_decfun_coef(const struct model *model_, int feat_idx, int label_idx)
 
 double get_decfun_bias(const struct model *model_, int label_idx)
 {
+       if(check_oneclass_model(model_))
+       {
+               fprintf(stderr, "ERROR: get_decfun_bias can not be called for a one-class SVM model\n");
+               return 0;
+       }
        int bias_idx = model_->nr_feature;
        double bias = model_->bias;
        if(bias <= 0)
@@ -3120,6 +3505,17 @@ double get_decfun_bias(const struct model *model_, int label_idx)
                return bias*get_w_value(model_, bias_idx, label_idx);
 }
 
+double get_decfun_rho(const struct model *model_)
+{
+       if(check_oneclass_model(model_))
+               return model_->rho;
+       else
+       {
+               fprintf(stderr, "ERROR: get_decfun_rho can be called only for a one-class SVM model\n");
+               return 0;
+       }
+}
+
 void free_model_content(struct model *model_ptr)
 {
        if(model_ptr->w != NULL)
@@ -3159,6 +3555,9 @@ const char *check_parameter(const problem *prob, const parameter *param)
        if(param->p < 0)
                return "p < 0";
 
+       if(prob->bias >= 0 && param->solver_type == ONECLASS_SVM)
+               return "prob->bias >=0, but this is ignored in ONECLASS_SVM";
+
        if(param->solver_type != L2R_LR
                && param->solver_type != L2R_L2LOSS_SVC_DUAL
                && param->solver_type != L2R_L2LOSS_SVC
@@ -3169,7 +3568,8 @@ const char *check_parameter(const problem *prob, const parameter *param)
                && param->solver_type != L2R_LR_DUAL
                && param->solver_type != L2R_L2LOSS_SVR
                && param->solver_type != L2R_L2LOSS_SVR_DUAL
-               && param->solver_type != L2R_L1LOSS_SVR_DUAL)
+               && param->solver_type != L2R_L1LOSS_SVR_DUAL
+               && param->solver_type != ONECLASS_SVM)
                return "unknown solver type";
 
        if(param->init_sol != NULL
@@ -3195,6 +3595,11 @@ int check_regression_model(const struct model *model_)
                        model_->param.solver_type==L2R_L2LOSS_SVR_DUAL);
 }
 
+int check_oneclass_model(const struct model *model_)
+{
+       return model_->param.solver_type == ONECLASS_SVM;
+}
+
 void set_print_string_function(void (*print_func)(const char*))
 {
        if (print_func == NULL)
index 7ad9606d9dc7a9c4a4ed35e099916d0e273c162a..2419188cf029ee503beaa69abad59e5b8dda3478 100644 (file)
--- a/linear.h
+++ b/linear.h
@@ -23,7 +23,7 @@ struct problem
        double bias;            /* < 0 if no bias term */
 };
 
-enum { L2R_LR, L2R_L2LOSS_SVC_DUAL, L2R_L2LOSS_SVC, L2R_L1LOSS_SVC_DUAL, MCSVM_CS, L1R_L2LOSS_SVC, L1R_LR, L2R_LR_DUAL, L2R_L2LOSS_SVR = 11, L2R_L2LOSS_SVR_DUAL, L2R_L1LOSS_SVR_DUAL }; /* solver_type */
+enum { L2R_LR, L2R_L2LOSS_SVC_DUAL, L2R_L2LOSS_SVC, L2R_L1LOSS_SVC_DUAL, MCSVM_CS, L1R_L2LOSS_SVC, L1R_LR, L2R_LR_DUAL, L2R_L2LOSS_SVR = 11, L2R_L2LOSS_SVR_DUAL, L2R_L1LOSS_SVR_DUAL, ONECLASS_SVM = 21 }; /* solver_type */
 
 struct parameter
 {
@@ -36,6 +36,7 @@ struct parameter
        int *weight_label;
        double* weight;
        double p;
+       double nu;
        double *init_sol;
 };
 
@@ -47,6 +48,7 @@ struct model
        double *w;
        int *label;             /* label of each class */
        double bias;
+       double rho;             /* one-class SVM only */
 };
 
 struct model* train(const struct problem *prob, const struct parameter *param);
@@ -65,6 +67,7 @@ int get_nr_class(const struct model *model_);
 void get_labels(const struct model *model_, int* label);
 double get_decfun_coef(const struct model *model_, int feat_idx, int label_idx);
 double get_decfun_bias(const struct model *model_, int label_idx);
+double get_decfun_rho(const struct model *model_);
 
 void free_model_content(struct model *model_ptr);
 void free_and_destroy_model(struct model **model_ptr_ptr);
@@ -73,6 +76,7 @@ void destroy_param(struct parameter *param);
 const char *check_parameter(const struct problem *prob, const struct parameter *param);
 int check_probability_model(const struct model *model);
 int check_regression_model(const struct model *model);
+int check_oneclass_model(const struct model *model);
 void set_print_string_function(void (*print_func) (const char*));
 
 #ifdef __cplusplus
index 8a32826f5e90a4ceae59467b84539953371cda7e..af656a0e9a1cdfbf68eac8b12c6587d5ebbd63d9 100644 (file)
@@ -109,7 +109,7 @@ Returned Model Structure
 
 The 'train' function returns a model which can be used for future
 prediction.  It is a structure and is organized as [Parameters, nr_class,
-nr_feature, bias, Label, w]:
+nr_feature, bias, Label, w, rho]:
 
         -Parameters: Parameters (now only solver type is provided)
         -nr_class: number of classes; = 2 for regression
@@ -122,6 +122,7 @@ nr_feature, bias, Label, w]:
             nr_w is 1 if nr_class=2 and -s is not 4 (i.e., not
             multi-class svm by Crammer and Singer). It is
             nr_class otherwise.
+        -rho: the bias term of one-class SVM.
 
 If the '-v' option is specified, cross validation is conducted and the
 returned model is just a scalar: cross-validation accuracy for
index f11b4516493d35596263eecf7f2b9d8f64549540..a048b6dda054660938d2c58822dd777dd9689c3e 100644 (file)
@@ -12,7 +12,7 @@ typedef int mwIndex;
 
 #define Malloc(type,n) (type *)malloc((n)*sizeof(type))
 
-#define NUM_OF_RETURN_FIELD 6
+#define NUM_OF_RETURN_FIELD 7
 
 static const char *field_names[] = {
        "Parameters",
@@ -21,6 +21,7 @@ static const char *field_names[] = {
        "bias",
        "Label",
        "w",
+       "rho",
 };
 
 const char *model_to_matlab_structure(mxArray *plhs[], struct model *model_)
@@ -89,6 +90,12 @@ const char *model_to_matlab_structure(mxArray *plhs[], struct model *model_)
                ptr[i]=model_->w[i];
        out_id++;
 
+       // rho
+       rhs[out_id] = mxCreateDoubleMatrix(1, 1, mxREAL);
+       ptr = mxGetPr(rhs[out_id]);
+       ptr[0] = model_->rho;
+       out_id++;
+
        /* Create a struct matrix contains NUM_OF_RETURN_FIELD fields */
        return_model = mxCreateStructMatrix(1, 1, NUM_OF_RETURN_FIELD, field_names);
 
index 1771aee076f543c7767655ae9d4d1dffea444c8f..63cd4b60f5ee09b2bc439186ed491eb4297687f8 100644 (file)
@@ -41,6 +41,7 @@ void exit_with_help()
        "       13 -- L2-regularized L1-loss support vector regression (dual)\n"
        "-c cost : set the parameter C (default 1)\n"
        "-p epsilon : set the epsilon in loss function of SVR (default 0.1)\n"
+       "-n nu : set the parameter nu of one-class SVM (default 0.5)\n"
        "-e epsilon : set tolerance of termination criterion\n"
        "       -s 0 and 2\n"
        "               |f'(w)|_2 <= eps*min(pos,neg)/l*|f'(w0)|_2,\n"
@@ -48,8 +49,8 @@ void exit_with_help()
        "               positive/negative data (default 0.01)\n"
        "       -s 11\n"
        "               |f'(w)|_2 <= eps*|f'(w0)|_2 (default 0.0001)\n"
-       "       -s 1, 3, 4 and 7\n"
-       "               Dual maximal violation <= eps; similar to libsvm (default 0.1)\n"
+       "       -s 1, 3, 4, 7, and 21\n"
+       "               Dual maximal violation <= eps; similar to libsvm (default 0.1 except 0.01 for -s 21)\n"
        "       -s 5 and 6\n"
        "               |f'(w)|_1 <= eps*min(pos,neg)/l*|f'(w0)|_1,\n"
        "               where f is the primal function (default 0.01)\n"
@@ -214,6 +215,9 @@ int parse_command_line(int nrhs, const mxArray *prhs[], char *model_file_name)
                                param.p = atof(argv[i]);
                                flag_p_specified = 1;
                                break;
+                       case 'n':
+                               param.nu = atof(argv[i]);
+                               break;
                        case 'e':
                                param.eps = atof(argv[i]);
                                break;
@@ -294,6 +298,9 @@ int parse_command_line(int nrhs, const mxArray *prhs[], char *model_file_name)
                        case L2R_L2LOSS_SVR_DUAL:
                                param.eps = 0.1;
                                break;
+                       case ONECLASS_SVM:
+                               param.eps = 0.01;
+                               break;
                }
        }
        return 0;
index f4e52f6365bdf2831e60956fc38d5ef15882b297..1403ab93343745ac790e526c8de9e7f76161b1e6 100644 (file)
@@ -280,12 +280,18 @@ LIBLINEAR shared library:
     Note that w_k is a Python list of length nr_feature, which means that
         w_k[0] = W_k1.
     For regression models, W is just a vector of length nr_feature. Either
-       set label_idx=0 or omit the label_idx parameter to access the coefficients.
+    set label_idx=0 or omit the label_idx parameter to access the coefficients.
 
     >>> W_j = model_.get_decfun_coef(feat_idx=j)
     >>> b = model_.get_decfun_bias()
     >>> [W, b] = model_.get_decfun()
 
+    For one-class SVM models, label_idx is ignored and b=-rho is returned. That
+    is, the decision function is w*x+b=w*x-rho.
+
+    >>> rho = model_.get_decfun_rho()
+    >>> [W, rho] = model_.get_decfun()
+
     Note that in get_decfun_coef, get_decfun_bias, and get_decfun, feat_idx
     starts from 1, while label_idx starts from 0. If label_idx is not in the
     valid range (0 to nr_class-1), then a NaN will be returned; and if feat_idx
index eb666080e8ce217e1f4040edbf64814e3b3f61bb..ae2bd770994a8ec91b1d99abe2578b53a661c0ca 100644 (file)
@@ -20,14 +20,15 @@ __all__ = ['liblinear', 'feature_node', 'gen_feature_nodearray', 'problem',
            'parameter', 'model', 'toPyModel', 'L2R_LR', 'L2R_L2LOSS_SVC_DUAL',
            'L2R_L2LOSS_SVC', 'L2R_L1LOSS_SVC_DUAL', 'MCSVM_CS',
            'L1R_L2LOSS_SVC', 'L1R_LR', 'L2R_LR_DUAL', 'L2R_L2LOSS_SVR',
-           'L2R_L2LOSS_SVR_DUAL', 'L2R_L1LOSS_SVR_DUAL', 'print_null']
+           'L2R_L2LOSS_SVR_DUAL', 'L2R_L1LOSS_SVR_DUAL', 'ONECLASS_SVM',
+           'print_null']
 
 try:
        dirname = path.dirname(path.abspath(__file__))
        if sys.platform == 'win32':
                liblinear = CDLL(path.join(dirname, r'..\windows\liblinear.dll'))
        else:
-               liblinear = CDLL(path.join(dirname, '../liblinear.so.3'))
+               liblinear = CDLL(path.join(dirname, '../liblinear.so.4'))
 except:
 # For unix the prefix 'lib' is not considered.
        if find_library('linear'):
@@ -48,6 +49,7 @@ L2R_LR_DUAL = 7
 L2R_L2LOSS_SVR = 11
 L2R_L2LOSS_SVR_DUAL = 12
 L2R_L1LOSS_SVR_DUAL = 13
+ONECLASS_SVM = 21
 
 PRINT_STRING_FUN = CFUNCTYPE(None, c_char_p)
 def print_null(s):
@@ -226,8 +228,8 @@ class problem(Structure):
 
 
 class parameter(Structure):
-       _names = ["solver_type", "eps", "C", "nr_weight", "weight_label", "weight", "p", "init_sol"]
-       _types = [c_int, c_double, c_double, c_int, POINTER(c_int), POINTER(c_double), c_double, POINTER(c_double)]
+       _names = ["solver_type", "eps", "C", "nr_weight", "weight_label", "weight", "p", "nu", "init_sol"]
+       _types = [c_int, c_double, c_double, c_int, POINTER(c_int), POINTER(c_double), c_double, c_double, POINTER(c_double)]
        _fields_ = genFields(_names, _types)
 
        def __init__(self, options = None):
@@ -250,6 +252,7 @@ class parameter(Structure):
                self.eps = float('inf')
                self.C = 1
                self.p = 0.1
+               self.nu = 0.5
                self.nr_weight = 0
                self.weight_label = None
                self.weight = None
@@ -289,6 +292,9 @@ class parameter(Structure):
                                i = i + 1
                                self.p = float(argv[i])
                                self.flag_p_specified = True
+                       elif argv[i] == "-n":
+                               i = i + 1
+                               self.nu = float(argv[i])
                        elif argv[i] == "-e":
                                i = i + 1
                                self.eps = float(argv[i])
@@ -343,10 +349,12 @@ class parameter(Structure):
                                self.eps = 0.01
                        elif self.solver_type in [L2R_L2LOSS_SVR_DUAL, L2R_L1LOSS_SVR_DUAL]:
                                self.eps = 0.1
+                       elif self.solver_type in [ONECLASS_SVM]:
+                               self.eps = 0.01
 
 class model(Structure):
-       _names = ["param", "nr_class", "nr_feature", "w", "label", "bias"]
-       _types = [parameter, c_int, c_int, POINTER(c_double), POINTER(c_int), c_double]
+       _names = ["param", "nr_class", "nr_feature", "w", "label", "bias", "rho"]
+       _types = [parameter, c_int, c_int, POINTER(c_double), POINTER(c_int), c_double, c_double]
        _fields_ = genFields(_names, _types)
 
        def __init__(self):
@@ -375,10 +383,17 @@ class model(Structure):
        def get_decfun_bias(self, label_idx=0):
                return liblinear.get_decfun_bias(self, label_idx)
 
+       def get_decfun_rho(self):
+               return liblinear.get_decfun_rho(self)
+
        def get_decfun(self, label_idx=0):
                w = [liblinear.get_decfun_coef(self, feat_idx, label_idx) for feat_idx in range(1, self.nr_feature+1)]
-               b = liblinear.get_decfun_bias(self, label_idx)
-               return (w, b)
+               if self.is_oneclass_model():
+                       rho = self.get_decfun_rho()
+                       return (w, -rho)
+               else:
+                       b = liblinear.get_decfun_bias(self, label_idx)
+                       return (w, b)
 
        def is_probability_model(self):
                return (liblinear.check_probability_model(self) == 1)
@@ -386,6 +401,9 @@ class model(Structure):
        def is_regression_model(self):
                return (liblinear.check_regression_model(self) == 1)
 
+       def is_oneclass_model(self):
+               return (liblinear.check_oneclass_model(self) == 1)
+
 def toPyModel(model_ptr):
        """
        toPyModel(model_ptr) -> model
@@ -414,6 +432,7 @@ fillprototype(liblinear.get_nr_class, c_int, [POINTER(model)])
 fillprototype(liblinear.get_labels, None, [POINTER(model), POINTER(c_int)])
 fillprototype(liblinear.get_decfun_coef, c_double, [POINTER(model), c_int, c_int])
 fillprototype(liblinear.get_decfun_bias, c_double, [POINTER(model), c_int])
+fillprototype(liblinear.get_decfun_rho, c_double, [POINTER(model)])
 
 fillprototype(liblinear.free_model_content, None, [POINTER(model)])
 fillprototype(liblinear.free_and_destroy_model, None, [POINTER(POINTER(model))])
@@ -421,4 +440,5 @@ fillprototype(liblinear.destroy_param, None, [POINTER(parameter)])
 fillprototype(liblinear.check_parameter, c_char_p, [POINTER(problem), POINTER(parameter)])
 fillprototype(liblinear.check_probability_model, c_int, [POINTER(model)])
 fillprototype(liblinear.check_regression_model, c_int, [POINTER(model)])
+fillprototype(liblinear.check_oneclass_model, c_int, [POINTER(model)])
 fillprototype(liblinear.set_print_string_function, None, [CFUNCTYPE(None, c_char_p)])
index a768ecf0424c11f600c3b06e9f89b638c888c1fe..9427974501f26f82e0352e37270e62317ac5da97 100644 (file)
@@ -75,6 +75,8 @@ def train(arg1, arg2=None, arg3=None):
                        11 -- L2-regularized L2-loss support vector regression (primal)
                        12 -- L2-regularized L2-loss support vector regression (dual)
                        13 -- L2-regularized L1-loss support vector regression (dual)
+                 for outlier detection
+                       21 -- one-class support vector machine (dual)
                -c cost : set the parameter C (default 1)
                -p epsilon : set the epsilon in loss function of SVR (default 0.1)
                -e epsilon : set tolerance of termination criterion
@@ -83,8 +85,8 @@ def train(arg1, arg2=None, arg3=None):
                                where f is the primal function, (default 0.01)
                        -s 11
                                |f'(w)|_2 <= eps*|f'(w0)|_2 (default 0.0001)
-                       -s 1, 3, 4, and 7
-                               Dual maximal violation <= eps; similar to liblinear (default 0.)
+                       -s 1, 3, 4, 7, and 21
+                               Dual maximal violation <= eps; similar to libsvm (default 0.1 except 0.01 for -s 21)
                        -s 5 and 6
                                |f'(w)|_inf <= eps*min(pos,neg)/l*|f'(w0)|_inf,
                                where f is the primal function (default 0.01)
diff --git a/train.c b/train.c
index f544dcabc5faf8dbe69fc373587a4969a0a2eae6..bd0af9465b218f5a619b791f7b03a5d5968fe74c 100644 (file)
--- a/train.c
+++ b/train.c
@@ -29,8 +29,11 @@ void exit_with_help()
        "       11 -- L2-regularized L2-loss support vector regression (primal)\n"
        "       12 -- L2-regularized L2-loss support vector regression (dual)\n"
        "       13 -- L2-regularized L1-loss support vector regression (dual)\n"
+       "  for outlier detection\n"
+       "       21 -- one-class support vector machine (dual)\n"
        "-c cost : set the parameter C (default 1)\n"
        "-p epsilon : set the epsilon in loss function of SVR (default 0.1)\n"
+       "-n nu : set the parameter nu of one-class SVM (default 0.5)\n"
        "-e epsilon : set tolerance of termination criterion\n"
        "       -s 0 and 2\n"
        "               |f'(w)|_2 <= eps*min(pos,neg)/l*|f'(w0)|_2,\n"
@@ -38,8 +41,8 @@ void exit_with_help()
        "               positive/negative data (default 0.01)\n"
        "       -s 11\n"
        "               |f'(w)|_2 <= eps*|f'(w0)|_2 (default 0.0001)\n"
-       "       -s 1, 3, 4, and 7\n"
-       "               Dual maximal violation <= eps; similar to libsvm (default 0.1)\n"
+       "       -s 1, 3, 4, 7, and 21\n"
+       "               Dual maximal violation <= eps; similar to libsvm (default 0.1 except 0.01 for -s 21)\n"
        "       -s 5 and 6\n"
        "               |f'(w)|_1 <= eps*min(pos,neg)/l*|f'(w0)|_1,\n"
        "               where f is the primal function (default 0.01)\n"
@@ -211,8 +214,9 @@ void parse_command_line(int argc, char **argv, char *input_file_name, char *mode
        // default values
        param.solver_type = L2R_L2LOSS_SVC_DUAL;
        param.C = 1;
-       param.eps = INF; // see setting below
        param.p = 0.1;
+       param.nu = 0.5;
+       param.eps = INF; // see setting below
        param.nr_weight = 0;
        param.weight_label = NULL;
        param.weight = NULL;
@@ -247,6 +251,10 @@ void parse_command_line(int argc, char **argv, char *input_file_name, char *mode
                                param.p = atof(argv[i]);
                                break;
 
+                       case 'n':
+                               param.nu = atof(argv[i]);
+                               break;
+
                        case 'e':
                                param.eps = atof(argv[i]);
                                break;
@@ -352,6 +360,9 @@ void parse_command_line(int argc, char **argv, char *input_file_name, char *mode
                        case L2R_L2LOSS_SVR_DUAL:
                                param.eps = 0.1;
                                break;
+                       case ONECLASS_SVM:
+                               param.eps = 0.01;
+                               break;
                }
        }
 }