Blame view

LDA/00-mmf_make_features.py 1.4 KB
7db73861f   Killian   add vae et mmf
1
2
3
4
5
6
7
8
9
10
11
12
13
  import sys 
  import os 
  
  import pandas 
  import numpy 
  import shelve
  
  from sklearn.preprocessing import LabelBinarizer
  
  from utils import select_mmf as select
  
  input_dir = sys.argv[1] # Dossier de premire niveau contient ASR et TRS
  level = sys.argv[2] # taille de LDA ( -5) voulu 
e5108393c   Killian   replace du mlp.p...
14
  output_dir = sys.argv[3]
7db73861f   Killian   add vae et mmf
15
16
17
  
  lb=LabelBinarizer()
  #y_train=lb.fit_transform([utils.select(ligneid) for ligneid in origin_corps["LABEL"]["TRAIN"]])
e5108393c   Killian   replace du mlp.p...
18
19
20
21
  data = shelve.open("{}/mmf_{}.shelve".format(output_dir,level),writeback=True)
  data["LABEL"]= {}
  data["LDA"] = {"ASR":{},"TRS":{}}
  for mod in ["ASR", "TRS" ]:
d1012a7a1   Killian   update LDA/.py
22
23
24
      train = pandas.read_table("{}/{}/train_{}.tab".format(input_dir, mod, level), sep=" ", header=None )
      dev = pandas.read_table("{}/{}/dev_{}.tab".format(input_dir, mod, level), sep=" ", header=None )
      test = pandas.read_table("{}/{}/test_{}.tab".format(input_dir, mod, level), sep=" ", header=None )
7db73861f   Killian   add vae et mmf
25
26
27
28
29
30
  
      y_train = train.iloc[:,0].apply(select)
      y_dev = dev.iloc[:,0].apply(select)
      y_test = test.iloc[:,0].apply(select)
      lb.fit(y_train)
      data["LABEL"][mod]={"TRAIN":lb.transform(y_train),"DEV":lb.transform(y_dev), "TEST": lb.transform(y_test)}
e5108393c   Killian   replace du mlp.p...
31
     # data["LDA"][mod]={'ASR':[]}
e5108393c   Killian   replace du mlp.p...
32
33
34
35
      print train.values
      data["LDA"][mod]["TRAIN"]=train.iloc[:,1:-1].values
      data["LDA"][mod]["DEV"]=dev.iloc[:,1:-1].values
      data["LDA"][mod]["TEST"]=test.iloc[:,1:-1].values
7db73861f   Killian   add vae et mmf
36

d1012a7a1   Killian   update LDA/.py
37
      print data["LDA"][mod]["TRAIN"].shape
7db73861f   Killian   add vae et mmf
38
39
  data.sync()
  data.close()