SHOGUN
v2.0.0
|
00001 /* 00002 * This program is free software; you can redistribute it and/or modify 00003 * it under the terms of the GNU General Public License as published by 00004 * the Free Software Foundation; either version 3 of the License, or 00005 * (at your option) any later version. 00006 * 00007 * Copyright (C) 2012 Sergey Lisitsyn 00008 */ 00009 00010 #include <shogun/transfer/multitask/MultitaskTraceLogisticRegression.h> 00011 #include <shogun/lib/malsar/malsar_low_rank.h> 00012 #include <shogun/lib/malsar/malsar_options.h> 00013 #include <shogun/lib/IndexBlockGroup.h> 00014 #include <shogun/lib/SGVector.h> 00015 00016 namespace shogun 00017 { 00018 00019 CMultitaskTraceLogisticRegression::CMultitaskTraceLogisticRegression() : 00020 CMultitaskLogisticRegression(), m_rho(0.0) 00021 { 00022 init(); 00023 } 00024 00025 CMultitaskTraceLogisticRegression::CMultitaskTraceLogisticRegression( 00026 float64_t rho, CDotFeatures* train_features, 00027 CBinaryLabels* train_labels, CTaskGroup* task_group) : 00028 CMultitaskLogisticRegression(0.0,train_features,train_labels,(CTaskRelation*)task_group) 00029 { 00030 set_rho(rho); 00031 init(); 00032 } 00033 00034 void CMultitaskTraceLogisticRegression::init() 00035 { 00036 SG_ADD(&m_rho,"rho","rho",MS_AVAILABLE); 00037 } 00038 00039 void CMultitaskTraceLogisticRegression::set_rho(float64_t rho) 00040 { 00041 m_rho = rho; 00042 } 00043 00044 float64_t CMultitaskTraceLogisticRegression::get_rho() const 00045 { 00046 return m_rho; 00047 } 00048 00049 CMultitaskTraceLogisticRegression::~CMultitaskTraceLogisticRegression() 00050 { 00051 } 00052 00053 bool CMultitaskTraceLogisticRegression::train_locked_implementation(SGVector<index_t>* tasks) 00054 { 00055 SGVector<float64_t> y(m_labels->get_num_labels()); 00056 for (int32_t i=0; i<y.vlen; i++) 00057 y[i] = ((CBinaryLabels*)m_labels)->get_label(i); 00058 00059 malsar_options options = malsar_options::default_options(); 00060 options.termination = m_termination; 00061 options.tolerance = m_tolerance; 00062 options.max_iter = m_max_iter; 00063 options.n_tasks = ((CTaskGroup*)m_task_relation)->get_num_tasks(); 00064 options.tasks_indices = tasks; 00065 00066 #ifdef HAVE_EIGEN3 00067 malsar_result_t model = malsar_low_rank( 00068 features, y.vector, m_rho, options); 00069 00070 m_tasks_w = model.w; 00071 m_tasks_c = model.c; 00072 #else 00073 SG_WARNING("Please install Eigen3 to use MultitaskTraceLogisticRegression\n"); 00074 m_tasks_w = SGMatrix<float64_t>(((CDotFeatures*)features)->get_dim_feature_space(), options.n_tasks); 00075 m_tasks_c = SGVector<float64_t>(options.n_tasks); 00076 #endif 00077 return true; 00078 } 00079 00080 bool CMultitaskTraceLogisticRegression::train_machine(CFeatures* data) 00081 { 00082 if (data && (CDotFeatures*)data) 00083 set_features((CDotFeatures*)data); 00084 00085 ASSERT(features); 00086 ASSERT(m_labels); 00087 ASSERT(m_task_relation); 00088 00089 SGVector<float64_t> y(m_labels->get_num_labels()); 00090 for (int32_t i=0; i<y.vlen; i++) 00091 y[i] = ((CBinaryLabels*)m_labels)->get_label(i); 00092 00093 malsar_options options = malsar_options::default_options(); 00094 options.termination = m_termination; 00095 options.tolerance = m_tolerance; 00096 options.max_iter = m_max_iter; 00097 options.n_tasks = ((CTaskGroup*)m_task_relation)->get_num_tasks(); 00098 options.tasks_indices = ((CTaskGroup*)m_task_relation)->get_tasks_indices(); 00099 00100 #ifdef HAVE_EIGEN3 00101 malsar_result_t model = malsar_low_rank( 00102 features, y.vector, m_rho, options); 00103 00104 m_tasks_w = model.w; 00105 m_tasks_c = model.c; 00106 #else 00107 SG_WARNING("Please install Eigen3 to use MultitaskTraceLogisticRegression\n"); 00108 m_tasks_w = SGMatrix<float64_t>(((CDotFeatures*)features)->get_dim_feature_space(), options.n_tasks); 00109 m_tasks_c = SGVector<float64_t>(options.n_tasks); 00110 #endif 00111 00112 for (int32_t i=0; i<options.n_tasks; i++) 00113 options.tasks_indices[i].~SGVector<index_t>(); 00114 SG_FREE(options.tasks_indices); 00115 00116 return true; 00117 } 00118 00119 }