SHOGUN
v2.0.0
|
00001 /* 00002 * This program is free software; you can redistribute it and/or modify 00003 * it under the terms of the GNU General Public License as published by 00004 * the Free Software Foundation; either version 3 of the License, or 00005 * (at your option) any later version. 00006 * 00007 * Copyright (C) 2012 Sergey Lisitsyn 00008 */ 00009 00010 #ifndef MULTITASKLOGISTICREGRESSION_H_ 00011 #define MULTITASKLOGISTICREGRESSION_H_ 00012 00013 #include <shogun/lib/config.h> 00014 #include <shogun/transfer/multitask/MultitaskLinearMachine.h> 00015 #include <shogun/transfer/multitask/TaskRelation.h> 00016 #include <shogun/transfer/multitask/TaskGroup.h> 00017 #include <shogun/transfer/multitask/TaskTree.h> 00018 #include <shogun/transfer/multitask/Task.h> 00019 00020 #include <vector> 00021 #include <set> 00022 00023 using namespace std; 00024 00025 namespace shogun 00026 { 00031 class CMultitaskLogisticRegression : public CMultitaskLinearMachine 00032 { 00033 00034 public: 00036 MACHINE_PROBLEM_TYPE(PT_BINARY) 00037 00038 00039 CMultitaskLogisticRegression(); 00040 00048 CMultitaskLogisticRegression( 00049 float64_t z, CDotFeatures* training_data, 00050 CBinaryLabels* training_labels, CTaskRelation* task_relation); 00051 00053 virtual ~CMultitaskLogisticRegression(); 00054 00056 virtual const char* get_name() const 00057 { 00058 return "MultitaskLogisticRegression"; 00059 } 00060 00062 int32_t get_max_iter() const; 00064 float64_t get_q() const; 00066 int32_t get_regularization() const; 00068 int32_t get_termination() const; 00070 float64_t get_tolerance() const; 00072 float64_t get_z() const; 00073 00075 void set_max_iter(int32_t max_iter); 00077 void set_q(float64_t q); 00079 void set_regularization(int32_t regularization); 00081 void set_termination(int32_t termination); 00083 void set_tolerance(float64_t tolerance); 00085 void set_z(float64_t z); 00086 00088 virtual float64_t apply_one(int32_t i); 00089 00090 protected: 00091 00093 virtual bool train_machine(CFeatures* data=NULL); 00094 00096 virtual bool train_locked_implementation(SGVector<index_t>* tasks); 00097 00098 private: 00099 00101 void register_parameters(); 00102 00104 void initialize_parameters(); 00105 00106 protected: 00107 00109 int32_t m_regularization; 00110 00112 int32_t m_termination; 00113 00115 int32_t m_max_iter; 00116 00118 float64_t m_tolerance; 00119 00121 float64_t m_q; 00122 00124 float64_t m_z; 00125 00126 }; 00127 } 00128 #endif