import ctypes import enum import os from typing import Optional class _Key(ctypes.Structure): _fields_ = [("p", ctypes.POINTER(ctypes.c_ubyte)), ("len", ctypes.c_int)] class ReadRange(ctypes.Structure): _fields_ = [ ("begin", _Key), ("end", _Key), ("readVersion", ctypes.c_int64), ] class WriteRange(ctypes.Structure): _fields_ = [("begin", _Key), ("end", _Key)] class Result(enum.Enum): COMMIT = 0 CONFLICT = 1 TOO_OLD = 2 def write(begin: bytes, end: Optional[bytes] = None) -> WriteRange: b = (ctypes.c_ubyte * len(begin)).from_buffer(bytearray(begin)) if end is None: e = (ctypes.c_ubyte * 0)() else: e = (ctypes.c_ubyte * len(end)).from_buffer(bytearray(end)) return WriteRange(_Key(b, len(b)), _Key(e, len(e))) def read(version: int, begin: bytes, end: Optional[bytes] = None) -> ReadRange: b = (ctypes.c_ubyte * len(begin)).from_buffer(bytearray(begin)) if end is None: e = (ctypes.c_ubyte * 0)() else: e = (ctypes.c_ubyte * len(end)).from_buffer(bytearray(end)) return ReadRange(_Key(b, len(b)), _Key(e, len(e)), version) class ConflictSet: def __init__( self, version: int = 0, build_dir: Optional[str] = None, implementation: Optional[str] = None, ) -> None: self._lib = None if build_dir is None: build_dir = os.path.dirname(__file__) + "/build" if implementation is None: implementation = "radix_tree" for f in ( build_dir + "/" + implementation + "/libconflict-set.so.0", os.path.dirname(__file__) + "/build/" + implementation + "/libconflict-set.0.dylib", ): try: self._lib = ctypes.cdll.LoadLibrary(f) except: pass if self._lib is None: import sys print( "Could not find libconflict-set implementation " + implementation, file=sys.stderr, ) sys.exit(1) self._lib.ConflictSet_create.argtypes = (ctypes.c_int64,) self._lib.ConflictSet_create.restype = ctypes.c_void_p self._lib.ConflictSet_check.argtypes = ( ctypes.c_void_p, ctypes.POINTER(ReadRange), ctypes.POINTER(ctypes.c_int), ctypes.c_int, ) self._lib.ConflictSet_addWrites.argtypes = ( ctypes.c_void_p, ctypes.POINTER(WriteRange), ctypes.c_int, ctypes.c_int64, ) self._lib.ConflictSet_setOldestVersion.argtypes = ( ctypes.c_void_p, ctypes.c_int64, ) self._lib.ConflictSet_destroy.argtypes = (ctypes.c_void_p,) self._lib.ConflictSet_getBytes.argtypes = (ctypes.c_void_p,) self._lib.ConflictSet_getBytes.restype = ctypes.c_int64 self.p = self._lib.ConflictSet_create(version) def addWrites(self, version: int, *writes: WriteRange): self._lib.ConflictSet_addWrites( self.p, (WriteRange * len(writes))(*writes), len(writes), version ) def check(self, *reads: ReadRange) -> list[Result]: r = (ctypes.c_int * len(reads))() self._lib.ConflictSet_check(self.p, *reads, r, 1) return [Result(x) for x in r] def setOldestVersion(self, version: int) -> None: self._lib.ConflictSet_setOldestVersion(self.p, version) def getBytes(self) -> int: return self._lib.ConflictSet_getBytes(self.p) def __enter__(self): return self def close(self) -> None: if self.p is not None: self._lib.ConflictSet_destroy(self.p) self.p = None def __exit__(self, exception_type, exception_value, exception_traceback): self.close()