from conflict_set import * build_dir = None class DebugConflictSet: """ Bisimulates the skip list and radix tree conflict sets for testing purposes """ def __init__(self, version: int = 0) -> None: self.skip_list = ConflictSet( version, build_dir=build_dir, implementation="skip_list" ) self.radix_tree = ConflictSet( version, build_dir=build_dir, implementation="radix_tree" ) def addWrites(self, version: int, *writes: WriteRange): self.skip_list.addWrites(version, *writes) self.radix_tree.addWrites(version, *writes) def check(self, *reads: ReadRange) -> list[Result]: expected = self.skip_list.check(*reads) actual = self.radix_tree.check(*reads) assert expected == actual return actual def setOldestVersion(self, version: int) -> None: self.skip_list.setOldestVersion(version) self.radix_tree.setOldestVersion(version) def getBytes(self) -> int: return self.radix_tree.getBytes() def __enter__(self): return self def close(self) -> None: self.skip_list.close() self.radix_tree.close() def __exit__(self, exception_type, exception_value, exception_traceback): self.close() def test_conflict_set(): with DebugConflictSet() as cs: before = cs.getBytes() key = b"a key" cs.addWrites(1, write(key)) assert cs.getBytes() - before > 0 assert cs.check(read(0, key)) == [Result.CONFLICT] cs.setOldestVersion(1) assert cs.check(read(0, key), read(1, key)) == [Result.TOO_OLD, Result.COMMIT] def test_inner_full_words(): with DebugConflictSet() as cs: cs.addWrites(1, write(b"\x3f\x61"), write(b"\x81\x61")) writes = [] for i in range(0x40, 0x81): writes.append(write(bytes([i, 0x61]))) cs.addWrites(2, *writes) cs.check(read(1, b"\x21", b"\xc2")) def test_internal_version_zero(): with DebugConflictSet() as cs: cs.setOldestVersion(0xFFFFFFF0) for i in range(24): cs.addWrites(0xFFFFFFF1, write(bytes([i]))) for i in range(256 - 25, 256): cs.addWrites(0xFFFFFFF1, write(bytes([i]))) cs.addWrites(0x100000000, write(b"\xff")) cs.check(read(0xFFFFFFF1, b"\x00", b"\xff")) def test_decrease_capacity(): # make a Node48, then a Node256 for count in (17, 49): with DebugConflictSet() as cs: for i in range(count): cs.addWrites(1, write(bytes(([0] * 99) + [i]))) # lower its partial key length cs.addWrites(2, write(bytes([0] * 98))) # create work for setOldestVersion for i in range(3, 1000): cs.addWrites(i) # setOldestVersion should decrease the capacity cs.setOldestVersion(1) def test_large(): with DebugConflictSet() as cs: end = 100000 for i in range(end): cs.addWrites(1, write(i.to_bytes(8, byteorder="big"))) cs.addWrites( 2, write((0).to_bytes(8, byteorder="big"), (end).to_bytes(8, byteorder="big")), ) def test_merge_child_node48(): with DebugConflictSet() as cs: cs.addWrites(1, write(b"\x00" * 9)) for i in range(17): cs.addWrites(1, write(b"\x00" * 10 + bytes([i]))) cs.addWrites(1, write(b"\x00" * 8, b"\x00" * 10)) if __name__ == "__main__": # budget "pytest" for ctest integration without pulling in a dependency. You can of course still use pytest in local development. import argparse import inspect import sys parser = argparse.ArgumentParser() subparsers = parser.add_subparsers(dest="command") list_parser = subparsers.add_parser("list") test_parser = subparsers.add_parser("test") test_parser.add_argument("test") test_parser.add_argument("--build-dir") args = parser.parse_args() if args.command == "list": sys.stdout.write( ";".join( name[5:] for name in dir() if name.startswith("test_") and inspect.isfunction(getattr(sys.modules[__name__], name)) ) ) elif args.command == "test": build_dir = args.build_dir globals()["test_" + args.test]()