class ConformerEncoder(Module):
  __parameters__ = []
  __buffers__ = []
  embed : __torch__.cosyvoice.transformer.subsampling.___torch_mangle_54.LinearNoSubsampling
  def forward(self: __torch__.cosyvoice.transformer.encoder.___torch_mangle_53.ConformerEncoder,
    xs: Tensor,
    xs_lens: Tensor,
    decoding_chunk_size: int=0,
    num_decoding_left_chunks: int=-1) -> Tuple[Tensor, Tensor]:
    T = torch.size(xs, 1)
    batch_size = torch.size(xs_lens, 0)
    if torch.gt(T, 0):
      max_len = T
    else:
      max_len = torch.item(torch.max(xs_lens))
    seq_range = torch.arange(0, max_len, dtype=4, layout=None, device=ops.prim.device(xs_lens))
    seq_range_expand = torch.expand(torch.unsqueeze(seq_range, 0), [batch_size, int(max_len)])
    seq_length_expand = torch.unsqueeze(xs_lens, -1)
    mask = torch.ge(seq_range_expand, seq_length_expand)
    masks = torch.bitwise_not(torch.unsqueeze(mask, 1))
    embed = self.embed
    _0 = torch.add(torch.matmul(xs, CONSTANTS.c0), CONSTANTS.c1)
    input = torch.layer_norm(_0, [512], CONSTANTS.c2, CONSTANTS.c3)
    pos_enc = embed.pos_enc
    pe = pos_enc.pe
    _1 = torch.size(pe, 1)
    _2 = torch.size(input, 1)
    _3 = torch.ge(_1, torch.sub(torch.mul(_2, 2), 1))
    if _3:
      pe0 = pos_enc.pe
      _4 = ops.prim.dtype(pe0)
      _5 = ops.prim.dtype(input)
      if torch.ne(_4, _5):
        _6 = True
      else:
        pe1 = pos_enc.pe
        _7 = torch.ne(ops.prim.device(pe1), ops.prim.device(input))
        _6 = _7
      if _6:
        pe2 = pos_enc.pe
        _8 = torch.to(pe2, ops.prim.device(input), _5)
        pos_enc.pe = _8
      else:
        pass
    else:
      _9 = [_2, 512]
      pe_positive = torch.zeros(_9)
      pe_negative = torch.zeros(_9)
      position = torch.unsqueeze(torch.arange(0, _2, dtype=6), 1)
      _10 = torch.mul(position, CONSTANTS.c4)
      _11 = torch.sin(_10)
      _12 = torch.slice(torch.slice(pe_positive), 1, 0, None, 2)
      _13 = torch.copy_(_12, _11)
      _14 = torch.cos(_10)
      _15 = torch.slice(torch.slice(pe_positive), 1, 1, None, 2)
      _16 = torch.copy_(_15, _14)
      _17 = torch.mul(torch.mul(position, -1), CONSTANTS.c4)
      _18 = torch.sin(_17)
      _19 = torch.slice(torch.slice(pe_negative), 1, 0, None, 2)
      _20 = torch.copy_(_19, _18)
      _21 = torch.cos(_17)
      _22 = torch.slice(torch.slice(pe_negative), 1, 1, None, 2)
      _23 = torch.copy_(_22, _21)
      pe_positive0 = torch.unsqueeze(torch.flip(pe_positive, [0]), 0)
      pe_negative0 = torch.unsqueeze(torch.slice(pe_negative, 0, 1), 0)
      pe3 = torch.cat([pe_positive0, pe_negative0], 1)
      _24 = torch.to(pe3, ops.prim.device(input), ops.prim.dtype(input))
      pos_enc.pe = _24
    x = torch.mul(input, 22.627416997969522)
    _25 = torch.size(x, 1)
    pe4 = pos_enc.pe
    _26 = torch.slice(pe4)
    pe5 = pos_enc.pe
    _27 = torch.floordiv(torch.size(pe5, 1), 2)
    _28 = torch.add(torch.sub(_27, _25), 1)
    pe6 = pos_enc.pe
    _29 = torch.floordiv(torch.size(pe6, 1), 2)
    pos_emb = torch.slice(_26, 1, _28, torch.add(_29, _25))
    x0 = torch.layer_norm(x, [512], CONSTANTS.c5, CONSTANTS.c6)
    n_batch = torch.size(x0, 0)
    _30 = torch.add(torch.matmul(x0, CONSTANTS.c7), CONSTANTS.c8)
    _31 = torch.slice(_30, -1, 1024, 1536)
    _32 = torch.slice(_30, -1, 512, 1024)
    _33 = torch.slice(_30, -1, 0, 512)
    _34 = [n_batch, -1, 8, 64]
    q = torch.view(_33, _34)
    k = torch.view(_32, _34)
    v = torch.view(_31, _34)
    q0 = torch.transpose(q, 1, 2)
    k0 = torch.transpose(k, 1, 2)
    v0 = torch.transpose(v, 1, 2)
    q1 = torch.transpose(q0, 1, 2)
    n_batch_pos = torch.size(pos_emb, 0)
    _35 = torch.matmul(pos_emb, CONSTANTS.c9)
    _36 = [n_batch_pos, -1, 8, 64]
    p = torch.view(_35, _36)
    p0 = torch.transpose(p, 1, 2)
    q_with_bias_u = torch.transpose(torch.add(q1, CONSTANTS.c10), 1, 2)
    q_with_bias_v = torch.transpose(torch.add(q1, CONSTANTS.c11), 1, 2)
    matrix_ac = torch.matmul(q_with_bias_u, torch.transpose(k0, -2, -1))
    matrix_bd = torch.matmul(q_with_bias_v, torch.transpose(p0, -2, -1))
    _37 = torch.size(matrix_ac)
    _38 = torch.size(matrix_bd)
    if torch.ne(_37, _38):
      _39 = _38[0]
      _40 = _38[1]
      _41 = _38[2]
      _42 = ops.prim.device(matrix_bd)
      _43 = ops.prim.dtype(matrix_bd)
      zero_pad = torch.zeros([_39, _40, _41, 1], dtype=_43, layout=None, device=_42)
      x_padded = torch.cat([zero_pad, matrix_bd], -1)
      _44 = torch.add(torch.size(matrix_bd, 3), 1)
      _45 = [_39, _40, _44, torch.size(matrix_bd, 2)]
      x_padded0 = torch.view(x_padded, _45)
      _46 = torch.slice(torch.slice(x_padded0), 1)
      _47 = torch.view_as(torch.slice(_46, 2, 1), matrix_bd)
      _48 = torch.slice(torch.slice(torch.slice(_47), 1), 2)
      _49 = torch.floordiv(torch.size(matrix_bd, -1), 2)
      matrix_bd1 = torch.slice(_48, 3, None, torch.add(_49, 1))
      matrix_bd0 = matrix_bd1
    else:
      matrix_bd0 = matrix_bd
    scores = torch.div(torch.add(matrix_ac, matrix_bd0), 8.)
    n_batch0 = torch.size(v0, 0)
    if torch.gt(torch.size(masks, 2), 0):
      mask0 = torch.eq(torch.unsqueeze(masks, 1), 0)
      _50 = torch.slice(torch.slice(mask0), 1)
      mask1 = torch.slice(torch.slice(_50, 2), 3, None, torch.size(scores, -1))
      scores0 = torch.masked_fill(scores, mask1, -inf)
      attn0 = torch.masked_fill(torch.softmax(scores0, -1), mask1, 0.)
      attn = attn0
    else:
      attn = torch.softmax(scores, -1)
    x1 = torch.matmul(attn, v0)
    _51 = torch.contiguous(torch.transpose(x1, 1, 2))
    x2 = torch.view(_51, [n_batch0, -1, 512])
    _52 = torch.add(torch.matmul(x2, CONSTANTS.c12), CONSTANTS.c13)
    x3 = torch.add(x, _52)
    x4 = torch.layer_norm(x3, [512], CONSTANTS.c14, CONSTANTS.c15)
    _53 = torch.add(torch.matmul(x4, CONSTANTS.c16), CONSTANTS.c17)
    _54 = torch.matmul(torch.silu(_53), CONSTANTS.c18)
    _55 = torch.mul(torch.add(_54, CONSTANTS.c19), 1.)
    x5 = torch.add(x3, _55)
    x6 = torch.layer_norm(x5, [512], CONSTANTS.c20, CONSTANTS.c21)
    n_batch1 = torch.size(x6, 0)
    _56 = torch.add(torch.matmul(x6, CONSTANTS.c22), CONSTANTS.c23)
    _57 = torch.slice(_56, -1, 1024, 1536)
    _58 = torch.slice(_56, -1, 512, 1024)
    _59 = torch.slice(_56, -1, 0, 512)
    _60 = [n_batch1, -1, 8, 64]
    q2 = torch.view(_59, _60)
    k1 = torch.view(_58, _60)
    v1 = torch.view(_57, _60)
    q3 = torch.transpose(q2, 1, 2)
    k2 = torch.transpose(k1, 1, 2)
    v2 = torch.transpose(v1, 1, 2)
    q4 = torch.transpose(q3, 1, 2)
    _61 = torch.matmul(pos_emb, CONSTANTS.c24)
    p1 = torch.view(_61, _36)
    p2 = torch.transpose(p1, 1, 2)
    q_with_bias_u0 = torch.transpose(torch.add(q4, CONSTANTS.c25), 1, 2)
    q_with_bias_v0 = torch.transpose(torch.add(q4, CONSTANTS.c26), 1, 2)
    matrix_ac0 = torch.matmul(q_with_bias_u0, torch.transpose(k2, -2, -1))
    matrix_bd2 = torch.matmul(q_with_bias_v0, torch.transpose(p2, -2, -1))
    _62 = torch.size(matrix_ac0)
    _63 = torch.size(matrix_bd2)
    if torch.ne(_62, _63):
      _64 = _63[0]
      _65 = _63[1]
      _66 = _63[2]
      _67 = ops.prim.device(matrix_bd2)
      _68 = ops.prim.dtype(matrix_bd2)
      zero_pad0 = torch.zeros([_64, _65, _66, 1], dtype=_68, layout=None, device=_67)
      x_padded1 = torch.cat([zero_pad0, matrix_bd2], -1)
      _69 = torch.add(torch.size(matrix_bd2, 3), 1)
      _70 = [_64, _65, _69, torch.size(matrix_bd2, 2)]
      x_padded2 = torch.view(x_padded1, _70)
      _71 = torch.slice(torch.slice(x_padded2), 1)
      _72 = torch.view_as(torch.slice(_71, 2, 1), matrix_bd2)
      _73 = torch.slice(torch.slice(torch.slice(_72), 1), 2)
      _74 = torch.floordiv(torch.size(matrix_bd2, -1), 2)
      matrix_bd4 = torch.slice(_73, 3, None, torch.add(_74, 1))
      matrix_bd3 = matrix_bd4
    else:
      matrix_bd3 = matrix_bd2
    scores1 = torch.div(torch.add(matrix_ac0, matrix_bd3), 8.)
    n_batch2 = torch.size(v2, 0)
    if torch.gt(torch.size(masks, 2), 0):
      mask2 = torch.eq(torch.unsqueeze(masks, 1), 0)
      _75 = torch.slice(torch.slice(mask2), 1)
      mask3 = torch.slice(torch.slice(_75, 2), 3, None, torch.size(scores1, -1))
      scores2 = torch.masked_fill(scores1, mask3, -inf)
      attn2 = torch.masked_fill(torch.softmax(scores2, -1), mask3, 0.)
      attn1 = attn2
    else:
      attn1 = torch.softmax(scores1, -1)
    x7 = torch.matmul(attn1, v2)
    _76 = torch.contiguous(torch.transpose(x7, 1, 2))
    x8 = torch.view(_76, [n_batch2, -1, 512])
    _77 = torch.add(torch.matmul(x8, CONSTANTS.c27), CONSTANTS.c28)
    x9 = torch.add(x5, _77)
    x10 = torch.layer_norm(x9, [512], CONSTANTS.c29, CONSTANTS.c30)
    _78 = torch.add(torch.matmul(x10, CONSTANTS.c31), CONSTANTS.c32)
    _79 = torch.matmul(torch.silu(_78), CONSTANTS.c33)
    _80 = torch.mul(torch.add(_79, CONSTANTS.c34), 1.)
    x11 = torch.add(x9, _80)
    x12 = torch.layer_norm(x11, [512], CONSTANTS.c35, CONSTANTS.c36)
    n_batch3 = torch.size(x12, 0)
    _81 = torch.add(torch.matmul(x12, CONSTANTS.c37), CONSTANTS.c38)
    _82 = torch.slice(_81, -1, 1024, 1536)
    _83 = torch.slice(_81, -1, 512, 1024)
    _84 = torch.slice(_81, -1, 0, 512)
    _85 = [n_batch3, -1, 8, 64]
    q5 = torch.view(_84, _85)
    k3 = torch.view(_83, _85)
    v3 = torch.view(_82, _85)
    q6 = torch.transpose(q5, 1, 2)
    k4 = torch.transpose(k3, 1, 2)
    v4 = torch.transpose(v3, 1, 2)
    q7 = torch.transpose(q6, 1, 2)
    _86 = torch.matmul(pos_emb, CONSTANTS.c39)
    p3 = torch.view(_86, _36)
    p4 = torch.transpose(p3, 1, 2)
    q_with_bias_u1 = torch.transpose(torch.add(q7, CONSTANTS.c40), 1, 2)
    q_with_bias_v1 = torch.transpose(torch.add(q7, CONSTANTS.c41), 1, 2)
    matrix_ac1 = torch.matmul(q_with_bias_u1, torch.transpose(k4, -2, -1))
    matrix_bd5 = torch.matmul(q_with_bias_v1, torch.transpose(p4, -2, -1))
    _87 = torch.size(matrix_ac1)
    _88 = torch.size(matrix_bd5)
    if torch.ne(_87, _88):
      _89 = _88[0]
      _90 = _88[1]
      _91 = _88[2]
      _92 = ops.prim.device(matrix_bd5)
      _93 = ops.prim.dtype(matrix_bd5)
      zero_pad1 = torch.zeros([_89, _90, _91, 1], dtype=_93, layout=None, device=_92)
      x_padded3 = torch.cat([zero_pad1, matrix_bd5], -1)
      _94 = torch.add(torch.size(matrix_bd5, 3), 1)
      _95 = [_89, _90, _94, torch.size(matrix_bd5, 2)]
      x_padded4 = torch.view(x_padded3, _95)
      _96 = torch.slice(torch.slice(x_padded4), 1)
      _97 = torch.view_as(torch.slice(_96, 2, 1), matrix_bd5)
      _98 = torch.slice(torch.slice(torch.slice(_97), 1), 2)
      _99 = torch.floordiv(torch.size(matrix_bd5, -1), 2)
      matrix_bd7 = torch.slice(_98, 3, None, torch.add(_99, 1))
      matrix_bd6 = matrix_bd7
    else:
      matrix_bd6 = matrix_bd5
    scores3 = torch.div(torch.add(matrix_ac1, matrix_bd6), 8.)
    n_batch4 = torch.size(v4, 0)
    if torch.gt(torch.size(masks, 2), 0):
      mask4 = torch.eq(torch.unsqueeze(masks, 1), 0)
      _100 = torch.slice(torch.slice(mask4), 1)
      mask5 = torch.slice(torch.slice(_100, 2), 3, None, torch.size(scores3, -1))
      scores4 = torch.masked_fill(scores3, mask5, -inf)
      attn4 = torch.masked_fill(torch.softmax(scores4, -1), mask5, 0.)
      attn3 = attn4
    else:
      attn3 = torch.softmax(scores3, -1)
    x13 = torch.matmul(attn3, v4)
    _101 = torch.contiguous(torch.transpose(x13, 1, 2))
    x14 = torch.view(_101, [n_batch4, -1, 512])
    _102 = torch.add(torch.matmul(x14, CONSTANTS.c42), CONSTANTS.c43)
    x15 = torch.add(x11, _102)
    x16 = torch.layer_norm(x15, [512], CONSTANTS.c44, CONSTANTS.c45)
    _103 = torch.add(torch.matmul(x16, CONSTANTS.c46), CONSTANTS.c47)
    _104 = torch.matmul(torch.silu(_103), CONSTANTS.c48)
    _105 = torch.mul(torch.add(_104, CONSTANTS.c49), 1.)
    x17 = torch.add(x15, _105)
    x18 = torch.layer_norm(x17, [512], CONSTANTS.c50, CONSTANTS.c51)
    n_batch5 = torch.size(x18, 0)
    _106 = torch.add(torch.matmul(x18, CONSTANTS.c52), CONSTANTS.c53)
    _107 = torch.slice(_106, -1, 1024, 1536)
    _108 = torch.slice(_106, -1, 512, 1024)
    _109 = torch.slice(_106, -1, 0, 512)
    _110 = [n_batch5, -1, 8, 64]
    q8 = torch.view(_109, _110)
    k5 = torch.view(_108, _110)
    v5 = torch.view(_107, _110)
    q9 = torch.transpose(q8, 1, 2)
    k6 = torch.transpose(k5, 1, 2)
    v6 = torch.transpose(v5, 1, 2)
    q10 = torch.transpose(q9, 1, 2)
    _111 = torch.matmul(pos_emb, CONSTANTS.c54)
    p5 = torch.view(_111, _36)
    p6 = torch.transpose(p5, 1, 2)
    q_with_bias_u2 = torch.transpose(torch.add(q10, CONSTANTS.c55), 1, 2)
    q_with_bias_v2 = torch.transpose(torch.add(q10, CONSTANTS.c56), 1, 2)
    matrix_ac2 = torch.matmul(q_with_bias_u2, torch.transpose(k6, -2, -1))
    matrix_bd8 = torch.matmul(q_with_bias_v2, torch.transpose(p6, -2, -1))
    _112 = torch.size(matrix_ac2)
    _113 = torch.size(matrix_bd8)
    if torch.ne(_112, _113):
      _114 = _113[0]
      _115 = _113[1]
      _116 = _113[2]
      _117 = ops.prim.device(matrix_bd8)
      _118 = ops.prim.dtype(matrix_bd8)
      zero_pad2 = torch.zeros([_114, _115, _116, 1], dtype=_118, layout=None, device=_117)
      x_padded5 = torch.cat([zero_pad2, matrix_bd8], -1)
      _119 = torch.add(torch.size(matrix_bd8, 3), 1)
      _120 = [_114, _115, _119, torch.size(matrix_bd8, 2)]
      x_padded6 = torch.view(x_padded5, _120)
      _121 = torch.slice(torch.slice(x_padded6), 1)
      _122 = torch.view_as(torch.slice(_121, 2, 1), matrix_bd8)
      _123 = torch.slice(torch.slice(torch.slice(_122), 1), 2)
      _124 = torch.floordiv(torch.size(matrix_bd8, -1), 2)
      matrix_bd10 = torch.slice(_123, 3, None, torch.add(_124, 1))
      matrix_bd9 = matrix_bd10
    else:
      matrix_bd9 = matrix_bd8
    scores5 = torch.div(torch.add(matrix_ac2, matrix_bd9), 8.)
    n_batch6 = torch.size(v6, 0)
    if torch.gt(torch.size(masks, 2), 0):
      mask6 = torch.eq(torch.unsqueeze(masks, 1), 0)
      _125 = torch.slice(torch.slice(mask6), 1)
      mask7 = torch.slice(torch.slice(_125, 2), 3, None, torch.size(scores5, -1))
      scores6 = torch.masked_fill(scores5, mask7, -inf)
      attn6 = torch.masked_fill(torch.softmax(scores6, -1), mask7, 0.)
      attn5 = attn6
    else:
      attn5 = torch.softmax(scores5, -1)
    x19 = torch.matmul(attn5, v6)
    _126 = torch.contiguous(torch.transpose(x19, 1, 2))
    x20 = torch.view(_126, [n_batch6, -1, 512])
    _127 = torch.add(torch.matmul(x20, CONSTANTS.c57), CONSTANTS.c58)
    x21 = torch.add(x17, _127)
    x22 = torch.layer_norm(x21, [512], CONSTANTS.c59, CONSTANTS.c60)
    _128 = torch.add(torch.matmul(x22, CONSTANTS.c61), CONSTANTS.c62)
    _129 = torch.matmul(torch.silu(_128), CONSTANTS.c63)
    _130 = torch.mul(torch.add(_129, CONSTANTS.c64), 1.)
    x23 = torch.add(x21, _130)
    x24 = torch.layer_norm(x23, [512], CONSTANTS.c65, CONSTANTS.c66)
    n_batch7 = torch.size(x24, 0)
    _131 = torch.add(torch.matmul(x24, CONSTANTS.c67), CONSTANTS.c68)
    _132 = torch.slice(_131, -1, 1024, 1536)
    _133 = torch.slice(_131, -1, 512, 1024)
    _134 = torch.slice(_131, -1, 0, 512)
    _135 = [n_batch7, -1, 8, 64]
    q11 = torch.view(_134, _135)
    k7 = torch.view(_133, _135)
    v7 = torch.view(_132, _135)
    q12 = torch.transpose(q11, 1, 2)
    k8 = torch.transpose(k7, 1, 2)
    v8 = torch.transpose(v7, 1, 2)
    q13 = torch.transpose(q12, 1, 2)
    _136 = torch.matmul(pos_emb, CONSTANTS.c69)
    p7 = torch.view(_136, _36)
    p8 = torch.transpose(p7, 1, 2)
    q_with_bias_u3 = torch.transpose(torch.add(q13, CONSTANTS.c70), 1, 2)
    q_with_bias_v3 = torch.transpose(torch.add(q13, CONSTANTS.c71), 1, 2)
    matrix_ac3 = torch.matmul(q_with_bias_u3, torch.transpose(k8, -2, -1))
    matrix_bd11 = torch.matmul(q_with_bias_v3, torch.transpose(p8, -2, -1))
    _137 = torch.size(matrix_ac3)
    _138 = torch.size(matrix_bd11)
    if torch.ne(_137, _138):
      _139 = _138[0]
      _140 = _138[1]
      _141 = _138[2]
      _142 = ops.prim.device(matrix_bd11)
      _143 = ops.prim.dtype(matrix_bd11)
      zero_pad3 = torch.zeros([_139, _140, _141, 1], dtype=_143, layout=None, device=_142)
      x_padded7 = torch.cat([zero_pad3, matrix_bd11], -1)
      _144 = torch.add(torch.size(matrix_bd11, 3), 1)
      _145 = [_139, _140, _144, torch.size(matrix_bd11, 2)]
      x_padded8 = torch.view(x_padded7, _145)
      _146 = torch.slice(torch.slice(x_padded8), 1)
      _147 = torch.view_as(torch.slice(_146, 2, 1), matrix_bd11)
      _148 = torch.slice(torch.slice(torch.slice(_147), 1), 2)
      _149 = torch.floordiv(torch.size(matrix_bd11, -1), 2)
      matrix_bd13 = torch.slice(_148, 3, None, torch.add(_149, 1))
      matrix_bd12 = matrix_bd13
    else:
      matrix_bd12 = matrix_bd11
    scores7 = torch.div(torch.add(matrix_ac3, matrix_bd12), 8.)
    n_batch8 = torch.size(v8, 0)
    if torch.gt(torch.size(masks, 2), 0):
      mask8 = torch.eq(torch.unsqueeze(masks, 1), 0)
      _150 = torch.slice(torch.slice(mask8), 1)
      mask9 = torch.slice(torch.slice(_150, 2), 3, None, torch.size(scores7, -1))
      scores8 = torch.masked_fill(scores7, mask9, -inf)
      attn8 = torch.masked_fill(torch.softmax(scores8, -1), mask9, 0.)
      attn7 = attn8
    else:
      attn7 = torch.softmax(scores7, -1)
    x25 = torch.matmul(attn7, v8)
    _151 = torch.contiguous(torch.transpose(x25, 1, 2))
    x26 = torch.view(_151, [n_batch8, -1, 512])
    _152 = torch.add(torch.matmul(x26, CONSTANTS.c72), CONSTANTS.c73)
    x27 = torch.add(x23, _152)
    x28 = torch.layer_norm(x27, [512], CONSTANTS.c74, CONSTANTS.c75)
    _153 = torch.add(torch.matmul(x28, CONSTANTS.c76), CONSTANTS.c77)
    _154 = torch.matmul(torch.silu(_153), CONSTANTS.c78)
    _155 = torch.mul(torch.add(_154, CONSTANTS.c79), 1.)
    x29 = torch.add(x27, _155)
    x30 = torch.layer_norm(x29, [512], CONSTANTS.c80, CONSTANTS.c81)
    n_batch9 = torch.size(x30, 0)
    _156 = torch.add(torch.matmul(x30, CONSTANTS.c82), CONSTANTS.c83)
    _157 = torch.slice(_156, -1, 1024, 1536)
    _158 = torch.slice(_156, -1, 512, 1024)
    _159 = torch.slice(_156, -1, 0, 512)
    _160 = [n_batch9, -1, 8, 64]
    q14 = torch.view(_159, _160)
    k9 = torch.view(_158, _160)
    v9 = torch.view(_157, _160)
    q15 = torch.transpose(q14, 1, 2)
    k10 = torch.transpose(k9, 1, 2)
    v10 = torch.transpose(v9, 1, 2)
    q16 = torch.transpose(q15, 1, 2)
    _161 = torch.matmul(pos_emb, CONSTANTS.c84)
    p9 = torch.view(_161, _36)
    p10 = torch.transpose(p9, 1, 2)
    q_with_bias_u4 = torch.transpose(torch.add(q16, CONSTANTS.c85), 1, 2)
    q_with_bias_v4 = torch.transpose(torch.add(q16, CONSTANTS.c86), 1, 2)
    matrix_ac4 = torch.matmul(q_with_bias_u4, torch.transpose(k10, -2, -1))
    matrix_bd14 = torch.matmul(q_with_bias_v4, torch.transpose(p10, -2, -1))
    _162 = torch.size(matrix_ac4)
    _163 = torch.size(matrix_bd14)
    if torch.ne(_162, _163):
      _164 = _163[0]
      _165 = _163[1]
      _166 = _163[2]
      _167 = ops.prim.device(matrix_bd14)
      _168 = ops.prim.dtype(matrix_bd14)
      zero_pad4 = torch.zeros([_164, _165, _166, 1], dtype=_168, layout=None, device=_167)
      x_padded9 = torch.cat([zero_pad4, matrix_bd14], -1)
      _169 = torch.add(torch.size(matrix_bd14, 3), 1)
      _170 = [_164, _165, _169, torch.size(matrix_bd14, 2)]
      x_padded10 = torch.view(x_padded9, _170)
      _171 = torch.slice(torch.slice(x_padded10), 1)
      _172 = torch.view_as(torch.slice(_171, 2, 1), matrix_bd14)
      _173 = torch.slice(torch.slice(torch.slice(_172), 1), 2)
      _174 = torch.floordiv(torch.size(matrix_bd14, -1), 2)
      matrix_bd16 = torch.slice(_173, 3, None, torch.add(_174, 1))
      matrix_bd15 = matrix_bd16
    else:
      matrix_bd15 = matrix_bd14
    scores9 = torch.div(torch.add(matrix_ac4, matrix_bd15), 8.)
    n_batch10 = torch.size(v10, 0)
    if torch.gt(torch.size(masks, 2), 0):
      mask10 = torch.eq(torch.unsqueeze(masks, 1), 0)
      _175 = torch.slice(torch.slice(mask10), 1)
      mask11 = torch.slice(torch.slice(_175, 2), 3, None, torch.size(scores9, -1))
      scores10 = torch.masked_fill(scores9, mask11, -inf)
      attn10 = torch.masked_fill(torch.softmax(scores10, -1), mask11, 0.)
      attn9 = attn10
    else:
      attn9 = torch.softmax(scores9, -1)
    x31 = torch.matmul(attn9, v10)
    _176 = torch.contiguous(torch.transpose(x31, 1, 2))
    x32 = torch.view(_176, [n_batch10, -1, 512])
    _177 = torch.add(torch.matmul(x32, CONSTANTS.c87), CONSTANTS.c88)
    x33 = torch.add(x29, _177)
    x34 = torch.layer_norm(x33, [512], CONSTANTS.c89, CONSTANTS.c90)
    _178 = torch.add(torch.matmul(x34, CONSTANTS.c91), CONSTANTS.c92)
    _179 = torch.matmul(torch.silu(_178), CONSTANTS.c93)
    _180 = torch.mul(torch.add(_179, CONSTANTS.c94), 1.)
    x35 = torch.add(x33, _180)
    xs0 = torch.layer_norm(x35, [512], CONSTANTS.c95, CONSTANTS.c96)
    return (xs0, masks)
