3. TokyoTech
TokyoTech
Background: The Age of AI
スマートフォンやタブレットはデータの宝庫
・テキスト、画像、健康状態、移動履歴…
機械学習に活⽤したい!
2022/5/17 SSII OS3「深層学習のための効率的なデータ収集と活⽤」 3
機械学習による様々なアプリケーションの実現
・⾃動翻訳、ロボット制御、⾃動診断…
⾼度なモデルの学習には豊富なデータが不可⽋
データの効率的収集が課題
7. TokyoTech
TokyoTech
Applications: Emoji prediction from Google [4]
2022/5/17 T5: Part1 8
[4] Ramaswamy, et al., “Federated Learning for Emoji
Prediction in a Mobile Keyboard,” arXiv:1906.04329.
ML model predicts a Emoji based on the context.
The model trained via FL achieved better prediction accuracy (+7%).
8. TokyoTech
TokyoTech
Applications: Oxygen needs prediction from NVIDIA [6]
[6] https://blogs.nvidia.com/blog/2020/10/05/federated-learning-covid-oxygen-needs/
2022/5/17 T5: Part1 9
Using NVIDIA Clara Federated Learning Framework,
researchers at individual hospitals were able to
use a chest X-ray, patient vitals and lab values to
train a local model and share only a subset of model
weights back with the global model in a privacy-
preserving technique called federated learning.
10. TokyoTech
TokyoTech
System model (これは概ね共通)
2022/5/17 SSII OS3「深層学習のための効率的なデータ収集と活⽤」 18
Server
Clients (user devices)
Data
Client
• データ保持者であり、モデルの学習に必要な情報の
提供を⾏う。ただし、データの共有は不可。
• 学習に適したデータ(前処理やラベル付け済み)を
もつと仮定する
• データが少量ならば5〜10回程度のモデル訓練が実⾏
可能な程度の計算能⼒をもつ
Server
• 学習の管理と訓練対象のモデル(グローバルモデ
ル)の更新を⾏う
• Clientと通信が可能
• Clientが兼任することも可能
11. TokyoTech
TokyoTech
いろいろな設定のFederated Learning Problem
2022/5/18 SSII OS3「深層学習のための効率的なデータ収集と活⽤」 20
データ: IID or Non-IID
学習タスク: 教師あり学習、半教師あり学習、教師なし学習、強化学習
システム構成: Server-Client型 or 階層型 or 分散型(Server-less)
攻撃者の有無: 学習の妨害(Poisoning)、バックドア、データ盗聴
シナリオ(主にClientの数や性能)
•Cross-silo FL: 10-100台程度のサーバなど
同じClientが何度も学習に参加する
•Cross-device FL: 1K-1M台のスマートフォンやラップトップなど
各Clientが学習に参加するのは数回程度
12. TokyoTech
TokyoTech
いろいろな設定のFederated Learning Problem
2022/5/18 SSII OS3「深層学習のための効率的なデータ収集と活⽤」 21
データ: IID or Non-IID
学習タスク: 教師あり学習、半教師あり学習、教師なし学習、強化学習
システム構成: Server-Client型 or 階層型 or 分散型(Server-less)
攻撃者の有無: 学習の妨害(Poisoning)、バックドア、データ盗聴
シナリオ(主にClientの数や性能)
•Cross-silo FL: 10-100台程度のサーバなど
同じClientが何度も学習に参加する
•Cross-device FL: 1K-1M台のスマートフォンやラップトップなど
各Clientが学習に参加するのは数回程度
最もよくある
設定で解説
13. TokyoTech
TokyoTech
補⾜:シナリオ(主にClientの数や性能)の違い 1/2
2022/5/18 T5: Part1 22
Cross-silo federated learning
Cross-device federated learning
Clients: millions of devices such as
mobile phones and IoT sensors
Clients: small numbers of data silos
such as institutions and factories
Server
Use cases
• Keyboard next-word
prediction [3]
• Emoji prediction [4]
• Speaker recognition [5]
Client: Millions of smart phone
Server
Use case
Oxygen need prediction [6]
Client: 20 hospitals
Silo A
Silo B
Clients
14. TokyoTech
TokyoTech
補⾜:シナリオ(主にClientの数や性能)の違い 2/2
2022/5/18 T5: Part1 23
Cross-silo federated learning
Cross-device federated learning
Clients: millions of devices such as
mobile phones and IoT sensors
Clients: small numbers of data silos
such as institutions and factories
Server
• Clients are intermittently
available
• Only a portion of clients
participate in round.
• Clients may participate few
times or once.
Server
• Clients are always available.
• Most clients participate in
every round.
• Clients are identified.
• Server can know the
characteristics of each
client and manage their
participation in detail.
Silo A
Silo B
Clients
15. TokyoTech
TokyoTech
Federated Learningの具体的なアルゴリズム
2022/5/18 SSII OS3「深層学習のための効率的なデータ収集と活⽤」 24
Clients (user devices)
Model
Data
3. モデルの更新
5. モデルの統合
2. モデルの配布
FedAvg (Federated Averaging) [1]
各Clientが訓練したモデルのパラメタ
を収集し、算術平均をとることで⼀つ
のモデルに統合し、学習する⽅式
• Server-Client間でやりとりするのは
モデルだけ
• データはそれを保持するClient⾃⾝
しか参照しない
[1] B. McMahan, et al.,“Communication-efficient learning of deep networks from decentralized data,” Proc. AISTATS, pp. 1273–1282, Apr. 2017.
16. TokyoTech
TokyoTech
Federated Learningの具体的なアルゴリズム
1. Client selection: サーバはラウンド(⼀連
の更新⼿順)に参加するClientを選択
2. 選択されたClientにグローバルモデルを配布
3. Local update: Clientは⾃⾝の持つデータを
使って、配布されたモデルを更新する。更
新したモデルはローカルモデルと呼ぶ。
4. ローカルモデルのパラメタをサーバに共有
する
5. Model aggregation: 共有されたパラメタを
平均し、グローバルモデルとする
6. 1~5の⼿順を繰り返す
2022/5/18 SSII OS3「深層学習のための効率的なデータ収集と活⽤」 25
[1] B. McMahan, et al.,“Communication-efficient learning of deep networks from decentralized data,” Proc. AISTATS, pp. 1273–1282, Apr. 2017.
Clients (user devices)
Model
Data
3. モデルの更新
5. モデルの統合
2. モデルの配布
17. TokyoTech
TokyoTech
モデルの更新と統合(ニューラルネットワークを想定)
2022/5/17 SSII OS3「深層学習のための効率的なデータ収集と活⽤」 27
厳密にはミニバッチ確率的勾配降下法
Client update
通常のモデル更新のように確率的
勾配降下法によりモデルを更新
損失関数に対する勾配
ミニバッチ
パラメタ(重み)
for local epoch 𝑖 from 1 to 𝐸:
for batch 𝑏 ∈ 𝐵:
𝑤 ← 𝑤 − 𝜂 ∇𝑙(𝑤; 𝑏)
Model aggregation
ローカルモデルのパラメタをデー
タ数で重み付けし平均
𝑤!
"
← Client update
𝑤!#$ ← ∑"∈𝑺!
'"
'
𝑤!
"
グローバルモデル Client kのデータ数 / 総データ数
18. TokyoTech
TokyoTech
性能評価 (Supplementary PDF of [B. McMahan, et al.,“Communication-efficient learning of deep networks from
decentralized data,” Proc. AISTATS, pp. 1273‒1282, Apr. 2017.])
2022/5/17 SSII OS3「深層学習のための効率的なデータ収集と活⽤」 28
CIFAR-10 (画像分類タスク)
ハイパーパラメータにもよるが、
データを集約した場合と同程度の
精度までモデルを訓練できている
19. TokyoTech
TokyoTech
集中型機械学習 vs. Federated Learning
2022/5/17 SSII OS3「深層学習のための効率的なデータ収集と活⽤」 29
Clients (user devices)
Model
Data
モデルの更新
モデルの統合
モデルの配布
Federated Learning
集中型機械学習
Server
User devices
モデルの訓練
データの集約によるプライバシ
情報や機密情報漏洩の懸念
学習のための情報のみ共有し、データ
は端末に保持されるため、漏洩リスク
が軽減
20. TokyoTech
TokyoTech
分散機械学習 vs. Federated Learning
2022/5/17 SSII OS3「深層学習のための効率的なデータ収集と活⽤」 30
Servers
分散機械学習
モデルの訓練
Database
Clients (user devices)
Model
Data
モデルの更新
モデルの統合
モデルの配布
Federated Learning
学習処理の分散化に焦点
データを⼀度集約し任意に分配
クライアントごとに異なる分布
のデータを持ち、学習が困難に
24. TokyoTech
TokyoTech
Non-iid (not independent and identically distributed)データ
各Clientの持つデータが従う分布が、全クライアントのデータを集約し
た場合の分布と⼀致しない状況
2022/5/17 SSII OS3「深層学習のための効率的なデータ収集と活⽤」 35
Class
Distribution of data
統合してもモデル性能
があまり向上しない
データが⼤きく異なるため
Model 2と極端に異なるモデル
を獲得
Model 1
Model
training
Class
Client 1
Client 2
Model 2
Aggregated model
Distribution of data
32. TokyoTech
TokyoTech
Distillation based Semi-Supervised Federated Learning (DS-FL) [A]
[A] S. Itahara, T. Nishio, et al.,, “Distillation-Based Semi-Supervised Federated Learning for Communication-Efficient Collaborative Training with Non-IID Private Data,”
IEEE Trans. Mobile Compt.
2022/5/17 SSII OS3「深層学習のための効率的なデータ収集と活⽤」 44
Distillation (蒸留)
モデルの出⼒であるロジットを⽤いた学習⽅式
モデルの代わりにサイズの⼩さいロジットを⽤いることでトラヒック
を⼤幅に削減
Semi-supervised learning (半教師あり学習)
ラベル付きデータに加えて、ラベルなしデータも活⽤する機械学習
従来はモデルの汎化性能向上に⽤いられることが多い
本⽅式ではDistillationをFLに組み込むために活⽤
通信トラヒックを⼤幅削減可能(FedAvgの1/50)な学習⼿法
33. TokyoTech
TokyoTech
DS-FLと従来⼿法(FedAvg)の⽐較
[A] S. Itahara, T. Nishio, et al.,, “Distillation-Based Semi-Supervised Federated Learning for Communication-Efficient Collaborative Training with Non-IID Private Data,” IEEE Trans. Mobile Compt.
2022/5/17 SSII OS3「深層学習のための効率的なデータ収集と活⽤」 46
データサイズの⼤きいモデルを何度も
共有するため⼤きなトラヒックが発⽣
Clients (user devices)
Model
Data
モデルの更新
出⼒統合とモデル訓練
Model Logit
モデルの出⼒
提案⼿法
Clients (user devices)
Model
Data
モデルの更新
モデルの統合
モデルの配布
従来のFL
モデルの出⼒情報を⽤いて学習するこ
とで学習時のトラヒックを⼤幅に削減