Compare View
Commits (3)
Changes
Showing 1 changed file Inline Diff
volia/masseffect.py
1 | import argparse | 1 | import argparse |
2 | from os import path | 2 | from os import path |
3 | import core.data | 3 | import core.data |
4 | from utils import SubCommandRunner | 4 | from utils import SubCommandRunner |
5 | import os | 5 | import os |
6 | 6 | ||
7 | 7 | ||
8 | def utt2char(features: str, outfile: str): | 8 | def utt2char(features: str, outfile: str): |
9 | """Allow the user to generate utt2char file from masseffect features file. | 9 | """Allow the user to generate utt2char file from masseffect features file. |
10 | 10 | ||
11 | TODO: Don't forget to manage two cases: one with old ids, and an other with | 11 | TODO: Don't forget to manage two cases: one with old ids, and an other with |
12 | new ones. | 12 | new ones. |
13 | 13 | ||
14 | Args: | 14 | Args: |
15 | features (str): [description] | 15 | features (str): [description] |
16 | outfile (str): [description] | 16 | outfile (str): [description] |
17 | """ | 17 | """ |
18 | data = core.data.read_features(features) | 18 | data = core.data.read_features(features) |
19 | keys = list(data.keys()) | 19 | keys = list(data.keys()) |
20 | 20 | ||
21 | with open(outfile, "w") as f: | 21 | with open(outfile, "w") as f: |
22 | for key in keys: | 22 | for key in keys: |
23 | splited = key.replace("\n", "").split(",") | 23 | splited = key.replace("\n", "").split(",") |
24 | character = splited[1] | 24 | character = splited[1] |
25 | f.write(",".join(splited) + " " + character + "\n") | 25 | f.write(",".join(splited) + " " + character + "\n") |
26 | 26 | ||
27 | 27 | ||
28 | def char2utt(features: str, outfile: str): | 28 | def char2utt(features: str, outfile: str): |
29 | raise Exception("Not implemented yet") | 29 | raise Exception("Not implemented yet") |
30 | pass | 30 | pass |
31 | 31 | ||
32 | 32 | ||
33 | def wavscp(datadir: str, outfile: str): | 33 | def wavscp(datadir: str, outfile: str): |
34 | """Generate the masseffect wav scp file from the directories. | 34 | """Generate the masseffect wav scp file from the directories. |
35 | 35 | ||
36 | Args: | 36 | Args: |
37 | datadir (str): oath of the data directory where "audio_en-us" and "audio_fr-fr" are available | 37 | datadir (str): oath of the data directory where "audio_en-us" and "audio_fr-fr" are available |
38 | outfile (str): path of the wav scp output file | 38 | outfile (str): path of the wav scp output file |
39 | 39 | ||
40 | Raises: | 40 | Raises: |
41 | Exception: if one of the directory is not available | 41 | Exception: if one of the directory is not available |
42 | """ | 42 | """ |
43 | en_us_dir = os.path.join(datadir, "audio_en-us") | 43 | en_us_dir = os.path.join(datadir, "audio_en-us") |
44 | fr_fr_dir = os.path.join(datadir, "audio_fr-fr") | 44 | fr_fr_dir = os.path.join(datadir, "audio_fr-fr") |
45 | 45 | ||
46 | if (not os.path.isdir(en_us_dir)) or (not os.path.isdir(fr_fr_dir)): | 46 | if (not os.path.isdir(en_us_dir)) or (not os.path.isdir(fr_fr_dir)): |
47 | raise Exception("Directory audio_en-us or audio_fr-fr does not exist") | 47 | raise Exception("Directory audio_en-us or audio_fr-fr does not exist") |
48 | 48 | ||
49 | _,_,filenames_en=next(os.walk(en_us_dir)) | 49 | _,_,filenames_en=next(os.walk(en_us_dir)) |
50 | # filenames_en = [ os.path.join(en_us_dir, f) for f in filenames_en ] | 50 | # filenames_en = [ os.path.join(en_us_dir, f) for f in filenames_en ] |
51 | dir_en = [ en_us_dir for f in filenames_en ] | 51 | dir_en = [ en_us_dir for f in filenames_en ] |
52 | _,_,filenames_fr=next(os.walk(fr_fr_dir)) | 52 | _,_,filenames_fr=next(os.walk(fr_fr_dir)) |
53 | dir_fr = [ fr_fr_dir for f in filenames_fr ] | 53 | dir_fr = [ fr_fr_dir for f in filenames_fr ] |
54 | # filenames_fr = [ os.path.join(fr_fr_dir, f) for f in filenames_fr ] | 54 | # filenames_fr = [ os.path.join(fr_fr_dir, f) for f in filenames_fr ] |
55 | 55 | ||
56 | directories = dir_en + dir_fr | 56 | directories = dir_en + dir_fr |
57 | filenames = filenames_en + filenames_fr | 57 | filenames = filenames_en + filenames_fr |
58 | 58 | ||
59 | with open(outfile, "w") as f: | 59 | with open(outfile, "w") as f: |
60 | for i, fn in enumerate(filenames): | 60 | for i, fn in enumerate(filenames): |
61 | splited = fn.split(".")[0].split(",") | 61 | splited = fn.split(".")[0].split(",") |
62 | lang = splited[0] | 62 | lang = splited[0] |
63 | character = splited[1] | 63 | character = splited[1] |
64 | record_id = splited[3] | 64 | record_id = splited[3] |
65 | path = os.path.join(directories[i], fn) | 65 | path = os.path.join(directories[i], fn) |
66 | f.write(f"{lang},{character},{record_id} {path}\n") | 66 | f.write(f"{lang},{character},{record_id} {path}\n") |
67 | 67 | ||
68 | 68 | ||
69 | def changelabels(source: str, labels: str, outfile: str): | 69 | def changelabels(source: str, labels: str, outfile: str): |
70 | data_dict = core.data.read_id_values(source) | 70 | data_dict = core.data.read_id_values(source) |
71 | labels_dict = core.data.read_labels(labels) | 71 | labels_dict = core.data.read_labels(labels) |
72 | keys = list(data_dict.keys()) | 72 | keys = list(data_dict.keys()) |
73 | 73 | ||
74 | with open(outfile, "w") as f: | 74 | with open(outfile, "w") as f: |
75 | for key in keys: | 75 | for key in keys: |
76 | splited = key.split(",") | 76 | splited = key.split(",") |
77 | splited[1] = labels_dict[key][0] | 77 | splited[1] = labels_dict[key][0] |
78 | core.data.write_line(",".join(splited), data_dict[key], out=f) | 78 | core.data.write_line(",".join(splited), data_dict[key], out=f) |
79 | 79 | ||
80 | 80 | ||
81 | def converter(file: str, outtype: str, outfile: str): | ||
82 | data = core.data.read_id_values(file) | ||
83 | |||
84 | with open(outfile, "w") as of: | ||
85 | for key in data: | ||
86 | splited = key.replace("\n", "").split(",") | ||
87 | masseffect_id = key.replace("\n", "") | ||
88 | kaldi_id = ",".join([splited[0], splited[1], splited[3]]) | ||
89 | if outtype == "masseffect2kaldi": | ||
90 | of.write(f"{masseffect_id} {kaldi_id}\n") | ||
91 | elif outtype == "kaldi2masseffect": | ||
92 | of.write(f"{kaldi_id} {masseffect_id}\n") | ||
93 | |||
94 | |||
95 | def utt2sub(self, file: str, subfile: str, outfile: str): | ||
96 | data = core.data.read_id_values(file) | ||
97 | keys = [key for key in data] | ||
98 | |||
99 | data_sub = core.data.read_id_values(subfile) | ||
100 | keys_sub = [key for key in data_sub] | ||
101 | |||
102 | with open(outfile) as of: | ||
103 | for key in keys: | ||
104 | subkeys = [subkey for subkey in keys_sub if subkey.startswith(key)] | ||
105 | subkeys_str = " ".join(subkeys) | ||
106 | of.write(f"{key} {subkeys_str}") | ||
107 | |||
108 | |||
109 | def sub2utt(self, file: str, subfile: str, outfile: str): | ||
110 | data = core.data.read_id_values(file) | ||
111 | keys = [key for key in data] | ||
112 | |||
113 | data_sub = core.data.read_id_values(subfile) | ||
114 | keys_sub = [key for key in data_sub] | ||
115 | |||
116 | with open(outfile) as of: | ||
117 | for key in keys: | ||
118 | subkeys = [subkey for subkey in keys_sub if subkey.startswith(key)] | ||
119 | for subkey in subkeys: | ||
120 | of.write(f"{subkey} {key}") | ||
121 | |||
122 | |||
81 | def converter(file: str, outtype: str, outfile: str): | 123 | if __name__ == '__main__': |
82 | data = core.data.read_id_values(file) | 124 | # Main parser |
83 | 125 | parser = argparse.ArgumentParser(description="...") | |
84 | with open(outfile, "w") as of: | 126 | subparsers = parser.add_subparsers(title="action") |
85 | for key in data: | 127 | |
86 | splited = key.replace("\n", "").split(",") | 128 | # utt2char |
87 | masseffect_id = key.replace("\n", "") | 129 | parser_utt2char = subparsers.add_parser("utt2char", help="generate utt2char file") |
88 | kaldi_id = ",".join([splited[0], splited[1], splited[3]]) | 130 | parser_utt2char.add_argument("--features", type=str, help="features file") |
89 | if outtype == "masseffect2kaldi": | 131 | parser_utt2char.add_argument("--outfile", type=str, help="output file") |
90 | of.write(f"{masseffect_id} {kaldi_id}\n") | 132 | parser_utt2char.set_defaults(which="utt2char") |
91 | elif outtype == "kaldi2masseffect": | 133 | |
92 | of.write(f"{kaldi_id} {masseffect_id}\n") | 134 | # char2utt |
93 | 135 | parser_char2utt = subparsers.add_parser("char2utt", help="generate char2utt file") | |
94 | 136 | parser_char2utt.add_argument("--features", type=str, help="features file") | |
95 | if __name__ == '__main__': | 137 | parser_char2utt.add_argument("--outfile", type=str, help="output file") |
96 | # Main parser | 138 | parser_char2utt.set_defaults(which="char2utt") |
97 | parser = argparse.ArgumentParser(description="...") | 139 | |
98 | subparsers = parser.add_subparsers(title="action") | 140 | # wavscp |
99 | 141 | parser_wavscp = subparsers.add_parser("wavscp", help="generate wav scp file") | |
100 | # utt2char | 142 | parser_wavscp.add_argument("--datadir", required=True, help="data directory of masseffect") |
101 | parser_utt2char = subparsers.add_parser("utt2char", help="generate utt2char file") | 143 | parser_wavscp.add_argument("--outfile", default="wav.scp", help="wav.scp output file") |
102 | parser_utt2char.add_argument("--features", type=str, help="features file") | 144 | parser_wavscp.set_defaults(which="wavscp") |
103 | parser_utt2char.add_argument("--outfile", type=str, help="output file") | 145 | |
104 | parser_utt2char.set_defaults(which="utt2char") | 146 | # Change labels |
105 | 147 | parser_changelabels = subparsers.add_parser("changelabels", help="...") | |
106 | # char2utt | 148 | parser_changelabels.add_argument("--source", required=True, type=str, help="source file where we want to change ids.") |
107 | parser_char2utt = subparsers.add_parser("char2utt", help="generate char2utt file") | 149 | parser_changelabels.add_argument("--labels", required=True, type=str, help="file with labels") |
108 | parser_char2utt.add_argument("--features", type=str, help="features file") | 150 | parser_changelabels.add_argument("--outfile", required=True, type=str, help="Output file") |
109 | parser_char2utt.add_argument("--outfile", type=str, help="output file") | 151 | parser_changelabels.set_defaults(which="changelabels") |
110 | parser_char2utt.set_defaults(which="char2utt") | 152 | |
153 | # Create converter | ||
154 | parser_converter = subparsers.add_parser("converter", help="Create converter file") | ||
155 | parser_converter.add_argument("--file", | ||
156 | type=str, | ||
157 | required=True, | ||
158 | help="File with ids from which create converter.") | ||
159 | parser_converter.add_argument("--outtype", type=str, choices=["kaldi2masseffect", "masseffect2kaldi"]) | ||
160 | parser_converter.add_argument("--outfile", type=str, required=True, help="") | ||
161 | parser_converter.set_defaults(which="converter") | ||
162 | |||
163 | # Create utt2sub | ||
164 | parser_utt2sub = subparsers.add_parser("utt2sub", help="generate utt2sub file") | ||
165 | parser_utt2sub.add_argument("--file", required=True, type=str, help="features, list or labels file with normal ids") | ||
166 | parser_utt2sub.add_argument("--subfile", required=True, type=str, help="features, list or labels file with sub ids") | ||
167 | parser_utt2sub.add_argument("--outfile", required=True, type=str, help="output file") | ||
168 | parser_utt2sub.set_defaults(which="utt2sub") | ||
169 | |||
170 | # Create sub2utt | ||
171 | parser_sub2utt = subparsers.add_parser("sub2utt", help="generate sub2utt file") | ||
172 | parser_sub2utt.add_argument("--file", required=True, type=str, help="features, list or labels file with normal ids") | ||
173 | parser_sub2utt.add_argument("--subfile", required=True, type=str, help="features, list or labels file sub ids") | ||
174 | parser_sub2utt.add_argument("--outfile", required=True, type=str, help="output file") | ||
175 | parser_sub2utt.set_defaults(which="sub2utt") | ||
176 | |||
177 | |||
111 | 178 | # Parse | |
112 | # wavscp | 179 | args = parser.parse_args() |
113 | parser_wavscp = subparsers.add_parser("wavscp", help="generate wav scp file") | 180 | |
114 | parser_wavscp.add_argument("--datadir", required=True, help="data directory of masseffect") | 181 | # Run commands |
115 | parser_wavscp.add_argument("--outfile", default="wav.scp", help="wav.scp output file") | 182 | runner = SubCommandRunner({ |
116 | parser_wavscp.set_defaults(which="wavscp") | 183 | "utt2char" : utt2char, |
117 | 184 | "char2utt": char2utt, | |
118 | # Change labels | 185 | "wavscp": wavscp, |
119 | parser_changelabels = subparsers.add_parser("changelabels", help="...") | 186 | "changelabels": changelabels, |
187 | "converter": converter, | ||
188 | "utt2sub": utt2sub, | ||
189 | "sub2utt": sub2utt | ||
120 | parser_changelabels.add_argument("--source", required=True, type=str, help="source file where we want to change ids.") | 190 | }) |
121 | parser_changelabels.add_argument("--labels", required=True, type=str, help="file with labels") | 191 | |
122 | parser_changelabels.add_argument("--outfile", required=True, type=str, help="Output file") | 192 | runner.run(args.which, args.__dict__, remove="which") |
123 | parser_changelabels.set_defaults(which="changelabels") | 193 |