block_exchange.cuh
51.1 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
1001
1002
1003
1004
1005
1006
1007
1008
1009
1010
1011
1012
1013
1014
1015
1016
1017
1018
1019
1020
1021
1022
1023
1024
1025
1026
1027
1028
1029
1030
1031
1032
1033
1034
1035
1036
1037
1038
1039
1040
1041
1042
1043
1044
1045
1046
1047
1048
1049
1050
1051
1052
1053
1054
1055
1056
1057
1058
1059
1060
1061
1062
1063
1064
1065
1066
1067
1068
1069
1070
1071
1072
1073
1074
1075
1076
1077
1078
1079
1080
1081
1082
1083
1084
1085
1086
1087
1088
1089
1090
1091
1092
1093
1094
1095
1096
1097
1098
1099
1100
1101
1102
1103
1104
1105
1106
1107
1108
1109
1110
1111
1112
1113
1114
1115
1116
1117
1118
1119
1120
1121
1122
1123
1124
1125
1126
1127
1128
1129
1130
1131
1132
1133
1134
1135
1136
1137
1138
1139
1140
1141
1142
1143
1144
1145
1146
1147
1148
1149
1150
1151
1152
1153
1154
1155
1156
1157
1158
1159
1160
1161
1162
1163
1164
1165
1166
1167
1168
1169
1170
1171
1172
1173
1174
1175
1176
1177
1178
1179
1180
1181
1182
1183
1184
1185
1186
1187
1188
1189
1190
1191
1192
1193
1194
1195
1196
1197
1198
1199
1200
1201
1202
1203
1204
1205
1206
1207
1208
1209
1210
1211
1212
1213
1214
1215
1216
1217
1218
1219
1220
1221
1222
1223
1224
1225
1226
1227
1228
1229
1230
1231
1232
1233
1234
1235
1236
1237
1238
1239
1240
1241
1242
1243
1244
1245
1246
1247
1248
/******************************************************************************
* Copyright (c) 2011, Duane Merrill. All rights reserved.
* Copyright (c) 2011-2018, NVIDIA CORPORATION. All rights reserved.
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
* * Redistributions of source code must retain the above copyright
* notice, this list of conditions and the following disclaimer.
* * Redistributions in binary form must reproduce the above copyright
* notice, this list of conditions and the following disclaimer in the
* documentation and/or other materials provided with the distribution.
* * Neither the name of the NVIDIA CORPORATION nor the
* names of its contributors may be used to endorse or promote products
* derived from this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
* ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
* WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY
* DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
* (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
* LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
* ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
* (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
* SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
******************************************************************************/
/**
* \file
* The cub::BlockExchange class provides [<em>collective</em>](index.html#sec0) methods for rearranging data partitioned across a CUDA thread block.
*/
#pragma once
#include "../util_ptx.cuh"
#include "../util_arch.cuh"
#include "../util_macro.cuh"
#include "../util_type.cuh"
#include "../util_namespace.cuh"
/// Optional outer namespace(s)
CUB_NS_PREFIX
/// CUB namespace
namespace cub {
/**
* \brief The BlockExchange class provides [<em>collective</em>](index.html#sec0) methods for rearranging data partitioned across a CUDA thread block. ![](transpose_logo.png)
* \ingroup BlockModule
*
* \tparam T The data type to be exchanged.
* \tparam BLOCK_DIM_X The thread block length in threads along the X dimension
* \tparam ITEMS_PER_THREAD The number of items partitioned onto each thread.
* \tparam WARP_TIME_SLICING <b>[optional]</b> When \p true, only use enough shared memory for a single warp's worth of tile data, time-slicing the block-wide exchange over multiple synchronized rounds. Yields a smaller memory footprint at the expense of decreased parallelism. (Default: false)
* \tparam BLOCK_DIM_Y <b>[optional]</b> The thread block length in threads along the Y dimension (default: 1)
* \tparam BLOCK_DIM_Z <b>[optional]</b> The thread block length in threads along the Z dimension (default: 1)
* \tparam PTX_ARCH <b>[optional]</b> \ptxversion
*
* \par Overview
* - It is commonplace for blocks of threads to rearrange data items between
* threads. For example, the device-accessible memory subsystem prefers access patterns
* where data items are "striped" across threads (where consecutive threads access consecutive items),
* yet most block-wide operations prefer a "blocked" partitioning of items across threads
* (where consecutive items belong to a single thread).
* - BlockExchange supports the following types of data exchanges:
* - Transposing between [<em>blocked</em>](index.html#sec5sec3) and [<em>striped</em>](index.html#sec5sec3) arrangements
* - Transposing between [<em>blocked</em>](index.html#sec5sec3) and [<em>warp-striped</em>](index.html#sec5sec3) arrangements
* - Scattering ranked items to a [<em>blocked arrangement</em>](index.html#sec5sec3)
* - Scattering ranked items to a [<em>striped arrangement</em>](index.html#sec5sec3)
* - \rowmajor
*
* \par A Simple Example
* \blockcollective{BlockExchange}
* \par
* The code snippet below illustrates the conversion from a "blocked" to a "striped" arrangement
* of 512 integer items partitioned across 128 threads where each thread owns 4 items.
* \par
* \code
* #include <cub/cub.cuh> // or equivalently <cub/block/block_exchange.cuh>
*
* __global__ void ExampleKernel(int *d_data, ...)
* {
* // Specialize BlockExchange for a 1D block of 128 threads owning 4 integer items each
* typedef cub::BlockExchange<int, 128, 4> BlockExchange;
*
* // Allocate shared memory for BlockExchange
* __shared__ typename BlockExchange::TempStorage temp_storage;
*
* // Load a tile of data striped across threads
* int thread_data[4];
* cub::LoadDirectStriped<128>(threadIdx.x, d_data, thread_data);
*
* // Collectively exchange data into a blocked arrangement across threads
* BlockExchange(temp_storage).StripedToBlocked(thread_data);
*
* \endcode
* \par
* Suppose the set of striped input \p thread_data across the block of threads is
* <tt>{ [0,128,256,384], [1,129,257,385], ..., [127,255,383,511] }</tt>.
* The corresponding output \p thread_data in those threads will be
* <tt>{ [0,1,2,3], [4,5,6,7], [8,9,10,11], ..., [508,509,510,511] }</tt>.
*
* \par Performance Considerations
* - Proper device-specific padding ensures zero bank conflicts for most types.
*
*/
template <
typename InputT,
int BLOCK_DIM_X,
int ITEMS_PER_THREAD,
bool WARP_TIME_SLICING = false,
int BLOCK_DIM_Y = 1,
int BLOCK_DIM_Z = 1,
int PTX_ARCH = CUB_PTX_ARCH>
class BlockExchange
{
private:
/******************************************************************************
* Constants
******************************************************************************/
/// Constants
enum
{
/// The thread block size in threads
BLOCK_THREADS = BLOCK_DIM_X * BLOCK_DIM_Y * BLOCK_DIM_Z,
LOG_WARP_THREADS = CUB_LOG_WARP_THREADS(PTX_ARCH),
WARP_THREADS = 1 << LOG_WARP_THREADS,
WARPS = (BLOCK_THREADS + WARP_THREADS - 1) / WARP_THREADS,
LOG_SMEM_BANKS = CUB_LOG_SMEM_BANKS(PTX_ARCH),
SMEM_BANKS = 1 << LOG_SMEM_BANKS,
TILE_ITEMS = BLOCK_THREADS * ITEMS_PER_THREAD,
TIME_SLICES = (WARP_TIME_SLICING) ? WARPS : 1,
TIME_SLICED_THREADS = (WARP_TIME_SLICING) ? CUB_MIN(BLOCK_THREADS, WARP_THREADS) : BLOCK_THREADS,
TIME_SLICED_ITEMS = TIME_SLICED_THREADS * ITEMS_PER_THREAD,
WARP_TIME_SLICED_THREADS = CUB_MIN(BLOCK_THREADS, WARP_THREADS),
WARP_TIME_SLICED_ITEMS = WARP_TIME_SLICED_THREADS * ITEMS_PER_THREAD,
// Insert padding to avoid bank conflicts during raking when items per thread is a power of two and > 4 (otherwise we can typically use 128b loads)
INSERT_PADDING = (ITEMS_PER_THREAD > 4) && (PowerOfTwo<ITEMS_PER_THREAD>::VALUE),
PADDING_ITEMS = (INSERT_PADDING) ? (TIME_SLICED_ITEMS >> LOG_SMEM_BANKS) : 0,
};
/******************************************************************************
* Type definitions
******************************************************************************/
/// Shared memory storage layout type
struct __align__(16) _TempStorage
{
InputT buff[TIME_SLICED_ITEMS + PADDING_ITEMS];
};
public:
/// \smemstorage{BlockExchange}
struct TempStorage : Uninitialized<_TempStorage> {};
private:
/******************************************************************************
* Thread fields
******************************************************************************/
/// Shared storage reference
_TempStorage &temp_storage;
/// Linear thread-id
unsigned int linear_tid;
unsigned int lane_id;
unsigned int warp_id;
unsigned int warp_offset;
/******************************************************************************
* Utility methods
******************************************************************************/
/// Internal storage allocator
__device__ __forceinline__ _TempStorage& PrivateStorage()
{
__shared__ _TempStorage private_storage;
return private_storage;
}
/**
* Transposes data items from <em>blocked</em> arrangement to <em>striped</em> arrangement. Specialized for no timeslicing.
*/
template <typename OutputT>
__device__ __forceinline__ void BlockedToStriped(
InputT input_items[ITEMS_PER_THREAD], ///< [in] Items to exchange, converting between <em>blocked</em> and <em>striped</em> arrangements.
OutputT output_items[ITEMS_PER_THREAD], ///< [out] Items to exchange, converting between <em>blocked</em> and <em>striped</em> arrangements.
Int2Type<false> /*time_slicing*/)
{
#pragma unroll
for (int ITEM = 0; ITEM < ITEMS_PER_THREAD; ITEM++)
{
int item_offset = (linear_tid * ITEMS_PER_THREAD) + ITEM;
if (INSERT_PADDING) item_offset += item_offset >> LOG_SMEM_BANKS;
temp_storage.buff[item_offset] = input_items[ITEM];
}
CTA_SYNC();
#pragma unroll
for (int ITEM = 0; ITEM < ITEMS_PER_THREAD; ITEM++)
{
int item_offset = int(ITEM * BLOCK_THREADS) + linear_tid;
if (INSERT_PADDING) item_offset += item_offset >> LOG_SMEM_BANKS;
output_items[ITEM] = temp_storage.buff[item_offset];
}
}
/**
* Transposes data items from <em>blocked</em> arrangement to <em>striped</em> arrangement. Specialized for warp-timeslicing.
*/
template <typename OutputT>
__device__ __forceinline__ void BlockedToStriped(
InputT input_items[ITEMS_PER_THREAD], ///< [in] Items to exchange, converting between <em>blocked</em> and <em>striped</em> arrangements.
OutputT output_items[ITEMS_PER_THREAD], ///< [out] Items to exchange, converting between <em>blocked</em> and <em>striped</em> arrangements.
Int2Type<true> /*time_slicing*/)
{
InputT temp_items[ITEMS_PER_THREAD];
#pragma unroll
for (int SLICE = 0; SLICE < TIME_SLICES; SLICE++)
{
const int SLICE_OFFSET = SLICE * TIME_SLICED_ITEMS;
const int SLICE_OOB = SLICE_OFFSET + TIME_SLICED_ITEMS;
CTA_SYNC();
if (warp_id == SLICE)
{
#pragma unroll
for (int ITEM = 0; ITEM < ITEMS_PER_THREAD; ITEM++)
{
int item_offset = (lane_id * ITEMS_PER_THREAD) + ITEM;
if (INSERT_PADDING) item_offset += item_offset >> LOG_SMEM_BANKS;
temp_storage.buff[item_offset] = input_items[ITEM];
}
}
CTA_SYNC();
#pragma unroll
for (int ITEM = 0; ITEM < ITEMS_PER_THREAD; ITEM++)
{
// Read a strip of items
const int STRIP_OFFSET = ITEM * BLOCK_THREADS;
const int STRIP_OOB = STRIP_OFFSET + BLOCK_THREADS;
if ((SLICE_OFFSET < STRIP_OOB) && (SLICE_OOB > STRIP_OFFSET))
{
int item_offset = STRIP_OFFSET + linear_tid - SLICE_OFFSET;
if ((item_offset >= 0) && (item_offset < TIME_SLICED_ITEMS))
{
if (INSERT_PADDING) item_offset += item_offset >> LOG_SMEM_BANKS;
temp_items[ITEM] = temp_storage.buff[item_offset];
}
}
}
}
// Copy
#pragma unroll
for (int ITEM = 0; ITEM < ITEMS_PER_THREAD; ITEM++)
{
output_items[ITEM] = temp_items[ITEM];
}
}
/**
* Transposes data items from <em>blocked</em> arrangement to <em>warp-striped</em> arrangement. Specialized for no timeslicing
*/
template <typename OutputT>
__device__ __forceinline__ void BlockedToWarpStriped(
InputT input_items[ITEMS_PER_THREAD], ///< [in] Items to exchange, converting between <em>blocked</em> and <em>striped</em> arrangements.
OutputT output_items[ITEMS_PER_THREAD], ///< [out] Items to exchange, converting between <em>blocked</em> and <em>striped</em> arrangements.
Int2Type<false> /*time_slicing*/)
{
#pragma unroll
for (int ITEM = 0; ITEM < ITEMS_PER_THREAD; ITEM++)
{
int item_offset = warp_offset + ITEM + (lane_id * ITEMS_PER_THREAD);
if (INSERT_PADDING) item_offset += item_offset >> LOG_SMEM_BANKS;
temp_storage.buff[item_offset] = input_items[ITEM];
}
WARP_SYNC(0xffffffff);
#pragma unroll
for (int ITEM = 0; ITEM < ITEMS_PER_THREAD; ITEM++)
{
int item_offset = warp_offset + (ITEM * WARP_TIME_SLICED_THREADS) + lane_id;
if (INSERT_PADDING) item_offset += item_offset >> LOG_SMEM_BANKS;
output_items[ITEM] = temp_storage.buff[item_offset];
}
}
/**
* Transposes data items from <em>blocked</em> arrangement to <em>warp-striped</em> arrangement. Specialized for warp-timeslicing
*/
template <typename OutputT>
__device__ __forceinline__ void BlockedToWarpStriped(
InputT input_items[ITEMS_PER_THREAD], ///< [in] Items to exchange, converting between <em>blocked</em> and <em>striped</em> arrangements.
OutputT output_items[ITEMS_PER_THREAD], ///< [out] Items to exchange, converting between <em>blocked</em> and <em>striped</em> arrangements.
Int2Type<true> /*time_slicing*/)
{
if (warp_id == 0)
{
#pragma unroll
for (int ITEM = 0; ITEM < ITEMS_PER_THREAD; ITEM++)
{
int item_offset = ITEM + (lane_id * ITEMS_PER_THREAD);
if (INSERT_PADDING) item_offset += item_offset >> LOG_SMEM_BANKS;
temp_storage.buff[item_offset] = input_items[ITEM];
}
WARP_SYNC(0xffffffff);
#pragma unroll
for (int ITEM = 0; ITEM < ITEMS_PER_THREAD; ITEM++)
{
int item_offset = (ITEM * WARP_TIME_SLICED_THREADS) + lane_id;
if (INSERT_PADDING) item_offset += item_offset >> LOG_SMEM_BANKS;
output_items[ITEM] = temp_storage.buff[item_offset];
}
}
#pragma unroll
for (unsigned int SLICE = 1; SLICE < TIME_SLICES; ++SLICE)
{
CTA_SYNC();
if (warp_id == SLICE)
{
#pragma unroll
for (int ITEM = 0; ITEM < ITEMS_PER_THREAD; ITEM++)
{
int item_offset = ITEM + (lane_id * ITEMS_PER_THREAD);
if (INSERT_PADDING) item_offset += item_offset >> LOG_SMEM_BANKS;
temp_storage.buff[item_offset] = input_items[ITEM];
}
WARP_SYNC(0xffffffff);
#pragma unroll
for (int ITEM = 0; ITEM < ITEMS_PER_THREAD; ITEM++)
{
int item_offset = (ITEM * WARP_TIME_SLICED_THREADS) + lane_id;
if (INSERT_PADDING) item_offset += item_offset >> LOG_SMEM_BANKS;
output_items[ITEM] = temp_storage.buff[item_offset];
}
}
}
}
/**
* Transposes data items from <em>striped</em> arrangement to <em>blocked</em> arrangement. Specialized for no timeslicing.
*/
template <typename OutputT>
__device__ __forceinline__ void StripedToBlocked(
InputT input_items[ITEMS_PER_THREAD], ///< [in] Items to exchange, converting between <em>blocked</em> and <em>striped</em> arrangements.
OutputT output_items[ITEMS_PER_THREAD], ///< [out] Items to exchange, converting between <em>blocked</em> and <em>striped</em> arrangements.
Int2Type<false> /*time_slicing*/)
{
#pragma unroll
for (int ITEM = 0; ITEM < ITEMS_PER_THREAD; ITEM++)
{
int item_offset = int(ITEM * BLOCK_THREADS) + linear_tid;
if (INSERT_PADDING) item_offset += item_offset >> LOG_SMEM_BANKS;
temp_storage.buff[item_offset] = input_items[ITEM];
}
CTA_SYNC();
// No timeslicing
#pragma unroll
for (int ITEM = 0; ITEM < ITEMS_PER_THREAD; ITEM++)
{
int item_offset = (linear_tid * ITEMS_PER_THREAD) + ITEM;
if (INSERT_PADDING) item_offset += item_offset >> LOG_SMEM_BANKS;
output_items[ITEM] = temp_storage.buff[item_offset];
}
}
/**
* Transposes data items from <em>striped</em> arrangement to <em>blocked</em> arrangement. Specialized for warp-timeslicing.
*/
template <typename OutputT>
__device__ __forceinline__ void StripedToBlocked(
InputT input_items[ITEMS_PER_THREAD], ///< [in] Items to exchange, converting between <em>blocked</em> and <em>striped</em> arrangements.
OutputT output_items[ITEMS_PER_THREAD], ///< [out] Items to exchange, converting between <em>blocked</em> and <em>striped</em> arrangements.
Int2Type<true> /*time_slicing*/)
{
// Warp time-slicing
InputT temp_items[ITEMS_PER_THREAD];
#pragma unroll
for (int SLICE = 0; SLICE < TIME_SLICES; SLICE++)
{
const int SLICE_OFFSET = SLICE * TIME_SLICED_ITEMS;
const int SLICE_OOB = SLICE_OFFSET + TIME_SLICED_ITEMS;
CTA_SYNC();
#pragma unroll
for (int ITEM = 0; ITEM < ITEMS_PER_THREAD; ITEM++)
{
// Write a strip of items
const int STRIP_OFFSET = ITEM * BLOCK_THREADS;
const int STRIP_OOB = STRIP_OFFSET + BLOCK_THREADS;
if ((SLICE_OFFSET < STRIP_OOB) && (SLICE_OOB > STRIP_OFFSET))
{
int item_offset = STRIP_OFFSET + linear_tid - SLICE_OFFSET;
if ((item_offset >= 0) && (item_offset < TIME_SLICED_ITEMS))
{
if (INSERT_PADDING) item_offset += item_offset >> LOG_SMEM_BANKS;
temp_storage.buff[item_offset] = input_items[ITEM];
}
}
}
CTA_SYNC();
if (warp_id == SLICE)
{
#pragma unroll
for (int ITEM = 0; ITEM < ITEMS_PER_THREAD; ITEM++)
{
int item_offset = (lane_id * ITEMS_PER_THREAD) + ITEM;
if (INSERT_PADDING) item_offset += item_offset >> LOG_SMEM_BANKS;
temp_items[ITEM] = temp_storage.buff[item_offset];
}
}
}
// Copy
#pragma unroll
for (int ITEM = 0; ITEM < ITEMS_PER_THREAD; ITEM++)
{
output_items[ITEM] = temp_items[ITEM];
}
}
/**
* Transposes data items from <em>warp-striped</em> arrangement to <em>blocked</em> arrangement. Specialized for no timeslicing
*/
template <typename OutputT>
__device__ __forceinline__ void WarpStripedToBlocked(
InputT input_items[ITEMS_PER_THREAD], ///< [in] Items to exchange, converting between <em>blocked</em> and <em>striped</em> arrangements.
OutputT output_items[ITEMS_PER_THREAD], ///< [out] Items to exchange, converting between <em>blocked</em> and <em>striped</em> arrangements.
Int2Type<false> /*time_slicing*/)
{
#pragma unroll
for (int ITEM = 0; ITEM < ITEMS_PER_THREAD; ITEM++)
{
int item_offset = warp_offset + (ITEM * WARP_TIME_SLICED_THREADS) + lane_id;
if (INSERT_PADDING) item_offset += item_offset >> LOG_SMEM_BANKS;
temp_storage.buff[item_offset] = input_items[ITEM];
}
WARP_SYNC(0xffffffff);
#pragma unroll
for (int ITEM = 0; ITEM < ITEMS_PER_THREAD; ITEM++)
{
int item_offset = warp_offset + ITEM + (lane_id * ITEMS_PER_THREAD);
if (INSERT_PADDING) item_offset += item_offset >> LOG_SMEM_BANKS;
output_items[ITEM] = temp_storage.buff[item_offset];
}
}
/**
* Transposes data items from <em>warp-striped</em> arrangement to <em>blocked</em> arrangement. Specialized for warp-timeslicing
*/
template <typename OutputT>
__device__ __forceinline__ void WarpStripedToBlocked(
InputT input_items[ITEMS_PER_THREAD], ///< [in] Items to exchange, converting between <em>blocked</em> and <em>striped</em> arrangements.
OutputT output_items[ITEMS_PER_THREAD], ///< [out] Items to exchange, converting between <em>blocked</em> and <em>striped</em> arrangements.
Int2Type<true> /*time_slicing*/)
{
#pragma unroll
for (unsigned int SLICE = 0; SLICE < TIME_SLICES; ++SLICE)
{
CTA_SYNC();
if (warp_id == SLICE)
{
#pragma unroll
for (int ITEM = 0; ITEM < ITEMS_PER_THREAD; ITEM++)
{
int item_offset = (ITEM * WARP_TIME_SLICED_THREADS) + lane_id;
if (INSERT_PADDING) item_offset += item_offset >> LOG_SMEM_BANKS;
temp_storage.buff[item_offset] = input_items[ITEM];
}
WARP_SYNC(0xffffffff);
#pragma unroll
for (int ITEM = 0; ITEM < ITEMS_PER_THREAD; ITEM++)
{
int item_offset = ITEM + (lane_id * ITEMS_PER_THREAD);
if (INSERT_PADDING) item_offset += item_offset >> LOG_SMEM_BANKS;
output_items[ITEM] = temp_storage.buff[item_offset];
}
}
}
}
/**
* Exchanges data items annotated by rank into <em>blocked</em> arrangement. Specialized for no timeslicing.
*/
template <typename OutputT, typename OffsetT>
__device__ __forceinline__ void ScatterToBlocked(
InputT input_items[ITEMS_PER_THREAD], ///< [in] Items to exchange, converting between <em>blocked</em> and <em>striped</em> arrangements.
OutputT output_items[ITEMS_PER_THREAD], ///< [out] Items to exchange, converting between <em>blocked</em> and <em>striped</em> arrangements.
OffsetT ranks[ITEMS_PER_THREAD], ///< [in] Corresponding scatter ranks
Int2Type<false> /*time_slicing*/)
{
#pragma unroll
for (int ITEM = 0; ITEM < ITEMS_PER_THREAD; ITEM++)
{
int item_offset = ranks[ITEM];
if (INSERT_PADDING) item_offset = SHR_ADD(item_offset, LOG_SMEM_BANKS, item_offset);
temp_storage.buff[item_offset] = input_items[ITEM];
}
CTA_SYNC();
#pragma unroll
for (int ITEM = 0; ITEM < ITEMS_PER_THREAD; ITEM++)
{
int item_offset = (linear_tid * ITEMS_PER_THREAD) + ITEM;
if (INSERT_PADDING) item_offset = SHR_ADD(item_offset, LOG_SMEM_BANKS, item_offset);
output_items[ITEM] = temp_storage.buff[item_offset];
}
}
/**
* Exchanges data items annotated by rank into <em>blocked</em> arrangement. Specialized for warp-timeslicing.
*/
template <typename OutputT, typename OffsetT>
__device__ __forceinline__ void ScatterToBlocked(
InputT input_items[ITEMS_PER_THREAD], ///< [in] Items to exchange, converting between <em>blocked</em> and <em>striped</em> arrangements.
OutputT output_items[ITEMS_PER_THREAD], ///< [out] Items to exchange, converting between <em>blocked</em> and <em>striped</em> arrangements.
OffsetT ranks[ITEMS_PER_THREAD], ///< [in] Corresponding scatter ranks
Int2Type<true> /*time_slicing*/)
{
InputT temp_items[ITEMS_PER_THREAD];
#pragma unroll
for (int SLICE = 0; SLICE < TIME_SLICES; SLICE++)
{
CTA_SYNC();
const int SLICE_OFFSET = TIME_SLICED_ITEMS * SLICE;
#pragma unroll
for (int ITEM = 0; ITEM < ITEMS_PER_THREAD; ITEM++)
{
int item_offset = ranks[ITEM] - SLICE_OFFSET;
if ((item_offset >= 0) && (item_offset < WARP_TIME_SLICED_ITEMS))
{
if (INSERT_PADDING) item_offset = SHR_ADD(item_offset, LOG_SMEM_BANKS, item_offset);
temp_storage.buff[item_offset] = input_items[ITEM];
}
}
CTA_SYNC();
if (warp_id == SLICE)
{
#pragma unroll
for (int ITEM = 0; ITEM < ITEMS_PER_THREAD; ITEM++)
{
int item_offset = (lane_id * ITEMS_PER_THREAD) + ITEM;
if (INSERT_PADDING) item_offset = SHR_ADD(item_offset, LOG_SMEM_BANKS, item_offset);
temp_items[ITEM] = temp_storage.buff[item_offset];
}
}
}
// Copy
#pragma unroll
for (int ITEM = 0; ITEM < ITEMS_PER_THREAD; ITEM++)
{
output_items[ITEM] = temp_items[ITEM];
}
}
/**
* Exchanges data items annotated by rank into <em>striped</em> arrangement. Specialized for no timeslicing.
*/
template <typename OutputT, typename OffsetT>
__device__ __forceinline__ void ScatterToStriped(
InputT input_items[ITEMS_PER_THREAD], ///< [in] Items to exchange, converting between <em>blocked</em> and <em>striped</em> arrangements.
OutputT output_items[ITEMS_PER_THREAD], ///< [out] Items to exchange, converting between <em>blocked</em> and <em>striped</em> arrangements.
OffsetT ranks[ITEMS_PER_THREAD], ///< [in] Corresponding scatter ranks
Int2Type<false> /*time_slicing*/)
{
#pragma unroll
for (int ITEM = 0; ITEM < ITEMS_PER_THREAD; ITEM++)
{
int item_offset = ranks[ITEM];
if (INSERT_PADDING) item_offset = SHR_ADD(item_offset, LOG_SMEM_BANKS, item_offset);
temp_storage.buff[item_offset] = input_items[ITEM];
}
CTA_SYNC();
#pragma unroll
for (int ITEM = 0; ITEM < ITEMS_PER_THREAD; ITEM++)
{
int item_offset = int(ITEM * BLOCK_THREADS) + linear_tid;
if (INSERT_PADDING) item_offset = SHR_ADD(item_offset, LOG_SMEM_BANKS, item_offset);
output_items[ITEM] = temp_storage.buff[item_offset];
}
}
/**
* Exchanges data items annotated by rank into <em>striped</em> arrangement. Specialized for warp-timeslicing.
*/
template <typename OutputT, typename OffsetT>
__device__ __forceinline__ void ScatterToStriped(
InputT input_items[ITEMS_PER_THREAD], ///< [in] Items to exchange, converting between <em>blocked</em> and <em>striped</em> arrangements.
OutputT output_items[ITEMS_PER_THREAD], ///< [out] Items to exchange, converting between <em>blocked</em> and <em>striped</em> arrangements.
OffsetT ranks[ITEMS_PER_THREAD], ///< [in] Corresponding scatter ranks
Int2Type<true> /*time_slicing*/)
{
InputT temp_items[ITEMS_PER_THREAD];
#pragma unroll
for (int SLICE = 0; SLICE < TIME_SLICES; SLICE++)
{
const int SLICE_OFFSET = SLICE * TIME_SLICED_ITEMS;
const int SLICE_OOB = SLICE_OFFSET + TIME_SLICED_ITEMS;
CTA_SYNC();
#pragma unroll
for (int ITEM = 0; ITEM < ITEMS_PER_THREAD; ITEM++)
{
int item_offset = ranks[ITEM] - SLICE_OFFSET;
if ((item_offset >= 0) && (item_offset < WARP_TIME_SLICED_ITEMS))
{
if (INSERT_PADDING) item_offset = SHR_ADD(item_offset, LOG_SMEM_BANKS, item_offset);
temp_storage.buff[item_offset] = input_items[ITEM];
}
}
CTA_SYNC();
#pragma unroll
for (int ITEM = 0; ITEM < ITEMS_PER_THREAD; ITEM++)
{
// Read a strip of items
const int STRIP_OFFSET = ITEM * BLOCK_THREADS;
const int STRIP_OOB = STRIP_OFFSET + BLOCK_THREADS;
if ((SLICE_OFFSET < STRIP_OOB) && (SLICE_OOB > STRIP_OFFSET))
{
int item_offset = STRIP_OFFSET + linear_tid - SLICE_OFFSET;
if ((item_offset >= 0) && (item_offset < TIME_SLICED_ITEMS))
{
if (INSERT_PADDING) item_offset += item_offset >> LOG_SMEM_BANKS;
temp_items[ITEM] = temp_storage.buff[item_offset];
}
}
}
}
// Copy
#pragma unroll
for (int ITEM = 0; ITEM < ITEMS_PER_THREAD; ITEM++)
{
output_items[ITEM] = temp_items[ITEM];
}
}
public:
/******************************************************************//**
* \name Collective constructors
*********************************************************************/
//@{
/**
* \brief Collective constructor using a private static allocation of shared memory as temporary storage.
*/
__device__ __forceinline__ BlockExchange()
:
temp_storage(PrivateStorage()),
linear_tid(RowMajorTid(BLOCK_DIM_X, BLOCK_DIM_Y, BLOCK_DIM_Z)),
warp_id((WARPS == 1) ? 0 : linear_tid / WARP_THREADS),
lane_id(LaneId()),
warp_offset(warp_id * WARP_TIME_SLICED_ITEMS)
{}
/**
* \brief Collective constructor using the specified memory allocation as temporary storage.
*/
__device__ __forceinline__ BlockExchange(
TempStorage &temp_storage) ///< [in] Reference to memory allocation having layout type TempStorage
:
temp_storage(temp_storage.Alias()),
linear_tid(RowMajorTid(BLOCK_DIM_X, BLOCK_DIM_Y, BLOCK_DIM_Z)),
lane_id(LaneId()),
warp_id((WARPS == 1) ? 0 : linear_tid / WARP_THREADS),
warp_offset(warp_id * WARP_TIME_SLICED_ITEMS)
{}
//@} end member group
/******************************************************************//**
* \name Structured exchanges
*********************************************************************/
//@{
/**
* \brief Transposes data items from <em>striped</em> arrangement to <em>blocked</em> arrangement.
*
* \par
* - \smemreuse
*
* \par Snippet
* The code snippet below illustrates the conversion from a "striped" to a "blocked" arrangement
* of 512 integer items partitioned across 128 threads where each thread owns 4 items.
* \par
* \code
* #include <cub/cub.cuh> // or equivalently <cub/block/block_exchange.cuh>
*
* __global__ void ExampleKernel(int *d_data, ...)
* {
* // Specialize BlockExchange for a 1D block of 128 threads owning 4 integer items each
* typedef cub::BlockExchange<int, 128, 4> BlockExchange;
*
* // Allocate shared memory for BlockExchange
* __shared__ typename BlockExchange::TempStorage temp_storage;
*
* // Load a tile of ordered data into a striped arrangement across block threads
* int thread_data[4];
* cub::LoadDirectStriped<128>(threadIdx.x, d_data, thread_data);
*
* // Collectively exchange data into a blocked arrangement across threads
* BlockExchange(temp_storage).StripedToBlocked(thread_data, thread_data);
*
* \endcode
* \par
* Suppose the set of striped input \p thread_data across the block of threads is
* <tt>{ [0,128,256,384], [1,129,257,385], ..., [127,255,383,511] }</tt> after loading from device-accessible memory.
* The corresponding output \p thread_data in those threads will be
* <tt>{ [0,1,2,3], [4,5,6,7], [8,9,10,11], ..., [508,509,510,511] }</tt>.
*
*/
template <typename OutputT>
__device__ __forceinline__ void StripedToBlocked(
InputT input_items[ITEMS_PER_THREAD], ///< [in] Items to exchange, converting between <em>striped</em> and <em>blocked</em> arrangements.
OutputT output_items[ITEMS_PER_THREAD]) ///< [out] Items from exchange, converting between <em>striped</em> and <em>blocked</em> arrangements.
{
StripedToBlocked(input_items, output_items, Int2Type<WARP_TIME_SLICING>());
}
/**
* \brief Transposes data items from <em>blocked</em> arrangement to <em>striped</em> arrangement.
*
* \par
* - \smemreuse
*
* \par Snippet
* The code snippet below illustrates the conversion from a "blocked" to a "striped" arrangement
* of 512 integer items partitioned across 128 threads where each thread owns 4 items.
* \par
* \code
* #include <cub/cub.cuh> // or equivalently <cub/block/block_exchange.cuh>
*
* __global__ void ExampleKernel(int *d_data, ...)
* {
* // Specialize BlockExchange for a 1D block of 128 threads owning 4 integer items each
* typedef cub::BlockExchange<int, 128, 4> BlockExchange;
*
* // Allocate shared memory for BlockExchange
* __shared__ typename BlockExchange::TempStorage temp_storage;
*
* // Obtain a segment of consecutive items that are blocked across threads
* int thread_data[4];
* ...
*
* // Collectively exchange data into a striped arrangement across threads
* BlockExchange(temp_storage).BlockedToStriped(thread_data, thread_data);
*
* // Store data striped across block threads into an ordered tile
* cub::StoreDirectStriped<STORE_DEFAULT, 128>(threadIdx.x, d_data, thread_data);
*
* \endcode
* \par
* Suppose the set of blocked input \p thread_data across the block of threads is
* <tt>{ [0,1,2,3], [4,5,6,7], [8,9,10,11], ..., [508,509,510,511] }</tt>.
* The corresponding output \p thread_data in those threads will be
* <tt>{ [0,128,256,384], [1,129,257,385], ..., [127,255,383,511] }</tt> in
* preparation for storing to device-accessible memory.
*
*/
template <typename OutputT>
__device__ __forceinline__ void BlockedToStriped(
InputT input_items[ITEMS_PER_THREAD], ///< [in] Items to exchange, converting between <em>striped</em> and <em>blocked</em> arrangements.
OutputT output_items[ITEMS_PER_THREAD]) ///< [out] Items from exchange, converting between <em>striped</em> and <em>blocked</em> arrangements.
{
BlockedToStriped(input_items, output_items, Int2Type<WARP_TIME_SLICING>());
}
/**
* \brief Transposes data items from <em>warp-striped</em> arrangement to <em>blocked</em> arrangement.
*
* \par
* - \smemreuse
*
* \par Snippet
* The code snippet below illustrates the conversion from a "warp-striped" to a "blocked" arrangement
* of 512 integer items partitioned across 128 threads where each thread owns 4 items.
* \par
* \code
* #include <cub/cub.cuh> // or equivalently <cub/block/block_exchange.cuh>
*
* __global__ void ExampleKernel(int *d_data, ...)
* {
* // Specialize BlockExchange for a 1D block of 128 threads owning 4 integer items each
* typedef cub::BlockExchange<int, 128, 4> BlockExchange;
*
* // Allocate shared memory for BlockExchange
* __shared__ typename BlockExchange::TempStorage temp_storage;
*
* // Load a tile of ordered data into a warp-striped arrangement across warp threads
* int thread_data[4];
* cub::LoadSWarptriped<LOAD_DEFAULT>(threadIdx.x, d_data, thread_data);
*
* // Collectively exchange data into a blocked arrangement across threads
* BlockExchange(temp_storage).WarpStripedToBlocked(thread_data);
*
* \endcode
* \par
* Suppose the set of warp-striped input \p thread_data across the block of threads is
* <tt>{ [0,32,64,96], [1,33,65,97], [2,34,66,98], ..., [415,447,479,511] }</tt>
* after loading from device-accessible memory. (The first 128 items are striped across
* the first warp of 32 threads, the second 128 items are striped across the second warp, etc.)
* The corresponding output \p thread_data in those threads will be
* <tt>{ [0,1,2,3], [4,5,6,7], [8,9,10,11], ..., [508,509,510,511] }</tt>.
*
*/
template <typename OutputT>
__device__ __forceinline__ void WarpStripedToBlocked(
InputT input_items[ITEMS_PER_THREAD], ///< [in] Items to exchange, converting between <em>striped</em> and <em>blocked</em> arrangements.
OutputT output_items[ITEMS_PER_THREAD]) ///< [out] Items from exchange, converting between <em>striped</em> and <em>blocked</em> arrangements.
{
WarpStripedToBlocked(input_items, output_items, Int2Type<WARP_TIME_SLICING>());
}
/**
* \brief Transposes data items from <em>blocked</em> arrangement to <em>warp-striped</em> arrangement.
*
* \par
* - \smemreuse
*
* \par Snippet
* The code snippet below illustrates the conversion from a "blocked" to a "warp-striped" arrangement
* of 512 integer items partitioned across 128 threads where each thread owns 4 items.
* \par
* \code
* #include <cub/cub.cuh> // or equivalently <cub/block/block_exchange.cuh>
*
* __global__ void ExampleKernel(int *d_data, ...)
* {
* // Specialize BlockExchange for a 1D block of 128 threads owning 4 integer items each
* typedef cub::BlockExchange<int, 128, 4> BlockExchange;
*
* // Allocate shared memory for BlockExchange
* __shared__ typename BlockExchange::TempStorage temp_storage;
*
* // Obtain a segment of consecutive items that are blocked across threads
* int thread_data[4];
* ...
*
* // Collectively exchange data into a warp-striped arrangement across threads
* BlockExchange(temp_storage).BlockedToWarpStriped(thread_data, thread_data);
*
* // Store data striped across warp threads into an ordered tile
* cub::StoreDirectStriped<STORE_DEFAULT, 128>(threadIdx.x, d_data, thread_data);
*
* \endcode
* \par
* Suppose the set of blocked input \p thread_data across the block of threads is
* <tt>{ [0,1,2,3], [4,5,6,7], [8,9,10,11], ..., [508,509,510,511] }</tt>.
* The corresponding output \p thread_data in those threads will be
* <tt>{ [0,32,64,96], [1,33,65,97], [2,34,66,98], ..., [415,447,479,511] }</tt>
* in preparation for storing to device-accessible memory. (The first 128 items are striped across
* the first warp of 32 threads, the second 128 items are striped across the second warp, etc.)
*
*/
template <typename OutputT>
__device__ __forceinline__ void BlockedToWarpStriped(
InputT input_items[ITEMS_PER_THREAD], ///< [in] Items to exchange, converting between <em>striped</em> and <em>blocked</em> arrangements.
OutputT output_items[ITEMS_PER_THREAD]) ///< [out] Items from exchange, converting between <em>striped</em> and <em>blocked</em> arrangements.
{
BlockedToWarpStriped(input_items, output_items, Int2Type<WARP_TIME_SLICING>());
}
//@} end member group
/******************************************************************//**
* \name Scatter exchanges
*********************************************************************/
//@{
/**
* \brief Exchanges data items annotated by rank into <em>blocked</em> arrangement.
*
* \par
* - \smemreuse
*
* \tparam OffsetT <b>[inferred]</b> Signed integer type for local offsets
*/
template <typename OutputT, typename OffsetT>
__device__ __forceinline__ void ScatterToBlocked(
InputT input_items[ITEMS_PER_THREAD], ///< [in] Items to exchange, converting between <em>striped</em> and <em>blocked</em> arrangements.
OutputT output_items[ITEMS_PER_THREAD], ///< [out] Items from exchange, converting between <em>striped</em> and <em>blocked</em> arrangements.
OffsetT ranks[ITEMS_PER_THREAD]) ///< [in] Corresponding scatter ranks
{
ScatterToBlocked(input_items, output_items, ranks, Int2Type<WARP_TIME_SLICING>());
}
/**
* \brief Exchanges data items annotated by rank into <em>striped</em> arrangement.
*
* \par
* - \smemreuse
*
* \tparam OffsetT <b>[inferred]</b> Signed integer type for local offsets
*/
template <typename OutputT, typename OffsetT>
__device__ __forceinline__ void ScatterToStriped(
InputT input_items[ITEMS_PER_THREAD], ///< [in] Items to exchange, converting between <em>striped</em> and <em>blocked</em> arrangements.
OutputT output_items[ITEMS_PER_THREAD], ///< [out] Items from exchange, converting between <em>striped</em> and <em>blocked</em> arrangements.
OffsetT ranks[ITEMS_PER_THREAD]) ///< [in] Corresponding scatter ranks
{
ScatterToStriped(input_items, output_items, ranks, Int2Type<WARP_TIME_SLICING>());
}
/**
* \brief Exchanges data items annotated by rank into <em>striped</em> arrangement. Items with rank -1 are not exchanged.
*
* \par
* - \smemreuse
*
* \tparam OffsetT <b>[inferred]</b> Signed integer type for local offsets
*/
template <typename OutputT, typename OffsetT>
__device__ __forceinline__ void ScatterToStripedGuarded(
InputT input_items[ITEMS_PER_THREAD], ///< [in] Items to exchange, converting between <em>striped</em> and <em>blocked</em> arrangements.
OutputT output_items[ITEMS_PER_THREAD], ///< [out] Items from exchange, converting between <em>striped</em> and <em>blocked</em> arrangements.
OffsetT ranks[ITEMS_PER_THREAD]) ///< [in] Corresponding scatter ranks
{
#pragma unroll
for (int ITEM = 0; ITEM < ITEMS_PER_THREAD; ITEM++)
{
int item_offset = ranks[ITEM];
if (INSERT_PADDING) item_offset = SHR_ADD(item_offset, LOG_SMEM_BANKS, item_offset);
if (ranks[ITEM] >= 0)
temp_storage.buff[item_offset] = input_items[ITEM];
}
CTA_SYNC();
#pragma unroll
for (int ITEM = 0; ITEM < ITEMS_PER_THREAD; ITEM++)
{
int item_offset = int(ITEM * BLOCK_THREADS) + linear_tid;
if (INSERT_PADDING) item_offset = SHR_ADD(item_offset, LOG_SMEM_BANKS, item_offset);
output_items[ITEM] = temp_storage.buff[item_offset];
}
}
/**
* \brief Exchanges valid data items annotated by rank into <em>striped</em> arrangement.
*
* \par
* - \smemreuse
*
* \tparam OffsetT <b>[inferred]</b> Signed integer type for local offsets
* \tparam ValidFlag <b>[inferred]</b> FlagT type denoting which items are valid
*/
template <typename OutputT, typename OffsetT, typename ValidFlag>
__device__ __forceinline__ void ScatterToStripedFlagged(
InputT input_items[ITEMS_PER_THREAD], ///< [in] Items to exchange, converting between <em>striped</em> and <em>blocked</em> arrangements.
OutputT output_items[ITEMS_PER_THREAD], ///< [out] Items from exchange, converting between <em>striped</em> and <em>blocked</em> arrangements.
OffsetT ranks[ITEMS_PER_THREAD], ///< [in] Corresponding scatter ranks
ValidFlag is_valid[ITEMS_PER_THREAD]) ///< [in] Corresponding flag denoting item validity
{
#pragma unroll
for (int ITEM = 0; ITEM < ITEMS_PER_THREAD; ITEM++)
{
int item_offset = ranks[ITEM];
if (INSERT_PADDING) item_offset = SHR_ADD(item_offset, LOG_SMEM_BANKS, item_offset);
if (is_valid[ITEM])
temp_storage.buff[item_offset] = input_items[ITEM];
}
CTA_SYNC();
#pragma unroll
for (int ITEM = 0; ITEM < ITEMS_PER_THREAD; ITEM++)
{
int item_offset = int(ITEM * BLOCK_THREADS) + linear_tid;
if (INSERT_PADDING) item_offset = SHR_ADD(item_offset, LOG_SMEM_BANKS, item_offset);
output_items[ITEM] = temp_storage.buff[item_offset];
}
}
//@} end member group
#ifndef DOXYGEN_SHOULD_SKIP_THIS // Do not document
__device__ __forceinline__ void StripedToBlocked(
InputT items[ITEMS_PER_THREAD]) ///< [in-out] Items to exchange, converting between <em>striped</em> and <em>blocked</em> arrangements.
{
StripedToBlocked(items, items);
}
__device__ __forceinline__ void BlockedToStriped(
InputT items[ITEMS_PER_THREAD]) ///< [in-out] Items to exchange, converting between <em>striped</em> and <em>blocked</em> arrangements.
{
BlockedToStriped(items, items);
}
__device__ __forceinline__ void WarpStripedToBlocked(
InputT items[ITEMS_PER_THREAD]) ///< [in-out] Items to exchange, converting between <em>striped</em> and <em>blocked</em> arrangements.
{
WarpStripedToBlocked(items, items);
}
__device__ __forceinline__ void BlockedToWarpStriped(
InputT items[ITEMS_PER_THREAD]) ///< [in-out] Items to exchange, converting between <em>striped</em> and <em>blocked</em> arrangements.
{
BlockedToWarpStriped(items, items);
}
template <typename OffsetT>
__device__ __forceinline__ void ScatterToBlocked(
InputT items[ITEMS_PER_THREAD], ///< [in-out] Items to exchange, converting between <em>striped</em> and <em>blocked</em> arrangements.
OffsetT ranks[ITEMS_PER_THREAD]) ///< [in] Corresponding scatter ranks
{
ScatterToBlocked(items, items, ranks);
}
template <typename OffsetT>
__device__ __forceinline__ void ScatterToStriped(
InputT items[ITEMS_PER_THREAD], ///< [in-out] Items to exchange, converting between <em>striped</em> and <em>blocked</em> arrangements.
OffsetT ranks[ITEMS_PER_THREAD]) ///< [in] Corresponding scatter ranks
{
ScatterToStriped(items, items, ranks);
}
template <typename OffsetT>
__device__ __forceinline__ void ScatterToStripedGuarded(
InputT items[ITEMS_PER_THREAD], ///< [in-out] Items to exchange, converting between <em>striped</em> and <em>blocked</em> arrangements.
OffsetT ranks[ITEMS_PER_THREAD]) ///< [in] Corresponding scatter ranks
{
ScatterToStripedGuarded(items, items, ranks);
}
template <typename OffsetT, typename ValidFlag>
__device__ __forceinline__ void ScatterToStripedFlagged(
InputT items[ITEMS_PER_THREAD], ///< [in-out] Items to exchange, converting between <em>striped</em> and <em>blocked</em> arrangements.
OffsetT ranks[ITEMS_PER_THREAD], ///< [in] Corresponding scatter ranks
ValidFlag is_valid[ITEMS_PER_THREAD]) ///< [in] Corresponding flag denoting item validity
{
ScatterToStriped(items, items, ranks, is_valid);
}
#endif // DOXYGEN_SHOULD_SKIP_THIS
};
#ifndef DOXYGEN_SHOULD_SKIP_THIS // Do not document
template <
typename T,
int ITEMS_PER_THREAD,
int LOGICAL_WARP_THREADS = CUB_PTX_WARP_THREADS,
int PTX_ARCH = CUB_PTX_ARCH>
class WarpExchange
{
private:
/******************************************************************************
* Constants
******************************************************************************/
/// Constants
enum
{
// Whether the logical warp size and the PTX warp size coincide
IS_ARCH_WARP = (LOGICAL_WARP_THREADS == CUB_WARP_THREADS(PTX_ARCH)),
WARP_ITEMS = (ITEMS_PER_THREAD * LOGICAL_WARP_THREADS) + 1,
LOG_SMEM_BANKS = CUB_LOG_SMEM_BANKS(PTX_ARCH),
SMEM_BANKS = 1 << LOG_SMEM_BANKS,
// Insert padding if the number of items per thread is a power of two and > 4 (otherwise we can typically use 128b loads)
INSERT_PADDING = (ITEMS_PER_THREAD > 4) && (PowerOfTwo<ITEMS_PER_THREAD>::VALUE),
PADDING_ITEMS = (INSERT_PADDING) ? (WARP_ITEMS >> LOG_SMEM_BANKS) : 0,
};
/******************************************************************************
* Type definitions
******************************************************************************/
/// Shared memory storage layout type
struct _TempStorage
{
T buff[WARP_ITEMS + PADDING_ITEMS];
};
public:
/// \smemstorage{WarpExchange}
struct TempStorage : Uninitialized<_TempStorage> {};
private:
/******************************************************************************
* Thread fields
******************************************************************************/
_TempStorage &temp_storage;
int lane_id;
public:
/******************************************************************************
* Construction
******************************************************************************/
/// Constructor
__device__ __forceinline__ WarpExchange(
TempStorage &temp_storage)
:
temp_storage(temp_storage.Alias()),
lane_id(IS_ARCH_WARP ?
LaneId() :
LaneId() % LOGICAL_WARP_THREADS)
{}
/******************************************************************************
* Interface
******************************************************************************/
/**
* \brief Exchanges valid data items annotated by rank into <em>striped</em> arrangement.
*
* \par
* - \smemreuse
*
* \tparam OffsetT <b>[inferred]</b> Signed integer type for local offsets
*/
template <typename OffsetT>
__device__ __forceinline__ void ScatterToStriped(
T items[ITEMS_PER_THREAD], ///< [in-out] Items to exchange
OffsetT ranks[ITEMS_PER_THREAD]) ///< [in] Corresponding scatter ranks
{
#pragma unroll
for (int ITEM = 0; ITEM < ITEMS_PER_THREAD; ITEM++)
{
if (INSERT_PADDING) ranks[ITEM] = SHR_ADD(ranks[ITEM], LOG_SMEM_BANKS, ranks[ITEM]);
temp_storage.buff[ranks[ITEM]] = items[ITEM];
}
WARP_SYNC(0xffffffff);
#pragma unroll
for (int ITEM = 0; ITEM < ITEMS_PER_THREAD; ITEM++)
{
int item_offset = (ITEM * LOGICAL_WARP_THREADS) + lane_id;
if (INSERT_PADDING) item_offset = SHR_ADD(item_offset, LOG_SMEM_BANKS, item_offset);
items[ITEM] = temp_storage.buff[item_offset];
}
}
};
#endif // DOXYGEN_SHOULD_SKIP_THIS
} // CUB namespace
CUB_NS_POSTFIX // Optional outer namespace(s)