Blame view

tools/cub-1.8.0/examples/block/reduce_by_key.cu 1.52 KB
8dcb6dfcb   Yannick Estève   first commit
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
  
  
  #include <cub/cub.cuh>
  
  
  template <
      int         BLOCK_THREADS,          ///< Number of CTA threads
      typename    KeyT,                   ///< Key type
      typename    ValueT>                 ///< Value type
  __global__ void Kernel()
  {
      // Tuple type for scanning (pairs accumulated segment-value with segment-index)
      typedef cub::KeyValuePair<int, ValueT> OffsetValuePairT;
  
      // Reduce-value-by-segment scan operator
      typedef cub::ReduceBySegmentOp<cub::Sum> ReduceBySegmentOpT;
  
      // Parameterized BlockDiscontinuity type for setting head flags
      typedef cub::BlockDiscontinuity<
              KeyT,
              BLOCK_THREADS>
          BlockDiscontinuityKeysT;
  
      // Parameterized BlockScan type
      typedef cub::BlockScan<
              OffsetValuePairT,
              BLOCK_THREADS,
              cub::BLOCK_SCAN_WARP_SCANS>
          BlockScanT;
  
      // Shared memory
      __shared__ union TempStorage
      {
          typename BlockScanT::TempStorage                scan;           // Scan storage
          typename BlockDiscontinuityKeysT::TempStorage   discontinuity;  // Discontinuity storage
      } temp_storage;
  
  
      // Read data (each thread gets 3 items each, every 9 items is a segment)
      KeyT    my_keys[3]      = {threadIdx.x / 3, threadIdx.x / 3, threadIdx.x / 3};
      ValueT  my_values[3]    = {1, 1, 1};
  
      // Set head segment head flags
      int     my_flags[3];
      BlockDiscontinuityKeysT(temp_storage.discontinuity).FlagHeads(
          my_flags,
          my_keys,
          cub::Inequality());
  
      __syncthreads();
  
  
  
  
  
  
  }