Core API¶
Main classes for configuration, model construction, simulation, and signal handling.
Configuration¶
jaxfne.Configuration()
Declarative TFNE model configuration. Methods return new objects, enabling immutable, chainable construction.
Methods¶
runtime(**kwargs) -> Configuration¶
Set runtime/simulation metadata in chainable form.
Parameters:
- seed (int): Random seed for reproducible PRNG
- dtype (str): JAX dtype, e.g., "float32" or "float64"
- duration_ms (float): Total simulation duration in milliseconds
- dt_ms (float): Timestep in milliseconds
Returns: Updated Configuration
Example:
cfg = jtfne.Configuration()
cfg = cfg.runtime(seed=7, dtype="float32", duration_ms=1000.0, dt_ms=0.1)
column(name: str, layers: list, n: int) -> Configuration¶
Declare one cortical column with specified layers and neuron count.
Parameters:
- name (str): Unique column name
- layers (list[str]): Layer labels (e.g., ["L2/3", "L4", "L5"])
- n (int): Number of neurons in the column
Returns: Updated Configuration
Example:
cfg = cfg.column("V1", layers=["L2/3", "L4", "L5"], n=100)
cell_types(cell_type_map: dict) -> Configuration¶
Define the proportion or count of excitatory/inhibitory cell types.
Parameters:
- cell_type_map (dict): E/I distribution, e.g., {"E": 0.8, "I": 0.2} or {"E": 80, "I": 20}
Returns: Updated Configuration
Example:
cfg = cfg.cell_types({"E": 0.8, "I": 0.2})
connectivity(**kwargs) -> Configuration¶
Declare synaptic connectivity parameters (recurrence, plasticity, etc.).
Returns: Updated Configuration
Example:
cfg = cfg.connectivity()
set_emitter(family: str, preset: str) -> Configuration¶
Set the neuron model family and parameter preset.
Parameters:
- family (str): Emitter family (e.g., "izhikevich")
- preset (str): Parameter preset (e.g., "cortical_eig", "tonic_spiking")
Returns: Updated Configuration
Example:
cfg = cfg.set_emitter("izhikevich", "cortical_eig")
probes(modes: list, **kwargs) -> Configuration¶
Declare multimodal proxy probe modes (SPK, Vm, LFP-proxy, CSD-proxy, etc.).
Parameters:
- modes (list[str]): Probe modes, e.g., ["MUA-proxy", "LFP-proxy", "CSD-proxy"]
- name (str, optional): Probe name (default: "multimodal_probe")
Returns: Updated Configuration
Example:
cfg = cfg.probes(["MUA-proxy", "source-proxy", "LFP-proxy"])
Model¶
jaxfne.Model
Constructed TFNE workflow ready for simulation. Created via jaxfne.construct(cfg).
Attributes¶
cfg(Configuration): Source configurationgeometry(LaminarSourceGeometry): Spatial geometrybasis_spec(BasisSpec): Basis function specification
Methods¶
simulate(signals_or_params, ...) -> Signals¶
Run the neural simulation.
Parameters:
- duration_ms (float): Simulation duration
- dt_ms (float): Timestep
- seed (int, optional): Random seed
Returns: Signals object containing spikes, voltage, and field readouts
Example:
model = jtfne.construct(cfg)
signals = jtfne.simulate(model, duration_ms=1000.0, dt_ms=0.1, seed=7)
compute_readout(signals, readout_specs) -> ReadoutResult¶
Compute metrics from signals according to specifications.
Parameters:
- signals (Signals): Output from simulation
- readout_specs (list[ReadoutSpec]): Metric specifications
Returns: ReadoutResult with computed metrics
Example:
readouts = model.compute_readout(signals, [
jtfne.readout_spec("rate", "spike_rate_hz"),
jtfne.readout_spec("voltage", "mean_V_m")
])
Simulation¶
jaxfne.Simulation
Simulation run parameters.
Attributes¶
duration_ms(float): Total duration in millisecondsdt_ms(float): Timestep in millisecondsseed(int, optional): Random seed
Example:
sim = jtfne.simulation(duration_ms=1000.0, dt_ms=0.1, seed=7)
Signals¶
jaxfne.Signals
Neural signals output from simulation.
Attributes¶
V_m(jax.Array): Membrane voltage [time, neurons]spikes(jax.Array): Spike raster [time, neurons]I_syn(jax.Array, optional): Synaptic currentsource(jax.Array, optional): Source density [time, locations]LFP(jax.Array, optional): Local field potential [time, locations]CSD(jax.Array, optional): Current source density [time, locations]EEG(jax.Array, optional): EEG proxy [time, channels]MEG(jax.Array, optional): MEG proxy [time, channels]EMM(jax.Array, optional): Metabolic proxy [time]
Methods¶
to_dict() -> dict¶
Convert signals to JSON-safe dictionary.
save_json(path: str)¶
Save signals to JSON file.
Example:
signals.save_json("output.json")
ReadoutSpec¶
jaxfne.readout_spec(name: str, metric: str, **kwargs)
Specification for a single readout metric.
Parameters:
- name (str): Human-readable name
- metric (str): Metric key (see Probe Operators)
Returns: ReadoutSpec object
Available metrics:
- spike_rate_hz: Mean firing rate in Hz
- mean_V_m: Mean membrane voltage (mV)
- mean_source: Mean source density
- mean_LFP: Mean LFP-proxy magnitude (unscaled)
- mean_CSD: Mean CSD-proxy magnitude (unscaled)
- max_spike_rate_hz: Maximum firing rate
- burst_frequency_hz: Burst frequency estimate
Example:
readout = jtfne.readout_spec("firing_rate", "spike_rate_hz")
ReadoutResult¶
jaxfne.ReadoutResult
Container for computed readout metrics.
Attributes¶
specs(list[ReadoutSpec]): Original specificationsresults(dict[str, float]): Computed metric values
Methods¶
to_dict() -> dict¶
Convert results to JSON-safe dictionary.
Example:
results_dict = readout.to_dict()
Objective¶
jaxfne.objective(name: str, metric: str, target: float, **kwargs)
Optimization objective specification.
Parameters:
- name (str): Objective name
- metric (str): Target metric (from ReadoutSpec)
- target (float): Target value
- weight (float, optional): Loss weight in multi-objective optimization
Returns: Objective object
Example:
obj = jtfne.objective(name="spike_rate", metric="spike_rate_hz", target=10.0)
Helper Functions¶
construct(cfg: Configuration) -> Model¶
Build a compiled Model from Configuration.
Parameters:
- cfg (Configuration): Declarative configuration
Returns: Compiled Model ready for simulation
Example:
model = jtfne.construct(cfg)
simulate(model: Model, duration_ms: float, dt_ms: float, seed: int) -> Signals¶
Run simulation on a Model.
Parameters:
- model (Model): Constructed model
- duration_ms (float): Duration in milliseconds
- dt_ms (float): Timestep in milliseconds
- seed (int): Random seed
Returns: Signals object
Example:
signals = jtfne.simulate(model, duration_ms=1000.0, dt_ms=0.1, seed=7)
simulation(**kwargs) -> Simulation¶
Create a Simulation object.
Example:
sim = jtfne.simulation(duration_ms=1000.0, dt_ms=0.1, seed=7)
configuration() -> Configuration¶
Create a new Configuration object.
Example:
cfg = jtfne.configuration()