mirror of
https://github.com/guezoloic/neural-network.git
synced 2026-01-25 03:34:21 +00:00
feat: binary to int conversion
This commit is contained in:
23
main.py
23
main.py
@@ -10,6 +10,9 @@ def data(size:int, max_val: int):
|
|||||||
for i in range(max_val + 1)
|
for i in range(max_val + 1)
|
||||||
]
|
]
|
||||||
|
|
||||||
|
def binatodeci(binary: list[int]):
|
||||||
|
return sum(val*(2**idx) for idx, val in enumerate(reversed(binary)))
|
||||||
|
|
||||||
def train_network(network: NeuralNetwork, epochs=10000, learning_rate=0.1,
|
def train_network(network: NeuralNetwork, epochs=10000, learning_rate=0.1,
|
||||||
verbose: bool = False, size_data: int = 8, max_val: int = 255):
|
verbose: bool = False, size_data: int = 8, max_val: int = 255):
|
||||||
|
|
||||||
@@ -23,31 +26,35 @@ def train_network(network: NeuralNetwork, epochs=10000, learning_rate=0.1,
|
|||||||
output = network.forward(bits)[0]
|
output = network.forward(bits)[0]
|
||||||
loss = (output - target[0]) ** 2
|
loss = (output - target[0]) ** 2
|
||||||
|
|
||||||
print(f"Epoch: {epoch}, Loss: {loss:.6f}")
|
print(f"Epoch: {epoch}, Loss: {loss:.6f} {(loss*100):.6f}%")
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
size = 8
|
size = 4
|
||||||
max_val = (1 << size) - 1
|
max_val = (1 << size) - 1
|
||||||
|
epoch_size = 6_500
|
||||||
|
|
||||||
network = NeuralNetwork([8, 16, 1])
|
network = NeuralNetwork([size, 16, 1])
|
||||||
|
|
||||||
print("Start training...")
|
print("Start training...")
|
||||||
train_network(network, verbose=True, size_data=size, epochs=5_000, max_val=max_val)
|
train_network(network, verbose=True, size_data=size, epochs=epoch_size, max_val=max_val)
|
||||||
print("End training...")
|
print("End training...")
|
||||||
|
|
||||||
while True:
|
while True:
|
||||||
string = input("Enter 8 bit number (ex: 01101001) or 'quit' to close: ") \
|
string = input(f"Enter {size} bit number (ex: {''.join([str(random.randint(0, 1)) for i in range(size)])}) or 'quit' to close: ") \
|
||||||
.strip().lower()
|
.strip().lower()
|
||||||
|
|
||||||
if (string == 'quit'): break
|
if (string == 'quit'): break
|
||||||
if (len(string) != 8 or any (char not in '01' for char in string)):
|
if (len(string) != size or any (char not in '01' for char in string)):
|
||||||
print("Error: please enter exactly 8 bits (only 0 or 1).")
|
print(f"Error: please enter exactly {size} bits (only 0 or 1).")
|
||||||
continue
|
continue
|
||||||
|
|
||||||
bits_input = [int(char) for char in string]
|
bits_input = [int(char) for char in string]
|
||||||
output = network.forward(bits_input)[0] * max_val
|
output = network.forward(bits_input)[0] * max_val
|
||||||
|
|
||||||
print(f"Estimated value: {output} (approx: {round(output)})\n")
|
print("\n===== Estimated value =====")
|
||||||
|
print(f"{output} (approx: {round(output)})")
|
||||||
|
print("\n===== Real value =====")
|
||||||
|
print(f"{binatodeci(bits_input)}\n")
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
main()
|
main()
|
||||||
|
|||||||
Reference in New Issue
Block a user