More low-level look at GRNNs and some mechanisms
!pip install git+https://github.com/netbrainml/nbml.git
from nbml.pytorch import *
from IPython.display import clear_output
import torch.nn.functional as F
clear_output()
def linear(ni,nh): return nn.Sequential(nn.Linear(ni,nh, bias=True), nn.Linear(nh,nh, bias=True))
class LSTMCell(nn.Module):
def __init__(self, ni, nh, bs):
super().__init__()
self.ii, self.ih = linear(ni,nh)
self.fi, self.fh = linear(ni,nh)
self.gi, self.gh = linear(ni,nh)
self.oi, self.oh = linear(ni,nh)
self.bs = bs ; self.nh = nh
self.reset()
def forward(self, x):
i = torch.sigmoid(self.ii(x)+self.ih(self.h))
f = torch.sigmoid(self.fi(x)+self.fh(self.h))
g = torch.tanh(self.gi(x)+self.gh(self.h))
o = torch.tanh(self.oi(x)+self.oh(self.h))
self.c = f * self.c + i * g
self.h = o * torch.tanh(self.c)
return self.c, self.h
def reset(self):
self.h, self.c = torch.zeros(self.bs,self.nh), torch.zeros(self.bs,self.nh)
lstm = LSTMCell(1,16,1)
for item in range(3):
print(lstm(torch.ones(1,1))[1])
class GRUCell(nn.Module):
def __init__(self, ni, nh, bs):
super().__init__()
self.zi, self.zh = linear(ni,nh)
self.ri, self.rh = linear(ni,nh)
self.hi, self.hh = linear(ni,nh)
self.bs = bs ; self.nh = nh
self.reset()
def forward(self, x):
z = torch.sigmoid(self.zi(x)+self.zh(self.h))
r = torch.sigmoid(self.ri(x)+self.rh(self.h))
hh = torch.tanh(self.hi(x)+self.hh(self.h*r))
self.h = (1-z) * self.h + z * hh
return self.h
def reset(self):
self.h = torch.zeros(self.bs,self.nh)
gru = GRUCell(1,16,1)
for item in range(3):
print(gru(torch.ones(1,1)))
class MultiGRU(nn.Module):
def __init__(self, ni, nh, bs, nl=1):
super().__init__()
self.gru = nn.Sequential(GRUCell(ni, nh, bs),
*[GRUCell(nh, nh, bs) for _ in range(nl)])
def forward(self, x):
return self.gru(x)
def reset(self):
self.h = torch.zeros(self.bs,self.nh)
ml_gru = MultiGRU(1,16,1)
for item in range(3):
print(ml_gru(torch.ones(1,1)))
Randomly 'drop' nodes from forward and back propagation during training. For inference/validation, use all nodes discounted by a factor p.
class MultiGRU_do(nn.Module):
def __init__(self, ni, nh, bs, do=0.5, nl=2):
super().__init__()
self.gru = nn.Sequential(GRUCell(ni, nh, bs),
*[nn.Sequential(GRUCell(nh, nh, bs),nn.Dropout(p=do)) for _ in range(nl)])
def forward(self, x):
return self.gru(x)
def reset(self):
self.h = torch.zeros(self.bs,self.nh)
ml_gru_do = MultiGRU_do(1,16,1)
for item in range(3):
print(ml_gru_do(torch.ones(1,1)))
class MultiGRU_bd(nn.Module):
def __init__(self, ni, nh, bs, nl=2):
super().__init__()
self.fgru = nn.Sequential(GRUCell(ni, nh, bs),
*[nn.Sequential(GRUCell(nh, nh, bs)) for _ in range(nl)])
self.bgru = nn.Sequential(GRUCell(ni, nh, bs),
*[nn.Sequential(GRUCell(nh, nh, bs)) for _ in range(nl)])
def forward(self, x, backward=False):
return self.bgru(x) if backward else self.fgru(x)
def reset(self):
self.h = torch.zeros(self.bs,self.nh)
ml_gru_bd = MultiGRU_bd(1,16,1)
#Reading data normally
for item in range(3):
fo = ml_gru_bd(torch.ones(1,1))
#Reading data in reverse
for item in range(3):
bo = ml_gru_bd(torch.ones(1,1), backward=True)
fo + bo