Source code for apkutils.dex.jvm.optimization.stack

# Copyright 2015 Google Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from .. import ir
from ..jvmops import *

[docs]def visitLinearCode(irdata, visitor): # Visit linear sections of code, pessimistically treating all exception # handler ranges as jumps. except_level = 0 for instr in irdata.flat_instructions: if instr in irdata.except_starts: except_level += 1 visitor.visitExceptionRange() elif instr in irdata.except_ends: except_level -= 1 if except_level > 0: continue if instr in irdata.jump_targets or isinstance(instr, (ir.LazyJumpBase, ir.Switch)): visitor.visitJumpTargetOrBranch(instr) elif not instr.fallsthrough(): visitor.visitReturn() else: visitor.visit(instr) assert except_level == 0 return visitor
[docs]class NoExceptVisitorBase:
[docs] def visitExceptionRange(self): self.reset()
[docs] def visitJumpTargetOrBranch(self, instr): self.reset()
[docs]class ConstInliner(NoExceptVisitorBase): def __init__(self): self.uses = {} self.notmultiused = set() self.current = {}
[docs] def reset(self): self.current = {}
[docs] def visitReturn(self): for key in self.current: self.notmultiused.add(self.current[key]) self.reset()
[docs] def visit(self, instr): if isinstance(instr, ir.RegAccess): key = instr.key if instr.store: if key in self.current: self.notmultiused.add(self.current[key]) self.current[key] = instr elif key in self.current: # if currently used 0, mark it used once # if used once already, mark it as multiused if self.current[key] in self.uses: del self.current[key] else: self.uses[self.current[key]] = instr
[docs]def inlineConsts(irdata): # Inline constants which are only used once or not at all. This only covers # linear sections of code and pessimistically assumes everything is used # when it reaches a jump or exception range. Essentially, this means that # the value can only be considered unused if it is either overwritten by a # store or reaches a return or throw before any jumps. # As usual, assume no iinc. instrs = irdata.flat_instructions visitor = visitLinearCode(irdata, ConstInliner()) replace = {} for ins1, ins2 in zip(instrs, instrs[1:]): if ins2 in visitor.notmultiused and isinstance(ins1, (ir.PrimConstant, ir.OtherConstant)): replace[ins1] = [] replace[ins2] = [] if ins2 in visitor.uses: replace[visitor.uses[ins2]] = [ins1] irdata.replaceInstrs(replace)
[docs]class StoreLoadPruner(NoExceptVisitorBase): def __init__(self): self.current = {} self.last = None self.removed = set()
[docs] def reset(self): self.current = {} self.last = None
[docs] def visitReturn(self): for pair in self.current.values(): assert pair[0].store and not pair[1].store self.removed.update(pair) self.reset()
[docs] def visit(self, instr): if isinstance(instr, ir.RegAccess): key = instr.key if instr.store: if key in self.current: pair = self.current[key] assert pair[0].store and not pair[1].store self.removed.update(self.current.pop(key)) self.last = instr else: self.current.pop(key, None) if self.last and self.last.key == key: self.current[key] = self.last, instr self.last = None elif not isinstance(instr, ir.Label): self.last = None
[docs]def pruneStoreLoads(irdata): # Remove a store immediately followed by a load from the same register # (potentially with a label in between) if it can be proven that this # register isn't read again. As above, this only considers linear sections of code. # Must not be run before dup2ize! data = visitLinearCode(irdata, StoreLoadPruner()) irdata.replaceInstrs({instr:[] for instr in data.removed})
# used by writeir too
[docs]def genDups(needed, needed_after): # Generate a sequence of dup and dup2 instructions to duplicate the given # value. This keeps up to 4 copies of the value on the stack. Thanks to dup2 # this asymptotically takes only half a byte per access. have = 1 ele_count = needed needed += needed_after for _ in range(ele_count): cur = [] if have < needed: if have == 1 and needed >= 2: cur.append(ir.Dup()) have += 1 if have == 2 and needed >= 4: cur.append(ir.Dup2()) have += 2 have -= 1 needed -= 1 yield cur assert have >= needed # check if we have to pop at end yield [ir.Pop() for _ in range(have-needed)]
# Range of instruction indexes at which a given register is read (in linear code)
[docs]class UseRange: def __init__(self, uses): self.uses = uses
[docs] def add(self, i): self.uses.append(i)
@property def start(self): return self.uses[0] @property def end(self): return self.uses[-1]
[docs] def subtract(self, other): s, e = other.start, other.end left = [i for i in self.uses if i < s] right = [i for i in self.uses if i > e] if len(left) >= 2: yield UseRange(left) if len(right) >= 2: yield UseRange(right)
[docs] def sortkey(self): return len(self.uses), self.uses[0]
[docs]def makeRange(instr): assert isinstance(instr, ir.RegAccess) and not instr.store return UseRange([])
[docs]def dup2ize(irdata): # This optimization replaces narrow registers which are frequently read at # stack height 0 with a single read followed by the more efficient dup and # dup2 instructions. This asymptotically uses only half a byte per access. # For simplicity, instead of explicitly keeping track of which locations # have stack height 0, we take advantage of the invariant that ranges of code # corresponding to a single Dalvik instruction always begin with empty stack. # These can be recognized by labels with a non-None id. # This isn't true for move-result instructions, but in that case the range # won't begin with a register load so it doesn't matter. # Note that pruneStoreLoads breaks this invariant, so dup2ize must be run first. # Also, for simplicity, we only keep at most one such value on the stack at # a time (duplicated up to 4 times). instrs = irdata.flat_instructions ranges = [] current = {} at_head = False for i, instr in enumerate(instrs): # if not linear section of bytecode, reset everything. Exceptions are ok # since they clear the stack, but jumps obviously aren't. if instr in irdata.jump_targets or isinstance(instr, (ir.If, ir.Switch)): ranges.extend(current.values()) current = {} if isinstance(instr, ir.RegAccess): key = instr.key if not instr.wide: if instr.store: if key in current: ranges.append(current.pop(key)) elif at_head: current.setdefault(key, makeRange(instr)).add(i) at_head = isinstance(instr, ir.Label) and instr.id is not None ranges.extend(current.values()) ranges = [ur for ur in ranges if len(ur.uses) >= 2] ranges.sort(key=UseRange.sortkey) # Greedily choose a set of disjoint ranges to dup2ize. chosen = [] while ranges: best = ranges.pop() chosen.append(best) newranges = [] for ur in ranges: newranges.extend(ur.subtract(best)) ranges = sorted(newranges, key=UseRange.sortkey) replace = {} for ur in chosen: gen = genDups(len(ur.uses), 0) for pos in ur.uses: ops = next(gen) # remember to include initial load! if pos == ur.start: ops = [instrs[pos]] + ops replace[instrs[pos]] = ops irdata.replaceInstrs(replace)