import torch import torch.nn as nn import networkx as nx import matplotlib.pyplot as plt import numpy as np import os script_path=os.path.abspath(__file__) script_dir=os.path.dirname(script_path) os.chdir(script_dir) class SimpleNN(nn.Module): def __init__(self, input_dim): super(SimpleNN, self).__init__() self.fc1 = nn.Linear(input_dim, 100) self.dropout1 = nn.Dropout(0.5) self.fc2 = nn.Linear(100, 100) self.dropout2 = nn.Dropout(0.5) self.fc3 = nn.Linear(100, 1) def forward(self, x): x = torch.relu(self.fc1(x)) x = self.dropout1(x) x = torch.relu(self.fc2(x)) x = self.dropout2(x) x = self.fc3(x) return x input_dim = 51 model = SimpleNN(input_dim) model.load_state_dict(torch.load('best_model.pth')) model.eval() weights = [] weights.append(model.fc1.weight.detach().numpy()) weights.append(model.fc2.weight.detach().numpy()) weights.append(model.fc3.weight.detach().numpy()) layers = [input_dim, 100, 100, 1] def draw_neural_network(layers, weights): G = nx.Graph() pos = {} layer_nodes = [] for i, num_nodes in enumerate(layers): layer_nodes.append([]) for j in range(num_nodes): node_name = f'L{i}_N{j}' layer_nodes[-1].append(node_name) pos[node_name] = (i, -j + num_nodes // 2) edges = [] edge_colors = [] for i in range(len(layers) - 1): for j, node in enumerate(layer_nodes[i]): for k, next_node in enumerate(layer_nodes[i+1]): weight = weights[i][k, j] edges.append((node, next_node)) edge_colors.append(weight) G.add_edges_from(edges) plt.figure(figsize=(10, 10)) nx.draw(G, pos, with_labels=False, node_size=700, node_color='lightblue', edge_color=edge_colors, edge_cmap=plt.cm.viridis, width=2, edge_vmin=min(edge_colors), edge_vmax=max(edge_colors)) for key, value in pos.items(): plt.text(value[0], value[1] + 0.1, key, ha='center', va='center') plt.title("Neural Network Visualization") plt.show() draw_neural_network(layers, weights)