1. Add code to include control flow dependence in slicing
[dyninst.git] / parseAPI / src / IndirectAnalyzer.C
1 #include "dyntypes.h"
2 #include "IndirectControlFlow.h"
3 #include "BackwardSlicing.h"
4 #include "IA_IAPI.h"
5 #include "debug_parse.h"
6
7 #include "CodeObject.h"
8 #include "Graph.h"
9
10 #include "Instruction.h"
11 #include "InstructionDecoder.h"
12 #include "Register.h"
13
14 #define SIGNEX_64_32 0xffffffff00000000LL
15 #define SIGNEX_64_16 0xffffffffffff0000LL
16 #define SIGNEX_64_8  0xffffffffffffff00LL
17 #define SIGNEX_32_16 0xffff0000
18 #define SIGNEX_32_8 0xffffff00
19
20 // Assume the table contain less than this many entries.
21 #define MAX_TABLE_ENTRY 1000000
22
23 bool IndirectControlFlowAnalyzer::FillInOutEdges(BoundValue &target, 
24                                                  vector<pair< Address, Dyninst::ParseAPI::EdgeTypeEnum > >& outEdges) {
25
26     Address tableBase = target.interval.low;
27     Address tableLastEntry = target.interval.high;
28     Architecture arch = block->obj()->cs()->getArch();
29     if (arch == Arch_x86) {
30         tableBase &= 0xffffffff;
31         tableLastEntry &= 0xffffffff;
32     }
33
34 #if defined(os_windows)
35     tableBase -= block->obj()->cs()->loadAddress();
36     tableLastEntry -= block->obj()->cs()->loadAddress();
37 #endif
38
39     parsing_printf("The final target bound fact:\n");
40     target.Print();
41
42     if (!block->obj()->cs()->isValidAddress(tableBase)) {
43         parsing_printf("\ttableBase 0x%lx invalid, returning false\n", tableBase);
44         return false;
45     }
46
47     for (Address tableEntry = tableBase; tableEntry <= tableLastEntry; tableEntry += target.interval.stride) {
48         if (!block->obj()->cs()->isValidAddress(tableEntry)) continue;
49         Address targetAddress = 0;
50         if (target.isTableRead) {
51             // Two assumptions:
52             // 1. Assume the table contents are moved in a sign extended way;
53             // 2. Assume memory access size is the same as the table stride
54             switch (target.interval.stride) {
55                 case 8:
56                     targetAddress = *(const uint64_t *) block->obj()->cs()->getPtrToInstruction(tableEntry);
57                     break;
58                 case 4:
59                     targetAddress = *(const uint32_t *) block->obj()->cs()->getPtrToInstruction(tableEntry);
60                     if ((arch == Arch_x86_64) && (targetAddress & 0x80000000)) {
61                         targetAddress |= SIGNEX_64_32;
62                     }
63                     break;
64                 case 2:
65                     targetAddress = *(const uint16_t *) block->obj()->cs()->getPtrToInstruction(tableEntry);
66                     if ((arch == Arch_x86_64) && (targetAddress & 0x8000)) {
67                         targetAddress |= SIGNEX_64_16;
68                     }
69                     if ((arch == Arch_x86) && (targetAddress & 0x8000)) {
70                         targetAddress |= SIGNEX_32_16;
71                     }
72
73                     break;
74                 case 1:
75                     targetAddress = *(const uint8_t *) block->obj()->cs()->getPtrToInstruction(tableEntry);
76                     if ((arch == Arch_x86_64) && (targetAddress & 0x80)) {
77                         targetAddress |= SIGNEX_64_8;
78                     }
79                     if ((arch == Arch_x86) && (targetAddress & 0x80)) {
80                         targetAddress |= SIGNEX_32_8;
81                     }
82
83                     break;
84
85                 default:
86                     parsing_printf("Invalid table stride %d\n", target.interval.stride);
87                     return false;
88             }
89             if (targetAddress != 0) {
90                 if (target.isSubReadContent) 
91                     targetAddress = target.targetBase - targetAddress;
92                 else 
93                     targetAddress += target.targetBase; 
94
95             }
96 #if defined(os_windows)
97             targetAddress -= block->obj()->cs()->loadAddress();
98 #endif
99         } else targetAddress = tableEntry;
100
101         if (block->obj()->cs()->getArch() == Arch_x86) targetAddress &= 0xffffffff;
102         parsing_printf("Jumping to target %lx,", targetAddress);
103         if (block->obj()->cs()->isCode(targetAddress)) {
104             outEdges.push_back(make_pair(targetAddress, INDIRECT));
105             parsing_printf(" is code.\n" );
106         } else {
107             parsing_printf(" not code.\n");
108         }
109     }
110     return true;
111 }
112
113 bool IndirectControlFlowAnalyzer::NewJumpTableAnalysis(std::vector<std::pair< Address, Dyninst::ParseAPI::EdgeTypeEnum > >& outEdges) {
114
115 //    if (block->last() != 0x4e4ffb) return false;
116
117     parsing_printf("Apply indirect control flow analysis at %lx\n", block->last());
118 //    FindAllConditionalGuards();
119     FindAllThunks();
120     ReachFact rf(guards, thunks);
121     parsing_printf("Calculate backward slice\n");
122
123     BackwardSlicer bs(func, block, block->last(), guards, rf);
124     GraphPtr slice =  bs.CalculateBackwardSlicing();
125
126     parsing_printf("Calculate bound facts\n");     
127     BoundFactsCalculator bfc(guards, func, slice, func->entry() == block, rf, thunks);
128     bfc.CalculateBoundedFacts();
129
130     BoundValue target;
131     bool ijt = IsJumpTable(slice, bfc, target);
132     if (ijt) {
133         return FillInOutEdges(target, outEdges);
134     } else return false;
135 }                                                      
136
137
138
139 bool IndirectControlFlowAnalyzer::EndWithConditionalJump(ParseAPI::Block * b) {
140
141     const unsigned char * buf = (const unsigned char*) b->obj()->cs()->getPtrToInstruction(b->last());
142     InstructionDecoder dec(buf, b->end() - b->last(), b->obj()->cs()->getArch());
143     Instruction::Ptr insn = dec.decode();
144     entryID id = insn->getOperation().getID();
145
146     if (id == e_jz || id == e_jnz ||
147         id == e_jb || id == e_jnb ||
148         id == e_jbe || id == e_jnbe) return true;
149    
150 //    for (auto eit = b->targets().begin(); eit != b->targets().end(); ++eit)
151 //        if ((*eit)->type() == COND_TAKEN) return true;
152     return false;
153
154 }
155
156 void IndirectControlFlowAnalyzer::GetAllReachableBlock() {
157     reachable.clear();
158     queue<Block*> q;
159     q.push(block);
160     while (!q.empty()) {
161         ParseAPI::Block *cur = q.front();
162         q.pop();
163         if (reachable.find(cur) != reachable.end()) continue;
164         reachable.insert(cur);
165         for (auto eit = cur->sources().begin(); eit != cur->sources().end(); ++eit)
166             if ((*eit)->intraproc()) 
167                 q.push((*eit)->src());
168     }
169
170 }
171
172 void IndirectControlFlowAnalyzer::SaveGuardData(ParseAPI::Block *prev) {
173
174     Address curAddr = prev->start();
175     const unsigned char* buf = (const unsigned char*) prev->obj()->cs()->getPtrToInstruction(prev->start());
176     InstructionDecoder dec(buf, prev->end() - prev->start(), prev->obj()->cs()->getArch());
177     Instruction::Ptr insn;
178     vector<pair<Instruction::Ptr, Address> > insns;
179     
180     while ( (insn = dec.decode()) != NULL ) {
181         insns.push_back(make_pair(insn, curAddr));
182         curAddr += insn->size();
183     }
184     
185     for (auto iit = insns.rbegin(); iit != insns.rend(); ++iit) {
186         insn = iit->first;
187         if (insn->getOperation().getID() == e_cmp || insn->getOperation().getID() == e_test) {
188             guards.insert(GuardData(func, prev, insn, insns.rbegin()->first, iit->second, insns.rbegin()->second));
189             parsing_printf("Find guard and cmp pair: cmp %s, addr %lx, cond jump %s, addr %lx\n", insn->format().c_str(), iit->second, insns.rbegin()->first->format().c_str(), insns.rbegin()->second); 
190             break;
191         }    
192     }
193 }
194
195 void IndirectControlFlowAnalyzer::FindAllConditionalGuards(){
196     set<ParseAPI::Block*> visited;
197     queue<Block*> q;
198     q.push(block);
199     GetAllReachableBlock();
200
201     while (!q.empty()) {
202         ParseAPI::Block * cur = q.front();
203         q.pop();
204         if (visited.find(cur) != visited.end()) continue;
205         visited.insert(cur);
206
207         // Since a guard has the condition that one branch must always reach the indirect jump,
208         // if the current block can reach a block that cannot reach the indirect jump, 
209         // then all the sources of the current block is not post-dominated by the indirect jump.
210         bool postDominate = true;
211         for (auto eit = cur->targets().begin(); eit != cur->targets().end(); ++eit) 
212             if ((*eit)->intraproc() && (*eit)->type() != INDIRECT)
213                 if (reachable.find((*eit)->trg()) == reachable.end()) postDominate = false;
214         if (!postDominate) continue;
215
216         for (auto eit = cur->sources().begin(); eit != cur->sources().end(); ++eit)
217             if ((*eit)->intraproc()) {
218                 ParseAPI::Block* prev = (*eit)->src();
219                 if (EndWithConditionalJump(prev)) {                
220                     SaveGuardData(prev);
221                 }
222                 else {
223                     q.push(prev);
224                 }
225             }
226     }
227 }
228
229
230
231 bool IndirectControlFlowAnalyzer::IsJumpTable(GraphPtr slice, 
232                                               BoundFactsCalculator &bfc,
233                                               BoundValue &target) {
234     NodeIterator exitBegin, exitEnd, srcBegin, srcEnd;
235     slice->exitNodes(exitBegin, exitEnd);
236     SliceNode::Ptr virtualExit = boost::static_pointer_cast<SliceNode>(*exitBegin);
237     virtualExit->ins(srcBegin, srcEnd);
238     SliceNode::Ptr jumpNode = boost::static_pointer_cast<SliceNode>(*srcBegin);
239     
240     const Absloc &loc = jumpNode->assign()->out().absloc();
241     parsing_printf("Checking final bound fact for %s\n",loc.format().c_str()); 
242     BoundFact *bf = bfc.GetBoundFact(virtualExit);
243     if (bf->IsBounded(loc)) {
244         target = *(bf->GetBound(loc));
245         uint64_t s = target.interval.size();
246         if (s > 0 && s <= MAX_TABLE_ENTRY) return true;
247     }
248     return false;
249 }
250
251 void IndirectControlFlowAnalyzer::FindAllThunks() {
252     for (auto bit = reachable.begin(); bit != reachable.end(); ++bit) {
253         // We intentional treat a getting PC call as a special case that does not
254         // end a basic block. So, we need to check every instruction to find all thunks
255         ParseAPI::Block *b = *bit;
256         const unsigned char* buf =
257             (const unsigned char*)(b->obj()->cs()->getPtrToInstruction(b->start()));
258         if( buf == NULL ) {
259             parsing_printf("%s[%d]: failed to get pointer to instruction by offset\n",FILE__, __LINE__);
260             return;
261         }
262         parsing_printf("Looking for thunk in block [%lx,%lx).", b->start(), b->end());
263         InstructionDecoder dec(buf, b->end() - b->start(), b->obj()->cs()->getArch());
264         InsnAdapter::IA_IAPI block(dec, b->start(), b->obj() , b->region(), b->obj()->cs(), b);
265         while (block.getAddr() < b->end()) {
266             if (block.getInstruction()->getCategory() == c_CallInsn && block.isThunk()) {
267                 bool valid;
268                 Address addr;
269                 boost::tie(valid, addr) = block.getCFT();
270                 const unsigned char *target = (const unsigned char *) b->obj()->cs()->getPtrToInstruction(addr);
271                 InstructionDecoder targetChecker(target, InstructionDecoder::maxInstructionLength, b->obj()->cs()->getArch());
272                 Instruction::Ptr thunkFirst = targetChecker.decode();
273                 set<RegisterAST::Ptr> thunkTargetRegs;
274                 thunkFirst->getWriteSet(thunkTargetRegs);
275                 
276                 for (auto curReg = thunkTargetRegs.begin(); curReg != thunkTargetRegs.end(); ++curReg) {
277                     ThunkInfo t;
278                     t.reg = (*curReg)->getID();
279                     t.value = block.getAddr() + block.getInstruction()->size();
280                     t.block = b;
281                     thunks.insert(make_pair(block.getAddr(), t));
282
283                     parsing_printf("\tfind thunk at %lx, storing value %lx to %s\n", block.getAddr(), t.value , t.reg.name().c_str());
284                 }
285             }
286             block.advance();
287         }
288     }
289 }
290
291