2020from .generators import Generator2D
2121from .generators import GeneratorND
2222from .function_basis import RealSphericalHarmonics
23- from .conditions import BaseCondition
23+ from .conditions import BaseCondition , NoCondition
2424from .neurodiffeq import safe_diff as diff
2525from .losses import _losses
2626
@@ -113,7 +113,7 @@ class BaseSolver(ABC, PretrainedSolver):
113113 def __init__ (self , diff_eqs , conditions ,
114114 nets = None , train_generator = None , valid_generator = None , analytic_solutions = None ,
115115 optimizer = None , loss_fn = None , n_batches_train = 1 , n_batches_valid = 4 ,
116- metrics = None , n_input_units = None , n_output_units = None ,
116+ metrics = None , n_input_units = None , n_output_units = None , system_parameters = None ,
117117 # deprecated arguments are listed below
118118 shuffle = None , batch_size = None ):
119119 # deprecate argument `shuffle`
@@ -130,6 +130,9 @@ def __init__(self, diff_eqs, conditions,
130130 )
131131
132132 self .diff_eqs = diff_eqs
133+ self .system_parameters = {}
134+ if system_parameters is not None :
135+ self .system_parameters = system_parameters
133136 self .conditions = conditions
134137 self .n_funcs = len (conditions )
135138 if nets is None :
@@ -376,7 +379,7 @@ def closure(zero_grad=True):
376379 for name in self .metrics_fn :
377380 value = self .metrics_fn [name ](* funcs , * batch ).item ()
378381 metric_values [name ] += value
379- residuals = self .diff_eqs (* funcs , * batch )
382+ residuals = self .diff_eqs (* funcs , * batch , ** self . system_parameters )
380383 residuals = torch .cat (residuals , dim = 1 )
381384 try :
382385 loss = self .loss_fn (residuals , funcs , batch ) + self .additional_loss (residuals , funcs , batch )
@@ -1105,7 +1108,7 @@ class Solver1D(BaseSolver):
11051108
11061109 def __init__ (self , ode_system , conditions , t_min = None , t_max = None ,
11071110 nets = None , train_generator = None , valid_generator = None , analytic_solutions = None , optimizer = None ,
1108- loss_fn = None , n_batches_train = 1 , n_batches_valid = 4 , metrics = None , n_output_units = 1 ,
1111+ loss_fn = None , n_batches_train = 1 , n_batches_valid = 4 , metrics = None , n_output_units = 1 , system_parameters = None ,
11091112 # deprecated arguments are listed below
11101113 batch_size = None , shuffle = None ):
11111114
@@ -1136,6 +1139,7 @@ def __init__(self, ode_system, conditions, t_min=None, t_max=None,
11361139 metrics = metrics ,
11371140 n_input_units = 1 ,
11381141 n_output_units = n_output_units ,
1142+ system_parameters = system_parameters ,
11391143 shuffle = shuffle ,
11401144 batch_size = batch_size ,
11411145 )
@@ -1164,11 +1168,12 @@ def get_solution(self, copy=True, best=True):
11641168 :rtype: BaseSolution
11651169 """
11661170 nets = self .best_nets if best else self .nets
1171+ print (nets )
11671172 conditions = self .conditions
11681173 if copy :
11691174 nets = deepcopy (nets )
11701175 conditions = deepcopy (conditions )
1171-
1176+ print ( nets )
11721177 return Solution1D (nets , conditions )
11731178
11741179 def _get_internal_variables (self ):
@@ -1590,3 +1595,300 @@ def _get_internal_variables(self):
15901595 'xy_max' : self .xy_max ,
15911596 })
15921597 return available_variables
1598+
1599+ class _SingleSolver1D (GenericSolver ):
1600+
1601+ class Head (nn .Module ):
1602+ def __init__ (self , u_0 , base , n_input , n_output = 1 ):
1603+ super ().__init__ ()
1604+ self .u_0 = u_0
1605+ self .base = base
1606+ self .last_layer = nn .Linear (n_input , n_output )
1607+
1608+ def forward (self , x ):
1609+ x = self .base (x )
1610+ x = self .last_layer (x )
1611+ return x
1612+
1613+ def __init__ (self , bases , HeadClass , initial_conditions , n_last_layer_head , diff_eqs ,
1614+ system_parameters = [{}],
1615+ optimizer = torch .optim .Adam , optimizer_args = None , optimizer_kwargs = {"lr" :1e-3 },
1616+ train_generator = None , valid_generator = None , n_batches_train = 1 , n_batches_valid = 4 ,
1617+ loss_fn = None , metrics = None , is_system = False ):
1618+
1619+ if train_generator is None or valid_generator is None :
1620+ raise Exception (f"Train and Valid Generator cannot be None" )
1621+
1622+ self .num = len (initial_conditions )
1623+ self .bases = bases
1624+ if HeadClass is None :
1625+ if is_system :
1626+ self .head = [self .Head (initial_conditions [i ], self .bases [i ], n_last_layer_head ) for i in range (self .num )]
1627+ else :
1628+ self .head = [self .Head (torch .Tensor (initial_conditions ).view (1 , - 1 ), self .bases , n_last_layer_head , len (initial_conditions ))]
1629+ else :
1630+ if is_system :
1631+ self .head = [HeadClass (initial_conditions [i ], self .bases [i ], n_last_layer_head ) for i in range (self .num )]
1632+ else :
1633+ self .head = [HeadClass (torch .Tensor (initial_conditions ).view (1 , - 1 ), self .bases , n_last_layer_head , len (initial_conditions ))]
1634+
1635+ self .optimizer_args = optimizer_args or ()
1636+ self .optimizer_kwargs = optimizer_kwargs or {}
1637+
1638+ if isinstance (optimizer , torch .optim .Optimizer ):
1639+ self .optimizer = optimizer
1640+ elif issubclass (optimizer , torch .optim .Optimizer ):
1641+ params = chain .from_iterable (n .parameters () for n in self .head )
1642+ self .optimizer = optimizer (params , * self .optimizer_args , ** self .optimizer_kwargs )
1643+ else :
1644+ raise TypeError (f"Unknown optimizer instance/type { self .optimizer } " )
1645+
1646+ super ().__init__ (
1647+ diff_eqs = diff_eqs ,
1648+ conditions = [NoCondition ()]* self .num ,
1649+ train_generator = train_generator ,
1650+ valid_generator = valid_generator ,
1651+ nets = self .head ,
1652+ system_parameters = system_parameters ,
1653+ optimizer = self .optimizer ,
1654+ n_batches_train = n_batches_train ,
1655+ n_batches_valid = n_batches_valid ,
1656+ loss_fn = loss_fn ,
1657+ metrics = metrics
1658+ )
1659+
1660+ def additional_loss (self , residuals , funcs , coords ):
1661+
1662+ loss = 0
1663+ for i in range (len (self .nets )):
1664+ out = self .nets [i ](torch .zeros ((1 ,1 )))
1665+ loss += ((self .nets [i ].u_0 - out )** 2 ).mean ()
1666+ return loss
1667+
1668+
1669+ class UniversalSolver1D (ABC ):
1670+ r"""A solver class for solving a family of ODEs (for different initial conditions and parameters)
1671+
1672+ :param ode_system:
1673+ The ODE system to solve, which maps a torch.Tensor to a tuple of ODE residuals,
1674+ both the input and output must have shape (n_samples, 1).
1675+ :type ode_system: callable
1676+ """
1677+
1678+ class Base (nn .Module ):
1679+ def __init__ (self ):
1680+ super ().__init__ ()
1681+ self .linear_1 = nn .Linear (1 , 10 )
1682+ self .linear_2 = nn .Linear (10 , 10 )
1683+ self .linear_3 = nn .Linear (10 , 10 )
1684+
1685+ def forward (self , x ):
1686+ x = self .linear_1 (x )
1687+ x = torch .tanh (x )
1688+ x = self .linear_2 (x )
1689+ x = torch .tanh (x )
1690+ x = self .linear_3 (x )
1691+ x = torch .tanh (x )
1692+ return x
1693+
1694+ def __init__ (self , diff_eqs , is_system = True ):
1695+
1696+ self .diff_eqs = diff_eqs
1697+ self .is_system = is_system
1698+
1699+ self .t_min = None
1700+ self .t_max = None
1701+ self .train_generator = None
1702+ self .valid_generator = None
1703+
1704+ def build (self ,u_0s = None ,
1705+ system_parameters = [{}],
1706+ BaseClass = Base ,
1707+ HeadClass = None ,
1708+ n_last_layer_head = 10 ,
1709+ build_source = False ,
1710+ optimizer = torch .optim .Adam ,
1711+ optimizer_args = None , optimizer_kwargs = {"lr" :1e-3 },
1712+ t_min = None ,
1713+ t_max = None ,
1714+ train_generator = None ,
1715+ valid_generator = None ,
1716+ n_batches_train = 1 ,
1717+ n_batches_valid = 4 ,
1718+ loss_fn = None ,
1719+ metrics = None ):
1720+
1721+ r"""
1722+ :param system_parameters:
1723+ List of dictionaries of parameters for which the solver will be trained
1724+ :type system_parameters: list[dict]
1725+ :param BaseClass:
1726+ Neural network class for base networks
1727+ :type nets: torch.nn.Module
1728+ :param n_last_layer_head:
1729+ Number of neurons in the last layer for each network
1730+ :type n_last_layer_head: int
1731+ :param build_source:
1732+ Boolean value for training the base networks or freezing their weights
1733+ :type build_source: bool
1734+ :param optimizer:
1735+ Optimizer to be used for training.
1736+ Defaults to a ``torch.optim.Adam`` instance that trains on all parameters of ``nets``.
1737+ :type optimizer: ``torch.nn.optim.Optimizer``, optional
1738+ :param t_min:
1739+ Lower bound of input (start time).
1740+ Ignored if ``train_generator`` and ``valid_generator`` are both set.
1741+ :type t_min: float, optional
1742+ :param t_max:
1743+ Upper bound of input (start time).
1744+ Ignored if ``train_generator`` and ``valid_generator`` are both set.
1745+ :type t_max: float, optional
1746+ :param train_generator:
1747+ Generator for sampling training points,
1748+ which must provide a ``.get_examples()`` method and a ``.size`` field.
1749+ ``train_generator`` must be specified if ``t_min`` and ``t_max`` are not set.
1750+ :type train_generator: `neurodiffeq.generators.BaseGenerator`, optional
1751+ :param valid_generator:
1752+ Generator for sampling validation points,
1753+ which must provide a ``.get_examples()`` method and a ``.size`` field.
1754+ ``valid_generator`` must be specified if ``t_min`` and ``t_max`` are not set.
1755+ :type valid_generator: `neurodiffeq.generators.BaseGenerator`, optional
1756+ :param n_batches_train:
1757+ Number of batches to train in every epoch, where batch-size equals ``train_generator.size``.
1758+ Defaults to 1.
1759+ :type n_batches_train: int, optional
1760+ :param n_batches_valid:
1761+ Number of batches to validate in every epoch, where batch-size equals ``valid_generator.size``.
1762+ Defaults to 4.
1763+ :type n_batches_valid: int, optional
1764+ :param loss_fn:
1765+ The loss function used for training.
1766+
1767+ - If a str, must be present in the keys of `neurodiffeq.losses._losses`.
1768+ - If a `torch.nn.modules.loss._Loss` instance, just pass the instance.
1769+ - If any other callable, it must map
1770+ A) a residual tensor (shape `(n_points, n_equations)`),
1771+ B) a function values tuple (length `n_funcs`, each element a tensor of shape `(n_points, 1)`), and
1772+ C) a coordinate values tuple (length `n_coords`, each element a tensor of shape `(n_coords, 1)`
1773+ to a tensor of empty shape (i.e. a scalar). The returned tensor must be connected to the computational graph,
1774+ so that backpropagation can be performed.
1775+
1776+ :type loss_fn:
1777+ str or `torch.nn.moduesl.loss._Loss` or callable
1778+ :param metrics:
1779+ Additional metrics to be logged (besides loss). ``metrics`` should be a dict where
1780+
1781+ - Keys are metric names (e.g. 'analytic_mse');
1782+ - Values are functions (callables) that computes the metric value.
1783+ These functions must accept the same input as the differential equation ``ode_system``.
1784+
1785+ :type metrics: dict[str, callable], optional
1786+ """
1787+
1788+ self .u_0s = u_0s
1789+ self .system_parameters = system_parameters
1790+ self .n_last_layer_head = n_last_layer_head
1791+
1792+ if t_min is not None :
1793+ self .t_min = t_min
1794+ if t_max is not None :
1795+ self .t_max = t_max
1796+
1797+ if self .t_min is not None and self .t_max is not None :
1798+ self .train_generator = Generator1D (32 , t_min = self .t_min , t_max = self .t_max , method = 'equally-spaced-noisy' )
1799+ self .valid_generator = Generator1D (32 , t_min = self .t_min , t_max = self .t_max , method = 'equally-spaced-noisy' )
1800+
1801+ if train_generator is not None :
1802+ self .train_generator = train_generator
1803+ if valid_generator is not None :
1804+ self .valid_generator = valid_generator
1805+
1806+ if self .u_0s is None :
1807+ raise Exception ("ICs must be specified" )
1808+ if self .train_generator is None or self .valid_generator is None :
1809+ raise Exception (f"Train and valid generators cannot be None. Either provide `t_min` and `t_max` \
1810+ or provide the generators as arguments" )
1811+
1812+ self .optimizer = optimizer
1813+ self .optimizer_args = optimizer_args or ()
1814+ self .optimizer_kwargs = optimizer_kwargs or {}
1815+
1816+ if build_source :
1817+ if self .is_system :
1818+ self .bases = [BaseClass () for _ in range (len (u_0s [0 ]))]
1819+ else :
1820+ self .bases = BaseClass ()
1821+
1822+ self .solvers_base = [_SingleSolver1D (
1823+ bases = self .bases ,
1824+ HeadClass = HeadClass ,
1825+ initial_conditions = self .u_0s [i ],
1826+ n_last_layer_head = n_last_layer_head ,
1827+ diff_eqs = self .diff_eqs ,
1828+ train_generator = self .train_generator ,
1829+ valid_generator = self .valid_generator ,
1830+ system_parameters = self .system_parameters [p ],
1831+ optimizer = optimizer ,optimizer_args = optimizer_args , optimizer_kwargs = optimizer_kwargs ,
1832+ n_batches_train = n_batches_train ,
1833+ n_batches_valid = n_batches_valid ,
1834+ loss_fn = loss_fn ,
1835+ metrics = metrics ,
1836+ is_system = self .is_system
1837+ ) for i in range (len (u_0s )) for p in range (len (self .system_parameters ))]
1838+ else :
1839+ self .solvers_head = [_SingleSolver1D (
1840+ bases = self .bases ,
1841+ HeadClass = HeadClass ,
1842+ initial_conditions = self .u_0s [i ],
1843+ n_last_layer_head = self .n_last_layer_head ,
1844+ diff_eqs = self .diff_eqs ,
1845+ train_generator = self .train_generator ,
1846+ valid_generator = self .valid_generator ,
1847+ system_parameters = self .system_parameters [p ],
1848+ optimizer = optimizer ,optimizer_args = optimizer_args , optimizer_kwargs = optimizer_kwargs ,
1849+ n_batches_train = n_batches_train ,
1850+ n_batches_valid = n_batches_valid ,
1851+ loss_fn = loss_fn ,
1852+ metrics = metrics ,
1853+ is_system = self .is_system
1854+ ) for i in range (len (self .u_0s )) for p in range (len (self .system_parameters ))]
1855+
1856+
1857+ def fit (self , epochs = 10 , freeze_source = True ):
1858+ r"""
1859+ :param epochs:
1860+ Number of epochs for training
1861+ :type epochs: int
1862+ :param freeze_source:
1863+ Boolean value indicating whether to freeze the base networks or not
1864+ :type freeze_source: bool
1865+ """
1866+
1867+ if not freeze_source :
1868+ for i in range (len (self .solvers_base )):
1869+ self .solvers_base [i ].fit (max_epochs = epochs )
1870+ else :
1871+ if self .is_system :
1872+ for net in self .bases :
1873+ for param in net .parameters ():
1874+ param .requires_grad = False
1875+ else :
1876+ for param in self .bases .parameters ():
1877+ param .requires_grad = False
1878+ for i in range (len (self .solvers_head )):
1879+ self .solvers_head [i ].fit (max_epochs = epochs )
1880+
1881+
1882+ def get_solution (self , base = False ):
1883+ r"""
1884+ :param base:
1885+ Boolean value indicating whether to get solutions for those conditions for which the base
1886+ was trained or solutions for those conditions for which only the last layer was trained
1887+ :type base: bool
1888+ :rtype: list[BaseSolution]
1889+ """
1890+
1891+ if base :
1892+ return [self .solvers_base [i ].get_solution () for i in range (len (self .solvers_base ))]
1893+ else :
1894+ return [self .solvers_head [i ].get_solution () for i in range (len (self .solvers_head ))]
0 commit comments