In [1]:
import numpy as np
In [2]:
maxitr = 50
def mandelbrot_map(z, c):
return z * z + c
def mandelbrot_check(z):
c = z
z_iter = 0 + 0j
for i in range(maxitr):
z_iter = mandelbrot_map(z_iter, c)
if abs(z_iter) > 2:
return 0
return 1
In [3]:
z=complex(0.5, 0.5)
mandelbrot_check(z)
Out[3]:
0
In [4]:
import matplotlib.pyplot as plt
width, height = 800, 800
x_min, x_max = -2.0, 1.0
y_min, y_max = -1.5, 1.5
x = np.linspace(x_min, x_max, width)
y = np.linspace(y_min, y_max, height)
X, Y = np.meshgrid(x, y)
Z = X + 1j * Y
mandelbrot_result = np.zeros(Z.shape, dtype=int)
for i in range(width):
for j in range(height):
mandelbrot_result[j, i] = mandelbrot_check(Z[j, i])
plt.figure(figsize=(10, 10))
plt.imshow(mandelbrot_result, extent=(x_min, x_max, y_min, y_max), cmap='hot', origin='lower')
plt.title('Mandelbrot Set')
plt.xlabel('Re(z)')
plt.ylabel('Im(z)')
plt.show()
In [5]:
mandelbrot_result
Out[5]:
array([[0, 0, 0, ..., 0, 0, 0], [0, 0, 0, ..., 0, 0, 0], [0, 0, 0, ..., 0, 0, 0], ..., [0, 0, 0, ..., 0, 0, 0], [0, 0, 0, ..., 0, 0, 0], [0, 0, 0, ..., 0, 0, 0]], shape=(800, 800))
In [6]:
coordinates = np.stack([X.ravel(), Y.ravel()], axis=-1)
results = mandelbrot_result.ravel()
results.shape
coordinates
Out[6]:
array([[-2. , -1.5 ], [-1.99624531, -1.5 ], [-1.99249061, -1.5 ], ..., [ 0.99249061, 1.5 ], [ 0.99624531, 1.5 ], [ 1. , 1.5 ]], shape=(640000, 2))
In [7]:
coordinates.shape
Out[7]:
(640000, 2)
In [8]:
results.shape
Out[8]:
(640000,)
In [9]:
import torch
coordinates_tensor = torch.tensor(coordinates, dtype=torch.float32)
results_tensor = torch.tensor(results, dtype=torch.float32)
results_tensor.shape
print(torch.cuda.is_available())
False
In [10]:
train_ratio = 0.8
train_size = int(train_ratio * len(coordinates_tensor))
X_train, X_test = coordinates_tensor[:train_size], coordinates_tensor[train_size:]
y_train, y_test = results_tensor[:train_size], results_tensor[train_size:]
print(y_train.shape)
y_train = y_train.unsqueeze(1)
print(y_train.shape)
torch.Size([512000]) torch.Size([512000, 1])
In [11]:
import torch.nn as nn
class mlnn2_10_1(nn.Module):
def __init__(self):
super(mlnn2_10_1, self).__init__()
self.fc1 = nn.Linear(2, 600)
self.fc2 = nn.Linear(600, 1)
self.relu = nn.ReLU()
self.sigmoid = nn.Sigmoid()
def forward(self, x):
x = self.relu(self.fc1(x))
x = self.sigmoid(self.fc2(x))
return x
In [12]:
model = mlnn2_10_1()
In [13]:
import torch.optim as optim
criterion = nn.BCELoss()
optimizer = optim.Adam(model.parameters(), lr=0.1)
In [14]:
epochs = 1500
for epoch in range(epochs):
outputs = model(X_train)
loss = criterion(outputs, y_train)
optimizer.zero_grad()
loss.backward()
optimizer.step()
if (epoch + 1) % 10 == 0:
print(f"Epoch [{epoch+1}/{epochs}], Loss: {loss.item():.4f}")
Epoch [10/1500], Loss: 0.2667 Epoch [20/1500], Loss: 0.1967 Epoch [30/1500], Loss: 0.1043 Epoch [40/1500], Loss: 0.0947 Epoch [50/1500], Loss: 0.0858 Epoch [60/1500], Loss: 0.0804 Epoch [70/1500], Loss: 0.0776 Epoch [80/1500], Loss: 0.0759 Epoch [90/1500], Loss: 0.0746 Epoch [100/1500], Loss: 0.0735 Epoch [110/1500], Loss: 0.0725 Epoch [120/1500], Loss: 0.0717 Epoch [130/1500], Loss: 0.0709 Epoch [140/1500], Loss: 0.0703 Epoch [150/1500], Loss: 0.0698 Epoch [160/1500], Loss: 0.0693 Epoch [170/1500], Loss: 0.0689 Epoch [180/1500], Loss: 0.0686 Epoch [190/1500], Loss: 0.0682 Epoch [200/1500], Loss: 0.0679 Epoch [210/1500], Loss: 0.0676 Epoch [220/1500], Loss: 0.0673 Epoch [230/1500], Loss: 0.0670 Epoch [240/1500], Loss: 0.0667 Epoch [250/1500], Loss: 0.0663 Epoch [260/1500], Loss: 0.0660 Epoch [270/1500], Loss: 0.0657 Epoch [280/1500], Loss: 0.0654 Epoch [290/1500], Loss: 0.0651 Epoch [300/1500], Loss: 0.0648 Epoch [310/1500], Loss: 0.0647 Epoch [320/1500], Loss: 0.0643 Epoch [330/1500], Loss: 0.0639 Epoch [340/1500], Loss: 0.0636 Epoch [350/1500], Loss: 0.0636 Epoch [360/1500], Loss: 0.0630 Epoch [370/1500], Loss: 0.0625 Epoch [380/1500], Loss: 0.0624 Epoch [390/1500], Loss: 0.0617 Epoch [400/1500], Loss: 0.0618 Epoch [410/1500], Loss: 0.0610 Epoch [420/1500], Loss: 0.0605 Epoch [430/1500], Loss: 0.0609 Epoch [440/1500], Loss: 0.0599 Epoch [450/1500], Loss: 0.0594 Epoch [460/1500], Loss: 0.0598 Epoch [470/1500], Loss: 0.0605 Epoch [480/1500], Loss: 0.0580 Epoch [490/1500], Loss: 0.0578 Epoch [500/1500], Loss: 0.0574 Epoch [510/1500], Loss: 0.0582 Epoch [520/1500], Loss: 0.0564 Epoch [530/1500], Loss: 0.0570 Epoch [540/1500], Loss: 0.0567 Epoch [550/1500], Loss: 0.0552 Epoch [560/1500], Loss: 0.0570 Epoch [570/1500], Loss: 0.0552 Epoch [580/1500], Loss: 0.0542 Epoch [590/1500], Loss: 0.0563 Epoch [600/1500], Loss: 0.0543 Epoch [610/1500], Loss: 0.0534 Epoch [620/1500], Loss: 0.0546 Epoch [630/1500], Loss: 0.0528 Epoch [640/1500], Loss: 0.0547 Epoch [650/1500], Loss: 0.0549 Epoch [660/1500], Loss: 0.0521 Epoch [670/1500], Loss: 0.0526 Epoch [680/1500], Loss: 0.0519 Epoch [690/1500], Loss: 0.0516 Epoch [700/1500], Loss: 0.0545 Epoch [710/1500], Loss: 0.0521 Epoch [720/1500], Loss: 0.0513 Epoch [730/1500], Loss: 0.0504 Epoch [740/1500], Loss: 0.0516 Epoch [750/1500], Loss: 0.0501 Epoch [760/1500], Loss: 0.0515 Epoch [770/1500], Loss: 0.0501 Epoch [780/1500], Loss: 0.0499 Epoch [790/1500], Loss: 0.0492 Epoch [800/1500], Loss: 0.0491 Epoch [810/1500], Loss: 0.0677 Epoch [820/1500], Loss: 0.0547 Epoch [830/1500], Loss: 0.0499 Epoch [840/1500], Loss: 0.0487 Epoch [850/1500], Loss: 0.0485 Epoch [860/1500], Loss: 0.0483 Epoch [870/1500], Loss: 0.0480 Epoch [880/1500], Loss: 0.0479 Epoch [890/1500], Loss: 0.0477 Epoch [900/1500], Loss: 0.0480 Epoch [910/1500], Loss: 0.0551 Epoch [920/1500], Loss: 0.0484 Epoch [930/1500], Loss: 0.0481 Epoch [940/1500], Loss: 0.0472 Epoch [950/1500], Loss: 0.0472 Epoch [960/1500], Loss: 0.0468 Epoch [970/1500], Loss: 0.0467 Epoch [980/1500], Loss: 0.0587 Epoch [990/1500], Loss: 0.0475 Epoch [1000/1500], Loss: 0.0503 Epoch [1010/1500], Loss: 0.0471 Epoch [1020/1500], Loss: 0.0466 Epoch [1030/1500], Loss: 0.0464 Epoch [1040/1500], Loss: 0.0460 Epoch [1050/1500], Loss: 0.0459 Epoch [1060/1500], Loss: 0.0458 Epoch [1070/1500], Loss: 0.0457 Epoch [1080/1500], Loss: 0.0458 Epoch [1090/1500], Loss: 0.0638 Epoch [1100/1500], Loss: 0.0506 Epoch [1110/1500], Loss: 0.0469 Epoch [1120/1500], Loss: 0.0454 Epoch [1130/1500], Loss: 0.0455 Epoch [1140/1500], Loss: 0.0452 Epoch [1150/1500], Loss: 0.0453 Epoch [1160/1500], Loss: 0.0478 Epoch [1170/1500], Loss: 0.0453 Epoch [1180/1500], Loss: 0.0461 Epoch [1190/1500], Loss: 0.0451 Epoch [1200/1500], Loss: 0.0457 Epoch [1210/1500], Loss: 0.0455 Epoch [1220/1500], Loss: 0.0486 Epoch [1230/1500], Loss: 0.0458 Epoch [1240/1500], Loss: 0.0445 Epoch [1250/1500], Loss: 0.0462 Epoch [1260/1500], Loss: 0.0443 Epoch [1270/1500], Loss: 0.0458 Epoch [1280/1500], Loss: 0.0448 Epoch [1290/1500], Loss: 0.0442 Epoch [1300/1500], Loss: 0.0439 Epoch [1310/1500], Loss: 0.0628 Epoch [1320/1500], Loss: 0.0478 Epoch [1330/1500], Loss: 0.0462 Epoch [1340/1500], Loss: 0.0440 Epoch [1350/1500], Loss: 0.0441 Epoch [1360/1500], Loss: 0.0437 Epoch [1370/1500], Loss: 0.0436 Epoch [1380/1500], Loss: 0.0444 Epoch [1390/1500], Loss: 0.0459 Epoch [1400/1500], Loss: 0.0450 Epoch [1410/1500], Loss: 0.0434 Epoch [1420/1500], Loss: 0.0486 Epoch [1430/1500], Loss: 0.0454 Epoch [1440/1500], Loss: 0.0434 Epoch [1450/1500], Loss: 0.0446 Epoch [1460/1500], Loss: 0.0434 Epoch [1470/1500], Loss: 0.0429 Epoch [1480/1500], Loss: 0.0576 Epoch [1490/1500], Loss: 0.0450 Epoch [1500/1500], Loss: 0.0431
In [16]:
import numpy as np
import torch
import matplotlib.pyplot as plt
def predict(x, y):
return int(model(torch.tensor([x, y], dtype=torch.float32)))
def predict_vectorized(coords):
coords_tensor = torch.tensor(coords, dtype=torch.float32)
raw_predictions = model(coords_tensor)
predictions = (raw_predictions > 0.5).int()
return predictions.detach().numpy()
width, height = 800, 800
x_min, x_max = -2.0, 1.0
y_min, y_max = -1.5, 1.5
x = np.linspace(x_min, x_max, width)
y = np.linspace(y_min, y_max, height)
X, Y = np.meshgrid(x, y)
mandelbrot_result = predict_vectorized(coordinates)
mandelbrot_result = mandelbrot_result.reshape((height, width))
plt.figure(figsize=(10, 10))
plt.imshow(mandelbrot_result, extent=(x_min, x_max, y_min, y_max), cmap='hot', origin='lower')
plt.title('Mandelbrot Set')
plt.xlabel('Re(z)')
plt.ylabel('Im(z)')
plt.show()
In [ ]:
In [ ]: