search algorithm, dfs, symmetrical problem

1 2 3
4 5 6
7 8 9

0,0 0,1 0,2
1,0 1,1 1,2
2,0 2,1 2,2

I think implementing this in Lua with a table will be a good idea.

1 -> 2 (right)
  -> 5 (diag)
  -> 4 (down)

the assumption is that we do not know the correct solution.

rules:
you may create straight strikes across the grid until all the numbers have disappeared.

the lower the number of strikes the better. we know it can be done in 4 strikes and 5 strikes.


In [355]:
import numpy as np

class Grid():
    def __init__(self, n, m):
        self.n = n
        self.m = m
        self.pucks = n * m
        self.board = np.int32(2**(n*m)-1)
        self.lines = generate_lines(n, m)
        self.memo = set()

    def print_board(self):
        arr = str(int_to_bits(self.board, 32))[47:64]
        print(arr)

    def do_move(self, move):
        self.board ^= move
        return self.board

    def available_moves(self):
        moves = []
        for line in self.lines:
            mask = sum(1 << i for i in line)
            if (self.board & mask) == mask:  # all bits are still set (i.e. unvisited)
                moves.append(mask)
        return moves

    def search(self, k, depth=0, visited=None):
        if visited is None:
            visited = self.board
    
        if depth == k:
            return int(visited == 0)
    
        count = 0
        for line in self.lines:
            mask = sum(1 << i for i in line)
            if (visited & mask) != mask:
                continue  # can't draw over already used dots
    
            is_first_move = visited == (1 << (self.n * self.m)) - 1
            if is_first_move or (mask & ~visited):  # connects to drawn part
                new_visited = visited ^ mask
                key = canonical(new_visited, self.n, self.m)
                if key in self.memo:
                    continue
                self.memo.add(key)
                count += self.search(k, depth + 1, new_visited)
        return count

def generate_lines(n, m):
    lines = []

    # horizontal
    for i in range(n):
        for j in range(m):
            for length in range(2, m - j + 1):
                line = [i * m + (j + k) for k in range(length)]
                lines.append(line)

    # vertical
    for j in range(m):
        for i in range(n):
            for length in range(2, n - i + 1):
                line = [(i + k) * m + j for k in range(length)]
                lines.append(line)

    # diagonal down-right
    for i in range(n):
        for j in range(m):
            for length in range(2, min(n - i, m - j) + 1):
                line = [(i + k) * m + (j + k) for k in range(length)]
                lines.append(line)

    # diagonal up-right
    for i in range(n):
        for j in range(m):
            for length in range(2, min(i + 1, m - j) + 1):
                line = [(i - k) * m + (j + k) for k in range(length)]
                lines.append(line)

    return lines

def int_to_bits(integer, num_bits=None):
    if num_bits is None:
        num_bits = integer.bit_length()
    binary_string = bin(integer)[2:].zfill(num_bits)
    bit_array = np.array([int(bit) for bit in binary_string], dtype=np.int8)
    return bit_array

def rotate90(board, n, m):
    rotated = 0
    for i in range(n):       # original rows
        for j in range(m):   # original cols
            bit = (board >> (i * m + j)) & 1
            if bit:
                new_i = j
                new_j = n - 1 - i
                new_index = new_i * n + new_j
                rotated |= (1 << new_index)
    return rotated, m, n

def canonical(board, n, m):
    boards = []
    for _ in range(4):
        board, n, m = rotate90(board, n, m)
        boards.append(board)
    return min(boards)

In [356]:
g = Grid(3, 3)
print("Ways with 4 lines:", g.search(4))
print("Ways with 5 lines:", g.search(5))
g.lines

Ways with 4 lines: 0
Ways with 5 lines: 0


[[0, 1],
 [0, 1, 2],
 [1, 2],
 [3, 4],
 [3, 4, 5],
 [4, 5],
 [6, 7],
 [6, 7, 8],
 [7, 8],
 [0, 3],
 [0, 3, 6],
 [3, 6],
 [1, 4],
 [1, 4, 7],
 [4, 7],
 [2, 5],
 [2, 5, 8],
 [5, 8],
 [0, 4],
 [0, 4, 8],
 [1, 5],
 [3, 7],
 [4, 8],
 [3, 1],
 [4, 2],
 [6, 4],
 [6, 4, 2],
 [7, 5]]

In [357]:
g.print_board()
g.memo

1 1 1 1 1 1 1 1 1


{63, 127, 223, 238, 239, 351, 365, 367}

In [298]:
bits = [(np.int32(448) >> i) & 1 for i in range(9)]

In [299]:
bits

[np.int32(0),
 np.int32(0),
 np.int32(0),
 np.int32(0),
 np.int32(0),
 np.int32(0),
 np.int32(1),
 np.int32(1),
 np.int32(1)]

