#!/usr/bin/env python3 from collections import defaultdict lines = (x.strip() for x in open("input.txt")) paths = defaultdict(set) stored_sum = {} for line in lines: a, b = line.split("-") paths[a].add(b) paths[b].add(a) def visit(node, visited): store_key = hash((node, frozenset(hash((k, v)) for k, v in visited.items()))) if node == 'end': return 1 elif store_key not in stored_sum: has_twice = max(visited.values()) > 1 stored_sum[store_key] = sum( visit(nxt, {**visited, nxt: visited.get(nxt, 0) + 1} if nxt != nxt.upper() else visited) for nxt in paths[node].difference( set(k for k, v in visited.items() if v >= (1 if has_twice else 2) or (v == 1 and k == 'start')) ) ) return stored_sum[store_key] print(visit('start', {'start': 1}))