Giter Site home page Giter Site logo

Comments (2)

chenyangkang avatar chenyangkang commented on August 15, 2024

After some testing, I think I will stick to using tabular query instead of spatial query for efficiency:

Two process seems to add complexity for model training and predicting:

  1. stemflow contains jitter and rotation operation to induce random griddings. geopandas(shapely) operation is less efficient than pandas dataframe query. This discrepancy is likely to increase as data volume gets larger.

  2. The spatial query is less efficient than tabular query at all. This is probably because the Polygon class in shapely is designed to expect highly irregular shape, while in our case the grid is squared. Querying by boundaries is better in our case.

  3. Using pandas query may allows flexibility for future multi-processing, for example, using package swifter.

These three arguments are presented with the setting in mind that stemflow is adapted for big data modeling, as mentioned in our paper. We see exponential increase in data volume each year, and seed is no wonder one of the priority consideration.

I further use lprun to see if query is the bottleneck in training and prediction. It shows that query takes less than 10% time use for training, with 91% time consumed by base-model-training. While more than half the times in prediction is used by querying the corresponding grids and models. Therefore, using tabular query may save up to 50% time used in prediction, and prediction is the REAL bottleneck for AdaSTEM.

Still, I will test if this tabular query is better for spherical coordinates as well.


--- Update Jan 23, 2024:

I finished implementing the spherical coordinates.

  • For training, the base model fitting takes 80% of time and <20% for query.
  • For predicting, the base model prediction task 30% of the time, and ~65% for query.

This is already optimized useing vectorization in determining whether a point falls in a triangle. I can imagine that using a Shapley object will cause performance problem here.


image


Additionally, I also check the relative time consumption of transformation (jitter, rotation) and query

For run_normal_query + run_linalg_transform

Line #      Hits         Time  Per Hit   % Time  Line Contents
==============================================================
     2                                           def run_normal_query(query, tabular_df, transform_func):
     3                                               # query = transform_pandas_to_geopandas(_query)
     4         5          4.0      0.8      0.0      res = []
     5         5        551.0    110.2      0.0      unique_start_time = sorted(tabular_df['start_time'].unique())
     6        50         18.0      0.4      0.0      for start_time in unique_start_time:
     7        45      12135.0    269.7      0.3          tmp_tabular_df = tabular_df[tabular_df['start_time']==start_time]
     8        90      17215.0    191.3      0.4          tmp_query = query[(query['time']>=start_time) &\
     9        45       7250.0    161.1      0.2                  (query['time']<tmp_tabular_df['end_time'].iloc[0])]
    10                                                   
    11        45      45467.0   1010.4      1.0          tmp_query = transform_func(tmp_query)
    12                                                   
    13      5045     250150.0     49.6      5.3          for index,line in tmp_tabular_df.iterrows():
    14     10000     556157.0     55.6     11.8              tmp = tmp_query[
    15     20000    2040232.0    102.0     43.1                  (tmp_query['lng']>=line['x0']) &\
    16      5000     538869.0    107.8     11.4                  (tmp_query['lng']<line['x1']) &\
    17      5000     536038.0    107.2     11.3                  (tmp_query['lat']>=line['y0']) &\
    18      5000     532655.0    106.5     11.3                  (tmp_query['lat']<line['y1'])
    19                                                       ]
    20      5000       2379.0      0.5      0.1              res.append(tmp)
    21                                                       
    22         5     190984.0  38196.8      4.0      res = pd.concat(res, axis=0)
    23         5         18.0      3.6      0.0      return res.shape[0]

For run_geo_query + run_linalg_transform

