Learning with tree tensor networks: complexity estimates and model selection

by   Bertrand Michel, et al.

In this paper, we propose and analyze a model selection method for tree tensor networks in an empirical risk minimization framework. Tree tensor networks, or tree-based tensor formats, are prominent model classes for the approximation of high-dimensional functions in numerical analysis and data science. They correspond to sum-product neural networks with a sparse connectivity associated with a dimension partition tree T, widths given by a tuple r of tensor ranks, and multilinear activation functions (or units). The approximation power of these model classes has been proved to be near-optimal for classical smoothness classes. However, in an empirical risk minimization framework with a limited number of observations, the dimension tree T and ranks r should be selected carefully to balance estimation and approximation errors. In this paper, we propose a complexity-based model selection strategy à la Barron, Birgé, Massart. Given a family of model classes, with different trees, ranks and tensor product feature spaces, a model is selected by minimizing a penalized empirical risk, with a penalty depending on the complexity of the model class. After deriving bounds of the metric entropy of tree tensor networks with bounded parameters, we deduce a form of the penalty from bounds on suprema of empirical processes. This choice of penalty yields a risk bound for the predictor associated with the selected model. For classical smoothness spaces, we show that the proposed strategy is minimax optimal in a least-squares setting. In practice, the amplitude of the penalty is calibrated with a slope heuristics method. Numerical experiments in a least-squares regression setting illustrate the performance of the strategy for the approximation of multivariate functions and univariate functions identified with tensors by tensorization (quantization).


Please sign up or login with your details

Forgot password? Click here to reset