Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions neurodiffeq/conditions.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,12 @@
import numpy as np
import torch
import warnings
import logging
from .neurodiffeq import safe_diff as diff
from ._version_utils import deprecated_alias

conditions_logger = logging.getLogger('neurodiffeq.conditions')


class BaseCondition:
r"""Base class for all conditions.
Expand Down Expand Up @@ -48,6 +51,8 @@ def enforce(self, net, *coordinates):
:return: The re-parameterized output, where the condition is automatically satisfied.
:rtype: `torch.Tensor`
"""
if conditions_logger.isEnabledFor(logging.DEBUG):
conditions_logger.debug(f"Enforcing {self.__class__.__name__} condition on {coordinates[0].shape[0]} points")
# concatenate the coordinates and pass to network
network_output = net(torch.cat(coordinates, dim=1))
# if `ith_unit` is set, the condition will only be enforced on the i-th output unit
Expand Down Expand Up @@ -165,13 +170,18 @@ class EnsembleCondition(BaseCondition):

def __init__(self, *sub_conditions, force=False):
super(EnsembleCondition, self).__init__()

conditions_logger.debug(f"Creating EnsembleCondition with {len(sub_conditions)} sub-conditions")

for i, c in enumerate(sub_conditions):
if c.__class__.enforce != BaseCondition.enforce:
msg = f"{c.__class__.__name__} (index={i})'s overrides BaseCondition's `.enforce` method. " \
f"Ensembl'ing is likely not going to work."
if force:
conditions_logger.warning(f"Forcing ensemble creation despite override: {msg}")
warnings.warn(msg)
else:
conditions_logger.error(f"Cannot create ensemble: {msg}")
raise ValueError(msg + "\nTry with `force=True` if you know what you are doing.")

self.conditions = sub_conditions
Expand Down
5 changes: 5 additions & 0 deletions neurodiffeq/generators.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,11 @@
"""
import torch
import numpy as np
import logging
from typing import List

generators_logger = logging.getLogger('neurodiffeq.generators')


def _chebyshev_first(a, b, n):
nodes = torch.cos(((torch.arange(n) + 0.5) / n) * np.pi)
Expand Down Expand Up @@ -143,6 +146,8 @@ def __init__(self, size, t_min=0.0, t_max=1.0, method='uniform', noise_std=None)
self.size = size
self.t_min, self.t_max = t_min, t_max
self.method = method

generators_logger.debug(f"Initialized Generator1D: size={size}, t_min={t_min}, t_max={t_max}, method={method}")
if noise_std:
self.noise_std = noise_std
else:
Expand Down
47 changes: 38 additions & 9 deletions neurodiffeq/losses.py
Original file line number Diff line number Diff line change
@@ -1,29 +1,58 @@
import torch
import logging
from .operators import grad

losses_logger = logging.getLogger('neurodiffeq.losses')


def _l1_norm(residual, funcs, coords):
return torch.abs(residual).mean()
loss = torch.abs(residual).mean()
if losses_logger.isEnabledFor(logging.DEBUG):
losses_logger.debug(f"L1 loss: {loss.item():.6f}, residual stats: mean={residual.mean().item():.6f}, std={residual.std().item():.6f}")
return loss


def _l2_norm(residual, funcs, coords):
return (residual ** 2).mean()
loss = (residual ** 2).mean()
if losses_logger.isEnabledFor(logging.DEBUG):
losses_logger.debug(f"L2 loss: {loss.item():.6f}, residual stats: mean={residual.mean().item():.6f}, std={residual.std().item():.6f}")
return loss


def _infinity_norm(residual, funcs, coords):
return residual.abs().max(dim=1)[0].mean()
loss = residual.abs().max(dim=1)[0].mean()
if losses_logger.isEnabledFor(logging.DEBUG):
max_residual = residual.abs().max().item()
losses_logger.debug(f"Infinity loss: {loss.item():.6f}, max residual: {max_residual:.6f}")
return loss


def _h1_norm(residual, funcs, coords):
g = grad(residual, *coords)
rg = torch.cat([residual, *g], dim=1)
return (rg ** 2).mean()
try:
g = grad(residual, *coords)
rg = torch.cat([residual, *g], dim=1)
loss = (rg ** 2).mean()
if losses_logger.isEnabledFor(logging.DEBUG):
grad_norm = torch.cat(g, dim=1).norm().item()
losses_logger.debug(f"H1 loss: {loss.item():.6f}, gradient norm: {grad_norm:.6f}")
return loss
except Exception as e:
losses_logger.error(f"Error computing H1 norm: {e}")
raise


def _h1_semi_norm(residual, funcs, coords):
g = grad(residual, *coords)
g = torch.cat(g, dim=1)
return (g ** 2).mean()
try:
g = grad(residual, *coords)
g = torch.cat(g, dim=1)
loss = (g ** 2).mean()
if losses_logger.isEnabledFor(logging.DEBUG):
grad_norm = g.norm().item()
losses_logger.debug(f"H1 semi loss: {loss.item():.6f}, gradient norm: {grad_norm:.6f}")
return loss
except Exception as e:
losses_logger.error(f"Error computing H1 semi-norm: {e}")
raise


_losses = {
Expand Down
12 changes: 12 additions & 0 deletions neurodiffeq/monitors.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import math
import torch
import warnings
import logging
import matplotlib
import numpy as np
import pandas as pd
Expand All @@ -10,6 +11,8 @@
import seaborn as sns
from abc import ABC, abstractmethod

monitors_logger = logging.getLogger('neurodiffeq.monitors')

from ._version_utils import deprecated_alias
from .function_basis import RealSphericalHarmonics as _RealSphericalHarmonics
from .generators import Generator1D as _Generator1D
Expand Down Expand Up @@ -40,8 +43,11 @@ def __init__(self, check_every=None):
self.check_every = check_every or 100
self.fig = ...
self.using_non_gui_backend = (matplotlib.get_backend() == 'agg')

monitors_logger.debug(f"Initialized {self.__class__.__name__}: check_every={self.check_every}, backend={matplotlib.get_backend()}")

if matplotlib.get_backend() == 'module://ipykernel.pylab.backend_inline':
monitors_logger.warning("Using jupyter inline backend - plots may not update properly")
warnings.warn(
"You seem to be using jupyter notebook with '%matplotlib inline' "
"which can lead to monitor plots not updating. "
Expand Down Expand Up @@ -705,6 +711,12 @@ def check(self, nets, conditions, history):
.. note::
`check` is meant to be called by the function `solve2D`.
"""

