#!/usr/bin/env python3 from collections import defaultdict lines = (x.strip() for x in open("input.txt")) paths = defaultdict(set) stored_sum = {} for line in lines: a, b = line.split("-") paths[a].add(b) paths[b].add(a) def visit(node, visited): store_key = hash((node, frozenset(visited))) if node == 'end': return 1 elif store_key not in stored_sum: stored_sum[store_key] = sum( visit(nxt, visited.union([nxt]) if nxt != nxt.upper() else visited) for nxt in paths[node].difference(visited) ) return stored_sum[store_key] print(visit('start', {'start'}))