from inspect import getsource from utils import argmax, argmin from games import TicTacToe, alphabeta_player, random_player, Fig52Extended, infinity from logic import parse_definite_clause, standardize_variables, unify, subst from learning import DataSet from IPython.display import HTML, display from collections import Counter, defaultdict import matplotlib.pyplot as plt import numpy as np from PIL import Image import os, struct import array import time #______________________________________________________________________________ # Magic Words def pseudocode(algorithm): """Print the pseudocode for the given algorithm.""" from urllib.request import urlopen from IPython.display import Markdown algorithm = algorithm.replace(' ', '-') url = "https://raw.githubusercontent.com/aimacode/aima-pseudocode/master/md/{}.md".format(algorithm) f = urlopen(url) md = f.read().decode('utf-8') md = md.split('\n', 1)[-1].strip() md = '#' + md return Markdown(md) def psource(*functions): """Print the source code for the given function(s).""" source_code = '\n\n'.join(getsource(fn) for fn in functions) try: from pygments.formatters import HtmlFormatter from pygments.lexers import PythonLexer from pygments import highlight display(HTML(highlight(source_code, PythonLexer(), HtmlFormatter(full=True)))) except ImportError: print(source_code) # ______________________________________________________________________________ # Iris Visualization def show_iris(i=0, j=1, k=2): """Plots the iris dataset in a 3D plot. The three axes are given by i, j and k, which correspond to three of the four iris features.""" from mpl_toolkits.mplot3d import Axes3D plt.rcParams.update(plt.rcParamsDefault) fig = plt.figure() ax = fig.add_subplot(111, projection='3d') iris = DataSet(name="iris") buckets = iris.split_values_by_classes() features = ["Sepal Length", "Sepal Width", "Petal Length", "Petal Width"] f1, f2, f3 = features[i], features[j], features[k] a_setosa = [v[i] for v in buckets["setosa"]] b_setosa = [v[j] for v in buckets["setosa"]] c_setosa = [v[k] for v in buckets["setosa"]] a_virginica = [v[i] for v in buckets["virginica"]] b_virginica = [v[j] for v in buckets["virginica"]] c_virginica = [v[k] for v in buckets["virginica"]] a_versicolor = [v[i] for v in buckets["versicolor"]] b_versicolor = [v[j] for v in buckets["versicolor"]] c_versicolor = [v[k] for v in buckets["versicolor"]] for c, m, sl, sw, pl in [('b', 's', a_setosa, b_setosa, c_setosa), ('g', '^', a_virginica, b_virginica, c_virginica), ('r', 'o', a_versicolor, b_versicolor, c_versicolor)]: ax.scatter(sl, sw, pl, c=c, marker=m) ax.set_xlabel(f1) ax.set_ylabel(f2) ax.set_zlabel(f3) plt.show() # ______________________________________________________________________________ # MNIST def load_MNIST(path="aima-data/MNIST/Digits", fashion=False): import os, struct import array import numpy as np from collections import Counter if fashion: path = "aima-data/MNIST/Fashion" plt.rcParams.update(plt.rcParamsDefault) plt.rcParams['figure.figsize'] = (10.0, 8.0) plt.rcParams['image.interpolation'] = 'nearest' plt.rcParams['image.cmap'] = 'gray' train_img_file = open(os.path.join(path, "train-images-idx3-ubyte"), "rb") train_lbl_file = open(os.path.join(path, "train-labels-idx1-ubyte"), "rb") test_img_file = open(os.path.join(path, "t10k-images-idx3-ubyte"), "rb") test_lbl_file = open(os.path.join(path, 't10k-labels-idx1-ubyte'), "rb") magic_nr, tr_size, tr_rows, tr_cols = struct.unpack(">IIII", train_img_file.read(16)) tr_img = array.array("B", train_img_file.read()) train_img_file.close() magic_nr, tr_size = struct.unpack(">II", train_lbl_file.read(8)) tr_lbl = array.array("b", train_lbl_file.read()) train_lbl_file.close() magic_nr, te_size, te_rows, te_cols = struct.unpack(">IIII", test_img_file.read(16)) te_img = array.array("B", test_img_file.read()) test_img_file.close() magic_nr, te_size = struct.unpack(">II", test_lbl_file.read(8)) te_lbl = array.array("b", test_lbl_file.read()) test_lbl_file.close() #print(len(tr_img), len(tr_lbl), tr_size) #print(len(te_img), len(te_lbl), te_size) train_img = np.zeros((tr_size, tr_rows*tr_cols), dtype=np.int16) train_lbl = np.zeros((tr_size,), dtype=np.int8) for i in range(tr_size): train_img[i] = np.array(tr_img[i*tr_rows*tr_cols : (i+1)*tr_rows*tr_cols]).reshape((tr_rows*te_cols)) train_lbl[i] = tr_lbl[i] test_img = np.zeros((te_size, te_rows*te_cols), dtype=np.int16) test_lbl = np.zeros((te_size,), dtype=np.int8) for i in range(te_size): test_img[i] = np.array(te_img[i*te_rows*te_cols : (i+1)*te_rows*te_cols]).reshape((te_rows*te_cols)) test_lbl[i] = te_lbl[i] return(train_img, train_lbl, test_img, test_lbl) digit_classes = [str(i) for i in range(10)] fashion_classes = ["T-shirt/top", "Trouser", "Pullover", "Dress", "Coat", "Sandal", "Shirt", "Sneaker", "Bag", "Ankle boot"] def show_MNIST(labels, images, samples=8, fashion=False): if not fashion: classes = digit_classes else: classes = fashion_classes num_classes = len(classes) for y, cls in enumerate(classes): idxs = np.nonzero([i == y for i in labels]) idxs = np.random.choice(idxs[0], samples, replace=False) for i , idx in enumerate(idxs): plt_idx = i * num_classes + y + 1 plt.subplot(samples, num_classes, plt_idx) plt.imshow(images[idx].reshape((28, 28))) plt.axis("off") if i == 0: plt.title(cls) plt.show() def show_ave_MNIST(labels, images, fashion=False): if not fashion: item_type = "Digit" classes = digit_classes else: item_type = "Apparel" classes = fashion_classes num_classes = len(classes) for y, cls in enumerate(classes): idxs = np.nonzero([i == y for i in labels]) print(item_type, y, ":", len(idxs[0]), "images.") ave_img = np.mean(np.vstack([images[i] for i in idxs[0]]), axis = 0) #print(ave_img.shape) plt.subplot(1, num_classes, y+1) plt.imshow(ave_img.reshape((28, 28))) plt.axis("off") plt.title(cls) plt.show() # ______________________________________________________________________________ # MDP def make_plot_grid_step_function(columns, rows, U_over_time): """ipywidgets interactive function supports single parameter as input. This function creates and return such a function by taking as input other parameters.""" def plot_grid_step(iteration): data = U_over_time[iteration] data = defaultdict(lambda: 0, data) grid = [] for row in range(rows): current_row = [] for column in range(columns): current_row.append(data[(column, row)]) grid.append(current_row) grid.reverse() # output like book fig = plt.imshow(grid, cmap=plt.cm.bwr, interpolation='nearest') plt.axis('off') fig.axes.get_xaxis().set_visible(False) fig.axes.get_yaxis().set_visible(False) for col in range(len(grid)): for row in range(len(grid[0])): magic = grid[col][row] fig.axes.text(row, col, "{0:.2f}".format(magic), va='center', ha='center') plt.show() return plot_grid_step def make_visualize(slider): """Takes an input a sliderand returns callback function for timer and animation.""" def visualize_callback(Visualize, time_step): if Visualize is True: for i in range(slider.min, slider.max + 1): slider.value = i time.sleep(float(time_step)) return visualize_callback # ______________________________________________________________________________ _canvas = """
""" # noqa class Canvas: """Inherit from this class to manage the HTML canvas element in jupyter notebooks. To create an object of this class any_name_xyz = Canvas("any_name_xyz") The first argument given must be the name of the object being created. IPython must be able to reference the variable name that is being passed.""" def __init__(self, varname, width=800, height=600, cid=None): self.name = varname self.cid = cid or varname self.width = width self.height = height self.html = _canvas.format(self.cid, self.width, self.height, self.name) self.exec_list = [] display_html(self.html) def mouse_click(self, x, y): """Override this method to handle mouse click at position (x, y)""" raise NotImplementedError def mouse_move(self, x, y): raise NotImplementedError def execute(self, exec_str): """Stores the command to be executed to a list which is used later during update()""" if not isinstance(exec_str, str): print("Invalid execution argument:", exec_str) self.alert("Received invalid execution command format") prefix = "{0}_canvas_object.".format(self.cid) self.exec_list.append(prefix + exec_str + ';') def fill(self, r, g, b): """Changes the fill color to a color in rgb format""" self.execute("fill({0}, {1}, {2})".format(r, g, b)) def stroke(self, r, g, b): """Changes the colors of line/strokes to rgb""" self.execute("stroke({0}, {1}, {2})".format(r, g, b)) def strokeWidth(self, w): """Changes the width of lines/strokes to 'w' pixels""" self.execute("strokeWidth({0})".format(w)) def rect(self, x, y, w, h): """Draw a rectangle with 'w' width, 'h' height and (x, y) as the top-left corner""" self.execute("rect({0}, {1}, {2}, {3})".format(x, y, w, h)) def rect_n(self, xn, yn, wn, hn): """Similar to rect(), but the dimensions are normalized to fall between 0 and 1""" x = round(xn * self.width) y = round(yn * self.height) w = round(wn * self.width) h = round(hn * self.height) self.rect(x, y, w, h) def line(self, x1, y1, x2, y2): """Draw a line from (x1, y1) to (x2, y2)""" self.execute("line({0}, {1}, {2}, {3})".format(x1, y1, x2, y2)) def line_n(self, x1n, y1n, x2n, y2n): """Similar to line(), but the dimensions are normalized to fall between 0 and 1""" x1 = round(x1n * self.width) y1 = round(y1n * self.height) x2 = round(x2n * self.width) y2 = round(y2n * self.height) self.line(x1, y1, x2, y2) def arc(self, x, y, r, start, stop): """Draw an arc with (x, y) as centre, 'r' as radius from angles 'start' to 'stop'""" self.execute("arc({0}, {1}, {2}, {3}, {4})".format(x, y, r, start, stop)) def arc_n(self, xn, yn, rn, start, stop): """Similar to arc(), but the dimensions are normalized to fall between 0 and 1 The normalizing factor for radius is selected between width and height by seeing which is smaller.""" x = round(xn * self.width) y = round(yn * self.height) r = round(rn * min(self.width, self.height)) self.arc(x, y, r, start, stop) def clear(self): """Clear the HTML canvas""" self.execute("clear()") def font(self, font): """Changes the font of text""" self.execute('font("{0}")'.format(font)) def text(self, txt, x, y, fill=True): """Display a text at (x, y)""" if fill: self.execute('fill_text("{0}", {1}, {2})'.format(txt, x, y)) else: self.execute('stroke_text("{0}", {1}, {2})'.format(txt, x, y)) def text_n(self, txt, xn, yn, fill=True): """Similar to text(), but with normalized coordinates""" x = round(xn * self.width) y = round(yn * self.height) self.text(txt, x, y, fill) def alert(self, message): """Immediately display an alert""" display_html(''.format(message)) def update(self): """Execute the JS code to execute the commands queued by execute()""" exec_code = "" self.exec_list = [] display_html(exec_code) def display_html(html_string): display(HTML(html_string)) ################################################################################ class Canvas_TicTacToe(Canvas): """Play a 3x3 TicTacToe game on HTML canvas""" def __init__(self, varname, player_1='human', player_2='random', width=300, height=350, cid=None): valid_players = ('human', 'random', 'alphabeta') if player_1 not in valid_players or player_2 not in valid_players: raise TypeError("Players must be one of {}".format(valid_players)) Canvas.__init__(self, varname, width, height, cid) self.ttt = TicTacToe() self.state = self.ttt.initial self.turn = 0 self.strokeWidth(5) self.players = (player_1, player_2) self.font("20px Arial") self.draw_board() def mouse_click(self, x, y): player = self.players[self.turn] if self.ttt.terminal_test(self.state): if 0.55 <= x/self.width <= 0.95 and 6/7 <= y/self.height <= 6/7+1/8: self.state = self.ttt.initial self.turn = 0 self.draw_board() return if player == 'human': x, y = int(3*x/self.width) + 1, int(3*y/(self.height*6/7)) + 1 if (x, y) not in self.ttt.actions(self.state): # Invalid move return move = (x, y) elif player == 'alphabeta': move = alphabeta_player(self.ttt, self.state) else: move = random_player(self.ttt, self.state) self.state = self.ttt.result(self.state, move) self.turn ^= 1 self.draw_board() def draw_board(self): self.clear() self.stroke(0, 0, 0) offset = 1/20 self.line_n(0 + offset, (1/3)*6/7, 1 - offset, (1/3)*6/7) self.line_n(0 + offset, (2/3)*6/7, 1 - offset, (2/3)*6/7) self.line_n(1/3, (0 + offset)*6/7, 1/3, (1 - offset)*6/7) self.line_n(2/3, (0 + offset)*6/7, 2/3, (1 - offset)*6/7) board = self.state.board for mark in board: if board[mark] == 'X': self.draw_x(mark) elif board[mark] == 'O': self.draw_o(mark) if self.ttt.terminal_test(self.state): # End game message utility = self.ttt.utility(self.state, self.ttt.to_move(self.ttt.initial)) if utility == 0: self.text_n('Game Draw!', offset, 6/7 + offset) else: self.text_n('Player {} wins!'.format("XO"[utility < 0]), offset, 6/7 + offset) # Find the 3 and draw a line self.stroke([255, 0][self.turn], [0, 255][self.turn], 0) for i in range(3): if all([(i + 1, j + 1) in self.state.board for j in range(3)]) and \ len({self.state.board[(i + 1, j + 1)] for j in range(3)}) == 1: self.line_n(i/3 + 1/6, offset*6/7, i/3 + 1/6, (1 - offset)*6/7) if all([(j + 1, i + 1) in self.state.board for j in range(3)]) and \ len({self.state.board[(j + 1, i + 1)] for j in range(3)}) == 1: self.line_n(offset, (i/3 + 1/6)*6/7, 1 - offset, (i/3 + 1/6)*6/7) if all([(i + 1, i + 1) in self.state.board for i in range(3)]) and \ len({self.state.board[(i + 1, i + 1)] for i in range(3)}) == 1: self.line_n(offset, offset*6/7, 1 - offset, (1 - offset)*6/7) if all([(i + 1, 3 - i) in self.state.board for i in range(3)]) and \ len({self.state.board[(i + 1, 3 - i)] for i in range(3)}) == 1: self.line_n(offset, (1 - offset)*6/7, 1 - offset, offset*6/7) # restart button self.fill(0, 0, 255) self.rect_n(0.5 + offset, 6/7, 0.4, 1/8) self.fill(0, 0, 0) self.text_n('Restart', 0.5 + 2*offset, 13/14) else: # Print which player's turn it is self.text_n("Player {}'s move({})".format("XO"[self.turn], self.players[self.turn]), offset, 6/7 + offset) self.update() def draw_x(self, position): self.stroke(0, 255, 0) x, y = [i-1 for i in position] offset = 1/15 self.line_n(x/3 + offset, (y/3 + offset)*6/7, x/3 + 1/3 - offset, (y/3 + 1/3 - offset)*6/7) self.line_n(x/3 + 1/3 - offset, (y/3 + offset)*6/7, x/3 + offset, (y/3 + 1/3 - offset)*6/7) def draw_o(self, position): self.stroke(255, 0, 0) x, y = [i-1 for i in position] self.arc_n(x/3 + 1/6, (y/3 + 1/6)*6/7, 1/9, 0, 360) class Canvas_minimax(Canvas): """Minimax for Fig52Extended on HTML canvas""" def __init__(self, varname, util_list, width=800, height=600, cid=None): Canvas.__init__(self, varname, width, height, cid) self.utils = {node:util for node, util in zip(range(13, 40), util_list)} self.game = Fig52Extended() self.game.utils = self.utils self.nodes = list(range(40)) self.l = 1/40 self.node_pos = {} for i in range(4): base = len(self.node_pos) row_size = 3**i for node in [base + j for j in range(row_size)]: self.node_pos[node] = ((node - base)/row_size + 1/(2*row_size) - self.l/2, self.l/2 + (self.l + (1 - 5*self.l)/3)*i) self.font("12px Arial") self.node_stack = [] self.explored = {node for node in self.utils} self.thick_lines = set() self.change_list = [] self.draw_graph() self.stack_manager = self.stack_manager_gen() def minimax(self, node): game = self.game player = game.to_move(node) def max_value(node): if game.terminal_test(node): return game.utility(node, player) self.change_list.append(('a', node)) self.change_list.append(('h',)) max_a = argmax(game.actions(node), key=lambda x: min_value(game.result(node, x))) max_node = game.result(node, max_a) self.utils[node] = self.utils[max_node] x1, y1 = self.node_pos[node] x2, y2 = self.node_pos[max_node] self.change_list.append(('l', (node, max_node - 3*node - 1))) self.change_list.append(('e', node)) self.change_list.append(('p',)) self.change_list.append(('h',)) return self.utils[node] def min_value(node): if game.terminal_test(node): return game.utility(node, player) self.change_list.append(('a', node)) self.change_list.append(('h',)) min_a = argmin(game.actions(node), key=lambda x: max_value(game.result(node, x))) min_node = game.result(node, min_a) self.utils[node] = self.utils[min_node] x1, y1 = self.node_pos[node] x2, y2 = self.node_pos[min_node] self.change_list.append(('l', (node, min_node - 3*node - 1))) self.change_list.append(('e', node)) self.change_list.append(('p',)) self.change_list.append(('h',)) return self.utils[node] return max_value(node) def stack_manager_gen(self): self.minimax(0) for change in self.change_list: if change[0] == 'a': self.node_stack.append(change[1]) elif change[0] == 'e': self.explored.add(change[1]) elif change[0] == 'h': yield elif change[0] == 'l': self.thick_lines.add(change[1]) elif change[0] == 'p': self.node_stack.pop() def mouse_click(self, x, y): try: self.stack_manager.send(None) except StopIteration: pass self.draw_graph() def draw_graph(self): self.clear() # draw nodes self.stroke(0, 0, 0) self.strokeWidth(1) # highlight for nodes in stack for node in self.node_stack: x, y = self.node_pos[node] self.fill(200, 200, 0) self.rect_n(x - self.l/5, y - self.l/5, self.l*7/5, self.l*7/5) for node in self.nodes: x, y = self.node_pos[node] if node in self.explored: self.fill(255, 255, 255) else: self.fill(200, 200, 200) self.rect_n(x, y, self.l, self.l) self.line_n(x, y, x + self.l, y) self.line_n(x, y, x, y + self.l) self.line_n(x + self.l, y + self.l, x + self.l, y) self.line_n(x + self.l, y + self.l, x, y + self.l) self.fill(0, 0, 0) if node in self.explored: self.text_n(self.utils[node], x + self.l/10, y + self.l*9/10) # draw edges for i in range(13): x1, y1 = self.node_pos[i][0] + self.l/2, self.node_pos[i][1] + self.l for j in range(3): x2, y2 = self.node_pos[i*3 + j + 1][0] + self.l/2, self.node_pos[i*3 + j + 1][1] if i in [1, 2, 3]: self.stroke(200, 0, 0) else: self.stroke(0, 200, 0) if (i, j) in self.thick_lines: self.strokeWidth(3) else: self.strokeWidth(1) self.line_n(x1, y1, x2, y2) self.update() class Canvas_alphabeta(Canvas): """Alpha-beta pruning for Fig52Extended on HTML canvas""" def __init__(self, varname, util_list, width=800, height=600, cid=None): Canvas.__init__(self, varname, width, height, cid) self.utils = {node:util for node, util in zip(range(13, 40), util_list)} self.game = Fig52Extended() self.game.utils = self.utils self.nodes = list(range(40)) self.l = 1/40 self.node_pos = {} for i in range(4): base = len(self.node_pos) row_size = 3**i for node in [base + j for j in range(row_size)]: self.node_pos[node] = ((node - base)/row_size + 1/(2*row_size) - self.l/2, 3*self.l/2 + (self.l + (1 - 6*self.l)/3)*i) self.font("12px Arial") self.node_stack = [] self.explored = {node for node in self.utils} self.pruned = set() self.ab = {} self.thick_lines = set() self.change_list = [] self.draw_graph() self.stack_manager = self.stack_manager_gen() def alphabeta_search(self, node): game = self.game player = game.to_move(node) # Functions used by alphabeta def max_value(node, alpha, beta): if game.terminal_test(node): self.change_list.append(('a', node)) self.change_list.append(('h',)) self.change_list.append(('p',)) return game.utility(node, player) v = -infinity self.change_list.append(('a', node)) self.change_list.append(('ab',node, v, beta)) self.change_list.append(('h',)) for a in game.actions(node): min_val = min_value(game.result(node, a), alpha, beta) if v < min_val: v = min_val max_node = game.result(node, a) self.change_list.append(('ab',node, v, beta)) if v >= beta: self.change_list.append(('h',)) self.pruned.add(node) break alpha = max(alpha, v) self.utils[node] = v if node not in self.pruned: self.change_list.append(('l', (node, max_node - 3*node - 1))) self.change_list.append(('e',node)) self.change_list.append(('p',)) self.change_list.append(('h',)) return v def min_value(node, alpha, beta): if game.terminal_test(node): self.change_list.append(('a', node)) self.change_list.append(('h',)) self.change_list.append(('p',)) return game.utility(node, player) v = infinity self.change_list.append(('a', node)) self.change_list.append(('ab',node, alpha, v)) self.change_list.append(('h',)) for a in game.actions(node): max_val = max_value(game.result(node, a), alpha, beta) if v > max_val: v = max_val min_node = game.result(node, a) self.change_list.append(('ab',node, alpha, v)) if v <= alpha: self.change_list.append(('h',)) self.pruned.add(node) break beta = min(beta, v) self.utils[node] = v if node not in self.pruned: self.change_list.append(('l', (node, min_node - 3*node - 1))) self.change_list.append(('e',node)) self.change_list.append(('p',)) self.change_list.append(('h',)) return v return max_value(node, -infinity, infinity) def stack_manager_gen(self): self.alphabeta_search(0) for change in self.change_list: if change[0] == 'a': self.node_stack.append(change[1]) elif change[0] == 'ab': self.ab[change[1]] = change[2:] elif change[0] == 'e': self.explored.add(change[1]) elif change[0] == 'h': yield elif change[0] == 'l': self.thick_lines.add(change[1]) elif change[0] == 'p': self.node_stack.pop() def mouse_click(self, x, y): try: self.stack_manager.send(None) except StopIteration: pass self.draw_graph() def draw_graph(self): self.clear() # draw nodes self.stroke(0, 0, 0) self.strokeWidth(1) # highlight for nodes in stack for node in self.node_stack: x, y = self.node_pos[node] # alpha > beta if node not in self.explored and self.ab[node][0] > self.ab[node][1]: self.fill(200, 100, 100) else: self.fill(200, 200, 0) self.rect_n(x - self.l/5, y - self.l/5, self.l*7/5, self.l*7/5) for node in self.nodes: x, y = self.node_pos[node] if node in self.explored: if node in self.pruned: self.fill(50, 50, 50) else: self.fill(255, 255, 255) else: self.fill(200, 200, 200) self.rect_n(x, y, self.l, self.l) self.line_n(x, y, x + self.l, y) self.line_n(x, y, x, y + self.l) self.line_n(x + self.l, y + self.l, x + self.l, y) self.line_n(x + self.l, y + self.l, x, y + self.l) self.fill(0, 0, 0) if node in self.explored and node not in self.pruned: self.text_n(self.utils[node], x + self.l/10, y + self.l*9/10) # draw edges for i in range(13): x1, y1 = self.node_pos[i][0] + self.l/2, self.node_pos[i][1] + self.l for j in range(3): x2, y2 = self.node_pos[i*3 + j + 1][0] + self.l/2, self.node_pos[i*3 + j + 1][1] if i in [1, 2, 3]: self.stroke(200, 0, 0) else: self.stroke(0, 200, 0) if (i, j) in self.thick_lines: self.strokeWidth(3) else: self.strokeWidth(1) self.line_n(x1, y1, x2, y2) # display alpha and beta for node in self.node_stack: if node not in self.explored: x, y = self.node_pos[node] alpha, beta = self.ab[node] self.text_n(alpha, x - self.l/2, y - self.l/10) self.text_n(beta, x + self.l, y - self.l/10) self.update() class Canvas_fol_bc_ask(Canvas): """fol_bc_ask() on HTML canvas""" def __init__(self, varname, kb, query, width=800, height=600, cid=None): Canvas.__init__(self, varname, width, height, cid) self.kb = kb self.query = query self.l = 1/20 self.b = 3*self.l bc_out = list(self.fol_bc_ask()) if len(bc_out) is 0: self.valid = False else: self.valid = True graph = bc_out[0][0][0] s = bc_out[0][1] while True: new_graph = subst(s, graph) if graph == new_graph: break graph = new_graph self.make_table(graph) self.context = None self.draw_table() def fol_bc_ask(self): KB = self.kb query = self.query def fol_bc_or(KB, goal, theta): for rule in KB.fetch_rules_for_goal(goal): lhs, rhs = parse_definite_clause(standardize_variables(rule)) for theta1 in fol_bc_and(KB, lhs, unify(rhs, goal, theta)): yield ([(goal, theta1[0])], theta1[1]) def fol_bc_and(KB, goals, theta): if theta is None: pass elif not goals: yield ([], theta) else: first, rest = goals[0], goals[1:] for theta1 in fol_bc_or(KB, subst(theta, first), theta): for theta2 in fol_bc_and(KB, rest, theta1[1]): yield (theta1[0] + theta2[0], theta2[1]) return fol_bc_or(KB, query, {}) def make_table(self, graph): table = [] pos = {} links = set() edges = set() def dfs(node, depth): if len(table) <= depth: table.append([]) pos = len(table[depth]) table[depth].append(node[0]) for child in node[1]: child_id = dfs(child, depth + 1) links.add(((depth, pos), child_id)) return (depth, pos) dfs(graph, 0) y_off = 0.85/len(table) for i, row in enumerate(table): x_off = 0.95/len(row) for j, node in enumerate(row): pos[(i, j)] = (0.025 + j*x_off + (x_off - self.b)/2, 0.025 + i*y_off + (y_off - self.l)/2) for p, c in links: x1, y1 = pos[p] x2, y2 = pos[c] edges.add((x1 + self.b/2, y1 + self.l, x2 + self.b/2, y2)) self.table = table self.pos = pos self.edges = edges def mouse_click(self, x, y): x, y = x/self.width, y/self.height for node in self.pos: xs, ys = self.pos[node] xe, ye = xs + self.b, ys + self.l if xs <= x <= xe and ys <= y <= ye: self.context = node break self.draw_table() def draw_table(self): self.clear() self.strokeWidth(3) self.stroke(0, 0, 0) self.font("12px Arial") if self.valid: # draw nodes for i, j in self.pos: x, y = self.pos[(i, j)] self.fill(200, 200, 200) self.rect_n(x, y, self.b, self.l) self.line_n(x, y, x + self.b, y) self.line_n(x, y, x, y + self.l) self.line_n(x + self.b, y, x + self.b, y + self.l) self.line_n(x, y + self.l, x + self.b, y + self.l) self.fill(0, 0, 0) self.text_n(self.table[i][j], x + 0.01, y + self.l - 0.01) #draw edges for x1, y1, x2, y2 in self.edges: self.line_n(x1, y1, x2, y2) else: self.fill(255, 0, 0) self.rect_n(0, 0, 1, 1) # text area self.fill(255, 255, 255) self.rect_n(0, 0.9, 1, 0.1) self.strokeWidth(5) self.stroke(0, 0, 0) self.line_n(0, 0.9, 1, 0.9) self.font("22px Arial") self.fill(0, 0, 0) self.text_n(self.table[self.context[0]][self.context[1]] if self.context else "Click for text", 0.025, 0.975) self.update() ############################################################################################################ ##################### Functions to assist plotting in search.ipynb #################### ############################################################################################################ import networkx as nx import matplotlib.pyplot as plt from matplotlib import lines from ipywidgets import interact import ipywidgets as widgets from IPython.display import display import time from search import GraphProblem, romania_map def show_map(graph_data, node_colors = None): G = nx.Graph(graph_data['graph_dict']) node_colors = node_colors or graph_data['node_colors'] node_positions = graph_data['node_positions'] node_label_pos = graph_data['node_label_positions'] edge_weights= graph_data['edge_weights'] # set the size of the plot plt.figure(figsize=(18,13)) # draw the graph (both nodes and edges) with locations from romania_locations nx.draw(G, pos={k: node_positions[k] for k in G.nodes()}, node_color=[node_colors[node] for node in G.nodes()], linewidths=0.3, edgecolors='k') # draw labels for nodes node_label_handles = nx.draw_networkx_labels(G, pos=node_label_pos, font_size=14) # add a white bounding box behind the node labels [label.set_bbox(dict(facecolor='white', edgecolor='none')) for label in node_label_handles.values()] # add edge lables to the graph nx.draw_networkx_edge_labels(G, pos=node_positions, edge_labels=edge_weights, font_size=14) # add a legend white_circle = lines.Line2D([], [], color="white", marker='o', markersize=15, markerfacecolor="white") orange_circle = lines.Line2D([], [], color="orange", marker='o', markersize=15, markerfacecolor="orange") red_circle = lines.Line2D([], [], color="red", marker='o', markersize=15, markerfacecolor="red") gray_circle = lines.Line2D([], [], color="gray", marker='o', markersize=15, markerfacecolor="gray") green_circle = lines.Line2D([], [], color="green", marker='o', markersize=15, markerfacecolor="green") plt.legend((white_circle, orange_circle, red_circle, gray_circle, green_circle), ('Un-explored', 'Frontier', 'Currently Exploring', 'Explored', 'Final Solution'), numpoints=1, prop={'size':16}, loc=(.8,.75)) # show the plot. No need to use in notebooks. nx.draw will show the graph itself. plt.show() ## helper functions for visualisations def final_path_colors(initial_node_colors, problem, solution): "Return a node_colors dict of the final path provided the problem and solution." # get initial node colors final_colors = dict(initial_node_colors) # color all the nodes in solution and starting node to green final_colors[problem.initial] = "green" for node in solution: final_colors[node] = "green" return final_colors def display_visual(graph_data, user_input, algorithm=None, problem=None): initial_node_colors = graph_data['node_colors'] if user_input == False: def slider_callback(iteration): # don't show graph for the first time running the cell calling this function try: show_map(graph_data, node_colors=all_node_colors[iteration]) except: pass def visualize_callback(Visualize): if Visualize is True: button.value = False global all_node_colors iterations, all_node_colors, node = algorithm(problem) solution = node.solution() all_node_colors.append(final_path_colors(all_node_colors[0], problem, solution)) slider.max = len(all_node_colors) - 1 for i in range(slider.max + 1): slider.value = i #time.sleep(.5) slider = widgets.IntSlider(min=0, max=1, step=1, value=0) slider_visual = widgets.interactive(slider_callback, iteration=slider) display(slider_visual) button = widgets.ToggleButton(value=False) button_visual = widgets.interactive(visualize_callback, Visualize=button) display(button_visual) if user_input == True: node_colors = dict(initial_node_colors) if isinstance(algorithm, dict): assert set(algorithm.keys()).issubset({"Breadth First Tree Search", "Depth First Tree Search", "Breadth First Search", "Depth First Graph Search", "Best First Graph Search", "Uniform Cost Search", "Depth Limited Search", "Iterative Deepening Search", "Greedy Best First Search", "A-star Search", "Recursive Best First Search"}) algo_dropdown = widgets.Dropdown(description="Search algorithm: ", options=sorted(list(algorithm.keys())), value="Breadth First Tree Search") display(algo_dropdown) elif algorithm is None: print("No algorithm to run.") return 0 def slider_callback(iteration): # don't show graph for the first time running the cell calling this function try: show_map(graph_data, node_colors=all_node_colors[iteration]) except: pass def visualize_callback(Visualize): if Visualize is True: button.value = False problem = GraphProblem(start_dropdown.value, end_dropdown.value, romania_map) global all_node_colors user_algorithm = algorithm[algo_dropdown.value] iterations, all_node_colors, node = user_algorithm(problem) solution = node.solution() all_node_colors.append(final_path_colors(all_node_colors[0], problem, solution)) slider.max = len(all_node_colors) - 1 for i in range(slider.max + 1): slider.value = i #time.sleep(.5) start_dropdown = widgets.Dropdown(description="Start city: ", options=sorted(list(node_colors.keys())), value="Arad") display(start_dropdown) end_dropdown = widgets.Dropdown(description="Goal city: ", options=sorted(list(node_colors.keys())), value="Fagaras") display(end_dropdown) button = widgets.ToggleButton(value=False) button_visual = widgets.interactive(visualize_callback, Visualize=button) display(button_visual) slider = widgets.IntSlider(min=0, max=1, step=1, value=0) slider_visual = widgets.interactive(slider_callback, iteration=slider) display(slider_visual) # Function to plot NQueensCSP in csp.py and NQueensProblem in search.py def plot_NQueens(solution): n = len(solution) board = np.array([2 * int((i + j) % 2) for j in range(n) for i in range(n)]).reshape((n, n)) im = Image.open('images/queen_s.png') height = im.size[1] im = np.array(im).astype(np.float) / 255 fig = plt.figure(figsize=(7, 7)) ax = fig.add_subplot(111) ax.set_title('{} Queens'.format(n)) plt.imshow(board, cmap='binary', interpolation='nearest') # NQueensCSP gives a solution as a dictionary if isinstance(solution, dict): for (k, v) in solution.items(): newax = fig.add_axes([0.064 + (k * 0.112), 0.062 + ((7 - v) * 0.112), 0.1, 0.1], zorder=1) newax.imshow(im) newax.axis('off') # NQueensProblem gives a solution as a list elif isinstance(solution, list): for (k, v) in enumerate(solution): newax = fig.add_axes([0.064 + (k * 0.112), 0.062 + ((7 - v) * 0.112), 0.1, 0.1], zorder=1) newax.imshow(im) newax.axis('off') fig.tight_layout() plt.show() # Function to plot a heatmap, given a grid def heatmap(grid, cmap='binary', interpolation='nearest'): fig = plt.figure(figsize=(7, 7)) ax = fig.add_subplot(111) ax.set_title('Heatmap') plt.imshow(grid, cmap=cmap, interpolation=interpolation) fig.tight_layout() plt.show() # Generates a gaussian kernel def gaussian_kernel(l=5, sig=1.0): ax = np.arange(-l // 2 + 1., l // 2 + 1.) xx, yy = np.meshgrid(ax, ax) kernel = np.exp(-(xx**2 + yy**2) / (2. * sig**2)) return kernel # Plots utility function for a POMDP def plot_pomdp_utility(utility): save = utility['0'][0] delete = utility['1'][0] ask_save = utility['2'][0] ask_delete = utility['2'][-1] left = (save[0] - ask_save[0]) / (save[0] - ask_save[0] + ask_save[1] - save[1]) right = (delete[0] - ask_delete[0]) / (delete[0] - ask_delete[0] + ask_delete[1] - delete[1]) colors = ['g', 'b', 'k'] for action in utility: for value in utility[action]: plt.plot(value, color=colors[int(action)]) plt.vlines([left, right], -20, 10, linestyles='dashed', colors='c') plt.ylim(-20, 13) plt.xlim(0, 1) plt.text(left/2 - 0.05, 10, 'Save') plt.text((right + left)/2 - 0.02, 10, 'Ask') plt.text((right + 1)/2 - 0.07, 10, 'Delete') plt.show()