通过summarize_clusters函数构建每个聚类的protein['cluster_profile']和protein['cluster_deletion_mean']特征。目的是把extra_msa信息反映到msa中。  

集成函数数据处理流程: sample_msa ->make_masked_msa -> nearest_neighbor_clusters -> summarize_clusters-> ...  

主要函数 tf.math.unsorted_segment_sum:用于沿指定轴对数据进行分段求和。 tf.math.unsorted_segment_sum(data, segment_ids, num_segments, name=None)

data: 输入张量,包含待求和的数据。segment_ids: 用于指定每个元素属于哪个段的一维整数张量。num_segments: 整数,表示分段的总数。name: 可选参数,用于指定操作的名称。

import tensorflow as tf

import pickle

def shape_list(x):

  """Return list of dimensions of a tensor, statically where possible.

  Like `x.shape.as_list()` but with tensors instead of `None`s.

  Args:

    x: A tensor.

  Returns:

    A list with length equal to the rank of the tensor. The n-th element of the

    list is an integer when that dimension is statically known otherwise it is

    the n-th element of `tf.shape(x)`.

  """

  x = tf.convert_to_tensor(x)

  # If unknown rank, return dynamic shape

  if x.get_shape().dims is None:

    return tf.shape(x)

  static = x.get_shape().as_list()

  shape = tf.shape(x)

  ret = []

  for i in range(len(static)):

    dim = static[i]

    if dim is None:

      dim = shape[i]

    ret.append(dim)

  return ret

def data_transforms_curry1(f):

  """Supply all arguments but the first."""

  def fc(*args, **kwargs):

    return lambda x: f(x, *args, **kwargs)

  return fc

@data_transforms_curry1

def summarize_clusters(protein):

  """Produce profile and deletion_matrix_mean within each cluster."""

  num_seq = shape_list(protein['msa'])[0]

  def csum(x):

    return tf.math.unsorted_segment_sum(

        x, protein['extra_cluster_assignment'], num_seq)

  mask = protein['extra_msa_mask']

  mask_counts = 1e-6 + protein['msa_mask'] + csum(mask)  # Include center

  

  # 结果张量[num_seq, num_resi],第一行表示和msa中的0号序列是最近邻序列的extr_msa之和,以此类推

  msa_sum = csum(mask[:, :, None] * tf.one_hot(protein['extra_msa'], 23))

  msa_sum += tf.one_hot(protein['msa'], 23)  # Original sequence

  protein['cluster_profile'] = msa_sum / mask_counts[:, :, None]

  del msa_sum

  # 每条msa序列的最近邻序列的extr_msa,在不同位置deletion数统计

  # del_sum [num_seq, num_resi],第一行表示和msa中的0号序列是最近邻序列的extr_msa,不同位置deletion数,以此类推

  del_sum = csum(mask * protein['extra_deletion_matrix'])

  del_sum += protein['deletion_matrix']  # Original sequence

  protein['cluster_deletion_mean'] = del_sum / mask_counts

  del del_sum

  return protein

with open('Human_HBB_tensor_dict_nnclusted.pkl','rb') as f:

    protein = pickle.load(f)

print(protein.keys())

protein = summarize_clusters()(protein)

print(protein.keys())

print(protein['cluster_profile'].shape)

print(protein['cluster_profile'])

print(protein['cluster_deletion_mean'].shape)

print(protein['cluster_deletion_mean'])

好文推荐

评论可见,请评论后查看内容,谢谢!!!
 您阅读本篇文章共花了: