MethSurvPredictor / webplot.py
csycsycsy's picture
Upload 10 files
84db192 verified
raw
history blame
2.29 kB
import torch
import torch.nn as nn
import networkx as nx
import matplotlib.pyplot as plt
import numpy as np
import os
script_path=os.path.abspath(__file__)
script_dir=os.path.dirname(script_path)
os.chdir(script_dir)
class SimpleNN(nn.Module):
def __init__(self, input_dim):
super(SimpleNN, self).__init__()
self.fc1 = nn.Linear(input_dim, 100)
self.dropout1 = nn.Dropout(0.5)
self.fc2 = nn.Linear(100, 100)
self.dropout2 = nn.Dropout(0.5)
self.fc3 = nn.Linear(100, 1)
def forward(self, x):
x = torch.relu(self.fc1(x))
x = self.dropout1(x)
x = torch.relu(self.fc2(x))
x = self.dropout2(x)
x = self.fc3(x)
return x
input_dim = 51
model = SimpleNN(input_dim)
model.load_state_dict(torch.load('best_model.pth'))
model.eval()
weights = []
weights.append(model.fc1.weight.detach().numpy())
weights.append(model.fc2.weight.detach().numpy())
weights.append(model.fc3.weight.detach().numpy())
layers = [input_dim, 100, 100, 1]
def draw_neural_network(layers, weights):
G = nx.Graph()
pos = {}
layer_nodes = []
for i, num_nodes in enumerate(layers):
layer_nodes.append([])
for j in range(num_nodes):
node_name = f'L{i}_N{j}'
layer_nodes[-1].append(node_name)
pos[node_name] = (i, -j + num_nodes // 2)
edges = []
edge_colors = []
for i in range(len(layers) - 1):
for j, node in enumerate(layer_nodes[i]):
for k, next_node in enumerate(layer_nodes[i+1]):
weight = weights[i][k, j]
edges.append((node, next_node))
edge_colors.append(weight)
G.add_edges_from(edges)
plt.figure(figsize=(10, 10))
nx.draw(G, pos, with_labels=False, node_size=700, node_color='lightblue',
edge_color=edge_colors, edge_cmap=plt.cm.viridis,
width=2, edge_vmin=min(edge_colors), edge_vmax=max(edge_colors))
for key, value in pos.items():
plt.text(value[0], value[1] + 0.1, key, ha='center', va='center')
plt.title("Neural Network Visualization")
plt.show()
draw_neural_network(layers, weights)