try:
monitors_logger.debug(f"Monitor2D check: {len(nets)} networks, {len(conditions)} conditions")
except Exception as e:
monitors_logger.error(f"Error during monitor check: {e}")
return

if not self.fig:
# initialize the figure and axes here so that the Monitor knows the number of dependent variables and
Expand Down
5 changes: 5 additions & 0 deletions neurodiffeq/networks.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
import torch
import torch.nn as nn
import logging
from warnings import warn

networks_logger = logging.getLogger('neurodiffeq.networks')


class FCNN(nn.Module):
"""A fully connected neural network.
Expand Down Expand Up @@ -64,6 +67,8 @@ def __init__(self, n_input_units=1, n_output_units=1, n_hidden_units=None, n_hid
# There's not activation in after the last layer
layers.append(nn.Linear(units[-1], n_output_units))
self.NN = torch.nn.Sequential(*layers)

networks_logger.debug(f"Initialized FCNN: input={n_input_units}, output={n_output_units}, hidden={hidden_units}")

def forward(self, t):
x = self.NN(t)
Expand Down
6 changes: 6 additions & 0 deletions neurodiffeq/ode.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import matplotlib
import matplotlib.pyplot as plt
import numpy as np
import logging

import torch
import torch.nn as nn
Expand All @@ -15,6 +16,8 @@
from copy import deepcopy
import warnings

ode_logger = logging.getLogger('neurodiffeq.ode')

ExampleGenerator = warn_deprecate_class(Generator1D)
Monitor = warn_deprecate_class(Monitor1D)

Expand Down Expand Up @@ -261,6 +264,9 @@ def solve_system(
"The `solve_system` function is deprecated, use a `neurodiffeq.solvers.Solver1D` instance instead",
FutureWarning,
)

ode_logger.info(f"Solving ODE system with {len(conditions)} equations for {max_epochs} epochs")

if single_net and nets:
raise ValueError('Only one of net and nets should be specified')

Expand Down
28 changes: 21 additions & 7 deletions neurodiffeq/operators.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,11 @@
import torch
import logging
from torch import sin, cos
from torch import autograd
from .neurodiffeq import safe_diff as diff

operators_logger = logging.getLogger('neurodiffeq.operators')


def _split_u_x(*us_xs):
if len(us_xs) == 0 or len(us_xs) % 2 != 0:
Expand All @@ -24,13 +27,24 @@ def grad(u, *xs):
:return: A tuple of :math:`\frac{\partial u}{\partial x_1}`, ..., :math:`\frac{\partial u}{\partial x_n}`
:rtype: List[`torch.Tensor`]
"""
grads = []
for x, g in zip(xs, autograd.grad(u, xs, grad_outputs=torch.ones_like(u), create_graph=True, allow_unused=True)):
if g is None:
grads.append(torch.zeros_like(x, requires_grad=True))
else:
grads.append(g.requires_grad_(True))
return grads
try:
grads = []
for x, g in zip(xs, autograd.grad(u, xs, grad_outputs=torch.ones_like(u), create_graph=True, allow_unused=True)):
if g is None:
grads.append(torch.zeros_like(x, requires_grad=True))
if operators_logger.isEnabledFor(logging.DEBUG):
operators_logger.debug(f"Gradient w.r.t. {x.shape} is None, using zeros")
else:
grads.append(g.requires_grad_(True))

if operators_logger.isEnabledFor(logging.DEBUG):
grad_norms = [g.norm().item() for g in grads]
operators_logger.debug(f"Computed gradients with norms: {grad_norms}")

return grads
except Exception as e:
operators_logger.error(f"Error computing gradient: {e}")
raise


