"""
.. currentmodule:: neet.boolean.wtnetwork
.. testsetup:: wtnetwork
from neet.boolean.examples import s_pombe
from neet.boolean.wtnetwork import WTNetwork
Weight/Threshold Networks
=========================
"""
import numpy as np
import networkx as nx
import re
from neet.statespace import StateSpace
[docs]class WTNetwork(object):
"""
The WTNetwork class represents weight/threshold-based boolean networks. As
such it is specified in terms of a matrix of edge weights (rows are target
nodes) and a vector of node thresholds, and each node of the network is
expected to be in either of two states ``0`` or ``1``.
"""
[docs] def __init__(self, weights, thresholds=None, names=None, theta=None):
"""
Construct a network from weights and thresholds.
.. rubric:: Examples
.. doctest:: wtnetwork
>>> net = WTNetwork([[1,0],[1,1]])
>>> net.size
2
>>> net.weights
array([[1., 0.],
[1., 1.]])
>>> net.thresholds
array([0., 0.])
.. doctest:: wtnetwork
>>> net = WTNetwork([[1,0],[1,1]], [0.5,-0.5])
>>> net.size
2
>>> net.weights
array([[1., 0.],
[1., 1.]])
>>> net.thresholds
array([ 0.5, -0.5])
.. doctest:: wtnetwork
>>> net = WTNetwork(3)
>>> net.size
3
>>> net.weights
array([[0., 0., 0.],
[0., 0., 0.],
[0., 0., 0.]])
>>> net.thresholds
array([0., 0., 0.])
:param weights: the network weights, where: source/column -> target/row
:param thresholds: the network thresholds
:param names: the names of the network nodes (optional)
:parma theta: the threshold function to use
:raises ValueError: if ``weights`` is empty
:raises ValueError: if ``weights`` is not a square matrix
:raises ValueError: if ``thresholds`` is not a vector
:raises ValueError: if ``weights`` and ``thresholds`` have different
dimensions
:raises ValueError: if ``len(names)`` is not equal to the number of
nodes
:raises TypeError: if ``threshold_func`` is not callable
"""
if isinstance(weights, int):
self.weights = np.zeros([weights, weights])
else:
self.weights = np.asarray(weights, dtype=np.float)
shape = self.weights.shape
if self.weights.ndim != 2:
raise(ValueError("weights must be a matrix"))
elif shape[0] != shape[1]:
raise(ValueError("weights must be square"))
if thresholds is None:
self.thresholds = np.zeros(shape[1], dtype=np.float)
else:
self.thresholds = np.asarray(thresholds, dtype=np.float)
self.__size = self.thresholds.size
if isinstance(names, str):
self.names = list(names)
else:
self.names = names
if theta is None:
self.theta = type(self).split_threshold
elif callable(theta):
self.theta = theta
else:
raise(TypeError("theta must be a function"))
if self.thresholds.ndim != 1:
raise(ValueError("thresholds must be a vector"))
elif shape[0] != self.size:
msg = "weights and thresholds have different dimensions"
raise(ValueError(msg))
elif self.size < 1:
raise(ValueError("invalid network size"))
elif names is not None and len(names) != self.size:
msg = "either all or none of the nodes may have a name"
raise(ValueError(msg))
self.metadata = {}
@property
def size(self):
"""
The number of nodes in the network.
.. doctest:: wtnetwork
>>> net = WTNetwork(5)
>>> net.size
5
>>> WTNetwork(0)
Traceback (most recent call last):
...
ValueError: invalid network size
:type: int
"""
return self.__size
[docs] def state_space(self):
"""
Return a :class:`neet.statespace.StateSpace` object for the network.
.. doctest:: wtnetwork
>>> net = WTNetwork(3)
>>> net.state_space()
<neet.statespace.StateSpace object at 0x...>
>>> space = net.state_space()
>>> list(space)
[[0, 0, 0], [1, 0, 0], [0, 1, 0], [1, 1, 0], [0, 0, 1], [1, 0, 1], [0, 1, 1], [1, 1, 1]]
:returns: the network's :class:`neet.statespace.StateSpace`
"""
return StateSpace(self.size, base=2)
def _unsafe_update(self, states, index=None, pin=None, values=None):
"""
Update ``states``, in place, according to the network update rules
without checking the validity of the arguments.
.. rubric:: Basic Use
.. doctest:: wtnetwork
>>> s_pombe.size
9
>>> xs = [0,0,0,0,1,0,0,0,0]
>>> s_pombe._unsafe_update(xs)
[0, 0, 0, 0, 0, 0, 0, 0, 1]
>>> s_pombe._unsafe_update(xs)
[0, 1, 1, 1, 0, 0, 1, 0, 0]
.. rubric:: Single-Node Update
.. doctest:: wtnetwork
>>> xs = [0,0,0,0,1,0,0,0,0]
>>> net._unsafe_update(xs, index=-1)
[0, 0, 0, 0, 1, 0, 0, 0, 1]
>>> net._unsafe_update(xs, index=2)
[0, 0, 1, 0, 1, 0, 0, 0, 1]
>>> net._unsafe_update(xs, index=3)
[0, 0, 1, 1, 1, 0, 0, 0, 1]
.. rubric:: State Pinning
.. doctest:: wtnetwork
>>> net._unsafe_update([0,0,0,0,1,0,0,0,0], pin=[-1])
[0, 0, 0, 0, 0, 0, 0, 0, 0]
>>> net._unsafe_update([0,0,0,0,0,0,0,0,1], pin=[1])
[0, 0, 1, 1, 0, 0, 1, 0, 0]
>>> net._unsafe_update([0,0,0,0,0,0,0,0,1], pin=range(1,4))
[0, 0, 0, 0, 0, 0, 1, 0, 0]
>>> net._unsafe_update([0,0,0,0,0,0,0,0,1], pin=[1,2,3,-1])
[0, 0, 0, 0, 0, 0, 1, 0, 1]
.. rubric:: Value Fixing
.. doctest:: wtnetwork
>>> net.update([0,0,0,0,1,0,0,0,0], values={0:1, 2:1})
[1, 0, 1, 0, 0, 0, 0, 0, 1]
>>> net.update([0,0,0,0,0,0,0,0,1], values={0:1, 1:0, 2:0})
[1, 0, 0, 1, 0, 0, 1, 0, 0]
>>> net.update([0,0,0,0,0,0,0,0,1], values={-1:1, -2:1})
[0, 1, 1, 1, 0, 0, 1, 1, 1]
.. rubric:: Erroneous Usage
.. doctest:: wtnetwork
>>> net._unsafe_update([0,0,0])
Traceback (most recent call last):
...
ValueError: shapes (9,9) and (3,) not aligned: 9 (dim 1) != 3 (dim 0)
>>> net._unsafe_update([0,0,0,0,2,0,0,0,0])
[0, 0, 0, 0, 0, 0, 0, 0, 1]
>>> net._unsafe_update([0,0,0,0,1,0,0,0,0], 9)
Traceback (most recent call last):
...
IndexError: index 9 is out of bounds for axis 0 with size 9
>>> net._unsafe_update([0,0,0,0,0,0,0,0,1], pin=[10])
Traceback (most recent call last):
...
IndexError: index 10 is out of bounds for axis 1 with size 9
:param states: the one-dimensional sequence of node states
:param index: the index to update or None
:param pin: the indices to pin (fix to their current state) or None
:param values: a dictionary of index-value pairs to fix after update
:returns: the updated states
"""
pin_states = pin is not None and pin != []
if index is None:
if pin_states:
pinned = np.asarray(states)[pin]
temp = np.dot(self.weights, states) - self.thresholds
self.theta(temp, states)
if pin_states:
for (j, i) in enumerate(pin):
states[i] = pinned[j]
else:
temp = np.dot(self.weights[index], states) - self.thresholds[index]
states[index] = self.theta(temp, states[index])
if values is not None:
for key in values:
states[key] = values[key]
return states
[docs] def update(self, states, index=None, pin=None, values=None):
"""
Update ``states``, in place, according to the network update rules.
.. rubric:: Basic Use
.. doctest:: wtnetwork
>>> s_pombe.size
9
>>> xs = [0,0,0,0,1,0,0,0,0]
>>> s_pombe.update(xs)
[0, 0, 0, 0, 0, 0, 0, 0, 1]
>>> s_pombe.update(xs)
[0, 1, 1, 1, 0, 0, 1, 0, 0]
.. rubric:: Single-Node Update
.. doctest:: wtnetwork
>>> xs = [0,0,0,0,1,0,0,0,0]
>>> s_pombe.update(xs, index=-1)
[0, 0, 0, 0, 1, 0, 0, 0, 1]
>>> s_pombe.update(xs, index=2)
[0, 0, 1, 0, 1, 0, 0, 0, 1]
>>> s_pombe.update(xs, index=3)
[0, 0, 1, 1, 1, 0, 0, 0, 1]
.. rubric:: State Pinning
.. doctest:: wtnetwork
>>> s_pombe.update([0,0,0,0,1,0,0,0,0], pin=[-1])
[0, 0, 0, 0, 0, 0, 0, 0, 0]
>>> s_pombe.update([0,0,0,0,0,0,0,0,1], pin=[1])
[0, 0, 1, 1, 0, 0, 1, 0, 0]
>>> s_pombe.update([0,0,0,0,0,0,0,0,1], pin=range(1,4))
[0, 0, 0, 0, 0, 0, 1, 0, 0]
>>> s_pombe.update([0,0,0,0,0,0,0,0,1], pin=[1,2,3,-1])
[0, 0, 0, 0, 0, 0, 1, 0, 1]
.. rubric:: Value Fixing
.. doctest:: wtnetwork
>>> s_pombe.update([0,0,0,0,1,0,0,0,0], values={0:1, 2:1})
[1, 0, 1, 0, 0, 0, 0, 0, 1]
>>> s_pombe.update([0,0,0,0,0,0,0,0,1], values={0:1, 1:0, 2:0})
[1, 0, 0, 1, 0, 0, 1, 0, 0]
>>> s_pombe.update([0,0,0,0,0,0,0,0,1], values={-1:1, -2:1})
[0, 1, 1, 1, 0, 0, 1, 1, 1]
.. rubric:: Erroneous Usage
.. doctest:: wtnetwork
>>> s_pombe.update([0,0,0])
Traceback (most recent call last):
...
ValueError: incorrect number of states in array
>>> s_pombe.update([0,0,0,0,2,0,0,0,0])
Traceback (most recent call last):
...
ValueError: invalid node state in states
>>> s_pombe.update([0,0,0,0,1,0,0,0,0], 9)
Traceback (most recent call last):
...
IndexError: index 9 is out of bounds for axis 0 with size 9
>>> s_pombe.update([0,0,0,0,1,0,0,0,0], index=-1, pin=[-1])
Traceback (most recent call last):
...
ValueError: cannot provide both the index and pin arguments
>>> s_pombe.update([0,0,0,0,1,0,0,0,0], pin=[10])
Traceback (most recent call last):
...
IndexError: index 10 is out of bounds for axis 1 with size 9
>>> s_pombe.update([0,0,0,0,1,0,0,0,0], index=1, values={1:0,3:0,2:1})
Traceback (most recent call last):
...
ValueError: cannot provide both the index and values arguments
>>> s_pombe.update([0,0,0,0,1,0,0,0,0], pin=[1], values={1:0,3:0,2:1})
Traceback (most recent call last):
...
ValueError: cannot set a value for a pinned state
>>> s_pombe.update([0,0,0,0,1,0,0,0,0], values={1:2})
Traceback (most recent call last):
...
ValueError: invalid state in values argument
:param states: the one-dimensional sequence of node states
:param index: the index to update (or None)
:param pin: the indices to pin (or None)
:param values: a dictionary of index-value pairs to set after update
:returns: the updated states
:raises ValueError: if ``states`` is not in the network's state space
:raises ValueError: if ``index`` and ``pin`` are both provided
:raises ValueError: if ``index`` and ``values`` are both provided
:raises ValueError: if an element of ``pin`` is a key in ``values``
:raises ValueError: if a value in ``values`` is not binary (0 or 1)
"""
if states not in self.state_space():
raise ValueError(
"the provided state is not in the network's state space")
if index is not None:
if pin is not None and pin != []:
raise ValueError(
"cannot provide both the index and pin arguments")
elif values is not None and values != {}:
raise ValueError(
"cannot provide both the index and values arguments")
elif pin is not None and values is not None:
for k in values.keys():
if k in pin:
raise ValueError("cannot set a value for a pinned state")
if values is not None:
for val in values.values():
if val != 0 and val != 1:
raise ValueError("invalid state in values argument")
return self._unsafe_update(states, index, pin, values)
[docs] @staticmethod
def read(nodes_path, edges_path):
"""
Read a network from a pair of node/edge files.
.. doctest:: wtnetwork
>>> nodes_path = '../neet/boolean/data/s_pombe-nodes.txt'
>>> edges_path = '../neet/boolean/data/s_pombe-edges.txt'
>>> net = WTNetwork.read(nodes_path, edges_path)
>>> net.size
9
>>> net.names
['SK', 'Cdc2_Cdc13', 'Ste9', 'Rum1', 'Slp1', 'Cdc2_Cdc13_active', 'Wee1_Mik1', 'Cdc25', 'PP']
:param nodes_path: path to the nodes file
:type nodes_path: str
:param edges_path: path to the edges file
:type edges_path: str
:returns: a :class:`WTNetwork`
"""
comment = re.compile(r'^\s*#.*$')
names, thresholds = [], []
nameindices, index = dict(), 0
with open(nodes_path, "r") as f:
for line in f.readlines():
if comment.match(line) is None:
name, threshold = line.strip().split()
names.append(name)
nameindices[name] = index
thresholds.append(float(threshold))
index += 1
n = len(names)
weights = np.zeros((n, n), dtype=np.float)
with open(edges_path, "r") as f:
for line in f.readlines():
if comment.match(line) is None:
a, b, w = line.strip().split()
weights[nameindices[b], nameindices[a]] = float(w)
return WTNetwork(weights, thresholds, names)
[docs] @staticmethod
def split_threshold(values, states):
"""
Applies the following functional form to the arguments:
.. math::
\\theta_s(x,y) = \\begin{cases}
0 & x < 0 \\\\
y & x = 0 \\\\
1 & x > 0.
\\end{cases}
If ``values`` and ``states`` are iterable, then apply the above
function to each pair ``(x,y) in zip(values, states)`` and stores
the result in ``states``.
If ``values`` and ``states`` are scalar values, then simply apply
the above threshold function to the pair ``(values, states)`` and
return the result.
.. rubric:: Examples
.. doctest:: wtnetwork
>>> ys = [0,0,0]
>>> WTNetwork.split_threshold([1, -1, 0], ys)
[1, 0, 0]
>>> ys
[1, 0, 0]
>>> ys = [1,1,1]
>>> WTNetwork.split_threshold([1, -1, 0], ys)
[1, 0, 1]
>>> ys
[1, 0, 1]
>>> WTNetwork.split_threshold(0,0)
0
>>> WTNetwork.split_threshold(0,1)
1
>>> WTNetwork.split_threshold(1,0)
1
>>> WTNetwork.split_threshold(1,1)
1
:param values: the threshold-shifted values of each node
:param states: the pre-updated states of the nodes
:returns: the updated states
"""
if isinstance(values, list) or isinstance(values, np.ndarray):
for i, x in enumerate(values):
if x < 0:
states[i] = 0
elif x > 0:
states[i] = 1
return states
else:
if values < 0:
return 0
elif values > 0:
return 1
return states
[docs] @staticmethod
def negative_threshold(values, states):
"""
Applies the following functional form to the arguments:
.. math::
\\theta_n(x) = \\begin{cases}
0 & x \\leq 0 \\\\
1 & x > 0.
\\end{cases}
If ``values`` and ``states`` are iterable, then apply the above
function to each pair ``(x,y) in zip(values, states)`` and stores
the result in ``states``.
If ``values`` and ``states`` are scalar values, then simply apply
the above threshold function to the pair ``(values, states)`` and
return the result.
.. rubric:: Examples
.. doctest:: wtnetwork
>>> ys = [0,0,0]
>>> WTNetwork.negative_threshold([1, -1, 0], ys)
[1, 0, 0]
>>> ys
[1, 0, 0]
>>> ys = [1,1,1]
>>> WTNetwork.negative_threshold([1, -1, 0], ys)
[1, 0, 0]
>>> ys
[1, 0, 0]
>>> WTNetwork.negative_threshold(0,0)
0
>>> WTNetwork.negative_threshold(0,1)
0
>>> WTNetwork.negative_threshold(1,0)
1
>>> WTNetwork.negative_threshold(1,1)
1
:param values: the threshold-shifted values of each node
:param states: the pre-updated states of the nodes
:returns: the updated states
"""
if isinstance(values, list) or isinstance(values, np.ndarray):
for i, x in enumerate(values):
if x <= 0:
states[i] = 0
else:
states[i] = 1
return states
else:
if values <= 0:
return 0
else:
return 1
[docs] @staticmethod
def positive_threshold(values, states):
"""
Applies the following functional form to the arguments:
.. math::
\\theta_p(x) = \\begin{cases}
0 & x < 0 \\\\
1 & x \\geq 0.
\\end{cases}
If ``values`` and ``states`` are iterable, then apply the above
function to each pair ``(x,y) in zip(values, states)`` and stores
the result in ``states``.
If ``values`` and ``states`` are scalar values, then simply apply
the above threshold function to the pair ``(values, states)`` and
return the result.
.. rubric:: Examples
.. doctest:: wtnetwork
>>> ys = [0,0,0]
>>> WTNetwork.positive_threshold([1, -1, 0], ys)
[1, 0, 1]
>>> ys
[1, 0, 1]
>>> ys = [1,1,1]
>>> WTNetwork.positive_threshold([1, -1, 0], ys)
[1, 0, 1]
>>> ys
[1, 0, 1]
>>> WTNetwork.positive_threshold(0,0)
1
>>> WTNetwork.positive_threshold(0,1)
1
>>> WTNetwork.positive_threshold(1,0)
1
>>> WTNetwork.positive_threshold(-1,0)
0
:param values: the threshold-shifted values of each node
:param states: the pre-updated states of the nodes
:returns: the updated states
"""
if isinstance(values, list) or isinstance(values, np.ndarray):
for i, x in enumerate(values):
if x < 0:
states[i] = 0
else:
states[i] = 1
return states
else:
if values < 0:
return 0
else:
return 1
[docs] def neighbors_in(self, index):
"""
Return the set of all neighbor nodes, where edge(neighbor_node-->index)
exists. An important consideration is that some threshold functions
can introduce implicit dependence between nodes, e.g.
:meth:`WTNetwork.split_threshold`.
:param index: node index
:returns: the set of all node indices which point toward the index node
.. rubric:: Examples
.. doctest:: wtnetwork
>>> net = WTNetwork([[0,0,0],[1,0,1],[0,1,0]],
... theta=WTNetwork.split_threshold)
>>> [net.neighbors_in(node) for node in range(net.size)]
[{0}, {0, 1, 2}, {1, 2}]
>>> net.theta = WTNetwork.negative_threshold
>>> [net.neighbors_in(node) for node in range(net.size)]
[set(), {0, 2}, {1}]
"""
negative_thresh = type(self).negative_threshold
positive_thresh = type(self).positive_threshold
if self.theta is negative_thresh or self.theta is positive_thresh:
return set(np.flatnonzero(self.weights[index]))
else:
# Assume every other theta has self loops. This will be depreciated
# when we convert all WTNetworks to logicnetworks by default.
return set(np.flatnonzero(self.weights[index])) | set([index])
[docs] def neighbors_out(self, index):
"""
Return the set of all neighbor nodes, where
edge(index-->neighbor_node) exists.
:param index: node index
:returns: the set of all node indices which the index node points to
.. rubric:: Basic Use
.. doctest:: wtnetwork
>>> net = WTNetwork([[0,0,0],[1,0,1],[0,1,0]],
... theta=WTNetwork.split_threshold)
>>> [net.neighbors_out(node) for node in range(net.size)]
[{0, 1}, {1, 2}, {1, 2}]
>>> net.theta = WTNetwork.negative_threshold
>>> [net.neighbors_out(node) for node in range(net.size)]
[{1}, {2}, {1}]
"""
negative_thresh = type(self).negative_threshold
positive_thresh = type(self).positive_threshold
if self.theta is negative_thresh or self.theta is positive_thresh:
return set(np.flatnonzero(self.weights[:, index]))
else:
# Assume every other theta has self loops. This will be depreciated
# when we convert all WTNetworks to logicnetworks by default.
return set(np.flatnonzero(self.weights[:, index])) | set([index])
[docs] def neighbors(self, index):
"""
Return a set of neighbors for a specified node, or a list of sets of
neighbors for all nodes in the network.
:param index: node index
:returns: a set (if index!=None) or list of sets of neighbors of a
node or network or nodes
.. doctest:: wtnetwork
>>> net = WTNetwork([[0,0,0],[1,0,1],[0,1,0]],
... theta=WTNetwork.split_threshold)
>>> [net.neighbors(node) for node in range(net.size)]
[{0, 1}, {0, 1, 2}, {1, 2}]
>>> net.theta = WTNetwork.negative_threshold
>>> [net.neighbors(node) for node in range(net.size)]
[{1}, {0, 2}, {1}]
"""
return self.neighbors_in(index) | self.neighbors_out(index)
[docs] def to_networkx_graph(self, labels='indices'):
"""
Return networkx graph given neet network.
Return a ``networkx`` graph from a :class:`WTNetwork`.
:param labels: how nodes are labeled and thus identified in networkx
graph (``'names'`` or ``'indices'``)
:returns: a ``networkx.DiGraph``
"""
if labels == 'names':
if hasattr(self, 'names') and (self.names is not None):
labels = self.names
else:
raise ValueError("network nodes do not have names")
elif labels == 'indices':
labels = range(self.size)
else:
raise ValueError("labels must be 'names' or 'indices'")
edges = []
for i, label in enumerate(labels):
for j in self.neighbors_out(i):
edges.append((labels[i], labels[j]))
return nx.DiGraph(edges, name=self.metadata.get('name'))
[docs] def draw(self, labels='indices', filename=None):
"""
Output a file with a simple network drawing.
Requires ``networkx`` and ``pygraphviz``.
Supported image formats are determined by ``graphviz``. In particular,
pdf support requires ``cairo`` and ``pango`` to be installed prior to
``graphviz`` installation.
:param labels: how node is labeled and thus identified in networkx
graph ('names' or 'indices'), only used if network is
a :class:`neet.boolean.LogicNetwork` or
:class:`WTNetwork`
:param filename: filename to write drawing to. Temporary filename will
be used if no filename provided.
:returns: a ``pygraphviz`` network drawing
"""
nx.nx_agraph.view_pygraphviz(self.to_networkx_graph(
labels=labels), prog='circo', path=filename)