Blame view

egs/wsj/s5/local/wsj_train_rnnlms.sh 6.04 KB
8dcb6dfcb   Yannick Estève   first commit
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
  #!/bin/bash
  
  # Copyright 2012  Johns Hopkins University (author: Daniel Povey)  Tony Robinson
  #           2015  Guoguo Chen
  
  # This script trains LMs on the WSJ LM-training data.
  # It requires that you have already run wsj_extend_dict.sh,
  # to get the larger-size dictionary including all of CMUdict
  # plus any OOVs and possible acronyms that we could easily
  # derive pronunciations for.
  
  # This script takes no command-line arguments but takes the --cmd option.
  
  # Begin configuration section.
  rand_seed=0
  cmd=run.pl
  nwords=10000 # This is how many words we're putting in the vocab of the RNNLM.
  hidden=30
  class=200 # Num-classes... should be somewhat larger than sqrt of nwords.
  direct=1000 # Number of weights that are used for "direct" connections, in millions.
  rnnlm_ver=rnnlm-0.3e # version of RNNLM to use
  threads=1 # for RNNLM-HS
  bptt=2 # length of BPTT unfolding in RNNLM
  bptt_block=20 # length of BPTT unfolding in RNNLM
  dict_suffix=
  rnnlm_options="-direct-order 4"
  # End configuration section.
  
  [ -f ./path.sh ] && . ./path.sh
  . utils/parse_options.sh
  
  if [ $# != 1 ]; then
     echo "Usage: local/wsj_train_rnnlms.sh [options] <dest-dir>"
     echo "For options, see top of script file"
     exit 1;
  fi
  
  dir=$1
  srcdir=data/local/dict${dict_suffix}_larger
  mkdir -p $dir
  
  $KALDI_ROOT/tools/extras/check_for_rnnlm.sh "$rnnlm_ver" || exit 1
  export PATH=$KALDI_ROOT/tools/$rnnlm_ver:$PATH
  
  
  if [ ! -f $srcdir/cleaned.gz -o ! -f $srcdir/lexicon.txt ]; then
    echo "Expecting files $srcdir/cleaned.gz and $srcdir/wordlist.final to exist";
    echo "You need to run local/wsj_extend_dict.sh before running this script."
    exit 1;
  fi
  
  cat $srcdir/lexicon.txt | awk '{print $1}' | grep -v -w '!SIL' > $dir/wordlist.all
  
  # Get training data with OOV words (w.r.t. our current vocab) replaced with <UNK>.
  echo "Getting training data with OOV words replaced with <UNK> (train_nounk.gz)"
  gunzip -c $srcdir/cleaned.gz | awk -v w=$dir/wordlist.all \
    'BEGIN{while((getline<w)>0) v[$1]=1;}
    {for (i=1;i<=NF;i++) if ($i in v) printf $i" ";else printf "<UNK> ";print ""}'|sed 's/ $//g' \
    | gzip -c > $dir/all.gz
  
  echo "Splitting data into train and validation sets."
  heldout_sent=10000
  gunzip -c $dir/all.gz | head -n $heldout_sent > $dir/valid.in # validation data
  gunzip -c $dir/all.gz | tail -n +$heldout_sent | \
   perl -e ' use List::Util qw(shuffle); @A=<>; print join("", shuffle(@A)); ' \
    > $dir/train.in # training data
  
  
    # The rest will consist of a word-class represented by <RNN_UNK>, that
    # maps (with probabilities) to a whole class of words.
  
  # Get unigram counts from our training data, and use this to select word-list
  # for RNNLM training; e.g. 10k most frequent words.  Rest will go in a class
  # that we (manually, at the shell level) assign probabilities for words that
  # are in that class.  Note: this word-list doesn't need to include </s>; this
  # automatically gets added inside the rnnlm program.
  # Note: by concatenating with $dir/wordlist.all, we are doing add-one
  # smoothing of the counts.
  
  export TMPDIR=$dir # to avoid filling up /tmp/
  
  cat $dir/train.in $dir/wordlist.all | grep -v '</s>' | grep -v '<s>' | \
    awk '{ for(x=1;x<=NF;x++) count[$x]++; } END{for(w in count){print count[w], w;}}' | \
    sort -nr > $dir/unigram.counts
  
  head -$nwords $dir/unigram.counts | awk '{print $2}' > $dir/wordlist.rnn
  
  tail -n +$nwords $dir/unigram.counts > $dir/unk_class.counts
  
  tot=`awk '{x=x+$1} END{print x}' $dir/unk_class.counts`
  awk -v tot=$tot '{print $2, ($1*1.0/tot);}' <$dir/unk_class.counts  >$dir/unk.probs
  
  
  for type in train valid; do
    cat $dir/$type.in | awk -v w=$dir/wordlist.rnn \
      'BEGIN{while((getline<w)>0) v[$1]=1;}
      {for (i=1;i<=NF;i++) if ($i in v) printf $i" ";else printf "<RNN_UNK> ";print ""}'|sed 's/ $//g' \
      > $dir/$type
  done
  rm $dir/train.in # no longer needed-- and big.
  
  # Now randomize the order of the training data.
  cat $dir/train | awk -v rand_seed=$rand_seed 'BEGIN{srand(rand_seed);} {printf("%f\t%s
  ", rand(), $0);}' | \
   sort | cut -f 2 > $dir/foo
  mv $dir/foo $dir/train
  
  # OK we'll train the RNNLM on this data.
  
  # todo: change 100 to 320.
  # using 100 classes as square root of 10k.
  echo "Training RNNLM (note: this uses a lot of memory! Run it on a big machine.)"
  #time rnnlm -train $dir/train -valid $dir/valid -rnnlm $dir/100.rnnlm \
  #  -hidden 100 -rand-seed 1 -debug 2 -class 100 -bptt 2 -bptt-block 20 \
  #  -direct-order 4 -direct 1000 -binary >& $dir/rnnlm1.log &
  
  # since the mikolov rnnlm and faster-rnnlm have slightly different interfaces...
  if [ "$rnnlm_ver" == "faster-rnnlm" ]; then
    $cmd $dir/rnnlm.log \
       $KALDI_ROOT/tools/$rnnlm_ver/rnnlm -threads $threads -train $dir/train -valid $dir/valid \
       -rnnlm $dir/rnnlm -hidden $hidden -seed 1 -bptt $bptt -bptt-block $bptt_block \
       $rnnlm_options -direct $direct || exit 1;
  else
    $cmd $dir/rnnlm.log \
       $KALDI_ROOT/tools/$rnnlm_ver/rnnlm -independent -train $dir/train -valid $dir/valid \
       -rnnlm $dir/rnnlm -hidden $hidden -rand-seed 1 -debug 2 -class $class -bptt $bptt -bptt-block $bptt_block \
       $rnnlm_options -direct $direct -binary || exit 1;
  fi
  
  # make it like a Kaldi table format, with fake utterance-ids.
  cat $dir/valid.in | awk '{ printf("uttid-%d ", NR); print; }' > $dir/valid.with_ids
  
  utils/rnnlm_compute_scores.sh --ensure_normalized_probs true --rnnlm_ver $rnnlm_ver $dir $dir/tmp.valid $dir/valid.with_ids \
    $dir/valid.scores
  
  nw=`cat $dir/valid.with_ids | awk '{a+=NF}END{print a}'` # Note: valid.with_ids includes utterance-ids which
    # is one per word, to account for the </s> at the end of each sentence; this is the
    # correct number to normalize buy.
    # we have noticed that "wc -w" might give wrong results for certain languages
  
  p=`awk -v nw=$nw '{x=x+$2} END{print exp(x/nw);}' <$dir/valid.scores`
  echo Perplexity is $p | tee $dir/perplexity.log
  
  rm $dir/train $dir/all.gz
  
  # This is a better setup, but takes a long time to train:
  #echo "Training RNNLM (note: this uses a lot of memory! Run it on a big machine.)"
  #time rnnlm -train $dir/train -valid $dir/valid -rnnlm $dir/320.rnnlm \
  #  -hidden 320 -rand-seed 1 -debug 2 -class 300 -bptt 2 -bptt-block 20 \
  #  -direct-order 4 -direct 2000 -binary