cuadron11 commited on
Commit
20b690b
·
verified ·
1 Parent(s): ecb323b

Add new CrossEncoder model

Browse files
.gitattributes CHANGED
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ tokenizer.json filter=lfs diff=lfs merge=lfs -text
README.md ADDED
@@ -0,0 +1,408 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ tags:
3
+ - sentence-transformers
4
+ - cross-encoder
5
+ - reranker
6
+ - generated_from_trainer
7
+ - dataset_size:3200
8
+ - loss:CachedMultipleNegativesRankingLoss
9
+ base_model: jinaai/jina-reranker-v2-base-multilingual
10
+ pipeline_tag: text-ranking
11
+ library_name: sentence-transformers
12
+ metrics:
13
+ - map
14
+ - mrr@10
15
+ - ndcg@10
16
+ model-index:
17
+ - name: CrossEncoder based on jinaai/jina-reranker-v2-base-multilingual
18
+ results:
19
+ - task:
20
+ type: cross-encoder-reranking
21
+ name: Cross Encoder Reranking
22
+ dataset:
23
+ name: jina reranker v2 base multilingual contrastive parl 4 10ep
24
+ type: jina-reranker-v2-base-multilingual-contrastive-parl-4-10ep
25
+ metrics:
26
+ - type: map
27
+ value: 0.0194
28
+ name: Map
29
+ - type: mrr@10
30
+ value: 0.0194
31
+ name: Mrr@10
32
+ - type: ndcg@10
33
+ value: 0.0198
34
+ name: Ndcg@10
35
+ ---
36
+
37
+ # CrossEncoder based on jinaai/jina-reranker-v2-base-multilingual
38
+
39
+ This is a [Cross Encoder](https://www.sbert.net/docs/cross_encoder/usage/usage.html) model finetuned from [jinaai/jina-reranker-v2-base-multilingual](https://huggingface.co/jinaai/jina-reranker-v2-base-multilingual) using the [sentence-transformers](https://www.SBERT.net) library. It computes scores for pairs of texts, which can be used for text reranking and semantic search.
40
+
41
+ ## Model Details
42
+
43
+ ### Model Description
44
+ - **Model Type:** Cross Encoder
45
+ - **Base model:** [jinaai/jina-reranker-v2-base-multilingual](https://huggingface.co/jinaai/jina-reranker-v2-base-multilingual) <!-- at revision 2f894e63642a95228da19cdd583cd2309983c867 -->
46
+ - **Maximum Sequence Length:** 1024 tokens
47
+ - **Number of Output Labels:** 1 label
48
+ <!-- - **Training Dataset:** Unknown -->
49
+ <!-- - **Language:** Unknown -->
50
+ <!-- - **License:** Unknown -->
51
+
52
+ ### Model Sources
53
+
54
+ - **Documentation:** [Sentence Transformers Documentation](https://sbert.net)
55
+ - **Documentation:** [Cross Encoder Documentation](https://www.sbert.net/docs/cross_encoder/usage/usage.html)
56
+ - **Repository:** [Sentence Transformers on GitHub](https://github.com/UKPLab/sentence-transformers)
57
+ - **Hugging Face:** [Cross Encoders on Hugging Face](https://huggingface.co/models?library=sentence-transformers&other=cross-encoder)
58
+
59
+ ## Usage
60
+
61
+ ### Direct Usage (Sentence Transformers)
62
+
63
+ First install the Sentence Transformers library:
64
+
65
+ ```bash
66
+ pip install -U sentence-transformers
67
+ ```
68
+
69
+ Then you can load this model and run inference.
70
+ ```python
71
+ from sentence_transformers import CrossEncoder
72
+
73
+ # Download from the 🤗 Hub
74
+ model = CrossEncoder("cuadron11/jina-reranker-v2-base-multilingual-contrastive-parl-4-10ep")
75
+ # Get scores for pairs of texts
76
+ pairs = [
77
+ ['Zer gertatu zen martxoaren 3an Euskal Autonomia Erkidegoan?', '[TOPIC: Honako ekimen hauek batera eztabaidatu eta behin betiko ebazpena hartzea: ]\n[UNZALU HERMOSA, (SV-ES)]:\nSekula. Gertatzen dena da uste dugula martxoaren 3ko jokaerak baduela zer hobetua. Eta hobetzeko abiapuntu bakarra gogoeta egitea da, aztertzea eta hasieratik aitortzea hutsegiteak egin zirela. Izan ere, nire lehenengo hitzaldian esan dudanez, triskantzak gertatu izanak pentsarazi behar liguke zerbaitek huts egin zuela egun hartako dispositiboa edo operazioa planifikatzean eta zuzentzean. Horixe sartu nahi dugu guk: eztabaida-elementuak, hobekuntzarako kritika-elementuak, eta UPyDrekin eta Alderdi Popularrarekin sinatu dugun zuzenketan hori esaten da, onar dadila gauzak hobetu egin daitezkeela. Izan ere, Iturrate jauna, zuk egin dizkiguzun galderei nik beste batzuekin erantzungo nieke. Posible da hutsegiteetatik ikastea eta herritarren segurtasuna hobetzea? Posible da? Edo, besterik gabe, "Ahal zen modu bakarrean jokatu dugu" esatera mugatu behar dugu? Posible da herritarrei kalte gutxiago eragitea horrelako istiluak gertatzen direnean? Horixe planteatu nahi dugu guk, beharrezkoa dela… Eta uste osoa dugunez hobetu daitekeela, eta uste osoa dugunez hobeto joka zitekeela, horregatik nahi dugu eta horregatik planteatzen dugu hutsegiteak aztertzea, gogoeta egitea, eta elementu zuzentzaileak martxan jartzea horrelako egoerarik berriro gerta ez dadin. Eta, begira, sailarekin batera dispositiboari babesa eman dioten bakarrak dira, hain justu, Ertzaintzaren jokaerak inoiz babesten ez dituztenak; lehen esan dudanez, Ertzaintzaren kontrako ekintzak ere gaitzetsi ez dituztenak. Eta horrek kezkatu egiten gaitu. Nik ez dakit zu, Iturrate jauna, eta sailburu andrea kezkatzen zaituzten; baina, (Date: 03.04.2014)'],
78
+ ['Zenbat denbora behar da Ertzaintzako promozio baten deialdia egiten denetik agenteak kalera irteten diren arte?', '[TOPIC: Interpelazioa, Javier Ruiz de Arbulo Cerio Euskal Talde Popularreko legebiltzarkideak Segurtasuneko sailburuari egina, Arabako Miñoien Atalari buruz]\n[SEGURTASUNEKO SAILBURUAK (BELTRÁN DE HEREDIA ARRONIZ), (EA-NV)]:\nhoriek aldatu egiten dira egun batetik bestera, unitate batetik bestera, kontuan hartuta zer bilakaera duten erretiroek, kontuan hartuta nola gertatzen diren baja horiek… Baina, batez ere, nik bezain ondo dakizu Ertzaintzan defizit handia daukagula, eta ezin hobeto dakizu zergatia zein den. Ez dakit defizit horren zergatia zein den errepika diezazudan etorri zaren hona, baina ez daukat inolako eragozpenik Legebiltzar honetan berriro azaltzeko eta zuek berriro entzun behar izateko. Honela gaude Espainiako Gobernuak, Alderdi Popularraren Gobernuak, denbora asko behar izan zuelako, denbora gehiegi, zuk behar izan duzun be- zala, ulertzeko premia geneukala Ertzaintzan gertatzen ari ziren erretiro-bajak estaltzeko promozio berriak deitzeko –gero eta gehiago dira erretiroak eragindako bajak–; logikoa denez, baja horiek eragina zeukaten eta daukate Miñoien Atalean ere, bajak oraindik ere gertatzen ari baitira. 26. promozioa hautatzeko prozesua urtebete baino gehiago atzeratu da, errekurtsoek mehatxatu egin zituztelako 25. promozioaren bilakaera normala eta amaiera. Nik uste dut orain bide onetik goazela, baina ez duzu ahaztu behar promozio baten deialdia egiten dugunetik agenteak kalera irteten diren arte bi urte baino gehiago igarotzen direla. Bi urte baino gehiago. Eta ziztu bizian ibili ginen, betoa amaitu orduko azterketak egiteko: hogei egun eskas behar izan genituen 26. promozioko azterketen deialdia egiteko. Ziztu bizian ibili ginen, baina, hala ere, kale. Denbora eman behar da, ezta? Hemen, urdaiazpikoekin bezala geratzen da: denbora eman behar zaie, ontzeko. Bada, (Date: 01.12.2017)'],
79
+ ['Zergatik dimititu zuen Eusko Jaurlaritzako Komunikazio zuzendariak?', '[TOPIC: Galdera, Gorka Maneiro Labayen Mistoa-UPyD taldeko legebiltzarkideak lehendakariari egina, Eusko Jaurlaritzako Komunikazio zuzendariaren dimisioaren ondoren hartu beharreko erantzukizun politikoei buruz]\n[MANEIRO LABAYEN, (Mixto-UPyD)]:\nsailburu jakin batzuei elkarrizketak egitearen truke? Erantzun ahal diezaiokezu galdera horri? Halaxe da, bai. Zure esanetan, ez dago ezer arrarorik eta irregularrik, baina pertsona batek dimititu egin du. Zer egiteko asmoa duzu zuk? Bide batez, zer da aldi baterako dimisioaren kontu hori? Beste postu batean jarri al duzue pertsona hori? Diru publikoa kobratzen jarraitzen al du? Argitu dezakezu, edo herritarrak engainatu nahi dituzue? Pertsona horrek dimititu egin du. Zer egiteko (Date: 30.10.2015)'],
80
+ ['Zein da euskal herritarren iritzia independentziari buruz, Soziometroaren arabera?', '[TOPIC: Mozioa, Maddalen Iriarte Okiñena EH Bildu taldeko legebiltzarkideak aurkeztua, herri bezala ditugun erronka estrategikoei erantzuteko, herri-jakintza aktibatzeko eta ariketa kolektibo bat egiteko beharraren inguruan. Eztabaida eta behin betiko ebazpena]\n[BARRIO BAROJA, (PV-ETP)]:\nasko; eta ezin dela horren autokonplazientea izan eta dena positiboki egin dela esan. Argi dago, Iriarte andrea, amaitzeko, etorkizuneko erronkak ditugula; ados gaude gogor lan egin behar dela; baina estatus berria herritarrei arazo gehiago sortzea da; hura agerian jartzea eta hona ekartzea, berriz ere konfrontazio- eta eztabaida-eremu izatea da, herritarrei arazo gehiago sortzea da. Atzo argi eta garbi zioen euskal Soziometroak euskal herritarrok independentziari buruz zer iritzi dugu; eta inoiz ez da hain maila baxurik ikusi. Beraz, ildo horretan, erronka estrategikoei buruz hitz egiten ari zaren une honetan, estatus berriaren eztabaida hona ekartzea atzerapausoa litzateke, arazo gehiago ematea litzateke; eta, jakina, gu –zuri erantzuten dizut, baita orain hura aldarrikatu duen Egibar jaunari ere esaten diot– aurka egongo gara. Eskerrik asko. (Date: 10.06.2021)'],
81
+ ['Zeintzuk dira Eusko Jaurlaritzaren asmoak euskararen normalizazioan sakontzeko?', '[TOPIC: Galdera, Rebeka Ubera Aranzeta EH Bildu taldeko legebiltzarkideak Kultura eta Hizkuntza Politikako sailburuari egina, euskararen normalizazioan sakontzeko neurri funtsezkoak hartzeari buruz]\n[UBERA ARANZETA, (EH Bildu)]:\nAdministrazioa euskalduntzeko urratsak emango zirela: ekarpenak egin ditugu eta ezezkoa jaso dugu. Esan zitzaigun euskara ikastea doako bilakatzeko urratsak emango zirela, eta mugak besterik ez dugu ikusi eta ezezkoa jaso dugu. Eta jada dagoeneko zalantzan jartzen hasiak gara Gobernu honen borondate politikoa zein den. Eta, legegintzaldi honetan, sailburuen aldetik ere, atzerakada izugarria izan da, aurreko legegintzaldiarekin konparatuta –nabarmen gainera–, eta zentzu horretan ere, zerbait egin beharko duzu. Neurtzen ari (Date: 19.05.2017)'],
82
+ ]
83
+ scores = model.predict(pairs)
84
+ print(scores.shape)
85
+ # (5,)
86
+
87
+ # Or rank different texts based on similarity to a single text
88
+ ranks = model.rank(
89
+ 'Zer gertatu zen martxoaren 3an Euskal Autonomia Erkidegoan?',
90
+ [
91
+ '[TOPIC: Honako ekimen hauek batera eztabaidatu eta behin betiko ebazpena hartzea: ]\n[UNZALU HERMOSA, (SV-ES)]:\nSekula. Gertatzen dena da uste dugula martxoaren 3ko jokaerak baduela zer hobetua. Eta hobetzeko abiapuntu bakarra gogoeta egitea da, aztertzea eta hasieratik aitortzea hutsegiteak egin zirela. Izan ere, nire lehenengo hitzaldian esan dudanez, triskantzak gertatu izanak pentsarazi behar liguke zerbaitek huts egin zuela egun hartako dispositiboa edo operazioa planifikatzean eta zuzentzean. Horixe sartu nahi dugu guk: eztabaida-elementuak, hobekuntzarako kritika-elementuak, eta UPyDrekin eta Alderdi Popularrarekin sinatu dugun zuzenketan hori esaten da, onar dadila gauzak hobetu egin daitezkeela. Izan ere, Iturrate jauna, zuk egin dizkiguzun galderei nik beste batzuekin erantzungo nieke. Posible da hutsegiteetatik ikastea eta herritarren segurtasuna hobetzea? Posible da? Edo, besterik gabe, "Ahal zen modu bakarrean jokatu dugu" esatera mugatu behar dugu? Posible da herritarrei kalte gutxiago eragitea horrelako istiluak gertatzen direnean? Horixe planteatu nahi dugu guk, beharrezkoa dela… Eta uste osoa dugunez hobetu daitekeela, eta uste osoa dugunez hobeto joka zitekeela, horregatik nahi dugu eta horregatik planteatzen dugu hutsegiteak aztertzea, gogoeta egitea, eta elementu zuzentzaileak martxan jartzea horrelako egoerarik berriro gerta ez dadin. Eta, begira, sailarekin batera dispositiboari babesa eman dioten bakarrak dira, hain justu, Ertzaintzaren jokaerak inoiz babesten ez dituztenak; lehen esan dudanez, Ertzaintzaren kontrako ekintzak ere gaitzetsi ez dituztenak. Eta horrek kezkatu egiten gaitu. Nik ez dakit zu, Iturrate jauna, eta sailburu andrea kezkatzen zaituzten; baina, (Date: 03.04.2014)',
92
+ '[TOPIC: Interpelazioa, Javier Ruiz de Arbulo Cerio Euskal Talde Popularreko legebiltzarkideak Segurtasuneko sailburuari egina, Arabako Miñoien Atalari buruz]\n[SEGURTASUNEKO SAILBURUAK (BELTRÁN DE HEREDIA ARRONIZ), (EA-NV)]:\nhoriek aldatu egiten dira egun batetik bestera, unitate batetik bestera, kontuan hartuta zer bilakaera duten erretiroek, kontuan hartuta nola gertatzen diren baja horiek… Baina, batez ere, nik bezain ondo dakizu Ertzaintzan defizit handia daukagula, eta ezin hobeto dakizu zergatia zein den. Ez dakit defizit horren zergatia zein den errepika diezazudan etorri zaren hona, baina ez daukat inolako eragozpenik Legebiltzar honetan berriro azaltzeko eta zuek berriro entzun behar izateko. Honela gaude Espainiako Gobernuak, Alderdi Popularraren Gobernuak, denbora asko behar izan zuelako, denbora gehiegi, zuk behar izan duzun be- zala, ulertzeko premia geneukala Ertzaintzan gertatzen ari ziren erretiro-bajak estaltzeko promozio berriak deitzeko –gero eta gehiago dira erretiroak eragindako bajak–; logikoa denez, baja horiek eragina zeukaten eta daukate Miñoien Atalean ere, bajak oraindik ere gertatzen ari baitira. 26. promozioa hautatzeko prozesua urtebete baino gehiago atzeratu da, errekurtsoek mehatxatu egin zituztelako 25. promozioaren bilakaera normala eta amaiera. Nik uste dut orain bide onetik goazela, baina ez duzu ahaztu behar promozio baten deialdia egiten dugunetik agenteak kalera irteten diren arte bi urte baino gehiago igarotzen direla. Bi urte baino gehiago. Eta ziztu bizian ibili ginen, betoa amaitu orduko azterketak egiteko: hogei egun eskas behar izan genituen 26. promozioko azterketen deialdia egiteko. Ziztu bizian ibili ginen, baina, hala ere, kale. Denbora eman behar da, ezta? Hemen, urdaiazpikoekin bezala geratzen da: denbora eman behar zaie, ontzeko. Bada, (Date: 01.12.2017)',
93
+ '[TOPIC: Galdera, Gorka Maneiro Labayen Mistoa-UPyD taldeko legebiltzarkideak lehendakariari egina, Eusko Jaurlaritzako Komunikazio zuzendariaren dimisioaren ondoren hartu beharreko erantzukizun politikoei buruz]\n[MANEIRO LABAYEN, (Mixto-UPyD)]:\nsailburu jakin batzuei elkarrizketak egitearen truke? Erantzun ahal diezaiokezu galdera horri? Halaxe da, bai. Zure esanetan, ez dago ezer arrarorik eta irregularrik, baina pertsona batek dimititu egin du. Zer egiteko asmoa duzu zuk? Bide batez, zer da aldi baterako dimisioaren kontu hori? Beste postu batean jarri al duzue pertsona hori? Diru publikoa kobratzen jarraitzen al du? Argitu dezakezu, edo herritarrak engainatu nahi dituzue? Pertsona horrek dimititu egin du. Zer egiteko (Date: 30.10.2015)',
94
+ '[TOPIC: Mozioa, Maddalen Iriarte Okiñena EH Bildu taldeko legebiltzarkideak aurkeztua, herri bezala ditugun erronka estrategikoei erantzuteko, herri-jakintza aktibatzeko eta ariketa kolektibo bat egiteko beharraren inguruan. Eztabaida eta behin betiko ebazpena]\n[BARRIO BAROJA, (PV-ETP)]:\nasko; eta ezin dela horren autokonplazientea izan eta dena positiboki egin dela esan. Argi dago, Iriarte andrea, amaitzeko, etorkizuneko erronkak ditugula; ados gaude gogor lan egin behar dela; baina estatus berria herritarrei arazo gehiago sortzea da; hura agerian jartzea eta hona ekartzea, berriz ere konfrontazio- eta eztabaida-eremu izatea da, herritarrei arazo gehiago sortzea da. Atzo argi eta garbi zioen euskal Soziometroak euskal herritarrok independentziari buruz zer iritzi dugu; eta inoiz ez da hain maila baxurik ikusi. Beraz, ildo horretan, erronka estrategikoei buruz hitz egiten ari zaren une honetan, estatus berriaren eztabaida hona ekartzea atzerapausoa litzateke, arazo gehiago ematea litzateke; eta, jakina, gu –zuri erantzuten dizut, baita orain hura aldarrikatu duen Egibar jaunari ere esaten diot– aurka egongo gara. Eskerrik asko. (Date: 10.06.2021)',
95
+ '[TOPIC: Galdera, Rebeka Ubera Aranzeta EH Bildu taldeko legebiltzarkideak Kultura eta Hizkuntza Politikako sailburuari egina, euskararen normalizazioan sakontzeko neurri funtsezkoak hartzeari buruz]\n[UBERA ARANZETA, (EH Bildu)]:\nAdministrazioa euskalduntzeko urratsak emango zirela: ekarpenak egin ditugu eta ezezkoa jaso dugu. Esan zitzaigun euskara ikastea doako bilakatzeko urratsak emango zirela, eta mugak besterik ez dugu ikusi eta ezezkoa jaso dugu. Eta jada dagoeneko zalantzan jartzen hasiak gara Gobernu honen borondate politikoa zein den. Eta, legegintzaldi honetan, sailburuen aldetik ere, atzerakada izugarria izan da, aurreko legegintzaldiarekin konparatuta –nabarmen gainera–, eta zentzu horretan ere, zerbait egin beharko duzu. Neurtzen ari (Date: 19.05.2017)',
96
+ ]
97
+ )
98
+ # [{'corpus_id': ..., 'score': ...}, {'corpus_id': ..., 'score': ...}, ...]
99
+ ```
100
+
101
+ <!--
102
+ ### Direct Usage (Transformers)
103
+
104
+ <details><summary>Click to see the direct usage in Transformers</summary>
105
+
106
+ </details>
107
+ -->
108
+
109
+ <!--
110
+ ### Downstream Usage (Sentence Transformers)
111
+
112
+ You can finetune this model on your own dataset.
113
+
114
+ <details><summary>Click to expand</summary>
115
+
116
+ </details>
117
+ -->
118
+
119
+ <!--
120
+ ### Out-of-Scope Use
121
+
122
+ *List how the model may foreseeably be misused and address what users ought not to do with the model.*
123
+ -->
124
+
125
+ ## Evaluation
126
+
127
+ ### Metrics
128
+
129
+ #### Cross Encoder Reranking
130
+
131
+ * Dataset: `jina-reranker-v2-base-multilingual-contrastive-parl-4-10ep`
132
+ * Evaluated with [<code>CrossEncoderRerankingEvaluator</code>](https://sbert.net/docs/package_reference/cross_encoder/evaluation.html#sentence_transformers.cross_encoder.evaluation.CrossEncoderRerankingEvaluator) with these parameters:
133
+ ```json
134
+ {
135
+ "at_k": 10,
136
+ "always_rerank_positives": false
137
+ }
138
+ ```
139
+
140
+ | Metric | Value |
141
+ |:------------|:---------------------|
142
+ | map | 0.0194 (+0.0172) |
143
+ | mrr@10 | 0.0194 (+0.0176) |
144
+ | **ndcg@10** | **0.0198 (+0.0172)** |
145
+
146
+ <!--
147
+ ## Bias, Risks and Limitations
148
+
149
+ *What are the known or foreseeable issues stemming from this model? You could also flag here known failure cases or weaknesses of the model.*
150
+ -->
151
+
152
+ <!--
153
+ ### Recommendations
154
+
155
+ *What are recommendations with respect to the foreseeable issues? For example, filtering explicit content.*
156
+ -->
157
+
158
+ ## Training Details
159
+
160
+ ### Training Dataset
161
+
162
+ #### Unnamed Dataset
163
+
164
+ * Size: 3,200 training samples
165
+ * Columns: <code>query</code> and <code>positive</code>
166
+ * Approximate statistics based on the first 1000 samples:
167
+ | | query | positive |
168
+ |:--------|:-----------------------------------------------------------------------------------------------|:---------------------------------------------------------------------------------------------------|
169
+ | type | string | string |
170
+ | details | <ul><li>min: 27 characters</li><li>mean: 99.5 characters</li><li>max: 250 characters</li></ul> | <ul><li>min: 569 characters</li><li>mean: 975.13 characters</li><li>max: 2175 characters</li></ul> |
171
+ * Samples:
172
+ | query | positive |
173
+ |:-------------------------------------------------------------------------------------------------------------------------------------------------------------------|:----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|
174
+ | <code>Zein urtetan egin zuen José Ramón Becerra Carollo legebiltzarkideak SOS Deiak-112 larrialdi-deien arretarako zerbitzuaren esleipenari buruzko mozioa?</code> | <code>[TOPIC: Mozioa, José Ramón Becerra Carollo Elkarrekin Podemos taldeko legebiltzarkideak aurkeztua, SOS Deiak-112 larrialdi-deien arretarako zerbitzuaren esleipenari buruz. Eztabaida eta behin betiko ebazpena]<br>[LATXAGA UGARTEMENDIA, (EA-NV)]:<br>eta gero Sabin Etxearekin, Eliza Katolikoarekin, Xabier Arzalluzekin eta Eusko Jaurlaritzarekin berarekin lotu zenuen enpresa esleipenduna. Konspirazio perfektua lortzeko, Mosad eta BBVA falta zitzaizkizun, nik uste. Mesedez, ez erabili Ganbera hau gure eserlekuen gainean zikinkeria, zaborra botatzeko. Ez erabili horretarako, onbidezko gauzetarako baizik. Eta ez egin funtsik gabe, inolako frogarik gabe. Zuk esaten zenuena oso larria zen, oso larria, eta ezin duzu hemen tribuna honetan besterik gabe (Date: 21.12.2017)</code> |
175
+ | <code>Zergatik da beharrezkoa kargudun publikoen jokaera kodea arautzea?</code> | <code>[TOPIC: Euskal Sozialistak legebiltzar-taldeak egindako lege-proposamena, Kargudun Publikoaren Jokaera Kodea eta haren Bateraezintasunen Erregimena arautzeko. Aintzat hartzeari buruzko eztabaida eta behin betiko ebazpena]<br>[MINTEGI LAKARRA, (EH Bildu)]:<br>Egun on, presidente andrea, lehendakari jauna, legebiltzarkideok. Legerik onena da behar ez dena eta arautu beharra dagoenean hor badago ja gabeziaren sintoma, edo ez dagoelako adostasunik edo jokaera desegokiak egon direlako eta horiek saihestu behar direlako eta ez da ikusi beste biderik arautu beharra baino. Beraz, orain kargu publikoen jokaera etikoa edo jokaera kodea arautu beharrak adierazten digu badagoela gabezia, horren sintoma da. Izatez, jokaera zuzena berezkoa izan beharko (Date: 28.02.2013)</code> |
176
+ | <code>Zein da EH Bildu talde parlamentarioaren jarrera Ikuskizunen eta Jolas Jardueren Legea garatzeko erregelamenduaren inguruan?</code> | <code>[TOPIC: EH Bildu talde parlamentarioak egindako legez besteko proposamena, Ikuskizunen eta Jolas Jardueren Legea garatzeko erregelamenduaren inguruan. Eztabaida eta behin betiko ebazpena]<br>[ÁLVAREZ MARTÍNEZ, (EA-NV)]:<br>mintzaldian aipatu ditugun puntuak zehaztu behar ditugun. Uste dugu, erantzukizunetik, dekretu hori berrikusi egin behar dela, eta uste dugu dagoeneko abian dela berrikuspen-prozesu hori, Eudelekin batera, udalek dituzten ikuspegiekin batera. Puntu honetan, gogoratu behar da Eudelen kolore guzti-guztietako udalak daudela ordezkatuta, eta kontuan hartu behar da, halaber, udal horiek guztiek zer iritzi duten eta zer ikuspuntu duten. Sémper jauna, nik ere uste dut –esperientzia handirik ez daukat, baina (Date: 14.03.2019)</code> |
177
+ * Loss: [<code>CachedMultipleNegativesRankingLoss</code>](https://sbert.net/docs/package_reference/cross_encoder/losses.html#cachedmultiplenegativesrankingloss) with these parameters:
178
+ ```json
179
+ {
180
+ "scale": 10.0,
181
+ "num_negatives": null,
182
+ "activation_fn": "torch.nn.modules.activation.Sigmoid",
183
+ "mini_batch_size": 16
184
+ }
185
+ ```
186
+
187
+ ### Evaluation Dataset
188
+
189
+ #### Unnamed Dataset
190
+
191
+ * Size: 800 evaluation samples
192
+ * Columns: <code>query</code> and <code>positive</code>
193
+ * Approximate statistics based on the first 800 samples:
194
+ | | query | positive |
195
+ |:--------|:-------------------------------------------------------------------------------------------------|:----------------------------------------------------------------------------------------------------|
196
+ | type | string | string |
197
+ | details | <ul><li>min: 32 characters</li><li>mean: 102.26 characters</li><li>max: 247 characters</li></ul> | <ul><li>min: 550 characters</li><li>mean: 1011.95 characters</li><li>max: 2370 characters</li></ul> |
198
+ * Samples:
199
+ | query | positive |
200
+ |:-----------------------------------------------------------------------------------------------------------------------------|:---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|
201
+ | <code>Zer gertatu zen martxoaren 3an Euskal Autonomia Erkidegoan?</code> | <code>[TOPIC: Honako ekimen hauek batera eztabaidatu eta behin betiko ebazpena hartzea: ]<br>[UNZALU HERMOSA, (SV-ES)]:<br>Sekula. Gertatzen dena da uste dugula martxoaren 3ko jokaerak baduela zer hobetua. Eta hobetzeko abiapuntu bakarra gogoeta egitea da, aztertzea eta hasieratik aitortzea hutsegiteak egin zirela. Izan ere, nire lehenengo hitzaldian esan dudanez, triskantzak gertatu izanak pentsarazi behar liguke zerbaitek huts egin zuela egun hartako dispositiboa edo operazioa planifikatzean eta zuzentzean. Horixe sartu nahi dugu guk: eztabaida-elementuak, hobekuntzarako kritika-elementuak, eta UPyDrekin eta Alderdi Popularrarekin sinatu dugun zuzenketan hori esaten da, onar dadila gauzak hobetu egin daitezkeela. Izan ere, Iturrate jauna, zuk egin dizkiguzun galderei nik beste batzuekin erantzungo nieke. Posible da hutsegiteetatik ikastea eta herritarren segurtasuna hobetzea? Posible da? Edo, besterik gabe, "Ahal zen modu bakarrean jokatu dugu" esatera mugatu behar dugu? Posible da herritarrei k...</code> |
202
+ | <code>Zenbat denbora behar da Ertzaintzako promozio baten deialdia egiten denetik agenteak kalera irteten diren arte?</code> | <code>[TOPIC: Interpelazioa, Javier Ruiz de Arbulo Cerio Euskal Talde Popularreko legebiltzarkideak Segurtasuneko sailburuari egina, Arabako Miñoien Atalari buruz]<br>[SEGURTASUNEKO SAILBURUAK (BELTRÁN DE HEREDIA ARRONIZ), (EA-NV)]:<br>horiek aldatu egiten dira egun batetik bestera, unitate batetik bestera, kontuan hartuta zer bilakaera duten erretiroek, kontuan hartuta nola gertatzen diren baja horiek… Baina, batez ere, nik bezain ondo dakizu Ertzaintzan defizit handia daukagula, eta ezin hobeto dakizu zergatia zein den. Ez dakit defizit horren zergatia zein den errepika diezazudan etorri zaren hona, baina ez daukat inolako eragozpenik Legebiltzar honetan berriro azaltzeko eta zuek berriro entzun behar izateko. Honela gaude Espainiako Gobernuak, Alderdi Popularraren Gobernuak, denbora asko behar izan zuelako, denbora gehiegi, zuk behar izan duzun be- zala, ulertzeko premia geneukala Ertzaintzan gertatzen ari ziren erretiro-bajak estaltzeko promozio berriak deitzeko –gero eta gehiago dira erretiro...</code> |
203
+ | <code>Zergatik dimititu zuen Eusko Jaurlaritzako Komunikazio zuzendariak?</code> | <code>[TOPIC: Galdera, Gorka Maneiro Labayen Mistoa-UPyD taldeko legebiltzarkideak lehendakariari egina, Eusko Jaurlaritzako Komunikazio zuzendariaren dimisioaren ondoren hartu beharreko erantzukizun politikoei buruz]<br>[MANEIRO LABAYEN, (Mixto-UPyD)]:<br>sailburu jakin batzuei elkarrizketak egitearen truke? Erantzun ahal diezaiokezu galdera horri? Halaxe da, bai. Zure esanetan, ez dago ezer arrarorik eta irregularrik, baina pertsona batek dimititu egin du. Zer egiteko asmoa duzu zuk? Bide batez, zer da aldi baterako dimisioaren kontu hori? Beste postu batean jarri al duzue pertsona hori? Diru publikoa kobratzen jarraitzen al du? Argitu dezakezu, edo herritarrak engainatu nahi dituzue? Pertsona horrek dimititu egin du. Zer egiteko (Date: 30.10.2015)</code> |
204
+ * Loss: [<code>CachedMultipleNegativesRankingLoss</code>](https://sbert.net/docs/package_reference/cross_encoder/losses.html#cachedmultiplenegativesrankingloss) with these parameters:
205
+ ```json
206
+ {
207
+ "scale": 10.0,
208
+ "num_negatives": null,
209
+ "activation_fn": "torch.nn.modules.activation.Sigmoid",
210
+ "mini_batch_size": 16
211
+ }
212
+ ```
213
+
214
+ ### Training Hyperparameters
215
+ #### Non-Default Hyperparameters
216
+
217
+ - `eval_strategy`: steps
218
+ - `per_device_train_batch_size`: 16
219
+ - `per_device_eval_batch_size`: 16
220
+ - `learning_rate`: 2e-05
221
+ - `num_train_epochs`: 10
222
+ - `warmup_ratio`: 0.1
223
+ - `load_best_model_at_end`: True
224
+ - `batch_sampler`: no_duplicates
225
+
226
+ #### All Hyperparameters
227
+ <details><summary>Click to expand</summary>
228
+
229
+ - `overwrite_output_dir`: False
230
+ - `do_predict`: False
231
+ - `eval_strategy`: steps
232
+ - `prediction_loss_only`: True
233
+ - `per_device_train_batch_size`: 16
234
+ - `per_device_eval_batch_size`: 16
235
+ - `per_gpu_train_batch_size`: None
236
+ - `per_gpu_eval_batch_size`: None
237
+ - `gradient_accumulation_steps`: 1
238
+ - `eval_accumulation_steps`: None
239
+ - `torch_empty_cache_steps`: None
240
+ - `learning_rate`: 2e-05
241
+ - `weight_decay`: 0.0
242
+ - `adam_beta1`: 0.9
243
+ - `adam_beta2`: 0.999
244
+ - `adam_epsilon`: 1e-08
245
+ - `max_grad_norm`: 1.0
246
+ - `num_train_epochs`: 10
247
+ - `max_steps`: -1
248
+ - `lr_scheduler_type`: linear
249
+ - `lr_scheduler_kwargs`: {}
250
+ - `warmup_ratio`: 0.1
251
+ - `warmup_steps`: 0
252
+ - `log_level`: passive
253
+ - `log_level_replica`: warning
254
+ - `log_on_each_node`: True
255
+ - `logging_nan_inf_filter`: True
256
+ - `save_safetensors`: True
257
+ - `save_on_each_node`: False
258
+ - `save_only_model`: False
259
+ - `restore_callback_states_from_checkpoint`: False
260
+ - `no_cuda`: False
261
+ - `use_cpu`: False
262
+ - `use_mps_device`: False
263
+ - `seed`: 42
264
+ - `data_seed`: None
265
+ - `jit_mode_eval`: False
266
+ - `use_ipex`: False
267
+ - `bf16`: False
268
+ - `fp16`: False
269
+ - `fp16_opt_level`: O1
270
+ - `half_precision_backend`: auto
271
+ - `bf16_full_eval`: False
272
+ - `fp16_full_eval`: False
273
+ - `tf32`: None
274
+ - `local_rank`: 0
275
+ - `ddp_backend`: None
276
+ - `tpu_num_cores`: None
277
+ - `tpu_metrics_debug`: False
278
+ - `debug`: []
279
+ - `dataloader_drop_last`: False
280
+ - `dataloader_num_workers`: 0
281
+ - `dataloader_prefetch_factor`: None
282
+ - `past_index`: -1
283
+ - `disable_tqdm`: False
284
+ - `remove_unused_columns`: True
285
+ - `label_names`: None
286
+ - `load_best_model_at_end`: True
287
+ - `ignore_data_skip`: False
288
+ - `fsdp`: []
289
+ - `fsdp_min_num_params`: 0
290
+ - `fsdp_config`: {'min_num_params': 0, 'xla': False, 'xla_fsdp_v2': False, 'xla_fsdp_grad_ckpt': False}
291
+ - `fsdp_transformer_layer_cls_to_wrap`: None
292
+ - `accelerator_config`: {'split_batches': False, 'dispatch_batches': None, 'even_batches': True, 'use_seedable_sampler': True, 'non_blocking': False, 'gradient_accumulation_kwargs': None}
293
+ - `parallelism_config`: None
294
+ - `deepspeed`: None
295
+ - `label_smoothing_factor`: 0.0
296
+ - `optim`: adamw_torch
297
+ - `optim_args`: None
298
+ - `adafactor`: False
299
+ - `group_by_length`: False
300
+ - `length_column_name`: length
301
+ - `ddp_find_unused_parameters`: None
302
+ - `ddp_bucket_cap_mb`: None
303
+ - `ddp_broadcast_buffers`: False
304
+ - `dataloader_pin_memory`: True
305
+ - `dataloader_persistent_workers`: False
306
+ - `skip_memory_metrics`: True
307
+ - `use_legacy_prediction_loop`: False
308
+ - `push_to_hub`: False
309
+ - `resume_from_checkpoint`: None
310
+ - `hub_model_id`: None
311
+ - `hub_strategy`: every_save
312
+ - `hub_private_repo`: None
313
+ - `hub_always_push`: False
314
+ - `hub_revision`: None
315
+ - `gradient_checkpointing`: False
316
+ - `gradient_checkpointing_kwargs`: None
317
+ - `include_inputs_for_metrics`: False
318
+ - `include_for_metrics`: []
319
+ - `eval_do_concat_batches`: True
320
+ - `fp16_backend`: auto
321
+ - `push_to_hub_model_id`: None
322
+ - `push_to_hub_organization`: None
323
+ - `mp_parameters`:
324
+ - `auto_find_batch_size`: False
325
+ - `full_determinism`: False
326
+ - `torchdynamo`: None
327
+ - `ray_scope`: last
328
+ - `ddp_timeout`: 1800
329
+ - `torch_compile`: False
330
+ - `torch_compile_backend`: None
331
+ - `torch_compile_mode`: None
332
+ - `include_tokens_per_second`: False
333
+ - `include_num_input_tokens_seen`: False
334
+ - `neftune_noise_alpha`: None
335
+ - `optim_target_modules`: None
336
+ - `batch_eval_metrics`: False
337
+ - `eval_on_start`: False
338
+ - `use_liger_kernel`: False
339
+ - `liger_kernel_config`: None
340
+ - `eval_use_gather_object`: False
341
+ - `average_tokens_across_devices`: False
342
+ - `prompts`: None
343
+ - `batch_sampler`: no_duplicates
344
+ - `multi_dataset_batch_sampler`: proportional
345
+ - `router_mapping`: {}
346
+ - `learning_rate_mapping`: {}
347
+
348
+ </details>
349
+
350
+ ### Training Logs
351
+ | Epoch | Step | Training Loss | Validation Loss | jina-reranker-v2-base-multilingual-contrastive-parl-4-10ep_ndcg@10 |
352
+ |:-------:|:-------:|:-------------:|:---------------:|:------------------------------------------------------------------:|
353
+ | **1.0** | **200** | **0.0644** | **0.0238** | **0.0200 (+0.0175)** |
354
+ | 2.0 | 400 | 0.0238 | 0.0220 | 0.0198 (+0.0172) |
355
+ | 3.0 | 600 | 0.0182 | 0.0231 | 0.0200 (+0.0175) |
356
+ | 4.0 | 800 | 0.0167 | 0.0235 | 0.0198 (+0.0172) |
357
+ | 5.0 | 1000 | 0.0123 | 0.0240 | 0.0198 (+0.0172) |
358
+ | 6.0 | 1200 | 0.0123 | 0.0260 | 0.0198 (+0.0172) |
359
+ | 7.0 | 1400 | 0.0133 | 0.0260 | 0.0198 (+0.0172) |
360
+ | 8.0 | 1600 | 0.0143 | 0.0258 | 0.0198 (+0.0172) |
361
+ | 9.0 | 1800 | 0.0136 | 0.0258 | 0.0198 (+0.0172) |
362
+ | 10.0 | 2000 | 0.0135 | 0.0257 | 0.0198 (+0.0172) |
363
+
364
+ * The bold row denotes the saved checkpoint.
365
+
366
+ ### Framework Versions
367
+ - Python: 3.9.7
368
+ - Sentence Transformers: 5.0.0
369
+ - Transformers: 4.56.0
370
+ - PyTorch: 2.7.1+cu126
371
+ - Accelerate: 1.5.2
372
+ - Datasets: 4.0.0
373
+ - Tokenizers: 0.22.0
374
+
375
+ ## Citation
376
+
377
+ ### BibTeX
378
+
379
+ #### Sentence Transformers
380
+ ```bibtex
381
+ @inproceedings{reimers-2019-sentence-bert,
382
+ title = "Sentence-BERT: Sentence Embeddings using Siamese BERT-Networks",
383
+ author = "Reimers, Nils and Gurevych, Iryna",
384
+ booktitle = "Proceedings of the 2019 Conference on Empirical Methods in Natural Language Processing",
385
+ month = "11",
386
+ year = "2019",
387
+ publisher = "Association for Computational Linguistics",
388
+ url = "https://arxiv.org/abs/1908.10084",
389
+ }
390
+ ```
391
+
392
+ <!--
393
+ ## Glossary
394
+
395
+ *Clearly define terms in order to be accessible across audiences.*
396
+ -->
397
+
398
+ <!--
399
+ ## Model Card Authors
400
+
401
+ *Lists the people who create the model card, providing recognition and accountability for the detailed work that goes into its construction.*
402
+ -->
403
+
404
+ <!--
405
+ ## Model Card Contact
406
+
407
+ *Provides a way for people who have updates to the Model Card, suggestions, or questions, to contact the Model Card authors.*
408
+ -->
block.py ADDED
@@ -0,0 +1,470 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # This implementation was adapted from https://github.com/Dao-AILab/flash-attention/blob/main/flash_attn/modules/block.py
2
+ # Commit id: abbc1311731867310635f9edc2a9ec18317c8c48
3
+
4
+ # Copyright (c) 2024, Tri Dao.
5
+
6
+ from functools import partial
7
+ from typing import Optional
8
+
9
+ import torch
10
+ import torch.fx
11
+ import torch.nn as nn
12
+ import torch.nn.functional as F
13
+ from torch import Tensor
14
+
15
+ from .mha import MHA
16
+ from .mlp import Mlp
17
+
18
+ try:
19
+ from flash_attn.ops.triton.layer_norm import layer_norm_fn, RMSNorm
20
+ except ImportError:
21
+ layer_norm_fn, RMSNorm = None, None
22
+
23
+
24
+ def stochastic_depth(
25
+ input: Tensor, p: float, mode: str, training: bool = True
26
+ ) -> Tensor:
27
+ """
28
+ Implements the Stochastic Depth from `"Deep Networks with Stochastic Depth"
29
+ <https://arxiv.org/abs/1603.09382>`_ used for randomly dropping residual
30
+ branches of residual architectures.
31
+ Args:
32
+ input (Tensor[N, ...]): The input tensor or arbitrary dimensions with the first one
33
+ being its batch i.e. a batch with ``N`` rows.
34
+ p (float): probability of the input to be zeroed.
35
+ mode (str): ``"batch"`` or ``"row"``.
36
+ ``"batch"`` randomly zeroes the entire input, ``"row"`` zeroes
37
+ randomly selected rows from the batch.
38
+ training: apply stochastic depth if is ``True``. Default: ``True``
39
+ Returns:
40
+ Tensor[N, ...]: The randomly zeroed tensor.
41
+ """
42
+ if p < 0.0 or p > 1.0:
43
+ raise ValueError(f"drop probability has to be between 0 and 1, but got {p}")
44
+ if mode not in ["batch", "row"]:
45
+ raise ValueError(f"mode has to be either 'batch' or 'row', but got {mode}")
46
+ if not training or p == 0.0:
47
+ return input
48
+
49
+ survival_rate = 1.0 - p
50
+ if mode == "row":
51
+ size = [input.shape[0]] + [1] * (input.ndim - 1)
52
+ else:
53
+ size = [1] * input.ndim
54
+ noise = torch.empty(size, dtype=input.dtype, device=input.device)
55
+ noise = noise.bernoulli_(survival_rate)
56
+ if survival_rate > 0.0:
57
+ noise.div_(survival_rate)
58
+ return input * noise
59
+
60
+
61
+ torch.fx.wrap("stochastic_depth")
62
+
63
+
64
+ class StochasticDepth(nn.Module):
65
+ """
66
+ See :func:`stochastic_depth`.
67
+ """
68
+
69
+ def __init__(self, p: float, mode: str) -> None:
70
+ super().__init__()
71
+ self.p = p
72
+ self.mode = mode
73
+
74
+ def forward(self, input: Tensor) -> Tensor:
75
+ return stochastic_depth(input, self.p, self.mode, self.training)
76
+
77
+ def __repr__(self) -> str:
78
+ s = f"{self.__class__.__name__}(p={self.p}, mode={self.mode})"
79
+ return s
80
+
81
+
82
+ class Block(nn.Module):
83
+ def __init__(
84
+ self,
85
+ dim,
86
+ mixer_cls=None,
87
+ mlp_cls=None,
88
+ norm_cls=nn.LayerNorm,
89
+ dropout_cls=nn.Dropout,
90
+ prenorm=True,
91
+ resid_dropout1=0.0,
92
+ resid_dropout2=0.0,
93
+ drop_path1=0.0,
94
+ drop_path2=0.0,
95
+ fused_dropout_add_ln=False,
96
+ return_residual=False,
97
+ residual_in_fp32=False,
98
+ sequence_parallel=False,
99
+ mark_shared_params=False,
100
+ ):
101
+ """
102
+ For prenorm=True, this Block has a slightly different structure compared to a regular
103
+ prenorm Transformer block.
104
+ The standard block is: LN -> MHA -> Dropout -> Add -> LN -> MLP -> Dropout -> Add.
105
+ [Ref: https://arxiv.org/abs/2002.04745]
106
+ Here we have: Dropout -> Add -> LN -> MHA -> Dropout -> Add -> LN -> MLP, returning both
107
+ the hidden_states (output of the MLP) and the residual.
108
+ This is for performance reasons, as we can fuse the dropout, add and LayerNorm.
109
+ The residual needs to be provided (except for the very first block).
110
+
111
+ For prenorm=False, this Block has the same structure as a regular postnorm Transformer
112
+ block: MHA -> Dropout -> Add -> LN -> MLP -> Dropout -> Add -> LN.
113
+
114
+ return_residual: whether each of the sub-layers (mixer and mlp) will return the residual.
115
+ This is for performance reason: for post-norm architecture, returning the input allows us
116
+ to fuse the backward of nn.Linear with the residual connection.
117
+ """
118
+ super().__init__()
119
+ self.prenorm = prenorm
120
+ self.fused_dropout_add_ln = fused_dropout_add_ln
121
+ self.return_residual = return_residual
122
+ self.residual_in_fp32 = residual_in_fp32
123
+ if self.residual_in_fp32:
124
+ assert self.prenorm, "residual_in_fp32 is only compatible with prenorm=True"
125
+ if mixer_cls is None:
126
+ mixer_cls = partial(MHA, num_heads=dim // 64)
127
+ if mlp_cls is None:
128
+ mlp_cls = partial(Mlp, hidden_features=4 * dim)
129
+ self.mixer = mixer_cls(dim)
130
+ self.dropout1 = dropout_cls(resid_dropout1)
131
+ self.drop_path1 = StochasticDepth(drop_path1, mode="row")
132
+ self.norm1 = norm_cls(dim)
133
+ self.mlp = mlp_cls(dim)
134
+ if not isinstance(self.mlp, nn.Identity):
135
+ self.dropout2 = dropout_cls(resid_dropout2)
136
+ self.drop_path2 = StochasticDepth(drop_path2, mode="row")
137
+ self.norm2 = norm_cls(dim)
138
+
139
+ if self.fused_dropout_add_ln:
140
+ assert layer_norm_fn is not None, "Triton is not installed"
141
+ assert isinstance(self.norm1, (nn.LayerNorm, RMSNorm)) and isinstance(
142
+ self.dropout1, nn.Dropout
143
+ )
144
+
145
+ # TD [2023-01-07]: TODO: During training, if sequence_parallel is False and dropout != 0.0,
146
+ # then the input to each worker in the tensor parallel group will be different.
147
+ # This would produce wrong outputs? Somehow we'd need to sync the RNG state across workers.
148
+ # For now this is not an issue because we always use sequence_parallel=True during training
149
+ # and only use sequence_parallel=False during inference.
150
+
151
+ # Mark the norm parameters as "sequence_parallel" so that we run all-reduce on their grads.
152
+ if sequence_parallel:
153
+ for p in self.norm1.parameters():
154
+ p._sequence_parallel = True
155
+ if hasattr(self, "norm2"):
156
+ for p in self.norm2.parameters():
157
+ p._sequence_parallel = True
158
+ # Mark the norm parameters as "shared_params" so that we sync their values at init.
159
+ if mark_shared_params:
160
+ for p in self.norm1.parameters():
161
+ p._shared_params = True
162
+ if hasattr(self, "norm2"):
163
+ for p in self.norm2.parameters():
164
+ p._shared_params = True
165
+
166
+ def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, **kwargs):
167
+ return self.mixer.allocate_inference_cache(
168
+ batch_size, max_seqlen, dtype=dtype, **kwargs
169
+ )
170
+
171
+ def forward(
172
+ self,
173
+ hidden_states: Tensor,
174
+ residual: Optional[Tensor] = None,
175
+ mixer_subset=None,
176
+ mixer_kwargs=None,
177
+ ):
178
+ r"""Pass the input through the encoder layer.
179
+
180
+ Args:
181
+ hidden_states: the sequence to the encoder layer (required).
182
+ residual: if postnorm, residual=None, If prenorm, hidden_states = Attn/MLP(LN(residual))
183
+ mixer_subset: for cross-attention only. If not None, will take a subset of x
184
+ before applying the query projection. Useful for e.g., ViT where we only care
185
+ about the CLS token in the last layer.
186
+ """
187
+ if self.prenorm:
188
+ if not self.fused_dropout_add_ln:
189
+ dropped = self.drop_path1(self.dropout1(hidden_states))
190
+ residual = (dropped + residual) if residual is not None else dropped
191
+ hidden_states = self.norm1(residual.to(dtype=self.norm1.weight.dtype))
192
+ if self.residual_in_fp32:
193
+ residual = residual.to(torch.float32)
194
+ else:
195
+ if self.drop_path1.p == 0 or not self.training:
196
+ rowscale1 = None
197
+ else:
198
+ rowscale1 = self.drop_path1(
199
+ torch.ones(
200
+ hidden_states.shape[:-1],
201
+ device=hidden_states.device,
202
+ dtype=hidden_states.dtype,
203
+ )
204
+ )
205
+ hidden_states, residual = layer_norm_fn(
206
+ hidden_states,
207
+ self.norm1.weight,
208
+ self.norm1.bias,
209
+ residual=residual,
210
+ eps=self.norm1.eps,
211
+ dropout_p=self.dropout1.p if self.training else 0.0,
212
+ rowscale=rowscale1,
213
+ prenorm=True,
214
+ residual_in_fp32=self.residual_in_fp32,
215
+ is_rms_norm=isinstance(self.norm1, RMSNorm),
216
+ )
217
+ if mixer_kwargs is None:
218
+ mixer_kwargs = {}
219
+ if mixer_subset is not None:
220
+ mixer_kwargs["mixer_subset"] = mixer_subset
221
+ hidden_states = self.mixer(hidden_states, **mixer_kwargs)
222
+ if mixer_subset is not None:
223
+ residual = residual[:, mixer_subset]
224
+ if not isinstance(self.mlp, nn.Identity):
225
+ if not self.fused_dropout_add_ln:
226
+ dropped = self.drop_path2(self.dropout2(hidden_states))
227
+ residual = (dropped + residual) if residual is not None else dropped
228
+ hidden_states = self.norm2(
229
+ residual.to(dtype=self.norm2.weight.dtype)
230
+ )
231
+ if self.residual_in_fp32:
232
+ residual = residual.to(torch.float32)
233
+ else:
234
+ if self.drop_path2.p == 0 or not self.training:
235
+ rowscale2 = None
236
+ else:
237
+ rowscale2 = self.drop_path2(
238
+ torch.ones(
239
+ hidden_states.shape[:-1],
240
+ device=hidden_states.device,
241
+ dtype=hidden_states.dtype,
242
+ )
243
+ )
244
+ hidden_states, residual = layer_norm_fn(
245
+ hidden_states,
246
+ self.norm2.weight,
247
+ self.norm2.bias,
248
+ residual=residual,
249
+ eps=self.norm2.eps,
250
+ dropout_p=self.dropout2.p if self.training else 0.0,
251
+ rowscale=rowscale2,
252
+ prenorm=True,
253
+ residual_in_fp32=self.residual_in_fp32,
254
+ is_rms_norm=isinstance(self.norm2, RMSNorm),
255
+ )
256
+ hidden_states = self.mlp(hidden_states)
257
+ return hidden_states, residual
258
+ else:
259
+ assert residual is None
260
+ mixer_out = self.mixer(
261
+ hidden_states, **(mixer_kwargs if mixer_kwargs is not None else {})
262
+ )
263
+ if self.return_residual: # mixer out is actually a pair here
264
+ mixer_out, hidden_states = mixer_out
265
+ if not self.fused_dropout_add_ln:
266
+ hidden_states = self.norm1(
267
+ (self.drop_path1(self.dropout1(mixer_out)) + hidden_states).to(
268
+ dtype=self.norm1.weight.dtype
269
+ )
270
+ )
271
+ else:
272
+ if self.drop_path1.p == 0 or not self.training:
273
+ rowscale1 = None
274
+ else:
275
+ rowscale1 = self.drop_path1(
276
+ torch.ones(
277
+ mixer_out.shape[:-1],
278
+ device=mixer_out.device,
279
+ dtype=mixer_out.dtype,
280
+ )
281
+ )
282
+ hidden_states = layer_norm_fn(
283
+ mixer_out,
284
+ self.norm1.weight,
285
+ self.norm1.bias,
286
+ residual=hidden_states,
287
+ eps=self.norm1.eps,
288
+ dropout_p=self.dropout1.p if self.training else 0.0,
289
+ rowscale=rowscale1,
290
+ prenorm=False,
291
+ is_rms_norm=isinstance(self.norm1, RMSNorm),
292
+ )
293
+ if not isinstance(self.mlp, nn.Identity):
294
+ mlp_out = self.mlp(hidden_states)
295
+ if self.return_residual: # mlp out is actually a pair here
296
+ mlp_out, hidden_states = mlp_out
297
+ if not self.fused_dropout_add_ln:
298
+ hidden_states = self.norm2(
299
+ (self.drop_path2(self.dropout2(mlp_out)) + hidden_states).to(
300
+ dtype=self.norm2.weight.dtype
301
+ )
302
+ )
303
+ else:
304
+ if self.drop_path2.p == 0 or not self.training:
305
+ rowscale2 = None
306
+ else:
307
+ rowscale2 = self.drop_path2(
308
+ torch.ones(
309
+ mlp_out.shape[:-1],
310
+ device=mlp_out.device,
311
+ dtype=mlp_out.dtype,
312
+ )
313
+ )
314
+ hidden_states = layer_norm_fn(
315
+ mlp_out,
316
+ self.norm2.weight,
317
+ self.norm2.bias,
318
+ residual=hidden_states,
319
+ eps=self.norm2.eps,
320
+ dropout_p=self.dropout2.p if self.training else 0.0,
321
+ rowscale=rowscale2,
322
+ prenorm=False,
323
+ is_rms_norm=isinstance(self.norm2, RMSNorm),
324
+ )
325
+ return hidden_states
326
+
327
+
328
+ class ParallelBlock(nn.Module):
329
+ """The attention (mixer) and MLP blocks are done in parallel, similar to GPT-J, GPT-NeoX,
330
+ and PaLM.
331
+ """
332
+
333
+ def __init__(
334
+ self,
335
+ dim,
336
+ mixer_cls=None,
337
+ mlp_cls=None,
338
+ norm_cls=nn.LayerNorm,
339
+ dropout_cls=nn.Dropout,
340
+ resid_dropout1=0.0,
341
+ resid_dropout2=0.0,
342
+ tied_norm=False,
343
+ fused_dropout_add_ln=False,
344
+ residual_in_fp32=False,
345
+ sequence_parallel=False,
346
+ mark_shared_params=False,
347
+ ):
348
+ """
349
+ This Block has a slightly different structure compared to a regular
350
+ prenorm Transformer block.
351
+ The standard block is: LN -> MHA / MLP -> Dropout -> Add.
352
+ [Ref: https://arxiv.org/abs/2002.04745]
353
+ Here we have: Dropout -> Add -> LN -> MHA / MLP, returning both
354
+ the hidden_states (output1 of the MHA / MLP) and the residual.
355
+ This is for performance reasons, as we can fuse the dropout, add and LayerNorm.
356
+ The residual needs to be provided (except for the very first block).
357
+ """
358
+ super().__init__()
359
+ self.tied_norm = tied_norm
360
+ self.fused_dropout_add_ln = fused_dropout_add_ln
361
+ self.residual_in_fp32 = residual_in_fp32
362
+ if mixer_cls is None:
363
+ mixer_cls = partial(MHA, num_heads=dim // 64)
364
+ if mlp_cls is None:
365
+ mlp_cls = partial(Mlp, hidden_features=4 * dim)
366
+ self.mixer = mixer_cls(dim)
367
+ self.dropout1 = dropout_cls(resid_dropout1)
368
+ self.norm1 = norm_cls(dim)
369
+ self.mlp = mlp_cls(dim)
370
+ self.dropout2 = dropout_cls(resid_dropout2)
371
+ if not self.tied_norm:
372
+ self.norm2 = norm_cls(dim)
373
+
374
+ if self.fused_dropout_add_ln:
375
+ assert layer_norm_fn is not None, "Triton is not installed"
376
+ assert isinstance(self.norm1, (nn.LayerNorm, RMSNorm)) and isinstance(
377
+ self.dropout1, nn.Dropout
378
+ )
379
+
380
+ # TD [2023-01-07]: TODO: During training, if sequence_parallel is False and dropout != 0.0,
381
+ # then the input to each worker in the tensor parallel group will be different.
382
+ # This would produce wrong outputs? Somehow we'd need to sync the RNG state across workers.
383
+ # For now this is not an issue because we always use sequence_parallel=True during training
384
+ # and only use sequence_parallel=False during inference.
385
+
386
+ # Mark the norm parameters as "sequence_parallel" so that we run all-reduce on their grads.
387
+ if sequence_parallel:
388
+ for p in self.norm1.parameters():
389
+ p._sequence_parallel = True
390
+ if hasattr(self, "norm2"):
391
+ for p in self.norm2.parameters():
392
+ p._sequence_parallel = True
393
+ # Mark the norm parameters as "shared_params" so that we sync their values at init.
394
+ if mark_shared_params:
395
+ for p in self.norm1.parameters():
396
+ p._shared_params = True
397
+ if hasattr(self, "norm2"):
398
+ for p in self.norm2.parameters():
399
+ p._shared_params = True
400
+
401
+ def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, **kwargs):
402
+ return self.mixer.allocate_inference_cache(
403
+ batch_size, max_seqlen, dtype=dtype, **kwargs
404
+ )
405
+
406
+ def forward(
407
+ self,
408
+ hidden_states1: Tensor,
409
+ hidden_states2: Optional[Tensor] = None,
410
+ residual: Optional[Tensor] = None,
411
+ mixer_kwargs=None,
412
+ ):
413
+ r"""Pass the input through the encoder layer.
414
+
415
+ Args:
416
+ hidden_states1: the output of the previous attention (mixer) or embedding layer.
417
+ hidden_states2: the output of the previous MLP layer (if None, will use hidden_states1).
418
+ residual.
419
+ """
420
+ # TODO: Ideally we should only do the allgather / allreduce once for
421
+ # the Linear to MLP & Attention
422
+ if not self.fused_dropout_add_ln:
423
+ dropped1 = self.dropout1(hidden_states1)
424
+ # For the very 1st block, we only want 1 dropout, not two different dropouts
425
+ if hidden_states2 is not None:
426
+ dropped2 = self.dropout2(hidden_states2)
427
+ residual = (
428
+ (residual + dropped1 + dropped2)
429
+ if residual is not None
430
+ else dropped1 + dropped2
431
+ )
432
+ else:
433
+ residual = (residual + dropped1) if residual is not None else dropped1
434
+ hidden_states1 = self.norm1(residual.to(dtype=self.norm1.weight.dtype))
435
+ hidden_states2 = (
436
+ self.norm2(residual.to(dtype=self.norm2.weight.dtype))
437
+ if not self.tied_norm
438
+ else hidden_states1
439
+ )
440
+ if self.residual_in_fp32:
441
+ residual = residual.to(torch.float32)
442
+ else:
443
+ weight2, bias2 = (
444
+ (self.norm2.weight, self.norm2.bias)
445
+ if not self.tied_norm
446
+ else (None, None)
447
+ )
448
+ hidden_states1, *rest, residual = layer_norm_fn(
449
+ hidden_states1,
450
+ self.norm1.weight,
451
+ self.norm1.bias,
452
+ residual=residual,
453
+ x1=hidden_states2,
454
+ weight1=weight2,
455
+ bias1=bias2,
456
+ eps=self.norm1.eps,
457
+ dropout_p=self.dropout1.p if self.training else 0.0,
458
+ prenorm=True,
459
+ residual_in_fp32=self.residual_in_fp32,
460
+ is_rms_norm=isinstance(self.norm1, RMSNorm),
461
+ )
462
+ if self.tied_norm:
463
+ hidden_states2 = hidden_states1
464
+ else:
465
+ (hidden_states2,) = rest
466
+ if mixer_kwargs is None:
467
+ mixer_kwargs = {}
468
+ hidden_states1 = self.mixer(hidden_states1, **mixer_kwargs)
469
+ hidden_states2 = self.mlp(hidden_states2)
470
+ return hidden_states1, hidden_states2, residual
config.json ADDED
@@ -0,0 +1,51 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "architectures": [
3
+ "XLMRobertaForSequenceClassification"
4
+ ],
5
+ "attention_probs_dropout_prob": 0.1,
6
+ "auto_map": {
7
+ "AutoConfig": "configuration_xlm_roberta.XLMRobertaFlashConfig",
8
+ "AutoModel": "modeling_xlm_roberta.XLMRobertaModel",
9
+ "AutoModelForSequenceClassification": "modeling_xlm_roberta.XLMRobertaForSequenceClassification"
10
+ },
11
+ "bos_token_id": 0,
12
+ "classifier_dropout": null,
13
+ "dtype": "bfloat16",
14
+ "emb_pooler": null,
15
+ "eos_token_id": 2,
16
+ "hidden_act": "gelu",
17
+ "hidden_dropout_prob": 0.1,
18
+ "hidden_size": 768,
19
+ "id2label": {
20
+ "0": "LABEL_0"
21
+ },
22
+ "initializer_range": 0.02,
23
+ "intermediate_size": 3072,
24
+ "label2id": {
25
+ "LABEL_0": 0
26
+ },
27
+ "layer_norm_eps": 1e-05,
28
+ "load_trained_adapters": false,
29
+ "lora_adaptations": null,
30
+ "lora_alpha": 1,
31
+ "lora_dropout_p": 0.0,
32
+ "lora_main_params_trainable": false,
33
+ "lora_rank": 4,
34
+ "matryoshka_dimensions": null,
35
+ "max_position_embeddings": 1026,
36
+ "num_attention_heads": 12,
37
+ "num_hidden_layers": 12,
38
+ "output_past": true,
39
+ "pad_token_id": 1,
40
+ "position_embedding_type": "absolute",
41
+ "sentence_transformers": {
42
+ "activation_fn": "torch.nn.modules.activation.Sigmoid",
43
+ "version": "5.0.0"
44
+ },
45
+ "transformers_version": "4.56.0",
46
+ "truncate_dim": null,
47
+ "type_vocab_size": 1,
48
+ "use_cache": false,
49
+ "use_flash_attn": true,
50
+ "vocab_size": 250002
51
+ }
configuration_xlm_roberta.py ADDED
@@ -0,0 +1,69 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import PretrainedConfig
2
+ import torch
3
+
4
+ class XLMRobertaFlashConfig(PretrainedConfig):
5
+ def __init__(
6
+ self,
7
+ vocab_size=30522,
8
+ hidden_size=768,
9
+ num_hidden_layers=12,
10
+ num_attention_heads=12,
11
+ intermediate_size=3072,
12
+ hidden_act="gelu",
13
+ hidden_dropout_prob=0.1,
14
+ attention_probs_dropout_prob=0.1,
15
+ max_position_embeddings=512,
16
+ type_vocab_size=2,
17
+ initializer_range=0.02,
18
+ layer_norm_eps=1e-12,
19
+ pad_token_id=1,
20
+ bos_token_id=0,
21
+ eos_token_id=2,
22
+ position_embedding_type="absolute",
23
+ use_cache=True,
24
+ classifier_dropout=None,
25
+ lora_adaptations=None,
26
+ lora_rank=4,
27
+ lora_dropout_p=0.0,
28
+ lora_alpha=1,
29
+ lora_main_params_trainable=False,
30
+ load_trained_adapters=False,
31
+ use_flash_attn=True,
32
+ torch_dtype=None,
33
+ emb_pooler=None,
34
+ matryoshka_dimensions=None,
35
+ truncate_dim=None,
36
+ **kwargs,
37
+ ):
38
+ super().__init__(pad_token_id=pad_token_id, bos_token_id=bos_token_id, eos_token_id=eos_token_id, **kwargs)
39
+
40
+
41
+ self.vocab_size = vocab_size
42
+ self.hidden_size = hidden_size
43
+ self.num_hidden_layers = num_hidden_layers
44
+ self.num_attention_heads = num_attention_heads
45
+ self.hidden_act = hidden_act
46
+ self.intermediate_size = intermediate_size
47
+ self.hidden_dropout_prob = hidden_dropout_prob
48
+ self.attention_probs_dropout_prob = attention_probs_dropout_prob
49
+ self.max_position_embeddings = max_position_embeddings
50
+ self.type_vocab_size = type_vocab_size
51
+ self.initializer_range = initializer_range
52
+ self.layer_norm_eps = layer_norm_eps
53
+ self.position_embedding_type = position_embedding_type
54
+ self.use_cache = use_cache
55
+ self.classifier_dropout = classifier_dropout
56
+ self.load_trained_adapters = load_trained_adapters
57
+ self.lora_adaptations = lora_adaptations
58
+ self.lora_rank = lora_rank
59
+ self.lora_dropout_p = lora_dropout_p
60
+ self.lora_alpha = lora_alpha
61
+ self.lora_main_params_trainable = lora_main_params_trainable
62
+ self.use_flash_attn = use_flash_attn
63
+ self.emb_pooler = emb_pooler
64
+ self.matryoshka_dimensions = matryoshka_dimensions
65
+ self.truncate_dim = truncate_dim
66
+ if torch_dtype and hasattr(torch, torch_dtype) and type(getattr(torch, torch_dtype)) is torch.dtype:
67
+ self.torch_dtype = getattr(torch, torch_dtype)
68
+ else:
69
+ self.torch_dtype = torch_dtype
embedding.py ADDED
@@ -0,0 +1,62 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # This implementation was adapted from https://github.com/Dao-AILab/flash-attention/blob/main/flash_attn/modules/embedding.py
2
+ # Commit id: f1a73d074002226c42ce65a1df170ecff9f022c0
3
+
4
+ # Copyright (c) 2022, Tri Dao.
5
+
6
+ import torch
7
+ import torch.nn as nn
8
+ from einops import rearrange
9
+ from torch import Tensor
10
+
11
+ from transformers.models.xlm_roberta.modeling_xlm_roberta import create_position_ids_from_input_ids
12
+
13
+
14
+ class XLMRobertaEmbeddings(nn.Module):
15
+ def __init__(
16
+ self,
17
+ embed_dim,
18
+ vocab_size,
19
+ max_position_embeddings,
20
+ type_vocab_size,
21
+ padding_idx=None,
22
+ device=None,
23
+ dtype=None,
24
+ ):
25
+ """
26
+ If max_position_embeddings <= 0, there's no position embeddings
27
+ If type_vocab_size <= 0, there's no token type embeddings
28
+ """
29
+ factory_kwargs = {"device": device, "dtype": dtype}
30
+ super().__init__()
31
+ self.word_embeddings = nn.Embedding(
32
+ vocab_size, embed_dim, padding_idx=padding_idx, **factory_kwargs
33
+ )
34
+ self.max_position_embeddings = max_position_embeddings
35
+ self.type_vocab_size = type_vocab_size
36
+ if self.max_position_embeddings > 0:
37
+ self.position_embeddings = nn.Embedding(
38
+ max_position_embeddings, embed_dim, **factory_kwargs
39
+ )
40
+ if self.type_vocab_size > 0:
41
+ self.token_type_embeddings = nn.Embedding(type_vocab_size, embed_dim, **factory_kwargs)
42
+
43
+ def forward(self, input_ids, position_ids=None, token_type_ids=None):
44
+ """
45
+ input_ids: (batch, seqlen)
46
+ position_ids: (batch, seqlen)
47
+ token_type_ids: (batch, seqlen)
48
+ """
49
+ batch_size, seqlen = input_ids.shape
50
+ embeddings = self.word_embeddings(input_ids)
51
+ if self.max_position_embeddings > 0:
52
+ if position_ids is None:
53
+ position_ids = create_position_ids_from_input_ids(input_ids, padding_idx=self.word_embeddings.padding_idx).to(input_ids.device)
54
+ # position_ids = torch.arange(seqlen, dtype=torch.long, device=input_ids.device)
55
+ position_embeddings = self.position_embeddings(position_ids)
56
+ embeddings = embeddings + position_embeddings
57
+ if self.type_vocab_size > 0:
58
+ if token_type_ids is None:
59
+ token_type_ids = torch.zeros(seqlen, dtype=torch.long, device=input_ids.device)
60
+ token_type_embeddings = self.token_type_embeddings(token_type_ids)
61
+ embeddings = embeddings + token_type_embeddings
62
+ return embeddings
mha.py ADDED
@@ -0,0 +1,662 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2023, Tri Dao.
2
+ # Adapted from https://github.com/Dao-AILab/flash-attention/pull/556
3
+
4
+ import math
5
+ from functools import partial
6
+
7
+ import torch
8
+ import torch.nn as nn
9
+ from einops import rearrange, repeat
10
+
11
+ try:
12
+ from flash_attn import (
13
+ flash_attn_kvpacked_func,
14
+ flash_attn_qkvpacked_func,
15
+ flash_attn_varlen_kvpacked_func,
16
+ flash_attn_varlen_qkvpacked_func,
17
+ flash_attn_with_kvcache,
18
+ )
19
+ except ImportError:
20
+ flash_attn_varlen_qkvpacked_func, flash_attn_varlen_kvpacked_func = None, None
21
+ flash_attn_qkvpacked_func, flash_attn_kvpacked_func = None, None
22
+ flash_attn_with_kvcache = None
23
+
24
+ try:
25
+ from flash_attn.ops.fused_dense import ColumnParallelLinear, FusedDense, RowParallelLinear
26
+ except ImportError:
27
+ FusedDense, ColumnParallelLinear, RowParallelLinear = None, None, None
28
+
29
+
30
+ class FlashSelfAttention(nn.Module):
31
+ """Implement the scaled dot product attention with softmax.
32
+ Arguments
33
+ ---------
34
+ softmax_scale: The temperature to use for the softmax attention.
35
+ (default: 1/sqrt(d_keys) where d_keys is computed at
36
+ runtime)
37
+ attention_dropout: The dropout rate to apply to the attention
38
+ (default: 0.0)
39
+ """
40
+
41
+ def __init__(
42
+ self,
43
+ causal=False,
44
+ softmax_scale=None,
45
+ attention_dropout=0.0,
46
+ window_size=(-1, -1),
47
+ deterministic=False,
48
+ ):
49
+ super().__init__()
50
+ assert flash_attn_varlen_qkvpacked_func is not None, "FlashAttention is not installed"
51
+ assert flash_attn_qkvpacked_func is not None, "FlashAttention is not installed"
52
+ self.causal = causal
53
+ self.softmax_scale = softmax_scale
54
+ self.drop = nn.Dropout(attention_dropout)
55
+ self.window_size = window_size
56
+ self.deterministic = deterministic
57
+
58
+ def forward(self, qkv, causal=None, cu_seqlens=None, max_seqlen=None):
59
+ """Implements the multihead softmax attention.
60
+ Arguments
61
+ ---------
62
+ qkv: The tensor containing the query, key, and value.
63
+ If cu_seqlens is None and max_seqlen is None, then qkv has shape (B, S, 3, H, D).
64
+ If cu_seqlens is not None and max_seqlen is not None, then qkv has shape
65
+ (total, 3, H, D), where total is the sum of the sequence lengths in the batch.
66
+ causal: if passed, will override self.causal
67
+ cu_seqlens: (batch_size + 1,), dtype torch.int32. The cumulative sequence lengths
68
+ of the sequences in the batch, used to index into qkv.
69
+ max_seqlen: int. Maximum sequence length in the batch.
70
+ Returns:
71
+ --------
72
+ out: (total, H, D) if cu_seqlens is not None and max_seqlen is not None,
73
+ else (B, S, H, D).
74
+ """
75
+ assert qkv.dtype in [torch.float16, torch.bfloat16]
76
+ assert qkv.is_cuda
77
+ causal = self.causal if causal is None else causal
78
+ unpadded = cu_seqlens is not None
79
+
80
+ if unpadded:
81
+ assert cu_seqlens.dtype == torch.int32
82
+ assert max_seqlen is not None
83
+ assert isinstance(max_seqlen, int)
84
+ return flash_attn_varlen_qkvpacked_func(
85
+ qkv,
86
+ cu_seqlens,
87
+ max_seqlen,
88
+ self.drop.p if self.training else 0.0,
89
+ softmax_scale=self.softmax_scale,
90
+ causal=causal,
91
+ alibi_slopes=None,
92
+ window_size=self.window_size,
93
+ deterministic=self.deterministic,
94
+ )
95
+ else:
96
+ return flash_attn_qkvpacked_func(
97
+ qkv,
98
+ self.drop.p if self.training else 0.0,
99
+ softmax_scale=self.softmax_scale,
100
+ causal=causal,
101
+ alibi_slopes=None,
102
+ window_size=self.window_size,
103
+ deterministic=self.deterministic,
104
+ )
105
+
106
+
107
+ class FlashCrossAttention(nn.Module):
108
+ """Implement the scaled dot product attention with softmax.
109
+ Arguments
110
+ ---------
111
+ softmax_scale: The temperature to use for the softmax attention.
112
+ (default: 1/sqrt(d_keys) where d_keys is computed at
113
+ runtime)
114
+ attention_dropout: The dropout rate to apply to the attention
115
+ (default: 0.0)
116
+ """
117
+
118
+ def __init__(
119
+ self,
120
+ causal=False,
121
+ softmax_scale=None,
122
+ attention_dropout=0.0,
123
+ window_size=(-1, -1),
124
+ deterministic=False,
125
+ ):
126
+ super().__init__()
127
+ assert flash_attn_varlen_kvpacked_func is not None, "FlashAttention is not installed"
128
+ assert flash_attn_kvpacked_func is not None, "FlashAttention is not installed"
129
+ self.causal = causal
130
+ self.softmax_scale = softmax_scale
131
+ self.drop = nn.Dropout(attention_dropout)
132
+ self.window_size = window_size
133
+ self.deterministic = deterministic
134
+
135
+ def forward(
136
+ self,
137
+ q,
138
+ kv,
139
+ causal=None,
140
+ cu_seqlens=None,
141
+ max_seqlen=None,
142
+ cu_seqlens_k=None,
143
+ max_seqlen_k=None,
144
+ ):
145
+ """Implements the multihead softmax attention.
146
+ Arguments
147
+ ---------
148
+ q: The tensor containing the query. (B, Sq, H, D)
149
+ kv: The tensor containing the key and value. (B, Sk, 2, H_k, D)
150
+ causal: if passed, will override self.causal
151
+ cu_seqlens: (batch_size + 1,), dtype torch.int32. The cumulative sequence lengths
152
+ of the sequences in the batch, used to index into q.
153
+ max_seqlen: int. Maximum sequence length in the batch of q.
154
+ cu_seqlens_k: (batch_size + 1,), dtype torch.int32. The cumulative sequence lengths
155
+ of the sequences in the batch, used to index into kv.
156
+ max_seqlen_k: int. Maximum sequence length in the batch of k and v.
157
+ """
158
+ assert q.dtype in [torch.float16, torch.bfloat16]
159
+ assert q.is_cuda and kv.is_cuda
160
+ causal = self.causal if causal is None else causal
161
+ unpadded = cu_seqlens is not None
162
+
163
+ if unpadded:
164
+ assert cu_seqlens.dtype == torch.int32
165
+ assert max_seqlen is not None
166
+ assert isinstance(max_seqlen, int)
167
+ assert cu_seqlens_k is not None
168
+ assert cu_seqlens_k.dtype == torch.int32
169
+ assert max_seqlen_k is not None
170
+ assert isinstance(max_seqlen, int)
171
+ return flash_attn_varlen_kvpacked_func(
172
+ q,
173
+ kv,
174
+ cu_seqlens,
175
+ cu_seqlens_k,
176
+ max_seqlen,
177
+ max_seqlen_k,
178
+ self.drop.p if self.training else 0.0,
179
+ softmax_scale=self.softmax_scale,
180
+ causal=causal,
181
+ alibi_slopes=None,
182
+ window_size=self.window_size,
183
+ deterministic=self.deterministic,
184
+ )
185
+ else:
186
+ batch_size, seqlen_q = q.shape[0], q.shape[1]
187
+ seqlen_k = kv.shape[1]
188
+ assert kv.shape[0] == batch_size and kv.shape[4] == q.shape[3]
189
+ return flash_attn_kvpacked_func(
190
+ q,
191
+ kv,
192
+ self.drop.p if self.training else 0.0,
193
+ causal=causal,
194
+ softmax_scale=self.softmax_scale,
195
+ alibi_slopes=None,
196
+ window_size=self.window_size,
197
+ deterministic=self.deterministic,
198
+ )
199
+
200
+
201
+ class SelfAttention(nn.Module):
202
+ """Implement the scaled dot product attention with softmax.
203
+ Arguments
204
+ ---------
205
+ softmax_scale: The temperature to use for the softmax attention.
206
+ (default: 1/sqrt(d_keys) where d_keys is computed at
207
+ runtime)
208
+ attention_dropout: The dropout rate to apply to the attention
209
+ (default: 0.0)
210
+ """
211
+
212
+ def __init__(self, causal=False, softmax_scale=None, attention_dropout=0.0):
213
+ super().__init__()
214
+ self.causal = causal
215
+ self.softmax_scale = softmax_scale
216
+ self.drop = nn.Dropout(attention_dropout)
217
+
218
+ def forward(self, qkv, causal=None, key_padding_mask=None):
219
+ """Implements the multihead softmax attention.
220
+ Arguments
221
+ ---------
222
+ qkv: The tensor containing the query, key, and value. (B, S, 3, H, D)
223
+ causal: if passed, will override self.causal
224
+ key_padding_mask: boolean mask to apply to the attention weights. True means to keep,
225
+ False means to mask out. (B, S)
226
+ """
227
+ batch_size, seqlen = qkv.shape[0], qkv.shape[1]
228
+ causal = self.causal if causal is None else causal
229
+ q, k, v = qkv.unbind(dim=2)
230
+ softmax_scale = self.softmax_scale or 1.0 / math.sqrt(q.shape[-1])
231
+ scores = torch.einsum("bthd,bshd->bhts", q, k * softmax_scale)
232
+ if key_padding_mask is not None:
233
+ padding_mask = torch.full(
234
+ (batch_size, seqlen), -10000.0, dtype=scores.dtype, device=scores.device
235
+ )
236
+ padding_mask.masked_fill_(key_padding_mask, 0.0)
237
+ # TD [2022-09-30]: Adding is faster than masked_fill_ (idk why, just better kernel I guess)
238
+ scores = scores + rearrange(padding_mask, "b s -> b 1 1 s")
239
+ if causal:
240
+ # "triu_tril_cuda_template" not implemented for 'BFloat16'
241
+ # So we have to construct the mask in float
242
+ causal_mask = torch.triu(
243
+ torch.full((seqlen, seqlen), -10000.0, device=scores.device), 1
244
+ )
245
+ # TD [2022-09-30]: Adding is faster than masked_fill_ (idk why, just better kernel I guess)
246
+ scores = scores + causal_mask.to(dtype=scores.dtype)
247
+ attention = torch.softmax(scores, dim=-1, dtype=v.dtype)
248
+ attention_drop = self.drop(attention)
249
+ output = torch.einsum("bhts,bshd->bthd", attention_drop, v)
250
+ return output
251
+
252
+
253
+ class CrossAttention(nn.Module):
254
+ """Implement the scaled dot product attention with softmax.
255
+ Arguments
256
+ ---------
257
+ softmax_scale: The temperature to use for the softmax attention.
258
+ (default: 1/sqrt(d_keys) where d_keys is computed at
259
+ runtime)
260
+ attention_dropout: The dropout rate to apply to the attention
261
+ (default: 0.0)
262
+ """
263
+
264
+ def __init__(self, causal=False, softmax_scale=None, attention_dropout=0.0):
265
+ super().__init__()
266
+ self.causal = causal
267
+ self.softmax_scale = softmax_scale
268
+ self.drop = nn.Dropout(attention_dropout)
269
+
270
+ def forward(self, q, kv, causal=None, key_padding_mask=None):
271
+ """Implements the multihead softmax attention.
272
+ Arguments
273
+ ---------
274
+ q: The tensor containing the query. (B, Sq, H, D)
275
+ kv: The tensor containing the key and value. (B, Sk, 2, H_k, D)
276
+ causal: if passed, will override self.causal
277
+ key_padding_mask: boolean mask to apply to the attention weights. True means to keep,
278
+ False means to mask out. (B, Sk)
279
+ """
280
+ batch_size, seqlen_q = q.shape[0], q.shape[1]
281
+ causal = self.causal if causal is None else causal
282
+ seqlen_k = kv.shape[1]
283
+ assert kv.shape[0] == batch_size and kv.shape[4] == q.shape[3]
284
+ if kv.shape[3] != q.shape[2]: # MQA/GQA
285
+ kv = repeat(kv, "... hkv d -> ... (hkv g) d", g=q.shape[2] // kv.shape[3])
286
+ k, v = kv.unbind(dim=2)
287
+ softmax_scale = self.softmax_scale or 1.0 / math.sqrt(q.shape[-1])
288
+ scores = torch.einsum("bthd,bshd->bhts", q, k * softmax_scale)
289
+ if key_padding_mask is not None:
290
+ padding_mask = torch.full(
291
+ (batch_size, seqlen_k), -10000.0, dtype=scores.dtype, device=scores.device
292
+ )
293
+ padding_mask.masked_fill_(key_padding_mask, 0.0)
294
+ # TD [2022-09-30]: Adding is faster than masked_fill_ (idk why, just better kernel I guess)
295
+ scores = scores + rearrange(padding_mask, "b s -> b 1 1 s")
296
+ if causal:
297
+ # causal mask needs to take into account the difference between seqlen_q and seqlen_k
298
+ row_idx = rearrange(
299
+ torch.arange(seqlen_q, device=q.device, dtype=torch.long), "s -> s 1"
300
+ )
301
+ col_idx = torch.arange(seqlen_k, device=kv.device, dtype=torch.long)
302
+ sk = (
303
+ seqlen_k
304
+ if key_padding_mask is None
305
+ else rearrange(key_padding_mask.sum(-1), "b -> b 1 1 1")
306
+ )
307
+ causal_mask = col_idx > row_idx + sk - seqlen_q
308
+ scores = scores.masked_fill(causal_mask, -10000.0)
309
+ attention = torch.softmax(scores, dim=-1, dtype=v.dtype)
310
+ attention_drop = self.drop(attention)
311
+ output = torch.einsum("bhts,bshd->bthd", attention_drop, v)
312
+ return output
313
+
314
+
315
+ class LinearResidual(nn.Linear):
316
+ """Wrap nn.Linear to return the residual as well. For compatibility with FusedDense."""
317
+
318
+ def forward(self, input: torch.Tensor) -> torch.Tensor:
319
+ return super().forward(input), input
320
+
321
+
322
+ def _update_kv_cache(kv, inference_params, layer_idx):
323
+ """kv: (batch_size, seqlen, 2, nheads, head_dim) or (batch_size, 1, 2, nheads, head_dim)"""
324
+ # Pre-allocate memory for key-values for inference.
325
+ num_heads, head_dim = kv.shape[-2:]
326
+ if layer_idx not in inference_params.key_value_memory_dict:
327
+ kv_cache = torch.empty(
328
+ inference_params.max_batch_size,
329
+ inference_params.max_seqlen,
330
+ 2,
331
+ num_heads,
332
+ head_dim,
333
+ dtype=kv.dtype,
334
+ device=kv.device,
335
+ )
336
+ inference_params.key_value_memory_dict[layer_idx] = kv_cache
337
+ else:
338
+ kv_cache = inference_params.key_value_memory_dict[layer_idx]
339
+ # Adjust key and value for inference
340
+ batch_start = inference_params.batch_size_offset
341
+ batch_end = batch_start + kv.shape[0]
342
+ sequence_start = inference_params.seqlen_offset
343
+ sequence_end = sequence_start + kv.shape[1]
344
+ assert batch_end <= kv_cache.shape[0]
345
+ assert sequence_end <= kv_cache.shape[1]
346
+ assert kv_cache is not None
347
+ kv_cache[batch_start:batch_end, sequence_start:sequence_end, ...] = kv
348
+ return kv_cache[batch_start:batch_end, :sequence_end, ...]
349
+
350
+
351
+ class MHA(nn.Module):
352
+ """Multi-head self-attention and cross-attention"""
353
+
354
+ def __init__(
355
+ self,
356
+ embed_dim,
357
+ num_heads,
358
+ num_heads_kv=None,
359
+ cross_attn=False,
360
+ qkv_proj_bias=True,
361
+ out_proj_bias=True,
362
+ dropout=0.0,
363
+ softmax_scale=None,
364
+ causal=False,
365
+ layer_idx=None,
366
+ dwconv=False,
367
+ window_size=(-1, -1),
368
+ fused_bias_fc=False,
369
+ use_flash_attn=False,
370
+ return_residual=False,
371
+ checkpointing=False,
372
+ device=None,
373
+ dtype=None,
374
+ ) -> None:
375
+ """
376
+ num_heads_kv: can be used to toggle MQA / GQA. If None, use num_heads.
377
+ return_residual: whether to return the input x along with the output. This is for
378
+ performance reason: for post-norm architecture, returning the input allows us
379
+ to fuse the backward of nn.Linear with the residual connection.
380
+ """
381
+ factory_kwargs = {"device": device, "dtype": dtype}
382
+ super().__init__()
383
+ self.embed_dim = embed_dim
384
+ self.cross_attn = cross_attn
385
+ self.causal = causal
386
+ self.layer_idx = layer_idx
387
+ self.dwconv = dwconv
388
+ self.use_flash_attn = use_flash_attn
389
+ self.return_residual = return_residual
390
+ self.checkpointing = checkpointing
391
+
392
+ if window_size != (-1, -1):
393
+ assert use_flash_attn, "Local (sliding window) attention code path requires flash_attn"
394
+
395
+ self.num_heads = num_heads
396
+ self.num_heads_kv = num_heads_kv if num_heads_kv is not None else num_heads
397
+ assert (
398
+ self.num_heads % self.num_heads_kv == 0
399
+ ), "num_heads must be divisible by num_heads_kv"
400
+ assert self.embed_dim % num_heads == 0, "embed_dim must be divisible by num_heads"
401
+ self.head_dim = self.embed_dim // num_heads
402
+ qkv_dim = self.head_dim * (self.num_heads + 2 * self.num_heads_kv)
403
+ kv_dim = 2 * self.head_dim * self.num_heads_kv
404
+
405
+ if fused_bias_fc and FusedDense is None:
406
+ raise ImportError("fused_dense is not installed")
407
+ linear_cls = nn.Linear if not fused_bias_fc else FusedDense
408
+ linear_resid_cls = (
409
+ LinearResidual if not fused_bias_fc else partial(FusedDense, return_residual=True)
410
+ )
411
+ wqkv_cls = linear_cls if not self.return_residual else linear_resid_cls
412
+ inner_attn_cls = (
413
+ partial(FlashSelfAttention, window_size=window_size)
414
+ if use_flash_attn
415
+ else SelfAttention
416
+ )
417
+ inner_cross_attn_cls = (
418
+ partial(FlashCrossAttention, window_size=window_size)
419
+ if use_flash_attn
420
+ else CrossAttention
421
+ )
422
+ if not self.cross_attn:
423
+ self.Wqkv = wqkv_cls(embed_dim, qkv_dim, bias=qkv_proj_bias, **factory_kwargs)
424
+ else:
425
+ self.Wq = linear_cls(embed_dim, embed_dim, bias=qkv_proj_bias, **factory_kwargs)
426
+ self.Wkv = wqkv_cls(embed_dim, kv_dim, bias=qkv_proj_bias, **factory_kwargs)
427
+ if self.dwconv:
428
+ if self.num_heads_kv == self.num_heads:
429
+ self.dwconv_qkv = nn.Conv1d(
430
+ qkv_dim, qkv_dim, kernel_size=3, padding=2, groups=qkv_dim
431
+ )
432
+ else:
433
+ self.dwconv_q = nn.Conv1d(
434
+ embed_dim, embed_dim, kernel_size=3, padding=2, groups=embed_dim
435
+ )
436
+ self.dwconv_kv = nn.Conv1d(kv_dim, kv_dim, kernel_size=3, padding=2, groups=kv_dim)
437
+ self.inner_attn = inner_attn_cls(
438
+ causal=causal,
439
+ softmax_scale=softmax_scale,
440
+ attention_dropout=dropout,
441
+ )
442
+ self.inner_cross_attn = inner_cross_attn_cls(
443
+ causal=causal, softmax_scale=softmax_scale, attention_dropout=dropout
444
+ )
445
+ self.out_proj = linear_cls(embed_dim, embed_dim, bias=out_proj_bias, **factory_kwargs)
446
+
447
+ def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None):
448
+ dtype = self.out_proj.weight.dtype if dtype is None else dtype
449
+ device = self.out_proj.weight.device
450
+ return torch.empty(
451
+ batch_size,
452
+ max_seqlen,
453
+ 2,
454
+ self.num_heads_kv,
455
+ self.head_dim,
456
+ dtype=dtype,
457
+ device=device,
458
+ )
459
+
460
+ def _update_kv_cache(self, kv, inference_params):
461
+ """kv: (batch_size, seqlen, 2, nheads, head_dim) or (batch_size, 1, 2, nheads, head_dim)"""
462
+ assert not self.dwconv, "Generation does not support dwconv yet"
463
+ assert self.layer_idx is not None, "Generation requires layer_idx in the constructor"
464
+ return _update_kv_cache(kv, inference_params, self.layer_idx)
465
+
466
+ def _apply_rotary_update_kvcache_attention(self, q, kv, inference_params):
467
+ """
468
+ Fast path that combine 3 steps: apply rotary to Q and K, update kv cache, and apply attention.
469
+ q: (batch_size, seqlen_q, nheads, head_dim)
470
+ kv: (batch_size, seqlen_k, 2, nheads_kv, head_dim)
471
+ """
472
+ assert inference_params is not None and inference_params.seqlen_offset > 0
473
+ assert self.use_flash_attn
474
+ batch = q.shape[0]
475
+ kv_cache = inference_params.key_value_memory_dict[self.layer_idx][:batch]
476
+ cache_seqlens = (
477
+ inference_params.lengths_per_sample[:batch]
478
+ if inference_params.lengths_per_sample is not None
479
+ else inference_params.seqlen_offset
480
+ )
481
+ context = flash_attn_with_kvcache(
482
+ q,
483
+ kv_cache[:, :, 0],
484
+ kv_cache[:, :, 1],
485
+ kv[:, :, 0],
486
+ kv[:, :, 1],
487
+ cache_seqlens=cache_seqlens,
488
+ softmax_scale=self.inner_cross_attn.softmax_scale,
489
+ causal=self.inner_cross_attn.causal,
490
+ rotary_interleaved=False,
491
+ alibi_slopes=None,
492
+ )
493
+ return context
494
+
495
+ def _update_kvcache_attention(self, q, kv, inference_params):
496
+ """Write kv to inference_params, then do attention"""
497
+ if (
498
+ inference_params.seqlen_offset == 0
499
+ or flash_attn_with_kvcache is None
500
+ or not self.use_flash_attn
501
+ ):
502
+ # TODO: this only uses seqlen_offset and not lengths_per_sample.
503
+ kv = self._update_kv_cache(kv, inference_params)
504
+ return self.inner_cross_attn(q, kv)
505
+ else:
506
+ batch = q.shape[0]
507
+ kv_cache = inference_params.key_value_memory_dict[self.layer_idx][:batch]
508
+ cache_seqlens = (
509
+ inference_params.lengths_per_sample[:batch]
510
+ if inference_params.lengths_per_sample is not None
511
+ else inference_params.seqlen_offset
512
+ )
513
+ return flash_attn_with_kvcache(
514
+ q,
515
+ kv_cache[:, :, 0],
516
+ kv_cache[:, :, 1],
517
+ kv[:, :, 0],
518
+ kv[:, :, 1],
519
+ cache_seqlens=cache_seqlens,
520
+ softmax_scale=self.inner_cross_attn.softmax_scale,
521
+ causal=self.inner_cross_attn.causal,
522
+ alibi_slopes=None,
523
+ )
524
+
525
+ def forward(
526
+ self,
527
+ x,
528
+ x_kv=None,
529
+ key_padding_mask=None,
530
+ cu_seqlens=None,
531
+ max_seqlen=None,
532
+ mixer_subset=None,
533
+ inference_params=None,
534
+ **kwargs,
535
+ ):
536
+ """
537
+ Arguments:
538
+ x: (batch, seqlen, hidden_dim) (where hidden_dim = num heads * head dim) if
539
+ cu_seqlens is None and max_seqlen is None, else (total, hidden_dim) where total
540
+ is the is the sum of the sequence lengths in the batch.
541
+ x_kv: (batch, seqlen, hidden_dim), only applicable for cross-attention. If None, use x.
542
+ cu_seqlens: (batch_size + 1,), dtype torch.int32. The cumulative sequence lengths
543
+ of the sequences in the batch, used to index into x. Only applicable when using
544
+ FlashAttention.
545
+ max_seqlen: int. Maximum sequence length in the batch.
546
+ key_padding_mask: boolean mask, True means to keep, False means to mask out.
547
+ (batch, seqlen). Only applicable when not using FlashAttention.
548
+ mixer_subset: for cross-attention only. If not None, will take a subset of x
549
+ before applying the query projection. Useful for e.g., ViT where we only care
550
+ about the CLS token in the last layer.
551
+ inference_params: for generation. Adapted from Megatron-LM (and Apex)
552
+ https://github.com/NVIDIA/apex/blob/3ff1a10f72ec07067c4e44759442329804ac5162/apex/transformer/testing/standalone_transformer_lm.py#L470
553
+ """
554
+ if cu_seqlens is not None:
555
+ assert max_seqlen is not None
556
+ assert key_padding_mask is None
557
+ assert self.use_flash_attn
558
+ assert not self.dwconv
559
+ if key_padding_mask is not None:
560
+ assert cu_seqlens is None
561
+ assert max_seqlen is None
562
+ assert not self.use_flash_attn
563
+ if inference_params is not None:
564
+ assert key_padding_mask is None
565
+ assert cu_seqlens is None and max_seqlen is None
566
+ assert not self.dwconv
567
+
568
+ kwargs = (
569
+ {"cu_seqlens": cu_seqlens, "max_seqlen": max_seqlen, **kwargs}
570
+ if self.use_flash_attn
571
+ else {"key_padding_mask": key_padding_mask, **kwargs}
572
+ )
573
+ seqlen_offset = (
574
+ 0
575
+ if inference_params is None
576
+ else (
577
+ inference_params.lengths_per_sample
578
+ if inference_params.lengths_per_sample is not None
579
+ else inference_params.seqlen_offset
580
+ )
581
+ )
582
+ rotary_max_seqlen = (
583
+ inference_params.max_sequence_len if inference_params is not None else max_seqlen
584
+ )
585
+ batch, seqlen = x.shape[:2]
586
+ if not self.cross_attn and self.num_heads_kv == self.num_heads:
587
+ assert x_kv is None and mixer_subset is None
588
+ if not self.return_residual:
589
+ qkv = self.Wqkv(x)
590
+ else:
591
+ qkv, x = self.Wqkv(x)
592
+ if self.dwconv:
593
+ qkv = rearrange(
594
+ self.dwconv_qkv(rearrange(qkv, "b s d -> b d s"))[..., :-2], "b d s -> b s d"
595
+ ).contiguous()
596
+ qkv = rearrange(qkv, "... (three h d) -> ... three h d", three=3, d=self.head_dim)
597
+ if (
598
+ inference_params is None
599
+ or inference_params.seqlen_offset == 0
600
+ or (self.rotary_emb_dim == 0 or self.rotary_emb_dim % 16 != 0)
601
+ or not self.use_flash_attn
602
+ ):
603
+ if inference_params is None:
604
+ if not self.checkpointing:
605
+ context = self.inner_attn(qkv, **kwargs)
606
+ else:
607
+ context = torch.utils.checkpoint.checkpoint(self.inner_attn, qkv, **kwargs)
608
+ else:
609
+ context = self._update_kvcache_attention(
610
+ qkv[:, :, 0], qkv[:, :, 1:], inference_params
611
+ )
612
+ else:
613
+ context = self._apply_rotary_update_kvcache_attention(
614
+ qkv[:, :, 0], qkv[:, :, 1:], inference_params
615
+ )
616
+ else:
617
+ if self.cross_attn:
618
+ if not self.return_residual:
619
+ q = self.Wq(x if mixer_subset is None else x[:, mixer_subset])
620
+ kv = self.Wkv(x_kv if x_kv is not None else x)
621
+ else:
622
+ if x_kv is not None:
623
+ kv, x_kv = self.Wkv(x_kv)
624
+ else:
625
+ kv, x = self.Wkv(x)
626
+ q = self.Wq(x if mixer_subset is None else x[:, mixer_subset])
627
+ else:
628
+ assert self.num_heads_kv != self.num_heads
629
+ if not self.return_residual:
630
+ qkv = self.Wqkv(x)
631
+ else:
632
+ qkv, x = self.Wqkv(x)
633
+ q = qkv[..., : self.num_heads * self.head_dim]
634
+ kv = qkv[..., self.num_heads * self.head_dim :]
635
+ q = rearrange(q, "... (h d) -> ... h d", d=self.head_dim)
636
+ kv = rearrange(kv, "... (two hkv d) -> ... two hkv d", two=2, d=self.head_dim)
637
+ if self.dwconv:
638
+ q = rearrange(
639
+ self.dwconv_q(rearrange(q, "b s d -> b d s"))[..., :-2], "b d s -> b s d"
640
+ ).contiguous()
641
+ kv = rearrange(
642
+ self.dwconv_kv(rearrange(kv, "b s d -> b d s"))[..., :-2], "b d s -> b s d"
643
+ ).contiguous()
644
+ if (
645
+ inference_params is None
646
+ or inference_params.seqlen_offset == 0
647
+ or (self.rotary_emb_dim == 0 or self.rotary_emb_dim % 16 != 0)
648
+ or not self.use_flash_attn
649
+ ):
650
+ if inference_params is None:
651
+ if not self.checkpointing:
652
+ context = self.inner_cross_attn(q, kv, **kwargs)
653
+ else:
654
+ context = torch.utils.checkpoint.checkpoint(
655
+ self.inner_cross_attn, q, kv, **kwargs
656
+ )
657
+ else:
658
+ context = self._update_kvcache_attention(q, kv, inference_params)
659
+ else:
660
+ context = self._apply_rotary_update_kvcache_attention(q, kv, inference_params)
661
+ out = self.out_proj(rearrange(context, "... h d -> ... (h d)"))
662
+ return out if not self.return_residual else (out, x)
mlp.py ADDED
@@ -0,0 +1,194 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # This implementation was adapted from https://github.com/Dao-AILab/flash-attention/blob/main/flash_attn/modules/mlp.py
2
+ # Commit id: c3b219665292c61a51153d0ded4473c494296382
3
+
4
+ # Copyright (c) 2023, Tri Dao.
5
+
6
+ import torch
7
+ import torch.nn as nn
8
+ import torch.nn.functional as F
9
+ from torch.distributed import ProcessGroup
10
+
11
+
12
+ try:
13
+ from flash_attn.ops.activations import swiglu
14
+ except ImportError:
15
+ swiglu = None
16
+
17
+ try:
18
+ from flash_attn.ops.fused_dense import ColumnParallelLinear, RowParallelLinear
19
+ except ImportError:
20
+ ColumnParallelLinear, RowParallelLinear = None, None
21
+
22
+ try:
23
+ from flash_attn.ops.fused_dense import FusedMLP, ParallelFusedMLP
24
+ except ImportError:
25
+ FusedMLP, ParallelFusedMLP = None, None
26
+
27
+
28
+ class Mlp(nn.Module):
29
+ def __init__(
30
+ self,
31
+ in_features,
32
+ hidden_features=None,
33
+ out_features=None,
34
+ activation=F.gelu,
35
+ bias1=True,
36
+ bias2=True,
37
+ return_residual=False,
38
+ device=None,
39
+ dtype=None,
40
+ ):
41
+ factory_kwargs = {"device": device, "dtype": dtype}
42
+ super().__init__()
43
+ out_features = out_features if out_features is not None else in_features
44
+ hidden_features = hidden_features if hidden_features is not None else in_features * 4
45
+ self.return_residual = return_residual
46
+ self.fc1 = nn.Linear(in_features, hidden_features, bias=bias1, **factory_kwargs)
47
+ self.activation = activation
48
+ self.fc2 = nn.Linear(hidden_features, out_features, bias=bias2, **factory_kwargs)
49
+
50
+ def forward(self, x):
51
+ y = self.fc1(x)
52
+ y = self.activation(y)
53
+ y = self.fc2(y)
54
+ return y if not self.return_residual else (y, x)
55
+
56
+
57
+ class ParallelMLP(nn.Module):
58
+ def __init__(
59
+ self,
60
+ in_features,
61
+ hidden_features=None,
62
+ out_features=None,
63
+ activation=F.gelu,
64
+ process_group: ProcessGroup = None,
65
+ sequence_parallel=True,
66
+ bias1=True,
67
+ bias2=True,
68
+ device=None,
69
+ dtype=None,
70
+ ):
71
+ factory_kwargs = {"device": device, "dtype": dtype}
72
+ super().__init__()
73
+ assert ColumnParallelLinear is not None, "Need to install fused_dense"
74
+ assert RowParallelLinear is not None, "Need to install fused_dense"
75
+ out_features = out_features if out_features is not None else in_features
76
+ hidden_features = hidden_features if hidden_features is not None else in_features * 4
77
+ self.fc1 = ColumnParallelLinear(
78
+ in_features,
79
+ hidden_features,
80
+ process_group,
81
+ bias=bias1,
82
+ sequence_parallel=sequence_parallel,
83
+ **factory_kwargs,
84
+ )
85
+ self.activation = activation
86
+ self.fc2 = RowParallelLinear(
87
+ hidden_features,
88
+ out_features,
89
+ process_group,
90
+ bias=bias2,
91
+ sequence_parallel=sequence_parallel,
92
+ **factory_kwargs,
93
+ )
94
+
95
+ def forward(self, x):
96
+ y = self.fc1(x)
97
+ y = self.activation(y)
98
+ y = self.fc2(y)
99
+ return y
100
+
101
+
102
+ class GatedMlp(nn.Module):
103
+ def __init__(
104
+ self,
105
+ in_features,
106
+ hidden_features=None,
107
+ out_features=None,
108
+ activation=F.sigmoid,
109
+ bias1=True,
110
+ bias2=True,
111
+ multiple_of=128,
112
+ return_residual=False,
113
+ device=None,
114
+ dtype=None,
115
+ ):
116
+ factory_kwargs = {"device": device, "dtype": dtype}
117
+ super().__init__()
118
+ out_features = out_features if out_features is not None else in_features
119
+ hidden_features = (
120
+ hidden_features if hidden_features is not None else int(8 * in_features / 3)
121
+ )
122
+ hidden_features = (hidden_features + multiple_of - 1) // multiple_of * multiple_of
123
+ self.return_residual = return_residual
124
+ self.fc1 = nn.Linear(in_features, 2 * hidden_features, bias=bias1, **factory_kwargs)
125
+ self.activation = activation
126
+ self.fc2 = nn.Linear(hidden_features, out_features, bias=bias2, **factory_kwargs)
127
+
128
+ def forward(self, x):
129
+ y = self.fc1(x)
130
+ if self.activation == F.sigmoid: # Special case for GLU
131
+ y = F.glu(y, dim=-1)
132
+ elif self.activation == F.silu and swiglu is not None: # Special case for SwiGLU
133
+ y, gate = y.chunk(2, dim=-1)
134
+ y = swiglu(gate, y)
135
+ else:
136
+ y, gate = y.chunk(2, dim=-1)
137
+ y = y * self.activation(gate)
138
+ y = self.fc2(y)
139
+ return y if not self.return_residual else (y, x)
140
+
141
+
142
+ class ParallelGatedMlp(nn.Module):
143
+ """Parallel GatedMlp"""
144
+
145
+ def __init__(
146
+ self,
147
+ in_features,
148
+ process_group,
149
+ hidden_features=None,
150
+ out_features=None,
151
+ activation=F.sigmoid,
152
+ bias1=True,
153
+ bias2=True,
154
+ multiple_of=128,
155
+ sequence_parallel=True,
156
+ device=None,
157
+ dtype=None,
158
+ ):
159
+ factory_kwargs = {"device": device, "dtype": dtype}
160
+ super().__init__()
161
+ out_features = out_features if out_features is not None else in_features
162
+ hidden_features = (
163
+ hidden_features if hidden_features is not None else int(8 * in_features / 3)
164
+ )
165
+ hidden_features = (hidden_features + multiple_of - 1) // multiple_of * multiple_of
166
+ if ColumnParallelLinear is None or RowParallelLinear is None:
167
+ raise ImportError("fused_dense is not installed")
168
+ self.fc1 = ColumnParallelLinear(
169
+ in_features,
170
+ 2 * hidden_features,
171
+ process_group,
172
+ bias=bias1,
173
+ sequence_parallel=sequence_parallel,
174
+ **factory_kwargs,
175
+ )
176
+ self.activation = activation
177
+ self.fc2 = RowParallelLinear(
178
+ hidden_features,
179
+ out_features,
180
+ process_group,
181
+ bias=bias2,
182
+ sequence_parallel=sequence_parallel,
183
+ **factory_kwargs,
184
+ )
185
+
186
+ def forward(self, x):
187
+ y = self.fc1(x)
188
+ if self.activation == F.sigmoid: # Special case for GLU
189
+ y = F.glu(y, dim=-1)
190
+ else:
191
+ y, gate = y.chunk(2, dim=-1)
192
+ y = y * self.activation(gate)
193
+ y = self.fc2(y)
194
+ return y
model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:5c7c4b46d9d342fe128e46a55adbac52fd96dd0aaafee7c26e8d8f2286bee91a
3
+ size 556892306
modeling_xlm_roberta.py ADDED
@@ -0,0 +1,1119 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # This implementation was adopted from https://github.com/Dao-AILab/flash-attention/blob/main/flash_attn/models/bert.py
2
+ # Commit id: abbc1311731867310635f9edc2a9ec18317c8c48
3
+ # Copyright (c) 2022, Tri Dao.
4
+ # This BERT implementation is based on our MLPerf 2.0 and MLPerf 2.1 BERT implementation.
5
+ # https://github.com/mlcommons/training_results_v2.0/blob/main/HazyResearch/benchmarks/bert/implementations/pytorch/modeling.py
6
+ # https://github.com/mlcommons/training_results_v2.1/blob/main/Azure-HazyResearch/benchmarks/bert/implementations/ND96amsr_A100_v4/modeling.py
7
+
8
+ # Inspired by https://github.com/huggingface/transformers/blob/main/src/transformers/models/bert/modeling_bert.py
9
+
10
+ import importlib.util
11
+ import logging
12
+ import re
13
+ from collections import OrderedDict
14
+ from collections.abc import Sequence
15
+ from functools import partial
16
+ import numpy as np
17
+
18
+ import torch
19
+ import torch.nn as nn
20
+ import torch.nn.functional as F
21
+ import torch.utils.checkpoint
22
+ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
23
+ from einops import rearrange
24
+ from transformers import PretrainedConfig
25
+ from transformers.modeling_utils import PreTrainedModel
26
+ from transformers.modeling_outputs import MaskedLMOutput,SequenceClassifierOutput
27
+ from transformers.models.xlm_roberta.modeling_xlm_roberta import XLMRobertaLMHead
28
+
29
+ from transformers.models.bert.modeling_bert import (
30
+ BaseModelOutputWithPoolingAndCrossAttentions,
31
+ BertForPreTrainingOutput,
32
+ )
33
+
34
+ from typing import List, Optional, Tuple, Union
35
+
36
+ from .xlm_padding import (
37
+ index_first_axis,
38
+ index_first_axis_residual,
39
+ pad_input,
40
+ unpad_input,
41
+ )
42
+ from .configuration_xlm_roberta import XLMRobertaFlashConfig
43
+ from .block import Block
44
+ from .embedding import XLMRobertaEmbeddings
45
+ from .mha import MHA
46
+ from .mlp import FusedMLP, Mlp
47
+
48
+ try:
49
+ from flash_attn.ops.fused_dense import FusedDense
50
+ except ImportError:
51
+ FusedDense = None
52
+
53
+ try:
54
+ from flash_attn.ops.triton.layer_norm import layer_norm_fn
55
+ except ImportError:
56
+ layer_norm_fn = None
57
+
58
+
59
+ try:
60
+ from flash_attn.losses.cross_entropy import CrossEntropyLoss
61
+ except ImportError:
62
+ CrossEntropyLoss = torch.nn.CrossEntropyLoss
63
+
64
+ try:
65
+ from tqdm.autonotebook import trange
66
+ except ImportError:
67
+ trange = None
68
+
69
+
70
+ logger = logging.getLogger(__name__)
71
+
72
+
73
+ def get_use_flash_attn(config: XLMRobertaFlashConfig):
74
+ if not getattr(config, "use_flash_attn", False):
75
+ return False
76
+ if not torch.cuda.is_available():
77
+ return False
78
+ if importlib.util.find_spec("flash_attn") is None:
79
+ logger.warning(
80
+ 'flash_attn is not installed. Using PyTorch native attention implementation.'
81
+ )
82
+ return False
83
+ return True
84
+
85
+
86
+ def create_mixer_cls(config, cross_attn=False, return_residual=False):
87
+ use_flash_attn = get_use_flash_attn(config)
88
+ fused_bias_fc = getattr(config, "fused_bias_fc", False)
89
+
90
+ mixer_cls = partial(
91
+ MHA,
92
+ num_heads=config.num_attention_heads,
93
+ cross_attn=cross_attn,
94
+ dropout=config.attention_probs_dropout_prob,
95
+ causal=False,
96
+ fused_bias_fc=fused_bias_fc,
97
+ use_flash_attn=use_flash_attn,
98
+ return_residual=return_residual,
99
+ )
100
+ return mixer_cls
101
+
102
+
103
+ def create_mlp_cls(config, layer_idx=None, return_residual=False):
104
+ inner_dim = config.intermediate_size
105
+ fused_mlp = getattr(config, "fused_mlp", False)
106
+ if fused_mlp:
107
+ assert config.hidden_act in ["gelu_new", "gelu_fast", "gelu_pytorch_tanh"], (
108
+ "fused_mlp only " "supports approximate gelu"
109
+ )
110
+ if not fused_mlp:
111
+ approximate = (
112
+ "tanh"
113
+ if config.hidden_act in ["gelu_new", "gelu_fast", "gelu_pytorch_tanh"]
114
+ else "none"
115
+ )
116
+ mlp_cls = partial(
117
+ Mlp,
118
+ hidden_features=inner_dim,
119
+ activation=partial(F.gelu, approximate=approximate),
120
+ return_residual=return_residual,
121
+ )
122
+ else:
123
+ if FusedMLP is None:
124
+ raise ImportError("fused_dense is not installed")
125
+ mlp_checkpoint_lvl = getattr(config, "mlp_checkpoint_lvl", 0)
126
+ # mlp_checkpoint_lvl could be a list, which contains the checkpoint_lvl for each layer
127
+ if isinstance(mlp_checkpoint_lvl, Sequence):
128
+ assert layer_idx is not None
129
+ mlp_checkpoint_lvl = mlp_checkpoint_lvl[layer_idx]
130
+ mlp_cls = partial(
131
+ FusedMLP,
132
+ hidden_features=inner_dim,
133
+ checkpoint_lvl=mlp_checkpoint_lvl,
134
+ return_residual=return_residual,
135
+ )
136
+ return mlp_cls
137
+
138
+
139
+ def create_block(config, layer_idx=None):
140
+ last_layer_subset = getattr(config, "last_layer_subset", False)
141
+ cross_attn = last_layer_subset and layer_idx == config.num_hidden_layers - 1
142
+ # TD [2022-12-19]: For cross attention (last layer), we actually want to return the
143
+ # residual x_kv, not residual x. But it's annoying to change the API (and it only affects
144
+ # one layer) so we just choose not to return residual in this case.
145
+ return_residual = not cross_attn
146
+ mixer_cls = create_mixer_cls(config, cross_attn, return_residual=return_residual)
147
+ mlp_cls = create_mlp_cls(config, layer_idx, return_residual=return_residual)
148
+ norm_cls = partial(nn.LayerNorm, eps=config.layer_norm_eps)
149
+ block = Block(
150
+ config.hidden_size,
151
+ mixer_cls,
152
+ mlp_cls,
153
+ norm_cls=norm_cls,
154
+ prenorm=False,
155
+ resid_dropout1=config.hidden_dropout_prob,
156
+ resid_dropout2=config.hidden_dropout_prob,
157
+ fused_dropout_add_ln=getattr(config, "fused_dropout_add_ln", False),
158
+ return_residual=return_residual,
159
+ )
160
+ return block
161
+
162
+
163
+ # https://github.com/huggingface/transformers/blob/7032e0203262ebb2ebf55da8d2e01f873973e835/src/transformers/models/bert/modeling_bert.py#L748
164
+ def _init_weights(module, initializer_range=0.02):
165
+ if isinstance(module, nn.Linear):
166
+ nn.init.normal_(module.weight, std=initializer_range)
167
+ if module.bias is not None:
168
+ nn.init.zeros_(module.bias)
169
+ elif isinstance(module, nn.Embedding):
170
+ nn.init.normal_(module.weight, std=initializer_range)
171
+ if module.padding_idx is not None:
172
+ nn.init.zeros_(module.weight[module.padding_idx])
173
+
174
+
175
+ class XLMRobertaEncoder(nn.Module):
176
+ def __init__(self, config: XLMRobertaFlashConfig):
177
+ super().__init__()
178
+ self.use_flash_attn = get_use_flash_attn(config)
179
+ self.layers = nn.ModuleList(
180
+ [create_block(config, layer_idx=i) for i in range(config.num_hidden_layers)]
181
+ )
182
+ self._grad_checkpointing = False
183
+
184
+ @property
185
+ def gradient_checkpointing(self):
186
+ return self._grad_checkpointing
187
+
188
+ @gradient_checkpointing.setter
189
+ def gradient_checkpointing(self, value):
190
+ self._grad_checkpointing = value
191
+
192
+ def forward(self, hidden_states, key_padding_mask=None, subset_mask=None):
193
+ """If subset_mask is not None, we only want output for the subset of the sequence.
194
+ This means that we only compute the last layer output for these tokens.
195
+ subset_mask: (batch, seqlen), dtype=torch.bool
196
+ """
197
+ if key_padding_mask is None or not self.use_flash_attn:
198
+ mixer_kwargs = (
199
+ {"key_padding_mask": key_padding_mask.bool()}
200
+ if key_padding_mask is not None
201
+ else None
202
+ )
203
+ for layer in self.layers:
204
+ if self._grad_checkpointing:
205
+ hidden_states = torch.utils.checkpoint.checkpoint(
206
+ layer,
207
+ hidden_states,
208
+ use_reentrant=False,
209
+ mixer_kwargs=mixer_kwargs,
210
+ )
211
+ else:
212
+ hidden_states = layer(hidden_states, mixer_kwargs=mixer_kwargs)
213
+ if subset_mask is not None:
214
+ hidden_states = hidden_states[subset_mask]
215
+ else:
216
+ batch, seqlen = hidden_states.shape[:2]
217
+ hidden_states, indices, cu_seqlens, max_seqlen_in_batch = unpad_input(
218
+ hidden_states, key_padding_mask
219
+ )
220
+ mixer_kwargs = {"cu_seqlens": cu_seqlens, "max_seqlen": max_seqlen_in_batch}
221
+ if subset_mask is None:
222
+ for layer in self.layers:
223
+ if self._grad_checkpointing:
224
+ hidden_states = torch.utils.checkpoint.checkpoint(
225
+ layer,
226
+ hidden_states,
227
+ use_reentrant=False,
228
+ mixer_kwargs=mixer_kwargs,
229
+ )
230
+ else:
231
+ hidden_states = layer(hidden_states, mixer_kwargs=mixer_kwargs)
232
+ hidden_states = pad_input(hidden_states, indices, batch, seqlen)
233
+ else:
234
+ for layer in self.layers[:-1]:
235
+ if self._grad_checkpointing:
236
+ hidden_states = torch.utils.checkpoint.checkpoint(
237
+ layer,
238
+ hidden_states,
239
+ use_reentrant=False,
240
+ mixer_kwargs=mixer_kwargs,
241
+ )
242
+ else:
243
+ hidden_states = layer(hidden_states, mixer_kwargs=mixer_kwargs)
244
+ if key_padding_mask is not None:
245
+ subset_idx = torch.nonzero(
246
+ subset_mask[key_padding_mask], as_tuple=False
247
+ ).flatten()
248
+ subset_seqlens = (subset_mask & key_padding_mask).sum(
249
+ dim=-1, dtype=torch.int32
250
+ )
251
+ subset_cu_seqlens = F.pad(
252
+ torch.cumsum(subset_seqlens, dim=0, dtype=torch.torch.int32),
253
+ (1, 0),
254
+ )
255
+ else:
256
+ subset_idx = torch.nonzero(subset_mask, as_tuple=False).flatten()
257
+ subset_seqlens = subset_mask.sum(dim=-1, dtype=torch.int32)
258
+ subset_cu_seqlens = F.pad(
259
+ torch.cumsum(subset_seqlens, dim=0, dtype=torch.torch.int32),
260
+ (1, 0),
261
+ )
262
+ hidden_states_subset, hidden_states = index_first_axis_residual(
263
+ hidden_states, subset_idx
264
+ )
265
+ # It's ok to set max_seqlen_q to be much larger
266
+ mixer_kwargs = {
267
+ "x_kv": hidden_states,
268
+ "cu_seqlens": subset_cu_seqlens,
269
+ "max_seqlen": max_seqlen_in_batch,
270
+ "cu_seqlens_k": cu_seqlens,
271
+ "max_seqlen_k": max_seqlen_in_batch,
272
+ }
273
+ if self._grad_checkpointing:
274
+ torch.utils.checkpoint.checkpoint(
275
+ self.layers[-1],
276
+ hidden_states_subset,
277
+ use_reentrant=False,
278
+ mixer_kwargs=mixer_kwargs,
279
+ )
280
+ else:
281
+ hidden_states = self.layers[-1](
282
+ hidden_states_subset, mixer_kwargs=mixer_kwargs
283
+ )
284
+ return hidden_states
285
+
286
+
287
+ class XLMRobertaPooler(nn.Module):
288
+ def __init__(self, config):
289
+ super().__init__()
290
+ fused_bias_fc = getattr(config, "fused_bias_fc", False)
291
+ if fused_bias_fc and FusedDense is None:
292
+ raise ImportError("fused_dense is not installed")
293
+ linear_cls = nn.Linear if not fused_bias_fc else FusedDense
294
+ self.dense = linear_cls(config.hidden_size, config.hidden_size)
295
+ self.activation = nn.Tanh()
296
+
297
+ def forward(self, hidden_states, pool=True):
298
+ # We "pool" the model by simply taking the hidden state corresponding
299
+ # to the first token.
300
+ first_token_tensor = hidden_states[:, 0] if pool else hidden_states
301
+ pooled_output = self.dense(first_token_tensor)
302
+ pooled_output = self.activation(pooled_output)
303
+ return pooled_output
304
+
305
+
306
+ class XLMRobertaPredictionHeadTransform(nn.Module):
307
+ def __init__(self, config):
308
+ super().__init__()
309
+ fused_bias_fc = getattr(config, "fused_bias_fc", False)
310
+ if fused_bias_fc and FusedDense is None:
311
+ raise ImportError("fused_dense is not installed")
312
+ self.fused_dropout_add_ln = getattr(config, "fused_dropout_add_ln", False)
313
+ if self.fused_dropout_add_ln and layer_norm_fn is None:
314
+ raise ImportError("Triton is not installed")
315
+ linear_cls = nn.Linear if not fused_bias_fc else FusedDense
316
+ self.dense = linear_cls(config.hidden_size, config.hidden_size)
317
+ approximate = (
318
+ "tanh"
319
+ if config.hidden_act in ["gelu_new", "gelu_fast", "gelu_pytorch_tanh"]
320
+ else "none"
321
+ )
322
+ self.transform_act_fn = nn.GELU(approximate=approximate)
323
+ self.layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
324
+
325
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
326
+ hidden_states = self.dense(hidden_states)
327
+ hidden_states = self.transform_act_fn(hidden_states)
328
+ if not self.fused_dropout_add_ln:
329
+ hidden_states = self.layer_norm(hidden_states)
330
+ else:
331
+ hidden_states = layer_norm_fn(
332
+ hidden_states,
333
+ self.layer_norm.weight,
334
+ self.layer_norm.bias,
335
+ eps=self.layer_norm.eps,
336
+ )
337
+ return hidden_states
338
+
339
+
340
+ class XLMRobertaLMPredictionHead(nn.Module):
341
+ def __init__(self, config):
342
+ super().__init__()
343
+ fused_bias_fc = getattr(config, "fused_bias_fc", False)
344
+ if fused_bias_fc and FusedDense is None:
345
+ raise ImportError("fused_dense is not installed")
346
+ linear_cls = nn.Linear if not fused_bias_fc else FusedDense
347
+
348
+ self.transform = XLMRobertaPredictionHeadTransform(config)
349
+
350
+ # The output weights are the same as the input embeddings, but there is
351
+ # an output-only bias for each token.
352
+ self.decoder = linear_cls(config.hidden_size, config.vocab_size, bias=True)
353
+
354
+ def forward(self, hidden_states):
355
+ hidden_states = self.transform(hidden_states)
356
+ hidden_states = self.decoder(hidden_states)
357
+ return hidden_states
358
+
359
+
360
+ class XLMRobertaPreTrainingHeads(nn.Module):
361
+ def __init__(self, config):
362
+ super().__init__()
363
+ self.predictions = XLMRobertaLMPredictionHead(config)
364
+ self.seq_relationship = nn.Linear(config.hidden_size, 2)
365
+
366
+ def forward(self, sequence_output, pooled_output):
367
+ prediction_scores = self.predictions(sequence_output)
368
+ seq_relationship_score = self.seq_relationship(pooled_output)
369
+ return prediction_scores, seq_relationship_score
370
+
371
+
372
+ class XLMRobertaPreTrainedModel(PreTrainedModel):
373
+ """An abstract class to handle weights initialization and
374
+ a simple interface for dowloading and loading pretrained models.
375
+ """
376
+
377
+ config_class = XLMRobertaFlashConfig
378
+ base_model_prefix = "roberta"
379
+ supports_gradient_checkpointing = True
380
+
381
+ def _set_gradient_checkpointing(self, module, value=False):
382
+ if isinstance(module, XLMRobertaEncoder):
383
+ module.gradient_checkpointing = value
384
+
385
+ @classmethod
386
+ def from_pretrained(
387
+ cls,
388
+ *args,
389
+ **kwargs,
390
+ ):
391
+ if not 'torch_dtype' in kwargs:
392
+ kwargs['torch_dtype'] = 'auto'
393
+ return super().from_pretrained(*args, **kwargs)
394
+
395
+
396
+
397
+ class XLMRobertaModel(XLMRobertaPreTrainedModel):
398
+ def __init__(self, config: XLMRobertaFlashConfig, add_pooling_layer=True):
399
+ super().__init__(config)
400
+ self.pad_vocab_size_multiple = getattr(config, "pad_vocab_size_multiple", 1)
401
+ if config.vocab_size % self.pad_vocab_size_multiple != 0:
402
+ config.vocab_size += self.pad_vocab_size_multiple - (
403
+ config.vocab_size % self.pad_vocab_size_multiple
404
+ )
405
+ self.fused_dropout_add_ln = getattr(config, "fused_dropout_add_ln", False)
406
+ if self.fused_dropout_add_ln and layer_norm_fn is None:
407
+ raise ImportError("Triton is not installed")
408
+ assert config.hidden_act in [
409
+ "gelu",
410
+ "gelu_new",
411
+ "gelu_fast",
412
+ "gelu_pytorch_tanh",
413
+ ]
414
+
415
+ self.embeddings = XLMRobertaEmbeddings(
416
+ config.hidden_size,
417
+ config.vocab_size,
418
+ config.max_position_embeddings if config.position_embedding_type == 'absolute' else -1,
419
+ config.type_vocab_size,
420
+ padding_idx=config.pad_token_id,
421
+ )
422
+ self.emb_drop = nn.Dropout(config.hidden_dropout_prob)
423
+ self.emb_ln = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
424
+ self.encoder = XLMRobertaEncoder(config)
425
+ self.pooler = XLMRobertaPooler(config) if add_pooling_layer else None
426
+
427
+ self.apply(partial(_init_weights, initializer_range=config.initializer_range))
428
+
429
+
430
+ @torch.inference_mode()
431
+ def encode(
432
+ self: 'XLMRobertaModel',
433
+ sentences: Union[str, List[str]],
434
+ batch_size: int = 32,
435
+ show_progress_bar: Optional[bool] = None,
436
+ output_value: str = 'sentence_embedding',
437
+ convert_to_numpy: bool = True,
438
+ convert_to_tensor: bool = False,
439
+ device: Optional[torch.device] = None,
440
+ normalize_embeddings: bool = False,
441
+ truncate_dim: Optional[int] = None,
442
+ **tokenizer_kwargs,
443
+ ) -> Union[List[torch.Tensor], np.ndarray, torch.Tensor]:
444
+ """
445
+ Computes sentence embeddings
446
+ Args:
447
+ sentences(`str` or `List[str]`):
448
+ Sentence or sentences to be encoded
449
+ batch_size(`int`, *optional*, defaults to 32):
450
+ Batch size for the computation
451
+ show_progress_bar(`bool`, *optional*, defaults to None):
452
+ Show a progress bar when encoding sentences.
453
+ If set to None, progress bar is only shown when
454
+ `logger.level == logging.INFO` or `logger.level == logging.DEBUG`.
455
+ output_value(`str`, *optional*, defaults to 'sentence_embedding'):
456
+ Default sentence_embedding, to get sentence embeddings.
457
+ Can be set to token_embeddings to get wordpiece token embeddings.
458
+ Set to None, to get all output values
459
+ convert_to_numpy(`bool`, *optional*, defaults to True):
460
+ If true, the output is a list of numpy vectors.
461
+ Else, it is a list of pytorch tensors.
462
+ convert_to_tensor(`bool`, *optional*, defaults to False):
463
+ If true, you get one large tensor as return.
464
+ Overwrites any setting from convert_to_numpy
465
+ device(`torch.device`, *optional*, defaults to None):
466
+ Which torch.device to use for the computation
467
+ normalize_embeddings(`bool`, *optional*, defaults to False):
468
+ If set to true, returned vectors will have length 1. In that case, the
469
+ faster dot-product (util.dot_score) instead of cosine similarity can
470
+ be used.
471
+ truncate_dim(`int`, *optional*, defaults to None):
472
+ The dimension to truncate sentence embeddings to. `None` does no truncation.
473
+ tokenizer_kwargs(`Dict[str, Any]`, *optional*, defaults to {}):
474
+ Keyword arguments for the tokenizer
475
+ Returns:
476
+ By default, a list of tensors is returned.
477
+ If convert_to_tensor, a stacked tensor is returned.
478
+ If convert_to_numpy, a numpy matrix is returned.
479
+ """
480
+ from transformers import AutoTokenizer
481
+
482
+ self.tokenizer = AutoTokenizer.from_pretrained(
483
+ self.name_or_path, trust_remote_code=True
484
+ )
485
+
486
+ is_training = self.training
487
+ self.eval()
488
+
489
+ if show_progress_bar is None:
490
+ show_progress_bar = (
491
+ logger.getEffectiveLevel() == logging.INFO
492
+ or logger.getEffectiveLevel() == logging.DEBUG
493
+ )
494
+
495
+ if convert_to_tensor:
496
+ convert_to_numpy = False
497
+
498
+ if output_value != 'sentence_embedding':
499
+ convert_to_tensor = False
500
+ convert_to_numpy = False
501
+
502
+ input_was_string = False
503
+ if isinstance(sentences, str) or not hasattr(sentences, '__len__'):
504
+ sentences = [sentences]
505
+ input_was_string = True
506
+
507
+ if device is not None:
508
+ self.to(device)
509
+
510
+ permutation = np.argsort([-len(i) for i in sentences])
511
+ inverse_permutation = np.argsort(permutation)
512
+ sentences = [sentences[idx] for idx in permutation]
513
+
514
+ tokenizer_kwargs['padding'] = tokenizer_kwargs.get('padding', True)
515
+ tokenizer_kwargs['max_length'] = tokenizer_kwargs.get(
516
+ 'max_length', self.tokenizer.init_kwargs.get('model_max_length', 8192)
517
+ )
518
+ tokenizer_kwargs['truncation'] = tokenizer_kwargs.get('truncation', True)
519
+
520
+ all_embeddings = []
521
+
522
+ if trange is not None:
523
+ range_iter = trange(
524
+ 0,
525
+ len(sentences),
526
+ batch_size,
527
+ desc="Encoding",
528
+ disable=not show_progress_bar,
529
+ )
530
+ else:
531
+ range_iter = range(0, len(sentences), batch_size)
532
+
533
+ for i in range_iter:
534
+ encoded_input = self.tokenizer(
535
+ sentences[i : i + batch_size],
536
+ return_tensors='pt',
537
+ **tokenizer_kwargs,
538
+ ).to(self.device)
539
+ token_embs = self.forward(**encoded_input)[0]
540
+
541
+ # Accumulate in fp32 to avoid overflow
542
+ token_embs = token_embs.float()
543
+
544
+ if output_value == 'token_embeddings':
545
+ raise NotImplementedError
546
+ elif output_value is None:
547
+ raise NotImplementedError
548
+ else:
549
+ if self.config.emb_pooler == 'cls':
550
+ embeddings = self.cls_pooling(
551
+ token_embs, encoded_input['attention_mask']
552
+ )
553
+ else:
554
+ embeddings = self.mean_pooling(
555
+ token_embs, encoded_input['attention_mask']
556
+ )
557
+
558
+ if normalize_embeddings:
559
+ embeddings = torch.nn.functional.normalize(embeddings, p=2, dim=1)
560
+
561
+ if convert_to_numpy:
562
+ embeddings = embeddings.cpu()
563
+ all_embeddings.extend(embeddings)
564
+
565
+ all_embeddings = [all_embeddings[idx] for idx in inverse_permutation]
566
+
567
+ truncate_dim = truncate_dim or self.config.truncate_dim
568
+ if truncate_dim:
569
+ all_embeddings = self.truncate_embeddings(all_embeddings, truncate_dim)
570
+
571
+ if convert_to_tensor:
572
+ all_embeddings = torch.stack(all_embeddings)
573
+ elif convert_to_numpy:
574
+ all_embeddings = np.asarray([emb.numpy() for emb in all_embeddings])
575
+
576
+ if input_was_string:
577
+ all_embeddings = all_embeddings[0]
578
+
579
+ self.train(is_training)
580
+ return all_embeddings
581
+
582
+
583
+ def truncate_embeddings(self, embeddings, truncate_dim):
584
+ if not self.config.matryoshka_dimensions:
585
+ logger.warning(
586
+ 'Matryoshka embeddings are not supported, so dimension truncation will not be performed.'
587
+ )
588
+ return embeddings
589
+ elif truncate_dim in self.config.matryoshka_dimensions:
590
+ return [tensor[:truncate_dim] for tensor in embeddings]
591
+ else:
592
+ raise ValueError(f'The provided `truncate_dim` value of {truncate_dim} is not supported. '
593
+ f'Supported dimensions are {self.config.matryoshka_dimensions}.')
594
+
595
+ def mean_pooling(
596
+ self, token_embeddings: torch.Tensor, attention_mask: torch.Tensor
597
+ ):
598
+ input_mask_expanded = (
599
+ attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
600
+ )
601
+ return torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp(
602
+ input_mask_expanded.sum(1), min=1e-9
603
+ )
604
+
605
+
606
+ def cls_pooling(
607
+ self, token_embeddings: torch.Tensor, attention_mask: torch.Tensor
608
+ ):
609
+ return token_embeddings[:,0]
610
+
611
+
612
+ def forward(
613
+ self,
614
+ input_ids,
615
+ position_ids=None,
616
+ token_type_ids=None,
617
+ attention_mask=None,
618
+ masked_tokens_mask=None,
619
+ return_dict=None,
620
+ **kwargs,
621
+ ):
622
+ """If masked_tokens_mask is not None (i.e. last_layer_subset == True in XLMForPreTraining),
623
+ we only want the output for the masked tokens. This means that we only compute the last
624
+ layer output for these tokens.
625
+ masked_tokens_mask: (batch, seqlen), dtype=torch.bool
626
+ """
627
+
628
+ if kwargs:
629
+ for key, value in kwargs.items():
630
+ if value is not None:
631
+ logger.warning(
632
+ 'Flash attention implementation does not support kwargs: %s',
633
+ key,
634
+ )
635
+
636
+ return_dict = (
637
+ return_dict if return_dict is not None else self.config.use_return_dict
638
+ )
639
+
640
+ hidden_states = self.embeddings(
641
+ input_ids, position_ids=position_ids, token_type_ids=token_type_ids
642
+ )
643
+ # TD [2022-12:18]: Don't need to force residual in fp32
644
+ # BERT puts embedding LayerNorm before embedding dropout.
645
+ if not self.fused_dropout_add_ln:
646
+ hidden_states = self.emb_ln(hidden_states)
647
+ else:
648
+ hidden_states = layer_norm_fn(
649
+ hidden_states, self.emb_ln.weight, self.emb_ln.bias, eps=self.emb_ln.eps
650
+ )
651
+ hidden_states = self.emb_drop(hidden_states)
652
+
653
+ if masked_tokens_mask is not None:
654
+ batch_size, seqlen = input_ids.shape[:2]
655
+ # We also need the first column for the CLS token
656
+ first_col_mask = torch.zeros(
657
+ batch_size, seqlen, dtype=torch.bool, device=input_ids.device
658
+ )
659
+ first_col_mask[:, 0] = True
660
+ subset_mask = masked_tokens_mask | first_col_mask
661
+ else:
662
+ subset_mask = None
663
+
664
+ sequence_output = self.encoder(
665
+ hidden_states, key_padding_mask=attention_mask, subset_mask=subset_mask
666
+ )
667
+
668
+ if masked_tokens_mask is None:
669
+ pooled_output = (
670
+ self.pooler(sequence_output) if self.pooler is not None else None
671
+ )
672
+ else:
673
+ # TD [2022-03-01]: the indexing here is very tricky.
674
+ if attention_mask is not None:
675
+ subset_idx = subset_mask[attention_mask]
676
+ pool_input = sequence_output[first_col_mask[attention_mask][subset_idx]]
677
+ sequence_output = sequence_output[
678
+ masked_tokens_mask[attention_mask][subset_idx]
679
+ ]
680
+ else:
681
+ pool_input = sequence_output[first_col_mask[subset_mask]]
682
+ sequence_output = sequence_output[masked_tokens_mask[subset_mask]]
683
+ pooled_output = (
684
+ self.pooler(pool_input, pool=False) if self.pooler is not None else None
685
+ )
686
+
687
+ if not return_dict:
688
+ return sequence_output, pooled_output
689
+
690
+ return BaseModelOutputWithPoolingAndCrossAttentions(
691
+ last_hidden_state=sequence_output,
692
+ pooler_output=pooled_output,
693
+ )
694
+
695
+
696
+ class XLMRobertaForMaskedLM(XLMRobertaPreTrainedModel):
697
+ _tied_weights_keys = ["lm_head.decoder.weight", "lm_head.decoder.bias"]
698
+
699
+ def __init__(self, config):
700
+ super().__init__(config)
701
+
702
+ if config.is_decoder:
703
+ logger.warning(
704
+ "If you want to use `XLMRobertaForMaskedLM` make sure `config.is_decoder=False` for "
705
+ "bi-directional self-attention."
706
+ )
707
+
708
+ self.roberta = XLMRobertaModel(config, add_pooling_layer=False)
709
+ self.lm_head = XLMRobertaLMHead(config)
710
+
711
+ # Initialize weights and apply final processing
712
+ self.post_init()
713
+
714
+ def get_input_embeddings(self):
715
+ return self.roberta.embeddings.word_embeddings
716
+
717
+ def get_output_embeddings(self):
718
+ return self.lm_head.decoder
719
+
720
+ def set_output_embeddings(self, new_embeddings):
721
+ self.lm_head.decoder = new_embeddings
722
+
723
+ def forward(
724
+ self,
725
+ input_ids: Optional[torch.LongTensor] = None,
726
+ attention_mask: Optional[torch.FloatTensor] = None,
727
+ token_type_ids: Optional[torch.LongTensor] = None,
728
+ position_ids: Optional[torch.LongTensor] = None,
729
+ head_mask: Optional[torch.FloatTensor] = None,
730
+ inputs_embeds: Optional[torch.FloatTensor] = None,
731
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
732
+ encoder_attention_mask: Optional[torch.FloatTensor] = None,
733
+ labels: Optional[torch.LongTensor] = None,
734
+ output_attentions: Optional[bool] = None,
735
+ output_hidden_states: Optional[bool] = None,
736
+ return_dict: Optional[bool] = None,
737
+ ) -> Union[Tuple[torch.Tensor], MaskedLMOutput]:
738
+ r"""
739
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
740
+ Labels for computing the masked language modeling loss. Indices should be in `[-100, 0, ...,
741
+ config.vocab_size]` (see `input_ids` docstring) Tokens with indices set to `-100` are ignored (masked), the
742
+ loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`
743
+ kwargs (`Dict[str, any]`, optional, defaults to *{}*):
744
+ Used to hide legacy arguments that have been deprecated.
745
+ """
746
+ return_dict = (
747
+ return_dict if return_dict is not None else self.config.use_return_dict
748
+ )
749
+
750
+ outputs = self.roberta(
751
+ input_ids,
752
+ attention_mask=attention_mask,
753
+ token_type_ids=token_type_ids,
754
+ position_ids=position_ids,
755
+ head_mask=head_mask,
756
+ inputs_embeds=inputs_embeds,
757
+ encoder_hidden_states=encoder_hidden_states,
758
+ encoder_attention_mask=encoder_attention_mask,
759
+ output_attentions=output_attentions,
760
+ output_hidden_states=output_hidden_states,
761
+ return_dict=return_dict,
762
+ )
763
+ sequence_output = outputs[0]
764
+ prediction_scores = self.lm_head(sequence_output)
765
+
766
+ masked_lm_loss = None
767
+ if labels is not None:
768
+ # move labels to correct device to enable model parallelism
769
+ labels = labels.to(prediction_scores.device)
770
+ loss_fct = CrossEntropyLoss()
771
+ masked_lm_loss = loss_fct(
772
+ prediction_scores.view(-1, self.config.vocab_size), labels.view(-1)
773
+ )
774
+
775
+ if not return_dict:
776
+ output = (prediction_scores,) + outputs[2:]
777
+ return (
778
+ ((masked_lm_loss,) + output) if masked_lm_loss is not None else output
779
+ )
780
+
781
+ return MaskedLMOutput(
782
+ loss=masked_lm_loss,
783
+ logits=prediction_scores,
784
+ hidden_states=outputs.hidden_states,
785
+ attentions=outputs.attentions,
786
+ )
787
+
788
+
789
+ # Copied from transformers.models.roberta.modeling_roberta.RobertaClassificationHead with Roberta->XLMRoberta
790
+ class XLMRobertaClassificationHead(nn.Module):
791
+ """Head for sentence-level classification tasks."""
792
+
793
+ def __init__(self, config):
794
+ super().__init__()
795
+ fused_bias_fc = getattr(config, "fused_bias_fc", False)
796
+ if fused_bias_fc and FusedDense is None:
797
+ raise ImportError("fused_dense is not installed")
798
+ linear_cls = nn.Linear if not fused_bias_fc else FusedDense
799
+ self.dense = linear_cls(config.hidden_size, config.hidden_size)
800
+ classifier_dropout = (
801
+ config.classifier_dropout
802
+ if config.classifier_dropout is not None
803
+ else config.hidden_dropout_prob
804
+ )
805
+ self.dropout = nn.Dropout(classifier_dropout)
806
+ self.out_proj = linear_cls(config.hidden_size, config.num_labels)
807
+
808
+ def forward(self, features, **kwargs):
809
+ x = features[:, 0, :] # take <s> token (equiv. to [CLS])
810
+ x = self.dropout(x)
811
+ x = self.dense(x)
812
+ x = torch.tanh(x)
813
+ x = self.dropout(x)
814
+ x = self.out_proj(x)
815
+ return x
816
+
817
+
818
+ # Copied from transformers.models.roberta.modeling_roberta.RobertaForSequenceClassification with Roberta->XLMRoberta, ROBERTA->XLM_ROBERTA
819
+ class XLMRobertaForSequenceClassification(XLMRobertaPreTrainedModel):
820
+ def __init__(self, config):
821
+ super().__init__(config)
822
+ self.num_labels = config.num_labels
823
+ self.config = config
824
+
825
+ self.roberta = XLMRobertaModel(config, add_pooling_layer=False)
826
+ self.classifier = XLMRobertaClassificationHead(config)
827
+
828
+ # Initialize weights and apply final processing
829
+ self.post_init()
830
+
831
+ def forward(
832
+ self,
833
+ input_ids: Optional[torch.LongTensor] = None,
834
+ attention_mask: Optional[torch.FloatTensor] = None,
835
+ token_type_ids: Optional[torch.LongTensor] = None,
836
+ position_ids: Optional[torch.LongTensor] = None,
837
+ head_mask: Optional[torch.FloatTensor] = None,
838
+ inputs_embeds: Optional[torch.FloatTensor] = None,
839
+ labels: Optional[torch.LongTensor] = None,
840
+ output_attentions: Optional[bool] = None,
841
+ output_hidden_states: Optional[bool] = None,
842
+ return_dict: Optional[bool] = None,
843
+ ) -> Union[Tuple[torch.Tensor], SequenceClassifierOutput]:
844
+ r"""
845
+ labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
846
+ Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
847
+ config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
848
+ `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
849
+ """
850
+ return_dict = (
851
+ return_dict if return_dict is not None else self.config.use_return_dict
852
+ )
853
+
854
+ outputs = self.roberta(
855
+ input_ids,
856
+ attention_mask=attention_mask,
857
+ token_type_ids=token_type_ids,
858
+ position_ids=position_ids,
859
+ head_mask=head_mask,
860
+ inputs_embeds=inputs_embeds,
861
+ output_attentions=output_attentions,
862
+ output_hidden_states=output_hidden_states,
863
+ return_dict=return_dict,
864
+ )
865
+ sequence_output = outputs[0]
866
+ logits = self.classifier(sequence_output)
867
+
868
+ loss = None
869
+ if labels is not None:
870
+ # move labels to correct device to enable model parallelism
871
+ labels = labels.to(logits.device)
872
+ if self.config.problem_type is None:
873
+ if self.num_labels == 1:
874
+ self.config.problem_type = "regression"
875
+ elif self.num_labels > 1 and (
876
+ labels.dtype == torch.long or labels.dtype == torch.int
877
+ ):
878
+ self.config.problem_type = "single_label_classification"
879
+ else:
880
+ self.config.problem_type = "multi_label_classification"
881
+
882
+ if self.config.problem_type == "regression":
883
+ loss_fct = MSELoss()
884
+ if self.num_labels == 1:
885
+ loss = loss_fct(logits.squeeze(), labels.squeeze())
886
+ else:
887
+ loss = loss_fct(logits, labels)
888
+ elif self.config.problem_type == "single_label_classification":
889
+ loss_fct = CrossEntropyLoss()
890
+ loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
891
+ elif self.config.problem_type == "multi_label_classification":
892
+ loss_fct = BCEWithLogitsLoss()
893
+ loss = loss_fct(logits, labels)
894
+
895
+ if not return_dict:
896
+ output = (logits,) + outputs[2:]
897
+ return ((loss,) + output) if loss is not None else output
898
+
899
+ return SequenceClassifierOutput(
900
+ loss=loss,
901
+ logits=logits,
902
+ hidden_states=outputs.hidden_states,
903
+ attentions=outputs.attentions,
904
+ )
905
+
906
+
907
+ @torch.inference_mode()
908
+ def compute_score(
909
+ self,
910
+ sentence_pairs: Union[List[Tuple[str, str]], Tuple[str, str]],
911
+ batch_size: int = 32,
912
+ max_length: Optional[int] = None,
913
+ ) -> List[float]:
914
+
915
+ if not hasattr(self, "_tokenizer"):
916
+ from transformers import AutoTokenizer
917
+
918
+ self._tokenizer = AutoTokenizer.from_pretrained(
919
+ self.name_or_path, trust_remote_code=True
920
+ )
921
+
922
+ assert isinstance(sentence_pairs, list)
923
+ if isinstance(sentence_pairs[0], str):
924
+ sentence_pairs = [sentence_pairs]
925
+
926
+ all_scores = []
927
+ for start_index in range(
928
+ 0, len(sentence_pairs), batch_size
929
+ ):
930
+ sentences_batch = sentence_pairs[
931
+ start_index : start_index + batch_size
932
+ ]
933
+ inputs = self._tokenizer(
934
+ sentences_batch,
935
+ padding=True,
936
+ truncation=True,
937
+ return_tensors='pt',
938
+ max_length=max_length,
939
+ ).to(self.device)
940
+ scores = (
941
+ self.forward(**inputs, return_dict=True)
942
+ .logits.view(
943
+ -1,
944
+ )
945
+ .float()
946
+ )
947
+ scores = torch.sigmoid(scores)
948
+ all_scores.extend(scores.cpu().numpy().tolist())
949
+
950
+ if len(all_scores) == 1:
951
+ return all_scores[0]
952
+ return all_scores
953
+
954
+ def predict(
955
+ self,
956
+ sentence_pairs: Union[List[Tuple[str, str]], Tuple[str, str]],
957
+ batch_size: int = 32,
958
+ max_length: Optional[int] = None,
959
+ ) -> List[float]:
960
+ # used for beir evaluation
961
+ return self.compute_score(sentence_pairs, batch_size=batch_size, max_length=max_length)
962
+
963
+ def rerank(
964
+ self,
965
+ query: str,
966
+ documents: List[str],
967
+ batch_size: int = 32,
968
+ max_length: int = 1024,
969
+ max_query_length: int = 512,
970
+ overlap_tokens: int = 80,
971
+ top_n: Optional[int] = None,
972
+ **kwargs,
973
+ ):
974
+ assert max_length >= max_query_length * 2, (
975
+ f'max_length ({max_length}) must be greater than or equal to '
976
+ f'max_query_length ({max_query_length}) * 2'
977
+ )
978
+
979
+ if not hasattr(self, "_tokenizer"):
980
+ from transformers import AutoTokenizer
981
+
982
+ self._tokenizer = AutoTokenizer.from_pretrained(
983
+ self.name_or_path, trust_remote_code=True
984
+ )
985
+
986
+ # preproc of tokenization
987
+ sentence_pairs, sentence_pairs_pids = reranker_tokenize_preproc(
988
+ query,
989
+ documents,
990
+ tokenizer=self._tokenizer,
991
+ max_length=max_length,
992
+ max_query_length=max_query_length,
993
+ overlap_tokens=overlap_tokens,
994
+ )
995
+
996
+ tot_scores = []
997
+ with torch.no_grad():
998
+ for k in range(0, len(sentence_pairs), batch_size):
999
+ batch = self._tokenizer.pad(
1000
+ sentence_pairs[k : k + batch_size],
1001
+ padding=True,
1002
+ max_length=max_length,
1003
+ pad_to_multiple_of=None,
1004
+ return_tensors="pt",
1005
+ )
1006
+ batch_on_device = {k: v.to(self.device) for k, v in batch.items()}
1007
+ scores = (
1008
+ self.forward(**batch_on_device, return_dict=True)
1009
+ .logits.view(
1010
+ -1,
1011
+ )
1012
+ .float()
1013
+ )
1014
+ scores = torch.sigmoid(scores)
1015
+ tot_scores.extend(scores.cpu().numpy().tolist())
1016
+
1017
+ # ranking
1018
+ merge_scores = [0 for _ in range(len(documents))]
1019
+ for pid, score in zip(sentence_pairs_pids, tot_scores):
1020
+ merge_scores[pid] = max(merge_scores[pid], score)
1021
+
1022
+ merge_scores_argsort = np.argsort(merge_scores)[::-1]
1023
+ sorted_documents = []
1024
+ sorted_scores = []
1025
+ for mid in merge_scores_argsort:
1026
+ sorted_scores.append(merge_scores[mid])
1027
+ sorted_documents.append(documents[mid])
1028
+
1029
+ top_n = min(top_n or len(sorted_documents), len(sorted_documents))
1030
+
1031
+ return [
1032
+ {
1033
+ 'document': sorted_documents[i],
1034
+ 'relevance_score': sorted_scores[i],
1035
+ 'index': merge_scores_argsort[i],
1036
+ }
1037
+ for i in range(top_n)
1038
+ ]
1039
+
1040
+
1041
+ def reranker_tokenize_preproc(
1042
+ query: str,
1043
+ passages: List[str],
1044
+ tokenizer=None,
1045
+ max_length: int = 1024,
1046
+ max_query_length: int = 512,
1047
+ overlap_tokens: int = 80,
1048
+ ):
1049
+ from copy import deepcopy
1050
+
1051
+ assert tokenizer is not None, "Please provide a valid tokenizer for tokenization!"
1052
+ sep_id = tokenizer.sep_token_id
1053
+
1054
+ def _merge_inputs(chunk1_raw, chunk2):
1055
+ chunk1 = deepcopy(chunk1_raw)
1056
+ chunk1['input_ids'].append(sep_id)
1057
+ chunk1['input_ids'].extend(chunk2['input_ids'])
1058
+ chunk1['input_ids'].append(sep_id)
1059
+ chunk1['attention_mask'].append(chunk2['attention_mask'][0])
1060
+ chunk1['attention_mask'].extend(chunk2['attention_mask'])
1061
+ chunk1['attention_mask'].append(chunk2['attention_mask'][-1])
1062
+ if 'token_type_ids' in chunk1:
1063
+ token_type_ids = [1 for _ in range(len(chunk2['token_type_ids']) + 2)]
1064
+ chunk1['token_type_ids'].extend(token_type_ids)
1065
+ return chunk1
1066
+
1067
+ # Note: the long query will be truncated to 256 tokens by default
1068
+ query_inputs = tokenizer.encode_plus(
1069
+ query, truncation=True, padding=False, max_length=max_query_length
1070
+ )
1071
+
1072
+ max_passage_inputs_length = max_length - len(query_inputs['input_ids']) - 2
1073
+ # assert (
1074
+ # max_passage_inputs_length > 100
1075
+ # ), "Your query is too long! Please make sure your query less than 500 tokens!"
1076
+
1077
+ overlap_tokens_implt = min(overlap_tokens, max_passage_inputs_length // 4)
1078
+
1079
+ res_merge_inputs = []
1080
+ res_merge_inputs_pids = []
1081
+ for pid, passage in enumerate(passages):
1082
+ passage_inputs = tokenizer.encode_plus(
1083
+ passage,
1084
+ truncation=False,
1085
+ padding=False,
1086
+ add_special_tokens=False,
1087
+ max_length=0,
1088
+ )
1089
+ passage_inputs_length = len(passage_inputs['input_ids'])
1090
+
1091
+ if passage_inputs_length <= max_passage_inputs_length:
1092
+ qp_merge_inputs = _merge_inputs(query_inputs, passage_inputs)
1093
+ res_merge_inputs.append(qp_merge_inputs)
1094
+ res_merge_inputs_pids.append(pid)
1095
+ else:
1096
+ start_id = 0
1097
+ while start_id < passage_inputs_length:
1098
+ end_id = start_id + max_passage_inputs_length
1099
+ # make sure the length of the last chunk is `max_passage_inputs_length`
1100
+ if end_id >= passage_inputs_length:
1101
+ sub_passage_inputs = {
1102
+ k: v[-max_passage_inputs_length:]
1103
+ for k, v in passage_inputs.items()
1104
+ }
1105
+ else:
1106
+ sub_passage_inputs = {
1107
+ k: v[start_id:end_id] for k, v in passage_inputs.items()
1108
+ }
1109
+ start_id = (
1110
+ end_id - overlap_tokens_implt
1111
+ if end_id < passage_inputs_length
1112
+ else end_id
1113
+ )
1114
+
1115
+ qp_merge_inputs = _merge_inputs(query_inputs, sub_passage_inputs)
1116
+ res_merge_inputs.append(qp_merge_inputs)
1117
+ res_merge_inputs_pids.append(pid)
1118
+
1119
+ return res_merge_inputs, res_merge_inputs_pids
special_tokens_map.json ADDED
@@ -0,0 +1,51 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "bos_token": {
3
+ "content": "<s>",
4
+ "lstrip": false,
5
+ "normalized": false,
6
+ "rstrip": false,
7
+ "single_word": false
8
+ },
9
+ "cls_token": {
10
+ "content": "<s>",
11
+ "lstrip": false,
12
+ "normalized": false,
13
+ "rstrip": false,
14
+ "single_word": false
15
+ },
16
+ "eos_token": {
17
+ "content": "</s>",
18
+ "lstrip": false,
19
+ "normalized": false,
20
+ "rstrip": false,
21
+ "single_word": false
22
+ },
23
+ "mask_token": {
24
+ "content": "<mask>",
25
+ "lstrip": true,
26
+ "normalized": false,
27
+ "rstrip": false,
28
+ "single_word": false
29
+ },
30
+ "pad_token": {
31
+ "content": "<pad>",
32
+ "lstrip": false,
33
+ "normalized": false,
34
+ "rstrip": false,
35
+ "single_word": false
36
+ },
37
+ "sep_token": {
38
+ "content": "</s>",
39
+ "lstrip": false,
40
+ "normalized": false,
41
+ "rstrip": false,
42
+ "single_word": false
43
+ },
44
+ "unk_token": {
45
+ "content": "<unk>",
46
+ "lstrip": false,
47
+ "normalized": false,
48
+ "rstrip": false,
49
+ "single_word": false
50
+ }
51
+ }
tokenizer.json ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:e802fe5337779428818439760a1e6161ed36ceed72d4ebcbda9c139a2108fc99
3
+ size 17082988
tokenizer_config.json ADDED
@@ -0,0 +1,55 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "added_tokens_decoder": {
3
+ "0": {
4
+ "content": "<s>",
5
+ "lstrip": false,
6
+ "normalized": false,
7
+ "rstrip": false,
8
+ "single_word": false,
9
+ "special": true
10
+ },
11
+ "1": {
12
+ "content": "<pad>",
13
+ "lstrip": false,
14
+ "normalized": false,
15
+ "rstrip": false,
16
+ "single_word": false,
17
+ "special": true
18
+ },
19
+ "2": {
20
+ "content": "</s>",
21
+ "lstrip": false,
22
+ "normalized": false,
23
+ "rstrip": false,
24
+ "single_word": false,
25
+ "special": true
26
+ },
27
+ "3": {
28
+ "content": "<unk>",
29
+ "lstrip": false,
30
+ "normalized": false,
31
+ "rstrip": false,
32
+ "single_word": false,
33
+ "special": true
34
+ },
35
+ "250001": {
36
+ "content": "<mask>",
37
+ "lstrip": true,
38
+ "normalized": false,
39
+ "rstrip": false,
40
+ "single_word": false,
41
+ "special": true
42
+ }
43
+ },
44
+ "bos_token": "<s>",
45
+ "clean_up_tokenization_spaces": true,
46
+ "cls_token": "<s>",
47
+ "eos_token": "</s>",
48
+ "extra_special_tokens": {},
49
+ "mask_token": "<mask>",
50
+ "model_max_length": 1024,
51
+ "pad_token": "<pad>",
52
+ "sep_token": "</s>",
53
+ "tokenizer_class": "XLMRobertaTokenizerFast",
54
+ "unk_token": "<unk>"
55
+ }
xlm_padding.py ADDED
@@ -0,0 +1,218 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # This implementation was adapted from https://github.com/Dao-AILab/flash-attention/blob/main/flash_attn/modules/block.py
2
+ # Commit id: c94cd09744d20f0ac587a351ff6ff2e8ad11ae1b
3
+
4
+ # Previously adapted from https://github.com/mlcommons/training_results_v1.1/blob/main/NVIDIA/benchmarks/bert/implementations/pytorch/padding.py
5
+
6
+ import torch
7
+ import torch.nn.functional as F
8
+ from einops import rearrange, repeat
9
+
10
+
11
+ class IndexFirstAxis(torch.autograd.Function):
12
+ @staticmethod
13
+ def forward(ctx, input, indices):
14
+ ctx.save_for_backward(indices)
15
+ assert input.ndim >= 2
16
+ ctx.first_axis_dim, other_shape = input.shape[0], input.shape[1:]
17
+ second_dim = other_shape.numel()
18
+ # TD [2022-03-04] For some reason torch.gather is a bit faster than indexing.
19
+ # return input[indices]
20
+ return torch.gather(
21
+ rearrange(input, "b ... -> b (...)"), 0, repeat(indices, "z -> z d", d=second_dim)
22
+ ).reshape(-1, *other_shape)
23
+
24
+ @staticmethod
25
+ def backward(ctx, grad_output):
26
+ (indices,) = ctx.saved_tensors
27
+ assert grad_output.ndim >= 2
28
+ other_shape = grad_output.shape[1:]
29
+ grad_output = rearrange(grad_output, "b ... -> b (...)")
30
+ grad_input = torch.zeros(
31
+ [ctx.first_axis_dim, grad_output.shape[1]],
32
+ device=grad_output.device,
33
+ dtype=grad_output.dtype,
34
+ )
35
+ # TD [2022-03-04] For some reason torch.scatter is a bit faster than indexing.
36
+ # grad_input[indices] = grad_output
37
+ grad_input.scatter_(0, repeat(indices, "z -> z d", d=grad_output.shape[1]), grad_output)
38
+ return grad_input.reshape(ctx.first_axis_dim, *other_shape), None
39
+
40
+
41
+ index_first_axis = IndexFirstAxis.apply
42
+
43
+
44
+ class IndexPutFirstAxis(torch.autograd.Function):
45
+ @staticmethod
46
+ def forward(ctx, values, indices, first_axis_dim):
47
+ ctx.save_for_backward(indices)
48
+ assert indices.ndim == 1
49
+ assert values.ndim >= 2
50
+ output = torch.zeros(
51
+ first_axis_dim, *values.shape[1:], device=values.device, dtype=values.dtype
52
+ )
53
+ # TD [2022-03-04] For some reason torch.scatter is a bit faster than indexing.
54
+ output[indices] = values
55
+ # output.scatter_(0, repeat(indices, 'z -> z d', d=values.shape[1]), values)
56
+ return output
57
+
58
+ @staticmethod
59
+ def backward(ctx, grad_output):
60
+ (indices,) = ctx.saved_tensors
61
+ # TD [2022-03-04] For some reason torch.gather is a bit faster than indexing.
62
+ grad_values = grad_output[indices]
63
+ # grad_values = torch.gather(grad_output, 0, repeat(indices, 'z -> z d', d=grad_output.shape[1]))
64
+ return grad_values, None, None
65
+
66
+
67
+ index_put_first_axis = IndexPutFirstAxis.apply
68
+
69
+
70
+ class IndexFirstAxisResidual(torch.autograd.Function):
71
+ @staticmethod
72
+ def forward(ctx, input, indices):
73
+ ctx.save_for_backward(indices)
74
+ assert input.ndim >= 2
75
+ ctx.first_axis_dim, other_shape = input.shape[0], input.shape[1:]
76
+ second_dim = other_shape.numel()
77
+ # TD [2022-03-04] For some reason torch.gather is a bit faster than indexing.
78
+ output = input[indices]
79
+ # We don't want to reshape input (b ... -> b (...)) since it could change the channel_last
80
+ # memory format to channel_first. In other words, input might not be contiguous.
81
+ # If we don't detach, Pytorch complains about output being a view and is being modified inplace
82
+ return output, input.detach()
83
+
84
+ @staticmethod
85
+ def backward(ctx, grad_output, grad_residual):
86
+ (indices,) = ctx.saved_tensors
87
+ assert grad_output.ndim >= 2
88
+ other_shape = grad_output.shape[1:]
89
+ assert grad_residual.shape[1:] == other_shape
90
+ grad_input = grad_residual
91
+ # grad_input[indices] += grad_output
92
+ indices = indices.reshape(indices.shape[0], *((1,) * (grad_output.ndim - 1)))
93
+ indices = indices.expand_as(grad_output)
94
+ grad_input.scatter_add_(0, indices, grad_output)
95
+ return grad_input.reshape(ctx.first_axis_dim, *other_shape), None
96
+
97
+
98
+ index_first_axis_residual = IndexFirstAxisResidual.apply
99
+
100
+
101
+ def unpad_input(hidden_states, attention_mask):
102
+ """
103
+ Arguments:
104
+ hidden_states: (batch, seqlen, ...)
105
+ attention_mask: (batch, seqlen), bool / int, 1 means valid and 0 means not valid.
106
+ Return:
107
+ hidden_states: (total_nnz, ...), where total_nnz = number of tokens in selected in attention_mask.
108
+ indices: (total_nnz), the indices of non-masked tokens from the flattened input sequence.
109
+ cu_seqlens: (batch + 1), the cumulative sequence lengths, used to index into hidden_states.
110
+ max_seqlen_in_batch: int
111
+ """
112
+ seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32)
113
+ indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten()
114
+ max_seqlen_in_batch = seqlens_in_batch.max().item()
115
+ cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.torch.int32), (1, 0))
116
+ # TD [2022-03-04] We don't want to index with a bool mask, because Pytorch will expand the
117
+ # bool mask, then call nonzero to get the indices, then index with those. The indices is @dim
118
+ # times larger than it needs to be, wasting memory. It's faster and more memory-efficient to
119
+ # index with integer indices. Moreover, torch's index is a bit slower than it needs to be,
120
+ # so we write custom forward and backward to make it a bit faster.
121
+ return (
122
+ index_first_axis(rearrange(hidden_states, "b s ... -> (b s) ..."), indices),
123
+ indices,
124
+ cu_seqlens,
125
+ max_seqlen_in_batch,
126
+ )
127
+
128
+
129
+ def unpad_input_for_concatenated_sequences(hidden_states, attention_mask_in_length):
130
+ """
131
+ Supports concatenating short samples in one sequence. The attention_mask_in_length is utilized to mask other short samples. It helps efficient training of variant lengths-based samples (e.g., the supervised fine-tuning task in large language model).
132
+ The motivation for this function is explained [here](https://github.com/Dao-AILab/flash-attention/issues/432#issuecomment-1668822286).
133
+
134
+ For example, if batch = 3 and seqlen = 6, the attention_mask_in_length is:
135
+ ```
136
+ [
137
+ [2, 3, 0, 0, 0, 0],
138
+ [3, 2, 0, 0, 0, 0],
139
+ [6, 0, 0, 0, 0, 0]
140
+ ]
141
+ ```
142
+ , which refers to the 3D-attention mask:
143
+ ```
144
+ [
145
+ [
146
+ [1, 0, 0, 0, 0, 0],
147
+ [1, 1, 0, 0, 0, 0],
148
+ [0, 0, 1, 0, 0, 0],
149
+ [0, 0, 1, 1, 0, 0],
150
+ [0, 0, 1, 1, 1, 0],
151
+ [0, 0, 0, 0, 0, 1]
152
+ ],
153
+ [
154
+ [1, 0, 0, 0, 0, 0],
155
+ [1, 1, 0, 0, 0, 0],
156
+ [1, 1, 1, 0, 0, 0],
157
+ [0, 0, 0, 1, 0, 0],
158
+ [0, 0, 0, 1, 1, 0],
159
+ [0, 0, 0, 0, 0, 1]
160
+ ],
161
+ [
162
+ [1, 0, 0, 0, 0, 0],
163
+ [1, 1, 0, 0, 0, 0],
164
+ [1, 1, 1, 0, 0, 0],
165
+ [1, 1, 1, 1, 0, 0],
166
+ [1, 1, 1, 1, 1, 0],
167
+ [1, 1, 1, 1, 1, 1]
168
+ ]
169
+ ]
170
+ ```.
171
+
172
+ Arguments:
173
+ hidden_states: (batch, seqlen, ...)
174
+ attention_mask_in_length: (batch, seqlen), int, a nonzero number (e.g., 1, 2, 3, etc.) means length of concatenated sequence in b-th batch, and 0 means none.
175
+ Return:
176
+ hidden_states: (total_nnz, ...), where total_nnz = number of tokens in selected in attention_mask.
177
+ indices: (total_nnz), the indices of non-masked tokens from the flattened input sequence.
178
+ cu_seqlens: (batch + 1), the cumulative sequence lengths, used to index into hidden_states.
179
+ max_seqlen_in_batch: int
180
+ """
181
+ length = attention_mask_in_length.sum(dim=-1)
182
+ seqlen = attention_mask_in_length.size(-1)
183
+ attention_mask_2d = torch.arange(seqlen, device=length.device, dtype=length.dtype).expand(len(length),
184
+ seqlen) < length.unsqueeze(
185
+ 1)
186
+ real_indices_idx = torch.nonzero(attention_mask_in_length.flatten(), as_tuple=False).flatten()
187
+ seqlens_in_batch = attention_mask_in_length.flatten()[real_indices_idx]
188
+ indices = torch.nonzero(attention_mask_2d.flatten(), as_tuple=False).flatten()
189
+ max_seqlen_in_batch = seqlens_in_batch.max().item()
190
+ cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.torch.int32), (1, 0))
191
+ # TD [2022-03-04] We don't want to index with a bool mask, because Pytorch will expand the
192
+ # bool mask, then call nonzero to get the indices, then index with those. The indices is @dim
193
+ # times larger than it needs to be, wasting memory. It's faster and more memory-efficient to
194
+ # index with integer indices. Moreover, torch's index is a bit slower than it needs to be,
195
+ # so we write custom forward and backward to make it a bit faster.
196
+ return (
197
+ index_first_axis(rearrange(hidden_states, "b s ... -> (b s) ..."), indices),
198
+ indices,
199
+ cu_seqlens,
200
+ max_seqlen_in_batch,
201
+ )
202
+
203
+
204
+ def pad_input(hidden_states, indices, batch, seqlen):
205
+ """
206
+ Arguments:
207
+ hidden_states: (total_nnz, ...), where total_nnz = number of tokens in selected in attention_mask.
208
+ indices: (total_nnz), the indices that represent the non-masked tokens of the original padded input sequence.
209
+ batch: int, batch size for the padded sequence.
210
+ seqlen: int, maximum sequence length for the padded sequence.
211
+ Return:
212
+ hidden_states: (batch, seqlen, ...)
213
+ """
214
+ dim = hidden_states.shape[-1]
215
+ # output = torch.zeros((batch * seqlen), dim, device=hidden_states.device, dtype=hidden_states.dtype)
216
+ # output[indices] = hidden_states
217
+ output = index_put_first_axis(hidden_states, indices, batch * seqlen)
218
+ return rearrange(output, "(b s) ... -> b s ...", b=batch)