Source code for genlayer.py.eth.calldata

"""
Module that provides eth calldata encoding and decoding
"""

__all__ = ('MethodEncoder', 'decode')
from ..keccak import Keccak256

import typing

from genlayer.py.types import *
from genlayer.py.storage import Array, DynArray
from functools import partial

_integer_types = {
	u8: 'uint8',
	u16: 'uint16',
	u24: 'uint24',
	u32: 'uint32',
	u40: 'uint40',
	u48: 'uint48',
	u56: 'uint56',
	u64: 'uint64',
	u72: 'uint72',
	u80: 'uint80',
	u88: 'uint88',
	u96: 'uint96',
	u104: 'uint104',
	u112: 'uint112',
	u120: 'uint120',
	u128: 'uint128',
	u136: 'uint136',
	u144: 'uint144',
	u152: 'uint152',
	u160: 'uint160',
	u168: 'uint168',
	u176: 'uint176',
	u184: 'uint184',
	u192: 'uint192',
	u200: 'uint200',
	u208: 'uint208',
	u216: 'uint216',
	u224: 'uint224',
	u232: 'uint232',
	u240: 'uint240',
	u248: 'uint248',
	u256: 'uint256',
	i8: 'int8',
	i16: 'int16',
	i24: 'int24',
	i32: 'int32',
	i40: 'int40',
	i48: 'int48',
	i56: 'int56',
	i64: 'int64',
	i72: 'int72',
	i80: 'int80',
	i88: 'int88',
	i96: 'int96',
	i104: 'int104',
	i112: 'int112',
	i120: 'int120',
	i128: 'int128',
	i136: 'int136',
	i144: 'int144',
	i152: 'int152',
	i160: 'int160',
	i168: 'int168',
	i176: 'int176',
	i184: 'int184',
	i192: 'int192',
	i200: 'int200',
	i208: 'int208',
	i216: 'int216',
	i224: 'int224',
	i232: 'int232',
	i240: 'int240',
	i248: 'int248',
	i256: 'int256',
}

_simple = {
	bool: 'bool',
	str: 'string',
	bytes: 'bytes',
	Address: 'address',
	**_integer_types,
}


def get_type_eth_name(t: type) -> str:
	if (simp := _simple.get(t, None)) is not None:
		return simp

	origin = typing.get_origin(t)
	if origin is not None:
		args = typing.get_args(t)
		if origin is Array:
			assert typing.get_origin(args[1]) is typing.Literal
			le = int(*typing.get_args(args[1]))
			if args[0] == u8 and le >= 1 and le <= 32:
				return f'bytes{le}'
			else:
				return f'{get_type_eth_name(args[0])}[{le}]'
		elif origin is list or origin is DynArray:
			assert len(args) == 1
			return f'{get_type_eth_name(args[0])}[]'
		elif origin is tuple:
			return '(' + ','.join(map(get_type_eth_name, args)) + ')'
	assert False, f'unknown type {t} {type(t)} {t is str} {type(t) is str}'


def is_dynamic(param: type):
	if param is bytes or param is str:
		return True
	origin = typing.get_origin(param)
	if origin is None:
		return False
	type_args = typing.get_args(param)
	if origin is DynArray or origin is list:
		return True
	elif origin is tuple:
		return any(is_dynamic(x) for x in type_args)
	elif origin is Array:
		return True
	return False


def calc_size_here(param: type) -> int:
	if is_dynamic(param):
		return 32
	if param in _integer_types:
		return 32
	return False


type _Tails = list[typing.Callable[[_Tails], None]]


