【技术实现步骤摘要】
用户分类模型的训练方法、装置、电子设备及存储介质
本专利技术涉及数据处理
,更具体地,涉及一种用户分类模型的训练方法、装置、电子设备及存储介质。
技术介绍
在数据处理领域,当有大量的类别标签时,使用有监督的深度神经网络可以有效地进行分类模型的训练和预测,但是,在实际应用中,对类别标签的标注需要大量的人力资源和时间。并且,在有新的数据集时,需要重新花费大量的人力和时间去标注新的数据集,这显然不是高效的方法。目前,存在着大量的已标注好的数据,和新的数据集具有相同的标签分布,因此,如何使用迁移学习利用已标注好的数据对新的具有较少类别标签数据集进行预测具有重要的意义。
技术实现思路
有鉴于此,本专利技术实施例提供一种用户分类模型的训练方法、装置、电子设备及存储介质,以通过随机噪声掩盖源域数据集的部分特征,减小源域数据集与目标域数据集特征分布差异,从而提高用户分类模型的准确性。第一方面,本专利技术实施例提供一种用户分类模型的训练方法,所述方法包括:获取目标域数据集、源域数据集以及随机噪声,所述 ...
【技术保护点】
1.一种用户分类模型的训练方法,其特征在于,所述方法包括:/n获取目标域数据集、源域数据集以及随机噪声,所述源域数据集包括具有标签的多个源域样本数据,所述目标域数据集包括多个目标域样本数据,其中部分目标域样本数据具有标签;其中,所述源域样本数据包括对应用户在预定时间范围内的历史任务记录,所述目标域样本数据中不存在预定时间范围内的历史任务记录;/n将各所述源域样本数据和所述随机噪声输入至所述用户分类模型的掩码生成网络,确定所述源域数据集对应的掩码数据集;/n根据所述目标域数据集、所述源域数据集和所述掩码数据集训练所述用户分类模型中的特征生成网络、分类网络和域判别网络;/n响应 ...
【技术特征摘要】
1.一种用户分类模型的训练方法,其特征在于,所述方法包括:
获取目标域数据集、源域数据集以及随机噪声,所述源域数据集包括具有标签的多个源域样本数据,所述目标域数据集包括多个目标域样本数据,其中部分目标域样本数据具有标签;其中,所述源域样本数据包括对应用户在预定时间范围内的历史任务记录,所述目标域样本数据中不存在预定时间范围内的历史任务记录;
将各所述源域样本数据和所述随机噪声输入至所述用户分类模型的掩码生成网络,确定所述源域数据集对应的掩码数据集;
根据所述目标域数据集、所述源域数据集和所述掩码数据集训练所述用户分类模型中的特征生成网络、分类网络和域判别网络;
响应于所述特征生成网络、分类网络和域判别网络对应的损失函数满足预定条件,确定所述用户分类模型;
其中,所述特征生成网络用于获取所述目标域样本数据和源域样本数据的特征向量,所述分类网络用于根据所述目标域样本数据和源域样本数据的特征向量、所述源域样本数据的标签、部分所述目标域样本数据的标签确定对应的标签预测值,所述域判别网络用于使得所述目标域数据集与所述源域数据集进行特征对齐。
2.根据权利要求1所述的方法,其特征在于,所述标签用于表征用户在未来预定时间范围内的任务状态,所述任务状态包括用户在未来预定时间范围内会执行任务、以及用户在未来预定时间范围内不会执行任务。
3.根据权利要求1所述的方法,其特征在于,根据所述目标域数据集、所述源域数据集和所述掩码数据集训练所述用户分类模型中的特征生成网络、分类网络和域判别网络包括:
将所述目标域数据集、所述源域数据集和所述掩码数据集输入至所述特征生成网络进行处理,确定各所述目标域样本数据和源域样本数据的特征向量;
将各所述目标域样本数据和源域样本数据的特征向量输入至所述域判别网络中进行处理,确定对应的特征分布;
根据所述特征分布确定所述域判别网络对应的对抗损失;
将各所述目标域样本数据和源域样本数据的特征向量输入至所述分类网络,确定各所述目标域样本数据和源域样本数据的标签预测值;
根据具有标签的所述目标域样本数据的标签及对应的标签预测值、源域样本数据的标签及对应的标签预测值确定所述分类网络对应的分类损失;
根据具有标签的所述目标域样本数据的特征向量及对应的标签、源域样本数据的特征向量及对应的标签确定对应的对比损失;
根据所述对抗损失、所述分类损失和对比损失调节所述特征生成网络、分类网络和域判别网络的参数。
4.根据权利要求3所述的方法,其特征在于,根据所述对抗损失、所述分类损失和对比损失调节所述特征生成网络、分类网络和域判别网络的参数包括:
保持所述特征生成网络和所述分类网络的参数,调节所述域判别网络的参数以最大化所述对抗损失;
保持所述域判别网络的参数,调节所述特征生成网络和所述分类网络的参数以最小化所述分类损失、对比损失和对抗损失。
5.根据权利要求3所述的方法,其特征在于,所述特征生成网络包括第一特征生成子网络、第二特征生成子网络和第三特征生成子网络;
所述第一特征生成子网络用于根据所述源域样本数据和对应的掩码数据生成所述源域样本数据的特征向量;
所述第二特征生成子网络用于根据所述目标域样本数据生成所述目标域样本数据的特征向量;
所述第三特征生成子网络用于对所述源域样本数据的特征向量和所述目标域样本数据的特征向量进行特征处理,获取预定维度的所述源域样本数据的特征向量和所述目标域样本数据的特征向量。
6.根据权利要求5所述的方法,其特征在于,所述特征生成网络包括多个第三特征生成子网络,多个所述第三特征生成子网络权值共享。
7.根据权利要求5所述的方法,其特征在于,所述第一特征生成子网络具有对应的第一解码器和第一自编码损失,所述第二特征生成子网络具有对应的第二解码器和第二自编码损失;
根据所述目标域数据集、所述源域数据集和所述掩码数据集训练所述用户分类模型中的特征生成网络、分类网络和域判别网络还包括:
根据所述第一特征生成子网络的输入值和所述第一解码器的输出值计算所述第一自编码损失;
根据所述第二特征生成子网络的输入值和所述第二解码器的输出值计算所述第二自编码损失;
根据所述第一自编码损失和所述第二自编码损失确定所述特征生成网络的损失;
根据所述特征生成网络的损失、所述对抗损失、所述分类损失和对比损失调节所述特征生成网络、分类网络和域判别网络的参数。
8.根据权利要求7所述的方法,其特征在于,根据所述特征生成网络的损失、所述对抗损失、所述分类损失和对比损失调节所述特征生成网络、分类网络和域判别网络的参数包括:
保持所述特征生成网络和所述分类网络的参数,调节所述域判别网络的参数以最大化所述对抗损失;
保持所述域判别网络的参数,调节所述特征生成网络和所述分类网络的参数以最小化所述特征生成网络的损失、分类损失、对比损失和对抗损失。
9.一种用户分类模型的训练装置,其特征在于,所述装置包括
数据获取单元,被配置为获取目标域数据集、源域数据集以及随机噪声,所述源域数据集包括具有标签的多个源域样本数据,所述目标域数据集包括多个目标域样本数据,其中部分目标域样本数据具有标签;其中,所...
【专利技术属性】
技术研发人员:李振鹏,姜佳男,郭玉红,
申请(专利权)人:北京嘀嘀无限科技发展有限公司,
类型:发明
国别省市:北京;11
还没有人留言评论。发表了对其他浏览者有用的留言会获得科技券。