// nnet3/nnet-graph.cc // Copyright 2015 Johns Hopkins University (author: Daniel Povey) // 2015 Guoguo Chen // See ../../COPYING for clarification regarding multiple authors // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY // KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED // WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, // MERCHANTABLITY OR NON-INFRINGEMENT. // See the Apache 2 License for the specific language governing permissions and // limitations under the License. #include #include #include "nnet3/nnet-graph.h" namespace kaldi { namespace nnet3 { void NnetToDirectedGraph(const Nnet &nnet, std::vector > *graph) { graph->clear(); int32 num_nodes = nnet.NumNodes(); graph->resize(num_nodes); for (int32 n = 0; n < num_nodes; n++) { const NetworkNode &node = nnet.GetNode(n); // handle dependencies of this node. std::vector node_dependencies; switch (node.node_type) { case kInput: break; // no node dependencies. case kDescriptor: node.descriptor.GetNodeDependencies(&node_dependencies); break; case kComponent: node_dependencies.push_back(n - 1); break; case kDimRange: node_dependencies.push_back(node.u.node_index); break; default: KALDI_ERR << "Invalid node type"; } SortAndUniq(&node_dependencies); for (size_t i = 0; i < node_dependencies.size(); i++) { int32 dep_n = node_dependencies[i]; KALDI_ASSERT(dep_n >= 0 && dep_n < num_nodes); (*graph)[dep_n].push_back(n); } } } void ComputeGraphTranspose(const std::vector > &graph, std::vector > *graph_transpose) { int32 size = graph.size(); graph_transpose->clear(); graph_transpose->resize(size); for (int32 n = 0; n < size; n++) { const std::vector &nodes = graph[n]; std::vector::const_iterator iter = nodes.begin(), end = nodes.end(); for (; iter != end; ++iter) { int32 dest = *iter; (*graph_transpose)[dest].push_back(n); } } } struct TarjanNode { int32 index; int32 lowlink; bool on_stack; TarjanNode() : index(-1), lowlink(-1), on_stack(false) {} }; void TarjanSccRecursive(int32 node, const std::vector > &graph, int32 *global_index, std::vector *tarjan_nodes, std::vector *tarjan_stack, std::vector > *sccs) { KALDI_ASSERT(sccs != NULL); KALDI_ASSERT(tarjan_nodes != NULL); KALDI_ASSERT(tarjan_stack != NULL); KALDI_ASSERT(global_index != NULL); KALDI_ASSERT(node >= 0 && node < graph.size()); // Initializes the current Tarjan node. (*tarjan_nodes)[node].index = *global_index; (*tarjan_nodes)[node].lowlink = *global_index; *global_index += 1; (*tarjan_nodes)[node].on_stack = true; tarjan_stack->push_back(node); // DFS from the current node. for (int32 i = 0; i < graph[node].size(); ++i) { int32 next = graph[node][i]; if ((*tarjan_nodes)[next].index == -1) { // First time we see this node. TarjanSccRecursive(next, graph, global_index, tarjan_nodes, tarjan_stack, sccs); (*tarjan_nodes)[node].lowlink = std::min((*tarjan_nodes)[node].lowlink, (*tarjan_nodes)[next].lowlink); } else if ((*tarjan_nodes)[next].on_stack) { // Next node is on the stack -- back edge. We can't use the lowlink of // next node, because that may point to the index of the root, while the // current node can't be the root. (*tarjan_nodes)[node].lowlink = std::min((*tarjan_nodes)[node].lowlink, (*tarjan_nodes)[next].index); } } // Output SCC. if ((*tarjan_nodes)[node].index == (*tarjan_nodes)[node].lowlink) { std::vector scc; int32 pop_node; do { pop_node = tarjan_stack->back(); tarjan_stack->pop_back(); (*tarjan_nodes)[pop_node].on_stack = false; scc.push_back(pop_node); } while (pop_node != node); KALDI_ASSERT(pop_node == node); sccs->push_back(scc); } } void FindSccsTarjan(const std::vector > &graph, std::vector > *sccs) { KALDI_ASSERT(sccs != NULL); // Initialization. std::vector tarjan_nodes(graph.size()); std::vector tarjan_stack; int32 global_index = 0; // Calls the recursive function. for (int32 n = 0; n < graph.size(); ++n) { if (tarjan_nodes[n].index == -1) { TarjanSccRecursive(n, graph, &global_index, &tarjan_nodes, &tarjan_stack, sccs); } } } void FindSccs(const std::vector > &graph, std::vector > *sccs) { // Internally we call Tarjan's SCC algorithm, as it only requires one DFS. We // can change this to other methods later on if necessary. KALDI_ASSERT(sccs != NULL); FindSccsTarjan(graph, sccs); } void MakeSccGraph(const std::vector > &graph, const std::vector > &sccs, std::vector > *scc_graph) { KALDI_ASSERT(scc_graph != NULL); scc_graph->clear(); scc_graph->resize(sccs.size()); // Hash map from node to SCC index. std::vector node_to_scc_index(graph.size()); for (int32 i = 0; i < sccs.size(); ++i) { for (int32 j = 0; j < sccs[i].size(); ++j) { KALDI_ASSERT(sccs[i][j] >= 0 && sccs[i][j] < graph.size()); node_to_scc_index[sccs[i][j]] = i; } } // Builds graph. for (int32 i = 0; i < sccs.size(); ++i) { for (int32 j = 0; j < sccs[i].size(); ++j) { int32 node = sccs[i][j]; KALDI_ASSERT(node >= 0 && node < graph.size()); for (int32 k = 0; k < graph[node].size(); ++k) { if (node_to_scc_index[graph[node][k]] != i) { // Exclucding self. (*scc_graph)[i].push_back(node_to_scc_index[graph[node][k]]); } } } // If necessary, we can use a hash maps to avoid this sorting. SortAndUniq(&((*scc_graph)[i])); } } void ComputeTopSortOrderRecursive(int32 node, const std::vector > &graph, std::vector *cycle_detector, std::vector *is_visited, std::vector *reversed_orders) { KALDI_ASSERT(node >= 0 && node < graph.size()); KALDI_ASSERT(cycle_detector != NULL); KALDI_ASSERT(is_visited != NULL); KALDI_ASSERT(reversed_orders != NULL); if ((*cycle_detector)[node]) { KALDI_ERR << "Cycle detected when computing the topological sorting order"; } if (!(*is_visited)[node]) { (*cycle_detector)[node] = true; for (int32 i = 0; i < graph[node].size(); ++i) { ComputeTopSortOrderRecursive(graph[node][i], graph, cycle_detector, is_visited, reversed_orders); } (*cycle_detector)[node] = false; (*is_visited)[node] = true; // At this point we have added all the children to , so we // can add the current now. reversed_orders->push_back(node); } } void ComputeTopSortOrder(const std::vector > &graph, std::vector *node_to_order) { // Internally we use DFS, but we only put the node to when all // its parents have been visited. KALDI_ASSERT(node_to_order != NULL); node_to_order->resize(graph.size()); std::vector cycle_detector(graph.size(), false); std::vector is_visited(graph.size(), false); std::vector reversed_orders; for(int32 i = 0; i < graph.size(); ++i) { if (!is_visited[i]) { ComputeTopSortOrderRecursive(i, graph, &cycle_detector, &is_visited, &reversed_orders); } } KALDI_ASSERT(node_to_order->size() == reversed_orders.size()); for (int32 i = 0; i < reversed_orders.size(); ++i) { KALDI_ASSERT(reversed_orders[i] >= 0 && reversed_orders[i] < graph.size()); (*node_to_order)[reversed_orders[i]] = graph.size() - i - 1; } } std::string PrintGraphToString(const std::vector > &graph) { std::ostringstream os; int32 num_nodes = graph.size(); for (int32 i = 0; i < num_nodes; i++) { os << i << " -> ("; const std::vector &vec = graph[i]; int32 size = vec.size(); for (int32 j = 0; j < size; j++) { os << vec[j]; if (j + 1 < size) os << ","; } os << ")"; if (i + 1 < num_nodes) os << "; "; } return os.str(); } void ComputeNnetComputationEpochs(const Nnet &nnet, std::vector *node_to_epoch) { KALDI_ASSERT(node_to_epoch != NULL); std::vector > graph; NnetToDirectedGraph(nnet, &graph); KALDI_VLOG(6) << "graph is: " << PrintGraphToString(graph); std::vector > sccs; FindSccs(graph, &sccs); std::vector > scc_graph; MakeSccGraph(graph, sccs, &scc_graph); KALDI_VLOG(6) << "scc graph is: " << PrintGraphToString(scc_graph); std::vector scc_node_to_epoch; ComputeTopSortOrder(scc_graph, &scc_node_to_epoch); if (GetVerboseLevel() >= 6) { std::ostringstream os; for (int32 i = 0; i < scc_node_to_epoch.size(); i++) os << scc_node_to_epoch[i] << ", "; KALDI_VLOG(6) << "scc_node_to_epoch is: " << os.str(); } node_to_epoch->clear(); node_to_epoch->resize(graph.size()); for (int32 i = 0; i < sccs.size(); ++i) { for (int32 j = 0; j < sccs[i].size(); ++j) { int32 node = sccs[i][j]; KALDI_ASSERT(node >= 0 && node < graph.size()); (*node_to_epoch)[node] = scc_node_to_epoch[i]; } } } bool GraphHasCycles(const std::vector > &graph) { std::vector > sccs; FindSccs(graph, &sccs); for (size_t i = 0; i < sccs.size(); i++) { if (sccs[i].size() > 1) return true; } // the next code checks for links from a state to itself. int32 num_nodes = graph.size(); for (size_t i = 0; i < num_nodes; i++) for (std::vector::const_iterator iter = graph[i].begin(), end = graph[i].end(); iter != end; ++iter) if (*iter == i) return true; return false; } } // namespace nnet3 } // namespace kaldi