neurai.nn package#
Subpackages#
- neurai.nn.conn package
- neurai.nn.layer package
- Submodules
AcitivateCeluEluGeluGluHardSigmoidHardSiluHardSwishHardTanhLeakyReluLogSigmoidLogSoftmaxLogSumexpMishReluRelu6SeluSigmoidSiluSoftSignSoftmaxSoftmax2DSoftminSoftplusTanhEmbedcanonicalize_dtype()masked_softmax()promote_dtype()sequence_mask()transpose_output()transpose_qkv()ConvConv1dConv2dConv3dConvTransposeConvTranspose1dConvTranspose2dConvTranspose3dSNNConv2dSNNConv3dSNNConvTranspose3dcanonicalize_padding()maybe_replicate()DropoutSNNDropout3dExodusexodus_cpu_bwd()exodus_cpu_fwd()FlattenLinearSNNLinear3dSequentialBatchNormBatchNorm1dBatchNorm2dBatchNorm3dLayerNormTdBatchNormTdLayerWeightNormConstantPadReflectionPadReplicationPadZeroPadAvgPoolMaxPoolMinPoolPoolSNNPoolUpSampleNearestALIFCellGRUCellLIFCellLSTMCellRNNRNNCellBaseSRNNSRNNCellBaseSlayerpsp_forward_jvp()
- Module contents
- Submodules
- neurai.nn.neuron package
- neurai.nn.rlayer package
- neurai.nn.synapse package
Submodules#
- class neurai.nn.delay_spike.DelaySpike(max_delay_step=0, spike_size=1, init_data=None, init_val=False, parent=<neurai.nn.module._Sentinel object>, name=None, frozen_params=None)#
Bases:
ModuleMaintain a matrix consisting of every-time-step spikes and manage the spike emitting at every time step.
- Parameters:
max_delay_step (Union[int, jnp.ndarray, np.ndarray], Optional) – Max value of delay step of spikes transmitting between pre and post neuron group, by default 0.
spike_size (int, Optional) – The size of the spike matrix, by default 1.
init_data (jnp.ndarray, Optional) – An optional initial spike matrix data to initialize the matrix, by default None.
init_val (bool, Optional) – Whether to initialize using the initial data. by default False.
pid (int, Optional) – The process id of the neuron, used in multiple process simulations, by default 0.
name (str, Optional) – Name of the module. If not provided, a name will be automatically generated.
parent (Union[Module, VarManager, _Sentinel, None], Optional) – The parent module, by default
neurai.nn.module._Sentinel.frozen_params (dict, Optional) – A dictionary of frozen parameters. If provided, the module will be initialized with these parameters and will not be updated during training.
- reset()#
Reset the DelaySpike instance to its initial state.
This function sets the data matrix to all zeros, the current_spiking_index to 0.
- neurai.nn.functional.flatten(input, start_dim=1, end_dim=-1)#
Initializes a Flatten module that flattens a jnp.ndarray along the specified dimensions.
- Parameters:
- Return type:
- Returns:
jnp.ndarray – The flattened data.
- neurai.nn.functional.interpolate(input, size=None, scale_factor=None, mode='bilinear', align_corners=False)#
Interpolate the input tensor to a specified size or scale.
- Parameters:
input (jnp.ndarray) – The input tensor to be interpolated.
size (tuple or int, Optional) – The target size for interpolation. If provided as a single integer, it’s interpreted as (size, size). Either ‘size’ or ‘scale_factor’ must be provided.
scale_factor (float or tuple, Optional) – The factor by which the input should be scaled. If provided as a single float, it’s interpreted as (scale_factor, scale_factor). Either
sizeorscale_factormust be provided.mode (str, Optional) – The interpolation mode. Supported modes are
nearest,linear,bilinear,bicubic, andtrilinear.align_corners (bool, Optional) – A flag indicating whether to align the corners of the input and output when using bilinear interpolation. Only applicable when mode is
bilinear.
- Returns:
jnp.ndarray – The interpolated tensor.
- Raises:
ValueError – If neither
sizenorscale_factoris sprovided, or if an unsupported interpolationmodeis specified.
Examples
input = jnp.array([[[[1.0, 2.0], [3.0, 4.0]]]]) output = interpolate(input, size=(4, 4), mode='bilinear', align_corners=False) print(output)
Note
When
modeisnearest, the function uses nearest-neighbor interpolation.When
modeislinear,bilinear,bicubic, ortrilinear, the function uses linear interpolation.align_cornersis relevant only forbilinearinterpolation, and it specifies whether to align the corners of the input and output grids. Setting it to True can give more accurate results when aligning grid corners, but it may not be suitable for all use cases.
- class neurai.nn.loss.SpikeLoss(model=None, start_time=0, end_time=100, positive=60, negative=10, time_step=1.0, time_windows=100, psp_fn=None, **psp_fn_args)#
Bases:
objectThis class defines different spike based loss modules that can be used to optimize the SNN.
- Parameters:
model (Callable) – The neural network model.
start_time (int, Optional) – For target region startID, by default 0.
end_time (int, Optional) – For target region stopID, by default 100.
positive (int, Optional) – corresponds to the desired spike count within the target region, where the desired class is true. By default 60.
negative (int, Optional) – corresponds to the desired spike count within the target region, where the desired class is false. By default 10.
time_windows (int, Optional) – time length of sample, by default 100.
psp_fn (Callable) – calculates the error based on the difference between the actual spike activity (spikeOut) and the desired spike activity (spikeDesired).
psp_fn_args – psp_fn function parameter
- numSpikes(predict, target, **kwargs)#
Calculates spike loss based on number of spikes within a target region. For classification tasks, a decision is typically made based on the number of output spikes during an interval rather than the precise timing of the spikes. To handle such cases, the error signal during the interval can be defined as:
\[e^{(n_l)}(t):= ( \int_{T_{int}} S^{(n_l)}( au)d au - \int_{T_{int}} \hat{S}( au)d au), t \in T_{int}\]and zero outside the interval \(T_{int}\).
- Parameters:
predict (jnp.ndarray) – spike
target (jnp.ndarray) – one-hot encoded desired class. Time dimension should be 1 and rest of the dimensions should be same as
predict.kwargs (Any) – any additional keyword arguments, such as numSpikesScale.
- spikeTime(ps, batch_data)#
Calculates spike loss based on spike time. Consider a loss function for the network in time interval t ∈ [0, T], defined as:
\[E:= \int_0^T L(S^{(n_l)}(t), \hat{S}(t)) d{t} = \frac{1}{2}\int_0^T (e^{(n_l)}(S^{(n_l)}(t), \hat{S}(t)))^2 d{t}\]where \(\hat{S}(t)\) is the target spike train, \(L(S^{(n_l)}(t)\), \(\hat{S}(t))\) is the loss at time instance \(t\) and \(e^{(n_l)}(S^{(n_l)}(t), \hat{S}(t))\) is the error singale at final layer. For brevity we will write the error signal as \(e^{(n_l)}(t)\) from here on.
To learn a target spike train \(\hat{S}(t)\) an error signal of the form:
\[e^{(n_l)}(t):= \varepsilon(t) * (S^{(n_l)}(t) - \hat{S}(t))\]The loss is similar to van Rossum distance between output and desired spike train. Where \(\Theta(t)\) is the Heaviside step function.
\[\varepsilon(t) = \frac{t}{\tau_s} e^{1 - \frac{t}{\tau_s}} \Theta(t)\]
- neurai.nn.loss.binary_cross_entropy(pred, label, weight=None, reduction='mean')#
Calculates the binary_cross_entropy loss between predicted and true labels.
- Parameters:
pred (jnp.ndarray) – The predicted between 0 and 1 as a numpy array.
label (jnp.ndarray) – The true labels between 0 and 1 as a numpy array.
weight (jnp.ndarray, Optional) – Manual rescaling weight, it match the shape of input.
reduction (str, Optional) – The reduction mode for the loss value. Default is ‘mean’. Options are ‘mean’, ‘sum’, or ‘none’.
- Returns:
jnp.ndarray – The calculated loss value.
Examples
>>> logits = jnp.array([0.5, 0.6, 0.7, 0.8, 0.9]) >>> labels = jnp.array([0, 1, 0, 1, 0]) >>> binary_cross_entropy(logits, labels, reduction='mean') array(0.9867)
- neurai.nn.loss.hinge_loss(pred, label)#
Calculates the hinge loss between predicted and true labels.
- Parameters:
pred (jnp.ndarray) – The predicted labels as a numpy array.
label (jnp.ndarray) – The true labels as a numpy array.
- Returns:
jnp.ndarray – The calculated loss value.
Examples
>>> pred = np.array([0.8, -0.4, 1.2]) >>> label = np.array([1, -1, 1]) >>> hinge_loss(pred, label) array(0.26666667)
- neurai.nn.loss.huber_loss(logits, labels, reduction='mean', delta=1.0)#
Calculates the huber loss between predicted and true labels.
- Parameters:
logits (jnp.ndarray) – The predicted labels as a numpy array.
labels (jnp.ndarray) – The true labels as a numpy array.
reduction (str, Optional) – The reduction mode for the loss value. Default is ‘mean’. Options are ‘mean’, ‘sum’, or ‘none’.
delta (float, Optional) – The threshold parameter for huber loss. Default is 1.0.
- Returns:
jnp.ndarray – The calculated loss value.
Examples
>>> pred = np.array([0.8, -0.4, 1.2]) >>> label = np.array([1, -1, 1]) >>> huber_loss(pred, label) Array(0.07333334, dtype=float32)
- neurai.nn.loss.l1_loss(logits, labels, reduction='mean')#
Calculates the L1 loss between predicted logits and true labels.
- Parameters:
- Returns:
jnp.ndarray – The calculated loss value.
Examples
>>> logits = np.array([0.8, -0.4, 1.2]) >>> labels = np.array([1, -1, 1]) >>> l1_loss(logits, labels) array(0.33333334)
- neurai.nn.loss.mse_loss(pred, label)#
Calculates the mean squared error (MSE) loss between predicted and true labels.
- Parameters:
pred (jnp.ndarray) – The predicted labels as a numpy array.
label (jnp.ndarray) – The true labels as a numpy array.
- Returns:
jnp.ndarray – The calculated loss value.
Examples
>>> pred = np.array([1, 2, 3]) >>> label = np.array([2, 4, 6]) >>> mse_loss(pred, label) array(4.66666667)
- neurai.nn.loss.sigmoid_binary_cross_entropy(pred, label)#
Calculates the sigmoid_binary_cross_entropy loss between predicted and true labels.
- Parameters:
pred (jnp.ndarray) – The predicted labels as a numpy array.
label (jnp.ndarray) – The true labels as a numpy array.
- Returns:
jnp.ndarray – The calculated loss value.
- neurai.nn.loss.smooth_l1_loss(logits, labels, reduction='mean', beta=1.0)#
Calculates the smooth L1 loss between predicted logits and true labels.
- Parameters:
logits (jax.jnp.ndarray) – The predicted logits as a JAX NumPy array.
labels (jax.jnp.ndarray) – The true labels as a JAX NumPy array.
reduction (str, Optional) – The reduction mode for the loss value. Default is ‘mean’. Options are ‘mean’, ‘sum’, or ‘none’.
beta (float, Optional) – The threshold parameter for smooth L1 loss. Default is 1.0.
- Returns:
jax.jnp.ndarray – The calculated loss value.
Examples
>>> logits = jnp.array([0.8, -0.4, 1.2]) >>> labels = jnp.array([1, -1, 1]) >>> smooth_l1_loss(logits, labels) array(0.07333334)
- neurai.nn.loss.softmax_cross_entropy(pred, label)#
Calculates the cross entropy softmax loss between predicted and true labels.
- Parameters:
pred (jnp.ndarray) – The predicted labels as a numpy array.
label (jnp.ndarray) – The true labels as a numpy array.
- Returns:
jnp.ndarray – The calculated loss value.
Examples
from jax import numpy as jnp from neurai.nn.loss import softmax_cross_entropy pred = jnp.array([[0.5, 0.3, 0.2], [0.8, 0.1, 0.1], [0.2, 0.2, 0.6]]) label = jnp.array([[1, 0, 0], [0, 1, 0], [0, 0, 1]]) softmax_cross_entropy(pred, label)
expect output: Array(1.0599941, dtype=float32)
- neurai.nn.loss.sparse_softmax_cross_entropy(pred, label)#
Calculates the cross entropy softmax loss between predicted and true labels without one-hot label.
- Parameters:
pred (jnp.ndarray) – The predicted labels as a numpy array.
label (jnp.ndarray) – The true labels as a numpy array.
- Returns:
jnp.ndarray – The calculated loss value.
Examples
from jax import numpy as jnp from neurai.nn.loss import sparse_softmax_cross_entropy pred = jnp.array([[0.5, 0.3, 0.2], [0.8, 0.1, 0.1], [0.2, 0.2, 0.6]]) label = jnp.array([[1, 0, 0], [0, 1, 0], [0, 0, 1]]) sparse_softmax_cross_entropy(pred, label)
expect output: Array(1.0599941, dtype=float32)