00001
00002
00003
00004
00005
00006
00007
00008
00009
00010
00011
00012
00013
00014
00015
00016
00017
00018
00019
00020
00021
00022
00023
00024
00025
00026
00027
00028
00029
00030
00031
00032
00033
00034
00035
00036
00037
00042
00043 #include <assert.h>
00044 #include <string.h>
00045 #include <math.h>
00046
00047
00048 #include <ckd_alloc.h>
00049 #include <listelem_alloc.h>
00050 #include <strfuncs.h>
00051 #include <pio.h>
00052
00053
00054 #include "pocketsphinx_internal.h"
00055 #include "ps_lattice_internal.h"
00056 #include "ngram_search.h"
00057
00058
00059
00060
00061
00062 void
00063 ps_lattice_link(ps_lattice_t *dag, ps_latnode_t *from, ps_latnode_t *to, int32 score, int32 ef)
00064 {
00065 latlink_list_t *fwdlink;
00066
00067
00068 for (fwdlink = from->exits; fwdlink; fwdlink = fwdlink->next)
00069 if (fwdlink->link->to == to)
00070 break;
00071
00072 if (fwdlink == NULL) {
00073 latlink_list_t *revlink;
00074 ps_latlink_t *link;
00075
00076
00077 link = listelem_malloc(dag->latlink_alloc);
00078 fwdlink = listelem_malloc(dag->latlink_list_alloc);
00079 revlink = listelem_malloc(dag->latlink_list_alloc);
00080
00081 link->from = from;
00082 link->to = to;
00083 link->ascr = score;
00084 link->ef = ef;
00085 link->best_prev = NULL;
00086
00087 fwdlink->link = revlink->link = link;
00088 fwdlink->next = from->exits;
00089 from->exits = fwdlink;
00090 revlink->next = to->entries;
00091 to->entries = revlink;
00092 }
00093 else {
00094
00095 if (fwdlink->link->ascr < score) {
00096 fwdlink->link->ascr = score;
00097 fwdlink->link->ef = ef;
00098 }
00099 }
00100 }
00101
00102 void
00103 ps_lattice_bypass_fillers(ps_lattice_t *dag, int32 silpen, int32 fillpen)
00104 {
00105 ps_latnode_t *node;
00106 int32 score;
00107
00108
00109 for (node = dag->nodes; node; node = node->next) {
00110 latlink_list_t *revlink;
00111 if (node == dag->end || !ISA_FILLER_WORD(dag->search, node->basewid))
00112 continue;
00113
00114
00115 for (revlink = node->entries; revlink; revlink = revlink->next) {
00116 latlink_list_t *forlink;
00117 ps_latlink_t *rlink = revlink->link;
00118
00119 score = (node->basewid == ps_search_silence_wid(dag->search)) ? silpen : fillpen;
00120 score += rlink->ascr;
00121
00122
00123
00124
00125
00126 for (forlink = node->exits; forlink; forlink = forlink->next) {
00127 ps_latlink_t *flink = forlink->link;
00128 if (!ISA_FILLER_WORD(dag->search, flink->to->basewid)) {
00129 ps_lattice_link(dag, rlink->from, flink->to,
00130 score + flink->ascr, flink->ef);
00131 }
00132 }
00133 }
00134 }
00135 }
00136
00137 static void
00138 delete_node(ps_lattice_t *dag, ps_latnode_t *node)
00139 {
00140 latlink_list_t *x, *next_x;
00141
00142 for (x = node->exits; x; x = next_x) {
00143 next_x = x->next;
00144 x->link->from = NULL;
00145 listelem_free(dag->latlink_list_alloc, x);
00146 }
00147 for (x = node->entries; x; x = next_x) {
00148 next_x = x->next;
00149 x->link->to = NULL;
00150 listelem_free(dag->latlink_list_alloc, x);
00151 }
00152 listelem_free(dag->latnode_alloc, node);
00153 }
00154
00155 static void
00156 remove_dangling_links(ps_lattice_t *dag, ps_latnode_t *node)
00157 {
00158 latlink_list_t *x, *prev_x, *next_x;
00159
00160 prev_x = NULL;
00161 for (x = node->exits; x; x = next_x) {
00162 next_x = x->next;
00163 if (x->link->to == NULL) {
00164 if (prev_x)
00165 prev_x->next = next_x;
00166 else
00167 node->exits = next_x;
00168 listelem_free(dag->latlink_alloc, x->link);
00169 listelem_free(dag->latlink_list_alloc, x);
00170 }
00171 else
00172 prev_x = x;
00173 }
00174 prev_x = NULL;
00175 for (x = node->entries; x; x = next_x) {
00176 next_x = x->next;
00177 if (x->link->from == NULL) {
00178 if (prev_x)
00179 prev_x->next = next_x;
00180 else
00181 node->exits = next_x;
00182 listelem_free(dag->latlink_alloc, x->link);
00183 listelem_free(dag->latlink_list_alloc, x);
00184 }
00185 else
00186 prev_x = x;
00187 }
00188 }
00189
00190 void
00191 ps_lattice_delete_unreachable(ps_lattice_t *dag)
00192 {
00193 ps_latnode_t *node, *prev_node, *next_node;
00194 int i;
00195
00196
00197 prev_node = NULL;
00198 for (node = dag->nodes; node; node = next_node) {
00199 next_node = node->next;
00200 if (!node->reachable) {
00201 if (prev_node)
00202 prev_node->next = next_node;
00203 else
00204 dag->nodes = next_node;
00205
00206 delete_node(dag, node);
00207 }
00208 else
00209 prev_node = node;
00210 }
00211
00212
00213 i = 0;
00214 for (node = dag->nodes; node; node = node->next) {
00215
00216 node->id = i++;
00217
00218
00219 assert(node->reachable);
00220
00221
00222 remove_dangling_links(dag, node);
00223 }
00224 }
00225
00226 int32
00227 ps_lattice_write(ps_lattice_t *dag, char const *filename)
00228 {
00229 FILE *fp;
00230 int32 i;
00231 ps_latnode_t *d, *initial, *final;
00232
00233 initial = dag->start;
00234 final = dag->end;
00235
00236 E_INFO("Writing lattice file: %s\n", filename);
00237 if ((fp = fopen(filename, "w")) == NULL) {
00238 E_ERROR("fopen(%s,w) failed\n", filename);
00239 return -1;
00240 }
00241
00242
00243 fprintf(fp, "# getcwd: /this/is/bogus\n");
00244 fprintf(fp, "# -logbase %e\n", logmath_get_base(dag->lmath));
00245 fprintf(fp, "#\n");
00246
00247 fprintf(fp, "Frames %d\n", dag->n_frames);
00248 fprintf(fp, "#\n");
00249
00250 for (i = 0, d = dag->nodes; d; d = d->next, i++);
00251 fprintf(fp,
00252 "Nodes %d (NODEID WORD STARTFRAME FIRST-ENDFRAME LAST-ENDFRAME)\n",
00253 i);
00254 for (i = 0, d = dag->nodes; d; d = d->next, i++) {
00255 d->id = i;
00256 fprintf(fp, "%d %s %d %d %d\n",
00257 i, dict_word_str(ps_search_dict(dag->search), d->wid),
00258 d->sf, d->fef, d->lef);
00259 }
00260 fprintf(fp, "#\n");
00261
00262 fprintf(fp, "Initial %d\nFinal %d\n", initial->id, final->id);
00263 fprintf(fp, "#\n");
00264
00265
00266 fprintf(fp, "BestSegAscr %d (NODEID ENDFRAME ASCORE)\n",
00267 0 );
00268 fprintf(fp, "#\n");
00269
00270 fprintf(fp, "Edges (FROM-NODEID TO-NODEID ASCORE)\n");
00271 for (d = dag->nodes; d; d = d->next) {
00272 latlink_list_t *l;
00273 for (l = d->exits; l; l = l->next)
00274 fprintf(fp, "%d %d %d\n",
00275 d->id, l->link->to->id, l->link->ascr);
00276 }
00277 fprintf(fp, "End\n");
00278 fclose(fp);
00279
00280 return 0;
00281 }
00282
00283
00284 static int
00285 dag_param_read(lineiter_t *li, char *param)
00286 {
00287 int32 n;
00288
00289 while ((li = lineiter_next(li)) != NULL) {
00290 char *c;
00291
00292
00293 if (li->buf[0] == '#')
00294 continue;
00295
00296
00297 c = strchr(li->buf, ' ');
00298 if (c == NULL) continue;
00299
00300
00301 if (strncmp(li->buf, param, strlen(param)) == 0
00302 && sscanf(c + 1, "%d", &n) == 1)
00303 return n;
00304 }
00305 return -1;
00306 }
00307
00308
00309 static void
00310 dag_mark_reachable(ps_latnode_t * d)
00311 {
00312 latlink_list_t *l;
00313
00314 d->reachable = 1;
00315 for (l = d->entries; l; l = l->next)
00316 if (l->link->from && !l->link->from->reachable)
00317 dag_mark_reachable(l->link->from);
00318 }
00319
00320 ps_lattice_t *
00321 ps_lattice_read(ps_decoder_t *ps,
00322 char const *file)
00323 {
00324 FILE *fp;
00325 int32 ispipe;
00326 lineiter_t *line;
00327 float64 lb;
00328 float32 logratio;
00329 ps_latnode_t *tail;
00330 ps_latnode_t **darray;
00331 ps_lattice_t *dag;
00332 int i, k, n_nodes;
00333 int32 pip, silpen, fillpen;
00334
00335 dag = ckd_calloc(1, sizeof(*dag));
00336 dag->search = ps->search;
00337 dag->lmath = logmath_retain(ps->lmath);
00338 dag->latnode_alloc = listelem_alloc_init(sizeof(ps_latnode_t));
00339 dag->latlink_alloc = listelem_alloc_init(sizeof(ps_latlink_t));
00340 dag->latlink_list_alloc = listelem_alloc_init(sizeof(latlink_list_t));
00341 dag->refcount = 1;
00342
00343 tail = NULL;
00344 darray = NULL;
00345
00346 E_INFO("Reading DAG file: %s\n", file);
00347 if ((fp = fopen_compchk(file, &ispipe)) == NULL) {
00348 E_ERROR("fopen_compchk(%s) failed\n", file);
00349 return NULL;
00350 }
00351 line = lineiter_start(fp);
00352
00353
00354 if (line == NULL) {
00355 E_ERROR("Premature EOF(%s)\n", file);
00356 goto load_error;
00357 }
00358 if (strncmp(line->buf, "# getcwd: ", 10) != 0) {
00359 E_ERROR("%s does not begin with '# getcwd: '\n%s", file, line->buf);
00360 goto load_error;
00361 }
00362 if ((line = lineiter_next(line)) == NULL) {
00363 E_ERROR("Premature EOF(%s)\n", file);
00364 goto load_error;
00365 }
00366 if ((strncmp(line->buf, "# -logbase ", 11) != 0)
00367 || (sscanf(line->buf + 11, "%lf", &lb) != 1)) {
00368 E_WARN("%s: Cannot find -logbase in header\n", file);
00369 lb = 1.0001;
00370 }
00371 logratio = 1.0f;
00372 if (dag->lmath == NULL)
00373 dag->lmath = logmath_init(lb, 0, TRUE);
00374 else {
00375 float32 pb = logmath_get_base(dag->lmath);
00376 if (fabs(lb - pb) >= 0.0001) {
00377 E_WARN("Inconsistent logbases: %f vs %f: will compensate\n", lb, pb);
00378 logratio = (float32)(log(lb) / log(pb));
00379 E_INFO("Lattice log ratio: %f\n", logratio);
00380 }
00381 }
00382
00383 dag->n_frames = dag_param_read(line, "Frames");
00384 if (dag->n_frames <= 0) {
00385 E_ERROR("Frames parameter missing or invalid\n");
00386 goto load_error;
00387 }
00388
00389 n_nodes = dag_param_read(line, "Nodes");
00390 if (n_nodes <= 0) {
00391 E_ERROR("Nodes parameter missing or invalid\n");
00392 goto load_error;
00393 }
00394
00395
00396 darray = ckd_calloc(n_nodes, sizeof(*darray));
00397 for (i = 0; i < n_nodes; i++) {
00398 ps_latnode_t *d;
00399 int32 w;
00400 int seqid, sf, fef, lef;
00401 char wd[256];
00402
00403 if ((line = lineiter_next(line)) == NULL) {
00404 E_ERROR("Premature EOF while loading Nodes(%s)\n", file);
00405 goto load_error;
00406 }
00407
00408 if ((k =
00409 sscanf(line->buf, "%d %255s %d %d %d", &seqid, wd, &sf, &fef,
00410 &lef)) != 5) {
00411 E_ERROR("Cannot parse line: %s, value of count %d\n", line->buf, k);
00412 goto load_error;
00413 }
00414
00415 w = dict_to_id(ps->dict, wd);
00416 if (w < 0) {
00417 E_ERROR("Unknown word in line: %s\n", line->buf);
00418 goto load_error;
00419 }
00420
00421 if (seqid != i) {
00422 E_ERROR("Seqno error: %s\n", line->buf);
00423 goto load_error;
00424 }
00425
00426 d = listelem_malloc(dag->latnode_alloc);
00427 darray[i] = d;
00428 d->wid = w;
00429 d->basewid = dict_base_wid(ps->dict, w);
00430 d->id = seqid;
00431 d->sf = sf;
00432 d->fef = fef;
00433 d->lef = lef;
00434 d->reachable = 0;
00435 d->exits = d->entries = NULL;
00436 d->next = NULL;
00437
00438 if (!dag->nodes)
00439 dag->nodes = d;
00440 else
00441 tail->next = d;
00442 tail = d;
00443 }
00444
00445
00446 k = dag_param_read(line, "Initial");
00447 if ((k < 0) || (k >= n_nodes)) {
00448 E_ERROR("Initial node parameter missing or invalid\n");
00449 goto load_error;
00450 }
00451 dag->start = darray[k];
00452
00453
00454 k = dag_param_read(line, "Final");
00455 if ((k < 0) || (k >= n_nodes)) {
00456 E_ERROR("Final node parameter missing or invalid\n");
00457 goto load_error;
00458 }
00459 dag->end = darray[k];
00460
00461
00462 if ((k = dag_param_read(line, "BestSegAscr")) < 0) {
00463 E_ERROR("BestSegAscr parameter missing\n");
00464 goto load_error;
00465 }
00466 for (i = 0; i < k; i++) {
00467 if ((line = lineiter_next(line)) == NULL) {
00468 E_ERROR("Premature EOF while (%s) ignoring BestSegAscr\n",
00469 line);
00470 goto load_error;
00471 }
00472 }
00473
00474
00475 while ((line = lineiter_next(line)) != NULL) {
00476 if (line->buf[0] == '#')
00477 continue;
00478 if (0 == strncmp(line->buf, "Edges", 5))
00479 break;
00480 }
00481 if (line == NULL) {
00482 E_ERROR("Edges missing\n");
00483 goto load_error;
00484 }
00485 while ((line = lineiter_next(line)) != NULL) {
00486 int from, to, ascr;
00487 ps_latnode_t *pd, *d;
00488
00489 if (sscanf(line->buf, "%d %d %d", &from, &to, &ascr) != 3)
00490 break;
00491 pd = darray[from];
00492 d = darray[to];
00493 if (logratio != 1.0f)
00494 ascr = (int32)(ascr * logratio);
00495 ps_lattice_link(dag, pd, d, ascr, d->sf - 1);
00496 }
00497 if (strcmp(line->buf, "End\n") != 0) {
00498 E_ERROR("Terminating 'End' missing\n");
00499 goto load_error;
00500 }
00501 lineiter_free(line);
00502 fclose_comp(fp, ispipe);
00503 ckd_free(darray);
00504
00505
00506
00507
00508 if (ISA_FILLER_WORD(dag->search, dag->end->wid))
00509 dag->end->basewid = ps_search_finish_wid(dag->search);
00510
00511
00512 dag_mark_reachable(dag->end);
00513
00514
00515 ps_lattice_delete_unreachable(dag);
00516
00517
00518
00519 pip = logmath_log(dag->lmath, cmd_ln_float32_r(ps->config, "-pip"));
00520 silpen = pip + logmath_log(dag->lmath,
00521 cmd_ln_float32_r(ps->config, "-silprob"));
00522 fillpen = pip + logmath_log(dag->lmath,
00523 cmd_ln_float32_r(ps->config, "-fillprob"));
00524 ps_lattice_bypass_fillers(dag, silpen, fillpen);
00525
00526 return dag;
00527
00528 load_error:
00529 E_ERROR("Failed to load %s\n", file);
00530 lineiter_free(line);
00531 if (fp) fclose_comp(fp, ispipe);
00532 ckd_free(darray);
00533 return NULL;
00534 }
00535
00536 int
00537 ps_lattice_n_frames(ps_lattice_t *dag)
00538 {
00539 return dag->n_frames;
00540 }
00541
00542 ps_lattice_t *
00543 ps_lattice_init_search(ps_search_t *search, int n_frame)
00544 {
00545 ps_lattice_t *dag;
00546
00547 dag = ckd_calloc(1, sizeof(*dag));
00548 dag->search = search;
00549 dag->lmath = logmath_retain(search->acmod->lmath);
00550 dag->n_frames = n_frame;
00551 dag->latnode_alloc = listelem_alloc_init(sizeof(ps_latnode_t));
00552 dag->latlink_alloc = listelem_alloc_init(sizeof(ps_latlink_t));
00553 dag->latlink_list_alloc = listelem_alloc_init(sizeof(latlink_list_t));
00554 dag->refcount = 1;
00555 return dag;
00556 }
00557
00558 ps_lattice_t *
00559 ps_lattice_retain(ps_lattice_t *dag)
00560 {
00561 ++dag->refcount;
00562 return dag;
00563 }
00564
00565 int
00566 ps_lattice_free(ps_lattice_t *dag)
00567 {
00568 if (dag == NULL)
00569 return 0;
00570 if (--dag->refcount > 0)
00571 return dag->refcount;
00572 logmath_free(dag->lmath);
00573 listelem_alloc_free(dag->latnode_alloc);
00574 listelem_alloc_free(dag->latlink_alloc);
00575 listelem_alloc_free(dag->latlink_list_alloc);
00576 ckd_free(dag->hyp_str);
00577 ckd_free(dag);
00578 return 0;
00579 }
00580
00581 logmath_t *
00582 ps_lattice_get_logmath(ps_lattice_t *dag)
00583 {
00584 return dag->lmath;
00585 }
00586
00587 ps_latnode_iter_t *
00588 ps_latnode_iter(ps_lattice_t *dag)
00589 {
00590 return dag->nodes;
00591 }
00592
00593 ps_latnode_iter_t *
00594 ps_latnode_iter_next(ps_latnode_iter_t *itor)
00595 {
00596 return itor->next;
00597 }
00598
00599 void
00600 ps_latnode_iter_free(ps_latnode_iter_t *itor)
00601 {
00602
00603 }
00604
00605 ps_latnode_t *
00606 ps_latnode_iter_node(ps_latnode_iter_t *itor)
00607 {
00608 return itor;
00609 }
00610
00611 int
00612 ps_latnode_times(ps_latnode_t *node, int16 *out_fef, int16 *out_lef)
00613 {
00614 if (out_fef) *out_fef = (int16)node->fef;
00615 if (out_lef) *out_lef = (int16)node->lef;
00616 return node->sf;
00617 }
00618
00619 char const *
00620 ps_latnode_word(ps_lattice_t *dag, ps_latnode_t *node)
00621 {
00622 return dict_word_str(ps_search_dict(dag->search), node->wid);
00623 }
00624
00625 char const *
00626 ps_latnode_baseword(ps_lattice_t *dag, ps_latnode_t *node)
00627 {
00628 return dict_word_str(ps_search_dict(dag->search), node->basewid);
00629 }
00630
00631 int32
00632 ps_latnode_prob(ps_lattice_t *dag, ps_latnode_t *node,
00633 ps_latlink_t **out_link)
00634 {
00635 latlink_list_t *links;
00636 int32 bestpost = logmath_get_zero(dag->lmath);
00637
00638 for (links = node->exits; links; links = links->next) {
00639 int32 post = links->link->alpha + links->link->beta - dag->norm;
00640 if (post > bestpost) {
00641 if (out_link) *out_link = links->link;
00642 bestpost = post;
00643 }
00644 }
00645 return bestpost;
00646 }
00647
00648 ps_latlink_iter_t *
00649 ps_latnode_exits(ps_latnode_t *node)
00650 {
00651 return node->exits;
00652 }
00653
00654 ps_latlink_iter_t *
00655 ps_latnode_entries(ps_latnode_t *node)
00656 {
00657 return node->entries;
00658 }
00659
00660 ps_latlink_iter_t *
00661 ps_latlink_iter_next(ps_latlink_iter_t *itor)
00662 {
00663 return itor->next;
00664 }
00665
00666 void
00667 ps_latlink_iter_free(ps_latlink_iter_t *itor)
00668 {
00669
00670 }
00671
00672 ps_latlink_t *
00673 ps_latlink_iter_link(ps_latlink_iter_t *itor)
00674 {
00675 return itor->link;
00676 }
00677
00678 int
00679 ps_latlink_times(ps_latlink_t *link, int16 *out_sf)
00680 {
00681 if (out_sf) {
00682 if (link->from) {
00683 *out_sf = link->from->sf;
00684 }
00685 else {
00686 *out_sf = 0;
00687 }
00688 }
00689 return link->ef;
00690 }
00691
00692 ps_latnode_t *
00693 ps_latlink_nodes(ps_latlink_t *link, ps_latnode_t **out_src)
00694 {
00695 if (out_src) *out_src = link->from;
00696 return link->to;
00697 }
00698
00699 char const *
00700 ps_latlink_word(ps_lattice_t *dag, ps_latlink_t *link)
00701 {
00702 if (link->from == NULL)
00703 return NULL;
00704 return dict_word_str(ps_search_dict(dag->search), link->from->wid);
00705 }
00706
00707 char const *
00708 ps_latlink_baseword(ps_lattice_t *dag, ps_latlink_t *link)
00709 {
00710 if (link->from == NULL)
00711 return NULL;
00712 return dict_word_str(ps_search_dict(dag->search), link->from->basewid);
00713 }
00714
00715 ps_latlink_t *
00716 ps_latlink_pred(ps_latlink_t *link)
00717 {
00718 return link->best_prev;
00719 }
00720
00721 int32
00722 ps_latlink_prob(ps_lattice_t *dag, ps_latlink_t *link, int32 *out_ascr)
00723 {
00724 int32 post = link->alpha + link->beta - dag->norm;
00725 if (out_ascr) *out_ascr = link->ascr;
00726 return post;
00727 }
00728
00729 char const *
00730 ps_lattice_hyp(ps_lattice_t *dag, ps_latlink_t *link)
00731 {
00732 ps_latlink_t *l;
00733 size_t len;
00734 char *c;
00735
00736
00737 len = 0;
00738 if (ISA_REAL_WORD(dag->search, link->to->basewid))
00739 len += strlen(dict_word_str(ps_search_dict(dag->search), link->to->basewid)) + 1;
00740 for (l = link; l; l = l->best_prev) {
00741 if (ISA_REAL_WORD(dag->search, l->from->basewid))
00742 len += strlen(dict_word_str(ps_search_dict(dag->search), l->from->basewid)) + 1;
00743 }
00744
00745
00746 ckd_free(dag->hyp_str);
00747 dag->hyp_str = ckd_calloc(1, len);
00748 c = dag->hyp_str + len - 1;
00749 if (ISA_REAL_WORD(dag->search, link->to->basewid)) {
00750 len = strlen(dict_word_str(ps_search_dict(dag->search), link->to->basewid));
00751 c -= len;
00752 memcpy(c, dict_word_str(ps_search_dict(dag->search), link->to->basewid), len);
00753 if (c > dag->hyp_str) {
00754 --c;
00755 *c = ' ';
00756 }
00757 }
00758 for (l = link; l; l = l->best_prev) {
00759 if (ISA_REAL_WORD(dag->search, l->from->basewid)) {
00760 len = strlen(dict_word_str(ps_search_dict(dag->search), l->from->basewid));
00761 c -= len;
00762 memcpy(c, dict_word_str(ps_search_dict(dag->search), l->from->basewid), len);
00763 if (c > dag->hyp_str) {
00764 --c;
00765 *c = ' ';
00766 }
00767 }
00768 }
00769
00770 return dag->hyp_str;
00771 }
00772
00773 static void
00774 ps_lattice_compute_lscr(ps_seg_t *seg, ps_latlink_t *link, int to)
00775 {
00776 ngram_model_t *lmset;
00777
00778
00779
00780 if (0 != strcmp(ps_search_name(seg->search), "ngram")) {
00781 seg->lback = 1;
00782 seg->lscr = 0;
00783 return;
00784 }
00785
00786 lmset = ((ngram_search_t *)seg->search)->lmset;
00787
00788 if (link->best_prev == NULL) {
00789 if (to)
00790 seg->lscr = ngram_bg_score(lmset, link->to->basewid,
00791 link->from->basewid, &seg->lback);
00792 else {
00793 seg->lscr = 0;
00794 seg->lback = 1;
00795 }
00796 }
00797 else {
00798
00799 if (to) {
00800 seg->lscr = ngram_tg_score(lmset, link->to->basewid,
00801 link->from->basewid,
00802 link->best_prev->from->basewid,
00803 &seg->lback);
00804 }
00805 else {
00806 if (link->best_prev->best_prev)
00807 seg->lscr = ngram_tg_score(lmset, link->from->basewid,
00808 link->best_prev->from->basewid,
00809 link->best_prev->best_prev->from->basewid,
00810 &seg->lback);
00811 else
00812 seg->lscr = ngram_bg_score(lmset, link->from->basewid,
00813 link->best_prev->from->basewid,
00814 &seg->lback);
00815 }
00816 }
00817 }
00818
00819 static void
00820 ps_lattice_link2itor(ps_seg_t *seg, ps_latlink_t *link, int to)
00821 {
00822 dag_seg_t *itor = (dag_seg_t *)seg;
00823 ps_latnode_t *node;
00824
00825 if (to) {
00826 node = link->to;
00827 seg->ef = node->lef;
00828 seg->prob = 0;
00829 }
00830 else {
00831 latlink_list_t *x;
00832 ps_latnode_t *n;
00833 logmath_t *lmath = ps_search_acmod(seg->search)->lmath;
00834
00835 node = link->from;
00836 seg->ef = link->ef;
00837 seg->prob = link->alpha + link->beta - itor->norm;
00838
00839
00840 for (n = node; n; n = n->alt) {
00841 for (x = n->exits; x; x = x->next) {
00842 if (x->link == link)
00843 continue;
00844 seg->prob = logmath_add(lmath, seg->prob,
00845 x->link->alpha + x->link->beta - itor->norm);
00846 }
00847 }
00848 }
00849 seg->word = dict_word_str(ps_search_dict(seg->search), node->wid);
00850 seg->sf = node->sf;
00851 seg->ascr = link->ascr;
00852
00853 ps_lattice_compute_lscr(seg, link, to);
00854 }
00855
00856 static void
00857 ps_lattice_seg_free(ps_seg_t *seg)
00858 {
00859 dag_seg_t *itor = (dag_seg_t *)seg;
00860
00861 ckd_free(itor->links);
00862 ckd_free(itor);
00863 }
00864
00865 static ps_seg_t *
00866 ps_lattice_seg_next(ps_seg_t *seg)
00867 {
00868 dag_seg_t *itor = (dag_seg_t *)seg;
00869
00870 ++itor->cur;
00871 if (itor->cur == itor->n_links + 1) {
00872 ps_lattice_seg_free(seg);
00873 return NULL;
00874 }
00875 else if (itor->cur == itor->n_links) {
00876
00877 ps_lattice_link2itor(seg, itor->links[itor->cur - 1], TRUE);
00878 }
00879 else {
00880 ps_lattice_link2itor(seg, itor->links[itor->cur], FALSE);
00881 }
00882
00883 return seg;
00884 }
00885
00886 static ps_segfuncs_t ps_lattice_segfuncs = {
00887 ps_lattice_seg_next,
00888 ps_lattice_seg_free
00889 };
00890
00891 ps_seg_t *
00892 ps_lattice_seg_iter(ps_lattice_t *dag, ps_latlink_t *link, float32 lwf)
00893 {
00894 dag_seg_t *itor;
00895 ps_latlink_t *l;
00896 int cur;
00897
00898
00899
00900
00901 itor = ckd_calloc(1, sizeof(*itor));
00902 itor->base.vt = &ps_lattice_segfuncs;
00903 itor->base.search = dag->search;
00904 itor->base.lwf = lwf;
00905 itor->n_links = 0;
00906 itor->norm = dag->norm;
00907
00908 for (l = link; l; l = l->best_prev) {
00909 ++itor->n_links;
00910 }
00911 if (itor->n_links == 0) {
00912 ckd_free(itor);
00913 return NULL;
00914 }
00915
00916 itor->links = ckd_calloc(itor->n_links, sizeof(*itor->links));
00917 cur = itor->n_links - 1;
00918 for (l = link; l; l = l->best_prev) {
00919 itor->links[cur] = l;
00920 --cur;
00921 }
00922
00923 ps_lattice_link2itor((ps_seg_t *)itor, itor->links[0], FALSE);
00924 return (ps_seg_t *)itor;
00925 }
00926
00927 latlink_list_t *
00928 latlink_list_new(ps_lattice_t *dag, ps_latlink_t *link, latlink_list_t *next)
00929 {
00930 latlink_list_t *ll;
00931
00932 ll = listelem_malloc(dag->latlink_list_alloc);
00933 ll->link = link;
00934 ll->next = next;
00935
00936 return ll;
00937 }
00938
00939 void
00940 ps_lattice_pushq(ps_lattice_t *dag, ps_latlink_t *link)
00941 {
00942 if (dag->q_head == NULL)
00943 dag->q_head = dag->q_tail = latlink_list_new(dag, link, NULL);
00944 else {
00945 dag->q_tail->next = latlink_list_new(dag, link, NULL);
00946 dag->q_tail = dag->q_tail->next;
00947 }
00948
00949 }
00950
00951 ps_latlink_t *
00952 ps_lattice_popq(ps_lattice_t *dag)
00953 {
00954 latlink_list_t *x;
00955 ps_latlink_t *link;
00956
00957 if (dag->q_head == NULL)
00958 return NULL;
00959 link = dag->q_head->link;
00960 x = dag->q_head->next;
00961 listelem_free(dag->latlink_list_alloc, dag->q_head);
00962 dag->q_head = x;
00963 if (dag->q_head == NULL)
00964 dag->q_tail = NULL;
00965 return link;
00966 }
00967
00968 void
00969 ps_lattice_delq(ps_lattice_t *dag)
00970 {
00971 while (ps_lattice_popq(dag)) {
00972
00973 }
00974 }
00975
00976 ps_latlink_t *
00977 ps_lattice_traverse_edges(ps_lattice_t *dag, ps_latnode_t *start, ps_latnode_t *end)
00978 {
00979 ps_latnode_t *node;
00980 latlink_list_t *x;
00981
00982
00983 ps_lattice_delq(dag);
00984
00985
00986 for (node = dag->nodes; node; node = node->next)
00987 node->info.fanin = 0;
00988 for (node = dag->nodes; node; node = node->next) {
00989 for (x = node->exits; x; x = x->next)
00990 (x->link->to->info.fanin)++;
00991 }
00992
00993
00994 if (start == NULL) start = dag->start;
00995 for (x = start->exits; x; x = x->next)
00996 ps_lattice_pushq(dag, x->link);
00997
00998
00999 return ps_lattice_traverse_next(dag, end);
01000 }
01001
01002 ps_latlink_t *
01003 ps_lattice_traverse_next(ps_lattice_t *dag, ps_latnode_t *end)
01004 {
01005 ps_latlink_t *next;
01006
01007 next = ps_lattice_popq(dag);
01008 if (next == NULL)
01009 return NULL;
01010
01011
01012
01013 --next->to->info.fanin;
01014 if (next->to->info.fanin == 0) {
01015 latlink_list_t *x;
01016
01017 if (end == NULL) end = dag->end;
01018 if (next->to == end) {
01019
01020
01021
01022 ps_lattice_delq(dag);
01023 return next;
01024 }
01025
01026
01027 for (x = next->to->exits; x; x = x->next)
01028 ps_lattice_pushq(dag, x->link);
01029 }
01030 return next;
01031 }
01032
01033 ps_latlink_t *
01034 ps_lattice_reverse_edges(ps_lattice_t *dag, ps_latnode_t *start, ps_latnode_t *end)
01035 {
01036 ps_latnode_t *node;
01037 latlink_list_t *x;
01038
01039
01040 ps_lattice_delq(dag);
01041
01042
01043 for (node = dag->nodes; node; node = node->next) {
01044 node->info.fanin = 0;
01045 for (x = node->exits; x; x = x->next)
01046 ++node->info.fanin;
01047 }
01048
01049
01050 if (end == NULL) end = dag->end;
01051 for (x = end->entries; x; x = x->next)
01052 ps_lattice_pushq(dag, x->link);
01053
01054
01055 return ps_lattice_reverse_next(dag, start);
01056 }
01057
01058 ps_latlink_t *
01059 ps_lattice_reverse_next(ps_lattice_t *dag, ps_latnode_t *start)
01060 {
01061 ps_latlink_t *next;
01062
01063 next = ps_lattice_popq(dag);
01064 if (next == NULL)
01065 return NULL;
01066
01067
01068
01069 --next->from->info.fanin;
01070 if (next->from->info.fanin == 0) {
01071 latlink_list_t *x;
01072
01073 if (start == NULL) start = dag->start;
01074 if (next->from == start) {
01075
01076
01077
01078 ps_lattice_delq(dag);
01079 return next;
01080 }
01081
01082
01083 for (x = next->from->entries; x; x = x->next)
01084 ps_lattice_pushq(dag, x->link);
01085 }
01086 return next;
01087 }
01088
01089
01090
01091
01092
01093
01094
01095
01096
01097
01098
01099
01100 ps_latlink_t *
01101 ps_lattice_bestpath(ps_lattice_t *dag, ngram_model_t *lmset,
01102 float32 lwf, float32 ascale)
01103 {
01104 ps_search_t *search;
01105 ps_latnode_t *node;
01106 ps_latlink_t *link;
01107 ps_latlink_t *bestend;
01108 latlink_list_t *x;
01109 logmath_t *lmath;
01110 int32 bestescr;
01111
01112 search = dag->search;
01113 lmath = dag->lmath;
01114
01115
01116
01117
01118 for (node = dag->nodes; node; node = node->next) {
01119 for (x = node->exits; x; x = x->next) {
01120 x->link->path_scr = MAX_NEG_INT32;
01121 x->link->alpha = logmath_get_zero(lmath);
01122 }
01123 }
01124 for (x = dag->start->exits; x; x = x->next) {
01125 int32 n_used;
01126
01127
01128 if (ISA_FILLER_WORD(search, x->link->to->basewid)
01129 && x->link->to != dag->end)
01130 continue;
01131
01132
01133 if (lmset)
01134 x->link->path_scr = x->link->ascr +
01135 ngram_bg_score(lmset, x->link->to->basewid,
01136 ps_search_start_wid(search), &n_used) * lwf;
01137 else
01138 x->link->path_scr = x->link->ascr;
01139 x->link->best_prev = NULL;
01140
01141 x->link->alpha = 0;
01142 }
01143
01144
01145 for (link = ps_lattice_traverse_edges(dag, NULL, NULL);
01146 link; link = ps_lattice_traverse_next(dag, NULL)) {
01147 int32 bprob, n_used;
01148
01149
01150 if (ISA_FILLER_WORD(search, link->from->basewid) && link->from != dag->start)
01151 continue;
01152 if (ISA_FILLER_WORD(search, link->to->basewid) && link->to != dag->end)
01153 continue;
01154
01155
01156
01157 assert(link->path_scr != MAX_NEG_INT32);
01158
01159
01160 if (lmset)
01161 bprob = ngram_ng_prob(lmset,
01162 link->to->basewid,
01163 &link->from->basewid, 1, &n_used);
01164 else
01165 bprob = 0;
01166
01167
01168 link->alpha += link->ascr * ascale;
01169
01170
01171 for (x = link->to->exits; x; x = x->next) {
01172 int32 tscore, score;
01173
01174
01175 if (ISA_FILLER_WORD(search, x->link->to->basewid)
01176 && x->link->to != dag->end)
01177 continue;
01178
01179
01180 x->link->alpha = logmath_add(lmath, x->link->alpha, link->alpha + bprob);
01181
01182 if (lmset)
01183 tscore = ngram_tg_score(lmset, x->link->to->basewid,
01184 link->to->basewid,
01185 link->from->basewid, &n_used) * lwf;
01186 else
01187 tscore = 0;
01188
01189 score = link->path_scr + tscore + x->link->ascr;
01190 if (score > x->link->path_scr) {
01191 x->link->path_scr = score;
01192 x->link->best_prev = link;
01193 }
01194 }
01195 }
01196
01197
01198
01199 bestend = NULL;
01200 bestescr = MAX_NEG_INT32;
01201
01202
01203
01204 dag->norm = logmath_get_zero(lmath);
01205 for (x = dag->end->entries; x; x = x->next) {
01206 int32 bprob, n_used;
01207
01208 if (ISA_FILLER_WORD(search, x->link->from->basewid))
01209 continue;
01210 if (lmset)
01211 bprob = ngram_ng_prob(lmset,
01212 x->link->to->basewid,
01213 &x->link->from->basewid, 1, &n_used);
01214 else
01215 bprob = 0;
01216 dag->norm = logmath_add(lmath, dag->norm, x->link->alpha + bprob);
01217 if (x->link->path_scr > bestescr) {
01218 bestescr = x->link->path_scr;
01219 bestend = x->link;
01220 }
01221 }
01222
01223 dag->norm += (int32)dag->final_node_ascr * ascale;
01224
01225 E_INFO("Normalizer P(O) = alpha(%s:%d:%d) = %d\n",
01226 dict_word_str(dag->search->dict, dag->end->wid),
01227 dag->end->sf, dag->end->lef,
01228 dag->norm);
01229 return bestend;
01230 }
01231
01232 static int32
01233 ps_lattice_joint(ps_lattice_t *dag, ps_latlink_t *link, float32 ascale)
01234 {
01235 ngram_model_t *lmset;
01236 int32 jprob;
01237
01238
01239 if (dag->search && 0 == strcmp(ps_search_name(dag->search), "ngram"))
01240 lmset = ((ngram_search_t *)dag->search)->lmset;
01241 else
01242 lmset = NULL;
01243
01244 jprob = dag->final_node_ascr * ascale;
01245 while (link) {
01246 if (lmset) {
01247 int lback;
01248
01249
01250
01251
01252
01253 jprob += ngram_ng_prob(lmset, link->to->basewid,
01254 &link->from->basewid, 1, &lback);
01255 }
01256
01257
01258
01259 jprob += link->ascr * ascale;
01260 link = link->best_prev;
01261 }
01262
01263 E_INFO("Joint P(O,S) = %d P(S|O) = %d\n", jprob, jprob - dag->norm);
01264 return jprob;
01265 }
01266
01267 int32
01268 ps_lattice_posterior(ps_lattice_t *dag, ngram_model_t *lmset,
01269 float32 ascale)
01270 {
01271 ps_search_t *search;
01272 logmath_t *lmath;
01273 ps_latnode_t *node;
01274 ps_latlink_t *link;
01275 latlink_list_t *x;
01276 ps_latlink_t *bestend;
01277 int32 bestescr;
01278
01279 search = dag->search;
01280 lmath = dag->lmath;
01281
01282
01283 for (node = dag->nodes; node; node = node->next) {
01284 for (x = node->exits; x; x = x->next) {
01285 x->link->beta = logmath_get_zero(lmath);
01286 }
01287 }
01288
01289 bestend = NULL;
01290 bestescr = MAX_NEG_INT32;
01291
01292 for (link = ps_lattice_reverse_edges(dag, NULL, NULL);
01293 link; link = ps_lattice_reverse_next(dag, NULL)) {
01294 int32 bprob, n_used;
01295
01296
01297 if (ISA_FILLER_WORD(search, link->from->basewid) && link->from != dag->start)
01298 continue;
01299 if (ISA_FILLER_WORD(search, link->to->basewid) && link->to != dag->end)
01300 continue;
01301
01302
01303 if (lmset)
01304 bprob = ngram_ng_prob(lmset, link->to->basewid,
01305 &link->from->basewid, 1, &n_used);
01306 else
01307 bprob = 0;
01308
01309 if (link->to == dag->end) {
01310
01311
01312
01313 if (link->path_scr > bestescr) {
01314 bestescr = link->path_scr;
01315 bestend = link;
01316 }
01317
01318 link->beta = bprob + dag->final_node_ascr * ascale;
01319 }
01320 else {
01321
01322 for (x = link->to->exits; x; x = x->next) {
01323 if (ISA_FILLER_WORD(search, x->link->to->basewid) && x->link->to != dag->end)
01324 continue;
01325 link->beta = logmath_add(lmath, link->beta,
01326 x->link->beta + bprob + x->link->ascr * ascale);
01327 }
01328 }
01329 }
01330
01331
01332 return ps_lattice_joint(dag, bestend, ascale) - dag->norm;
01333 }
01334
01335
01336
01337 #define MAX_PATHS 500
01338 #define MAX_HYP_TRIES 10000
01339
01340
01341
01342
01343
01344
01345
01346 static int32
01347 best_rem_score(ps_astar_t *nbest, ps_latnode_t * from)
01348 {
01349 ps_lattice_t *dag;
01350 latlink_list_t *x;
01351 int32 bestscore, score;
01352
01353 dag = nbest->dag;
01354 if (from->info.rem_score <= 0)
01355 return (from->info.rem_score);
01356
01357
01358 bestscore = WORST_SCORE;
01359 for (x = from->exits; x; x = x->next) {
01360 int32 n_used;
01361
01362 score = best_rem_score(nbest, x->link->to);
01363 score += x->link->ascr;
01364 if (nbest->lmset)
01365 score += ngram_bg_score(nbest->lmset, x->link->to->basewid,
01366 from->basewid, &n_used) * nbest->lwf;
01367 if (score > bestscore)
01368 bestscore = score;
01369 }
01370 from->info.rem_score = bestscore;
01371
01372 return bestscore;
01373 }
01374
01375
01376
01377
01378
01379
01380 static void
01381 path_insert(ps_astar_t *nbest, ps_latpath_t *newpath, int32 total_score)
01382 {
01383 ps_lattice_t *dag;
01384 ps_latpath_t *prev, *p;
01385 int32 i;
01386
01387 dag = nbest->dag;
01388 prev = NULL;
01389 for (i = 0, p = nbest->path_list; (i < MAX_PATHS) && p; p = p->next, i++) {
01390 if ((p->score + p->node->info.rem_score) < total_score)
01391 break;
01392 prev = p;
01393 }
01394
01395
01396 if (i < MAX_PATHS) {
01397
01398 newpath->next = p;
01399 if (!prev)
01400 nbest->path_list = newpath;
01401 else
01402 prev->next = newpath;
01403 if (!p)
01404 nbest->path_tail = newpath;
01405
01406 nbest->n_path++;
01407 nbest->n_hyp_insert++;
01408 nbest->insert_depth += i;
01409 }
01410 else {
01411
01412 nbest->path_tail = prev;
01413 prev->next = NULL;
01414 nbest->n_path = MAX_PATHS;
01415 listelem_free(nbest->latpath_alloc, newpath);
01416
01417 nbest->n_hyp_reject++;
01418 for (; p; p = newpath) {
01419 newpath = p->next;
01420 listelem_free(nbest->latpath_alloc, p);
01421 nbest->n_hyp_reject++;
01422 }
01423 }
01424 }
01425
01426
01427 static void
01428 path_extend(ps_astar_t *nbest, ps_latpath_t * path)
01429 {
01430 latlink_list_t *x;
01431 ps_latpath_t *newpath;
01432 int32 total_score, tail_score;
01433 ps_lattice_t *dag;
01434
01435 dag = nbest->dag;
01436
01437
01438 for (x = path->node->exits; x; x = x->next) {
01439 int32 n_used;
01440
01441
01442 if (x->link->to->info.rem_score <= WORST_SCORE)
01443 continue;
01444
01445
01446 newpath = listelem_malloc(nbest->latpath_alloc);
01447 newpath->node = x->link->to;
01448 newpath->parent = path;
01449 newpath->score = path->score + x->link->ascr;
01450 if (nbest->lmset) {
01451 if (path->parent) {
01452 newpath->score += nbest->lwf
01453 * ngram_tg_score(nbest->lmset, newpath->node->basewid,
01454 path->node->basewid,
01455 path->parent->node->basewid, &n_used);
01456 }
01457 else
01458 newpath->score += nbest->lwf
01459 * ngram_bg_score(nbest->lmset, newpath->node->basewid,
01460 path->node->basewid, &n_used);
01461 }
01462
01463
01464 nbest->n_hyp_tried++;
01465 total_score = newpath->score + newpath->node->info.rem_score;
01466
01467
01468 if (nbest->n_path >= MAX_PATHS) {
01469 tail_score =
01470 nbest->path_tail->score
01471 + nbest->path_tail->node->info.rem_score;
01472 if (total_score < tail_score) {
01473 listelem_free(nbest->latpath_alloc, newpath);
01474 nbest->n_hyp_reject++;
01475 continue;
01476 }
01477 }
01478
01479 path_insert(nbest, newpath, total_score);
01480 }
01481 }
01482
01483 ps_astar_t *
01484 ps_astar_start(ps_lattice_t *dag,
01485 ngram_model_t *lmset,
01486 float32 lwf,
01487 int sf, int ef,
01488 int w1, int w2)
01489 {
01490 ps_astar_t *nbest;
01491 ps_latnode_t *node;
01492
01493 nbest = ckd_calloc(1, sizeof(*nbest));
01494 nbest->dag = dag;
01495 nbest->lmset = lmset;
01496 nbest->lwf = lwf;
01497 nbest->sf = sf;
01498 if (ef < 0)
01499 nbest->ef = dag->n_frames - ef;
01500 else
01501 nbest->ef = ef;
01502 nbest->w1 = w1;
01503 nbest->w2 = w2;
01504 nbest->latpath_alloc = listelem_alloc_init(sizeof(ps_latpath_t));
01505
01506
01507 for (node = dag->nodes; node; node = node->next) {
01508 if (node == dag->end)
01509 node->info.rem_score = 0;
01510 else if (node->exits == NULL)
01511 node->info.rem_score = WORST_SCORE;
01512 else
01513 node->info.rem_score = 1;
01514 }
01515
01516
01517 nbest->path_list = nbest->path_tail = NULL;
01518 for (node = dag->nodes; node; node = node->next) {
01519 if (node->sf == sf) {
01520 ps_latpath_t *path;
01521 int32 n_used;
01522
01523 best_rem_score(nbest, node);
01524 path = listelem_malloc(nbest->latpath_alloc);
01525 path->node = node;
01526 path->parent = NULL;
01527 if (nbest->lmset)
01528 path->score = nbest->lwf *
01529 (w1 < 0)
01530 ? ngram_bg_score(nbest->lmset, node->basewid, w2, &n_used)
01531 : ngram_tg_score(nbest->lmset, node->basewid, w2, w1, &n_used);
01532 else
01533 path->score = 0;
01534 path_insert(nbest, path, path->score + node->info.rem_score);
01535 }
01536 }
01537
01538 return nbest;
01539 }
01540
01541 ps_latpath_t *
01542 ps_astar_next(ps_astar_t *nbest)
01543 {
01544 ps_latpath_t *top;
01545 ps_lattice_t *dag;
01546
01547 dag = nbest->dag;
01548
01549
01550 while ((top = nbest->path_list) != NULL) {
01551 nbest->path_list = nbest->path_list->next;
01552 if (top == nbest->path_tail)
01553 nbest->path_tail = NULL;
01554 nbest->n_path--;
01555
01556
01557 if ((top->node->sf >= nbest->ef)
01558 || ((top->node == dag->end) &&
01559 (nbest->ef > dag->end->sf))) {
01560
01561 return top;
01562 }
01563 else {
01564 if (top->node->fef < nbest->ef)
01565 path_extend(nbest, top);
01566 }
01567
01568
01569
01570
01571
01572 top->next = nbest->paths_done;
01573 nbest->paths_done = top;
01574 }
01575
01576
01577 return NULL;
01578 }
01579
01580 char const *
01581 ps_astar_hyp(ps_astar_t *nbest, ps_latpath_t *path)
01582 {
01583 ps_search_t *search;
01584 ps_latpath_t *p;
01585 size_t len;
01586 char *c;
01587 char *hyp;
01588
01589 search = nbest->dag->search;
01590
01591
01592 len = 0;
01593 for (p = path; p; p = p->parent) {
01594 if (ISA_REAL_WORD(search, p->node->basewid))
01595 len += strlen(dict_word_str(ps_search_dict(search), p->node->basewid)) + 1;
01596 }
01597
01598
01599 hyp = ckd_calloc(1, len);
01600 c = hyp + len - 1;
01601 for (p = path; p; p = p->parent) {
01602 if (ISA_REAL_WORD(search, p->node->basewid)) {
01603 len = strlen(dict_word_str(ps_search_dict(search), p->node->basewid));
01604 c -= len;
01605 memcpy(c, dict_word_str(ps_search_dict(search), p->node->basewid), len);
01606 if (c > hyp) {
01607 --c;
01608 *c = ' ';
01609 }
01610 }
01611 }
01612
01613 nbest->hyps = glist_add_ptr(nbest->hyps, hyp);
01614 return hyp;
01615 }
01616
01617 static void
01618 ps_astar_node2itor(astar_seg_t *itor)
01619 {
01620 ps_seg_t *seg = (ps_seg_t *)itor;
01621 ps_latnode_t *node;
01622
01623 assert(itor->cur < itor->n_nodes);
01624 node = itor->nodes[itor->cur];
01625 if (itor->cur == itor->n_nodes - 1)
01626 seg->ef = node->lef;
01627 else
01628 seg->ef = itor->nodes[itor->cur + 1]->sf - 1;
01629 seg->word = dict_word_str(ps_search_dict(seg->search), node->wid);
01630 seg->sf = node->sf;
01631 seg->prob = 0;
01632 }
01633
01634 static void
01635 ps_astar_seg_free(ps_seg_t *seg)
01636 {
01637 astar_seg_t *itor = (astar_seg_t *)seg;
01638 ckd_free(itor->nodes);
01639 ckd_free(itor);
01640 }
01641
01642 static ps_seg_t *
01643 ps_astar_seg_next(ps_seg_t *seg)
01644 {
01645 astar_seg_t *itor = (astar_seg_t *)seg;
01646
01647 ++itor->cur;
01648 if (itor->cur == itor->n_nodes) {
01649 ps_astar_seg_free(seg);
01650 return NULL;
01651 }
01652 else {
01653 ps_astar_node2itor(itor);
01654 }
01655
01656 return seg;
01657 }
01658
01659 static ps_segfuncs_t ps_astar_segfuncs = {
01660 ps_astar_seg_next,
01661 ps_astar_seg_free
01662 };
01663
01664 ps_seg_t *
01665 ps_astar_seg_iter(ps_astar_t *astar, ps_latpath_t *path, float32 lwf)
01666 {
01667 astar_seg_t *itor;
01668 ps_latpath_t *p;
01669 int cur;
01670
01671
01672 itor = ckd_calloc(1, sizeof(*itor));
01673 itor->base.vt = &ps_astar_segfuncs;
01674 itor->base.search = astar->dag->search;
01675 itor->base.lwf = lwf;
01676 itor->n_nodes = itor->cur = 0;
01677 for (p = path; p; p = p->parent) {
01678 ++itor->n_nodes;
01679 }
01680 itor->nodes = ckd_calloc(itor->n_nodes, sizeof(*itor->nodes));
01681 cur = itor->n_nodes - 1;
01682 for (p = path; p; p = p->parent) {
01683 itor->nodes[cur] = p->node;
01684 --cur;
01685 }
01686
01687 ps_astar_node2itor(itor);
01688 return (ps_seg_t *)itor;
01689 }
01690
01691 void
01692 ps_astar_finish(ps_astar_t *nbest)
01693 {
01694 gnode_t *gn;
01695
01696
01697 for (gn = nbest->hyps; gn; gn = gnode_next(gn)) {
01698 ckd_free(gnode_ptr(gn));
01699 }
01700 glist_free(nbest->hyps);
01701
01702 listelem_alloc_free(nbest->latpath_alloc);
01703
01704 ckd_free(nbest);
01705 }