Commit 8b70a0c8 authored by Eddie Schoute's avatar Eddie Schoute
Browse files

Fix bugs in routing with TBS

parent adbd4d25
......@@ -5,7 +5,7 @@ import random
from reversal_sort import routing
from reversal_sort.tripartite_binary_sort import tripartite_binary_sort
from reversal_sort.adaptive_tbs import binary_sort_parallel
from reversal_sort.adaptive_tbs import adaptive_tb_sort
"""
Collects routing time data for the algorithms GDC(TBS), GDC(ATBS), and OES.
......@@ -131,6 +131,6 @@ if __name__ == "__main__":
# Check if is permutation
if set(range(len(perm))) != set(perm):
raise ValueError(f"Given permutation does not contain all elements in [0,{len(perm)-1}].")
router = routing.DCRoute(binary_sort_parallel)
router = routing.DCRoute(adaptive_tb_sort)
reversals = router.route(perm)
print(reversals) # TODO: Implement pretty print
\ No newline at end of file
......@@ -35,6 +35,12 @@ class Reversal:
def offset(self, by):
return Reversal(self.beg + by, self.end + by)
def apply(self, perm):
# Perform the given reversal on the given permutation
i = self.beg
j = self.end
perm[i:j+1] = perm[i:j+1][::-1]
def __lt__(self, other):
return self.beg < other.beg
......@@ -45,13 +51,6 @@ class Reversal:
return repr(self)
# Perform the given reversal on the given permutation
def reverse(permutation, reversal):
i = reversal.beg
j = reversal.end
permutation[i:j+1] = permutation[i:j+1][::-1]
# Determines if the reversals in the given list are independent, where perm_length is the length of the permutation
def independent(rev_list, perm_length):
busy = [0]*perm_length
......
from . import reversal
def median(L):
return len(L) // 2
def perm_to_01(L):
"""
TUrns a permutation L[i:j] into a permutation of 0s and 1s, where L[i] is 0 if it is
less than the median and L[i] is 1 if it is greater than the median
"""
return [int(x > (len(L)-1) // 2) for x in L]
# Note that permutations are 0-indexed so we need to shift by one.
return [int(x + 1 > median(L)) for x in L]
class DCRoute:
"""Divide and Conquer routing algorithm"""
......@@ -16,16 +20,19 @@ class DCRoute:
def route(self, L):
"""
Sorts the given permutations L from index i to j (inclusive) using a divide and conquer approach. Returns a list of
Sorts the given permutations L using a divide and conquer approach. Returns a list of
reversals used to perform the sort.
L must contain all elements in [0, len(L)-1] once.
"""
if len(L) <= 1:
return []
T = perm_to_01(L)
revs = self.alg(T)
m = len(T) // 2
reversal.apply_revs(revs, L)
for rev in revs:
rev.apply(L)
m = median(L)
left_revs = self.route(L[:m])
right_perm = [x - m for x in L[m:]]
right_revs = [rev.offset(m) for rev in self.route(right_perm)]
......
......@@ -8,26 +8,31 @@ from . import reversal
def tripartite_binary_sort(T):
"""
Performs TBS (binary version) on T[i,j] inclusive, returns list of reversals
Performs TBS on T, returns list of reversals
"""
if all(T[ind] <= T[ind + 1] for ind in range(len(T) - 1)):
return []
part1, part2 = len(T) // 3, 2 * len(T) // 3
revlist = []
revlist += tripartite_binary_sort(T[:part1])
flipbits(T, part1, part2)
middle_revs = tripartite_binary_sort(T[part1: part2])
revlist += [rev.offset(part1) for rev in middle_revs]
flipbits(T, part1, part2)
end_revs = tripartite_binary_sort(T[part2: len(T)])
revlist += [rev.offset(part2) for rev in end_revs]
oneind, zerind = zero_one_indices(T, 0, len(T)-1)
revlist += tripartite_binary_sort(T[:part1 + 1])
middle_revs = tripartite_binary_sort(negatelist(T[part1 + 1: part2 + 1]))
revlist += [rev.offset(part1 + 1) for rev in middle_revs]
end_revs = tripartite_binary_sort(T[part2 + 1: len(T)])
revlist += [rev.offset(part2 + 1) for rev in end_revs]
# Leave T immutable
copyT = T.copy()
for rev in revlist:
rev.apply(copyT)
oneind, zerind = zero_one_indices(copyT, 0, len(T)-1)
if oneind < zerind:
rev = reversal.Reversal(oneind, zerind)
reversal.reverse(T, rev)
revlist.append(rev)
return revlist
def negatelist(T):
return [int(not x) for x in T]
def flipbits(T, i, j):
for index in range(i,j+1):
T[index] = int(not T[index])
......
import random
from reversal_sort import routing
from unittest import TestCase
from reversal_sort.routing import DCRoute
from reversal_sort.tripartite_binary_sort import tripartite_binary_sort
class TestRouting(TestCase):
def test_perm_sort(self):
def test_route_tbs_simple(self):
L = [0, 1, 3, 2]
L_orig = L.copy()
router = DCRoute(tripartite_binary_sort)
reversals = router.route(L.copy())
for rev in reversals:
rev.apply(L)
self.assertEqual(sorted(L_orig), L)
def test_route_tbs_simple2(self):
L = [1, 3, 0, 4, 2]
L_orig = L.copy()
router = DCRoute(tripartite_binary_sort)
reversals = router.route(L.copy())
for rev in reversals:
rev.apply(L)
self.assertEqual(sorted(L_orig), L)
def test_route_tbs(self):
"""
Testing correctness of sorting permutations with TBS as
bitstring sorting subroutine.
Returns number
"""
n = 1000
router = DCRoute(tripartite_binary_sort)
for ct in range(n):
L = list(range(1,random.randint(5, 200)))
L = list(range(0,random.randint(5, 200)))
random.shuffle(L)
before = L.copy()
routing.routing_divideconquer_tbs(L,0,len(L)-1)
self.assertEqual(sorted(before), L)
\ No newline at end of file
reversals = router.route(L.copy())
L_orig = L.copy()
for rev in reversals:
rev.apply(L)
self.assertEqual(sorted(L_orig), L, msg=f"Did not route {L_orig}")
import random
from reversal_sort import tripartite_binary_sort
from reversal_sort.tripartite_binary_sort import tripartite_binary_sort
from unittest import TestCase
class TestTBS(TestCase):
def test_simple(self):
l = [1, 0, 1, 0]
sortme = l.copy()
reversals = tripartite_binary_sort(sortme)
for reversal in reversals:
reversal.apply(sortme)
self.assertEqual(sortme, sorted(l))
def test_simple2(self):
l = [1, 0, 0, 1, 1]
sortme = l.copy()
reversals = tripartite_binary_sort(sortme)
for reversal in reversals:
reversal.apply(sortme)
self.assertEqual(sortme, sorted(l))
def test_binary_sort(self):
"""
Testing correctness of sorting bitstrings with TBS.
"""
for i in range(1000):
l = [0]*random.randint(100,200) + [1]*random.randint(100,200)
random.shuffle(l)
before = l.copy()
tripartite_binary_sort.tripartite_binary_sort(l, 0 , len(l)-1)
self.assertEqual(sorted(before), l)
\ No newline at end of file
shuffled = l.copy()
random.shuffle(shuffled)
shuffled_orig = shuffled.copy()
reversals = tripartite_binary_sort(shuffled)
for reversal in reversals:
reversal.apply(shuffled)
self.assertEqual(shuffled, l, msg=f"Did not sort {shuffled_orig}")
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment