import ctypes import enum import os from typing import Optional _lib = None for f in ( os.path.dirname(__file__) + "/build/radix_tree/libconflict-set.so.0", os.path.dirname(__file__) + "/build/radix_tree/libconflict-set.0.dylib", ): try: _lib = ctypes.cdll.LoadLibrary(f) except: pass if _lib is None: import sys print("Could not find libconflict-set", file=sys.stderr) sys.exit(1) class _Key(ctypes.Structure): _fields_ = [("p", ctypes.POINTER(ctypes.c_ubyte)), ("len", ctypes.c_int)] class ReadRange(ctypes.Structure): _fields_ = [ ("begin", _Key), ("end", _Key), ("readVersion", ctypes.c_int64), ] class WriteRange(ctypes.Structure): _fields_ = [("begin", _Key), ("end", _Key)] _lib.ConflictSet_create.argtypes = (ctypes.c_int64,) _lib.ConflictSet_create.restype = ctypes.c_void_p _lib.ConflictSet_check.argtypes = ( ctypes.c_void_p, ctypes.POINTER(ReadRange), ctypes.POINTER(ctypes.c_int), ctypes.c_int, ) _lib.ConflictSet_addWrites.argtypes = ( ctypes.c_void_p, ctypes.POINTER(WriteRange), ctypes.c_int, ctypes.c_int64, ) _lib.ConflictSet_setOldestVersion.argtypes = (ctypes.c_void_p, ctypes.c_int64) _lib.ConflictSet_destroy.argtypes = (ctypes.c_void_p,) _lib.ConflictSet_getBytes.argtypes = (ctypes.c_void_p,) _lib.ConflictSet_getBytes.restype = ctypes.c_int64 class Result(enum.Enum): COMMIT = 0 CONFLICT = 1 TOO_OLD = 2 def write(begin: bytes, end: Optional[bytes] = None) -> WriteRange: b = (ctypes.c_ubyte * len(begin))() b.value = begin if end is None: e = (ctypes.c_ubyte * 0)() e.value = b"" else: e = (ctypes.c_ubyte * len(end))() e.value = end return WriteRange(_Key(b, len(b)), _Key(e, len(e))) def read(version: int, begin: bytes, end: Optional[bytes] = None) -> ReadRange: b = (ctypes.c_ubyte * len(begin))() b.value = begin if end is None: e = (ctypes.c_ubyte * 0)() e.value = b"" else: e = (ctypes.c_ubyte * len(end))() e.value = end return ReadRange(_Key(b, len(b)), _Key(e, len(e)), version) class ConflictSet: def __init__(self, version: int = 0) -> None: self.p = _lib.ConflictSet_create(version) def addWrites(self, version: int, *writes: WriteRange): _lib.ConflictSet_addWrites( self.p, (WriteRange * len(writes))(*writes), len(writes), version ) def check(self, *reads: ReadRange) -> list[Result]: r = (ctypes.c_int * len(reads))() _lib.ConflictSet_check(self.p, *reads, r, 1) return [Result(x) for x in r] def setOldestVersion(self, version: int) -> None: _lib.ConflictSet_setOldestVersion(self.p, version) def getBytes(self) -> int: return _lib.ConflictSet_getBytes(self.p) def __enter__(self): return self def close(self) -> None: if self.p is not None: _lib.ConflictSet_destroy(self.p) self.p = None def __exit__(self, exception_type, exception_value, exception_traceback): if self.p is not None: _lib.ConflictSet_destroy(self.p) self.p = None