Blame view
src/util/kaldi-thread.h
11.3 KB
8dcb6dfcb first commit |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 |
// util/kaldi-thread.h // Copyright 2012 Johns Hopkins University (Author: Daniel Povey) // Frantisek Skala // 2017 University of Southern California (Author: Dogan Can) // See ../../COPYING for clarification regarding multiple authors // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY // KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED // WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, // MERCHANTABLITY OR NON-INFRINGEMENT. // See the Apache 2 License for the specific language governing permissions and // limitations under the License. #ifndef KALDI_THREAD_KALDI_THREAD_H_ #define KALDI_THREAD_KALDI_THREAD_H_ 1 #include <thread> #include "itf/options-itf.h" #include "util/kaldi-semaphore.h" // This header provides convenient mechanisms for parallelization. // // The class MultiThreader, and the function RunMultiThreaded provide a // mechanism to run a specified number of jobs in parellel and wait for them // all to finish. They accept objects of some class C that derives from the // base class MultiThreadable. C needs to define the operator () that takes // no arguments. See ExampleClass below. // // The class TaskSequencer addresses a different problem typically encountered // in Kaldi command-line programs that process a sequence of items. The items // to be processed are coming in. They are all of different sizes, e.g. // utterances with different numbers of frames. We would like them to be // processed in parallel to make good use of the threads available but they // must be output in the same order they came in. Here, we again accept objects // of some class C with an operator () that takes no arguments. C may also have // a destructor with side effects (typically some kind of output). // TaskSequencer is responsible for running the jobs in parallel. It has a // function Run() that will accept a new object of class C; this will block // until a thread is free, at which time it will spawn a thread that starts // running the operator () of the object. When threads are finished running, // the objects will be deleted. TaskSequencer guarantees that the destructors // will be called sequentially (not in parallel) and in the same order the // objects were given to the Run() function, so that it is safe for the // destructor to have side effects such as outputting data. // Note: the destructor of TaskSequencer will wait for any remaining jobs that // are still running and will call the destructors. namespace kaldi { extern int32 g_num_threads; // Maximum number of threads (for programs that // use threads, which is not many of them, e.g. the SGMM update program does. // This is 8 by default. You can change this on the command line, where // used, with --num-threads. Programs that think they will use threads // should register it with their ParseOptions, as something like: // po.Register("num-threads", &g_num_threads, "Number of threads to use."); class MultiThreadable { // To create a function object that does part of the job, inherit from this // class, implement a copy constructor calling the default copy constructor // of this base class (so that thread_id_ and num_threads_ are copied to new // instances), and finally implement the operator() that does part of the job // based on thread_id_ and num_threads_ variables. // Note: example implementations are in util/kaldi-thread-test.cc public: virtual void operator() () = 0; // Does the main function of the class // Subclasses have to redefine this virtual ~MultiThreadable(); // Optional destructor. Note: the destructor of the object passed by the user // will also be called, so watch out. public: // Do not redeclare thread_id_ and num_threads_ in derived classes. int32 thread_id_; // 0 <= thread_id_ < num_threads_ int32 num_threads_; private: // Have additional member variables as needed. }; class ExampleClass: public MultiThreadable { public: ExampleClass(int32 *foo); // Typically there will be an initializer that // takes arguments. ExampleClass(const ExampleClass &other); // A copy constructor is also needed; // some example classes use the default version of this. void operator() () { // Does the main function of the class. This // function will typically want to look at the values of the // member variables thread_id_ and num_threads_, inherited // from MultiThreadable. } ~ExampleClass() { // Optional destructor. Sometimes useful things happen here, // for example summing up of certain quantities. See code // that uses RunMultiThreaded for examples. } private: // Have additional member variables as needed. }; template<class C> class MultiThreader { public: MultiThreader(int32 num_threads, const C &c_in) : threads_(std::max<int32>(1, num_threads)), cvec_(std::max<int32>(1, num_threads), c_in) { if (num_threads == 0) { // This is a special case with num_threads == 0, which behaves like with // num_threads == 1 but without creating extra threads. This can be // useful in GPU computations where threads cannot be used. cvec_[0].thread_id_ = 0; cvec_[0].num_threads_ = 1; (cvec_[0])(); } else { for (int32 i = 0; i < threads_.size(); i++) { cvec_[i].thread_id_ = i; cvec_[i].num_threads_ = threads_.size(); threads_[i] = std::thread(std::ref(cvec_[i])); } } } ~MultiThreader() { for (size_t i = 0; i < threads_.size(); i++) if (threads_[i].joinable()) threads_[i].join(); } private: std::vector<std::thread> threads_; std::vector<C> cvec_; }; /// Here, class C should inherit from MultiThreadable. Note: if you want to /// control the number of threads yourself, or need to do something in the main /// thread of the program while the objects exist, just initialize the /// MultiThreader<C> object yourself. template<class C> void RunMultiThreaded(const C &c_in) { MultiThreader<C> m(g_num_threads, c_in); } struct TaskSequencerConfig { int32 num_threads; int32 num_threads_total; TaskSequencerConfig(): num_threads(1), num_threads_total(0) { } void Register(OptionsItf *opts) { opts->Register("num-threads", &num_threads, "Number of actively processing " "threads to run in parallel"); opts->Register("num-threads-total", &num_threads_total, "Total number of " "threads, including those that are waiting on other threads " "to produce their output. Controls memory use. If <= 0, " "defaults to --num-threads plus 20. Otherwise, must " "be >= num-threads."); } }; // C should have an operator () taking no arguments, that does some kind // of computation, and a destructor that produces some kind of output (the // destructors will be run sequentially in the same order Run as called. template<class C> class TaskSequencer { public: TaskSequencer(const TaskSequencerConfig &config): num_threads_(config.num_threads), threads_avail_(config.num_threads), tot_threads_avail_(config.num_threads_total > 0 ? config.num_threads_total : config.num_threads + 20), thread_list_(NULL) { KALDI_ASSERT((config.num_threads_total <= 0 || config.num_threads_total >= config.num_threads) && "num-threads-total, if specified, must be >= num-threads"); } /// This function takes ownership of the pointer "c", and will delete it /// in the same sequence as Run was called on the jobs. void Run(C *c) { // run in main thread if (num_threads_ == 0) { (*c)(); delete c; return; } threads_avail_.Wait(); // wait till we have a thread for computation free. tot_threads_avail_.Wait(); // this ensures we don't have too many threads // waiting on I/O, and consume too much memory. // put the new RunTaskArgsList object at head of the singly // linked list thread_list_. thread_list_ = new RunTaskArgsList(this, c, thread_list_); thread_list_->thread = std::thread(TaskSequencer<C>::RunTask, thread_list_); } void Wait() { // You call this at the end if it's more convenient // than waiting for the destructor. It waits for all tasks to finish. if (thread_list_ != NULL) { thread_list_->thread.join(); KALDI_ASSERT(thread_list_->tail == NULL); // thread would not // have exited without setting tail to NULL. delete thread_list_; thread_list_ = NULL; } } /// The destructor waits for the last thread to exit. ~TaskSequencer() { Wait(); } private: struct RunTaskArgsList { TaskSequencer *me; // Think of this as a "this" pointer. C *c; // Clist element of the task we're expected std::thread thread; RunTaskArgsList *tail; RunTaskArgsList(TaskSequencer *me, C *c, RunTaskArgsList *tail): me(me), c(c), tail(tail) {} }; // This static function gets run in the threads that we create. static void RunTask(RunTaskArgsList *args) { // (1) run the job. (*(args->c))(); // call operator () on args->c, which does the computation. args->me->threads_avail_.Signal(); // Signal that the compute-intensive // part of the thread is done (we want to run no more than // config_.num_threads of these.) // (2) we want to destroy the object "c" now, by deleting it. But for // correct sequencing (this is the whole point of this class, it // is intended to ensure the output of the program is in correct order), // we first wait till the previous thread, whose details will be in "tail", // is finished. if (args->tail != NULL) { args->tail->thread.join(); } delete args->c; // delete the object "c". This may cause some output, // e.g. to a stream. We don't need to worry about concurrent access to // the output stream, because each thread waits for the previous thread // to be done, before doing this. So there is no risk of concurrent // access. args->c = NULL; if (args->tail != NULL) { KALDI_ASSERT(args->tail->tail == NULL); // Because we already // did join on args->tail->thread, which means that // thread was done, and before it exited, it would have // deleted and set to NULL its tail (which is the next line of code). delete args->tail; args->tail = NULL; } // At this point we are exiting from the thread. Signal the // "tot_threads_avail_" semaphore which is used to limit the total number of threads that are alive, including // not onlhy those that are in active computation in c->operator (), but those // that are waiting on I/O or other threads. args->me->tot_threads_avail_.Signal(); } int32 num_threads_; // copy of config.num_threads (since Semaphore doesn't store original count) Semaphore threads_avail_; // Initialized to the number of threads we are // supposed to run with; the function Run() waits on this. Semaphore tot_threads_avail_; // We use this semaphore to ensure we don't // consume too much memory... RunTaskArgsList *thread_list_; }; } // namespace kaldi #endif // KALDI_THREAD_KALDI_THREAD_H_ |