Source code for genlayer_embeddings.model_wrappers

__all__ = ('Model', 'SentenceTransformer')

import numpy as np
from numpy.typing import DTypeLike
from ._nn.builder import Builder
from ._nn import get_run_onnx
from pathlib import Path
import json
import onnx
import collections.abc
import typing
import os
import warnings

_models = os.getenv('GENLAYER_EMBEDDINGS_MODELS', '')
_models_paths = _models.split(':')

_ALL_MODELS = {}

for i in _models_paths:
	if len(i) == 0:
		continue
	p = Path(i)
	data = json.loads(p.joinpath('model.json').read_text())
	_ALL_MODELS[data['name']] = {'path': p.joinpath('model.onnx'), **data}


# type Model = typing.Callable[..., dict[str, np.ndarray]]


[docs] def get_model(model: str, inputs: dict[str, DTypeLike], *, models_db=_ALL_MODELS): model_desc = models_db[model] # Create input placeholders as variable names user_inputs = {} for k in inputs.keys(): user_inputs[k] = k onnx_model = onnx.load_model(model_desc['path'], load_external_data=False) rename_outputs = model_desc.get('rename-outputs', {}) builder, inp = get_run_onnx(onnx_model, user_inputs, rename_outputs) builder._prelude.append( f'tokens_truncate = {repr(model_desc.get("tokens_truncate", None))}\n' ) return builder.finish(parameters=inp)
def prod(x: collections.abc.Sequence[int]): res = 1 for i in x: res *= i return res def _unfold(x: np.ndarray): return x.reshape(prod(x.shape)) _cache_SentenceTransformer: dict[str, typing.Callable[[str], np.ndarray]] = {}
[docs] def SentenceTransformerFromPath(path: str) -> typing.Callable[[str], np.ndarray]: if res := _cache_SentenceTransformer.get(path): return res from word_piece_tokenizer import WordPieceTokenizer tokenizer = WordPieceTokenizer() data = Path(path).read_text() globs = {} exec(data, globs) nn_model = globs['main'] truncate: int | None = globs.get('tokens_truncate', None) def ret(text: str) -> np.ndarray: res = tokenizer.tokenize(text) if truncate and len(res) > truncate: warnings.warn(f'truncating input tokens from {len(res)} to {truncate}') res = res[:truncate] res = np.array(res, np.int64) res = res.reshape(1, prod(res.shape)) return _unfold( nn_model( input_ids=res, attention_mask=np.zeros(res.shape, res.dtype), token_type_ids=np.zeros(res.shape, res.dtype), )['embedding'] ) _cache_SentenceTransformer[path] = ret return ret
[docs] def SentenceTransformer(model: str) -> typing.Callable[[str], np.ndarray]: if res := _cache_SentenceTransformer.get(model): return res from word_piece_tokenizer import WordPieceTokenizer tokenizer = WordPieceTokenizer() nn_model = get_model( model, { 'input_ids': np.int64, 'attention_mask': np.int64, 'token_type_ids': np.int64, }, ) model_desc = _ALL_MODELS[model] truncate: int | None = model_desc.get('tokens_truncate', None) def ret(text: str) -> np.ndarray: res = tokenizer.tokenize(text) if truncate and len(res) > truncate: warnings.warn(f'truncating input tokens from {len(res)} to {truncate}') res = res[:truncate] res = np.array(res, np.int64) res = res.reshape(1, prod(res.shape)) return _unfold( nn_model( input_ids=res, attention_mask=np.zeros(res.shape, res.dtype), token_type_ids=np.zeros(res.shape, res.dtype), )['embedding'] ) _cache_SentenceTransformer[model] = ret return ret