# pylint: disable=missing-docstring, missing-docstring from ipaddress import IPv4Address import argparse import json import os import socket import time from pathlib import Path from dataclasses import dataclass from cryptography.hazmat.primitives.ciphers import Cipher, algorithms, modes from cryptography.hazmat.primitives import padding @dataclass class User: id: int name: str @dataclass class Message: user_id: int msg: str sleep: int @dataclass class Chat: stop: str users: list[User] messages: list[Message] def encode(message: str) -> bytes: data = message.encode() return data def decode(data: bytes) -> str: text = data.decode() return text def dict_to_chat(chat_dict: dict) -> Chat: """ :raise: KeyError if the json is not valid """ # Check the stopped string stop: str = chat_dict["stop"] # Check the users user_list: list[User] = [] for user in chat_dict["users"]: user_list.append(User(id=user["id"], name=user["name"])) # Check the messages message_list: list[Message] = [] for message in chat_dict["messages"]: # Check if the user id exist user_id = message["user_id"] if user_id not in [user.id for user in user_list]: raise ValueError( f"The user id '{user_id}' is used but is not in the user list" ) message_list.append( Message( user_id=message["user_id"], msg=message["msg"], sleep=message.get("sleep", 1), ) ) return Chat(stop=stop, users=user_list, messages=message_list) def server_connection(host_address: str, port: int) -> socket.socket: s = socket.socket(socket.AF_INET, socket.SOCK_STREAM) s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) s.bind((host_address, port)) s.listen() connection, _ = s.accept() return connection def client_connection(host_address: str, port: int) -> socket.socket: s = socket.socket(socket.AF_INET, socket.SOCK_STREAM) s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) while True: try: s.connect((host_address, port)) return s except ConnectionRefusedError: st = 5 print(f"connection refused. Retrying in {st} seconds...") time.sleep(st) pass def send_key(connection: socket.socket, key: bytes) -> None: if len(key) != 16: raise ValueError("Invalid key size") connection.sendall(b'AES_KEY:' + key) def recv_key(connection: socket.socket) -> bytes: data = connection.recv(1024) if not data: raise OSError("Connection closed") header, key = data.split(b':', maxsplit=1) if header != b'AES_KEY': raise ValueError("Invalid key header") if len(key) != 16: raise ValueError("Invalid key size") return key def send_message(connection: socket.socket, message: str, key: bytes) -> None: encoded_message = message.encode() # Check the message is not too big if len(encoded_message) > 256: raise ValueError("Message is greater than 256 bytes") # Padding # block size = 128 bits padder = padding.PKCS7(128).padder() padded_message = padder.update(encoded_message) + padder.finalize() # Encryption iv = b'\x00' * 16 cipher = Cipher(algorithms.AES(key), modes.CBC(iv)) encryptor = cipher.encryptor() cyphertext = encryptor.update(padded_message) + encryptor.finalize() connection.sendall(b'MESSAGE:' + cyphertext) def recv_message(connection: socket.socket, key: bytes) -> str: data = connection.recv(1024) if not data: raise OSError("Connection closed") try: header, ciphertext = data.split(b':', maxsplit=1) except ValueError as e: print(data) raise e if header != b'MESSAGE': raise ValueError("Invalid message header") # Check the message is not too long if len(ciphertext) > 256: raise ValueError("Invalid message size") # Decryption iv = b'\x00' * 16 cipher = Cipher(algorithms.AES(key), modes.CBC(iv)) decryptor = cipher.decryptor() encoded_message = decryptor.update(ciphertext) + decryptor.finalize() # Unpadding unpadder = padding.PKCS7(128).unpadder() encoded_message = unpadder.update(encoded_message) + unpadder.finalize() message = encoded_message.decode() return message def exchange(chat: Chat, is_server: bool, address: str, port: int): # Oriana is the server by default user_id: int = 1 if is_server else 0 print(f"user_id={user_id}") # Connect to the other host print(f"new connection address={address}, port={port}") connection: socket.socket if is_server: print(f"listening to a client...") connection = server_connection(address, port) else: print(f"connection to server...") connection = client_connection(address, port) print("connected") # Share the secret key secret_key: bytes if is_server: print("send secret key") secret_key = os.urandom(16) send_key(connection, secret_key) else: print("receive secret key") secret_key = recv_key(connection) print(f"secret_key={secret_key!r}") # Exchange messages for message in chat.messages: print(f"next message is: {message}") # Send message if message.user_id == user_id: print("sending the message...") time.sleep(message.sleep) send_message(connection, message.msg, secret_key) # Receive message else: print("receiving the message...") received_message = recv_message(connection, secret_key) if message.msg != received_message: print(f"received_message={received_message}") raise ValueError("Receive message doesn't match the expected message") print("conversation finished") def main() -> None: parser = argparse.ArgumentParser( prog='ctf_spied_conversation.py', description='Launch a vulnerable secured conversation', epilog='CTF ISRI 2025') parser.add_argument('role', choices=['client', 'server']) parser.add_argument('address', type=str) parser.add_argument('port', type=int) parser.add_argument('--path', default="chat.json", type=str) args = parser.parse_args() # Check the role is_server: bool if args.role == "server": is_server = True elif args.role == "client": is_server = False else: raise ValueError(f"The role must be either client or server: {args.role}") print(f"is_server={is_server}") # Check the address try: IPv4Address(args.address) except: raise ValueError(f"{args.address} is not an IPv4 address") print(f"address={args.address}") print(f"port={args.port}") # Load json file to dict chat_dict: dict chat_filepath = Path(args.path) print(f"chat_filepath={chat_filepath}") with open(chat_filepath, encoding="utf-8") as chat_file: print("load json") chat_dict = json.load(chat_file) # Load chat print("convert json to dict") chat: Chat = dict_to_chat(chat_dict) exchange( chat=chat, is_server=is_server, address=args.address, port=args.port) if __name__ == '__main__': main()