refactor(network.py): annotate forward and backward functions

This commit is contained in:
2026-01-18 18:28:35 +01:00
parent 4cca66db99
commit d2be9952f3

View File

@@ -1,5 +1,6 @@
import math import math
import random import random
from typing import Optional, Callable
def sigmoid(x: float) -> float: def sigmoid(x: float) -> float:
@@ -29,6 +30,50 @@ class Neuron:
# Initialize a shift to the activation threshold with a random bias # Initialize a shift to the activation threshold with a random bias
self._bias: float = random.uniform(-1., 1.) self._bias: float = random.uniform(-1., 1.)
def forward(self, x: list[int], fn: Optional[Callable[[float], float]] = None) -> float:
"""
Execute the neuron's forward pass.
:param x: Description
:param activate: Description
"""
if len(x) != self._input_size:
raise ValueError(
f"Input vertex dimension {len(x)} mismatches the "
f"stored size {self._input_size}")
self._z: float = sum(welement * xelement for welement,
xelement in zip(self._weight, x)) + self._bias
return fn(self._z) if fn is not None else self._z
def backward(self, dz_dw: list[float], dcost_dy: float, learning_rate: float,
fn_deriv: Optional[Callable[[float], float]] = None) -> list[float]:
# Check dimension consistency
if len(dz_dw) != self._input_size:
raise ValueError(
f"Input vertex dimension {len(dz_dw)} mismatches "
f"stored size {self._input_size}")
# Local gradient: dy/dz (defaults to 1.0 for linear identity)
dy_dz: float = fn_deriv(self._z) if fn_deriv is not None else 1.
# Compute common error term (delta)
# dC/dz = dC/dy * dy/dz
delta: float = dcost_dy * dy_dz
# Update weights: weight -= learning_rate * dC/dz * dz/dw
for i in range(self._input_size):
self._weight[i] -= learning_rate * delta * dz_dw[i]
# Update bias: bias -= learning_rate * dC/dz * dz/db (where dz/db = 1)
self._bias -= learning_rate * delta * 1.0
# Return input gradient (dC/dx): dC/dz * dz/dx (where dz/dx = weights)
# This vector allows the error to flow back to previous layers.
return [delta * w for w in self._weight]
def __repr__(self) -> str: def __repr__(self) -> str:
jmp: int = int(math.sqrt(self._input_size)) jmp: int = int(math.sqrt(self._input_size))
text: list[str] = [] text: list[str] = []