# Copyright 2015 Google Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import collections, struct
from functools import partial
from . import ir
from .. import flags, dalvik
from .jvmops import *
from . import arraytypes as arrays
from . import scalartypes as scalars
from . import mathops
from .optimization import stack
from .. import util
from ..typeinference import typeinference
# Code for converting dalvik bytecode to intermediate representation
# effectively this is just Java bytecode instructions with some abstractions for
# later optimization
_ilfdaOrd = [scalars.INT, scalars.LONG, scalars.FLOAT, scalars.DOUBLE, scalars.OBJ].index
_newArrayCodes = {('['+t).encode(): v for t, v in zip('ZCFDBSIJ', range(4, 12))}
_arrStoreOps = {t.encode(): v for t, v in zip('IJFD BCS', range(IASTORE, SASTORE+1))}
_arrLoadOps = {t.encode(): v for t, v in zip('IJFD BCS', range(IALOAD, SALOAD+1))}
_arrStoreOps[b'Z'] = BASTORE
_arrLoadOps[b'Z'] = BALOAD
# For generating IR instructions corresponding to a single Dalvik instruction
[docs]class IRBlock:
def __init__(self, parent, pos):
self.type_data = parent.types[pos]
self.pool = parent.pool
self.delay_consts = parent.opts.delay_consts
self.pos = pos
self.instructions = [ir.Label(pos)]
[docs] def add(self, jvm_instr):
self.instructions.append(jvm_instr)
def _other(self, bytecode):
self.add(ir.Other(bytecode=bytecode))
[docs] def u8(self, op): self._other(struct.pack('>B', op))
[docs] def u8u8(self, op, x): self._other(struct.pack('>BB', op, x))
[docs] def u8u16(self, op, x): self._other(struct.pack('>BH', op, x))
# wide non iinc
[docs] def u8u8u16(self, op, op2, x): self._other(struct.pack('>BBH', op, op2, x))
# invokeinterface
[docs] def u8u16u8u8(self, op, x, y, z): self._other(struct.pack('>BHBB', op, x, y, z))
[docs] def ldc(self, index):
if index < 256:
self.add(ir.OtherConstant(bytecode=bytes([LDC, index])))
else:
self.add(ir.OtherConstant(bytecode=struct.pack('>BH', LDC_W, index)))
[docs] def load(self, reg, stype, desc=None, clsname=None):
# if we know the register to be 0/null, don't bother loading
if self.type_data.arrs[reg] == arrays.NULL:
self.const(0, stype)
else:
self.add(ir.RegAccess(reg, stype, store=False))
# cast to appropriate type if tainted
if stype == scalars.OBJ and self.type_data.tainted[reg]:
assert desc is None or clsname is None
if clsname is None:
# remember to handle arrays - also fallthrough if desc is None
clsname = desc[1:-1] if (desc and desc.startswith(b'L')) else desc
if clsname is not None and clsname != b'java/lang/Object':
self.u8u16(CHECKCAST, self.pool.class_(clsname))
[docs] def loadAsArray(self, reg):
at = self.type_data.arrs[reg]
if at == arrays.NULL:
self.const_null()
else:
self.add(ir.RegAccess(reg, scalars.OBJ, store=False))
if self.type_data.tainted[reg]:
if at == arrays.INVALID:
# needs to be some type of object array, so just cast to Object[]
self.u8u16(CHECKCAST, self.pool.class_(b'[Ljava/lang/Object;'))
else:
# note - will throw if actual type is boolean[] but there's not
# much we can do in this case
self.u8u16(CHECKCAST, self.pool.class_(at))
[docs] def store(self, reg, stype):
self.add(ir.RegAccess(reg, stype, store=True))
[docs] def return_(self, stype=None):
if stype is None:
self.u8(RETURN)
else:
self.u8(IRETURN + _ilfdaOrd(stype))
[docs] def const(self, val, stype):
assert (1<<64) > val >= 0
if stype == scalars.OBJ:
assert val == 0
self.const_null()
else:
# If constant pool is simple, assume we're in non-opt mode and only use
# the constant pool for generating constants instead of calculating
# bytecode sequences for them. If we're in opt mode, pass None for pool
# to generate bytecode instead
pool = None if self.delay_consts else self.pool
self.add(ir.PrimConstant(stype, val, pool=pool))
[docs] def const_null(self):
self.add(ir.OtherConstant(bytecode=bytes([ACONST_NULL])))
[docs] def fillarraysub(self, op, cbs, pop=True):
gen = stack.genDups(len(cbs), 0 if pop else 1)
for i, cb in enumerate(cbs):
for instr in next(gen):
self.add(instr)
self.const(i, scalars.INT)
cb()
self.u8(op)
# may need to pop at end
for instr in next(gen):
self.add(instr)
[docs] def newarray(self, desc):
if desc in _newArrayCodes:
self.u8u8(NEWARRAY, _newArrayCodes[desc])
else:
# can be either multidim array or object array descriptor
desc = desc[1:]
if desc.startswith(b'L'):
desc = desc[1:-1]
self.u8u16(ANEWARRAY, self.pool.class_(desc))
[docs] def fillarraydata(self, op, stype, vals):
self.fillarraysub(op, [partial(self.const, val, stype) for val in vals])
[docs] def cast(self, dex, reg, index):
self.load(reg, scalars.OBJ)
self.u8u16(CHECKCAST, self.pool.class_(dex.clsType(index)))
self.store(reg, scalars.OBJ)
[docs] def goto(self, target):
self.add(ir.Goto(target))
[docs] def if_(self, op, target):
self.add(ir.If(op, target))
[docs] def switch(self, default, jumps):
jumps = {util.s32(k):v for k,v in jumps.items() if v != default}
if jumps:
self.add(ir.Switch(default, jumps))
else:
self.goto(default)
[docs] def generateExceptLabels(self):
s_ind = 0
e_ind = len(self.instructions)
# assume only Other instructions can throw
while s_ind < e_ind and not isinstance(self.instructions[s_ind], ir.Other):
s_ind += 1
while s_ind < e_ind and not isinstance(self.instructions[e_ind-1], ir.Other):
e_ind -= 1
assert s_ind < e_ind
start_lbl, end_lbl = ir.Label(), ir.Label()
self.instructions.insert(s_ind, start_lbl)
self.instructions.insert(e_ind+1, end_lbl)
return start_lbl, end_lbl
[docs]class IRWriter:
def __init__(self, pool, method, types, opts):
self.pool = pool
self.method = method
self.types = types
self.opts = opts
self.iblocks = {}
self.flat_instructions = None
self.excepts = []
self.labels = {}
self.initial_args = None
self.exception_redirects = {}
self.except_starts = set()
self.except_ends = set()
self.jump_targets = set()
# used to detect jump targets with a unique predecessor
self.target_pred_counts = collections.defaultdict(int)
self.numregs = None # will be set once registers are allocated (see registers.py)
[docs] def calcInitialArgs(self, nregs, scalar_ptypes):
self.initial_args = args = []
regoff = nregs - len(scalar_ptypes)
for i, st in enumerate(scalar_ptypes):
if st == scalars.INVALID:
args.append(None)
else:
args.append((i + regoff, st))
[docs] def addExceptionRedirect(self, target):
return self.exception_redirects.setdefault(target, ir.Label())
[docs] def createBlock(self, instr):
block = IRBlock(self, instr.pos)
self.iblocks[block.pos] = block
self.labels[block.pos] = block.instructions[0]
return block
[docs] def flatten(self):
instructions = []
for pos in sorted(self.iblocks):
if pos in self.exception_redirects:
# check if we can put handler pop in front of block
if instructions and not instructions[-1].fallsthrough():
instructions.append(self.exception_redirects.pop(pos))
instructions.append(ir.Pop())
# if not, leave it in dict to be redirected later
# now add instructions for actual block
instructions += self.iblocks[pos].instructions
# exception handler pops that couldn't be placed inline
# in this case, just put them at the end with a goto back to the handler
for target in sorted(self.exception_redirects):
instructions.append(self.exception_redirects[target])
instructions.append(ir.Pop())
instructions.append(ir.Goto(target))
self.flat_instructions = instructions
self.iblocks = self.exception_redirects = None
[docs] def replaceInstrs(self, replace):
if replace:
instructions = []
for instr in self.flat_instructions:
instructions.extend(replace.get(instr, [instr]))
self.flat_instructions = instructions
assert len(set(instructions)) == len(instructions)
[docs] def calcUpperBound(self):
# Get an uppper bound on the size of the bytecode
size = 0
for ins in self.flat_instructions:
if ins.bytecode is None:
size += ins.max
else:
size += len(ins.bytecode)
return size
################################################################################
[docs]def visitNop(method, dex, instr_d, type_data, block, instr):
pass
[docs]def visitMove(method, dex, instr_d, type_data, block, instr):
for st in (scalars.INT, scalars.OBJ, scalars.FLOAT):
if st & type_data.prims[instr.args[1]]:
block.load(instr.args[1], st)
block.store(instr.args[0], st)
[docs]def visitMoveWide(method, dex, instr_d, type_data, block, instr):
for st in (scalars.LONG, scalars.DOUBLE):
if st & type_data.prims[instr.args[1]]:
block.load(instr.args[1], st)
block.store(instr.args[0], st)
[docs]def visitMoveResult(method, dex, instr_d, type_data, block, instr):
st = scalars.fromDesc(instr.prev_result)
block.store(instr.args[0], st)
[docs]def visitReturn(method, dex, instr_d, type_data, block, instr):
if method.id.return_type == b'V':
block.return_()
else:
st = scalars.fromDesc(method.id.return_type)
block.load(instr.args[0], st, desc=method.id.return_type)
block.return_(st)
[docs]def visitConst32(method, dex, instr_d, type_data, block, instr):
val = instr.args[1] % (1<<32)
block.const(val, scalars.INT)
block.store(instr.args[0], scalars.INT)
block.const(val, scalars.FLOAT)
block.store(instr.args[0], scalars.FLOAT)
if not val:
block.const_null()
block.store(instr.args[0], scalars.OBJ)
[docs]def visitConst64(method, dex, instr_d, type_data, block, instr):
val = instr.args[1] % (1<<64)
block.const(val, scalars.LONG)
block.store(instr.args[0], scalars.LONG)
block.const(val, scalars.DOUBLE)
block.store(instr.args[0], scalars.DOUBLE)
[docs]def visitConstString(method, dex, instr_d, type_data, block, instr):
val = dex.string(instr.args[1])
block.ldc(block.pool.string(val))
block.store(instr.args[0], scalars.OBJ)
[docs]def visitConstClass(method, dex, instr_d, type_data, block, instr):
# Could use dex.type here since the JVM doesn't care, but this is cleaner
val = dex.clsType(instr.args[1])
block.ldc(block.pool.class_(val))
block.store(instr.args[0], scalars.OBJ)
[docs]def visitMonitorEnter(method, dex, instr_d, type_data, block, instr):
block.load(instr.args[0], scalars.OBJ)
block.u8(MONITORENTER)
[docs]def visitMonitorExit(method, dex, instr_d, type_data, block, instr):
block.load(instr.args[0], scalars.OBJ)
block.u8(MONITOREXIT)
[docs]def visitCheckCast(method, dex, instr_d, type_data, block, instr):
block.cast(dex, instr.args[0], instr.args[1])
[docs]def visitInstanceOf(method, dex, instr_d, type_data, block, instr):
block.load(instr.args[1], scalars.OBJ)
block.u8u16(INSTANCEOF, block.pool.class_(dex.clsType(instr.args[2])))
block.store(instr.args[0], scalars.INT)
[docs]def visitArrayLen(method, dex, instr_d, type_data, block, instr):
block.loadAsArray(instr.args[1])
block.u8(ARRAYLENGTH)
block.store(instr.args[0], scalars.INT)
[docs]def visitNewInstance(method, dex, instr_d, type_data, block, instr):
block.u8u16(NEW, block.pool.class_(dex.clsType(instr.args[1])))
block.store(instr.args[0], scalars.OBJ)
[docs]def visitNewArray(method, dex, instr_d, type_data, block, instr):
block.load(instr.args[1], scalars.INT)
block.newarray(dex.type(instr.args[2]))
block.store(instr.args[0], scalars.OBJ)
[docs]def visitFilledNewArray(method, dex, instr_d, type_data, block, instr):
regs = instr.args[1]
block.const(len(regs), scalars.INT)
block.newarray(dex.type(instr.args[0]))
st, elet = arrays.eletPair(arrays.fromDesc(dex.type(instr.args[0])))
op = _arrStoreOps.get(elet, AASTORE)
cbs = [partial(block.load, reg, st) for reg in regs]
# if not followed by move-result, don't leave it on the stack
mustpop = instr_d.get(instr.pos2).type != dalvik.MoveResult
block.fillarraysub(op, cbs, pop=mustpop)
[docs]def visitFillArrayData(method, dex, instr_d, type_data, block, instr):
width, arrdata = instr_d[instr.args[1]].fillarrdata
at = type_data.arrs[instr.args[0]]
block.loadAsArray(instr.args[0])
if at is arrays.NULL:
block.u8(ATHROW)
else:
if len(arrdata) == 0:
# fill-array-data throws a NPE if array is null even when
# there is 0 data, so we need to add an instruction that
# throws a NPE in this case
block.u8(ARRAYLENGTH)
block.add(ir.Pop())
else:
st, elet = arrays.eletPair(at)
# check if we need to sign extend
if elet == b'B' or elet == b'Z':
arrdata = [util.signExtend(x, 8) & 0xFFFFFFFF for x in arrdata]
elif elet == b'S':
arrdata = [util.signExtend(x, 16) & 0xFFFFFFFF for x in arrdata]
block.fillarraydata(_arrStoreOps.get(elet, AASTORE), st, arrdata)
[docs]def visitThrow(method, dex, instr_d, type_data, block, instr):
block.load(instr.args[0], scalars.OBJ, clsname=b'java/lang/Throwable')
block.u8(ATHROW)
[docs]def visitGoto(method, dex, instr_d, type_data, block, instr):
block.goto(instr.args[0])
[docs]def visitSwitch(method, dex, instr_d, type_data, block, instr):
block.load(instr.args[0], scalars.INT)
switchdata = instr_d[instr.args[1]].switchdata
default = instr.pos2
jumps = {k:(offset + instr.pos) % (1<<32) for k, offset in switchdata.items()}
block.switch(default, jumps)
[docs]def visitCmp(method, dex, instr_d, type_data, block, instr):
op = [FCMPL, FCMPG, DCMPL, DCMPG, LCMP][instr.opcode - 0x2d]
st = [scalars.FLOAT, scalars.FLOAT, scalars.DOUBLE, scalars.DOUBLE, scalars.LONG][instr.opcode - 0x2d]
block.load(instr.args[1], st)
block.load(instr.args[2], st)
block.u8(op)
block.store(instr.args[0], scalars.INT)
[docs]def visitIf(method, dex, instr_d, type_data, block, instr):
st = type_data.prims[instr.args[0]] & type_data.prims[instr.args[1]]
if st & scalars.INT:
block.load(instr.args[0], scalars.INT)
block.load(instr.args[1], scalars.INT)
op = [IF_ICMPEQ, IF_ICMPNE, IF_ICMPLT, IF_ICMPGE, IF_ICMPGT, IF_ICMPLE][instr.opcode - 0x32]
else:
block.load(instr.args[0], scalars.OBJ)
block.load(instr.args[1], scalars.OBJ)
op = [IF_ACMPEQ, IF_ACMPNE][instr.opcode - 0x32]
block.if_(op, instr.args[2])
[docs]def visitIfZ(method, dex, instr_d, type_data, block, instr):
if type_data.prims[instr.args[0]] & scalars.INT:
block.load(instr.args[0], scalars.INT)
op = [IFEQ, IFNE, IFLT, IFGE, IFGT, IFLE][instr.opcode - 0x38]
else:
block.load(instr.args[0], scalars.OBJ)
op = [IFNULL, IFNONNULL][instr.opcode - 0x38]
block.if_(op, instr.args[1])
[docs]def visitArrayGet(method, dex, instr_d, type_data, block, instr):
at = type_data.arrs[instr.args[1]]
if at is arrays.NULL:
block.const_null()
block.u8(ATHROW)
else:
block.loadAsArray(instr.args[1])
block.load(instr.args[2], scalars.INT)
st, elet = arrays.eletPair(at)
block.u8(_arrLoadOps.get(elet, AALOAD))
block.store(instr.args[0], st)
[docs]def visitArrayPut(method, dex, instr_d, type_data, block, instr):
at = type_data.arrs[instr.args[1]]
if at is arrays.NULL:
block.const_null()
block.u8(ATHROW)
else:
block.loadAsArray(instr.args[1])
block.load(instr.args[2], scalars.INT)
st, elet = arrays.eletPair(at)
block.load(instr.args[0], st)
block.u8(_arrStoreOps.get(elet, AASTORE))
[docs]def visitInstanceGet(method, dex, instr_d, type_data, block, instr):
field_id = dex.field_id(instr.args[2])
st = scalars.fromDesc(field_id.desc)
block.load(instr.args[1], scalars.OBJ, clsname=field_id.cname)
block.u8u16(GETFIELD, block.pool.field(field_id.triple()))
block.store(instr.args[0], st)
[docs]def visitInstancePut(method, dex, instr_d, type_data, block, instr):
field_id = dex.field_id(instr.args[2])
st = scalars.fromDesc(field_id.desc)
block.load(instr.args[1], scalars.OBJ, clsname=field_id.cname)
block.load(instr.args[0], st, desc=field_id.desc)
block.u8u16(PUTFIELD, block.pool.field(field_id.triple()))
[docs]def visitStaticGet(method, dex, instr_d, type_data, block, instr):
field_id = dex.field_id(instr.args[1])
st = scalars.fromDesc(field_id.desc)
block.u8u16(GETSTATIC, block.pool.field(field_id.triple()))
block.store(instr.args[0], st)
[docs]def visitStaticPut(method, dex, instr_d, type_data, block, instr):
field_id = dex.field_id(instr.args[1])
st = scalars.fromDesc(field_id.desc)
block.load(instr.args[0], st, desc=field_id.desc)
block.u8u16(PUTSTATIC, block.pool.field(field_id.triple()))
[docs]def visitInvoke(method, dex, instr_d, type_data, block, instr):
isstatic = instr.type == dalvik.InvokeStatic
called_id = dex.method_id(instr.args[0])
sts = scalars.paramTypes(called_id, static=isstatic)
descs = called_id.getSpacedParamTypes(isstatic=isstatic)
assert len(sts) == len(instr.args[1]) == len(descs)
for st, desc, reg in zip(sts, descs, instr.args[1]):
if st != scalars.INVALID: # skip long/double tops
block.load(reg, st, desc=desc)
op = {
dalvik.InvokeVirtual: INVOKEVIRTUAL,
dalvik.InvokeSuper: INVOKESPECIAL,
dalvik.InvokeDirect: INVOKESPECIAL,
dalvik.InvokeStatic: INVOKESTATIC,
dalvik.InvokeInterface: INVOKEINTERFACE,
}[instr.type]
if instr.type == dalvik.InvokeInterface:
block.u8u16u8u8(op, block.pool.imethod(called_id.triple()), len(descs), 0)
else:
block.u8u16(op, block.pool.method(called_id.triple()))
# check if we need to pop result instead of leaving on stack
if instr_d.get(instr.pos2).type != dalvik.MoveResult:
if called_id.return_type != b'V':
st = scalars.fromDesc(called_id.return_type)
block.add(ir.Pop2() if scalars.iswide(st) else ir.Pop())
[docs]def visitUnaryOp(method, dex, instr_d, type_data, block, instr):
op, srct, destt = mathops.UNARY[instr.opcode]
block.load(instr.args[1], srct)
# *not requires special handling since there's no direct Java equivalent. Instead we have to do x ^ -1
if op == IXOR:
block.u8(ICONST_M1)
elif op == LXOR:
block.u8(ICONST_M1)
block.u8(I2L)
block.u8(op)
block.store(instr.args[0], destt)
[docs]def visitBinaryOp(method, dex, instr_d, type_data, block, instr):
op, st, st2 = mathops.BINARY[instr.opcode]
# index arguments as negative so it works for regular and 2addr forms
block.load(instr.args[-2], st)
block.load(instr.args[-1], st2)
block.u8(op)
block.store(instr.args[0], st)
[docs]def visitBinaryOpConst(method, dex, instr_d, type_data, block, instr):
op = mathops.BINARY_LIT[instr.opcode]
if op == ISUB: # rsub
block.const(instr.args[2] % (1<<32), scalars.INT)
block.load(instr.args[1], scalars.INT)
else:
block.load(instr.args[1], scalars.INT)
block.const(instr.args[2] % (1<<32), scalars.INT)
block.u8(op)
block.store(instr.args[0], scalars.INT)
################################################################################
VISIT_FUNCS = {
dalvik.Nop: visitNop,
dalvik.Move: visitMove,
dalvik.MoveWide: visitMoveWide,
dalvik.MoveResult: visitMoveResult,
dalvik.Return: visitReturn,
dalvik.Const32: visitConst32,
dalvik.Const64: visitConst64,
dalvik.ConstString: visitConstString,
dalvik.ConstClass: visitConstClass,
dalvik.MonitorEnter: visitMonitorEnter,
dalvik.MonitorExit: visitMonitorExit,
dalvik.CheckCast: visitCheckCast,
dalvik.InstanceOf: visitInstanceOf,
dalvik.ArrayLen: visitArrayLen,
dalvik.NewInstance: visitNewInstance,
dalvik.NewArray: visitNewArray,
dalvik.FilledNewArray: visitFilledNewArray,
dalvik.FillArrayData: visitFillArrayData,
dalvik.Throw: visitThrow,
dalvik.Goto: visitGoto,
dalvik.Switch: visitSwitch,
dalvik.Cmp: visitCmp,
dalvik.If: visitIf,
dalvik.IfZ: visitIfZ,
dalvik.ArrayGet: visitArrayGet,
dalvik.ArrayPut: visitArrayPut,
dalvik.InstanceGet: visitInstanceGet,
dalvik.InstancePut: visitInstancePut,
dalvik.StaticGet: visitStaticGet,
dalvik.StaticPut: visitStaticPut,
dalvik.InvokeVirtual: visitInvoke,
dalvik.InvokeSuper: visitInvoke,
dalvik.InvokeDirect: visitInvoke,
dalvik.InvokeStatic: visitInvoke,
dalvik.InvokeInterface: visitInvoke,
dalvik.UnaryOp: visitUnaryOp,
dalvik.BinaryOp: visitBinaryOp,
dalvik.BinaryOpConst: visitBinaryOpConst,
}
[docs]def writeBytecode(pool, method, opts):
dex = method.dex
code = method.code
instr_d = {instr.pos: instr for instr in code.bytecode}
types, all_handlers = typeinference.doInference(dex, method, code, code.bytecode, instr_d)
scalar_ptypes = scalars.paramTypes(method.id, static=(method.access & flags.ACC_STATIC))
writer = IRWriter(pool, method, types, opts)
writer.calcInitialArgs(code.nregs, scalar_ptypes)
for instr in code.bytecode:
if instr.pos not in types: # skip unreachable instructions
continue
type_data = types[instr.pos]
block = writer.createBlock(instr)
VISIT_FUNCS[instr.type](method, dex, instr_d, type_data, block, instr)
for instr in sorted(all_handlers, key=lambda instr: instr.pos):
assert all_handlers[instr]
if instr.pos not in types: # skip unreachable instructions
continue
start, end = writer.iblocks[instr.pos].generateExceptLabels()
writer.except_starts.add(start)
writer.except_ends.add(end)
for ctype, handler_pos in all_handlers[instr]:
# If handler doesn't use the caught exception, we need to redirect to a pop instead
if instr_d.get(handler_pos).type != dalvik.MoveResult:
target = writer.addExceptionRedirect(handler_pos)
else:
target = writer.labels[handler_pos]
writer.jump_targets.add(target)
writer.target_pred_counts[target] += 1
# When catching Throwable, we can use the special index 0 instead,
# potentially saving a constant pool entry or two
jctype = 0 if ctype == b'java/lang/Throwable' else pool.class_(ctype)
writer.excepts.append((start, end, target, jctype))
writer.flatten()
# find jump targets (in addition to exception handler targets)
for instr in writer.flat_instructions:
for target in instr.targets():
label = writer.labels[target]
writer.jump_targets.add(label)
writer.target_pred_counts[label] += 1
return writer