diff --git a/day17/part1.py b/day17/part1.py index b6af93a..e8e8440 100644 --- a/day17/part1.py +++ b/day17/part1.py @@ -3,11 +3,12 @@ import numpy as np from scipy import signal +DIMS = 3 lines = [[(1 if x == '#' else 0) for x in x.strip()] for x in open("input.txt")] grid = np.array(lines, dtype=np.byte) -grid = np.expand_dims(grid, axis=0) -kernel = np.ones((3, 3, 3), dtype=np.byte) -kernel[1, 1, 1] = 0 +grid = np.expand_dims(grid, axis=tuple(range(DIMS-2))) +kernel = np.ones((3, )*DIMS, dtype=np.byte) +kernel[(1, )*DIMS] = 0 for iter in range(6): @@ -17,9 +18,8 @@ for iter in range(6): set_active = np.logical_and(grid == 0, neighbors == 3) grid[set_inactive] = 0 grid[set_active] = 1 - x = np.flatnonzero(grid.sum(axis=(1, 2))) - y = np.flatnonzero(grid.sum(axis=(0, 2))) - z = np.flatnonzero(grid.sum(axis=(0, 1))) - grid = grid[min(x):max(x)+1, min(y):max(y)+1, min(z):max(z)+1] + for dim in range(DIMS): + a = np.flatnonzero(grid.sum(axis=tuple(x for x in range(DIMS) if x!=dim))) + grid = grid[(slice(None), )*dim + (slice(min(a), max(a)+1), )] print(np.sum(grid)) diff --git a/day17/part2.py b/day17/part2.py index ea62614..4113f1e 100644 --- a/day17/part2.py +++ b/day17/part2.py @@ -3,12 +3,12 @@ import numpy as np from scipy import signal +DIMS = 4 lines = [[(1 if x == '#' else 0) for x in x.strip()] for x in open("input.txt")] grid = np.array(lines, dtype=np.byte) -grid = np.expand_dims(grid, axis=0) -grid = np.expand_dims(grid, axis=0) -kernel = np.ones((3, 3, 3, 3), dtype=np.byte) -kernel[1, 1, 1, 1] = 0 +grid = np.expand_dims(grid, axis=tuple(range(DIMS-2))) +kernel = np.ones((3, )*DIMS, dtype=np.byte) +kernel[(1, )*DIMS] = 0 for iter in range(6): @@ -18,10 +18,8 @@ for iter in range(6): set_active = np.logical_and(grid == 0, neighbors == 3) grid[set_inactive] = 0 grid[set_active] = 1 - x = np.flatnonzero(grid.sum(axis=(1, 2, 3))) - y = np.flatnonzero(grid.sum(axis=(0, 2, 3))) - z = np.flatnonzero(grid.sum(axis=(0, 1, 3))) - w = np.flatnonzero(grid.sum(axis=(0, 1, 2))) - grid = grid[min(x):max(x)+1, min(y):max(y)+1, min(z):max(z)+1, min(w):max(w)+1] + for dim in range(DIMS): + a = np.flatnonzero(grid.sum(axis=tuple(x for x in range(DIMS) if x!=dim))) + grid = grid[(slice(None), )*dim + (slice(min(a), max(a)+1), )] print(np.sum(grid))