Blame view
src/nnet3/nnet-nnet.h
14.3 KB
8dcb6dfcb first commit |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 |
// nnet3/nnet-nnet.h // Copyright 2012-2015 Johns Hopkins University (author: Daniel Povey) // 2016 Daniel Galvez // See ../../COPYING for clarification regarding multiple authors // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY // KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED // WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, // MERCHANTABLITY OR NON-INFRINGEMENT. // See the Apache 2 License for the specific language governing permissions and // limitations under the License. #ifndef KALDI_NNET3_NNET_NNET_H_ #define KALDI_NNET3_NNET_NNET_H_ #include "base/kaldi-common.h" #include "util/kaldi-io.h" #include "matrix/matrix-lib.h" #include "nnet3/nnet-common.h" #include "nnet3/nnet-component-itf.h" #include "nnet3/nnet-descriptor.h" #include <iostream> #include <sstream> #include <vector> #include <map> namespace kaldi { namespace nnet3 { /// This enum is for a kind of annotation we associate with output nodes of the /// network; it's for the convenience of calling code so that if the objective /// is one of a few standard types, we can compute it directly and know how to /// interpret the supervision labels. However, the core of the framework never /// makes use of the objective types, other than making them available to /// calling code which then supplies the derivatives. /// - Objective type kLinear is intended for Neural nets where the final /// component is a LogSoftmaxComponent, so the log-prob (negative /// cross-entropy) objective is just a linear function of the input. /// - Objective type kQuadratic is used to mean the objective function /// f(x, y) = -0.5 (x-y).(x-y), which is to be maximized, as in the kLinear /// case. enum ObjectiveType { kLinear, kQuadratic }; enum NodeType { kInput, kDescriptor, kComponent, kDimRange, kNone }; /// NetworkNode is used to represent, three types of thing: either an input of the /// network (which pretty much just states the dimension of the input vector); /// a Component (e.g. an affine component or a sigmoid component); or a Descriptor. /// A Descriptor is basically an expression that can do things like append /// the outputs of other components (or inputs) together, add them together, and /// do various other things like shifting the time index. /// /// Each Component must have an input of type kDescriptor that is numbered /// Preceding to the Component, and that is not used elsewhere. This may seem /// unintuitive but it makes the implementation a lot easier; any apparent waste /// can be optimized out after compilation. And outputs must also be of type /// kDescriptor. /// /// Note: in the actual computation you can provide input not only to nodes of /// type kInput but also to nodes of type kComponent; this is useful in things /// like recurrent nets where you may want to split the computation up into /// pieces. /// /// Note that in the config-file format, there are three types of node: input, /// component and output. output maps to kDescriptor, but the nodes of type /// kDescriptor that represent the input to a component, are described in the /// same config-file line as the Component itself. struct NetworkNode { NodeType node_type; // "descriptor" is relevant only for nodes of type kDescriptor. Descriptor descriptor; union { // For kComponent, the index into Nnet::components_ int32 component_index; // for kDimRange, the node-index of the input node, which must be of // type kComponent or kInput. int32 node_index; // for nodes of type kDescriptor that are output nodes (i.e. not followed by // a node of type kComponents), the objective function associated with the // output. The core parts of the nnet code just ignore; it is required only // for the information of the calling code, which is perfectly free to // ignore it. View it as a kind of annotation. ObjectiveType objective_type; } u; // for kInput, the dimension of the input feature. For kDimRange, the dimension // of the output (i.e. the length of the range) int32 dim; // for kDimRange, the dimension of the offset into the input component's feature. int32 dim_offset; int32 Dim(const Nnet &nnet) const; // Dimension that this node outputs. NetworkNode(NodeType nt = kNone): node_type(nt), dim(-1), dim_offset(-1) { u.component_index = -1; } NetworkNode(const NetworkNode &other); // copy constructor. // use default assignment operator }; class Nnet { public: // This function can be used either to initialize a new Nnet from a config // file, or to add to an existing Nnet, possibly replacing certain parts of // it. It will die with error if something went wrong. // Also see the function ReadEditConfig() in nnet-utils.h (it's made a // non-member because it doesn't need special access). void ReadConfig(std::istream &config_file); int32 NumComponents() const { return components_.size(); } int32 NumNodes() const { return nodes_.size(); } /// Return component indexed c. Not a copy; not owned by caller. Component *GetComponent(int32 c); /// Return component indexed c (const version). Not a copy; not owned by /// caller. const Component *GetComponent(int32 c) const; /// Replace the component indexed by c with a new component. /// Frees previous component indexed by c. Takes ownership of /// the pointer 'component'. void SetComponent(int32 c, Component *component); /// Adds a new component with the given name, which should not be the same as /// any existing component name. Returns the new component index. Takes /// ownership of the pointer 'component'. int32 AddComponent(const std::string &name, Component *component); /// returns const reference to a particular numbered network node. const NetworkNode &GetNode(int32 node) const { KALDI_ASSERT(node >= 0 && node < nodes_.size()); return nodes_[node]; } /// Non-const accessor for the node... use with extreme caution. NetworkNode &GetNode(int32 node) { KALDI_ASSERT(node >= 0 && node < nodes_.size()); return nodes_[node]; } /// Returns true if this is a component node, meaning that it is of type /// kComponent. bool IsComponentNode(int32 node) const; /// Returns true if this is a dim-range node, meaning that it is of type /// kDimRange. bool IsDimRangeNode(int32 node) const; /// Returns true if this is an output node, meaning that it is of type /// kInput. bool IsInputNode(int32 node) const; /// Returns true if this is a descriptor node, meaning that it is of type /// kDescriptor. Exactly one of IsOutput or IsComponentInput will also /// apply. bool IsDescriptorNode(int32 node) const; /// Returns true if this is an output node, meaning that it is of type kDescriptor /// and is not directly followed by a node of type kComponent. bool IsOutputNode(int32 node) const; /// Returns true if this is component-input node, i.e. a node of type kDescriptor /// that immediately precedes a node of type kComponent. bool IsComponentInputNode(int32 node) const; /// returns vector of node names (needed by some parsing code, for instance). const std::vector<std::string> &GetNodeNames() const; /// returns individual node name. const std::string &GetNodeName(int32 node_index) const; /// This can be used to modify invidual node names. Note, this does not /// affect the neural net structure at all, it just assigns a new /// name to an existing node while leaving all connections identical. void SetNodeName(int32 node_index, const std::string &new_name); /// returns vector of component names (needed by some parsing code, for instance). const std::vector<std::string> &GetComponentNames() const; /// returns individual component name. const std::string &GetComponentName(int32 component_index) const; /// returns index associated with this node name, or -1 if no such index. int32 GetNodeIndex(const std::string &node_name) const; /// returns index associated with this component name, or -1 if no such index. int32 GetComponentIndex(const std::string &node_name) const; // This convenience function returns the dimension of the input with name // "input_name" (e.g. input_name="input" or "ivector"), or -1 if there is no // such input. int32 InputDim(const std::string &input_name) const; // This convenience function returns the dimension of the output with // name "input_name" (e.g. output_name="input"), or -1 if there is // no such input. int32 OutputDim(const std::string &output_name) const; void Read(std::istream &istream, bool binary); void Write(std::ostream &ostream, bool binary) const; /// Checks the neural network for validity (dimension matches and various /// other requirements). /// You can call this with warn_for_orphans = false to disable the warnings /// that are printed if orphan nodes or components exist. void Check(bool warn_for_orphans = true) const; /// returns some human-readable information about the network, mostly for /// debugging purposes. /// Also see function NnetInfo() in nnet-utils.h, which prints out more /// extensive infoformation. std::string Info() const; /// [Relevant for clockwork RNNs and similar]. Computes the smallest integer /// n >=1 such that the neural net's behavior will be the same if we shift the /// input and output's time indexes (t) by integer multiples of n. Does this /// by computing the lcm of all the moduli of the Descriptors in the network. int32 Modulus() const; ~Nnet() { Destroy(); } // Default constructor Nnet() { } // Copy constructor Nnet(const Nnet &nnet); Nnet *Copy() const { return new Nnet(*this); } void Swap(Nnet *other); // Assignment operator Nnet& operator =(const Nnet &nnet); // Removes nodes that are never needed to compute any output. void RemoveOrphanNodes(bool remove_orphan_inputs = false); // Removes components that are not used by any node. void RemoveOrphanComponents(); // Removes some nodes. This is not to be called without a lot of thought, // as it could ruin the graph structure if done carelessly. void RemoveSomeNodes(const std::vector<int32> &nodes_to_remove); void ResetGenerators(); // resets random-number generators for all // random components. You must call srand() prior to this call, for this to // be effective. // This function outputs to "config_lines" the lines of a config file. If you // provide include_dim=false, this will enable you to reconstruct the nodes in // the network (but not the components, which need to be written separately). // If you provide include_dim=true, it also adds extra information about // node dimensions which is useful for a human reader but won't be // accepted as the config-file format. void GetConfigLines(bool include_dim, std::vector<std::string> *config_lines) const; private: void Destroy(); // This function returns as a string the contents of a line of a config-file // corresponding to the node indexed "node_index", which must not be of type // kComponentInput. If include_dim=false, it appears in the same format as it // would appear in a line of a config-file; if include_dim=true, we also // include dimension information that would not be provided in a config file. std::string GetAsConfigLine(int32 node_index, bool include_dim) const; // This function is used when reading config files; it exists in order to // handle replacement of existing nodes. The two input vectors have the same // size. Its job is to remove redundant lines that do not have "component" as // first_token, and where two lines have a configuration value name=xxx in the // config with the same name. In this case it removes the first of the two, // but that first one must have index less than num_lines_initial, else it is // an error. // This function also checks that all lines have a config name=xxx, that // IsValidName(xxx) is true, and that there are no two lines with "component" // as the first token and with the same config name=xxx. Note: here, "name" // means literally "name", but "xxx" stands in for the actual name, // e.g. "my-funky-component." static void RemoveRedundantConfigLines(int32 num_lines_initial, std::vector<ConfigLine> *config_lines); void ProcessComponentConfigLine(int32 initial_num_components, ConfigLine *config); void ProcessComponentNodeConfigLine(int32 pass, ConfigLine *config); void ProcessInputNodeConfigLine(ConfigLine *config); void ProcessOutputNodeConfigLine(int32 pass, ConfigLine *config); void ProcessDimRangeNodeConfigLine(int32 pass, ConfigLine *config); // This function output to "modified_node_names" a modified copy of // node_names_, in which all nodes which are not of type kComponent, kInput or // kDimRange are replaced with the string "***". This is useful when parsing // Descriptors, to avoid inadvertently accepting nodes of invalid types where // they are not allowed. void GetSomeNodeNames(std::vector<std::string> *modified_node_names) const; // the names of the components of the network. Note, these may be distinct // from the network node names below (and live in a different namespace); the // same component may be used in multiple network nodes, to define parameter // sharing. std::vector<std::string> component_names_; // the components of the nnet, in arbitrary order. The network topology is // defined separately, below; a given Component may appear more than once in // the network if necessary for parameter tying. std::vector<Component*> components_; // names of network nodes, i.e. inputs, components and outputs, used only in // reading and writing code. Indexed by network-node index. Note, // components' names are always listed twice, once as foo-input and once as // foo, because the input to a component always gets its own NetworkNode index. std::vector<std::string> node_names_; // the network nodes of the network. std::vector<NetworkNode> nodes_; }; } // namespace nnet3 } // namespace kaldi #endif |