starting setup for AI Training - made a connect4env class
This commit is contained in:
parent
5255ec30c9
commit
760c412a45
74
connect4_env.py
Normal file
74
connect4_env.py
Normal file
|
|
@ -0,0 +1,74 @@
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
class Connect4Env:
|
||||||
|
ROWS = 6
|
||||||
|
COLS = 7
|
||||||
|
EMPTY = 0
|
||||||
|
PLAYER1 = 1
|
||||||
|
PLAYER2 = -1
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
self.board = np.zeros((self.ROWS, self.COLS), dtype=np.int8)
|
||||||
|
self.current_player = self.PLAYER1
|
||||||
|
|
||||||
|
def reset(self):
|
||||||
|
self.board[:] = 0
|
||||||
|
self.current_player = self.PLAYER1
|
||||||
|
return self.get_state()
|
||||||
|
|
||||||
|
def get_state(self):
|
||||||
|
return self.board.copy(), self.current_player
|
||||||
|
|
||||||
|
def available_actions(self):
|
||||||
|
return [col for col in range(self.COLS) if self.board[0, col] == 0]
|
||||||
|
|
||||||
|
def step(self, action):
|
||||||
|
if action not in self.available_actions():
|
||||||
|
raise ValueError("Invalid action")
|
||||||
|
|
||||||
|
for row in reversed(range(self.ROWS)):
|
||||||
|
if self.board[row][action] == 0:
|
||||||
|
self.board[row][action] = self.current_player
|
||||||
|
break
|
||||||
|
|
||||||
|
done, winner = self.check_win()
|
||||||
|
reward = 0
|
||||||
|
if done:
|
||||||
|
if winner == 0:
|
||||||
|
reward = 0 # draw
|
||||||
|
elif winner == self.current_player:
|
||||||
|
reward = 1
|
||||||
|
else:
|
||||||
|
reward = -1
|
||||||
|
self.current_player *= -1
|
||||||
|
return self.get_state(), reward, done
|
||||||
|
|
||||||
|
def check_win(self):
|
||||||
|
for row in range(self.ROWS):
|
||||||
|
for col in range(self.COLS - 3):
|
||||||
|
line = self.board[row, col:col + 4]
|
||||||
|
if abs(sum(line)) == 4:
|
||||||
|
return True, np.sign(sum(line))
|
||||||
|
|
||||||
|
for row in range(self.ROWS - 3):
|
||||||
|
for col in range(self.COLS):
|
||||||
|
line = self.board[row:row + 4, col]
|
||||||
|
if abs(sum(line)) == 4:
|
||||||
|
return True, np.sign(sum(line))
|
||||||
|
|
||||||
|
for row in range(self.ROWS - 3):
|
||||||
|
for col in range(self.COLS - 3):
|
||||||
|
diag = [self.board[row + i][col + i] for i in range(4)]
|
||||||
|
if abs(sum(diag)) == 4:
|
||||||
|
return True, np.sign(sum(diag))
|
||||||
|
|
||||||
|
for row in range(3, self.ROWS):
|
||||||
|
for col in range(self.COLS - 3):
|
||||||
|
diag = [self.board[row - i][col + i] for i in range(4)]
|
||||||
|
if abs(sum(diag)) == 4:
|
||||||
|
return True, np.sign(sum(diag))
|
||||||
|
|
||||||
|
if all(self.board[0, :] != 0):
|
||||||
|
return True, 0 # draw
|
||||||
|
|
||||||
|
return False, None
|
||||||
40
main.py
40
main.py
|
|
@ -129,51 +129,11 @@ def play_lan_server():
|
||||||
print("PvP LAN is in maintenance due to exploits.!")
|
print("PvP LAN is in maintenance due to exploits.!")
|
||||||
input("Press Enter to return to menu...")
|
input("Press Enter to return to menu...")
|
||||||
return
|
return
|
||||||
|
|
||||||
HOST, PORT = "0.0.0.0", 65432
|
|
||||||
try:
|
|
||||||
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
|
|
||||||
s.bind((HOST, PORT))
|
|
||||||
s.listen()
|
|
||||||
print("Waiting for player 2...")
|
|
||||||
conn, addr = s.accept()
|
|
||||||
with conn:
|
|
||||||
print(f"Connected by {addr}")
|
|
||||||
play_game(
|
|
||||||
lambda p, b: send_and_return_local_move(p, b, conn),
|
|
||||||
lambda p, b: socket_receive_move(conn)
|
|
||||||
)
|
|
||||||
except:
|
|
||||||
print("Somebody broke something. Try again.")
|
|
||||||
input("Press ENTER to return to the menu.")
|
|
||||||
finally:
|
|
||||||
s.close()
|
|
||||||
|
|
||||||
def play_lan_client():
|
def play_lan_client():
|
||||||
print("PvP LAN is in maintenance due to exploits.!")
|
print("PvP LAN is in maintenance due to exploits.!")
|
||||||
input("Press Enter to return to menu...")
|
input("Press Enter to return to menu...")
|
||||||
return
|
return
|
||||||
|
|
||||||
while True:
|
|
||||||
HOST, PORT = input("Enter server IP: "), 65432
|
|
||||||
try:
|
|
||||||
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
|
|
||||||
s.connect((HOST, PORT))
|
|
||||||
print("Connected to server.")
|
|
||||||
play_game(
|
|
||||||
lambda p, b: socket_receive_move(s),
|
|
||||||
lambda p, b: send_and_return_local_move(p, b, s)
|
|
||||||
)
|
|
||||||
break
|
|
||||||
except ConnectionRefusedError:
|
|
||||||
print("No game found on that IP. Try again.")
|
|
||||||
except ConnectionResetError or ValueError:
|
|
||||||
print("The game was closed by host (I think).")
|
|
||||||
|
|
||||||
def send_and_return_local_move(player, board, sock):
|
|
||||||
col = local_move_provider(player, board)
|
|
||||||
socket_send_move(sock, col)
|
|
||||||
return col
|
|
||||||
|
|
||||||
def play_vs_computer():
|
def play_vs_computer():
|
||||||
print("PvC mode coming soon!")
|
print("PvC mode coming soon!")
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue
Block a user