diff --git a/connect4_env.py b/connect4_env.py new file mode 100644 index 0000000..20c678a --- /dev/null +++ b/connect4_env.py @@ -0,0 +1,74 @@ +import numpy as np + +class Connect4Env: + ROWS = 6 + COLS = 7 + EMPTY = 0 + PLAYER1 = 1 + PLAYER2 = -1 + + def __init__(self): + self.board = np.zeros((self.ROWS, self.COLS), dtype=np.int8) + self.current_player = self.PLAYER1 + + def reset(self): + self.board[:] = 0 + self.current_player = self.PLAYER1 + return self.get_state() + + def get_state(self): + return self.board.copy(), self.current_player + + def available_actions(self): + return [col for col in range(self.COLS) if self.board[0, col] == 0] + + def step(self, action): + if action not in self.available_actions(): + raise ValueError("Invalid action") + + for row in reversed(range(self.ROWS)): + if self.board[row][action] == 0: + self.board[row][action] = self.current_player + break + + done, winner = self.check_win() + reward = 0 + if done: + if winner == 0: + reward = 0 # draw + elif winner == self.current_player: + reward = 1 + else: + reward = -1 + self.current_player *= -1 + return self.get_state(), reward, done + + def check_win(self): + for row in range(self.ROWS): + for col in range(self.COLS - 3): + line = self.board[row, col:col + 4] + if abs(sum(line)) == 4: + return True, np.sign(sum(line)) + + for row in range(self.ROWS - 3): + for col in range(self.COLS): + line = self.board[row:row + 4, col] + if abs(sum(line)) == 4: + return True, np.sign(sum(line)) + + for row in range(self.ROWS - 3): + for col in range(self.COLS - 3): + diag = [self.board[row + i][col + i] for i in range(4)] + if abs(sum(diag)) == 4: + return True, np.sign(sum(diag)) + + for row in range(3, self.ROWS): + for col in range(self.COLS - 3): + diag = [self.board[row - i][col + i] for i in range(4)] + if abs(sum(diag)) == 4: + return True, np.sign(sum(diag)) + + if all(self.board[0, :] != 0): + return True, 0 # draw + + return False, None diff --git a/main.py b/main.py index e2efb2b..f90aa4b 100644 --- a/main.py +++ b/main.py @@ -129,51 +129,11 @@ def play_lan_server(): print("PvP LAN is in maintenance due to exploits.!") input("Press Enter to return to menu...") return - - HOST, PORT = "0.0.0.0", 65432 - try: - with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: - s.bind((HOST, PORT)) - s.listen() - print("Waiting for player 2...") - conn, addr = s.accept() - with conn: - print(f"Connected by {addr}") - play_game( - lambda p, b: send_and_return_local_move(p, b, conn), - lambda p, b: socket_receive_move(conn) - ) - except: - print("Somebody broke something. Try again.") - input("Press ENTER to return to the menu.") - finally: - s.close() def play_lan_client(): print("PvP LAN is in maintenance due to exploits.!") input("Press Enter to return to menu...") return - - while True: - HOST, PORT = input("Enter server IP: "), 65432 - try: - with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: - s.connect((HOST, PORT)) - print("Connected to server.") - play_game( - lambda p, b: socket_receive_move(s), - lambda p, b: send_and_return_local_move(p, b, s) - ) - break - except ConnectionRefusedError: - print("No game found on that IP. Try again.") - except ConnectionResetError or ValueError: - print("The game was closed by host (I think).") - -def send_and_return_local_move(player, board, sock): - col = local_move_provider(player, board) - socket_send_move(sock, col) - return col def play_vs_computer(): print("PvC mode coming soon!")