Add new CrossEncoder model
Browse files- .gitattributes +1 -0
- README.md +408 -0
- block.py +470 -0
- config.json +51 -0
- configuration_xlm_roberta.py +69 -0
- embedding.py +62 -0
- mha.py +662 -0
- mlp.py +194 -0
- model.safetensors +3 -0
- modeling_xlm_roberta.py +1119 -0
- special_tokens_map.json +51 -0
- tokenizer.json +3 -0
- tokenizer_config.json +55 -0
- xlm_padding.py +218 -0
.gitattributes
CHANGED
|
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
| 33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
| 33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
| 36 |
+
tokenizer.json filter=lfs diff=lfs merge=lfs -text
|
README.md
ADDED
|
@@ -0,0 +1,408 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
---
|
| 2 |
+
tags:
|
| 3 |
+
- sentence-transformers
|
| 4 |
+
- cross-encoder
|
| 5 |
+
- reranker
|
| 6 |
+
- generated_from_trainer
|
| 7 |
+
- dataset_size:3200
|
| 8 |
+
- loss:CachedMultipleNegativesRankingLoss
|
| 9 |
+
base_model: jinaai/jina-reranker-v2-base-multilingual
|
| 10 |
+
pipeline_tag: text-ranking
|
| 11 |
+
library_name: sentence-transformers
|
| 12 |
+
metrics:
|
| 13 |
+
- map
|
| 14 |
+
- mrr@10
|
| 15 |
+
- ndcg@10
|
| 16 |
+
model-index:
|
| 17 |
+
- name: CrossEncoder based on jinaai/jina-reranker-v2-base-multilingual
|
| 18 |
+
results:
|
| 19 |
+
- task:
|
| 20 |
+
type: cross-encoder-reranking
|
| 21 |
+
name: Cross Encoder Reranking
|
| 22 |
+
dataset:
|
| 23 |
+
name: jina reranker v2 base multilingual contrastive parl 4 10ep
|
| 24 |
+
type: jina-reranker-v2-base-multilingual-contrastive-parl-4-10ep
|
| 25 |
+
metrics:
|
| 26 |
+
- type: map
|
| 27 |
+
value: 0.0194
|
| 28 |
+
name: Map
|
| 29 |
+
- type: mrr@10
|
| 30 |
+
value: 0.0194
|
| 31 |
+
name: Mrr@10
|
| 32 |
+
- type: ndcg@10
|
| 33 |
+
value: 0.0198
|
| 34 |
+
name: Ndcg@10
|
| 35 |
+
---
|
| 36 |
+
|
| 37 |
+
# CrossEncoder based on jinaai/jina-reranker-v2-base-multilingual
|
| 38 |
+
|
| 39 |
+
This is a [Cross Encoder](https://www.sbert.net/docs/cross_encoder/usage/usage.html) model finetuned from [jinaai/jina-reranker-v2-base-multilingual](https://huggingface.co/jinaai/jina-reranker-v2-base-multilingual) using the [sentence-transformers](https://www.SBERT.net) library. It computes scores for pairs of texts, which can be used for text reranking and semantic search.
|
| 40 |
+
|
| 41 |
+
## Model Details
|
| 42 |
+
|
| 43 |
+
### Model Description
|
| 44 |
+
- **Model Type:** Cross Encoder
|
| 45 |
+
- **Base model:** [jinaai/jina-reranker-v2-base-multilingual](https://huggingface.co/jinaai/jina-reranker-v2-base-multilingual) <!-- at revision 2f894e63642a95228da19cdd583cd2309983c867 -->
|
| 46 |
+
- **Maximum Sequence Length:** 1024 tokens
|
| 47 |
+
- **Number of Output Labels:** 1 label
|
| 48 |
+
<!-- - **Training Dataset:** Unknown -->
|
| 49 |
+
<!-- - **Language:** Unknown -->
|
| 50 |
+
<!-- - **License:** Unknown -->
|
| 51 |
+
|
| 52 |
+
### Model Sources
|
| 53 |
+
|
| 54 |
+
- **Documentation:** [Sentence Transformers Documentation](https://sbert.net)
|
| 55 |
+
- **Documentation:** [Cross Encoder Documentation](https://www.sbert.net/docs/cross_encoder/usage/usage.html)
|
| 56 |
+
- **Repository:** [Sentence Transformers on GitHub](https://github.com/UKPLab/sentence-transformers)
|
| 57 |
+
- **Hugging Face:** [Cross Encoders on Hugging Face](https://huggingface.co/models?library=sentence-transformers&other=cross-encoder)
|
| 58 |
+
|
| 59 |
+
## Usage
|
| 60 |
+
|
| 61 |
+
### Direct Usage (Sentence Transformers)
|
| 62 |
+
|
| 63 |
+
First install the Sentence Transformers library:
|
| 64 |
+
|
| 65 |
+
```bash
|
| 66 |
+
pip install -U sentence-transformers
|
| 67 |
+
```
|
| 68 |
+
|
| 69 |
+
Then you can load this model and run inference.
|
| 70 |
+
```python
|
| 71 |
+
from sentence_transformers import CrossEncoder
|
| 72 |
+
|
| 73 |
+
# Download from the 🤗 Hub
|
| 74 |
+
model = CrossEncoder("cuadron11/jina-reranker-v2-base-multilingual-contrastive-parl-4-10ep")
|
| 75 |
+
# Get scores for pairs of texts
|
| 76 |
+
pairs = [
|
| 77 |
+
['Zer gertatu zen martxoaren 3an Euskal Autonomia Erkidegoan?', '[TOPIC: Honako ekimen hauek batera eztabaidatu eta behin betiko ebazpena hartzea: ]\n[UNZALU HERMOSA, (SV-ES)]:\nSekula. Gertatzen dena da uste dugula martxoaren 3ko jokaerak baduela zer hobetua. Eta hobetzeko abiapuntu bakarra gogoeta egitea da, aztertzea eta hasieratik aitortzea hutsegiteak egin zirela. Izan ere, nire lehenengo hitzaldian esan dudanez, triskantzak gertatu izanak pentsarazi behar liguke zerbaitek huts egin zuela egun hartako dispositiboa edo operazioa planifikatzean eta zuzentzean. Horixe sartu nahi dugu guk: eztabaida-elementuak, hobekuntzarako kritika-elementuak, eta UPyDrekin eta Alderdi Popularrarekin sinatu dugun zuzenketan hori esaten da, onar dadila gauzak hobetu egin daitezkeela. Izan ere, Iturrate jauna, zuk egin dizkiguzun galderei nik beste batzuekin erantzungo nieke. Posible da hutsegiteetatik ikastea eta herritarren segurtasuna hobetzea? Posible da? Edo, besterik gabe, "Ahal zen modu bakarrean jokatu dugu" esatera mugatu behar dugu? Posible da herritarrei kalte gutxiago eragitea horrelako istiluak gertatzen direnean? Horixe planteatu nahi dugu guk, beharrezkoa dela… Eta uste osoa dugunez hobetu daitekeela, eta uste osoa dugunez hobeto joka zitekeela, horregatik nahi dugu eta horregatik planteatzen dugu hutsegiteak aztertzea, gogoeta egitea, eta elementu zuzentzaileak martxan jartzea horrelako egoerarik berriro gerta ez dadin. Eta, begira, sailarekin batera dispositiboari babesa eman dioten bakarrak dira, hain justu, Ertzaintzaren jokaerak inoiz babesten ez dituztenak; lehen esan dudanez, Ertzaintzaren kontrako ekintzak ere gaitzetsi ez dituztenak. Eta horrek kezkatu egiten gaitu. Nik ez dakit zu, Iturrate jauna, eta sailburu andrea kezkatzen zaituzten; baina, (Date: 03.04.2014)'],
|
| 78 |
+
['Zenbat denbora behar da Ertzaintzako promozio baten deialdia egiten denetik agenteak kalera irteten diren arte?', '[TOPIC: Interpelazioa, Javier Ruiz de Arbulo Cerio Euskal Talde Popularreko legebiltzarkideak Segurtasuneko sailburuari egina, Arabako Miñoien Atalari buruz]\n[SEGURTASUNEKO SAILBURUAK (BELTRÁN DE HEREDIA ARRONIZ), (EA-NV)]:\nhoriek aldatu egiten dira egun batetik bestera, unitate batetik bestera, kontuan hartuta zer bilakaera duten erretiroek, kontuan hartuta nola gertatzen diren baja horiek… Baina, batez ere, nik bezain ondo dakizu Ertzaintzan defizit handia daukagula, eta ezin hobeto dakizu zergatia zein den. Ez dakit defizit horren zergatia zein den errepika diezazudan etorri zaren hona, baina ez daukat inolako eragozpenik Legebiltzar honetan berriro azaltzeko eta zuek berriro entzun behar izateko. Honela gaude Espainiako Gobernuak, Alderdi Popularraren Gobernuak, denbora asko behar izan zuelako, denbora gehiegi, zuk behar izan duzun be- zala, ulertzeko premia geneukala Ertzaintzan gertatzen ari ziren erretiro-bajak estaltzeko promozio berriak deitzeko –gero eta gehiago dira erretiroak eragindako bajak–; logikoa denez, baja horiek eragina zeukaten eta daukate Miñoien Atalean ere, bajak oraindik ere gertatzen ari baitira. 26. promozioa hautatzeko prozesua urtebete baino gehiago atzeratu da, errekurtsoek mehatxatu egin zituztelako 25. promozioaren bilakaera normala eta amaiera. Nik uste dut orain bide onetik goazela, baina ez duzu ahaztu behar promozio baten deialdia egiten dugunetik agenteak kalera irteten diren arte bi urte baino gehiago igarotzen direla. Bi urte baino gehiago. Eta ziztu bizian ibili ginen, betoa amaitu orduko azterketak egiteko: hogei egun eskas behar izan genituen 26. promozioko azterketen deialdia egiteko. Ziztu bizian ibili ginen, baina, hala ere, kale. Denbora eman behar da, ezta? Hemen, urdaiazpikoekin bezala geratzen da: denbora eman behar zaie, ontzeko. Bada, (Date: 01.12.2017)'],
|
| 79 |
+
['Zergatik dimititu zuen Eusko Jaurlaritzako Komunikazio zuzendariak?', '[TOPIC: Galdera, Gorka Maneiro Labayen Mistoa-UPyD taldeko legebiltzarkideak lehendakariari egina, Eusko Jaurlaritzako Komunikazio zuzendariaren dimisioaren ondoren hartu beharreko erantzukizun politikoei buruz]\n[MANEIRO LABAYEN, (Mixto-UPyD)]:\nsailburu jakin batzuei elkarrizketak egitearen truke? Erantzun ahal diezaiokezu galdera horri? Halaxe da, bai. Zure esanetan, ez dago ezer arrarorik eta irregularrik, baina pertsona batek dimititu egin du. Zer egiteko asmoa duzu zuk? Bide batez, zer da aldi baterako dimisioaren kontu hori? Beste postu batean jarri al duzue pertsona hori? Diru publikoa kobratzen jarraitzen al du? Argitu dezakezu, edo herritarrak engainatu nahi dituzue? Pertsona horrek dimititu egin du. Zer egiteko (Date: 30.10.2015)'],
|
| 80 |
+
['Zein da euskal herritarren iritzia independentziari buruz, Soziometroaren arabera?', '[TOPIC: Mozioa, Maddalen Iriarte Okiñena EH Bildu taldeko legebiltzarkideak aurkeztua, herri bezala ditugun erronka estrategikoei erantzuteko, herri-jakintza aktibatzeko eta ariketa kolektibo bat egiteko beharraren inguruan. Eztabaida eta behin betiko ebazpena]\n[BARRIO BAROJA, (PV-ETP)]:\nasko; eta ezin dela horren autokonplazientea izan eta dena positiboki egin dela esan. Argi dago, Iriarte andrea, amaitzeko, etorkizuneko erronkak ditugula; ados gaude gogor lan egin behar dela; baina estatus berria herritarrei arazo gehiago sortzea da; hura agerian jartzea eta hona ekartzea, berriz ere konfrontazio- eta eztabaida-eremu izatea da, herritarrei arazo gehiago sortzea da. Atzo argi eta garbi zioen euskal Soziometroak euskal herritarrok independentziari buruz zer iritzi dugu; eta inoiz ez da hain maila baxurik ikusi. Beraz, ildo horretan, erronka estrategikoei buruz hitz egiten ari zaren une honetan, estatus berriaren eztabaida hona ekartzea atzerapausoa litzateke, arazo gehiago ematea litzateke; eta, jakina, gu –zuri erantzuten dizut, baita orain hura aldarrikatu duen Egibar jaunari ere esaten diot– aurka egongo gara. Eskerrik asko. (Date: 10.06.2021)'],
|
| 81 |
+
['Zeintzuk dira Eusko Jaurlaritzaren asmoak euskararen normalizazioan sakontzeko?', '[TOPIC: Galdera, Rebeka Ubera Aranzeta EH Bildu taldeko legebiltzarkideak Kultura eta Hizkuntza Politikako sailburuari egina, euskararen normalizazioan sakontzeko neurri funtsezkoak hartzeari buruz]\n[UBERA ARANZETA, (EH Bildu)]:\nAdministrazioa euskalduntzeko urratsak emango zirela: ekarpenak egin ditugu eta ezezkoa jaso dugu. Esan zitzaigun euskara ikastea doako bilakatzeko urratsak emango zirela, eta mugak besterik ez dugu ikusi eta ezezkoa jaso dugu. Eta jada dagoeneko zalantzan jartzen hasiak gara Gobernu honen borondate politikoa zein den. Eta, legegintzaldi honetan, sailburuen aldetik ere, atzerakada izugarria izan da, aurreko legegintzaldiarekin konparatuta –nabarmen gainera–, eta zentzu horretan ere, zerbait egin beharko duzu. Neurtzen ari (Date: 19.05.2017)'],
|
| 82 |
+
]
|
| 83 |
+
scores = model.predict(pairs)
|
| 84 |
+
print(scores.shape)
|
| 85 |
+
# (5,)
|
| 86 |
+
|
| 87 |
+
# Or rank different texts based on similarity to a single text
|
| 88 |
+
ranks = model.rank(
|
| 89 |
+
'Zer gertatu zen martxoaren 3an Euskal Autonomia Erkidegoan?',
|
| 90 |
+
[
|
| 91 |
+
'[TOPIC: Honako ekimen hauek batera eztabaidatu eta behin betiko ebazpena hartzea: ]\n[UNZALU HERMOSA, (SV-ES)]:\nSekula. Gertatzen dena da uste dugula martxoaren 3ko jokaerak baduela zer hobetua. Eta hobetzeko abiapuntu bakarra gogoeta egitea da, aztertzea eta hasieratik aitortzea hutsegiteak egin zirela. Izan ere, nire lehenengo hitzaldian esan dudanez, triskantzak gertatu izanak pentsarazi behar liguke zerbaitek huts egin zuela egun hartako dispositiboa edo operazioa planifikatzean eta zuzentzean. Horixe sartu nahi dugu guk: eztabaida-elementuak, hobekuntzarako kritika-elementuak, eta UPyDrekin eta Alderdi Popularrarekin sinatu dugun zuzenketan hori esaten da, onar dadila gauzak hobetu egin daitezkeela. Izan ere, Iturrate jauna, zuk egin dizkiguzun galderei nik beste batzuekin erantzungo nieke. Posible da hutsegiteetatik ikastea eta herritarren segurtasuna hobetzea? Posible da? Edo, besterik gabe, "Ahal zen modu bakarrean jokatu dugu" esatera mugatu behar dugu? Posible da herritarrei kalte gutxiago eragitea horrelako istiluak gertatzen direnean? Horixe planteatu nahi dugu guk, beharrezkoa dela… Eta uste osoa dugunez hobetu daitekeela, eta uste osoa dugunez hobeto joka zitekeela, horregatik nahi dugu eta horregatik planteatzen dugu hutsegiteak aztertzea, gogoeta egitea, eta elementu zuzentzaileak martxan jartzea horrelako egoerarik berriro gerta ez dadin. Eta, begira, sailarekin batera dispositiboari babesa eman dioten bakarrak dira, hain justu, Ertzaintzaren jokaerak inoiz babesten ez dituztenak; lehen esan dudanez, Ertzaintzaren kontrako ekintzak ere gaitzetsi ez dituztenak. Eta horrek kezkatu egiten gaitu. Nik ez dakit zu, Iturrate jauna, eta sailburu andrea kezkatzen zaituzten; baina, (Date: 03.04.2014)',
|
| 92 |
+
'[TOPIC: Interpelazioa, Javier Ruiz de Arbulo Cerio Euskal Talde Popularreko legebiltzarkideak Segurtasuneko sailburuari egina, Arabako Miñoien Atalari buruz]\n[SEGURTASUNEKO SAILBURUAK (BELTRÁN DE HEREDIA ARRONIZ), (EA-NV)]:\nhoriek aldatu egiten dira egun batetik bestera, unitate batetik bestera, kontuan hartuta zer bilakaera duten erretiroek, kontuan hartuta nola gertatzen diren baja horiek… Baina, batez ere, nik bezain ondo dakizu Ertzaintzan defizit handia daukagula, eta ezin hobeto dakizu zergatia zein den. Ez dakit defizit horren zergatia zein den errepika diezazudan etorri zaren hona, baina ez daukat inolako eragozpenik Legebiltzar honetan berriro azaltzeko eta zuek berriro entzun behar izateko. Honela gaude Espainiako Gobernuak, Alderdi Popularraren Gobernuak, denbora asko behar izan zuelako, denbora gehiegi, zuk behar izan duzun be- zala, ulertzeko premia geneukala Ertzaintzan gertatzen ari ziren erretiro-bajak estaltzeko promozio berriak deitzeko –gero eta gehiago dira erretiroak eragindako bajak–; logikoa denez, baja horiek eragina zeukaten eta daukate Miñoien Atalean ere, bajak oraindik ere gertatzen ari baitira. 26. promozioa hautatzeko prozesua urtebete baino gehiago atzeratu da, errekurtsoek mehatxatu egin zituztelako 25. promozioaren bilakaera normala eta amaiera. Nik uste dut orain bide onetik goazela, baina ez duzu ahaztu behar promozio baten deialdia egiten dugunetik agenteak kalera irteten diren arte bi urte baino gehiago igarotzen direla. Bi urte baino gehiago. Eta ziztu bizian ibili ginen, betoa amaitu orduko azterketak egiteko: hogei egun eskas behar izan genituen 26. promozioko azterketen deialdia egiteko. Ziztu bizian ibili ginen, baina, hala ere, kale. Denbora eman behar da, ezta? Hemen, urdaiazpikoekin bezala geratzen da: denbora eman behar zaie, ontzeko. Bada, (Date: 01.12.2017)',
|
| 93 |
+
'[TOPIC: Galdera, Gorka Maneiro Labayen Mistoa-UPyD taldeko legebiltzarkideak lehendakariari egina, Eusko Jaurlaritzako Komunikazio zuzendariaren dimisioaren ondoren hartu beharreko erantzukizun politikoei buruz]\n[MANEIRO LABAYEN, (Mixto-UPyD)]:\nsailburu jakin batzuei elkarrizketak egitearen truke? Erantzun ahal diezaiokezu galdera horri? Halaxe da, bai. Zure esanetan, ez dago ezer arrarorik eta irregularrik, baina pertsona batek dimititu egin du. Zer egiteko asmoa duzu zuk? Bide batez, zer da aldi baterako dimisioaren kontu hori? Beste postu batean jarri al duzue pertsona hori? Diru publikoa kobratzen jarraitzen al du? Argitu dezakezu, edo herritarrak engainatu nahi dituzue? Pertsona horrek dimititu egin du. Zer egiteko (Date: 30.10.2015)',
|
| 94 |
+
'[TOPIC: Mozioa, Maddalen Iriarte Okiñena EH Bildu taldeko legebiltzarkideak aurkeztua, herri bezala ditugun erronka estrategikoei erantzuteko, herri-jakintza aktibatzeko eta ariketa kolektibo bat egiteko beharraren inguruan. Eztabaida eta behin betiko ebazpena]\n[BARRIO BAROJA, (PV-ETP)]:\nasko; eta ezin dela horren autokonplazientea izan eta dena positiboki egin dela esan. Argi dago, Iriarte andrea, amaitzeko, etorkizuneko erronkak ditugula; ados gaude gogor lan egin behar dela; baina estatus berria herritarrei arazo gehiago sortzea da; hura agerian jartzea eta hona ekartzea, berriz ere konfrontazio- eta eztabaida-eremu izatea da, herritarrei arazo gehiago sortzea da. Atzo argi eta garbi zioen euskal Soziometroak euskal herritarrok independentziari buruz zer iritzi dugu; eta inoiz ez da hain maila baxurik ikusi. Beraz, ildo horretan, erronka estrategikoei buruz hitz egiten ari zaren une honetan, estatus berriaren eztabaida hona ekartzea atzerapausoa litzateke, arazo gehiago ematea litzateke; eta, jakina, gu –zuri erantzuten dizut, baita orain hura aldarrikatu duen Egibar jaunari ere esaten diot– aurka egongo gara. Eskerrik asko. (Date: 10.06.2021)',
|
| 95 |
+
'[TOPIC: Galdera, Rebeka Ubera Aranzeta EH Bildu taldeko legebiltzarkideak Kultura eta Hizkuntza Politikako sailburuari egina, euskararen normalizazioan sakontzeko neurri funtsezkoak hartzeari buruz]\n[UBERA ARANZETA, (EH Bildu)]:\nAdministrazioa euskalduntzeko urratsak emango zirela: ekarpenak egin ditugu eta ezezkoa jaso dugu. Esan zitzaigun euskara ikastea doako bilakatzeko urratsak emango zirela, eta mugak besterik ez dugu ikusi eta ezezkoa jaso dugu. Eta jada dagoeneko zalantzan jartzen hasiak gara Gobernu honen borondate politikoa zein den. Eta, legegintzaldi honetan, sailburuen aldetik ere, atzerakada izugarria izan da, aurreko legegintzaldiarekin konparatuta –nabarmen gainera–, eta zentzu horretan ere, zerbait egin beharko duzu. Neurtzen ari (Date: 19.05.2017)',
|
| 96 |
+
]
|
| 97 |
+
)
|
| 98 |
+
# [{'corpus_id': ..., 'score': ...}, {'corpus_id': ..., 'score': ...}, ...]
|
| 99 |
+
```
|
| 100 |
+
|
| 101 |
+
<!--
|
| 102 |
+
### Direct Usage (Transformers)
|
| 103 |
+
|
| 104 |
+
<details><summary>Click to see the direct usage in Transformers</summary>
|
| 105 |
+
|
| 106 |
+
</details>
|
| 107 |
+
-->
|
| 108 |
+
|
| 109 |
+
<!--
|
| 110 |
+
### Downstream Usage (Sentence Transformers)
|
| 111 |
+
|
| 112 |
+
You can finetune this model on your own dataset.
|
| 113 |
+
|
| 114 |
+
<details><summary>Click to expand</summary>
|
| 115 |
+
|
| 116 |
+
</details>
|
| 117 |
+
-->
|
| 118 |
+
|
| 119 |
+
<!--
|
| 120 |
+
### Out-of-Scope Use
|
| 121 |
+
|
| 122 |
+
*List how the model may foreseeably be misused and address what users ought not to do with the model.*
|
| 123 |
+
-->
|
| 124 |
+
|
| 125 |
+
## Evaluation
|
| 126 |
+
|
| 127 |
+
### Metrics
|
| 128 |
+
|
| 129 |
+
#### Cross Encoder Reranking
|
| 130 |
+
|
| 131 |
+
* Dataset: `jina-reranker-v2-base-multilingual-contrastive-parl-4-10ep`
|
| 132 |
+
* Evaluated with [<code>CrossEncoderRerankingEvaluator</code>](https://sbert.net/docs/package_reference/cross_encoder/evaluation.html#sentence_transformers.cross_encoder.evaluation.CrossEncoderRerankingEvaluator) with these parameters:
|
| 133 |
+
```json
|
| 134 |
+
{
|
| 135 |
+
"at_k": 10,
|
| 136 |
+
"always_rerank_positives": false
|
| 137 |
+
}
|
| 138 |
+
```
|
| 139 |
+
|
| 140 |
+
| Metric | Value |
|
| 141 |
+
|:------------|:---------------------|
|
| 142 |
+
| map | 0.0194 (+0.0172) |
|
| 143 |
+
| mrr@10 | 0.0194 (+0.0176) |
|
| 144 |
+
| **ndcg@10** | **0.0198 (+0.0172)** |
|
| 145 |
+
|
| 146 |
+
<!--
|
| 147 |
+
## Bias, Risks and Limitations
|
| 148 |
+
|
| 149 |
+
*What are the known or foreseeable issues stemming from this model? You could also flag here known failure cases or weaknesses of the model.*
|
| 150 |
+
-->
|
| 151 |
+
|
| 152 |
+
<!--
|
| 153 |
+
### Recommendations
|
| 154 |
+
|
| 155 |
+
*What are recommendations with respect to the foreseeable issues? For example, filtering explicit content.*
|
| 156 |
+
-->
|
| 157 |
+
|
| 158 |
+
## Training Details
|
| 159 |
+
|
| 160 |
+
### Training Dataset
|
| 161 |
+
|
| 162 |
+
#### Unnamed Dataset
|
| 163 |
+
|
| 164 |
+
* Size: 3,200 training samples
|
| 165 |
+
* Columns: <code>query</code> and <code>positive</code>
|
| 166 |
+
* Approximate statistics based on the first 1000 samples:
|
| 167 |
+
| | query | positive |
|
| 168 |
+
|:--------|:-----------------------------------------------------------------------------------------------|:---------------------------------------------------------------------------------------------------|
|
| 169 |
+
| type | string | string |
|
| 170 |
+
| details | <ul><li>min: 27 characters</li><li>mean: 99.5 characters</li><li>max: 250 characters</li></ul> | <ul><li>min: 569 characters</li><li>mean: 975.13 characters</li><li>max: 2175 characters</li></ul> |
|
| 171 |
+
* Samples:
|
| 172 |
+
| query | positive |
|
| 173 |
+
|:-------------------------------------------------------------------------------------------------------------------------------------------------------------------|:----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|
|
| 174 |
+
| <code>Zein urtetan egin zuen José Ramón Becerra Carollo legebiltzarkideak SOS Deiak-112 larrialdi-deien arretarako zerbitzuaren esleipenari buruzko mozioa?</code> | <code>[TOPIC: Mozioa, José Ramón Becerra Carollo Elkarrekin Podemos taldeko legebiltzarkideak aurkeztua, SOS Deiak-112 larrialdi-deien arretarako zerbitzuaren esleipenari buruz. Eztabaida eta behin betiko ebazpena]<br>[LATXAGA UGARTEMENDIA, (EA-NV)]:<br>eta gero Sabin Etxearekin, Eliza Katolikoarekin, Xabier Arzalluzekin eta Eusko Jaurlaritzarekin berarekin lotu zenuen enpresa esleipenduna. Konspirazio perfektua lortzeko, Mosad eta BBVA falta zitzaizkizun, nik uste. Mesedez, ez erabili Ganbera hau gure eserlekuen gainean zikinkeria, zaborra botatzeko. Ez erabili horretarako, onbidezko gauzetarako baizik. Eta ez egin funtsik gabe, inolako frogarik gabe. Zuk esaten zenuena oso larria zen, oso larria, eta ezin duzu hemen tribuna honetan besterik gabe (Date: 21.12.2017)</code> |
|
| 175 |
+
| <code>Zergatik da beharrezkoa kargudun publikoen jokaera kodea arautzea?</code> | <code>[TOPIC: Euskal Sozialistak legebiltzar-taldeak egindako lege-proposamena, Kargudun Publikoaren Jokaera Kodea eta haren Bateraezintasunen Erregimena arautzeko. Aintzat hartzeari buruzko eztabaida eta behin betiko ebazpena]<br>[MINTEGI LAKARRA, (EH Bildu)]:<br>Egun on, presidente andrea, lehendakari jauna, legebiltzarkideok. Legerik onena da behar ez dena eta arautu beharra dagoenean hor badago ja gabeziaren sintoma, edo ez dagoelako adostasunik edo jokaera desegokiak egon direlako eta horiek saihestu behar direlako eta ez da ikusi beste biderik arautu beharra baino. Beraz, orain kargu publikoen jokaera etikoa edo jokaera kodea arautu beharrak adierazten digu badagoela gabezia, horren sintoma da. Izatez, jokaera zuzena berezkoa izan beharko (Date: 28.02.2013)</code> |
|
| 176 |
+
| <code>Zein da EH Bildu talde parlamentarioaren jarrera Ikuskizunen eta Jolas Jardueren Legea garatzeko erregelamenduaren inguruan?</code> | <code>[TOPIC: EH Bildu talde parlamentarioak egindako legez besteko proposamena, Ikuskizunen eta Jolas Jardueren Legea garatzeko erregelamenduaren inguruan. Eztabaida eta behin betiko ebazpena]<br>[ÁLVAREZ MARTÍNEZ, (EA-NV)]:<br>mintzaldian aipatu ditugun puntuak zehaztu behar ditugun. Uste dugu, erantzukizunetik, dekretu hori berrikusi egin behar dela, eta uste dugu dagoeneko abian dela berrikuspen-prozesu hori, Eudelekin batera, udalek dituzten ikuspegiekin batera. Puntu honetan, gogoratu behar da Eudelen kolore guzti-guztietako udalak daudela ordezkatuta, eta kontuan hartu behar da, halaber, udal horiek guztiek zer iritzi duten eta zer ikuspuntu duten. Sémper jauna, nik ere uste dut –esperientzia handirik ez daukat, baina (Date: 14.03.2019)</code> |
|
| 177 |
+
* Loss: [<code>CachedMultipleNegativesRankingLoss</code>](https://sbert.net/docs/package_reference/cross_encoder/losses.html#cachedmultiplenegativesrankingloss) with these parameters:
|
| 178 |
+
```json
|
| 179 |
+
{
|
| 180 |
+
"scale": 10.0,
|
| 181 |
+
"num_negatives": null,
|
| 182 |
+
"activation_fn": "torch.nn.modules.activation.Sigmoid",
|
| 183 |
+
"mini_batch_size": 16
|
| 184 |
+
}
|
| 185 |
+
```
|
| 186 |
+
|
| 187 |
+
### Evaluation Dataset
|
| 188 |
+
|
| 189 |
+
#### Unnamed Dataset
|
| 190 |
+
|
| 191 |
+
* Size: 800 evaluation samples
|
| 192 |
+
* Columns: <code>query</code> and <code>positive</code>
|
| 193 |
+
* Approximate statistics based on the first 800 samples:
|
| 194 |
+
| | query | positive |
|
| 195 |
+
|:--------|:-------------------------------------------------------------------------------------------------|:----------------------------------------------------------------------------------------------------|
|
| 196 |
+
| type | string | string |
|
| 197 |
+
| details | <ul><li>min: 32 characters</li><li>mean: 102.26 characters</li><li>max: 247 characters</li></ul> | <ul><li>min: 550 characters</li><li>mean: 1011.95 characters</li><li>max: 2370 characters</li></ul> |
|
| 198 |
+
* Samples:
|
| 199 |
+
| query | positive |
|
| 200 |
+
|:-----------------------------------------------------------------------------------------------------------------------------|:---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|
|
| 201 |
+
| <code>Zer gertatu zen martxoaren 3an Euskal Autonomia Erkidegoan?</code> | <code>[TOPIC: Honako ekimen hauek batera eztabaidatu eta behin betiko ebazpena hartzea: ]<br>[UNZALU HERMOSA, (SV-ES)]:<br>Sekula. Gertatzen dena da uste dugula martxoaren 3ko jokaerak baduela zer hobetua. Eta hobetzeko abiapuntu bakarra gogoeta egitea da, aztertzea eta hasieratik aitortzea hutsegiteak egin zirela. Izan ere, nire lehenengo hitzaldian esan dudanez, triskantzak gertatu izanak pentsarazi behar liguke zerbaitek huts egin zuela egun hartako dispositiboa edo operazioa planifikatzean eta zuzentzean. Horixe sartu nahi dugu guk: eztabaida-elementuak, hobekuntzarako kritika-elementuak, eta UPyDrekin eta Alderdi Popularrarekin sinatu dugun zuzenketan hori esaten da, onar dadila gauzak hobetu egin daitezkeela. Izan ere, Iturrate jauna, zuk egin dizkiguzun galderei nik beste batzuekin erantzungo nieke. Posible da hutsegiteetatik ikastea eta herritarren segurtasuna hobetzea? Posible da? Edo, besterik gabe, "Ahal zen modu bakarrean jokatu dugu" esatera mugatu behar dugu? Posible da herritarrei k...</code> |
|
| 202 |
+
| <code>Zenbat denbora behar da Ertzaintzako promozio baten deialdia egiten denetik agenteak kalera irteten diren arte?</code> | <code>[TOPIC: Interpelazioa, Javier Ruiz de Arbulo Cerio Euskal Talde Popularreko legebiltzarkideak Segurtasuneko sailburuari egina, Arabako Miñoien Atalari buruz]<br>[SEGURTASUNEKO SAILBURUAK (BELTRÁN DE HEREDIA ARRONIZ), (EA-NV)]:<br>horiek aldatu egiten dira egun batetik bestera, unitate batetik bestera, kontuan hartuta zer bilakaera duten erretiroek, kontuan hartuta nola gertatzen diren baja horiek… Baina, batez ere, nik bezain ondo dakizu Ertzaintzan defizit handia daukagula, eta ezin hobeto dakizu zergatia zein den. Ez dakit defizit horren zergatia zein den errepika diezazudan etorri zaren hona, baina ez daukat inolako eragozpenik Legebiltzar honetan berriro azaltzeko eta zuek berriro entzun behar izateko. Honela gaude Espainiako Gobernuak, Alderdi Popularraren Gobernuak, denbora asko behar izan zuelako, denbora gehiegi, zuk behar izan duzun be- zala, ulertzeko premia geneukala Ertzaintzan gertatzen ari ziren erretiro-bajak estaltzeko promozio berriak deitzeko –gero eta gehiago dira erretiro...</code> |
|
| 203 |
+
| <code>Zergatik dimititu zuen Eusko Jaurlaritzako Komunikazio zuzendariak?</code> | <code>[TOPIC: Galdera, Gorka Maneiro Labayen Mistoa-UPyD taldeko legebiltzarkideak lehendakariari egina, Eusko Jaurlaritzako Komunikazio zuzendariaren dimisioaren ondoren hartu beharreko erantzukizun politikoei buruz]<br>[MANEIRO LABAYEN, (Mixto-UPyD)]:<br>sailburu jakin batzuei elkarrizketak egitearen truke? Erantzun ahal diezaiokezu galdera horri? Halaxe da, bai. Zure esanetan, ez dago ezer arrarorik eta irregularrik, baina pertsona batek dimititu egin du. Zer egiteko asmoa duzu zuk? Bide batez, zer da aldi baterako dimisioaren kontu hori? Beste postu batean jarri al duzue pertsona hori? Diru publikoa kobratzen jarraitzen al du? Argitu dezakezu, edo herritarrak engainatu nahi dituzue? Pertsona horrek dimititu egin du. Zer egiteko (Date: 30.10.2015)</code> |
|
| 204 |
+
* Loss: [<code>CachedMultipleNegativesRankingLoss</code>](https://sbert.net/docs/package_reference/cross_encoder/losses.html#cachedmultiplenegativesrankingloss) with these parameters:
|
| 205 |
+
```json
|
| 206 |
+
{
|
| 207 |
+
"scale": 10.0,
|
| 208 |
+
"num_negatives": null,
|
| 209 |
+
"activation_fn": "torch.nn.modules.activation.Sigmoid",
|
| 210 |
+
"mini_batch_size": 16
|
| 211 |
+
}
|
| 212 |
+
```
|
| 213 |
+
|
| 214 |
+
### Training Hyperparameters
|
| 215 |
+
#### Non-Default Hyperparameters
|
| 216 |
+
|
| 217 |
+
- `eval_strategy`: steps
|
| 218 |
+
- `per_device_train_batch_size`: 16
|
| 219 |
+
- `per_device_eval_batch_size`: 16
|
| 220 |
+
- `learning_rate`: 2e-05
|
| 221 |
+
- `num_train_epochs`: 10
|
| 222 |
+
- `warmup_ratio`: 0.1
|
| 223 |
+
- `load_best_model_at_end`: True
|
| 224 |
+
- `batch_sampler`: no_duplicates
|
| 225 |
+
|
| 226 |
+
#### All Hyperparameters
|
| 227 |
+
<details><summary>Click to expand</summary>
|
| 228 |
+
|
| 229 |
+
- `overwrite_output_dir`: False
|
| 230 |
+
- `do_predict`: False
|
| 231 |
+
- `eval_strategy`: steps
|
| 232 |
+
- `prediction_loss_only`: True
|
| 233 |
+
- `per_device_train_batch_size`: 16
|
| 234 |
+
- `per_device_eval_batch_size`: 16
|
| 235 |
+
- `per_gpu_train_batch_size`: None
|
| 236 |
+
- `per_gpu_eval_batch_size`: None
|
| 237 |
+
- `gradient_accumulation_steps`: 1
|
| 238 |
+
- `eval_accumulation_steps`: None
|
| 239 |
+
- `torch_empty_cache_steps`: None
|
| 240 |
+
- `learning_rate`: 2e-05
|
| 241 |
+
- `weight_decay`: 0.0
|
| 242 |
+
- `adam_beta1`: 0.9
|
| 243 |
+
- `adam_beta2`: 0.999
|
| 244 |
+
- `adam_epsilon`: 1e-08
|
| 245 |
+
- `max_grad_norm`: 1.0
|
| 246 |
+
- `num_train_epochs`: 10
|
| 247 |
+
- `max_steps`: -1
|
| 248 |
+
- `lr_scheduler_type`: linear
|
| 249 |
+
- `lr_scheduler_kwargs`: {}
|
| 250 |
+
- `warmup_ratio`: 0.1
|
| 251 |
+
- `warmup_steps`: 0
|
| 252 |
+
- `log_level`: passive
|
| 253 |
+
- `log_level_replica`: warning
|
| 254 |
+
- `log_on_each_node`: True
|
| 255 |
+
- `logging_nan_inf_filter`: True
|
| 256 |
+
- `save_safetensors`: True
|
| 257 |
+
- `save_on_each_node`: False
|
| 258 |
+
- `save_only_model`: False
|
| 259 |
+
- `restore_callback_states_from_checkpoint`: False
|
| 260 |
+
- `no_cuda`: False
|
| 261 |
+
- `use_cpu`: False
|
| 262 |
+
- `use_mps_device`: False
|
| 263 |
+
- `seed`: 42
|
| 264 |
+
- `data_seed`: None
|
| 265 |
+
- `jit_mode_eval`: False
|
| 266 |
+
- `use_ipex`: False
|
| 267 |
+
- `bf16`: False
|
| 268 |
+
- `fp16`: False
|
| 269 |
+
- `fp16_opt_level`: O1
|
| 270 |
+
- `half_precision_backend`: auto
|
| 271 |
+
- `bf16_full_eval`: False
|
| 272 |
+
- `fp16_full_eval`: False
|
| 273 |
+
- `tf32`: None
|
| 274 |
+
- `local_rank`: 0
|
| 275 |
+
- `ddp_backend`: None
|
| 276 |
+
- `tpu_num_cores`: None
|
| 277 |
+
- `tpu_metrics_debug`: False
|
| 278 |
+
- `debug`: []
|
| 279 |
+
- `dataloader_drop_last`: False
|
| 280 |
+
- `dataloader_num_workers`: 0
|
| 281 |
+
- `dataloader_prefetch_factor`: None
|
| 282 |
+
- `past_index`: -1
|
| 283 |
+
- `disable_tqdm`: False
|
| 284 |
+
- `remove_unused_columns`: True
|
| 285 |
+
- `label_names`: None
|
| 286 |
+
- `load_best_model_at_end`: True
|
| 287 |
+
- `ignore_data_skip`: False
|
| 288 |
+
- `fsdp`: []
|
| 289 |
+
- `fsdp_min_num_params`: 0
|
| 290 |
+
- `fsdp_config`: {'min_num_params': 0, 'xla': False, 'xla_fsdp_v2': False, 'xla_fsdp_grad_ckpt': False}
|
| 291 |
+
- `fsdp_transformer_layer_cls_to_wrap`: None
|
| 292 |
+
- `accelerator_config`: {'split_batches': False, 'dispatch_batches': None, 'even_batches': True, 'use_seedable_sampler': True, 'non_blocking': False, 'gradient_accumulation_kwargs': None}
|
| 293 |
+
- `parallelism_config`: None
|
| 294 |
+
- `deepspeed`: None
|
| 295 |
+
- `label_smoothing_factor`: 0.0
|
| 296 |
+
- `optim`: adamw_torch
|
| 297 |
+
- `optim_args`: None
|
| 298 |
+
- `adafactor`: False
|
| 299 |
+
- `group_by_length`: False
|
| 300 |
+
- `length_column_name`: length
|
| 301 |
+
- `ddp_find_unused_parameters`: None
|
| 302 |
+
- `ddp_bucket_cap_mb`: None
|
| 303 |
+
- `ddp_broadcast_buffers`: False
|
| 304 |
+
- `dataloader_pin_memory`: True
|
| 305 |
+
- `dataloader_persistent_workers`: False
|
| 306 |
+
- `skip_memory_metrics`: True
|
| 307 |
+
- `use_legacy_prediction_loop`: False
|
| 308 |
+
- `push_to_hub`: False
|
| 309 |
+
- `resume_from_checkpoint`: None
|
| 310 |
+
- `hub_model_id`: None
|
| 311 |
+
- `hub_strategy`: every_save
|
| 312 |
+
- `hub_private_repo`: None
|
| 313 |
+
- `hub_always_push`: False
|
| 314 |
+
- `hub_revision`: None
|
| 315 |
+
- `gradient_checkpointing`: False
|
| 316 |
+
- `gradient_checkpointing_kwargs`: None
|
| 317 |
+
- `include_inputs_for_metrics`: False
|
| 318 |
+
- `include_for_metrics`: []
|
| 319 |
+
- `eval_do_concat_batches`: True
|
| 320 |
+
- `fp16_backend`: auto
|
| 321 |
+
- `push_to_hub_model_id`: None
|
| 322 |
+
- `push_to_hub_organization`: None
|
| 323 |
+
- `mp_parameters`:
|
| 324 |
+
- `auto_find_batch_size`: False
|
| 325 |
+
- `full_determinism`: False
|
| 326 |
+
- `torchdynamo`: None
|
| 327 |
+
- `ray_scope`: last
|
| 328 |
+
- `ddp_timeout`: 1800
|
| 329 |
+
- `torch_compile`: False
|
| 330 |
+
- `torch_compile_backend`: None
|
| 331 |
+
- `torch_compile_mode`: None
|
| 332 |
+
- `include_tokens_per_second`: False
|
| 333 |
+
- `include_num_input_tokens_seen`: False
|
| 334 |
+
- `neftune_noise_alpha`: None
|
| 335 |
+
- `optim_target_modules`: None
|
| 336 |
+
- `batch_eval_metrics`: False
|
| 337 |
+
- `eval_on_start`: False
|
| 338 |
+
- `use_liger_kernel`: False
|
| 339 |
+
- `liger_kernel_config`: None
|
| 340 |
+
- `eval_use_gather_object`: False
|
| 341 |
+
- `average_tokens_across_devices`: False
|
| 342 |
+
- `prompts`: None
|
| 343 |
+
- `batch_sampler`: no_duplicates
|
| 344 |
+
- `multi_dataset_batch_sampler`: proportional
|
| 345 |
+
- `router_mapping`: {}
|
| 346 |
+
- `learning_rate_mapping`: {}
|
| 347 |
+
|
| 348 |
+
</details>
|
| 349 |
+
|
| 350 |
+
### Training Logs
|
| 351 |
+
| Epoch | Step | Training Loss | Validation Loss | jina-reranker-v2-base-multilingual-contrastive-parl-4-10ep_ndcg@10 |
|
| 352 |
+
|:-------:|:-------:|:-------------:|:---------------:|:------------------------------------------------------------------:|
|
| 353 |
+
| **1.0** | **200** | **0.0644** | **0.0238** | **0.0200 (+0.0175)** |
|
| 354 |
+
| 2.0 | 400 | 0.0238 | 0.0220 | 0.0198 (+0.0172) |
|
| 355 |
+
| 3.0 | 600 | 0.0182 | 0.0231 | 0.0200 (+0.0175) |
|
| 356 |
+
| 4.0 | 800 | 0.0167 | 0.0235 | 0.0198 (+0.0172) |
|
| 357 |
+
| 5.0 | 1000 | 0.0123 | 0.0240 | 0.0198 (+0.0172) |
|
| 358 |
+
| 6.0 | 1200 | 0.0123 | 0.0260 | 0.0198 (+0.0172) |
|
| 359 |
+
| 7.0 | 1400 | 0.0133 | 0.0260 | 0.0198 (+0.0172) |
|
| 360 |
+
| 8.0 | 1600 | 0.0143 | 0.0258 | 0.0198 (+0.0172) |
|
| 361 |
+
| 9.0 | 1800 | 0.0136 | 0.0258 | 0.0198 (+0.0172) |
|
| 362 |
+
| 10.0 | 2000 | 0.0135 | 0.0257 | 0.0198 (+0.0172) |
|
| 363 |
+
|
| 364 |
+
* The bold row denotes the saved checkpoint.
|
| 365 |
+
|
| 366 |
+
### Framework Versions
|
| 367 |
+
- Python: 3.9.7
|
| 368 |
+
- Sentence Transformers: 5.0.0
|
| 369 |
+
- Transformers: 4.56.0
|
| 370 |
+
- PyTorch: 2.7.1+cu126
|
| 371 |
+
- Accelerate: 1.5.2
|
| 372 |
+
- Datasets: 4.0.0
|
| 373 |
+
- Tokenizers: 0.22.0
|
| 374 |
+
|
| 375 |
+
## Citation
|
| 376 |
+
|
| 377 |
+
### BibTeX
|
| 378 |
+
|
| 379 |
+
#### Sentence Transformers
|
| 380 |
+
```bibtex
|
| 381 |
+
@inproceedings{reimers-2019-sentence-bert,
|
| 382 |
+
title = "Sentence-BERT: Sentence Embeddings using Siamese BERT-Networks",
|
| 383 |
+
author = "Reimers, Nils and Gurevych, Iryna",
|
| 384 |
+
booktitle = "Proceedings of the 2019 Conference on Empirical Methods in Natural Language Processing",
|
| 385 |
+
month = "11",
|
| 386 |
+
year = "2019",
|
| 387 |
+
publisher = "Association for Computational Linguistics",
|
| 388 |
+
url = "https://arxiv.org/abs/1908.10084",
|
| 389 |
+
}
|
| 390 |
+
```
|
| 391 |
+
|
| 392 |
+
<!--
|
| 393 |
+
## Glossary
|
| 394 |
+
|
| 395 |
+
*Clearly define terms in order to be accessible across audiences.*
|
| 396 |
+
-->
|
| 397 |
+
|
| 398 |
+
<!--
|
| 399 |
+
## Model Card Authors
|
| 400 |
+
|
| 401 |
+
*Lists the people who create the model card, providing recognition and accountability for the detailed work that goes into its construction.*
|
| 402 |
+
-->
|
| 403 |
+
|
| 404 |
+
<!--
|
| 405 |
+
## Model Card Contact
|
| 406 |
+
|
| 407 |
+
*Provides a way for people who have updates to the Model Card, suggestions, or questions, to contact the Model Card authors.*
|
| 408 |
+
-->
|
block.py
ADDED
|
@@ -0,0 +1,470 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# This implementation was adapted from https://github.com/Dao-AILab/flash-attention/blob/main/flash_attn/modules/block.py
|
| 2 |
+
# Commit id: abbc1311731867310635f9edc2a9ec18317c8c48
|
| 3 |
+
|
| 4 |
+
# Copyright (c) 2024, Tri Dao.
|
| 5 |
+
|
| 6 |
+
from functools import partial
|
| 7 |
+
from typing import Optional
|
| 8 |
+
|
| 9 |
+
import torch
|
| 10 |
+
import torch.fx
|
| 11 |
+
import torch.nn as nn
|
| 12 |
+
import torch.nn.functional as F
|
| 13 |
+
from torch import Tensor
|
| 14 |
+
|
| 15 |
+
from .mha import MHA
|
| 16 |
+
from .mlp import Mlp
|
| 17 |
+
|
| 18 |
+
try:
|
| 19 |
+
from flash_attn.ops.triton.layer_norm import layer_norm_fn, RMSNorm
|
| 20 |
+
except ImportError:
|
| 21 |
+
layer_norm_fn, RMSNorm = None, None
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
def stochastic_depth(
|
| 25 |
+
input: Tensor, p: float, mode: str, training: bool = True
|
| 26 |
+
) -> Tensor:
|
| 27 |
+
"""
|
| 28 |
+
Implements the Stochastic Depth from `"Deep Networks with Stochastic Depth"
|
| 29 |
+
<https://arxiv.org/abs/1603.09382>`_ used for randomly dropping residual
|
| 30 |
+
branches of residual architectures.
|
| 31 |
+
Args:
|
| 32 |
+
input (Tensor[N, ...]): The input tensor or arbitrary dimensions with the first one
|
| 33 |
+
being its batch i.e. a batch with ``N`` rows.
|
| 34 |
+
p (float): probability of the input to be zeroed.
|
| 35 |
+
mode (str): ``"batch"`` or ``"row"``.
|
| 36 |
+
``"batch"`` randomly zeroes the entire input, ``"row"`` zeroes
|
| 37 |
+
randomly selected rows from the batch.
|
| 38 |
+
training: apply stochastic depth if is ``True``. Default: ``True``
|
| 39 |
+
Returns:
|
| 40 |
+
Tensor[N, ...]: The randomly zeroed tensor.
|
| 41 |
+
"""
|
| 42 |
+
if p < 0.0 or p > 1.0:
|
| 43 |
+
raise ValueError(f"drop probability has to be between 0 and 1, but got {p}")
|
| 44 |
+
if mode not in ["batch", "row"]:
|
| 45 |
+
raise ValueError(f"mode has to be either 'batch' or 'row', but got {mode}")
|
| 46 |
+
if not training or p == 0.0:
|
| 47 |
+
return input
|
| 48 |
+
|
| 49 |
+
survival_rate = 1.0 - p
|
| 50 |
+
if mode == "row":
|
| 51 |
+
size = [input.shape[0]] + [1] * (input.ndim - 1)
|
| 52 |
+
else:
|
| 53 |
+
size = [1] * input.ndim
|
| 54 |
+
noise = torch.empty(size, dtype=input.dtype, device=input.device)
|
| 55 |
+
noise = noise.bernoulli_(survival_rate)
|
| 56 |
+
if survival_rate > 0.0:
|
| 57 |
+
noise.div_(survival_rate)
|
| 58 |
+
return input * noise
|
| 59 |
+
|
| 60 |
+
|
| 61 |
+
torch.fx.wrap("stochastic_depth")
|
| 62 |
+
|
| 63 |
+
|
| 64 |
+
class StochasticDepth(nn.Module):
|
| 65 |
+
"""
|
| 66 |
+
See :func:`stochastic_depth`.
|
| 67 |
+
"""
|
| 68 |
+
|
| 69 |
+
def __init__(self, p: float, mode: str) -> None:
|
| 70 |
+
super().__init__()
|
| 71 |
+
self.p = p
|
| 72 |
+
self.mode = mode
|
| 73 |
+
|
| 74 |
+
def forward(self, input: Tensor) -> Tensor:
|
| 75 |
+
return stochastic_depth(input, self.p, self.mode, self.training)
|
| 76 |
+
|
| 77 |
+
def __repr__(self) -> str:
|
| 78 |
+
s = f"{self.__class__.__name__}(p={self.p}, mode={self.mode})"
|
| 79 |
+
return s
|
| 80 |
+
|
| 81 |
+
|
| 82 |
+
class Block(nn.Module):
|
| 83 |
+
def __init__(
|
| 84 |
+
self,
|
| 85 |
+
dim,
|
| 86 |
+
mixer_cls=None,
|
| 87 |
+
mlp_cls=None,
|
| 88 |
+
norm_cls=nn.LayerNorm,
|
| 89 |
+
dropout_cls=nn.Dropout,
|
| 90 |
+
prenorm=True,
|
| 91 |
+
resid_dropout1=0.0,
|
| 92 |
+
resid_dropout2=0.0,
|
| 93 |
+
drop_path1=0.0,
|
| 94 |
+
drop_path2=0.0,
|
| 95 |
+
fused_dropout_add_ln=False,
|
| 96 |
+
return_residual=False,
|
| 97 |
+
residual_in_fp32=False,
|
| 98 |
+
sequence_parallel=False,
|
| 99 |
+
mark_shared_params=False,
|
| 100 |
+
):
|
| 101 |
+
"""
|
| 102 |
+
For prenorm=True, this Block has a slightly different structure compared to a regular
|
| 103 |
+
prenorm Transformer block.
|
| 104 |
+
The standard block is: LN -> MHA -> Dropout -> Add -> LN -> MLP -> Dropout -> Add.
|
| 105 |
+
[Ref: https://arxiv.org/abs/2002.04745]
|
| 106 |
+
Here we have: Dropout -> Add -> LN -> MHA -> Dropout -> Add -> LN -> MLP, returning both
|
| 107 |
+
the hidden_states (output of the MLP) and the residual.
|
| 108 |
+
This is for performance reasons, as we can fuse the dropout, add and LayerNorm.
|
| 109 |
+
The residual needs to be provided (except for the very first block).
|
| 110 |
+
|
| 111 |
+
For prenorm=False, this Block has the same structure as a regular postnorm Transformer
|
| 112 |
+
block: MHA -> Dropout -> Add -> LN -> MLP -> Dropout -> Add -> LN.
|
| 113 |
+
|
| 114 |
+
return_residual: whether each of the sub-layers (mixer and mlp) will return the residual.
|
| 115 |
+
This is for performance reason: for post-norm architecture, returning the input allows us
|
| 116 |
+
to fuse the backward of nn.Linear with the residual connection.
|
| 117 |
+
"""
|
| 118 |
+
super().__init__()
|
| 119 |
+
self.prenorm = prenorm
|
| 120 |
+
self.fused_dropout_add_ln = fused_dropout_add_ln
|
| 121 |
+
self.return_residual = return_residual
|
| 122 |
+
self.residual_in_fp32 = residual_in_fp32
|
| 123 |
+
if self.residual_in_fp32:
|
| 124 |
+
assert self.prenorm, "residual_in_fp32 is only compatible with prenorm=True"
|
| 125 |
+
if mixer_cls is None:
|
| 126 |
+
mixer_cls = partial(MHA, num_heads=dim // 64)
|
| 127 |
+
if mlp_cls is None:
|
| 128 |
+
mlp_cls = partial(Mlp, hidden_features=4 * dim)
|
| 129 |
+
self.mixer = mixer_cls(dim)
|
| 130 |
+
self.dropout1 = dropout_cls(resid_dropout1)
|
| 131 |
+
self.drop_path1 = StochasticDepth(drop_path1, mode="row")
|
| 132 |
+
self.norm1 = norm_cls(dim)
|
| 133 |
+
self.mlp = mlp_cls(dim)
|
| 134 |
+
if not isinstance(self.mlp, nn.Identity):
|
| 135 |
+
self.dropout2 = dropout_cls(resid_dropout2)
|
| 136 |
+
self.drop_path2 = StochasticDepth(drop_path2, mode="row")
|
| 137 |
+
self.norm2 = norm_cls(dim)
|
| 138 |
+
|
| 139 |
+
if self.fused_dropout_add_ln:
|
| 140 |
+
assert layer_norm_fn is not None, "Triton is not installed"
|
| 141 |
+
assert isinstance(self.norm1, (nn.LayerNorm, RMSNorm)) and isinstance(
|
| 142 |
+
self.dropout1, nn.Dropout
|
| 143 |
+
)
|
| 144 |
+
|
| 145 |
+
# TD [2023-01-07]: TODO: During training, if sequence_parallel is False and dropout != 0.0,
|
| 146 |
+
# then the input to each worker in the tensor parallel group will be different.
|
| 147 |
+
# This would produce wrong outputs? Somehow we'd need to sync the RNG state across workers.
|
| 148 |
+
# For now this is not an issue because we always use sequence_parallel=True during training
|
| 149 |
+
# and only use sequence_parallel=False during inference.
|
| 150 |
+
|
| 151 |
+
# Mark the norm parameters as "sequence_parallel" so that we run all-reduce on their grads.
|
| 152 |
+
if sequence_parallel:
|
| 153 |
+
for p in self.norm1.parameters():
|
| 154 |
+
p._sequence_parallel = True
|
| 155 |
+
if hasattr(self, "norm2"):
|
| 156 |
+
for p in self.norm2.parameters():
|
| 157 |
+
p._sequence_parallel = True
|
| 158 |
+
# Mark the norm parameters as "shared_params" so that we sync their values at init.
|
| 159 |
+
if mark_shared_params:
|
| 160 |
+
for p in self.norm1.parameters():
|
| 161 |
+
p._shared_params = True
|
| 162 |
+
if hasattr(self, "norm2"):
|
| 163 |
+
for p in self.norm2.parameters():
|
| 164 |
+
p._shared_params = True
|
| 165 |
+
|
| 166 |
+
def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, **kwargs):
|
| 167 |
+
return self.mixer.allocate_inference_cache(
|
| 168 |
+
batch_size, max_seqlen, dtype=dtype, **kwargs
|
| 169 |
+
)
|
| 170 |
+
|
| 171 |
+
def forward(
|
| 172 |
+
self,
|
| 173 |
+
hidden_states: Tensor,
|
| 174 |
+
residual: Optional[Tensor] = None,
|
| 175 |
+
mixer_subset=None,
|
| 176 |
+
mixer_kwargs=None,
|
| 177 |
+
):
|
| 178 |
+
r"""Pass the input through the encoder layer.
|
| 179 |
+
|
| 180 |
+
Args:
|
| 181 |
+
hidden_states: the sequence to the encoder layer (required).
|
| 182 |
+
residual: if postnorm, residual=None, If prenorm, hidden_states = Attn/MLP(LN(residual))
|
| 183 |
+
mixer_subset: for cross-attention only. If not None, will take a subset of x
|
| 184 |
+
before applying the query projection. Useful for e.g., ViT where we only care
|
| 185 |
+
about the CLS token in the last layer.
|
| 186 |
+
"""
|
| 187 |
+
if self.prenorm:
|
| 188 |
+
if not self.fused_dropout_add_ln:
|
| 189 |
+
dropped = self.drop_path1(self.dropout1(hidden_states))
|
| 190 |
+
residual = (dropped + residual) if residual is not None else dropped
|
| 191 |
+
hidden_states = self.norm1(residual.to(dtype=self.norm1.weight.dtype))
|
| 192 |
+
if self.residual_in_fp32:
|
| 193 |
+
residual = residual.to(torch.float32)
|
| 194 |
+
else:
|
| 195 |
+
if self.drop_path1.p == 0 or not self.training:
|
| 196 |
+
rowscale1 = None
|
| 197 |
+
else:
|
| 198 |
+
rowscale1 = self.drop_path1(
|
| 199 |
+
torch.ones(
|
| 200 |
+
hidden_states.shape[:-1],
|
| 201 |
+
device=hidden_states.device,
|
| 202 |
+
dtype=hidden_states.dtype,
|
| 203 |
+
)
|
| 204 |
+
)
|
| 205 |
+
hidden_states, residual = layer_norm_fn(
|
| 206 |
+
hidden_states,
|
| 207 |
+
self.norm1.weight,
|
| 208 |
+
self.norm1.bias,
|
| 209 |
+
residual=residual,
|
| 210 |
+
eps=self.norm1.eps,
|
| 211 |
+
dropout_p=self.dropout1.p if self.training else 0.0,
|
| 212 |
+
rowscale=rowscale1,
|
| 213 |
+
prenorm=True,
|
| 214 |
+
residual_in_fp32=self.residual_in_fp32,
|
| 215 |
+
is_rms_norm=isinstance(self.norm1, RMSNorm),
|
| 216 |
+
)
|
| 217 |
+
if mixer_kwargs is None:
|
| 218 |
+
mixer_kwargs = {}
|
| 219 |
+
if mixer_subset is not None:
|
| 220 |
+
mixer_kwargs["mixer_subset"] = mixer_subset
|
| 221 |
+
hidden_states = self.mixer(hidden_states, **mixer_kwargs)
|
| 222 |
+
if mixer_subset is not None:
|
| 223 |
+
residual = residual[:, mixer_subset]
|
| 224 |
+
if not isinstance(self.mlp, nn.Identity):
|
| 225 |
+
if not self.fused_dropout_add_ln:
|
| 226 |
+
dropped = self.drop_path2(self.dropout2(hidden_states))
|
| 227 |
+
residual = (dropped + residual) if residual is not None else dropped
|
| 228 |
+
hidden_states = self.norm2(
|
| 229 |
+
residual.to(dtype=self.norm2.weight.dtype)
|
| 230 |
+
)
|
| 231 |
+
if self.residual_in_fp32:
|
| 232 |
+
residual = residual.to(torch.float32)
|
| 233 |
+
else:
|
| 234 |
+
if self.drop_path2.p == 0 or not self.training:
|
| 235 |
+
rowscale2 = None
|
| 236 |
+
else:
|
| 237 |
+
rowscale2 = self.drop_path2(
|
| 238 |
+
torch.ones(
|
| 239 |
+
hidden_states.shape[:-1],
|
| 240 |
+
device=hidden_states.device,
|
| 241 |
+
dtype=hidden_states.dtype,
|
| 242 |
+
)
|
| 243 |
+
)
|
| 244 |
+
hidden_states, residual = layer_norm_fn(
|
| 245 |
+
hidden_states,
|
| 246 |
+
self.norm2.weight,
|
| 247 |
+
self.norm2.bias,
|
| 248 |
+
residual=residual,
|
| 249 |
+
eps=self.norm2.eps,
|
| 250 |
+
dropout_p=self.dropout2.p if self.training else 0.0,
|
| 251 |
+
rowscale=rowscale2,
|
| 252 |
+
prenorm=True,
|
| 253 |
+
residual_in_fp32=self.residual_in_fp32,
|
| 254 |
+
is_rms_norm=isinstance(self.norm2, RMSNorm),
|
| 255 |
+
)
|
| 256 |
+
hidden_states = self.mlp(hidden_states)
|
| 257 |
+
return hidden_states, residual
|
| 258 |
+
else:
|
| 259 |
+
assert residual is None
|
| 260 |
+
mixer_out = self.mixer(
|
| 261 |
+
hidden_states, **(mixer_kwargs if mixer_kwargs is not None else {})
|
| 262 |
+
)
|
| 263 |
+
if self.return_residual: # mixer out is actually a pair here
|
| 264 |
+
mixer_out, hidden_states = mixer_out
|
| 265 |
+
if not self.fused_dropout_add_ln:
|
| 266 |
+
hidden_states = self.norm1(
|
| 267 |
+
(self.drop_path1(self.dropout1(mixer_out)) + hidden_states).to(
|
| 268 |
+
dtype=self.norm1.weight.dtype
|
| 269 |
+
)
|
| 270 |
+
)
|
| 271 |
+
else:
|
| 272 |
+
if self.drop_path1.p == 0 or not self.training:
|
| 273 |
+
rowscale1 = None
|
| 274 |
+
else:
|
| 275 |
+
rowscale1 = self.drop_path1(
|
| 276 |
+
torch.ones(
|
| 277 |
+
mixer_out.shape[:-1],
|
| 278 |
+
device=mixer_out.device,
|
| 279 |
+
dtype=mixer_out.dtype,
|
| 280 |
+
)
|
| 281 |
+
)
|
| 282 |
+
hidden_states = layer_norm_fn(
|
| 283 |
+
mixer_out,
|
| 284 |
+
self.norm1.weight,
|
| 285 |
+
self.norm1.bias,
|
| 286 |
+
residual=hidden_states,
|
| 287 |
+
eps=self.norm1.eps,
|
| 288 |
+
dropout_p=self.dropout1.p if self.training else 0.0,
|
| 289 |
+
rowscale=rowscale1,
|
| 290 |
+
prenorm=False,
|
| 291 |
+
is_rms_norm=isinstance(self.norm1, RMSNorm),
|
| 292 |
+
)
|
| 293 |
+
if not isinstance(self.mlp, nn.Identity):
|
| 294 |
+
mlp_out = self.mlp(hidden_states)
|
| 295 |
+
if self.return_residual: # mlp out is actually a pair here
|
| 296 |
+
mlp_out, hidden_states = mlp_out
|
| 297 |
+
if not self.fused_dropout_add_ln:
|
| 298 |
+
hidden_states = self.norm2(
|
| 299 |
+
(self.drop_path2(self.dropout2(mlp_out)) + hidden_states).to(
|
| 300 |
+
dtype=self.norm2.weight.dtype
|
| 301 |
+
)
|
| 302 |
+
)
|
| 303 |
+
else:
|
| 304 |
+
if self.drop_path2.p == 0 or not self.training:
|
| 305 |
+
rowscale2 = None
|
| 306 |
+
else:
|
| 307 |
+
rowscale2 = self.drop_path2(
|
| 308 |
+
torch.ones(
|
| 309 |
+
mlp_out.shape[:-1],
|
| 310 |
+
device=mlp_out.device,
|
| 311 |
+
dtype=mlp_out.dtype,
|
| 312 |
+
)
|
| 313 |
+
)
|
| 314 |
+
hidden_states = layer_norm_fn(
|
| 315 |
+
mlp_out,
|
| 316 |
+
self.norm2.weight,
|
| 317 |
+
self.norm2.bias,
|
| 318 |
+
residual=hidden_states,
|
| 319 |
+
eps=self.norm2.eps,
|
| 320 |
+
dropout_p=self.dropout2.p if self.training else 0.0,
|
| 321 |
+
rowscale=rowscale2,
|
| 322 |
+
prenorm=False,
|
| 323 |
+
is_rms_norm=isinstance(self.norm2, RMSNorm),
|
| 324 |
+
)
|
| 325 |
+
return hidden_states
|
| 326 |
+
|
| 327 |
+
|
| 328 |
+
class ParallelBlock(nn.Module):
|
| 329 |
+
"""The attention (mixer) and MLP blocks are done in parallel, similar to GPT-J, GPT-NeoX,
|
| 330 |
+
and PaLM.
|
| 331 |
+
"""
|
| 332 |
+
|
| 333 |
+
def __init__(
|
| 334 |
+
self,
|
| 335 |
+
dim,
|
| 336 |
+
mixer_cls=None,
|
| 337 |
+
mlp_cls=None,
|
| 338 |
+
norm_cls=nn.LayerNorm,
|
| 339 |
+
dropout_cls=nn.Dropout,
|
| 340 |
+
resid_dropout1=0.0,
|
| 341 |
+
resid_dropout2=0.0,
|
| 342 |
+
tied_norm=False,
|
| 343 |
+
fused_dropout_add_ln=False,
|
| 344 |
+
residual_in_fp32=False,
|
| 345 |
+
sequence_parallel=False,
|
| 346 |
+
mark_shared_params=False,
|
| 347 |
+
):
|
| 348 |
+
"""
|
| 349 |
+
This Block has a slightly different structure compared to a regular
|
| 350 |
+
prenorm Transformer block.
|
| 351 |
+
The standard block is: LN -> MHA / MLP -> Dropout -> Add.
|
| 352 |
+
[Ref: https://arxiv.org/abs/2002.04745]
|
| 353 |
+
Here we have: Dropout -> Add -> LN -> MHA / MLP, returning both
|
| 354 |
+
the hidden_states (output1 of the MHA / MLP) and the residual.
|
| 355 |
+
This is for performance reasons, as we can fuse the dropout, add and LayerNorm.
|
| 356 |
+
The residual needs to be provided (except for the very first block).
|
| 357 |
+
"""
|
| 358 |
+
super().__init__()
|
| 359 |
+
self.tied_norm = tied_norm
|
| 360 |
+
self.fused_dropout_add_ln = fused_dropout_add_ln
|
| 361 |
+
self.residual_in_fp32 = residual_in_fp32
|
| 362 |
+
if mixer_cls is None:
|
| 363 |
+
mixer_cls = partial(MHA, num_heads=dim // 64)
|
| 364 |
+
if mlp_cls is None:
|
| 365 |
+
mlp_cls = partial(Mlp, hidden_features=4 * dim)
|
| 366 |
+
self.mixer = mixer_cls(dim)
|
| 367 |
+
self.dropout1 = dropout_cls(resid_dropout1)
|
| 368 |
+
self.norm1 = norm_cls(dim)
|
| 369 |
+
self.mlp = mlp_cls(dim)
|
| 370 |
+
self.dropout2 = dropout_cls(resid_dropout2)
|
| 371 |
+
if not self.tied_norm:
|
| 372 |
+
self.norm2 = norm_cls(dim)
|
| 373 |
+
|
| 374 |
+
if self.fused_dropout_add_ln:
|
| 375 |
+
assert layer_norm_fn is not None, "Triton is not installed"
|
| 376 |
+
assert isinstance(self.norm1, (nn.LayerNorm, RMSNorm)) and isinstance(
|
| 377 |
+
self.dropout1, nn.Dropout
|
| 378 |
+
)
|
| 379 |
+
|
| 380 |
+
# TD [2023-01-07]: TODO: During training, if sequence_parallel is False and dropout != 0.0,
|
| 381 |
+
# then the input to each worker in the tensor parallel group will be different.
|
| 382 |
+
# This would produce wrong outputs? Somehow we'd need to sync the RNG state across workers.
|
| 383 |
+
# For now this is not an issue because we always use sequence_parallel=True during training
|
| 384 |
+
# and only use sequence_parallel=False during inference.
|
| 385 |
+
|
| 386 |
+
# Mark the norm parameters as "sequence_parallel" so that we run all-reduce on their grads.
|
| 387 |
+
if sequence_parallel:
|
| 388 |
+
for p in self.norm1.parameters():
|
| 389 |
+
p._sequence_parallel = True
|
| 390 |
+
if hasattr(self, "norm2"):
|
| 391 |
+
for p in self.norm2.parameters():
|
| 392 |
+
p._sequence_parallel = True
|
| 393 |
+
# Mark the norm parameters as "shared_params" so that we sync their values at init.
|
| 394 |
+
if mark_shared_params:
|
| 395 |
+
for p in self.norm1.parameters():
|
| 396 |
+
p._shared_params = True
|
| 397 |
+
if hasattr(self, "norm2"):
|
| 398 |
+
for p in self.norm2.parameters():
|
| 399 |
+
p._shared_params = True
|
| 400 |
+
|
| 401 |
+
def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, **kwargs):
|
| 402 |
+
return self.mixer.allocate_inference_cache(
|
| 403 |
+
batch_size, max_seqlen, dtype=dtype, **kwargs
|
| 404 |
+
)
|
| 405 |
+
|
| 406 |
+
def forward(
|
| 407 |
+
self,
|
| 408 |
+
hidden_states1: Tensor,
|
| 409 |
+
hidden_states2: Optional[Tensor] = None,
|
| 410 |
+
residual: Optional[Tensor] = None,
|
| 411 |
+
mixer_kwargs=None,
|
| 412 |
+
):
|
| 413 |
+
r"""Pass the input through the encoder layer.
|
| 414 |
+
|
| 415 |
+
Args:
|
| 416 |
+
hidden_states1: the output of the previous attention (mixer) or embedding layer.
|
| 417 |
+
hidden_states2: the output of the previous MLP layer (if None, will use hidden_states1).
|
| 418 |
+
residual.
|
| 419 |
+
"""
|
| 420 |
+
# TODO: Ideally we should only do the allgather / allreduce once for
|
| 421 |
+
# the Linear to MLP & Attention
|
| 422 |
+
if not self.fused_dropout_add_ln:
|
| 423 |
+
dropped1 = self.dropout1(hidden_states1)
|
| 424 |
+
# For the very 1st block, we only want 1 dropout, not two different dropouts
|
| 425 |
+
if hidden_states2 is not None:
|
| 426 |
+
dropped2 = self.dropout2(hidden_states2)
|
| 427 |
+
residual = (
|
| 428 |
+
(residual + dropped1 + dropped2)
|
| 429 |
+
if residual is not None
|
| 430 |
+
else dropped1 + dropped2
|
| 431 |
+
)
|
| 432 |
+
else:
|
| 433 |
+
residual = (residual + dropped1) if residual is not None else dropped1
|
| 434 |
+
hidden_states1 = self.norm1(residual.to(dtype=self.norm1.weight.dtype))
|
| 435 |
+
hidden_states2 = (
|
| 436 |
+
self.norm2(residual.to(dtype=self.norm2.weight.dtype))
|
| 437 |
+
if not self.tied_norm
|
| 438 |
+
else hidden_states1
|
| 439 |
+
)
|
| 440 |
+
if self.residual_in_fp32:
|
| 441 |
+
residual = residual.to(torch.float32)
|
| 442 |
+
else:
|
| 443 |
+
weight2, bias2 = (
|
| 444 |
+
(self.norm2.weight, self.norm2.bias)
|
| 445 |
+
if not self.tied_norm
|
| 446 |
+
else (None, None)
|
| 447 |
+
)
|
| 448 |
+
hidden_states1, *rest, residual = layer_norm_fn(
|
| 449 |
+
hidden_states1,
|
| 450 |
+
self.norm1.weight,
|
| 451 |
+
self.norm1.bias,
|
| 452 |
+
residual=residual,
|
| 453 |
+
x1=hidden_states2,
|
| 454 |
+
weight1=weight2,
|
| 455 |
+
bias1=bias2,
|
| 456 |
+
eps=self.norm1.eps,
|
| 457 |
+
dropout_p=self.dropout1.p if self.training else 0.0,
|
| 458 |
+
prenorm=True,
|
| 459 |
+
residual_in_fp32=self.residual_in_fp32,
|
| 460 |
+
is_rms_norm=isinstance(self.norm1, RMSNorm),
|
| 461 |
+
)
|
| 462 |
+
if self.tied_norm:
|
| 463 |
+
hidden_states2 = hidden_states1
|
| 464 |
+
else:
|
| 465 |
+
(hidden_states2,) = rest
|
| 466 |
+
if mixer_kwargs is None:
|
| 467 |
+
mixer_kwargs = {}
|
| 468 |
+
hidden_states1 = self.mixer(hidden_states1, **mixer_kwargs)
|
| 469 |
+
hidden_states2 = self.mlp(hidden_states2)
|
| 470 |
+
return hidden_states1, hidden_states2, residual
|
config.json
ADDED
|
@@ -0,0 +1,51 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"architectures": [
|
| 3 |
+
"XLMRobertaForSequenceClassification"
|
| 4 |
+
],
|
| 5 |
+
"attention_probs_dropout_prob": 0.1,
|
| 6 |
+
"auto_map": {
|
| 7 |
+
"AutoConfig": "configuration_xlm_roberta.XLMRobertaFlashConfig",
|
| 8 |
+
"AutoModel": "modeling_xlm_roberta.XLMRobertaModel",
|
| 9 |
+
"AutoModelForSequenceClassification": "modeling_xlm_roberta.XLMRobertaForSequenceClassification"
|
| 10 |
+
},
|
| 11 |
+
"bos_token_id": 0,
|
| 12 |
+
"classifier_dropout": null,
|
| 13 |
+
"dtype": "bfloat16",
|
| 14 |
+
"emb_pooler": null,
|
| 15 |
+
"eos_token_id": 2,
|
| 16 |
+
"hidden_act": "gelu",
|
| 17 |
+
"hidden_dropout_prob": 0.1,
|
| 18 |
+
"hidden_size": 768,
|
| 19 |
+
"id2label": {
|
| 20 |
+
"0": "LABEL_0"
|
| 21 |
+
},
|
| 22 |
+
"initializer_range": 0.02,
|
| 23 |
+
"intermediate_size": 3072,
|
| 24 |
+
"label2id": {
|
| 25 |
+
"LABEL_0": 0
|
| 26 |
+
},
|
| 27 |
+
"layer_norm_eps": 1e-05,
|
| 28 |
+
"load_trained_adapters": false,
|
| 29 |
+
"lora_adaptations": null,
|
| 30 |
+
"lora_alpha": 1,
|
| 31 |
+
"lora_dropout_p": 0.0,
|
| 32 |
+
"lora_main_params_trainable": false,
|
| 33 |
+
"lora_rank": 4,
|
| 34 |
+
"matryoshka_dimensions": null,
|
| 35 |
+
"max_position_embeddings": 1026,
|
| 36 |
+
"num_attention_heads": 12,
|
| 37 |
+
"num_hidden_layers": 12,
|
| 38 |
+
"output_past": true,
|
| 39 |
+
"pad_token_id": 1,
|
| 40 |
+
"position_embedding_type": "absolute",
|
| 41 |
+
"sentence_transformers": {
|
| 42 |
+
"activation_fn": "torch.nn.modules.activation.Sigmoid",
|
| 43 |
+
"version": "5.0.0"
|
| 44 |
+
},
|
| 45 |
+
"transformers_version": "4.56.0",
|
| 46 |
+
"truncate_dim": null,
|
| 47 |
+
"type_vocab_size": 1,
|
| 48 |
+
"use_cache": false,
|
| 49 |
+
"use_flash_attn": true,
|
| 50 |
+
"vocab_size": 250002
|
| 51 |
+
}
|
configuration_xlm_roberta.py
ADDED
|
@@ -0,0 +1,69 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from transformers import PretrainedConfig
|
| 2 |
+
import torch
|
| 3 |
+
|
| 4 |
+
class XLMRobertaFlashConfig(PretrainedConfig):
|
| 5 |
+
def __init__(
|
| 6 |
+
self,
|
| 7 |
+
vocab_size=30522,
|
| 8 |
+
hidden_size=768,
|
| 9 |
+
num_hidden_layers=12,
|
| 10 |
+
num_attention_heads=12,
|
| 11 |
+
intermediate_size=3072,
|
| 12 |
+
hidden_act="gelu",
|
| 13 |
+
hidden_dropout_prob=0.1,
|
| 14 |
+
attention_probs_dropout_prob=0.1,
|
| 15 |
+
max_position_embeddings=512,
|
| 16 |
+
type_vocab_size=2,
|
| 17 |
+
initializer_range=0.02,
|
| 18 |
+
layer_norm_eps=1e-12,
|
| 19 |
+
pad_token_id=1,
|
| 20 |
+
bos_token_id=0,
|
| 21 |
+
eos_token_id=2,
|
| 22 |
+
position_embedding_type="absolute",
|
| 23 |
+
use_cache=True,
|
| 24 |
+
classifier_dropout=None,
|
| 25 |
+
lora_adaptations=None,
|
| 26 |
+
lora_rank=4,
|
| 27 |
+
lora_dropout_p=0.0,
|
| 28 |
+
lora_alpha=1,
|
| 29 |
+
lora_main_params_trainable=False,
|
| 30 |
+
load_trained_adapters=False,
|
| 31 |
+
use_flash_attn=True,
|
| 32 |
+
torch_dtype=None,
|
| 33 |
+
emb_pooler=None,
|
| 34 |
+
matryoshka_dimensions=None,
|
| 35 |
+
truncate_dim=None,
|
| 36 |
+
**kwargs,
|
| 37 |
+
):
|
| 38 |
+
super().__init__(pad_token_id=pad_token_id, bos_token_id=bos_token_id, eos_token_id=eos_token_id, **kwargs)
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
self.vocab_size = vocab_size
|
| 42 |
+
self.hidden_size = hidden_size
|
| 43 |
+
self.num_hidden_layers = num_hidden_layers
|
| 44 |
+
self.num_attention_heads = num_attention_heads
|
| 45 |
+
self.hidden_act = hidden_act
|
| 46 |
+
self.intermediate_size = intermediate_size
|
| 47 |
+
self.hidden_dropout_prob = hidden_dropout_prob
|
| 48 |
+
self.attention_probs_dropout_prob = attention_probs_dropout_prob
|
| 49 |
+
self.max_position_embeddings = max_position_embeddings
|
| 50 |
+
self.type_vocab_size = type_vocab_size
|
| 51 |
+
self.initializer_range = initializer_range
|
| 52 |
+
self.layer_norm_eps = layer_norm_eps
|
| 53 |
+
self.position_embedding_type = position_embedding_type
|
| 54 |
+
self.use_cache = use_cache
|
| 55 |
+
self.classifier_dropout = classifier_dropout
|
| 56 |
+
self.load_trained_adapters = load_trained_adapters
|
| 57 |
+
self.lora_adaptations = lora_adaptations
|
| 58 |
+
self.lora_rank = lora_rank
|
| 59 |
+
self.lora_dropout_p = lora_dropout_p
|
| 60 |
+
self.lora_alpha = lora_alpha
|
| 61 |
+
self.lora_main_params_trainable = lora_main_params_trainable
|
| 62 |
+
self.use_flash_attn = use_flash_attn
|
| 63 |
+
self.emb_pooler = emb_pooler
|
| 64 |
+
self.matryoshka_dimensions = matryoshka_dimensions
|
| 65 |
+
self.truncate_dim = truncate_dim
|
| 66 |
+
if torch_dtype and hasattr(torch, torch_dtype) and type(getattr(torch, torch_dtype)) is torch.dtype:
|
| 67 |
+
self.torch_dtype = getattr(torch, torch_dtype)
|
| 68 |
+
else:
|
| 69 |
+
self.torch_dtype = torch_dtype
|
embedding.py
ADDED
|
@@ -0,0 +1,62 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# This implementation was adapted from https://github.com/Dao-AILab/flash-attention/blob/main/flash_attn/modules/embedding.py
|
| 2 |
+
# Commit id: f1a73d074002226c42ce65a1df170ecff9f022c0
|
| 3 |
+
|
| 4 |
+
# Copyright (c) 2022, Tri Dao.
|
| 5 |
+
|
| 6 |
+
import torch
|
| 7 |
+
import torch.nn as nn
|
| 8 |
+
from einops import rearrange
|
| 9 |
+
from torch import Tensor
|
| 10 |
+
|
| 11 |
+
from transformers.models.xlm_roberta.modeling_xlm_roberta import create_position_ids_from_input_ids
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
class XLMRobertaEmbeddings(nn.Module):
|
| 15 |
+
def __init__(
|
| 16 |
+
self,
|
| 17 |
+
embed_dim,
|
| 18 |
+
vocab_size,
|
| 19 |
+
max_position_embeddings,
|
| 20 |
+
type_vocab_size,
|
| 21 |
+
padding_idx=None,
|
| 22 |
+
device=None,
|
| 23 |
+
dtype=None,
|
| 24 |
+
):
|
| 25 |
+
"""
|
| 26 |
+
If max_position_embeddings <= 0, there's no position embeddings
|
| 27 |
+
If type_vocab_size <= 0, there's no token type embeddings
|
| 28 |
+
"""
|
| 29 |
+
factory_kwargs = {"device": device, "dtype": dtype}
|
| 30 |
+
super().__init__()
|
| 31 |
+
self.word_embeddings = nn.Embedding(
|
| 32 |
+
vocab_size, embed_dim, padding_idx=padding_idx, **factory_kwargs
|
| 33 |
+
)
|
| 34 |
+
self.max_position_embeddings = max_position_embeddings
|
| 35 |
+
self.type_vocab_size = type_vocab_size
|
| 36 |
+
if self.max_position_embeddings > 0:
|
| 37 |
+
self.position_embeddings = nn.Embedding(
|
| 38 |
+
max_position_embeddings, embed_dim, **factory_kwargs
|
| 39 |
+
)
|
| 40 |
+
if self.type_vocab_size > 0:
|
| 41 |
+
self.token_type_embeddings = nn.Embedding(type_vocab_size, embed_dim, **factory_kwargs)
|
| 42 |
+
|
| 43 |
+
def forward(self, input_ids, position_ids=None, token_type_ids=None):
|
| 44 |
+
"""
|
| 45 |
+
input_ids: (batch, seqlen)
|
| 46 |
+
position_ids: (batch, seqlen)
|
| 47 |
+
token_type_ids: (batch, seqlen)
|
| 48 |
+
"""
|
| 49 |
+
batch_size, seqlen = input_ids.shape
|
| 50 |
+
embeddings = self.word_embeddings(input_ids)
|
| 51 |
+
if self.max_position_embeddings > 0:
|
| 52 |
+
if position_ids is None:
|
| 53 |
+
position_ids = create_position_ids_from_input_ids(input_ids, padding_idx=self.word_embeddings.padding_idx).to(input_ids.device)
|
| 54 |
+
# position_ids = torch.arange(seqlen, dtype=torch.long, device=input_ids.device)
|
| 55 |
+
position_embeddings = self.position_embeddings(position_ids)
|
| 56 |
+
embeddings = embeddings + position_embeddings
|
| 57 |
+
if self.type_vocab_size > 0:
|
| 58 |
+
if token_type_ids is None:
|
| 59 |
+
token_type_ids = torch.zeros(seqlen, dtype=torch.long, device=input_ids.device)
|
| 60 |
+
token_type_embeddings = self.token_type_embeddings(token_type_ids)
|
| 61 |
+
embeddings = embeddings + token_type_embeddings
|
| 62 |
+
return embeddings
|
mha.py
ADDED
|
@@ -0,0 +1,662 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2023, Tri Dao.
|
| 2 |
+
# Adapted from https://github.com/Dao-AILab/flash-attention/pull/556
|
| 3 |
+
|
| 4 |
+
import math
|
| 5 |
+
from functools import partial
|
| 6 |
+
|
| 7 |
+
import torch
|
| 8 |
+
import torch.nn as nn
|
| 9 |
+
from einops import rearrange, repeat
|
| 10 |
+
|
| 11 |
+
try:
|
| 12 |
+
from flash_attn import (
|
| 13 |
+
flash_attn_kvpacked_func,
|
| 14 |
+
flash_attn_qkvpacked_func,
|
| 15 |
+
flash_attn_varlen_kvpacked_func,
|
| 16 |
+
flash_attn_varlen_qkvpacked_func,
|
| 17 |
+
flash_attn_with_kvcache,
|
| 18 |
+
)
|
| 19 |
+
except ImportError:
|
| 20 |
+
flash_attn_varlen_qkvpacked_func, flash_attn_varlen_kvpacked_func = None, None
|
| 21 |
+
flash_attn_qkvpacked_func, flash_attn_kvpacked_func = None, None
|
| 22 |
+
flash_attn_with_kvcache = None
|
| 23 |
+
|
| 24 |
+
try:
|
| 25 |
+
from flash_attn.ops.fused_dense import ColumnParallelLinear, FusedDense, RowParallelLinear
|
| 26 |
+
except ImportError:
|
| 27 |
+
FusedDense, ColumnParallelLinear, RowParallelLinear = None, None, None
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
class FlashSelfAttention(nn.Module):
|
| 31 |
+
"""Implement the scaled dot product attention with softmax.
|
| 32 |
+
Arguments
|
| 33 |
+
---------
|
| 34 |
+
softmax_scale: The temperature to use for the softmax attention.
|
| 35 |
+
(default: 1/sqrt(d_keys) where d_keys is computed at
|
| 36 |
+
runtime)
|
| 37 |
+
attention_dropout: The dropout rate to apply to the attention
|
| 38 |
+
(default: 0.0)
|
| 39 |
+
"""
|
| 40 |
+
|
| 41 |
+
def __init__(
|
| 42 |
+
self,
|
| 43 |
+
causal=False,
|
| 44 |
+
softmax_scale=None,
|
| 45 |
+
attention_dropout=0.0,
|
| 46 |
+
window_size=(-1, -1),
|
| 47 |
+
deterministic=False,
|
| 48 |
+
):
|
| 49 |
+
super().__init__()
|
| 50 |
+
assert flash_attn_varlen_qkvpacked_func is not None, "FlashAttention is not installed"
|
| 51 |
+
assert flash_attn_qkvpacked_func is not None, "FlashAttention is not installed"
|
| 52 |
+
self.causal = causal
|
| 53 |
+
self.softmax_scale = softmax_scale
|
| 54 |
+
self.drop = nn.Dropout(attention_dropout)
|
| 55 |
+
self.window_size = window_size
|
| 56 |
+
self.deterministic = deterministic
|
| 57 |
+
|
| 58 |
+
def forward(self, qkv, causal=None, cu_seqlens=None, max_seqlen=None):
|
| 59 |
+
"""Implements the multihead softmax attention.
|
| 60 |
+
Arguments
|
| 61 |
+
---------
|
| 62 |
+
qkv: The tensor containing the query, key, and value.
|
| 63 |
+
If cu_seqlens is None and max_seqlen is None, then qkv has shape (B, S, 3, H, D).
|
| 64 |
+
If cu_seqlens is not None and max_seqlen is not None, then qkv has shape
|
| 65 |
+
(total, 3, H, D), where total is the sum of the sequence lengths in the batch.
|
| 66 |
+
causal: if passed, will override self.causal
|
| 67 |
+
cu_seqlens: (batch_size + 1,), dtype torch.int32. The cumulative sequence lengths
|
| 68 |
+
of the sequences in the batch, used to index into qkv.
|
| 69 |
+
max_seqlen: int. Maximum sequence length in the batch.
|
| 70 |
+
Returns:
|
| 71 |
+
--------
|
| 72 |
+
out: (total, H, D) if cu_seqlens is not None and max_seqlen is not None,
|
| 73 |
+
else (B, S, H, D).
|
| 74 |
+
"""
|
| 75 |
+
assert qkv.dtype in [torch.float16, torch.bfloat16]
|
| 76 |
+
assert qkv.is_cuda
|
| 77 |
+
causal = self.causal if causal is None else causal
|
| 78 |
+
unpadded = cu_seqlens is not None
|
| 79 |
+
|
| 80 |
+
if unpadded:
|
| 81 |
+
assert cu_seqlens.dtype == torch.int32
|
| 82 |
+
assert max_seqlen is not None
|
| 83 |
+
assert isinstance(max_seqlen, int)
|
| 84 |
+
return flash_attn_varlen_qkvpacked_func(
|
| 85 |
+
qkv,
|
| 86 |
+
cu_seqlens,
|
| 87 |
+
max_seqlen,
|
| 88 |
+
self.drop.p if self.training else 0.0,
|
| 89 |
+
softmax_scale=self.softmax_scale,
|
| 90 |
+
causal=causal,
|
| 91 |
+
alibi_slopes=None,
|
| 92 |
+
window_size=self.window_size,
|
| 93 |
+
deterministic=self.deterministic,
|
| 94 |
+
)
|
| 95 |
+
else:
|
| 96 |
+
return flash_attn_qkvpacked_func(
|
| 97 |
+
qkv,
|
| 98 |
+
self.drop.p if self.training else 0.0,
|
| 99 |
+
softmax_scale=self.softmax_scale,
|
| 100 |
+
causal=causal,
|
| 101 |
+
alibi_slopes=None,
|
| 102 |
+
window_size=self.window_size,
|
| 103 |
+
deterministic=self.deterministic,
|
| 104 |
+
)
|
| 105 |
+
|
| 106 |
+
|
| 107 |
+
class FlashCrossAttention(nn.Module):
|
| 108 |
+
"""Implement the scaled dot product attention with softmax.
|
| 109 |
+
Arguments
|
| 110 |
+
---------
|
| 111 |
+
softmax_scale: The temperature to use for the softmax attention.
|
| 112 |
+
(default: 1/sqrt(d_keys) where d_keys is computed at
|
| 113 |
+
runtime)
|
| 114 |
+
attention_dropout: The dropout rate to apply to the attention
|
| 115 |
+
(default: 0.0)
|
| 116 |
+
"""
|
| 117 |
+
|
| 118 |
+
def __init__(
|
| 119 |
+
self,
|
| 120 |
+
causal=False,
|
| 121 |
+
softmax_scale=None,
|
| 122 |
+
attention_dropout=0.0,
|
| 123 |
+
window_size=(-1, -1),
|
| 124 |
+
deterministic=False,
|
| 125 |
+
):
|
| 126 |
+
super().__init__()
|
| 127 |
+
assert flash_attn_varlen_kvpacked_func is not None, "FlashAttention is not installed"
|
| 128 |
+
assert flash_attn_kvpacked_func is not None, "FlashAttention is not installed"
|
| 129 |
+
self.causal = causal
|
| 130 |
+
self.softmax_scale = softmax_scale
|
| 131 |
+
self.drop = nn.Dropout(attention_dropout)
|
| 132 |
+
self.window_size = window_size
|
| 133 |
+
self.deterministic = deterministic
|
| 134 |
+
|
| 135 |
+
def forward(
|
| 136 |
+
self,
|
| 137 |
+
q,
|
| 138 |
+
kv,
|
| 139 |
+
causal=None,
|
| 140 |
+
cu_seqlens=None,
|
| 141 |
+
max_seqlen=None,
|
| 142 |
+
cu_seqlens_k=None,
|
| 143 |
+
max_seqlen_k=None,
|
| 144 |
+
):
|
| 145 |
+
"""Implements the multihead softmax attention.
|
| 146 |
+
Arguments
|
| 147 |
+
---------
|
| 148 |
+
q: The tensor containing the query. (B, Sq, H, D)
|
| 149 |
+
kv: The tensor containing the key and value. (B, Sk, 2, H_k, D)
|
| 150 |
+
causal: if passed, will override self.causal
|
| 151 |
+
cu_seqlens: (batch_size + 1,), dtype torch.int32. The cumulative sequence lengths
|
| 152 |
+
of the sequences in the batch, used to index into q.
|
| 153 |
+
max_seqlen: int. Maximum sequence length in the batch of q.
|
| 154 |
+
cu_seqlens_k: (batch_size + 1,), dtype torch.int32. The cumulative sequence lengths
|
| 155 |
+
of the sequences in the batch, used to index into kv.
|
| 156 |
+
max_seqlen_k: int. Maximum sequence length in the batch of k and v.
|
| 157 |
+
"""
|
| 158 |
+
assert q.dtype in [torch.float16, torch.bfloat16]
|
| 159 |
+
assert q.is_cuda and kv.is_cuda
|
| 160 |
+
causal = self.causal if causal is None else causal
|
| 161 |
+
unpadded = cu_seqlens is not None
|
| 162 |
+
|
| 163 |
+
if unpadded:
|
| 164 |
+
assert cu_seqlens.dtype == torch.int32
|
| 165 |
+
assert max_seqlen is not None
|
| 166 |
+
assert isinstance(max_seqlen, int)
|
| 167 |
+
assert cu_seqlens_k is not None
|
| 168 |
+
assert cu_seqlens_k.dtype == torch.int32
|
| 169 |
+
assert max_seqlen_k is not None
|
| 170 |
+
assert isinstance(max_seqlen, int)
|
| 171 |
+
return flash_attn_varlen_kvpacked_func(
|
| 172 |
+
q,
|
| 173 |
+
kv,
|
| 174 |
+
cu_seqlens,
|
| 175 |
+
cu_seqlens_k,
|
| 176 |
+
max_seqlen,
|
| 177 |
+
max_seqlen_k,
|
| 178 |
+
self.drop.p if self.training else 0.0,
|
| 179 |
+
softmax_scale=self.softmax_scale,
|
| 180 |
+
causal=causal,
|
| 181 |
+
alibi_slopes=None,
|
| 182 |
+
window_size=self.window_size,
|
| 183 |
+
deterministic=self.deterministic,
|
| 184 |
+
)
|
| 185 |
+
else:
|
| 186 |
+
batch_size, seqlen_q = q.shape[0], q.shape[1]
|
| 187 |
+
seqlen_k = kv.shape[1]
|
| 188 |
+
assert kv.shape[0] == batch_size and kv.shape[4] == q.shape[3]
|
| 189 |
+
return flash_attn_kvpacked_func(
|
| 190 |
+
q,
|
| 191 |
+
kv,
|
| 192 |
+
self.drop.p if self.training else 0.0,
|
| 193 |
+
causal=causal,
|
| 194 |
+
softmax_scale=self.softmax_scale,
|
| 195 |
+
alibi_slopes=None,
|
| 196 |
+
window_size=self.window_size,
|
| 197 |
+
deterministic=self.deterministic,
|
| 198 |
+
)
|
| 199 |
+
|
| 200 |
+
|
| 201 |
+
class SelfAttention(nn.Module):
|
| 202 |
+
"""Implement the scaled dot product attention with softmax.
|
| 203 |
+
Arguments
|
| 204 |
+
---------
|
| 205 |
+
softmax_scale: The temperature to use for the softmax attention.
|
| 206 |
+
(default: 1/sqrt(d_keys) where d_keys is computed at
|
| 207 |
+
runtime)
|
| 208 |
+
attention_dropout: The dropout rate to apply to the attention
|
| 209 |
+
(default: 0.0)
|
| 210 |
+
"""
|
| 211 |
+
|
| 212 |
+
def __init__(self, causal=False, softmax_scale=None, attention_dropout=0.0):
|
| 213 |
+
super().__init__()
|
| 214 |
+
self.causal = causal
|
| 215 |
+
self.softmax_scale = softmax_scale
|
| 216 |
+
self.drop = nn.Dropout(attention_dropout)
|
| 217 |
+
|
| 218 |
+
def forward(self, qkv, causal=None, key_padding_mask=None):
|
| 219 |
+
"""Implements the multihead softmax attention.
|
| 220 |
+
Arguments
|
| 221 |
+
---------
|
| 222 |
+
qkv: The tensor containing the query, key, and value. (B, S, 3, H, D)
|
| 223 |
+
causal: if passed, will override self.causal
|
| 224 |
+
key_padding_mask: boolean mask to apply to the attention weights. True means to keep,
|
| 225 |
+
False means to mask out. (B, S)
|
| 226 |
+
"""
|
| 227 |
+
batch_size, seqlen = qkv.shape[0], qkv.shape[1]
|
| 228 |
+
causal = self.causal if causal is None else causal
|
| 229 |
+
q, k, v = qkv.unbind(dim=2)
|
| 230 |
+
softmax_scale = self.softmax_scale or 1.0 / math.sqrt(q.shape[-1])
|
| 231 |
+
scores = torch.einsum("bthd,bshd->bhts", q, k * softmax_scale)
|
| 232 |
+
if key_padding_mask is not None:
|
| 233 |
+
padding_mask = torch.full(
|
| 234 |
+
(batch_size, seqlen), -10000.0, dtype=scores.dtype, device=scores.device
|
| 235 |
+
)
|
| 236 |
+
padding_mask.masked_fill_(key_padding_mask, 0.0)
|
| 237 |
+
# TD [2022-09-30]: Adding is faster than masked_fill_ (idk why, just better kernel I guess)
|
| 238 |
+
scores = scores + rearrange(padding_mask, "b s -> b 1 1 s")
|
| 239 |
+
if causal:
|
| 240 |
+
# "triu_tril_cuda_template" not implemented for 'BFloat16'
|
| 241 |
+
# So we have to construct the mask in float
|
| 242 |
+
causal_mask = torch.triu(
|
| 243 |
+
torch.full((seqlen, seqlen), -10000.0, device=scores.device), 1
|
| 244 |
+
)
|
| 245 |
+
# TD [2022-09-30]: Adding is faster than masked_fill_ (idk why, just better kernel I guess)
|
| 246 |
+
scores = scores + causal_mask.to(dtype=scores.dtype)
|
| 247 |
+
attention = torch.softmax(scores, dim=-1, dtype=v.dtype)
|
| 248 |
+
attention_drop = self.drop(attention)
|
| 249 |
+
output = torch.einsum("bhts,bshd->bthd", attention_drop, v)
|
| 250 |
+
return output
|
| 251 |
+
|
| 252 |
+
|
| 253 |
+
class CrossAttention(nn.Module):
|
| 254 |
+
"""Implement the scaled dot product attention with softmax.
|
| 255 |
+
Arguments
|
| 256 |
+
---------
|
| 257 |
+
softmax_scale: The temperature to use for the softmax attention.
|
| 258 |
+
(default: 1/sqrt(d_keys) where d_keys is computed at
|
| 259 |
+
runtime)
|
| 260 |
+
attention_dropout: The dropout rate to apply to the attention
|
| 261 |
+
(default: 0.0)
|
| 262 |
+
"""
|
| 263 |
+
|
| 264 |
+
def __init__(self, causal=False, softmax_scale=None, attention_dropout=0.0):
|
| 265 |
+
super().__init__()
|
| 266 |
+
self.causal = causal
|
| 267 |
+
self.softmax_scale = softmax_scale
|
| 268 |
+
self.drop = nn.Dropout(attention_dropout)
|
| 269 |
+
|
| 270 |
+
def forward(self, q, kv, causal=None, key_padding_mask=None):
|
| 271 |
+
"""Implements the multihead softmax attention.
|
| 272 |
+
Arguments
|
| 273 |
+
---------
|
| 274 |
+
q: The tensor containing the query. (B, Sq, H, D)
|
| 275 |
+
kv: The tensor containing the key and value. (B, Sk, 2, H_k, D)
|
| 276 |
+
causal: if passed, will override self.causal
|
| 277 |
+
key_padding_mask: boolean mask to apply to the attention weights. True means to keep,
|
| 278 |
+
False means to mask out. (B, Sk)
|
| 279 |
+
"""
|
| 280 |
+
batch_size, seqlen_q = q.shape[0], q.shape[1]
|
| 281 |
+
causal = self.causal if causal is None else causal
|
| 282 |
+
seqlen_k = kv.shape[1]
|
| 283 |
+
assert kv.shape[0] == batch_size and kv.shape[4] == q.shape[3]
|
| 284 |
+
if kv.shape[3] != q.shape[2]: # MQA/GQA
|
| 285 |
+
kv = repeat(kv, "... hkv d -> ... (hkv g) d", g=q.shape[2] // kv.shape[3])
|
| 286 |
+
k, v = kv.unbind(dim=2)
|
| 287 |
+
softmax_scale = self.softmax_scale or 1.0 / math.sqrt(q.shape[-1])
|
| 288 |
+
scores = torch.einsum("bthd,bshd->bhts", q, k * softmax_scale)
|
| 289 |
+
if key_padding_mask is not None:
|
| 290 |
+
padding_mask = torch.full(
|
| 291 |
+
(batch_size, seqlen_k), -10000.0, dtype=scores.dtype, device=scores.device
|
| 292 |
+
)
|
| 293 |
+
padding_mask.masked_fill_(key_padding_mask, 0.0)
|
| 294 |
+
# TD [2022-09-30]: Adding is faster than masked_fill_ (idk why, just better kernel I guess)
|
| 295 |
+
scores = scores + rearrange(padding_mask, "b s -> b 1 1 s")
|
| 296 |
+
if causal:
|
| 297 |
+
# causal mask needs to take into account the difference between seqlen_q and seqlen_k
|
| 298 |
+
row_idx = rearrange(
|
| 299 |
+
torch.arange(seqlen_q, device=q.device, dtype=torch.long), "s -> s 1"
|
| 300 |
+
)
|
| 301 |
+
col_idx = torch.arange(seqlen_k, device=kv.device, dtype=torch.long)
|
| 302 |
+
sk = (
|
| 303 |
+
seqlen_k
|
| 304 |
+
if key_padding_mask is None
|
| 305 |
+
else rearrange(key_padding_mask.sum(-1), "b -> b 1 1 1")
|
| 306 |
+
)
|
| 307 |
+
causal_mask = col_idx > row_idx + sk - seqlen_q
|
| 308 |
+
scores = scores.masked_fill(causal_mask, -10000.0)
|
| 309 |
+
attention = torch.softmax(scores, dim=-1, dtype=v.dtype)
|
| 310 |
+
attention_drop = self.drop(attention)
|
| 311 |
+
output = torch.einsum("bhts,bshd->bthd", attention_drop, v)
|
| 312 |
+
return output
|
| 313 |
+
|
| 314 |
+
|
| 315 |
+
class LinearResidual(nn.Linear):
|
| 316 |
+
"""Wrap nn.Linear to return the residual as well. For compatibility with FusedDense."""
|
| 317 |
+
|
| 318 |
+
def forward(self, input: torch.Tensor) -> torch.Tensor:
|
| 319 |
+
return super().forward(input), input
|
| 320 |
+
|
| 321 |
+
|
| 322 |
+
def _update_kv_cache(kv, inference_params, layer_idx):
|
| 323 |
+
"""kv: (batch_size, seqlen, 2, nheads, head_dim) or (batch_size, 1, 2, nheads, head_dim)"""
|
| 324 |
+
# Pre-allocate memory for key-values for inference.
|
| 325 |
+
num_heads, head_dim = kv.shape[-2:]
|
| 326 |
+
if layer_idx not in inference_params.key_value_memory_dict:
|
| 327 |
+
kv_cache = torch.empty(
|
| 328 |
+
inference_params.max_batch_size,
|
| 329 |
+
inference_params.max_seqlen,
|
| 330 |
+
2,
|
| 331 |
+
num_heads,
|
| 332 |
+
head_dim,
|
| 333 |
+
dtype=kv.dtype,
|
| 334 |
+
device=kv.device,
|
| 335 |
+
)
|
| 336 |
+
inference_params.key_value_memory_dict[layer_idx] = kv_cache
|
| 337 |
+
else:
|
| 338 |
+
kv_cache = inference_params.key_value_memory_dict[layer_idx]
|
| 339 |
+
# Adjust key and value for inference
|
| 340 |
+
batch_start = inference_params.batch_size_offset
|
| 341 |
+
batch_end = batch_start + kv.shape[0]
|
| 342 |
+
sequence_start = inference_params.seqlen_offset
|
| 343 |
+
sequence_end = sequence_start + kv.shape[1]
|
| 344 |
+
assert batch_end <= kv_cache.shape[0]
|
| 345 |
+
assert sequence_end <= kv_cache.shape[1]
|
| 346 |
+
assert kv_cache is not None
|
| 347 |
+
kv_cache[batch_start:batch_end, sequence_start:sequence_end, ...] = kv
|
| 348 |
+
return kv_cache[batch_start:batch_end, :sequence_end, ...]
|
| 349 |
+
|
| 350 |
+
|
| 351 |
+
class MHA(nn.Module):
|
| 352 |
+
"""Multi-head self-attention and cross-attention"""
|
| 353 |
+
|
| 354 |
+
def __init__(
|
| 355 |
+
self,
|
| 356 |
+
embed_dim,
|
| 357 |
+
num_heads,
|
| 358 |
+
num_heads_kv=None,
|
| 359 |
+
cross_attn=False,
|
| 360 |
+
qkv_proj_bias=True,
|
| 361 |
+
out_proj_bias=True,
|
| 362 |
+
dropout=0.0,
|
| 363 |
+
softmax_scale=None,
|
| 364 |
+
causal=False,
|
| 365 |
+
layer_idx=None,
|
| 366 |
+
dwconv=False,
|
| 367 |
+
window_size=(-1, -1),
|
| 368 |
+
fused_bias_fc=False,
|
| 369 |
+
use_flash_attn=False,
|
| 370 |
+
return_residual=False,
|
| 371 |
+
checkpointing=False,
|
| 372 |
+
device=None,
|
| 373 |
+
dtype=None,
|
| 374 |
+
) -> None:
|
| 375 |
+
"""
|
| 376 |
+
num_heads_kv: can be used to toggle MQA / GQA. If None, use num_heads.
|
| 377 |
+
return_residual: whether to return the input x along with the output. This is for
|
| 378 |
+
performance reason: for post-norm architecture, returning the input allows us
|
| 379 |
+
to fuse the backward of nn.Linear with the residual connection.
|
| 380 |
+
"""
|
| 381 |
+
factory_kwargs = {"device": device, "dtype": dtype}
|
| 382 |
+
super().__init__()
|
| 383 |
+
self.embed_dim = embed_dim
|
| 384 |
+
self.cross_attn = cross_attn
|
| 385 |
+
self.causal = causal
|
| 386 |
+
self.layer_idx = layer_idx
|
| 387 |
+
self.dwconv = dwconv
|
| 388 |
+
self.use_flash_attn = use_flash_attn
|
| 389 |
+
self.return_residual = return_residual
|
| 390 |
+
self.checkpointing = checkpointing
|
| 391 |
+
|
| 392 |
+
if window_size != (-1, -1):
|
| 393 |
+
assert use_flash_attn, "Local (sliding window) attention code path requires flash_attn"
|
| 394 |
+
|
| 395 |
+
self.num_heads = num_heads
|
| 396 |
+
self.num_heads_kv = num_heads_kv if num_heads_kv is not None else num_heads
|
| 397 |
+
assert (
|
| 398 |
+
self.num_heads % self.num_heads_kv == 0
|
| 399 |
+
), "num_heads must be divisible by num_heads_kv"
|
| 400 |
+
assert self.embed_dim % num_heads == 0, "embed_dim must be divisible by num_heads"
|
| 401 |
+
self.head_dim = self.embed_dim // num_heads
|
| 402 |
+
qkv_dim = self.head_dim * (self.num_heads + 2 * self.num_heads_kv)
|
| 403 |
+
kv_dim = 2 * self.head_dim * self.num_heads_kv
|
| 404 |
+
|
| 405 |
+
if fused_bias_fc and FusedDense is None:
|
| 406 |
+
raise ImportError("fused_dense is not installed")
|
| 407 |
+
linear_cls = nn.Linear if not fused_bias_fc else FusedDense
|
| 408 |
+
linear_resid_cls = (
|
| 409 |
+
LinearResidual if not fused_bias_fc else partial(FusedDense, return_residual=True)
|
| 410 |
+
)
|
| 411 |
+
wqkv_cls = linear_cls if not self.return_residual else linear_resid_cls
|
| 412 |
+
inner_attn_cls = (
|
| 413 |
+
partial(FlashSelfAttention, window_size=window_size)
|
| 414 |
+
if use_flash_attn
|
| 415 |
+
else SelfAttention
|
| 416 |
+
)
|
| 417 |
+
inner_cross_attn_cls = (
|
| 418 |
+
partial(FlashCrossAttention, window_size=window_size)
|
| 419 |
+
if use_flash_attn
|
| 420 |
+
else CrossAttention
|
| 421 |
+
)
|
| 422 |
+
if not self.cross_attn:
|
| 423 |
+
self.Wqkv = wqkv_cls(embed_dim, qkv_dim, bias=qkv_proj_bias, **factory_kwargs)
|
| 424 |
+
else:
|
| 425 |
+
self.Wq = linear_cls(embed_dim, embed_dim, bias=qkv_proj_bias, **factory_kwargs)
|
| 426 |
+
self.Wkv = wqkv_cls(embed_dim, kv_dim, bias=qkv_proj_bias, **factory_kwargs)
|
| 427 |
+
if self.dwconv:
|
| 428 |
+
if self.num_heads_kv == self.num_heads:
|
| 429 |
+
self.dwconv_qkv = nn.Conv1d(
|
| 430 |
+
qkv_dim, qkv_dim, kernel_size=3, padding=2, groups=qkv_dim
|
| 431 |
+
)
|
| 432 |
+
else:
|
| 433 |
+
self.dwconv_q = nn.Conv1d(
|
| 434 |
+
embed_dim, embed_dim, kernel_size=3, padding=2, groups=embed_dim
|
| 435 |
+
)
|
| 436 |
+
self.dwconv_kv = nn.Conv1d(kv_dim, kv_dim, kernel_size=3, padding=2, groups=kv_dim)
|
| 437 |
+
self.inner_attn = inner_attn_cls(
|
| 438 |
+
causal=causal,
|
| 439 |
+
softmax_scale=softmax_scale,
|
| 440 |
+
attention_dropout=dropout,
|
| 441 |
+
)
|
| 442 |
+
self.inner_cross_attn = inner_cross_attn_cls(
|
| 443 |
+
causal=causal, softmax_scale=softmax_scale, attention_dropout=dropout
|
| 444 |
+
)
|
| 445 |
+
self.out_proj = linear_cls(embed_dim, embed_dim, bias=out_proj_bias, **factory_kwargs)
|
| 446 |
+
|
| 447 |
+
def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None):
|
| 448 |
+
dtype = self.out_proj.weight.dtype if dtype is None else dtype
|
| 449 |
+
device = self.out_proj.weight.device
|
| 450 |
+
return torch.empty(
|
| 451 |
+
batch_size,
|
| 452 |
+
max_seqlen,
|
| 453 |
+
2,
|
| 454 |
+
self.num_heads_kv,
|
| 455 |
+
self.head_dim,
|
| 456 |
+
dtype=dtype,
|
| 457 |
+
device=device,
|
| 458 |
+
)
|
| 459 |
+
|
| 460 |
+
def _update_kv_cache(self, kv, inference_params):
|
| 461 |
+
"""kv: (batch_size, seqlen, 2, nheads, head_dim) or (batch_size, 1, 2, nheads, head_dim)"""
|
| 462 |
+
assert not self.dwconv, "Generation does not support dwconv yet"
|
| 463 |
+
assert self.layer_idx is not None, "Generation requires layer_idx in the constructor"
|
| 464 |
+
return _update_kv_cache(kv, inference_params, self.layer_idx)
|
| 465 |
+
|
| 466 |
+
def _apply_rotary_update_kvcache_attention(self, q, kv, inference_params):
|
| 467 |
+
"""
|
| 468 |
+
Fast path that combine 3 steps: apply rotary to Q and K, update kv cache, and apply attention.
|
| 469 |
+
q: (batch_size, seqlen_q, nheads, head_dim)
|
| 470 |
+
kv: (batch_size, seqlen_k, 2, nheads_kv, head_dim)
|
| 471 |
+
"""
|
| 472 |
+
assert inference_params is not None and inference_params.seqlen_offset > 0
|
| 473 |
+
assert self.use_flash_attn
|
| 474 |
+
batch = q.shape[0]
|
| 475 |
+
kv_cache = inference_params.key_value_memory_dict[self.layer_idx][:batch]
|
| 476 |
+
cache_seqlens = (
|
| 477 |
+
inference_params.lengths_per_sample[:batch]
|
| 478 |
+
if inference_params.lengths_per_sample is not None
|
| 479 |
+
else inference_params.seqlen_offset
|
| 480 |
+
)
|
| 481 |
+
context = flash_attn_with_kvcache(
|
| 482 |
+
q,
|
| 483 |
+
kv_cache[:, :, 0],
|
| 484 |
+
kv_cache[:, :, 1],
|
| 485 |
+
kv[:, :, 0],
|
| 486 |
+
kv[:, :, 1],
|
| 487 |
+
cache_seqlens=cache_seqlens,
|
| 488 |
+
softmax_scale=self.inner_cross_attn.softmax_scale,
|
| 489 |
+
causal=self.inner_cross_attn.causal,
|
| 490 |
+
rotary_interleaved=False,
|
| 491 |
+
alibi_slopes=None,
|
| 492 |
+
)
|
| 493 |
+
return context
|
| 494 |
+
|
| 495 |
+
def _update_kvcache_attention(self, q, kv, inference_params):
|
| 496 |
+
"""Write kv to inference_params, then do attention"""
|
| 497 |
+
if (
|
| 498 |
+
inference_params.seqlen_offset == 0
|
| 499 |
+
or flash_attn_with_kvcache is None
|
| 500 |
+
or not self.use_flash_attn
|
| 501 |
+
):
|
| 502 |
+
# TODO: this only uses seqlen_offset and not lengths_per_sample.
|
| 503 |
+
kv = self._update_kv_cache(kv, inference_params)
|
| 504 |
+
return self.inner_cross_attn(q, kv)
|
| 505 |
+
else:
|
| 506 |
+
batch = q.shape[0]
|
| 507 |
+
kv_cache = inference_params.key_value_memory_dict[self.layer_idx][:batch]
|
| 508 |
+
cache_seqlens = (
|
| 509 |
+
inference_params.lengths_per_sample[:batch]
|
| 510 |
+
if inference_params.lengths_per_sample is not None
|
| 511 |
+
else inference_params.seqlen_offset
|
| 512 |
+
)
|
| 513 |
+
return flash_attn_with_kvcache(
|
| 514 |
+
q,
|
| 515 |
+
kv_cache[:, :, 0],
|
| 516 |
+
kv_cache[:, :, 1],
|
| 517 |
+
kv[:, :, 0],
|
| 518 |
+
kv[:, :, 1],
|
| 519 |
+
cache_seqlens=cache_seqlens,
|
| 520 |
+
softmax_scale=self.inner_cross_attn.softmax_scale,
|
| 521 |
+
causal=self.inner_cross_attn.causal,
|
| 522 |
+
alibi_slopes=None,
|
| 523 |
+
)
|
| 524 |
+
|
| 525 |
+
def forward(
|
| 526 |
+
self,
|
| 527 |
+
x,
|
| 528 |
+
x_kv=None,
|
| 529 |
+
key_padding_mask=None,
|
| 530 |
+
cu_seqlens=None,
|
| 531 |
+
max_seqlen=None,
|
| 532 |
+
mixer_subset=None,
|
| 533 |
+
inference_params=None,
|
| 534 |
+
**kwargs,
|
| 535 |
+
):
|
| 536 |
+
"""
|
| 537 |
+
Arguments:
|
| 538 |
+
x: (batch, seqlen, hidden_dim) (where hidden_dim = num heads * head dim) if
|
| 539 |
+
cu_seqlens is None and max_seqlen is None, else (total, hidden_dim) where total
|
| 540 |
+
is the is the sum of the sequence lengths in the batch.
|
| 541 |
+
x_kv: (batch, seqlen, hidden_dim), only applicable for cross-attention. If None, use x.
|
| 542 |
+
cu_seqlens: (batch_size + 1,), dtype torch.int32. The cumulative sequence lengths
|
| 543 |
+
of the sequences in the batch, used to index into x. Only applicable when using
|
| 544 |
+
FlashAttention.
|
| 545 |
+
max_seqlen: int. Maximum sequence length in the batch.
|
| 546 |
+
key_padding_mask: boolean mask, True means to keep, False means to mask out.
|
| 547 |
+
(batch, seqlen). Only applicable when not using FlashAttention.
|
| 548 |
+
mixer_subset: for cross-attention only. If not None, will take a subset of x
|
| 549 |
+
before applying the query projection. Useful for e.g., ViT where we only care
|
| 550 |
+
about the CLS token in the last layer.
|
| 551 |
+
inference_params: for generation. Adapted from Megatron-LM (and Apex)
|
| 552 |
+
https://github.com/NVIDIA/apex/blob/3ff1a10f72ec07067c4e44759442329804ac5162/apex/transformer/testing/standalone_transformer_lm.py#L470
|
| 553 |
+
"""
|
| 554 |
+
if cu_seqlens is not None:
|
| 555 |
+
assert max_seqlen is not None
|
| 556 |
+
assert key_padding_mask is None
|
| 557 |
+
assert self.use_flash_attn
|
| 558 |
+
assert not self.dwconv
|
| 559 |
+
if key_padding_mask is not None:
|
| 560 |
+
assert cu_seqlens is None
|
| 561 |
+
assert max_seqlen is None
|
| 562 |
+
assert not self.use_flash_attn
|
| 563 |
+
if inference_params is not None:
|
| 564 |
+
assert key_padding_mask is None
|
| 565 |
+
assert cu_seqlens is None and max_seqlen is None
|
| 566 |
+
assert not self.dwconv
|
| 567 |
+
|
| 568 |
+
kwargs = (
|
| 569 |
+
{"cu_seqlens": cu_seqlens, "max_seqlen": max_seqlen, **kwargs}
|
| 570 |
+
if self.use_flash_attn
|
| 571 |
+
else {"key_padding_mask": key_padding_mask, **kwargs}
|
| 572 |
+
)
|
| 573 |
+
seqlen_offset = (
|
| 574 |
+
0
|
| 575 |
+
if inference_params is None
|
| 576 |
+
else (
|
| 577 |
+
inference_params.lengths_per_sample
|
| 578 |
+
if inference_params.lengths_per_sample is not None
|
| 579 |
+
else inference_params.seqlen_offset
|
| 580 |
+
)
|
| 581 |
+
)
|
| 582 |
+
rotary_max_seqlen = (
|
| 583 |
+
inference_params.max_sequence_len if inference_params is not None else max_seqlen
|
| 584 |
+
)
|
| 585 |
+
batch, seqlen = x.shape[:2]
|
| 586 |
+
if not self.cross_attn and self.num_heads_kv == self.num_heads:
|
| 587 |
+
assert x_kv is None and mixer_subset is None
|
| 588 |
+
if not self.return_residual:
|
| 589 |
+
qkv = self.Wqkv(x)
|
| 590 |
+
else:
|
| 591 |
+
qkv, x = self.Wqkv(x)
|
| 592 |
+
if self.dwconv:
|
| 593 |
+
qkv = rearrange(
|
| 594 |
+
self.dwconv_qkv(rearrange(qkv, "b s d -> b d s"))[..., :-2], "b d s -> b s d"
|
| 595 |
+
).contiguous()
|
| 596 |
+
qkv = rearrange(qkv, "... (three h d) -> ... three h d", three=3, d=self.head_dim)
|
| 597 |
+
if (
|
| 598 |
+
inference_params is None
|
| 599 |
+
or inference_params.seqlen_offset == 0
|
| 600 |
+
or (self.rotary_emb_dim == 0 or self.rotary_emb_dim % 16 != 0)
|
| 601 |
+
or not self.use_flash_attn
|
| 602 |
+
):
|
| 603 |
+
if inference_params is None:
|
| 604 |
+
if not self.checkpointing:
|
| 605 |
+
context = self.inner_attn(qkv, **kwargs)
|
| 606 |
+
else:
|
| 607 |
+
context = torch.utils.checkpoint.checkpoint(self.inner_attn, qkv, **kwargs)
|
| 608 |
+
else:
|
| 609 |
+
context = self._update_kvcache_attention(
|
| 610 |
+
qkv[:, :, 0], qkv[:, :, 1:], inference_params
|
| 611 |
+
)
|
| 612 |
+
else:
|
| 613 |
+
context = self._apply_rotary_update_kvcache_attention(
|
| 614 |
+
qkv[:, :, 0], qkv[:, :, 1:], inference_params
|
| 615 |
+
)
|
| 616 |
+
else:
|
| 617 |
+
if self.cross_attn:
|
| 618 |
+
if not self.return_residual:
|
| 619 |
+
q = self.Wq(x if mixer_subset is None else x[:, mixer_subset])
|
| 620 |
+
kv = self.Wkv(x_kv if x_kv is not None else x)
|
| 621 |
+
else:
|
| 622 |
+
if x_kv is not None:
|
| 623 |
+
kv, x_kv = self.Wkv(x_kv)
|
| 624 |
+
else:
|
| 625 |
+
kv, x = self.Wkv(x)
|
| 626 |
+
q = self.Wq(x if mixer_subset is None else x[:, mixer_subset])
|
| 627 |
+
else:
|
| 628 |
+
assert self.num_heads_kv != self.num_heads
|
| 629 |
+
if not self.return_residual:
|
| 630 |
+
qkv = self.Wqkv(x)
|
| 631 |
+
else:
|
| 632 |
+
qkv, x = self.Wqkv(x)
|
| 633 |
+
q = qkv[..., : self.num_heads * self.head_dim]
|
| 634 |
+
kv = qkv[..., self.num_heads * self.head_dim :]
|
| 635 |
+
q = rearrange(q, "... (h d) -> ... h d", d=self.head_dim)
|
| 636 |
+
kv = rearrange(kv, "... (two hkv d) -> ... two hkv d", two=2, d=self.head_dim)
|
| 637 |
+
if self.dwconv:
|
| 638 |
+
q = rearrange(
|
| 639 |
+
self.dwconv_q(rearrange(q, "b s d -> b d s"))[..., :-2], "b d s -> b s d"
|
| 640 |
+
).contiguous()
|
| 641 |
+
kv = rearrange(
|
| 642 |
+
self.dwconv_kv(rearrange(kv, "b s d -> b d s"))[..., :-2], "b d s -> b s d"
|
| 643 |
+
).contiguous()
|
| 644 |
+
if (
|
| 645 |
+
inference_params is None
|
| 646 |
+
or inference_params.seqlen_offset == 0
|
| 647 |
+
or (self.rotary_emb_dim == 0 or self.rotary_emb_dim % 16 != 0)
|
| 648 |
+
or not self.use_flash_attn
|
| 649 |
+
):
|
| 650 |
+
if inference_params is None:
|
| 651 |
+
if not self.checkpointing:
|
| 652 |
+
context = self.inner_cross_attn(q, kv, **kwargs)
|
| 653 |
+
else:
|
| 654 |
+
context = torch.utils.checkpoint.checkpoint(
|
| 655 |
+
self.inner_cross_attn, q, kv, **kwargs
|
| 656 |
+
)
|
| 657 |
+
else:
|
| 658 |
+
context = self._update_kvcache_attention(q, kv, inference_params)
|
| 659 |
+
else:
|
| 660 |
+
context = self._apply_rotary_update_kvcache_attention(q, kv, inference_params)
|
| 661 |
+
out = self.out_proj(rearrange(context, "... h d -> ... (h d)"))
|
| 662 |
+
return out if not self.return_residual else (out, x)
|
mlp.py
ADDED
|
@@ -0,0 +1,194 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# This implementation was adapted from https://github.com/Dao-AILab/flash-attention/blob/main/flash_attn/modules/mlp.py
|
| 2 |
+
# Commit id: c3b219665292c61a51153d0ded4473c494296382
|
| 3 |
+
|
| 4 |
+
# Copyright (c) 2023, Tri Dao.
|
| 5 |
+
|
| 6 |
+
import torch
|
| 7 |
+
import torch.nn as nn
|
| 8 |
+
import torch.nn.functional as F
|
| 9 |
+
from torch.distributed import ProcessGroup
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
try:
|
| 13 |
+
from flash_attn.ops.activations import swiglu
|
| 14 |
+
except ImportError:
|
| 15 |
+
swiglu = None
|
| 16 |
+
|
| 17 |
+
try:
|
| 18 |
+
from flash_attn.ops.fused_dense import ColumnParallelLinear, RowParallelLinear
|
| 19 |
+
except ImportError:
|
| 20 |
+
ColumnParallelLinear, RowParallelLinear = None, None
|
| 21 |
+
|
| 22 |
+
try:
|
| 23 |
+
from flash_attn.ops.fused_dense import FusedMLP, ParallelFusedMLP
|
| 24 |
+
except ImportError:
|
| 25 |
+
FusedMLP, ParallelFusedMLP = None, None
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
class Mlp(nn.Module):
|
| 29 |
+
def __init__(
|
| 30 |
+
self,
|
| 31 |
+
in_features,
|
| 32 |
+
hidden_features=None,
|
| 33 |
+
out_features=None,
|
| 34 |
+
activation=F.gelu,
|
| 35 |
+
bias1=True,
|
| 36 |
+
bias2=True,
|
| 37 |
+
return_residual=False,
|
| 38 |
+
device=None,
|
| 39 |
+
dtype=None,
|
| 40 |
+
):
|
| 41 |
+
factory_kwargs = {"device": device, "dtype": dtype}
|
| 42 |
+
super().__init__()
|
| 43 |
+
out_features = out_features if out_features is not None else in_features
|
| 44 |
+
hidden_features = hidden_features if hidden_features is not None else in_features * 4
|
| 45 |
+
self.return_residual = return_residual
|
| 46 |
+
self.fc1 = nn.Linear(in_features, hidden_features, bias=bias1, **factory_kwargs)
|
| 47 |
+
self.activation = activation
|
| 48 |
+
self.fc2 = nn.Linear(hidden_features, out_features, bias=bias2, **factory_kwargs)
|
| 49 |
+
|
| 50 |
+
def forward(self, x):
|
| 51 |
+
y = self.fc1(x)
|
| 52 |
+
y = self.activation(y)
|
| 53 |
+
y = self.fc2(y)
|
| 54 |
+
return y if not self.return_residual else (y, x)
|
| 55 |
+
|
| 56 |
+
|
| 57 |
+
class ParallelMLP(nn.Module):
|
| 58 |
+
def __init__(
|
| 59 |
+
self,
|
| 60 |
+
in_features,
|
| 61 |
+
hidden_features=None,
|
| 62 |
+
out_features=None,
|
| 63 |
+
activation=F.gelu,
|
| 64 |
+
process_group: ProcessGroup = None,
|
| 65 |
+
sequence_parallel=True,
|
| 66 |
+
bias1=True,
|
| 67 |
+
bias2=True,
|
| 68 |
+
device=None,
|
| 69 |
+
dtype=None,
|
| 70 |
+
):
|
| 71 |
+
factory_kwargs = {"device": device, "dtype": dtype}
|
| 72 |
+
super().__init__()
|
| 73 |
+
assert ColumnParallelLinear is not None, "Need to install fused_dense"
|
| 74 |
+
assert RowParallelLinear is not None, "Need to install fused_dense"
|
| 75 |
+
out_features = out_features if out_features is not None else in_features
|
| 76 |
+
hidden_features = hidden_features if hidden_features is not None else in_features * 4
|
| 77 |
+
self.fc1 = ColumnParallelLinear(
|
| 78 |
+
in_features,
|
| 79 |
+
hidden_features,
|
| 80 |
+
process_group,
|
| 81 |
+
bias=bias1,
|
| 82 |
+
sequence_parallel=sequence_parallel,
|
| 83 |
+
**factory_kwargs,
|
| 84 |
+
)
|
| 85 |
+
self.activation = activation
|
| 86 |
+
self.fc2 = RowParallelLinear(
|
| 87 |
+
hidden_features,
|
| 88 |
+
out_features,
|
| 89 |
+
process_group,
|
| 90 |
+
bias=bias2,
|
| 91 |
+
sequence_parallel=sequence_parallel,
|
| 92 |
+
**factory_kwargs,
|
| 93 |
+
)
|
| 94 |
+
|
| 95 |
+
def forward(self, x):
|
| 96 |
+
y = self.fc1(x)
|
| 97 |
+
y = self.activation(y)
|
| 98 |
+
y = self.fc2(y)
|
| 99 |
+
return y
|
| 100 |
+
|
| 101 |
+
|
| 102 |
+
class GatedMlp(nn.Module):
|
| 103 |
+
def __init__(
|
| 104 |
+
self,
|
| 105 |
+
in_features,
|
| 106 |
+
hidden_features=None,
|
| 107 |
+
out_features=None,
|
| 108 |
+
activation=F.sigmoid,
|
| 109 |
+
bias1=True,
|
| 110 |
+
bias2=True,
|
| 111 |
+
multiple_of=128,
|
| 112 |
+
return_residual=False,
|
| 113 |
+
device=None,
|
| 114 |
+
dtype=None,
|
| 115 |
+
):
|
| 116 |
+
factory_kwargs = {"device": device, "dtype": dtype}
|
| 117 |
+
super().__init__()
|
| 118 |
+
out_features = out_features if out_features is not None else in_features
|
| 119 |
+
hidden_features = (
|
| 120 |
+
hidden_features if hidden_features is not None else int(8 * in_features / 3)
|
| 121 |
+
)
|
| 122 |
+
hidden_features = (hidden_features + multiple_of - 1) // multiple_of * multiple_of
|
| 123 |
+
self.return_residual = return_residual
|
| 124 |
+
self.fc1 = nn.Linear(in_features, 2 * hidden_features, bias=bias1, **factory_kwargs)
|
| 125 |
+
self.activation = activation
|
| 126 |
+
self.fc2 = nn.Linear(hidden_features, out_features, bias=bias2, **factory_kwargs)
|
| 127 |
+
|
| 128 |
+
def forward(self, x):
|
| 129 |
+
y = self.fc1(x)
|
| 130 |
+
if self.activation == F.sigmoid: # Special case for GLU
|
| 131 |
+
y = F.glu(y, dim=-1)
|
| 132 |
+
elif self.activation == F.silu and swiglu is not None: # Special case for SwiGLU
|
| 133 |
+
y, gate = y.chunk(2, dim=-1)
|
| 134 |
+
y = swiglu(gate, y)
|
| 135 |
+
else:
|
| 136 |
+
y, gate = y.chunk(2, dim=-1)
|
| 137 |
+
y = y * self.activation(gate)
|
| 138 |
+
y = self.fc2(y)
|
| 139 |
+
return y if not self.return_residual else (y, x)
|
| 140 |
+
|
| 141 |
+
|
| 142 |
+
class ParallelGatedMlp(nn.Module):
|
| 143 |
+
"""Parallel GatedMlp"""
|
| 144 |
+
|
| 145 |
+
def __init__(
|
| 146 |
+
self,
|
| 147 |
+
in_features,
|
| 148 |
+
process_group,
|
| 149 |
+
hidden_features=None,
|
| 150 |
+
out_features=None,
|
| 151 |
+
activation=F.sigmoid,
|
| 152 |
+
bias1=True,
|
| 153 |
+
bias2=True,
|
| 154 |
+
multiple_of=128,
|
| 155 |
+
sequence_parallel=True,
|
| 156 |
+
device=None,
|
| 157 |
+
dtype=None,
|
| 158 |
+
):
|
| 159 |
+
factory_kwargs = {"device": device, "dtype": dtype}
|
| 160 |
+
super().__init__()
|
| 161 |
+
out_features = out_features if out_features is not None else in_features
|
| 162 |
+
hidden_features = (
|
| 163 |
+
hidden_features if hidden_features is not None else int(8 * in_features / 3)
|
| 164 |
+
)
|
| 165 |
+
hidden_features = (hidden_features + multiple_of - 1) // multiple_of * multiple_of
|
| 166 |
+
if ColumnParallelLinear is None or RowParallelLinear is None:
|
| 167 |
+
raise ImportError("fused_dense is not installed")
|
| 168 |
+
self.fc1 = ColumnParallelLinear(
|
| 169 |
+
in_features,
|
| 170 |
+
2 * hidden_features,
|
| 171 |
+
process_group,
|
| 172 |
+
bias=bias1,
|
| 173 |
+
sequence_parallel=sequence_parallel,
|
| 174 |
+
**factory_kwargs,
|
| 175 |
+
)
|
| 176 |
+
self.activation = activation
|
| 177 |
+
self.fc2 = RowParallelLinear(
|
| 178 |
+
hidden_features,
|
| 179 |
+
out_features,
|
| 180 |
+
process_group,
|
| 181 |
+
bias=bias2,
|
| 182 |
+
sequence_parallel=sequence_parallel,
|
| 183 |
+
**factory_kwargs,
|
| 184 |
+
)
|
| 185 |
+
|
| 186 |
+
def forward(self, x):
|
| 187 |
+
y = self.fc1(x)
|
| 188 |
+
if self.activation == F.sigmoid: # Special case for GLU
|
| 189 |
+
y = F.glu(y, dim=-1)
|
| 190 |
+
else:
|
| 191 |
+
y, gate = y.chunk(2, dim=-1)
|
| 192 |
+
y = y * self.activation(gate)
|
| 193 |
+
y = self.fc2(y)
|
| 194 |
+
return y
|
model.safetensors
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:5c7c4b46d9d342fe128e46a55adbac52fd96dd0aaafee7c26e8d8f2286bee91a
|
| 3 |
+
size 556892306
|
modeling_xlm_roberta.py
ADDED
|
@@ -0,0 +1,1119 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# This implementation was adopted from https://github.com/Dao-AILab/flash-attention/blob/main/flash_attn/models/bert.py
|
| 2 |
+
# Commit id: abbc1311731867310635f9edc2a9ec18317c8c48
|
| 3 |
+
# Copyright (c) 2022, Tri Dao.
|
| 4 |
+
# This BERT implementation is based on our MLPerf 2.0 and MLPerf 2.1 BERT implementation.
|
| 5 |
+
# https://github.com/mlcommons/training_results_v2.0/blob/main/HazyResearch/benchmarks/bert/implementations/pytorch/modeling.py
|
| 6 |
+
# https://github.com/mlcommons/training_results_v2.1/blob/main/Azure-HazyResearch/benchmarks/bert/implementations/ND96amsr_A100_v4/modeling.py
|
| 7 |
+
|
| 8 |
+
# Inspired by https://github.com/huggingface/transformers/blob/main/src/transformers/models/bert/modeling_bert.py
|
| 9 |
+
|
| 10 |
+
import importlib.util
|
| 11 |
+
import logging
|
| 12 |
+
import re
|
| 13 |
+
from collections import OrderedDict
|
| 14 |
+
from collections.abc import Sequence
|
| 15 |
+
from functools import partial
|
| 16 |
+
import numpy as np
|
| 17 |
+
|
| 18 |
+
import torch
|
| 19 |
+
import torch.nn as nn
|
| 20 |
+
import torch.nn.functional as F
|
| 21 |
+
import torch.utils.checkpoint
|
| 22 |
+
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
|
| 23 |
+
from einops import rearrange
|
| 24 |
+
from transformers import PretrainedConfig
|
| 25 |
+
from transformers.modeling_utils import PreTrainedModel
|
| 26 |
+
from transformers.modeling_outputs import MaskedLMOutput,SequenceClassifierOutput
|
| 27 |
+
from transformers.models.xlm_roberta.modeling_xlm_roberta import XLMRobertaLMHead
|
| 28 |
+
|
| 29 |
+
from transformers.models.bert.modeling_bert import (
|
| 30 |
+
BaseModelOutputWithPoolingAndCrossAttentions,
|
| 31 |
+
BertForPreTrainingOutput,
|
| 32 |
+
)
|
| 33 |
+
|
| 34 |
+
from typing import List, Optional, Tuple, Union
|
| 35 |
+
|
| 36 |
+
from .xlm_padding import (
|
| 37 |
+
index_first_axis,
|
| 38 |
+
index_first_axis_residual,
|
| 39 |
+
pad_input,
|
| 40 |
+
unpad_input,
|
| 41 |
+
)
|
| 42 |
+
from .configuration_xlm_roberta import XLMRobertaFlashConfig
|
| 43 |
+
from .block import Block
|
| 44 |
+
from .embedding import XLMRobertaEmbeddings
|
| 45 |
+
from .mha import MHA
|
| 46 |
+
from .mlp import FusedMLP, Mlp
|
| 47 |
+
|
| 48 |
+
try:
|
| 49 |
+
from flash_attn.ops.fused_dense import FusedDense
|
| 50 |
+
except ImportError:
|
| 51 |
+
FusedDense = None
|
| 52 |
+
|
| 53 |
+
try:
|
| 54 |
+
from flash_attn.ops.triton.layer_norm import layer_norm_fn
|
| 55 |
+
except ImportError:
|
| 56 |
+
layer_norm_fn = None
|
| 57 |
+
|
| 58 |
+
|
| 59 |
+
try:
|
| 60 |
+
from flash_attn.losses.cross_entropy import CrossEntropyLoss
|
| 61 |
+
except ImportError:
|
| 62 |
+
CrossEntropyLoss = torch.nn.CrossEntropyLoss
|
| 63 |
+
|
| 64 |
+
try:
|
| 65 |
+
from tqdm.autonotebook import trange
|
| 66 |
+
except ImportError:
|
| 67 |
+
trange = None
|
| 68 |
+
|
| 69 |
+
|
| 70 |
+
logger = logging.getLogger(__name__)
|
| 71 |
+
|
| 72 |
+
|
| 73 |
+
def get_use_flash_attn(config: XLMRobertaFlashConfig):
|
| 74 |
+
if not getattr(config, "use_flash_attn", False):
|
| 75 |
+
return False
|
| 76 |
+
if not torch.cuda.is_available():
|
| 77 |
+
return False
|
| 78 |
+
if importlib.util.find_spec("flash_attn") is None:
|
| 79 |
+
logger.warning(
|
| 80 |
+
'flash_attn is not installed. Using PyTorch native attention implementation.'
|
| 81 |
+
)
|
| 82 |
+
return False
|
| 83 |
+
return True
|
| 84 |
+
|
| 85 |
+
|
| 86 |
+
def create_mixer_cls(config, cross_attn=False, return_residual=False):
|
| 87 |
+
use_flash_attn = get_use_flash_attn(config)
|
| 88 |
+
fused_bias_fc = getattr(config, "fused_bias_fc", False)
|
| 89 |
+
|
| 90 |
+
mixer_cls = partial(
|
| 91 |
+
MHA,
|
| 92 |
+
num_heads=config.num_attention_heads,
|
| 93 |
+
cross_attn=cross_attn,
|
| 94 |
+
dropout=config.attention_probs_dropout_prob,
|
| 95 |
+
causal=False,
|
| 96 |
+
fused_bias_fc=fused_bias_fc,
|
| 97 |
+
use_flash_attn=use_flash_attn,
|
| 98 |
+
return_residual=return_residual,
|
| 99 |
+
)
|
| 100 |
+
return mixer_cls
|
| 101 |
+
|
| 102 |
+
|
| 103 |
+
def create_mlp_cls(config, layer_idx=None, return_residual=False):
|
| 104 |
+
inner_dim = config.intermediate_size
|
| 105 |
+
fused_mlp = getattr(config, "fused_mlp", False)
|
| 106 |
+
if fused_mlp:
|
| 107 |
+
assert config.hidden_act in ["gelu_new", "gelu_fast", "gelu_pytorch_tanh"], (
|
| 108 |
+
"fused_mlp only " "supports approximate gelu"
|
| 109 |
+
)
|
| 110 |
+
if not fused_mlp:
|
| 111 |
+
approximate = (
|
| 112 |
+
"tanh"
|
| 113 |
+
if config.hidden_act in ["gelu_new", "gelu_fast", "gelu_pytorch_tanh"]
|
| 114 |
+
else "none"
|
| 115 |
+
)
|
| 116 |
+
mlp_cls = partial(
|
| 117 |
+
Mlp,
|
| 118 |
+
hidden_features=inner_dim,
|
| 119 |
+
activation=partial(F.gelu, approximate=approximate),
|
| 120 |
+
return_residual=return_residual,
|
| 121 |
+
)
|
| 122 |
+
else:
|
| 123 |
+
if FusedMLP is None:
|
| 124 |
+
raise ImportError("fused_dense is not installed")
|
| 125 |
+
mlp_checkpoint_lvl = getattr(config, "mlp_checkpoint_lvl", 0)
|
| 126 |
+
# mlp_checkpoint_lvl could be a list, which contains the checkpoint_lvl for each layer
|
| 127 |
+
if isinstance(mlp_checkpoint_lvl, Sequence):
|
| 128 |
+
assert layer_idx is not None
|
| 129 |
+
mlp_checkpoint_lvl = mlp_checkpoint_lvl[layer_idx]
|
| 130 |
+
mlp_cls = partial(
|
| 131 |
+
FusedMLP,
|
| 132 |
+
hidden_features=inner_dim,
|
| 133 |
+
checkpoint_lvl=mlp_checkpoint_lvl,
|
| 134 |
+
return_residual=return_residual,
|
| 135 |
+
)
|
| 136 |
+
return mlp_cls
|
| 137 |
+
|
| 138 |
+
|
| 139 |
+
def create_block(config, layer_idx=None):
|
| 140 |
+
last_layer_subset = getattr(config, "last_layer_subset", False)
|
| 141 |
+
cross_attn = last_layer_subset and layer_idx == config.num_hidden_layers - 1
|
| 142 |
+
# TD [2022-12-19]: For cross attention (last layer), we actually want to return the
|
| 143 |
+
# residual x_kv, not residual x. But it's annoying to change the API (and it only affects
|
| 144 |
+
# one layer) so we just choose not to return residual in this case.
|
| 145 |
+
return_residual = not cross_attn
|
| 146 |
+
mixer_cls = create_mixer_cls(config, cross_attn, return_residual=return_residual)
|
| 147 |
+
mlp_cls = create_mlp_cls(config, layer_idx, return_residual=return_residual)
|
| 148 |
+
norm_cls = partial(nn.LayerNorm, eps=config.layer_norm_eps)
|
| 149 |
+
block = Block(
|
| 150 |
+
config.hidden_size,
|
| 151 |
+
mixer_cls,
|
| 152 |
+
mlp_cls,
|
| 153 |
+
norm_cls=norm_cls,
|
| 154 |
+
prenorm=False,
|
| 155 |
+
resid_dropout1=config.hidden_dropout_prob,
|
| 156 |
+
resid_dropout2=config.hidden_dropout_prob,
|
| 157 |
+
fused_dropout_add_ln=getattr(config, "fused_dropout_add_ln", False),
|
| 158 |
+
return_residual=return_residual,
|
| 159 |
+
)
|
| 160 |
+
return block
|
| 161 |
+
|
| 162 |
+
|
| 163 |
+
# https://github.com/huggingface/transformers/blob/7032e0203262ebb2ebf55da8d2e01f873973e835/src/transformers/models/bert/modeling_bert.py#L748
|
| 164 |
+
def _init_weights(module, initializer_range=0.02):
|
| 165 |
+
if isinstance(module, nn.Linear):
|
| 166 |
+
nn.init.normal_(module.weight, std=initializer_range)
|
| 167 |
+
if module.bias is not None:
|
| 168 |
+
nn.init.zeros_(module.bias)
|
| 169 |
+
elif isinstance(module, nn.Embedding):
|
| 170 |
+
nn.init.normal_(module.weight, std=initializer_range)
|
| 171 |
+
if module.padding_idx is not None:
|
| 172 |
+
nn.init.zeros_(module.weight[module.padding_idx])
|
| 173 |
+
|
| 174 |
+
|
| 175 |
+
class XLMRobertaEncoder(nn.Module):
|
| 176 |
+
def __init__(self, config: XLMRobertaFlashConfig):
|
| 177 |
+
super().__init__()
|
| 178 |
+
self.use_flash_attn = get_use_flash_attn(config)
|
| 179 |
+
self.layers = nn.ModuleList(
|
| 180 |
+
[create_block(config, layer_idx=i) for i in range(config.num_hidden_layers)]
|
| 181 |
+
)
|
| 182 |
+
self._grad_checkpointing = False
|
| 183 |
+
|
| 184 |
+
@property
|
| 185 |
+
def gradient_checkpointing(self):
|
| 186 |
+
return self._grad_checkpointing
|
| 187 |
+
|
| 188 |
+
@gradient_checkpointing.setter
|
| 189 |
+
def gradient_checkpointing(self, value):
|
| 190 |
+
self._grad_checkpointing = value
|
| 191 |
+
|
| 192 |
+
def forward(self, hidden_states, key_padding_mask=None, subset_mask=None):
|
| 193 |
+
"""If subset_mask is not None, we only want output for the subset of the sequence.
|
| 194 |
+
This means that we only compute the last layer output for these tokens.
|
| 195 |
+
subset_mask: (batch, seqlen), dtype=torch.bool
|
| 196 |
+
"""
|
| 197 |
+
if key_padding_mask is None or not self.use_flash_attn:
|
| 198 |
+
mixer_kwargs = (
|
| 199 |
+
{"key_padding_mask": key_padding_mask.bool()}
|
| 200 |
+
if key_padding_mask is not None
|
| 201 |
+
else None
|
| 202 |
+
)
|
| 203 |
+
for layer in self.layers:
|
| 204 |
+
if self._grad_checkpointing:
|
| 205 |
+
hidden_states = torch.utils.checkpoint.checkpoint(
|
| 206 |
+
layer,
|
| 207 |
+
hidden_states,
|
| 208 |
+
use_reentrant=False,
|
| 209 |
+
mixer_kwargs=mixer_kwargs,
|
| 210 |
+
)
|
| 211 |
+
else:
|
| 212 |
+
hidden_states = layer(hidden_states, mixer_kwargs=mixer_kwargs)
|
| 213 |
+
if subset_mask is not None:
|
| 214 |
+
hidden_states = hidden_states[subset_mask]
|
| 215 |
+
else:
|
| 216 |
+
batch, seqlen = hidden_states.shape[:2]
|
| 217 |
+
hidden_states, indices, cu_seqlens, max_seqlen_in_batch = unpad_input(
|
| 218 |
+
hidden_states, key_padding_mask
|
| 219 |
+
)
|
| 220 |
+
mixer_kwargs = {"cu_seqlens": cu_seqlens, "max_seqlen": max_seqlen_in_batch}
|
| 221 |
+
if subset_mask is None:
|
| 222 |
+
for layer in self.layers:
|
| 223 |
+
if self._grad_checkpointing:
|
| 224 |
+
hidden_states = torch.utils.checkpoint.checkpoint(
|
| 225 |
+
layer,
|
| 226 |
+
hidden_states,
|
| 227 |
+
use_reentrant=False,
|
| 228 |
+
mixer_kwargs=mixer_kwargs,
|
| 229 |
+
)
|
| 230 |
+
else:
|
| 231 |
+
hidden_states = layer(hidden_states, mixer_kwargs=mixer_kwargs)
|
| 232 |
+
hidden_states = pad_input(hidden_states, indices, batch, seqlen)
|
| 233 |
+
else:
|
| 234 |
+
for layer in self.layers[:-1]:
|
| 235 |
+
if self._grad_checkpointing:
|
| 236 |
+
hidden_states = torch.utils.checkpoint.checkpoint(
|
| 237 |
+
layer,
|
| 238 |
+
hidden_states,
|
| 239 |
+
use_reentrant=False,
|
| 240 |
+
mixer_kwargs=mixer_kwargs,
|
| 241 |
+
)
|
| 242 |
+
else:
|
| 243 |
+
hidden_states = layer(hidden_states, mixer_kwargs=mixer_kwargs)
|
| 244 |
+
if key_padding_mask is not None:
|
| 245 |
+
subset_idx = torch.nonzero(
|
| 246 |
+
subset_mask[key_padding_mask], as_tuple=False
|
| 247 |
+
).flatten()
|
| 248 |
+
subset_seqlens = (subset_mask & key_padding_mask).sum(
|
| 249 |
+
dim=-1, dtype=torch.int32
|
| 250 |
+
)
|
| 251 |
+
subset_cu_seqlens = F.pad(
|
| 252 |
+
torch.cumsum(subset_seqlens, dim=0, dtype=torch.torch.int32),
|
| 253 |
+
(1, 0),
|
| 254 |
+
)
|
| 255 |
+
else:
|
| 256 |
+
subset_idx = torch.nonzero(subset_mask, as_tuple=False).flatten()
|
| 257 |
+
subset_seqlens = subset_mask.sum(dim=-1, dtype=torch.int32)
|
| 258 |
+
subset_cu_seqlens = F.pad(
|
| 259 |
+
torch.cumsum(subset_seqlens, dim=0, dtype=torch.torch.int32),
|
| 260 |
+
(1, 0),
|
| 261 |
+
)
|
| 262 |
+
hidden_states_subset, hidden_states = index_first_axis_residual(
|
| 263 |
+
hidden_states, subset_idx
|
| 264 |
+
)
|
| 265 |
+
# It's ok to set max_seqlen_q to be much larger
|
| 266 |
+
mixer_kwargs = {
|
| 267 |
+
"x_kv": hidden_states,
|
| 268 |
+
"cu_seqlens": subset_cu_seqlens,
|
| 269 |
+
"max_seqlen": max_seqlen_in_batch,
|
| 270 |
+
"cu_seqlens_k": cu_seqlens,
|
| 271 |
+
"max_seqlen_k": max_seqlen_in_batch,
|
| 272 |
+
}
|
| 273 |
+
if self._grad_checkpointing:
|
| 274 |
+
torch.utils.checkpoint.checkpoint(
|
| 275 |
+
self.layers[-1],
|
| 276 |
+
hidden_states_subset,
|
| 277 |
+
use_reentrant=False,
|
| 278 |
+
mixer_kwargs=mixer_kwargs,
|
| 279 |
+
)
|
| 280 |
+
else:
|
| 281 |
+
hidden_states = self.layers[-1](
|
| 282 |
+
hidden_states_subset, mixer_kwargs=mixer_kwargs
|
| 283 |
+
)
|
| 284 |
+
return hidden_states
|
| 285 |
+
|
| 286 |
+
|
| 287 |
+
class XLMRobertaPooler(nn.Module):
|
| 288 |
+
def __init__(self, config):
|
| 289 |
+
super().__init__()
|
| 290 |
+
fused_bias_fc = getattr(config, "fused_bias_fc", False)
|
| 291 |
+
if fused_bias_fc and FusedDense is None:
|
| 292 |
+
raise ImportError("fused_dense is not installed")
|
| 293 |
+
linear_cls = nn.Linear if not fused_bias_fc else FusedDense
|
| 294 |
+
self.dense = linear_cls(config.hidden_size, config.hidden_size)
|
| 295 |
+
self.activation = nn.Tanh()
|
| 296 |
+
|
| 297 |
+
def forward(self, hidden_states, pool=True):
|
| 298 |
+
# We "pool" the model by simply taking the hidden state corresponding
|
| 299 |
+
# to the first token.
|
| 300 |
+
first_token_tensor = hidden_states[:, 0] if pool else hidden_states
|
| 301 |
+
pooled_output = self.dense(first_token_tensor)
|
| 302 |
+
pooled_output = self.activation(pooled_output)
|
| 303 |
+
return pooled_output
|
| 304 |
+
|
| 305 |
+
|
| 306 |
+
class XLMRobertaPredictionHeadTransform(nn.Module):
|
| 307 |
+
def __init__(self, config):
|
| 308 |
+
super().__init__()
|
| 309 |
+
fused_bias_fc = getattr(config, "fused_bias_fc", False)
|
| 310 |
+
if fused_bias_fc and FusedDense is None:
|
| 311 |
+
raise ImportError("fused_dense is not installed")
|
| 312 |
+
self.fused_dropout_add_ln = getattr(config, "fused_dropout_add_ln", False)
|
| 313 |
+
if self.fused_dropout_add_ln and layer_norm_fn is None:
|
| 314 |
+
raise ImportError("Triton is not installed")
|
| 315 |
+
linear_cls = nn.Linear if not fused_bias_fc else FusedDense
|
| 316 |
+
self.dense = linear_cls(config.hidden_size, config.hidden_size)
|
| 317 |
+
approximate = (
|
| 318 |
+
"tanh"
|
| 319 |
+
if config.hidden_act in ["gelu_new", "gelu_fast", "gelu_pytorch_tanh"]
|
| 320 |
+
else "none"
|
| 321 |
+
)
|
| 322 |
+
self.transform_act_fn = nn.GELU(approximate=approximate)
|
| 323 |
+
self.layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
| 324 |
+
|
| 325 |
+
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
| 326 |
+
hidden_states = self.dense(hidden_states)
|
| 327 |
+
hidden_states = self.transform_act_fn(hidden_states)
|
| 328 |
+
if not self.fused_dropout_add_ln:
|
| 329 |
+
hidden_states = self.layer_norm(hidden_states)
|
| 330 |
+
else:
|
| 331 |
+
hidden_states = layer_norm_fn(
|
| 332 |
+
hidden_states,
|
| 333 |
+
self.layer_norm.weight,
|
| 334 |
+
self.layer_norm.bias,
|
| 335 |
+
eps=self.layer_norm.eps,
|
| 336 |
+
)
|
| 337 |
+
return hidden_states
|
| 338 |
+
|
| 339 |
+
|
| 340 |
+
class XLMRobertaLMPredictionHead(nn.Module):
|
| 341 |
+
def __init__(self, config):
|
| 342 |
+
super().__init__()
|
| 343 |
+
fused_bias_fc = getattr(config, "fused_bias_fc", False)
|
| 344 |
+
if fused_bias_fc and FusedDense is None:
|
| 345 |
+
raise ImportError("fused_dense is not installed")
|
| 346 |
+
linear_cls = nn.Linear if not fused_bias_fc else FusedDense
|
| 347 |
+
|
| 348 |
+
self.transform = XLMRobertaPredictionHeadTransform(config)
|
| 349 |
+
|
| 350 |
+
# The output weights are the same as the input embeddings, but there is
|
| 351 |
+
# an output-only bias for each token.
|
| 352 |
+
self.decoder = linear_cls(config.hidden_size, config.vocab_size, bias=True)
|
| 353 |
+
|
| 354 |
+
def forward(self, hidden_states):
|
| 355 |
+
hidden_states = self.transform(hidden_states)
|
| 356 |
+
hidden_states = self.decoder(hidden_states)
|
| 357 |
+
return hidden_states
|
| 358 |
+
|
| 359 |
+
|
| 360 |
+
class XLMRobertaPreTrainingHeads(nn.Module):
|
| 361 |
+
def __init__(self, config):
|
| 362 |
+
super().__init__()
|
| 363 |
+
self.predictions = XLMRobertaLMPredictionHead(config)
|
| 364 |
+
self.seq_relationship = nn.Linear(config.hidden_size, 2)
|
| 365 |
+
|
| 366 |
+
def forward(self, sequence_output, pooled_output):
|
| 367 |
+
prediction_scores = self.predictions(sequence_output)
|
| 368 |
+
seq_relationship_score = self.seq_relationship(pooled_output)
|
| 369 |
+
return prediction_scores, seq_relationship_score
|
| 370 |
+
|
| 371 |
+
|
| 372 |
+
class XLMRobertaPreTrainedModel(PreTrainedModel):
|
| 373 |
+
"""An abstract class to handle weights initialization and
|
| 374 |
+
a simple interface for dowloading and loading pretrained models.
|
| 375 |
+
"""
|
| 376 |
+
|
| 377 |
+
config_class = XLMRobertaFlashConfig
|
| 378 |
+
base_model_prefix = "roberta"
|
| 379 |
+
supports_gradient_checkpointing = True
|
| 380 |
+
|
| 381 |
+
def _set_gradient_checkpointing(self, module, value=False):
|
| 382 |
+
if isinstance(module, XLMRobertaEncoder):
|
| 383 |
+
module.gradient_checkpointing = value
|
| 384 |
+
|
| 385 |
+
@classmethod
|
| 386 |
+
def from_pretrained(
|
| 387 |
+
cls,
|
| 388 |
+
*args,
|
| 389 |
+
**kwargs,
|
| 390 |
+
):
|
| 391 |
+
if not 'torch_dtype' in kwargs:
|
| 392 |
+
kwargs['torch_dtype'] = 'auto'
|
| 393 |
+
return super().from_pretrained(*args, **kwargs)
|
| 394 |
+
|
| 395 |
+
|
| 396 |
+
|
| 397 |
+
class XLMRobertaModel(XLMRobertaPreTrainedModel):
|
| 398 |
+
def __init__(self, config: XLMRobertaFlashConfig, add_pooling_layer=True):
|
| 399 |
+
super().__init__(config)
|
| 400 |
+
self.pad_vocab_size_multiple = getattr(config, "pad_vocab_size_multiple", 1)
|
| 401 |
+
if config.vocab_size % self.pad_vocab_size_multiple != 0:
|
| 402 |
+
config.vocab_size += self.pad_vocab_size_multiple - (
|
| 403 |
+
config.vocab_size % self.pad_vocab_size_multiple
|
| 404 |
+
)
|
| 405 |
+
self.fused_dropout_add_ln = getattr(config, "fused_dropout_add_ln", False)
|
| 406 |
+
if self.fused_dropout_add_ln and layer_norm_fn is None:
|
| 407 |
+
raise ImportError("Triton is not installed")
|
| 408 |
+
assert config.hidden_act in [
|
| 409 |
+
"gelu",
|
| 410 |
+
"gelu_new",
|
| 411 |
+
"gelu_fast",
|
| 412 |
+
"gelu_pytorch_tanh",
|
| 413 |
+
]
|
| 414 |
+
|
| 415 |
+
self.embeddings = XLMRobertaEmbeddings(
|
| 416 |
+
config.hidden_size,
|
| 417 |
+
config.vocab_size,
|
| 418 |
+
config.max_position_embeddings if config.position_embedding_type == 'absolute' else -1,
|
| 419 |
+
config.type_vocab_size,
|
| 420 |
+
padding_idx=config.pad_token_id,
|
| 421 |
+
)
|
| 422 |
+
self.emb_drop = nn.Dropout(config.hidden_dropout_prob)
|
| 423 |
+
self.emb_ln = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
| 424 |
+
self.encoder = XLMRobertaEncoder(config)
|
| 425 |
+
self.pooler = XLMRobertaPooler(config) if add_pooling_layer else None
|
| 426 |
+
|
| 427 |
+
self.apply(partial(_init_weights, initializer_range=config.initializer_range))
|
| 428 |
+
|
| 429 |
+
|
| 430 |
+
@torch.inference_mode()
|
| 431 |
+
def encode(
|
| 432 |
+
self: 'XLMRobertaModel',
|
| 433 |
+
sentences: Union[str, List[str]],
|
| 434 |
+
batch_size: int = 32,
|
| 435 |
+
show_progress_bar: Optional[bool] = None,
|
| 436 |
+
output_value: str = 'sentence_embedding',
|
| 437 |
+
convert_to_numpy: bool = True,
|
| 438 |
+
convert_to_tensor: bool = False,
|
| 439 |
+
device: Optional[torch.device] = None,
|
| 440 |
+
normalize_embeddings: bool = False,
|
| 441 |
+
truncate_dim: Optional[int] = None,
|
| 442 |
+
**tokenizer_kwargs,
|
| 443 |
+
) -> Union[List[torch.Tensor], np.ndarray, torch.Tensor]:
|
| 444 |
+
"""
|
| 445 |
+
Computes sentence embeddings
|
| 446 |
+
Args:
|
| 447 |
+
sentences(`str` or `List[str]`):
|
| 448 |
+
Sentence or sentences to be encoded
|
| 449 |
+
batch_size(`int`, *optional*, defaults to 32):
|
| 450 |
+
Batch size for the computation
|
| 451 |
+
show_progress_bar(`bool`, *optional*, defaults to None):
|
| 452 |
+
Show a progress bar when encoding sentences.
|
| 453 |
+
If set to None, progress bar is only shown when
|
| 454 |
+
`logger.level == logging.INFO` or `logger.level == logging.DEBUG`.
|
| 455 |
+
output_value(`str`, *optional*, defaults to 'sentence_embedding'):
|
| 456 |
+
Default sentence_embedding, to get sentence embeddings.
|
| 457 |
+
Can be set to token_embeddings to get wordpiece token embeddings.
|
| 458 |
+
Set to None, to get all output values
|
| 459 |
+
convert_to_numpy(`bool`, *optional*, defaults to True):
|
| 460 |
+
If true, the output is a list of numpy vectors.
|
| 461 |
+
Else, it is a list of pytorch tensors.
|
| 462 |
+
convert_to_tensor(`bool`, *optional*, defaults to False):
|
| 463 |
+
If true, you get one large tensor as return.
|
| 464 |
+
Overwrites any setting from convert_to_numpy
|
| 465 |
+
device(`torch.device`, *optional*, defaults to None):
|
| 466 |
+
Which torch.device to use for the computation
|
| 467 |
+
normalize_embeddings(`bool`, *optional*, defaults to False):
|
| 468 |
+
If set to true, returned vectors will have length 1. In that case, the
|
| 469 |
+
faster dot-product (util.dot_score) instead of cosine similarity can
|
| 470 |
+
be used.
|
| 471 |
+
truncate_dim(`int`, *optional*, defaults to None):
|
| 472 |
+
The dimension to truncate sentence embeddings to. `None` does no truncation.
|
| 473 |
+
tokenizer_kwargs(`Dict[str, Any]`, *optional*, defaults to {}):
|
| 474 |
+
Keyword arguments for the tokenizer
|
| 475 |
+
Returns:
|
| 476 |
+
By default, a list of tensors is returned.
|
| 477 |
+
If convert_to_tensor, a stacked tensor is returned.
|
| 478 |
+
If convert_to_numpy, a numpy matrix is returned.
|
| 479 |
+
"""
|
| 480 |
+
from transformers import AutoTokenizer
|
| 481 |
+
|
| 482 |
+
self.tokenizer = AutoTokenizer.from_pretrained(
|
| 483 |
+
self.name_or_path, trust_remote_code=True
|
| 484 |
+
)
|
| 485 |
+
|
| 486 |
+
is_training = self.training
|
| 487 |
+
self.eval()
|
| 488 |
+
|
| 489 |
+
if show_progress_bar is None:
|
| 490 |
+
show_progress_bar = (
|
| 491 |
+
logger.getEffectiveLevel() == logging.INFO
|
| 492 |
+
or logger.getEffectiveLevel() == logging.DEBUG
|
| 493 |
+
)
|
| 494 |
+
|
| 495 |
+
if convert_to_tensor:
|
| 496 |
+
convert_to_numpy = False
|
| 497 |
+
|
| 498 |
+
if output_value != 'sentence_embedding':
|
| 499 |
+
convert_to_tensor = False
|
| 500 |
+
convert_to_numpy = False
|
| 501 |
+
|
| 502 |
+
input_was_string = False
|
| 503 |
+
if isinstance(sentences, str) or not hasattr(sentences, '__len__'):
|
| 504 |
+
sentences = [sentences]
|
| 505 |
+
input_was_string = True
|
| 506 |
+
|
| 507 |
+
if device is not None:
|
| 508 |
+
self.to(device)
|
| 509 |
+
|
| 510 |
+
permutation = np.argsort([-len(i) for i in sentences])
|
| 511 |
+
inverse_permutation = np.argsort(permutation)
|
| 512 |
+
sentences = [sentences[idx] for idx in permutation]
|
| 513 |
+
|
| 514 |
+
tokenizer_kwargs['padding'] = tokenizer_kwargs.get('padding', True)
|
| 515 |
+
tokenizer_kwargs['max_length'] = tokenizer_kwargs.get(
|
| 516 |
+
'max_length', self.tokenizer.init_kwargs.get('model_max_length', 8192)
|
| 517 |
+
)
|
| 518 |
+
tokenizer_kwargs['truncation'] = tokenizer_kwargs.get('truncation', True)
|
| 519 |
+
|
| 520 |
+
all_embeddings = []
|
| 521 |
+
|
| 522 |
+
if trange is not None:
|
| 523 |
+
range_iter = trange(
|
| 524 |
+
0,
|
| 525 |
+
len(sentences),
|
| 526 |
+
batch_size,
|
| 527 |
+
desc="Encoding",
|
| 528 |
+
disable=not show_progress_bar,
|
| 529 |
+
)
|
| 530 |
+
else:
|
| 531 |
+
range_iter = range(0, len(sentences), batch_size)
|
| 532 |
+
|
| 533 |
+
for i in range_iter:
|
| 534 |
+
encoded_input = self.tokenizer(
|
| 535 |
+
sentences[i : i + batch_size],
|
| 536 |
+
return_tensors='pt',
|
| 537 |
+
**tokenizer_kwargs,
|
| 538 |
+
).to(self.device)
|
| 539 |
+
token_embs = self.forward(**encoded_input)[0]
|
| 540 |
+
|
| 541 |
+
# Accumulate in fp32 to avoid overflow
|
| 542 |
+
token_embs = token_embs.float()
|
| 543 |
+
|
| 544 |
+
if output_value == 'token_embeddings':
|
| 545 |
+
raise NotImplementedError
|
| 546 |
+
elif output_value is None:
|
| 547 |
+
raise NotImplementedError
|
| 548 |
+
else:
|
| 549 |
+
if self.config.emb_pooler == 'cls':
|
| 550 |
+
embeddings = self.cls_pooling(
|
| 551 |
+
token_embs, encoded_input['attention_mask']
|
| 552 |
+
)
|
| 553 |
+
else:
|
| 554 |
+
embeddings = self.mean_pooling(
|
| 555 |
+
token_embs, encoded_input['attention_mask']
|
| 556 |
+
)
|
| 557 |
+
|
| 558 |
+
if normalize_embeddings:
|
| 559 |
+
embeddings = torch.nn.functional.normalize(embeddings, p=2, dim=1)
|
| 560 |
+
|
| 561 |
+
if convert_to_numpy:
|
| 562 |
+
embeddings = embeddings.cpu()
|
| 563 |
+
all_embeddings.extend(embeddings)
|
| 564 |
+
|
| 565 |
+
all_embeddings = [all_embeddings[idx] for idx in inverse_permutation]
|
| 566 |
+
|
| 567 |
+
truncate_dim = truncate_dim or self.config.truncate_dim
|
| 568 |
+
if truncate_dim:
|
| 569 |
+
all_embeddings = self.truncate_embeddings(all_embeddings, truncate_dim)
|
| 570 |
+
|
| 571 |
+
if convert_to_tensor:
|
| 572 |
+
all_embeddings = torch.stack(all_embeddings)
|
| 573 |
+
elif convert_to_numpy:
|
| 574 |
+
all_embeddings = np.asarray([emb.numpy() for emb in all_embeddings])
|
| 575 |
+
|
| 576 |
+
if input_was_string:
|
| 577 |
+
all_embeddings = all_embeddings[0]
|
| 578 |
+
|
| 579 |
+
self.train(is_training)
|
| 580 |
+
return all_embeddings
|
| 581 |
+
|
| 582 |
+
|
| 583 |
+
def truncate_embeddings(self, embeddings, truncate_dim):
|
| 584 |
+
if not self.config.matryoshka_dimensions:
|
| 585 |
+
logger.warning(
|
| 586 |
+
'Matryoshka embeddings are not supported, so dimension truncation will not be performed.'
|
| 587 |
+
)
|
| 588 |
+
return embeddings
|
| 589 |
+
elif truncate_dim in self.config.matryoshka_dimensions:
|
| 590 |
+
return [tensor[:truncate_dim] for tensor in embeddings]
|
| 591 |
+
else:
|
| 592 |
+
raise ValueError(f'The provided `truncate_dim` value of {truncate_dim} is not supported. '
|
| 593 |
+
f'Supported dimensions are {self.config.matryoshka_dimensions}.')
|
| 594 |
+
|
| 595 |
+
def mean_pooling(
|
| 596 |
+
self, token_embeddings: torch.Tensor, attention_mask: torch.Tensor
|
| 597 |
+
):
|
| 598 |
+
input_mask_expanded = (
|
| 599 |
+
attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
|
| 600 |
+
)
|
| 601 |
+
return torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp(
|
| 602 |
+
input_mask_expanded.sum(1), min=1e-9
|
| 603 |
+
)
|
| 604 |
+
|
| 605 |
+
|
| 606 |
+
def cls_pooling(
|
| 607 |
+
self, token_embeddings: torch.Tensor, attention_mask: torch.Tensor
|
| 608 |
+
):
|
| 609 |
+
return token_embeddings[:,0]
|
| 610 |
+
|
| 611 |
+
|
| 612 |
+
def forward(
|
| 613 |
+
self,
|
| 614 |
+
input_ids,
|
| 615 |
+
position_ids=None,
|
| 616 |
+
token_type_ids=None,
|
| 617 |
+
attention_mask=None,
|
| 618 |
+
masked_tokens_mask=None,
|
| 619 |
+
return_dict=None,
|
| 620 |
+
**kwargs,
|
| 621 |
+
):
|
| 622 |
+
"""If masked_tokens_mask is not None (i.e. last_layer_subset == True in XLMForPreTraining),
|
| 623 |
+
we only want the output for the masked tokens. This means that we only compute the last
|
| 624 |
+
layer output for these tokens.
|
| 625 |
+
masked_tokens_mask: (batch, seqlen), dtype=torch.bool
|
| 626 |
+
"""
|
| 627 |
+
|
| 628 |
+
if kwargs:
|
| 629 |
+
for key, value in kwargs.items():
|
| 630 |
+
if value is not None:
|
| 631 |
+
logger.warning(
|
| 632 |
+
'Flash attention implementation does not support kwargs: %s',
|
| 633 |
+
key,
|
| 634 |
+
)
|
| 635 |
+
|
| 636 |
+
return_dict = (
|
| 637 |
+
return_dict if return_dict is not None else self.config.use_return_dict
|
| 638 |
+
)
|
| 639 |
+
|
| 640 |
+
hidden_states = self.embeddings(
|
| 641 |
+
input_ids, position_ids=position_ids, token_type_ids=token_type_ids
|
| 642 |
+
)
|
| 643 |
+
# TD [2022-12:18]: Don't need to force residual in fp32
|
| 644 |
+
# BERT puts embedding LayerNorm before embedding dropout.
|
| 645 |
+
if not self.fused_dropout_add_ln:
|
| 646 |
+
hidden_states = self.emb_ln(hidden_states)
|
| 647 |
+
else:
|
| 648 |
+
hidden_states = layer_norm_fn(
|
| 649 |
+
hidden_states, self.emb_ln.weight, self.emb_ln.bias, eps=self.emb_ln.eps
|
| 650 |
+
)
|
| 651 |
+
hidden_states = self.emb_drop(hidden_states)
|
| 652 |
+
|
| 653 |
+
if masked_tokens_mask is not None:
|
| 654 |
+
batch_size, seqlen = input_ids.shape[:2]
|
| 655 |
+
# We also need the first column for the CLS token
|
| 656 |
+
first_col_mask = torch.zeros(
|
| 657 |
+
batch_size, seqlen, dtype=torch.bool, device=input_ids.device
|
| 658 |
+
)
|
| 659 |
+
first_col_mask[:, 0] = True
|
| 660 |
+
subset_mask = masked_tokens_mask | first_col_mask
|
| 661 |
+
else:
|
| 662 |
+
subset_mask = None
|
| 663 |
+
|
| 664 |
+
sequence_output = self.encoder(
|
| 665 |
+
hidden_states, key_padding_mask=attention_mask, subset_mask=subset_mask
|
| 666 |
+
)
|
| 667 |
+
|
| 668 |
+
if masked_tokens_mask is None:
|
| 669 |
+
pooled_output = (
|
| 670 |
+
self.pooler(sequence_output) if self.pooler is not None else None
|
| 671 |
+
)
|
| 672 |
+
else:
|
| 673 |
+
# TD [2022-03-01]: the indexing here is very tricky.
|
| 674 |
+
if attention_mask is not None:
|
| 675 |
+
subset_idx = subset_mask[attention_mask]
|
| 676 |
+
pool_input = sequence_output[first_col_mask[attention_mask][subset_idx]]
|
| 677 |
+
sequence_output = sequence_output[
|
| 678 |
+
masked_tokens_mask[attention_mask][subset_idx]
|
| 679 |
+
]
|
| 680 |
+
else:
|
| 681 |
+
pool_input = sequence_output[first_col_mask[subset_mask]]
|
| 682 |
+
sequence_output = sequence_output[masked_tokens_mask[subset_mask]]
|
| 683 |
+
pooled_output = (
|
| 684 |
+
self.pooler(pool_input, pool=False) if self.pooler is not None else None
|
| 685 |
+
)
|
| 686 |
+
|
| 687 |
+
if not return_dict:
|
| 688 |
+
return sequence_output, pooled_output
|
| 689 |
+
|
| 690 |
+
return BaseModelOutputWithPoolingAndCrossAttentions(
|
| 691 |
+
last_hidden_state=sequence_output,
|
| 692 |
+
pooler_output=pooled_output,
|
| 693 |
+
)
|
| 694 |
+
|
| 695 |
+
|
| 696 |
+
class XLMRobertaForMaskedLM(XLMRobertaPreTrainedModel):
|
| 697 |
+
_tied_weights_keys = ["lm_head.decoder.weight", "lm_head.decoder.bias"]
|
| 698 |
+
|
| 699 |
+
def __init__(self, config):
|
| 700 |
+
super().__init__(config)
|
| 701 |
+
|
| 702 |
+
if config.is_decoder:
|
| 703 |
+
logger.warning(
|
| 704 |
+
"If you want to use `XLMRobertaForMaskedLM` make sure `config.is_decoder=False` for "
|
| 705 |
+
"bi-directional self-attention."
|
| 706 |
+
)
|
| 707 |
+
|
| 708 |
+
self.roberta = XLMRobertaModel(config, add_pooling_layer=False)
|
| 709 |
+
self.lm_head = XLMRobertaLMHead(config)
|
| 710 |
+
|
| 711 |
+
# Initialize weights and apply final processing
|
| 712 |
+
self.post_init()
|
| 713 |
+
|
| 714 |
+
def get_input_embeddings(self):
|
| 715 |
+
return self.roberta.embeddings.word_embeddings
|
| 716 |
+
|
| 717 |
+
def get_output_embeddings(self):
|
| 718 |
+
return self.lm_head.decoder
|
| 719 |
+
|
| 720 |
+
def set_output_embeddings(self, new_embeddings):
|
| 721 |
+
self.lm_head.decoder = new_embeddings
|
| 722 |
+
|
| 723 |
+
def forward(
|
| 724 |
+
self,
|
| 725 |
+
input_ids: Optional[torch.LongTensor] = None,
|
| 726 |
+
attention_mask: Optional[torch.FloatTensor] = None,
|
| 727 |
+
token_type_ids: Optional[torch.LongTensor] = None,
|
| 728 |
+
position_ids: Optional[torch.LongTensor] = None,
|
| 729 |
+
head_mask: Optional[torch.FloatTensor] = None,
|
| 730 |
+
inputs_embeds: Optional[torch.FloatTensor] = None,
|
| 731 |
+
encoder_hidden_states: Optional[torch.FloatTensor] = None,
|
| 732 |
+
encoder_attention_mask: Optional[torch.FloatTensor] = None,
|
| 733 |
+
labels: Optional[torch.LongTensor] = None,
|
| 734 |
+
output_attentions: Optional[bool] = None,
|
| 735 |
+
output_hidden_states: Optional[bool] = None,
|
| 736 |
+
return_dict: Optional[bool] = None,
|
| 737 |
+
) -> Union[Tuple[torch.Tensor], MaskedLMOutput]:
|
| 738 |
+
r"""
|
| 739 |
+
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
| 740 |
+
Labels for computing the masked language modeling loss. Indices should be in `[-100, 0, ...,
|
| 741 |
+
config.vocab_size]` (see `input_ids` docstring) Tokens with indices set to `-100` are ignored (masked), the
|
| 742 |
+
loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`
|
| 743 |
+
kwargs (`Dict[str, any]`, optional, defaults to *{}*):
|
| 744 |
+
Used to hide legacy arguments that have been deprecated.
|
| 745 |
+
"""
|
| 746 |
+
return_dict = (
|
| 747 |
+
return_dict if return_dict is not None else self.config.use_return_dict
|
| 748 |
+
)
|
| 749 |
+
|
| 750 |
+
outputs = self.roberta(
|
| 751 |
+
input_ids,
|
| 752 |
+
attention_mask=attention_mask,
|
| 753 |
+
token_type_ids=token_type_ids,
|
| 754 |
+
position_ids=position_ids,
|
| 755 |
+
head_mask=head_mask,
|
| 756 |
+
inputs_embeds=inputs_embeds,
|
| 757 |
+
encoder_hidden_states=encoder_hidden_states,
|
| 758 |
+
encoder_attention_mask=encoder_attention_mask,
|
| 759 |
+
output_attentions=output_attentions,
|
| 760 |
+
output_hidden_states=output_hidden_states,
|
| 761 |
+
return_dict=return_dict,
|
| 762 |
+
)
|
| 763 |
+
sequence_output = outputs[0]
|
| 764 |
+
prediction_scores = self.lm_head(sequence_output)
|
| 765 |
+
|
| 766 |
+
masked_lm_loss = None
|
| 767 |
+
if labels is not None:
|
| 768 |
+
# move labels to correct device to enable model parallelism
|
| 769 |
+
labels = labels.to(prediction_scores.device)
|
| 770 |
+
loss_fct = CrossEntropyLoss()
|
| 771 |
+
masked_lm_loss = loss_fct(
|
| 772 |
+
prediction_scores.view(-1, self.config.vocab_size), labels.view(-1)
|
| 773 |
+
)
|
| 774 |
+
|
| 775 |
+
if not return_dict:
|
| 776 |
+
output = (prediction_scores,) + outputs[2:]
|
| 777 |
+
return (
|
| 778 |
+
((masked_lm_loss,) + output) if masked_lm_loss is not None else output
|
| 779 |
+
)
|
| 780 |
+
|
| 781 |
+
return MaskedLMOutput(
|
| 782 |
+
loss=masked_lm_loss,
|
| 783 |
+
logits=prediction_scores,
|
| 784 |
+
hidden_states=outputs.hidden_states,
|
| 785 |
+
attentions=outputs.attentions,
|
| 786 |
+
)
|
| 787 |
+
|
| 788 |
+
|
| 789 |
+
# Copied from transformers.models.roberta.modeling_roberta.RobertaClassificationHead with Roberta->XLMRoberta
|
| 790 |
+
class XLMRobertaClassificationHead(nn.Module):
|
| 791 |
+
"""Head for sentence-level classification tasks."""
|
| 792 |
+
|
| 793 |
+
def __init__(self, config):
|
| 794 |
+
super().__init__()
|
| 795 |
+
fused_bias_fc = getattr(config, "fused_bias_fc", False)
|
| 796 |
+
if fused_bias_fc and FusedDense is None:
|
| 797 |
+
raise ImportError("fused_dense is not installed")
|
| 798 |
+
linear_cls = nn.Linear if not fused_bias_fc else FusedDense
|
| 799 |
+
self.dense = linear_cls(config.hidden_size, config.hidden_size)
|
| 800 |
+
classifier_dropout = (
|
| 801 |
+
config.classifier_dropout
|
| 802 |
+
if config.classifier_dropout is not None
|
| 803 |
+
else config.hidden_dropout_prob
|
| 804 |
+
)
|
| 805 |
+
self.dropout = nn.Dropout(classifier_dropout)
|
| 806 |
+
self.out_proj = linear_cls(config.hidden_size, config.num_labels)
|
| 807 |
+
|
| 808 |
+
def forward(self, features, **kwargs):
|
| 809 |
+
x = features[:, 0, :] # take <s> token (equiv. to [CLS])
|
| 810 |
+
x = self.dropout(x)
|
| 811 |
+
x = self.dense(x)
|
| 812 |
+
x = torch.tanh(x)
|
| 813 |
+
x = self.dropout(x)
|
| 814 |
+
x = self.out_proj(x)
|
| 815 |
+
return x
|
| 816 |
+
|
| 817 |
+
|
| 818 |
+
# Copied from transformers.models.roberta.modeling_roberta.RobertaForSequenceClassification with Roberta->XLMRoberta, ROBERTA->XLM_ROBERTA
|
| 819 |
+
class XLMRobertaForSequenceClassification(XLMRobertaPreTrainedModel):
|
| 820 |
+
def __init__(self, config):
|
| 821 |
+
super().__init__(config)
|
| 822 |
+
self.num_labels = config.num_labels
|
| 823 |
+
self.config = config
|
| 824 |
+
|
| 825 |
+
self.roberta = XLMRobertaModel(config, add_pooling_layer=False)
|
| 826 |
+
self.classifier = XLMRobertaClassificationHead(config)
|
| 827 |
+
|
| 828 |
+
# Initialize weights and apply final processing
|
| 829 |
+
self.post_init()
|
| 830 |
+
|
| 831 |
+
def forward(
|
| 832 |
+
self,
|
| 833 |
+
input_ids: Optional[torch.LongTensor] = None,
|
| 834 |
+
attention_mask: Optional[torch.FloatTensor] = None,
|
| 835 |
+
token_type_ids: Optional[torch.LongTensor] = None,
|
| 836 |
+
position_ids: Optional[torch.LongTensor] = None,
|
| 837 |
+
head_mask: Optional[torch.FloatTensor] = None,
|
| 838 |
+
inputs_embeds: Optional[torch.FloatTensor] = None,
|
| 839 |
+
labels: Optional[torch.LongTensor] = None,
|
| 840 |
+
output_attentions: Optional[bool] = None,
|
| 841 |
+
output_hidden_states: Optional[bool] = None,
|
| 842 |
+
return_dict: Optional[bool] = None,
|
| 843 |
+
) -> Union[Tuple[torch.Tensor], SequenceClassifierOutput]:
|
| 844 |
+
r"""
|
| 845 |
+
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
|
| 846 |
+
Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
|
| 847 |
+
config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
|
| 848 |
+
`config.num_labels > 1` a classification loss is computed (Cross-Entropy).
|
| 849 |
+
"""
|
| 850 |
+
return_dict = (
|
| 851 |
+
return_dict if return_dict is not None else self.config.use_return_dict
|
| 852 |
+
)
|
| 853 |
+
|
| 854 |
+
outputs = self.roberta(
|
| 855 |
+
input_ids,
|
| 856 |
+
attention_mask=attention_mask,
|
| 857 |
+
token_type_ids=token_type_ids,
|
| 858 |
+
position_ids=position_ids,
|
| 859 |
+
head_mask=head_mask,
|
| 860 |
+
inputs_embeds=inputs_embeds,
|
| 861 |
+
output_attentions=output_attentions,
|
| 862 |
+
output_hidden_states=output_hidden_states,
|
| 863 |
+
return_dict=return_dict,
|
| 864 |
+
)
|
| 865 |
+
sequence_output = outputs[0]
|
| 866 |
+
logits = self.classifier(sequence_output)
|
| 867 |
+
|
| 868 |
+
loss = None
|
| 869 |
+
if labels is not None:
|
| 870 |
+
# move labels to correct device to enable model parallelism
|
| 871 |
+
labels = labels.to(logits.device)
|
| 872 |
+
if self.config.problem_type is None:
|
| 873 |
+
if self.num_labels == 1:
|
| 874 |
+
self.config.problem_type = "regression"
|
| 875 |
+
elif self.num_labels > 1 and (
|
| 876 |
+
labels.dtype == torch.long or labels.dtype == torch.int
|
| 877 |
+
):
|
| 878 |
+
self.config.problem_type = "single_label_classification"
|
| 879 |
+
else:
|
| 880 |
+
self.config.problem_type = "multi_label_classification"
|
| 881 |
+
|
| 882 |
+
if self.config.problem_type == "regression":
|
| 883 |
+
loss_fct = MSELoss()
|
| 884 |
+
if self.num_labels == 1:
|
| 885 |
+
loss = loss_fct(logits.squeeze(), labels.squeeze())
|
| 886 |
+
else:
|
| 887 |
+
loss = loss_fct(logits, labels)
|
| 888 |
+
elif self.config.problem_type == "single_label_classification":
|
| 889 |
+
loss_fct = CrossEntropyLoss()
|
| 890 |
+
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
|
| 891 |
+
elif self.config.problem_type == "multi_label_classification":
|
| 892 |
+
loss_fct = BCEWithLogitsLoss()
|
| 893 |
+
loss = loss_fct(logits, labels)
|
| 894 |
+
|
| 895 |
+
if not return_dict:
|
| 896 |
+
output = (logits,) + outputs[2:]
|
| 897 |
+
return ((loss,) + output) if loss is not None else output
|
| 898 |
+
|
| 899 |
+
return SequenceClassifierOutput(
|
| 900 |
+
loss=loss,
|
| 901 |
+
logits=logits,
|
| 902 |
+
hidden_states=outputs.hidden_states,
|
| 903 |
+
attentions=outputs.attentions,
|
| 904 |
+
)
|
| 905 |
+
|
| 906 |
+
|
| 907 |
+
@torch.inference_mode()
|
| 908 |
+
def compute_score(
|
| 909 |
+
self,
|
| 910 |
+
sentence_pairs: Union[List[Tuple[str, str]], Tuple[str, str]],
|
| 911 |
+
batch_size: int = 32,
|
| 912 |
+
max_length: Optional[int] = None,
|
| 913 |
+
) -> List[float]:
|
| 914 |
+
|
| 915 |
+
if not hasattr(self, "_tokenizer"):
|
| 916 |
+
from transformers import AutoTokenizer
|
| 917 |
+
|
| 918 |
+
self._tokenizer = AutoTokenizer.from_pretrained(
|
| 919 |
+
self.name_or_path, trust_remote_code=True
|
| 920 |
+
)
|
| 921 |
+
|
| 922 |
+
assert isinstance(sentence_pairs, list)
|
| 923 |
+
if isinstance(sentence_pairs[0], str):
|
| 924 |
+
sentence_pairs = [sentence_pairs]
|
| 925 |
+
|
| 926 |
+
all_scores = []
|
| 927 |
+
for start_index in range(
|
| 928 |
+
0, len(sentence_pairs), batch_size
|
| 929 |
+
):
|
| 930 |
+
sentences_batch = sentence_pairs[
|
| 931 |
+
start_index : start_index + batch_size
|
| 932 |
+
]
|
| 933 |
+
inputs = self._tokenizer(
|
| 934 |
+
sentences_batch,
|
| 935 |
+
padding=True,
|
| 936 |
+
truncation=True,
|
| 937 |
+
return_tensors='pt',
|
| 938 |
+
max_length=max_length,
|
| 939 |
+
).to(self.device)
|
| 940 |
+
scores = (
|
| 941 |
+
self.forward(**inputs, return_dict=True)
|
| 942 |
+
.logits.view(
|
| 943 |
+
-1,
|
| 944 |
+
)
|
| 945 |
+
.float()
|
| 946 |
+
)
|
| 947 |
+
scores = torch.sigmoid(scores)
|
| 948 |
+
all_scores.extend(scores.cpu().numpy().tolist())
|
| 949 |
+
|
| 950 |
+
if len(all_scores) == 1:
|
| 951 |
+
return all_scores[0]
|
| 952 |
+
return all_scores
|
| 953 |
+
|
| 954 |
+
def predict(
|
| 955 |
+
self,
|
| 956 |
+
sentence_pairs: Union[List[Tuple[str, str]], Tuple[str, str]],
|
| 957 |
+
batch_size: int = 32,
|
| 958 |
+
max_length: Optional[int] = None,
|
| 959 |
+
) -> List[float]:
|
| 960 |
+
# used for beir evaluation
|
| 961 |
+
return self.compute_score(sentence_pairs, batch_size=batch_size, max_length=max_length)
|
| 962 |
+
|
| 963 |
+
def rerank(
|
| 964 |
+
self,
|
| 965 |
+
query: str,
|
| 966 |
+
documents: List[str],
|
| 967 |
+
batch_size: int = 32,
|
| 968 |
+
max_length: int = 1024,
|
| 969 |
+
max_query_length: int = 512,
|
| 970 |
+
overlap_tokens: int = 80,
|
| 971 |
+
top_n: Optional[int] = None,
|
| 972 |
+
**kwargs,
|
| 973 |
+
):
|
| 974 |
+
assert max_length >= max_query_length * 2, (
|
| 975 |
+
f'max_length ({max_length}) must be greater than or equal to '
|
| 976 |
+
f'max_query_length ({max_query_length}) * 2'
|
| 977 |
+
)
|
| 978 |
+
|
| 979 |
+
if not hasattr(self, "_tokenizer"):
|
| 980 |
+
from transformers import AutoTokenizer
|
| 981 |
+
|
| 982 |
+
self._tokenizer = AutoTokenizer.from_pretrained(
|
| 983 |
+
self.name_or_path, trust_remote_code=True
|
| 984 |
+
)
|
| 985 |
+
|
| 986 |
+
# preproc of tokenization
|
| 987 |
+
sentence_pairs, sentence_pairs_pids = reranker_tokenize_preproc(
|
| 988 |
+
query,
|
| 989 |
+
documents,
|
| 990 |
+
tokenizer=self._tokenizer,
|
| 991 |
+
max_length=max_length,
|
| 992 |
+
max_query_length=max_query_length,
|
| 993 |
+
overlap_tokens=overlap_tokens,
|
| 994 |
+
)
|
| 995 |
+
|
| 996 |
+
tot_scores = []
|
| 997 |
+
with torch.no_grad():
|
| 998 |
+
for k in range(0, len(sentence_pairs), batch_size):
|
| 999 |
+
batch = self._tokenizer.pad(
|
| 1000 |
+
sentence_pairs[k : k + batch_size],
|
| 1001 |
+
padding=True,
|
| 1002 |
+
max_length=max_length,
|
| 1003 |
+
pad_to_multiple_of=None,
|
| 1004 |
+
return_tensors="pt",
|
| 1005 |
+
)
|
| 1006 |
+
batch_on_device = {k: v.to(self.device) for k, v in batch.items()}
|
| 1007 |
+
scores = (
|
| 1008 |
+
self.forward(**batch_on_device, return_dict=True)
|
| 1009 |
+
.logits.view(
|
| 1010 |
+
-1,
|
| 1011 |
+
)
|
| 1012 |
+
.float()
|
| 1013 |
+
)
|
| 1014 |
+
scores = torch.sigmoid(scores)
|
| 1015 |
+
tot_scores.extend(scores.cpu().numpy().tolist())
|
| 1016 |
+
|
| 1017 |
+
# ranking
|
| 1018 |
+
merge_scores = [0 for _ in range(len(documents))]
|
| 1019 |
+
for pid, score in zip(sentence_pairs_pids, tot_scores):
|
| 1020 |
+
merge_scores[pid] = max(merge_scores[pid], score)
|
| 1021 |
+
|
| 1022 |
+
merge_scores_argsort = np.argsort(merge_scores)[::-1]
|
| 1023 |
+
sorted_documents = []
|
| 1024 |
+
sorted_scores = []
|
| 1025 |
+
for mid in merge_scores_argsort:
|
| 1026 |
+
sorted_scores.append(merge_scores[mid])
|
| 1027 |
+
sorted_documents.append(documents[mid])
|
| 1028 |
+
|
| 1029 |
+
top_n = min(top_n or len(sorted_documents), len(sorted_documents))
|
| 1030 |
+
|
| 1031 |
+
return [
|
| 1032 |
+
{
|
| 1033 |
+
'document': sorted_documents[i],
|
| 1034 |
+
'relevance_score': sorted_scores[i],
|
| 1035 |
+
'index': merge_scores_argsort[i],
|
| 1036 |
+
}
|
| 1037 |
+
for i in range(top_n)
|
| 1038 |
+
]
|
| 1039 |
+
|
| 1040 |
+
|
| 1041 |
+
def reranker_tokenize_preproc(
|
| 1042 |
+
query: str,
|
| 1043 |
+
passages: List[str],
|
| 1044 |
+
tokenizer=None,
|
| 1045 |
+
max_length: int = 1024,
|
| 1046 |
+
max_query_length: int = 512,
|
| 1047 |
+
overlap_tokens: int = 80,
|
| 1048 |
+
):
|
| 1049 |
+
from copy import deepcopy
|
| 1050 |
+
|
| 1051 |
+
assert tokenizer is not None, "Please provide a valid tokenizer for tokenization!"
|
| 1052 |
+
sep_id = tokenizer.sep_token_id
|
| 1053 |
+
|
| 1054 |
+
def _merge_inputs(chunk1_raw, chunk2):
|
| 1055 |
+
chunk1 = deepcopy(chunk1_raw)
|
| 1056 |
+
chunk1['input_ids'].append(sep_id)
|
| 1057 |
+
chunk1['input_ids'].extend(chunk2['input_ids'])
|
| 1058 |
+
chunk1['input_ids'].append(sep_id)
|
| 1059 |
+
chunk1['attention_mask'].append(chunk2['attention_mask'][0])
|
| 1060 |
+
chunk1['attention_mask'].extend(chunk2['attention_mask'])
|
| 1061 |
+
chunk1['attention_mask'].append(chunk2['attention_mask'][-1])
|
| 1062 |
+
if 'token_type_ids' in chunk1:
|
| 1063 |
+
token_type_ids = [1 for _ in range(len(chunk2['token_type_ids']) + 2)]
|
| 1064 |
+
chunk1['token_type_ids'].extend(token_type_ids)
|
| 1065 |
+
return chunk1
|
| 1066 |
+
|
| 1067 |
+
# Note: the long query will be truncated to 256 tokens by default
|
| 1068 |
+
query_inputs = tokenizer.encode_plus(
|
| 1069 |
+
query, truncation=True, padding=False, max_length=max_query_length
|
| 1070 |
+
)
|
| 1071 |
+
|
| 1072 |
+
max_passage_inputs_length = max_length - len(query_inputs['input_ids']) - 2
|
| 1073 |
+
# assert (
|
| 1074 |
+
# max_passage_inputs_length > 100
|
| 1075 |
+
# ), "Your query is too long! Please make sure your query less than 500 tokens!"
|
| 1076 |
+
|
| 1077 |
+
overlap_tokens_implt = min(overlap_tokens, max_passage_inputs_length // 4)
|
| 1078 |
+
|
| 1079 |
+
res_merge_inputs = []
|
| 1080 |
+
res_merge_inputs_pids = []
|
| 1081 |
+
for pid, passage in enumerate(passages):
|
| 1082 |
+
passage_inputs = tokenizer.encode_plus(
|
| 1083 |
+
passage,
|
| 1084 |
+
truncation=False,
|
| 1085 |
+
padding=False,
|
| 1086 |
+
add_special_tokens=False,
|
| 1087 |
+
max_length=0,
|
| 1088 |
+
)
|
| 1089 |
+
passage_inputs_length = len(passage_inputs['input_ids'])
|
| 1090 |
+
|
| 1091 |
+
if passage_inputs_length <= max_passage_inputs_length:
|
| 1092 |
+
qp_merge_inputs = _merge_inputs(query_inputs, passage_inputs)
|
| 1093 |
+
res_merge_inputs.append(qp_merge_inputs)
|
| 1094 |
+
res_merge_inputs_pids.append(pid)
|
| 1095 |
+
else:
|
| 1096 |
+
start_id = 0
|
| 1097 |
+
while start_id < passage_inputs_length:
|
| 1098 |
+
end_id = start_id + max_passage_inputs_length
|
| 1099 |
+
# make sure the length of the last chunk is `max_passage_inputs_length`
|
| 1100 |
+
if end_id >= passage_inputs_length:
|
| 1101 |
+
sub_passage_inputs = {
|
| 1102 |
+
k: v[-max_passage_inputs_length:]
|
| 1103 |
+
for k, v in passage_inputs.items()
|
| 1104 |
+
}
|
| 1105 |
+
else:
|
| 1106 |
+
sub_passage_inputs = {
|
| 1107 |
+
k: v[start_id:end_id] for k, v in passage_inputs.items()
|
| 1108 |
+
}
|
| 1109 |
+
start_id = (
|
| 1110 |
+
end_id - overlap_tokens_implt
|
| 1111 |
+
if end_id < passage_inputs_length
|
| 1112 |
+
else end_id
|
| 1113 |
+
)
|
| 1114 |
+
|
| 1115 |
+
qp_merge_inputs = _merge_inputs(query_inputs, sub_passage_inputs)
|
| 1116 |
+
res_merge_inputs.append(qp_merge_inputs)
|
| 1117 |
+
res_merge_inputs_pids.append(pid)
|
| 1118 |
+
|
| 1119 |
+
return res_merge_inputs, res_merge_inputs_pids
|
special_tokens_map.json
ADDED
|
@@ -0,0 +1,51 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"bos_token": {
|
| 3 |
+
"content": "<s>",
|
| 4 |
+
"lstrip": false,
|
| 5 |
+
"normalized": false,
|
| 6 |
+
"rstrip": false,
|
| 7 |
+
"single_word": false
|
| 8 |
+
},
|
| 9 |
+
"cls_token": {
|
| 10 |
+
"content": "<s>",
|
| 11 |
+
"lstrip": false,
|
| 12 |
+
"normalized": false,
|
| 13 |
+
"rstrip": false,
|
| 14 |
+
"single_word": false
|
| 15 |
+
},
|
| 16 |
+
"eos_token": {
|
| 17 |
+
"content": "</s>",
|
| 18 |
+
"lstrip": false,
|
| 19 |
+
"normalized": false,
|
| 20 |
+
"rstrip": false,
|
| 21 |
+
"single_word": false
|
| 22 |
+
},
|
| 23 |
+
"mask_token": {
|
| 24 |
+
"content": "<mask>",
|
| 25 |
+
"lstrip": true,
|
| 26 |
+
"normalized": false,
|
| 27 |
+
"rstrip": false,
|
| 28 |
+
"single_word": false
|
| 29 |
+
},
|
| 30 |
+
"pad_token": {
|
| 31 |
+
"content": "<pad>",
|
| 32 |
+
"lstrip": false,
|
| 33 |
+
"normalized": false,
|
| 34 |
+
"rstrip": false,
|
| 35 |
+
"single_word": false
|
| 36 |
+
},
|
| 37 |
+
"sep_token": {
|
| 38 |
+
"content": "</s>",
|
| 39 |
+
"lstrip": false,
|
| 40 |
+
"normalized": false,
|
| 41 |
+
"rstrip": false,
|
| 42 |
+
"single_word": false
|
| 43 |
+
},
|
| 44 |
+
"unk_token": {
|
| 45 |
+
"content": "<unk>",
|
| 46 |
+
"lstrip": false,
|
| 47 |
+
"normalized": false,
|
| 48 |
+
"rstrip": false,
|
| 49 |
+
"single_word": false
|
| 50 |
+
}
|
| 51 |
+
}
|
tokenizer.json
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:e802fe5337779428818439760a1e6161ed36ceed72d4ebcbda9c139a2108fc99
|
| 3 |
+
size 17082988
|
tokenizer_config.json
ADDED
|
@@ -0,0 +1,55 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"added_tokens_decoder": {
|
| 3 |
+
"0": {
|
| 4 |
+
"content": "<s>",
|
| 5 |
+
"lstrip": false,
|
| 6 |
+
"normalized": false,
|
| 7 |
+
"rstrip": false,
|
| 8 |
+
"single_word": false,
|
| 9 |
+
"special": true
|
| 10 |
+
},
|
| 11 |
+
"1": {
|
| 12 |
+
"content": "<pad>",
|
| 13 |
+
"lstrip": false,
|
| 14 |
+
"normalized": false,
|
| 15 |
+
"rstrip": false,
|
| 16 |
+
"single_word": false,
|
| 17 |
+
"special": true
|
| 18 |
+
},
|
| 19 |
+
"2": {
|
| 20 |
+
"content": "</s>",
|
| 21 |
+
"lstrip": false,
|
| 22 |
+
"normalized": false,
|
| 23 |
+
"rstrip": false,
|
| 24 |
+
"single_word": false,
|
| 25 |
+
"special": true
|
| 26 |
+
},
|
| 27 |
+
"3": {
|
| 28 |
+
"content": "<unk>",
|
| 29 |
+
"lstrip": false,
|
| 30 |
+
"normalized": false,
|
| 31 |
+
"rstrip": false,
|
| 32 |
+
"single_word": false,
|
| 33 |
+
"special": true
|
| 34 |
+
},
|
| 35 |
+
"250001": {
|
| 36 |
+
"content": "<mask>",
|
| 37 |
+
"lstrip": true,
|
| 38 |
+
"normalized": false,
|
| 39 |
+
"rstrip": false,
|
| 40 |
+
"single_word": false,
|
| 41 |
+
"special": true
|
| 42 |
+
}
|
| 43 |
+
},
|
| 44 |
+
"bos_token": "<s>",
|
| 45 |
+
"clean_up_tokenization_spaces": true,
|
| 46 |
+
"cls_token": "<s>",
|
| 47 |
+
"eos_token": "</s>",
|
| 48 |
+
"extra_special_tokens": {},
|
| 49 |
+
"mask_token": "<mask>",
|
| 50 |
+
"model_max_length": 1024,
|
| 51 |
+
"pad_token": "<pad>",
|
| 52 |
+
"sep_token": "</s>",
|
| 53 |
+
"tokenizer_class": "XLMRobertaTokenizerFast",
|
| 54 |
+
"unk_token": "<unk>"
|
| 55 |
+
}
|
xlm_padding.py
ADDED
|
@@ -0,0 +1,218 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# This implementation was adapted from https://github.com/Dao-AILab/flash-attention/blob/main/flash_attn/modules/block.py
|
| 2 |
+
# Commit id: c94cd09744d20f0ac587a351ff6ff2e8ad11ae1b
|
| 3 |
+
|
| 4 |
+
# Previously adapted from https://github.com/mlcommons/training_results_v1.1/blob/main/NVIDIA/benchmarks/bert/implementations/pytorch/padding.py
|
| 5 |
+
|
| 6 |
+
import torch
|
| 7 |
+
import torch.nn.functional as F
|
| 8 |
+
from einops import rearrange, repeat
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
class IndexFirstAxis(torch.autograd.Function):
|
| 12 |
+
@staticmethod
|
| 13 |
+
def forward(ctx, input, indices):
|
| 14 |
+
ctx.save_for_backward(indices)
|
| 15 |
+
assert input.ndim >= 2
|
| 16 |
+
ctx.first_axis_dim, other_shape = input.shape[0], input.shape[1:]
|
| 17 |
+
second_dim = other_shape.numel()
|
| 18 |
+
# TD [2022-03-04] For some reason torch.gather is a bit faster than indexing.
|
| 19 |
+
# return input[indices]
|
| 20 |
+
return torch.gather(
|
| 21 |
+
rearrange(input, "b ... -> b (...)"), 0, repeat(indices, "z -> z d", d=second_dim)
|
| 22 |
+
).reshape(-1, *other_shape)
|
| 23 |
+
|
| 24 |
+
@staticmethod
|
| 25 |
+
def backward(ctx, grad_output):
|
| 26 |
+
(indices,) = ctx.saved_tensors
|
| 27 |
+
assert grad_output.ndim >= 2
|
| 28 |
+
other_shape = grad_output.shape[1:]
|
| 29 |
+
grad_output = rearrange(grad_output, "b ... -> b (...)")
|
| 30 |
+
grad_input = torch.zeros(
|
| 31 |
+
[ctx.first_axis_dim, grad_output.shape[1]],
|
| 32 |
+
device=grad_output.device,
|
| 33 |
+
dtype=grad_output.dtype,
|
| 34 |
+
)
|
| 35 |
+
# TD [2022-03-04] For some reason torch.scatter is a bit faster than indexing.
|
| 36 |
+
# grad_input[indices] = grad_output
|
| 37 |
+
grad_input.scatter_(0, repeat(indices, "z -> z d", d=grad_output.shape[1]), grad_output)
|
| 38 |
+
return grad_input.reshape(ctx.first_axis_dim, *other_shape), None
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
index_first_axis = IndexFirstAxis.apply
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
class IndexPutFirstAxis(torch.autograd.Function):
|
| 45 |
+
@staticmethod
|
| 46 |
+
def forward(ctx, values, indices, first_axis_dim):
|
| 47 |
+
ctx.save_for_backward(indices)
|
| 48 |
+
assert indices.ndim == 1
|
| 49 |
+
assert values.ndim >= 2
|
| 50 |
+
output = torch.zeros(
|
| 51 |
+
first_axis_dim, *values.shape[1:], device=values.device, dtype=values.dtype
|
| 52 |
+
)
|
| 53 |
+
# TD [2022-03-04] For some reason torch.scatter is a bit faster than indexing.
|
| 54 |
+
output[indices] = values
|
| 55 |
+
# output.scatter_(0, repeat(indices, 'z -> z d', d=values.shape[1]), values)
|
| 56 |
+
return output
|
| 57 |
+
|
| 58 |
+
@staticmethod
|
| 59 |
+
def backward(ctx, grad_output):
|
| 60 |
+
(indices,) = ctx.saved_tensors
|
| 61 |
+
# TD [2022-03-04] For some reason torch.gather is a bit faster than indexing.
|
| 62 |
+
grad_values = grad_output[indices]
|
| 63 |
+
# grad_values = torch.gather(grad_output, 0, repeat(indices, 'z -> z d', d=grad_output.shape[1]))
|
| 64 |
+
return grad_values, None, None
|
| 65 |
+
|
| 66 |
+
|
| 67 |
+
index_put_first_axis = IndexPutFirstAxis.apply
|
| 68 |
+
|
| 69 |
+
|
| 70 |
+
class IndexFirstAxisResidual(torch.autograd.Function):
|
| 71 |
+
@staticmethod
|
| 72 |
+
def forward(ctx, input, indices):
|
| 73 |
+
ctx.save_for_backward(indices)
|
| 74 |
+
assert input.ndim >= 2
|
| 75 |
+
ctx.first_axis_dim, other_shape = input.shape[0], input.shape[1:]
|
| 76 |
+
second_dim = other_shape.numel()
|
| 77 |
+
# TD [2022-03-04] For some reason torch.gather is a bit faster than indexing.
|
| 78 |
+
output = input[indices]
|
| 79 |
+
# We don't want to reshape input (b ... -> b (...)) since it could change the channel_last
|
| 80 |
+
# memory format to channel_first. In other words, input might not be contiguous.
|
| 81 |
+
# If we don't detach, Pytorch complains about output being a view and is being modified inplace
|
| 82 |
+
return output, input.detach()
|
| 83 |
+
|
| 84 |
+
@staticmethod
|
| 85 |
+
def backward(ctx, grad_output, grad_residual):
|
| 86 |
+
(indices,) = ctx.saved_tensors
|
| 87 |
+
assert grad_output.ndim >= 2
|
| 88 |
+
other_shape = grad_output.shape[1:]
|
| 89 |
+
assert grad_residual.shape[1:] == other_shape
|
| 90 |
+
grad_input = grad_residual
|
| 91 |
+
# grad_input[indices] += grad_output
|
| 92 |
+
indices = indices.reshape(indices.shape[0], *((1,) * (grad_output.ndim - 1)))
|
| 93 |
+
indices = indices.expand_as(grad_output)
|
| 94 |
+
grad_input.scatter_add_(0, indices, grad_output)
|
| 95 |
+
return grad_input.reshape(ctx.first_axis_dim, *other_shape), None
|
| 96 |
+
|
| 97 |
+
|
| 98 |
+
index_first_axis_residual = IndexFirstAxisResidual.apply
|
| 99 |
+
|
| 100 |
+
|
| 101 |
+
def unpad_input(hidden_states, attention_mask):
|
| 102 |
+
"""
|
| 103 |
+
Arguments:
|
| 104 |
+
hidden_states: (batch, seqlen, ...)
|
| 105 |
+
attention_mask: (batch, seqlen), bool / int, 1 means valid and 0 means not valid.
|
| 106 |
+
Return:
|
| 107 |
+
hidden_states: (total_nnz, ...), where total_nnz = number of tokens in selected in attention_mask.
|
| 108 |
+
indices: (total_nnz), the indices of non-masked tokens from the flattened input sequence.
|
| 109 |
+
cu_seqlens: (batch + 1), the cumulative sequence lengths, used to index into hidden_states.
|
| 110 |
+
max_seqlen_in_batch: int
|
| 111 |
+
"""
|
| 112 |
+
seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32)
|
| 113 |
+
indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten()
|
| 114 |
+
max_seqlen_in_batch = seqlens_in_batch.max().item()
|
| 115 |
+
cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.torch.int32), (1, 0))
|
| 116 |
+
# TD [2022-03-04] We don't want to index with a bool mask, because Pytorch will expand the
|
| 117 |
+
# bool mask, then call nonzero to get the indices, then index with those. The indices is @dim
|
| 118 |
+
# times larger than it needs to be, wasting memory. It's faster and more memory-efficient to
|
| 119 |
+
# index with integer indices. Moreover, torch's index is a bit slower than it needs to be,
|
| 120 |
+
# so we write custom forward and backward to make it a bit faster.
|
| 121 |
+
return (
|
| 122 |
+
index_first_axis(rearrange(hidden_states, "b s ... -> (b s) ..."), indices),
|
| 123 |
+
indices,
|
| 124 |
+
cu_seqlens,
|
| 125 |
+
max_seqlen_in_batch,
|
| 126 |
+
)
|
| 127 |
+
|
| 128 |
+
|
| 129 |
+
def unpad_input_for_concatenated_sequences(hidden_states, attention_mask_in_length):
|
| 130 |
+
"""
|
| 131 |
+
Supports concatenating short samples in one sequence. The attention_mask_in_length is utilized to mask other short samples. It helps efficient training of variant lengths-based samples (e.g., the supervised fine-tuning task in large language model).
|
| 132 |
+
The motivation for this function is explained [here](https://github.com/Dao-AILab/flash-attention/issues/432#issuecomment-1668822286).
|
| 133 |
+
|
| 134 |
+
For example, if batch = 3 and seqlen = 6, the attention_mask_in_length is:
|
| 135 |
+
```
|
| 136 |
+
[
|
| 137 |
+
[2, 3, 0, 0, 0, 0],
|
| 138 |
+
[3, 2, 0, 0, 0, 0],
|
| 139 |
+
[6, 0, 0, 0, 0, 0]
|
| 140 |
+
]
|
| 141 |
+
```
|
| 142 |
+
, which refers to the 3D-attention mask:
|
| 143 |
+
```
|
| 144 |
+
[
|
| 145 |
+
[
|
| 146 |
+
[1, 0, 0, 0, 0, 0],
|
| 147 |
+
[1, 1, 0, 0, 0, 0],
|
| 148 |
+
[0, 0, 1, 0, 0, 0],
|
| 149 |
+
[0, 0, 1, 1, 0, 0],
|
| 150 |
+
[0, 0, 1, 1, 1, 0],
|
| 151 |
+
[0, 0, 0, 0, 0, 1]
|
| 152 |
+
],
|
| 153 |
+
[
|
| 154 |
+
[1, 0, 0, 0, 0, 0],
|
| 155 |
+
[1, 1, 0, 0, 0, 0],
|
| 156 |
+
[1, 1, 1, 0, 0, 0],
|
| 157 |
+
[0, 0, 0, 1, 0, 0],
|
| 158 |
+
[0, 0, 0, 1, 1, 0],
|
| 159 |
+
[0, 0, 0, 0, 0, 1]
|
| 160 |
+
],
|
| 161 |
+
[
|
| 162 |
+
[1, 0, 0, 0, 0, 0],
|
| 163 |
+
[1, 1, 0, 0, 0, 0],
|
| 164 |
+
[1, 1, 1, 0, 0, 0],
|
| 165 |
+
[1, 1, 1, 1, 0, 0],
|
| 166 |
+
[1, 1, 1, 1, 1, 0],
|
| 167 |
+
[1, 1, 1, 1, 1, 1]
|
| 168 |
+
]
|
| 169 |
+
]
|
| 170 |
+
```.
|
| 171 |
+
|
| 172 |
+
Arguments:
|
| 173 |
+
hidden_states: (batch, seqlen, ...)
|
| 174 |
+
attention_mask_in_length: (batch, seqlen), int, a nonzero number (e.g., 1, 2, 3, etc.) means length of concatenated sequence in b-th batch, and 0 means none.
|
| 175 |
+
Return:
|
| 176 |
+
hidden_states: (total_nnz, ...), where total_nnz = number of tokens in selected in attention_mask.
|
| 177 |
+
indices: (total_nnz), the indices of non-masked tokens from the flattened input sequence.
|
| 178 |
+
cu_seqlens: (batch + 1), the cumulative sequence lengths, used to index into hidden_states.
|
| 179 |
+
max_seqlen_in_batch: int
|
| 180 |
+
"""
|
| 181 |
+
length = attention_mask_in_length.sum(dim=-1)
|
| 182 |
+
seqlen = attention_mask_in_length.size(-1)
|
| 183 |
+
attention_mask_2d = torch.arange(seqlen, device=length.device, dtype=length.dtype).expand(len(length),
|
| 184 |
+
seqlen) < length.unsqueeze(
|
| 185 |
+
1)
|
| 186 |
+
real_indices_idx = torch.nonzero(attention_mask_in_length.flatten(), as_tuple=False).flatten()
|
| 187 |
+
seqlens_in_batch = attention_mask_in_length.flatten()[real_indices_idx]
|
| 188 |
+
indices = torch.nonzero(attention_mask_2d.flatten(), as_tuple=False).flatten()
|
| 189 |
+
max_seqlen_in_batch = seqlens_in_batch.max().item()
|
| 190 |
+
cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.torch.int32), (1, 0))
|
| 191 |
+
# TD [2022-03-04] We don't want to index with a bool mask, because Pytorch will expand the
|
| 192 |
+
# bool mask, then call nonzero to get the indices, then index with those. The indices is @dim
|
| 193 |
+
# times larger than it needs to be, wasting memory. It's faster and more memory-efficient to
|
| 194 |
+
# index with integer indices. Moreover, torch's index is a bit slower than it needs to be,
|
| 195 |
+
# so we write custom forward and backward to make it a bit faster.
|
| 196 |
+
return (
|
| 197 |
+
index_first_axis(rearrange(hidden_states, "b s ... -> (b s) ..."), indices),
|
| 198 |
+
indices,
|
| 199 |
+
cu_seqlens,
|
| 200 |
+
max_seqlen_in_batch,
|
| 201 |
+
)
|
| 202 |
+
|
| 203 |
+
|
| 204 |
+
def pad_input(hidden_states, indices, batch, seqlen):
|
| 205 |
+
"""
|
| 206 |
+
Arguments:
|
| 207 |
+
hidden_states: (total_nnz, ...), where total_nnz = number of tokens in selected in attention_mask.
|
| 208 |
+
indices: (total_nnz), the indices that represent the non-masked tokens of the original padded input sequence.
|
| 209 |
+
batch: int, batch size for the padded sequence.
|
| 210 |
+
seqlen: int, maximum sequence length for the padded sequence.
|
| 211 |
+
Return:
|
| 212 |
+
hidden_states: (batch, seqlen, ...)
|
| 213 |
+
"""
|
| 214 |
+
dim = hidden_states.shape[-1]
|
| 215 |
+
# output = torch.zeros((batch * seqlen), dim, device=hidden_states.device, dtype=hidden_states.dtype)
|
| 216 |
+
# output[indices] = hidden_states
|
| 217 |
+
output = index_put_first_axis(hidden_states, indices, batch * seqlen)
|
| 218 |
+
return rearrange(output, "(b s) ... -> b s ...", b=batch)
|