SHOGUN
v2.0.0
|
00001 /* 00002 * Copyright (c) 2009 Yahoo! Inc. All rights reserved. The copyrights 00003 * embodied in the content of this file are licensed under the BSD 00004 * (revised) open source license. 00005 * 00006 * This program is free software; you can redistribute it and/or modify 00007 * it under the terms of the GNU General Public License as published by 00008 * the Free Software Foundation; either version 3 of the License, or 00009 * (at your option) any later version. 00010 * 00011 * Written (W) 2011 Shashwat Lal Das 00012 * Copyright (C) 2011 Berlin Institute of Technology and Max-Planck-Society. 00013 */ 00014 00015 #include <shogun/classifier/vw/VwRegressor.h> 00016 #include <shogun/loss/SquaredLoss.h> 00017 #include <shogun/io/IOBuffer.h> 00018 00019 using namespace shogun; 00020 00021 CVwRegressor::CVwRegressor() 00022 : CSGObject() 00023 { 00024 weight_vectors = NULL; 00025 loss = new CSquaredLoss(); 00026 init(NULL); 00027 } 00028 00029 CVwRegressor::CVwRegressor(CVwEnvironment* env_to_use) 00030 : CSGObject() 00031 { 00032 weight_vectors = NULL; 00033 loss = new CSquaredLoss(); 00034 init(env_to_use); 00035 } 00036 00037 CVwRegressor::~CVwRegressor() 00038 { 00039 SG_FREE(weight_vectors); 00040 SG_UNREF(loss); 00041 SG_UNREF(env); 00042 } 00043 00044 void CVwRegressor::init(CVwEnvironment* env_to_use) 00045 { 00046 if (!env_to_use) 00047 env_to_use = new CVwEnvironment(); 00048 00049 env = env_to_use; 00050 SG_REF(env); 00051 00052 // For each feature, there should be 'stride' number of 00053 // elements in the weight vector 00054 vw_size_t length = ((vw_size_t) 1) << env->num_bits; 00055 env->thread_mask = (env->stride * (length >> env->thread_bits)) - 1; 00056 00057 // Only one learning thread for now 00058 vw_size_t num_threads = 1; 00059 weight_vectors = SG_MALLOC(float32_t*, num_threads); 00060 00061 for (vw_size_t i = 0; i < num_threads; i++) 00062 { 00063 weight_vectors[i] = SG_CALLOC(float32_t, env->stride * length / num_threads); 00064 00065 if (env->random_weights) 00066 { 00067 for (vw_size_t j = 0; j < length/num_threads; j++) 00068 weight_vectors[i][j] = CMath::random(-0.5, 0.5); 00069 } 00070 00071 if (env->initial_weight != 0.) 00072 for (vw_size_t j = 0; j < env->stride*length/num_threads; j+=env->stride) 00073 weight_vectors[i][j] = env->initial_weight; 00074 00075 if (env->adaptive) 00076 for (vw_size_t j = 1; j < env->stride*length/num_threads; j+=env->stride) 00077 weight_vectors[i][j] = 1; 00078 } 00079 } 00080 00081 void CVwRegressor::dump_regressor(char* reg_name, bool as_text) 00082 { 00083 CIOBuffer io_temp; 00084 int32_t f = io_temp.open_file(reg_name,'w'); 00085 00086 if (f < 0) 00087 SG_SERROR("Can't open: %s for writing! Exiting.\n", reg_name); 00088 00089 const char* vw_version = env->vw_version; 00090 vw_size_t v_length = env->v_length; 00091 00092 if (!as_text) 00093 { 00094 // Write version info 00095 io_temp.write_file((char*)&v_length, sizeof(v_length)); 00096 io_temp.write_file(vw_version,v_length); 00097 00098 // Write max and min labels 00099 io_temp.write_file((char*)&env->min_label, sizeof(env->min_label)); 00100 io_temp.write_file((char*)&env->max_label, sizeof(env->max_label)); 00101 00102 // Write weight vector bits information 00103 io_temp.write_file((char *)&env->num_bits, sizeof(env->num_bits)); 00104 io_temp.write_file((char *)&env->thread_bits, sizeof(env->thread_bits)); 00105 00106 // For paired namespaces forming quadratic features 00107 int32_t len = env->pairs.get_num_elements(); 00108 io_temp.write_file((char *)&len, sizeof(len)); 00109 00110 for (int32_t k = 0; k < env->pairs.get_num_elements(); k++) 00111 io_temp.write_file(env->pairs.get_element(k), 2); 00112 00113 // ngram and skips information 00114 io_temp.write_file((char*)&env->ngram, sizeof(env->ngram)); 00115 io_temp.write_file((char*)&env->skips, sizeof(env->skips)); 00116 } 00117 else 00118 { 00119 // Write as human readable form 00120 char buff[512]; 00121 int32_t len; 00122 00123 len = sprintf(buff, "Version %s\n", vw_version); 00124 io_temp.write_file(buff, len); 00125 len = sprintf(buff, "Min label:%f max label:%f\n", env->min_label, env->max_label); 00126 io_temp.write_file(buff, len); 00127 len = sprintf(buff, "bits:%d thread_bits:%d\n", (int32_t)env->num_bits, (int32_t)env->thread_bits); 00128 io_temp.write_file(buff, len); 00129 00130 if (env->pairs.get_num_elements() > 0) 00131 { 00132 len = sprintf(buff, "\n"); 00133 io_temp.write_file(buff, len); 00134 } 00135 00136 len = sprintf(buff, "ngram:%d skips:%d\nindex:weight pairs:\n", (int32_t)env->ngram, (int32_t)env->skips); 00137 io_temp.write_file(buff, len); 00138 } 00139 00140 uint32_t length = 1 << env->num_bits; 00141 vw_size_t num_threads = env->num_threads(); 00142 vw_size_t stride = env->stride; 00143 00144 // Write individual weights 00145 for(uint32_t i = 0; i < length; i++) 00146 { 00147 float32_t v; 00148 v = weight_vectors[i%num_threads][stride*(i/num_threads)]; 00149 if (v != 0.) 00150 { 00151 if (!as_text) 00152 { 00153 io_temp.write_file((char *)&i, sizeof (i)); 00154 io_temp.write_file((char *)&v, sizeof (v)); 00155 } 00156 else 00157 { 00158 char buff[512]; 00159 int32_t len = sprintf(buff, "%d:%f\n", i, v); 00160 io_temp.write_file(buff, len); 00161 } 00162 } 00163 } 00164 00165 io_temp.close_file(); 00166 } 00167 00168 void CVwRegressor::load_regressor(char* file) 00169 { 00170 CIOBuffer source; 00171 int32_t fd = source.open_file(file, 'r'); 00172 00173 if (fd < 0) 00174 SG_SERROR("Unable to open file for loading regressor!\n"); 00175 00176 // Read version info 00177 vw_size_t v_length; 00178 source.read_file((char*)&v_length, sizeof(v_length)); 00179 char* t = SG_MALLOC(char, v_length); 00180 source.read_file(t,v_length); 00181 if (strcmp(t,env->vw_version) != 0) 00182 { 00183 SG_FREE(t); 00184 SG_SERROR("Regressor source has an incompatible VW version!\n"); 00185 } 00186 SG_FREE(t); 00187 00188 // Read min and max label 00189 source.read_file((char*)&env->min_label, sizeof(env->min_label)); 00190 source.read_file((char*)&env->max_label, sizeof(env->max_label)); 00191 00192 // Read num_bits, multiple sources are not supported 00193 vw_size_t local_num_bits; 00194 source.read_file((char *)&local_num_bits, sizeof(local_num_bits)); 00195 00196 if ((vw_size_t) env->num_bits != local_num_bits) 00197 SG_SERROR("Wrong number of bits in regressor source!\n"); 00198 00199 env->num_bits = local_num_bits; 00200 00201 vw_size_t local_thread_bits; 00202 source.read_file((char*)&local_thread_bits, sizeof(local_thread_bits)); 00203 00204 env->thread_bits = local_thread_bits; 00205 00206 int32_t len; 00207 source.read_file((char *)&len, sizeof(len)); 00208 00209 // Read paired namespace information 00210 DynArray<char*> local_pairs; 00211 for (; len > 0; len--) 00212 { 00213 char pair[3]; 00214 source.read_file(pair, sizeof(char)*2); 00215 pair[2]='\0'; 00216 local_pairs.push_back(pair); 00217 } 00218 00219 env->pairs = local_pairs; 00220 00221 // Initialize the weight vector 00222 if (weight_vectors) 00223 SG_FREE(weight_vectors); 00224 init(env); 00225 00226 vw_size_t local_ngram; 00227 source.read_file((char*)&local_ngram, sizeof(local_ngram)); 00228 vw_size_t local_skips; 00229 source.read_file((char*)&local_skips, sizeof(local_skips)); 00230 00231 env->ngram = local_ngram; 00232 env->skips = local_skips; 00233 00234 // Read individual weights 00235 vw_size_t stride = env->stride; 00236 while (true) 00237 { 00238 uint32_t hash; 00239 ssize_t hash_bytes = source.read_file((char *)&hash, sizeof(hash)); 00240 if (hash_bytes <= 0) 00241 break; 00242 00243 float32_t w = 0.; 00244 ssize_t weight_bytes = source.read_file((char *)&w, sizeof(float32_t)); 00245 if (weight_bytes <= 0) 00246 break; 00247 00248 vw_size_t num_threads = env->num_threads(); 00249 00250 weight_vectors[hash % num_threads][(hash*stride)/num_threads] 00251 = weight_vectors[hash % num_threads][(hash*stride)/num_threads] + w; 00252 } 00253 source.close_file(); 00254 }