5.12. Manual sharding
PopRT manual sharding supports dividing the model into different subgraphs through sharding points provided by users to achieve model parallelism and pipeline parallelism.
5.12.2. Pipelining and pipeline parallelism
PopRT supports sharding the ONNX graph across different pipeline stages based on the sharding points provided by users to achieve pipeline parallelism and improve throughput.
For more information, refer to the sections on pipelining in the technical note and in the IPU Programmer’s Guide.
Note
To use pipeline parallelism, it is necessary to enable model parallelism and set the PopRT backend options as follows:
options.enable_pipelining= Trueoptions.batches_per_step= integer multiple of the number of pipeline stages
5.12.3. Manual sharding process
PopRT manual sharding shards the ONNX graph based on the ONNX node, and the sharding point can be any ONNX node.
The nodes in the ONNX graph are arranged in topological sorting order. PopRT manual sharding first performs topological sorting of the sharding points set by the user.
Traverse the sharding point. Take the sharding point as the starting point to traverse the ONNX graph in the direction of input, and put all the traversed ONNX nodes into a subgraph. If there is no input node or the node has already set sharding information, then stop the traversal of such branch.
After the traversal is completed, you will get the subgraph. Set the sharding information of the subgraph using ONNX attribute:
__ipu_numberspecifies the device serial number corresponding to each subgraph in model parallelism__pipeline_stagespecifies the pipeline stage corresponding to each subgraph in pipeline parallelism.
Note
Different sharding points can have the same device serial number and pipeline stage. For example, if there are two parallel branches started from different sharding points, and we want to put them onto a single device, then these two sharding points will have same device serial number.
After the sharding information is set based on the sharding point, the remaining nodes without sharding information are automatically set:
__ipu_numberwill be set to the currently set maximum device serial number +1.__pipeline_stagewill be set to the currently set maximum pipeline stage +1.
5.12.4. Configuring manual sharding
There are two methods for configuring manual sharding:
with the PopRT CLI
with the
poprt.converter.Sharderclass.
Configuring manual sharding with the PopRT CLI
Specify the sharding point name, device serial number and pipeline stage with the yaml file:
1-
2 node: resnetv17_stage1__plus0
3 device: 0
4 stage: 0
5-
6 node: resnetv17_stage4_batchnorm2_fwd
7 device: 1
8 stage: 1
9-
10 node: resnetv17_stage4__plus0
11 device: 2
12 stage: 2
Configuring sharding information with
--manual_sharding_configin the PopRT CLI:
poprt \
--input_model model.onnx \
--manual_sharding_config shard.yaml
Determine whether to perform manual sharding only on
input_modelwith--only_manual_shardingin the PopRT CLI, which is not set by default.Not setting
--only_manual_shardingmeans that manual sharding is performed after the Convert phase optimisation oninput_model.Setting
--only_manual_shardingmeans that only manual sharding is performed oninput_model. Only--input_model,--output_model,--output_dirand--manual_sharding_configare supported; other parameters are invalid.
poprt \
--input_model model.onnx \
--manual_sharding_config shard.yaml \
--only_manual_sharding
Configuring manual sharding with the Python API
You can use poprt.converter.Sharder to configure manual sharding.
sharding_info = {
"resnetv17_stage1__plus0": 0,
"resnetv17_stage4_batchnorm2_fwd": 1,
"resnetv17_stage4__plus0: 2,
}
pipelining_info = {
"resnetv17_stage1__plus0": 0,
"resnetv17_stage4_batchnorm2_fwd": 1,
"resnetv17_stage4__plus0: 2,
}
sharded_model = poprt.converter.Sharder(
sharding_info=sharding_info,
pipelining_info=pipelining_info
).run(converted_model)
Note
The PopRT CLI with
--only_manual_shardingset or the use ofpoprt.converter.SharderAPI needs to guarantee that every node in the ONNX graph hasunique name.The PopRT CLI without
--only_manual_shardingset does not need to guarantee that every node in the ONNX graph hasunique name. The Convert optimisation process will guarantee that every node hasunique name.
5.12.5. Example
The following is a simple example of manual sharding:
Take ResNet50 as an example.
1# Copyright (c) 2023 Graphcore Ltd. All rights reserved.
2import numpy as np
3import onnx
4import requests
5
6from poprt import runtime
7from poprt.compiler import Compiler, CompilerOptions
8from poprt.converter import Sharder
9
10
11def load_model():
12 # Download model
13 url = 'https://github.com/onnx/models/raw/main/vision/classification/resnet/model/resnet50-v1-7.onnx'
14 response = requests.get(url)
15 if response.status_code == 200:
16 model = onnx.load_model_from_string(response.content)
17 else:
18 raise Exception(
19 f"Failed to download model with status_code {response.status_code}"
20 )
21 return model
22
23
24def manual_sharding(model):
25 # Fix the batch size to 1
26 model.graph.input[0].type.tensor_type.shape.dim[0].dim_value = 1
27
28 # Sharding and pipelining info
29 sharding_info = {
30 "resnetv17_stage1__plus0": 0,
31 "resnetv17_stage4_batchnorm2_fwd": 1,
32 "resnetv17_stage4__plus0": 2,
33 }
34 pipelining_info = {
35 "resnetv17_stage1__plus0": 0,
36 "resnetv17_stage4_batchnorm2_fwd": 1,
37 "resnetv17_stage4__plus0": 2,
38 }
39 model = Sharder(sharding_info=sharding_info, pipelining_info=pipelining_info).run(
40 model
41 )
42
43 return model
44
45
46def compile(model):
47 # Compile the model with backend options
48 model_bytes = model.SerializeToString()
49 outputs = [o.name for o in model.graph.output]
50
51 options = CompilerOptions()
52 options.ipu_version = runtime.DeviceManager().ipu_hardware_version()
53 # Sharding into 4 IPUs
54 options.num_ipus = 4
55 # Enable Sharding and Pipelining
56 options.enable_pipelining = True
57 options.virtual_graph_mode = "manual"
58 options.batches_per_step = 16
59
60 executable = Compiler.compile(model_bytes, outputs, options)
61 runner_config = runtime.RuntimeConfig()
62 runner_config.timeout_ns = 0
63 runner = runtime.Runner(executable, runner_config)
64 return runner
65
66
67def run(runner):
68 inputs_info = runner.get_execute_inputs()
69 outputs_info = runner.get_execute_outputs()
70
71 inputs = {}
72 for i in inputs_info:
73 inputs[i.name] = np.ones(i.shape, dtype=i.numpy_data_type())
74
75 outputs = {}
76 for o in outputs_info:
77 outputs[o.name] = np.zeros(o.shape, dtype=o.numpy_data_type())
78
79 runner.execute(inputs, outputs)
80
81
82if __name__ == '__main__':
83 model = load_model()
84 model = manual_sharding(model)
85 runner = compile(model)
86 run(runner)