[docs] class MethodEncoder: """ Type used to encode method call """ name: str """ method name """ params: list[type] """ method parameter types """ ret: type """ return type (can be unused) """ selector: bytes """ calculated function "selector", see eth docs """
[docs] def __init__(self, name: str, params: list[type], ret: type): self.name = name self.params = params self.ret = ret sig = self.make_sig() self.selector = Keccak256(sig.encode('utf-8')).digest()[:4]
[docs] def make_sig(self) -> str: """ calculates signature that is used for making method selector """ sig: list[str] = [self.name, '('] for i, par in enumerate(self.params): if i != 0: sig.append(',') sig.append(get_type_eth_name(par)) sig.append(')') return ''.join(sig)
[docs] def encode(self, args: list[typing.Any]) -> bytes: """ encodes ``args`` according to this encoder ``params`` to produce a calldata :returns: full calldata encoded call: both selector and arguments """ assert len(args) == len(self.params) result: bytearray = bytearray() result.extend(self.selector) current_off: int = len(result) def run_seq_with_new_tails(cur: _Tails): nonlocal current_off old_off = current_off current_off = len(result) loc_tails: _Tails = [] while len(cur) != 0: for i in cur: i(loc_tails) cur = loc_tails loc_tails = [] current_off = old_off def put_offset_at(off: int, off0: int) -> None: to_put = len(result) - off0 memoryview(result)[off : off + 32] = int.to_bytes(to_put, 32, 'big') def put_iloc(tails: _Tails): off = len(result) result.extend(b'\x00' * 32) tails.append(lambda _t: put_offset_at(off, current_off)) def put_regular(param: type, arg: typing.Any, tails: _Tails) -> None: as_int = _integer_types.get(param, None) if as_int is not None: result.extend(int.to_bytes(arg, 32, 'big', signed=as_int.startswith('i'))) elif param is bool: result.extend(int.to_bytes(1 if arg else 0, 32, 'big')) elif param is Address: result.extend(b'\x00' * 12) result.extend(arg.as_bytes) elif param is bytes or param is str: put_iloc(tails) if param is bytes: as_bytes = typing.cast(bytes, arg) else: as_bytes = typing.cast(str, arg).encode('utf-8') def put_bytes(_tails): result.extend(int.to_bytes(len(as_bytes), 32, 'big')) result.extend(as_bytes) result.extend(b'\x00' * ((32 - len(as_bytes) % 32) % 32)) tails.append(put_bytes) elif (origin := typing.get_origin(param)) is not None: type_args = typing.get_args(param) if origin is DynArray or origin is list: put_iloc(tails) as_seq = typing.cast(collections.abc.Sequence, arg) def put_arr(tails: _Tails): result.extend(int.to_bytes(len(as_seq), 32, 'big')) cur: _Tails = [] for i in range(len(as_seq)): cur.append(partial(put_regular, type_args[0], as_seq[i])) run_seq_with_new_tails(cur) tails.append(put_arr) elif origin is Array: assert typing.get_origin(type_args[1]) is typing.Literal le = int(*typing.get_args(type_args[1])) for i in range(le): put_regular(type_args[0], arg[i], tails) elif origin is tuple: def put_tuple(_tails): cur: _Tails = [] for p, a in zip(type_args, arg): cur.append(partial(put_regular, p, a)) run_seq_with_new_tails(cur) if is_dynamic(param): put_iloc(tails) tails.append(put_tuple) else: put_tuple(None) else: assert False else: assert False cur: _Tails = [] for p, a in zip(self.params, args): cur.append(partial(put_regular, p, a)) run_seq_with_new_tails(cur) return bytes(result)
[docs] def decode(params: list[type], data: collections.abc.Buffer) -> list[typing.Any]: """ :param params: eth method returns (if it returns a single value, i.e. ``u32``, provide a list of one element) :param data: eth calldata encoded structure that conforms to ``params``, note that it must not contain selector :returns: list of what is encoded in the calldata """ mem = memoryview(data) current_off: int = 0 current_off_0: int = 0 def with_indirection[T](fn: typing.Callable[[], T], new_self_length) -> T: nonlocal current_off, current_off_0 off = int.from_bytes(mem[current_off : current_off + 32], 'big', signed=False) current_off += 32 old_current_off = current_off old_current_off_0 = current_off_0 # current_off_0 = current_off_0 + self_length + off - 32 current_off_0 = current_off_0 + off # current_off_0 = current_off_0 + off current_off = current_off_0 assert current_off_0 < len(mem) res = fn() current_off = old_current_off current_off_0 = old_current_off_0 return res def read_regular(param: type) -> typing.Any: nonlocal current_off, current_off_0 as_int = _integer_types.get(param, None) if as_int is not None: res = int.from_bytes( mem[current_off : current_off + 32], 'big', signed=as_int.startswith('i') ) current_off += 32 return res elif param is bool: res = int.from_bytes(mem[current_off : current_off + 32], 'big', signed=False) current_off += 32 return res != 0 elif param is Address: current_off += 12 as_bytes = mem[current_off : current_off + 20] current_off += 20 return Address(as_bytes) elif param is bytes or param is str: def read_bytes_str() -> bytes | str: nonlocal current_off le = int.from_bytes(mem[current_off : current_off + 32], 'big', signed=False) current_off += 32 as_bytes = mem[current_off : current_off + le] if param is bytes: return bytes(as_bytes) else: return str(as_bytes, encoding='utf-8') return with_indirection(read_bytes_str, -1) elif (origin := typing.get_origin(param)) is not None: type_args = typing.get_args(param) if origin is tuple: if is_dynamic(param): return with_indirection( lambda: tuple(read_regular(p) for p in type_args), calc_size_here(param) ) else: return tuple(read_regular(p) for p in type_args) elif origin is list or origin is DynArray: [elem] = type_args def read_list() -> list: nonlocal current_off, current_off_0 le = int.from_bytes(mem[current_off : current_off + 32], 'big', signed=False) current_off_0 += 32 current_off += 32 return [read_regular(elem) for _i in range(le)] return with_indirection(read_list, -1) elif origin is Array: assert typing.get_origin(type_args[1]) is typing.Literal le = int(*typing.get_args(type_args[1])) return [read_regular(type_args[0]) for _i in range(le)] else: assert False else: assert False return [read_regular(par) for par in params]