Source code for apkutils.dex.jvm.ir

# Copyright 2015 Google Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import struct

from .constants import calc
from .jvmops import *
from . import constantpool, error
from . import scalartypes as scalars

# IR representation roughly corresponding to JVM bytecode instructions. Note that these
# may correspond to more than one instruction in the actual bytecode generated but they
# are useful logical units for the internal optimization passes.

[docs]class JvmInstruction: def __init__(self, bytecode=None): self.bytecode = bytecode # None or bytestring
[docs] def fallsthrough(self): return True
[docs] def targets(self): return []
# Used to mark locations in the IR instructions for various purposes. These are # seperate IR 'instructions' since the optimization passes may remove or replace # the other instructions.
[docs]class Label(JvmInstruction): def __init__(self, id=None): super().__init__(b'') self.id = id # None or int
_ilfdaOrd = [scalars.INT, scalars.LONG, scalars.FLOAT, scalars.DOUBLE, scalars.OBJ].index
[docs]class RegAccess(JvmInstruction): def __init__(self, dreg, st, store): super().__init__() self.key = dreg, st self.store = store self.wide = scalars.iswide(st)
[docs] @staticmethod def raw(local, stype, store): new = RegAccess(0, stype, store) new.calcBytecode(local) return new
[docs] def calcBytecode(self, local): assert self.bytecode is None stype = self.key[1] op_off = (ISTORE - ILOAD) if self.store else 0 if local < 4: self.bytecode = struct.pack('>B', ILOAD_0 + op_off + local + _ilfdaOrd(stype)*4) elif local < 256: self.bytecode = struct.pack('>BB', ILOAD + op_off + _ilfdaOrd(stype), local) else: self.bytecode = struct.pack('>BBH', WIDE, ILOAD + op_off + _ilfdaOrd(stype), local)
[docs]class PrimConstant(JvmInstruction): def __init__(self, st, val, pool=None): super().__init__() self.st = st self.val = val = calc.normalize(st, val) self.wide = scalars.iswide(st) # If pool is passed in, just grab an entry greedily, otherwise calculate # a sequence of bytecode to generate the constant if pool is not None: self.bytecode = calc.lookupOnly(st, val) if self.bytecode is None: self._from_pool(pool) if self.bytecode is None: raise error.ClassfileLimitExceeded() else: self.bytecode = calc.calc(st, val)
[docs] def cpool_key(self): tag = { scalars.INT: constantpool.CONSTANT_Integer, scalars.FLOAT: constantpool.CONSTANT_Float, scalars.DOUBLE: constantpool.CONSTANT_Double, scalars.LONG: constantpool.CONSTANT_Long, }[self.st] return tag, self.val
def _from_pool(self, pool): index = pool.tryGet(self.cpool_key()) if index is not None: if scalars.iswide(self.st): code = struct.pack('>BH', LDC2_W, index) elif index >= 256: code = struct.pack('>BH', LDC_W, index) else: code = struct.pack('>BB', LDC, index) self.bytecode = code
[docs] def fix_with_pool(self, pool): if len(self.bytecode) > 2: self._from_pool(pool)
[docs]class OtherConstant(JvmInstruction): wide = False # will be null, string or class - always single
[docs]class LazyJumpBase(JvmInstruction): def __init__(self, target): super().__init__() self.target = target
[docs] def targets(self): return [self.target]
[docs] def widenIfNecessary(self, labels, posd): offset = posd[labels[self.target]] - posd[self] if not -32768 <= offset < 32768: self.min = self.max return True return False
[docs]class Goto(LazyJumpBase): def __init__(self, target): super().__init__(target) self.min = 3 self.max = 5 # upper limit on length of bytecode
[docs] def fallsthrough(self): return False
[docs] def calcBytecode(self, posd, labels): offset = posd[labels[self.target]] - posd[self] if self.max == 3: self.bytecode = struct.pack('>Bh', GOTO, offset) else: self.bytecode = struct.pack('>Bi', GOTO_W, offset)
_ifOpposite = {} for _op1, _op2 in [(IFEQ, IFNE), (IFLT, IFGE), (IFGT, IFLE), (IF_ICMPEQ, IF_ICMPNE), (IF_ICMPLT, IF_ICMPGE), (IF_ICMPGT, IF_ICMPLE), (IFNULL, IFNONNULL), (IF_ACMPEQ, IF_ACMPNE)]: _ifOpposite[_op1] = _op2 _ifOpposite[_op2] = _op1
[docs]class If(LazyJumpBase): def __init__(self, op, target): super().__init__(target) self.op = op self.min = 3 self.max = 8 # upper limit on length of bytecode # Unlike with goto, if instructions are limited to a 16 bit jump offset. # Therefore, for larger jumps, we have to substitute a different sequence # # if x goto A # B: whatever # # becomes # # if !x goto B # goto A # B: whatever
[docs] def calcBytecode(self, posd, labels): if self.max == 3: offset = posd[labels[self.target]] - posd[self] self.bytecode = struct.pack('>Bh', self.op, offset) else: op = _ifOpposite[self.op] offset = posd[labels[self.target]] - posd[self] - 3 self.bytecode = struct.pack('>BhBi', op, 8, GOTO_W, offset)
[docs]class Switch(JvmInstruction): def __init__(self, default, jumps): super().__init__() self.default = default self.jumps = jumps assert jumps self.low = min(jumps) self.high = max(jumps) table_count = self.high - self.low + 1 table_size = 4*(table_count+1) jump_size = 8*len(jumps) self.istable = jump_size > table_size self.nopad_size = 9 + (table_size if self.istable else jump_size) self.max = self.nopad_size + 3
[docs] def fallsthrough(self): return False
[docs] def targets(self): return sorted(set(self.jumps.values())) + [self.default]
[docs] def calcBytecode(self, posd, labels): pos = posd[self] offset = posd[labels[self.default]] - pos pad = (-pos-1) % 4 bytecode = bytearray() if self.istable: bytecode += bytes([TABLESWITCH] + [0]*pad) bytecode += struct.pack('>iii', offset, self.low, self.high) for k in range(self.low, self.high + 1): target = self.jumps.get(k, self.default) bytecode += struct.pack('>i', posd[labels[target]] - pos) else: bytecode += bytes([LOOKUPSWITCH] + [0]*pad) bytecode += struct.pack('>iI', offset, len(self.jumps)) for k, target in sorted(self.jumps.items()): offset = posd[labels[target]] - pos bytecode += struct.pack('>ii', k, offset) self.bytecode = bytes(bytecode)
_return_or_throw_bytecodes = {bytes([op]) for op in range(IRETURN, RETURN+1) } _return_or_throw_bytecodes.add(bytes([ATHROW]))
[docs]class Other(JvmInstruction):
[docs] def fallsthrough(self): return self.bytecode not in _return_or_throw_bytecodes
[docs]def Pop(): return Other(bytes([POP]))
[docs]def Pop2(): return Other(bytes([POP2]))
[docs]def Dup(): return Other(bytes([DUP]))
[docs]def Dup2(): return Other(bytes([DUP2]))