From 7bc05a2563dde2214393906eedf6cbb8238f0fa0 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Lo=C3=AFc=20GUEZO?= Date: Sat, 12 Jul 2025 11:55:12 +0200 Subject: [PATCH] feat: binary to int conversion --- main.py | 23 +++++++++++++++-------- 1 file changed, 15 insertions(+), 8 deletions(-) diff --git a/main.py b/main.py index 7c44113..5c8c5f4 100644 --- a/main.py +++ b/main.py @@ -10,6 +10,9 @@ def data(size:int, max_val: int): for i in range(max_val + 1) ] +def binatodeci(binary: list[int]): + return sum(val*(2**idx) for idx, val in enumerate(reversed(binary))) + def train_network(network: NeuralNetwork, epochs=10000, learning_rate=0.1, verbose: bool = False, size_data: int = 8, max_val: int = 255): @@ -23,31 +26,35 @@ def train_network(network: NeuralNetwork, epochs=10000, learning_rate=0.1, output = network.forward(bits)[0] loss = (output - target[0]) ** 2 - print(f"Epoch: {epoch}, Loss: {loss:.6f}") + print(f"Epoch: {epoch}, Loss: {loss:.6f} {(loss*100):.6f}%") def main(): - size = 8 + size = 4 max_val = (1 << size) - 1 + epoch_size = 6_500 - network = NeuralNetwork([8, 16, 1]) + network = NeuralNetwork([size, 16, 1]) print("Start training...") - train_network(network, verbose=True, size_data=size, epochs=5_000, max_val=max_val) + train_network(network, verbose=True, size_data=size, epochs=epoch_size, max_val=max_val) print("End training...") while True: - string = input("Enter 8 bit number (ex: 01101001) or 'quit' to close: ") \ + string = input(f"Enter {size} bit number (ex: {''.join([str(random.randint(0, 1)) for i in range(size)])}) or 'quit' to close: ") \ .strip().lower() if (string == 'quit'): break - if (len(string) != 8 or any (char not in '01' for char in string)): - print("Error: please enter exactly 8 bits (only 0 or 1).") + if (len(string) != size or any (char not in '01' for char in string)): + print(f"Error: please enter exactly {size} bits (only 0 or 1).") continue bits_input = [int(char) for char in string] output = network.forward(bits_input)[0] * max_val - print(f"Estimated value: {output} (approx: {round(output)})\n") + print("\n===== Estimated value =====") + print(f"{output} (approx: {round(output)})") + print("\n===== Real value =====") + print(f"{binatodeci(bits_input)}\n") if __name__ == "__main__": main()