Blame view
egs/wsj/s5/steps/dict/internal/sum_arc_info.py
5.09 KB
8dcb6dfcb first commit |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 |
#!/usr/bin/env python # Copyright 2018 Xiaohui Zhang # Apache 2.0 from __future__ import print_function from collections import defaultdict import argparse import sys class StrToBoolAction(argparse.Action): """ A custom action to convert bools from shell format i.e., true/false to python format i.e., True/False """ def __call__(self, parser, namespace, values, option_string=None): if values == "true": setattr(namespace, self.dest, True) elif values == "false": setattr(namespace, self.dest, False) else: raise Exception("Unknown value {0} for --{1}".format(values, self.dest)) def GetArgs(): parser = argparse.ArgumentParser( description = "Accumulate statistics from per arc lattice statitics" "for lexicon learning", epilog = "See steps/dict/learn_lexicon_greedy.sh for example") parser.add_argument("--set-sum-to-one", type = str, default = True, action = StrToBoolAction, choices = ["true", "false"], help = "If normalize posteriors such that the sum of " "pronunciation posteriors of a word in an utterance is 1.") parser.add_argument("arc_info_file", metavar = "<arc-info-file>", type = str, help = "File containing per arc statistics; " "each line must be <utt-id> <word> <start-frame> <duration> <posterior>" "<phones-with-word-boundary-markers>") parser.add_argument("phone_map", metavar = "<phone-map>", type = str, help = "An input phone map used to remove word boundary markers from phones;" "generated in steps/cleanup/debug_lexicon.sh") parser.add_argument("stats_file", metavar = "<out-stats-file>", type = str, help = "Write accumulated statitistics to this file" "each line is <utt-id> <word> <start-frame> <posterior>" "<phones-without-word-boundary-markers>") print (' '.join(sys.argv), file=sys.stderr) args = parser.parse_args() args = CheckArgs(args) return args def CheckArgs(args): if args.arc_info_file == "-": args.arc_info_file_handle = sys.stdin else: args.arc_info_file_handle = open(args.arc_info_file) args.phone_map_handle = open(args.phone_map) if args.stats_file == "-": args.stats_file_handle = sys.stdout else: args.stats_file_handle = open(args.stats_file, "w") return args def Main(): args = GetArgs() lexicon = defaultdict(list) prons = defaultdict(list) start_frames = {} stats = defaultdict(lambda : defaultdict(float)) sum_tot = defaultdict(float) phone_map = {} for line in args.phone_map_handle.readlines(): splits = line.strip().split() phone_map[splits[0]] = splits[1] for line in args.arc_info_file_handle.readlines(): splits = line.strip().split() if (len(splits) == 0): continue if (len(splits) < 6): raise Exception('Invalid format of line ' + line + ' in ' + args.arc_info_file) utt = splits[0] start_frame = int(splits[1]) word = splits[4] count = float(splits[3]) phones_unmapped = splits[5:] phones = [phone_map[phone] for phone in phones_unmapped] phones = ' '.join(phones) overlap = False if word == '<eps>': continue if (word, utt) not in start_frames: start_frames[(word, utt)] = start_frame if (word, utt) in stats: stats[word, utt][phones] = stats[word, utt].get(phones, 0) + count else: stats[(word, utt)][phones] = count sum_tot[(word, utt)] += count if phones not in prons[word]: prons[word].append(phones) for (word, utt) in stats: count_sum = 0.0 counts = dict() for phones in stats[(word, utt)]: count = stats[(word, utt)][phones] count_sum += count counts[phones] = count # By default we normalize the pron posteriors of each word in each utterance, # so that they sum up exactly to one. If a word occurs two times in a utterance, # the effect of this operation is to average the posteriors of these two occurences # so that there's only one "equivalent occurence" of this word in the utterance. # However, this case should be extremely rare if the utterances are already # short sub-utterances produced by steps/dict/internal/get_subsegments.py for phones in stats[(word, utt)]: count = counts[phones] / count_sum print(word, utt, start_frames[(word, utt)], count, phones, file=args.stats_file_handle) # # Diagnostics info implying incomplete arc_info or multiple occurences of a word in a utterance: # if count_sum < 0.9 or count_sum > 1.1: # print(word, utt, start_frame, count_sum, stats[word, utt], file=sys.stderr) args.stats_file_handle.close() if __name__ == "__main__": Main() |