Blame view

egs/wsj/s5/utils/lang/adjust_unk_arpa.pl 2.15 KB
8dcb6dfcb   Yannick Estève   first commit
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
  #!/usr/bin/env perl
  
  # Copyright 2018  Xiaohui Zhang
  # Apache 2.0.
  #
  use strict;
  use warnings;
  use Getopt::Long;
  
  my $Usage = <<EOU;
  # This is a simple script to set/scale the prob of n-grams where the OOV dict entry is the predicted word, in an ARPA lm file.
  Usage: utils/lang/adjust_unk_arpa.pl [options] <oov-dict-entry> <unk-scale> <input-arpa >output-arpa
  
  Allowed options:
    --fixed-value (true|false)   : If true, interpret the unk-scale as a fixed value we'll set to
                                   the unigram prob of the OOV dict entry, rather than using it to
                                   scale the probs. In this case higher order n-grams containing
                                   the OOV dict entry remain untouched. This is useful when the OOV
                                   dict entry doesn't appear in n-grams (n>1) as the predicted word.
  EOU
  
  my $fixed_value = "false";
  GetOptions('fixed-value=s' => \$fixed_value);
  
  ($fixed_value eq "true" || $fixed_value eq "false") ||
    die "$0: Bad value for option --fixed-value
  ";
  
  if (@ARGV != 2) {
    die $Usage;
  }
  
  # Gets parameters.
  my $unk_word = shift @ARGV;
  my $unk_scale = shift @ARGV;
  my $arpa_in = shift @ARGV;
  my $arpa_out = shift @ARGV;
  
  $unk_scale > 0.0 || die "Bad unk_scale"; # this must be positive
  if ( $fixed_value eq "true" ) {
    print STDERR "$0: Setting the unigram prob of $unk_word in LM file as $unk_scale.
  ";
  } else {
    print STDERR "$0: Scaling the probs of ngrams where $unk_word is the predicted word in LM file by $unk_scale.
  ";
  }
  
  my $ngram = 0; # the order of ngram we are visiting
  
  # Change the unigram prob of the unk-word in the ARPA LM.
  while(<STDIN>) {
    if (m/^\\1-grams:$/) { $ngram = 1; }
    if (m/^\\2-grams:$/) { $ngram = 2; }
    if (m/^\\3-grams:$/) { $ngram = 3; }
    if (m/^\\4-grams:$/) { $ngram = 4; }
    if (m/^\\5-grams:$/) { $ngram = 5; }
    my @col = split(" ", $_);
    if ( @col > 1 && $ngram > 0 && $col[$ngram] eq $unk_word ) {
      if ( $fixed_value eq "true" && $ngram == 1 ) {
        $col[0] = (log($unk_scale) / log(10.0));
      } else {
        $col[0] += (log($unk_scale) / log(10.0));
      }
      my $line = join("\t", @col);
      print "$line
  ";
    } else {
      print;
    }
  }
  
  exit 0