diff --git a/ConflictSet.cpp b/ConflictSet.cpp index ebe88c6..b3555f0 100644 --- a/ConflictSet.cpp +++ b/ConflictSet.cpp @@ -2647,17 +2647,27 @@ ConflictSet_check(void *cs, const ConflictSet_ReadRange *reads, __attribute__((__visibility__("default"))) void ConflictSet_addWrites(void *cs, const ConflictSet_WriteRange *writes, int count, int64_t writeVersion) { - auto *impl = ((ConflictSet::Impl *)cs); + auto *impl = (ConflictSet::Impl *)cs; mallocBytesDelta = 0; impl->addWrites(writes, count, writeVersion); impl->totalBytes += mallocBytesDelta; +#if SHOW_MEMORY + if (impl->totalBytes != mallocBytes) { + abort(); + } +#endif } __attribute__((__visibility__("default"))) void ConflictSet_setOldestVersion(void *cs, int64_t oldestVersion) { - auto *impl = ((ConflictSet::Impl *)cs); + auto *impl = (ConflictSet::Impl *)cs; mallocBytesDelta = 0; impl->setOldestVersion(oldestVersion); impl->totalBytes += mallocBytesDelta; +#if SHOW_MEMORY + if (impl->totalBytes != mallocBytes) { + abort(); + } +#endif } __attribute__((__visibility__("default"))) void * ConflictSet_create(int64_t oldestVersion) { diff --git a/SkipList.cpp b/SkipList.cpp index 6b47a44..ea91c4d 100644 --- a/SkipList.cpp +++ b/SkipList.cpp @@ -702,16 +702,35 @@ ConflictSet_check(void *cs, const ConflictSet_ReadRange *reads, __attribute__((__visibility__("default"))) void ConflictSet_addWrites(void *cs, const ConflictSet_WriteRange *writes, int count, int64_t writeVersion) { - ((ConflictSet::Impl *)cs)->addWrites(writes, count, writeVersion); + auto *impl = (ConflictSet::Impl *)cs; + mallocBytesDelta = 0; + impl->addWrites(writes, count, writeVersion); + impl->totalBytes += mallocBytesDelta; +#if SHOW_MEMORY + if (impl->totalBytes != mallocBytes) { + abort(); + } +#endif } __attribute__((__visibility__("default"))) void ConflictSet_setOldestVersion(void *cs, int64_t oldestVersion) { - ((ConflictSet::Impl *)cs)->setOldestVersion(oldestVersion); + auto *impl = (ConflictSet::Impl *)cs; + mallocBytesDelta = 0; + impl->setOldestVersion(oldestVersion); + impl->totalBytes += mallocBytesDelta; +#if SHOW_MEMORY + if (impl->totalBytes != mallocBytes) { + abort(); + } +#endif } __attribute__((__visibility__("default"))) void * ConflictSet_create(int64_t oldestVersion) { - return new (safe_malloc(sizeof(ConflictSet::Impl))) + mallocBytesDelta = 0; + auto *result = new (safe_malloc(sizeof(ConflictSet::Impl))) ConflictSet::Impl{oldestVersion}; + result->totalBytes += mallocBytesDelta; + return result; } __attribute__((__visibility__("default"))) void ConflictSet_destroy(void *cs) { using Impl = ConflictSet::Impl; diff --git a/conflict_set.py b/conflict_set.py index 10c68fa..ccab8c3 100644 --- a/conflict_set.py +++ b/conflict_set.py @@ -4,22 +4,6 @@ import os from typing import Optional -_lib = None -for f in ( - os.path.dirname(__file__) + "/build/radix_tree/libconflict-set.so.0", - os.path.dirname(__file__) + "/build/radix_tree/libconflict-set.0.dylib", -): - try: - _lib = ctypes.cdll.LoadLibrary(f) - except: - pass - -if _lib is None: - import sys - - print("Could not find libconflict-set", file=sys.stderr) - sys.exit(1) - class _Key(ctypes.Structure): _fields_ = [("p", ctypes.POINTER(ctypes.c_ubyte)), ("len", ctypes.c_int)] @@ -37,31 +21,6 @@ class WriteRange(ctypes.Structure): _fields_ = [("begin", _Key), ("end", _Key)] -_lib.ConflictSet_create.argtypes = (ctypes.c_int64,) -_lib.ConflictSet_create.restype = ctypes.c_void_p - -_lib.ConflictSet_check.argtypes = ( - ctypes.c_void_p, - ctypes.POINTER(ReadRange), - ctypes.POINTER(ctypes.c_int), - ctypes.c_int, -) - -_lib.ConflictSet_addWrites.argtypes = ( - ctypes.c_void_p, - ctypes.POINTER(WriteRange), - ctypes.c_int, - ctypes.c_int64, -) - -_lib.ConflictSet_setOldestVersion.argtypes = (ctypes.c_void_p, ctypes.c_int64) - -_lib.ConflictSet_destroy.argtypes = (ctypes.c_void_p,) - -_lib.ConflictSet_getBytes.argtypes = (ctypes.c_void_p,) -_lib.ConflictSet_getBytes.restype = ctypes.c_int64 - - class Result(enum.Enum): COMMIT = 0 CONFLICT = 1 @@ -93,34 +52,88 @@ def read(version: int, begin: bytes, end: Optional[bytes] = None) -> ReadRange: class ConflictSet: - def __init__(self, version: int = 0) -> None: - self.p = _lib.ConflictSet_create(version) + def __init__(self, version: int = 0, implementation: Optional[str] = None) -> None: + self._lib = None + if implementation is None: + implementation = "radix_tree" + for f in ( + os.path.dirname(__file__) + + "/build/" + + implementation + + "/libconflict-set.so.0", + os.path.dirname(__file__) + + "/build/" + + implementation + + "/libconflict-set.0.dylib", + ): + try: + self._lib = ctypes.cdll.LoadLibrary(f) + except: + pass + + if self._lib is None: + import sys + + print( + "Could not find libconflict-set implementation " + implementation, + file=sys.stderr, + ) + sys.exit(1) + + self._lib.ConflictSet_create.argtypes = (ctypes.c_int64,) + self._lib.ConflictSet_create.restype = ctypes.c_void_p + + self._lib.ConflictSet_check.argtypes = ( + ctypes.c_void_p, + ctypes.POINTER(ReadRange), + ctypes.POINTER(ctypes.c_int), + ctypes.c_int, + ) + + self._lib.ConflictSet_addWrites.argtypes = ( + ctypes.c_void_p, + ctypes.POINTER(WriteRange), + ctypes.c_int, + ctypes.c_int64, + ) + + self._lib.ConflictSet_setOldestVersion.argtypes = ( + ctypes.c_void_p, + ctypes.c_int64, + ) + + self._lib.ConflictSet_destroy.argtypes = (ctypes.c_void_p,) + + self._lib.ConflictSet_getBytes.argtypes = (ctypes.c_void_p,) + self._lib.ConflictSet_getBytes.restype = ctypes.c_int64 + + self.p = self._lib.ConflictSet_create(version) def addWrites(self, version: int, *writes: WriteRange): - _lib.ConflictSet_addWrites( + self._lib.ConflictSet_addWrites( self.p, (WriteRange * len(writes))(*writes), len(writes), version ) def check(self, *reads: ReadRange) -> list[Result]: r = (ctypes.c_int * len(reads))() - _lib.ConflictSet_check(self.p, *reads, r, 1) + self._lib.ConflictSet_check(self.p, *reads, r, 1) return [Result(x) for x in r] def setOldestVersion(self, version: int) -> None: - _lib.ConflictSet_setOldestVersion(self.p, version) + self._lib.ConflictSet_setOldestVersion(self.p, version) def getBytes(self) -> int: - return _lib.ConflictSet_getBytes(self.p) + return self._lib.ConflictSet_getBytes(self.p) def __enter__(self): return self def close(self) -> None: if self.p is not None: - _lib.ConflictSet_destroy(self.p) + self._lib.ConflictSet_destroy(self.p) self.p = None def __exit__(self, exception_type, exception_value, exception_traceback): if self.p is not None: - _lib.ConflictSet_destroy(self.p) + self._lib.ConflictSet_destroy(self.p) self.p = None