以下說明 XlaBuilder
介面中定義的運算語意。通常,這些作業會一對一對應至 xla_data.proto
中 RPC 介面中定義的作業。
關於名稱的說明:XLA 處理的一般資料類型是 N 維陣列,可保留某些統一型別 (例如 32 位元浮點) 的元素。在整份說明文件中,陣列一詞用於表示任意維度的陣列。為方便起見,特殊情況會使用更具體且熟悉的名稱,例如向量是 1 維陣列,而矩陣是 2 維陣列。
AfterAll
另請參閱 XlaBuilder::AfterAll
。
AfterAll 會接受可變數數量的符記,並產生單一符記。符記是原始類型,可在副作用運算之間建立執行緒,以便強制排序。AfterAll
可用於彙整符記,以便在設定作業後排序作業。
AfterAll(operands)
引數 | 類型 | 語意 |
---|---|---|
operands |
XlaOp |
符記的變數數量 |
AllGather
另請參閱 XlaBuilder::AllGather
。
在備援機制之間執行連接。
AllGather(operand, all_gather_dim, shard_count, replica_group_ids,
channel_id)
引數 | 類型 | 語意 |
---|---|---|
operand
|
XlaOp
|
用於在各個複本之間連接的陣列 |
all_gather_dim |
int64 |
連接維度 |
replica_groups
|
int64 的向量向量 |
要執行連接的群組 |
channel_id
|
選填 int64 |
跨模組通訊的選用管道 ID |
replica_groups
是執行連結的複本群組清單 (您可以使用ReplicaId
擷取目前複本的複本 ID)。每個群組中的複本順序,會決定複本輸入內容在結果中的順序。replica_groups
必須為空白 (在這種情況下,所有副本都屬於單一群組,並以0
到N - 1
的順序排列),或是包含與副本數量相同的元素。例如,replica_groups = {0, 2}, {1, 3}
會在備援機制0
和2
,以及1
和3
之間執行連結作業。shard_count
是每個複本群組的大小。在replica_groups
為空白的情況下,我們需要這個值。channel_id
用於跨模組通訊:只有具有相同channel_id
的all-gather
作業才能相互通訊。
輸出形狀是輸入形狀,其中 all_gather_dim
放大了 shard_count
倍。舉例來說,如果有兩個副本,且運算元在兩個副本中分別具有 [1.0, 2.5]
和 [3.0, 5.25]
的值,則此運算子的輸出值 (all_gather_dim
為 0
) 會在兩個副本中皆為 [1.0, 2.5, 3.0,
5.25]
。
AllReduce
另請參閱 XlaBuilder::AllReduce
。
跨備用資源執行自訂運算。
AllReduce(operand, computation, replica_group_ids, channel_id)
引數 | 類型 | 語意 |
---|---|---|
operand
|
XlaOp
|
陣列或非空陣列元組,用於在備用資源之間執行縮減作業 |
computation |
XlaComputation |
減法運算 |
replica_groups
|
int64 的向量向量 |
要執行減法運算的群組 |
channel_id
|
選填 int64 |
跨模組通訊的選用管道 ID |
- 如果
operand
是陣列的元組,則會對元組的每個元素執行 all-reduce。 replica_groups
是執行縮減作業的備份群組清單 (您可以使用ReplicaId
擷取目前備份的 ID)。replica_groups
必須為空白 (此時所有備份都屬於單一群組),或包含與備份數量相同的元素。例如,replica_groups = {0, 2}, {1, 3}
會在複本0
和2
,以及1
和3
之間執行縮減作業。channel_id
用於跨模組通訊:只有具有相同channel_id
的all-reduce
作業才能相互通訊。
輸出形狀與輸入形狀相同。舉例來說,如果有兩個副本,且運算元在兩個副本中分別具有 [1.0, 2.5]
和 [3.0, 5.25]
的值,則這項運算和加總運算的輸出值會在兩個副本中皆為 [4.0, 7.75]
。如果輸入內容是元組,輸出內容也會是元組。
計算 AllReduce
的結果時,需要從每個複本取得一個輸入內容,因此如果某個複本執行 AllReduce
節點的次數多於其他複本,則前者會一直等待。由於副本都執行相同的程式,因此發生這種情況的機率不高,但如果 while 迴圈的條件取決於 infeed 的資料,且 infeed 的資料會導致 while 迴圈在某個副本上重複執行的次數多於另一個副本,就有可能發生這種情況。
AllToAll
另請參閱 XlaBuilder::AllToAll
。
AllToAll 是集體運算,可將資料從所有核心傳送至所有核心。這項作業有兩個階段:
- 散布階段。在每個核心上,運算元式會沿著
split_dimensions
分割為split_count
個區塊,並將區塊分散到所有核心,例如第 i 個區塊會傳送至第 i 個核心。 - 收集階段:每個核心會沿著
concat_dimension
串連收到的區塊。
您可以透過下列方式設定參與的核心:
replica_groups
:每個 ReplicaGroup 都包含參與運算的複本 ID 清單 (您可以使用ReplicaId
擷取目前複本的複本 ID)。AllToAll 會依指定順序在子群組中套用。舉例來說,replica_groups = { {1,2,3}, {4,5,0} }
表示 AllToAll 會套用至複本{1, 2, 3}
和收集階段,且收到的區塊會以 1、2、3 的順序連接。接著,系統會在副本 4、5、0 中套用另一個 AllToAll,且連結順序也是 4、5、0。如果replica_groups
為空白,則所有複本都屬於同一個群組,並按照其出現的連結順序排列。
需求條件:
split_dimension
上運算元的維度大小可被split_count
整除。- 運算元的形狀不是元組。
AllToAll(operand, split_dimension, concat_dimension, split_count,
replica_groups)
引數 | 類型 | 語意 |
---|---|---|
operand |
XlaOp |
n 維輸入陣列 |
split_dimension
|
int64
|
範圍 [0,
n) 中的值,用於命名運算元會沿著哪個維度進行切割 |
concat_dimension
|
int64
|
[0,
n) 區間中的值,用於命名分割區塊連接的維度 |
split_count
|
int64
|
參與此作業的核心數量。如果 replica_groups 為空白,則應為複本數量;否則,應等於每個群組中的複本數量。 |
replica_groups
|
ReplicaGroup 向量 |
每個群組都包含副本 ID 清單。 |
以下是 Alltoall 的範例。
XlaBuilder b("alltoall");
auto x = Parameter(&b, 0, ShapeUtil::MakeShape(F32, {4, 16}), "x");
AllToAll(x, /*split_dimension=*/1, /*concat_dimension=*/0, /*split_count=*/4);
在本例中,有 4 個核心參與 Alltoall。在每個核心上,運算子會沿著第 1 個維度分成 4 個部分,因此每個部分的形狀為 f32[4,4]。這 4 個部分會分散到所有核心。然後,每個核心會依照核心 0 至 4 的順序,沿著維度 0 連結收到的部分。因此,每個核心的輸出內容都具有 f32[16,4] 的形狀。
BatchNormGrad
如需演算法的詳細說明,請參閱 XlaBuilder::BatchNormGrad
和原始批次規格化論文。
計算批次標準化梯度。
BatchNormGrad(operand, scale, mean, variance, grad_output, epsilon,
feature_index)
引數 | 類型 | 語意 |
---|---|---|
operand |
XlaOp |
要正規化的 n 維陣列 (x) |
scale |
XlaOp |
1 維陣列 (\(\gamma\)) |
mean |
XlaOp |
1 維陣列 (\(\mu\)) |
variance |
XlaOp |
1 維陣列 (\(\sigma^2\)) |
grad_output |
XlaOp |
傳遞至 BatchNormTraining (\(\nabla y\)) 的漸層 |
epsilon |
float |
隱私損失值 (\(\epsilon\)) |
feature_index |
int64 |
operand 中的特徵維度索引 |
對於特徵維度中的每個特徵 (feature_index
是 operand
中特徵維度的索引),此運算會針對所有其他維度中的 operand
、offset
和 scale
計算梯度。feature_index
必須是 operand
中特徵維度的有效索引。
這三個漸層是由下列公式定義 (假設 4 維度的陣列為 operand
,且具有特徵維度索引 l
、批次大小 m
和空間大小 w
和 h
):
\[ \begin{split} c_l&= \frac{1}{mwh}\sum_{i=1}^m\sum_{j=1}^w\sum_{k=1}^h \left( \nabla y_{ijkl} \frac{x_{ijkl} - \mu_l}{\sigma^2_l+\epsilon} \right) \\\\ d_l&= \frac{1}{mwh}\sum_{i=1}^m\sum_{j=1}^w\sum_{k=1}^h \nabla y_{ijkl} \\\\ \nabla x_{ijkl} &= \frac{\gamma_{l} }{\sqrt{\sigma^2_{l}+\epsilon} } \left( \nabla y_{ijkl} - d_l - c_l (x_{ijkl} - \mu_{l}) \right) \\\\ \nabla \gamma_l &= \sum_{i=1}^m\sum_{j=1}^w\sum_{k=1}^h \left( \nabla y_{ijkl} \frac{x_{ijkl} - \mu_l}{\sqrt{\sigma^2_{l}+\epsilon} } \right) \\\\\ \nabla \beta_l &= \sum_{i=1}^m\sum_{j=1}^w\sum_{k=1}^h \nabla y_{ijkl} \end{split} \]
輸入 mean
和 variance
代表批次和空間維度的時刻值。
輸出類型是三個句柄的元組:
輸出內容 | 類型 | 語意 |
---|---|---|
grad_operand
|
XlaOp
|
對輸入 operand 的梯度 ($\nabla x$) |
grad_scale
|
XlaOp
|
相對於輸入 scale 的梯度 ($\nabla \gamma$) |
grad_offset
|
XlaOp
|
對輸入 offset 的梯度($\nabla \beta$) |
BatchNormInference
如需演算法的詳細說明,請參閱 XlaBuilder::BatchNormInference
和原始批次規格化論文。
在批次和空間維度中正規化陣列。
BatchNormInference(operand, scale, offset, mean, variance, epsilon,
feature_index)
引數 | 類型 | 語意 |
---|---|---|
operand |
XlaOp |
要正規化的 n 維陣列 |
scale |
XlaOp |
1 維陣列 |
offset |
XlaOp |
1 維陣列 |
mean |
XlaOp |
1 維陣列 |
variance |
XlaOp |
1 維陣列 |
epsilon |
float |
隱私損失值 |
feature_index |
int64 |
operand 中的特徵維度索引 |
對於特徵維度中的每個特徵 (feature_index
是 operand
中特徵維度的索引),此運算會計算所有其他維度的平均值和變異數,並使用平均值和變異數將 operand
中的每個元素正規化。feature_index
必須是 operand
中地圖項目維度的有效索引。
BatchNormInference
等同於呼叫 BatchNormTraining
,但不會為每個批次計算 mean
和 variance
。而是使用輸入的 mean
和 variance
做為預估值。這個運算子的目的是減少推論的延遲時間,因此名稱為 BatchNormInference
。
輸出內容是 n 維的標準化陣列,其形狀與輸入 operand
相同。
BatchNormTraining
如需演算法的詳細說明,請參閱 XlaBuilder::BatchNormTraining
和 the original batch normalization paper
。
在批次和空間維度中正規化陣列。
BatchNormTraining(operand, scale, offset, epsilon, feature_index)
引數 | 類型 | 語意 |
---|---|---|
operand |
XlaOp |
要正規化的 n 維陣列 (x) |
scale |
XlaOp |
1 維陣列 (\(\gamma\)) |
offset |
XlaOp |
1 維陣列 (\(\beta\)) |
epsilon |
float |
隱私損失值 (\(\epsilon\)) |
feature_index |
int64 |
operand 中的特徵維度索引 |
對於特徵維度中的每個特徵 (feature_index
是 operand
中特徵維度的索引),此運算會計算所有其他維度的平均值和變異數,並使用平均值和變異數將 operand
中的每個元素正規化。feature_index
必須是 operand
中地圖項目維度的有效索引。
針對 operand
\(x\) 中包含 m
元素的每個批次,演算法會依照下列方式運作,其中 w
和 h
是空間維度的大小 (假設 operand
是 4 維陣列):
計算特徵維度中每個特徵
l
的批次平均值 \(\mu_l\) : \(\mu_l=\frac{1}{mwh}\sum_{i=1}^m\sum_{j=1}^w\sum_{k=1}^h x_{ijkl}\)計算批次變異數 \(\sigma^2_l\): $\sigma^2l=\frac{1}{mwh}\sum{i=1}^m\sum{j=1}^w\sum{k=1}^h (x_{ijkl} - \mu_l)^2$
將資料正規化、縮放及移位: \(y_{ijkl}=\frac{\gamma_l(x_{ijkl}-\mu_l)}{\sqrt[2]{\sigma^2_l+\epsilon} }+\beta_l\)
為了避免除以零的錯誤,系統會加入 ε 值,通常為小數。
輸出類型是三個 XlaOp
的元組:
輸出內容 | 類型 | 語意 |
---|---|---|
output
|
XlaOp
|
與輸入 operand 形狀相同的 n 維陣列 (y) |
batch_mean |
XlaOp |
1 維陣列 (\(\mu\)) |
batch_var |
XlaOp |
1 維陣列 (\(\sigma^2\)) |
batch_mean
和 batch_var
是使用上述公式,在批次和空間維度中計算的片刻。
BitcastConvertType
另請參閱 XlaBuilder::BitcastConvertType
。
與 TensorFlow 中的 tf.bitcast
類似,會從資料形狀執行元素的位元組轉換運算,轉換至目標形狀。輸入和輸出大小必須相符:例如,s32
元素會透過位元組轉換例行程序成為 f32
元素,而一個 s32
元素會成為四個 s8
元素。位元轉換會以低階轉換方式實作,因此具有不同浮點表示法的機器會產生不同的結果。
BitcastConvertType(operand, new_element_type)
引數 | 類型 | 語意 |
---|---|---|
operand |
XlaOp |
具有 D 維度的 T 型陣列 |
new_element_type |
PrimitiveType |
類型 U |
除了最後一個維度會根據轉換前後的基礎元素大小比率而變更外,運算元和目標形狀的維度必須相符。
來源和目的地元素類型不得為元組。
將位元組轉換為不同寬度的原始類型
BitcastConvert
HLO 指令支援輸出元素類型 T'
的大小不等於輸入元素 T
的大小。由於整個作業在概念上是位元組轉換,且不會變更基礎位元組,因此輸出元素的形狀必須變更。對於 B = sizeof(T), B' =
sizeof(T')
,有兩種可能的情況。
首先,當 B > B'
時,輸出形狀會取得新的最小維度大小 B/B'
。例如:
f16[10,2]{1,0} %output = f16[10,2]{1,0} bitcast-convert(f32[10]{0} %input)
有效量值的規則維持不變:
f16[2]{0} %output = f16[2]{0} bitcast-convert(f32[] %input)
或者,對於 B' > B
,指令要求輸入形狀的最後一個邏輯維度必須等於 B'/B
,而這個維度會在轉換期間捨棄:
f32[10]{0} %output = f32[10]{0} bitcast-convert(f16[10,2]{1,0} %input)
請注意,不同位元寬之間的轉換並非元素式轉換。
廣播
另請參閱 XlaBuilder::Broadcast
。
複製陣列中的資料,為陣列新增維度。
Broadcast(operand, broadcast_sizes)
引數 | 類型 | 語意 |
---|---|---|
operand |
XlaOp |
要複製的陣列 |
broadcast_sizes |
ArraySlice<int64> |
新維度的大小 |
新的維度會插入左側,也就是說,如果 broadcast_sizes
有值 {a0, ..., aN}
,且運算子形狀有維度 {b0, ..., bM}
,則輸出形狀的維度為 {a0, ..., aN, b0, ..., bM}
。
新的維度索引會進入運算元的副本,也就是
output[i0, ..., iN, j0, ..., jM] = operand[j0, ..., jM]
舉例來說,如果 operand
是值為 2.0f
的標量 f32
,而 broadcast_sizes
是 {2, 3}
,則結果會是形狀為 f32[2, 3]
的陣列,且結果中的所有值都會是 2.0f
。
BroadcastInDim
另請參閱 XlaBuilder::BroadcastInDim
。
透過複製陣列中的資料,擴大陣列的大小和維度數量。
BroadcastInDim(operand, out_dim_size, broadcast_dimensions)
引數 | 類型 | 語意 |
---|---|---|
operand |
XlaOp |
要複製的陣列 |
out_dim_size |
ArraySlice<int64> |
目標形狀的尺寸大小 |
broadcast_dimensions |
ArraySlice<int64> |
運算元件形狀的每個維度對應至目標形狀的哪個維度 |
與 Broadcast 類似,但可在任何位置新增維度,並擴充大小為 1 的現有維度。
operand
會廣播至 out_dim_size
所描述的形狀。broadcast_dimensions
會將 operand
的維度對應至目標形狀的維度,也就是將運算元的第 i 維度對應至輸出形狀的 broadcast_dimension[i] 維度。operand
的維度必須為 1,或與所對應輸出形狀中的維度相同。其餘維度則會填入大小為 1 的維度。然後,退化維度廣播會沿著這些退化維度廣播,以達到輸出形狀。如需語意詳細說明,請參閱廣播頁面。
撥打電話
另請參閱 XlaBuilder::Call
。
使用指定的引數叫用運算。
Call(computation, args...)
引數 | 類型 | 語意 |
---|---|---|
computation |
XlaComputation |
具有任意類型 N 個參數的 T_0, T_1, ..., T_{N-1} -> S 類型運算 |
args |
N 個 XlaOp 的序列 |
任意類型的 N 個引數 |
args
的多項式和類型必須與 computation
的參數相符。允許不含 args
。
CompositeCall
另請參閱 XlaBuilder::CompositeCall
。
封裝由其他 StableHLO 作業組成的作業,並接收輸入內容和 composite_attributes 並產生結果。運算子的語意是由分解屬性實作。複合運算可替換為其分解運算,而不會變更程式語意。如果內嵌分解作業無法提供相同的 op 語意,建議使用 custom_call。
版本欄位 (預設為 0) 用於表示組合的語意變更時間。
這個運算是做為具有 is_composite=true
屬性的 kCall
實作。decomposition
欄位是由 computation
屬性指定。前端屬性會儲存前置字串為 composite.
的其餘屬性。
CompositeCall 作業範例:
f32[] call(f32[] %cst), to_apply=%computation, is_composite=true,
frontend_attributes = {
composite.name="foo.bar",
composite.attributes={n = 1 : i32, tensor = dense<1> : tensor<i32>},
composite.version="1"
}
Call(computation, args..., name, composite_attributes, version)
引數 | 類型 | 語意 |
---|---|---|
inputs |
XlaOp |
值的變化數 |
name |
string |
複合物的名稱 |
composite_attributes |
選填 string |
屬性選填字串化字典 |
decomposition |
XlaComputation |
具有任意類型 N 個參數的 T_0, T_1, ..., T_{N-1} -> S 類型運算 |
version |
int64 。 |
數字到版本,更新複合運算的語意 |
Cholesky
另請參閱 XlaBuilder::Cholesky
。
計算一批對稱 (Hermitian) 正定矩陣的 Cholesky 分解。
Cholesky(a, lower)
引數 | 類型 | 語意 |
---|---|---|
a |
XlaOp |
複數或浮點型別的陣列,維度大於 2。 |
lower |
bool |
是否使用 a 的上三角或下三角。 |
如果 lower
是 true
,則會計算下三角矩陣 l
,使 $a = l。l^T$。如果 lower
是 false
,則會計算上三角矩陣 u
,以便\(a = u^T . u\)。
輸入資料只會從 a
的下/上三角讀取,具體取決於 lower
的值。系統會忽略其他三角形的值。輸出資料會在同一三角形中傳回;其他三角形中的值則由實作定義,可以是任何值。
如果 a
的維度超過 2,a
會視為一批矩陣,其中除了次要 2 維度以外,所有都是批次維度。
如果 a
不是對稱 (Hermitian) 正定矩陣,則結果會由實作定義。
限制取值範圍
另請參閱 XlaBuilder::Clamp
。
將運算元組限制在最小值和最大值之間的範圍內。
Clamp(min, operand, max)
引數 | 類型 | 語意 |
---|---|---|
min |
XlaOp |
類型為 T 的陣列 |
operand |
XlaOp |
類型為 T 的陣列 |
max |
XlaOp |
類型為 T 的陣列 |
在給定運算元和最小值與最大值的情況下,如果運算元介於最小值和最大值之間,則會傳回運算元;如果運算元低於這個範圍,則會傳回最小值;如果運算元高於這個範圍,則會傳回最大值。即 clamp(a, x, b) = min(max(a, x), b)
。
這三個陣列的形狀必須相同。或者,min
和/或 max
可以是 T
類型的標量,做為廣播的限制形式。
使用純量 min
和 max
的範例:
let operand: s32[3] = {-1, 5, 9};
let min: s32 = 0;
let max: s32 = 6;
==>
Clamp(min, operand, max) = s32[3]{0, 5, 6};
收合
另請參閱 XlaBuilder::Collapse
和 tf.reshape
作業。
將陣列的維度縮減為一個維度。
Collapse(operand, dimensions)
引數 | 類型 | 語意 |
---|---|---|
operand |
XlaOp |
類型為 T 的陣列 |
dimensions |
int64 向量 |
以順序排列的 T 維度連續子集。 |
Collapse 會將運算元的維度指定子集取代為單一維度。輸入引數是任意型別 T 的陣列,以及維度索引的編譯時間常數向量。維度索引必須是 T 維度的有序 (由低到高) 連續子集。因此,{0, 1, 2}、{0, 1} 或 {1, 2} 都是有效的維度組合,但 {1, 0} 或 {0, 2} 則無效。這些維度會被單一新維度取代,且在維度序列中的位置與所取代的維度相同,新維度的大小則等於原始維度大小的乘積。dimensions
中最低的維度編號,是迴圈巢狀結構中變化最慢的維度 (最主要),該巢狀結構會折疊這些維度,而最高的維度編號則是變化最快的維度 (最次要)。如需更多一般摺疊順序,請參閱 tf.reshape
運算子。
例如,讓 v 為 24 個元素的陣列:
let v = f32[4x2x3] { { {10, 11, 12}, {15, 16, 17} },
{ {20, 21, 22}, {25, 26, 27} },
{ {30, 31, 32}, {35, 36, 37} },
{ {40, 41, 42}, {45, 46, 47} } };
// Collapse to a single dimension, leaving one dimension.
let v012 = Collapse(v, {0,1,2});
then v012 == f32[24] {10, 11, 12, 15, 16, 17,
20, 21, 22, 25, 26, 27,
30, 31, 32, 35, 36, 37,
40, 41, 42, 45, 46, 47};
// Collapse the two lower dimensions, leaving two dimensions.
let v01 = Collapse(v, {0,1});
then v01 == f32[4x6] { {10, 11, 12, 15, 16, 17},
{20, 21, 22, 25, 26, 27},
{30, 31, 32, 35, 36, 37},
{40, 41, 42, 45, 46, 47} };
// Collapse the two higher dimensions, leaving two dimensions.
let v12 = Collapse(v, {1,2});
then v12 == f32[8x3] { {10, 11, 12},
{15, 16, 17},
{20, 21, 22},
{25, 26, 27},
{30, 31, 32},
{35, 36, 37},
{40, 41, 42},
{45, 46, 47} };
CollectivePermute
另請參閱 XlaBuilder::CollectivePermute
。
CollectivePermute 是集體運算,可跨副本傳送及接收資料。
CollectivePermute(operand, source_target_pairs)
引數 | 類型 | 語意 |
---|---|---|
operand |
XlaOp |
n 維輸入陣列 |
source_target_pairs |
<int64, int64> 向量 |
包含 (source_replica_id, target_replica_id) 組合的清單。對於每個組合,運算元式會從來源備份傳送至目標備份。 |
請注意,source_target_pair
有下列限制:
- 任何兩組都不得使用相同的目標備援 ID,也不得使用相同的來源備援 ID。
- 如果備份 ID 不是任何一組的目標,則該備份的輸出內容會是包含 0 的張量,其形狀與輸入內容相同。
串連
另請參閱 XlaBuilder::ConcatInDim
。
連接會從多個陣列運算元組合陣列。陣列的維度數量與每個輸入陣列運算子的維度數量相同 (每個運算子的維度數量必須相同),且會按照指定的順序包含引數。
Concatenate(operands..., dimension)
引數 | 類型 | 語意 |
---|---|---|
operands |
N 個 XlaOp 的序列 |
型別為 T 的 N 個陣列,維度為 [L0, L1, ...]。N 必須大於或等於 1。 |
dimension |
int64 |
[0, N) 區間中的值,用於命名要在 operands 之間連接的維度。 |
除了 dimension
之外,所有維度都必須相同。這是因為 XLA 不支援「不連續」陣列。另請注意,0 維度值無法連接 (因為無法命名連接發生的維度)。
1 維範例:
Concat({ {2, 3}, {4, 5}, {6, 7} }, 0)
>>> {2, 3, 4, 5, 6, 7}
2D 範例:
let a = {
{1, 2},
{3, 4},
{5, 6},
};
let b = {
{7, 8},
};
Concat({a, b}, 0)
>>> {
{1, 2},
{3, 4},
{5, 6},
{7, 8},
}
圖表:
條件式
另請參閱 XlaBuilder::Conditional
。
Conditional(pred, true_operand, true_computation, false_operand,
false_computation)
引數 | 類型 | 語意 |
---|---|---|
pred |
XlaOp |
PRED 類型的純量 |
true_operand |
XlaOp |
類型為 \(T_0\)的引數 |
true_computation |
XlaComputation |
類型為 \(T_0 \to S\)的 XlaComputation |
false_operand |
XlaOp |
類型為 \(T_1\)的引數 |
false_computation |
XlaComputation |
類型為 \(T_1 \to S\)的 XlaComputation |
如果 pred
是 true
,就會執行 true_computation
,如果 pred
是 false
,則會執行 false_computation
,並傳回結果。
true_computation
必須採用單一 \(T_0\) 類型引數,並且會以 true_operand
呼叫,且 true_operand
必須屬於相同類型。false_computation
必須採用單一 \(T_1\) 類型引數,並且會以 false_operand
進行叫用,且 false_operand
必須屬於相同類型。true_computation
和 false_computation
的傳回值類型必須相同。
請注意,系統會根據 pred
的值執行 true_computation
和 false_computation
其中一個。
Conditional(branch_index, branch_computations, branch_operands)
引數 | 類型 | 語意 |
---|---|---|
branch_index |
XlaOp |
S32 類型的純量 |
branch_computations |
N 個 XlaComputation 的序列 |
類型為 \(T_0 \to S , T_1 \to S , ..., T_{N-1} \to S\)的 XlaComputation |
branch_operands |
N 個 XlaOp 的序列 |
\(T_0 , T_1 , ..., T_{N-1}\)類型的引數 |
執行 branch_computations[branch_index]
,並傳回結果。如果 branch_index
是小於 0 或大於等於 N 的 S32
,則會以預設分支版本執行 branch_computations[N-1]
。
每個 branch_computations[b]
都必須使用單一 \(T_b\) 類型引數,並且會以 branch_operands[b]
叫用,且必須為相同類型。每個 branch_computations[b]
的傳回值類型必須相同。
請注意,系統只會根據 branch_index
的值執行其中一個 branch_computations
。
Conv (卷積)
另請參閱 XlaBuilder::Conv
。
與 ConvWithGeneralPadding 相同,但邊框會以簡寫方式指定為 SAME 或 VALID。SAME 填充會以零填充輸入內容 (lhs
),讓輸出內容的形狀與不考量步幅時的輸入內容相同。有效邊框 padding 表示沒有邊框。
ConvWithGeneralPadding (卷積)
另請參閱 XlaBuilder::ConvWithGeneralPadding
。
計算類神經網路中所用的卷積類型。在此,卷積可視為在 n 維基本區域中移動的 n 維視窗,並針對視窗的每個可能位置執行運算。
引數 | 類型 | 語意 |
---|---|---|
lhs |
XlaOp |
(n+2) 維度的輸入陣列 |
rhs |
XlaOp |
核權重 (n+2) 維度陣列 |
window_strides |
ArraySlice<int64> |
核步長的 n-d 陣列 |
padding |
ArraySlice< pair<int64,int64>> |
包含 (低、高) 邊框間距的 n-d 陣列 |
lhs_dilation |
ArraySlice<int64> |
n-d 左手邊擴張係數陣列 |
rhs_dilation |
ArraySlice<int64> |
n-d 右手邊擴張係數陣列 |
feature_group_count |
int64 | 特徵群組數量 |
batch_group_count |
int64 | 批次群組數量 |
讓 n 代表空間維度數量。lhs
引數是描述基底區域的 (n+2) 維陣列。這稱為輸入,雖然 rhs 當然也是輸入。在類神經網路中,這些是輸入啟用。n+2 維度的順序如下:
batch
:這個維度的每個座標都代表要執行卷積的獨立輸入。z/depth/features
:基本區域中的每個 (y,x) 位置都有一個相關聯的向量,會進入這個維度。spatial_dims
:說明n
空間維度,定義視窗移動的基礎區域。
rhs
引數是一個 (n+2) 維陣列,用於描述卷積濾鏡/核/視窗。維度依序如下:
output-z
:輸出的z
維度。input-z
:這個維度的大小乘以feature_group_count
應等於左側z
維度的大小。spatial_dims
:說明n
空間維度,定義在基本區域中移動的 n 維視窗。
window_strides
引數會在空間維度中指定卷積視窗的步距。舉例來說,如果第一個空間維度的步幅為 3,則視窗只能放在第一個空間索引可被 3 整除的座標。
padding
引數會指定要套用至基底區域的零邊框間距數量。填充量可以為負值,負值填充的絕對值表示在執行卷積之前,從指定維度移除的元素數量。padding[0]
指定維度 y
的邊框,padding[1]
則指定維度 x
的邊框。每個組合的第一個元素為低邊框間距,第二個元素為高邊框間距。低邊框間距會套用於較低索引的方向,而高邊框間距會套用於較高索引的方向。舉例來說,如果 padding[1]
是 (2,3)
,則第二個空間維度會在左側加上 2 個零,在右側加上 3 個零。使用填充功能,就等同於在進行卷積之前,將相同的零值插入輸入內容 (lhs
)。
lhs_dilation
和 rhs_dilation
引數會指定在每個空間維度中,分別套用至左側和右側的擴張因子。如果空間維度的擴張因子為 d,則會在該維度的每個項目之間隱含放置 d-1 個洞,進而增加陣列的大小。空白區塊會填入無操作值,對卷積而言,這代表零值。
右側的擴張也稱為 atrous 卷積。詳情請參閱 tf.nn.atrous_conv2d
。左側的擴張也稱為轉置的卷積。詳情請參閱 tf.nn.conv2d_transpose
。
feature_group_count
引數 (預設值為 1) 可用於分組卷積。feature_group_count
必須是輸入和輸出特徵維度的除數。如果 feature_group_count
大於 1,表示在概念上,輸入和輸出特徵維度以及 rhs
輸出特徵維度會平均分割成許多 feature_group_count
群組,每個群組都包含特徵的連續子序列。rhs
的輸入特徵維度必須等於 lhs
輸入特徵維度除以 feature_group_count
(也就是說,它已具有一組輸入特徵的大小)。第 i 組會一起用於計算許多個獨立卷積的 feature_group_count
。這些卷積的結果會在輸出特徵維度中連接在一起。
針對深度卷積,feature_group_count
引數會設為輸入特徵維度,而篩選器會從 [filter_height, filter_width, in_channels, channel_multiplier]
重塑為 [filter_height, filter_width, 1, in_channels * channel_multiplier]
。詳情請參閱 tf.nn.depthwise_conv2d
。
在反向傳播期間,您可以使用 batch_group_count
(預設值 1) 引數為分組篩選器。batch_group_count
必須是 lhs
(輸入) 批次維度的大小除數。如果 batch_group_count
大於 1,表示輸出批次維度的大小應為 input batch
/ batch_group_count
。batch_group_count
必須是輸出功能大小的除數。
輸出形狀的維度如下:
batch
:這個維度的大小乘以batch_group_count
應等於左側batch
維度的大小。z
:與核心 (rhs
) 上的output-z
相同大小。spatial_dims
:每個卷積視窗的有效位置都有一個值。
上圖說明 batch_group_count
欄位的運作方式。實際上,我們會將每個左手邊批次切割成 batch_group_count
群組,並對輸出功能執行相同的操作。接著,針對每個群組,我們會執行成對卷積,並沿著輸出特徵維度連接輸出內容。所有其他維度的作業語意 (地圖和空間) 都保持不變。
卷積視窗的有效位置取決於步幅和填充後的底部區域大小。
為了說明卷積的運作方式,請考慮 2D 卷積,並在輸出內容中選取一些固定的 batch
、z
、y
、x
座標。接著,(y,x)
是視窗在底部區域內的角落位置 (例如左上角,視您解讀空間維度的做法而定)。我們現在有一個從基礎區域取得的 2D 視窗,其中每個 2D 點都與 1D 向量相關聯,因此我們會取得 3D 方塊。從卷積核來看,由於我們已修正輸出座標 z
,因此也有 3D 方塊。兩個方塊的尺寸相同,因此我們可以計算兩個方塊之間元素逐一相乘的總和 (類似於點積積)。這就是輸出值。
請注意,如果 output-z
為 5,則視窗的每個位置都會在輸出內容中產生 5 個值,並輸出至輸出內容的 z
維度。這些值的差異在於所使用的卷積核部分,每個 output-z
座標都會使用不同的 3D 值方塊。因此,您可以將其視為 5 個獨立的卷積,每個卷積都有不同的濾鏡。
以下是 2D 卷積的虛擬程式碼,其中包含填充和步進:
for (b, oz, oy, ox) { // output coordinates
value = 0;
for (iz, ky, kx) { // kernel coordinates and input z
iy = oy*stride_y + ky - pad_low_y;
ix = ox*stride_x + kx - pad_low_x;
if ((iy, ix) inside the base area considered without padding) {
value += input(b, iz, iy, ix) * kernel(oz, iz, ky, kx);
}
}
output(b, oz, oy, ox) = value;
}
ConvertElementType
另請參閱 XlaBuilder::ConvertElementType
。
類似於 C++ 中的元素式 static_cast
,可執行從資料形狀到目標形狀的元素式轉換作業。維度必須相符,且轉換作業必須以元素為單位進行,例如 s32
元素透過 s32
至 f32
的轉換例程式變成 f32
元素。
ConvertElementType(operand, new_element_type)
引數 | 類型 | 語意 |
---|---|---|
operand |
XlaOp |
具有 D 維度的 T 型陣列 |
new_element_type |
PrimitiveType |
類型 U |
運算元和目標形狀的大小必須一致。來源和目的地元素類型不得為元組。
T=s32
到 U=f32
這類轉換會執行標準化整數至浮點轉換例程,例如四捨五入。
let a: s32[3] = {0, 1, 2};
let b: f32[3] = convert(a, f32);
then b == f32[3]{0.0, 1.0, 2.0}
CrossReplicaSum
使用加總運算執行 AllReduce
。
CustomCall
另請參閱 XlaBuilder::CustomCall
。
在運算中呼叫使用者提供的函式。
CustomCall(target_name, args..., shape)
引數 | 類型 | 語意 |
---|---|---|
target_name |
string |
函式名稱。系統會發出以此符號名稱為目標的呼叫指令。 |
args |
N 個 XlaOp 的序列 |
N 個任意類型的引數,會傳遞至函式。 |
shape |
Shape |
函式的輸出形狀 |
無論 args 的類型或類型為何,函式簽名都相同:
extern "C" void target_name(void* out, void** in);
舉例來說,如果 CustomCall 的用法如下:
let x = f32[2] {1,2};
let y = f32[2x3] { {10, 20, 30}, {40, 50, 60} };
CustomCall("myfunc", {x, y}, f32[3x3])
以下是 myfunc
實作範例:
extern "C" void myfunc(void* out, void** in) {
float (&x)[2] = *static_cast<float(*)[2]>(in[0]);
float (&y)[2][3] = *static_cast<float(*)[2][3]>(in[1]);
EXPECT_EQ(1, x[0]);
EXPECT_EQ(2, x[1]);
EXPECT_EQ(10, y[0][0]);
EXPECT_EQ(20, y[0][1]);
EXPECT_EQ(30, y[0][2]);
EXPECT_EQ(40, y[1][0]);
EXPECT_EQ(50, y[1][1]);
EXPECT_EQ(60, y[1][2]);
float (&z)[3][3] = *static_cast<float(*)[3][3]>(out);
z[0][0] = x[1] + y[1][0];
// ...
}
使用者提供的函式不得有副作用,且執行方式必須是冪等的。
Dot
另請參閱 XlaBuilder::Dot
。
Dot(lhs, rhs)
引數 | 類型 | 語意 |
---|---|---|
lhs |
XlaOp |
類型為 T 的陣列 |
rhs |
XlaOp |
類型為 T 的陣列 |
這項運算的確切語意取決於運算元的階層:
輸入 | 輸出 | 語意 |
---|---|---|
向量 [n] dot 向量 [n] |
純量 | 向量內積 |
矩陣 [m x k] dot 向量 [k] |
向量 [m] | 矩陣-向量相乘 |
矩陣 [m x k] dot 矩陣 [k x n] |
矩陣 [m x n] | 矩陣-矩陣相乘 |
此運算會在 lhs
的第二個維度 (如果有 1 個維度,則為第一個維度) 和 rhs
的第一個維度上,執行乘積的總和。這些是「收縮」維度。lhs
和 rhs
的縮減維度必須相同。在實際應用中,可用於在向量之間執行內積、向量/矩陣相乘或矩陣/矩陣相乘。
DotGeneral
另請參閱 XlaBuilder::DotGeneral
。
DotGeneral(lhs, rhs, dimension_numbers)
引數 | 類型 | 語意 |
---|---|---|
lhs |
XlaOp |
類型為 T 的陣列 |
rhs |
XlaOp |
類型為 T 的陣列 |
dimension_numbers |
DotDimensionNumbers |
收縮和批次維度數 |
與 Dot 類似,但可同時為 lhs
和 rhs
指定收縮和批次維度編號。
DotDimensionNumbers 欄位 | 類型 | 語意 |
---|---|---|
lhs_contracting_dimensions
|
repeated int64 | lhs 收縮維度數字 |
rhs_contracting_dimensions
|
repeated int64 | rhs 收縮維度數字 |
lhs_batch_dimensions
|
repeated int64 | lhs 批次維度編號 |
rhs_batch_dimensions
|
repeated int64 | rhs 批次維度編號 |
DotGeneral 會在 dimension_numbers
中指定的收縮維度上執行乘積和運算。
lhs
和 rhs
的相關收縮維度編號不必相同,但必須具有相同的維度大小。
以下是使用收縮維度數字的範例:
lhs = { {1.0, 2.0, 3.0},
{4.0, 5.0, 6.0} }
rhs = { {1.0, 1.0, 1.0},
{2.0, 2.0, 2.0} }
DotDimensionNumbers dnums;
dnums.add_lhs_contracting_dimensions(1);
dnums.add_rhs_contracting_dimensions(1);
DotGeneral(lhs, rhs, dnums) -> { {6.0, 12.0},
{15.0, 30.0} }
lhs
和 rhs
的關聯批量維度編號必須具有相同的維度大小。
以下是包含批次維度數字的範例 (批次大小 2,2x2 矩陣):
lhs = { { {1.0, 2.0},
{3.0, 4.0} },
{ {5.0, 6.0},
{7.0, 8.0} } }
rhs = { { {1.0, 0.0},
{0.0, 1.0} },
{ {1.0, 0.0},
{0.0, 1.0} } }
DotDimensionNumbers dnums;
dnums.add_lhs_contracting_dimensions(2);
dnums.add_rhs_contracting_dimensions(1);
dnums.add_lhs_batch_dimensions(0);
dnums.add_rhs_batch_dimensions(0);
DotGeneral(lhs, rhs, dnums) -> { { {1.0, 2.0},
{3.0, 4.0} },
{ {5.0, 6.0},
{7.0, 8.0} } }
輸入 | 輸出 | 語意 |
---|---|---|
[b0, m, k] dot [b0, k, n] |
[b0, m, n] | 批次矩陣乘法 |
[b0, b1, m, k] dot [b0, b1, k, n] |
[b0, b1, m, n] | 批次矩陣乘法 |
因此,產生的維度編號會以批次維度開頭,接著是 lhs
非收縮/非批次維度,最後是 rhs
非收縮/非批次維度。
DynamicSlice
另請參閱 XlaBuilder::DynamicSlice
。
DynamicSlice 會從動態 start_indices
的輸入陣列中擷取子陣列。size_indices
會傳遞每個維度中切片的大小,並指定每個維度中獨立切片間隔的端點:[start, start + size)。start_indices
的形狀必須為 1 維,且維度大小必須等於 operand
的維度數量。
DynamicSlice(operand, start_indices, size_indices)
引數 | 類型 | 語意 |
---|---|---|
operand |
XlaOp |
類型為 T 的 N 維陣列 |
start_indices |
N 個 XlaOp 的序列 |
包含每個維度切片起始索引的 N 個單點整數清單。值必須大於或等於 0。 |
size_indices |
ArraySlice<int64> |
包含每個維度切片大小的 N 個整數清單。每個值都必須大於零,且開始 + 大小必須小於或等於維度的大小,以免產生模塊的維度大小。 |
有效切片索引的計算方式為,在執行切片之前,針對 [1, N)
中的每個索引 i
套用下列轉換:
start_indices[i] = clamp(start_indices[i], 0, operand.dimension_size[i] - size_indices[i])
這麼做可確保擷取的切片一律在運算子陣列的邊界內。如果切片在套用轉換前已在邊界內,轉換就不會生效。
1 維範例:
let a = {0.0, 1.0, 2.0, 3.0, 4.0}
let s = {2}
DynamicSlice(a, s, {2}) produces:
{2.0, 3.0}
2D 範例:
let b =
{ {0.0, 1.0, 2.0},
{3.0, 4.0, 5.0},
{6.0, 7.0, 8.0},
{9.0, 10.0, 11.0} }
let s = {2, 1}
DynamicSlice(b, s, {2, 2}) produces:
{ { 7.0, 8.0},
{10.0, 11.0} }
DynamicUpdateSlice
另請參閱 XlaBuilder::DynamicUpdateSlice
。
DynamicUpdateSlice 會產生結果,也就是輸入陣列 operand
的值,並在 start_indices
上覆寫切片 update
。update
的形狀會決定要更新的結果子陣列形狀。start_indices
的形狀必須為 1 維,且維度大小必須等於 operand
的維度數量。
DynamicUpdateSlice(operand, update, start_indices)
引數 | 類型 | 語意 |
---|---|---|
operand |
XlaOp |
類型為 T 的 N 維陣列 |
update |
XlaOp |
包含切片更新的 T 型 N 維陣列。更新形狀的每個維度都必須大於零,且開始 + 更新必須小於或等於每個維度的運算元大小,以免產生超出範圍的更新索引。 |
start_indices |
N 個 XlaOp 的序列 |
包含每個維度切片起始索引的 N 個單點整數清單。值必須大於或等於 0。 |
有效切片索引的計算方式為,在執行切片之前,針對 [1, N)
中的每個索引 i
套用下列轉換:
start_indices[i] = clamp(start_indices[i], 0, operand.dimension_size[i] - update.dimension_size[i])
這可確保更新後的切片一律會在運算子陣列的邊界內。如果切片在套用轉換前已在邊界內,轉換就不會生效。
1 維範例:
let a = {0.0, 1.0, 2.0, 3.0, 4.0}
let u = {5.0, 6.0}
let s = {2}
DynamicUpdateSlice(a, u, s) produces:
{0.0, 1.0, 5.0, 6.0, 4.0}
2D 範例:
let b =
{ {0.0, 1.0, 2.0},
{3.0, 4.0, 5.0},
{6.0, 7.0, 8.0},
{9.0, 10.0, 11.0} }
let u =
{ {12.0, 13.0},
{14.0, 15.0},
{16.0, 17.0} }
let s = {1, 1}
DynamicUpdateSlice(b, u, s) produces:
{ {0.0, 1.0, 2.0},
{3.0, 12.0, 13.0},
{6.0, 14.0, 15.0},
{9.0, 16.0, 17.0} }
元素級別二進位算術運算
另請參閱 XlaBuilder::Add
。
支援一組元素級別的二元算術運算。
Op(lhs, rhs)
其中 Op
為 Add
(加法)、Sub
(減法)、Mul
(乘法)、Div
(除法)、Pow
(指數)、Rem
(餘數)、Max
(最大值)、Min
(最小值)、And
(邏輯 AND)、Or
(邏輯 OR)、Xor
(邏輯 XOR)、ShiftLeft
(左移)、ShiftRightArithmetic
(算術右移)、ShiftRightLogical
(邏輯右移)、Atan2
(2 個引數的反正切) 或 Complex
(將實數和虛數部分組合成複數)
引數 | 類型 | 語意 |
---|---|---|
lhs |
XlaOp |
左側運算元:型別為 T 的陣列 |
rhs |
XlaOp |
右側運算元:類型為 T 的陣列 |
引數的形狀必須相似或相容。請參閱廣播說明文件,瞭解形狀相容性的意義。作業的結果具有形狀,這是廣播兩個輸入陣列的結果。在這個變化版本中,系統「不」支援不同階層陣列之間的運算,除非其中一個運算元是純量。
當 Op
為 Rem
時,結果的符號會取自除數,且結果的絕對值一律小於除數的絕對值。
整數除法溢位 (以零為依據的帶符號/不帶符號除法/餘數,或 INT_SMIN
與 -1
的帶符號除法/餘數) 會產生實作定義的值。
以下作業有支援不同維度廣播的替代變化版本:
Op(lhs, rhs, broadcast_dimensions)
其中 Op
與上述相同。這項運算的變化版本應用於不同階層陣列之間的算術運算 (例如將矩陣加到向量)。
額外的 broadcast_dimensions
運算元是整數切片,用於將低維運算元的維度數量擴充至高維運算元的維度數量。broadcast_dimensions
會將較低維度形狀的維度對應至較高維度形狀的維度。展開形狀中未對應的維度會填入大小為 1 的維度。然後沿著這些退化維度廣播形狀,以便讓兩個運算元的形狀相等。廣播頁面會詳細說明語意。
元素級別比較運算
另請參閱 XlaBuilder::Eq
。
支援一組標準元素級別二元比較運算。請注意,比較浮點類型時,會套用標準的 IEEE 754 浮點比較語意。
Op(lhs, rhs)
其中 Op
是 Eq
(等於)、Ne
(不等於)、Ge
(大於或等於)、Gt
(大於)、Le
(小於或等於)、Lt
(小於) 之一。另一組運算子 (EqTotalOrder、NeTotalOrder、GeTotalOrder、GtTotalOrder、LeTotalOrder 和 LtTotalOrder) 提供相同功能,但它們還支援浮點數的總排序,方法是強制執行 -NaN < -Inf < -Finite < -0 < +0 < +Finite < +Inf < +NaN。
引數 | 類型 | 語意 |
---|---|---|
lhs |
XlaOp |
左側運算元:型別為 T 的陣列 |
rhs |
XlaOp |
右側運算元:類型為 T 的陣列 |
引數的形狀必須相似或相容。請參閱廣播說明文件,瞭解形狀相容性的意義。作業的結果具有形狀,這是以元素類型 PRED
廣播兩個輸入陣列的結果。在這個變化版本中,系統不支援不同階層陣列之間的運算,除非其中一個運算元是純量。
以下作業有支援不同維度廣播的替代變化版本:
Op(lhs, rhs, broadcast_dimensions)
其中 Op
與上述相同。這個運算變化版本應用於不同階層陣列之間的比較運算 (例如將矩陣加到向量)。
額外的 broadcast_dimensions
運算元是整數切片,可指定用於廣播運算元的維度。廣播頁面會詳細說明語意。
元素級別的單一函式
XlaBuilder 支援以下元素逐元素單一函式:
Abs(operand)
元素為 abs x -> |x|
。
Cbrt(operand)
元素級別立方根運算 x -> cbrt(x)
。
Ceil(operand)
元素逐元素的圓頂 x -> ⌈x⌉
。
Clz(operand)
逐元素計算前置零。
Cos(operand)
元素逐元素餘弦 x -> cos(x)
。
Erf(operand)
元素級別誤差函數 x -> erf(x)
,其中
\(\text{erf}(x) = \frac{2}{\sqrt{\pi} }\int_0^x e^{-t^2} \, dt\)。
Exp(operand)
元素逐元素自然指數 x -> e^x
。
Expm1(operand)
元素逐元素自然指數減一 x -> e^x - 1
。
Floor(operand)
元素逐元素的底層 x -> ⌊x⌋
。
Imag(operand)
複雜 (或實) 形狀的元素逐項虛部。x -> imag(x)
。如果運算元是浮點型,則會傳回 0。
IsFinite(operand)
會測試 operand
的每個元素是否有限,也就是說,不是正數或負無限大,也不是 NaN
。傳回 PRED
值的陣列,其形狀與輸入值相同,其中每個元素都是 true
,前提是相應的輸入元素必須有限。
Log(operand)
元素逐元素自然對數 x -> ln(x)
。
Log1p(operand)
元素逐元素移位的自然對數 x -> ln(1+x)
。
Logistic(operand)
元素級別邏輯函式計算 x ->
logistic(x)
。
Neg(operand)
元素逐元素否定 x -> -x
。
Not(operand)
元素逐元素邏輯否定 x -> !(x)
。
PopulationCount(operand)
計算 operand
中每個元素設定的位元數。
Real(operand)
複數 (或實數) 形狀的元素逐元素實部。x -> real(x)
。如果運算元為浮點型,則會傳回相同的值。
Round(operand)
元素逐元素四捨五入,與 0 相等的值會捨去。
RoundNearestEven(operand)
元素逐項捨入,並以最接近的整數為準。
Rsqrt(operand)
元素級別平方根運算 x -> 1.0 / sqrt(x)
的倒數。
Sign(operand)
元素級別符號運算 x -> sgn(x)
,其中
\[\text{sgn}(x) = \begin{cases} -1 & x < 0\\ -0 & x = -0\\ NaN & x = NaN\\ +0 & x = +0\\ 1 & x > 0 \end{cases}\]
使用 operand
元素類型的比較運算子。
Sin(operand)
元素逐元素正弦 x -> sin(x)
。
Sqrt(operand)
元素級別平方根運算 x -> sqrt(x)
。
Tan(operand)
元素逐元素的切線 x -> tan(x)
。
Tanh(operand)
元素逐元素的雙曲正切 x -> tanh(x)
。
引數 | 類型 | 語意 |
---|---|---|
operand |
XlaOp |
函式的運算元 |
函式會套用至 operand
陣列中的每個元素,產生形狀相同的陣列。operand
可以是標量 (0 維度)。
Fft
XLA FFT 運算會針對實數和複數輸入/輸出,實作正向和反向的傅立葉變換。支援最多 3 個軸的多維 FFT。
另請參閱 XlaBuilder::Fft
。
引數 | 類型 | 語意 |
---|---|---|
operand |
XlaOp |
我們要進行傅立葉變換的陣列。 |
fft_type |
FftType |
請參閱下表。 |
fft_length |
ArraySlice<int64> |
要轉換的軸的時間域長度。這項操作特別適用於 IRFFT,因為 RFFT(fft_length=[16]) 的輸出形狀與 RFFT(fft_length=[17]) 相同。 |
FftType |
語意 |
---|---|
FFT |
正向複雜-複雜 FFT。形狀不變。 |
IFFT |
反向複雜-複雜 FFT。形狀不變。 |
RFFT |
將實數轉換為複數的 FFT 轉換。如果 fft_length[-1] 為非零值,則最內軸的形狀會縮減為 fft_length[-1] // 2 + 1 ,省略 Nyquist 頻率以外的轉換信號的反相共軛部分。 |
IRFFT |
反向實數至複數 FFT (即取複數,傳回實數)。如果 fft_length[-1] 為非零值,則會將最內軸的形狀展開為 fft_length[-1] ,從 1 到 fft_length[-1] // 2 + 1 項目的反函式推論超出 Nyquist 頻率的轉換信號。 |
多維 FFT
如果提供的 fft_length
超過 1 個,就等同於將一連串的 FFT 運算套用至每個最內側的軸。請注意,對於 real->complex 和 complex->real 情況,最內側的軸轉換會 (實際上) 優先執行 (RFFT;IRFFT 則為最後),因此最內側的軸是會變更大小的軸。其他軸轉換作業則會是 complex->complex。
實作詳情
CPU FFT 由 Eigen 的 TensorFFT 提供支援。GPU FFT 會使用 cuFFT。
Gather
XLA 收集運算會將輸入陣列的多個切片 (每個切片的執行時間偏移可能不同) 拼接在一起。
一般語意
另請參閱 XlaBuilder::Gather
。如需更直觀的說明,請參閱下方的「非正式說明」一節。
gather(operand, start_indices, offset_dims, collapsed_slice_dims,
slice_sizes, start_index_map)
引數 | 類型 | 語意 |
---|---|---|
operand |
XlaOp |
我們要從中收集資料的陣列。 |
start_indices |
XlaOp |
陣列,其中包含所收集陣列片段的起始索引。 |
index_vector_dim |
int64 |
start_indices 中「包含」起始索引的維度。詳情請參閱下文。 |
offset_dims |
ArraySlice<int64> |
輸出形狀中的維度集合,會偏移至從運算子切割的陣列。 |
slice_sizes |
ArraySlice<int64> |
slice_sizes[i] 是維度 i 上切片的邊界。 |
collapsed_slice_dims |
ArraySlice<int64> |
每個切片中已摺疊的維度組合。這些維度的大小必須為 1。 |
start_index_map |
ArraySlice<int64> |
這張對應表說明如何將 start_indices 中的索引對應至運算元的有效索引。 |
indices_are_sorted |
bool |
是否保證索引會由呼叫端排序。 |
為方便起見,我們將輸出陣列中非 offset_dims
的維度標示為 batch_dims
。
輸出結果是具有 batch_dims.size
+ offset_dims.size
維度的陣列。
operand.rank
必須等於 offset_dims.size
和 collapsed_slice_dims.size
的總和。此外,slice_sizes.size
必須等於 operand.rank
。
如果 index_vector_dim
等於 start_indices.rank
,系統會隱含地將 start_indices
視為具有尾隨 1
維度的形狀 (也就是說,如果 start_indices
的形狀為 [6,7]
,而 index_vector_dim
為 2
,則系統會隱含地將 start_indices
的形狀視為 [6,7,1]
)。
沿著維度 i
計算輸出陣列的邊界如下:
如果
i
出現在batch_dims
中 (即等於某些k
的batch_dims[k]
),我們會從start_indices.shape
中挑選對應的維度邊界,並略過index_vector_dim
(即如果k
<index_vector_dim
,則挑選start_indices.shape.dims
[k
];否則挑選start_indices.shape.dims
[k
+1
])。如果
i
出現在offset_dims
中 (也就是某些k
等於offset_dims
[k
]),那麼在考量collapsed_slice_dims
後,我們會從slice_sizes
中挑選對應的邊界 (也就是我們會挑選adjusted_slice_sizes
[k
],其中adjusted_slice_sizes
是slice_sizes
,但已移除索引collapsed_slice_dims
的邊界)。
正式來說,對應給定輸出索引 Out
的運算元索引 In
的計算方式如下:
讓
G
= {Out
[k
] fork
inbatch_dims
}。使用G
切割向量S
,讓S
[i
] =start_indices
[Combine(G
,i
)],其中 Combine(A, b) 會將 b 插入 A 的index_vector_dim
位置。請注意,即使G
為空白,這也是正確的定義:如果G
為空白,則S
=start_indices
。使用
start_index_map
將S
散布到operand
中,並使用S
建立起始索引S
in
。具體來說:S
in
[start_index_map
[k
]] =S
[k
],如果k
<start_index_map.size
。S
in
[_
] =0
,否則為0
。
根據
collapsed_slice_dims
集合,在Out
的偏移維度中散布索引,藉此在operand
中建立索引O
in
。具體來說:O
in
[remapped_offset_dims
(k
)] =Out
[offset_dims
[k
]],如果k
<offset_dims.size
(remapped_offset_dims
定義如下)。O
in
[_
] =0
,否則為0
。
In
是O
in
+S
in
,其中 + 是元素相加。
remapped_offset_dims
是一個單調函式,網域為 [0
, offset_dims.size
),範圍為 [0
, operand.rank
) \ collapsed_slice_dims
。因此,例如offset_dims.size
是 4
,operand.rank
是 6
,collapsed_slice_dims
是 {0
, 2
},則 remapped_offset_dims
是 {0
→1
, 1
→3
, 2
→4
, 3
→5
}。
如果將 indices_are_sorted
設為 true,XLA 可以假設 start_indices
已由使用者排序 (依升冪順序,在根據 start_index_map
散布其值之後)。如果不是,則語意是由實作定義。
非正式說明和範例
非正式地說,輸出陣列中的每個索引 Out
都對應至運算子陣列中的元素 E
,計算方式如下:
我們會使用
Out
中的批次維度,從start_indices
查詢起始索引。我們使用
start_index_map
將起始索引 (大小可能小於 operand.rank) 對應至operand
中的「完整」起始索引。我們會使用完整的起始索引,動態切出大小為
slice_sizes
的切片。我們會透過收合
collapsed_slice_dims
維度來調整切片形狀。由於所有已摺疊的切片維度都必須有 1 個邊界,因此這個重塑作業一律合法。我們使用
Out
中的偏移維度索引至此切片,以便取得與輸出索引Out
相對應的輸入元素E
。
在後續所有範例中,index_vector_dim
都設為 start_indices.rank
- 1
。index_vector_dim
的其他有趣值不會從根本上改變運算,但會使視覺呈現更加繁瑣。
為了讓您能直觀瞭解上述所有內容如何搭配使用,我們將舉例說明如何從 [16,11]
陣列中收集 5 個形狀 [8,6]
的切片。在 [16,11]
陣列中,切片的位置可表示為形狀為 S64[2]
的索引向量,因此 5 個位置的組合可表示為 S64[5,2]
陣列。
接著,收集作業的行為可視為索引轉換,該轉換會取得 [G
,O
0
,O
1
],也就是輸出形狀中的索引,並以以下方式將其對應至輸入陣列中的元素:
我們會先使用 G
從收集索引陣列中選取 (X
,Y
) 向量。輸出陣列中索引為 [G
,O
0
,O
1
] 的元素,就是輸入陣列中索引為 [X
+O
0
,Y
+O
1
] 的元素。
slice_sizes
是 [8,6]
,可決定 O0
和 O1
的範圍,進而決定切片的邊界。
這個收集作業會做為批次動態切片,並將 G
做為批次維度。
收集索引可能為多維。舉例來說,上例使用形狀為 [4,5,2]
的「收集索引」陣列,其較通用的版本會將索引轉譯為以下形式:
再次提醒,這會做為批次動態切片 G
0
,而 G
1
則是做為批次維度。切片大小仍為 [8,6]
。
XLA 中的收集運算會以以下方式概略上述非正式語意:
我們可以設定輸出形狀中的哪些維度為偏移維度 (包含
O
0
的維度,在最後一個範例中為O
1
)。輸出批次維度 (包含G
0
的維度,在上一例中為G
1
) 定義為非偏移維度的輸出維度。輸出形狀中明確顯示的輸出偏移維度數量,可能會小於輸入維度數量。這些「缺少」的維度 (明確列為
collapsed_slice_dims
) 必須具有1
的切片大小。由於這些元素的切片大小為1
,因此唯一有效的索引為0
,而省略這些元素不會造成歧義。從「Gather Indices」陣列 (上一個範例中的
X
、Y
) 擷取的切片,可能包含的元素數量少於輸入陣列的維度數量,而明確的對應方式會決定如何擴充索引,讓索引的維度數量與輸入陣列相同。
最後一個範例,我們使用 (2) 和 (3) 實作 tf.gather_nd
:
G
0
和 G
1
用於從收集索引陣列中切割出起始索引,這與平常一樣,只是起始索引只有一個元素 X
。同樣地,只有一個輸出偏移索引,其值為 O
0
。不過,在用於輸入陣列的索引之前,這些索引會根據「Gather Index Mapping」(正式說明中的 start_index_map
) 和「Offset Mapping」(正式說明中的 remapped_offset_dims
) 擴展為 [X
,0
] 和 [0
,O
0
],總計為 [X
,O
0
]。換句話說,輸出索引 [G
0
,G
1
,O
0
] 會對應至輸入索引 [GatherIndices
[G
0
,G
1
,0
],O
0
],這會為我們提供 tf.gather_nd
的語意。
在本例中,slice_sizes
為 [1,11]
。這表示收集索引陣列中的每個索引 X
會挑選整個資料列,而結果則是所有資料列的串連。
GetDimensionSize
另請參閱 XlaBuilder::GetDimensionSize
。
傳回運算元的指定維度大小。運算元必須為陣列形狀。
GetDimensionSize(operand, dimension)
引數 | 類型 | 語意 |
---|---|---|
operand |
XlaOp |
n 維輸入陣列 |
dimension |
int64 |
指定維度的 [0, n) 間隔值 |
SetDimensionSize
另請參閱 XlaBuilder::SetDimensionSize
。
設定 XlaOp 指定維度的動態大小。運算元必須為陣列形狀。
SetDimensionSize(operand, size, dimension)
引數 | 類型 | 語意 |
---|---|---|
operand |
XlaOp |
n 維輸入陣列。 |
size |
XlaOp |
int32,代表執行階段的動態大小。 |
dimension |
int64 |
指定維度的 [0, n) 間隔值。 |
將運算元做為結果傳遞,並由編譯器追蹤動態維度。
下游的縮減運算作業會忽略填補的值。
let v: f32[10] = f32[10]{1, 2, 3, 4, 5, 6, 7, 8, 9, 10};
let five: s32 = 5;
let six: s32 = 6;
// Setting dynamic dimension size doesn't change the upper bound of the static
// shape.
let padded_v_five: f32[10] = set_dimension_size(v, five, /*dimension=*/0);
let padded_v_six: f32[10] = set_dimension_size(v, six, /*dimension=*/0);
// sum == 1 + 2 + 3 + 4 + 5
let sum:f32[] = reduce_sum(padded_v_five);
// product == 1 * 2 * 3 * 4 * 5
let product:f32[] = reduce_product(padded_v_five);
// Changing padding size will yield different result.
// sum == 1 + 2 + 3 + 4 + 5 + 6
let sum:f32[] = reduce_sum(padded_v_six);
GetTupleElement
另請參閱 XlaBuilder::GetTupleElement
。
使用編譯時間常數值索引至元組。
值必須是編譯時間常數,這樣形狀推論才能判斷產生值的類型。
這類似於 C++ 中的 std::get<int N>(t)
。概念上:
let v: f32[10] = f32[10]{0, 1, 2, 3, 4, 5, 6, 7, 8, 9};
let s: s32 = 5;
let t: (f32[10], s32) = tuple(v, s);
let element_1: s32 = gettupleelement(t, 1); // Inferred shape matches s32.
另請參閱 tf.tuple
。
動態內
另請參閱 XlaBuilder::Infeed
。
Infeed(shape)
引數 | 類型 | 語意 |
---|---|---|
shape |
Shape |
從 Infeed 介面讀取的資料形狀。形狀的版面配置欄位必須設為與傳送至裝置的資料版面配置相符,否則其行為未定義。 |
從裝置的隱含動態饋給串流介面讀取單一資料項目,將資料解讀為指定形狀及其版面配置,並傳回資料的 XlaOp
。計算中允許有多個中介內容操作,但中介內容操作之間必須有總順序。舉例來說,下方程式碼中的兩個 Infeed 具有總順序,因為 while 迴圈之間存在依附關係。
result1 = while (condition, init = init_value) {
Infeed(shape)
}
result2 = while (condition, init = result1) {
Infeed(shape)
}
不支援巢狀元組形狀。對於空的元組形狀,Infeed 作業實際上是無操作,且在未從裝置的 Infeed 讀取任何資料的情況下繼續執行。
器皿打擊樂
另請參閱 XlaBuilder::Iota
。
Iota(shape, iota_dimension)
在裝置上建構常數字面值,而非可能龐大的主機轉移。建立具有指定形狀的陣列,並保留從零開始的值,並沿著指定維度遞增 1。對於浮點類型,產生的陣列等同於 ConvertElementType(Iota(...))
,其中 Iota
為整數類型,且轉換為浮點類型。
引數 | 類型 | 語意 |
---|---|---|
shape |
Shape |
由 Iota() 建立的陣列形狀 |
iota_dimension |
int64 |
要遞增的維度。 |
舉例來說,Iota(s32[4, 8], 0)
會傳回
[[0, 0, 0, 0, 0, 0, 0, 0 ],
[1, 1, 1, 1, 1, 1, 1, 1 ],
[2, 2, 2, 2, 2, 2, 2, 2 ],
[3, 3, 3, 3, 3, 3, 3, 3 ]]
可退貨 (費用:Iota(s32[4, 8], 1)
)
[[0, 1, 2, 3, 4, 5, 6, 7 ],
[0, 1, 2, 3, 4, 5, 6, 7 ],
[0, 1, 2, 3, 4, 5, 6, 7 ],
[0, 1, 2, 3, 4, 5, 6, 7 ]]
地圖
另請參閱 XlaBuilder::Map
。
Map(operands..., computation)
引數 | 類型 | 語意 |
---|---|---|
operands |
N 個 XlaOp 的序列 |
類型 T 的 N 個陣列 0..T{N-1} |
computation |
XlaComputation |
型別為 T_0, T_1, .., T_{N + M -1} -> S 的運算,其中 N 個參數為型別 T,而 M 為任意型別 |
dimensions |
int64 陣列 |
地圖維度的陣列 |
對指定的 operands
陣列套用標量函式,產生相同維度的陣列,其中每個元素都是將對應函式套用至輸入陣列中的對應元素所產生的結果。
對應函式是任意運算,但有以下限制:它有 N 個純量類型 T
的輸入,以及單一 S
類型的輸出。輸出內容的維度與運算元件的維度相同,但元素類型 T 已替換為 S。
例如:Map(op1, op2, op3, computation, par1)
會將 elem_out <-
computation(elem1, elem2, elem3, par1)
對應至輸入陣列中的每個 (多維) 索引,以產生輸出陣列。
OptimizationBarrier
阻止任何最佳化階段在邊界之間移動運算。
確保在任何依賴分隔符輸出的運算子之前,先評估所有輸入內容。
熱敷墊
另請參閱 XlaBuilder::Pad
。
Pad(operand, padding_value, padding_config)
引數 | 類型 | 語意 |
---|---|---|
operand |
XlaOp |
T 類型的陣列 |
padding_value |
XlaOp |
T 類型的純量,用於填入新增的邊框 |
padding_config |
PaddingConfig |
兩側邊框間距 (低、高) 和各維度元素之間的間距 |
使用指定的 padding_value
在陣列周圍和陣列元素之間加上邊框,藉此擴展指定的 operand
陣列。padding_config
會指定每個維度的邊框間距和內部間距。
PaddingConfig
是 PaddingConfigDimension
的重複欄位,其中包含每個維度的三個欄位:edge_padding_low
、edge_padding_high
和 interior_padding
。
edge_padding_low
和 edge_padding_high
分別指定在各維度的低端 (靠近索引 0) 和高端 (靠近最高索引) 新增的邊框間距量。邊緣邊框的邊框寬度可以為負值,負邊框的絕對值表示從指定維度移除的元素數量。
interior_padding
會指定在每個維度的任何兩個元素之間加入的邊框間距量,且不得為負值。內部邊框間距在邏輯上會出現在邊框間距之前,因此在邊框間距為負值的情況下,系統會從內部邊框間距運算元中移除元素。
如果邊緣邊框組合皆為 (0, 0),且內部邊框值皆為 0,則此作業會是無操作。下圖顯示了二維陣列的不同 edge_padding
和 interior_padding
值範例。
Recv
另請參閱 XlaBuilder::Recv
。
Recv(shape, channel_handle)
引數 | 類型 | 語意 |
---|---|---|
shape |
Shape |
要接收的資料形狀 |
channel_handle |
ChannelHandle |
每個傳送/接收組合的專屬 ID |
在共用相同管道句柄的其他運算中,從 Send
指令接收指定形狀的資料。針對已接收的資料傳回 XlaOp。
Recv
作業的用戶端 API 代表同步通訊。不過,這項指令會在內部分解為 2 個 HLO 指令 (Recv
和 RecvDone
),以便啟用非同步資料傳輸。另請參閱 HloInstruction::CreateRecv
和 HloInstruction::CreateRecvDone
。
Recv(const Shape& shape, int64 channel_id)
分配接收資料所需的資源,這些資料來自具有相同 channel_id 的 Send
指令。傳回已分配資源的內容,後續 RecvDone
指令會使用該內容,等待資料傳輸作業完成。這個內容是 {接收緩衝區 (形狀)、要求 ID (U32)} 的元組,且只能由 RecvDone
指令使用。
RecvDone(HloInstruction context)
提供 Recv
指令建立的內容,等待資料傳輸完成,並傳回收到的資料。
遏止
另請參閱 XlaBuilder::Reduce
。
將縮減函式並行套用至一或多個陣列。
Reduce(operands..., init_values..., computation, dimensions)
引數 | 類型 | 語意 |
---|---|---|
operands |
N 個 XlaOp 的序列 |
類型為 T_0, ..., T_{N-1} 的 N 個陣列。 |
init_values |
N 個 XlaOp 的序列 |
型別為 T_0, ..., T_{N-1} 的 N 個純量。 |
computation |
XlaComputation |
類型為 T_0, ..., T_{N-1}, T_0, ..., T_{N-1} -> Collate(T_0, ..., T_{N-1}) 的運算。 |
dimensions |
int64 陣列 |
要縮減的維度無序陣列。 |
在此情況下:
- N 必須大於或等於 1。
- 計算必須「大致」具有結合性 (請參閱下文)。
- 所有輸入陣列的維度都必須相同。
- 所有初始值都必須在
computation
下形成一個識別值。 - 如果是
N = 1
,Collate(T)
就是T
。 - 如果是
N > 1
,Collate(T_0, ..., T_{N-1})
就是T
類型的N
元素元組。
這個運算會將每個輸入陣列的一或多個維度縮減為標量。每個傳回陣列的維度數量為 number_of_dimensions(operand) - len(dimensions)
。此運算子的輸出值為 Collate(Q_0, ..., Q_N)
,其中 Q_i
是 T_i
類型的陣列,其維度如下所述。
允許不同的後端重新連結減法運算。這可能會導致數值差異,因為某些減法函式 (例如加法) 無法與浮點值建立關聯。不過,如果資料範圍有限,浮點加法在大多數實際用途上就足以達到關聯性。
範例
當您使用值 [10, 11,
12, 13]
的單一 1D 陣列,透過縮減函式 f
(即 computation
) 縮減單一維度時,則可計算為
f(10, f(11, f(12, f(init_value, 13)))
但還有許多其他可能性,例如
f(init_value, f(f(10, f(init_value, 11)), f(f(init_value, 12), f(init_value, 13))))
以下是粗略的模擬程式碼範例,說明如何實作減法,使用加總做為減法運算,並將初始值設為 0。
result_shape <- remove all dims in dimensions from operand_shape
# Iterate over all elements in result_shape. The number of r's here is equal
# to the number of dimensions of the result.
for r0 in range(result_shape[0]), r1 in range(result_shape[1]), ...:
# Initialize this result element
result[r0, r1...] <- 0
# Iterate over all the reduction dimensions
for d0 in range(dimensions[0]), d1 in range(dimensions[1]), ...:
# Increment the result element with the value of the operand's element.
# The index of the operand's element is constructed from all ri's and di's
# in the right order (by construction ri's and di's together index over the
# whole operand shape).
result[r0, r1...] += operand[ri... di]
以下是 2D 陣列 (矩陣) 縮減的範例。形狀有 2 個維度,其中維度 0 的大小為 2,維度 1 的大小為 3:
使用「add」函式減少維度 0 或 1 的結果:
請注意,兩個縮減結果都是 1D 陣列。圖表中顯示一個為欄,另一個為列,只是為了方便查看。
以下是 3D 陣列的較複雜範例。其維度數量為 3,其中維度 0 的大小為 4,維度 1 的大小為 2,維度 2 的大小為 3。為了簡單起見,我們會在維度 0 中複製 1 到 6 的值。
與 2D 示例類似,我們可以只縮減一個維度。舉例來說,如果我們減少維度 0,就會得到 2 維陣列,其中維度 0 的所有值都會折疊為單一值:
| 4 8 12 |
| 16 20 24 |
如果我們縮減第 2 個維度,也會得到一個 2 維陣列,其中第 2 個維度的所有值都會折疊為一個標量:
| 6 15 |
| 6 15 |
| 6 15 |
| 6 15 |
請注意,輸入內容中其餘維度之間的相對順序會保留在輸出內容中,但部分維度可能會指派新的編號 (因為維度數量會改變)。
我們也可以減少多個維度。新增並減少維度 0 和 1,會產生 1D 陣列 [20, 28, 36]
。
在所有維度上縮減 3D 陣列,會產生標量 84
。
變化資料縮減
在 N > 1
時,reduce 函式應用程序會同時套用至所有輸入內容,因此會稍微複雜一些。運算子會以以下順序提供給運算:
- 為第一個運算元執行已降低的值
- ...
- 執行第 N 個運算元的減數值
- 輸入第一個運算元的值
- ...
- 第 N 個運算元的輸入值
舉例來說,請考慮下列縮減函式,可用於平行計算 1 維陣列的最大值和 argmax:
f: (Float, Int, Float, Int) -> Float, Int
f(max, argmax, value, index):
if value >= max:
return (value, index)
else:
return (max, argmax)
針對 1 維輸入陣列 V = Float[N], K = Int[N]
和初始值 I_V = Float, I_K = Int
,在單一輸入維度中縮減的結果 f_(N-1)
等同於以下遞迴應用程式:
f_0 = f(I_V, I_K, V_0, K_0)
f_1 = f(f_0.first, f_0.second, V_1, K_1)
...
f_(N-1) = f(f_(N-2).first, f_(N-2).second, V_(N-1), K_(N-1))
將這項縮減作業套用至值陣列和序列索引陣列 (即 iota),會同時對陣列進行迭代,並傳回包含最大值和相符索引的元組。
ReducePrecision
另請參閱 XlaBuilder::ReducePrecision
。
模擬將浮點值轉換為精確度較低的格式 (例如 IEEE-FP16) 並還原為原始格式的效果。您可以任意指定較低精確度格式中的指數和尾數位元數量,但所有硬體實作可能不支援所有位元大小。
ReducePrecision(operand, mantissa_bits, exponent_bits)
引數 | 類型 | 語意 |
---|---|---|
operand |
XlaOp |
浮點類型 T 的陣列。 |
exponent_bits |
int32 |
較低精度格式中的指數位元數 |
mantissa_bits |
int32 |
較低精確度格式的小數值位元數 |
結果為 T
類型的陣列。輸入值會四捨五入至最接近的值,可使用指定的尾數位元數量 (使用「ties to even」語義),任何超過指數位元數量所指定範圍的值都會被箝制為正無窮或負無窮。NaN
值會保留,但可能會轉換為標準 NaN
值。
較低精確度的格式必須至少包含一個指數位元 (為了區分零值和無窮大,因為兩者都具有零尾數位元),且必須包含非負數的尾數位元位元。指數或尾數位元數量可能超過類型 T
的對應值;轉換的對應部分就會變成無操作。
ReduceScatter
另請參閱 XlaBuilder::ReduceScatter
。
ReduceScatter 是集體運算,可有效執行 AllReduce,然後沿著 scatter_dimension
將結果分割成 shard_count
區塊,並在複本群組中接收 ith
資料分割的複本 i
。
ReduceScatter(operand, computation, scatter_dim, shard_count,
replica_group_ids, channel_id)
引數 | 類型 | 語意 |
---|---|---|
operand |
XlaOp |
陣列或陣列的非空元組,用於在備用資源之間執行縮減作業。 |
computation |
XlaComputation |
減法運算 |
scatter_dimension |
int64 |
散布圖的維度。 |
shard_count |
int64 |
要分割的區塊數量 scatter_dimension |
replica_groups |
int64 的向量向量 |
要執行減法運算的群組 |
channel_id |
選填 int64 |
跨模組通訊的選用管道 ID |
- 如果
operand
是陣列的元組,則會針對元組的每個元素執行 reduce-scatter。 replica_groups
是執行縮減作業的備份群組清單 (可使用ReplicaId
擷取目前備份的備份 ID)。每個群組中的備份順序,決定了全縮減結果的散發順序。replica_groups
必須為空白 (在這種情況下,所有副本都屬於單一群組),或包含與副本數量相同的元素。如果有超過一個複本群組,則所有群組的大小都必須相同。舉例來說,replica_groups = {0, 2}, {1, 3}
會在複本0
和2
之間,以及1
和3
之間執行縮減作業,然後散布結果。shard_count
是每個複本群組的大小。在replica_groups
為空白的情況下,我們需要這個值。如果replica_groups
非空白,shard_count
必須等於每個複本群組的大小。channel_id
用於跨模組通訊:只有具有相同channel_id
的reduce-scatter
作業才能相互通訊。
輸出形狀是輸入形狀,其中 scatter_dimension
縮小了 shard_count
倍。舉例來說,如果有兩個副本,且運算元在兩個副本中分別具有 [1.0, 2.25]
和 [3.0, 5.25]
值,則這個運算子的輸出值 (scatter_dim
為 0
) 將是第一個副本的 [4.0]
,第二個副本的 [7.5]
。
ReduceWindow
另請參閱 XlaBuilder::ReduceWindow
。
將縮減函式套用至 N 個多維陣列序列的每個視窗中的所有元素,產生單一或 N 個多維陣列的元組做為輸出。每個輸出陣列的元素數量都與視窗的有效位置數量相同。匯集層可以用 ReduceWindow
表示。與 Reduce
類似,套用的 computation
一律會傳遞左側的 init_values
。
ReduceWindow(operands..., init_values..., computation, window_dimensions,
window_strides, padding)
引數 | 類型 | 語意 |
---|---|---|
operands |
N XlaOps |
一系列 T_0,..., T_{N-1} 類型的 N 個多維陣列,每個陣列都代表窗口放置的基礎區域。 |
init_values |
N XlaOps |
運算的 N 個起始值,每個運算元式各一個。詳情請參閱「調降」一節。 |
computation |
XlaComputation |
T_0, ..., T_{N-1}, T_0, ..., T_{N-1} -> Collate(T_0, ..., T_{N-1}) 類型的簡化函式,可套用至所有輸入運算元的每個區間元素。 |
window_dimensions |
ArraySlice<int64> |
視窗維度值的整數陣列 |
window_strides |
ArraySlice<int64> |
視窗步距值的整數陣列 |
base_dilations |
ArraySlice<int64> |
用於基本放大值的整數陣列 |
window_dilations |
ArraySlice<int64> |
用於窗口擴大值的整數陣列 |
padding |
Padding |
視窗的邊框類型 (Padding::kSame,如果步幅為 1,則會填充邊框,使輸出形狀與輸入形狀相同;或 Padding::kValid,不使用邊框,且在視窗無法再填充時「停止」) |
在此情況下:
- N 必須大於或等於 1。
- 所有輸入陣列的維度都必須相同。
- 如果是
N = 1
,Collate(T)
就是T
。 - 如果是
N > 1
,Collate(T_0, ..., T_{N-1})
就是(T0,...T{N-1})
類型的N
元素元組。
以下程式碼和圖表顯示使用 ReduceWindow
的範例。輸入內容是大小為 [4x6] 的矩陣,且 window_dimensions 和 window_stride_dimensions 都是 [2x3]。
// Create a computation for the reduction (maximum).
XlaComputation max;
{
XlaBuilder builder(client_, "max");
auto y = builder.Parameter(0, ShapeUtil::MakeShape(F32, {}), "y");
auto x = builder.Parameter(1, ShapeUtil::MakeShape(F32, {}), "x");
builder.Max(y, x);
max = builder.Build().value();
}
// Create a ReduceWindow computation with the max reduction computation.
XlaBuilder builder(client_, "reduce_window_2x3");
auto shape = ShapeUtil::MakeShape(F32, {4, 6});
auto input = builder.Parameter(0, shape, "input");
builder.ReduceWindow(
input,
/*init_val=*/builder.ConstantLiteral(LiteralUtil::MinValue(F32)),
*max,
/*window_dimensions=*/{2, 3},
/*window_stride_dimensions=*/{2, 3},
Padding::kValid);
在維度中,步幅為 1 表示維度中視窗的位置與相鄰視窗相距 1 個元素。為指定不重疊的視窗,window_stride_dimensions 應等於 window_dimensions。下圖說明瞭使用兩個不同的步幅值。填充會套用至輸入內容的每個維度,而計算結果與輸入內容填充後的維度相同。
針對非簡單的填充範例,請考慮在輸入陣列 [10000, 1000, 100, 10, 1]
上,使用維度 3
和步幅 2
計算 reduce-window 最小值 (初始值為 MAX_FLOAT
)。Padding kValid
會在兩個有效視窗 ([10000, 1000, 100]
和 [100, 10, 1]
) 中計算最小值,並產生輸出 [100, 1]
。Padding kSame
會先為陣列填入邊框,藉此在兩側新增初始元素,讓縮減視窗後的形狀與第 1 步的輸入相同,進而取得 [MAX_VALUE, 10000, 1000, 100, 10, 1,
MAX_VALUE]
。在經過填補的陣列上執行 reduce-window 會對三個視窗 [MAX_VALUE, 10000, 1000]
、[1000, 100, 10]
、[10, 1, MAX_VALUE]
運算,並產生 [1000, 10, 1]
。
縮減函式的評估順序為任意順序,且可能是非決定性的。因此,縮減函式不應對重新關聯過度敏感。如需詳細資訊,請參閱 Reduce
的相關討論內容。
ReplicaId
另請參閱 XlaBuilder::ReplicaId
。
傳回複本的專屬 ID (U32 單一值)。
ReplicaId()
每個副本的專屬 ID 是 [0, N)
範圍內的無符號整數,其中 N
是副本數量。由於所有副本都執行相同的程式,因此程式中的 ReplicaId()
呼叫會在每個副本上傳回不同的值。
Reshape
另請參閱 XlaBuilder::Reshape
和 Collapse
作業。
將陣列的維度重新調整為新設定。
Reshape(operand, dimensions)
引數 | 類型 | 語意 |
---|---|---|
operand |
XlaOp |
類型為 T 的陣列 |
dimensions |
int64 向量 |
新維度的大小向量 |
從概念上來說,重新調整形狀會先將陣列扁平化為資料值的一維向量,然後將這個向量精緻化為新的形狀。輸入引數是任意型別 T 的陣列、編譯時間常數向量的維度索引,以及結果的編譯時間常數向量維度大小。dimensions
向量會決定輸出陣列的大小。dimensions
中索引 0 的值是維度 0 的大小,索引 1 的值是維度 1 的大小,以此類推。dimensions
維度的乘積必須等於運算元的維度大小乘積。當您將已摺疊的陣列精緻化為 dimensions
定義的多維陣列時,dimensions
中的維度會依變化速度由慢到快排序 (最主要) 和 (最次要)。
例如,讓 v 為 24 個元素的陣列:
let v = f32[4x2x3] { { {10, 11, 12}, {15, 16, 17} },
{ {20, 21, 22}, {25, 26, 27} },
{ {30, 31, 32}, {35, 36, 37} },
{ {40, 41, 42}, {45, 46, 47} } };
let v012_24 = Reshape(v, {24});
then v012_24 == f32[24] {10, 11, 12, 15, 16, 17, 20, 21, 22, 25, 26, 27,
30, 31, 32, 35, 36, 37, 40, 41, 42, 45, 46, 47};
let v012_83 = Reshape(v, {8,3});
then v012_83 == f32[8x3] { {10, 11, 12}, {15, 16, 17},
{20, 21, 22}, {25, 26, 27},
{30, 31, 32}, {35, 36, 37},
{40, 41, 42}, {45, 46, 47} };
在特殊情況下,reshape 可將單一元素陣列轉換為標量,反之亦然。例如:
Reshape(f32[1x1] { {5} }, {}) == 5;
Reshape(5, {1,1}) == f32[1x1] { {5} };
Rev (倒轉)
另請參閱 XlaBuilder::Rev
。
Rev(operand, dimensions)
引數 | 類型 | 語意 |
---|---|---|
operand |
XlaOp |
類型為 T 的陣列 |
dimensions |
ArraySlice<int64> |
要反轉的維度 |
沿著指定的 dimensions
反轉 operand
陣列中的元素順序,產生相同形狀的輸出陣列。多維索引運算子陣列的每個元素都會儲存至轉換索引的輸出陣列中。多維索引會透過反轉每個要反轉維度的索引來轉換 (也就是說,如果大小為 N 的維度是其中一個反轉維度,其索引 i 會轉換為 N - 1 - i)。
Rev
運算的其中一個用途,是在神經網路中沿著兩個視窗維度反轉卷積權重陣列。
RngNormal
另請參閱 XlaBuilder::RngNormal
。
根據 \(N(\mu, \sigma)\) 常態分佈產生的隨機號碼,建構特定形狀的輸出內容。參數 \(\mu\) 和 \(\sigma\),以及輸出形狀必須具有浮點元素類型。此外,參數必須是純量值。
RngNormal(mu, sigma, shape)
引數 | 類型 | 語意 |
---|---|---|
mu |
XlaOp |
指定產生數字平均值的 T 型純量 |
sigma |
XlaOp |
指定產生標準差的 T 類型單值 |
shape |
Shape |
輸出型別 T 的形狀 |
RngUniform
另請參閱 XlaBuilder::RngUniform
。
根據在區間 \([a,b)\)內均勻分布的隨機號碼,建構指定形狀的輸出內容。參數和輸出元素類型必須是布林值類型、整數類型或浮點類型,且類型必須一致。CPU 和 GPU 後端目前僅支援 F64、F32、F16、BF16、S64、U64、S32 和 U32。此外,參數必須是純量值。如果 \(b <= a\) 為結果,則由實作定義。
RngUniform(a, b, shape)
引數 | 類型 | 語意 |
---|---|---|
a |
XlaOp |
指定間隔下限的 T 型別向量 |
b |
XlaOp |
指定間隔上限的 T 型別標量 |
shape |
Shape |
輸出型別 T 的形狀 |
RngBitGenerator
使用指定的演算法 (或後端預設值),產生以特定形狀填入均勻隨機位元的輸出內容,並傳回更新的狀態 (與初始狀態相同的形狀) 和產生的隨機資料。
初始狀態是目前隨機號碼產生的初始狀態。這項屬性和所需形狀及有效值,取決於所使用的演算法。
輸出內容保證為初始狀態的決定性函式,但不保證在後端和不同編譯器版本之間具有決定性。
RngBitGenerator(algorithm, key, shape)
引數 | 類型 | 語意 |
---|---|---|
algorithm |
RandomAlgorithm |
要使用的 PRNG 演算法。 |
initial_state |
XlaOp |
PRNG 演算法的初始狀態。 |
shape |
Shape |
產生資料的輸出形狀。 |
algorithm
的可用值:
rng_default
:後端專屬演算法,具有後端專屬形狀需求。rng_three_fry
:ThreeFry 計數器式 PRNG 演算法。initial_state
形狀是u64[2]
,其中包含任意值。Salmon et al. SC 2011. 並行隨機號碼:輕鬆 3 步驟完成rng_philox
:Philox 演算法可並行產生隨機數字。initial_state
形狀是u64[3]
,含有任意值。Salmon et al. SC 2011. 並行隨機號碼:輕鬆 3 步驟完成
散布圖
XLA 散布運算會產生一系列結果,這些結果是輸入陣列 operands
的值,其中有幾個切片 (在 scatter_indices
指定的索引處) 會使用 update_computation
更新 updates
中的值序列。
另請參閱 XlaBuilder::Scatter
。
scatter(operands..., scatter_indices, updates..., update_computation,
index_vector_dim, update_window_dims, inserted_window_dims,
scatter_dims_to_operand_dims)
引數 | 類型 | 語意 |
---|---|---|
operands |
N 個 XlaOp 的序列 |
要散布到其中的 N 個 T_0, ..., T_N 類型陣列。 |
scatter_indices |
XlaOp |
陣列,其中包含必須散發至的切片起始索引。 |
updates |
N 個 XlaOp 的序列 |
類型為 T_0, ..., T_N 的 N 個陣列。updates[i] 包含必須用於散布 operands[i] 的值。 |
update_computation |
XlaComputation |
用於結合輸入陣列中現有值和散布期間更新的運算。這項運算的類型應為 T_0, ..., T_N, T_0, ..., T_N -> Collate(T_0, ..., T_N) 。 |
index_vector_dim |
int64 |
scatter_indices 中包含起始索引的維度。 |
update_window_dims |
ArraySlice<int64> |
updates 形狀中的視窗大小維度集。 |
inserted_window_dims |
ArraySlice<int64> |
必須插入 updates 形狀的視窗大小集。 |
scatter_dims_to_operand_dims |
ArraySlice<int64> |
從散布索引到運算元索引空間的維度對應。這個陣列會解讀為將 i 對應至 scatter_dims_to_operand_dims[i] 。必須是 1 比 1 的對應關係。 |
indices_are_sorted |
bool |
是否保證索引會由呼叫端排序。 |
unique_indices |
bool |
呼叫端是否保證索引不會重複。 |
在此情況下:
- N 必須大於或等於 1。
operands
[0
]、...、operands
[N-1
] 的尺寸都必須相同。updates
[0
]、...、updates
[N-1
] 的尺寸都必須相同。- 如果是
N = 1
,Collate(T)
就是T
。 - 如果是
N > 1
,Collate(T_0, ..., T_N)
就是T
類型的N
元素元組。
如果 index_vector_dim
等於 scatter_indices.rank
,我們會隱含地將 scatter_indices
視為具有尾隨 1
維度的資料。
我們將 ArraySlice<int64>
類型的 update_scatter_dims
定義為 updates
形狀中不屬於 update_window_dims
的一組維度,並以升冪順序排列。
散布圖的引數應遵循下列限制:
每個
updates
陣列都必須有update_window_dims.size + scatter_indices.rank - 1
維度。每個
updates
陣列中維度i
的邊界必須符合下列條件:- 如果
i
出現在update_window_dims
中 (即等於某些k
的update_window_dims
[k
]),則updates
中的維度i
邊界不得超過operand
的對應邊界,且必須考量inserted_window_dims
(即adjusted_window_bounds
[k
],其中adjusted_window_bounds
包含operand
的邊界,且已移除索引inserted_window_dims
的邊界)。 - 如果
i
出現在update_scatter_dims
中 (也就是在某些k
中等於update_scatter_dims
[k
]),updates
中的維度i
邊界必須等於scatter_indices
的對應邊界,並略過index_vector_dim
(也就是scatter_indices.shape.dims
[k
],如果k
<index_vector_dim
,否則為scatter_indices.shape.dims
[k+1
])。
- 如果
update_window_dims
必須依遞增順序排列,且不得有重複的維度編號,且必須在[0, updates.rank)
的範圍內。inserted_window_dims
必須依遞增順序排列,且不得有重複的維度編號,且必須在[0, operand.rank)
的範圍內。operand.rank
必須等於update_window_dims.size
和inserted_window_dims.size
的總和。scatter_dims_to_operand_dims.size
必須等於scatter_indices.shape.dims
[index_vector_dim
],且其值必須在[0, operand.rank)
的範圍內。
針對每個 updates
陣列中的指定索引 U
,系統會依下列方式計算要套用此更新的對應 operands
陣列中的對應索引 I
:
- 讓
G
= {U
[k
] fork
inupdate_scatter_dims
}。使用G
在scatter_indices
陣列中查詢索引向量S
,以便S
[i
] =scatter_indices
[Combine(G
,i
)],其中 Combine(A, b) 會將 b 插入 A 的index_vector_dim
位置。 - 使用
scatter_dims_to_operand_dims
地圖散布S
,藉此使用S
將索引S
in
建立至operand
。更正式的做法如下:S
in
[scatter_dims_to_operand_dims
[k
]] =S
[k
] (如果k
<scatter_dims_to_operand_dims.size
)。S
in
[_
] =0
,否則為0
。
- 根據
inserted_window_dims
將索引散布在U
的update_window_dims
中,藉此在每個operands
陣列中建立索引W
in
。更正式的做法如下:- 如果
k
位於update_window_dims
中,W
in
[window_dims_to_operand_dims
(k
)] =U
[k
],其中window_dims_to_operand_dims
是具有網域 [0
,update_window_dims.size
) 和範圍 [0
,operand.rank
) \inserted_window_dims
的單調函式。(舉例來說,如果update_window_dims.size
是4
、operand.rank
是6
,而inserted_window_dims
是 {0
,2
},那麼window_dims_to_operand_dims
就是 {0
→1
,1
→3
,2
→4
,3
→5
}。) W
in
[_
] =0
,否則為0
。
- 如果
I
是W
in
+S
in
,其中 + 是元素相加。
總而言之,散布運算可定義如下:
- 使用
operands
初始化output
,也就是針對J
的所有索引,以及operands
[J
] 陣列中的所有索引O
:
output
[J
][O
] =operands
[J
][O
] - 針對
updates
[J
] 陣列中的每個索引U
和operand
[J
] 陣列中的對應索引O
,如果O
是output
的有效索引:
(output
[0
][O
]、...、output
[N-1
][O
]) =update_computation
(output
[0
][O
]、...、output
[N-1
][O
]、updates
[0
][U
]、...、updates
[N-1
][U
])
更新套用順序不確定。因此,當 updates
中的多個索引參照 operands
中的同一個索引時,output
中的對應值就會變得不確定。
請注意,傳遞至 update_computation
的第一個參數一律為 output
陣列的目前值,而第二個參數一律為 updates
陣列的值。這點對於 update_computation
非可交換的情況尤其重要。
如果 indices_are_sorted
設為 true,XLA 可假設使用者已排序 scatter_indices
(以遞增順序排列,在根據 scatter_dims_to_operand_dims
散布其值)。如果不是,則語意是由實作方式定義。
如果將 unique_indices
設為 true,XLA 可以假設所有散布到的元素都是唯一的。因此 XLA 可以使用非原子作業。如果 unique_indices
設為 true,且散布的索引並非唯一,則語意是由實作方式定義。
非正式地說,散布運算可視為收集運算的反向運算,也就是散布運算會更新由對應收集運算擷取的輸入內容元素。
如需詳細的非正式說明和範例,請參閱 Gather
下方的「非正式說明」一節。
選取
另請參閱 XlaBuilder::Select
。
根據謂詞陣列的值,從兩個輸入陣列的元素建構輸出陣列。
Select(pred, on_true, on_false)
引數 | 類型 | 語意 |
---|---|---|
pred |
XlaOp |
PRED 類型的陣列 |
on_true |
XlaOp |
類型為 T 的陣列 |
on_false |
XlaOp |
類型為 T 的陣列 |
陣列 on_true
和 on_false
必須具有相同的形狀。這也是輸出陣列的形狀。陣列 pred
必須使用 PRED
元素類型,且維度必須與 on_true
和 on_false
相同。
對於 pred
的每個元素 P
,如果 P
的值為 true
,則輸出陣列的對應元素會從 on_true
取得;如果 P
的值為 false
,則會從 on_false
取得。pred
是廣播的限制形式,可以是 PRED
類型的標量。在這種情況下,如果 pred
為 true
,輸出陣列會完全從 on_true
取得;如果 pred
為 false
,則會從 on_false
取得。
使用非標量 pred
的範例:
let pred: PRED[4] = {true, false, false, true};
let v1: s32[4] = {1, 2, 3, 4};
let v2: s32[4] = {100, 200, 300, 400};
==>
Select(pred, v1, v2) = s32[4]{1, 200, 300, 4};
使用標量 pred
的範例:
let pred: PRED = true;
let v1: s32[4] = {1, 2, 3, 4};
let v2: s32[4] = {100, 200, 300, 400};
==>
Select(pred, v1, v2) = s32[4]{1, 2, 3, 4};
支援在元組之間進行選取。在此情況下,元組會視為純量型別。如果 on_true
和 on_false
是元組 (必須具有相同的形狀!),則 pred
必須是 PRED
類型的純量。
SelectAndScatter
另請參閱 XlaBuilder::SelectAndScatter
。
這項運算可視為複合運算,首先會在 operand
陣列上計算 ReduceWindow
,從每個視窗中選取一個元素,然後將 source
陣列散布至所選元素的索引,以建構與運算子陣列相同形狀的輸出陣列。二進位 select
函式可用於在每個視窗中套用函式,從每個視窗中選取元素,並以第一個參數的索引向量在字典順序上小於第二個參數的索引向量為條件來呼叫。如果選取第一個參數,select
函式會傳回 true
,如果選取第二個參數,則會傳回 false
,且函式必須具備傳遞性 (如果 select(a, b)
和 select(b, c)
為 true
,則 select(a, c)
也為 true
),以便所選元素不依賴特定視窗中經過的元素順序。
函式 scatter
會套用至輸出陣列中的每個所選索引。這個函式會使用兩個純量參數:
- 輸出陣列中所選索引的目前值
- 套用至所選索引的
source
散布圖值
它會結合兩個參數,並傳回純量值,用於更新輸出陣列中所選索引的值。一開始,輸出陣列的所有索引都會設為 init_value
。
輸出陣列的形狀與 operand
陣列相同,而 source
陣列的形狀必須與在 operand
陣列上套用 ReduceWindow
運算的結果相同。SelectAndScatter
可用於在神經網路中,為匯集層回傳梯度值。
SelectAndScatter(operand, select, window_dimensions, window_strides,
padding, source, init_value, scatter)
引數 | 類型 | 語意 |
---|---|---|
operand |
XlaOp |
系統在其中滑動視窗的 T 型陣列 |
select |
XlaComputation |
類型為 T, T -> PRED 的二進位運算,可套用至每個視窗中的所有元素;如果選取第一個參數,則會傳回 true ,如果選取第二個參數,則會傳回 false |
window_dimensions |
ArraySlice<int64> |
視窗維度值的整數陣列 |
window_strides |
ArraySlice<int64> |
視窗步距值的整數陣列 |
padding |
Padding |
視窗的邊框間距類型 (Padding::kSame 或 Padding::kValid) |
source |
XlaOp |
含有要散布值的 T 型陣列 |
init_value |
XlaOp |
輸出陣列的初始值,為型別 T 的純量 |
scatter |
XlaComputation |
類型為 T, T -> T 的二進位運算,可為每個散布來源元素套用其目的地元素 |
下圖顯示使用 SelectAndScatter
的範例,其中 select
函式會計算參數中的最大值。請注意,當視窗重疊時 (如下圖 2 所示),不同視窗可能會多次選取 operand
陣列的索引。在圖中,兩個頂端視窗 (藍色和紅色) 都選取了值為 9 的元素,而二進位加法 scatter
函式會產生值為 8 (2 + 6) 的輸出元素。
scatter
函式的評估順序為任意順序,且可能是非決定性的。因此,scatter
函式不應對重新關聯過度敏感。如需詳細資訊,請參閱 Reduce
的相關討論內容。
傳送
另請參閱 XlaBuilder::Send
。
Send(operand, channel_handle)
引數 | 類型 | 語意 |
---|---|---|
operand |
XlaOp |
要傳送的資料 (T 型態的陣列) |
channel_handle |
ChannelHandle |
每個傳送/接收組合的專屬 ID |
將指定的運算元資料傳送至共用相同管道句柄的另一個運算中 Recv
指令。不會傳回任何資料。
與 Recv
作業類似,Send
作業的用戶端 API 代表同步通訊,並在內部分解為 2 個 HLO 指令 (Send
和 SendDone
),以啟用非同步資料移轉。另請參閱 HloInstruction::CreateSend
和 HloInstruction::CreateSendDone
。
Send(HloInstruction operand, int64 channel_id)
啟動非同步轉移作業,將運算元傳送至 Recv
指令以相同管道 ID 分配的資源。會傳回背景資訊,供後續 SendDone
指令使用,以便等待資料傳輸作業完成。這個內容是 {運算子 (形狀)、要求 ID (U32)} 的元組,且只能由 SendDone
指令使用。
SendDone(HloInstruction context)
在 Send
指令建立的內容中,等待資料傳輸完成。這項指令不會傳回任何資料。
管道指示的排程
每個管道的 4 個指令的執行順序如下:Recv
、RecvDone
、Send
、SendDone
。
Recv
會在Send
之前發生Send
會在RecvDone
之前發生Recv
會在RecvDone
之前發生Send
會在SendDone
之前發生
當後端編譯器為每項透過管道指令進行通訊的運算產生線性排程時,運算之間不得有循環。舉例來說,下列排程會導致死結。
請注意,指令的限制僅適用於執行階段的 TPU。在 GPU 上,send
和 recv
會在來源和目標裝置之間完成握手後,才會封鎖並停止傳送任何實際資料。
配量
另請參閱 XlaBuilder::Slice
。
切片會從輸入陣列中擷取子陣列。子陣列的維度數量與輸入值相同,且包含輸入陣列內邊界框內的值,其中邊界框的維度和索引會做為切片運算的引數提供。
Slice(operand, start_indices, limit_indices, strides)
引數 | 類型 | 語意 |
---|---|---|
operand |
XlaOp |
類型為 T 的 N 維陣列 |
start_indices |
ArraySlice<int64> |
包含 N 個整數的清單,其中包含每個維度的切片起始索引。值必須大於或等於零。 |
limit_indices |
ArraySlice<int64> |
包含 N 個整數的清單,其中包含每個維度切片的結束索引 (不含)。每個值都必須大於或等於維度的相應 start_indices 值,且小於或等於維度的大小。 |
strides |
ArraySlice<int64> |
決定切片輸入步幅的 N 個整數清單。切片會選取維度 d 中的每個 strides[d] 元素。 |
1 維範例:
let a = {0.0, 1.0, 2.0, 3.0, 4.0}
Slice(a, {2}, {4}) produces:
{2.0, 3.0}
2D 範例:
let b =
{ {0.0, 1.0, 2.0},
{3.0, 4.0, 5.0},
{6.0, 7.0, 8.0},
{9.0, 10.0, 11.0} }
Slice(b, {2, 1}, {4, 3}) produces:
{ { 7.0, 8.0},
{10.0, 11.0} }
排序
另請參閱 XlaBuilder::Sort
。
Sort(operands, comparator, dimension, is_stable)
引數 | 類型 | 語意 |
---|---|---|
operands |
ArraySlice<XlaOp> |
要排序的運算元。 |
comparator |
XlaComputation |
要使用的比較器運算。 |
dimension |
int64 |
排序的維度。 |
is_stable |
bool |
是否應使用穩定排序。 |
如果只提供一個運算元:
如果運算元是 1 維張量 (陣列),結果會是排序陣列。如果您想將陣列排序為遞增順序,比較器應執行小於比較。正式來說,陣列排序後,會保留所有索引位置
i, j
,其中i < j
為comparator(value[i], value[j]) = comparator(value[j], value[i]) = false
或comparator(value[i], value[j]) = true
。如果運算元項的維度數較多,系統會依照提供的維度對運算元項進行排序。舉例來說,如果是 2 維張量 (矩陣),
0
維度值會分別排序每個資料欄,而1
維度值則會分別排序每個資料列。如果未提供維度編號,系統會預設選擇最後一個維度。對於要排序的維度,會套用與 1 維度相同的排序順序。
如果提供 n > 1
運算元:
所有
n
運算元都必須是相同維度的張量。張量的元素類型可能不同。所有運算元皆會一起排序,而非個別排序。從概念上來說,運算元式會視為元組。在檢查索引位置
i
和j
的每個運算元的元素是否需要交換時,會使用2 * n
標量參數呼叫比較器,其中參數2 * k
對應k-th
運算元i
位置的值,而參數2 * k + 1
對應k-th
運算元j
位置的值。因此,比較器通常會比較2 * k
和2 * k + 1
參數,並可能使用其他參數組合做為平手時的判斷依據。結果是元組,其中包含以排序順序排列的運算子 (如上所述,沿著提供的維度)。元組的
i-th
運算元會對應至排序的i-th
運算元。
舉例來說,如果有三個運算元 operand0 = [3, 1]
、operand1 = [42, 50]
、operand2 = [-3.0, 1.1]
,且比較器只比較 operand0
的值,並以小於運算,則排序的輸出內容就是元組 ([1, 3], [50, 42], [1.1, -3.0])
。
如果 is_stable
設為 true,系統會保證排序穩定,也就是說,如果比較器認為某些元素相等,則會保留相等值的相對順序。只有在 comparator(e1, e2) = comparator(e2, e1) = false
的情況下,兩個元素 e1
和 e2
才會相等。預設情況下,is_stable
會設為 false。
TopK
另請參閱 XlaBuilder::TopK
。
TopK
會找出指定張量最後一維度的 k
最大或最小元素的值和索引。
TopK(operand, k, largest)
引數 | 類型 | 語意 |
---|---|---|
operand |
XlaOp |
要從中擷取前 k 個元素的張量。張量必須具有大於或等於一個維度的值。張量的最後一個維度大小必須大於或等於 k 。 |
k |
int64 |
要擷取的元素數量。 |
largest |
bool |
是否要擷取最大的或最小的 k 元素。 |
針對 1 維輸入張量 (陣列),找出陣列中 k
最大的或最小的項目,並輸出兩個陣列 (values, indices)
的元組。因此,values[j]
是 operand
中第 j
大的/小的項目,其索引為 indices[j]
。
對於具有超過 1 個維度的輸入張量,沿著最後一個維度計算前 k
個項目,並在輸出內容中保留所有其他維度 (列)。因此,對於形狀為 [A, B, ..., P, Q]
的運算元,Q >= k
的輸出值是元組 (values, indices)
,其中:
values.shape = indices.shape = [A, B, ..., P, k]
如果一列中的兩個元素相等,則會先顯示索引較低的元素。
轉置
另請參閱 tf.reshape
運算。
Transpose(operand)
引數 | 類型 | 語意 |
---|---|---|
operand |
XlaOp |
要轉置的運算元。 |
permutation |
ArraySlice<int64> |
如何排列維度。 |
使用指定的排列法對運算元維度進行排列,因此為 ∀ i . 0 ≤ i < number of dimensions ⇒
input_dimensions[permutation[i]] = output_dimensions[i]
。
這與 Reshape(operand, permutation, Permute(permutation, operand.shape.dimensions)) 相同。
TriangularSolve
另請參閱 XlaBuilder::TriangularSolve
。
透過前向或後向代入法,解決具有上三角或下三角係數矩陣的線性方程式組。這個例行程式會沿著前導維度進行廣播,在 a
和 b
的情況下,為變數 x
解決其中一個矩陣系統 op(a) * x =
b
或 x * op(a) = b
,其中 op(a)
為 op(a) = a
、op(a) = Transpose(a)
或 op(a) = Conj(Transpose(a))
。
TriangularSolve(a, b, left_side, lower, unit_diagonal, transpose_a)
引數 | 類型 | 語意 |
---|---|---|
a |
XlaOp |
a > 2 維陣列,其型別為複數或浮點,形狀為 [..., M, M] 。 |
b |
XlaOp |
a > 2 維同型陣列,如果 left_side 為 true,則形狀為 [..., M, K] ,否則為 [..., K, M] 。 |
left_side |
bool |
表示要解決 op(a) * x = b (true ) 或 x * op(a) = b (false ) 形式的系統。 |
lower |
bool |
是否使用 a 的上三角或下三角。 |
unit_diagonal |
bool |
如果 true ,則 a 的對角元素會假設為 1 ,且不會存取。 |
transpose_a |
Transpose |
是否要使用 a 原樣、轉置或取其共軛轉置。 |
輸入資料只會從 a
的下/上三角讀取,具體取決於 lower
的值。系統會忽略其他三角形的值。輸出資料會在同一三角形中傳回;其他三角形中的值則由實作定義,可以是任何值。
如果 a
和 b
的維度數量大於 2,系統會將其視為矩陣批次,其中除了次要 2 維度以外,所有都是批次維度。a
和 b
必須具有相同的批次維度。
元組
另請參閱 XlaBuilder::Tuple
。
一個元組,其中包含可變數量的資料句柄,每個句柄都有各自的形狀。
這類似於 C++ 中的 std::tuple
。概念上:
let v: f32[10] = f32[10]{0, 1, 2, 3, 4, 5, 6, 7, 8, 9};
let s: s32 = 5;
let t: (f32[10], s32) = tuple(v, s);
您可以透過 GetTupleElement
運算來解構 (存取) 元組。
While
另請參閱 XlaBuilder::While
。
While(condition, body, init)
引數 | 類型 | 語意 |
---|---|---|
condition |
XlaComputation |
定義迴圈結束條件的 T -> PRED 類型 XlaComputation。 |
body |
XlaComputation |
定義迴圈主體的 T -> T 類型 XlaComputation。 |
init |
T |
condition 和 body 參數的初始值。 |
依序執行 body
,直到 condition
失敗為止。這與許多其他語言中的一般 while 迴圈類似,但請注意下列差異和限制。
While
節點會傳回T
類型的值,這是body
上次執行的結果。- 系統會以靜態方式決定
T
類型的形狀,且所有迭代作業的形狀都必須相同。
計算作業的 T 參數會在第一個迴迭中使用 init
值初始化,並在後續每個迴迭中自動更新為 body
的新結果。
While
節點的主要用途之一,是實作在神經網路中重複執行訓練的功能。下方列出簡化的虛擬程式碼,以及代表運算的圖表。您可以在 while_test.cc
中找到這段程式碼。此範例中的 T
類型為 Tuple
,其中包含用於疊代計數的 int32
,以及用於累加器的 vector[10]
。在 1000 次迭代中,迴圈會持續將常數向量新增至累加器。
// Pseudocode for the computation.
init = {0, zero_vector[10]} // Tuple of int32 and float[10].
result = init;
while (result(0) < 1000) {
iteration = result(0) + 1;
new_vector = result(1) + constant_vector[10];
result = {iteration, new_vector};
}