SphinxBase 0.6

src/libsphinxbase/lm/lm3g_model.c

00001 /* -*- c-basic-offset: 4; indent-tabs-mode: nil -*- */
00002 /* ====================================================================
00003  * Copyright (c) 1999-2007 Carnegie Mellon University.  All rights
00004  * reserved.
00005  *
00006  * Redistribution and use in source and binary forms, with or without
00007  * modification, are permitted provided that the following conditions
00008  * are met:
00009  *
00010  * 1. Redistributions of source code must retain the above copyright
00011  *    notice, this list of conditions and the following disclaimer. 
00012  *
00013  * 2. Redistributions in binary form must reproduce the above copyright
00014  *    notice, this list of conditions and the following disclaimer in
00015  *    the documentation and/or other materials provided with the
00016  *    distribution.
00017  *
00018  * This work was supported in part by funding from the Defense Advanced 
00019  * Research Projects Agency and the National Science Foundation of the 
00020  * United States of America, and the CMU Sphinx Speech Consortium.
00021  *
00022  * THIS SOFTWARE IS PROVIDED BY CARNEGIE MELLON UNIVERSITY ``AS IS'' AND 
00023  * ANY EXPRESSED OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, 
00024  * THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
00025  * PURPOSE ARE DISCLAIMED.  IN NO EVENT SHALL CARNEGIE MELLON UNIVERSITY
00026  * NOR ITS EMPLOYEES BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
00027  * SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT 
00028  * LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, 
00029  * DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY 
00030  * THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT 
00031  * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 
00032  * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
00033  *
00034  * ====================================================================
00035  *
00036  */
00037 /*
00038  * \file lm3g_model.c Core Sphinx 3-gram code used in
00039  * DMP/DMP32/ARPA (for now) model code.
00040  *
00041  * Author: A cast of thousands, probably.
00042  */
00043 #include <string.h>
00044 #include <assert.h>
00045 #include <limits.h>
00046 
00047 #include "sphinxbase/listelem_alloc.h"
00048 #include "sphinxbase/ckd_alloc.h"
00049 #include "sphinxbase/err.h"
00050 
00051 #include "lm3g_model.h"
00052 
00053 void
00054 lm3g_tginfo_free(ngram_model_t *base, lm3g_model_t *lm3g)
00055 {
00056         if (lm3g->tginfo == NULL)
00057                 return;
00058         listelem_alloc_free(lm3g->le);
00059         ckd_free(lm3g->tginfo);
00060 }
00061 
00062 void
00063 lm3g_tginfo_reset(ngram_model_t *base, lm3g_model_t *lm3g)
00064 {
00065     if (lm3g->tginfo == NULL)
00066         return;
00067     listelem_alloc_free(lm3g->le);
00068     memset(lm3g->tginfo, 0, base->n_counts[0] * sizeof(tginfo_t *));
00069     lm3g->le = listelem_alloc_init(sizeof(tginfo_t));
00070 }
00071 
00072 void
00073 lm3g_apply_weights(ngram_model_t *base,
00074                    lm3g_model_t *lm3g,
00075                    float32 lw, float32 wip, float32 uw)
00076 {
00077     int32 log_wip, log_uw, log_uniform_weight;
00078     int i;
00079 
00080     /* Precalculate some log values we will like. */
00081     log_wip = logmath_log(base->lmath, wip);
00082     log_uw = logmath_log(base->lmath, uw);
00083     log_uniform_weight = logmath_log(base->lmath, 1.0 - uw);
00084 
00085     for (i = 0; i < base->n_counts[0]; ++i) {
00086         int32 prob1, bo_wt, n_used;
00087 
00088         /* Backoff weights just get scaled by the lw. */
00089         bo_wt = (int32)(lm3g->unigrams[i].bo_wt1.l / base->lw);
00090         /* Unscaling unigram probs is a bit more complicated, so punt
00091          * it back to the general code. */
00092         prob1 = ngram_ng_prob(base, i, NULL, 0, &n_used);
00093         /* Now compute the new scaled probabilities. */
00094         lm3g->unigrams[i].bo_wt1.l = (int32)(bo_wt * lw);
00095         if (strcmp(base->word_str[i], "<s>") == 0) { /* FIXME: configurable start_sym */
00096             /* Apply language weight and WIP */
00097             lm3g->unigrams[i].prob1.l = (int32)(prob1 * lw) + log_wip;
00098         }
00099         else {
00100             /* Interpolate unigram probability with uniform. */
00101             prob1 += log_uw;
00102             prob1 = logmath_add(base->lmath, prob1, base->log_uniform + log_uniform_weight);
00103             /* Apply language weight and WIP */
00104             lm3g->unigrams[i].prob1.l = (int32)(prob1 * lw) + log_wip;
00105         }
00106     }
00107 
00108     for (i = 0; i < lm3g->n_prob2; ++i) {
00109         int32 prob2;
00110         /* Can't just punt this back to general code since it is quantized. */
00111         prob2 = (int32)((lm3g->prob2[i].l - base->log_wip) / base->lw);
00112         lm3g->prob2[i].l = (int32)(prob2 * lw) + log_wip;
00113     }
00114 
00115     if (base->n > 2) {
00116         for (i = 0; i < lm3g->n_bo_wt2; ++i) {
00117             lm3g->bo_wt2[i].l = (int32)(lm3g->bo_wt2[i].l  / base->lw * lw);
00118         }
00119         for (i = 0; i < lm3g->n_prob3; i++) {
00120             int32 prob3;
00121             /* Can't just punt this back to general code since it is quantized. */
00122             prob3 = (int32)((lm3g->prob3[i].l - base->log_wip) / base->lw);
00123             lm3g->prob3[i].l = (int32)(prob3 * lw) + log_wip;
00124         }
00125     }
00126 
00127     /* Store updated values in the model. */
00128     base->log_wip = log_wip;
00129     base->log_uw = log_uw;
00130     base->log_uniform_weight = log_uniform_weight;
00131     base->lw = lw;
00132 }
00133 
00134 int32
00135 lm3g_add_ug(ngram_model_t *base,
00136             lm3g_model_t *lm3g, int32 wid, int32 lweight)
00137 {
00138     int32 score;
00139 
00140     /* This would be very bad if this happened! */
00141     assert(!NGRAM_IS_CLASSWID(wid));
00142 
00143     /* Reallocate unigram array. */
00144     lm3g->unigrams = ckd_realloc(lm3g->unigrams,
00145                                  sizeof(*lm3g->unigrams) * base->n_1g_alloc);
00146     memset(lm3g->unigrams + base->n_counts[0], 0,
00147            (base->n_1g_alloc - base->n_counts[0]) * sizeof(*lm3g->unigrams));
00148     /* Reallocate tginfo array. */
00149     lm3g->tginfo = ckd_realloc(lm3g->tginfo,
00150                                sizeof(*lm3g->tginfo) * base->n_1g_alloc);
00151     memset(lm3g->tginfo + base->n_counts[0], 0,
00152            (base->n_1g_alloc - base->n_counts[0]) * sizeof(*lm3g->tginfo));
00153     /* FIXME: we really ought to update base->log_uniform *and*
00154      * renormalize all the other unigrams.  This is really slow, so I
00155      * will probably just provide a function to renormalize after
00156      * adding unigrams, for anyone who really cares. */
00157     /* This could be simplified but then we couldn't do it in logmath */
00158     score = lweight + base->log_uniform + base->log_uw;
00159     score = logmath_add(base->lmath, score,
00160                         base->log_uniform + base->log_uniform_weight);
00161     lm3g->unigrams[wid].prob1.l = score;
00162     /* This unigram by definition doesn't participate in any bigrams,
00163      * so its backoff weight and bigram pointer are both undefined. */
00164     lm3g->unigrams[wid].bo_wt1.l = 0;
00165     lm3g->unigrams[wid].bigrams = 0;
00166     /* Finally, increase the unigram count */
00167     ++base->n_counts[0];
00168     /* FIXME: Note that this can actually be quite bogus due to the
00169      * presence of class words.  If wid falls outside the unigram
00170      * count, increase it to compensate, at the cost of no longer
00171      * really knowing how many unigrams we have :( */
00172     if (wid >= base->n_counts[0])
00173         base->n_counts[0] = wid + 1;
00174 
00175     return score;
00176 }
00177 
00178 void
00179 init_sorted_list(sorted_list_t * l)
00180 {
00181     /* FIXME FIXME FIXME: Fixed size array!??! */
00182     l->list = ckd_calloc(MAX_SORTED_ENTRIES,
00183                          sizeof(sorted_entry_t));
00184     l->list[0].val.l = INT_MIN;
00185     l->list[0].lower = 0;
00186     l->list[0].higher = 0;
00187     l->free = 1;
00188 }
00189 
00190 void
00191 free_sorted_list(sorted_list_t * l)
00192 {
00193     free(l->list);
00194 }
00195 
00196 lmprob_t *
00197 vals_in_sorted_list(sorted_list_t * l)
00198 {
00199     lmprob_t *vals;
00200     int32 i;
00201 
00202     vals = ckd_calloc(l->free, sizeof(lmprob_t));
00203     for (i = 0; i < l->free; i++)
00204         vals[i] = l->list[i].val;
00205     return (vals);
00206 }
00207 
00208 int32
00209 sorted_id(sorted_list_t * l, int32 *val)
00210 {
00211     int32 i = 0;
00212 
00213     for (;;) {
00214         if (*val == l->list[i].val.l)
00215             return (i);
00216         if (*val < l->list[i].val.l) {
00217             if (l->list[i].lower == 0) {
00218                 if (l->free >= MAX_SORTED_ENTRIES) {
00219                     /* Make the best of a bad situation. */
00220                     E_WARN("sorted list overflow (%d => %d)\n",
00221                            *val, l->list[i].val.l);
00222                     return i;
00223                 }
00224 
00225                 l->list[i].lower = l->free;
00226                 (l->free)++;
00227                 i = l->list[i].lower;
00228                 l->list[i].val.l = *val;
00229                 return (i);
00230             }
00231             else
00232                 i = l->list[i].lower;
00233         }
00234         else {
00235             if (l->list[i].higher == 0) {
00236                 if (l->free >= MAX_SORTED_ENTRIES) {
00237                     /* Make the best of a bad situation. */
00238                     E_WARN("sorted list overflow (%d => %d)\n",
00239                            *val, l->list[i].val);
00240                     return i;
00241                 }
00242 
00243                 l->list[i].higher = l->free;
00244                 (l->free)++;
00245                 i = l->list[i].higher;
00246                 l->list[i].val.l = *val;
00247                 return (i);
00248             }
00249             else
00250                 i = l->list[i].higher;
00251         }
00252     }
00253 }
00254