In [308]:
rotate90(np.int32(448),3,3)

(73, 3, 3)

In [320]:
def generate_lines(n, m):
    lines = []

    # horizontal
    for i in range(n):
        for j in range(m):
            for length in range(2, m - j + 1):
                line = [i * m + (j + k) for k in range(length)]
                lines.append(line)

    # vertical
    for j in range(m):
        for i in range(n):
            for length in range(2, n - i + 1):
                line = [(i + k) * m + j for k in range(length)]
                lines.append(line)

    # diagonal (down-right)
    for i in range(n):
        for j in range(m):
            for length in range(2, min(n - i, m - j) + 1):
                line = [(i + k) * m + (j + k) for k in range(length)]
                lines.append(line)

    # diagonal (up-right)
    for i in range(n):
        for j in range(m):
            for length in range(2, min(i + 1, m - j) + 1):
                line = [(i - k) * m + (j + k) for k in range(length)]
                lines.append(line)

    return lines



In [321]:
def search(n, m, k, lines, visited=0, depth=0, last_endpoints=None):
    if depth == k:
        if visited == (1 << (n * m)) - 1:
            return 1  # found a solution
        return 0

    count = 0
    for line in lines:
        bits = sum(1 << i for i in line)
        if visited & bits:
            continue  # already used

        # if first move or connects to previous line
        if last_endpoints is None or any(i in last_endpoints for i in line):
            new_endpoints = [line[0], line[-1]]
            count += search(
                n, m, k, lines,
                visited | bits,
                depth + 1,
                new_endpoints
            )
    return count

In [322]:
lines = generate_lines(3, 3)
ways4 = search(3, 3, 4, lines)
ways5 = search(3, 3, 5, lines)
print(f"Ways with 4 lines: {ways4}")
print(f"Ways with 5 lines: {ways5}")

Ways with 4 lines: 0
Ways with 5 lines: 0


In [323]:
lines

[[0, 1],
 [0, 1, 2],
 [1, 2],
 [3, 4],
 [3, 4, 5],
 [4, 5],
 [6, 7],
 [6, 7, 8],
 [7, 8],
 [0, 3],
 [0, 3, 6],
 [3, 6],
 [1, 4],
 [1, 4, 7],
 [4, 7],
 [2, 5],
 [2, 5, 8],
 [5, 8],
 [0, 4],
 [0, 4, 8],
 [1, 5],
 [3, 7],
 [4, 8],
 [3, 1],
 [4, 2],
 [6, 4],
 [6, 4, 2],
 [7, 5]]

In [362]:
DIRECTIONS = [
    (-1,  0), (1,  0),  # up, down
    (0, -1), (0,  1),   # left, right
    (-1, -1), (1, 1),   # diag ↖ ↘
    (-1,  1), (1, -1),  # diag ↗ ↙
]

def in_bounds(x, y, n, m):
    return 0 <= x < n and 0 <= y < m

def coord_to_index(x, y, m):
    return x * m + y

def generate_straight_moves(n, m):
    moves = {i: [] for i in range(n * m)}
    for i in range(n):
        for j in range(m):
            origin = coord_to_index(i, j, m)
            for dx, dy in DIRECTIONS:
                path = []
                x, y = i + dx, j + dy
                while in_bounds(x, y, n, m):
                    path.append(coord_to_index(x, y, m))
                    moves[origin].append(list(path))  # all partials
                    x += dx
                    y += dy
    return moves

def search_paths(n, m, k):
    moves = generate_straight_moves(n, m)
    results = []

    def dfs(pos, visited, path, segments, last_dir):
        if visited == (1 << (n * m)) - 1:
            if segments == k:
                results.append(path[:])
            return

        for path_segment in moves[pos]:
            next_pos = path_segment[-1]
            bitmask = sum(1 << p for p in path_segment)
            if visited & bitmask:
                continue

            dx = (next_pos % m) - (pos % m)
            dy = (next_pos // m) - (pos // m)
            dirn = (dx and dx // abs(dx), dy and dy // abs(dy))

            path.extend(path_segment)
            if last_dir is None or dirn != last_dir:
                dfs(next_pos, visited | bitmask, path, segments + 1, dirn)
            else:
                dfs(next_pos, visited | bitmask, path, segments, dirn)
            for _ in path_segment:
                path.pop()

    for start in range(n * m):
        dfs(start, 1 << start, [start], 0, None)

    return results

print(search_path(3, 3, 4))  # number of ways to visit all dots in 4 straight lines
print(search_path(3, 3, 5))

0
416


In [361]:
paths = search_paths(3,3,4)
for path in paths:
    print(path)