def div(*us_xs):
Expand Down
6 changes: 6 additions & 0 deletions neurodiffeq/pde.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import torch
import warnings
import logging
import torch.optim as optim
import torch.nn as nn

Expand All @@ -19,6 +20,8 @@
from .solvers import Solver2D
from copy import deepcopy

pde_logger = logging.getLogger('neurodiffeq.pde')

ExampleGenerator2D = warn_deprecate_class(Generator2D)
PredefinedExampleGenerator2D = warn_deprecate_class(PredefinedGenerator)
Solution = warn_deprecate_class(Solution2D)
Expand Down Expand Up @@ -286,6 +289,9 @@ def solve2D_system(
"The `solve2D_system` function is deprecated, use a `neurodiffeq.solvers.Solver2D` instance instead",
FutureWarning,
)

pde_logger.info(f"Solving 2D PDE system with {len(conditions)} equations for {max_epochs} epochs")

if single_net and nets:
raise ValueError('Only one of net and nets should be specified')

Expand Down
25 changes: 25 additions & 0 deletions neurodiffeq/solvers.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import sys
import warnings
import inspect
import logging
from inspect import signature
from abc import ABC, abstractmethod
from itertools import chain
Expand All @@ -25,6 +26,8 @@
from .neurodiffeq import safe_diff as diff
from .losses import _losses

logger = logging.getLogger('neurodiffeq.solvers')


def _requires_closure(optimizer):
# starting from torch v1.13, simple optimizers no longer have a `closure` argument
Expand Down Expand Up @@ -133,13 +136,18 @@ def __init__(self, diff_eqs, conditions,
self.diff_eqs = diff_eqs
self.conditions = conditions
self.n_funcs = len(conditions)

logger.info(f"Initializing solver with {self.n_funcs} functions")

if nets is None:
self.nets = [
FCNN(n_input_units=n_input_units, n_output_units=n_output_units, hidden_units=(32, 32), actv=nn.Tanh)
for _ in range(self.n_funcs)
]
logger.info(f"Created {len(self.nets)} default FCNN networks")
else:
self.nets = nets
logger.info(f"Using {len(self.nets)} provided networks")

if train_generator is None:
raise ValueError("train_generator must be specified")
Expand Down Expand Up @@ -180,7 +188,10 @@ def analytic_mse(*args):
self.metrics_history.update({'valid__' + name: [] for name in self.metrics_fn})

self.optimizer = optimizer if optimizer else Adam(OrderedSet(chain.from_iterable(n.parameters() for n in self.nets)))
logger.info(f"Using optimizer: {type(self.optimizer).__name__}")

self._set_loss_fn(loss_fn)
logger.info(f"Loss function: {self.loss_fn}")

def make_pair_dict(train=None, valid=None):
return {'train': train, 'valid': valid}
Expand Down Expand Up @@ -437,8 +448,13 @@ def _update_best(self, key):
"""
current_loss = self.metrics_history[key + '_loss'][-1]
if (self.lowest_loss is None) or current_loss < self.lowest_loss:
previous_best = self.lowest_loss
self.lowest_loss = current_loss
self.best_nets = deepcopy(self.nets)
if previous_best is not None:
logger.debug(f"New best model found: {key}_loss improved from {previous_best:.6f} to {current_loss:.6f}")
else:
logger.debug(f"Initial best model set: {key}_loss = {current_loss:.6f}")

def fit(self, max_epochs, callbacks=(), tqdm_file=sys.stderr, **kwargs):
r"""Run multiple epochs of training and validation, update best loss at the end of each epoch.
Expand All @@ -463,6 +479,8 @@ def fit(self, max_epochs, callbacks=(), tqdm_file=sys.stderr, **kwargs):
"""
self._stop_training = False
self._max_local_epoch = max_epochs

logger.info(f"Starting training for {max_epochs} epochs")

monitor = kwargs.pop('monitor', None)
if monitor:
Expand All @@ -486,12 +504,19 @@ def fit(self, max_epochs, callbacks=(), tqdm_file=sys.stderr, **kwargs):
for local_epoch in loop:
# stop training if self._stop_training is set to True by a callback
if self._stop_training:
logger.info(f"Training stopped early at epoch {local_epoch + 1}")
break

# register local epoch (starting from 1 instead of 0) so it can be accessed by callbacks
self.local_epoch = local_epoch + 1
self.run_train_epoch()
self.run_valid_epoch()

# Log progress every 10% or at specific intervals
if local_epoch == 0 or (local_epoch + 1) % max(1, max_epochs // 10) == 0 or local_epoch == max_epochs - 1:
train_loss = self.metrics_history['train_loss'][-1] if self.metrics_history['train_loss'] else 'N/A'
valid_loss = self.metrics_history['valid_loss'][-1] if self.metrics_history['valid_loss'] else 'N/A'
logger.info(f"Epoch {local_epoch + 1}/{max_epochs}: train_loss={train_loss:.6f}, valid_loss={valid_loss:.6f}")

for cb in callbacks:
cb(self)
Expand Down
Loading
Loading