LLVM Optimization (4)
실습 3
[실습 3]에서는 프로그램 내 특정 명령어의 수를 세어 출력하는 간단한 정적 프로파일러를 생성한다. LLVM 6.0.1 기준으로 LLVM IR에는 총 64개의 명령어 Opcode가 있으며 자주 사용되는 명령어는 아래와 같다.
- BinaryOperator: add, sub, mul과 같은 산술연산 / and, or과 같은 비교연산 / 등 2개의 operand를 연산하는 명령어
- ReturnInst, BranchInst: 제어흐름 명령어
- CallInst: 함수 호출 명령어
- CastInst: 타입 변환 명령어
- AllocInst: 스택(정적메모리)에 메모리 할당하는 명령어
- LoadInst, StoreInst: 메모리의 데이터들을 접근하는 명령어
- GetElementPtrInst: 배열 접근 등에서 메모리에 접근하는 명령어
[실습 3]에서는 [실습 2]에서 진행한 순회에서 한 단계 업그레이드 시켜 instruction의 종류에 따른 동작을 구분한다. 간단한 예시로 add instruction을 필터링하는 작업을 수행해보자.
- BinaryOperator 필터링하기:
llvm::instruction
의 멤버함수인inBinaryOp()
를 사용하여 해당 instruction이 BinaryOperator인지 확인할 수 있다. - Opcode를 직접 얻어오기:
llvm::Instruction
의getOpcode()
를 이용해서 현재 명령어가 어떤 명령어인지 알 수 있다. - Add 연산 필터링하기: 아래와 같이 BinaryOperator 중 add연산에 대해서만 조건을 추가할 수 있다.
if (Inst->getOpcode() == llvm::Instruction::Add) {total_add_inst++;}
예제
CountInst.cpp에서는 BasicBlock 내의 Instruction을 순회하면서 add, sub, mul, div의 instruction의 수와 Test.c 파일에 존재하는 함수의 호출 수를 출력한다.
#include <iostream>
#include <fstream>
#include <vector>
#include <string>
#include <stdlib.h>
#include <string.h>
#include "llvm/IR/IRBuilder.h"
#include "llvm/IR/LLVMContext.h"
#include "llvm/IR/Module.h"
#include "llvm/IR/Metadata.h"
#include "llvm/IR/Instruction.h"
#include "llvm/IR/InstrTypes.h"
#include "llvm/IR/BasicBlock.h"
#include "llvm/IR/Function.h"
#include "llvm/IR/DiagnosticInfo.h"
#include "llvm/IR/DiagnosticPrinter.h"
#include "llvm/IR/DiagnosticHandler.h"
#include "llvm/Bitcode/BitcodeReader.h"
#include "llvm/Bitcode/BitcodeWriter.h"
#include "llvm/Transforms/Utils/ValueMapper.h"
#include "llvm/Transforms/Utils/Cloning.h"
#include "llvm/Transforms/Utils/BasicBlockUtils.h"
#include "llvm/IRReader/IRReader.h"
#include "llvm/Support/raw_ostream.h"
#include "llvm/Support/raw_os_ostream.h"
#include "llvm/IR/DebugInfoMetadata.h"
#include "llvm/Support/raw_ostream.h"
std::string input_path;
llvm::LLVMContext* TheContext;
std::unique_ptr<llvm::Module> TheModule;
void ParseIRSource(void);
void TraverseModule(void);
int main(int argc , char** argv)
{
if(argc < 2)
{
std::cout << "Usage: ./CountInst <ir_file_path>" << std::endl;
return -1;
}
input_path = std::string(argv[1]);
// Read & Parse IR Source
ParseIRSource();
// Traverse TheModule
TraverseModule();
return 0;
}
// Read & Parse IR Sources
// Human-readable assembly(*.ll) or Bitcode(*.bc) format is required
void ParseIRSource(void)
{
llvm::SMDiagnostic err;
// Context
TheContext = new llvm::LLVMContext();
if( ! TheContext )
{
std::cerr << "Failed to allocated llvm::LLVMContext" << std::endl;
exit( -1 );
}
// Module from IR Source
TheModule = llvm::parseIRFile(input_path, err, *TheContext);
if( ! TheModule )
{
std::cerr << "Failed to parse IR File : " << input_path << std::endl;
exit( -1 );
}
}
// Traverse Instructions in TheModule
void TraverseModule(void)
{
int total_add_inst = 0;
int total_sub_inst = 0;
int total_mul_inst = 0;
int total_div_inst = 0;
int total_init_func = 0;
int total_vAdd_func = 0;
int total_fuse_func = 0;
int total_main_func = 0;
int total_prnt_func = 0;
for( llvm::Module::iterator ModIter = TheModule->begin(); ModIter != TheModule->end(); ++ModIter )
{
llvm::Function* Func = llvm::cast<llvm::Function>(ModIter);
for( llvm::Function::iterator FuncIter = Func->begin(); FuncIter != Func->end(); ++FuncIter )
{
llvm::BasicBlock* BB = llvm::cast<llvm::BasicBlock>(FuncIter);
for( llvm::BasicBlock::iterator BBIter = BB->begin(); BBIter != BB->end(); ++BBIter )
{
llvm::Instruction* Inst = llvm::cast<llvm::Instruction>(BBIter);
if( Inst->getOpcode() == llvm::Instruction::Add ) { total_add_inst++; }
if (Inst->getOpcode() == llvm::Instruction::Sub) { total_sub_inst++; }
if (Inst->getOpcode() == llvm::Instruction::FMul) { total_mul_inst++; }
if (Inst->getOpcode() == llvm::Instruction::FDiv) { total_div_inst++; }
llvm::CallInst* CInst = llvm::cast<llvm::CallInst>(Inst);
llvm::Function* fn = CInst->getCalledFunction();
if (fn) {
std::string name = fn->getName().str();
std::cout << name << std::endl;
if (name == "init") { total_init_func++; }
if (name == "VectorAdd") { total_vAdd_func++; }
if (name == "FuseAddMul") { total_fuse_func++; }
if (name == "main") { total_main_func++; }
if (name == "printf") { total_prnt_func++; }
}
}
}
}
std::cout << "The Number of ADD Instructions in the Module " << TheModule->getName().str() << " is " << total_add_inst << std::endl;
std::cout << "The Number of SUB Instructions in the Module " << TheModule->getName().str() << " is " << total_sub_inst << std::endl;
std::cout << "The Number of MUL Instructions in the Module " << TheModule->getName().str() << " is " << total_mul_inst << std::endl;
std::cout << "The Number of DIV Instructions in the Module " << TheModule->getName().str() << " is " << total_div_inst << std::endl;
std::cout << "The Number of init function in the Module " << TheModule->getName().str() << " is " << total_init_func << std::endl;
std::cout << "The Number of VectorAdd function in the Module " << TheModule->getName().str() << " is " << total_vAdd_func << std::endl;
std::cout << "The Number of FuseAddMul function in the Module " << TheModule->getName().str() << " is " << total_fuse_func << std::endl;
std::cout << "The Number of main function in the Module " << TheModule->getName().str() << " is " << total_main_func << std::endl;
std::cout << "The Number of printf function in the Module " << TheModule->getName().str() << " is " << total_prnt_func << std::endl;
}
Test.c
#include <stdio.h>
void init(float* arr,float* brr,float* crr, size_t n)
{
for(size_t i=0; i<n;i++)
{
arr[i]=n-i;
brr[i]=n+n+i;
crr[i]=i+n;
}
}
void VectorAdd(float*arr,float*brr,float*crr,size_t n)
{
for(size_t i=0; i<n ; i++)
{
crr[i] = (arr[i]+brr[i])/arr[i];
}
}
void FuseAddMul(float*arr,float*brr,float*crr,size_t n)
{
for(size_t i=0; i<n; i++)
{
crr[i]=arr[i]*brr[i]+crr[i];
}
}
int main()
{
float arr[10];
float brr[10];
float crr[10];
init(arr,brr,crr,10);
VectorAdd(arr,brr,crr,10);
FuseAddMul(arr,brr,crr,10);
for(size_t i=0; i<10; i++)
{
printf("%f\n",crr[i]);
}
}
여기서 한가지 더 고려하자면 -O3
옵션을 적용하면 Loop Unrolling과 같은 최적화가 적용되므로 적용 전과 후의 add 명령어 수가 다르게 출력된다.
clang -O3 -emit-llvm -S Test.c -o Test.ll
./CountInst ./Test.ll
실행 결과
아래의 코드를 통해서 BasicBlock들을 순회하면서 목표로 하는 명령어 수를 세어준다.
llvm::BasicBlock* BB = llvm::cast<llvm::BasicBlock>(FuncIter);
for( llvm::BasicBlock::iterator BBIter = BB->begin(); BBIter != BB->end(); ++BBIter {
llvm::Instruction* Inst = llvm::cast<llvm::Instruction>(BBIter);
if( Inst->getOpcode() == llvm::Instruction::Add ) { total_add_inst++; }
if (Inst->getOpcode() == llvm::Instruction::Sub) { total_sub_inst++; }
if (Inst->getOpcode() == llvm::Instruction::FMul) { total_mul_inst++; }
if (Inst->getOpcode() == llvm::Instruction::FDiv) { total_div_inst++; }
llvm::CallInst* CInst = llvm::cast<llvm::CallInst>(Inst);
llvm::Function* fn = CInst->getCalledFunction();
if (fn) {
std::string name = fn->getName().str();
std::cout << name << std::endl;
if (name == "init") { total_init_func++; }
if (name == "VectorAdd") { total_vAdd_func++; }
if (name == "FuseAddMul") { total_fuse_func++; }
if (name == "main") { total_main_func++; }
if (name == "printf") { total_prnt_func++; }
}
}
해당 코드를 실행하면 위와 같이 명령어 수가 잘 출력되는 것을 확인할 수 있다.
댓글남기기