MethSurvPredictor / webplot.py
csycsycsy's picture
Upload 10 files
84db192 verified
import torch
import torch.nn as nn
import networkx as nx
import matplotlib.pyplot as plt
import numpy as np
import os
script_path=os.path.abspath(__file__)
script_dir=os.path.dirname(script_path)
os.chdir(script_dir)
class SimpleNN(nn.Module):
def __init__(self, input_dim):
super(SimpleNN, self).__init__()
self.fc1 = nn.Linear(input_dim, 100)
self.dropout1 = nn.Dropout(0.5)
self.fc2 = nn.Linear(100, 100)
self.dropout2 = nn.Dropout(0.5)
self.fc3 = nn.Linear(100, 1)
def forward(self, x):
x = torch.relu(self.fc1(x))
x = self.dropout1(x)
x = torch.relu(self.fc2(x))
x = self.dropout2(x)
x = self.fc3(x)
return x
input_dim = 51
model = SimpleNN(input_dim)
model.load_state_dict(torch.load('best_model.pth'))
model.eval()
weights = []
weights.append(model.fc1.weight.detach().numpy())
weights.append(model.fc2.weight.detach().numpy())
weights.append(model.fc3.weight.detach().numpy())
layers = [input_dim, 100, 100, 1]
def draw_neural_network(layers, weights):
G = nx.Graph()
pos = {}
layer_nodes = []
for i, num_nodes in enumerate(layers):
layer_nodes.append([])
for j in range(num_nodes):
node_name = f'L{i}_N{j}'
layer_nodes[-1].append(node_name)
pos[node_name] = (i, -j + num_nodes // 2)
edges = []
edge_colors = []
for i in range(len(layers) - 1):
for j, node in enumerate(layer_nodes[i]):
for k, next_node in enumerate(layer_nodes[i+1]):
weight = weights[i][k, j]
edges.append((node, next_node))
edge_colors.append(weight)
G.add_edges_from(edges)
plt.figure(figsize=(10, 10))
nx.draw(G, pos, with_labels=False, node_size=700, node_color='lightblue',
edge_color=edge_colors, edge_cmap=plt.cm.viridis,
width=2, edge_vmin=min(edge_colors), edge_vmax=max(edge_colors))
for key, value in pos.items():
plt.text(value[0], value[1] + 0.1, key, ha='center', va='center')
plt.title("Neural Network Visualization")
plt.show()
draw_neural_network(layers, weights)