SHOGUN
v2.0.0
|
00001 /* 00002 * This program is free software; you can redistribute it and/or modify 00003 * it under the terms of the GNU General Public License as published by 00004 * the Free Software Foundation; either version 3 of the License, or 00005 * (at your option) any later version. 00006 * 00007 * Written (W) 2012 Fernando José Iglesias García 00008 * Copyright (C) 2012 Fernando José Iglesias García 00009 */ 00010 00011 #include <shogun/evaluation/StructuredAccuracy.h> 00012 #include <shogun/structure/HMSVMLabels.h> 00013 #include <shogun/structure/MulticlassSOLabels.h> 00014 00015 using namespace shogun; 00016 00017 CStructuredAccuracy::CStructuredAccuracy() : CEvaluation() 00018 { 00019 } 00020 00021 CStructuredAccuracy::~CStructuredAccuracy() 00022 { 00023 } 00024 00025 float64_t CStructuredAccuracy::evaluate(CLabels* predicted, CLabels* ground_truth) 00026 { 00027 REQUIRE(predicted && ground_truth, "CLabels objects passed to evaluate " 00028 "cannot be null\n"); 00029 REQUIRE(predicted->get_num_labels() == ground_truth->get_num_labels(), 00030 "The number of predicted and ground truth labels must " 00031 "be the same\n"); 00032 REQUIRE(predicted->get_label_type() == LT_STRUCTURED, "The predicted " 00033 "labels must be of type CStructuredLabels\n"); 00034 REQUIRE(ground_truth->get_label_type() == LT_STRUCTURED, "The ground truth " 00035 "labels must be of type CStructuredLabels\n"); 00036 00037 CStructuredLabels* pred_labs = CStructuredLabels::obtain_from_generic(predicted); 00038 CStructuredLabels* true_labs = CStructuredLabels::obtain_from_generic(ground_truth); 00039 00040 REQUIRE(pred_labs->get_structured_data_type() == 00041 true_labs->get_structured_data_type(), "Predicted and ground truth " 00042 "labels must be composed of the same structured data\n"); 00043 00044 switch ( pred_labs->get_structured_data_type() ) 00045 { 00046 case (SDT_REAL): 00047 return evaluate_real(pred_labs, true_labs); 00048 case (SDT_SEQUENCE): 00049 return evaluate_sequence(pred_labs, true_labs); 00050 default: 00051 SG_ERROR("Unknown structured data type for evaluation\n"); 00052 } 00053 00054 return 0.0; 00055 } 00056 00057 SGMatrix< int32_t > CStructuredAccuracy::get_confusion_matrix( 00058 CLabels* predicted, CLabels* ground_truth) 00059 { 00060 SG_SERROR("Not implemented\n"); 00061 return SGMatrix< int32_t >(); 00062 } 00063 00064 float64_t CStructuredAccuracy::evaluate_real(CStructuredLabels* predicted, 00065 CStructuredLabels* ground_truth) 00066 { 00067 int32_t length = predicted->get_num_labels(); 00068 int32_t num_equal = 0; 00069 00070 for ( int32_t i = 0 ; i < length ; ++i ) 00071 { 00072 CRealNumber* truth = 00073 CRealNumber::obtain_from_generic(ground_truth->get_label(i)); 00074 CRealNumber* pred = 00075 CRealNumber::obtain_from_generic(predicted->get_label(i)); 00076 00077 num_equal += truth->value == pred->value; 00078 00079 SG_UNREF(truth); 00080 SG_UNREF(pred); 00081 } 00082 00083 return (1.0*num_equal) / length; 00084 } 00085 00086 float64_t CStructuredAccuracy::evaluate_sequence(CStructuredLabels* predicted, 00087 CStructuredLabels* ground_truth) 00088 { 00089 int32_t length = predicted->get_num_labels(); 00090 // Accuracy of each each label 00091 SGVector< float64_t > accuracies(length); 00092 int32_t num_equal = 0; 00093 00094 for ( int32_t i = 0 ; i < length ; ++i ) 00095 { 00096 CSequence* true_seq = 00097 CSequence::obtain_from_generic(ground_truth->get_label(i)); 00098 CSequence* pred_seq = 00099 CSequence::obtain_from_generic(predicted->get_label(i)); 00100 00101 SGVector<int32_t> true_seq_data = true_seq->get_data(); 00102 SGVector<int32_t> pred_seq_data = pred_seq->get_data(); 00103 00104 REQUIRE(true_seq_data.size() == pred_seq_data.size(), "Corresponding ground " 00105 "truth and predicted sequences must be equally long\n"); 00106 00107 num_equal = 0; 00108 // Count the number of elements that are equal in both sequences 00109 for ( int32_t j = 0 ; j < true_seq_data.size() ; ++j ) 00110 num_equal += true_seq_data[j] == pred_seq_data[j]; 00111 00112 accuracies[i] = (1.0*num_equal) / true_seq_data.size(); 00113 00114 SG_UNREF(true_seq); 00115 SG_UNREF(pred_seq); 00116 } 00117 00118 return accuracies.mean(); 00119 }