Line #      Hits         Time  Per Hit   % Time  Line Contents
==============================================================
    28                                           def run_geo_query(query, tabular_df, transform_func):
    29        10      12434.0   1243.4      0.1      tabular_df_ = gpd.GeoDataFrame(
    30        10      52634.0   5263.4      0.2          tabular_df, geometry=[Polygon([(a,c),(a,d),(b,d),(b,c)]) for a,b,c,d in zip(
    31         5        405.0     81.0      0.0              tabular_df['x0'], tabular_df['x1'], 
    32         5         65.0     13.0      0.0              tabular_df['y0'], tabular_df['y1']
    33                                                   )]
    34                                               )
    35         5    1406609.0 281321.8      5.7      query_ = transform_pandas_to_geopandas(query)
    36                                               
    37         5          5.0      1.0      0.0      res_list = []
    38         5        859.0    171.8      0.0      unique_start_time = sorted(tabular_df_['start_time'].unique())
    39        50         22.0      0.4      0.0      for start_time in unique_start_time:
    40        45      68435.0   1520.8      0.3          tmp_tabular_df_ = tabular_df_[tabular_df_['start_time']==start_time]
    41        45       5169.0    114.9      0.0          end_time = tmp_tabular_df_['end_time'].iloc[0]
    42        45     554697.0  12326.6      2.3          tmp_query_ = query_[(query_['time']>=start_time) & (query_['time']<end_time)]
    43        45   13999324.0 311096.1     57.0          tmp_query_ = transform_func(tmp_query_)
    44        45    8457689.0 187948.6     34.4          res = tmp_query_.sjoin(tmp_tabular_df_)
    45        45         54.0      1.2      0.0          res_list.append(res)
    46                                                   
    47         5      10330.0   2066.0      0.0      res_list = pd.concat(res_list, axis=0)
    48         5         10.0      2.0      0.0      return res_list.shape[0]

For run_geo_query + run_geo_transform

Line #      Hits         Time  Per Hit   % Time  Line Contents
==============================================================
    28                                           def run_geo_query(query, tabular_df, transform_func):
    29        10      11298.0   1129.8      0.0      tabular_df_ = gpd.GeoDataFrame(
    30        10      50757.0   5075.7      0.1          tabular_df, geometry=[Polygon([(a,c),(a,d),(b,d),(b,c)]) for a,b,c,d in zip(
    31         5        167.0     33.4      0.0              tabular_df['x0'], tabular_df['x1'], 
    32         5         62.0     12.4      0.0              tabular_df['y0'], tabular_df['y1']
    33                                                   )]
    34                                               )
    35         5    1519317.0 303863.4      3.7      query_ = transform_pandas_to_geopandas(query)
    36                                               
    37         5          5.0      1.0      0.0      res_list = []
    38         5        706.0    141.2      0.0      unique_start_time = sorted(tabular_df_['start_time'].unique())
    39        50         30.0      0.6      0.0      for start_time in unique_start_time:
    40        45      65283.0   1450.7      0.2          tmp_tabular_df_ = tabular_df_[tabular_df_['start_time']==start_time]
    41        45       2608.0     58.0      0.0          end_time = tmp_tabular_df_['end_time'].iloc[0]
    42        45     731824.0  16262.8      1.8          tmp_query_ = query_[(query_['time']>=start_time) & (query_['time']<end_time)]
    43        45   27501438.0 611143.1     67.1          tmp_query_ = transform_func(tmp_query_)
    44        45   11072522.0 246056.0     27.0          res = tmp_query_.sjoin(tmp_tabular_df_)
    45        45         57.0      1.3      0.0          res_list.append(res)
    46                                                   
    47         5      10622.0   2124.4      0.0      res_list = pd.concat(res_list, axis=0)
    48         5         12.0      2.4      0.0      return res_list.shape[0]

It is apparent that the transformation function for geo-object is super time consuming. In contrary, the time use for linear algebra is almost negligible.

And the pandas query is also super efficient:
query time 1 (pandas): (250150.0 + 556157 + 2040232 + 538869 + 536038 + 532655)/5=890820.2
query time 2 (sjoin): 13999324.0/5=2799864.8
query time 3 (sjoin): 11072522.0/5=2214504.4

  • pandas query is at least 2 times efficient than sjoin

Code to regenerate this:

Generate pseudo-samples:

def get_samples(sample_size = 1000):
    #
    width=height=10
    time_step = 10
    cali_point_x = np.random.uniform(-100,100, 1000)
    cali_point_y = np.random.uniform(-100,100, 1000)
    start_time = [int(i) for i in np.random.uniform(1,10,1000)]

    tabular_df = pd.DataFrame({
        'start_time':start_time,
        'end_time':[i+time_step for i in start_time],
        'x0':cali_point_x,
        'x1':cali_point_x+width,
        'y0':cali_point_y,
        'y1':cali_point_y+height
    })

    
    query = pd.DataFrame({
        'lng':np.random.uniform(-100,100,sample_size),
        'lat':np.random.uniform(-100,100,sample_size),
        'time':np.random.uniform(1,10,sample_size),
    })
    
    return tabular_df, query

Transformation function (jitter and rotation):

# linalg transform
def transform_pandas_to_geopandas(query):
    query_ = gpd.GeoDataFrame(
        query, geometry=[Point(a,b) for a,b in zip(
            query['lng'],query['lat']
            )]
    )
    return query_

