5.12. Manual sharding
PopRT manual sharding supports dividing the model into different subgraphs through sharding points provided by users to achieve model parallelism and pipeline parallelism.
5.12.2. Pipelining and pipeline parallelism
PopRT supports sharding the ONNX graph across different pipeline stages based on the sharding points provided by users to achieve pipeline parallelism and improve throughput.
For more information, refer to the sections on pipelining in the technical note and in the IPU Programmer’s Guide.
Note
To use pipeline parallelism, it is necessary to enable model parallelism and set the PopRT backend options as follows:
options.enable_pipelining
= Trueoptions.batches_per_step
= integer multiple of the number of pipeline stages
5.12.3. Manual sharding process
PopRT manual sharding shards the ONNX graph based on the ONNX node, and the sharding point can be any ONNX node.
The nodes in the ONNX graph are arranged in topological sorting order. PopRT manual sharding first performs topological sorting of the sharding points set by the user.
Traverse the sharding point. Take the sharding point as the starting point to traverse the ONNX graph in the direction of input, and put all the traversed ONNX nodes into a subgraph. If there is no input node or the node has already set sharding information, then stop the traversal of such branch.
After the traversal is completed, you will get the subgraph. Set the sharding information of the subgraph using ONNX attribute:
__ipu_number
specifies the device serial number corresponding to each subgraph in model parallelism__pipeline_stage
specifies the pipeline stage corresponding to each subgraph in pipeline parallelism.
Note
Different sharding points can have the same device serial number and pipeline stage. For example, if there are two parallel branches started from different sharding points, and we want to put them onto a single device, then these two sharding points will have same device serial number.
After the sharding information is set based on the sharding point, the remaining nodes without sharding information are automatically set:
__ipu_number
will be set to the currently set maximum device serial number +1.__pipeline_stage
will be set to the currently set maximum pipeline stage +1.
5.12.4. Configuring manual sharding
There are two methods for configuring manual sharding:
with the PopRT CLI
with the
poprt.converter.Sharder
class.
Configuring manual sharding with the PopRT CLI
Specify the sharding point name, device serial number and pipeline stage with the yaml file:
1-
2 node: resnetv17_stage1__plus0
3 device: 0
4 stage: 0
5-
6 node: resnetv17_stage4_batchnorm2_fwd
7 device: 1
8 stage: 1
9-
10 node: resnetv17_stage4__plus0
11 device: 2
12 stage: 2
Configuring sharding information with
--manual_sharding_config
in the PopRT CLI:
poprt \
--input_model model.onnx \
--manual_sharding_config shard.yaml
Determine whether to perform manual sharding only on
input_model
with--only_manual_sharding
in the PopRT CLI, which is not set by default.Not setting
--only_manual_sharding
means that manual sharding is performed after the Convert phase optimisation oninput_model
.Setting
--only_manual_sharding
means that only manual sharding is performed oninput_model
. Only--input_model
,--output_model
,--output_dir
and--manual_sharding_config
are supported; other parameters are invalid.
poprt \
--input_model model.onnx \
--manual_sharding_config shard.yaml \
--only_manual_sharding
Configuring manual sharding with the Python API
You can use poprt.converter.Sharder
to configure manual sharding.
sharding_info = {
"resnetv17_stage1__plus0": 0,
"resnetv17_stage4_batchnorm2_fwd": 1,
"resnetv17_stage4__plus0: 2,
}
pipelining_info = {
"resnetv17_stage1__plus0": 0,
"resnetv17_stage4_batchnorm2_fwd": 1,
"resnetv17_stage4__plus0: 2,
}
sharded_model = poprt.converter.Sharder(
sharding_info=sharding_info,
pipelining_info=pipelining_info
).run(converted_model)
Note
The PopRT CLI with
--only_manual_sharding
set or the use ofpoprt.converter.Sharder
API needs to guarantee that every node in the ONNX graph hasunique name
.The PopRT CLI without
--only_manual_sharding
set does not need to guarantee that every node in the ONNX graph hasunique name
. The Convert optimisation process will guarantee that every node hasunique name
.
5.12.5. Example
The following is a simple example of manual sharding:
Take ResNet50 as an example.
1# Copyright (c) 2023 Graphcore Ltd. All rights reserved.
2import numpy as np
3import onnx
4import requests
5
6from poprt import runtime
7from poprt.compiler import Compiler, CompilerOptions
8from poprt.converter import Sharder
9
10
11def load_model():
12 # Download model
13 url = 'https://github.com/onnx/models/raw/main/vision/classification/resnet/model/resnet50-v1-7.onnx'
14 response = requests.get(url)
15 if response.status_code == 200:
16 model = onnx.load_model_from_string(response.content)
17 else:
18 raise Exception(
19 f"Failed to download model with status_code {response.status_code}"
20 )
21 return model
22
23
24def manual_sharding(model):
25 # Fix the batch size to 1
26 model.graph.input[0].type.tensor_type.shape.dim[0].dim_value = 1
27
28 # Sharding and pipelining info
29 sharding_info = {
30 "resnetv17_stage1__plus0": 0,
31 "resnetv17_stage4_batchnorm2_fwd": 1,
32 "resnetv17_stage4__plus0": 2,
33 }
34 pipelining_info = {
35 "resnetv17_stage1__plus0": 0,
36 "resnetv17_stage4_batchnorm2_fwd": 1,
37 "resnetv17_stage4__plus0": 2,
38 }
39 model = Sharder(sharding_info=sharding_info, pipelining_info=pipelining_info).run(
40 model
41 )
42
43 return model
44
45
46def compile(model):
47 # Compile the model with backend options
48 model_bytes = model.SerializeToString()
49 outputs = [o.name for o in model.graph.output]
50
51 options = CompilerOptions()
52 options.ipu_version = runtime.DeviceManager().ipu_hardware_version()
53 # Sharding into 4 IPUs
54 options.num_ipus = 4
55 # Enable Sharding and Pipelining
56 options.enable_pipelining = True
57 options.virtual_graph_mode = "manual"
58 options.batches_per_step = 16
59
60 executable = Compiler.compile(model_bytes, outputs, options)
61 runner_config = runtime.RuntimeConfig()
62 runner_config.timeout_ns = 0
63 runner = runtime.Runner(executable, runner_config)
64 return runner
65
66
67def run(runner):
68 inputs_info = runner.get_execute_inputs()
69 outputs_info = runner.get_execute_outputs()
70
71 inputs = {}
72 for i in inputs_info:
73 inputs[i.name] = np.ones(i.shape, dtype=i.numpy_data_type())
74
75 outputs = {}
76 for o in outputs_info:
77 outputs[o.name] = np.zeros(o.shape, dtype=o.numpy_data_type())
78
79 runner.execute(inputs, outputs)
80
81
82if __name__ == '__main__':
83 model = load_model()
84 model = manual_sharding(model)
85 runner = compile(model)
86 run(runner)