欢迎光临小豌豆知识网!
当前位置:首页 > 电学技术 > 电通讯技术> 一种基于预测的联邦学习通信优化方法及系统独创技术28550字

一种基于预测的联邦学习通信优化方法及系统

2021-03-21 18:35:49

一种基于预测的联邦学习通信优化方法及系统

  技术领域

  本发明涉及联邦机器学习领域,更具体地,涉及一种基于预测的联邦学习通信优化方法及系统,用于解决联邦学习技术中终端用户/设备与云中心频繁传递更新参数所导致的高通信代价问题。

  背景技术

  机器学习作为人工智能领域的一个重要分支,被成功且广泛的应用于模式识别、数据挖掘和计算机视觉等各个领域。由于终端设备计算资源受限,目前对于机器学习模型的训练通常采用基于云的方式,在这种方式中,终端设备所收集的数据,如图片、视频,或者个人位置信息,必须上传至云中心集中完成模型的训练。然而,上传用户真实数据会泄露其隐私,出于隐私保护的考虑,终端用户不愿共享其隐私数据。从长远来看,这严重阻碍了机器学习技术的发展和应用。

  因此,为了保护终端用户的敏感数据,同时又不影响机器学习模型的训练,联邦学习应运而生。在联邦学习环境中,用户不用上传其敏感数据至云中心,而只需共享其本地模型参数更新,云中心通过与终端用户多次交互,迭代计算得到全局模型更新,既保护了用户的敏感数据,又得到了最终的可用模型。

  在联邦学习环境中,终端用户与云中心需要多轮交互才能获得目标精度的全局模型。那么,对于复杂的模型训练,如深度学习模型训练,每次模型更新可能包含数百万个参数,模型更新的高维性将耗费大量的通信成本,甚至成为一个模型训练瓶颈。此外,由于终端用户/设备的异构性,每个设备网络状态的不可靠性以及互联网连接速度的不对称性,如下传速度大于上载速度,导致终端用户上传更新参数的延迟,都会使模型训练瓶颈进一步恶化。

  目前,为了解决联邦学习的高通信代价问题,国内外研究学者纷纷对其进行了大量研究,并提出了许多有效的通信优化方法。根据其优化的目标,这些解决方法大致可划分为两类:一类是以减少终端用户与云中心通信轮数为目标;另一类是以减少终端用户与云中心通信量为目标。在以减少通信量为目标的方法中,通常对本地模型更新进行压缩、轻量化、知识蒸馏以及稀疏化等操作,使得上传的模型更加紧致,从而达到通信量减少的目的。然而,由于模型压缩通常会造成模型信息量的丢失,甚至不能保证模型收敛,因此,越来越多的研究学者开始研究以减少通信次数为目标的通信优化方法。

  主流的通信次数减少方法可划分为两类,一类是基于模型收敛的方法,另一类是基于重要性的方法。在基于模型收敛的方法中,通常采用增加本地模型训练迭代轮数、减少每轮本地训练batch块的大小或者修改联邦学习算法等方式加快模型学习速度,使得每次通信迭代上传的本地模型更新更有利于全局模型的收敛;另一类则是通过研究本地模型更新与全局模型更新的相关性或者计算本地模型更新对全局模型的重要性,选择重要的或者与全局模型收敛趋势相同的本地模型更新上传至云中心。虽然这两类方法能够从一定程度上提高联邦学习的通信效率,但它们仍然存在以下不足:基于模型收敛的方法,通常是以消耗更多的本地计算资源为代价,然而,在联邦学习环境中,终端通常是资源受限的异构设备,它们没有足够的计算资源来处理复杂模型的训练,因此,将该算法运用于实际场景的联邦通信优化具有一定的挑战性;基于重要性的方法中,本地更新的重要性或者相关性都是通过一个可调的阈值判断,且这个阈值的设置通常是基于最大化通信次数减少为目标,因此,这类算法由于大量本地更新没有被聚集,而导致严重的模型准确率降低。

  综合所述,为了弥补基于云训练所造成的用户敏感数据泄露以及模型可用性问题,联邦学习应运而生。然而由于模型训练参数的高维性以及联邦学习环境中网络的不可靠性,使得通信代价问题成为联邦学习中基础且重要的问题。虽然现有研究方法分别从减小通信量和通信轮数两个方面提出了许多有效的通信优化方法,但他们通常伴随着其他方面的不足,如需要消耗更多的本地计算资源或者严重降低训练模型的准确率,因此,为了更好的解决联邦学习的高通信代价问题,需要设计一种既不需要消耗更多本地计算资源,又能极大减少所需的通信轮数同时保证训练模型准确率的方法。

  基于上述背景,本发明提出了一种简单易实现的基于预测的联邦学习通信优化方法及系统,为联邦学习高通信代价问题的解决奠定基础。

  发明内容

  为了更有效的解决联邦学习的高通信代价问题,本发明提出了一种基于预测的联邦学习通信优化方法及系统。首先,初始化全局模型以及本发明中所需的全局变量,每个终端用户根据其本地数据进行本地模型训练,得到本地模型更新。随后,云中心分别根据每个终端用户的历史模型更新趋势,预测其本地模型更新。然后,通过计算每个终端用户采用预测更新时全局模型损失函数的变化,设置其预测误差阈值,其中包括初始阈值和动态阈值设置两个步骤。最后,根据设置的预测误差阈值设计全局模型更新策略,云中心采用准确的预测更新代替本地模型更新计算全局模型更新。

  本发明提出的基于预测的联邦学习通信优化方法及系统,包括以下步骤,

  步骤S1,云中心初始化。包括搭建训练模型、初始化全局模型以及所需的全局变量,包括以下子步骤:

  步骤S1-1,用于搭建训练模型,其主要包括输入层、隐藏层以及输出层的神经元个数设计。

  步骤S1-2,用于初始化全局模型,其主要包括全局模型参数W0,全局模型更新G0。

  步骤S1-3,用于初始化全局变量,其主要包括用户集合U={u1,u2,...,uj,...un},通信轮数R。

  步骤S2,本地模型训练。集合U中每个终端用户根据其本地数据并行地进行本地模型训练,得到本地模型更新,以用户uj为例,包括以下子步骤:

  步骤S2-1,用于从云中心获取聚集的全局模型参数Wt。

  步骤S2-2,用于根据本地数据进行本地模型训练,得到用户uj在第t轮迭代的本地模型更新Lj,t。

  重复步骤S2,得到集合U中所有终端用户的本地更新集合L={L1,t,L2,t,...,Lj,t,Ln,t}。

  步骤S3,本地更新预测。预测集合U中每个终端用户的本地模型更新,以集合U中用户uj为例,包括以下子步骤:

  步骤S3-1,用于从云中心获取用户uj的历史参数更新集合Hj,计算用户uj在第t-1轮迭代的一步预测更新,其中,Hj=<Hj,1,Hj,2,…,Hj,i,Hj,t-1>,k表示更新参数的维度。下面以用户uj第d维更新参数为例,假设由用户uj第d维更新参数组成的历史参数更新集合为则用户uj第d维更新参数的一步预测更新值可表示为:

  

  其中,f表示状态转移矩阵,b表示控制矩阵,表示用户uj的第d维更新参数在第t-2轮迭代的预测更新值。

  步骤S3-2,用于计算第t-1轮迭代的状态协方差矩阵mt-1,其计算公式如(2)所示:

  mt-1=f*mt-2*fT+q(2)

  其中,q为预测噪声,fT为状态转移矩阵f的转置。

  步骤S3-3,用于计算第t-1轮迭代的卡尔曼增益zt-1,其计算公式如(3)所示:

  

  其中,r表示本地更新协方差,c表示转换矩阵。

  步骤S3-4,用于计算第d维更新参数在第t轮迭代的预测更新值计算公式如(4)所示:

  

  其中,表示第t-1轮迭代的预测偏差。

  步骤S3-5,用于更新第t轮迭代的状态协方差矩阵mt,更新公式如(5)所示:

  mt=(1-zt-1*c)*mt-1(5)

  重复步骤S3,并行计算得到集合U中所有终端用户当前迭代的预测更新集合Pt,其中,Pt=<P1,t,P2,t,...,Pj,t,…,Pn,t>,Pj,t表示用户uj的预测更新,k表示更新参数的维度。

  步骤S4,计算全局损失函数的变化。并行计算集合U中每个终端用户在第t-1轮迭代采用其预测更新,全局模型的损失函数变化,以用户uj为例,包括以下子步骤:

  步骤S4-1,用于从云中心获取用户uj第t-1轮迭代的预测更新Pj,t-1以及U中所有终端用户的本地更新集合Lt-1。

  步骤S4-2,用于检查标记变量Checkj,若Checkj=true,则进入步骤S4-3;反之,若Checkj=false,则进入步骤S5-2。

  步骤S4-3,用于计算用户uj第t-1轮迭代采用预测更新Pj,t-1时全局模型更新Gj,t-1、全局模型Wj,t-1,集合U中所有用户采用本地更新时全局更新Gall,t-1、全局模型Wall,t-1以及全局模型损失函数变化e,具体计算公式如(6)、(7)、(8)、(9)、(10)所示:

  

  其中,L-j,t-1表示非用户uj第t-1轮迭代的本地模型更新。

  

  用户uj第t-1轮迭代的全局模型Wj,t-1和Wall,t-1计算公式分别如(8)、(9)所示:

  Wj,t-1=Wt-1-Gj,t-1(8)

  Wall,t-1=Wt-1-Gall,t-1(9)

  进一步,全局模型损失函数变化e的计算公式如(10)所示:

  

  其中,f(·)表示损失函数,|·|表示绝对值。

  步骤S4-4,用于比较e与预先设定阈值δ的大小,若e≤δ,则进入步骤S5-1,设置Checkj=false,变量Tj=Tj+1;反之,若e>δ,则进入步骤S4-5。

  步骤S4-5,用于上传本地更新Lj,t至云中心,设置通信轮数R=R+i,同时为了得到更加准确的预测更新,添加Lj,t至预测资源池,模型训练进入下一轮迭代。

  步骤S5,设置预测误差阈值。设置用户的预测误差阈值,以用户uj为例,包括以下子步骤,

  步骤S5-1,用于设置预测误差初始阈值vj,0,其具体计算公式如下:

  vj,0=||Pj,t-1-Lj,t-1||(11)

  其中,||·||表示两个向量的内积。

  步骤S5-2,用于设置用户uj第t轮迭代的预测误差阈值vj,t,其具体计算公式如下:

  

  步骤S6,全局模型更新策略。制定全局模型更新策略,以用户uj为例,包括以下子步骤:

  步骤S6-1,用于计算第t轮迭代的预测更新误差Δj,t,其具体公式如下:

  Δj,t=||Pj,t-Lj,t||(13)

  步骤S6-2,比较Δj,t与vj,t的大小,若Δj,t≤vj,t,表示预测更新准确,则进入步骤S6-3;反之,若Δj,t>vj,t,表示预测参数不准确,则进入步骤S6-4。

  步骤S6-3,云中心采用用户uj的预测更新Pj,t进行全局模型聚集,模型训练进入下一轮迭代。

  步骤S6-4,用于上传本地更新Lj,t至云中心,设置通信轮数R=R+1,同时为了得到更加准确的预测更新,添加Lj,t至预测资源池,模型训练进入下一轮迭代。

  步骤S7,云中心全局模型更新。云中心聚集上传的本地更新与云中心准确的预测更新,计算得到聚集的全局模型更新和全局模型,模型训练进入下一轮迭代。

  重复以上步骤S1~S7,直至全局模型收敛,模型训练结束。

  同时,本发明还相应提供了一种基于预测的联邦学习通信优化系统,如图4所示,包含:

  初始化模块,用户初始化全局模型和全局变量,包含以下子模块,

  训练模型构建子模块,用于搭建训练模型,主要包括输入层、隐藏层以及输出层的神经元个数设计。

  全局模型初始化子模块,用于初始化全局模型参数W0和全局模型更新G0。

  全局变量初始化子模块,用于初始化全局模型训练过程所需的变量,如通信轮数R。

  本地模型训练模块,用于训练本地模型,得到本地模型更新,以用户uj为例,包含以下子模块,

  全局模型输入子模块,用于从云中心获取聚集的全局模型参数Wt。

  模型训练子模块,用于根据本地数据并行地进行本地模型训练,得到用户uj在第t轮迭代的本地模型更新Lj,t。

  并行训练模块,用于并行执行以上子步骤,得到集合U中所有终端用户的本地更新集合L={L1,t,L2,t,…,Lj,t,Ln,t}。

  本地更新预测模块,用于预测集合U中每个终端用户的本地模型更新,以集合U中用户uj为例,包含以下子模块,

  历史更新输入子模块,用于从云中心获取终端用户的历史模型更新集合Hj。

  中间变量子模块,用于存储中间步骤所计算得到的中间变量值,如一步预测更新值状态协方差矩阵mt-1,卡尔曼增益zt-1,预测更新值状态协方差矩阵mt。

  预测更新输出子模块,用于下发云中心的预测更新Pj,t至终端用户uj。

  并行预测子模块,用于并行执行以上子步骤,预测得到集合U中所有终端用户当前迭代的预测更新集合Pt,其中,Pt=<P1,t,P2,t,…,Pj,t,…,Pn,t>。

  预测误差阈值设置模块,用于设置终端用户的预测误差阈值,包括预测误差初始阈值和动态阈值设置两个步骤,以集合U中用户uj为例,包含以下子模块,

  变量判断子模块,用于判断终端用户是否已经设置预测误差初始阈值,若标记变量Checkj=true,表示用户uj未设置预测误差初始阈值,则进入全局损失函数变化计算子模块;反之,进入预测误差动态阈值设置子模块。

  全局损失函数变化计算子模块,用于计算终端用户前一轮迭代采用预测更新,全局模型损失函数的变化e。

  损失函数判断子模块,用于比较全局模型损失函数变化e与预先设定阈值δ的大小,若e≤δ,则进入预测误差初始阈值设置子模块。

  预测误差初始阈值设置子模块,用于设置终端用户的预测误差初始阈值vj,0。

  预测误差动态阈值设置子模块,用于设置终端用户当前迭代的预测误差阈值vj,t。

  并行设置子模块,用于并行执行以上子步骤,为每个用户设置当前迭代的预测误差阈值。

  全局模型更新策略模块,用于设计全局模型更新方式,以用户uj为例,包含以下子模块,

  预测误差阈值输入子模块,用于获取终端用户当前迭代的预测误差阈值vj,t。

  变量判断子模块,用于判断终端用户是否设置预测误差阈值,若已设置,则进入预测误差计算子模块,反之,进入本地更新上传子模块。

  预测误差计算子模块,用于计算终端用户当前迭代的预测更新误差Δj,t。

  预测准确性判断子模块,用于比较终端用户的预测误差Δj,t与预测误差阈值vj,t的大小,若Δj,t>vj,t,则进入本地更新上传子模块。

  本地更新上传子模块,用于上传终端用户的本地模型更新Lj,t至云中心及预测资源池。

  通信轮数计算及输出子模块,用于计算和输出模型训练的通信轮数R。

  云中心全局模型更新模块,用于计算全局模型更新和判定训练模型是否收敛,包含以下子模块,

  全局模型更新子模块,用于聚集上传的本地模型更新和云中心准确的预测更新,计算得到全局模型更新Gt和全局模型Wt,模型训练进入下一轮更新迭代。

  终止判定子模块,用于判定训练模型是否收敛,若收敛,则模型训练结束;反之,进入下一轮训练迭代。

  本发明根据本地模型的历史更新趋势,预测本地模型更新,然后通过计算全局模型损失函数变化设置预测误差阈值,并根据设置的预测误差阈值设计全局模型更新策略,云中心采用准确的预测更新代替本地模型更新计算全局模型更新,解决了联邦学习技术中,终端用户与云中心频繁传递更新参数所导致的高通信代价问题,与现有技术相比,具有以下有益效果:

  (1)本发明所提方法及系统不仅可以极大减少终端用户与云中心的通信轮次,而且可以极小降低训练模型的准确率;

  (2)由于本发明将本地模型更新的预测放置在资源丰富的云中心,而终端用户只需进行简单的预测准确性判断,因此,可以消耗极少的本地计算资源;

  (3)本发明的本地模型更新预测部分采用卡尔曼滤波预测,由于卡尔曼滤波能够对数据进行实时处理,具有较好的预测效果,且便于计算机编程实现,因此采用卡尔曼滤波预测不仅可以获得准确的本地模型更新预测,而且可以进一步降低计算复杂度,便于算法高效实施。

  附图说明

  图1是本发明实施例提供的总体方法流程图。

  图2是本发明实施例提供的具体步骤流程图。

  图3是本发明实施例提供的总体原理示意图。

  图4是本发明实施例基于预测的联邦学习通信优化系统模块设计示意图。

  具体实施方式

  以下将结合附图及实施例,对本发明的构思、具体结构及产生的技术效果作进一步说明,以充分地了解本发明的目的、特征和效果。

  下面以100名终端用户联合训练线性回归模型为例,来阐述本发明的具体实施步骤。记线性回归模型的表达式为其中,|k|表示训练样本的数目,W表示训练模型参数向量,X表示训练样本的特征向量。

  本发明技术方案所提供方法可采用计算机软件技术实现自动运行流程,图1是本发明实施例的总体方法流程图,参见图1,结合图2本发明实施例的具体步骤流程图,本发明基于预测的联邦学习通信优化方法及系统的实施例具体步骤包括:

  步骤S1,云中心初始化。包括搭建训练模型、初始化全局模型以及所需的全局变量,包括以下子步骤:

  步骤S1-1,用于搭建训练模型,其主要包括输入层、隐藏层以及输出层的神经元个数设计。

  实施例中,模拟输入层和输出层分别为784和1个神经元节点的线性回归模型。

  步骤S1-2,用于初始化全局模型,其主要包括全局模型参数W0,全局模型更新G0。

  实施例中,初始化全局模型参数W0,全局模型更新G0。

  步骤S1-3,用于初始化全局变量,其主要包括用户集合U={u1,u2,...,uj,...un},通信轮数R。

  实施例中,初始化终端用户集合U={u1,u2,...,u100},通信轮数R=0。

  步骤S2,本地模型训练。集合U中每个终端用户根据其本地数据并行地进行本地模型训练,得到本地模型更新,以用户uj为例,包括以下子步骤:

  步骤S2-1,用于从云中心获取聚集的全局模型参数Wt。

  实施例中,假定当前迭代轮次t=4,则终端用户uj从云中心获取聚集的全局模型参数W4。

  步骤S2-2,用于根据本地数据进行本地模型训练,得到用户uj在第t轮迭代的本地模型更新Lj,t。

  实施例中,终端用户uj根据其本地数据进行本地模型训练,得到本地模型更新L100,4。

  重复步骤S2,得到集合U中所有终端用户的本地更新集合L={L1,4,L2,4,...,Lj,4,L100,4}。

  步骤S3,本地更新预测。预测集合U中每个终端用户的本地模型更新,以集合U中用户uj为例,包括以下子步骤:

  步骤S3-1,用于从云中心获取用户uj的历史参数更新集合Hj,计算用户uj在第t-1轮迭代的一步预测更新,其中,Hj=<Hj,1,Hj,2,…,Hj,i,Hj,t-1>,k表示更新参数的维度。下面以用户uj第d维更新参数为例,假设由用户uj第d维更新参数组成的历史参数更新集合为则用户uj第d维更新参数的一步预测更新值可表示为:

  

  其中,f表示状态转移矩阵,b表示控制矩阵,表示用户uj的第d维更新参数在第t-2轮迭代的预测更新值。

  实施例中,从云中心获取终端用户u100的历史参数更新集合H100=<H100,1,H100,2,H100,3>,以用户u100第784维更新参数为例,由终端用户u100第784维参数组成的历史参数更新集合为设置公式(1)中f=1,b=0,根据公式(1)计算得到终端用户u100的第784维参数的一步预测更新值

  步骤S3-2,用于计算第t-1轮迭代的状态协方差矩阵mt-1,其计算公式如(2)所示:

  mt-1=f*mt-2*fT+q(2)

  其中,q为预测噪声,fT为状态转移矩阵f的转置。

  实施例中,设置q=0.001,根据公式(2)计算第3lun迭代的状态协方差矩阵m3=m2+q=m3=m2+0.001。

  步骤S3-3,用于计算第t-1轮迭代的卡尔曼增益zt-1,其计算公式如(3)所示:

  

  其中,r表示本地更新协方差,c表示转换矩阵。

  实施例中,设置c=1,r=0.042,根据公式(3)计算第3轮迭代的卡尔曼增益

  步骤S3-4,用于计算第d维更新参数在第t轮迭代的预测更新值计算公式如(4)所示:

  

  其中,表示第t-1轮迭代的预测偏差。

  实施例中,根据公式(4)计算终端用户u100第784维参数在当前迭代轮次t=4的预测更新

  步骤S3-5,用于更新第t轮迭代的状态协方差矩阵mt,更新公式如(5)所示:

  mt=(1-zt-1*c)*mt-1(5)

  实施例中,根据公式(5)更新当前迭代轮次t=4的状态协方差矩阵m4=(1-z3)*m3。

  重复步骤S3,并行计算得到集合U中所有终端用户当前迭代t=4的预测更新向量集合P4,其中,P4=<P1,4,P2,4,…,Pj,4,…,P100,4>,P100,4表示用户u100的预测更新,784表示更新向量的维度大小。

  步骤S4,计算全局损失函数的变化。并行计算集合U中每个终端用户在第t-1轮迭代采用其预测更新,全局模型的损失函数变化,以用户uj为例,包括以下子步骤:

  步骤S4-1,用于从云中心获取用户uj第t-1轮迭代的预测更新Pj,t-1以及U中所有终端用户的本地更新集合Lt-1。

  实施例中,从云中心获取终端用户u100第3轮迭代的预测更新P100,3以及U中所有终端用户的本地更新集合L3。

  步骤S4-2,用于检查标记变量Checkj,若Checkj=true,则进入步骤S4-3;反之,若Checkj=false,则进入步骤S5-2。

  实施例中,检查终端用户u100的标记变量Check100,若Check100=true,进入步骤S4-3;反之,若Check100=false,则进入步骤S5-2。

  步骤S4-3,用于计算用户uj第t-1轮迭代采用预测更新Pj,t-1时全局模型更新Gj,t-1、全局模型Wj,t-1,集合U中所有用户采用本地更新时全局更新Gall,t-1、全局模型Wall,t-1以及全局模型损失函数变化e,具体计算公式如(6)、(7)、(8)、(9)、(10)所示:

  

  其中,L-j,t-1表示非用户uj第t-1轮迭代的本地模型更新。

  

  用户uj第t-1轮迭代的全局模型Wj,t-1和Wall,t-1计算公式分别如(8)、(9)所示:

  Wj,t-1=Wt-1-Gj,t-1(8)

  Wall,t-1=Wt-1-Gall,t-1(9)

  进一步,全局模型损失函数变化e的计算公式如(10)所示:

  

  其中,f(·)表示损失函数,|·|表示绝对值。

  实例中,根据公式(6)和公式(7)计算第3轮迭代,终端用户u100采用预测更新P100,3时全局模型更新G100,3,以及集合U中所有终端采用本地更新时全局更新Gall,3,并根据公式(8)和(9)分别计算得到终端用户u100第3轮迭代的全局模型W100,3和Wall,3,根据公式(10)计算全局模型损失函数变化

  步骤S4-4,用于比较e与预先设定阈值δ的大小,若e≤δ,则进入步骤S5-1,设置Checkj=false,变量Tj=Tj+1;反之,若e>δ,则进入步骤S4-5。

  实施例中,设置δ=0.01,比较e与阈值δ的大小,若e≤0.01,则进入步骤S5-1,置Check100=false,T100=T100+1;反之,若e>0.01,则进入步骤S4-5。

  步骤S4-5,用于上传本地更新Lj,t至云中心,设置通信轮数R=R+1,同时为了得到更加准确的预测更新,添加Lj,t至预测资源池,模型训练进入下一轮迭代。

  实施例中,上传终端用户u100的本地更新L100,4至云中心及预测资源池,设置通信轮数R=R+1,模型训练进入下一轮迭代。

  步骤S5,设置预测误差阈值。设置用户的预测误差阈值,以用户uj为例,包括以下子步骤,

  步骤S5-1,用于设置预测误差初始阈值vj,0,其具体计算公式如下:

  vj,0=||Pj,t-1-Lj,t-1||(11)

  其中,||·||表示两个向量的内积。

  实施例中,设置终端用户u100的预测误差初始阈值v100,0=||P100,3-L100,3||。

  步骤S5-2,用于设置用户uj第t轮迭代的预测误差阈值vj,t,其具体计算公式如下:

  

  实施例中,设置终端用户u100在当前迭代轮次t=4的预测误差阈值

  步骤S6,全局模型更新策略。制定全局模型更新策略,以用户uj为例,包括以下子步骤:

  步骤S6-1,用于计算第t轮迭代的预测更新误差Δj,t,其具体公式如下:

  Δj,t=||Pj,t-Lj,t||(13)

  实施例中,计算终端用户u100当前迭代轮次t=4的预测误差Δ100,4=||P100,4-L100,4||。

  步骤S6-2,比较Δj,t与vj,t的大小,若Δj,t≤vj,t,表示预测更新准确,则进入步骤S6-3;反之,若Δj,t>vj,t,表示预测参数不准确,则进入步骤S6-4。

  实施例中,比较终端用户u100当前迭代轮次t=4的预测误差Δ100,4与设定的预测误差阈值v100,4的大小,若Δ100,4≤v100,4,则进入步骤S6-3;反之,Δ100,4>v100,4,则进入步骤S6-4。

  步骤S6-3,云中心采用用户uj的预测更新Pj,t进行全局模型聚集,模型训练进入下一轮迭代。

  实施例中,云中心采用终端用户u100的预测更新P100,4进行全局模型聚集,模型训练进入下一轮迭代。

  步骤S6-4,用于上传本地更新Lj,t至云中心,设置通信轮数R=R+1,同时为了得到更加准确的预测更新,添加Lj,t至预测资源池,模型训练进入下一轮迭代。

  实施例中,上传终端用户u100的本地更新L100,4至云中心及预测资源池,设置通信轮数R=R+1,模型训练进入下一轮迭代。

  步骤S7,云中心全局模型更新。云中心聚集上传的本地更新与云中心准确的预测更新,计算得到聚集的全局模型更新和全局模型,模型训练进入下一轮迭代。

  实施例中,云中心聚集上传的本地更新与云中心准确的预测更新,计算得到全局模型更新Gt和全局模型Wt,模型训练进入下一轮更新迭代。

  重复以上步骤S1~S7,直至全局模型收敛,模型训练结束。

  本发明提供了本领域技术人员能够实现的技术方案。以上实施例仅供说明本发明之用,而非对本发明的限制,有关技术领域的技术人员,在不脱离本发明的精神和范围的情况下,还可以做出各种变换或变型,因此所有等同的技术方案,都落入本发明的保护范围。

《一种基于预测的联邦学习通信优化方法及系统.doc》
将本文的Word文档下载到电脑,方便收藏和打印
推荐度:
点击下载文档

文档为doc格式(或pdf格式)