You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

66 lines
2.3 KiB

import numpy as np
import torch
from numba import jit
@jit(nopython=True, fastmath=True, cache=True)
def sigmoid(x):
return 1.0/(1 + np.exp(-x))
@jit(nopython=True, fastmath=True, cache=True)
def AANN_Derivative(v1, v2, w1, w2, testsample, faulty_directions):
"""
:param v1:
:param v2:
:param w1:
:param w2:
:param testsample:
:param faulty_directions:
:return:
"""
[mmv1, nnv1] = v1.shape
[mmw1, nnw1] = w1.shape
[mmv2, nnv2] = v2.shape
[mmw2, nnw2] = w2.shape
z = np.dot(w1, v2)
faulty_number = faulty_directions.shape[0]
max_count = 3000
spe = np.zeros(max_count+1)
count = 0
derivative, ff, delta = np.zeros((max_count+1, mmv1)), np.zeros((max_count+1, mmv1)), np.zeros((max_count+1, mmv1))
y_ = testsample
out, e = np.zeros((max_count, mmv1)),np.zeros((max_count, mmv1))
ahfa1 = 0.0018
ahfa2 = 0.9
while 1:
count = count + 1
testsample = np.copy(y_)
delta[count, :] = (-ahfa1*derivative[count-1, :]+ahfa2*delta[count-1, :])
ff[count, :] = (delta[count])
y_ = y_-ff[count]
g = sigmoid(np.dot(y_, v1))
t = np.dot(g, w1)
h = sigmoid(np.dot(t, v2))
out[count-1, :] = np.dot(h, w2).reshape((1,mmv1))
e[count-1, :] = y_-out[count-1, :]
spe[count] = np.sum(e[count-1, :]*e[count-1, :])
if count >= max_count or np.abs(spe[count-1]-spe[count]) < 0.000001:
iteration_number = count
break
for i in range(faulty_number):
deltyf = np.zeros((mmv1, mmv1))
yitao, yita = np.zeros(nnv2), np.zeros(nnv2)
yitao = yitao + np.dot(g[0] * (1 - g[0]) * (-v1[int(faulty_directions[i])]), z)
yita = np.dot(h[0] * (1 - h[0]) * yitao, w2)
deltyf[:, int(faulty_directions[i])] = yita.T
ee = testsample - out[count-1, :]
derivative[count, int(faulty_directions[i])] = (-2*np.sum(ee*deltyf[:, int(faulty_directions[i])])-2*(ee[int(faulty_directions[i])]-2*deltyf[int(faulty_directions[i]), int(faulty_directions[i])]*ff[count, int(faulty_directions[i])])+2*ff[count, int(faulty_directions[i])])
rbc_spe = spe[count]
delt_f = np.sum(ff, axis=0)
return rbc_spe, delt_f.reshape(1,mmv1), iteration_number, spe[:count]