-data = data_spiral(1000)
-# data = data_zigzag(1000)
-# data = data_penta(1000)
-
-data = data - data.mean(0)
-
-batch_size, nb_epochs = 100, 1000
-optimizer = torch.optim.Adam(model.parameters(), lr = 1e-3)
-criterion = nn.MSELoss()
-
-for e in range(nb_epochs):
- acc_loss = 0
- for input in data.split(batch_size):
- noise = input.new(input.size()).normal_(0, 0.1)
- output = model(input + noise)
- loss = criterion(output, input)
- acc_loss += loss.item()
- optimizer.zero_grad()
- loss.backward()
- optimizer.step()
- if (e+1)%10 == 0: print(e+1, acc_loss)
+def train_model(data):
+ model = nn.Sequential(
+ nn.Linear(2, 100),
+ nn.ReLU(),
+ nn.Linear(100, 2)
+ )
+
+ batch_size, nb_epochs = 100, 1000
+ optimizer = torch.optim.Adam(model.parameters(), lr = 1e-3)
+ criterion = nn.MSELoss()
+
+ for e in range(nb_epochs):
+ acc_loss = 0
+ for input in data.split(batch_size):
+ noise = input.new(input.size()).normal_(0, 0.1)
+ output = model(input + noise)
+ loss = criterion(output, input)
+ acc_loss += loss.item()
+ optimizer.zero_grad()
+ loss.backward()
+ optimizer.step()
+ if (e+1)%100 == 0: print(e+1, acc_loss)
+
+ return model