diff --git a/conflict_set.py b/conflict_set.py new file mode 100644 index 0000000..0f82970 --- /dev/null +++ b/conflict_set.py @@ -0,0 +1,118 @@ +import ctypes +import enum +import os + +_lib = None +for f in ( + os.path.dirname(__file__) + "/build/radix_tree/libconflict-set.so.0", + os.path.dirname(__file__) + "/build/radix_tree/libconflict-set.dylib.0", +): + try: + _lib = ctypes.cdll.LoadLibrary(f) + except: + pass + +if _lib is None: + import sys + + print("Could not find libconflict-set", file=sys.stderr) + sys.exit(1) + + +class _Key(ctypes.Structure): + _fields_ = [("p", ctypes.POINTER(ctypes.c_ubyte)), ("len", ctypes.c_int)] + + +class ReadRange(ctypes.Structure): + _fields_ = [ + ("begin", _Key), + ("end", _Key), + ("readVersion", ctypes.c_int64), + ] + + +class WriteRange(ctypes.Structure): + _fields_ = [("begin", _Key), ("end", _Key)] + + +_lib.ConflictSet_create.argtypes = (ctypes.c_int64,) +_lib.ConflictSet_create.restype = ctypes.c_void_p + +_lib.ConflictSet_check.argtypes = ( + ctypes.c_void_p, + ctypes.POINTER(ReadRange), + ctypes.POINTER(ctypes.c_int), + ctypes.c_int, +) + +_lib.ConflictSet_addWrites.argtypes = ( + ctypes.c_void_p, + ctypes.POINTER(WriteRange), + ctypes.c_int, + ctypes.c_int64, +) + +_lib.ConflictSet_setOldestVersion.argtypes = (ctypes.c_void_p, ctypes.c_int64) + +_lib.ConflictSet_destroy.argtypes = (ctypes.c_void_p,) + + +class Result(enum.Enum): + COMMIT = 0 + CONFLICT = 1 + TOO_OLD = 2 + + +def write(begin: bytes, end: bytes | None = None) -> WriteRange: + b = (ctypes.c_ubyte * len(begin))() + b.value = begin + if end is None: + e = (ctypes.c_ubyte * 0)() + e.value = b"" + else: + e = (ctypes.c_ubyte * len(end))() + e.value = end + return WriteRange(_Key(b, len(b)), _Key(e, len(e))) + + +def read(version: int, begin: bytes, end: bytes = None) -> ReadRange: + b = (ctypes.c_ubyte * len(begin))() + b.value = begin + if end is None: + e = (ctypes.c_ubyte * 0)() + e.value = b"" + else: + e = (ctypes.c_ubyte * len(end))() + e.value = end + return ReadRange(_Key(b, len(b)), _Key(e, len(e)), version) + + +class ConflictSet: + def __init__(self, version: int = 0) -> None: + self.p = _lib.ConflictSet_create(version) + + def addWrites(self, version: int, *writes: WriteRange): + _lib.ConflictSet_addWrites( + self.p, (WriteRange * len(writes))(*writes), len(writes), version + ) + + def check(self, *reads: ReadRange) -> list[Result]: + r = (ctypes.c_int * len(reads))() + _lib.ConflictSet_check(self.p, *reads, r, 1) + return [Result(x) for x in r] + + def setOldestVersion(self, version: int) -> None: + _lib.ConflictSet_setOldestVersion(self.p, version) + + def __enter__(self): + return self + + def close(self) -> None: + if self.p is not None: + _lib.ConflictSet_destroy(self.p) + self.p = None + + def __exit__(self, exception_type, exception_value, exception_traceback): + if self.p is not None: + _lib.ConflictSet_destroy(self.p) + self.p = None diff --git a/test_conflict_set.py b/test_conflict_set.py new file mode 100644 index 0000000..7d45336 --- /dev/null +++ b/test_conflict_set.py @@ -0,0 +1,9 @@ +from conflict_set import * + + +def test_conflict_set(): + with ConflictSet() as cs: + cs.addWrites(1, write(b"")) + assert cs.check(read(0, b"")) == [Result.CONFLICT] + cs.setOldestVersion(1) + assert cs.check(read(0, b"")) == [Result.TOO_OLD]