def run_linalg_transform(query_):

    if isinstance(query_, gpd.geodataframe.GeoDataFrame):
        query_['lng'] = query_['geometry'].x 
        query_['lat'] = query_['geometry'].y
        
    a,b = JitterRotator.rotate_jitter(
        query_['lng'], 
        query_['lat'],
        0,50,50)

    query_.loc[:,'lng'] = np.array(a)
    query_.loc[:,'lat'] = np.array(b)
    
    if isinstance(query_, gpd.geodataframe.GeoDataFrame):
        query_ = gpd.GeoDataFrame(query_, geometry=gpd.GeoSeries.from_xy(a,b))
        
    return query_

# geo transform
def run_geo_transform(query_):
    
    query_ = JitterRotator.rotate_jitter_gpd(
        query_,
        0,50,50)

    return query_
    
    

Query Function:

# normal query:
def run_normal_query(query, tabular_df, transform_func):
    # query = transform_pandas_to_geopandas(_query)
    res = []
    unique_start_time = sorted(tabular_df['start_time'].unique())
    for start_time in unique_start_time:
        tmp_tabular_df = tabular_df[tabular_df['start_time']==start_time]
        tmp_query = query[(query['time']>=start_time) &\
                (query['time']<tmp_tabular_df['end_time'].iloc[0])]
        
        tmp_query = transform_func(tmp_query)
        
        for index,line in tmp_tabular_df.iterrows():
            tmp = tmp_query[
                (tmp_query['lng']>=line['x0']) &\
                (tmp_query['lng']<line['x1']) &\
                (tmp_query['lat']>=line['y0']) &\
                (tmp_query['lat']<line['y1'])
            ]
            res.append(tmp)
            
    res = pd.concat(res, axis=0)
    return res.shape[0]

    
# geo query:
from shapely.geometry import Point, Polygon
def run_geo_query(query, tabular_df, transform_func):
    tabular_df_ = gpd.GeoDataFrame(
        tabular_df, geometry=[Polygon([(a,c),(a,d),(b,d),(b,c)]) for a,b,c,d in zip(
            tabular_df['x0'], tabular_df['x1'], 
            tabular_df['y0'], tabular_df['y1']
        )]
    )
    query_ = transform_pandas_to_geopandas(query)
    
    res_list = []
    unique_start_time = sorted(tabular_df_['start_time'].unique())
    for start_time in unique_start_time:
        tmp_tabular_df_ = tabular_df_[tabular_df_['start_time']==start_time]
        end_time = tmp_tabular_df_['end_time'].iloc[0]
        tmp_query_ = query_[(query_['time']>=start_time) & (query_['time']<end_time)]
        tmp_query_ = transform_func(tmp_query_)
        res = tmp_query_.sjoin(tmp_tabular_df_)
        res_list.append(res)
        
    res_list = pd.concat(res_list, axis=0)
    return res_list.shape[0]

Execution

def get_time(query, tabular_df, query_func, transform_func):
    time_list = []
    
    for i in range(5):
        start_time = time.time()
        query_result_shape = query_func(query, tabular_df, transform_func)
        print(query_result_shape)
        end_time = time.time()
        time_list.append(end_time - start_time)
        
    return np.mean(time_list)
    
    

res_list = []
for sample_size in tqdm(np.logspace(2,4,8)):
    sample_size = int(sample_size)
    tabular_df, query = get_samples(sample_size = sample_size)
    
    # 1. run_normal_query + run_linalg_transform
    # 2. run_geo_query + run_linalg_transform
    # 3. run_geo_query + run_geo_transform
    
    time1 = get_time(query, tabular_df, run_normal_query, run_linalg_transform)
    time2 = get_time(query, tabular_df, run_geo_query, run_linalg_transform)
    time3 = get_time(query, tabular_df, run_geo_query, run_geo_transform)
    
    res_list.append({
        'sample_size':sample_size,
        'run_normal_query + run_linalg_transform':time1,
        # 'run_normal_query + run_geo_transform':time2,
        'run_geo_query + run_linalg_transform':time2,
        'run_geo_query + run_geo_transform':time3
    })

res_list = pd.DataFrame(res_list)

Plot results:

for var_ in res_list.columns[1:]:
    plt.plot(res_list['sample_size'], res_list[var_], '-o', label=var_,)

plt.legend()
plt.ylabel('Time use (s)')
plt.xlabel('sample size')
plt.show()

image

from stemflow.

chenyangkang avatar chenyangkang commented on August 15, 2024

Closing this.

from stemflow.

Related Issues (19)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.