starting setup for AI Training - made a connect4env class

This commit is contained in:
Vincent Rodley 2025-08-07 10:08:20 +12:00
parent 5255ec30c9
commit 760c412a45
2 changed files with 74 additions and 40 deletions

74
connect4_env.py Normal file
View File

@ -0,0 +1,74 @@
import numpy as np
class Connect4Env:
ROWS = 6
COLS = 7
EMPTY = 0
PLAYER1 = 1
PLAYER2 = -1
def __init__(self):
self.board = np.zeros((self.ROWS, self.COLS), dtype=np.int8)
self.current_player = self.PLAYER1
def reset(self):
self.board[:] = 0
self.current_player = self.PLAYER1
return self.get_state()
def get_state(self):
return self.board.copy(), self.current_player
def available_actions(self):
return [col for col in range(self.COLS) if self.board[0, col] == 0]
def step(self, action):
if action not in self.available_actions():
raise ValueError("Invalid action")
for row in reversed(range(self.ROWS)):
if self.board[row][action] == 0:
self.board[row][action] = self.current_player
break
done, winner = self.check_win()
reward = 0
if done:
if winner == 0:
reward = 0 # draw
elif winner == self.current_player:
reward = 1
else:
reward = -1
self.current_player *= -1
return self.get_state(), reward, done
def check_win(self):
for row in range(self.ROWS):
for col in range(self.COLS - 3):
line = self.board[row, col:col + 4]
if abs(sum(line)) == 4:
return True, np.sign(sum(line))
for row in range(self.ROWS - 3):
for col in range(self.COLS):
line = self.board[row:row + 4, col]
if abs(sum(line)) == 4:
return True, np.sign(sum(line))
for row in range(self.ROWS - 3):
for col in range(self.COLS - 3):
diag = [self.board[row + i][col + i] for i in range(4)]
if abs(sum(diag)) == 4:
return True, np.sign(sum(diag))
for row in range(3, self.ROWS):
for col in range(self.COLS - 3):
diag = [self.board[row - i][col + i] for i in range(4)]
if abs(sum(diag)) == 4:
return True, np.sign(sum(diag))
if all(self.board[0, :] != 0):
return True, 0 # draw
return False, None

40
main.py
View File

@ -130,51 +130,11 @@ def play_lan_server():
input("Press Enter to return to menu...") input("Press Enter to return to menu...")
return return
HOST, PORT = "0.0.0.0", 65432
try:
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
s.bind((HOST, PORT))
s.listen()
print("Waiting for player 2...")
conn, addr = s.accept()
with conn:
print(f"Connected by {addr}")
play_game(
lambda p, b: send_and_return_local_move(p, b, conn),
lambda p, b: socket_receive_move(conn)
)
except:
print("Somebody broke something. Try again.")
input("Press ENTER to return to the menu.")
finally:
s.close()
def play_lan_client(): def play_lan_client():
print("PvP LAN is in maintenance due to exploits.!") print("PvP LAN is in maintenance due to exploits.!")
input("Press Enter to return to menu...") input("Press Enter to return to menu...")
return return
while True:
HOST, PORT = input("Enter server IP: "), 65432
try:
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
s.connect((HOST, PORT))
print("Connected to server.")
play_game(
lambda p, b: socket_receive_move(s),
lambda p, b: send_and_return_local_move(p, b, s)
)
break
except ConnectionRefusedError:
print("No game found on that IP. Try again.")
except ConnectionResetError or ValueError:
print("The game was closed by host (I think).")
def send_and_return_local_move(player, board, sock):
col = local_move_provider(player, board)
socket_send_move(sock, col)
return col
def play_vs_computer(): def play_vs_computer():
print("PvC mode coming soon!") print("PvC mode coming soon!")
input("Press Enter to return to menu...") input("Press Enter to return to menu...")