1. Introduction
In recent years, parametric Machine Learning (ML) applications have shown brilliant performance in capturing relevant symmetries and hidden patterns characterizing a specific knowledge base. Specifically, Neural Networks (NNs), i.e., systems of interconnected artificial neurons, constitute a fundamental tool to capture complex patterns and to make accurate predictions for various applications, ranging from computer vision and natural language processing to robotics and reinforcement learning. Their growing popularity has prompted an increasing demand for a deep mathematical description of the underlying training procedures, specifically in high dimensions to tackle the curse of dimensionality.
For this latter research challenge, we consider a novel class of NNs, termed Mean Field Neural Networks (MFNNs), which are defined as the limiting object of a population of NNs when its number of components tends to infinity. Our aim concerns deriving a unified perspective for this class of models based on existing symmetries between Mean Field Control (MFC) theory and the Optimal Transport (OT) method. Our approach is based on an infinite dimensional lifting which allows new insights to be gained into relationships between data in the corresponding finite-dimensional scenario.
We start the analysis by looking at the continuous idealization of a specific class of NNs, namely Residual NNs, also named ResNets, whose training process in a supervised learning scenario is stated as a Mean Field Optimal Control Problem (MFOCP). We consider a deterministic dynamic that evolves in terms of an ordinary differential equation (ODE). Moreover, the training problem of a ResNet is shown to be equivalent to an MFOCP of Bolza type, see [
1,
2] for further details.
The next passage in our analysis concerns introducing a noisy component into the dynamics of the ODE, moving to a Stochastic Differential Equation (SDE) that allows us to consider the inherent uncertainty connected to the variations in the real-world data, simultaneously allowing for the integration of stochastic aspects into the learning process. Although this second model does not include any mean field terms, it allows the development of a class of algorithms known as Stochastic NNs (SNNs). In [
3], the authors develop a sample-wise backpropagation method for SNNs based on backward SDE that models the gradient (with respect to the parameters) process of the loss function, representing a feasible tool for quantifying the uncertainty of the learning process. Another possible approach for probabilistic learning is studied in [
4], where the authors develop the so-called Stochastic Deep Network (SDN), namely an NN architecture that can use as input data not only single vectors but also
random vectors to model the probability distribution of given inputs. Following their analysis, the SDN is considered as an architecture based on the composition of maps between probability measures performing inference tasks and solving ML problems over the space of probability measures.
In the last passage, we merge the stochastic aspect with the mean field one by considering the so-called Mean Field Optimal Transport (MFOT) formulation, recently introduced in [
5]. We describe the MFC tools relevant to formalize the training process; hence, we formulate the training problem as MFOT in an infinite-dimensional setting. Considering the collective interactions and distributions of the network’s parameters may facilitate the analysis of the network behavior on a macroscopic level, hence improving the interpretability, scalability, and robustness of NNs models, while adding knowledge by highlighting the hidden symmetries and relations between data.
We highlight that the symmetry between mean field models and ML algorithms is also studied in [
6], where the authors establish a mathematical relationship between the MFG framework and normalizing flows, a popular method for generative models composed of a sequence of invertible mappings. Similarly, in [
7], the authors analyze Generative Adversarial Networks (GANs) from the perspectives of MFGs, providing a theoretical connection between GANs, OT, and MFG and numerical experiments.
This paper is organized as follows: In
Section 2, we introduce the mathematical formalism of the supervised learning paradigm while providing the description of the continuous idealization of a Residual NN stated as an MFOCP; in
Section 3, we introduce a noisy component into the network dynamic, thus focusing on Stochastic NNs formalized as stochastic optimal control problems; in
Section 4, we review the MFG setting in a cooperative scenario defined in terms of MFC theory. Then, we consider recently developed Mean Field Optimal Transport methods that allow MFC problems to be rephrased into OT ones. We also illustrate related approximation schemes and possible connection to an abstract class of NNs that respect the MFOT structure. We conclude by reviewing some methods to
learn, i.e., approximate, mean field functions that depend on probability distribution, obtained as the limiting object of empirical measures.
2. Residual Neural Networks as a Mean Field Optimal Control Problem
In this section, we present the workflow to treat a feed-forward NN, specifically a Residual NN, as a dynamical system based on the work in [
8]. The main reference for this part is the well-known paper in [
2], where the authors introduce a continuous idealization of Deep Learning (DL) to study the Supervised Learning (SL) procedure; this is stated as an optimal control problem by considering the associated population risk minimization problem.
2.1. The Supervised Learning Paradigm
Following [
9,
10], the SL problem aims at estimating the function
, commonly known as the Oracle. The space
can be identified with a subset of
related to input arrays (such as images, string texts, or time series), while
is the corresponding target set. Here, for simplicity, we consider
and
Euclidean spaces with different dimensions. Thus, training begins with a set of
N input–target pairs
where:
denotes the inputs of the NN;
denotes the outputs of the NN;
denotes the corresponding targets.
We assume the same dimension of the Euclidean space for NN inputs and outputs, allowing us to explicitly write a dynamic in terms of a difference equation. Hence, for a ResNet (see [
11] for more details) with
T layers, the feed-forward propagation is given by
with
being a parameterized function and
being the trainable parameters, e.g., bias, weights of the
t-th layer that belong to a measurable set
with values in a subset of the Euclidean space
.
Remark 1. Following [12], we report an example of a domain for parameters of NN with ReLU activation functions. We define the following parameter domainwith activation functions defined as 2.2. Empirical Risk Minimization
We aim at minimizing, over the set of measurable parameters
, a
terminal loss function plus a regularization term,
L, to derive a Supervised Learning problem as an Empirical Risk Minimization (ERP) problem, namely
over
N training data samples indexed by
i. We write
to identify the set of all parameters of the network.
If we consider no regularization of the parameters, i.e.,
, and a quadratic loss function in terms of
, then Equation (
2) reads
being
the discrete state process defined in Equation (
1).
Optimizing by computing its gradient is computationally expensive, especially if the amount of data K is very large.
To handle the curse of dimensionality, it is usually common to initialize parameters from a
from a probability distribution, to then optimize their choice inductively according to a Stochastic Gradient Descent scheme
with learning rate
over
K optimization steps.
For the sake of completeness, before going to the limit (we pass from a discrete set of training data to the corresponding distribution), we point out in the following remark that it is also possible to associate a measure corresponding to the empirical distribution of the parameters when the number of neurons goes to infinity.
Remark 2. A different approach, as illustrated, e.g., by Sirignano and Spilipupouls in [13], consists of associating to each layer the corresponding empirical measure and building a measure to describe the whole network, hence working with the empirical measure of controls, rather than states, as presented in Section 4. Following the perspective of mean field term in controls, the SGD Equation (4) can be formalized as a minimization method over the set of probability distributions. Moreover, the training of the NN is based on the correspondence between the empirical measure of neurons and the function that is approximated by the . Specifically, it has been proved that training via gradient descent of an over-parametrised one-hidden-layer NN with infinite width is equivalent to gradient flow in Wasserstein space [2,9,14,15]. Conversely, in the small learning rate regime, the training is equivalent to an SDE, see, e.g., [16]. From here on, we deal with empirical distribution and measures associated to the training data.
2.3. Population Risk Minimization as Mean Field Optimal Control Problem
In what follows, we move from the discrete setting to the corresponding continuous idealization by:
Going from layer index T to continuous parameter t;
Passing from a discrete set of inputs/output to a distribution that represents the joint distribution in , modeling the input label distribution;
Passing from empirical risk minimization to population risk (i.e., minimization over expectation ).
In particular, we pass to the limit in the number of data samples (number of input-target pairs), also assuming a continuous dynamic in place of layer discretization. The latter limit allows us to describe the dynamic of the state process
x with the following Ordinary Differential Equation (ODE)
in place of the finite difference Equation (
1). We identify the input–target pairs as sampled from a given distribution
allowing us to write the SL problem as a Population Risk Minimization (PRM) problem.
In summary, we aim at approximating the Oracle function using a provided set of training data sampled by a (known) distribution by optimizing weights to achieve maximal proximity between (output) and (target). Thus, we consider a probability space and we assume inputs in to be sampled from a distribution , with corresponding target in sampled from a distribution , while the joint probability distribution , which models the distribution of the input–target pairs, defined by , belongs to the Wasserstein space and has and as its marginals. We recall that given a metric space , the p-Wasserstein space is defined as the set of all Borel probability measures on X with finite p-moments.
The marginal distributions are obtained by projecting the joint probability distribution
over the subspaces of inputs and output, respectively. We identify the first marginal, i.e., the projection over
, with the distribution of inputs
while the distribution of targets reads
Moreover, we assume the controls
depend on the whole distribution of input–target pairs capturing the mean field aspect of the training data. We consider a measurable set of admissible controls, i.e., training weights,
and we state a Mean Field Optimal Control Problem (MFOCP) to solve the following PRM problem:
We briefly report basic assumptions allowing us to have a solution for (
6):
, , are bounded;
f, L, are Lipschitz-continuous with respect to x, with the Lipschitz constants of f and L being independent of parameters ;
has finite support in .
Problem (
6) can be approached through two different methods: the first one is based on the Hamilton–Jacobi–Bellman (HJB) equation in the Wasserstein space, while the second one is based on a Mean Field Pontryagin Principle. We refer to [
17,
18] for viscosity solutions to the HJB equation in the Wasserstein space of probability measures, and to [
19] for solving the constrained optimal control problems via the Pontryagin Maximum Principle.
For the sake of completeness, let us also cite [
20], where the authors introduce a BSDE technique to solve the related Stochastic Maximum Principle, allowing us to consider the uncertainty associated with NN. The authors employ a Stochastic Differential Equation (SDE) in place of the ODE appearing in (
6) to continuously approximate a Stochastic Neural Network (SNN). We deepen this approach in the next paragraph.
3. Stochastic Neural Network as a Stochastic Optimal Control Problem
In this paragraph, we generalize the previous setting considering a noisy dynamic, namely adding a stochastic integral to the deterministic setting described by the ODE in Problem (
6). The reference model corresponds to Stochastic NN whose discrete state process is described by the following equation
with
being a sequence of i.i.d. standard Gaussian random variables. We refer to [
4] for a theoretical and computational analysis of the SNN.
Equation (
7) can be generalized in a continuous setting. To this end, we consider a complete filtered probability space
, and we introduce the following SDE
with standard Brownian motion
and diffusion term
. Analogously to ResNets, the index
represents a continuous parameter modeling the width of the layer, with
being the output of the network.
Here, we report the theory developed in [
3] to study Equation (
8) in the framework of the SOC problem by introducing the control process
. Thus, we also consider the diffusion
as a trainable parameter of the model. We start by translating the SDE (
8) into the following controlled process, written in differential form
where
and
. As in classical control theory applied to ML, the aim is to select the control
u that minimizes the discrepancy between the SNN output and the data. Accordingly, we define the cost function for our stochastic optimal control problem as
with
being a random variable that corresponds to the target of a given input, i.e.,
. Then, the optimal control
is the one that solves
above the class of measurable control
.
At this point, we are able to write the optimization problem that represents the analogue of Equation (
6) with stochastic evolution (where the diffusion is also considered as a model parameter) but without reference to the mean field aspect of the learning procedure.
Following [
3], we address the Stochastic Maximum Principle approach to solve the stochastic optimal control problem stated in (
11). Firstly, the functional
J is differentiated with respect to the control with a derivative in Gateaux sense over
Then, via the martingale representation of
, the following backward SDE is introduced
to model the back-propagation of the forward state process equation defined in (
9) associated with the optimal control
.
Finally, the problem is solved via the gradient descent method with step size
Also in [
3], the authors provide a numerical scheme whose main benefit is to derive an estimate of the uncertainty connected to the output of this stochastic class of NNs.
We remark that for Equation (
14) it is not possible to write the chain rule as previously performed for Equation (
4) due to the presence of the stochastic integral term that, differently from classical ML theory, makes the back-propagation itself a stochastic process, see Equation (
13). However, modern programming libraries (e.g., TensorFlow or PyTorch) may perform the computation (
14) automatically, reducing the computational cost, hence allowing us to go towards a mean field formulation (in terms of multiple interacting agents) of previous problems.
4. Mean Field Neural Network as a Mean Field Optimal Transport
In this section, we focus on the connection between SOC and OT, highlighting potential symmetries specifically for a class of infinite-dimensional stochastic games.
4.1. Optimal Transport
As seen in
Section 3, SOC deals with finding the optimal control policy for a dynamic system in the presence of uncertainty. Conversely, OT theory focuses on finding the optimal map to transport from one distribution to another. More precisely, given two marginal distributions
and
, the classical OT problem in the Kantorovich formulation reads
where
c is a cost function and
corresponds to the set of couplings between
and
.
We focus on the setting where
and
are distributions computed on
, i.e.,
and
. The Monge formulation reads
where the infimum is computed over all measurable maps
with the pushforward constraint
.
The possibility to link a SOC problem, hence the related mathematical formulation of a specific learning procedure, to the corresponding OT formulation relies on lifting the SOC problem in a proper Wasserstein space. For example, considering the SOC problem introduced in (
11), the stochastic process
described by Equation (
9) can be viewed as a vehicle of mass transportation under an initial measure
.
We mention that there are also specific scenarios where the dynamics of the stochastic control problem can be interpreted as a mass transportation problem, provided that certain assumptions of functionals and cost are guaranteed. For example, in [
21,
22] and similarly in [
23], the authors focus on extending an OT problem into the corresponding SOC formulation for a cost, which depends on the drift and the diffusion coefficients of a continuous semimartingale and the minimization is run among all continuous semimartingales with given initial and terminal distributions.
For example, in [
22], the authors consider a special form for the cost function, namely
with
convex in
u proving its equivalence to a proper SOC problem based on the so-called graph property. Indeed, we can define an image measure as
mapping
x into
. Thus, for any measurable map
, the following equality between the two formulations holds:
Thus, models a probability measure on with marginals and .
For the problem stated in (
17), we know from [
24] that an optimal measure
always exists. Moreover, if the optimal measure
is supported by the graph of a measurable map, we say that the graph property holds; that is, if for any
optimal for (
15), there exists a set
satisfying
with
for some measurable mapping
that resembles the NN parameters introduced in
Section 2 and analogously
represents the corresponding output according to Equation (
1).
4.2. Mean Field Games
In the context of Mean Field Games (MFGs), i.e., stochastic games where a large number of agents interact and influence each other, the link between SOC and OT is particularly explicable, specifically according to the variational formulation of MFGs, which is directly linked to the dynamic formulation of OT by Benamou and Brenier, see, e.g., ref. [
25] for an in-depth analysis.
In
Section 2, we focus on deterministic evolution by means of Equation (
5) with the mean field interactions captured by the loss function as an expectation given a known joint measure
between the input and target in the corresponding Mean Field Optimal Problem (
22). On the other hand, in
Section 3, we introduce the stochastic process in Equation (
8) and state the learning problem as an SOC as shown in Equation (
10) without focusing on the interaction during the evolution but looking at just a single trajectory. Finally, the further natural step relies on extending the previous equation to a McKean–Vlasov setting where the dynamic of a random variable
X depends on the other
N random variables by the mean of the distribution in order to merge the two scenarios presented in
Section 2 and
Section 3 while extending the problem stated in (
10) by allowing the presence of a mean field term.
Indeed, instead of considering a single evolution as in Equation (
9), we introduce the following McKean–Vlasov SDE for
N particles/agents
with
being the initial states. We assume a measurable drift
, a constant diffusion
, and we define the empirical distribution
as
The main idea would be to model multiple SNNs and generalize the dynamic defined in (
9); including the dependence on a mean field term in the drift allows us to model the shared connections between the neurons of different SNNs.
At the limit , the population of SNNs corresponds to the evolution of a representative SNN, while the empirical measure tends to the probability measure m belonging to the Wasserstein space , i.e., the space of probability measures on with a finite second-order moment that captures a measure of interactions among SSNs.
More precisely, we introduce the following settings, which we need to define the solution of an MFG.
A finite time horizon ;
is the state space;
is the space of probability measure over ;
describes the agent state, the mean field term, and the agent control;
, and , provide the running and the terminal cost, respectively;
represents the drift function;
is the volatility of the state.
Definition 1 (MFG equilibrium). We consider an MFG problem with a given initial distribution . A Nash equilibrium is a flow of probability measures in plus a feedback control satisfying the following two conditions:
- 1.
minimizes over α:where solves the SDEwith W being a d-dimensional Brownian motion and having distribution ; - 2.
For all , is the probability distribution of .
4.3. Mean Field Control
Differently from MFG, where players are modeled as competitors, Mean Field Control (MFC) models a framework that considers a large population of agents aiming to cooperate and optimize individual objectives. In the MFC setting, each agent cost depends on a mean field term representing the average behavior of all agents. Accordingly, the solution of an MFC is defined in the following way:
Definition 2 (MFC optimum)
. Given , a feedback control is an optimal control for the MFC problem if it minimizes over α defined bywhere is the probability distribution of the law of , under the constraint that the process solves the following McKean–Vlasov-type SDE:with having distribution . We refer to [
26] for an extensive treatment of McKean–Vlasov control problems (
20).
By considering the joint optimization problem of the entire population, MFC enables the analysis of large-scale systems with cooperative agents and provides insights into the emergence of collective behavior. One possibility relies on stating the dynamic in Equation (
6) in terms of probability measures. For example, we can consider a continuity equation such as the Fokker–Planck equation to consider the evolution of the density function. Along this setting, we cite the measure theoretical approach for NeurODE developed in [
1], where the authors introduced a forward continuity equation in the space of measures with a constrained dynamic in the form of an ODE. Conversely, within the cooperative setting, we can also rely on a novel approach, named Mean Field Optimal Transport, introduced in [
5], which we explore in the next paragraph.
4.4. Mean Field Optimal Transport
Mean Field Optimal Transport deals with a framework where all the agents cooperate (such as in MFC) in order to minimize a total cost without terminal cost but with an additional constraint since also the final distribution is prescribed. We notice that the setting with fixed initial and terminal distributions resembles the one introduced in the Population Risk Minimization problem described in
Section 2. We follow the numerical scheme introduced in Section 3.1 in [
5] to approximate feedback controls, that is, we introduce the following model.
Definition 3 (Mean Field Optimal Transport)
. Let , describe the state space and denote by the set of square-integrable probability measures on . Let be the running cost, be the terminal cost, the drift function, and the non-negative diffusion. Given two distributions, and , the aim of MFOT is to compute the optimal feedback control minimizingwhere is the distribution of process , whose dynamics is given bywith and the prescribed initial and terminal distributions. This type of problem incorporates mean field interactions into the drift and the running cost. Furthermore, it encompasses classical OT as a special case by considering , , and .
The integration of MFC and OT allows us to both tackle the weight optimization problem in NN and to model the flow of information or mass between layers of neurons, while the optimal weights may be computed as the minimizers of the functional with respect to the controls
v
along all the trajectories
, where
is the set of admissible controls.
Thus, we look at the MFNN as a collection of identical, interchangeable, indistinguishable NNs where the dynamic of the representative agents is a generalization of an SNN (
7), allowing a dependence on the term
modeling the mean field interactions. By considering the MFNN dynamic as a population of interconnected NNs, we can employ mean field control to analyze the collective behavior and interactions of networks, accounting for their impact on the overall network performance.
To summarize, we are looking at this novel class of NN, i.e., MFNN, as the asymptotic configuration of NNs in a cooperative setting.
We remark that the representative agent does not know the mean field interaction terms, since it depends on the whole population, but an approximated version can be recursively learned. For example, in [
5], the authors present a different numerical scheme to solve MFOT:
Optimal control via direct approximation of controls v;
Deep Galerkin method for solving forward–backward systems of PDEs;
Augmented Lagrangian method with Deep Learning exploiting the variational formulation of MFOT and the primal/dual approach.
We briefly review the direct method (1) to approximate feedback-type controls via an optimal control formulation. The controls are assumed to be of feedback form and can be approximated by
where
is an increasing function. The idea is to use the function in Equation (
25) as a penalty for being far from the target distribution
as the terminal cost to embed the problem into the classical MFG/MFC literature. Intuitively, Equation (
25) corresponds to the infinite dimensional analogue of the loss function of the leveraged NN algorithm, where
is the final distribution that has to be as close as possible in the sense of the Wasserstein metric to the target distribution
.
In view of obtaining a numerically tractable version of the SDE (
23), one may consider a classical discretization Euler–Maruyama scheme, also requiring the set of controls
v to be restricted to the ones approximated by NNs
with parameters
. Moreover, approximating the mean field term
m by its finite dimensional counterpart, see Equation (
19), allows us to develop a stable numerical algorithm, see Section 3.1 in [
5] for further details, particularly with respect to the linked numerical implementation.
4.5. Other Approaches for Learning Mean Field function
For the sake of completeness, we also mention two different methods to deal with the approximation of the mean field function that can be used in parallel with MFOT:
The first data-driven approach, presented in [
27], has been considered to solve a stochastic optimal control problem, where the unknown model parameters were estimated in real time using a
direct filter method. This method involves transitioning from the Stochastic Maximum Principle to approximate the conditional probability density functions of the parameters given an observation, which is a set of random samples;
In [
28], the authors report a map that by operating over an appropriate classes of neural networks, specifically the
bin-density-based approximation and
cylindrical approximation, is able to reconstruct a mapping between the Wasserstein space of probability measures and an infinite dimensional function space on a similar setting to MFG.
5. Conclusions and Further Directions
In the present article, we provided a general overview of methods at the intersection of parametric ML, MFC, and OT. By assuming a dynamical system viewpoint, we considered the deterministic, ODE-based setting of the supervised learning problem, to then incorporate noisy components, allowing for the definition of stochastic NNs, hence introducing the MFOT approach. The latter, derived as the limit in the number of training data, recasts the classical learning process as a Mean Field Optimal Transport one. As a result, we gained a unified perspective on the parameter optimization process, characterizing ML models with a specified learning dynamic, within the framework of OT and MFC, which may allow high-dimensional data sets to be efficiently handled.
We empathise that the major limitation of MFOT (
22) concerns the fact that many of its convergence results, such as those related to corresponding forward–backward systems, still need to be verified. Nevertheless, it represents an indubitably fertile and stimulating research ground that should be enhanced since it permits the derivation of techniques that may significantly improve the robustness of algorithms, particularly when dealing with huge sets of training data that are potentially perturbed by random noise components, while also allowing hidden symmetries within data to be highlighted. The latter aspect is particularly interesting when dealing with intrinsically structured problems as, e.g., in the case of NLP tasks, see, e.g., [
29,
30].