From 56d69b2d87fc943611cbfa18b072fb6ebb451102 Mon Sep 17 00:00:00 2001 From: Yu-Chin Date: Fri, 25 Jul 2014 23:37:50 +0800 Subject: [PATCH] use __all__ in Python interfaces to make sure only useful things are visible to the users --- python/liblinear.py | 24 +++++++++++++++++------- python/liblinearutil.py | 5 +++++ 2 files changed, 22 insertions(+), 7 deletions(-) diff --git a/python/liblinear.py b/python/liblinear.py index 584687f..ac06a33 100644 --- a/python/liblinear.py +++ b/python/liblinear.py @@ -5,6 +5,12 @@ from ctypes.util import find_library from os import path import sys +__all__ = ['liblinear', 'feature_node', 'gen_feature_nodearray', 'problem', + 'parameter', 'model', 'toPyModel', 'L2R_LR', 'L2R_L2LOSS_SVC_DUAL', + 'L2R_L2LOSS_SVC', 'L2R_L1LOSS_SVC_DUAL', 'MCSVM_CS', + 'L1R_L2LOSS_SVC', 'L1R_LR', 'L2R_LR_DUAL', 'L2R_L2LOSS_SVR', + 'L2R_L2LOSS_SVR_DUAL', 'L2R_L1LOSS_SVR_DUAL', 'print_null'] + try: dirname = path.dirname(path.abspath(__file__)) if sys.platform == 'win32': @@ -20,13 +26,17 @@ except: else: raise Exception('LIBLINEAR library not found.') -# Construct constants -SOLVER_TYPE = ['L2R_LR', 'L2R_L2LOSS_SVC_DUAL', 'L2R_L2LOSS_SVC', 'L2R_L1LOSS_SVC_DUAL',\ - 'MCSVM_CS', 'L1R_L2LOSS_SVC', 'L1R_LR', 'L2R_LR_DUAL', \ - None, None, None, \ - 'L2R_L2LOSS_SVR', 'L2R_L2LOSS_SVR_DUAL', 'L2R_L1LOSS_SVR_DUAL'] -for i, s in enumerate(SOLVER_TYPE): - if s is not None: exec("%s = %d" % (s , i)) +L2R_LR = 0 +L2R_L2LOSS_SVC_DUAL = 1 +L2R_L2LOSS_SVC = 2 +L2R_L1LOSS_SVC_DUAL = 3 +MCSVM_CS = 4 +L1R_L2LOSS_SVC = 5 +L1R_LR = 6 +L2R_LR_DUAL = 7 +L2R_L2LOSS_SVR = 11 +L2R_L2LOSS_SVR_DUAL = 12 +L2R_L1LOSS_SVR_DUAL = 13 PRINT_STRING_FUN = CFUNCTYPE(None, c_char_p) def print_null(s): diff --git a/python/liblinearutil.py b/python/liblinearutil.py index d63e088..626489b 100644 --- a/python/liblinearutil.py +++ b/python/liblinearutil.py @@ -3,6 +3,11 @@ import os, sys sys.path = [os.path.dirname(os.path.abspath(__file__))] + sys.path from liblinear import * +from ctypes import c_double + +__all__ = ['svm_read_problem', 'load_model', 'save_model', 'evaluations', + 'train', 'predict'] + def svm_read_problem(data_file_name): """ -- 2.40.0