Blame view
tools/cub-1.8.0/examples/block/reduce_by_key.cu
1.52 KB
8dcb6dfcb first commit |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 |
#include <cub/cub.cuh> template < int BLOCK_THREADS, ///< Number of CTA threads typename KeyT, ///< Key type typename ValueT> ///< Value type __global__ void Kernel() { // Tuple type for scanning (pairs accumulated segment-value with segment-index) typedef cub::KeyValuePair<int, ValueT> OffsetValuePairT; // Reduce-value-by-segment scan operator typedef cub::ReduceBySegmentOp<cub::Sum> ReduceBySegmentOpT; // Parameterized BlockDiscontinuity type for setting head flags typedef cub::BlockDiscontinuity< KeyT, BLOCK_THREADS> BlockDiscontinuityKeysT; // Parameterized BlockScan type typedef cub::BlockScan< OffsetValuePairT, BLOCK_THREADS, cub::BLOCK_SCAN_WARP_SCANS> BlockScanT; // Shared memory __shared__ union TempStorage { typename BlockScanT::TempStorage scan; // Scan storage typename BlockDiscontinuityKeysT::TempStorage discontinuity; // Discontinuity storage } temp_storage; // Read data (each thread gets 3 items each, every 9 items is a segment) KeyT my_keys[3] = {threadIdx.x / 3, threadIdx.x / 3, threadIdx.x / 3}; ValueT my_values[3] = {1, 1, 1}; // Set head segment head flags int my_flags[3]; BlockDiscontinuityKeysT(temp_storage.discontinuity).FlagHeads( my_flags, my_keys, cub::Inequality()); __syncthreads(); } |