1 //===---- CGOpenMPRuntimeNVPTX.cpp - Interface to OpenMP NVPTX Runtimes ---===//
3 // The LLVM Compiler Infrastructure
5 // This file is distributed under the University of Illinois Open Source
6 // License. See LICENSE.TXT for details.
8 //===----------------------------------------------------------------------===//
10 // This provides a class for OpenMP runtime code generation specialized to NVPTX
13 //===----------------------------------------------------------------------===//
15 #include "CGOpenMPRuntimeNVPTX.h"
16 #include "clang/AST/DeclOpenMP.h"
18 using namespace clang;
19 using namespace CodeGen;
21 /// \brief Get the GPU warp size.
22 llvm::Value *CGOpenMPRuntimeNVPTX::getNVPTXWarpSize(CodeGenFunction &CGF) {
23 CGBuilderTy &Bld = CGF.Builder;
24 return Bld.CreateCall(
25 llvm::Intrinsic::getDeclaration(
26 &CGM.getModule(), llvm::Intrinsic::nvvm_read_ptx_sreg_warpsize),
27 llvm::None, "nvptx_warp_size");
30 /// \brief Get the id of the current thread on the GPU.
31 llvm::Value *CGOpenMPRuntimeNVPTX::getNVPTXThreadID(CodeGenFunction &CGF) {
32 CGBuilderTy &Bld = CGF.Builder;
33 return Bld.CreateCall(
34 llvm::Intrinsic::getDeclaration(
35 &CGM.getModule(), llvm::Intrinsic::nvvm_read_ptx_sreg_tid_x),
36 llvm::None, "nvptx_tid");
39 // \brief Get the maximum number of threads in a block of the GPU.
40 llvm::Value *CGOpenMPRuntimeNVPTX::getNVPTXNumThreads(CodeGenFunction &CGF) {
41 CGBuilderTy &Bld = CGF.Builder;
42 return Bld.CreateCall(
43 llvm::Intrinsic::getDeclaration(
44 &CGM.getModule(), llvm::Intrinsic::nvvm_read_ptx_sreg_ntid_x),
45 llvm::None, "nvptx_num_threads");
48 /// \brief Get barrier to synchronize all threads in a block.
49 void CGOpenMPRuntimeNVPTX::getNVPTXCTABarrier(CodeGenFunction &CGF) {
50 CGBuilderTy &Bld = CGF.Builder;
51 Bld.CreateCall(llvm::Intrinsic::getDeclaration(
52 &CGM.getModule(), llvm::Intrinsic::nvvm_barrier0));
55 // \brief Synchronize all GPU threads in a block.
56 void CGOpenMPRuntimeNVPTX::syncCTAThreads(CodeGenFunction &CGF) {
57 getNVPTXCTABarrier(CGF);
60 /// \brief Get the thread id of the OMP master thread.
61 /// The master thread id is the first thread (lane) of the last warp in the
62 /// GPU block. Warp size is assumed to be some power of 2.
63 /// Thread id is 0 indexed.
64 /// E.g: If NumThreads is 33, master id is 32.
65 /// If NumThreads is 64, master id is 32.
66 /// If NumThreads is 1024, master id is 992.
67 llvm::Value *CGOpenMPRuntimeNVPTX::getMasterThreadID(CodeGenFunction &CGF) {
68 CGBuilderTy &Bld = CGF.Builder;
69 llvm::Value *NumThreads = getNVPTXNumThreads(CGF);
71 // We assume that the warp size is a power of 2.
72 llvm::Value *Mask = Bld.CreateSub(getNVPTXWarpSize(CGF), Bld.getInt32(1));
74 return Bld.CreateAnd(Bld.CreateSub(NumThreads, Bld.getInt32(1)),
75 Bld.CreateNot(Mask), "master_tid");
79 enum OpenMPRTLFunctionNVPTX {
80 /// \brief Call to void __kmpc_kernel_init(kmp_int32 omp_handle,
81 /// kmp_int32 thread_limit);
82 OMPRTL_NVPTX__kmpc_kernel_init,
85 // NVPTX Address space
87 ADDRESS_SPACE_SHARED = 3,
91 CGOpenMPRuntimeNVPTX::WorkerFunctionState::WorkerFunctionState(
93 : WorkerFn(nullptr), CGFI(nullptr) {
94 createWorkerFunction(CGM);
97 void CGOpenMPRuntimeNVPTX::WorkerFunctionState::createWorkerFunction(
99 // Create an worker function with no arguments.
100 CGFI = &CGM.getTypes().arrangeNullaryFunction();
102 WorkerFn = llvm::Function::Create(
103 CGM.getTypes().GetFunctionType(*CGFI), llvm::GlobalValue::InternalLinkage,
104 /* placeholder */ "_worker", &CGM.getModule());
105 CGM.SetInternalFunctionAttributes(/*D=*/nullptr, WorkerFn, *CGFI);
106 WorkerFn->setLinkage(llvm::GlobalValue::InternalLinkage);
107 WorkerFn->addFnAttr(llvm::Attribute::NoInline);
110 void CGOpenMPRuntimeNVPTX::initializeEnvironment() {
112 // Initialize master-worker control state in shared memory.
115 auto DL = CGM.getDataLayout();
116 ActiveWorkers = new llvm::GlobalVariable(
117 CGM.getModule(), CGM.Int32Ty, /*isConstant=*/false,
118 llvm::GlobalValue::CommonLinkage,
119 llvm::Constant::getNullValue(CGM.Int32Ty), "__omp_num_threads", 0,
120 llvm::GlobalVariable::NotThreadLocal, ADDRESS_SPACE_SHARED);
121 ActiveWorkers->setAlignment(DL.getPrefTypeAlignment(CGM.Int32Ty));
123 WorkID = new llvm::GlobalVariable(
124 CGM.getModule(), CGM.Int64Ty, /*isConstant=*/false,
125 llvm::GlobalValue::CommonLinkage,
126 llvm::Constant::getNullValue(CGM.Int64Ty), "__tgt_work_id", 0,
127 llvm::GlobalVariable::NotThreadLocal, ADDRESS_SPACE_SHARED);
128 WorkID->setAlignment(DL.getPrefTypeAlignment(CGM.Int64Ty));
131 void CGOpenMPRuntimeNVPTX::emitWorkerFunction(WorkerFunctionState &WST) {
132 auto &Ctx = CGM.getContext();
134 CodeGenFunction CGF(CGM, /*suppressNewContext=*/true);
135 CGF.StartFunction(GlobalDecl(), Ctx.VoidTy, WST.WorkerFn, *WST.CGFI, {});
136 emitWorkerLoop(CGF, WST);
137 CGF.FinishFunction();
140 void CGOpenMPRuntimeNVPTX::emitWorkerLoop(CodeGenFunction &CGF,
141 WorkerFunctionState &WST) {
143 // The workers enter this loop and wait for parallel work from the master.
144 // When the master encounters a parallel region it sets up the work + variable
145 // arguments, and wakes up the workers. The workers first check to see if
146 // they are required for the parallel region, i.e., within the # of requested
147 // parallel threads. The activated workers load the variable arguments and
148 // execute the parallel work.
151 CGBuilderTy &Bld = CGF.Builder;
153 llvm::BasicBlock *AwaitBB = CGF.createBasicBlock(".await.work");
154 llvm::BasicBlock *SelectWorkersBB = CGF.createBasicBlock(".select.workers");
155 llvm::BasicBlock *ExecuteBB = CGF.createBasicBlock(".execute.parallel");
156 llvm::BasicBlock *TerminateBB = CGF.createBasicBlock(".terminate.parallel");
157 llvm::BasicBlock *BarrierBB = CGF.createBasicBlock(".barrier.parallel");
158 llvm::BasicBlock *ExitBB = CGF.createBasicBlock(".exit");
160 CGF.EmitBranch(AwaitBB);
162 // Workers wait for work from master.
163 CGF.EmitBlock(AwaitBB);
164 // Wait for parallel work
166 // On termination condition (workid == 0), exit loop.
167 llvm::Value *ShouldTerminate = Bld.CreateICmpEQ(
168 Bld.CreateAlignedLoad(WorkID, WorkID->getAlignment()),
169 llvm::Constant::getNullValue(WorkID->getType()->getElementType()),
171 Bld.CreateCondBr(ShouldTerminate, ExitBB, SelectWorkersBB);
173 // Activate requested workers.
174 CGF.EmitBlock(SelectWorkersBB);
175 llvm::Value *ThreadID = getNVPTXThreadID(CGF);
176 llvm::Value *ActiveThread = Bld.CreateICmpSLT(
178 Bld.CreateAlignedLoad(ActiveWorkers, ActiveWorkers->getAlignment()),
180 Bld.CreateCondBr(ActiveThread, ExecuteBB, BarrierBB);
182 // Signal start of parallel region.
183 CGF.EmitBlock(ExecuteBB);
184 // TODO: Add parallel work.
186 // Signal end of parallel region.
187 CGF.EmitBlock(TerminateBB);
188 CGF.EmitBranch(BarrierBB);
190 // All active and inactive workers wait at a barrier after parallel region.
191 CGF.EmitBlock(BarrierBB);
192 // Barrier after parallel region.
194 CGF.EmitBranch(AwaitBB);
196 // Exit target region.
197 CGF.EmitBlock(ExitBB);
200 // Setup NVPTX threads for master-worker OpenMP scheme.
201 void CGOpenMPRuntimeNVPTX::emitEntryHeader(CodeGenFunction &CGF,
202 EntryFunctionState &EST,
203 WorkerFunctionState &WST) {
204 CGBuilderTy &Bld = CGF.Builder;
206 // Get the master thread id.
207 llvm::Value *MasterID = getMasterThreadID(CGF);
208 // Current thread's identifier.
209 llvm::Value *ThreadID = getNVPTXThreadID(CGF);
211 // Setup BBs in entry function.
212 llvm::BasicBlock *WorkerCheckBB = CGF.createBasicBlock(".check.for.worker");
213 llvm::BasicBlock *WorkerBB = CGF.createBasicBlock(".worker");
214 llvm::BasicBlock *MasterBB = CGF.createBasicBlock(".master");
215 EST.ExitBB = CGF.createBasicBlock(".exit");
217 // The head (master thread) marches on while its body of companion threads in
218 // the warp go to sleep.
219 llvm::Value *ShouldDie =
220 Bld.CreateICmpUGT(ThreadID, MasterID, "excess_in_master_warp");
221 Bld.CreateCondBr(ShouldDie, EST.ExitBB, WorkerCheckBB);
223 // Select worker threads...
224 CGF.EmitBlock(WorkerCheckBB);
225 llvm::Value *IsWorker = Bld.CreateICmpULT(ThreadID, MasterID, "is_worker");
226 Bld.CreateCondBr(IsWorker, WorkerBB, MasterBB);
228 // ... and send to worker loop, awaiting parallel invocation.
229 CGF.EmitBlock(WorkerBB);
230 CGF.EmitCallOrInvoke(WST.WorkerFn, llvm::None);
231 CGF.EmitBranch(EST.ExitBB);
233 // Only master thread executes subsequent serial code.
234 CGF.EmitBlock(MasterBB);
236 // First action in sequential region:
237 // Initialize the state of the OpenMP runtime library on the GPU.
238 llvm::Value *Args[] = {Bld.getInt32(/*OmpHandle=*/0), getNVPTXThreadID(CGF)};
239 CGF.EmitRuntimeCall(createNVPTXRuntimeFunction(OMPRTL_NVPTX__kmpc_kernel_init),
243 void CGOpenMPRuntimeNVPTX::emitEntryFooter(CodeGenFunction &CGF,
244 EntryFunctionState &EST) {
245 CGBuilderTy &Bld = CGF.Builder;
246 llvm::BasicBlock *TerminateBB = CGF.createBasicBlock(".termination.notifier");
247 CGF.EmitBranch(TerminateBB);
249 CGF.EmitBlock(TerminateBB);
250 // Signal termination condition.
251 Bld.CreateAlignedStore(
252 llvm::Constant::getNullValue(WorkID->getType()->getElementType()), WorkID,
253 WorkID->getAlignment());
254 // Barrier to terminate worker threads.
256 // Master thread jumps to exit point.
257 CGF.EmitBranch(EST.ExitBB);
259 CGF.EmitBlock(EST.ExitBB);
262 /// \brief Returns specified OpenMP runtime function for the current OpenMP
263 /// implementation. Specialized for the NVPTX device.
264 /// \param Function OpenMP runtime function.
265 /// \return Specified function.
267 CGOpenMPRuntimeNVPTX::createNVPTXRuntimeFunction(unsigned Function) {
268 llvm::Constant *RTLFn = nullptr;
269 switch (static_cast<OpenMPRTLFunctionNVPTX>(Function)) {
270 case OMPRTL_NVPTX__kmpc_kernel_init: {
271 // Build void __kmpc_kernel_init(kmp_int32 omp_handle,
272 // kmp_int32 thread_limit);
273 llvm::Type *TypeParams[] = {CGM.Int32Ty, CGM.Int32Ty};
274 llvm::FunctionType *FnTy =
275 llvm::FunctionType::get(CGM.VoidTy, TypeParams, /*isVarArg*/ false);
276 RTLFn = CGM.CreateRuntimeFunction(FnTy, "__kmpc_kernel_init");
283 void CGOpenMPRuntimeNVPTX::createOffloadEntry(llvm::Constant *ID,
284 llvm::Constant *Addr,
286 auto *F = dyn_cast<llvm::Function>(Addr);
287 // TODO: Add support for global variables on the device after declare target
291 llvm::Module *M = F->getParent();
292 llvm::LLVMContext &Ctx = M->getContext();
294 // Get "nvvm.annotations" metadata node
295 llvm::NamedMDNode *MD = M->getOrInsertNamedMetadata("nvvm.annotations");
297 llvm::Metadata *MDVals[] = {
298 llvm::ConstantAsMetadata::get(F), llvm::MDString::get(Ctx, "kernel"),
299 llvm::ConstantAsMetadata::get(
300 llvm::ConstantInt::get(llvm::Type::getInt32Ty(Ctx), 1))};
301 // Append metadata to nvvm.annotations
302 MD->addOperand(llvm::MDNode::get(Ctx, MDVals));
305 void CGOpenMPRuntimeNVPTX::emitTargetOutlinedFunction(
306 const OMPExecutableDirective &D, StringRef ParentName,
307 llvm::Function *&OutlinedFn, llvm::Constant *&OutlinedFnID,
308 bool IsOffloadEntry) {
309 if (!IsOffloadEntry) // Nothing to do.
312 assert(!ParentName.empty() && "Invalid target region parent name!");
314 const CapturedStmt &CS = *cast<CapturedStmt>(D.getAssociatedStmt());
316 EntryFunctionState EST;
317 WorkerFunctionState WST(CGM);
319 // Emit target region as a standalone region.
320 auto &&CodeGen = [&EST, &WST, &CS, &D, this](CodeGenFunction &CGF) {
321 CodeGenFunction::OMPPrivateScope PrivateScope(CGF);
322 (void)CGF.EmitOMPFirstprivateClause(D, PrivateScope);
323 CGF.EmitOMPPrivateClause(D, PrivateScope);
324 (void)PrivateScope.Privatize();
326 emitEntryHeader(CGF, EST, WST);
327 CGF.EmitStmt(CS.getCapturedStmt());
328 emitEntryFooter(CGF, EST);
330 emitTargetOutlinedFunctionHelper(D, ParentName, OutlinedFn, OutlinedFnID,
331 IsOffloadEntry, CodeGen);
333 // Create the worker function
334 emitWorkerFunction(WST);
336 // Now change the name of the worker function to correspond to this target
337 // region's entry function.
338 WST.WorkerFn->setName(OutlinedFn->getName() + "_worker");
341 CGOpenMPRuntimeNVPTX::CGOpenMPRuntimeNVPTX(CodeGenModule &CGM)
342 : CGOpenMPRuntime(CGM), ActiveWorkers(nullptr), WorkID(nullptr) {
343 if (!CGM.getLangOpts().OpenMPIsDevice)
344 llvm_unreachable("OpenMP NVPTX can only handle device code.");
346 // Called once per module during initialization.
347 initializeEnvironment();