import torch
import [Link] as nn
import [Link] as optim
from torchvision import datasets, transforms
import numpy as np
import [Link] as plt
# Configuration
NUM_DEVICES = 5
EPOCHS = 2
ROUNDS = 10
BATCH_SIZE = 64
LEARNING_RATE = 0.01
# Energy, Bandwidth, SNR initialization
device_energy = [Link](30, 100, NUM_DEVICES)
device_bandwidth = [Link](1.0, 5.0, NUM_DEVICES)
device_snr = [Link](10, 40, NUM_DEVICES)
# Dataset loading (split among devices)
transform = [Link]([[Link]()])
full_dataset = [Link]('./data', train=True, download=True,
transform=transform)
datasets_split = [Link].random_split(full_dataset,
[int(len(full_dataset)/NUM_DEVICES)]*NUM_DEVICES)
# Homogeneous CNN model for all devices
class CNN_Model([Link]):
def __init__(self):
super(CNN_Model, self).__init__()
[Link] = [Link](
nn.Conv2d(1, 10, 5),
[Link](),
nn.MaxPool2d(2),
[Link](),
[Link](1440, 10)
)
def forward(self, x):
return [Link](x)
# Initialize models and optimizers
models = [CNN_Model() for _ in range(NUM_DEVICES)]
optimizers = [[Link]([Link](), lr=LEARNING_RATE) for model in models]
loss_fn = [Link]()
# Training function
def train_model(model, data_loader, optimizer):
[Link]()
for epoch in range(EPOCHS):
for data, target in data_loader:
optimizer.zero_grad()
output = model(data)
loss = loss_fn(output, target)
[Link]()
[Link]()
return model.state_dict()
# Model averaging
def average_models(model_states):
new_state = {}
for key in model_states[0].keys():
new_state[key] = sum([state[key] for state in model_states]) /
len(model_states)
return new_state
# Evaluation function
def evaluate_model(model, test_loader):
[Link]()
correct = 0
total = 0
with torch.no_grad():
for data, target in test_loader:
output = model(data)
_, predicted = [Link]([Link], 1)
total += [Link](0)
correct += (predicted == target).sum().item()
return correct / total
# Test data loader
test_loader = [Link]([Link]('./data', train=False,
download=True, transform=transform), batch_size=1000)
# Federated training rounds
accuracy_per_round = []
for round_idx in range(ROUNDS):
print(f"\n--- Round {round_idx+1} ---")
# Scheduling: energy > 40, bandwidth > 2.0, SNR > 15
selected_devices = []
for i in range(NUM_DEVICES):
if device_energy[i] > 40 and device_bandwidth[i] > 2.0 and device_snr[i] >
15:
selected_devices.append(i)
print(f"Selected devices: {selected_devices}")
model_states = []
for i in selected_devices:
data_loader = [Link](datasets_split[i],
batch_size=BATCH_SIZE, shuffle=True)
model_states.append(train_model(models[i], data_loader, optimizers[i]))
if model_states:
avg_state = average_models(model_states)
for model in models:
model.load_state_dict(avg_state)
acc = evaluate_model(models[0], test_loader)
accuracy_per_round.append(acc)
print(f"Accuracy after Round {round_idx+1}: {acc*100:.2f}%")
# Restore built-in round function in case it was overwritten
if isinstance(round, int):
del round
print("\n--- Device Initialization Info ---")
for i in range(NUM_DEVICES):
print(f"Device {i}: Energy = {device_energy[i]:.2f}, Bandwidth =
{device_bandwidth[i]:.2f}, SNR = {device_snr[i]:.2f}")
# Plotting results
[Link](range(1, ROUNDS+1), [round(a*100, 2) for a in accuracy_per_round],
marker='o')
[Link]("Rounds")
[Link]("Accuracy (%)")
[Link]("Federated Learning Accuracy Over Rounds")
[Link](True)
plt.tight_layout()
[Link]()