-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathfitstomat.py
81 lines (67 loc) · 2.53 KB
/
fitstomat.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
import torch
import stomatalmodels as stomat
def getACi(fvcbmtt, gsw, learnrate = 2, maxiteration = 8000, minloss = 1e-10):
gsmtest = stomat.gsACi(torch.tensor(gsw))
optimizer = torch.optim.Adam(gsmtest.parameters(), lr=learnrate)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size= 2000, gamma=0.9)
best_loss = 100000
best_iter = 0
best_weights = gsmtest.state_dict()
criterion = stomat.lossA()
minloss = minloss
for iter in range(maxiteration):
optimizer.zero_grad()
An_gs = gsmtest()
fvcbmtt.lcd.Ci = gsmtest.Ci
fvcbmtt.lcd.A = An_gs
An_f, Ac_o, Aj_o, Ap_o = fvcbmtt()
loss = criterion(An_f, An_gs, gsmtest.Ci)
loss.backward()
if (iter + 1) % 100 == 0:
# print(vcmax25)
print(f'Loss at iter {iter}: {loss.item():.4f}')
optimizer.step()
scheduler.step()
if loss.item() < minloss:
best_loss = loss.item()
best_weights = gsmtest.state_dict()
best_iter = iter
print(f'Fitting converged at iter {iter}: {loss.item():.4f}')
break
if loss.item() < best_loss:
best_loss = loss.item()
best_weights = gsmtest.state_dict()
best_iter = iter
print(f'Best loss at iter {best_iter}: {best_loss:.4f}')
gsmtest.load_state_dict(best_weights)
return gsmtest
def run(scm, gsw, learnrate = 0.01, maxiteration = 8000, minloss = 1e-6):
optimizer = torch.optim.Adam(scm.parameters(), lr=learnrate)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=2000, gamma=0.9)
best_loss = 100000
best_iter = 0
best_weights = scm.state_dict()
criterion = stomat.lossSC()
for iter in range(maxiteration):
optimizer.zero_grad()
gs_fit = scm()
loss = criterion(scm,gs_fit,gsw)
loss.backward()
if (iter + 1) % 100 == 0:
# print(vcmax25)
print(f'Loss at iter {iter}: {loss.item():.4f}')
optimizer.step()
scheduler.step()
if loss.item() < minloss:
best_loss = loss.item()
best_weights = scm.state_dict()
best_iter = iter
print(f'Fitting converged at iter {iter}: {loss.item():.4f}')
break
if loss.item() < best_loss:
best_loss = loss.item()
best_weights = scm.state_dict()
best_iter = iter
print(f'Best loss at iter {best_iter}: {best_loss:.4f}')
scm.load_state_dict(best_weights)
return scm