SHOGUN
v2.0.0
|
00001 /* 00002 * This program is free software; you can redistribute it and/or modify 00003 * it under the terms of the GNU General Public License as published by 00004 * the Free Software Foundation; either version 3 of the License, or 00005 * (at your option) any later version. 00006 * 00007 * Copyright (C) 2012 Sergey Lisitsyn 00008 */ 00009 00010 #include <shogun/transfer/multitask/MultitaskClusteredLogisticRegression.h> 00011 #include <shogun/lib/malsar/malsar_clustered.h> 00012 #include <shogun/lib/malsar/malsar_options.h> 00013 #include <shogun/lib/SGVector.h> 00014 00015 namespace shogun 00016 { 00017 00018 CMultitaskClusteredLogisticRegression::CMultitaskClusteredLogisticRegression() : 00019 CMultitaskLogisticRegression(), m_rho1(0.0), m_rho2(0.0) 00020 { 00021 } 00022 00023 CMultitaskClusteredLogisticRegression::CMultitaskClusteredLogisticRegression( 00024 float64_t rho1, float64_t rho2, CDotFeatures* train_features, 00025 CBinaryLabels* train_labels, CTaskGroup* task_group, int32_t n_clusters) : 00026 CMultitaskLogisticRegression(0.0,train_features,train_labels,(CTaskRelation*)task_group) 00027 { 00028 set_rho1(rho1); 00029 set_rho2(rho2); 00030 set_num_clusters(n_clusters); 00031 } 00032 00033 int32_t CMultitaskClusteredLogisticRegression::get_rho1() const 00034 { 00035 return m_rho1; 00036 } 00037 00038 int32_t CMultitaskClusteredLogisticRegression::get_rho2() const 00039 { 00040 return m_rho2; 00041 } 00042 00043 void CMultitaskClusteredLogisticRegression::set_rho1(float64_t rho1) 00044 { 00045 m_rho1 = rho1; 00046 } 00047 00048 void CMultitaskClusteredLogisticRegression::set_rho2(float64_t rho2) 00049 { 00050 m_rho2 = rho2; 00051 } 00052 00053 int32_t CMultitaskClusteredLogisticRegression::get_num_clusters() const 00054 { 00055 return m_num_clusters; 00056 } 00057 00058 void CMultitaskClusteredLogisticRegression::set_num_clusters(int32_t num_clusters) 00059 { 00060 m_num_clusters = num_clusters; 00061 } 00062 00063 CMultitaskClusteredLogisticRegression::~CMultitaskClusteredLogisticRegression() 00064 { 00065 } 00066 00067 bool CMultitaskClusteredLogisticRegression::train_locked_implementation(SGVector<index_t>* tasks) 00068 { 00069 SGVector<float64_t> y(m_labels->get_num_labels()); 00070 for (int32_t i=0; i<y.vlen; i++) 00071 y[i] = ((CBinaryLabels*)m_labels)->get_label(i); 00072 00073 malsar_options options = malsar_options::default_options(); 00074 options.termination = m_termination; 00075 options.tolerance = m_tolerance; 00076 options.max_iter = m_max_iter; 00077 options.n_tasks = ((CTaskGroup*)m_task_relation)->get_num_tasks(); 00078 options.tasks_indices = tasks; 00079 options.n_clusters = m_num_clusters; 00080 00081 #ifdef HAVE_EIGEN3 00082 malsar_result_t model = malsar_clustered( 00083 features, y.vector, m_rho1, m_rho2, options); 00084 00085 m_tasks_w = model.w; 00086 m_tasks_c = model.c; 00087 #else 00088 SG_WARNING("Please install Eigen3 to use MultitaskClusteredLogisticRegression\n"); 00089 m_tasks_w = SGMatrix<float64_t>(((CDotFeatures*)features)->get_dim_feature_space(), options.n_tasks); 00090 m_tasks_c = SGVector<float64_t>(options.n_tasks); 00091 #endif 00092 return true; 00093 } 00094 00095 bool CMultitaskClusteredLogisticRegression::train_machine(CFeatures* data) 00096 { 00097 if (data && (CDotFeatures*)data) 00098 set_features((CDotFeatures*)data); 00099 00100 ASSERT(features); 00101 ASSERT(m_labels); 00102 ASSERT(m_task_relation); 00103 00104 SGVector<float64_t> y(m_labels->get_num_labels()); 00105 for (int32_t i=0; i<y.vlen; i++) 00106 y[i] = ((CBinaryLabels*)m_labels)->get_label(i); 00107 00108 malsar_options options = malsar_options::default_options(); 00109 options.termination = m_termination; 00110 options.tolerance = m_tolerance; 00111 options.max_iter = m_max_iter; 00112 options.n_tasks = ((CTaskGroup*)m_task_relation)->get_num_tasks(); 00113 options.tasks_indices = ((CTaskGroup*)m_task_relation)->get_tasks_indices(); 00114 options.n_clusters = m_num_clusters; 00115 00116 #ifdef HAVE_EIGEN3 00117 malsar_result_t model = malsar_clustered( 00118 features, y.vector, m_rho1, m_rho2, options); 00119 00120 m_tasks_w = model.w; 00121 m_tasks_c = model.c; 00122 #else 00123 SG_WARNING("Please install Eigen3 to use MultitaskClusteredLogisticRegression\n"); 00124 m_tasks_w = SGMatrix<float64_t>(((CDotFeatures*)features)->get_dim_feature_space(), options.n_tasks); 00125 m_tasks_c = SGVector<float64_t>(options.n_tasks); 00126 #endif 00127 00128 for (int32_t i=0; i<options.n_tasks; i++) 00129 options.tasks_indices[i].~SGVector<index_t>(); 00130 SG_FREE(options.tasks_indices); 00131 00132 return true; 00133 } 00134 00135 }