Blame view

egs/cifar/v1/image/get_image2num_frames.py 2.04 KB
8dcb6dfcb   Yannick Estève   first commit
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
  #!/usr/bin/env python3
  
  # Copyright      2018  Hossein Hadian
  
  
  """ This script computes the image lengths (with padding) in an image data dir.
      The output is written to 'image2num_frames' in the given data dir. This
      file is later used by image/get_allowed_lengths.py to find a set of allowed lengths
      for the data dir. The output format is similar to utt2num_frames
  
  """
  
  import argparse
  import os
  import sys
  import numpy as np
  from PIL import Image
  
  parser = argparse.ArgumentParser(description="""Computes the image lengths (i.e. width) in an image data dir
                                                  and writes them (by default) to image2num_frames.""")
  parser.add_argument('dir', type=str,
                      help='Source data directory (containing images.scp)')
  parser.add_argument('--out-ark', type=str, default=None,
                      help='Where to write the output image-to-num_frames info. '
                      'Default: "dir"/image2num_frames')
  parser.add_argument('--feat-dim', type=int, default=40,
                      help='Size to scale the height of all images')
  parser.add_argument('--padding', type=int, default=5,
                      help='Number of white pixels to pad on the left'
                      'and right side of the image.')
  args = parser.parse_args()
  
  
  def get_scaled_image_length(im):
      scale_size = args.feat_dim
      sx, sy = im.size
      scale = (1.0 * scale_size) / sy
      nx = int(scale * sx)
      return nx
  
  ### main ###
  data_list_path = os.path.join(args.dir,'images.scp')
  
  if not args.out_ark:
      args.out_ark = os.path.join(args.dir,'image2num_frames')
  if args.out_ark == '-':
      out_fh = sys.stdout
  else:
      out_fh = open(args.out_ark, 'w', encoding='latin-1')
  
  with open(data_list_path) as f:
      for line in f:
          line = line.strip()
          line_vect = line.split(' ')
          image_id = line_vect[0]
          image_path = line_vect[1]
          im = Image.open(image_path)
          im_len = get_scaled_image_length(im) + (args.padding * 2)
          print('{} {}'.format(image_id, im_len), file=out_fh)
  
  out_fh.close()